diff --git a/CMakeLists.txt b/CMakeLists.txt index d83ee4f224f31f198b43d2b46608ecadc02d6f1e..bb8fcf0d11c2f3d4ac0ecb2189161ace66d74d08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -120,44 +120,22 @@ endif() include(cmake/common_funcs.cmake) add_subdirectory(inc) -add_subdirectory(proto) -add_subdirectory(graph) -add_subdirectory(exe_graph) add_subdirectory(error_manager) -add_subdirectory(register) add_subdirectory(base) if (ENABLE_METADEF_UT OR ENABLE_METADEF_ST OR ENABLE_BENCHMARK) find_package(benchmark CONFIG REQUIRED) add_subdirectory(tests) endif() -install(TARGETS exe_graph lowering error_manager graph graph_base opp_registry metadef register rt2_registry_static metadef_headers - aihac_ir aihac_ir_register aihac_symbolizer +install(TARGETS exe_graph error_manager opp_registry rt2_registry_static metadef EXPORT metadef-targets LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR} OPTIONAL COMPONENT opensdk ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR} OPTIONAL COMPONENT opensdk RUNTIME DESTINATION ${INSTALL_RUNTIME_DIR} OPTIONAL COMPONENT opensdk ) -if (NOT ENABLE_OPEN_SRC) - install(TARGETS graph_base_static metadef_static - ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR} OPTIONAL - ) - install(TARGETS atc_stub_graph_static stub_exe_graph_static stub_register_static - ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR}/stub OPTIONAL - ) -endif () - # 下列头文件发布是非法的,需要在后续整改中删掉 # --------------------start------------------------ -install(FILES ${METADEF_DIR}/third_party/transformer/inc/axis_util.h - ${METADEF_DIR}/third_party/transformer/inc/expand_dimension.h - ${METADEF_DIR}/third_party/transformer/inc/transfer_shape_utils.h - ${METADEF_DIR}/third_party/transformer/inc/transfer_range_according_to_format.h - ${METADEF_DIR}/third_party/transformer/inc/transfer_shape_according_to_format.h - ${METADEF_DIR}/third_party/transformer/inc/transfer_def.h - DESTINATION ${INSTALL_INCLUDE_DIR}/metadef/transformer COMPONENT opensdk EXCLUDE_FROM_ALL -) install(FILES ${METADEF_DIR}/register/op_tiling/op_tiling_constants.h ${METADEF_DIR}/register/op_tiling/op_compile_info_manager.h ${METADEF_DIR}/register/op_tiling/op_tiling_utils.h @@ -206,15 +184,8 @@ if (ENABLE_OPEN_SRC) install(TARGETS error_manager exe_graph - lowering - graph - graph_base - register opp_registry rt2_registry_static - aihac_ir - aihac_ir_register - aihac_symbolizer metadef LIBRARY DESTINATION ${ARCH_LINX_PATH}/lib64 OPTIONAL COMPONENT packages EXCLUDE_FROM_ALL ARCHIVE DESTINATION ${ARCH_LINX_PATH}/lib64 OPTIONAL COMPONENT packages EXCLUDE_FROM_ALL diff --git a/README.md b/README.md index 3652a077ada10fe055b6f340d8ddfab69ad2bdd1..d5f4173bd188b900c482f890b2fbb4d19c9973d7 100644 --- a/README.md +++ b/README.md @@ -9,18 +9,21 @@ ```angular2html metadef ├── error_manager # 相关错误码定义 -├── exe_graph -| ├── lowering # 执行图构图接口相关实现 -| ├── runtime # 执行图执行接口相关实现 -├── graph # 图相关接口实现,包括图缓存模块、序列化 +├── base +| ├── attr # 属性存储相关实现 +| ├── common # 算子包管理相关实现 +| ├── context_builder # 构建算子执行上下文相关实现 +| ├── registry # 算子注册相关实现 +| ├── runtime # 算子执行接口相关实现 +| ├── type # 基础数据类型相关实现 +| ├── utils # 工具类相关实现 ├── inc +| ├── base # 一些基础数据结构头文件 | ├── common # 一些公共头文件 | ├── exe_graph # 执行图头文件 | ├── external # 对外发布的头文件(保证兼容性) | ├── graph # 图接口相关头文件 | ├── register # 算子注册头文件 -├── proto # 图相关proto定义 -├── register # 算子注册实现 ├── tests # 开发者测试目录 ``` diff --git a/base/CMakeLists.txt b/base/CMakeLists.txt index cae15d0d4bb096efaeb9439ebee406f9a1a71a53..340dbaef097e955a0aed39a558674e6ce2537ced 100644 --- a/base/CMakeLists.txt +++ b/base/CMakeLists.txt @@ -65,6 +65,66 @@ target_link_libraries(opp_registry metadef_headers ) +############ librt2_registry.a ############ +add_library(rt2_registry_objects OBJECT + "${METADEF_DIR}/base/registry/op_impl_registry.cc" + "${METADEF_DIR}/base/registry/op_ct_impl_registry.cc" + "${METADEF_DIR}/base/registry/op_impl_functions.cc" + "${METADEF_DIR}/base/registry/op_bin_info.cc" + ) + +target_compile_options(rt2_registry_objects PRIVATE + $<$,$>: -fvisibility=hidden -fno-common -fPIC -O2 -Werror -Wextra -Wfloat-equal> + $<$:/utf-8> + $<$,$>:/MTd> + $<$,$>:/MT> + ) + +target_compile_definitions(rt2_registry_objects PRIVATE + $,OS_TYPE=WIN,OS_TYPE=0> + $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> + $<$:ONLY_COMPILE_OPEN_SRC> + ) + +target_include_directories(rt2_registry_objects PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${METADEF_DIR}/inc + ${METADEF_DIR}/inc/external + ${TOP_DIR}/ace/npuruntime/runtime/platform/inc + ) + +set_target_properties(rt2_registry_objects PROPERTIES + WINDOWS_EXPORT_ALL_SYMBOLS TRUE + OUTPUT_NAME $,librt2_registry,rt2_registry> + ) + +target_link_libraries(rt2_registry_objects + PRIVATE + $ + c_sec + slog + PUBLIC + metadef_headers + ) + +############ librt2_registry.a ############ +add_library(rt2_registry_static STATIC + $ + ) + +set_target_properties(rt2_registry_static PROPERTIES + WINDOWS_EXPORT_ALL_SYMBOLS TRUE + OUTPUT_NAME $,librt2_registry,rt2_registry> + ) + +target_link_libraries(rt2_registry_static + PUBLIC + metadef_headers + ) + +install(TARGETS rt2_registry_static + ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR} OPTIONAL + ) ############################################################## set(STUB_HEADER_LIST ${METADEF_DIR}/inc/external/base/registry/op_impl_space_registry_v2.h @@ -266,7 +326,7 @@ add_custom_target(metadef_stub DEPENDS ${STUB_SRC_LIST}) ############ stub/libmetadef.so ############ add_library(stub_metadef SHARED ${STUB_SRC_LIST}) -add_dependencies(stub_metadef metadef_protos metadef_stub) +add_dependencies(stub_metadef metadef_stub) target_include_directories(stub_metadef PRIVATE ${CMAKE_CURRENT_LIST_DIR} @@ -302,7 +362,7 @@ set_target_properties(stub_metadef PROPERTIES if (NOT ENABLE_OPEN_SRC) target_clone(stub_metadef stub_metadef_static STATIC) - add_dependencies(stub_metadef_static metadef_protos metadef_stub) + add_dependencies(stub_metadef_static metadef_stub) target_compile_options(stub_metadef_static PRIVATE $<$:-O2 -fPIC -Wextra -Wfloat-equal -Wno-array-bounds>) diff --git a/register/op_bin_info.cc b/base/registry/op_bin_info.cc similarity index 99% rename from register/op_bin_info.cc rename to base/registry/op_bin_info.cc index c1ebeefee50bf109d58512473c0bd07b977c77ad..cbfc0d97d981bf455f3401bda3531233aa5a4eb4 100644 --- a/register/op_bin_info.cc +++ b/base/registry/op_bin_info.cc @@ -21,8 +21,7 @@ #include #include #include -#include "graph/operator_reg.h" -#include "register/op_bin_info_utils.h" +#include "op_bin_info_utils.h" #include "toolchain/slog.h" #include "common/ge_common/debug/ge_log.h" diff --git a/register/op_bin_info_utils.h b/base/registry/op_bin_info_utils.h similarity index 100% rename from register/op_bin_info_utils.h rename to base/registry/op_bin_info_utils.h diff --git a/base/registry/op_impl_registry.cc b/base/registry/op_impl_registry.cc index ceaefaf1ff4cffc28ec5a8d93b7db3100600d587..15be8ccea6b2d9b846cc2a714810aeb2f399e8ce 100644 --- a/base/registry/op_impl_registry.cc +++ b/base/registry/op_impl_registry.cc @@ -17,55 +17,6 @@ namespace gert { namespace { -enum class kRegFuncType : int32_t { - kInferShape = 0, - kInferShapeRange, - kInferDataType, - kTiling, - kGenSimplifiedKey, - kInputsDependency, - kHostInputs, - kTilingDependency, - kTilingDependencyPlacement, - kOpExecuteFunc, - kOp2StageExecuteFuncs, - kTilingParse, - kPrivateAtrr, - kOutputShapeDependCompute, - kInferFormat, - kCalcOpParam, - kGenTask, - kCheckSupport, - kOpSelectFormat, - kInvalid -}; -const std::vector kRegFuncToString = {"InferShape", - "InferShapeRange", - "InferDataType", - "Tiling", - "GenSimplifiedKey", - "InputsDependency", - "HostInputs", - "TilingDependency", - "TilingDependencyPlacement", - "OpExecuteFunc", - "Op2StageExecuteFunc", - "TilingParse", - "IsPrivateAtrrReg", - "OutputShapeDependCompute", - "InferFormat", - "CalcOpParam", - "GenerateTask", - "CheckSupport", - "OpSelectFormat", - "Invalid"}; -std::string RegFuncTypeToString(kRegFuncType type) { - static const bool is_valid = ((kRegFuncToString.size() - 1UL) == static_cast(kRegFuncType::kInvalid)); - if (!is_valid) { - return ""; - } - return kRegFuncToString[static_cast(type)]; -} void RegisterOpImplToRegistry(const OpImplRegisterV2Impl *rd) { if (rd == nullptr) { GELOGW("The register data is invalid, the impl is nullptr"); @@ -74,84 +25,84 @@ void RegisterOpImplToRegistry(const OpImplRegisterV2Impl *rd) { auto &funcs = OpImplRegistry::GetInstance().CreateOrGetOpImpl(rd->op_type.GetString()); std::stringstream ss; if (rd->functions.infer_shape != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kInferShape) << "]"; + ss << "[InferShape]"; funcs.infer_shape = rd->functions.infer_shape; } if (rd->functions.infer_shape_range != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kInferShapeRange) << "]"; + ss << "[InferShapeRange]"; funcs.infer_shape_range = rd->functions.infer_shape_range; } if (rd->functions.infer_datatype != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kInferDataType) << "]"; + ss << "[InferDataType]"; funcs.infer_datatype = rd->functions.infer_datatype; } if (rd->functions.tiling != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kTiling) << "]"; + ss << "[Tiling]"; funcs.tiling = rd->functions.tiling; funcs.max_tiling_data_size = rd->functions.max_tiling_data_size; } if (rd->functions.gen_simplifiedkey != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kGenSimplifiedKey) << "]"; + ss << "[GenSimplifiedKey]"; funcs.gen_simplifiedkey = rd->functions.gen_simplifiedkey; } if (rd->functions.inputs_dependency != 0U) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kInputsDependency) << "]"; + ss << "[InputsDependency]"; funcs.inputs_dependency = rd->functions.inputs_dependency; } if (rd->functions.host_inputs != 0U) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kHostInputs) << "]"; + ss << "[HostInputs]"; funcs.host_inputs = rd->functions.host_inputs; } if (rd->functions.tiling_dependency != 0U) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kTilingDependency) << "]"; + ss << "[TilingDependency]"; funcs.tiling_dependency = rd->functions.tiling_dependency; } if (rd->functions.tiling_dependency_placements != 0U) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kTilingDependencyPlacement) << "]"; + ss << "[TilingDependencyPlacement]"; funcs.tiling_dependency_placements = rd->functions.tiling_dependency_placements; } if (rd->functions.op_execute_func != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kOpExecuteFunc) << "]"; + ss << "[OpExecuteFunc]"; funcs.op_execute_func = rd->functions.op_execute_func; } if ((rd->functions.op_execute_prepare_func != nullptr) && (rd->functions.op_execute_launch_func != nullptr)) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kOp2StageExecuteFuncs) << "]"; + ss << "[Op2StageExecuteFuncs]"; funcs.op_execute_prepare_func = rd->functions.op_execute_prepare_func; funcs.op_execute_launch_func = rd->functions.op_execute_launch_func; } if (rd->functions.tiling_parse != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kTilingParse) << "]"; + ss << "[TilingParse]"; funcs.tiling_parse = rd->functions.tiling_parse; funcs.compile_info_creator = rd->functions.compile_info_creator; funcs.compile_info_deleter = rd->functions.compile_info_deleter; } if (rd->is_private_attr_registered) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kPrivateAtrr) << "]"; + ss << "[PrivateAtrr]"; funcs.private_attrs = rd->functions.private_attrs; funcs.unique_private_attrs = rd->functions.unique_private_attrs; } if (rd->functions.output_shape_depend_compute != 0U) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kOutputShapeDependCompute) << "]"; + ss << "[OutputShapeDependCompute]"; funcs.output_shape_depend_compute = rd->functions.output_shape_depend_compute; } if (rd->functions.infer_format_func != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kInferFormat) << "]"; + ss << "[InferFormat]"; funcs.infer_format_func = rd->functions.infer_format_func; } if (rd->functions.calc_op_param != 0U) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kCalcOpParam) << "]"; + ss << "[CalcOpParam]"; funcs.calc_op_param = rd->functions.calc_op_param; } if (rd->functions.gen_task != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kGenTask) << "]"; + ss << "[GenTask]"; funcs.gen_task = rd->functions.gen_task; } if (rd->functions.check_support != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kCheckSupport) << "]"; + ss << "[CheckSupport]"; funcs.check_support = rd->functions.check_support; } if (rd->functions.op_select_format != nullptr) { - ss << "[" << RegFuncTypeToString(kRegFuncType::kOpSelectFormat) << "]"; + ss << "[OpSelectFormat]"; funcs.op_select_format = rd->functions.op_select_format; } GELOGI("Op type[%s] register OP_IMPL : %s", rd->op_type.GetString(), ss.str().c_str()); diff --git a/base/registry/op_impl_registry_holder_manager.cc b/base/registry/op_impl_registry_holder_manager.cc index b6d7d1916529edfb836f9982408e9be1a2f26926..aec2ea338c49528e5ff4b64f611e95e446085a09 100644 --- a/base/registry/op_impl_registry_holder_manager.cc +++ b/base/registry/op_impl_registry_holder_manager.cc @@ -188,17 +188,14 @@ ImplMenu kImplMenuVec[static_cast(ImplType::END_TYPE)] = { }; ge::graphStatus OpImplRegistryHolder::GetOpImplFunctionsByHandle(const void *handle, const std::string &so_path) { - if (handle == nullptr) { - GELOGE(ge::FAILED, "handle is nullptr"); - return ge::GRAPH_FAILED; - } for (size_t i = 0; i < static_cast(ImplType::END_TYPE); ++i) { const auto &impl_menu = kImplMenuVec[i]; // 兼容1.0注册方式注册自定义算子so,查找不到符号告警返回 - const auto get_impl_num = reinterpret_cast(mmDlsym(const_cast(handle), - impl_menu.get_reg_num_func.c_str())); + // mmDlsym暂不支持handle为nullptr的功能,先用系统的dlsym + const auto get_impl_num = + reinterpret_cast(dlsym(const_cast(handle), impl_menu.get_reg_num_func.c_str())); if (get_impl_num == nullptr) { - const ge::char_t *error = mmDlerror(); + const ge::char_t *error = dlerror(); error = (error == nullptr) ? "" : error; GELOGW("Get registered op num functions failed, path:%s, errmsg:%s", so_path.c_str(), error); return ge::GRAPH_FAILED; @@ -208,9 +205,9 @@ ge::graphStatus OpImplRegistryHolder::GetOpImplFunctionsByHandle(const void *han if ((impl_num == 0U) && !impl_menu.need_check_empty) { continue; } - const auto void_impl_func = mmDlsym(const_cast(handle), impl_menu.get_reg_impl_func.c_str()); + const auto void_impl_func = dlsym(const_cast(handle), impl_menu.get_reg_impl_func.c_str()); if (void_impl_func == nullptr) { - const ge::char_t *error = mmDlerror(); + const ge::char_t *error = dlerror(); error = (error == nullptr) ? "" : error; GELOGW("Get op impl functions failed, path:%s, errmsg:%s", so_path.c_str(), error); if (impl_menu.type == ImplType::RT_V2_TYPE) { diff --git a/base/registry/op_impl_space_registry_v2_impl.cc b/base/registry/op_impl_space_registry_v2_impl.cc index a1b791a3cf52b8a0b7a5291d5ee6b4c5a92887d0..3392ad4edef733da0a449976c4c1b3f5b1706f88 100644 --- a/base/registry/op_impl_space_registry_v2_impl.cc +++ b/base/registry/op_impl_space_registry_v2_impl.cc @@ -18,6 +18,7 @@ #include "graph/debug/ge_util.h" #include "graph/utils/file_utils.h" #include "op_impl_space_registry_v2_impl.h" +#include "register/op_impl_registry_base.h" #include "graph/any_value.h" #include "base/err_msg.h" @@ -44,7 +45,30 @@ void CloseHandle(void *const handle) { } } // namespace +ge::graphStatus OpImplSpaceRegistryImpl::AddMainExeToRegistry(const OppSoDesc &so_desc) { + auto types_to_impl_from_holder = std::map(); + for (const auto &so_path_ascend_string : so_desc.GetSoPaths()) { + auto so_path = so_path_ascend_string.GetString(); + GELOGI("Start to add main_exe op_impl to registry."); + const auto om_registry_holder = ge::MakeShared(); + GE_CHECK_NOTNULL(om_registry_holder); + if (om_registry_holder->GetOpImplFunctionsByHandle(RTLD_DEFAULT, so_path) != ge::GRAPH_SUCCESS) { + GELOGW("Failed to get funcs from so!"); + return ge::GRAPH_FAILED; + } + for (const auto &type : om_registry_holder->GetTypesToImpl()) { + types_to_impl_from_holder[type.first] = type.second; + } + GE_ASSERT_GRAPH_SUCCESS(AddRegistry(om_registry_holder)); + GELOGI("Save so symbol and handle in main_exe successfully!"); + } + return ge::GRAPH_SUCCESS; +} + ge::graphStatus OpImplSpaceRegistryImpl::AddSoToRegistry(const OppSoDesc &so_desc) { + if (so_desc.GetPackageName() == "main_exe") { + return AddMainExeToRegistry(so_desc); + } auto types_to_impl_from_holder = std::map(); for (const auto &so_path_ascend_string : so_desc.GetSoPaths()) { auto so_path = so_path_ascend_string.GetString(); @@ -105,7 +129,8 @@ ge::graphStatus OpImplSpaceRegistryImpl::AddSoToRegistry(const OppSoDesc &so_des const OpImplKernelRegistry::OpImplFunctionsV2 *OpImplSpaceRegistryImpl::GetOpImpl(const std::string &op_type) const { const auto iter = merged_types_to_impl_.find(op_type.c_str()); if (iter == merged_types_to_impl_.cend()) { - return nullptr; + GELOGW("Get %s's op_mpl from local registry.", op_type.c_str()); + return gert::OpImplRegistry::GetInstance().GetOpImpl(op_type.c_str()); } return &iter->second; } diff --git a/base/registry/op_impl_space_registry_v2_impl.h b/base/registry/op_impl_space_registry_v2_impl.h index e4da4c7eec15b57ed75058b72f84674305ebb29c..9fc86bf4d7af9d7ede24cfced1e5648f80bf6627 100644 --- a/base/registry/op_impl_space_registry_v2_impl.h +++ b/base/registry/op_impl_space_registry_v2_impl.h @@ -24,6 +24,8 @@ class OpImplSpaceRegistryImpl { ge::graphStatus AddSoToRegistry(const OppSoDesc &so_desc); + ge::graphStatus AddMainExeToRegistry(const OppSoDesc &so_desc); + const OpImplKernelRegistry::OpImplFunctionsV2 *GetOpImpl(const std::string &op_type) const; ge::graphStatus AddRegistry(const std::shared_ptr ®istry_holder); diff --git a/base/registry/opp_so_manager.cc b/base/registry/opp_so_manager.cc index de37ca8e4d5ef52075b58612d71874b13d710a53..c3c0b622715260d63974d7bc021340212fd1562b 100644 --- a/base/registry/opp_so_manager.cc +++ b/base/registry/opp_so_manager.cc @@ -105,6 +105,7 @@ void OppSoManager::LoadSoAndInitDefault(const std::vector &so_list space_registry_v2 = std::make_shared(); gert::DefaultOpImplSpaceRegistryV2::GetInstance().SetSpaceRegistry(space_registry_v2, opp_version_tag); } + if (space_registry_v2->AddSoToRegistry(opp_so_desc) != ge::SUCCESS) { GELOGW("AddSoToRegistry failed, package name is %s", package_name.c_str()); } @@ -148,6 +149,14 @@ void OppSoManager::LoadOppPackage() const { LoadSoAndInitDefault(so_list_opp.GetSoPaths(), version, package_name); } } + + // 静态库场景,算子注册到主程序中 + auto space_registry_v2 = + gert::DefaultOpImplSpaceRegistryV2::GetInstance().GetSpaceRegistry(); + gert::OppSoDesc opp_main_desc({ge::GetModelPath().c_str()}, ge::AscendString("main_exe")); + if ((space_registry_v2 != nullptr) && (space_registry_v2->AddSoToRegistry(opp_main_desc) != ge::SUCCESS)) { + GELOGW("Add main_exe Registry failed"); + } } void OppSoManager::LoadOpsProtoSo(gert::OppImplVersionTag version, std::vector> &package_to_opp_so_desc, bool is_split) const { diff --git a/build.sh b/build.sh index 7fee88ff1a47bc64e101d083addb49954b42744e..39a403aca149fc6dee043fe30b64e739566e5d36 100755 --- a/build.sh +++ b/build.sh @@ -142,7 +142,7 @@ build_metadef() { echo "CMAKE_ARGS is: $CMAKE_ARGS" cmake_generate_make "${BUILD_PATH}" "${CMAKE_ARGS}" - make graph graph_base exe_graph lowering register register_static rt2_registry_static error_manager error_manager_static ${VERBOSE} -j${THREAD_NUM} && \ + make exe_graph error_manager error_manager_static ${VERBOSE} -j${THREAD_NUM} && \ make install && make package if [ 0 -ne $? ]; then diff --git a/classify_rule.yaml b/classify_rule.yaml index eda8bce860ca27243499b4af839fcb900f866418..e5f509c01bfbd6d0cb77c74fb95dc2cbe452b375 100644 --- a/classify_rule.yaml +++ b/classify_rule.yaml @@ -13,7 +13,7 @@ metadef: - base unrelease: - register/graph_optimizer - - register/op_tiling + - base/registry/op_bin_info.cc - inc/register/graph_optimizer - inc/register/op_tiling.h - graph/attr/ge_attr_define.cc diff --git a/exe_graph/CMakeLists.txt b/exe_graph/CMakeLists.txt deleted file mode 100644 index 26743145186feef0e2bd536be6f1b8784c34c945..0000000000000000000000000000000000000000 --- a/exe_graph/CMakeLists.txt +++ /dev/null @@ -1,165 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -include(${METADEF_DIR}/cmake/build_type.cmake) - -set(LOWERING_SRCS - lowering/data_dependent_interpreter.cc - lowering/device_tiling_context_builder.cc - lowering/getcdim.cc - lowering/kernel_run_context_builder.cc - lowering/shape_utils.cc - lowering/tiling_context_builder.cc - lowering/tiling_parse_context_builder.cc - lowering/bg_ir_attrs.cc - lowering/bg_kernel_context_extend.cc - lowering/buffer_pool.cc - lowering/dev_mem_value_holder.cc - lowering/frame_selector.cc - lowering/generate_exe_graph.cc - lowering/lowering_global_data.cc - lowering/value_holder.cc - lowering/value_holder_utils.cc -) - -######### liblowering.so ############# -add_library(lowering SHARED ${LOWERING_SRCS}) - -target_include_directories(lowering PRIVATE - ${METADEF_DIR} -) - -target_include_directories(lowering PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_link_libraries(lowering PRIVATE - intf_pub - mmpa_headers - metadef_headers - c_sec - slog - $<$>:-lrt> - -ldl -) - -target_compile_definitions(lowering PRIVATE - google=ascend_private - $<$:ONLY_COMPILE_OPEN_SRC> - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> -) - -target_compile_options(lowering PRIVATE - $<$:-O2 -fPIC -Wextra -Wfloat-equal -fno-common> - $<$,$>:/MTd> - $<$,$>:/MT> -) - -target_link_libraries(lowering PRIVATE - ascend_protobuf - graph_base - register -) - -if (NOT ENABLE_OPEN_SRC) - ######### liblowering_static.a ############# - target_clone(lowering lowering_static STATIC) - target_link_libraries(lowering_static PRIVATE - ascend_protobuf_static - graph_static - register_static - ) - target_compile_options(lowering_static PRIVATE $<$:-O2 -fPIC -Wextra -Wfloat-equal>) - set_target_properties(lowering_static PROPERTIES OUTPUT_NAME lowering) - - ############ install ############ - install(TARGETS lowering_static OPTIONAL - ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR} - ) -endif() - -############################################################## -set(STUB_HEADER_LIST_LOWERING - ${METADEF_DIR}/inc/exe_graph/lowering/bg_kernel_context_extend.h - ${METADEF_DIR}/inc/exe_graph/lowering/device_tiling_context_builder.h -) - -list(TRANSFORM STUB_HEADER_LIST_LOWERING - REPLACE "^.*/([^/]+)\\.h$" "${CMAKE_CURRENT_BINARY_DIR}/stub_\\1.cc" - OUTPUT_VARIABLE STUB_SRC_LIST_LOWERING -) - -add_custom_command( - OUTPUT ${STUB_SRC_LIST_LOWERING} - COMMAND echo "Generating stub files." - && ${HI_PYTHON} ${METADEF_DIR}/tests/stub/gen_stubapi.py ${CMAKE_CURRENT_BINARY_DIR} ${STUB_HEADER_LIST_LOWERING} - && echo "Generating stub files end." -) - -add_custom_target(lowering_stub DEPENDS ${STUB_SRC_LIST_LOWERING}) - -############ stub/liblowering.so ############ -add_library(stub_lowering SHARED ${STUB_SRC_LIST_LOWERING}) - -add_dependencies(stub_lowering lowering_stub) - -target_include_directories(stub_lowering PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_compile_options(stub_lowering PRIVATE - -Wfloat-equal - -fno-common - -Os - -Werror=return-type -) - -target_link_libraries(stub_lowering - PRIVATE - intf_pub - c_sec_headers - PUBLIC - metadef_headers -) - -set_target_properties(stub_lowering PROPERTIES - OUTPUT_NAME lowering - LIBRARY_OUTPUT_DIRECTORY stub -) - -############ stub/liblowering.a ############ -if (NOT ENABLE_OPEN_SRC) - target_clone(stub_lowering stub_lowering_static STATIC) - - add_dependencies(stub_lowering_static lowering_stub) - - target_compile_options(stub_lowering_static PRIVATE - -ffunction-sections - -fdata-sections - ) - set_target_properties(stub_lowering_static PROPERTIES - OUTPUT_NAME lowering - ARCHIVE_OUTPUT_DIRECTORY stub - ) - - ############ install ############ - install(TARGETS stub_lowering_static OPTIONAL - ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR}/stub - ) -endif () - -############ install ############ -install(TARGETS stub_lowering OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR}/stub -) diff --git a/exe_graph/lowering/bg_ir_attrs.cc b/exe_graph/lowering/bg_ir_attrs.cc deleted file mode 100644 index 0159a42f27ac3bfb4e1ca2dcf5e74b82bd22b90d..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/bg_ir_attrs.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/bg_ir_attrs.h" - -#include -#include -#include "common/ge_common/debug/ge_log.h" -#include "graph/utils/math_util.h" -#include "graph/def_types.h" -#include "external/graph/types.h" -#include "common/checker.h" -#include "graph/debug/ge_util.h" - -#include "exe_graph/runtime/tensor.h" -#include "base/attr/attrs_to_buffer.h" - -namespace gert { -namespace bg { -namespace { -void GeShapeToGertShape(const ge::GeShape &ge_shape, gert::Shape &gert_shape) { - gert_shape.SetDimNum(ge_shape.GetDimNum()); - for (size_t i = 0; i < ge_shape.GetDimNum(); ++i) { - gert_shape.SetDim(i, ge_shape.GetDim(i)); - } -} - -size_t GetGeTensorSize(const ge::GeTensor &tensor) { - auto dt = tensor.GetTensorDesc().GetDataType(); - if (dt == ge::DT_STRING) { - return tensor.GetData().GetSize(); - } - auto shape_size = tensor.GetTensorDesc().GetShape().GetShapeSize(); - return static_cast(ge::GetSizeInBytes(shape_size, dt)); -} -bool AppendTensorAttr(const ge::AnyValue &attr, std::vector> &attrs) { - auto val = attr.Get(); - GE_ASSERT_NOTNULL(val); - auto &tensor_desc = val->GetTensorDesc(); - auto shape_size = tensor_desc.GetShape().GetShapeSize(); - if (shape_size < 0) { - GELOGE(ge::PARAM_INVALID, "Failed to append tensor attr, shape size less than 0"); - return false; - } - size_t total_size; - size_t tensor_size = GetGeTensorSize(*val); - auto tensor_holder = Tensor::CreateFollowing(val->GetTensorDesc().GetDataType(), tensor_size, total_size); - GE_ASSERT_NOTNULL(tensor_holder); - auto tensor = ge::PtrToPtr(tensor_holder.get()); - GeShapeToGertShape(tensor_desc.GetShape(), tensor->MutableStorageShape()); - GeShapeToGertShape(tensor_desc.GetOriginShape(), tensor->MutableOriginShape()); - tensor->SetOriginFormat(tensor_desc.GetOriginFormat()); - tensor->SetStorageFormat(tensor_desc.GetFormat()); - if (total_size < sizeof(Tensor)) { - GELOGE(ge::PARAM_INVALID, "total_size[%zu] < size of Tensor[%zu]", total_size, sizeof(Tensor)); - return false; - } - const auto copy_len = total_size - sizeof(Tensor); - if (copy_len != 0U) { - GE_CHECK_GE(val->GetData().size(), total_size - sizeof(Tensor)); - const auto ret_copy = ge::GeMemcpy(tensor->GetData(), total_size - sizeof(Tensor), - val->GetData().GetData(), total_size - sizeof(Tensor)); - GE_ASSERT_TRUE((ret_copy == ge::SUCCESS), "memcpy_s failed, copy size is %zu", (total_size - sizeof(Tensor))); - } - - std::vector buf(total_size); - const auto ret = ge::GeMemcpy(buf.data(), total_size, tensor_holder.get(), total_size); - GE_ASSERT_TRUE((ret == ge::SUCCESS), "memcpy_s failed, copy size is %zu", total_size); - attrs.emplace_back(std::move(buf)); - return true; -} - -bool AppendAttrTensor(const ge::AnyValue &attr, std::vector> &attrs) { - switch (attr.GetValueType()) { - case ge::AnyValue::VT_TENSOR: - return AppendTensorAttr(attr, attrs); - default: - return AppendAttr(attr, attrs); - } -} - -bool GetAllIrAttrs(const ge::NodePtr &node, std::vector> &runtime_attrs) { - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - const auto &all_attrs = ge::AttrUtils::GetAllAttrs(op_desc); - const auto &ir_attr_names = op_desc->GetIrAttrNames(); - for (const auto &attr_name : ir_attr_names) { - const std::map::const_iterator &iter = all_attrs.find(attr_name); - if (iter == all_attrs.cend()) { - runtime_attrs.clear(); - GELOGI("Can not find the IR attr %s from node %s(%s), clear all attrs", - attr_name.c_str(), node->GetNamePtr(), node->GetTypePtr()); - return true; - } - GE_ASSERT_TRUE(AppendAttrTensor(iter->second, runtime_attrs)); - } - return true; -} - -} // namespace -std::unique_ptr CreateAttrBuffer(const ge::NodePtr &node, size_t &size) { - return CreateAttrBuffer(node, {}, size); -} - -std::unique_ptr CreateAttrBuffer(const ge::NodePtr &node, - const std::vector &runtime_attrs_list, - size_t &size) { - std::vector> runtime_attrs; - GE_ASSERT_TRUE(GetAllIrAttrs(node, runtime_attrs)); - for (auto &runtime_attr : runtime_attrs_list) { - AppendAttrTensor(runtime_attr, runtime_attrs); - } - return CreateAttrBuffer(runtime_attrs, size); -} - -std::unique_ptr CreateAttrBufferWithoutIr(const ge::NodePtr &node, - const std::vector &runtime_attrs_list, - size_t &size) { - (void)node; - std::vector> runtime_attrs; - for (auto &runtime_attr : runtime_attrs_list) { - AppendAttrTensor(runtime_attr, runtime_attrs); - } - return CreateAttrBuffer(runtime_attrs, size); -} -} // namespace bg -} // namespace gert diff --git a/exe_graph/lowering/bg_kernel_context_extend.cc b/exe_graph/lowering/bg_kernel_context_extend.cc deleted file mode 100644 index cbe7775b9c0f406da2a61b905d7192763603cece..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/bg_kernel_context_extend.cc +++ /dev/null @@ -1,228 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/bg_kernel_context_extend.h" - -#include "common/ge_common/debug/ge_log.h" -#include "common/checker.h" - -#include "exe_graph/lowering/bg_ir_attrs.h" -#include "exe_graph/runtime/context_extend.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/math_util.h" - -namespace gert { -namespace bg { -namespace { -ge::graphStatus InitIOInstanceInfo(const ge::NodePtr &node, ComputeNodeInfo &compute_node_info) { - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - auto in_ir_index_to_instance_index_pair_map - = ge::OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(node->GetOpDesc()); - if (in_ir_index_to_instance_index_pair_map.empty()) { - GELOGI("node [%s(%s)] ir_index_to_instance_index_pair_map is empty", - node->GetNamePtr(), node->GetTypePtr()); - return ge::GRAPH_SUCCESS; - } - const auto &ir_inputs = op_desc->GetIrInputs(); - size_t input_index = 0; - for (size_t i = 0; i < ir_inputs.size(); ++i) { - auto ins_info = compute_node_info.MutableInputInstanceInfo(i); - GE_ASSERT_NOTNULL(ins_info); - size_t instance_num = in_ir_index_to_instance_index_pair_map[i].second; - compute_node_info.MutableInputInstanceInfo(i)->SetInstantiationNum(instance_num); - compute_node_info.MutableInputInstanceInfo(i)->SetInstanceStart(input_index); - input_index += instance_num; - } - - auto out_ir_index_to_instance_index_pair_map - = ge::OpDescUtils::GetOutputIrIndexes2InstanceIndexesPairMap(node->GetOpDesc()); - if (out_ir_index_to_instance_index_pair_map.empty()) { - GELOGI("node [%s(%s)] output ir_index_to_instance_index_pair_map is empty", - node->GetNamePtr(), node->GetTypePtr()); - return ge::GRAPH_SUCCESS; - } - const auto &ir_outputs = op_desc->GetIrOutputs(); - size_t output_index = 0; - for (size_t i = 0; i < ir_outputs.size(); ++i) { - auto ins_info = compute_node_info.MutableOutputInstanceInfo(i); - GE_ASSERT_NOTNULL(ins_info); - size_t instance_num = out_ir_index_to_instance_index_pair_map[i].second; - compute_node_info.MutableOutputInstanceInfo(i)->SetInstantiationNum(instance_num); - compute_node_info.MutableOutputInstanceInfo(i)->SetInstanceStart(output_index); - output_index += instance_num; - } - return ge::GRAPH_SUCCESS; -} - -void SetCompileTimeTd(const ge::ConstGeTensorDescPtr &desc, CompileTimeTensorDesc &td) { - td.SetDataType(desc->GetDataType()); - td.SetOriginFormat(desc->GetOriginFormat()); - td.SetStorageFormat(desc->GetFormat()); - int64_t reshape_type_mask = 0; - if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_RESHAPE_TYPE_MASK, reshape_type_mask)) { - td.SetExpandDimsType(ExpandDimsType(reshape_type_mask)); - } -} - -ge::graphStatus GetConnectedEdgeIndexesToAnchorIndexMap(const ge::NodePtr &node, - std::map &connected_edge_indexes_to_anchor_index) { - size_t compute_node_index = 0U; - for (const auto anchor : node->GetAllInDataAnchorsPtr()) { - GE_ASSERT_NOTNULL(anchor); - if (anchor->GetPeerOutAnchor() == nullptr) { - continue; - } - GE_ASSERT_NOTNULL(anchor->GetPeerOutAnchor()->GetOwnerNodeBarePtr()); - connected_edge_indexes_to_anchor_index[compute_node_index] = static_cast(anchor->GetIdx()); - ++compute_node_index; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus InitCompileTimeTD(const ge::NodePtr &node, ComputeNodeInfo &compute_node_info) { - std::map connected_edge_indexes_to_anchor_index; - const auto ret = GetConnectedEdgeIndexesToAnchorIndexMap(node, connected_edge_indexes_to_anchor_index); - if (ret != ge::GRAPH_SUCCESS) { - GELOGE(ret, "get connected edge indexes to anchor index map failed. node:%s(%s)", - node->GetName().c_str(), node->GetType().c_str()); - return ret; - } - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - GE_ASSERT_TRUE(connected_edge_indexes_to_anchor_index.size() == compute_node_info.GetInputsNum()); - for (size_t i = 0; i < compute_node_info.GetInputsNum(); ++i) { - GE_ASSERT_TRUE(i < op_desc->GetAllInputsSize()); - const auto &desc_need_check = op_desc->GetInputDesc(connected_edge_indexes_to_anchor_index[i]); - if (desc_need_check.IsValid() != ge::GRAPH_SUCCESS) { - continue; - } - const auto &desc = op_desc->GetInputDescPtr(connected_edge_indexes_to_anchor_index[i]); - GE_ASSERT_NOTNULL(desc); - auto td = compute_node_info.MutableInputTdInfo(i); - GE_ASSERT_NOTNULL(td); - SetCompileTimeTd(desc, *td); - } - - for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { - const auto &desc = op_desc->GetOutputDescPtr(i); - GE_ASSERT_NOTNULL(desc); - auto td = compute_node_info.MutableOutputTdInfo(i); - GE_ASSERT_NOTNULL(td); - SetCompileTimeTd(desc, *td); - } - return ge::SUCCESS; -} -bool GetPrivateAttrsList(const ge::NodePtr &node, const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, - std::vector &runtime_attrs_list) { - const auto op = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op); - const auto &all_attrs = op->GetAllAttrs(); - for (auto &private_attr : private_attrs) { - auto &private_attr_name = private_attr.first; - auto iter = all_attrs.find(private_attr_name.GetString()); - if (iter == all_attrs.end()) { - if (!private_attr.second.IsEmpty()) { - runtime_attrs_list.push_back(private_attr.second); - continue; - } - GELOGE(ge::FAILED, "Can not find the private attr %s from node %s", - private_attr_name.GetString(), node->GetName().c_str()); - return false; - } - runtime_attrs_list.push_back(iter->second); - } - return true; -} -std::unique_ptr CreateComputeNodeInfoImpl(const std::unique_ptr &attr_buf, - const size_t attr_size, - const ge::NodePtr &node, - BufferPool &buffer_pool, - size_t &total_size) { - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - const size_t ir_input_num = op_desc->GetIrInputs().size(); - const size_t ir_output_num = op_desc->GetIrOutputs().size(); - const size_t input_num = node->GetInDataNodesAndAnchors().size(); - const uint32_t output_num = node->GetAllOutDataAnchorsSize(); - GELOGD("node: %s(%s), ir_input_num:%zu, ir_output_num:%zu, input_num:%zu, output_num:%u.", node->GetNamePtr(), - node->GetTypePtr(), ir_input_num, ir_output_num, input_num, output_num); - GE_ASSERT_SUCCESS(ComputeNodeInfo::CalcSize(ir_input_num, ir_output_num, input_num, output_num, total_size)); - GE_ASSERT_TRUE(!ge::AddOverflow(total_size, attr_size, total_size)); - auto compute_node_info_holder = ge::ComGraphMakeUnique(total_size); - GE_ASSERT_NOTNULL(compute_node_info_holder, "Create compute node info holder failed"); - - auto node_name = buffer_pool.AddStr(node->GetNamePtr()); - auto node_type = buffer_pool.AddStr(node->GetTypePtr()); - auto compute_node_info = ge::PtrToPtr(compute_node_info_holder.get()); - compute_node_info->Init(ir_input_num, ir_output_num, input_num, output_num, attr_size, - ge::PtrToPtr(ge::ValueToPtr(node_name)), - ge::PtrToPtr(ge::ValueToPtr(node_type))); - - auto ret = InitIOInstanceInfo(node, *compute_node_info); - GE_ASSERT_SUCCESS(ret, "Init input instance info for node:%s failed.", node->GetNamePtr()); - - ret = InitCompileTimeTD(node, *compute_node_info); - GE_ASSERT_SUCCESS(ret, "Init compile time tensor desc for node:%s failed.", node->GetNamePtr()); - - auto attr = compute_node_info->MutableAttrs(); - const auto offset = ge::PtrToPtr(attr) - compute_node_info_holder.get(); - if (static_cast(offset) > total_size) { - GELOGE( - ge::FAILED, - "Failed to create kernel context extend info, the offset of attr %zu beyond the total size of ExtendInfo %zu", - offset, attr_size); - return nullptr; - } - const auto outputs_ins_info_size = compute_node_info->GetIrOutputsNum() * sizeof(AnchorInstanceInfo); - ret = ge::GeMemcpy(ge::PtrToPtr(attr), (total_size - offset - outputs_ins_info_size), - attr_buf.get(), attr_size); - GE_ASSERT_SUCCESS(ret, "memcpy_s failed, copy size is %zu, dst size is %zu", attr_size, - (total_size - offset - outputs_ins_info_size)); - GELOGI("Node %s, compute_node_info attr_size %zu, outputs_ins_info_size:%zu, offset:%zu, total_size:%zu.", - node->GetNamePtr(), attr_size, outputs_ins_info_size, offset, total_size); - return compute_node_info_holder; -} -} // namespace - -std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool, size_t &total_size) { - size_t attr_size; - const auto attr_buf = CreateAttrBuffer(node, attr_size); - GE_ASSERT_NOTNULL(attr_buf, "Create attr buffer for node: %s failed", node->GetNamePtr()); - return CreateComputeNodeInfoImpl(attr_buf, attr_size, node, buffer_pool, total_size); -} - -std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, - BufferPool &buffer_pool, - const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, - size_t &total_size) { - std::vector runtime_attrs_list; - GE_ASSERT_TRUE(GetPrivateAttrsList(node, private_attrs, runtime_attrs_list)); - size_t attr_size; - const auto attr_buf = CreateAttrBuffer(node, runtime_attrs_list, attr_size); - GE_ASSERT_NOTNULL(attr_buf, "Create attr buffer for node: %s failed", node->GetNamePtr()); - return CreateComputeNodeInfoImpl(attr_buf, attr_size, node, buffer_pool, total_size); -} -std::unique_ptr CreateComputeNodeInfoWithoutIrAttr(const ge::NodePtr &node, BufferPool &buffer_pool, - const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, size_t &total_size) { - std::vector runtime_attrs_list; - GE_ASSERT_TRUE(GetPrivateAttrsList(node, private_attrs, runtime_attrs_list)); - size_t attr_size; - const auto attr_buf = CreateAttrBufferWithoutIr(node, runtime_attrs_list, attr_size); - GE_ASSERT_NOTNULL(attr_buf, "Create attr buffer without ir for node: %s failed", node->GetNamePtr()); - return CreateComputeNodeInfoImpl(attr_buf, attr_size, node, buffer_pool, total_size); -} -std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool) { - size_t total_size; - return CreateComputeNodeInfo(node, buffer_pool, total_size); -} -} // namespace bg -} // namespace gert diff --git a/exe_graph/lowering/buffer_pool.cc b/exe_graph/lowering/buffer_pool.cc deleted file mode 100644 index fa0a6cbff8023d25e6eaadf85fd5f3b4fb416f79..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/buffer_pool.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/buffer_pool.h" - -#include -#include "common/ge_common/debug/ge_log.h" -#include "graph/utils/math_util.h" -#include "graph/debug/ge_log.h" -#include "graph/def_types.h" -#include "graph/debug/ge_util.h" -#include "common/checker.h" - -#include "exe_graph/runtime/continuous_buffer.h" - -namespace gert { -namespace bg { -namespace { -constexpr size_t kLargeBufSizeThreshold = 1024U * 1024U; // 1M -} -BufferPool::BufId BufferPool::AddBuf(const uint8_t *data, const size_t len) { - if (len >= kLargeBufSizeThreshold) { - return AddLargeBuf(std::string(ge::PtrToPtr(data), len)); - } - return AddBuf(std::string(ge::PtrToPtr(data), len)); -} -BufferPool::BufId BufferPool::AddStr(const char *data) { - size_t len = strlen(data) + 1; - if (len >= kLargeBufSizeThreshold) { - return AddLargeBuf(std::string(data, len)); - } - return AddBuf(std::string(data, len)); -} -BufferPool::BufId BufferPool::AddBuf(std::string &&str) { - auto res = bufs_to_id_.emplace(std::move(str), id_generator_); - if (res.second) { - ++id_generator_; - } - return res.first->second; -} -BufferPool::BufId BufferPool::AddLargeBuf(std::string &&str) { - auto id = id_generator_++; - large_bufs_to_id_.emplace_back(std::move(str), id); - return id; -} -std::unique_ptr BufferPool::Serialize() const { - size_t total_size; - return Serialize(total_size); -} -std::unique_ptr BufferPool::Serialize(size_t &total_size) const { - total_size = sizeof(ContinuousBuffer); - const size_t buf_count = id_generator_; - size_t offset_size; - size_t text_offset; - // 申请了n个,但是使用时会用n+1个,多的一个由ContinuousText自带 - if (ge::MulOverflow(sizeof(size_t), buf_count, offset_size)) { - GE_LOGE("Failed to serialize buffer pool, size overflow, buf num %zu", buf_count); - return nullptr; - } - if (ge::AddOverflow(total_size, offset_size, total_size)) { - GE_LOGE("Failed to serialize buffer pool, size overflow, buf size %zu", offset_size); - return nullptr; - } - text_offset = total_size; - - std::vector ids_to_buf(buf_count); - for (const auto &iter : bufs_to_id_) { - if (iter.second >= buf_count) { - return nullptr; - } - ids_to_buf[iter.second] = &iter.first; - - if (ge::AddOverflow(total_size, iter.first.size(), total_size)) { - GE_LOGE("Failed to serialize buffer pool, size overflow, buf size %zu, id %zu", iter.first.size(), iter.second); - return nullptr; - } - } - for (const auto &iter : large_bufs_to_id_) { - if (iter.second >= buf_count) { - return nullptr; - } - ids_to_buf[iter.second] = &iter.first; - - if (ge::AddOverflow(total_size, iter.first.size(), total_size)) { - GE_LOGE("Failed to serialize buffer pool, size overflow, buf size %zu, id %zu", iter.first.size(), iter.second); - return nullptr; - } - } - - auto text_holder = ge::ComGraphMakeUnique(total_size); - GE_ASSERT_NOTNULL(text_holder); - - auto text = ge::PtrToPtr(text_holder.get()); - text->num_ = buf_count; - text->reserved_ = 0; - size_t i = 0; - for (; i < buf_count; ++i) { - const auto buf = ids_to_buf[i]; - if (buf == nullptr) { - GELOGE(ge::FAILED, "Failed to serialize text pool, miss buf id %zu", i); - return nullptr; - } - const auto ret = ge::GeMemcpy(text_holder.get() + text_offset, total_size - text_offset, - ge::PtrToPtr(buf->data()), buf->size()); - GE_ASSERT_TRUE((ret == ge::SUCCESS), "memcpy_s failed, copy size is %zu, dst size is %zu", - buf->size(), total_size - text_offset); - text->offsets_[i] = text_offset; - text_offset += buf->size(); - } - text->offsets_[i] = text_offset; - - return text_holder; -} -const char *BufferPool::GetBufById(const BufId id) const { - for (const auto &buf_and_id : bufs_to_id_) { - if (buf_and_id.second == id) { - return buf_and_id.first.c_str(); - } - } - for (const auto &buf_and_id : large_bufs_to_id_) { - if (buf_and_id.second == id) { - return buf_and_id.first.c_str(); - } - } - return nullptr; -} -size_t BufferPool::GetSize() const { - return id_generator_; -} -} // namespace bg -} // namespace gert diff --git a/exe_graph/lowering/data_dependent_interpreter.cc b/exe_graph/lowering/data_dependent_interpreter.cc deleted file mode 100644 index 6c2c927a163e73a018db1fd639e6626a7232451b..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/data_dependent_interpreter.cc +++ /dev/null @@ -1,276 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "lowering/data_dependent_interpreter.h" -#include "common/checker.h" -#include "graph/node.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_type_utils.h" -#include "graph/type/sym_dtype.h" - -namespace gert { -namespace { -constexpr const ge::char_t* kUbGraph = "_original_fusion_graph"; -bool IsUbFusedNode(const ge::OpDesc *const op_desc) { - return ge::AttrUtils::HasAttr(op_desc, kUbGraph); -} -ge::graphStatus IsDataDependentByAttr(const ge::OpDesc *op_desc, const int32_t input_index, bool &is_data_dependent) { - GE_ASSERT_NOTNULL(op_desc); - const auto &data_dependent_inputs = op_desc->GetOpInferDepends(); - if (data_dependent_inputs.empty()) { - is_data_dependent = false; - return ge::GRAPH_SUCCESS; - } - const auto &input_name = op_desc->GetValidInputNameByIndex(static_cast(input_index)); - is_data_dependent = std::find(data_dependent_inputs.cbegin(), data_dependent_inputs.cend(), input_name) != - data_dependent_inputs.cend(); - return ge::GRAPH_SUCCESS; -} -ge::NodePtr FindSubgraphDataNode(const ge::ComputeGraphPtr &graph, int32_t parent_node_index) { - for (const auto &node : graph->GetDirectNode()) { - if (node->GetType() != "Data") { - continue; - } - int32_t parent_index = 0; - if (!ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(ge::INTERNAL_ERROR, "[Get][Attr] failed, node:[%s(%s)] attr:[%s]", node->GetName().c_str(), - node->GetType().c_str(), ge::ATTR_NAME_PARENT_NODE_INDEX.c_str()); - REPORT_INNER_ERR_MSG("E19999", "invoke GetInt failed, node:[%s(%s)] attr:[%s]", node->GetName().c_str(), - node->GetType().c_str(), ge::ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return nullptr; - } - if (parent_index == parent_node_index) { - return node; - } - } - return nullptr; -} - -ge::graphStatus GetInputIrIndexByInstanceIndex(const ge::OpDescPtr &op_desc, const size_t instance_index, - size_t &ir_index) { - GE_ASSERT_NOTNULL(op_desc); - std::map> ir_index_to_instance_index_pair_map; - GE_ASSERT_GRAPH_SUCCESS(ge::GetIrInputInstanceDescRange(op_desc, ir_index_to_instance_index_pair_map)); - GE_ASSERT_TRUE(!ir_index_to_instance_index_pair_map.empty()); - ir_index = std::numeric_limits::max(); - for (size_t i = 0U; i < op_desc->GetIrInputs().size(); ++i) { - const auto &index_pair = ir_index_to_instance_index_pair_map[i]; - size_t ir_index_end = 0U; - GE_ASSERT_TRUE(!ge::AddOverflow(index_pair.first, index_pair.second, ir_index_end)); - if ((instance_index >= index_pair.first) && (instance_index < ir_index_end)) { - ir_index = i; - GELOGD("node [%s(%s)] get ir index [%zu] successfully!", op_desc->GetName().c_str(), op_desc->GetType().c_str(), - ir_index); - return ge::GRAPH_SUCCESS; - } - } - GELOGW("node [%s(%s)] failed to get ir index by instance index[%zu], set ir_index to %zu", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), instance_index, ir_index); - return ge::GRAPH_SUCCESS; -} -} // namespace - -DataDependentInterpreter::DataDependentInterpreter(const ge::NodePtr &node, - const gert::OpImplSpaceRegistryPtr &space_registry) - : node_(node) { - if (node_ != nullptr) { - op_desc_ = node_->GetOpDesc(); - } - if (space_registry != nullptr) { - space_registries_[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - } - use_registry_v2_ = false; -} -DataDependentInterpreter::DataDependentInterpreter(const ge::OpDescPtr &op_desc, - const gert::OpImplSpaceRegistryArray &space_registry) - : op_desc_(op_desc), space_registries_(space_registry) { -} - -DataDependentInterpreter::DataDependentInterpreter(const ge::OpDescPtr &op_desc, - const gert::OpImplSpaceRegistryV2Array &space_registry) - : op_desc_(op_desc), space_registries_v2_(space_registry), use_registry_v2_(true) {} - - -ge::graphStatus DataDependentInterpreter::IsDataDependentByImplOp(const int32_t input_index, - bool &is_data_dependent) const { - const auto op_impl = GetOpImplFunctionsV2(); - if (op_impl == nullptr) { - GELOGW("The node %s type %s does not registered by `IMPL_OP`", op_desc_->GetNamePtr(), op_desc_->GetType().c_str()); - is_data_dependent = false; - // 这里产生了变更,原有实现中,如果impl找不到,并且1.0标记了任意一个输入为数据依赖,那么整个节点所有输入都会被认为是数据依赖。 - // 变更后,如果impl找不到,那么仅会返回1.0标记的输入为数据依赖。这个变更影响应该不大,验证过后,本注释可以被删除 - return ge::GRAPH_SUCCESS; - } - if (!op_impl->HasDataDependency()) { - is_data_dependent = false; - return ge::GRAPH_SUCCESS; - } - size_t ir_index; - const ge::graphStatus ret = GetInputIrIndexByInstanceIndex(op_desc_, static_cast(input_index), ir_index); - if (ret != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Failed to get ir index by input_index[%d] for node %s(%s).", input_index, - op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); - return ge::FAILED; - } - is_data_dependent = op_impl->IsInputDataDependency(ir_index); - return ge::GRAPH_SUCCESS; -} - -// 此接口返回的结果表示算子是否只支持TilingDepend -// true: 算子只注册了tilingDepend,未注册DataDepend,表示算子在Infershape的时候不是dataDepend, -// 在tiling时是dataDepend -// false:其他情况 -ge::graphStatus DataDependentInterpreter::IsTilingInputDataDependent(const int32_t index, - bool &is_tiling_dependent) const { - const auto op_impl = GetOpImplFunctionsV2(); - if (op_impl == nullptr) { - GELOGW("The node %s type %s does not registered by `IMPL_OP`", op_desc_->GetNamePtr(), op_desc_->GetType().c_str()); - is_tiling_dependent = false; - return ge::GRAPH_SUCCESS; - } - if (!op_impl->HasTilingInputDataDependency()) { - is_tiling_dependent = false; - return ge::GRAPH_SUCCESS; - } - - size_t ir_index = 0UL; - const ge::graphStatus ret = GetInputIrIndexByInstanceIndex(op_desc_, static_cast(index), ir_index); - if (ret != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "Failed to get ir index by input_index[%d] for node %s(%s).", index, - op_desc_->GetName().c_str(), op_desc_->GetType().c_str()); - return ge::GRAPH_FAILED; - } - is_tiling_dependent = op_impl->IsTilingInputDataDependency(ir_index); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus DataDependentInterpreter::IsSupportTilingDependPlacement(const uint32_t placement, - bool &is_support) const { - const auto op_impl = GetOpImplFunctionsV2(); - if (op_impl == nullptr) { - GELOGW("The node %s type %s does not registered by `IMPL_OP`", op_desc_->GetNamePtr(), op_desc_->GetType().c_str()); - is_support = false; - return ge::GRAPH_SUCCESS; - } - is_support = op_impl->IsSupportTilingDependencyPlacement(placement); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus DataDependentInterpreter::IsDataDependent(const int32_t index, bool &is_data_dependent) const { - bool by_ir = false; - GE_ASSERT_SUCCESS(IsDataDependentByIr(index, by_ir)); - GE_ASSERT_NOTNULL(op_desc_); - if (!IsUbFusedNode(op_desc_.get())) { - is_data_dependent = by_ir; - return ge::GRAPH_SUCCESS; - } - - bool by_ub = false; - GE_ASSERT_SUCCESS(IsDataDependentByUbGraph(index, by_ub)); - is_data_dependent = GetByIrAndUb(by_ir, by_ub, index); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus DataDependentInterpreter::IsDataDependentByIr(int32_t index, bool &is_data_dependent) const { - bool by_1_0 = false; - bool by_2_0 = false; - GE_ASSERT_SUCCESS(IsDataDependentByImplOp(index, by_2_0)); - GE_ASSERT_SUCCESS(IsDataDependentByAttr(op_desc_.get(), index, by_1_0)); - - is_data_dependent = GetByIr(by_1_0, by_2_0, index); - return ge::GRAPH_SUCCESS; -} -bool DataDependentInterpreter::GetByIr(bool by_1_0, bool by_2_0, int32_t index_for_log) const { - if (by_1_0 == by_2_0) { - return by_2_0; - } - GE_ASSERT_NOTNULL(op_desc_); - if (by_1_0) { // by_2_0 is false - GELOGW( - "The node %s type %s input index %d is interpreted data-dependent, because there is data dependent attr on the " - "node. But the IMPL_OP does not registered as data-dependent", - op_desc_->GetNamePtr(), op_desc_->GetTypePtr(), index_for_log); - } - return true; -} -ge::graphStatus DataDependentInterpreter::IsDataDependentByUbGraph(int32_t index, bool &is_data_dependent) const { - GE_ASSERT_NOTNULL(op_desc_); - auto ub_graph = GetUbGraph(); - GE_ASSERT_NOTNULL(ub_graph); - - const auto data_node = FindSubgraphDataNode(ub_graph, index); - GE_ASSERT_NOTNULL(data_node, "Failed to find the data node from ub graph by index %d from node %s type %s.", - index, op_desc_->GetNamePtr(), op_desc_->GetTypePtr()); - - is_data_dependent = false; - for (const auto &node_and_anchor : data_node->GetOutDataNodesAndAnchors()) { - bool node_data_dependent; - GE_ASSERT_NOTNULL(node_and_anchor.first); - if (use_registry_v2_) { - GE_ASSERT_SUCCESS(DataDependentInterpreter(node_and_anchor.first->GetOpDesc(), space_registries_v2_) - .IsDataDependentByIr(node_and_anchor.second->GetIdx(), node_data_dependent)); - } else { - GE_ASSERT_SUCCESS(DataDependentInterpreter(node_and_anchor.first->GetOpDesc(), space_registries_) - .IsDataDependentByIr(node_and_anchor.second->GetIdx(), node_data_dependent)); - } - if (node_data_dependent) { - is_data_dependent = true; - break; - } - } - - return ge::GRAPH_SUCCESS; -} -bool DataDependentInterpreter::GetByIrAndUb(bool by_ir, bool by_ub, int32_t index_for_log) const { - if (by_ir == by_ub) { - return by_ir; - } - - if (by_ir) { // by_ub is false - GELOGW( - "The UB-fused node %s type %s input index %d is interpreted data-dependent. The data-dependent flag is marked " - "by IR, but not the UB graph", - op_desc_->GetNamePtr(), op_desc_->GetTypePtr(), index_for_log); - } - return true; -} -ge::ComputeGraphPtr DataDependentInterpreter::GetUbGraph() const { - GE_ASSERT_NOTNULL(op_desc_); - if (ub_graph_cache_ == nullptr) { - GE_ASSERT_TRUE(ge::AttrUtils::GetGraph(op_desc_.get(), kUbGraph, ub_graph_cache_)); - } - return ub_graph_cache_; -} - -const OpImplKernelRegistry::OpImplFunctionsV2 *DataDependentInterpreter::GetOpImplFunctionsV2() const { - GE_ASSERT_NOTNULL(op_desc_); - std::string type; - GE_ASSERT_SUCCESS(ge::OpTypeUtils::GetOriginalType(op_desc_, type), "Failed to get original type from %s(%s).", - op_desc_->GetNamePtr(), op_desc_->GetTypePtr()); - GELOGD("GetOpImplFunctionsV2, node name %s, type %s", op_desc_->GetNamePtr(), type.c_str()); - if (use_registry_v2_) { - GELOGD("GetOpImplFunctionsV2 use space_registry v2"); - GE_ASSERT_TRUE(space_registries_v2_.size() > static_cast(op_desc_->GetOppImplVersion())); - auto space_registry_v2 = space_registries_v2_[static_cast(op_desc_->GetOppImplVersion())]; - if (space_registry_v2 == nullptr) { - GELOGW("Attention: default registry does not exist. Tiling will be executed failed"); - return nullptr; - } - return space_registry_v2->GetOpImpl(type.c_str()); - } - GELOGD("GetOpImplFunctionsV2 use space_registry v1"); - GE_ASSERT_TRUE(space_registries_.size() > static_cast(op_desc_->GetOppImplVersion())); - auto space_registry = space_registries_[static_cast(op_desc_->GetOppImplVersion())]; - if (space_registry == nullptr) { - GELOGW("Attention: default registry does not exist. Tiling will be executed failed"); - return nullptr; - } - return space_registry->GetOpImpl(type); -} -} // namespace gert diff --git a/exe_graph/lowering/data_dependent_interpreter.h b/exe_graph/lowering/data_dependent_interpreter.h deleted file mode 100644 index 6d9555f7d2d3c0ffa5b1f4f35210971a721d6720..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/data_dependent_interpreter.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_LOWERING_DATA_DEPENDENT_INTERPRETER_H_ -#define AIR_CXX_RUNTIME_V2_LOWERING_DATA_DEPENDENT_INTERPRETER_H_ -#include -#include "graph/node.h" -#include "graph/compute_graph.h" -#include "base/registry/op_impl_space_registry_v2.h" -#include "register/op_impl_space_registry.h" -namespace gert { -class DataDependentInterpreter { - public: - DataDependentInterpreter(const ge::NodePtr &node, const gert::OpImplSpaceRegistryPtr &space_registry); - // 当前air仓未使用, 待上库后air仓适配都使用这个, 适配完成, 删除上面的构造函数 - DataDependentInterpreter(const ge::OpDescPtr &op_desc, const gert::OpImplSpaceRegistryArray &space_registry); - - // 兼容处理,先合入v2,然后删除V1构造函数 - DataDependentInterpreter(const ge::OpDescPtr &op_desc, const OpImplSpaceRegistryV2Array &space_registry); - /** - * 当前数据依赖的标记来源有三个: - * - * 1. runtime2.0方式,通过IMPL_OP标记了数据依赖 - * 2. runtime2.0方式,在InferShape时,通过OpDesc::SetOpInferDependes设置了数据依赖 - * 3. UB融合算子,在子图内的边界上出现了数据依赖 - * - * 本函数综合上述几种数据依赖的标记来源,给出一个最终的数据依赖判定结果。详细展开来说,本函数的策略为: - * - * |序号|IMPL_OP标记|RT1.0 Style标记|UB融合算子子图解析|最终结果|备注 | - * |--|-----|-----|------|------|-----------------------------------------| - * |1 |true |true | NA | true |数据依赖、编译时1.0、2.0标记正确 | - * |2 |true |false| NA | true |数据依赖、编译时2.0、2.0标记正确 | - * |3 |false|true | NA | true |数据依赖、编译时1.0、2.0标记错误,打印Warning | - * |4 |false|false| NA | true |非数据依赖、编译时1.0、2.0标记正确 | - * |5 |true |true | true | true |UB融合算子、数据依赖、恰好正确 | - * |6 |true |true | false| true |UB融合算子、编译时1.0、原型标记了数据依赖,但是UB融合子图不需要,打印warning | - * |7 |true |false| true | true |UB融合算子、编译时2.0、恰好正确 | - * |8 |true |false| false| true |UB融合算子、编译时2.0、原型标记了数据依赖,但是UB融合子图不需要,打印warning | - * |9 |false|true | true | true |UB融合算子、数据依赖、编译时1.0、2.0标记错误,打印Warning | - * |10|false|true | false| true |UB融合算子、编译时1.0、原型标记了数据依赖,但是UB融合子图不需要,1.0与2.0不一致,打印2 warning | - * |11|false|false| true | true |UB融合算子、编译时2.0、UB子图带来了数据依赖 | - * |12|false|false| false| false|UB融合算子、没有数据依赖 | - * - * @param is_data_dependent - * @return - */ - ge::graphStatus IsDataDependent(const int32_t index, bool &is_data_dependent) const; - ge::graphStatus IsTilingInputDataDependent(const int32_t index, bool &is_tiling_dependent) const; - ge::graphStatus IsSupportTilingDependPlacement(const uint32_t placement, bool &is_support) const; - private: - ge::graphStatus IsDataDependentByImplOp(const int32_t input_index, bool &is_data_dependent) const; - ge::graphStatus IsDataDependentByIr(int32_t index, bool &is_data_dependent) const; - bool GetByIr(bool by_1_0, bool by_2_0, int32_t index_for_log) const; - - ge::graphStatus IsDataDependentByUbGraph(int32_t index, bool &is_data_dependent) const; - bool GetByIrAndUb(bool by_ir, bool by_ub, int32_t index_for_log) const; - - ge::ComputeGraphPtr GetUbGraph() const; - - const OpImplKernelRegistry::OpImplFunctionsV2 *GetOpImplFunctionsV2() const; - - private: - ge::NodePtr node_{nullptr}; - ge::OpDescPtr op_desc_{nullptr}; - OpImplSpaceRegistryArray space_registries_; - mutable ge::ComputeGraphPtr ub_graph_cache_; - // 兼容上库处理,先合入v2,然后删除V1构造函数 - OpImplSpaceRegistryV2Array space_registries_v2_; - bool use_registry_v2_{false}; -}; -} - -#endif // AIR_CXX_RUNTIME_V2_LOWERING_DATA_DEPENDENT_INTERPRETER_H_ diff --git a/exe_graph/lowering/dev_mem_value_holder.cc b/exe_graph/lowering/dev_mem_value_holder.cc deleted file mode 100644 index 8fb4a8b827e2f4f052981ce5afdbfd9ecef36ff3..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/dev_mem_value_holder.cc +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/dev_mem_value_holder.h" - -#include "common/checker.h" -#include "common/util/mem_utils.h" -#include "exe_graph/lowering/builtin_node_types.h" -#include "exe_graph/lowering/exe_graph_attrs.h" - -namespace gert { -namespace bg { -DevMemValueHolderPtr DevMemValueHolder::CreateError(int64_t logic_stream_id, const char *fmt, va_list arg) { - auto value_holder = ge::MakeShared(logic_stream_id); - GE_ASSERT_NOTNULL(value_holder); - value_holder->SetErrorMsg(fmt, arg); - return value_holder; -} - -DevMemValueHolderPtr DevMemValueHolder::CreateError(int64_t logic_stream_id, const char *fmt, ...) { - va_list arg; - va_start(arg, fmt); - auto holder = DevMemValueHolder::CreateError(logic_stream_id, fmt, arg); - va_end(arg); - return holder; -} - -std::vector DevMemValueHolder::CreateDataOutput(const char *node_type, - const std::vector &inputs, - size_t out_count, int64_t logic_stream_id) { - auto node = CreateNode(node_type, inputs, out_count); - if (node == nullptr) { - return {out_count, nullptr}; - } - return ValueHolder::CreateFromNode(node, 0, out_count, logic_stream_id); -} - -DevMemValueHolderPtr DevMemValueHolder::CreateSingleDataOutput(const char *node_type, - const std::vector &inputs, - int64_t logic_stream_id) { - auto node = CreateNode(node_type, inputs, 1U); - if (node == nullptr) { - return nullptr; - } - return ValueHolder::CreateFromNode(node, 0, ValueHolderType::kOutput, logic_stream_id); -} - - -/** - * @param data const数据 - * @param size const数据的长度 - * @param is_string 此const是否是个字符串, todo: 当前对string支持的不好 - * @return - */ -DevMemValueHolderPtr DevMemValueHolder::CreateConst(const void *data, size_t size, int64_t logic_stream_id, - bool is_string) { - GE_ASSERT_NOTNULL(data); - auto node = CreateNode(kConst, {}, 1); - GE_ASSERT_NOTNULL(node); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - GE_ASSERT_SUCCESS(op_desc->SetAttr("is_string", ge::AnyValue::CreateFrom(is_string))); - GE_ASSERT_TRUE(ge::AttrUtils::SetZeroCopyBytes(op_desc, kConstValue, - ge::Buffer::CopyFrom(ge::PtrToPtr(data), size))); - return CreateFromNode(node, 0, ValueHolderType::kConst, logic_stream_id); -} - -ValueHolderPtr DevMemValueHolder::CreateMateFromNode(ge::FastNode *node, int32_t index, ValueHolderType type) { - return ValueHolder::CreateFromNode(node, index, type, logic_stream_id_); -} - -int64_t DevMemValueHolder::GetLogicStream() const { return logic_stream_id_; } -} // namespace bg -} // namespace gert diff --git a/exe_graph/lowering/device_tiling_context_builder.cc b/exe_graph/lowering/device_tiling_context_builder.cc deleted file mode 100644 index 31f44883788474c62649f1733d56bc4a9e0da1a2..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/device_tiling_context_builder.cc +++ /dev/null @@ -1,265 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/device_tiling_context_builder.h" - -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/def_types.h" -#include "common/checker.h" - -namespace gert { -namespace { -constexpr size_t kChainMemAlignedSize = 256UL; - -inline static size_t MemoryAligned(const size_t bytes, const size_t aligns = 128U) { - const size_t aligned_size = (aligns == 0UL) ? sizeof(uintptr_t) : aligns; - return ((bytes + aligned_size - 1UL) / aligned_size) * aligned_size; -} - -static void TiledPointerOffset(const size_t offset_size, uint8_t *&host_addr, uint64_t &dev_addr, - size_t &max_mem_size) { - max_mem_size -= offset_size; - host_addr += offset_size; - dev_addr += offset_size; -} - -void GetStorageShape(const ge::GeTensorDesc &tensor_desc, gert::StorageShape &storage_shape) { - const auto &storage_dims = tensor_desc.GetShape().GetDims(); - for (const auto &dim : storage_dims) { - (void) storage_shape.MutableStorageShape().AppendDim(dim); - } - const auto &origin_dims = tensor_desc.GetOriginShape().GetDims(); - for (const auto &dim : origin_dims) { - (void) storage_shape.MutableOriginShape().AppendDim(dim); - } -} -} // namespace - -size_t DeviceTilingContextBuilder::CalcTotalTiledSize(const ge::OpDescPtr &op_desc) { - // op infos - size_t total_size{op_desc->GetName().size() + 1UL}; // \0 - total_size += op_desc->GetType().size() + 1UL; // \0 - - // gert::tensor size - const size_t io_num = op_desc->GetInputsSize() + op_desc->GetOutputsSize(); - total_size += io_num * sizeof(gert::Tensor); - - // kernel context_size - const size_t chain_num = io_num + static_cast(TilingContext::kOutputNum) + 4UL; // default input ptr nums - const size_t context_size = sizeof(KernelRunContext) + sizeof(Chain *) * chain_num; - const size_t chain_size = (sizeof(Chain) + kChainMemAlignedSize) * chain_num; - total_size += context_size; - total_size += chain_size; - return total_size; -} - -DeviceTilingContextBuilder &DeviceTilingContextBuilder::CompileInfo(void *compile_info) { - compile_info_ = compile_info; - return *this; -} -DeviceTilingContextBuilder &DeviceTilingContextBuilder::PlatformInfo(void *platform_info) { - platform_info_ = platform_info; - return *this; -} -DeviceTilingContextBuilder &DeviceTilingContextBuilder::Deterministic(int32_t deterministic) { - deterministic_ = deterministic; - return *this; -} - -DeviceTilingContextBuilder &DeviceTilingContextBuilder::TilingData(void *tiling_data) { - outputs_[TilingContext::kOutputTilingData] = tiling_data; - return *this; -} - -DeviceTilingContextBuilder &DeviceTilingContextBuilder::AddrRefreshedInputTensor( - const std::map &index_to_tensor) { - index_to_tensor_ = index_to_tensor; - return *this; -} - -DeviceTilingContextBuilder &DeviceTilingContextBuilder::Workspace(void *workspace) { - outputs_[TilingContext::kOutputWorkspace] = workspace; - return *this; -} - -DeviceTilingContextBuilder &DeviceTilingContextBuilder::TiledHolder(uint8_t *host_addr, uint64_t dev_addr, - size_t max_mem_size) { - host_begin_ = host_addr; - dev_begin_ = dev_addr; - max_mem_size_ = max_mem_size; - return *this; -} - -ge::graphStatus DeviceTilingContextBuilder::BuildRtTensor(const ge::GeTensorDesc &tensor_desc, - ConstTensorAddressPtr address) { - gert::StorageShape storage_shape; - GetStorageShape(tensor_desc, storage_shape); - const size_t rt_tensor_size = sizeof(gert::Tensor); - GE_ASSERT(max_mem_size_ >= rt_tensor_size); - GE_ASSERT_NOTNULL(host_begin_); - auto rt_tensor = new (host_begin_)(gert::Tensor); - GE_ASSERT_NOTNULL(rt_tensor); - rt_tensor->SetDataType(tensor_desc.GetDataType()); - rt_tensor->MutableStorageShape() = storage_shape.GetStorageShape(); - rt_tensor->MutableOriginShape() = storage_shape.GetOriginShape(); - rt_tensor->MutableFormat().SetStorageFormat(tensor_desc.GetFormat()); - rt_tensor->MutableFormat().SetOriginFormat(tensor_desc.GetOriginFormat()); - (void) rt_tensor->MutableTensorData().SetAddr(address, nullptr); - rt_tensor->MutableTensorData().SetPlacement(gert::kOnDeviceHbm); - // dev_value - inputs_.push_back(ge::ValueToPtr(dev_begin_)); - dev_begin_ += rt_tensor_size; - max_mem_size_ -= rt_tensor_size; - host_begin_ += rt_tensor_size; - GELOGD("Build rt tensor from device addr %lx.", dev_begin_); - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus DeviceTilingContextBuilder::BuildPlacementRtTensor(const ge::GeTensorDesc &tensor_desc, - Tensor *rt_tensor) const { - GE_ASSERT_NOTNULL(rt_tensor); - gert::StorageShape storage_shape; - GetStorageShape(tensor_desc, storage_shape); - rt_tensor->SetDataType(tensor_desc.GetDataType()); - rt_tensor->MutableStorageShape() = storage_shape.GetStorageShape(); - rt_tensor->MutableOriginShape() = storage_shape.GetOriginShape(); - rt_tensor->MutableFormat().SetStorageFormat(tensor_desc.GetFormat()); - rt_tensor->MutableFormat().SetOriginFormat(tensor_desc.GetOriginFormat()); - rt_tensor->MutableTensorData().SetPlacement(gert::kOnDeviceHbm); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus DeviceTilingContextBuilder::BuildIOTensors(const ge::OpDesc *const op_desc) { - GE_ASSERT_NOTNULL(op_desc); - size_t valid_inputs{0UL}; - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); ++i) { - const ge::GeTensorDesc &input_desc = op_desc->GetInputDesc(i); - if (input_desc.IsValid() != ge::GRAPH_SUCCESS) { - continue; - } - const auto iter = index_to_tensor_.find(valid_inputs); - if (iter != index_to_tensor_.end()) { - GE_ASSERT_GRAPH_SUCCESS(BuildPlacementRtTensor(input_desc, iter->second.host_addr)); - // dev_value - inputs_.push_back(ge::ValueToPtr(iter->second.device_addr)); - } else { - GE_ASSERT_GRAPH_SUCCESS(BuildRtTensor(input_desc, nullptr)); - } - ++valid_inputs; - } - - for (size_t i = 0UL; i < op_desc->GetOutputsSize(); ++i) { - GE_ASSERT_GRAPH_SUCCESS(BuildRtTensor(op_desc->GetOutputDesc(i), nullptr)); - } - return ge::GRAPH_SUCCESS; -} - -// 0-n input tensors -// n-m output shapes -// m + 1 compile info -// m + 2 tiling func -// 其中 n为输入个数总和,m为输入输出个数总和 -ge::graphStatus DeviceTilingContextBuilder::Build(const ge::NodePtr &node, TiledKernelContextHolder &holder) { - GE_ASSERT_NOTNULL(platform_info_, " Device platform info addr is nullptr."); - GE_ASSERT_EOK(memset_s(host_begin_, max_mem_size_, 0, max_mem_size_), "Failed to memset host context buffer."); - - inputs_.clear(); - GE_ASSERT_GRAPH_SUCCESS(BuildIOTensors(node->GetOpDescBarePtr())); - - inputs_.emplace_back(compile_info_); - inputs_.emplace_back(platform_info_); - inputs_.emplace_back(nullptr); - inputs_.emplace_back(reinterpret_cast(deterministic_)); - - return TiledBuild(node, holder); -} - -ge::graphStatus DeviceTilingContextBuilder::TiledBuild(const ge::NodePtr &node, TiledKernelContextHolder &holder) { - // op_type - const size_t op_type_len = node->GetType().length() + 1UL; // '\0' - GE_ASSERT_TRUE(max_mem_size_ >= op_type_len); - GE_ASSERT_EOK(memcpy_s(host_begin_, max_mem_size_, node->GetTypePtr(), op_type_len)); - holder.dev_op_type_addr_ = dev_begin_; - TiledPointerOffset(op_type_len, host_begin_, dev_begin_, max_mem_size_); - - // op_name - const size_t op_name_len = node->GetName().length() + 1UL; // '\0' - GE_ASSERT_TRUE(max_mem_size_ >= op_name_len); - GE_ASSERT_EOK(memcpy_s(host_begin_, max_mem_size_, node->GetNamePtr(), op_name_len)); - holder.dev_op_name_addr_ = dev_begin_; - TiledPointerOffset(op_name_len, host_begin_, dev_begin_, max_mem_size_); - - // compute node info - auto host_compute_node_info = ge::PtrToPtr(holder.host_compute_node_info_); - GE_ASSERT_NOTNULL(host_compute_node_info); - host_compute_node_info->SetNodeName(ge::PtrToPtr(ge::ValueToPtr(holder.dev_op_name_addr_))); - host_compute_node_info->SetNodeType(ge::PtrToPtr(ge::ValueToPtr(holder.dev_op_type_addr_))); - - GE_ASSERT_TRUE(max_mem_size_ >= holder.compute_node_info_size_); - const uint64_t dev_compute_node_info = dev_begin_; - GE_ASSERT_EOK(memcpy_s(host_begin_, max_mem_size_, holder.host_compute_node_info_, holder.compute_node_info_size_)); - TiledPointerOffset(holder.compute_node_info_size_, host_begin_, dev_begin_, max_mem_size_); - - size_t context_size = sizeof(KernelRunContext) + sizeof(Chain *) * (inputs_.size() + outputs_.size()); - GE_ASSERT_TRUE(max_mem_size_ >= context_size); - KernelContext *kernel_context = ge::PtrToPtr(host_begin_); - GE_ASSERT_NOTNULL(kernel_context); - holder.host_context_ = kernel_context; - holder.dev_context_addr_ = dev_begin_; - TiledPointerOffset(context_size, host_begin_, dev_begin_, max_mem_size_); - - // kernel run context - auto kernel_run_context = holder.host_context_->GetContext(); - kernel_run_context->input_size = inputs_.size(); - kernel_run_context->output_size = outputs_.size(); - kernel_run_context->compute_node_info = ge::ValueToPtr(dev_compute_node_info); - // set output_start with dev_begin_ - kernel_run_context->output_start = reinterpret_cast( - holder.dev_context_addr_ + ge::PtrToValue(&(kernel_run_context->values[kernel_run_context->input_size])) - - ge::PtrToValue(holder.host_context_)); - - // aligned dev for ts - const size_t aligned_dev_addr = MemoryAligned(dev_begin_); - const size_t aligned_offset = static_cast(aligned_dev_addr - dev_begin_); - TiledPointerOffset(aligned_offset, host_begin_, dev_begin_, max_mem_size_); - - // dev - const size_t aligned_chain_size = MemoryAligned(sizeof(Chain)); - const size_t total_chain_size = aligned_chain_size * (inputs_.size() + outputs_.size()); - GE_ASSERT_TRUE(max_mem_size_ >= total_chain_size); - - // input output chain - size_t chain_index{0UL}; - for (auto &input : inputs_) { - Chain *host_chain = ge::PtrToPtr(host_begin_); - GE_ASSERT_NOTNULL(host_chain); - host_chain->Set(input, nullptr); - kernel_run_context->values[chain_index] = ge::PtrToPtr(ge::ValueToPtr(dev_begin_)); - host_begin_ += aligned_chain_size; - dev_begin_ += aligned_chain_size; - ++chain_index; - } - for (auto &output : outputs_) { - Chain *host_chain = ge::PtrToPtr(host_begin_); - GE_ASSERT_NOTNULL(host_chain); - host_chain->Set(output, nullptr); - kernel_run_context->values[chain_index] = ge::PtrToPtr(ge::ValueToPtr(dev_begin_)); - holder.output_addrs_.push_back(dev_begin_); - host_begin_ += aligned_chain_size; - dev_begin_ += aligned_chain_size; - ++chain_index; - } - - return ge::GRAPH_SUCCESS; -} -} // namespace gert diff --git a/exe_graph/lowering/extend_exe_graph.h b/exe_graph/lowering/extend_exe_graph.h deleted file mode 100644 index c26f63a77c34219a365795382e421da1493e1cf7..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/extend_exe_graph.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXETEND_EXE_GRAPH_ATTRS_H_ -#define AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXETEND_EXE_GRAPH_ATTRS_H_ -#include "graph/compute_graph.h" -#include "graph/fast_graph/execute_graph.h" -namespace gert { -template -inline bool FindValFromMapExtAttr(const ge::ExecuteGraph *exe_graph, const char *attr_name, const K &key, V &val) { - auto ext_attr = exe_graph->GetExtAttr>(attr_name); - if (ext_attr == nullptr) { - return false; - } - const auto iter = ext_attr->find(key); - if (iter != ext_attr->cend()) { - val = iter->second; - return true; - } - return false; -} - -template -inline void AddKVToMapExtAttr(ge::ExecuteGraph *exe_graph, const char *attr_name, const K &key, const V &val) { - auto ext_attr = exe_graph->GetExtAttr>(attr_name); - if (ext_attr == nullptr) { - std::unordered_map temp_ext_attr {}; - temp_ext_attr[key] = val; - exe_graph->SetExtAttr(attr_name, temp_ext_attr); - } else { - (*ext_attr)[key] = val; - } -} -} // namespace gert -#endif // AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXETEND_EXE_GRAPH_ATTRS_H_ diff --git a/exe_graph/lowering/frame_selector.cc b/exe_graph/lowering/frame_selector.cc deleted file mode 100644 index f245868beb157e710bda30926bd295415e6c4645..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/frame_selector.cc +++ /dev/null @@ -1,438 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/frame_selector.h" - -#include "common/checker.h" -#include "exe_graph/lowering/graph_frame.h" -#include "exe_graph/lowering/value_holder.h" -#include "exe_graph/runtime/execute_graph_types.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/execute_graph_utils.h" -#include "scoped_current_frame.h" -#include "value_holder_inner.h" - -namespace gert { -namespace bg { -namespace { -const int64_t kControlAnchorIdx = -1; -ge::EdgeSrcEndpoint CreateInnerData(ge::ExecuteGraph *graph, const GraphFrame &graph_frame, - const size_t index) { - auto op_desc = ge::ComGraphMakeShared(ValueHolder::GenerateNodeName(kInnerData, graph_frame), kInnerData); - GE_ASSERT_NOTNULL(op_desc); - GE_ASSERT_SUCCESS(op_desc->AddOutputDesc(ge::GeTensorDesc())); - GE_ASSERT_TRUE(ge::AttrUtils::SetInt(op_desc, "index", index)); - auto node = graph->AddNode(op_desc); - GE_ASSERT_NOTNULL(node); - return {node, 0}; -} - -ge::FastNode *GetOrCreateInnerNetOutput(const GraphFrame &frame) { - auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(frame.GetExecuteGraph().get(), kInnerNetOutput); - if (netoutput != nullptr) { - return netoutput; - } - return ValueHolder::AddNode(kInnerNetOutput, 0U, 0U, frame); -} - -ge::graphStatus MoveGuardersToDeInit(ge::FastNode *init_node, const GraphFrame &root_frame, - const std::vector> &guarders_and_out_index) { - const auto de_init_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(root_frame.GetExecuteGraph().get(), - GetExecuteGraphTypeStr(ExecuteGraphType::kDeInit)); - GE_ASSERT_NOTNULL(de_init_node); - auto de_init_graph = ge::FastNodeUtils::GetSubgraphFromNode(de_init_node, 0U); - GE_ASSERT_NOTNULL(de_init_graph); - - auto index = de_init_node->GetDataInNum(); - GE_ASSERT_SUCCESS(ge::FastNodeUtils::AppendInputEdgeInfo(de_init_node, index + guarders_and_out_index.size())); - - auto init_graph = init_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(init_graph); - for (size_t i = 0U; i < guarders_and_out_index.size(); ++i) { - auto guarder_node = guarders_and_out_index[i].first->GetFastNode(); - GE_ASSERT_NOTNULL(guarder_node); - GE_ASSERT_SUCCESS( - ge::ExecuteGraphUtils::MoveNodeToGraph(guarder_node, de_init_graph)); - GE_ASSERT_NOTNULL(init_graph->AddEdge(init_node, - static_cast(guarders_and_out_index[i].second), de_init_node, - static_cast(index + i))); - auto src_end_point = CreateInnerData(de_init_graph, root_frame, index + i); - GE_ASSERT_NOTNULL(src_end_point.node); - GE_ASSERT_NOTNULL(de_init_graph->AddEdge(src_end_point.node, src_end_point.index, - guarders_and_out_index[i].first->GetFastNode(), 0)); - } - - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus GetNewOutIndexes(const std::vector &init_graph_outputs, const size_t start_index, - size_t &new_out_num, std::map &graph_outputs_to_out_index) { - new_out_num = 0U; - for (const auto &output : init_graph_outputs) { - GE_ASSERT_NOTNULL(output, "Failed to construct on init graph, the graph builder return nullptr"); - const auto node = output->GetFastNode(); - GE_ASSERT_NOTNULL(node); - const auto index = output->GetOutIndex(); - if (index < 0) { - graph_outputs_to_out_index[output] = kControlAnchorIdx; - continue; - } - int32_t out_index = -1; - for (const auto edge : node->GetOutEdgesRefByIndex(index)) { - if (edge == nullptr) { - continue; - } - GE_ASSERT_NOTNULL(edge->dst); - if (IsTypeInnerNetOutput(edge->dst->GetTypePtr())) { - out_index = edge->dst_input; - break; - } - } - if (out_index < 0) { - graph_outputs_to_out_index[output] = start_index + (new_out_num++); - } else { - graph_outputs_to_out_index[output] = static_cast(out_index); - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ConnectSubGraphOut(ge::FastNode *parent_node, GraphFrame &sub_frame, - const std::vector &sub_graph_outputs, - std::vector> &guarders_and_out_index, - std::vector &parent_node_outputs) { - auto netoutput = GetOrCreateInnerNetOutput(sub_frame); - GE_ASSERT_NOTNULL(netoutput); - const auto index = netoutput->GetDataInNum(); - - size_t new_out_num; - std::map graph_outputs_to_out_index; - GE_ASSERT_SUCCESS(GetNewOutIndexes(sub_graph_outputs, index, new_out_num, graph_outputs_to_out_index)); - - GE_ASSERT_SUCCESS(ge::FastNodeUtils::AppendInputEdgeInfo(netoutput, index + new_out_num)); - guarders_and_out_index.reserve(sub_graph_outputs.size()); - parent_node_outputs.reserve(sub_graph_outputs.size()); - for (const auto &holder : sub_graph_outputs) { - GE_ASSERT_NOTNULL(holder); - const auto &out_index = graph_outputs_to_out_index.at(holder); - auto node = holder->GetFastNode(); - GE_ASSERT_NOTNULL(node); - auto graph = node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph); - if (out_index >= static_cast(index)) { - GE_ASSERT_NOTNULL(graph->AddEdge(node, holder->GetOutIndex(), netoutput, out_index)); - } else if (out_index == kControlAnchorIdx) { - GE_ASSERT_NOTNULL(graph->AddEdge(node, ge::kControlEdgeIndex, - netoutput, ge::kControlEdgeIndex)); - } - - auto guarder = holder->GetGuarder(); - if (guarder != nullptr) { - guarders_and_out_index.emplace_back(guarder, out_index); - } - - parent_node_outputs.emplace_back(holder->CreateMateFromNode(parent_node, static_cast(out_index), - ValueHolder::ValueHolderType::kOutput)); - } - - GE_ASSERT_SUCCESS(ge::FastNodeUtils::AppendOutputEdgeInfo(parent_node, index + new_out_num)); - - for (size_t i = 0U; i < sub_graph_outputs.size(); ++i) { - parent_node_outputs[i]->SetPlacement(sub_graph_outputs[i]->GetPlacement()); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ConnectOut(ge::FastNode *init_node, GraphFrame &root_frame, GraphFrame &init_frame, - const std::vector &init_graph_outputs, - std::vector &init_node_outputs) { - std::vector> guarders_and_out_index; - GE_ASSERT_SUCCESS( - ConnectSubGraphOut(init_node, init_frame, init_graph_outputs, guarders_and_out_index, init_node_outputs)); - if (!guarders_and_out_index.empty()) { - GE_ASSERT_SUCCESS(MoveGuardersToDeInit(init_node, root_frame, guarders_and_out_index)); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus LoweringToSubgraph(const std::function()> &builder, - ValueHolderPtr partition_call_holder, GraphFrame *sub_graph_frame, - std::vector &parent_node_outputs) { - const ScopedCurrentFrame frame_guarder(sub_graph_frame); - auto outputs = builder(); - partition_call_holder->AppendOutputs(outputs.size()); - - std::vector> guarders_and_out_index; - auto partitioned_call = const_cast(sub_graph_frame->GetExecuteGraph()->GetParentNodeBarePtr()); - GE_ASSERT_GRAPH_SUCCESS( - ConnectSubGraphOut(partitioned_call, *sub_graph_frame, outputs, guarders_and_out_index, parent_node_outputs)); - return ge::GRAPH_SUCCESS; -} - -bool InitStageIdsToPartitionedCalls(const char *attr_name, std::vector &stage_ids_to_pcall) { - if (attr_name == kStageIdsToFirstPartitionedCall) { - stage_ids_to_pcall.resize(static_cast(OnMainRootFirstExecStage::kStageSize)); - } else if (attr_name == kStageIdsToLastPartitionedCall) { - stage_ids_to_pcall.resize(static_cast(OnMainRootLastExecStage::kStageSize)); - } else { - return false; - } - return true; -} - -std::vector GetOrCreateAllPartitionedCalls(const ge::ExecuteGraph *exe_graph, - const char *stage_attr_name) { - std::vector tmp_stage_ids_to_pcall; - auto stage_ids_to_partitioned_calls = exe_graph->GetExtAttr>(stage_attr_name); - if (stage_ids_to_partitioned_calls == nullptr) { - GE_ASSERT_TRUE(InitStageIdsToPartitionedCalls(stage_attr_name, tmp_stage_ids_to_pcall)); - return tmp_stage_ids_to_pcall; - } else { - return *stage_ids_to_partitioned_calls; - } -} - -ValueHolderPtr GetOrCreatePartitionedCallHolder(std::vector &stage_ids_to_pcalls, size_t stage_id) { - GE_ASSERT_TRUE(stage_ids_to_pcalls.size() > stage_id); - auto partition_call_holder = stage_ids_to_pcalls[stage_id]; - if (partition_call_holder != nullptr) { - return partition_call_holder; - } - return ValueHolder::CreateVoid("PartitionedCall", {}); -} - -GraphFrame *PushPartitionedCallSubFrame(bg::ValueHolderPtr &partition_call_holder) { - GE_ASSERT_NOTNULL(partition_call_holder->GetFastNode()); - auto sub_graph = ge::FastNodeUtils::GetSubgraphFromNode(partition_call_holder->GetFastNode(), 0U); - if (sub_graph == nullptr) { - return ValueHolder::PushGraphFrame(partition_call_holder, "exec_sub_graph"); - } - - GraphFrame *cur_frame = ValueHolder::GetCurrentFrame(); - GE_ASSERT_NOTNULL(cur_frame); - std::unique_ptr sub_graph_frame_holder = - ge::ComGraphMakeUnique(sub_graph->shared_from_this(), *cur_frame); - GE_ASSERT_NOTNULL(sub_graph_frame_holder); - sub_graph_frame_holder->SetCurrentComputeNode(cur_frame->GetCurrentComputeNode()); - return ValueHolder::PushGraphFrame(sub_graph_frame_holder.release()); -} - -std::vector OnMainRootPartitionedCall( - const std::function()> &partition_call_builder, const char *attr_name, - size_t stage_id) { - GE_ASSERT_NOTNULL(partition_call_builder); - GE_ASSERT_TRUE(GetGraphFrames().size() > 1U); - GraphFrame *current_frame = (GetGraphFrames().begin() + 1)->get(); - GE_ASSERT_EQ(current_frame->GetExecuteGraph()->GetParentNodeBarePtr()->GetType(), "Main"); - const ScopedCurrentFrame main_frame_guarder(current_frame); - - std::vector stage_ids_to_pcall = - GetOrCreateAllPartitionedCalls(current_frame->GetExecuteGraph().get(), attr_name); - GE_ASSERT_TRUE(stage_ids_to_pcall.size() > stage_id, "Stage_ids_2_partitioncall size %zu, stage_id is %zu", - stage_ids_to_pcall.size(), stage_id); - ValueHolderPtr partition_call_holder = GetOrCreatePartitionedCallHolder(stage_ids_to_pcall, stage_id); - GE_ASSERT_NOTNULL(partition_call_holder); - GraphFrame *sub_graph_frame = PushPartitionedCallSubFrame(partition_call_holder); - GE_ASSERT_NOTNULL(sub_graph_frame); - - std::vector parent_node_outputs; - GE_ASSERT_GRAPH_SUCCESS( - LoweringToSubgraph(partition_call_builder, partition_call_holder, sub_graph_frame, parent_node_outputs)); - - stage_ids_to_pcall[stage_id] = partition_call_holder; - current_frame->GetExecuteGraph()->SetExtAttr>(attr_name, stage_ids_to_pcall); - ValueHolder::PopGraphFrame(); - return parent_node_outputs; -} - -ge::graphStatus OnDeInitGraph(const std::function()> &builder, - std::vector &de_init_nodss) { - GE_ASSERT_NOTNULL(builder, "Failed to do frame selection, the builder is nullptr"); - GE_ASSERT_TRUE(!GetGraphFrames().empty(), "Failed to do frame selection, there is no root-frame exists"); - - const auto root_frame = GetGraphFrames().begin()->get(); - GE_ASSERT_NOTNULL(root_frame, "Failed to find the root frame"); - - // check if the main_frame is correct - const auto root_graph = root_frame->GetExecuteGraph(); - GE_ASSERT_NOTNULL(root_graph, "Failed to find the root graph"); - const auto de_init_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(root_graph.get(), GetExecuteGraphTypeStr(ExecuteGraphType::kDeInit)); - GE_ASSERT_NOTNULL(de_init_node, "Failed to find the de_init node"); - const auto de_init_graph = ge::FastNodeUtils::GetSubgraphFromNode(de_init_node, 0U); - GE_ASSERT_NOTNULL(de_init_graph, "Failed to find the DeInit graph from de_init node %s", - de_init_node->GetName().c_str()); - - auto tmp_de_init_frame = ge::ComGraphMakeUnique(de_init_graph->shared_from_this(), *root_frame); - GE_ASSERT_NOTNULL(tmp_de_init_frame); - - tmp_de_init_frame->SetCurrentComputeNode(ValueHolder::GetCurrentFrame()->GetCurrentComputeNode()); - - const ScopedCurrentFrame frame_guarder(tmp_de_init_frame.get()); - - de_init_nodss = builder(); - - return ge::GRAPH_SUCCESS; -} -} // namespace -std::vector FrameSelector::OnMainRoot(const std::function()> &builder) { - if (builder == nullptr || GetGraphFrames().empty()) { - return {}; - } - std::vector outputs; - if (OnMainRoot(builder, outputs) != ge::GRAPH_SUCCESS) { - GELOGW("Compatible mode, the air code is not the newest."); - const ScopedCurrentFrame frame_guarder(GetGraphFrames().front().get()); - outputs = builder(); - } - return outputs; -} - -ge::graphStatus FrameSelector::OnMainRoot(const std::function()> &builder, - std::vector &outputs) { - GE_ASSERT_NOTNULL(builder, "Failed to do frame selection, the builder is nullptr"); - // 栈底是root-frame,向上是main-frame,因此栈中至少有两个元素 - GE_ASSERT_TRUE(GetGraphFrames().size() > 1U, "Failed to do frame selection, there is no main-frame exists"); - - const auto root_frame = GetGraphFrames().begin()->get(); - GE_ASSERT_NOTNULL(root_frame, "Failed to find the root frame"); - auto main_frame = (GetGraphFrames().begin() + 1)->get(); - GE_ASSERT_NOTNULL(main_frame, "Failed to find the main frame"); - - // check if the main_frame is correct - const auto root_graph = root_frame->GetExecuteGraph(); - GE_ASSERT_NOTNULL(root_graph, "Failed to find the root graph"); - const auto main_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(root_graph.get(), GetExecuteGraphTypeStr(ExecuteGraphType::kMain)); - GE_ASSERT_NOTNULL(main_node, "Failed to find the main node"); - const auto main_graph = ge::FastNodeUtils::GetSubgraphFromNode(main_node, 0U); - GE_ASSERT_TRUE(main_graph == main_frame->GetExecuteGraph().get(), "Failed to find the main frame"); - - const ScopedCurrentFrame frame_guarder(main_frame); - outputs = builder(); - return ge::GRAPH_SUCCESS; -} - -std::vector FrameSelector::OnMainRootFirst( - const std::function()> &builder) { - return OnMainRootPartitionedCall(builder, kStageIdsToFirstPartitionedCall, - static_cast(OnMainRootFirstExecStage::kFirstEventSyncStage)); -} - -std::vector FrameSelector::OnDeInitRoot(const std::function()> &builder) { - if (builder == nullptr || GetGraphFrames().empty()) { - return {}; - } - std::vector de_init_nodes; - if (OnDeInitGraph(builder, de_init_nodes) != ge::GRAPH_SUCCESS) { - return {}; - } - return de_init_nodes; -} - -std::vector FrameSelector::OnInitRoot(const std::function()> &builder) { - std::vector init_graph_outputs; - std::vector init_node_outputs; - const auto ret = OnInitRoot(builder, init_graph_outputs, init_node_outputs); - if (ret != ge::GRAPH_SUCCESS) { - return {}; - } - return init_node_outputs; -} - -ge::graphStatus FrameSelector::OnInitRoot(const std::function()> &builder, - std::vector &init_graph_outputs, - std::vector &init_node_outputs) { - GE_ASSERT_NOTNULL(builder, "Failed to do frame selection, the builder is nullptr"); - GE_ASSERT_TRUE(!GetGraphFrames().empty(), "Failed to do frame selection, there is no root-frame exists"); - - const auto root_frame = GetGraphFrames().begin()->get(); - GE_ASSERT_NOTNULL(root_frame, "Failed to find the root frame"); - - const auto init_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(root_frame->GetExecuteGraph().get(), - GetExecuteGraphTypeStr(ExecuteGraphType::kInit)); - GE_ASSERT_NOTNULL(init_node, "Failed to find the Init node from root graph"); - auto init_graph = ge::FastNodeUtils::GetSubgraphFromNode(init_node, 0U); - GE_ASSERT_NOTNULL(init_graph, "Failed to find the Init graph from init node %s", init_node->GetNamePtr()); - - auto tmp_init_frame = ge::ComGraphMakeUnique(init_graph->shared_from_this(), *root_frame); - GE_ASSERT_NOTNULL(tmp_init_frame); - - tmp_init_frame->SetCurrentComputeNode(ValueHolder::GetCurrentFrame()->GetCurrentComputeNode()); - const ScopedCurrentFrame frame_guarder(tmp_init_frame.get()); - init_graph_outputs = builder(); - if (!init_graph_outputs.empty()) { - return ConnectOut(init_node, *root_frame, *tmp_init_frame, init_graph_outputs, init_node_outputs); - } - return ge::GRAPH_SUCCESS; -} - -ValueHolderPtr FrameSelector::OnMainRootLast(const std::function &builder) { - if (builder == nullptr || GetGraphFrames().empty()) { - return nullptr; - } - GraphFrame *current_frame = nullptr; - if (GetGraphFrames().size() > 1U) { - current_frame = (GetGraphFrames().begin() + 1)->get(); - } else { - current_frame = GetGraphFrames().begin()->get(); - } - GE_ASSERT_NOTNULL(current_frame); - const ScopedCurrentFrame frame_guarder(current_frame); - auto output = builder(); - GetCurrentFrame()->SetLastExecNode(output); - return output; -} - -std::vector FrameSelector::OnMainRootLastEventSync( - const std::function()> &builder) { - return OnMainRootPartitionedCall(builder, kStageIdsToLastPartitionedCall, - static_cast(OnMainRootLastExecStage::kLastEventSyncStage)); -} - -std::vector FrameSelector::OnMainRootLastResourceClean( - const std::function()> &builder) { - return OnMainRootPartitionedCall(builder, kStageIdsToLastPartitionedCall, - static_cast(OnMainRootLastExecStage::kLastResourceClean)); -} - -ValueHolderPtr HolderOnInit(const ValueHolderPtr &holder) { - GE_ASSERT_NOTNULL(holder); - const auto index = holder->GetOutIndex(); - GE_ASSERT_TRUE(index >= 0, "The holder is a ctrl holder"); - - const auto holder_node = holder->GetFastNode(); - GE_ASSERT_NOTNULL(holder_node); - if (strcmp(holder_node->GetTypePtr(), GetExecuteGraphTypeStr(ExecuteGraphType::kInit)) == 0) { - const auto init_graph = ge::FastNodeUtils::GetSubgraphFromNode(holder_node, 0U); - GE_ASSERT_NOTNULL(init_graph); - const auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_graph, kInnerNetOutput); - GE_ASSERT_NOTNULL(netoutput, "Can not find the InnerNetOutput node on the Init graph"); - - auto edge = netoutput->GetInDataEdgeByIndex(index); - GE_ASSERT_NOTNULL(edge, "The InnerNetOutput does not have the in edge %d", index); - auto src_node = edge->src; - GE_ASSERT_NOTNULL(src_node); - return holder->CreateMateFromNode(src_node, edge->src_output, ValueHolder::ValueHolderType::kOutput); - } - - const auto holder_graph = holder_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(holder_graph); - const auto parent_node = holder_graph->GetParentNodeBarePtr(); - GE_ASSERT_NOTNULL(parent_node, "The node %s is not on the Root graph", holder_node->GetNamePtr()); - if (strcmp(parent_node->GetTypePtr(), GetExecuteGraphTypeStr(ExecuteGraphType::kInit)) == 0) { - return holder; - } - return nullptr; -} -} // namespace bg -} // namespace gert diff --git a/exe_graph/lowering/generate_exe_graph.cc b/exe_graph/lowering/generate_exe_graph.cc deleted file mode 100644 index 293687d4e6a1f570cb70f84f1e13689a5c055803..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/generate_exe_graph.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/generate_exe_graph.h" -#include "exe_graph/lowering/lowering_global_data.h" -namespace gert { -namespace bg { -GenerateExeGraph::ExeGraphGenerator GenerateExeGraph::generator_ = {nullptr, nullptr, nullptr}; - -ValueHolderPtr GenerateExeGraph::MakeSureTensorAtHost(const ge::Node *node, LoweringGlobalData &global_data, - const ValueHolderPtr &addr, const ValueHolderPtr &size) { - std::vector copy_inputs{global_data.GetStream()}; - copy_inputs.emplace_back(global_data.GetOrCreateAllocator({kOnHost, AllocatorUsage::kAllocNodeWorkspace})); - copy_inputs.emplace_back(addr); - copy_inputs.emplace_back(size); - GE_ASSERT_NOTNULL(node); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - return bg::DevMemValueHolder::CreateSingleDataOutput("MakeSureTensorAtHost", copy_inputs, op_desc->GetStreamId()); -} - -ValueHolderPtr GenerateExeGraph::CalcTensorSizeFromShape(ge::DataType dt, const ValueHolderPtr &shape) { - auto data_type = ValueHolder::CreateConst(&dt, sizeof(dt)); - return ValueHolder::CreateSingleDataOutput("CalcTensorSizeFromShape", {data_type, shape}); -} - -ValueHolderPtr GenerateExeGraph::FreeMemoryGuarder(const ValueHolderPtr &resource) { - return ValueHolder::CreateVoidGuarder("FreeMemory", resource, {}); -} -} // namespace bg -} // namespace gert diff --git a/exe_graph/lowering/getcdim.cc b/exe_graph/lowering/getcdim.cc deleted file mode 100644 index da87749e8d0f9c4699b3e188f2b8382ea03e40af..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/getcdim.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include -#include "graph/types.h" -#include "graph/def_types.h" -#include "axis_constants.h" -#include "exe_graph/runtime/tiling_context.h" -#include "exe_graph/runtime/extended_kernel_context.h" -#include "exe_graph/runtime/shape.h" -#include "exe_graph/lowering/getcdim.h" - -namespace gert { - const std::map CDIM_INDEX_OF_FORMAT { - {ge::FORMAT_NCHW, transformer::AXIS_NCHW_DIM_C}, - {ge::FORMAT_HWCN, transformer::AXIS_HWCN_DIM_C}, - {ge::FORMAT_NHWC, transformer::AXIS_NHWC_DIM_C}, - {ge::FORMAT_CHWN, transformer::AXIS_CHWN_DIM_C}, - {ge::FORMAT_NDHWC, transformer::NDHWC_DIM_C}, - {ge::FORMAT_NCDHW, transformer::NCDHW_DIM_C}, - {ge::FORMAT_DHWCN, transformer::DHWCN_DIM_C}, - {ge::FORMAT_DHWNC, transformer::DHWNC_DIM_C} - }; - -namespace { - int64_t GetCDim(TilingContext *const context, const size_t index, const bool is_input) { - if (context == nullptr) { - return -1; - } - auto extend_context = ge::PtrToPtr(context); - auto compute_node_info = extend_context->GetComputeNodeInfo(); - if (compute_node_info == nullptr) { - return -1; - } - auto kernel_context = ge::PtrToPtr(context); - const CompileTimeTensorDesc *td = nullptr; - StorageShape *storage_shape = nullptr; - if (is_input) { - td = compute_node_info->GetInputTdInfo(index); - storage_shape = kernel_context->MutableInputPointer(index); - } else { - td = compute_node_info->GetOutputTdInfo(index); - storage_shape = kernel_context->GetOutputPointer(index); - } - if ((td == nullptr) || (storage_shape == nullptr)) { - return -1; - } - const auto original_format = td->GetOriginFormat(); - const auto iter = CDIM_INDEX_OF_FORMAT.find(original_format); - if (iter == CDIM_INDEX_OF_FORMAT.cend()) { - return -1; - } - Shape &origin_shape = storage_shape->MutableOriginShape(); - const auto expend_dims = td->GetExpandDimsType(); - Shape expand_shape; - (void) expend_dims.Expand(origin_shape, expand_shape); - - if (static_cast(iter->second) >= expand_shape.GetDimNum()) { - return -1; - } - if (expand_shape.GetDimNum() == origin_shape.GetDimNum()) { - return static_cast(origin_shape.GetDim(iter->second)); - } else { - return static_cast(expand_shape.GetDim(iter->second)); - } - } -} // namespace - - int64_t GetInputCDim(TilingContext *kernel_context, const size_t index) { - return GetCDim(kernel_context, index, true); - } - int64_t GetOutputCDim(TilingContext *kernel_context, const size_t index) { - return GetCDim(kernel_context, index, false); - } -} // namespace gert diff --git a/exe_graph/lowering/kernel_run_context_builder.cc b/exe_graph/lowering/kernel_run_context_builder.cc deleted file mode 100644 index af409a55c2c854c494d58b1c79de3c92eaf2121a..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/kernel_run_context_builder.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include "graph/compute_graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_util.h" -#include "register/op_impl_space_registry.h" -#include "graph/def_types.h" - -namespace gert { -KernelContextHolder KernelRunContextBuilder::Build(const ge::OpDescPtr &op_desc) { - ge::Status ret = ge::GRAPH_FAILED; - return Build(op_desc, ret); -} - -KernelContextHolder KernelRunContextBuilder::Build(const ge::OpDescPtr &op_desc, ge::graphStatus &ret) { - ret = ge::GRAPH_FAILED; - KernelContextHolder holder; - size_t size = sizeof(KernelRunContext) + sizeof(Chain *) * (inputs_.size() + outputs_.size()); - holder.context_holder_ = ge::ComGraphMakeUnique(size); - if (holder.context_holder_ == nullptr) { - GELOGE(ge::GRAPH_FAILED, "Create context holder failed."); - return holder; - } - const auto &space_registry = gert::DefaultOpImplSpaceRegistry::GetInstance() - .GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()); - OpImplRegisterV2::PrivateAttrList private_attrs; - if (space_registry != nullptr) { - private_attrs = space_registry->GetPrivateAttrs(op_desc->GetType()); - } - GELOGD("Default space registry is %s. Op:%s(%s) has %zu private attrs.", - space_registry == nullptr ? "nullptr" : "not nullptr", op_desc->GetNamePtr(), op_desc->GetTypePtr(), - private_attrs.size()); - size_t extend_info_size; - holder.compute_node_extend_holder_ = - bg::CreateComputeNodeInfo(MakeNode(op_desc), holder.buffer_pool_, private_attrs, extend_info_size); - - if (holder.compute_node_extend_holder_ == nullptr) { - GELOGE(ge::GRAPH_FAILED, - "Failed to create compute node info for node %s", op_desc->GetName().c_str()); - return holder; - } - auto compute_node_info = ge::PtrToPtr(holder.compute_node_extend_holder_.get()); - compute_node_info->SetNodeName( - holder.buffer_pool_.GetBufById(reinterpret_cast(compute_node_info->GetNodeName()))); - compute_node_info->SetNodeType( - holder.buffer_pool_.GetBufById(reinterpret_cast(compute_node_info->GetNodeType()))); - holder.context_ = ge::PtrToPtr(holder.context_holder_.get()); - auto kernel_run_context = holder.context_->GetContext(); - kernel_run_context->input_size = inputs_.size(); - kernel_run_context->output_size = outputs_.size(); - kernel_run_context->compute_node_info = compute_node_info; - kernel_run_context->output_start = &(kernel_run_context->values[kernel_run_context->input_size]); - holder.value_holder_.resize(inputs_.size() + outputs_.size()); - for (size_t i = 0UL; i < holder.value_holder_.size(); ++i) { - kernel_run_context->values[i] = ge::PtrToPtr(&holder.value_holder_[i]); - } - for (size_t i = 0UL; i < inputs_.size(); ++i) { - holder.value_holder_[i].Set(inputs_[i].first, inputs_[i].second); - } - for (size_t i = 0UL; i < outputs_.size(); ++i) { - holder.value_holder_[inputs_.size() + i].Set(outputs_[i].first, outputs_[i].second); - } - ret = ge::GRAPH_SUCCESS; - return holder; -} - -ge::NodePtr KernelRunContextBuilder::MakeNode(const ge::OpDescPtr &op_desc) { - const auto node_id = op_desc->GetId(); - graph_ = std::make_shared("tmp"); - auto fake_node = graph_->AddNode(op_desc); - GE_CHECK_NOTNULL_EXEC(fake_node, return nullptr); - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); ++i) { - const auto input_desc = op_desc->GetInputDesc(i); - if (input_desc.IsValid() != ge::GRAPH_SUCCESS) { - GELOGD("Node: %s, input: %zu, is invalid, skip add edge.", op_desc->GetNamePtr(), i); - continue; - } - auto op_data = ge::OpDescBuilder(std::to_string(i), "Data").AddInput("x").AddOutput("y").Build(); - auto data_node = graph_->AddNode(op_data); - if (data_node == nullptr) { - return nullptr; - } - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), fake_node->GetInDataAnchor(i)); - } - // AddNode operation may change node id to 0, which need to be recovered - op_desc->SetId(node_id); - return fake_node; -} -} // namespace gert diff --git a/exe_graph/lowering/kernel_run_context_builder.h b/exe_graph/lowering/kernel_run_context_builder.h deleted file mode 100644 index 29a75270ec35b0b4c87c97da64ec847e5bf0a910..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/kernel_run_context_builder.h +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_RUNTIME_KERNEL_CONTEXT_BUILDER_H_ -#define METADEF_CXX_RUNTIME_KERNEL_CONTEXT_BUILDER_H_ - -#include "graph/node.h" -#include "exe_graph/runtime/compute_node_info.h" -#include "exe_graph/runtime/kernel_context.h" -#include "exe_graph/lowering/buffer_pool.h" - -namespace gert { -class KernelContextHolder { -public: - KernelContextHolder() = default; - KernelContextHolder(KernelContextHolder &&holder) { - context_holder_ = std::move(holder.context_holder_); - value_holder_ = std::move(holder.value_holder_); - compute_node_extend_holder_ = std::move(holder.compute_node_extend_holder_); - buffer_pool_ = holder.buffer_pool_; - context_ = holder.context_; - } - - KernelContextHolder &operator=(KernelContextHolder &&holder) { - context_holder_ = std::move(holder.context_holder_); - value_holder_ = std::move(holder.value_holder_); - compute_node_extend_holder_ = std::move(holder.compute_node_extend_holder_); - buffer_pool_ = holder.buffer_pool_; - context_ = holder.context_; - return *this; - } - - ~KernelContextHolder() { - for (auto &value : value_holder_) { - value.Set(nullptr, nullptr); - } - } - - KernelContext *GetKernelContext() { - return context_; - } - - std::unique_ptr context_holder_; - std::vector value_holder_; - std::unique_ptr compute_node_extend_holder_; - bg::BufferPool buffer_pool_; - KernelContext *context_; -}; -class KernelRunContextBuilder { -public: - KernelRunContextBuilder() = default; - KernelRunContextBuilder &Inputs(std::vector> inputs) { - inputs_ = std::move(inputs); - return *this; - } - - KernelRunContextBuilder &Inputs(std::vector inputs) { - for (auto &input : inputs) { - inputs_.emplace_back(input, nullptr); - } - return *this; - } - - KernelRunContextBuilder &Outputs(std::vector outputs) { - for (auto &output : outputs) { - outputs_.emplace_back(output, nullptr); - } - return *this; - } - - KernelRunContextBuilder &Outputs(std::vector> outputs) { - outputs_ = std::move(outputs); - return *this; - } - - // deprecated when air use Build interface with ret_status - KernelContextHolder Build(const ge::OpDescPtr &op_desc); - KernelContextHolder Build(const ge::OpDescPtr &op_desc, ge::graphStatus &ret); - -private: - ge::NodePtr MakeNode(const ge::OpDescPtr &op_desc); -private: - ge::ComputeGraphPtr graph_; - std::vector> inputs_; - std::vector> outputs_; -}; -} // namespace gert -#endif diff --git a/exe_graph/lowering/lowering_global_data.cc b/exe_graph/lowering/lowering_global_data.cc deleted file mode 100644 index e3b1089f7d0199cc8025b71b279abaaf466be8f2..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/lowering_global_data.cc +++ /dev/null @@ -1,578 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/lowering_global_data.h" -#include -#include "common/checker.h" -#include "exe_graph/lowering/frame_selector.h" -#include "runtime/allocator.h" -#include "graph/fast_graph/execute_graph.h" - -namespace gert { -namespace { -constexpr const ge::char_t *kGlobalDataL2AllocatorsPrefix = "L2_Allocators_"; -const std::set kCurrentAllocatorSupportPlacement = { - TensorPlacement::kOnDeviceHbm, TensorPlacement::kOnHost, TensorPlacement::kFollowing}; -const std::set kCurrentAllocatorSupportUsage = { - AllocatorUsage::kAllocNodeOutput, AllocatorUsage::kAllocNodeWorkspace, AllocatorUsage::kAllocNodeShapeBuffer}; - -inline bool IsPlacementSupportExternalAllocator(const TensorPlacement placement) { - return kCurrentAllocatorSupportPlacement.find(placement) != kCurrentAllocatorSupportPlacement.cend(); -} - -inline bool IsUsageSupportExternalAllocator(const AllocatorUsage &usage) { - return kCurrentAllocatorSupportUsage.find(usage) != kCurrentAllocatorSupportUsage.cend(); -} - -bool CurrentOnInitGraph() { - const ge::FastNode *subgraph_node = nullptr; - auto current_graph = bg::ValueHolder::GetCurrentExecuteGraph(); - while ((current_graph != nullptr) && (current_graph->GetParentNodeBarePtr() != nullptr)) { - subgraph_node = current_graph->GetParentNodeBarePtr(); - current_graph = subgraph_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - } - - if (subgraph_node != nullptr) { - return strcmp(subgraph_node->GetTypePtr(), GetExecuteGraphTypeStr(ExecuteGraphType::kInit)) == 0; - } else { - return false; - } -} -/* - * 此处用于判断是否能使用init图里的allocator: - * 1.用户设置了always_external_allocator选项 - * 2.AllocatorDesc在我们当前支持的allocator范围内 - * - * 为了兼容性考虑,当前只能支持现有的allocator,否则后续我们新增placement/useage时则会出错,用户老的版本加上我们新的软件会出错 - * */ -bool CanUseInitAllocator(const bool always_external_allocator, const AllocatorDesc &desc) { - if (!always_external_allocator) { - return false; - } - if (IsPlacementSupportExternalAllocator(desc.placement) && IsUsageSupportExternalAllocator(desc.usage)) { - return true; - } else { - GELOGW("We don't support placement[%d] or usage[%d] current while always_external_allocator is true", - static_cast(desc.placement), static_cast(desc.usage)); - return false; - } -} - -std::string GenL2AllocatorKey(int64_t logic_stream_id, const AllocatorDesc desc) { - if (TensorPlacementUtils::IsOnDevice(desc.placement)) { - return kGlobalDataL2AllocatorsPrefix + std::to_string(desc.placement) + std::to_string(logic_stream_id); - } else { - return kGlobalDataL2AllocatorsPrefix + std::to_string(desc.placement); - } -} - -std::string GenAllL2AllocatorsKey(const AllocatorDesc desc) { - return kGlobalDataL2AllocatorsPrefix + std::to_string(desc.placement); -} - -std::string GenInitL2AllocatorsKey(const AllocatorDesc desc) { - return kGlobalDataL2AllocatorsPrefix + std::to_string(desc.placement) + "-Init"; -} - -bg::ValueHolderPtr GetOrCreateL2Allocators(const AllocatorDesc desc, LoweringGlobalData &global_data) { - auto all_l2_allocators_key = GenAllL2AllocatorsKey(desc); - return global_data.GetOrCreateUniqueValueHolder(all_l2_allocators_key, [&desc, &global_data]() -> bg::ValueHolderPtr { - auto init_out = bg::FrameSelector::OnInitRoot([&desc, &global_data]() -> std::vector { - auto placement_holder = bg::ValueHolder::CreateConst(&desc.placement, sizeof(desc.placement)); - auto init_l1_allocator = global_data.GetOrCreateL1Allocator(desc); - auto stream_num = global_data.GetUniqueValueHolder(kGlobalDataModelStreamNum); - GE_ASSERT_NOTNULL(stream_num); - // CreateL2Allocators第一个输出是AllL2Allocators,表示所有二级内存池; - // 第二个输出是AllL2MemPools,表示所有二级内存池里实际申请内存的allocator; - auto l2_allocators_on_init = bg::ValueHolder::CreateDataOutput("CreateL2Allocators", - {stream_num, placement_holder}, 2)[0]; - bg::ValueHolder::AddDependency(init_l1_allocator, l2_allocators_on_init); - return {l2_allocators_on_init}; - }); - GE_ASSERT_EQ(init_out.size(), 1U); - GE_ASSERT_NOTNULL(init_out[0]); - return init_out[0]; - }); -} - -bg::ValueHolderPtr GetOrCreateDeviceL2Allocator(int64_t logic_stream_id, AllocatorDesc desc, - const std::string &l2_allocator_key, LoweringGlobalData &global_data) { - auto builder = [&logic_stream_id, &desc, &global_data]() -> bg::ValueHolderPtr { - bg::ValueHolderPtr l2_allocators = GetOrCreateL2Allocators(desc, global_data); - GE_ASSERT_NOTNULL(l2_allocators); - auto logic_stream_id_holder = bg::ValueHolder::CreateConst(&logic_stream_id, sizeof(logic_stream_id)); - auto rt_stream_holder = global_data.GetStreamById(logic_stream_id); - GE_ASSERT_NOTNULL(rt_stream_holder); - auto l1_allocator = global_data.GetOrCreateL1Allocator(desc); - return bg::ValueHolder::CreateSingleDataOutput( - "SelectL2Allocator", {logic_stream_id_holder, rt_stream_holder, l1_allocator, l2_allocators}); - }; - return global_data.GetOrCreateUniqueValueHolder(l2_allocator_key, builder); -} - -bg::ValueHolderPtr GetOrCreateHostL2Allocator(AllocatorDesc desc, const std::string &l2_allocator_key, - LoweringGlobalData &global_data) { - auto builder = [&desc, &global_data]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateSingleDataOutput("CreateHostL2Allocator", {global_data.GetOrCreateL1Allocator(desc)}); - }; - return global_data.GetOrCreateUniqueValueHolder(l2_allocator_key, builder); -} -} // namespace - -const LoweringGlobalData::NodeCompileResult *LoweringGlobalData::FindCompiledResult(const ge::NodePtr &node) const { - const auto iter = node_name_to_compile_result_holders_.find(node->GetName()); - if (iter == node_name_to_compile_result_holders_.cend()) { - return nullptr; - } - return &iter->second; -} -LoweringGlobalData &LoweringGlobalData::AddCompiledResult(const ge::NodePtr &node, - LoweringGlobalData::NodeCompileResult compile_result) { - node_name_to_compile_result_holders_[node->GetName()] = std::move(compile_result); - return *this; -} - -void *LoweringGlobalData::GetGraphStaticCompiledModel(const std::string &graph_name) const { - const auto iter = graph_to_static_models_.find(graph_name); - if (iter == graph_to_static_models_.cend()) { - return nullptr; - } - return iter->second; -} - -LoweringGlobalData &LoweringGlobalData::AddStaticCompiledGraphModel(const std::string &graph_name, void *const model) { - graph_to_static_models_[graph_name] = model; - return *this; -} - -bg::ValueHolderPtr LoweringGlobalData::GetL1Allocator(const AllocatorDesc &desc) const { - if (CurrentOnInitGraph()) { - return GetUniqueValueHolder(desc.GetKey() + "-Init"); - } else { - return GetUniqueValueHolder(desc.GetKey()); - } -} -LoweringGlobalData &LoweringGlobalData::SetExternalAllocator(bg::ValueHolderPtr &&allocator) { - return SetExternalAllocator(std::move(allocator), ExecuteGraphType::kMain); -} -LoweringGlobalData &LoweringGlobalData::SetExternalAllocator(bg::ValueHolderPtr &&allocator, - const ExecuteGraphType graph_type) { - if (graph_type >= ExecuteGraphType::kNum) { - return *this; - } - external_allocators_.holders[static_cast(graph_type)] = std::move(allocator); - return *this; -} - -bool LoweringGlobalData::CanUseExternalAllocator(const ExecuteGraphType &graph_type, - const TensorPlacement placement) const { - return IsPlacementSupportExternalAllocator(placement) - && (external_allocators_.holders[static_cast(graph_type)] != nullptr); -} - -bg::ValueHolderPtr LoweringGlobalData::GetExternalAllocator(const bool from_init, const string &key, - const AllocatorDesc &desc) { - bg::ValueHolderPtr init_selected_allocator = nullptr; - auto init_out = - bg::FrameSelector::OnInitRoot([&desc, &init_selected_allocator, this]() -> std::vector { - auto placement_holder = bg::ValueHolder::CreateConst(&desc.placement, sizeof(desc.placement)); - init_selected_allocator = nullptr; - if (CanUseExternalAllocator(ExecuteGraphType::kInit, desc.placement)) { - init_selected_allocator = bg::ValueHolder::CreateSingleDataOutput( - "GetExternalL1Allocator", - {placement_holder, - external_allocators_.holders[static_cast(ExecuteGraphType::kInit)]}); - } else { - GELOGE(ge::PARAM_INVALID, "always_external_allocator option is true but external_allocators is nullptr!"); - } - return {init_selected_allocator}; - }); - GE_ASSERT_EQ(init_out.size(), 1U); - GE_ASSERT_NOTNULL(init_out[0]); - - auto allocator = bg::FrameSelector::OnMainRoot([&desc, &init_out, this]() -> std::vector { - auto main_selected_allocator = init_out[0]; - auto placement_holder = bg::ValueHolder::CreateConst(&desc.placement, sizeof(desc.placement)); - if (CanUseExternalAllocator(ExecuteGraphType::kMain, desc.placement)) { - main_selected_allocator = bg::ValueHolder::CreateSingleDataOutput( - "SelectL1Allocator", - {placement_holder, external_allocators_.holders[static_cast(ExecuteGraphType::kMain)], init_out[0], - GetStreamById(kDefaultMainStreamId)}); - } - return {main_selected_allocator}; - }); - GE_ASSERT_EQ(allocator.size(), 1U); - - SetUniqueValueHolder(key + "-Init", init_selected_allocator); - SetUniqueValueHolder(key, allocator[0]); - if (from_init) { - return init_selected_allocator; - } else { - return allocator[0]; - } -} - -/* CanUseInitAllocator is true - * +------------------------------------------------------------------+ - * |Main Graph | - * | AllocMemory | - * | | | - * | (allocator) | - * | | | - * | InnerData | - * +------------------------------------------------------------------+ - * +------------------------------------------------------------------+ - * |Init Graph | - * | | - * | InnerNetOutput | - * | ^ | - * | | | - * | GetExternalL1Allocator | - * | / | \ | - * | Const(placement) Const(usage) Data(Allocator)(-2) | | - * +------------------------------------------------------------------+ - */ - -/* CanUseInitAllocator is false - * +------------------------------------------------------------------+ - * |Main Graph | - * | (allocator) | - * | | | - * | +------> SelectL1Allocator <-----+ | - * | | / \ | | - * | InnerData InnerData InnerData Data(-2) | - * +------------------------------------------------------------------+ - * +------------------------------------------------------------------+ - * |Init Graph | - * | | - * | +------+---> InnerNetOutput (allocator) | - * | | | ^ | | - * | | | | SelectL1Allocator | - * | | | | / ^ \ | - * | | | CreateL1Allocator | Data(Allocator)(-2) | - * | | | / \ | | - * | | Const(placement) Const(usage) | | - * | | | | | - * | +-------------------------+----------+ | - * +------------------------------------------------------------------+ - */ -bg::ValueHolderPtr LoweringGlobalData::GetOrCreateL1Allocator(const AllocatorDesc desc) { - const auto key = desc.GetKey(); - const auto init_key = key + "-Init"; - const auto from_init = CurrentOnInitGraph(); - - bg::ValueHolderPtr allocator_holder; - if (from_init) { - allocator_holder = GetUniqueValueHolder(init_key); - } else { - allocator_holder = GetUniqueValueHolder(key); - } - - if (allocator_holder != nullptr) { - return allocator_holder; - } - /* - * 用户设置always_external_allocator场景下,同时external_allocators_不为空的情况下,一定认为所有类型的allocator都创建好了,原因: - * 1.不能考虑外置的external_allocators_中存在某些类型的allocator没有创建,之前为了保证正确性,必须在构图时根据placement跟usage - * 创建一个CreateAllocator节点,在执行时创建兜底的allocator对象,但是allocator对象是需要浪费host内存资源,对单算子场景下, - * 频繁创建导致host内存上升,因此设置了always_external_allocator的场景下不考虑某些类型的allocator没有创建 - * - * 2.为什么这个地方不能判断满足当前placement+usage的allocator是否已经创建好了?这个地方还在构图,此时还是valueholder,还没有到初始化 - * 执行,因此无法感知用户是否完整创建了所有allocator,只有初始化图执行时才知道。 - * - * 3.因此对于此场景,考虑在初始化图执行时做一个校验,用户设置了always_external_allocator的场景下,确保所有类型的allocator都创建好了 - * 因此, 在单算子场景下,需要无脑校验 - * - * 4.为了兼容性考虑,当前只能支持现有的allocator,否则后续我们新增placement/useage时则会出错,用户老的版本加上我们新的软件会出错 - * - * 5.always_external_allocator可以后续整改为always_use_init_allocator - * */ - if (CanUseInitAllocator(GetLoweringOption().always_external_allocator, desc)) { - return GetExternalAllocator(from_init, key, desc); - } else { - bg::ValueHolderPtr init_selected_allocator = nullptr; - auto init_out = - bg::FrameSelector::OnInitRoot([&desc, &init_selected_allocator, this]() -> std::vector { - auto placement_holder = bg::ValueHolder::CreateConst(&desc.placement, sizeof(desc.placement)); - auto created_allocator = bg::ValueHolder::CreateSingleDataOutput("CreateL1Allocator", {placement_holder}); - if (CanUseExternalAllocator(ExecuteGraphType::kInit, desc.placement)) { - const auto init_external_allocator = - external_allocators_.holders[static_cast(ExecuteGraphType::kInit)]; - init_selected_allocator = bg::ValueHolder::CreateSingleDataOutput( - "SelectL1Allocator", - {placement_holder, init_external_allocator, created_allocator, GetStreamById(kDefaultMainStreamId)}); - // here init_selected_allocator to init_node_out just for deconstruct sequence. - // To make sure memblock alloced from init graph, which deconstruction relies on allocator alive. - return {created_allocator, placement_holder, init_selected_allocator}; - } else { - init_selected_allocator = created_allocator; - return {created_allocator, placement_holder}; - } - }); - GE_ASSERT_TRUE(init_out.size() >= 2U); - - auto allocator = bg::FrameSelector::OnMainRoot([&init_out, &desc, this]() -> std::vector { - auto main_selected_allocator = init_out[0]; - if (CanUseExternalAllocator(ExecuteGraphType::kMain, desc.placement)) { - const auto main_external_allocator = external_allocators_.holders[static_cast(ExecuteGraphType::kMain)]; - main_selected_allocator = bg::ValueHolder::CreateSingleDataOutput( - "SelectL1Allocator", {init_out[1], main_external_allocator, init_out[0], - GetStreamById(kDefaultMainStreamId)}); - } - return {main_selected_allocator}; - }); - GE_ASSERT_EQ(allocator.size(), 1U); - - SetUniqueValueHolder(key + "-Init", init_selected_allocator); - SetUniqueValueHolder(key, allocator[0]); - - if (from_init) { - return init_selected_allocator; - } else { - return allocator[0]; - } - } -} - -bg::ValueHolderPtr LoweringGlobalData::GetOrCreateUniqueValueHolder( - const std::string &name, const std::function &builder) { - return GetOrCreateUniqueValueHolder(name, [&builder]() -> std::vector { return {builder()}; })[0]; -} - -std::vector LoweringGlobalData::GetOrCreateUniqueValueHolder( - const std::string &name, const std::function()> &builder) { - const decltype(unique_name_to_value_holders_)::const_iterator &iter = unique_name_to_value_holders_.find(name); - if (iter == unique_name_to_value_holders_.cend()) { - auto holder = builder(); - return unique_name_to_value_holders_.emplace(name, holder).first->second; - } - return iter->second; -} -void LoweringGlobalData::SetUniqueValueHolder(const string &name, const bg::ValueHolderPtr &holder) { - if (!unique_name_to_value_holders_.emplace(name, std::vector{holder}).second) { - unique_name_to_value_holders_.erase(name); - unique_name_to_value_holders_.emplace(name, std::vector{holder}); - } -} -bg::ValueHolderPtr LoweringGlobalData::GetUniqueValueHolder(const string &name) const { - const auto &iter = unique_name_to_value_holders_.find(name); - if (iter == unique_name_to_value_holders_.cend()) { - return nullptr; - } - return iter->second[0]; -} - -void LoweringGlobalData::SetValueHolders(const string &name, const bg::ValueHolderPtr &holder) { - unique_name_to_value_holders_[name].emplace_back(holder); -} - -size_t LoweringGlobalData::GetValueHoldersSize(const string &name) { - const auto &iter = unique_name_to_value_holders_.find(name); - if (iter == unique_name_to_value_holders_.cend()) { - return 0U; - } - return iter->second.size(); -} - -void LoweringGlobalData::SetModelWeightSize(const size_t require_weight_size) { - model_weight_size_ = require_weight_size; -} -size_t LoweringGlobalData::GetModelWeightSize() const { - return model_weight_size_; -} - -const LoweringOption &LoweringGlobalData::GetLoweringOption() const { - return lowering_option_; -} -void LoweringGlobalData::SetLoweringOption(const LoweringOption &lowering_option) { - lowering_option_ = lowering_option; -} - -/* -* init_graph -* +--------------------------+ -* | | | -* Const(placement) Const(stream_num) Const(placement) Const(usage) | -* \ / \ / | Data(externel_allocator) -* CreateL2Allocators CreateL1Allocator ----------+-----+ | Data(external_stream) -* \ / | \ | / -* \ / | SelectAllocator -* InnerNetOutput <-------------------------+ / \ -* CreateHostL2Allocator CreateInitL2Allocator -* -* -* main_graph -* SelectL1Allocator -* Data(rt_streams) / \ -* | / InnerData CreateHostL2Allocator -* SplitRtStreams /(l2 allocators) -* \ / / -* SelectL2Allocator -* -* 有外置allocator场景下的L2 allocator -*/ -bg::ValueHolderPtr LoweringGlobalData::GetOrCreateL2Allocator(int64_t logic_stream_id, AllocatorDesc desc) { - if (CurrentOnInitGraph()) { - // 在init图也可能有申请内存的行为,也需要使用l2 allocator来申请。 - // init图中的l2 allocator要绑定到init图的主流上,因此使用新kernel,与main图中l2 allocator构图不同 - return GetOrCreateInitL2Allocator(desc); - } - - auto l2_allocator_key = GenL2AllocatorKey(logic_stream_id, desc); - auto l2_allocator = GetUniqueValueHolder(l2_allocator_key); - if (l2_allocator != nullptr) { - return l2_allocator; - } - // To make sure l2 allocator lowering at main graph. In case l2 allocator lowering at subgraph cause a link - // from inside subgraph to main root. - auto tmp_l2_allocators = bg::FrameSelector::OnMainRoot( - [&desc, &logic_stream_id, &l2_allocator_key, this]() -> std::vector { - if (TensorPlacementUtils::IsOnDevice(desc.placement)) { - return {GetOrCreateDeviceL2Allocator(logic_stream_id, desc, l2_allocator_key, *this)}; - } else { - return {GetOrCreateHostL2Allocator(desc, l2_allocator_key, *this)}; - } - }); - GE_ASSERT_TRUE(tmp_l2_allocators.size() == 1U); - - // main图中创建host allocator时确保init图中也创建相应的host allocator,为CEM做准备 - if (TensorPlacementUtils::IsOnHost(desc.placement)) { - auto init_l2_host_allocator = - bg::FrameSelector::OnInitRoot([&desc, this]() -> std::vector { - return {GetOrCreateInitL2Allocator(desc)}; - }); - GE_ASSERT_TRUE(init_l2_host_allocator.size() == 1U); - } - return tmp_l2_allocators[0]; -} - -bg::ValueHolderPtr LoweringGlobalData::GetInitL2Allocator(AllocatorDesc desc) const { - auto init_l2_allocator_key = GenInitL2AllocatorsKey(desc); - return GetUniqueValueHolder(init_l2_allocator_key); -} - -bg::ValueHolderPtr LoweringGlobalData::GetMainL2Allocator(int64_t logic_stream_id, AllocatorDesc desc) const { - auto main_l2_allocator_key = GenL2AllocatorKey(logic_stream_id, desc); - return GetUniqueValueHolder(main_l2_allocator_key); -} - -/** - * This interface can only use on init graph - * @param logic_stream_id - * @param desc - * @param global_data - * @return - */ -bg::ValueHolderPtr LoweringGlobalData::GetOrCreateInitL2Allocator(const AllocatorDesc desc) { - if (!CurrentOnInitGraph()) { - return nullptr; - } - auto init_l2_allocator_key = GenInitL2AllocatorsKey(desc); - bg::ValueHolderPtr init_l2_allocator = nullptr; - if (TensorPlacementUtils::IsOnHost(desc.placement)) { - auto builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateSingleDataOutput("CreateHostL2Allocator", - {GetOrCreateL1Allocator(desc)}); - }; - init_l2_allocator = GetOrCreateUniqueValueHolder(init_l2_allocator_key, builder); - } else if (TensorPlacementUtils::IsOnDevice(desc.placement)) { - bg::ValueHolderPtr l2_allocators = GetOrCreateL2Allocators(desc, *this); - GE_ASSERT_NOTNULL(l2_allocators); - - auto builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateDataOutput( - "CreateInitL2Allocator", - {GetOrCreateL1Allocator(desc), bg::HolderOnInit(l2_allocators), GetStreamById(kDefaultMainStreamId)}, 2)[0]; - }; - init_l2_allocator = GetOrCreateUniqueValueHolder(init_l2_allocator_key, builder); - } else { - GELOGE(ge::PARAM_INVALID, "Unsupported placement %s.", desc.GetKey().c_str()); - return nullptr; - } - GE_ASSERT_NOTNULL(init_l2_allocator); - - // 将InitL2Allocator放到init的输出上,保证其在根图析构时 析构 - bg::FrameSelector::OnInitRoot([&init_l2_allocator]()->std::vector { - return {init_l2_allocator}; - }); - return init_l2_allocator; -} - -bg::ValueHolderPtr LoweringGlobalData::GetStreamById(int64_t logic_stream_id) const { - ExecuteGraphType graph_type = ExecuteGraphType::kMain; - if (CurrentOnInitGraph()) { - graph_type = ExecuteGraphType::kInit; - GE_ASSERT_TRUE(logic_stream_id == kDefaultMainStreamId); - } - const auto split_rt_streams = streams_.holders[static_cast(graph_type)]; - GE_ASSERT_TRUE(static_cast(split_rt_streams.size()) > logic_stream_id); - return split_rt_streams[logic_stream_id]; -} - -bg::ValueHolderPtr LoweringGlobalData::GetNotifyById(int64_t logic_notify_id) const { - ExecuteGraphType graph_type = ExecuteGraphType::kMain; - if (CurrentOnInitGraph()) { - graph_type = ExecuteGraphType::kInit; - } - const auto rt_notifies = notifies_.holders[static_cast(graph_type)]; - GE_ASSERT_TRUE(static_cast(rt_notifies.size()) > logic_notify_id, - "notify id [%ld] is invalid, total usable notify nums:[%zu].", logic_notify_id, rt_notifies.size()); - return rt_notifies[logic_notify_id]; -} - -/* - * Init - * +-------------------------------------------------+ - * | Data(-1) Const(stream num = 1) | - * | \ / | - * | SplitRtStreams Const(stream_num) | - * | \ / | - * | InnerNetoutput | - * +-------------------------------------------------+ - * Main - * +-------------------------------------------------+ - * | Data(-1) InnerData(stream_num) | - * | \ / | - * | SplitRtStreams | - * | | | - * | | - * +-------------------------------------------------+ - */ -std::vector LoweringGlobalData::LoweringAndSplitRtStreams(int64_t stream_num) { - ExecuteGraphType graph_type = ExecuteGraphType::kMain; - if (CurrentOnInitGraph()) { - graph_type = ExecuteGraphType::kInit; - GE_ASSERT_TRUE(stream_num == 1); - } - - if (graph_type == ExecuteGraphType::kMain) { - is_single_stream_scene_ = (stream_num <= 1); - } - const auto stream_num_holder = bg::ValueHolder::CreateConst(&stream_num, sizeof(stream_num)); - GE_ASSERT_NOTNULL(stream_num_holder); - auto execute_arg_streams = bg::ValueHolder::CreateFeed(-1); - GE_ASSERT_NOTNULL(execute_arg_streams); - auto streams = - bg::ValueHolder::CreateDataOutput("SplitRtStreams", {execute_arg_streams, stream_num_holder}, stream_num); - streams_.holders[static_cast(graph_type)] = streams; - return streams_.holders[static_cast(graph_type)]; -} - -void LoweringGlobalData::SetRtNotifies(const std::vector ¬ify_holders) { - ExecuteGraphType graph_type = ExecuteGraphType::kMain; - if (CurrentOnInitGraph()) { - graph_type = ExecuteGraphType::kInit; - } - notifies_.holders[static_cast(graph_type)] = notify_holders; -} - -bg::ValueHolderPtr LoweringGlobalData::GetOrCreateAllL2Allocators() { - return GetOrCreateL2Allocators({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}, *this); -} -} // namespace gert diff --git a/exe_graph/lowering/scoped_current_frame.h b/exe_graph/lowering/scoped_current_frame.h deleted file mode 100644 index bb3db186b0bb74e5086aa7251677f1b68a182461..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/scoped_current_frame.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_EXE_GRAPH_LOWERING_SCOPED_CURRENT_FRAME_H_ -#define METADEF_CXX_EXE_GRAPH_LOWERING_SCOPED_CURRENT_FRAME_H_ -#include "exe_graph/lowering/graph_frame.h" -#include "value_holder_inner.h" -namespace gert { -namespace bg { -class ScopedCurrentFrame { - public: - explicit ScopedCurrentFrame(GraphFrame *frame) { - backup_graph_frame_ = GetCurrentFrame(); - SetCurrentFrame(frame); - } - ~ScopedCurrentFrame() { - SetCurrentFrame(backup_graph_frame_); - } - - private: - GraphFrame *backup_graph_frame_ = nullptr; -}; -} // namespace bg -} // namespace gert -#endif // METADEF_CXX_EXE_GRAPH_LOWERING_SCOPED_CURRENT_FRAME_H_ diff --git a/exe_graph/lowering/shape_utils.cc b/exe_graph/lowering/shape_utils.cc deleted file mode 100644 index f3d9d198b1353a63ce5d3d297f02ad0ff88799e3..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/shape_utils.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/shape_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/math_util.h" -#include "inc/common/checker.h" - -namespace gert { -const Shape g_vec_1_shape = {1}; - -ge::graphStatus CalcAlignedSizeByShape(const Shape &shape, ge::DataType data_type, uint64_t &ret_tensor_size) { - constexpr uint64_t kAlignBytes = 32U; - auto shape_size = shape.GetShapeSize(); - int64_t cal_size = 0; - if (data_type == ge::DT_STRING) { - uint32_t type_size = 0U; - GE_ASSERT_TRUE(ge::TypeUtils::GetDataTypeLength(data_type, type_size)); - if (ge::MulOverflow(shape_size, static_cast(type_size), cal_size)) { - GELOGE(ge::GRAPH_FAILED, "[Calc][TensorSizeByShape] shape_size[%ld] multiplied by type_size[%u] overflowed!", - shape_size, type_size); - return ge::GRAPH_FAILED; - } - } else { - cal_size = ge::GetSizeInBytes(shape_size, data_type); - } - if (cal_size < 0) { - GELOGE(ge::GRAPH_FAILED, "[Calc][TensorSizeByShape] shape_size[%" PRId64 "] data_type[%s] failed", shape_size, - ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); - return ge::GRAPH_FAILED; - } - - // 不可能溢出,因为ret最大值也只有int64的最大值 - ret_tensor_size = ge::RoundUp(static_cast(cal_size), kAlignBytes) + kAlignBytes; - return ge::GRAPH_SUCCESS; -} -} // namespace gert diff --git a/exe_graph/lowering/tiling_context_builder.cc b/exe_graph/lowering/tiling_context_builder.cc deleted file mode 100644 index 828f8616f1ed444fe5c8bf09eef99e09f74269b2..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/tiling_context_builder.cc +++ /dev/null @@ -1,261 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/tiling_context_builder.h" - -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include "lowering/data_dependent_interpreter.h" -#include "graph/compute_graph.h" -#include "graph/operator.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/debug/ge_util.h" -#include "graph/def_types.h" -#include "common/checker.h" -#include "graph/debug/ge_util.h" - -namespace gert { -namespace { -void GetStorageShape(const ge::GeTensorDesc &tensor_desc, gert::StorageShape &storage_shape) { - const auto &storage_dims = tensor_desc.GetShape().GetDims(); - for (const auto &dim : storage_dims) { - (void)storage_shape.MutableStorageShape().AppendDim(dim); - } - const auto &origin_dims = tensor_desc.GetOriginShape().GetDims(); - for (const auto &dim : origin_dims) { - (void)storage_shape.MutableOriginShape().AppendDim(dim); - } -} -} // namespace - -TilingContextBuilder &TilingContextBuilder::CompileInfo(void *compile_info) { - compile_info_ = compile_info; - return *this; -} -TilingContextBuilder &TilingContextBuilder::PlatformInfo(void *platform_info) { - platform_info_ = platform_info; - return *this; -} -TilingContextBuilder &TilingContextBuilder::Deterministic(int32_t deterministic) { - deterministic_ = deterministic; - return *this; -} - -TilingContextBuilder &TilingContextBuilder::TilingData(void *tiling_data) { - outputs_[TilingContext::kOutputTilingData] = tiling_data; - return *this; -} -TilingContextBuilder &TilingContextBuilder::Workspace(ContinuousVector *workspace) { - outputs_[TilingContext::kOutputWorkspace] = workspace; - return *this; -} -TilingContextBuilder &TilingContextBuilder::SpaceRegistry(const OpImplSpaceRegistryPtr &space_registry) { - space_registries_[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - use_registry_v2_ = false; - return *this; -} -TilingContextBuilder &TilingContextBuilder::SpaceRegistries(const OpImplSpaceRegistryArray &space_registries) { - space_registries_ = space_registries; - use_registry_v2_ = false; - return *this; -} - -TilingContextBuilder &TilingContextBuilder::SetSpaceRegistryV2(const OpImplSpaceRegistryV2Ptr &space_registry, OppImplVersionTag version_tag) { - if (version_tag >= OppImplVersionTag::kVersionEnd) { - GELOGE(ge::PARAM_INVALID, "version_tag %d is invalid", static_cast(version_tag)); - return *this; - } - space_registries_v2_[static_cast(version_tag)] = space_registry; - use_registry_v2_ = true; - return *this; -} - -ge::graphStatus TilingContextBuilder::GetDependInputTensorAddr(const ge::Operator &op, const size_t input_idx, - TensorAddress &address) { - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_ASSERT_NOTNULL(op_desc); - auto depend_tensor = ge::ComGraphMakeUnique(); - depend_ge_tensor_holders_.emplace_back(std::move(depend_tensor)); - GE_ASSERT_NOTNULL(depend_ge_tensor_holders_.back()); - auto input_name = op_desc->GetValidInputNameByIndex(static_cast(input_idx)); - if (op.GetInputConstData(input_name.c_str(), *(depend_ge_tensor_holders_.back().get())) == ge::GRAPH_SUCCESS) { - address = depend_ge_tensor_holders_.back()->GetData(); - } else { - address = nullptr; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus TilingContextBuilder::BuildRtTensor(const ge::GeTensorDesc &tensor_desc, - ConstTensorAddressPtr address, - std::unique_ptr &rt_tensor_holder) const { - gert::StorageShape storage_shape; - GetStorageShape(tensor_desc, storage_shape); - - rt_tensor_holder = ge::ComGraphMakeUnique(sizeof(gert::Tensor)); - GE_ASSERT_NOTNULL(rt_tensor_holder, "Create context holder inputs failed."); - auto rt_tensor = ge::PtrToPtr(rt_tensor_holder.get()); - rt_tensor->SetDataType(tensor_desc.GetDataType()); - rt_tensor->MutableStorageShape() = storage_shape.GetStorageShape(); - rt_tensor->MutableOriginShape() = storage_shape.GetOriginShape(); - rt_tensor->MutableFormat().SetStorageFormat(tensor_desc.GetFormat()); - rt_tensor->MutableFormat().SetOriginFormat(tensor_desc.GetOriginFormat()); - (void)rt_tensor->MutableTensorData().SetAddr(address, nullptr); - rt_tensor->MutableTensorData().SetPlacement(gert::kOnHost); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus TilingContextBuilder::BuildRTInputTensors(const ge::Operator &op) { - const auto node = ge::NodeUtilsEx::GetNodeFromOperator(op); - auto shared_node = const_cast(node.get())->shared_from_this(); - std::shared_ptr ddi = nullptr; - if (use_registry_v2_) { - ddi = ge::ComGraphMakeShared(shared_node->GetOpDesc(), space_registries_v2_); - } else { - ddi = ge::ComGraphMakeShared(shared_node->GetOpDesc(), space_registries_); - } - GE_ASSERT_NOTNULL(ddi); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - - size_t valid_input_idx = 0U; - const auto &all_in_data_anchors = node->GetAllInDataAnchorsPtr(); - for (size_t i = 0U; i < all_in_data_anchors.size(); ++i) { - GE_ASSERT_NOTNULL(all_in_data_anchors.at(i)); - if (all_in_data_anchors.at(i)->GetPeerOutAnchor() == nullptr) { - continue; - } - TensorAddress address = nullptr; - bool is_data_dependent = false; - GE_ASSERT_SUCCESS(ddi->IsDataDependent(static_cast(valid_input_idx), is_data_dependent)); - bool is_tiling_dependent = false; - if (!is_data_dependent) { - GE_ASSERT_SUCCESS(ddi->IsTilingInputDataDependent(static_cast(valid_input_idx), is_tiling_dependent)); - } - GELOGD("Node: %s input: %zu data/tiling depend flag: %d/%d", node->GetNamePtr(), valid_input_idx, is_data_dependent, - is_tiling_dependent); - is_data_dependent = is_data_dependent || is_tiling_dependent; - if (is_data_dependent) { - GE_ASSERT_GRAPH_SUCCESS(GetDependInputTensorAddr(op, valid_input_idx, address)); - } - std::unique_ptr tensor_holder; - const auto &valid_op_desc = op_desc->GetInputDescPtr(i); - GE_ASSERT_NOTNULL(valid_op_desc); - GE_ASSERT_GRAPH_SUCCESS(BuildRtTensor(*valid_op_desc, address, tensor_holder)); - rt_tensor_holders_.emplace_back(std::move(tensor_holder)); - ++valid_input_idx; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus TilingContextBuilder::BuildRTOutputShapes(const ge::Operator &op) { - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_ASSERT_NOTNULL(op_desc); - for (size_t i = 0U; i < op_desc->GetOutputsSize(); ++i) { - gert::StorageShape storage_shape; - GetStorageShape(op_desc->GetOutputDesc(i), storage_shape); - std::unique_ptr tensor_holder; - GE_ASSERT_GRAPH_SUCCESS(BuildRtTensor(op_desc->GetOutputDesc(i), nullptr, tensor_holder)); - GE_ASSERT_NOTNULL(tensor_holder, "Create context holder outputs failed, op[%s]", op_desc->GetNamePtr()); - rt_tensor_holders_.emplace_back(std::move(tensor_holder)); - } - return ge::GRAPH_SUCCESS; -} -KernelContextHolder TilingContextBuilder::Build(const ge::Operator &op, ge::graphStatus &ret) { - ret = ge::GRAPH_FAILED; - KernelContextHolder holder; - if (compile_info_ == nullptr) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Please give tiling context builder compile info."); - return holder; - } - if (platform_info_ == nullptr) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Please give tiling context builder platform info."); - return holder; - } - auto node = ge::NodeUtilsEx::GetNodeFromOperator(op); - std::vector context_inputs; - auto build_ret = BuildRTInputTensors(op); - if (build_ret != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Fail to BuildRTInputTensors."); - return holder; - } - build_ret = BuildRTOutputShapes(op); - if (build_ret != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Fail to BuildRTOutputShapes."); - return holder; - } - for (const auto &input_holder : rt_tensor_holders_) { - context_inputs.emplace_back(input_holder.get()); - } - context_inputs.emplace_back(compile_info_); - context_inputs.emplace_back(platform_info_); - context_inputs.emplace_back(nullptr); - context_inputs.emplace_back(reinterpret_cast(deterministic_)); - return base_builder_.Inputs(context_inputs).Outputs(outputs_).Build(node->GetOpDesc(), ret); -} -// 0-n input tensors -// n-m output shapes -// m + 1 compile info -// m + 2 tiling func -// 其中 n为输入个数总和,m为输入输出个数总和 -KernelContextHolder TilingContextBuilder::Build(const ge::Operator &op) { - ge::Status ret = ge::GRAPH_FAILED; - return Build(op, ret); -} - -AtomicTilingContextBuilder &AtomicTilingContextBuilder::CompileInfo(void *compile_info) { - compile_info_ = compile_info; - return *this; -} - -AtomicTilingContextBuilder &AtomicTilingContextBuilder::CleanWorkspaceSizes(ContinuousVector *workspace_sizes) { - worksapce_sizes_ = reinterpret_cast(workspace_sizes); - return *this; -} - -AtomicTilingContextBuilder &AtomicTilingContextBuilder::CleanOutputSizes(const std::vector &output_sizes) { - clean_output_sizes_ = output_sizes; - return *this; -} - -AtomicTilingContextBuilder &AtomicTilingContextBuilder::TilingData(void *tiling_data) { - outputs_[TilingContext::kOutputTilingData] = tiling_data; - return *this; -} -AtomicTilingContextBuilder &AtomicTilingContextBuilder::Workspace(ContinuousVector *workspace) { - outputs_[TilingContext::kOutputWorkspace] = workspace; - return *this; -} -KernelContextHolder AtomicTilingContextBuilder::Build(const ge::Operator &op, ge::graphStatus &ret) { - ret = ge::GRAPH_FAILED; - KernelContextHolder holder; - if (compile_info_ == nullptr) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Please give tiling context builder compile info."); - return holder; - } - std::vector context_inputs; - context_inputs.emplace_back(worksapce_sizes_); - for (const int64_t out_size : clean_output_sizes_) { - context_inputs.emplace_back(reinterpret_cast(out_size)); - } - context_inputs.emplace_back(compile_info_); - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - return base_builder_.Inputs(context_inputs).Outputs(outputs_).Build(op_desc, ret); -} -// 0 atomic op workspace -// 1~n 待清零的output size -// n+1 compile info -// n+2 atomic tiling func -// 其中 n 为待清零的输出个数, -KernelContextHolder AtomicTilingContextBuilder::Build(const ge::Operator &op) { - ge::graphStatus ret = ge::GRAPH_FAILED; - return Build(op, ret); -} -} // namespace gert diff --git a/exe_graph/lowering/tiling_parse_context_builder.cc b/exe_graph/lowering/tiling_parse_context_builder.cc deleted file mode 100644 index bfe481ffb50142ea640d0abbb7244fdfe36dcf2f..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/tiling_parse_context_builder.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/tiling_parse_context_builder.h" - -#include "graph/compute_graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_util.h" -#include "graph/def_types.h" -#include "common/checker.h" -#include "graph/debug/ge_util.h" - -namespace gert { -TilingParseContextBuilder &TilingParseContextBuilder::CompileJson(const ge::char_t *compile_json) { - compile_json_ = const_cast(compile_json); - return *this; -} - -TilingParseContextBuilder &TilingParseContextBuilder::PlatformInfo(void *platform_info) { - platform_info_ = platform_info; - return *this; -} - -TilingParseContextBuilder &TilingParseContextBuilder::CompileInfoCreatorFunc( - OpImplRegisterV2::CompileInfoCreatorFunc create_func) { - create_func_ = create_func; - return *this; -} - -TilingParseContextBuilder &TilingParseContextBuilder::CompileInfoDeleterFunc( - OpImplRegisterV2::CompileInfoDeleterFunc delete_func) { - delete_func_ = delete_func; - return *this; -} - -KernelContextHolder TilingParseContextBuilder::Build(const ge::Operator &op) { - KernelContextHolder holder; - if (compile_json_ == nullptr) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Compile info is nullptr."); - return holder; - } - if (platform_info_ == nullptr) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Platform info is nullptr."); - return holder; - } - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_CHECK_NOTNULL_EXEC(op_desc, return holder); - std::vector> tiling_parse_outputs(1, std::make_pair(nullptr, nullptr)); - if (create_func_ != nullptr && delete_func_ != nullptr) { - tiling_parse_outputs[0].first = create_func_(); - tiling_parse_outputs[0].second = delete_func_; - } - return gert::KernelRunContextBuilder() - .Inputs({compile_json_}) - .Inputs({platform_info_}) - .Inputs({const_cast(op_desc->GetTypePtr())}) - .Outputs(tiling_parse_outputs) - .Build(op_desc); -} -} // namespace gert diff --git a/exe_graph/lowering/value_holder.cc b/exe_graph/lowering/value_holder.cc deleted file mode 100644 index 491de4ac179f56d8ba4a4b53b6ea18083255320f..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/value_holder.cc +++ /dev/null @@ -1,673 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/value_holder.h" -#include "value_holder_inner.h" - -#include -#include - -#include -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_op_types.h" -#include "graph/utils/graph_utils.h" -#include "common/checker.h" - -#include "exe_graph/lowering/exe_graph_attrs.h" -#include "exe_graph/lowering/extend_exe_graph.h" -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/def_types.h" -#include "graph/fast_graph/execute_graph.h" -namespace gert { -namespace bg { -namespace { -constexpr const ge::char_t *kInnerDataNodes = "_inner_data_nodes"; -thread_local std::deque> graph_frames; -thread_local GraphFrame *current_frame; -bool IsGraphOutType(const char *node_type) { - return strcmp(kNetOutput, node_type) == 0 || strcmp(kInnerNetOutput, node_type) == 0; -} -ge::OpDescPtr CreateOpDesc(const std::string &node_name, const char *node_type, size_t in_count, size_t out_count) { - auto op_desc = ge::MakeShared(node_name, node_type); - GE_ASSERT_NOTNULL(op_desc); - for (size_t i = 0; i < in_count; ++i) { - if (op_desc->AddInputDesc(ge::GeTensorDesc()) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to create OpDesc for node %s, io-count %zu/%zu, add input desc %zu failed ", node_name.c_str(), - in_count, out_count, i); - return nullptr; - } - } - for (size_t i = 0; i < out_count; ++i) { - if (op_desc->AddOutputDesc(ge::GeTensorDesc()) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to create OpDesc for node %s, io-count %zu/%zu, add output desc %zu failed ", node_name.c_str(), - in_count, out_count, i); - return nullptr; - } - } - return op_desc; -} - -struct ConnectionPathPoint { - ge::FastNode *node; - ge::ExecuteGraph *graph; -}; - -ge::EdgeDstEndpoint EnsureHasDataEdge(ge::FastNode *src, int32_t src_index, ge::FastNode *cur_node) { - GE_ASSERT_NOTNULL(cur_node); - for (const auto edge : src->GetOutEdgesRefByIndex(src_index)) { - if (edge == nullptr) { - continue; - } - auto dst_node = edge->dst; - GE_ASSERT_NOTNULL(dst_node); - if (dst_node == cur_node) { - return {cur_node, edge->dst_input}; - } - } - - const auto index = cur_node->GetDataInNum(); - GE_ASSERT_SUCCESS(ge::FastNodeUtils::AppendInputEdgeInfo(cur_node, index + 1)); - - auto graph = cur_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph); - - GE_ASSERT_NOTNULL(graph->AddEdge(src, src_index, cur_node, static_cast(index))); - return {cur_node, static_cast(index)}; -} - -ge::Status GetTempFrameByGraph(ge::ExecuteGraph *graph, std::unique_ptr &frame) { - GE_ASSERT_TRUE(!GetGraphFrames().empty(), "Failed to do frame selection, there is no root-frame exists"); - const auto root_frame = GetGraphFrames().begin()->get(); - GE_ASSERT_NOTNULL(root_frame, "Failed to find the root frame"); - - GE_ASSERT_NOTNULL(graph); - frame = ge::ComGraphMakeUnique(graph->shared_from_this(), *root_frame); - GE_ASSERT_NOTNULL(frame); - const auto current_graph_frame = ValueHolder::GetCurrentFrame(); - GE_ASSERT_NOTNULL(current_graph_frame); - frame->SetCurrentComputeNode(current_graph_frame->GetCurrentComputeNode()); - - return ge::SUCCESS; -} - -ge::FastNode *EnsureHasData(const ConnectionPathPoint &point, int32_t index, bool &new_created) { - std::unique_ptr frame; - GE_ASSERT_SUCCESS(GetTempFrameByGraph(point.graph, frame)); - ge::FastNode *data = nullptr; - if (!FindValFromMapExtAttr(frame->GetExecuteGraph().get(), kInnerDataNodes, index, data)) { - data = ValueHolder::AddNode(kInnerData, 0, 1, *frame); - GE_ASSERT_NOTNULL(data); - GE_ASSERT_TRUE(ge::AttrUtils::SetInt(data->GetOpDescBarePtr(), ge::ATTR_NAME_INDEX, index)); - AddKVToMapExtAttr(frame->GetExecuteGraph().get(), kInnerDataNodes, index, data); - new_created = true; - } - return data; -} - -ge::Status GetOutsideGuarderType(const ge::FastNode *node, const ge::FastNode *src_node_from_parent_graph, - std::string &guarder_type) { - const auto inside_guarder_type = ge::AttrUtils::GetStr(node->GetOpDescBarePtr(), kGuarderNodeType); - // 因为透传祖先图的valueholer而产生的InnerData都是ValueHolder类自行产生的,此时产生InnerData的时候不会追加guarder, - // 所以理论上InnerData是不会带有guarder的, 此处只校验只有outside guarder场景,子图内部Innerdata有guarder属于异常场景 - GE_ASSERT_TRUE(inside_guarder_type == nullptr); - - const auto op_desc = src_node_from_parent_graph->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - const auto outside_guarder_type = ge::AttrUtils::GetStr(op_desc, kNodeWithGuarderOutside); - if (outside_guarder_type != nullptr) { - guarder_type = *outside_guarder_type; - return ge::SUCCESS; - } - const auto guarder_node_type = ge::AttrUtils::GetStr(op_desc, kGuarderNodeType); - if (guarder_node_type != nullptr) { - guarder_type = *guarder_node_type; - } - return ge::SUCCESS; -} - -ge::EdgeSrcEndpoint ConnectFromParents(ge::FastNode *src, int32_t src_index, const ge::FastNode *dst) { - auto next_graph = dst->GetExtendInfo()->GetOwnerGraphBarePtr(); - const auto src_graph = src->GetExtendInfo()->GetOwnerGraphBarePtr(); - if (src_graph != next_graph) { - std::stack connect_path; - - bool full_path = false; - while (next_graph != nullptr) { - const auto parent_node = next_graph->GetParentNodeBarePtr(); - if (parent_node == nullptr) { - // log out of loop scope - break; - } - connect_path.push({parent_node, next_graph}); - next_graph = parent_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - if (next_graph == src_graph) { - full_path = true; - break; - } - } - - if (!full_path) { - GE_LOGE( - "Failed to connect from %s index %d to node %s, the src node does not on the graph or on its parent graphs", - src->GetNamePtr(), src_index, dst->GetNamePtr()); - return {nullptr, ge::kInvalidEdgeIndex}; - } - - while (!connect_path.empty()) { - auto point = std::move(connect_path.top()); - connect_path.pop(); - - const auto dst_endpoint = EnsureHasDataEdge(src, src_index, point.node); - GE_ASSERT_NOTNULL(dst_endpoint.node); - - bool new_created = false; - const auto data_node = EnsureHasData(point, dst_endpoint.index, new_created); - GE_ASSERT_NOTNULL(data_node); - - std::string guarder_type; - GE_ASSERT_SUCCESS(GetOutsideGuarderType(data_node, src, guarder_type)); - if ((new_created) && (!guarder_type.empty())) { - (void)ge::AttrUtils::SetStr(data_node->GetOpDescBarePtr(), kNodeWithGuarderOutside, guarder_type); - } - - src = data_node; - src_index = 0; - } - } - return {src, src_index}; -} - -ge::graphStatus AddDataEdge(ge::FastNode *src, int32_t src_index, ge::FastNode *dst, int32_t dst_index) { - auto src_endpoint = ConnectFromParents(src, src_index, dst); - if (src_endpoint.node == nullptr) { - GE_LOGE("Failed to connect from %s(%d) to %s(%d), connect from parents failed", src->GetNamePtr(), src_index, - dst->GetNamePtr(), dst_index); - return ge::GRAPH_FAILED; - } - auto graph = dst->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph); - auto edge = graph->AddEdge(src_endpoint.node, src_endpoint.index, dst, dst_index); - if (edge == nullptr) { - GE_LOGE("Failed to connect edge from %s:%d to %s:%d", src->GetNamePtr(), src_index, - dst->GetNamePtr(), dst_index); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -HyperStatus AddDependencyBetweenNodes(ge::FastNode *src, ge::FastNode *dst) { - auto src_graph = src->GetExtendInfo()->GetOwnerGraphBarePtr(); - auto dst_graph = dst->GetExtendInfo()->GetOwnerGraphBarePtr(); - if (src_graph != dst_graph) { - return HyperStatus::ErrorStatus("The source node %s(%s) and dst node %s(%s) does not on the same graph", - src->GetNamePtr(), src->GetTypePtr(), dst->GetNamePtr(), - dst->GetTypePtr()); - } - if (src_graph == nullptr) { - return HyperStatus::ErrorStatus("The source node %s(%s) and dst node %s(%s) does not on the graph", - src->GetNamePtr(), src->GetTypePtr(), dst->GetNamePtr(), - dst->GetTypePtr()); - } - if (src_graph->AddEdge(src, ge::kControlEdgeIndex, - dst, ge::kControlEdgeIndex) == nullptr) { - return HyperStatus::ErrorStatus("Failed to add control edge from %s to %s", src->GetNamePtr(), - dst->GetNamePtr()); - } - return HyperStatus::Success(); -} - -ge::graphStatus AddDependencyToGuarder(ge::FastNode *src, ge::FastNode *guarder) { - auto guarder_graph = guarder->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(guarder_graph); - ge::FastNode *current_node = src; - while (current_node->GetExtendInfo()->GetOwnerGraphBarePtr() != guarder_graph) { - auto owner_graph = current_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(owner_graph); - auto parent_node = owner_graph->GetParentNodeBarePtr(); - GE_ASSERT_NOTNULL(parent_node, - "Failed to add dependency from node %s(%s) to guarder %s(%s), the guarder node does not on the " - "same graph or the parent graphs of the source node", - src->GetNamePtr(), src->GetTypePtr(), guarder->GetNamePtr(), guarder->GetTypePtr()); - current_node = const_cast(parent_node); - } - GE_ASSERT_HYPER_SUCCESS(AddDependencyBetweenNodes(current_node, guarder)); - return ge::GRAPH_SUCCESS; -} - -ge::NodePtr GetComputeNodeByIndex(const GraphFrame &frame, size_t index) { - auto &indexes_to_node = frame.GetIndexesToNode(); - GE_ASSERT_TRUE(indexes_to_node.size() > index, "The current compute node index %zu out of range", index); - return indexes_to_node[index]; -} -} // namespace -std::atomic ValueHolder::id_generator_{0}; -ValueHolder::~ValueHolder() = default; - -ValueHolder::ValueHolder() - : id_(id_generator_++), type_(ValueHolderType::kValueHolderTypeEnd), - fast_node_(nullptr), index_(0), placement_(0) {} - -bool ValueHolder::IsOk() const noexcept { - return error_msg_ == nullptr; -} -ValueHolder::ValueHolderType ValueHolder::GetType() const noexcept { - return type_; -} - -ge::FastNode *ValueHolder::GetFastNode() const noexcept { - return fast_node_; -} - -int32_t ValueHolder::GetOutIndex() const noexcept { - return index_; -} -int64_t ValueHolder::GetId() const noexcept { - return id_; -} - -ge::ExecuteGraph *ValueHolder::GetExecuteGraph() const noexcept { - return fast_node_->GetExtendInfo()->GetOwnerGraphBarePtr(); -} - -ValueHolderPtr ValueHolder::CreateError(const char *fmt, va_list arg) { - auto value_holder = std::shared_ptr(new (std::nothrow) ValueHolder()); - GE_ASSERT_NOTNULL(value_holder); - value_holder->error_msg_ = std::unique_ptr(CreateMessage(fmt, arg)); - return value_holder; -} -ValueHolderPtr ValueHolder::CreateError(const char *fmt, ...) { - va_list arg; - va_start(arg, fmt); - auto holder = CreateError(fmt, arg); - va_end(arg); - return holder; -} -std::string ValueHolder::GenerateNodeName(const char *node_type, const GraphFrame &frame) { - std::string node_name(node_type); - const auto ¤t_compute_node = frame.GetCurrentComputeNode(); - if (current_compute_node != nullptr) { - node_name.append("_").append(current_compute_node->GetName()); - } - node_name.append("_").append(std::to_string(id_generator_)); - ++id_generator_; - return node_name; -} - -ge::FastNode *ValueHolder::AddNode(const char *node_type, size_t input_count, size_t output_count, - const GraphFrame &frame) { - auto graph = frame.GetExecuteGraph(); - GE_ASSERT_NOTNULL(graph); - - auto node = graph->AddNode(CreateOpDesc(GenerateNodeName(node_type, frame), node_type, input_count, output_count)); - GE_ASSERT_NOTNULL(node); - - // add compute node info index - size_t index; - if (frame.GetCurrentNodeIndex(index)) { - if (!ge::AttrUtils::SetInt(node->GetOpDescBarePtr(), kComputeNodeIndex, static_cast(index))) { - GE_LOGE("Failed to add node %s, add ComputeNodeIndex failed", node_type); - return nullptr; - } - } - - return node; -} - -ge::FastNode *ValueHolder::CreateNode(const char *node_type, const std::vector &inputs, - size_t out_count) { - auto frame = GetCurrentFrame(); - if (frame == nullptr) { - GE_LOGE("The current frame does not exist, " - "the function ValueHolder::PushGraphFrame should be called before construct the graph"); - return nullptr; - } - auto node = ValueHolder::AddNode(node_type, inputs.size(), out_count, *frame); - - /* - * todo 检查是否有子图向父图连接的场景,这种场景需要报错 - * 父图向子图连接的场景,为父图节点创建一个InnerData - */ - for (size_t i = 0U; i < inputs.size(); ++i) { - GE_ASSERT_NOTNULL(inputs[i]); - GE_ASSERT_NOTNULL(inputs[i]->fast_node_); - GE_ASSERT_SUCCESS(AddDataEdge(inputs[i]->fast_node_, inputs[i]->index_, node, static_cast(i))); - if (inputs[i]->guarder_ != nullptr && !IsGraphOutType(node_type)) { - GE_ASSERT_SUCCESS(AddDependencyToGuarder(node, inputs[i]->guarder_->GetFastNode())); - } - } - return node; -} - -ValueHolderPtr ValueHolder::CreateMateFromNode(ge::FastNode *node, int32_t index, ValueHolderType type) { - auto holder = std::shared_ptr(new (std::nothrow) ValueHolder()); - GE_ASSERT_NOTNULL(holder); - - holder->type_ = type; - holder->fast_node_ = node; - holder->index_ = index; - holder->op_desc_ = holder->fast_node_->GetOpDescPtr(); - return holder; -} - -void ValueHolder::SetErrorMsg(const char *fmt, va_list arg) { - error_msg_ = std::unique_ptr(CreateMessage(fmt, arg)); -} - -std::vector ValueHolder::CreateDataOutput(const char *node_type, - const std::vector &inputs, - size_t out_count) { - auto node = CreateNode(node_type, inputs, out_count); - if (node == nullptr) { - return {out_count, nullptr}; - } - return CreateFromNodeStart(node, out_count); -} - -/** - * @param data const数据 - * @param size const数据的长度 - * @param is_string 此const是否是个字符串, todo: 当前对string支持的不好 - * @return - */ -ValueHolderPtr ValueHolder::CreateConst(const void *data, size_t size, bool is_string) { - GE_ASSERT_NOTNULL(data); - auto node = ValueHolder::CreateNode(kConst, {}, 1U); - GE_ASSERT_NOTNULL(node); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - GE_ASSERT_SUCCESS(op_desc->SetAttr("is_string", ge::AnyValue::CreateFrom(is_string))); - GE_ASSERT_TRUE(ge::AttrUtils::SetZeroCopyBytes(op_desc, kConstValue, - ge::Buffer::CopyFrom(ge::PtrToPtr(data), size))); - return CreateFromNode(node, 0, ValueHolderType::kConst); -} - -ValueHolderPtr ValueHolder::CreateFeed(int64_t index) { - auto node = ValueHolder::CreateNode(kData, {}, 1U); - GE_ASSERT_NOTNULL(node); - GE_ASSERT_TRUE(ge::AttrUtils::SetInt(node->GetOpDescBarePtr(), kFeedIndex, index)); - return CreateFromNode(node, 0, ValueHolderType::kFeed); -} - -ValueHolderPtr ValueHolder::CreateConstData(int64_t index) { - auto node = ValueHolder::CreateNode(kConstData, {}, 1U); - GE_ASSERT_NOTNULL(node); - GE_ASSERT_TRUE(ge::AttrUtils::SetInt(node->GetOpDescBarePtr(), kFeedIndex, index)); - return CreateFromNode(node, 0, ValueHolderType::kConstData); -} - -HyperStatus ValueHolder::AddInnerDataToKVMap(int32_t index) const noexcept { - if (fast_node_->GetType() != kInnerData) { - return HyperStatus::ErrorStatus("Failed to add node to KVMap, because node type is not InnerData."); - } - GELOGI("Set inner data: %s to kv map, index: %d", fast_node_->GetNamePtr(), index); - AddKVToMapExtAttr(fast_node_->GetExtendInfo()->GetOwnerGraphBarePtr(), - kInnerDataNodes, index, fast_node_); - return HyperStatus::Success(); -} - -ValueHolderPtr ValueHolder::CreateSingleDataOutput(const char *node_type, const std::vector &inputs) { - auto holders = CreateDataOutput(node_type, inputs, 1U); - if (holders.empty()) { - return nullptr; - } - return holders[0]; -} - -HyperStatus ValueHolder::AddDependency(const ValueHolderPtr &src, const ValueHolderPtr &dst) { - if (src == nullptr || src->GetFastNode() == nullptr) { - return HyperStatus::ErrorStatus("Failed to add control ege, because the src does not have a node."); - } - if (dst == nullptr || dst->GetFastNode() == nullptr) { - return HyperStatus::ErrorStatus("Failed to add control ege, because the dst does not have a node."); - } - if (src->GetFastNode() == dst->GetFastNode()) { - GELOGW("Add dependency between the same node %s, skip", src->GetFastNode()->GetNamePtr()); - return HyperStatus::Success(); - } - return AddDependencyBetweenNodes(src->GetFastNode(), dst->GetFastNode()); -} - -ge::graphStatus ValueHolder::AppendInputs(const std::vector &src) { - const uint32_t start_index = fast_node_->GetDataInNum(); - GE_ASSERT_SUCCESS( - ge::FastNodeUtils::AppendInputEdgeInfo(fast_node_, start_index + static_cast(src.size()))); - const auto &dst = CreateFromNode(fast_node_, static_cast(start_index), src.size()); - GE_ASSERT_TRUE(dst.size() == src.size()); - - for (size_t i = 0U; i < src.size(); ++i) { - GE_ASSERT_NOTNULL(src[i]); - GE_ASSERT_NOTNULL(dst[i]); - GE_ASSERT_SUCCESS( - AddDataEdge(src[i]->fast_node_, src[i]->GetOutIndex(), dst[i]->fast_node_, dst[i]->GetOutIndex())); - } - - return ge::SUCCESS; -} - -GraphFrame *ValueHolder::PushGraphFrame() { - if (!graph_frames.empty()) { - GELOGE(ge::INTERNAL_ERROR, - "Failed to push root graph frame, if you want to push a non-root graph frame, specify which ValueHolder the " - "graph frame belongs and the ir name."); - return nullptr; - } - auto graph = ge::MakeShared("ROOT"); - GE_ASSERT_NOTNULL(graph); - auto frame = new (std::nothrow) GraphFrame(graph); - GE_ASSERT_NOTNULL(frame); - return ValueHolder::PushGraphFrame(frame); -} - -GraphFrame *ValueHolder::PushGraphFrame(const ValueHolderPtr &belongs, const char *graph_name) { - GE_ASSERT_NOTNULL(belongs); - GE_ASSERT_NOTNULL(belongs->GetFastNode()); - GE_ASSERT_NOTNULL(graph_name); - if (graph_frames.empty()) { - GELOGE(ge::INTERNAL_ERROR, "Failed to push a non-root graph frame, there is no root graph frames exists"); - return nullptr; - } - auto &parent_frame = *graph_frames.back(); - auto instance_name = GenerateNodeName(graph_name, parent_frame); - auto graph = ge::MakeShared(instance_name); - GE_ASSERT_NOTNULL(graph); - - auto frame_holder = ge::ComGraphMakeUnique(graph, parent_frame); - GE_ASSERT_NOTNULL(frame_holder); - - int64_t compute_node_index; - if (ge::AttrUtils::GetInt(belongs->GetFastNode()->GetOpDescBarePtr(), kComputeNodeIndex, compute_node_index)) { - auto compute_node = GetComputeNodeByIndex(*frame_holder.get(), static_cast(compute_node_index)); - if (compute_node != nullptr) { - frame_holder->SetCurrentComputeNode(compute_node); - } - } - - GE_ASSERT_SUCCESS(ge::FastNodeUtils::AppendSubgraphToNode(belongs->GetFastNode(), graph_name, graph)); - return ValueHolder::PushGraphFrame(frame_holder.release()); -} - -GraphFrame *ValueHolder::PushGraphFrame(GraphFrame *graph_frame) { - GE_ASSERT_NOTNULL(graph_frame); - if (!graph_frames.empty()) { - GE_ASSERT_TRUE((graph_frames.back()->GetExecuteGraph().get() == - graph_frame->GetExecuteGraph()->GetParentGraphBarePtr()), - "Last graph frame in stack %s is not parent graph frame of %s.", - graph_frames.back()->GetExecuteGraph()->GetName().c_str(), - graph_frame->GetExecuteGraph()->GetName().c_str()); - } - graph_frames.emplace_back(graph_frame); - return graph_frames.back().get(); -} - -std::unique_ptr ValueHolder::PopGraphFrame() { - if (graph_frames.empty()) { - return nullptr; - } - auto ret = std::move(graph_frames.back()); - graph_frames.pop_back(); - return ret; -} -GraphFrame *ValueHolder::GetCurrentFrame() { - if (current_frame != nullptr) { - return current_frame; - } - if (graph_frames.empty()) { - return nullptr; - } - return graph_frames.back().get(); -} -void ValueHolder::ClearGraphFrameResource() { - graph_frames.clear(); - current_frame = nullptr; -} -void ValueHolder::SetCurrentComputeNode(const ge::NodePtr &node) { - auto frame = GetCurrentFrame(); - if (frame == nullptr) { - GELOGW("Ignore to add current compute node, the current frame is nullptr"); - return; - } - frame->SetCurrentComputeNode(node); -} -void ValueHolder::AddRelevantInputNode(const ge::NodePtr &node) { - auto frame = GetCurrentFrame(); - if (frame == nullptr) { - GELOGW("Ignore to add relevant input node, the current frame is nullptr"); - } else { - frame->AddRelevantInputNode(node); - } -} -std::unique_ptr ValueHolder::SetScopedCurrentComputeNode( - const ge::NodePtr &node) { - auto frame = GetCurrentFrame(); - GE_ASSERT_NOTNULL(frame); - - auto guarder = ge::ComGraphMakeUnique(frame->GetCurrentComputeNode()); - GE_ASSERT_NOTNULL(guarder); - frame->SetCurrentComputeNode(node); - return guarder; -} - -ge::ExecuteGraph *ValueHolder::GetCurrentExecuteGraph() { - auto frame = GetCurrentFrame(); - GE_ASSERT_NOTNULL(frame); - return frame->GetExecuteGraph().get(); -} - -ge::graphStatus ValueHolder::RefFrom(const ValueHolderPtr &other) { - GE_ASSERT_NOTNULL(fast_node_); - GE_ASSERT_NOTNULL(other); - GE_ASSERT_NOTNULL(other->fast_node_); - - if (index_ < 0 || other->index_ < 0) { - GELOGE(ge::PARAM_INVALID, "Invalid index to ref %d -> %d", index_, other->index_); - return ge::PARAM_INVALID; - } - - const auto op_desc = fast_node_->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - const auto &td = op_desc->MutableOutputDesc(index_); - GE_ASSERT_NOTNULL(td); - - GE_ASSERT_TRUE(ge::AttrUtils::SetStr(td, kRefFromNode, other->GetFastNode()->GetName())); - GE_ASSERT_TRUE(ge::AttrUtils::SetInt(td, kRefFromIndex, other->index_)); - return ge::GRAPH_SUCCESS; -} - -ValueHolderPtr ValueHolder::CreateVoidGuarder(const char *node_type, const ValueHolderPtr &resource, - const std::vector &args) { - GE_ASSERT_NOTNULL(resource); - std::vector inputs; - inputs.reserve(args.size() + 1); - inputs.emplace_back(resource); - inputs.insert(inputs.cend(), args.cbegin(), args.cend()); - auto ret = CreateVoid(node_type, inputs); - GE_ASSERT_NOTNULL(ret); - GE_ASSERT_NOTNULL(ret->GetFastNode()); - GE_ASSERT_TRUE(ge::AttrUtils::SetInt(ret->GetFastNode()->GetOpDescBarePtr(), kReleaseResourceIndex, 0)); - const auto resource_node = resource->GetFastNode(); - GE_ASSERT_NOTNULL(resource_node); - GE_ASSERT_TRUE(ge::AttrUtils::SetStr(resource_node->GetOpDescBarePtr(), kGuarderNodeType, node_type)); - resource->SetGuarder(ret); - return ret; -} - -const int32_t &ValueHolder::GetPlacement() const { - return placement_; -} -void ValueHolder::SetPlacement(const int32_t &placement) { - placement_ = placement; -} -void ValueHolder::ReleaseAfter(const ValueHolderPtr &other) { - if (guarder_ == nullptr) { - GELOGW("Current holder from node %s index %d does not has a guarder", fast_node_->GetNamePtr(), index_); - return; - } - AddDependency(other, guarder_); -} -std::unique_ptr ValueHolder::PopGraphFrame(const std::vector &outputs, - const std::vector &targets) { - const char *node_type = kNetOutput; - if (graph_frames.size() > 1U) { - // The NetOutput type means "Network outputs", subgraph use InnerNetOutput as output type - node_type = kInnerNetOutput; - } - return PopGraphFrame(outputs, targets, node_type); -} - -std::unique_ptr ValueHolder::PopGraphFrame(const std::vector &outputs, - const std::vector &targets, - const char *out_node_type) { - GE_ASSERT_NOTNULL(out_node_type); - auto out_holder = CreateVoid(out_node_type, outputs); - GE_ASSERT_NOTNULL(out_holder); - if (strcmp(ge::NETOUTPUT, out_node_type) == 0) { - // the name of NetOutput node must be `NetOutput` - GE_ASSERT_NOTNULL(out_holder->GetFastNode()); - const auto op_desc = out_holder->GetFastNode()->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - op_desc->SetName(out_node_type); - } - - for (const auto &target : targets) { - AddDependency(target, out_holder); - } - return PopGraphFrame(); -} - -ValueHolderPtr ValueHolder::GetGuarder() const noexcept { - return guarder_; -} - -void ValueHolder::SetGuarder(const bg::ValueHolderPtr &guarder) noexcept { - guarder_ = guarder; -} - -void SetCurrentFrame(GraphFrame *frame) { - current_frame = frame; -} -GraphFrame *GetCurrentFrame() { - return current_frame; -} - -std::vector ValueHolder::GetLastExecNodes() { - if (graph_frames.empty()) { - return {}; - } - auto frame = graph_frames.cbegin()->get(); - if (graph_frames.size() > 1U) { - frame = (graph_frames.begin() + 1)->get(); - } - return frame->GetLastExecNodes(); -} -std::deque> &GetGraphFrames() { - return graph_frames; -} -} // namespace bg -} // namespace gert diff --git a/exe_graph/lowering/value_holder_inner.h b/exe_graph/lowering/value_holder_inner.h deleted file mode 100644 index 02ddad6935cadc29713b659a4545c99cef5c9d8a..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/value_holder_inner.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_EXE_GRAPH_LOWERING_VALUE_HOLDER_INNER_H_ -#define METADEF_CXX_EXE_GRAPH_LOWERING_VALUE_HOLDER_INNER_H_ -#include -#include "exe_graph/lowering/builtin_node_types.h" -#include "exe_graph/lowering/graph_frame.h" -namespace gert { -namespace bg { -void SetCurrentFrame(GraphFrame *frame); -GraphFrame *GetCurrentFrame(); -std::deque> &GetGraphFrames(); -} // namespace bg -} // namespace gert -#endif // METADEF_CXX_EXE_GRAPH_LOWERING_VALUE_HOLDER_INNER_H_ diff --git a/exe_graph/lowering/value_holder_utils.cc b/exe_graph/lowering/value_holder_utils.cc deleted file mode 100644 index 044942fcd73d5f28a13ea7d12bdf0a2041fd1cdc..0000000000000000000000000000000000000000 --- a/exe_graph/lowering/value_holder_utils.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/value_holder_utils.h" - -namespace gert { -namespace bg { -bool ValueHolderUtils::IsNodeValid(const ValueHolderPtr &holder) { - if (holder == nullptr) { - return false; - } - return (holder->fast_node_ != nullptr); -} - -bool ValueHolderUtils::IsNodeEqual(const ValueHolderPtr &src, const ValueHolderPtr &dst) { - if (src == dst) { - return true; - } - if ((src == nullptr) || (dst == nullptr)) { - return false; - } - return (src->fast_node_ == dst->fast_node_); -} - -std::string ValueHolderUtils::GetNodeName(const ValueHolderPtr &holder) { - if (holder == nullptr) { - return ""; - } - return holder->op_desc_->GetName(); -} -const char *ValueHolderUtils::GetNodeNameBarePtr(const ValueHolderPtr &holder) { - if (holder == nullptr) { - return ""; - } - return holder->op_desc_->GetNamePtr(); -} -std::string ValueHolderUtils::GetNodeType(const ValueHolderPtr &holder) { - if (holder == nullptr) { - return ""; - } - return holder->op_desc_->GetType(); -} -const char *ValueHolderUtils::GetNodeTypeBarePtr(const ValueHolderPtr &holder) { - if (holder == nullptr) { - return ""; - } - return holder->op_desc_->GetTypePtr(); -} - -ge::OpDescPtr ValueHolderUtils::GetNodeOpDesc(const ValueHolderPtr &holder) { - if (holder == nullptr) { - return nullptr; - } - return holder->op_desc_; -} -ge::OpDesc *ValueHolderUtils::GetNodeOpDescBarePtr(const ValueHolderPtr &holder) { - if (holder == nullptr) { - return nullptr; - } - return holder->op_desc_.get(); -} - -bool ValueHolderUtils::IsDirectlyControlled(const bg::ValueHolderPtr &src, const bg::ValueHolderPtr &dst) { - if (src == nullptr || dst == nullptr) { - return false; - } - return dst->fast_node_->IsDirectlyControlledByNode(src->fast_node_); -} -} // bg -} // gert diff --git a/exe_graph/stub/Makefile b/exe_graph/stub/Makefile deleted file mode 100755 index 0987161d80fec4eab652c1b674d42bded6b6f664..0000000000000000000000000000000000000000 --- a/exe_graph/stub/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -inc_path := $(shell pwd)/metadef/inc/external/ -out_path := $(shell pwd)/out/exe_graph/lib64/stub/ -stub_path := $(shell pwd)/metadef/exe_graph/stub/ - -mkdir_stub := $(shell mkdir -p $(out_path)) -exe_graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) \ No newline at end of file diff --git a/graph/CMakeLists.txt b/graph/CMakeLists.txt deleted file mode 100644 index 11caf72ba2439507d15511ad7cbf01b25fc7d5fc..0000000000000000000000000000000000000000 --- a/graph/CMakeLists.txt +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -include(${METADEF_DIR}/cmake/build_type.cmake) - -######### for base ############# -set(GRAPH_BASE_SOURCE_LIST - "type/types.cc" - "normal_graph/anchor.cc" - "attr/ge_attr_value.cc" - "utils/args_format_desc.cc" - "../base/any_value.cc" - "normal_graph/operator.cc" - "normal_graph/operator_impl.cc" - "../base/attr/attr_store.cc" - "attr/attr_group_serialize.cc" - "attr/attr_group_base.cc" - "attr/attr_group_serializer_registry.cc" - "buffer/buffer.cc" - "../base/utils/aligned_ptr.cc" - "normal_graph/compute_graph.cc" - "normal_graph/model.cc" - "serialization/model_serialize.cc" - "normal_graph/node.cc" - "normal_graph/op_desc.cc" - "ir/ir_meta.cc" - "ir/ir_data_type_symbol_store.cc" - "type/sym_dtype.cc" - "attr/ge_attr_define.cc" - "normal_graph/ge_tensor.cc" - "buffer/graph_buffer.cc" - "hcom/hcom_topo_info.cc" - "common/large_bm.cc" - "../base/common/plugin/plugin_manager.cc" - "common/hyper_status.cc" - "detail/attributes_holder.cc" - "utils/anchor_utils.cc" - "utils/enum_attr_utils.cc" - "utils/graph_utils.cc" - "utils/dumper/ge_graph_dumper.cc" - "utils/trace/trace_manager.cc" - "utils/ge_ir_utils.cc" - "utils/node_utils.cc" - "utils/type_utils.cc" - "utils/tensor_utils.cc" - "utils/constant_utils.cc" - "utils/connection_matrix.cc" - "utils/cycle_detector.cc" - "utils/op_type_utils.cc" - "utils/fast_node_utils.cc" - "utils/execute_graph_utils.cc" - "utils/execute_graph_adapter.cc" - "utils/args_format_desc_utils.cc" - "utils/inference_rule.cc" - "option/ge_context.cc" - "option/ge_local_context.cc" - "option/optimization_option.cc" - "option/optimization_option_info.cc" - "../base/utils/file_utils.cc" - "serialization/attr_serializer.cc" - "serialization/string_serializer.cc" - "serialization/data_type_serializer.cc" - "serialization/named_attrs_serializer.cc" - "serialization/bool_serializer.cc" - "serialization/buffer_serializer.cc" - "serialization/float_serializer.cc" - "serialization/int_serializer.cc" - "serialization/tensor_serializer.cc" - "serialization/tensor_desc_serializer.cc" - "serialization/graph_serializer.cc" - "serialization/list_value_serializer.cc" - "serialization/list_list_int_serializer.cc" - "serialization/list_list_float_serializer.cc" - "serialization/attr_serializer_registry.cc" - "serialization/utils/serialization_util.cc" - "cache_policy/cache_policy.cc" - "cache_policy/compile_cache_desc.cc" - "cache_policy/policy_register.cc" - "cache_policy/cache_state.cc" - "cache_policy/policy_management/match_policy/match_policy_exact_only.cc" - "cache_policy/policy_management/match_policy/match_policy_for_exactly_the_same.cc" - "cache_policy/policy_management/aging_policy/aging_policy_lru.cc" - "cache_policy/policy_management/aging_policy/aging_policy_lru_k.cc" - "utils/profiler.cc" - "normal_graph/tensor.cc" -) - -SET(GRAPH_SOURCE_LIST - "type/ascend_string.cc" - "attr/attr_value.cc" - "type/axis_type_info.cc" - "normal_graph/operator_factory.cc" - "normal_graph/operator_factory_impl.cc" - "normal_graph/graph.cc" - "normal_graph/gnode.cc" - "args_format/args_format_serializer.cc" - "args_format/arg_desc_info.cc" - "args_format/arg_desc_info_impl.cc" - "refiner/format_refiner.cc" - "context/inference_context.cc" - "refiner/ref_relation.cc" - "context/resource_context_mgr.cc" - "context/runtime_inference_context.cc" - "refiner/shape_refiner.cc" - "ir/ir_definitions_recover.cc" - "opsproto/opsproto_manager.cc" - "utils/op_desc_utils.cc" - "utils/tuning_utils.cc" - "utils/ffts_graph_utils.cc" - "utils/transformer_utils.cc" - "utils/graph_utils_ex.cc" - "utils/node_utils_ex.cc" - "utils/op_desc_utils_ex.cc" - "utils/graph_thread_pool.cc" - "utils/multi_thread_graph_builder.cc" - "utils/type_utils_ex.cc" - "parallelism/tensor_parallel_attrs.cc" - "utils/screen_printer.cc" - "${METADEF_DIR}/third_party/transformer/src/axis_util.cc" - "${METADEF_DIR}/third_party/transformer/src/transfer_shape_according_to_format.cc" - "${METADEF_DIR}/third_party/transformer/src/expand_dimension.cc" - "${METADEF_DIR}/third_party/transformer/src/transfer_range_according_to_format.cc" - "${METADEF_DIR}/third_party/transformer/src/transfer_shape_utils.cc" - "${METADEF_DIR}/ops/op_imp.cpp" -) - -SET(FAST_GRAPH_SOURCE_LIST - "fast_graph/fast_node.cc" - "fast_graph/execute_graph.cc" -) - -######### libgraph_base.so ############# -add_library(graph_base SHARED - ${GRAPH_BASE_SOURCE_LIST} - ${FAST_GRAPH_SOURCE_LIST} - $ -) - -target_compile_options(graph_base PRIVATE - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal -Wno-array-bounds> - $<$,$>:/MTd> - $<$,$>:/MT>) - -target_compile_definitions(graph_base PRIVATE - $<$,$>:FMK_SUPPORT_DUMP> - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> -) - -target_include_directories(graph_base PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos - ${METADEF_DIR} -) - -target_link_options(graph_base PRIVATE - -Wl,-Bsymbolic -) - -target_link_libraries(graph_base - PRIVATE - intf_pub - static_mmpa - -Wl,--no-as-needed - c_sec - slog - json - metadef - -Wl,--as-needed - $<$>:-lrt> - -ldl - PUBLIC - metadef_headers -) - -target_link_libraries(graph_base PRIVATE ascend_protobuf error_manager) -target_compile_options(graph_base PRIVATE -O2 -Werror -DNO_METADEF_ABI_COMPATIABLE) - -######### libgraph_base.a ############# -if (NOT ENABLE_OPEN_SRC) - target_clone(graph_base graph_base_static STATIC) - - target_compile_options(graph_base_static PRIVATE - -Os - -fvisibility=hidden - -fvisibility-inlines-hidden - -ffunction-sections - -fdata-sections - ) - - set_target_properties(graph_base_static PROPERTIES - OUTPUT_NAME graph_base - ) -endif () - -######### libgraph.so ############# -add_library(graph SHARED - ${GRAPH_SOURCE_LIST} - $ -) -target_compile_options(graph PRIVATE -DNO_METADEF_ABI_COMPATIABLE) - -target_compile_options(graph PRIVATE - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal -Wno-array-bounds> - $<$,$>:/MTd> - $<$,$>:/MT>) - -target_compile_definitions(graph PRIVATE - $<$,$>:FMK_SUPPORT_DUMP> - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> -) - -target_include_directories(graph PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${METADEF_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_link_options(graph PRIVATE - -Wl,-Bsymbolic -) - -target_link_libraries(graph - PRIVATE - intf_pub - static_mmpa - -Wl,--no-as-needed - c_sec - slog - json - platform - metadef - -Wl,--as-needed - ascend_protobuf_shared_headers - $<$>:-lrt> - -ldl - PUBLIC - metadef_headers -) - -target_link_libraries(graph PRIVATE graph_base error_manager) -target_compile_options(graph PRIVATE -O2 -Werror) - -if (${ENABLE_OPEN_SRC} STREQUAL "True") -else() - ######### libgraph.a ############# - add_library(graph_share SHARED - ${GRAPH_BASE_SOURCE_LIST} - ${GRAPH_SOURCE_LIST} - ${FAST_GRAPH_SOURCE_LIST} - $ - $ - ) - - target_compile_options(graph_share PRIVATE - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal -Wno-array-bounds> - $<$,$>:/MTd> - $<$,$>:/MT>) - - target_compile_definitions(graph_share PRIVATE - $<$,$>:FMK_SUPPORT_DUMP> - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> - ) - - target_include_directories(graph_share PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos - ${METADEF_DIR} - ) - - target_link_options(graph_share PRIVATE - -Wl,-Bsymbolic - ) - - target_link_libraries(graph_share - PRIVATE - intf_pub - static_mmpa - -Wl,--no-as-needed - c_sec - slog - json - platform - metadef - -Wl,--as-needed - $<$>:-lrt> - -ldl - PUBLIC - metadef_headers - ) - - target_clone(graph_share graph_static STATIC) - - target_link_libraries(graph_share PRIVATE ascend_protobuf error_manager) - target_link_libraries(graph_static PRIVATE ascend_protobuf_static) - - target_compile_options(graph_share PRIVATE -O2) - target_compile_options(graph_static PRIVATE $<$:-O2 -fPIC -Wextra -Wfloat-equal -Wno-array-bounds>) - - set_target_properties(graph_static PROPERTIES - WINDOWS_EXPORT_ALL_SYMBOLS TRUE - OUTPUT_NAME $,libgraph_share,graph_share> - ) -endif() - -############################################################## -set(STUB_HEADER_LIST - ${METADEF_DIR}/inc/external/graph/ascend_string.h - ${METADEF_DIR}/inc/external/graph/attr_value.h - ${METADEF_DIR}/inc/external/graph/gnode.h - ${METADEF_DIR}/inc/external/graph/graph.h - ${METADEF_DIR}/inc/external/graph/inference_context.h - ${METADEF_DIR}/inc/external/graph/operator_factory.h - ${METADEF_DIR}/inc/external/graph/types.h - ${METADEF_DIR}/inc/external/hcom/hcom_topo_info.h - ${METADEF_DIR}/inc/graph/utils/op_desc_utils.h - ${METADEF_DIR}/inc/graph/utils/op_desc_utils_ex.h - ${METADEF_DIR}/inc/graph/utils/node_utils_ex.h - ${METADEF_DIR}/inc/graph/utils/graph_utils_ex.h - ${METADEF_DIR}/inc/graph/utils/tensor_adapter.h - ${METADEF_DIR}/inc/graph/shape_refiner.h - ${METADEF_DIR}/inc/graph/opsproto_manager.h - ${METADEF_DIR}/inc/graph/runtime_inference_context.h - ${METADEF_DIR}/inc/graph/args_format_desc.h - ${METADEF_DIR}/inc/graph/ir_definitions_recover.h - ${METADEF_DIR}/third_party/transformer/src/axis_constants.h - ${METADEF_DIR}/inc/common/screen_printer.h -) - -# 当前开放的sample样例中仅链接libgraph桩包,考虑兼容性,将libgraph_base里的内容也打到libgraph桩包内。 -# ascend031下会同时使用libgraph桩包和libgraph_base编译和执行样例,而这些符号会影响对libgraph_base的链接, -# 故目前先做隔离,在ascend031下不将这些打到libgraph桩包内。 -if (NOT "${PRODUCT}" STREQUAL "ascend031") - list(APPEND STUB_HEADER_LIST - ${METADEF_DIR}/inc/external/graph/operator.h - ${METADEF_DIR}/inc/external/graph/tensor.h - ${METADEF_DIR}/inc/graph/ge_tensor.h - ) -endif () - -list(TRANSFORM STUB_HEADER_LIST - REPLACE "^.*/([^/]+)\\.h$" "${CMAKE_CURRENT_BINARY_DIR}/stub_\\1.cc" - OUTPUT_VARIABLE STUB_SRC_LIST -) - -add_custom_command( - OUTPUT ${STUB_SRC_LIST} - COMMAND echo "Generating stub files." - && ${HI_PYTHON} ${METADEF_DIR}/tests/stub/gen_stubapi.py ${CMAKE_CURRENT_BINARY_DIR} ${STUB_HEADER_LIST} - && echo "Generating stub files end." -) - -add_custom_target(graph_stub DEPENDS ${STUB_SRC_LIST}) - -############################################################# - -############ stub/libgraph.so ############ -add_library(atc_stub_graph SHARED ${STUB_SRC_LIST}) - -add_dependencies(atc_stub_graph graph_stub) - -target_include_directories(atc_stub_graph PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${METADEF_DIR} -) - -target_compile_options(atc_stub_graph PRIVATE - -Wfloat-equal - -fno-common - -Os - -Werror=return-type -) - -target_link_libraries(atc_stub_graph - PRIVATE - intf_pub - mmpa_headers - slog_headers - $<$>:-lrt> - -ldl - PUBLIC - metadef_headers -) - -set_target_properties(atc_stub_graph PROPERTIES - OUTPUT_NAME graph - LIBRARY_OUTPUT_DIRECTORY atc_stub -) - -############ stub/libgraph.a ############ -if (NOT ENABLE_OPEN_SRC) - target_clone(atc_stub_graph atc_stub_graph_static STATIC) - - add_dependencies(atc_stub_graph_static graph_stub) - - target_compile_options(atc_stub_graph_static PRIVATE - -ffunction-sections - -fdata-sections - ) - - set_target_properties(atc_stub_graph_static PROPERTIES - OUTPUT_NAME graph - ARCHIVE_OUTPUT_DIRECTORY atc_stub - ) -endif () - - -############ fwk_stub/libgraph.so ############ -add_library(fwk_stub_graph SHARED ${STUB_SRC_LIST}) - -add_dependencies(fwk_stub_graph graph_stub) - -target_compile_options(fwk_stub_graph PRIVATE - -Wfloat-equal - -fno-common -) - -target_include_directories(fwk_stub_graph PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${METADEF_DIR} -) - -target_link_libraries(fwk_stub_graph - PRIVATE - intf_pub - static_mmpa - -Wl,--no-as-needed - slog - json - -Wl,--as-needed - $<$>:-lrt> - -ldl - PUBLIC - metadef_headers -) - -set_target_properties(fwk_stub_graph PROPERTIES - OUTPUT_NAME graph - LIBRARY_OUTPUT_DIRECTORY fwk_stub -) -add_subdirectory(ascendc_ir) -add_subdirectory(expression) - -############ install ############ -install(TARGETS atc_stub_graph OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR}/stub -) - -install(TARGETS fwk_stub_graph OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR}/fwk_stub -) diff --git a/graph/args_format/arg_desc_info.cc b/graph/args_format/arg_desc_info.cc deleted file mode 100644 index 4ab4ba7c8dcb6c6467393ab5651cfd7acca3a12c..0000000000000000000000000000000000000000 --- a/graph/args_format/arg_desc_info.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "graph/arg_desc_info.h" -#include "arg_desc_info_impl.h" -#include "graph/debug/ge_util.h" -#include "common/checker.h" - -namespace ge { -ArgDescInfo::~ArgDescInfo() {} -ArgDescInfo::ArgDescInfo(ArgDescType arg_type, int32_t ir_index, bool is_folded) { - impl_ = ComGraphMakeUnique(arg_type, ir_index, is_folded); -} - -ArgDescInfo::ArgDescInfo(ArgDescInfoImplPtr &&impl) : impl_(std::move(impl)) {} - -ArgDescInfo::ArgDescInfo(const ArgDescInfo &other) { - impl_ = ComGraphMakeUnique(); - if ((other.impl_ != nullptr) && (impl_ != nullptr)) { - *impl_ = *other.impl_; - } -} -ArgDescInfo::ArgDescInfo(ArgDescInfo &&other) noexcept { - impl_ = std::move(other.impl_); -} -ArgDescInfo &ArgDescInfo::operator=(const ArgDescInfo &other) { - if (&other != this) { - impl_ = ComGraphMakeUnique(); - if ((other.impl_ != nullptr) && (impl_ != nullptr)) { - *impl_ = *other.impl_; - } - } - return *this; -} -ArgDescInfo &ArgDescInfo::operator=(ArgDescInfo &&other) noexcept { - if (&other != this) { - impl_ = std::move(other.impl_); - } - return *this; -} -ArgDescInfo ArgDescInfo::CreateCustomValue(uint64_t custom_value) { - return ArgDescInfo(ArgDescInfoImpl::CreateCustomValue(custom_value)); -} -ArgDescInfo ArgDescInfo::CreateHiddenInput(HiddenInputSubType hidden_type) { - return ArgDescInfo(ArgDescInfoImpl::CreateHiddenInput(hidden_type)); -} -ArgDescType ArgDescInfo::GetType() const { - if (impl_ != nullptr) { - return impl_->GetType(); - } - return ArgDescType::kEnd; -} -uint64_t ArgDescInfo::GetCustomValue() const { - if (impl_ != nullptr) { - return impl_->GetCustomValue(); - } - return std::numeric_limits::max(); -} -graphStatus ArgDescInfo::SetCustomValue(uint64_t custom_value) { - GE_ASSERT_NOTNULL(impl_); - return impl_->SetCustomValue(custom_value); -} -HiddenInputSubType ArgDescInfo::GetHiddenInputSubType() const { - if (impl_ != nullptr) { - return impl_->GetHiddenInputSubType(); - } - return HiddenInputSubType::kEnd; -} -graphStatus ArgDescInfo::SetHiddenInputSubType(HiddenInputSubType hidden_type) { - GE_ASSERT_NOTNULL(impl_); - return impl_->SetHiddenInputSubType(hidden_type); -} - -int32_t ArgDescInfo::GetIrIndex() const { - if (impl_ != nullptr) { - return impl_->GetIrIndex(); - } - return -1; -} - -void ArgDescInfo::SetIrIndex(int32_t ir_index) { - if (impl_ != nullptr) { - impl_->SetIrIndex(ir_index); - } -} - -bool ArgDescInfo::IsFolded() const { - if (impl_ != nullptr) { - return impl_->IsFolded(); - } - return false; -} -void ArgDescInfo::SetFolded(bool is_folded) { - if (impl_ != nullptr) { - impl_->SetFolded(is_folded); - } -} -} \ No newline at end of file diff --git a/graph/args_format/arg_desc_info_impl.cc b/graph/args_format/arg_desc_info_impl.cc deleted file mode 100644 index 7f52edd92d519b55d90b1dce95cb8a8bac4bd04a..0000000000000000000000000000000000000000 --- a/graph/args_format/arg_desc_info_impl.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "arg_desc_info_impl.h" -#include "graph/debug/ge_util.h" -#include "common/checker.h" - -namespace ge { -ArgDescInfoImpl::ArgDescInfoImpl(ArgDescType arg_type, int32_t ir_index, bool is_folded) - : arg_type_(arg_type), ir_index_(ir_index), is_folded_(is_folded) {} - -ArgDescInfoImplPtr ArgDescInfoImpl::CreateCustomValue(uint64_t custom_value) { - auto impl_ptr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(impl_ptr); - impl_ptr->arg_type_ = ArgDescType::kCustomValue; - impl_ptr->custom_value_ = custom_value; - return impl_ptr; -} - -ArgDescInfoImplPtr ArgDescInfoImpl::CreateHiddenInput(HiddenInputSubType hidden_type) { - auto impl_ptr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(impl_ptr); - impl_ptr->arg_type_ = ArgDescType::kHiddenInput; - impl_ptr->hidden_type_ = hidden_type; - return impl_ptr; -} - -ArgDescType ArgDescInfoImpl::GetType() const { - return arg_type_; -} -uint64_t ArgDescInfoImpl::GetCustomValue() const { - return custom_value_; -} -graphStatus ArgDescInfoImpl::SetCustomValue(uint64_t custom_value) { - GE_ASSERT_TRUE(arg_type_ == ArgDescType::kCustomValue, - "Only ArgDescType::kCustomValue arg desc info can set custom value"); - custom_value_ = custom_value; - return SUCCESS; -} -HiddenInputSubType ArgDescInfoImpl::GetHiddenInputSubType() const { - return hidden_type_; -} -graphStatus ArgDescInfoImpl::SetHiddenInputSubType(HiddenInputSubType hidden_type) { - GE_ASSERT_TRUE(arg_type_ == ArgDescType::kHiddenInput, - "Only ArgDescType::kHiddenInput arg desc info can set hidden input sub type"); - hidden_type_ = hidden_type; - return SUCCESS; -} -int32_t ArgDescInfoImpl::GetIrIndex() const { - return ir_index_; -} - -void ArgDescInfoImpl::SetIrIndex(int32_t ir_index) { - ir_index_ = ir_index; -} - -bool ArgDescInfoImpl::IsFolded() const { - return is_folded_; -} -void ArgDescInfoImpl::SetFolded(bool is_folded) { - is_folded_ = is_folded; -} -void ArgDescInfoImpl::SetInnerArgType(AddrType inner_arg_type) { - inner_arg_type_ = inner_arg_type; -} -AddrType ArgDescInfoImpl::GetInnerArgType() const { - return inner_arg_type_; -} -} \ No newline at end of file diff --git a/graph/args_format/arg_desc_info_impl.h b/graph/args_format/arg_desc_info_impl.h deleted file mode 100644 index 09b874d101fa3173e66e2f58e0570354b6bdc8db..0000000000000000000000000000000000000000 --- a/graph/args_format/arg_desc_info_impl.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_ARGS_FORMAT_ARG_DESC_INFO_IMPL_H -#define METADEF_GRAPH_ARGS_FORMAT_ARG_DESC_INFO_IMPL_H - -#include "graph/arg_desc_info.h" -#include "graph/utils/args_format_desc_utils.h" - -namespace ge { -class ArgDescInfoImpl { - public: - explicit ArgDescInfoImpl(ArgDescType arg_type, - int32_t ir_index = -1, bool is_folded = false); - ~ArgDescInfoImpl() = default; - ArgDescInfoImpl() = default; - static ArgDescInfoImplPtr CreateCustomValue(uint64_t custom_value); - static ArgDescInfoImplPtr CreateHiddenInput(HiddenInputSubType hidden_type); - ArgDescType GetType() const; - uint64_t GetCustomValue() const; - graphStatus SetCustomValue(uint64_t custom_value); - HiddenInputSubType GetHiddenInputSubType() const; - graphStatus SetHiddenInputSubType(HiddenInputSubType hidden_type); - void SetIrIndex(int32_t ir_index); - int32_t GetIrIndex() const; - bool IsFolded() const; - void SetFolded(bool is_folded); - void SetInnerArgType(AddrType inner_arg_type); - AddrType GetInnerArgType() const; - private: - ArgDescType arg_type_{ArgDescType::kEnd}; - AddrType inner_arg_type_{AddrType::MAX}; - int32_t ir_index_{-1}; - HiddenInputSubType hidden_type_{HiddenInputSubType::kEnd}; - uint64_t custom_value_{0}; - bool is_folded_{false}; -}; -} - -#endif // METADEF_GRAPH_ARGS_FORMAT_ARG_DESC_INFO_IMPL_H \ No newline at end of file diff --git a/graph/args_format/args_format_serializer.cc b/graph/args_format/args_format_serializer.cc deleted file mode 100644 index bc789cdf54b1f1ae933ad56a0814642313a687b7..0000000000000000000000000000000000000000 --- a/graph/args_format/args_format_serializer.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "graph/args_format/arg_desc_info_impl.h" -#include -#include "graph/utils/args_format_desc_utils.h" -#include "common/checker.h" - -namespace ge { -namespace { -bool HasIrIndex(AddrType type) { - return (type == AddrType::INPUT) || (type == AddrType::OUTPUT) || - (type == AddrType::INPUT_DESC) || (type == AddrType::OUTPUT_DESC); -} -AddrType TransToAddrType(ArgDescType args_type) { - static const std::unordered_map arg_desc_to_addr_type = { - {ArgDescType::kIrInput, AddrType::INPUT}, {ArgDescType::kIrOutput, AddrType::OUTPUT}, - {ArgDescType::kWorkspace, AddrType::WORKSPACE}, {ArgDescType::kTiling, AddrType::TILING}, - {ArgDescType::kIrInput, AddrType::INPUT_DESC}, {ArgDescType::kIrOutput, AddrType::OUTPUT_DESC}, - {ArgDescType::kHiddenInput, AddrType::HIDDEN_INPUT}, - {ArgDescType::kCustomValue, AddrType::CUSTOM_VALUE}, - {ArgDescType::kIrInputDesc, AddrType::INPUT_DESC}, - {ArgDescType::kIrOutputDesc, AddrType::OUTPUT_DESC}, - {ArgDescType::kInputInstance, AddrType::INPUT_INSTANCE}, - {ArgDescType::kOutputInstance, AddrType::OUTPUT_INSTANCE}}; - auto iter = arg_desc_to_addr_type.find(args_type); - if (iter != arg_desc_to_addr_type.end()) { - return iter->second; - } - return AddrType::MAX; -} - -ArgDescType TransToArgDescType(AddrType addr_type) { - static const std::unordered_map addr_to_arg_desc_type = { - {AddrType::INPUT, ArgDescType::kIrInput}, {AddrType::OUTPUT, ArgDescType::kIrOutput}, - {AddrType::WORKSPACE, ArgDescType::kWorkspace}, {AddrType::TILING, ArgDescType::kTiling}, - {AddrType::HIDDEN_INPUT, ArgDescType::kHiddenInput}, - {AddrType::CUSTOM_VALUE, ArgDescType::kCustomValue}, - {AddrType::INPUT_DESC, ArgDescType::kIrInputDesc}, - {AddrType::INPUT_INSTANCE, ArgDescType::kInputInstance}, - {AddrType::OUTPUT_DESC, ArgDescType::kIrOutputDesc}, - {AddrType::OUTPUT_INSTANCE, ArgDescType::kOutputInstance} - }; - auto iter = addr_to_arg_desc_type.find(addr_type); - if (iter != addr_to_arg_desc_type.end()) { - return iter->second; - } - return ArgDescType::kEnd; -} - -HiddenInputsType TransToHiddenInputType(HiddenInputSubType hidden_sub_type) { - static const std::unordered_map hidden_sub_types = { - {HiddenInputSubType::kHcom, HiddenInputsType::HCOM}, - {HiddenInputSubType::kEnd, HiddenInputsType::MAX} - }; - auto iter = hidden_sub_types.find(hidden_sub_type); - if (iter != hidden_sub_types.end()) { - return iter->second; - } - return HiddenInputsType::MAX; -} -HiddenInputSubType TransToHiddenInputSubType(HiddenInputsType hidden_type) { - static const std::unordered_map hidden_input_types = { - {HiddenInputsType::HCOM, HiddenInputSubType::kHcom}, - {HiddenInputsType::MAX, HiddenInputSubType::kEnd} - }; - auto iter = hidden_input_types.find(hidden_type); - if (iter != hidden_input_types.end()) { - return iter->second; - } - return HiddenInputSubType::kEnd; -} -} -AscendString ArgsFormatSerializer::Serialize(const std::vector &args_format) { - // 将args_desc_info转成arg_desc - std::vector arg_descs; - int32_t hidden_input_index = 0; - int32_t input_instance_index = 0; - int32_t output_instance_index = 0; - for (const auto &arg_desc_info : args_format) { - ArgDesc desc; - desc.addr_type = TransToAddrType(arg_desc_info.GetType()); - // 当内部类型无法被解析出来时,表示这个argDescInfo可能是内部框架生成,尝试使用inner_arg_type做序列化 - if (desc.addr_type == AddrType::MAX) { - GE_ASSERT_NOTNULL(arg_desc_info.impl_); - desc.addr_type = arg_desc_info.impl_->GetInnerArgType(); - } - // kHiddenInput,kInputInstance和kOutputInstance的索引需要单独排序 - if (arg_desc_info.GetType() == ArgDescType::kHiddenInput) { - desc.ir_idx = hidden_input_index; - hidden_input_index++; - } else if (arg_desc_info.GetType() == ArgDescType::kInputInstance) { - desc.ir_idx = input_instance_index; - input_instance_index++; - } else if (arg_desc_info.GetType() == ArgDescType::kOutputInstance) { - desc.ir_idx = output_instance_index; - output_instance_index++; - } else { - desc.ir_idx = arg_desc_info.GetIrIndex(); - } - - desc.folded = arg_desc_info.IsFolded(); - if (arg_desc_info.GetType() == ArgDescType::kCustomValue) { - *reinterpret_cast(desc.reserved) = arg_desc_info.GetCustomValue(); - } else if (arg_desc_info.GetType() == ArgDescType::kHiddenInput) { - *reinterpret_cast(desc.reserved) = - static_cast(TransToHiddenInputType(arg_desc_info.GetHiddenInputSubType())); - } else { - // static check - } - arg_descs.emplace_back(desc); - } - return AscendString(ArgsFormatDescUtils::Serialize(arg_descs).c_str()); -} - -std::vector ArgsFormatSerializer::Deserialize(const AscendString &args_str) { - std::vector arg_descs; - GE_ASSERT_SUCCESS(ArgsFormatDescUtils::Parse(std::string(args_str.GetString()), arg_descs)); - // 将args_desc转成arg_desc_info - std::vector args_format; - for (const auto &desc : arg_descs) { - auto arg_desc_type = TransToArgDescType(desc.addr_type); - ArgDescInfo arg_desc(arg_desc_type, -1, desc.folded); - if (HasIrIndex(desc.addr_type)) { - // 除了kIrInput,kIrOutput,kIrInputDesc,kIrOutputDesc,其他type没有ir_index - arg_desc.SetIrIndex(desc.ir_idx); - } - GE_ASSERT_NOTNULL(arg_desc.impl_); - arg_desc.impl_->SetInnerArgType(desc.addr_type); - if (desc.addr_type == AddrType::CUSTOM_VALUE) { - arg_desc.SetCustomValue(*reinterpret_cast(desc.reserved)); - } else if (desc.addr_type == AddrType::HIDDEN_INPUT) { - arg_desc.SetHiddenInputSubType(TransToHiddenInputSubType( - static_cast(*reinterpret_cast(desc.reserved)))); - } else { - // static check - } - args_format.emplace_back(arg_desc); - } - return args_format; -} -} \ No newline at end of file diff --git a/graph/ascendc_ir/CMakeLists.txt b/graph/ascendc_ir/CMakeLists.txt deleted file mode 100644 index 30f3bd57aef6470823b37df76c985b4be3a04285..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/CMakeLists.txt +++ /dev/null @@ -1,54 +0,0 @@ -add_subdirectory(generator) -######### libaihac_ir.so ############# -file(GLOB_RECURSE UTILS_SRCS CONFIGURE_DEPENDS "utils/*.cc") -add_library(aihac_ir SHARED - core/ascendc_ir.cc - ${UTILS_SRCS} -) -target_compile_options(aihac_ir PRIVATE -DNO_METADEF_ABI_COMPATIABLE -O2 -Werror) - -target_compile_options(aihac_ir PRIVATE - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$,$>:/MTd> - $<$,$>:/MT>) - -target_compile_definitions(aihac_ir PRIVATE - $<$,$>:FMK_SUPPORT_DUMP> - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> -) - -target_include_directories(aihac_ir PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${METADEF_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_link_options(aihac_ir PRIVATE - -Wl,-Bsymbolic -) - -target_link_libraries(aihac_ir - PRIVATE - intf_pub - static_mmpa - -Wl,--no-as-needed - c_sec - slog - json - platform - -Wl,--as-needed - ascend_protobuf_shared_headers - $<$>:-lrt> - -ldl - graph - graph_base - error_manager - aihac_symbolizer - PUBLIC - metadef_headers -) \ No newline at end of file diff --git a/graph/ascendc_ir/core/ascendc_ir.cc b/graph/ascendc_ir/core/ascendc_ir.cc deleted file mode 100644 index c88aa95d791ed583542d6c0e37b389b205a16812..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/core/ascendc_ir.cc +++ /dev/null @@ -1,1764 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - - -#include "inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "ascendc_ir_impl.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/utils/cg_utils.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_op_types.h" -#include "graph/ascendc_ir/utils/asc_tensor_utils.h" -#include "graph/ascendc_ir/utils/asc_graph_utils.h" -#include "expression/const_values.h" -#include "graph/attribute_group/attr_group_serializer_registry.h" -#include "inc/graph/attribute_group/attr_group_shape_env.h" -#include "graph/symbolizer/symbolic_utils.h" -#include "inc/graph/attribute_group/attr_group_symbolic_desc.h" - -namespace ge { -namespace { -constexpr int32_t kDefaultAlignVal = 1; -constexpr uint32_t kMinMergeAxisFromSize = 2U; -const char *const kAscData = ge::DATA; -const char *const kAscOutput = "Output"; -} - -// TODO ascend attr will be split into asc_attr_group -std::unique_ptr AscGraphAttr::Clone() { - auto ptr = ComGraphMakeUnique(*this); - GE_ASSERT_NOTNULL(ptr); - return ptr; -} - -std::unique_ptr AscNodeAttr::Clone() { - auto ptr = ComGraphMakeUnique(*this); - GE_ASSERT_NOTNULL(ptr); - return ptr; -} - -AscNodeAttr *AscNodeAttr::CreateImpl(ge::Operator &op) { - auto opdesc = ge::OpDescUtils::GetOpDescFromOperator(op).get(); - GE_ASSERT_NOTNULL(opdesc); - GE_ASSERT_TRUE(opdesc->GetAttrsGroup() == nullptr); - auto attr_group = opdesc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(attr_group); - return attr_group; -} - -AscNodeAttr &AscNodeAttr::operator=(const AscNodeAttr &other) { - if (this == &other) { - return *this; - } - name = other.name; - type = other.type; - sched = other.sched; - api = other.api; - tmp_buffers = other.tmp_buffers; - - if (other.ir_attr) { - ir_attr = other.ir_attr->Clone(); - } else { - ir_attr.reset(); - } - return *this; -} - -AscNodeAttr *AscNodeAttr::Create(Operator &op) { - return CreateImpl(op); -} - -AscTensorAttr &AscTensorAttr::GetTensorAttr(ge::Operator *op, const uint32_t index) { - try { - auto attr_group = GetTensorAttrPtr(op, index); - CHECK_NOTNULL_WITH_THROW_EXCEPTION(attr_group); - return *attr_group; - } catch (const AscIRException &exception) { - GELOGE(FAILED, "Create failed, reason is %s", exception.GetInfo().error_msg.c_str()); - static AscTensorAttr asc_tensor_attr; - return asc_tensor_attr; - } -} - -AscTensorAttr *AscTensorAttr::GetTensorAttrPtr(ge::Operator *op, const uint32_t index) { - GE_ASSERT_NOTNULL(op); - const auto desc = ge::OpDescUtils::GetOpDescFromOperator(*op); - GE_ASSERT_NOTNULL(desc); - auto tensor = desc->MutableOutputDesc(index); - if (tensor == nullptr) { - return nullptr; - } - const auto attr_group = tensor->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(attr_group); - attr_group->dtype.tensor_desc_ = tensor.get(); - return attr_group; -} - -AscTensorAttr &AscTensorAttr::GetTensorAttr(const OutDataAnchor &output) { - try { - auto attr_group = GetTensorAttrPtr(output); - CHECK_NOTNULL_WITH_THROW_EXCEPTION(attr_group); - return *attr_group; - } - catch (const AscIRException &exception) { - GELOGE(FAILED, "Create failed, reason is %s", exception.GetInfo().error_msg.c_str()); - static AscTensorAttr asc_tensor_attr; - return asc_tensor_attr; - } -} - -AscTensorAttr *AscTensorAttr::GetTensorAttrPtr(const OutDataAnchor &output) { - const auto node = output.GetOwnerNodeBarePtr(); - GE_ASSERT_NOTNULL(node); - const auto op_desc = node->GetOpDescBarePtr(); - const auto tensor_desc = op_desc->MutableOutputDesc(output.GetIdx()); - GE_ASSERT_NOTNULL(tensor_desc); - const auto attr_group = tensor_desc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(attr_group); - attr_group->dtype.tensor_desc_ = tensor_desc.get(); - return attr_group; -} - -std::unique_ptr AscTensorAttr::Clone() { - auto ptr = ComGraphMakeUnique(*this); - GE_ASSERT_NOTNULL(ptr); - return ptr; -} - -void AscNodeOutputs::Init() { - // node_和out_data_anchor代码逻辑可以保证非空 - for (const auto &output : node_->GetAllOutDataAnchorsPtr()) { - tensors_.emplace_back(AscTensor(*output)); - } -} - -AscTensor &AscNodeOutputs::operator[](uint32_t index) { - if (tensors_.empty()) { - Init(); - } - CHECK_BOOL_WITH_THROW_EXCEPTION(index < tensors_.size(), - "index = %u but tensors_.size() = %zu", - index, - tensors_.size()); - return tensors_[index]; -} - -std::vector AscNodeOutputs::operator()() { - if (tensors_.empty()) { - Init(); - } - if (tensors_.empty()) { - return {}; - } - std::vector tensors; - for (auto &tensor : tensors_) { - tensors.push_back(&tensor); - } - return tensors; -} - -void AscNodeInputs::Init() { - std::vector tmp_tensors; - // node_和in_data_anchor代码逻辑可以保证非空 - for (const auto &in_anchor : node_->GetAllInDataAnchorsPtr()) { - const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - GELOGD("node[%s, %s] link [%d] are not ready", node_->GetNamePtr(), node_->GetTypePtr(), in_anchor->GetIdx()); - continue; - } - tmp_tensors.emplace_back(AscTensor(*peer_out_anchor)); - } - tensors_ = std::move(tmp_tensors); -} - -// make sure ascend graph is fixed, if index 1 is first linked, tensors_ 0 means index 1, that may cause bug -AscTensor &AscNodeInputs::operator[](uint32_t index) { - // as not all input is ready at the same time, must call Init on every function call - Init(); - CHECK_BOOL_WITH_THROW_EXCEPTION(index < tensors_.size()); - return tensors_[index]; -} - -std::vector AscNodeInputs::operator()() { - // as not all input is ready at the same time, must call Init on every function call - Init(); - if (tensors_.empty()) { - return {}; - } - const auto node = ascir::AscTensorUtils::GetOwner(tensors_[0]); - GE_ASSERT_NOTNULL(node); - auto op_desc = node->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - std::vector tensors; - for (auto &tensor : tensors_) { - tensors.emplace_back(&tensor); - } - return tensors; -} - -uint32_t AscNodeInputs::Size() { - // as not all input is ready at the same time, must call Init on every function call - Init(); - return tensors_.size(); -} - -// 此处op_desc和GetOrCreateAttrsGroup的返回值未判空,内部构造AscNode前已判空 -// 资料需注明不允许外部用户构造AscNode -AscNode::AscNode(const OpDescPtr &op_desc, const ComputeGraphPtr &compute_graph) : - Node(op_desc, compute_graph), inputs(this), outputs(this), - attr(*(op_desc->GetOrCreateAttrsGroup())) { - if (op_desc != nullptr) { - attr.name = op_desc->GetName(); - attr.type = op_desc->GetType(); - } -} - -AscNodeIter::AscNodeIter(ge::ComputeGraph::Vistor::Iterator &&iter) : impl_(iter) {} - -AscNodeIter &AscNodeIter::operator++() { - impl_++; - return *this; -} - -AscNodePtr AscNodeIter::operator*() { - auto ptr = *impl_; - return std::dynamic_pointer_cast(ptr); -} - -bool AscNodeIter::operator!=(const AscNodeIter &other) const { - return impl_ != other.impl_; -} - -AscNodeVisitor::AscNodeVisitor(ge::ComputeGraph::Vistor &&visitor) - : impl_(visitor) {} - -AscNodeIter AscNodeVisitor::begin() { - return AscNodeIter(impl_.begin()); -} - -AscNodeIter AscNodeVisitor::end() { - return AscNodeIter(impl_.end()); -} - -AscGraphImpl::AscGraphImpl(const char *name) : - compute_graph_(ComGraphMakeSharedAndThrow(name)) {} - -std::string AscGraphImpl::GetName() const { - return compute_graph_->GetName(); -} - - -void AscGraphImpl::SetTilingKey(const uint32_t tiling_key) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - graph_attr_group_ptr->tiling_key = static_cast(tiling_key); -} - -int64_t AscGraphImpl::GetTilingKey() const { - const auto graph_attr_group_ptr = GetGraphAttrsGroup(); - GE_WARN_ASSERT(graph_attr_group_ptr); - return graph_attr_group_ptr->tiling_key; -} - -void AscGraphImpl::SetGraphType(const AscGraphType type) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - graph_attr_group_ptr->type = type; -} - -AscGraphType AscGraphImpl::GetGraphType() const { - const auto graph_attr_group_ptr = GetGraphAttrsGroup(); - GE_WARN_ASSERT(graph_attr_group_ptr); - return graph_attr_group_ptr->type; -} - -AscNodePtr AscGraphImpl::AddNode(ge::Operator &op) { - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_ASSERT_NOTNULL(op_desc); - AscNodePtr asc_node = std::make_shared(op_desc, compute_graph_); - GE_ASSERT_NOTNULL(asc_node); - GE_ASSERT_GRAPH_SUCCESS(asc_node->Init()); - ConstNodePtr const_node = asc_node; - GE_ASSERT_GRAPH_SUCCESS(ge::NodeUtilsEx::SetNodeToOperator(op, const_node)); - auto node = compute_graph_->AddNode(asc_node); - auto new_node = std::dynamic_pointer_cast(node); - // update - (void) new_node->inputs(); - (void) new_node->outputs(); - return new_node; -} - -AscNodePtr AscGraphImpl::FindNode(const char *name) const{ - auto node = compute_graph_->FindNode(name); - auto dst_node = std::dynamic_pointer_cast(node); - return dst_node; -} - -AscNodeVisitor AscGraphImpl::GetAllNodes() const{ - return AscNodeVisitor(compute_graph_->GetAllNodes()); -} - -AscNodeVisitor AscGraphImpl::GetInputNodes() const{ - return AscNodeVisitor(compute_graph_->GetInputNodes()); -} - -ge::Expression AscGraphImpl::CreateSizeVar(const int64_t value) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto expr = Symbol(value); - const auto size_var = ComGraphMakeShared(expr); - GE_ASSERT_NOTNULL(size_var); - graph_attr_group_ptr->size_vars.push_back(size_var); - return graph_attr_group_ptr->size_vars.back()->expr; -} - -ge::Expression AscGraphImpl::CreateSizeVar(const std::string &name) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto expr = Symbol(name.c_str()); - const auto size_var = ComGraphMakeShared(expr); - GE_ASSERT_NOTNULL(size_var); - graph_attr_group_ptr->size_vars.push_back(size_var); - return graph_attr_group_ptr->size_vars.back()->expr; -} - -AxisPtr AscGraphImpl::CreateAxis(const std::string &name, Axis::Type type, - const Expression &size, const std::vector &from, const int64_t split_peer) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - auto axis = ComGraphMakeShared(); - GE_ASSERT_NOTNULL(axis); - axis->type = type; - axis->name = name; - axis->size = size; - axis->from = from; - axis->align = kDefaultAlignVal; - axis->split_pair_other_id = split_peer; - axis->allow_oversize_axis = false; - axis->allow_unaligned_tail = true; - axis->id = static_cast(graph_attr_group_ptr->axis.size()); - - graph_attr_group_ptr->axis.push_back(std::move(axis)); - - return graph_attr_group_ptr->axis.back(); -} - -std::vector AscGraphImpl::GetAllAxis() const { - const auto graph_attr_group_ptr = GetGraphAttrsGroup(); - GE_WARN_ASSERT(graph_attr_group_ptr); - return graph_attr_group_ptr->axis; -} - -std::vector AscGraphImpl::GetAllSizeVar() const { - const auto graph_attr_group_ptr = GetGraphAttrsGroup(); - GE_WARN_ASSERT(graph_attr_group_ptr); - return graph_attr_group_ptr->size_vars; -} - -TransInfoRoadOfGraph AscGraphImpl::GetAllAxisTransInfo() const { - const auto graph_attr_group_ptr = GetGraphAttrsGroup(); - GE_WARN_ASSERT(graph_attr_group_ptr); - return graph_attr_group_ptr->trans_info_road; -} - -Axis *AscGraphImpl::FindAxis(const int64_t axis_id) const { - const auto graph_attr_group_ptr = GetGraphAttrsGroup(); - GE_WARN_ASSERT(graph_attr_group_ptr); - if (axis_id < 0 || axis_id > static_cast(graph_attr_group_ptr->axis.size())) { - return nullptr; - } - return graph_attr_group_ptr->axis[axis_id].get(); -} - -std::pair AscGraphImpl::DoSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name, const bool is_tile_split) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto &axis = graph_attr_group_ptr->axis; - GE_ASSERT_TRUE((axis_id >= 0) && (static_cast(axis_id) < axis.size())); - - const auto &single_axis = *axis[axis_id]; - const std::string inner_suffix = is_tile_split ? "t" : "b"; - const std::string outer_suffix = is_tile_split ? "T" : "B"; - std::string actual_inner_axis_name = inner_axis_name; - if (actual_inner_axis_name.empty()) { - actual_inner_axis_name = single_axis.name + inner_suffix; - } - std::string actual_outer_axis_name = outer_axis_name; - if (actual_outer_axis_name.empty()) { - actual_outer_axis_name = single_axis.name + outer_suffix; - } - ge::Expression inner_size; - ge::Expression outer_size; - if (single_axis.size == sym::kSymbolOne) { - inner_size = sym::kSymbolOne; - outer_size = sym::kSymbolOne; - } else { - inner_size = CreateSizeVar(actual_inner_axis_name + "_size"); - outer_size = ge::sym::Ceiling(single_axis.size / inner_size); - } - - Axis::Type inner_type = is_tile_split ? Axis::kAxisTypeTileInner : Axis::kAxisTypeBlockInner; - Axis::Type outer_type = is_tile_split ? Axis::kAxisTypeTileOuter : Axis::kAxisTypeBlockOuter; - auto outter_id = static_cast(graph_attr_group_ptr->axis.size()); - int64_t inner_id = outter_id + 1; - AxisPtr outer = CreateAxis(actual_outer_axis_name, outer_type, outer_size, {axis_id}, inner_id); - AxisPtr inner = CreateAxis(actual_inner_axis_name, inner_type, inner_size, {axis_id}, outter_id); - graph_attr_group_ptr->trans_info_road.push_back({TransType::kSplit, {axis[axis_id]}, {outer, inner}}); - return {outer, inner}; -} - -std::pair AscGraphImpl::BlockSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name) { - return DoSplit(axis_id, outer_axis_name, inner_axis_name, false); -} - -std::pair AscGraphImpl::TileSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name) { - return DoSplit(axis_id, outer_axis_name, inner_axis_name, true); -} - -AxisPtr AscGraphImpl::MergeAxis(const std::vector &axis_ids, const std::string &merge_axis_name) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - const auto &axis = graph_attr_group_ptr->axis; - std::string name; - Expression size = sym::kSymbolOne; - std::vector from_axis_ids; - std::vector from_axis; - for (const auto &axis_id : axis_ids) { - GE_ASSERT_TRUE((axis_id >= 0) && (static_cast(axis_id) < axis.size())); - from_axis.push_back(axis[axis_id]); - name += axis[axis_id]->name; - size = size * axis[axis_id]->size; - from_axis_ids.push_back(axis_id); - } - name = merge_axis_name.empty() ? name : merge_axis_name; - AxisPtr merge_axis = CreateAxis(name, Axis::kAxisTypeMerged, size, from_axis_ids); - graph_attr_group_ptr->trans_info_road.push_back({TransType::kMerge, from_axis, {merge_axis}}); - return merge_axis; -} - -bool AscGraphImpl::BindBlock(const int64_t outter_id, const int64_t inner_id) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto &axis = graph_attr_group_ptr->axis; - GE_ASSERT_TRUE((outter_id >= 0) && (static_cast(outter_id) < axis.size())); - GE_ASSERT_TRUE((inner_id >= 0) && (static_cast(inner_id) < axis.size())); - - auto outter_axis = axis[outter_id]; - GE_ASSERT_NOTNULL(outter_axis); - outter_axis->type = Axis::kAxisTypeBlockOuter; - outter_axis->name.append("B"); - - auto inner_axis = axis[inner_id]; - GE_ASSERT_NOTNULL(inner_axis); - inner_axis->type = Axis::kAxisTypeBlockInner; - inner_axis->name.append("b"); - return true; -} - -bool AscGraphImpl::DoApplySplit(const AscNodePtr &node, const int64_t outter_id, - const int64_t inner_id, const int64_t original_id) { - GE_ASSERT_NOTNULL(node); - GE_ASSERT_TRUE(DoApplySchedAxisSplit(node, outter_id, inner_id, original_id)); - GE_ASSERT_TRUE(DoApplyTensorAxisSplit(node, outter_id, inner_id, original_id)); - return true; -} - -bool AscGraphImpl::DoApplyTensorAxisSplit(const AscNodePtr &node, const int64_t outter_id, - const int64_t inner_id, const int64_t original_id) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto &all_axis = graph_attr_group_ptr->axis; - // check inner_axis before - const Expression &split_size = all_axis[inner_id]->size; - for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); i++) { - const auto &result = - AxisUtils::SplitView({node->outputs[i].attr.axis, - node->outputs[i].attr.repeats, node->outputs[i].attr.strides}, - split_size, - outter_id, - inner_id, - original_id); - GE_ASSERT_TRUE(!result.axis_ids.empty(), - "Split out view failed for node %s %s, index %u", - node->GetNamePtr(), - node->GetTypePtr(), - i); - node->outputs[i].attr.axis = result.axis_ids; - node->outputs[i].attr.repeats = result.repeats; - node->outputs[i].attr.strides = result.strides; - } - return true; -} - -bool AscGraphImpl::DoApplySchedAxisSplit(const AscNodePtr &node, const int64_t outter_id, - const int64_t inner_id, const int64_t original_id) { - std::vector new_node_attr_axis; - const auto &node_axis = node->attr.sched.axis; - for (auto &node_axis_id : node_axis) { - if (node_axis_id == original_id) { - new_node_attr_axis.push_back(outter_id); - new_node_attr_axis.push_back(inner_id); - } else { - new_node_attr_axis.push_back(node_axis_id); - } - } - node->attr.sched.axis = new_node_attr_axis; - return true; -} - -bool AscGraphImpl::ApplySplit(const AscNodePtr &node, const int64_t outter_id, const int64_t inner_id) { - GE_ASSERT_NOTNULL(node); - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto &all_axis = graph_attr_group_ptr->axis; - GE_ASSERT_TRUE( - (outter_id >= 0) && (outter_id < static_cast(all_axis.size())) && - (inner_id >= 0) && (inner_id < static_cast(all_axis.size()))); - const auto &out_axis = *all_axis[outter_id]; - const auto &in_axis = *all_axis[inner_id]; - GE_ASSERT_TRUE((out_axis.type == Axis::kAxisTypeBlockOuter && - in_axis.type == Axis::kAxisTypeBlockInner) || - (out_axis.type == Axis::kAxisTypeTileOuter && in_axis.type == Axis::kAxisTypeTileInner)); - GE_ASSERT_TRUE( - (out_axis.from.size() == 1U) && (in_axis.from.size() == 1U) && (out_axis.from[0] == in_axis.from[0])); - return DoApplySplit(node, outter_id, inner_id, out_axis.from[0]); -} - -bool AscGraphImpl::DoApplyMerge(const AscNodePtr &node, - const int64_t merged_axis_id, - const std::vector &original) { - GE_ASSERT_NOTNULL(node); - GE_ASSERT_TRUE(DoApplySchedAxisMerge(node, merged_axis_id, original)); - GE_ASSERT_TRUE(DoApplyTensorAxisMerge(node, merged_axis_id, original)); - return true; -} - -bool AscGraphImpl::DoApplySchedAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id, - const std::vector &original) { - std::vector new_node_attr_axis; - std::set original_set(original.begin(), original.end()); - GE_ASSERT_TRUE(original_set.size() == original.size(), "merge axis redundant"); - std::set merge_axis_set; - size_t first_merge_axis_index = SIZE_MAX; - for (size_t axis_index = 0; axis_index < node->attr.sched.axis.size(); ++axis_index) { - if (original_set.find(node->attr.sched.axis[axis_index]) != original_set.end()) { - if (first_merge_axis_index == SIZE_MAX) { - first_merge_axis_index = axis_index; // 记录首个待合并轴的位置 - } - merge_axis_set.emplace(node->attr.sched.axis[axis_index]); - if (merge_axis_set.size() == original.size()) { - new_node_attr_axis.insert(new_node_attr_axis.begin() + first_merge_axis_index, merged_axis_id); // 合并轴放入首个待合并轴位置 - } - } else { - new_node_attr_axis.push_back(node->attr.sched.axis[axis_index]); - } - } - GE_ASSERT_TRUE( - merge_axis_set.size() == original.size() || merge_axis_set.empty(), - "node {%s} has sched.axis %s but origin is %s", - node->GetNamePtr(), - ViewMemberToString(node->attr.sched.axis).c_str(), - ViewMemberToString(original).c_str()); - node->attr.sched.axis = new_node_attr_axis; - return true; -} - -bool AscGraphImpl::DoApplySchedAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - const auto &node_axis = node->attr.sched.axis; - for (const auto axis_id : reordered_axis) { - const auto it = std::find(node_axis.begin(), node_axis.end(), axis_id); - GE_ASSERT_TRUE(it != node_axis.end(), - "can not find axis_id[%ld] of reordered_axis, node[%s,%s]", axis_id, - node->GetNamePtr(), node->GetTypePtr()); - } - node->attr.sched.axis = reordered_axis; - return true; -} - -bool AscGraphImpl::DoApplyTensorAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - const auto &node_axis = node->attr.sched.axis; - for (const auto axis_id : reordered_axis) { - const auto it = std::find(node_axis.begin(), node_axis.end(), axis_id); - GE_ASSERT_TRUE(it != node_axis.end(), "can not find axis_id[%ld] of reordered_axis, node[%s,%s]", axis_id, - node->GetNamePtr(), node->GetTypePtr()); - } - for (const auto output_ptr : node->outputs()) { - auto &output = *output_ptr; - std::vector new_axis; - std::vector new_repeat; - std::vector new_strides; - auto output_axis = output.attr.axis; - for (const auto axis_id : reordered_axis) { - const auto it = std::find(output_axis.begin(), output_axis.end(), axis_id); - if (it == output_axis.end()) { - continue; - } - const auto pos = std::distance(output_axis.begin(), it); - new_axis.push_back(output_axis[pos]); - new_repeat.push_back(output.attr.repeats[pos]); - new_strides.push_back(output.attr.strides[pos]); - } - output.attr.axis = new_axis; - output.attr.repeats = new_repeat; - output.attr.strides = new_strides; - } - return true; -} - -bool AscGraphImpl::DoCopyAscGraphAttr(const AscGraph &src_asc_graph, AscGraph &dst_asc_graph) { - return DoCopyAscGraphAttrImpl(AscGraphUtils::GetComputeGraph(src_asc_graph), - AscGraphUtils::GetComputeGraph(dst_asc_graph)); -} - -bool AscGraphImpl::DoCopyAscGraphAttrImpl(const ComputeGraphPtr &src_compute_graph, - const ComputeGraphPtr &dst_compute_graph) { - GE_ASSERT_NOTNULL(src_compute_graph); - GE_ASSERT_NOTNULL(dst_compute_graph); - const auto dst_graph_attr = dst_compute_graph->GetOrCreateAttrsGroup(); - const auto src_graph_attr = src_compute_graph->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(dst_graph_attr); - GE_ASSERT_NOTNULL(src_graph_attr); - - ascendc_ir::proto::AscGraphAttrGroupsDef asc_graph_group; - GE_ASSERT_GRAPH_SUCCESS(src_graph_attr->SerializeAttr(asc_graph_group)); - GE_ASSERT_GRAPH_SUCCESS(dst_graph_attr->DeserializeAttr(asc_graph_group)); - - return true; -} - -bool AscGraphImpl::DoCopyAscNodeAndRelink(const AscGraph &src_asc_graph, AscGraph &dst_asc_graph) { - const auto src_compute_graph = AscGraphUtils::GetComputeGraph(src_asc_graph); - auto dst_compute_graph = AscGraphUtils::GetComputeGraph(dst_asc_graph); - GE_ASSERT_NOTNULL(src_compute_graph); - GE_ASSERT_NOTNULL(dst_compute_graph); - std::unordered_map all_new_nodes; - for (const auto &src_node : src_asc_graph.GetAllNodes()) { - const auto &op_desc = GraphUtils::CopyOpDesc(src_node->GetOpDesc(), nullptr); - GE_ASSERT_NOTNULL(op_desc); - op_desc->SetName(src_node->GetName()); - ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); - auto dst_new_node = dst_asc_graph.AddNode(op); - all_new_nodes[dst_new_node->GetName()] = std::dynamic_pointer_cast(dst_new_node); - DoCopyAscNodeTensorAttr(src_node, dst_new_node); - } - - for (const auto &src_node : src_compute_graph->GetAllNodes()) { - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RelinkGraphEdges(src_node, "", all_new_nodes)); - } - return true; -} - -bool AscGraphImpl::DoCopyAscNodeTensorAttr(const AscNodePtr &src_node, AscNodePtr &dst_node) { - // op_desc保证非空 - auto op_desc = dst_node->GetOpDesc(); - auto dst_asc_node_attr = op_desc->GetOrCreateAttrsGroup(); - auto src_asc_node_attr = src_node->GetOpDesc()->GetOrCreateAttrsGroup(); - if (src_asc_node_attr != nullptr && dst_asc_node_attr != nullptr) { - *dst_asc_node_attr = *src_asc_node_attr; - } - for (uint32_t i = 0; i < src_node->outputs().size(); i++) { - GE_ASSERT_NOTNULL(op_desc->MutableOutputDesc(i)); - auto tensor_attr_group = op_desc->MutableOutputDesc(i)->GetAttrsGroup(); - GE_ASSERT_NOTNULL(tensor_attr_group); - *tensor_attr_group = src_node->outputs[i].attr; - } - return true; -} - -// original中的轴不连续时没法做合轴 -// 判断轴是否连续 stride_i == repeat_{i+1} * stride_{i+1} -bool AscGraphImpl::CheckContinuous(const AscNodePtr &node, - const uint32_t tensor_index, - const std::vector &original) { - std::vector repeats; - std::vector strides; - std::set original_set(original.begin(), original.end()); - std::set merge_axis_set; - auto axis = node->outputs[tensor_index].attr.axis; - for (uint32_t axis_index = 0U; axis_index < axis.size(); axis_index++) { - if (original_set.find(axis[axis_index]) != original_set.end()) { - repeats.emplace_back(node->outputs[tensor_index].attr.repeats[axis_index]); - strides.emplace_back(node->outputs[tensor_index].attr.strides[axis_index]); - merge_axis_set.emplace(axis[axis_index]); - } - } - GE_ASSERT_TRUE( - merge_axis_set.size() == original_set.size() || merge_axis_set.empty(), - "node {%s}'s output[%u] has axis %s but origin is %s", - node->GetNamePtr(), tensor_index, - ViewMemberToString(axis).c_str(), - ViewMemberToString(original).c_str()); - if (repeats.size() <= 1U) { - return true; - } - for (uint32_t i = 0U; i < repeats.size() - 1; i++) { - auto post_stride = repeats[i + 1] * strides[i + 1]; - if (ge::SymbolicUtils::StaticCheckEq(strides[i], post_stride) != ge::TriBool::kTrue) { - GELOGD("strides of %u is %s but {repeats * strides} of %u is %s", i, strides[i].Str().get(), i + 1, - post_stride.Str().get()); - return false; - } - } - return true; -} - -bool AscGraphImpl::DoApplyTensorAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id, - const std::vector &original) { - for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); i++) { - if (!CheckContinuous(node, i, original)) { - GELOGW("%s's [%u]th output's view is not continuous.", node->GetNamePtr(), i); - continue; - } - const auto &view = - AxisUtils::MergeView({node->outputs[i].attr.axis, node->outputs[i].attr.repeats, node->outputs[i].attr.strides}, - merged_axis_id, original); - node->outputs[i].attr.axis = view.axis_ids; - node->outputs[i].attr.repeats = view.repeats; - node->outputs[i].attr.strides = view.strides; - } - return true; -} - -bool AscGraphImpl::ApplyMerge(const AscNodePtr &node, const int64_t merged_axis_id) { - GE_ASSERT_NOTNULL(node); - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto &all_axis = graph_attr_group_ptr->axis; - GE_ASSERT_TRUE( - (merged_axis_id >= 0) && (merged_axis_id < static_cast(all_axis.size()))); - const auto &axis = *all_axis[merged_axis_id]; - GE_ASSERT_TRUE((axis.type == Axis::kAxisTypeMerged) && - axis.from.size() >= kMinMergeAxisFromSize); - return DoApplyMerge(node, merged_axis_id, axis.from); -} - -bool AscGraphImpl::ApplyTensorAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id) { - GE_ASSERT_NOTNULL(node); - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto &all_axis = graph_attr_group_ptr->axis; - GE_ASSERT_TRUE( - (merged_axis_id >= 0) && (merged_axis_id < static_cast(all_axis.size()))); - const auto &axis = *all_axis[merged_axis_id]; - GE_ASSERT_TRUE( - (axis.type == Axis::kAxisTypeMerged) && axis.from.size() >= kMinMergeAxisFromSize); - return DoApplyTensorAxisMerge(node, merged_axis_id, axis.from); -} - -bool AscGraphImpl::ApplySchedAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id) { - GE_ASSERT_NOTNULL(node); - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto &all_axis = graph_attr_group_ptr->axis; - GE_ASSERT_TRUE( - (merged_axis_id >= 0) && (merged_axis_id < static_cast(all_axis.size()))); - const auto &axis = *all_axis[merged_axis_id]; - GE_ASSERT_TRUE( - (axis.type == Axis::kAxisTypeMerged) && axis.from.size() >= kMinMergeAxisFromSize); - return DoApplySchedAxisMerge(node, merged_axis_id, axis.from); -} - -bool AscGraphImpl::ApplyTensorAxisMerge(const AscNodePtr &node, - const int64_t merged_axis_id, - const std::vector &original) { - GE_ASSERT_NOTNULL(node); - return DoApplyTensorAxisMerge(node, merged_axis_id, original); -} - -bool AscGraphImpl::ApplySchedAxisMerge(const AscNodePtr &node, - const int64_t merged_axis_id, - const std::vector &original) { - GE_ASSERT_NOTNULL(node); - return DoApplySchedAxisMerge(node, merged_axis_id, original); -} - -bool AscGraphImpl::ApplyReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - GE_ASSERT_NOTNULL(node); - GE_ASSERT_TRUE(DoApplySchedAxisReorder(node, reordered_axis)); - return DoApplyTensorAxisReorder(node, reordered_axis); -} - -bool AscGraphImpl::ApplySchedAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - GE_ASSERT_NOTNULL(node); - const auto &node_axis = node->attr.sched.axis; - GE_ASSERT_EQ(node_axis.size(), reordered_axis.size()); - return DoApplySchedAxisReorder(node, reordered_axis); -} - -bool AscGraphImpl::ApplyTensorAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - GE_ASSERT_NOTNULL(node); - const auto &node_axis = node->attr.sched.axis; - GE_ASSERT_EQ(node_axis.size(), reordered_axis.size()); - return DoApplyTensorAxisReorder(node, reordered_axis); -} - -bool AscGraphImpl::TryApplyAxisReplace(const AscNodePtr &node, const Axis &src, const Axis &dst) { - GE_ASSERT_NOTNULL(node); - std::vector new_axes = node->attr.sched.axis; - bool found{false}; - for (int64_t &id : new_axes) { - if (id == src.id) { - id = dst.id; - found = true; - } - } - node->attr.sched.axis = new_axes; - for (auto outputs : node->outputs()) { - auto new_output_axes = outputs->attr.axis; - for (auto &id : new_output_axes) { - if (id == src.id) { - id = dst.id; - found = true; - } - } - outputs->attr.axis = new_output_axes; - } - return found; -} - -AscGraphAttr *AscGraphImpl::GetOrCreateGraphAttrsGroup() { - return compute_graph_->GetOrCreateAttrsGroup(); -} - -AscGraphAttr *AscGraphImpl::GetGraphAttrsGroup() const{ - return compute_graph_->GetAttrsGroup(); -} - -AscOpOutput AscGraphImpl::CreateContiguousData(const char *name, - const ge::DataType &dt, - const vector &axes, - const Format &format) { - auto data_op_desc = OpDescBuilder(name, kAscData).AddOutput("y").Build(); - GE_ASSERT_NOTNULL(data_op_desc); - // Add output and attr - data_op_desc->AppendIrAttrName("index"); - data_op_desc->AppendIrOutput("y", kIrOutputRequired); - auto data_op = std::make_shared(OpDescUtils::CreateOperatorFromOpDesc(data_op_desc)); - GE_ASSERT_NOTNULL(data_op); - auto data_attr = data_op_desc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(data_attr); - AddNode(*data_op); - data_op_desc->SetExtAttr(ascir::cg::RELATED_OP, data_op); - data_attr->sched.exec_order = ascir::cg::CodeGenUtils::GenNextExecId(*data_op); - auto data_ir_attr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(data_ir_attr); - GE_ASSERT_GRAPH_SUCCESS(data_ir_attr->SetIndex(data_attr->sched.exec_order)); - data_attr->ir_attr = std::move(data_ir_attr); - - AscOpOutput asc_op_output(data_op.get(), 0U); // data只有一个输出 - asc_op_output.dtype = dt; - asc_op_output.format = format; // tensor上的format - GE_ASSERT_TRUE(asc_op_output.SetContiguousView(axes)); - *asc_op_output.vectorized_axis = AxisUtils::GetDefaultVectorizedAxis(*asc_op_output.axis, -1); - return asc_op_output; -} - -AscOpOutput AscGraphImpl::CreateContiguousOut(const char *name, - const DataType &dt, - const vector &axes, - const Format &format) { - auto out_op_desc = OpDescBuilder(name, kAscOutput).AddInput("x").AddOutput("y").Build(); - GE_ASSERT_NOTNULL(out_op_desc); - auto out_op = std::make_shared(OpDescUtils::CreateOperatorFromOpDesc(out_op_desc)); - GE_ASSERT_NOTNULL(out_op); - AddNode(*out_op); - out_op_desc->SetExtAttr(ascir::cg::RELATED_OP, out_op); - AscOpOutput asc_op_output(out_op.get(), 0U); // output只有一个输出 - asc_op_output.dtype = dt; - asc_op_output.format = format; - GE_ASSERT_TRUE(asc_op_output.SetContiguousView(axes)); - *asc_op_output.vectorized_axis = AxisUtils::GetDefaultVectorizedAxis(*asc_op_output.axis, -1); - return asc_op_output; -} - -void AscGraphImpl::SortByExecOrder() { - compute_graph_->TopologicalSorting([](const ge::NodePtr &a, const ge::NodePtr &b) { - auto node_a = std::dynamic_pointer_cast(a); - auto node_b = std::dynamic_pointer_cast(b); - return node_a->attr.sched.exec_order < node_b->attr.sched.exec_order; - }); -} - -const ComputeGraphPtr AscGraphImpl::GetComputeGraph() const { - return compute_graph_; -} - -bool AscGraphImpl::CopyFrom(const ge::AscGraph &src_graph, ge::AscGraph &dst_graph) { - GE_ASSERT_TRUE(DoCopyAscGraphAttr(src_graph, dst_graph)); - GE_ASSERT_TRUE(DoCopyAscNodeAndRelink(src_graph, dst_graph)); - return true; -} - -graphStatus AscGraphImpl::CreateSizeVar(const Expression &expression) { - const auto graph_attr_group_ptr = GetOrCreateGraphAttrsGroup(); - GE_ASSERT_NOTNULL(graph_attr_group_ptr); - const auto size_var = ComGraphMakeShared(expression); - GE_ASSERT_NOTNULL(size_var); - graph_attr_group_ptr->size_vars.push_back(size_var); - return GRAPH_SUCCESS; -} - -Status AscGraphImpl::AddSubGraph(const ComputeGraphPtr &sub_graph) const { - GE_ASSERT_NOTNULL(sub_graph); - auto root_graph = GraphUtils::FindRootGraph(compute_graph_); - GE_ASSERT_NOTNULL(root_graph); - root_graph->AddSubGraph(sub_graph); - return ge::SUCCESS; -} - -Status AscGraphImpl::FindSubGraph(const std::string &name, std::shared_ptr &graph_impl) const { - auto root_graph = GraphUtils::FindRootGraph(compute_graph_); - GE_ASSERT_NOTNULL(root_graph); - auto sub_graph = root_graph->GetSubgraph(name); - GE_ASSERT_NOTNULL(sub_graph, "Failed to get subgraph named [%s] from [%s].", name.c_str(), - compute_graph_->GetName().c_str()); - - graph_impl = ComGraphMakeShared(name.c_str()); - GE_ASSERT_NOTNULL(graph_impl); - graph_impl->compute_graph_ = sub_graph; - return ge::SUCCESS; -} - -AscGraph::AscGraph(const char *name) : - impl_(ComGraphMakeSharedAndThrow(name)) {} - -std::string AscGraph::GetName() const { - return impl_->GetName(); -} - -void AscGraph::SortByExecOrder() { - impl_->SortByExecOrder(); -} - -bool AscGraph::CopyFrom(const ge::AscGraph &graph) { - GE_ASSERT_TRUE(impl_->CopyFrom(graph, *this)); - std::vector sub_graphs; - GE_ASSERT_SUCCESS(graph.GetAllSubGraphs(sub_graphs)); - for (const auto &sub_graph : sub_graphs) { - AscGraph new_sub(sub_graph.GetName().c_str()); - GE_ASSERT_TRUE(new_sub.impl_->CopyFrom(sub_graph, new_sub)); - GE_ASSERT_SUCCESS(AddSubGraph(new_sub)); - } - return true; -} - -bool AscGraph::CopyAttrFrom(const AscGraph &src_asc_graph) { - GE_ASSERT_TRUE(AscGraphImpl::DoCopyAscGraphAttr(src_asc_graph, *this)); - return true; -} - -bool AscGraph::CopyAscNodeTensorAttr(const AscNodePtr &src_node, AscNodePtr &dst_node) { - GE_ASSERT_TRUE(AscGraphImpl::DoCopyAscNodeTensorAttr(src_node, dst_node)); - return true; -} - -void AscGraph::SetTilingKey(const uint32_t tiling_key) { - impl_->SetTilingKey(tiling_key); -} - -void AscGraph::SetGraphType(const AscGraphType type) { - impl_->SetGraphType(type); -} - -Status AscGraph::AddSubGraph(const ge::AscGraph &graph) const { - return impl_->AddSubGraph(graph.impl_->compute_graph_); -} - -Status AscGraph::GetAllSubGraphs(std::vector &graphs) const { - auto root_graph = GraphUtils::FindRootGraph(impl_->compute_graph_); - GE_ASSERT_NOTNULL(root_graph); - auto subgraphs = root_graph->GetAllSubgraphs(); - graphs.reserve(subgraphs.size()); - for (const auto &iter : subgraphs) { - AscGraph graph(iter->GetName().c_str()); - graph.impl_->compute_graph_ = iter; - graphs.emplace_back(std::move(graph)); - } - return ge::SUCCESS; -} - -Status AscGraph::FindSubGraph(const std::string &name, ge::AscGraph &graph) const { - return impl_->FindSubGraph(name, graph.impl_); -} - -int64_t AscGraph::GetTilingKey() const { - return impl_->GetTilingKey(); -} - -AscGraphType AscGraph::GetGraphType() const { - return impl_->GetGraphType(); -} - -Expression AscGraph::CreateSizeVar(const int64_t value) { - return impl_->CreateSizeVar(value); -} - -Expression AscGraph::CreateSizeVar(const std::string &name) { - return impl_->CreateSizeVar(name); -} - -graphStatus AscGraph::CreateSizeVar(const Expression &expression) { - return impl_->CreateSizeVar(expression); -} - -Axis &AscGraph::CreateAxis(const std::string &name, const Expression &size) { - return *(impl_->CreateAxis(name, Axis::kAxisTypeOriginal, size, {})); -} - -Axis &AscGraph::CreateAxis(const std::string &name, Axis::Type type, const Expression &size, - const std::vector &from, AxisId split_peer) { - return *(impl_->CreateAxis(name, type, size, from, split_peer)); -} - -Axis *AscGraph::FindAxis(const int64_t axis_id) { - return impl_->FindAxis(axis_id); -} - -AscNodePtr AscGraph::AddNode(ge::Operator &op) { - return impl_->AddNode(op); -} - -AscNodePtr AscGraph::FindNode(const char *name) const { - return impl_->FindNode(name); -} - -AscNodeVisitor AscGraph::GetAllNodes() const { - return impl_->GetAllNodes(); -} - -AscNodeVisitor AscGraph::GetInputNodes() const { - return impl_->GetInputNodes(); -} - -std::pair AscGraph::BlockSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name) { - GE_ASSERT_TRUE(IsVarNameValidAllowEmpty(inner_axis_name)); - GE_ASSERT_TRUE(IsVarNameValidAllowEmpty(outer_axis_name)); - return impl_->BlockSplit(axis_id, outer_axis_name, inner_axis_name); -} - -std::pair AscGraph::TileSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name) { - return impl_->TileSplit(axis_id, outer_axis_name, inner_axis_name); -} - -AxisPtr AscGraph::MergeAxis(const std::vector &axis_ids, const std::string &merge_axis_name) { - return impl_->MergeAxis(axis_ids, merge_axis_name); -} - -bool AscGraph::BindBlock(const int64_t outter_id, const int64_t inner_id) { - return impl_->BindBlock(outter_id, inner_id); -} - -bool AscGraph::ApplySplit(const AscNodePtr &node, const int64_t outter_id, const int64_t inner_id) { - return impl_->ApplySplit(node, outter_id, inner_id); -} - -bool AscGraph::ApplyMerge(const AscNodePtr &node, const int64_t merged_axis_id) { - return impl_->ApplyMerge(node, merged_axis_id); -} - -bool AscGraph::ApplySchedAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id) { - return impl_->ApplySchedAxisMerge(node, merged_axis_id); -} - -bool AscGraph::ApplyTensorAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id) { - return impl_->ApplyTensorAxisMerge(node, merged_axis_id); -} - -bool AscGraph::ApplySchedAxisMerge(const AscNodePtr &node, - const int64_t merged_axis_id, - const std::vector &original) { - return impl_->ApplySchedAxisMerge(node, merged_axis_id, original); -} - -bool AscGraph::ApplyTensorAxisMerge(const AscNodePtr &node, - const int64_t merged_axis_id, - const std::vector &original) { - return impl_->ApplyTensorAxisMerge(node, merged_axis_id, original); -} - -bool AscGraph::ApplyReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - return impl_->ApplyReorder(node, reordered_axis); -} - -bool AscGraph::ApplySchedAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - return impl_->ApplySchedAxisReorder(node, reordered_axis); -} - -bool AscGraph::ApplyTensorAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis) { - return impl_->ApplyTensorAxisReorder(node, reordered_axis); -} - -bool AscGraph::TryApplyAxisReplace(const AscNodePtr &node, const Axis &src, const Axis &dst) { - return impl_->TryApplyAxisReplace(node, src, dst); -} - -std::vector AscGraph::GetAllAxis() const{ - return impl_->GetAllAxis(); -} - -std::vector AscGraph::GetAllSizeVar() const{ - return impl_->GetAllSizeVar(); -} - -AscGraph::~AscGraph() { - for (const auto &node: impl_->GetAllNodes()) { - if (node == nullptr) { - continue; - } - const auto &op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - // 打破shared ptr的循环引用 - op_desc->DelExtAttr(ascir::cg::RELATED_OP); - } - } -} - -bool AscGraph::CheckExprValid() const { - int32_t node_index = -1; - for (const auto &node : GetAllNodes()) { - node_index++; - GE_ASSERT_NOTNULL(node, "Node ptr is null, index[%d].", node_index); - int32_t output_index = -1; - for (const auto &tensor : node->outputs()) { - output_index++; - GE_ASSERT_NOTNULL(tensor, "Tensor ptr is null, index[%d], node name[%s].", output_index, - node->GetName().c_str()); - } - } - return true; -} - -bool AscGraph::CheckAxisValid() const { - int64_t id_index = 0; - const auto axes = GetAllAxis(); - for (const auto &axis : axes) { - GE_ASSERT_NOTNULL(axis, "Axis ptr is null, index[%ld].", id_index); - GE_ASSERT_TRUE(axis->id == id_index, "Axis index[%ld] is not equal to id[%ld].", id_index, axis->id); - id_index++; - } - int32_t node_index = -1; - for (const auto &node : GetAllNodes()) { - node_index++; - GE_ASSERT_NOTNULL(node, "Node ptr is null, index[%d].", node_index); - std::set sched_axis_set; - int32_t sched_axis_index = -1; - for (const auto &sched_axis : node->attr.sched.axis) { - sched_axis_index++; - GE_ASSERT_TRUE(sched_axis >= 0L, "Invalid sched axis[%ld], node_name[%s], index[%d].", sched_axis, - node->GetName().c_str(), sched_axis_index); - GE_ASSERT_TRUE(sched_axis < static_cast(axes.size()), - "Invalid sched axis[%ld], node_name[%s], index[%d].", sched_axis, - node->GetName().c_str(), sched_axis_index); - const auto iter = sched_axis_set.find(sched_axis); - GE_ASSERT_TRUE(iter == sched_axis_set.cend(), "Redundant sched axis[%ld], node_name[%s].", sched_axis, - node->GetName().c_str()); - sched_axis_set.insert(sched_axis); - } - int32_t output_index = -1; - for (const auto &tensor : node->outputs()) { - output_index++; - GE_ASSERT_TRUE(tensor != nullptr, "Tensor ptr is null, index[%d], node name[%s].", output_index, - node->GetName().c_str()); - GE_ASSERT_TRUE(tensor->attr.axis.size() == tensor->attr.repeats.size(), - "Tensor axis size[%zu] is not equal to repeat size[%zu], index[%d], node name[%s].", - tensor->attr.axis.size(), tensor->attr.repeats.size(), output_index, - node->GetName().c_str()); - GE_ASSERT_TRUE(tensor->attr.axis.size() == tensor->attr.strides.size(), - "Tensor axis size[%zu] is not equal to stride size[%zu], index[%d], node name[%s].", - tensor->attr.axis.size(), tensor->attr.strides.size(), output_index, - node->GetName().c_str()); - for (const auto &axis : tensor->attr.axis) { - GE_ASSERT_TRUE(axis >= 0, "Invalid tensor axis[%ld].", axis); - GE_ASSERT_TRUE(axis < static_cast(axes.size()), "Invalid tensor axis[%ld].", axis); - } - for (const auto &vectorized_axis : tensor->attr.vectorized_axis) { - GE_ASSERT_TRUE(vectorized_axis >= 0, "Invalid tensor vectorized_axis[%ld].", vectorized_axis); - GE_ASSERT_TRUE(vectorized_axis < static_cast(axes.size()), - "Invalid tensor vectorized_axis[%ld].", vectorized_axis); - } - } - } - return true; -} - -bool AscGraph::CheckExecOrderValid() const { - std::set exec_order_set; - for (const auto &node : GetAllNodes()) { - const auto exec_order = node->attr.sched.exec_order; - const auto iter = exec_order_set.find(exec_order); - GE_ASSERT_TRUE(iter == exec_order_set.end(), "Redundant exec_order[%ld].", exec_order); - exec_order_set.insert(exec_order); - } - return true; -} - -bool AscGraph::CheckTensorValid() const { - for (const auto &node : GetAllNodes()) { - int32_t output_index = -1; - for (const auto &tensor : node->outputs()) { - output_index++; - if (tensor->attr.mem.alloc_type == AllocType::kAllocTypeGlobal) { - continue; - } - if ((tensor->attr.buf.id != kIdNone) && (tensor->attr.que.id == kIdNone)) { - continue; - } - if ((tensor->attr.buf.id == kIdNone) && (tensor->attr.que.id != kIdNone)) { - GE_ASSERT_TRUE(tensor->attr.que.depth > 0, "Invalid que depth[%ld], tensor index[%d], node[%s].", - tensor->attr.que.depth, output_index, node->GetName().c_str()); - GE_ASSERT_TRUE(tensor->attr.que.buf_num > 0, "Invalid que buf_num[%ld], tensor index[%d], node[%s].", - tensor->attr.que.buf_num, output_index, node->GetName().c_str()); - continue; - } - GE_LOGE("Invalid mem, alloc type[%d], que id[%ld], buf id[%ld], tensor index[%d], node[%s].", - static_cast(tensor->attr.mem.alloc_type), tensor->attr.que.id, tensor->attr.buf.id, - output_index, node->GetName().c_str()); - return false; - } - } - return true; -} - -bool AscGraph::CheckNodeConnectionValid() const { - for (const auto &node : GetAllNodes()) { - for (uint32_t index = 0U; index < node->inputs.Size(); index++) { - GE_ASSERT_TRUE(node->GetInDataAnchor(index) != nullptr, "Input is not connected, index[%u], node[%s].", - index, node->GetName().c_str()); - GE_ASSERT_TRUE(node->GetInDataAnchor(index)->GetPeerOutAnchor() != nullptr, - "Input is not connected, index[%u], node[%s].", index, node->GetName().c_str()); - } - } - return true; -} - -bool AscGraph::CheckValid() const { - if (!CheckExprValid()) { - return false; - } - if (!CheckAxisValid()) { - return false; - } - if (!CheckTensorValid()) { - return false; - } - if (!CheckNodeConnectionValid()) { - return false; - } - return true; -} - -TransInfoRoadOfGraph AscGraph::GetAllAxisTransInfo() const { - return impl_->GetAllAxisTransInfo(); -} - -AscOpOutput AscGraph::CreateContiguousData(const char *name, - const ge::DataType &dt, - const std::vector &axes, - const ge::Format &format) { - return impl_->CreateContiguousData(name, dt, axes, format); -} - -AscOpOutput AscGraph::CreateContiguousOut(const char *name, - const ge::DataType &dt, - const std::vector &axes, - const ge::Format &format) { - return impl_->CreateContiguousOut(name, dt, axes, format); -} - -graphStatus AddEdgeForNode(const ge::Operator &src_op, int32_t src_index, ge::Operator &dst_op, int32_t dst_index) { - auto src_node = ge::NodeUtilsEx::GetNodeFromOperator(src_op); - auto dst_node = ge::NodeUtilsEx::GetNodeFromOperator(dst_op); - GE_ASSERT_NOTNULL(src_node); - if (dst_node == nullptr) { - auto com_graph = src_node->GetOwnerComputeGraph(); - GE_ASSERT_NOTNULL(com_graph); - auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); - auto dst_asc_node = std::make_shared(dst_op_desc, com_graph); - GE_ASSERT_NOTNULL(dst_asc_node); - (void) dst_asc_node->Init(); - ConstNodePtr const_dst_node = dst_asc_node; - GE_ASSERT_GRAPH_SUCCESS( - ge::NodeUtilsEx::SetNodeToOperator(dst_op, const_dst_node)); - dst_node = com_graph->AddNode(dst_asc_node); - GE_ASSERT_NOTNULL(dst_node); - GE_ASSERT_GRAPH_SUCCESS( - ge::GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_index), dst_node->GetInDataAnchor(dst_index))); - // update tensors - (void) dst_asc_node->inputs(); - (void) dst_asc_node->outputs(); - } else { - GE_ASSERT_GRAPH_SUCCESS( - ge::GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_index), dst_node->GetInDataAnchor(dst_index))); - } - return GRAPH_SUCCESS; -} - -int64_t AscOpOutput::GenContainerId() { - GE_ASSERT_NOTNULL(op_); - return ascir::cg::CodeGenUtils::GenNextContainerId(*op_); -} - -int64_t AscOpOutput::GenNextReuseId() { - GE_ASSERT_NOTNULL(op_); - return ascir::cg::CodeGenUtils::GenNextReuseId(*op_); -} - -bool AscOpOutput::UseTQue(const Position pos, const int64_t depth, const int64_t buf_num, const int64_t id) { - GE_ASSERT_TRUE(!HasBindToContainer(), - " this tensor has been bound to a que, can not use any other que."); - GE_ASSERT_TRUE(buf_num > 0, "input buf_num should be greater than 0."); - GE_ASSERT_TRUE(buf_num < static_cast(INT32_MAX), - "input buf_num should be less than INT32_MAX."); - GE_ASSERT_TRUE(depth > 0, "input depth should be greater than 0."); - GE_ASSERT_TRUE(depth < static_cast(INT32_MAX), - "input depth should be less than INT32_MAX."); - mem->position = pos; - mem->alloc_type = AllocType::kAllocTypeQueue; - buf->id = kIdNone; - que->depth = depth; - que->buf_num = buf_num; - if (id == kIdNone) { - que->id = GenContainerId(); - } else { - que->id = id; - } - return true; -} - -bool AscOpOutput::UseTBuf(const Position pos, const int64_t id) { - GE_ASSERT_TRUE(!HasBindToContainer(), - " this tensor has been bound to a buf, can not use any other buf."); - mem->position = pos; - mem->alloc_type = AllocType::kAllocTypeBuffer; - que->id = kIdNone; - if (id == kIdNone) { - buf->id = GenContainerId(); - } else { - buf->id = id; - } - return true; -} - -bool AscOpOutput::HasBindToContainer() const { - bool has_bind_que = (que->id != kIdNone); - bool has_bind_buf = (buf->id != kIdNone); - // 1.if alloc type has set to que or buffer means has binding to a container - // 2.if que/buf is valid, means also means has binding to a container - return ((mem->alloc_type == AllocType::kAllocTypeQueue) || (mem->alloc_type == AllocType::kAllocTypeBuffer)) && - (has_bind_que || has_bind_buf); -} - -// 既有动态输出,也有普通的输出,返回错误 -Status GetAndCheckDynamicOutput(const std::vector> &ir_outputs, - bool &only_has_one_dynamic_output) { - bool has_dynamic_output = false; - bool has_com_output = false; - for (auto &ir_output : ir_outputs) { - if (ir_output.second == ge::IrOutputType::kIrOutputDynamic) { - has_dynamic_output = true; - } else { - has_com_output = true; - } - } - only_has_one_dynamic_output = (has_dynamic_output) && (ir_outputs.size() == 1U); - - return (has_dynamic_output && has_com_output) ? ge::FAILED : ge::SUCCESS; -} - -graphStatus LinkByIrIndex(const ge::Operator &src_op, - uint32_t src_ir_index, - ge::Operator &dst_op, - uint32_t dst_ir_index, - uint32_t dynamic_index) { - auto dst_op_desc = ge::OpDescUtils::GetOpDescFromOperator(dst_op); - auto src_op_desc = ge::OpDescUtils::GetOpDescFromOperator(src_op); - GE_ASSERT_NOTNULL(src_op_desc); - GE_ASSERT_NOTNULL(dst_op_desc); - const std::vector> &ir_inputs = dst_op_desc->GetIrInputs(); - const std::vector> &ir_outputs = src_op_desc->GetIrOutputs(); - bool only_has_one_dynamic_output = false; - GE_ASSERT_SUCCESS(GetAndCheckDynamicOutput(ir_outputs, only_has_one_dynamic_output), - "Not supporting both dynamic and non dynamic outputs"); - - GE_ASSERT_TRUE(dst_ir_index < ir_inputs.size(), - "dst_ir_index = %u, ir_inputs size = %zu", dst_ir_index, ir_inputs.size()); - auto &name_to_input_idx = dst_op_desc->MutableAllInputName(); - auto &name_to_output_idx = src_op_desc->MutableAllOutputName(); - uint32_t src_index; - uint32_t dst_index; - if (ir_inputs[dst_ir_index].second == ge::IrInputType::kIrInputDynamic) { - std::map> ir_input_2_range; - (void) ge::OpDescUtils::GetIrInputInstanceDescRange(dst_op_desc, ir_input_2_range); - dst_index = ir_input_2_range[dst_ir_index].first + dynamic_index; - } else { - dst_index = name_to_input_idx[ir_inputs[dst_ir_index].first]; - } - - if (only_has_one_dynamic_output) { - src_index = src_ir_index; - } else { - src_index = name_to_output_idx[ir_outputs[src_ir_index].first]; - } - - dst_op.SetInput(dst_index, src_op, src_index); - AddEdgeForNode(src_op, static_cast(src_index), dst_op, static_cast(dst_index)); - - return GRAPH_SUCCESS; -} - -graphStatus SetDynamicInputNumByIrIndex(ge::Operator &op, uint32_t ir_index, uint32_t dynamic_num) { - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - const std::vector> &ir_inputs = op_desc->GetIrInputs(); - GE_ASSERT_TRUE(ir_index < ir_inputs.size()); - GE_ASSERT_TRUE(ir_inputs[ir_index].second == ge::IrInputType::kIrInputDynamic); - std::map> ir_input_2_range; - (void) ge::OpDescUtils::GetIrInputInstanceDescRange(op_desc, ir_input_2_range); - - GE_ASSERT_TRUE(ir_input_2_range[ir_index].second < dynamic_num, - "Dynamic index [%u] is invalid.", dynamic_num); - op_desc->AddDynamicInputDescByIndex(ir_inputs[ir_index].first, dynamic_num, ir_input_2_range[ir_index].first); - GELOGD("Add DynamicInputDescByIndex for op_desc[%s], ir_index[%u], dynamic_num[%u]", op_desc->GetNamePtr(), ir_index, - dynamic_num); - return GRAPH_SUCCESS; -} - -graphStatus AscGraphAttr::SerializeAttr(ascendc_ir::proto::AscGraphAttrGroupsDef &asc_graph_group) { - asc_graph_group.set_tiling_key(tiling_key); - // axis serialize - auto axis_defs = asc_graph_group.axis(); - for (const auto &ax : axis) { - auto ax_def = asc_graph_group.add_axis(); - ax_def->set_id(ax->id); - ax_def->set_name(ax->name); - ax_def->set_axis_type(ax->type); - ax_def->set_bind_block(ax->bind_block); - ax_def->set_size(SymbolicUtils::ToString(ax->size)); - ax_def->set_align(ax->align); - for (const auto fm : ax->from) { - ax_def->add_from(fm); - } - ax_def->set_split_pair_other_id(ax->split_pair_other_id); - ax_def->set_allow_oversize_axis(ax->allow_oversize_axis); - ax_def->set_allow_unaligned_tail(ax->allow_unaligned_tail); - } - for (const auto &var : size_vars) { - asc_graph_group.add_size_var(SymbolicUtils::ToString(var->expr)); - } - asc_graph_group.set_type(static_cast(type)); - GELOGD("Graph serialization successful, tiling_key[%ld] type[%ld]", tiling_key, static_cast(type)); - return GRAPH_SUCCESS; -} -graphStatus AscGraphAttr::Serialize(proto::AttrGroupDef &attr_group_def) { - auto asc_graph_attr_group = attr_group_def.mutable_asc_graph_attr_group(); - GE_ASSERT_NOTNULL(asc_graph_attr_group); - return SerializeAttr(*asc_graph_attr_group); -} - -graphStatus AscGraphAttr::DeserializeAttr(const ascendc_ir::proto::AscGraphAttrGroupsDef &asc_graph_group) { - tiling_key = asc_graph_group.tiling_key(); - type = static_cast(asc_graph_group.type()); - for (const auto &ax : asc_graph_group.axis()) { - auto new_axis = std::make_shared(); - GE_ASSERT_NOTNULL(new_axis); - new_axis->id = ax.id(); - new_axis->name = ax.name(); - new_axis->type = static_cast(ax.axis_type()); - new_axis->bind_block = ax.bind_block(); - new_axis->size = Expression::Deserialize(ax.size().c_str()); - new_axis->align = ax.align(); - for (const auto &fm : ax.from()) { - new_axis->from.emplace_back(fm); - } - new_axis->split_pair_other_id = ax.split_pair_other_id(); - new_axis->allow_oversize_axis = ax.allow_oversize_axis(); - new_axis->allow_unaligned_tail = ax.allow_unaligned_tail(); - axis.emplace_back(new_axis); - } - for (const auto &var : asc_graph_group.size_var()) { - auto new_size_var = std::make_shared(Expression::Deserialize(var.c_str())); - size_vars.emplace_back(new_size_var); - } - type = static_cast(asc_graph_group.type()); - GELOGD("Graph deserialization successful, tiling_key[%ld], type[%ld]", tiling_key, asc_graph_group.type()); - return GRAPH_SUCCESS; -} - -graphStatus AscGraphAttr::Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) { - (void) attr_holder; - const auto &asc_graph_attr_group_def = attr_group_def.asc_graph_attr_group(); - return DeserializeAttr(asc_graph_attr_group_def); -} - - -graphStatus AscNodeAttr::SerializeAttr(ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_group) const{ - asc_node_group.set_name(name); - asc_node_group.set_type(type); - auto sched_def = asc_node_group.mutable_sched(); - sched_def->set_exec_order(sched.exec_order); - for (const int64_t axis_id : sched.axis) { - sched_def->add_axis(axis_id); - } - sched_def->set_loop_axis(sched.loop_axis); - sched_def->set_exec_order(sched.exec_order); - sched_def->set_exec_condition(static_cast(sched.exec_condition)); - auto api_def = asc_node_group.mutable_api(); - api_def->set_type(static_cast(api.type)); - api_def->set_compute_type(static_cast(api.compute_type)); - api_def->set_unit(static_cast(api.unit)); - if (ir_attr != nullptr) { - ir_attr->Serialize(*(asc_node_group.mutable_ir_attr_def())); - } - for (const auto &tmp_buffer : tmp_buffers) { - auto tmp_buffer_def = asc_node_group.add_tmp_buffers(); - auto buf_desc_def = tmp_buffer_def->mutable_buf_desc(); - buf_desc_def->set_size(SymbolicUtils::ToString(tmp_buffer.buf_desc.size)); - buf_desc_def->set_life_time_axis_id(tmp_buffer.buf_desc.life_time_axis_id); - auto mem_def = tmp_buffer_def->mutable_mem(); - mem_def->set_tensor_id(tmp_buffer.mem.tensor_id); - mem_def->set_alloc_type(static_cast(tmp_buffer.mem.alloc_type)); - mem_def->set_position(static_cast(tmp_buffer.mem.position)); - mem_def->set_hardware(static_cast(tmp_buffer.mem.hardware)); - mem_def->set_reuse_id(static_cast(tmp_buffer.mem.reuse_id)); - for (const int64_t buf_id : tmp_buffer.mem.buf_ids) { - mem_def->add_buf_ids(buf_id); - } - mem_def->set_name(tmp_buffer.mem.name); - } - GELOGD("Serialize node[%s:%s] success.", name.c_str(), type.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus AscNodeAttr::DeserializeAttr(const ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_group) { - name = asc_node_group.name(); - type = asc_node_group.type(); - const auto &sched_def = asc_node_group.sched(); - for (const auto &ax : sched_def.axis()) { - sched.axis.emplace_back(ax); - } - sched.loop_axis = sched_def.loop_axis(); - sched.exec_order = sched_def.exec_order(); - sched.exec_condition = static_cast(sched_def.exec_condition()); - const auto &api_def = asc_node_group.api(); - api.type = static_cast((api_def.type())); - api.compute_type = static_cast(api_def.compute_type()); - api.unit = static_cast(api_def.unit()); - if (asc_node_group.has_ir_attr_def()) { - if (ir_attr == nullptr) { - ir_attr = ComGraphMakeUnique(); - } - GE_ASSERT_NOTNULL(ir_attr); - ir_attr->Deserialize(asc_node_group.ir_attr_def()); - } - for (const auto &tmp_buffer_def : asc_node_group.tmp_buffers()) { - TmpBufDesc new_tmp_buffer_desc; - new_tmp_buffer_desc.size = Expression::Deserialize(tmp_buffer_def.buf_desc().size().c_str()); - new_tmp_buffer_desc.life_time_axis_id = tmp_buffer_def.buf_desc().life_time_axis_id(); - MemAttr new_mem_attr; - new_mem_attr.name = tmp_buffer_def.mem().name(); - new_mem_attr.tensor_id = tmp_buffer_def.mem().tensor_id(); - new_mem_attr.alloc_type = static_cast(tmp_buffer_def.mem().alloc_type()); - new_mem_attr.position = static_cast(tmp_buffer_def.mem().position()); - new_mem_attr.hardware = static_cast(tmp_buffer_def.mem().hardware()); - new_mem_attr.reuse_id = tmp_buffer_def.mem().reuse_id(); - for (const int64_t buf_id : tmp_buffer_def.mem().buf_ids()) { - new_mem_attr.buf_ids.emplace_back(buf_id); - } - TmpBuffer new_tmp_buffer; - new_tmp_buffer.buf_desc = new_tmp_buffer_desc; - new_tmp_buffer.mem = new_mem_attr; - tmp_buffers.emplace_back(new_tmp_buffer); - } - return GRAPH_SUCCESS; -} - -graphStatus AscNodeAttr::Serialize(proto::AttrGroupDef &attr_group_def) { - auto asc_node_attr = attr_group_def.mutable_asc_node_attr_group(); - GE_ASSERT_NOTNULL(asc_node_attr); - return SerializeAttr(*asc_node_attr); -} - -graphStatus AscNodeAttr::Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) { - (void) attr_holder; - const auto &asc_node_attr_def = attr_group_def.asc_node_attr_group(); - return DeserializeAttr(asc_node_attr_def); -} - -graphStatus AscTensorAttr::SerializeAttr(ascendc_ir::proto::AscTensorAttrGroupsDef &asc_tensor_group) { - if (dtype.tensor_desc_ != nullptr) { - asc_tensor_group.set_dtype(static_cast(dtype)); - } - for (const int64_t axis_id : axis) { - asc_tensor_group.add_axis_ids(axis_id); - } - for (const auto &repeat : repeats) { - asc_tensor_group.add_repeats(SymbolicUtils::ToString(repeat)); - } - for (const auto &stride : strides) { - asc_tensor_group.add_strides(SymbolicUtils::ToString(stride)); - } - for (const auto &vectorized_axis_id : vectorized_axis) { - asc_tensor_group.add_vectorized_axis(vectorized_axis_id); - } - for (const auto &vectorized_stride : vectorized_strides) { - asc_tensor_group.add_vectorized_strides(SymbolicUtils::ToString(vectorized_stride)); - } - auto mem_def = asc_tensor_group.mutable_mem(); - mem_def->set_tensor_id(mem.tensor_id); - mem_def->set_alloc_type(static_cast(mem.alloc_type)); - mem_def->set_position(static_cast(mem.position)); - mem_def->set_hardware(static_cast(mem.hardware)); - for (const int64_t buf_id : mem.buf_ids) { - mem_def->add_buf_ids(buf_id); - } - mem_def->set_name(mem.name); - auto que_def = asc_tensor_group.mutable_que(); - que_def->set_id(que.id); - que_def->set_depth(que.depth); - que_def->set_buf_num(que.buf_num); - que_def->set_name(que.name); - auto buf_def = asc_tensor_group.mutable_buf(); - buf_def->set_id(buf.id); - buf_def->set_name(buf.name); - auto opt_def = asc_tensor_group.mutable_opt(); - opt_def->set_reuse_id(opt.reuse_id); - opt_def->set_ref_tensor(opt.ref_tensor); - opt_def->set_merge_scope(opt.merge_scope); - return GRAPH_SUCCESS; -} - -graphStatus AscTensorAttr::DeserializeAttr(const ascendc_ir::proto::AscTensorAttrGroupsDef &asc_tensor_group, - GeTensorDesc *tensor_desc) { - if ((tensor_desc != nullptr) && (dtype.tensor_desc_ == nullptr)) { - dtype.tensor_desc_ = tensor_desc; - } - dtype.tensor_desc_->SetDataType(static_cast(asc_tensor_group.dtype())); - for (const auto &axis_id : asc_tensor_group.axis_ids()) { - axis.emplace_back(axis_id); - } - const auto &repeat_defs = asc_tensor_group.repeats(); - for (const auto &repeat : repeat_defs) { - repeats.emplace_back(Expression::Deserialize(repeat.c_str())); - } - const auto &strides_defs = asc_tensor_group.strides(); - for (const auto &stride : strides_defs) { - strides.emplace_back(Expression::Deserialize(stride.c_str())); - } - const auto &vectorized_axis_ids = asc_tensor_group.vectorized_axis(); - for (const auto &vectorized_axis_id : vectorized_axis_ids) { - vectorized_axis.emplace_back(vectorized_axis_id); - } - const auto &vectorized_strides_def = asc_tensor_group.vectorized_strides(); - for (const auto &vectorized_stride : vectorized_strides_def) { - vectorized_strides.emplace_back(Expression::Deserialize(vectorized_stride.c_str())); - } - const auto &mem_def = asc_tensor_group.mem(); - mem.name = mem_def.name(); - mem.tensor_id = mem_def.tensor_id(); - mem.alloc_type = static_cast(mem_def.alloc_type()); - mem.position = static_cast(mem_def.position()); - mem.hardware = static_cast(mem_def.hardware()); - for (const int64_t buf_id : mem_def.buf_ids()) { - mem.buf_ids.emplace_back(buf_id); - } - mem.name = mem_def.name(); - const auto &que_def = asc_tensor_group.que(); - que.id = que_def.id(); - que.name = que_def.name(); - que.depth = que_def.depth(); - que.buf_num = que_def.buf_num(); - const auto &buf_def = asc_tensor_group.buf(); - buf.id = buf_def.id(); - buf.name = buf_def.name(); - const auto &opt_def = asc_tensor_group.opt(); - opt.merge_scope = opt_def.merge_scope(); - opt.ref_tensor = opt_def.ref_tensor(); - opt.reuse_id = opt_def.reuse_id(); - return GRAPH_SUCCESS; -} - -graphStatus AscTensorAttr::Serialize(proto::AttrGroupDef &attr_group_def) { - auto asc_tensor_attr_group = attr_group_def.mutable_asc_tensor_attr_group(); - GE_ASSERT_NOTNULL(asc_tensor_attr_group); - return SerializeAttr(*asc_tensor_attr_group); -} - -graphStatus AscTensorAttr::Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) { - const auto &asc_tensor_attr_group_def = attr_group_def.asc_tensor_attr_group(); - return DeserializeAttr(asc_tensor_attr_group_def, dynamic_cast(attr_holder)); -} - -graphStatus AscIrAttrDefBase::Serialize(ascendc_ir::proto::AscIrAttrDef &asc_ir_attr_def) { - std::map names_to_attr; - attr_store_.GetAllAttrs(names_to_attr); - auto &attr_map = *asc_ir_attr_def.mutable_attr(); - for (const auto &pair:names_to_attr) { - const auto serializer = AttrSerializerRegistry::GetInstance().GetSerializer( - pair.second.GetValueTypeId()); - GE_ASSERT_NOTNULL(serializer); - proto::AttrDef attr_def; - GE_ASSERT_GRAPH_SUCCESS(serializer->Serialize(pair.second, attr_def)); - attr_map[pair.first] = attr_def; - } - return GRAPH_SUCCESS; -} - -graphStatus AscIrAttrDefBase::Deserialize(const ascendc_ir::proto::AscIrAttrDef &asc_ir_attr_def) { - const auto &attr_map = asc_ir_attr_def.attr(); - for (const auto &pair:attr_map) { - const auto deserializer = AttrSerializerRegistry::GetInstance() - .GetDeserializer(pair.second.value_case()); - GE_ASSERT_NOTNULL(deserializer); - auto attr_value = attr_store_.GetOrCreateAnyValue(pair.first); - GE_ASSERT_NOTNULL(attr_value); - GE_ASSERT_GRAPH_SUCCESS(deserializer->Deserialize(pair.second, *attr_value)); - } - return GRAPH_SUCCESS; -} - -std::unique_ptr AscIrAttrDefBase::Clone() { - auto ptr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(ptr); - ptr->attr_store_ = this->attr_store_; - return ptr; -} - -graphStatus AscDataIrAttrDef::GetIndex(int64_t &index) const { - auto value = attr_store_.GetAnyValue(kDataIndex); - GE_WARN_ASSERT(value != nullptr); - return value->GetValue(index); -} - -graphStatus AscDataIrAttrDef::SetIndex(int64_t index) { - auto value = attr_store_.GetOrCreateAnyValue(kDataIndex); - GE_ASSERT_NOTNULL(value); - return value->SetValue(index); -} -REG_ATTR_GROUP_SERIALIZER(AscNodeAttr, AscNodeAttr, GetTypeId(), proto::AttrGroupDef::kAscNodeAttrGroup); -REG_ATTR_GROUP_SERIALIZER(AscGraphAttr, - AscGraphAttr, - GetTypeId(), - proto::AttrGroupDef::kAscGraphAttrGroup); -REG_ATTR_GROUP_SERIALIZER(AscTensorAttr, - AscTensorAttr, - GetTypeId(), - proto::AttrGroupDef::kAscTensorAttrGroup); -REG_ATTR_GROUP_SERIALIZER(ShapeEnvAttr, - ShapeEnvAttr, - GetTypeId(), - proto::AttrGroupDef::kShapeEnvAttrGroup); -REG_ATTR_GROUP_SERIALIZER(SymbolicDescAttr, - SymbolicDescAttr, - GetTypeId(), - proto::AttrGroupDef::kTensorAttrGroup); -} // namespace ge diff --git a/graph/ascendc_ir/core/ascendc_ir_impl.h b/graph/ascendc_ir/core/ascendc_ir_impl.h deleted file mode 100644 index 03fb07176432d6d8153579c1afa52e427feedeef..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/core/ascendc_ir_impl.h +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_ASCENDC_IR_IMPL_H -#define GRAPH_ASCENDC_IR_IMPL_H - -#include -#include "attr_store.h" -#include "graph/compute_graph.h" -#include "graph/node.h" -#include "graph/anchor.h" -#include "graph/utils/op_desc_utils.h" -#include "common/ge_common/debug/ge_log.h" -#include "external/graph/operator.h" - -namespace ge { -namespace ascir { -namespace cg { -class CodeGenUtils; -} -} -class AscGraphImpl { - friend class AscGraph; - friend class AscGraphUtils; - friend class ascir::cg::CodeGenUtils; - public: - explicit AscGraphImpl(const char *name); - - Axis *FindAxis(const int64_t axis_id) const; - - void SetTilingKey(const uint32_t tiling_key); - - int64_t GetTilingKey() const; - - void SetGraphType(const AscGraphType type); - - AscGraphType GetGraphType() const; - - AscNodePtr AddNode(ge::Operator &op); - - Expression CreateSizeVar(const int64_t value); - - AxisPtr CreateAxis(const std::string &name, Axis::Type type, const ge::Expression &size, - const std::vector &from, const int64_t split_peer = 0UL); - - graphStatus CreateSizeVar(const Expression &expression); - - Expression CreateSizeVar(const std::string &name); - - std::pair BlockSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name); - - std::pair TileSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name); - - AxisPtr MergeAxis(const std::vector &axis_ids, const std::string &merge_axis_name); - - bool BindBlock(const int64_t outter_id, const int64_t inner_id); - - bool ApplySplit(const AscNodePtr &node, const int64_t outter_id, const int64_t inner_id); - - bool ApplyMerge(const AscNodePtr &node, const int64_t merged_axis_id); - - static bool ApplyTensorAxisMerge(const AscNodePtr &node, - const int64_t merged_axis_id, - const std::vector &original); - - bool ApplyTensorAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id); - - static bool ApplySchedAxisMerge(const AscNodePtr &node, - const int64_t merged_axis_id, - const std::vector &original); - - bool ApplySchedAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id); - - static bool ApplyReorder(const AscNodePtr &node, const std::vector &reordered_axis); - - static bool ApplySchedAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis); - - static bool ApplyTensorAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis); - - static bool TryApplyAxisReplace(const AscNodePtr &node, const Axis &src, const Axis &dst); - - AscNodePtr FindNode(const char *name) const; - - std::vector GetAllAxis() const; - - std::vector GetAllSizeVar() const; - - TransInfoRoadOfGraph GetAllAxisTransInfo() const; - - AscNodeVisitor GetAllNodes() const; - - AscNodeVisitor GetInputNodes() const; - - std::string GetName() const; - - AscOpOutput CreateContiguousData(const char *name, - const ge::DataType &dt, - const std::vector &axes, - const ge::Format &format); - - AscOpOutput CreateContiguousOut(const char *name, - const ge::DataType &dt, - const std::vector &axes, - const ge::Format &format); - - void SortByExecOrder(); - - const ComputeGraphPtr GetComputeGraph() const; - - static bool CopyFrom(const ge::AscGraph &src_graph, ge::AscGraph &dst_graph); - private: - std::pair DoSplit(const int64_t axis_id, const std::string &outer_axis_name, - const std::string &inner_axis_name, const bool is_tile_split); - - bool DoApplySplit(const AscNodePtr &node, const int64_t outter_id, const int64_t inner_id, const int64_t original_id); - - static bool DoApplyMerge(const AscNodePtr &node, const int64_t merged_axis_id, const std::vector &original); - - static bool DoApplyTensorAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id, - const std::vector &original); - - bool DoApplyTensorAxisSplit(const AscNodePtr &node, const int64_t outter_id, const int64_t inner_id, - const int64_t original_id); - - static bool DoApplySchedAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id, - const std::vector &original); - - static bool DoApplySchedAxisSplit(const AscNodePtr &node, const int64_t outter_id, const int64_t inner_id, - const int64_t original_id); - - static bool DoApplySchedAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis); - - static bool DoApplyTensorAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis); - - static bool DoCopyAscGraphAttr(const AscGraph &src_asc_graph, AscGraph &dst_asc_graph); - static bool DoCopyAscGraphAttrImpl(const ComputeGraphPtr &src_compute_graph, - const ComputeGraphPtr &dst_compute_graph); - - static bool DoCopyAscNodeAndRelink(const AscGraph &src_asc_graph, AscGraph &dst_asc_graph); - - static bool DoCopyAscNodeTensorAttr(const AscNodePtr &src_node, AscNodePtr &dst_node); - - AscGraphAttr *GetOrCreateGraphAttrsGroup(); - AscGraphAttr *GetGraphAttrsGroup() const; - static bool CheckContinuous(const AscNodePtr &node, - const uint32_t tensor_index, - const std::vector &original); - Status AddSubGraph(const ComputeGraphPtr &sub_graph) const; - Status FindSubGraph(const std::string &name, std::shared_ptr &graph_impl) const; - private: - ComputeGraphPtr compute_graph_; -}; -using AscGraphImplPtr = std::shared_ptr; -} // namespace ge - -#endif // GRAPH_ASCENDC_IR_IMPL_H diff --git a/graph/ascendc_ir/generator/CMakeLists.txt b/graph/ascendc_ir/generator/CMakeLists.txt deleted file mode 100644 index c6437a6d9e9e367c401442107e392bc96ce026f5..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/CMakeLists.txt +++ /dev/null @@ -1,62 +0,0 @@ -include(${METADEF_DIR}/cmake/build_type.cmake) -add_library(aihac_ir_register SHARED - ascir_register.cc - ascir_registry.cc - ) -target_clone_compile_and_link_options(graph_base aihac_ir_register) -target_include_directories(aihac_ir_register PRIVATE - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/graph/utils - ${CMAKE_BINARY_DIR}/proto/metadef_protos) -target_link_libraries(aihac_ir_register PRIVATE - intf_pub - -Wl,--no-as-needed - $<$>:-lrt> - -ldl - graph - PUBLIC - aihac_ir - ascend_protobuf_shared_headers - slog - c_sec - metadef_headers) - -add_library(ascir_generate SHARED - generator.cc - ) -target_clone_compile_and_link_options(aihac_ir_register ascir_generate) - -target_include_directories(ascir_generate PRIVATE - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/graph - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/graph/utils - ${CMAKE_BINARY_DIR}/proto/metadef_protos) - -target_link_libraries(ascir_generate PRIVATE - intf_pub - -Wl,--no-as-needed - $<$>:-lrt> - -ldl - aihac_ir_register - aihac_ir - ascend_protobuf_shared_headers - PUBLIC - metadef_headers) - -add_executable(ascir_ops_header_generator ascir_ops_generator_main.cc) -target_compile_definitions(ascir_ops_header_generator PRIVATE - $<$:ONLY_COMPILE_OPEN_SRC> -) -target_link_libraries(ascir_ops_header_generator PRIVATE ascir_generate - intf_pub - -Wl,--no-as-needed - $<$>:-lrt> - -ldl - c_sec - static_mmpa) -include(generator.cmake) - - diff --git a/graph/ascendc_ir/generator/ascir_ops_generator_main.cc b/graph/ascendc_ir/generator/ascir_ops_generator_main.cc deleted file mode 100644 index 8ef7a32a9d5871eedccb5b84742915d285696347..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/ascir_ops_generator_main.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include -#include "mmpa/mmpa_api.h" -#include "generator.h" -#include "inc/common/util/sanitizer_options.h" - -int main(int argc, char *argv[]) { - constexpr int kExpectArgNum = 3; - if (argc != kExpectArgNum) { - std::cerr << "Arg format: ascir_ops_header_generator " << std::endl; - return 1; - } - void *const handle = mmDlopen( - argv[1], static_cast(static_cast(MMPA_RTLD_NOW) | static_cast(MMPA_RTLD_GLOBAL))); - if (handle == nullptr) { - const auto *error = mmDlerror(); - error = (error == nullptr) ? "" : error; - std::cerr << "dlopen failed, so name:" << argv[1] << ", error info:" << error << std::endl; - return 1; - } - const auto ret = ge::ascir::GenHeaderFile(argv[kExpectArgNum - 1]); - (void) mmDlclose(handle); - DT_DO_DETECT_LEAKS(); - return ret; -} \ No newline at end of file diff --git a/graph/ascendc_ir/generator/ascir_register.cc b/graph/ascendc_ir/generator/ascir_register.cc deleted file mode 100644 index 25e421832e8836fb8158cf217fcf2da6d1f5b85f..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/ascir_register.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "graph/ascendc_ir/ascir_register.h" -#include "graph/ascendc_ir/ascir_registry.h" -#include "graph/types.h" -namespace ge { -namespace ascir { -AscirRegister::AscirRegister(const char *type, const char *def_file_path, int64_t line) : ir_def_{} { - ir_def_.Init(type, def_file_path, line); -} - -AscirRegister &AscirRegister::Inputs(std::vector &&input_names) { - for (const auto &input_name : input_names) { - ir_def_.AppendInput(input_name.GetString(), ge::IrInputType::kIrInputRequired); - } - return *this; -} - -AscirRegister &AscirRegister::DynamicInput(const std::string &input_name) { - ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputDynamic); - return *this; -} - -AscirRegister &AscirRegister::OptionalInput(const std::string &input_name) { - ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputOptional); - return *this; -} - -AscirRegister &AscirRegister::Outputs(std::vector &&output_names) { - for (const auto &output_name : output_names) { - ir_def_.AppendOutput(output_name.GetString(), ge::IrOutputType::kIrOutputRequired); - } - return *this; -} - -AscirRegister &AscirRegister::DynamicOutput(const std::string &output_name) { - ir_def_.AppendOutput(output_name, ge::IrOutputType::kIrOutputDynamic); - return *this; -} - -AscirRegister::AscirRegister(const AscirRegister &other) { - AscirRegistry::GetInstance().RegisterAscIr(other.ir_def_.GetType(), other.ir_def_); -} -AscirRegister &AscirRegister::Attr(std::string name, std::string asc_type, std::string ge_type) { - if (ir_def_.IsAttrExisted(name)) { - return *this; - } - ir_def_.SetAttr(name, asc_type, ge_type); - return *this; -} -AscirRegister &AscirRegister::StartNode() { - ir_def_.StartNode(); - return *this; -} -AscirRegister &AscirRegister::InferDataType(AscIrDef::CodeGenerator infer_data_type_generator) { - ir_def_.infer_data_type_generator = std::move(infer_data_type_generator); - return *this; -} -AscirRegister &AscirRegister::InferView(AscIrDef::CodeGenerator infer_view_generator) { - ir_def_.infer_view_generator = std::move(infer_view_generator); - return *this; -} - -AscirRegister &AscirRegister::Views(const std::vector &views_policy) { - ir_def_.SetViewPolicy(views_policy); - return InferView(InferViewByPolicy); -} -AscirRegister &AscirRegister::DataTypes(const std::vector &data_types_policy) { - ir_def_.SetDtypePolicy(data_types_policy); - return InferDataType(InferDtypeByPolicy); -} - -AscirRegister &AscirRegister::Input(const char_t *input_name, const char_t *datatype_symbol) { - ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputRequired); - ir_def_.StoreInputIrSymName(input_name, datatype_symbol); - ir_def_.MutableDataTypeSymbolStore().SetInputSymbol(input_name, ge::kIrInputRequired, datatype_symbol); - return *this; -} -AscirRegister &AscirRegister::Output(const char_t *output_name, const char_t *datatype_symbol) { - ir_def_.AppendOutput(output_name, ge::IrOutputType::kIrOutputRequired); - ir_def_.StoreOutputIrSymName(output_name, datatype_symbol); - ir_def_.MutableDataTypeSymbolStore().SetOutputSymbol(output_name, ge::kIrOutputRequired, datatype_symbol); - return *this; -} -AscirRegister &AscirRegister::DataType(const char_t *datatype_symbol, const TensorType &type_range) { - ir_def_.MutableDataTypeSymbolStore().DeclareSymbol(datatype_symbol, type_range); - return *this; -} - -AscirRegister &AscirRegister::DynamicInput(const char_t *input_name, const char_t *datatype_symbol) { - ir_def_.AppendInput(input_name, ge::IrInputType::kIrInputDynamic); - ir_def_.StoreInputIrSymName(input_name, datatype_symbol); - ir_def_.MutableDataTypeSymbolStore().SetInputSymbol(input_name, ge::kIrInputDynamic, datatype_symbol); - return *this; -} - -AscirRegister &AscirRegister::DynamicOutput(const char_t *output_name, const char_t *datatype_symbol) { - ir_def_.AppendOutput(output_name, ge::IrOutputType::kIrOutputDynamic); - ir_def_.StoreOutputIrSymName(output_name, datatype_symbol); - ir_def_.MutableDataTypeSymbolStore().SetOutputSymbol(output_name, ge::kIrOutputDynamic, datatype_symbol); - return *this; -} - -AscirRegister &AscirRegister::DataType(const char_t *datatype_symbol, const OrderedTensorTypeList &type_range) { - ir_def_.MutableDataTypeSymbolStore().DeclareSymbol(datatype_symbol, type_range); - return *this; -} - -AscirRegister &AscirRegister::CalcTmpBufSize(const std::string &calc_tmp_buf_size_func) { - ir_def_.SetCalcTmpBufSizeFunc(calc_tmp_buf_size_func, CalcTmpBufSizeFuncType::CustomizeType); - return *this; -} -AscirRegister &AscirRegister::SameTmpBufSizeFromFirstInput() { - ir_def_.SetCalcTmpBufSizeFunc("SameTmpBufSizeWithFirstInput", CalcTmpBufSizeFuncType::CommonType); - return *this; -} - -AscirRegister &AscirRegister::ApiTilingDataType(const std::string &tiling_data_name) { - ir_def_.SetApiTilingDataName(tiling_data_name); - return *this; -} - -AscirRegister &AscirRegister::Impl(const std::vector &soc_versions, const AscIrImpl &impl) { - ir_def_.AddSocImpl(soc_versions, impl); - return *this; -} - -AscirRegister &AscirRegister::Impl(const std::vector &soc_versions, const AscIrImplV2 &impl) { - ir_def_.AddSocImplV2(soc_versions, impl); - return *this; -} - -size_t AscirRegister::GetSocImplSize() const { - return ir_def_.GetSocImplSize(); -} - -template<> -AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { - return Attr(name.GetString(), "float", "Float"); -} - -template<> -AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { - return Attr(name.GetString(), "ge::DataType", "Int"); -} -template<> -AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { - return Attr(name.GetString(), "ge::Tensor", "Tensor"); -} -template<> -AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { - return Attr(name.GetString(), "std::string", "String"); -} -template<> -AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { - return Attr(name.GetString(), "int64_t", "Int"); -} -template<> -AscirRegister &AscirRegister::Attr>>(ge::AscendString &&name) { - return Attr(name.GetString(), "std::vector>", "ListListInt"); -} -template<> -AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { - return Attr(name.GetString(), "ge::Format", "Int"); -} -template<> -AscirRegister &AscirRegister::Attr(ge::AscendString &&name) { - return Attr(name.GetString(), "ge::Expression", "ge::Expression"); -} - -AscirRegister &AscirRegister::Comment(const string &comment) { - ir_def_.SetComment(comment); - return *this; -} -} // namespace ascir -} // namespace ge diff --git a/graph/ascendc_ir/generator/ascir_registry.cc b/graph/ascendc_ir/generator/ascir_registry.cc deleted file mode 100644 index 2ca0a459c161a47d6413b06b8ef1183a286317de..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/ascir_registry.cc +++ /dev/null @@ -1,289 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include -#include "graph/ascendc_ir/ascir_registry.h" -namespace ge { -namespace ascir { -struct AscIrDefImpl { - public: - std::unique_ptr GetAscIrAttImpl(const std::string &soc_version) { - auto impl = soc_2_impl_.find(soc_version); - if (impl != soc_2_impl_.end()) { - return (impl->second.att == nullptr) ? nullptr : impl->second.att(); - } - - auto impl_v2 = soc_2_impl_v2_.find(soc_version); - if (impl_v2 != soc_2_impl_v2_.end()) { - return (impl_v2->second.att == nullptr) ? nullptr : impl_v2->second.att(); - } - return nullptr; - } - std::unique_ptr GetAscIrCodegenImpl(const std::string &soc_version) { - auto impl = soc_2_impl_.find(soc_version); - if (impl != soc_2_impl_.end()) { - return (impl->second.codegen == nullptr) ? nullptr : impl->second.codegen(); - } - - auto impl_v2 = soc_2_impl_v2_.find(soc_version); - if (impl_v2 != soc_2_impl_v2_.end()) { - return (impl_v2->second.codegen == nullptr) ? nullptr : impl_v2->second.codegen(); - } - - return nullptr; - } - - std::string file_path; - int64_t line{}; - std::string type; - std::vector> input_defs; - std::vector> output_defs; - std::unordered_map input_name_to_sym_name; - std::unordered_map output_name_to_sym_name; - std::vector attr_defs; - - std::vector output_views_policy; - std::vector output_dtypes_policy; - - bool start_node{false}; - // TODO 整改后删除 - IRDataTypeSymbolStore dtype_symbol_store; - std::string comment; - CalcTmpBufSizeFunc calc_tmp_buf_size_func; - std::string tiling_data_name; - std::map soc_2_impl_; - std::map soc_2_impl_v2_; - std::map soc_2_dtype_sym_store_; -}; - -AscIrDef::AscIrDef() { - impl_ = std::make_shared(); -} - -bool AscIrDef::IsAttrExisted(const std::string &attr_name) const { - return std::find_if(impl_->attr_defs.begin(), impl_->attr_defs.end(), - [&attr_name](const AscIrAttrDef &asc_ir_attr_def) { - return asc_ir_attr_def.name == attr_name; - }) != impl_->attr_defs.end(); -} - -void AscIrDef::Init(const char *type, const char *def_file_path, int64_t line) const { - impl_->type = type; - impl_->file_path = def_file_path; - impl_->line = line; - impl_->start_node = false; -} - -const std::vector> &AscIrDef::GetInputDefs() const { - return impl_->input_defs; -} -const std::vector> &AscIrDef::GetOutputDefs() const { - return impl_->output_defs; -} - -bool AscIrDef::HasDynamicOutput() const { - for (auto &def : impl_->output_defs) { - if (def.second == ge::IrOutputType::kIrOutputDynamic) { - return true; - } - } - - return false; -} - -void AscIrDef::AppendInput(const std::string &name, ge::IrInputType type) const { - impl_->input_defs.emplace_back(name, type); -} -void AscIrDef::AppendOutput(const std::string &name, ge::IrOutputType type) const { - impl_->output_defs.emplace_back(name, type); -} - -void AscIrDef::StoreInputIrSymName(const std::string &ir_name, const std::string &sym_name) const { - impl_->input_name_to_sym_name[ir_name] = sym_name; -} -void AscIrDef::StoreOutputIrSymName(const std::string &ir_name, const std::string &sym_name) const { - impl_->output_name_to_sym_name[ir_name] = sym_name; -} - -const std::string &AscIrDef::GetType() const { - return impl_->type; -} -void AscIrDef::StartNode() const { - impl_->start_node = true; -} - -bool AscIrDef::IsStartNode() const { - return impl_->start_node; -} - -void AscIrDef::SetAttr(const std::string &name, const std::string &asc_type, const std::string &ge_type) const { - impl_->attr_defs.emplace_back(AscIrAttrDef{name, asc_type, ge_type}); -} - -void AscIrDef::SetDtypePolicy(const std::vector &output_dtypes_policy) const { - impl_->output_dtypes_policy = output_dtypes_policy; -} - -const std::vector &AscIrDef::GetOutputDtypePolicy() const { - return impl_->output_dtypes_policy; -} - -void AscIrDef::SetViewPolicy(const std::vector &view_policy) const { - impl_->output_views_policy = view_policy; -} - -const std::vector &AscIrDef::GetViewPolicy() const { - return impl_->output_views_policy; -} - -void AscIrDef::SetApiTilingDataName(const std::string &tiling_data_name) const { - if (!impl_->tiling_data_name.empty()) { - GELOGE(ge::FAILED, "%s has registered tiling data: %s", impl_->type.c_str(), impl_->tiling_data_name.c_str()); - return; - } - impl_->tiling_data_name = tiling_data_name; -} - -const string &AscIrDef::GetApiTilingDataName() const { - return impl_->tiling_data_name; -} - -void AscIrDef::SetCalcTmpBufSizeFunc(const std::string &calc_tmp_buf_size_func, CalcTmpBufSizeFuncType type) const { - if (!impl_->calc_tmp_buf_size_func.func_name.empty()) { - GELOGE(ge::FAILED, "has registered calc_tmp_buf_size_func: %s", impl_->calc_tmp_buf_size_func.func_name.c_str()); - return; - } - impl_->calc_tmp_buf_size_func = CalcTmpBufSizeFunc{calc_tmp_buf_size_func, type}; -} - -const CalcTmpBufSizeFunc &AscIrDef::GetCalcTmpBufSizeFunc() const { - return impl_->calc_tmp_buf_size_func; -} - -const std::vector &AscIrDef::GetAttrDefs() const { - return impl_->attr_defs; -} - -std::vector &AscIrDef::MutableAttrDefs() const { - return impl_->attr_defs; -} - -void AscIrDef::SetComment(const string &comment) const { - impl_->comment = comment; -} - -const std::string &AscIrDef::GetComment() const { - return impl_->comment; -} -const std::string &AscIrDef::GetFilePath() const { - return impl_->file_path; -} - -int64_t AscIrDef::GetLine() const { - return impl_->line; -} - -IRDataTypeSymbolStore &AscIrDef::MutableDataTypeSymbolStore() const { - return impl_->dtype_symbol_store; -} - -const IRDataTypeSymbolStore &AscIrDef::GetDataTypeSymbolStore() const { - return impl_->dtype_symbol_store; -} - -const std::map &AscIrDef::GetSocToDataTypeSymbolStore() const { - return impl_->soc_2_dtype_sym_store_; -} - -void AscIrDef::AddSocImpl(const std::vector &soc_versions, const AscIrImpl &impl) const { - for (const auto &soc : soc_versions) { - impl_->soc_2_impl_[soc] = impl; - auto &dtype_sym_store = impl_->soc_2_dtype_sym_store_[soc]; - // reg_sym - for (const auto &input_def : impl_->input_defs) { - dtype_sym_store.SetInputSymbol(input_def.first, input_def.second, impl_->input_name_to_sym_name[input_def.first]); - } - for (const auto &output_def : impl_->output_defs) { - dtype_sym_store.SetOutputSymbol(output_def.first, output_def.second, - impl_->output_name_to_sym_name[output_def.first]); - } - // bind symbol to dtypes - for (const auto &iter : impl.support_dtypes) { - (void) dtype_sym_store.DeclareSymbol(iter.first, iter.second); - } - } -} - -void AscIrDef::AddSocImplV2(const std::vector &soc_versions, const AscIrImplV2 &impl) const { - for (const auto &soc : soc_versions) { - impl_->soc_2_impl_v2_[soc] = impl; - auto &dtype_sym_store = impl_->soc_2_dtype_sym_store_[soc]; - // reg_sym - for (const auto &input_def : impl_->input_defs) { - dtype_sym_store.SetInputSymbol(input_def.first, input_def.second, impl_->input_name_to_sym_name[input_def.first]); - } - for (const auto &output_def : impl_->output_defs) { - dtype_sym_store.SetOutputSymbol(output_def.first, output_def.second, - impl_->output_name_to_sym_name[output_def.first]); - } - // bind symbol to dtypes - for (const auto &iter : impl.support_dtypes) { - (void) dtype_sym_store.DeclareSymbol(iter.first, iter.second); - } - } -} - -void AscIrDef::AppendSocImpl(const AscIrDef &ir_def) const { - impl_->soc_2_impl_.insert(ir_def.impl_->soc_2_impl_.begin(), ir_def.impl_->soc_2_impl_.end()); - impl_->soc_2_impl_v2_.insert(ir_def.impl_->soc_2_impl_v2_.begin(), ir_def.impl_->soc_2_impl_v2_.end()); - impl_->soc_2_dtype_sym_store_.insert(ir_def.impl_->soc_2_dtype_sym_store_.begin(), - ir_def.impl_->soc_2_dtype_sym_store_.end()); -} - -size_t AscIrDef::GetSocImplSize() const { - return impl_->soc_2_impl_.size() + impl_->soc_2_impl_v2_.size(); -} - -std::unique_ptr AscIrDef::GetAscIrAttImpl(const std::string &soc_version) { - return impl_->GetAscIrAttImpl(soc_version); -} -std::unique_ptr AscIrDef::GetAscIrCodegenImpl(const std::string &soc_version) { - return impl_->GetAscIrCodegenImpl(soc_version); -} - -AscirRegistry &AscirRegistry::GetInstance() { - static AscirRegistry registry; - return registry; -} -void AscirRegistry::RegisterAscIr(const std::string &type, const AscIrDef &def) { - auto iter = types_to_ascir_.find(type); - if (iter == types_to_ascir_.end()) { - types_to_ascir_[type] = def; - } else { - iter->second.AppendSocImpl(def); - } -} -const std::unordered_map &AscirRegistry::GetAll() const { - return types_to_ascir_; -} - -std::unique_ptr AscirRegistry::GetIrAttImpl(const std::string &soc_version, const std::string &type) { - auto iter = types_to_ascir_.find(type); - return (iter == types_to_ascir_.end()) ? nullptr : types_to_ascir_[type].GetAscIrAttImpl(soc_version); -} -std::unique_ptr AscirRegistry::GetIrCodegenImpl(const std::string &soc_version, const std::string &type) { - auto iter = types_to_ascir_.find(type); - return (iter == types_to_ascir_.end()) ? nullptr : types_to_ascir_[type].GetAscIrCodegenImpl(soc_version); -} - -void AscirRegistry::ClearAll() { - types_to_ascir_.clear(); -}; - -} // namespace ascir -} // namespace ge diff --git a/graph/ascendc_ir/generator/external_generator.cmake b/graph/ascendc_ir/generator/external_generator.cmake deleted file mode 100644 index faa557d1713fc8635c0d5642d7bb503690b1d927..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/external_generator.cmake +++ /dev/null @@ -1,7 +0,0 @@ -function(ascir_generate depend_so_target out_dir so_var h_var) - add_custom_command( - OUTPUT ${h_var} - DEPENDS ${depend_so_target} - COMMAND ${out_dir}/ascir_ops_header_generator ${so_var} ${h_var} - ) -endfunction() diff --git a/graph/ascendc_ir/generator/generator.cc b/graph/ascendc_ir/generator/generator.cc deleted file mode 100644 index b41b41d876f073845aca340d09c75530e63837fa..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/generator.cc +++ /dev/null @@ -1,1853 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include -#include -#include -#include -#include -#include -#include - -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_util.h" -#include "graph/ascendc_ir/ascir_registry.h" -#include "graph/utils/type_utils.h" -namespace ge { -namespace ascir { -namespace { -const char *GetPureFileName(const char *path) { - const char *name = std::strrchr(path, '/'); - if (name == nullptr) { - name = path; - } else { - ++name; - } - return name; -} -std::string CapitalizeFirstLetter(const std::string &input) { - if (input.empty()) { - return input; - } - - std::string result = input; - if (std::islower(result[0])) { - result[0] = std::toupper(result[0]); - } - return result; -} - -void GenIrAttrMemberFuncs(const std::vector &attr_defs, std::stringstream &ss) { - if (attr_defs.empty()) { - return; - } - // 对每个属性生成对应的Set, Get函数 - for (const auto &attr_def : attr_defs) { - ss << " graphStatus Get" << CapitalizeFirstLetter(attr_def.name) << "(" << attr_def.asc_ir_type << "&" - << " " << attr_def.name << ") const {" << std::endl; - ss << " auto attr_value = attr_store_.GetAnyValue(\"" << attr_def.name << "\");" << std::endl; - ss << " GE_WARN_ASSERT(attr_value != nullptr);" << std::endl; - ss << " return attr_value->GetValue(" << attr_def.name << ");" << std::endl << " }" << std::endl; - - ss << " graphStatus Set" << CapitalizeFirstLetter(attr_def.name) << "(" << attr_def.asc_ir_type << " " - << attr_def.name << ") {" << std::endl; - ss << " auto attr_value = attr_store_.GetOrCreateAnyValue(\"" << attr_def.name << "\");" << std::endl; - ss << " GE_ASSERT_NOTNULL(attr_value);" << std::endl; - ss << " return attr_value->SetValue(" << attr_def.name << ");" << std::endl << " }" << std::endl; - } -} - -std::string TryGenIrAttrClass(const AscIrDef &def, std::stringstream &ss) { - const auto &attr_defs = def.GetAttrDefs(); - if (attr_defs.empty()) { - return (""); - } - const std::string &ir_type = def.GetType(); - std::string derived_class_name = std::string("Asc").append(ir_type).append("IrAttrDef"); - // 暂时没啥好的办法,data的类需要先定义好,gen出来的话有点晚了 - if (ir_type == ge::DATA) { - ss << " using " << derived_class_name << " = ge::" << derived_class_name << ";" << std::endl; - ss << " " << derived_class_name << " &ir_attr;" << std::endl; - return derived_class_name; - } - // 生成子类的定义 - ss << " struct " << derived_class_name << ": public AscIrAttrDefBase {" << std::endl; - ss << " ~" << derived_class_name << "() override = default;" << std::endl; - GenIrAttrMemberFuncs(attr_defs, ss); - ss << " };" << std::endl; - // 添加引用成员到上一级类 - ss << " " << derived_class_name << " &ir_attr;" << std::endl; - return derived_class_name; -}; -void GenIrInputAndOutputDef(const AscIrDef &def, std::stringstream &ss) { - const auto &input_defs = def.GetInputDefs(); - for (const auto &input_def : input_defs) { - if (input_def.second == ge::IrInputType::kIrInputDynamic) { - ss << " this->DynamicInputRegister(\"" << input_def.first << "\", 0U, true);" << std::endl; - } else if (input_def.second == ge::IrInputType::kIrInputOptional) { - ss << " this->OptionalInputRegister(\"" << input_def.first << "\");" << std::endl; - } else { - ss << " this->InputRegister(\"" << input_def.first << "\");" << std::endl; - } - } - - const auto &output_defs = def.GetOutputDefs(); - for (const auto &output_def : output_defs) { - if (output_def.second == ge::IrOutputType::kIrOutputDynamic) { - ss << " this->DynamicOutputRegister(\"" << output_def.first << "\", 0U, true);" << std::endl; - } else { - ss << " this->OutputRegister(\"" << output_def.first << "\");" << std::endl; - } - } -} -// 初始化列表中初始化的成员依赖构造函数中的ir信息确定之后再次进行赋值 -void GenOutTensorInitDef(const AscIrDef &def, std::stringstream &ss) { - const auto &output_defs = def.GetOutputDefs(); - for (const auto &output_def : output_defs) { - if (output_def.second != ge::IrOutputType::kIrOutputDynamic) { - ss << " " << output_def.first << ".TryInitTensorAttr();\n"; - } - } -} -void GenConstructorDef(const AscIrDef &def, const std::string &attr_class, std::stringstream &ss, - const bool need_graph = false) { - const std::string &ir_type = def.GetType(); - if (need_graph) { - ss << " inline " << ir_type << "(const char *name, AscGraph &graph) : ge::Operator"; - } else { - ss << " inline " << ir_type << "(const char *name) : ge::Operator"; - } - if (attr_class.empty()) { - ss << "(name, Type), attr(" << "*AscNodeAttr::Create(*this))"; - } else { - ss << "(name, Type), attr(" << "*AscNodeAttr::Create<" << attr_class << ">(*this))"; - ss << ", ir_attr(dynamic_cast<" << attr_class << "&>(*(attr.ir_attr)))"; - } - const auto &input_defs = def.GetInputDefs(); - for (const auto &input_def : input_defs) { - ss << "," << std::endl << " " << input_def.first << "(this)"; - } - const auto &output_defs = def.GetOutputDefs(); - for (size_t i = 0UL; i < output_defs.size(); ++i) { - if (output_defs[i].second != ge::IrOutputType::kIrOutputDynamic) { - ss << "," << std::endl << " " << output_defs[i].first << "(this, " << i << ")"; - } - } - ss << " {" << std::endl; - GenIrInputAndOutputDef(def, ss); - GenOutTensorInitDef(def, ss); - if (need_graph) { - ss << " graph.AddNode(*this) ;" << std::endl << " }" << std::endl; - } else { - ss << " }" << std::endl << std::endl; - } -} - -std::string DataTypeToSerialString(const DataType data_type) { - auto res = TypeUtils::DataTypeToSerialString(data_type); - if (res == "DT_BFLOAT16") { // 历史原因,DT_BF16的string表达是DT_BFLOAT16,所以我们需要特殊处理一下 - return "DT_BF16"; - } - return res; -} - -std::string TensorTypeToCode(const TensorType &tensor_type) { - std::string s = "{"; - size_t index = 0U; - for (const auto &dtype : tensor_type.tensor_type_impl_->GetMutableDateTypeSet()) { - s += DataTypeToSerialString(dtype); - if (index++ < (tensor_type.tensor_type_impl_->GetMutableDateTypeSet().size() - 1U)) { - s += ", "; - } - } - s += "}"; - return s; -} - -// TODO 兼容新老接口 -const IRDataTypeSymbolStore &GetFirstDataTypeSymbolStore(const AscIrDef &ir_def) { - const auto &soc_to_sym_store = ir_def.GetSocToDataTypeSymbolStore(); - if (!soc_to_sym_store.empty()) { - return soc_to_sym_store.begin()->second; - } - return ir_def.GetDataTypeSymbolStore(); -} - -bool IsSocOrderedTensorListStore(const AscIrDef &ir_def) { - const auto &soc_to_sym_store = ir_def.GetSocToDataTypeSymbolStore(); - if (!soc_to_sym_store.empty()) { - return soc_to_sym_store.begin()->second.IsSupportOrderedSymbolicInferDtype(); - } - return false; -} - -class SymbolProcessor { - public: - explicit SymbolProcessor(const AscIrDef &def) : def_(def) {} - Status ProcessSymbol(const std::pair &sym, std::stringstream &ss) { - // 外部保证非空,不允许注册一个不带dtype的sym - GE_ASSERT_TRUE(!(sym.second->GetTensorType().tensor_type_impl_->GetMutableDateTypeSet().empty())); - // 只有输出持有的sym不在这里处理 - if (sym.second->GetIrInputIndexes().empty()) { - return SUCCESS; - } - - GenerateInputDtypeUniquenessCheck(sym.second, ss); - GenerateTypeDefinition(sym, ss); - GenerateDtypeValidation(sym, ss); - return SUCCESS; - } - - Status GenSocVersionCallFunc(std::stringstream &ss) { - for (const auto &iter : def_.GetSocToDataTypeSymbolStore()) { - for (const auto &sym : iter.second.GetNamedSymbols()) - name_to_soc_to_sym_dtype_[sym.first].emplace(iter.first, iter.second.GetNamedSymbols()[sym.first]); - } - if (!name_to_soc_to_sym_dtype_.empty()) { - // get soc version. - ss << " char soc_version[128] = {};" << std::endl; - ss << " auto res = rtGetSocVersion(soc_version, 128U);" << std::endl; - ss << R"( GE_ASSERT_TRUE(res == RT_ERROR_NONE, "Failed to get soc version str.");)" << std::endl; - ss << " auto soc_str = std::string(soc_version);" << std::endl; - } - return SUCCESS; - } - - Status ProcessSymbolWithNoCheck(const std::pair &sym, std::stringstream &ss) { - // 外部保证非空,不允许注册一个不带dtype的sym - GE_ASSERT_TRUE(!(sym.second->GetTensorType().tensor_type_impl_->GetMutableDateTypeSet().empty())); - // 只有输出持有的sym不在这里处理 - if (sym.second->GetIrInputIndexes().empty()) { - return SUCCESS; - } - GenerateInputDtypeUniquenessCheck(sym.second, ss); - return SUCCESS; - } - - protected: - void GenerateTypeDefinition(const std::pair &sym, std::stringstream &ss) { - // TODO 老代码删除后优化for循环 - const std::string tensor_type_obj = "support_dtypes_of_sym_" + sym.first; - auto iter = name_to_soc_to_sym_dtype_.find(sym.first); - if (iter != name_to_soc_to_sym_dtype_.end()) { - ss << " std::set " << tensor_type_obj << ";" << std::endl; - bool is_first = true; - for (const auto &soc_to_sym : iter->second) { - if (is_first) { - ss << " if (soc_str == \"" << soc_to_sym.first << "\") {\n"; - is_first = false; - } else { - ss << " } else if (soc_str == \"" << soc_to_sym.first << "\") {\n"; - } - ss << " " << tensor_type_obj << " = " << TensorTypeToCode(soc_to_sym.second->GetTensorType()) << ";\n"; - } - ss << " } else {\n"; - ss << R"( GELOGE(ge::FAILED, "Failed to get soc version, res:%s", soc_str.c_str());)" << std::endl; - ss << " return ge::FAILED;\n"; - ss << " }\n"; - return; - } - - ss << " const static std::set " << tensor_type_obj << " = " - << TensorTypeToCode(sym.second->GetTensorType()) << ";\n"; - } - - static void GenerateDtypeValidation(const std::pair &sym, std::stringstream &ss) { - // 共sym的已经都校验过一致了,所以这里可以用第一个来校验是否在支持范围内 - auto check_index = sym.second->GetIrInputIndexes().front(); - ss << " GE_WARN_ASSERT(" - << "support_dtypes_of_sym_" << sym.first << ".find(input_dtypes[" << check_index - << "]) != support_dtypes_of_sym_" << sym.first << ".end());" << std::endl; - } - - void GenerateInputDtypeUniquenessCheck(const SymDtype *sym, std::stringstream &ss) { - auto input_ir_indexes = sym->GetIrInputIndexes(); - // 小于2个,不需要比较 - if (input_ir_indexes.size() < 2U) { - return; - } - for (size_t index = 0U; index < input_ir_indexes.size() - 1U; ++index) { - ss << " GE_WARN_ASSERT(" << input_arg_name_ << "[" << input_ir_indexes[index] << "] == " << input_arg_name_ - << "[" << input_ir_indexes[index + 1] << "]);" << std::endl; - } - } - const AscIrDef &def_; - std::string input_arg_name_{"input_dtypes"}; - std::map> name_to_soc_to_sym_dtype_; -}; - -class OutputHandler { - public: - explicit OutputHandler(const AscIrDef &def) : def_(def) {} - - void GenerateOutputInference(std::stringstream &ss, int space_count = 6) { - bool could_infer = true; - std::stringstream warning_code; - size_t out_index = 0U; - for (const auto out_sym : GetFirstDataTypeSymbolStore(def_).GetOutSymbols()) { - sym_2_ir_indexs_[out_sym].push_back(out_index); - if (!(out_sym->GetIrInputIndexes().empty())) { - // 有输入对应,一定是可以推导的 - ++out_index; - continue; - } - (void) syms_only_of_output_.insert(out_sym); - auto support_types = out_sym->GetTensorType().tensor_type_impl_->GetMutableDateTypeSet(); - if (support_types.size() > 1U) { - could_infer = false; - warning_code << std::string(space_count, ' ') << "GELOGW(\"Output ir_index [" << out_index - << "] has multi result " << TensorTypeToCode(out_sym->GetTensorType()) << ", can not infer.\");\n"; - } - ++out_index; - } - if (!could_infer) { - ss << warning_code.str(); - ss << std::string(space_count, ' ') << "return FAILED;\n"; - return; - } - for (const auto out_sym : GetFirstDataTypeSymbolStore(def_).GetOutSymbols()) { - if (syms_only_of_output_.find(out_sym) == syms_only_of_output_.end()) { - // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype作为输出的dtype" - ss << std::string(space_count, ' ') << "expect_output_dtypes.push_back(input_dtypes[" - << out_sym->GetIrInputIndexes().front() << "]);" << std::endl; - } else { - GenerateCustomTypeInference(out_sym, ss, space_count); - } - } - ss << std::string(space_count, ' ') << "return SUCCESS;" << std::endl; - } - - void GenerateOutputValidation(std::stringstream &ss) { - size_t out_index = 0; - for (const auto out_sym : GetFirstDataTypeSymbolStore(def_).GetOutSymbols()) { - if (syms_only_of_output_.find(out_sym) == syms_only_of_output_.end()) { - // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype跟out的dtype比对 - ss << " GE_WARN_ASSERT(input_dtypes[" << out_sym->GetIrInputIndexes().front() << "] == " - << "expect_output_dtypes[" << out_index << "]);" << std::endl; - } else { - GenerateCustomTypeValidation(out_sym, ss); - } - ++out_index; - } - } - - protected: - static void GenerateCustomTypeInference(const SymDtype *sym, std::stringstream &ss, int space_count) { - // 走到这里说明out使用了跟所有input不一样的sym, 并且此输出只有一个支持的类型 - auto support_types = sym->GetTensorType().tensor_type_impl_->GetMutableDateTypeSet(); - ss << std::string(space_count, ' ') << "expect_output_dtypes.push_back(" - << DataTypeToSerialString(*support_types.begin()) << ");\n"; - } - - Status GenerateCustomTypeValidation(SymDtype *sym, std::stringstream &ss) { - if (!syms_checked_.insert(sym).second) { - return SUCCESS; - } - auto iter = sym_2_ir_indexs_.find(sym); - GE_ASSERT_TRUE(iter != sym_2_ir_indexs_.end()); - auto indexes_of_this_sym = iter->second; - // 小于2个,不需要比较 - if (indexes_of_this_sym.size() >= 2U) { - for (size_t index = 0U; index < indexes_of_this_sym.size() - 1U; ++index) { - ss << " GE_WARN_ASSERT(expect_output_dtypes[" << indexes_of_this_sym[index] << "] == expect_output_dtypes[" - << indexes_of_this_sym[index + 1] << "]);" << std::endl; - } - } - const std::string tensor_type_obj = "support_dtypes_of_sym_" + sym->Id(); - ss << " static std::set " << tensor_type_obj << " = " << TensorTypeToCode(sym->GetTensorType()) - << ";\n"; - // 共sym的已经都校验过一致了,所以这里可以用第一个来校验是否在支持范围内 - auto check_index = indexes_of_this_sym.front(); - ss << " GE_WARN_ASSERT(" - << "support_dtypes_of_sym_" << sym->Id() << ".find(expect_output_dtypes[" << check_index - << "]) != support_dtypes_of_sym_" << sym->Id() << ".end());" << std::endl; - return SUCCESS; - } - - private: - const AscIrDef &def_; - std::set syms_only_of_output_{}; - std::map> sym_2_ir_indexs_{}; - std::set syms_checked_{}; // 多个输出可能sym一样 -}; - -class OrderedSymbolProcessor : public SymbolProcessor { - public: - explicit OrderedSymbolProcessor(const AscIrDef &def) : SymbolProcessor(def), valid_dtype_nums_of_sym_(0U) {} - - Status PreProcessSymbol(std::stringstream &ss) { - const auto &symbols = def_.GetDataTypeSymbolStore().GetSymbols(); - GE_ASSERT_TRUE(!symbols.empty()); - GE_ASSERT_SUCCESS(InitializeSymbolAttributes(symbols)); - GE_ASSERT_SUCCESS(ClassifySymbols(symbols, ss)); - return SUCCESS; - } - - Status InstanceSymbol(std::stringstream &ss) { - BuildResultMapping(); - ss << GenerateResultContainer(); - return SUCCESS; - } - - void CheckSymbol(std::stringstream &ss) { - ss << "\n"; - if (input_syms_.size() > 1U) { - ss << " auto iter = results.find(std::vector{"; - size_t index{0U}; - for (auto input_sym : input_syms_) { - ss << "input_dtypes[" << input_sym->GetIrInputIndexes().front() << "]"; - if (index++ < input_syms_.size() - 1U) { - ss << ", "; - } - } - ss << "});\n"; - } else { - ss << " auto iter = results.find(input_dtypes[" << input_syms_.front()->GetIrInputIndexes().front() << "]);\n"; - } - ss << " GE_WARN_ASSERT(iter != results.end());\n"; - } - - Status HandleOutput(std::stringstream &ss) { - ss << " // 输出外部不指定的时候,生成推导的代码" << std::endl; - ss << " if (expect_output_dtypes.empty()) {" << std::endl; - GenerateOutputInference(ss); - ss << " }" << std::endl; - ss << " // 输出外部指定,生成校验的代码" << std::endl; - GenerateOutputValidation(ss); - return SUCCESS; - } - - private: - void GenerateOutputInference(std::stringstream &ss) { - size_t only_output_index{0U}; - for (const auto out_sym : def_.GetDataTypeSymbolStore().GetOutSymbols()) { - if (std::find(only_out_syms_.begin(), only_out_syms_.end(), out_sym) == only_out_syms_.end()) { - // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype作为输出的dtype" - ss << " expect_output_dtypes.push_back(input_dtypes[" << out_sym->GetIrInputIndexes().front() << "]);" - << std::endl; - } else { - // 走到这里说明out使用了跟所有input不一样的sym, 我们使用输入的type推导输出 - if (container_meta_.output_count == 1U) { - // std::set - if (container_meta_.has_multiple_solutions) { - ss << " GE_WARN_ASSERT(iter->second.size() == 1U);" << std::endl; - ss << " expect_output_dtypes.push_back(*(iter->second.begin()));" << std::endl; - // ge::DataType - } else { - ss << " expect_output_dtypes.push_back(iter->second);" << std::endl; - } - } else { - // std::vector> - if (container_meta_.has_multiple_solutions) { - ss << " GE_WARN_ASSERT(iter->second[" << only_output_index << "].size() == 1U);" << std::endl; - ss << " expect_output_dtypes.push_back(*(iter->second[" << only_output_index << "].begin()));" - << std::endl; - // std::vector> - } else { - ss << " expect_output_dtypes.push_back(iter->second[" << only_output_index << "]));" << std::endl; - } - only_output_index++; - } - } - } - ss << " return SUCCESS;" << std::endl; - } - - void GenerateOutputValidation(std::stringstream &ss) { - size_t only_output_index{0U}; - size_t output_index{0U}; - for (const auto out_sym : def_.GetDataTypeSymbolStore().GetOutSymbols()) { - if (std::find(only_out_syms_.begin(), only_out_syms_.end(), out_sym) == only_out_syms_.end()) { - // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype跟输出的dtype做校验" - ss << " GE_WARN_ASSERT(input_dtypes[" << out_sym->GetIrInputIndexes().front() << "] == " - << "expect_output_dtypes[" << output_index << "]);" << std::endl; - } else { - // 走到这里说明out使用了跟所有input不一样的sym, 我们使用输入的type推导输出 - if (container_meta_.output_count == 1U) { - // std::set - if (container_meta_.has_multiple_solutions) { - ss << " GE_WARN_ASSERT(iter->second.find(expect_output_dtypes[" << output_index - << "]) != iter->second.end());" << std::endl; - // ge::DataType - } else { - ss << " GE_WARN_ASSERT(iter->second == " - << "expect_output_dtypes[" << output_index << "]);" << std::endl; - } - } else { - // std::vector> - if (container_meta_.has_multiple_solutions) { - ss << " GE_WARN_ASSERT(iter->second[" << only_output_index << "].find(expect_output_dtypes[" - << output_index << "]) != iter->second[" << only_output_index << "].end());" << std::endl; - // std::vector> - } else { - ss << " GE_WARN_ASSERT(iter->second[" << only_output_index << "] == " - << "expect_output_dtypes[" << output_index << "]);" << std::endl; - } - only_output_index++; - } - } - output_index++; - } - } - Status InitializeSymbolAttributes(const std::list> &symbols) { - const auto &first_sym = symbols.front(); - GE_ASSERT_NOTNULL(first_sym); - const auto &type_list = first_sym->GetOrderedTensorTypeList(); - valid_dtype_nums_of_sym_ = type_list.GetOrderedDtypes().size(); - return SUCCESS; - } - - Status ClassifySymbols(const std::list> &symbols, std::stringstream &ss) { - for (const auto &sym : symbols) { - GE_ASSERT_NOTNULL(sym); - GE_ASSERT_TRUE(sym->IsOrderedList()); - GE_ASSERT_SUCCESS(ValidateDtypeConsistency(sym)); - if (IsOutputOnlySymbol(sym)) { - only_out_syms_.emplace_back(sym.get()); - } else { - ProcessInputSymbol(sym.get(), ss); - } - } - return SUCCESS; - } - - Status ValidateDtypeConsistency(const std::shared_ptr &sym) const { - const auto ¤t_types = sym->GetOrderedTensorTypeList().GetOrderedDtypes(); - GE_ASSERT_TRUE(current_types.size() == valid_dtype_nums_of_sym_); - return SUCCESS; - } - - static bool IsOutputOnlySymbol(const std::shared_ptr &sym) { - return sym->GetIrInputIndexes().empty(); - } - - void ProcessInputSymbol(SymDtype *sym, std::stringstream &ss) { - GenerateInputDtypeUniquenessCheck(sym, ss); - input_syms_.emplace_back(sym); - } - - // 结果生成相关 - void BuildResultMapping() { - results2_.clear(); - for (size_t idx = 0U; idx < valid_dtype_nums_of_sym_; ++idx) { - auto inputs = CollectInputDtypes(idx); - auto outputs = CollectOutputDtypes(idx); - results2_.emplace(std::move(inputs), std::move(outputs)); - } - } - - std::vector CollectInputDtypes(size_t index) { - std::vector inputs; - inputs.reserve(input_syms_.size()); - - for (auto sym : input_syms_) { - inputs.push_back(GetDtypeByIndex(sym, index)); - } - return inputs; - } - - std::vector CollectOutputDtypes(size_t index) { - std::vector outputs; - outputs.reserve(only_out_syms_.size()); - - for (auto sym : only_out_syms_) { - outputs.push_back(GetDtypeByIndex(sym, index)); - } - return outputs; - } - - static ge::DataType GetDtypeByIndex(const SymDtype *sym, size_t index) { - const auto &types = sym->GetOrderedTensorTypeList().GetOrderedDtypes(); - GE_ASSERT_TRUE(index < types.size()); - return types[index]; - } - - // 容器生成相关 - std::string GenerateResultContainer() { - const auto solution_map = BuildSolutionMap(); - const bool has_multiple = CheckMultipleSolutions(solution_map); - - container_meta_ = {.input_count = input_syms_.size(), - .output_count = only_out_syms_.size(), - .has_multiple_solutions = has_multiple}; - - return BuildContainerString(solution_map, container_meta_); - } - - using SolutionMap = std::map, std::set>>; - - SolutionMap BuildSolutionMap() { - SolutionMap mapping; - for (const auto &pair : results2_) { - mapping[pair.first].insert(pair.second); - } - return mapping; - } - - static bool CheckMultipleSolutions(const SolutionMap &mapping) { - return std::any_of(mapping.begin(), mapping.end(), - [](const std::pair, std::set>> &pair) { - return pair.second.size() > 1; - }); - } - - struct ContainerMeta { - size_t input_count; - size_t output_count; - bool has_multiple_solutions; - }; - - std::string BuildContainerString(const SolutionMap &mapping, const ContainerMeta &meta) { - std::ostringstream oss; - container_type_ = GetContainerType(meta); - oss << " const static " << container_type_ << " results = {\n"; - AppendContainerEntries(oss, mapping, meta); - oss << "\n };"; - - return oss.str(); - } - static std::string GetContainerType(const ContainerMeta &meta) { - std::ostringstream oss; - oss << "std::map<" << (meta.input_count > 1 ? "std::vector" : "ge::DataType") << ", "; - - if (meta.output_count > 1) { - oss << (meta.has_multiple_solutions ? "std::vector>" : "std::vector"); - } else { - oss << (meta.has_multiple_solutions ? "std::set" : "ge::DataType"); - } - - oss << ">"; - return oss.str(); - } - - static void AppendContainerEntries(std::ostream &os, const SolutionMap &mapping, const ContainerMeta &meta) { - std::vector entries; - entries.reserve(mapping.size()); - - for (const auto &inputs_2_outputs : mapping) { - entries.push_back(BuildEntryString(inputs_2_outputs.first, inputs_2_outputs.second, meta)); - } - os << " " << JoinEntries(entries, ",\n "); - } - - static std::string BuildEntryString(const std::vector &input, - const std::set> &outputs, const ContainerMeta &meta) { - return "{" + SerializeVector(input) + ", " + SerializeOutputs(outputs, meta) + "}"; - } - - static std::string SerializeVector(const std::vector &vec) { - if (vec.size() == 1U) { - return DataTypeToSerialString(vec[0U]); - } - - std::ostringstream oss; - oss << "{"; - for (size_t i = 0U; i < vec.size(); ++i) { - oss << DataTypeToSerialString(vec[i]); - if (i < vec.size() - 1U) { - oss << ", "; - } - } - oss << "}"; - return oss.str(); - } - - static std::string SerializeOutputs(const std::set> &outputs, const ContainerMeta &meta) { - if (meta.output_count == 1U) { - return SerializeSingleOutput(outputs, meta.has_multiple_solutions); - } - return SerializeMultiOutputs(outputs, meta); - } - - static std::string SerializeSingleOutput(const std::set> &outputs, bool multiple) { - if (!multiple) { - return DataTypeToSerialString(outputs.begin()->front()); - } - - std::vector unique_outputs; - for (const auto &out : outputs) { - unique_outputs.push_back(out.front()); - } - return SerializeSet(unique_outputs); - } - - static std::string SerializeMultiOutputs(const std::set> &outputs, - const ContainerMeta &meta) { - if (!meta.has_multiple_solutions) { - return SerializeVector(*outputs.begin()); - } - - std::vector> output_sets; - for (size_t i = 0; i < meta.output_count; ++i) { - output_sets.emplace_back(); - for (const auto &out : outputs) { - output_sets.back().insert(out[i]); - } - } - return SerializeSetVector(output_sets); - } - - static std::string SerializeSet(const std::vector &types) { - std::ostringstream oss; - oss << "{"; - for (size_t i = 0U; i < types.size(); ++i) { - oss << DataTypeToSerialString(types[i]); - if (i < types.size() - 1U) { - oss << ", "; - } - } - oss << "}"; - return oss.str(); - } - - static std::string SerializeSetVector(const std::vector> &sets) { - std::ostringstream oss; - oss << "{"; - for (size_t i = 0U; i < sets.size(); ++i) { - oss << SerializeSet(std::vector{sets[i].begin(), sets[i].end()}); - if (i < sets.size() - 1U) { - oss << ", "; - } - } - oss << "}"; - return oss.str(); - } - - static std::string JoinEntries(const std::vector &entries, const std::string &delimiter) { - std::ostringstream oss; - for (size_t i = 0U; i < entries.size(); ++i) { - oss << entries[i]; - if (i < entries.size() - 1U) { - oss << delimiter; - } - } - return oss.str(); - } - - std::vector input_syms_{}; - std::vector only_out_syms_{}; - // 每一个对key,value代表一个合法解, key是多个输入的实参dtype,value是多个输出的推导处理的dtype,之所以是multimap因为某个输出可能 - // 有多个解 - std::multimap, std::vector> results2_{}; - size_t valid_dtype_nums_of_sym_; - // 根据输入输出的个数和输出是否有唯一解需要实例化不同类型的容器 - std::string container_type_{}; - ContainerMeta container_meta_{}; -}; - -class SocOrderedSymbolProcessor : public SymbolProcessor { - public: - explicit SocOrderedSymbolProcessor(const AscIrDef &def) : SymbolProcessor(def) {} - - Status PreProcessSymbol(std::stringstream &ss) { - GE_ASSERT_SUCCESS(InitializeSymbolAttributes()); - const auto &store = GetFirstDataTypeSymbolStore(def_); - GE_ASSERT_TRUE(!store.GetSymbols().empty()); - GE_ASSERT_SUCCESS(ClassifySymbols(store, ss)); - return SUCCESS; - } - - Status InstanceSymbol(std::stringstream &ss) { - BuildResultMapping(); - ss << GenerateResultContainer(); - return SUCCESS; - } - - void CheckSymbol(std::stringstream &ss) { - ss << "\n"; - if (input_sym_id_to_ir_idx_.size() > 1U) { - ss << " auto iter = results.find(std::vector{"; - size_t index{0U}; - for (const auto &input_sym : input_sym_id_to_ir_idx_) { - ss << "input_dtypes[" << input_sym.second << "]"; - if (index++ < input_sym_id_to_ir_idx_.size() - 1U) { - ss << ", "; - } - } - ss << "});\n"; - } else { - ss << " auto iter = results.find(input_dtypes[" << input_sym_id_to_ir_idx_.begin()->second << "]);\n"; - } - ss << " GE_WARN_ASSERT(iter != results.end());\n"; - } - - Status HandleOutput(std::stringstream &ss) { - ss << " // 输出外部不指定的时候,生成推导的代码" << std::endl; - ss << " if (expect_output_dtypes.empty()) {" << std::endl; - GenerateOutputInference(ss); - ss << " }" << std::endl; - ss << " // 输出外部指定,生成校验的代码" << std::endl; - GenerateOutputValidation(ss); - return SUCCESS; - } - - private: - struct ContainerMeta { - size_t input_count; - size_t output_count; - }; - - void GenerateOutputInference(std::stringstream &ss) { - size_t only_output_index{0U}; - for (const auto &out_sym : GetFirstDataTypeSymbolStore(def_).GetOutSymbols()) { - auto iter = input_sym_id_to_ir_idx_.find(out_sym->Id()); - if (iter != input_sym_id_to_ir_idx_.end()) { - // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype作为输出的dtype" - ss << " expect_output_dtypes.push_back(input_dtypes[" << iter->second << "]);" << std::endl; - } else { - // 走到这里说明out使用了跟所有input不一样的sym, 我们使用输入的type推导输出 - if (container_meta_.output_count == 1U) { - // std::set - ss << " GE_WARN_ASSERT(iter->second.size() == 1U);" << std::endl; - ss << " expect_output_dtypes.push_back(*(iter->second.begin()));" << std::endl; - } else { - // std::vector> - ss << " GE_WARN_ASSERT(iter->second[" << only_output_index << "].size() == 1U);" << std::endl; - ss << " expect_output_dtypes.push_back(*(iter->second[" << only_output_index << "].begin()));" - << std::endl; - only_output_index++; - } - } - } - ss << " return ge::SUCCESS;" << std::endl; - } - - void GenerateOutputValidation(std::stringstream &ss) { - size_t only_output_index{0U}; - size_t output_index{0U}; - for (const auto &out_sym : GetFirstDataTypeSymbolStore(def_).GetOutSymbols()) { - auto iter = input_sym_id_to_ir_idx_.find(out_sym->Id()); - if (iter != input_sym_id_to_ir_idx_.end()) { - // 走到这里说明out使用了input的sym,因为前面校验过共sym的输入dtypes是一样的,所以我们用此sym的第一个输入ir的dtype跟输出的dtype做校验" - ss << " GE_WARN_ASSERT(input_dtypes[" << iter->second << "] == " - << "expect_output_dtypes[" << output_index << "]);" << std::endl; - } else { - // 走到这里说明out使用了跟所有input不一样的sym, 我们使用输入的type推导输出 - if (container_meta_.output_count == 1U) { - // std::set - ss << " GE_WARN_ASSERT(iter->second.find(expect_output_dtypes[" << output_index - << "]) != iter->second.end());" << std::endl; - } else { - // std::vector> - ss << " GE_WARN_ASSERT(iter->second[" << only_output_index << "].find(expect_output_dtypes[" - << output_index << "]) != iter->second[" << only_output_index << "].end());" << std::endl; - only_output_index++; - } - } - output_index++; - } - } - - Status InitializeSymbolAttributes() { - for (const auto &soc_store : def_.GetSocToDataTypeSymbolStore()) { - bool is_first{true}; - for (const auto &name_to_sym : soc_store.second.GetNamedSymbols()) { - GE_ASSERT_TRUE(name_to_sym.second->IsOrderedList()); - const auto ¤t_types = name_to_sym.second->GetOrderedTensorTypeList().GetOrderedDtypes(); - if (is_first) { - soc_name_to_meta_[soc_store.first].valid_dtype_nums_of_sym = current_types.size(); - is_first = false; - } else { - GE_ASSERT_TRUE(current_types.size() == soc_name_to_meta_[soc_store.first].valid_dtype_nums_of_sym); - } - } - } - return SUCCESS; - } - - Status ClassifySymbols(const IRDataTypeSymbolStore &store, std::stringstream &ss) { - for (const auto &sym : store.GetSymbols()) { - GE_ASSERT_NOTNULL(sym); - if (!sym->GetIrInputIndexes().empty()) { - GenerateInputDtypeUniquenessCheck(sym.get(), ss); - input_sym_id_to_ir_idx_.emplace(sym->Id(), sym->GetIrInputIndexes().front()); - } else { - only_output_sym_ids_.push_back(sym->Id()); - } - } - - container_meta_ = {.input_count = input_sym_id_to_ir_idx_.size(), .output_count = only_output_sym_ids_.size()}; - return SUCCESS; - } - - // 结果生成相关 - Status BuildResultMapping() { - for (const auto &iter : def_.GetSocToDataTypeSymbolStore()) { - const auto &soc_name = iter.first; - auto &soc_meta = soc_name_to_meta_[soc_name]; - for (size_t idx = 0U; idx < soc_meta.valid_dtype_nums_of_sym; ++idx) { - std::vector inputs; - std::vector outputs; - GE_ASSERT_GRAPH_SUCCESS(CollectInputDtypes(iter.second.GetNamedSymbols(), idx, inputs)); - GE_ASSERT_GRAPH_SUCCESS(CollectOutputDtypes(iter.second.GetNamedSymbols(), idx, outputs)); - soc_meta.results2.emplace(std::move(inputs), std::move(outputs)); - } - } - return ge::GRAPH_SUCCESS; - } - - Status CollectInputDtypes(const std::map &name_to_syms, size_t index, - std::vector &inputs) { - inputs.reserve(input_sym_id_to_ir_idx_.size()); - for (const auto &sym_id : input_sym_id_to_ir_idx_) { - auto iter = name_to_syms.find(sym_id.first); - GE_ASSERT_TRUE((iter != name_to_syms.end()) && (iter->second != nullptr)); - inputs.push_back(GetDtypeByIndex(iter->second, index)); - } - return ge::SUCCESS; - } - - Status CollectOutputDtypes(const std::map &name_to_syms, size_t index, - std::vector &outputs) { - outputs.reserve(only_output_sym_ids_.size()); - for (const auto &sym_id : only_output_sym_ids_) { - auto iter = name_to_syms.find(sym_id); - GE_ASSERT_TRUE((iter != name_to_syms.end()) && (iter->second != nullptr)); - outputs.push_back(GetDtypeByIndex(iter->second, index)); - } - return ge::SUCCESS; - } - - static ge::DataType GetDtypeByIndex(const SymDtype *sym, size_t index) { - const auto &types = sym->GetOrderedTensorTypeList().GetOrderedDtypes(); - GE_ASSERT_TRUE(index < types.size()); - return types[index]; - } - - // 容器生成相关 - std::string GenerateResultContainer() { - std::map soc_to_solution_map; - BuildSolutionMap(soc_to_solution_map); - ContainerMeta container_meta = {.input_count = input_sym_id_to_ir_idx_.size(), - .output_count = only_output_sym_ids_.size()}; - - return BuildContainerString(soc_to_solution_map, container_meta); - } - - using SolutionMap = std::map, std::set>>; - - void BuildSolutionMap(std::map &soc_to_solution_map) { - for (auto &iter : soc_name_to_meta_) { - SolutionMap mapping; - for (const auto &pair : iter.second.results2) { - mapping[pair.first].insert(pair.second); - } - soc_to_solution_map.emplace(iter.first, std::move(mapping)); - } - } - - std::string BuildContainerString(const std::map &soc_to_solution_map, - const ContainerMeta &meta) { - std::ostringstream oss; - // get soc version. - oss << " char soc_version[128] = {};" << std::endl; - oss << " auto res = rtGetSocVersion(soc_version, 128U);" << std::endl; - oss << R"( GE_ASSERT_TRUE(res == RT_ERROR_NONE, "Failed to get soc version str.");)" << std::endl; - oss << " auto soc_str = std::string(soc_version);" << std::endl; - - container_type_ = GetContainerType(meta); - oss << " " << container_type_ << " results;" << std::endl; - bool is_first = true; - for (const auto &iter : soc_to_solution_map) { - if (is_first) { - oss << " if (soc_str == \"" << iter.first << "\") {\n"; - is_first = false; - } else { - oss << " } else if (soc_str == \"" << iter.first << "\") {\n"; - } - oss << GenContainerEntries(iter.second, meta) << std::endl; - } - oss << " } else {\n"; - oss << R"( GELOGE(ge::FAILED, "Failed to get soc version, res:%s", soc_str.c_str());)" << std::endl; - oss << " return ge::FAILED;\n"; - oss << " }\n"; - return oss.str(); - } - - static std::string GetContainerType(const ContainerMeta &meta) { - std::ostringstream oss; - oss << "std::map<" << (meta.input_count > 1 ? "std::vector" : "ge::DataType") << ", "; - - if (meta.output_count > 1) { - oss << "std::vector>"; - } else { - oss << "std::set"; - } - - oss << ">"; - return oss.str(); - } - - static std::string GenContainerEntries(const SolutionMap &mapping, const ContainerMeta &meta) { - std::stringstream ss; - ss << " results = {\n"; - - std::vector entries; - entries.reserve(mapping.size()); - for (const auto &inputs_2_outputs : mapping) { - entries.push_back(BuildEntryString(inputs_2_outputs.first, inputs_2_outputs.second, meta)); - } - ss << " " << JoinEntries(entries, ",\n "); - ss << "\n };"; - return ss.str(); - } - - static std::string BuildEntryString(const std::vector &input, - const std::set> &outputs, const ContainerMeta &meta) { - return "{" + SerializeVector(input) + ", " + SerializeOutputs(outputs, meta) + "}"; - } - - static std::string SerializeVector(const std::vector &vec) { - if (vec.size() == 1U) { - return DataTypeToSerialString(vec[0U]); - } - - std::ostringstream oss; - oss << "{"; - for (size_t i = 0U; i < vec.size(); ++i) { - oss << DataTypeToSerialString(vec[i]); - if (i < vec.size() - 1U) { - oss << ", "; - } - } - oss << "}"; - return oss.str(); - } - - static std::string SerializeOutputs(const std::set> &outputs, const ContainerMeta &meta) { - if (meta.output_count == 1U) { - return SerializeSingleOutput(outputs); - } - return SerializeMultiOutputs(outputs, meta); - } - - static std::string SerializeSingleOutput(const std::set> &outputs) { - std::vector unique_outputs; - for (const auto &out : outputs) { - unique_outputs.push_back(out.front()); - } - return SerializeSet(unique_outputs); - } - - static std::string SerializeMultiOutputs(const std::set> &outputs, - const ContainerMeta &meta) { - std::vector> output_sets; - for (size_t i = 0; i < meta.output_count; ++i) { - output_sets.emplace_back(); - for (const auto &out : outputs) { - output_sets.back().insert(out[i]); - } - } - return SerializeSetVector(output_sets); - } - - static std::string SerializeSet(const std::vector &types) { - std::ostringstream oss; - oss << "{"; - for (size_t i = 0U; i < types.size(); ++i) { - oss << DataTypeToSerialString(types[i]); - if (i < types.size() - 1U) { - oss << ", "; - } - } - oss << "}"; - return oss.str(); - } - - static std::string SerializeSetVector(const std::vector> &sets) { - std::ostringstream oss; - oss << "{"; - for (size_t i = 0U; i < sets.size(); ++i) { - oss << SerializeSet(std::vector{sets[i].begin(), sets[i].end()}); - if (i < sets.size() - 1U) { - oss << ", "; - } - } - oss << "}"; - return oss.str(); - } - - static std::string JoinEntries(const std::vector &entries, const std::string &delimiter) { - std::ostringstream oss; - for (size_t i = 0U; i < entries.size(); ++i) { - oss << entries[i]; - if (i < entries.size() - 1U) { - oss << delimiter; - } - } - return oss.str(); - } - - struct SocMeta { - // 每一个对key,value代表一个合法解, key是多个输入的实参dtype,value是多个输出的推导处理的dtype,之所以是multimap因为某个输出可能 - // 有多个解 - std::multimap, std::vector> results2{}; - size_t valid_dtype_nums_of_sym{0UL}; - }; - - std::map soc_name_to_meta_; - // 根据输入输出的个数和输出是否有唯一解需要实例化不同类型的容器 - std::string container_type_{}; - std::map input_sym_id_to_ir_idx_; - std::vector only_output_sym_ids_; - ContainerMeta container_meta_{}; -}; - -class InferDtypeCodeGenerator { - public: - explicit InferDtypeCodeGenerator(const AscIrDef &def) : def_(def) { - is_soc_ordered_dtype_infer_ = IsSocOrderedTensorListStore(def_); - is_ordered_dtype_infer_ = def_.GetDataTypeSymbolStore().IsSupportOrderedSymbolicInferDtype(); - } - Status Generate(std::stringstream &ss) { - GenerateFunctionSignature(ss); - GenerateArgsSizeAssertion(ss); - GE_ASSERT_SUCCESS(GenerateSymbolProcessing(ss)); - GenerateOutputHandling(ss); - GenerateReturnStatement(ss); - return SUCCESS; - } - - private: - static void GenerateFunctionSignature(std::stringstream &ss) { - ss << R"( inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) {)" - << std::endl; - } - - void GenerateArgsSizeAssertion(std::stringstream &ss) { - ss << " // 校验入参容器的元素个数是否合法" << std::endl; - ss << " GE_ASSERT_EQ(input_dtypes.size(), " << def_.GetInputDefs().size() << "U);" << std::endl; - ss << " GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == " - << def_.GetOutputDefs().size() << "U);" << std::endl; - ss << std::endl; - } - - Status GenerateSymbolProcessing(std::stringstream &ss) { - if (is_soc_ordered_dtype_infer_) { - SocOrderedSymbolProcessor ordered_symbol_processor(def_); - GE_ASSERT_SUCCESS(ordered_symbol_processor.PreProcessSymbol(ss)); - GE_ASSERT_SUCCESS(ordered_symbol_processor.InstanceSymbol(ss)); - ordered_symbol_processor.CheckSymbol(ss); - // output跟sym处理紧密相关,所以放在这里处理 - GE_ASSERT_SUCCESS(ordered_symbol_processor.HandleOutput(ss)); - ss << std::endl; - return SUCCESS; - } - if (is_ordered_dtype_infer_) { - OrderedSymbolProcessor ordered_symbol_processor(def_); - GE_ASSERT_SUCCESS(ordered_symbol_processor.PreProcessSymbol(ss)); - GE_ASSERT_SUCCESS(ordered_symbol_processor.InstanceSymbol(ss)); - ordered_symbol_processor.CheckSymbol(ss); - // output跟sym处理紧密相关,所以放在这里处理 - GE_ASSERT_SUCCESS(ordered_symbol_processor.HandleOutput(ss)); - ss << std::endl; - return SUCCESS; - } - ss << " // 校验同sym的输入的dtype是否在注册范围内并且一致" << std::endl; - SymbolProcessor symbol_processor(def_); - symbol_processor.GenSocVersionCallFunc(ss); - for (const auto &sym : GetFirstDataTypeSymbolStore(def_).GetNamedSymbols()) { - symbol_processor.ProcessSymbol(sym, ss); - } - ss << std::endl; - return SUCCESS; - } - - void GenerateOutputHandling(std::stringstream &ss) { - if (is_ordered_dtype_infer_ || is_soc_ordered_dtype_infer_) { - return; - } - OutputHandler handler(def_); - ss << " // 输出外部不指定的时候,生成推导的代码" << std::endl; - ss << " if (expect_output_dtypes.empty()) {" << std::endl; - handler.GenerateOutputInference(ss); - ss << " }" << std::endl; - ss << " // 输出外部指定,生成校验的代码" << std::endl; - handler.GenerateOutputValidation(ss); - } - - static void GenerateReturnStatement(std::stringstream &ss) { - ss << " return SUCCESS;" << std::endl; - ss << " };" << std::endl; - }; - const AscIrDef &def_; - bool is_ordered_dtype_infer_{false}; - bool is_soc_ordered_dtype_infer_{false}; -}; - -class InferDtypeWithNoCheckCodeGenerator { - public: - explicit InferDtypeWithNoCheckCodeGenerator(const AscIrDef &def) : def_(def) { - is_soc_ordered_dtype_infer_ = IsSocOrderedTensorListStore(def_); - is_ordered_dtype_infer_ = def_.GetDataTypeSymbolStore().IsSupportOrderedSymbolicInferDtype(); - } - Status Generate(std::stringstream &ss) const { - GenerateFunctionSignature(ss); - if (is_ordered_dtype_infer_ || is_soc_ordered_dtype_infer_) { - ss << " // 输入输出存在关联, 无法进行推导" << std::endl; - ss << " GELOGW(\"Node type %s is not supported to infernocheck for dtype.\", Type);" << std::endl; - ss << " return ge::FAILED;" << std::endl; - ss << " };" << std::endl << std::endl; - return SUCCESS; - } - GenerateArgsSizeAssertion(ss); - GE_ASSERT_SUCCESS(GenerateSymbolProcessing(ss)); - GenerateOutputHandling(ss); - GenerateReturnStatement(ss); - return SUCCESS; - } - - private: - static void GenerateFunctionSignature(std::stringstream &ss) { - ss << R"( inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) {)" - << std::endl; - } - - void GenerateArgsSizeAssertion(std::stringstream &ss) const { - ss << " // 校验入参容器的元素个数是否合法" << std::endl; - ss << " GE_ASSERT_EQ(input_dtypes.size(), " << def_.GetInputDefs().size() << "U);" << std::endl; - ss << " GE_ASSERT_TRUE(expect_output_dtypes.empty());" << std::endl; - - ss << std::endl; - } - - Status GenerateSymbolProcessing(std::stringstream &ss) const { - ss << " // 校验同sym的输入的dtype是否一致" << std::endl; - SymbolProcessor symbol_processor(def_); - for (const auto &sym : GetFirstDataTypeSymbolStore(def_).GetNamedSymbols()) { - symbol_processor.ProcessSymbolWithNoCheck(sym, ss); - } - ss << std::endl; - return SUCCESS; - } - - void GenerateOutputHandling(std::stringstream &ss) const { - constexpr int space_count = 4; - OutputHandler handler(def_); - - handler.GenerateOutputInference(ss, space_count); - } - - static void GenerateReturnStatement(std::stringstream &ss) { - ss << " };" << std::endl << std::endl; - }; - const AscIrDef &def_; - bool is_ordered_dtype_infer_{false}; - bool is_soc_ordered_dtype_infer_{false}; -}; - -Status GenInferDtypeFuncDef(const AscIrDef &def, std::stringstream &ss) { - InferDtypeCodeGenerator generator(def); - return generator.Generate(ss); -} - -Status GenInferDtypeWithNoCheckFuncDef(const AscIrDef &def, std::stringstream &ss) { - InferDtypeWithNoCheckCodeGenerator generator(def); - return generator.Generate(ss); -} - -void GenCopyConstructor(const AscIrDef &def, std::stringstream &ss) { - const std::string &ir_type = def.GetType(); - ss << " inline " << ir_type << "& operator=(const " << ir_type << "&) = delete;" << std::endl; - ss << " inline " << ir_type << "(" << ir_type << " &&) = delete;" << std::endl; - ss << " inline " << ir_type << "(const " << ir_type << " &other)"; - ss << " : ge::Operator(other),"; - ss << " attr(other.attr)"; - if (!def.GetAttrDefs().empty()) { - ss << ", ir_attr(dynamic_cast(*(attr.ir_attr)))"; - } - const auto &input_defs = def.GetInputDefs(); - for (const auto &input_def : input_defs) { - ss << "," << std::endl << " " << input_def.first << "(this)"; - } - const auto &output_defs = def.GetOutputDefs(); - for (size_t i = 0UL; i < output_defs.size(); ++i) { - if (output_defs[i].second != ge::IrOutputType::kIrOutputDynamic) { - ss << "," << std::endl << " " << output_defs[i].first << "(this, " << i << ")"; - } - } - ss << " {" << std::endl; - for (const auto &output_def : output_defs) { - if (output_def.second != ge::IrOutputType::kIrOutputDynamic) { - ss << " " << output_def.first << ".TryInitTensorAttr();" << std::endl; - } - } - - if ((output_defs.size() == 1U) && (output_defs[0].second == ge::IrOutputType::kIrOutputDynamic)) { - ss << " InstanceOutputy(other.y.size());" << std::endl; - } - - ss << " }" << std::endl; -} - -void GenInstanceOutputy(const AscIrDef &def, std::stringstream &ss) { - const auto &output_defs = def.GetOutputDefs(); - if (output_defs.size() != 1U) { - return; - } - if (output_defs[0].second != ge::IrOutputType::kIrOutputDynamic) { - return; - } - - ss << " void InstanceOutputy(uint32_t num) {" << std::endl; - ss << " this->DynamicOutputRegister(\"y\", num);" << std::endl; - ss << " for (size_t i = 0U; i < num; i++) {" << std::endl; - ss << " this->y.push_back(AscOpOutput(this, i));" << std::endl; - ss << " }" << std::endl; - ss << " }" << std::endl; - - return; -} - -void GenAscIr(const AscIrDef &def, std::stringstream &ss) { - const std::string &ir_type = def.GetType(); - ss << "namespace ascir_op {" << std::endl; - ss << "struct " << ir_type << " : public ge::Operator {" << std::endl; - - const auto &out_defs = def.GetOutputDefs(); - for (const auto &output_def : out_defs) { - if (output_def.second == ge::IrOutputType::kIrOutputDynamic) { - ss << " using Operator::DynamicOutputRegister;" << std::endl; - break; - } - } - - ss << " static constexpr const char *Type = \"" << ir_type << "\";" << std::endl; - ss << " AscNodeAttr &attr;" << std::endl; - const auto &ir_attr_class_name = TryGenIrAttrClass(def, ss); - // generate input output definitions - const auto &input_defs = def.GetInputDefs(); - for (size_t i = 0UL; i < input_defs.size(); ++i) { - const auto &input_def = input_defs[i]; - if (input_def.second == ge::IrInputType::kIrInputDynamic) { - ss << " AscOpDynamicInput<" << i << "> " << input_def.first << ";" << std::endl; - } else { - ss << " AscOpInput<" << i << "> " << input_def.first << ";" << std::endl; - } - } - - const auto &output_defs = def.GetOutputDefs(); - for (const auto &output_def : output_defs) { - if (output_def.second == ge::IrOutputType::kIrOutputDynamic) { - ss << " std::vector " << output_def.first << ";" << std::endl; - } else { - ss << " AscOpOutput " << output_def.first << ";" << std::endl; - } - } - - // generate constructor func definitions - if (def.IsStartNode()) { - GenConstructorDef(def, ir_attr_class_name, ss, true); - } - GenConstructorDef(def, ir_attr_class_name, ss); - (void) GenInferDtypeFuncDef(def, ss); - (void) GenInferDtypeWithNoCheckFuncDef(def, ss); - GenCopyConstructor(def, ss); - GenInstanceOutputy(def, ss); - ss << "};" << std::endl; - ss << "}" << std::endl; -} - -void GenIrComment(const AscIrDef &def, std::stringstream &ss) { - const auto &comment = def.GetComment(); - if (!comment.empty()) { - ss << "/* \n"; - ss << comment << "\n"; - ss << "*/ \n"; - } -} - -class FunctionGenerator { - public: - explicit FunctionGenerator(const AscIrDef &def) : def_(def) {} - virtual ~FunctionGenerator() = default; - - virtual void Gen(std::stringstream &ss, const bool has_optional_input) const { - GenDefinition(ss, has_optional_input); - - GenInstantiation(ss); - ss << std::endl; - - if (GenConnectInputs(ss, has_optional_input)) { - ss << std::endl; - } - - if (GenAttrAssignment(ss)) { - ss << std::endl; - } - - GenSchedInfo(ss); - ss << std::endl; - - if (GenOutputsAssignment(ss)) { - ss << std::endl; - } - GenOutputMemInfo(ss); - GenPaddingAxis(ss); - - // 计算向量化轴,向量化轴的计算顺序:输出View在当前API的Loop轴内侧(不包括Loop轴)的所有轴 - TryGenOutputsVectorizedAxis(ss); - - ss << std::endl; - - GenReturn(ss); - } - - virtual void GenDefinition(std::stringstream &ss, const bool has_optional_input) const; - virtual void GenInstantiation(std::stringstream &ss) const; - virtual bool GenConnectInputs(std::stringstream &ss, const bool has_optional_input) const; - virtual bool GenAttrAssignment(std::stringstream &ss) const; - virtual void GenSchedInfo(std::stringstream &ss) const { - ss << " op.attr.sched.exec_order = CodeGenUtils::GenNextExecId(op);" << std::endl; - ss << " SET_SCHED_AXIS_IF_IN_CONTEXT(op);" << std::endl; - } - virtual void TryGenOutputsVectorizedAxis(std::stringstream &ss) const { - const auto &output_defs = def_.GetOutputDefs(); - for (const auto &name : output_defs) { - ss << " *op." << name.first << ".vectorized_axis = AxisUtils::GetDefaultVectorizedAxis(*op." << name.first - << ".axis, op.attr.sched.loop_axis);" << std::endl; - } - } - virtual void GenOutputMemInfo(std::stringstream &ss) const { - const auto &output_defs = def_.GetOutputDefs(); - for (const auto &name : output_defs) { - ss << " op." << name.first << ".mem->tensor_id = " - << "CodeGenUtils::GenNextTensorId(op);" << std::endl; - } - } - virtual bool GenOutputsAssignment(std::stringstream &ss) const { - bool generated = false; - - // generate infer data type code - if (def_.infer_data_type_generator != nullptr) { - generated = true; - def_.infer_data_type_generator(def_, ss); - } - if (def_.infer_view_generator != nullptr) { - generated = true; - def_.infer_view_generator(def_, ss); - } - return generated; - } - virtual void GenPaddingAxis(std::stringstream &ss) const { - const auto &output_defs = def_.GetOutputDefs(); - for (const auto &name : output_defs) { - ss << " THROW(PadOutputViewToSched(op." << name.first << "));" << std::endl; - } - } - virtual void GenReturn(std::stringstream &ss) const { - const auto &output_defs = def_.GetOutputDefs(); - if (output_defs.empty()) { - ss << " return op;" << std::endl; - } else if (output_defs.size() == 1U) { - ss << " return op." << output_defs[0U].first << ";" << std::endl; - } else { - ss << " return std::make_tuple("; - for (size_t i = 0; i < output_defs.size(); ++i) { - if (i == 0) { - ss << "op." << output_defs[i].first; - } else { - ss << " ,op." << output_defs[i].first; - } - } - ss << ");" << std::endl; - } - ss << "}" << std::endl; - } - - protected: - const AscIrDef &def_; - - private: - static bool NeedConnectByInputArgs(const bool has_optional_input, - const std::pair &input_def) { - return ((has_optional_input || (input_def.second != ge::IrInputType::kIrInputOptional)) && - (input_def.second != ge::IrInputType::kIrInputDynamic)); - } -}; - -void ascir::FunctionGenerator::GenDefinition(std::stringstream &ss, const bool has_optional_input) const { - const std::vector> *input_defs; - std::vector> empty_input_defs; - - if (def_.IsStartNode()) { - // TODO 由于历史原因,IsStartNode()(例如Data)仍然带有输入定义,但是这种输入实际是不连边的。 - // 但是为了最小化修改,当前先不修改Data的定义,后续需要做调整,对与StartNode类型,不定义输入, - // 或者认为没有输入的op就是start node,在定义IR时不需要再显式指定start node标记 - input_defs = &empty_input_defs; - } else { - input_defs = &def_.GetInputDefs(); - } - auto append_output_types = [&ss](size_t count) { - for (size_t i = 0; i < count; ++i) { - if (i != 0) { - ss << ", "; - } - ss << "AscOpOutput"; - } - }; - const auto &output_defs = def_.GetOutputDefs(); - ss << "inline "; - if (output_defs.size() > 1U) { - ss << "std::tuple<"; - append_output_types(output_defs.size()); - ss << "> " << def_.GetType() << "(const char* name"; - } else { - ss << "AscOpOutput " << def_.GetType() << "(const char* name"; - } - if (!input_defs->empty()) { - for (const auto &input_def : *input_defs) { - if (NeedConnectByInputArgs(has_optional_input, input_def)) { - ss << ", const ge::AscOpOutput &" << input_def.first << "_in"; - } - } - } else { - ss << ", ge::AscGraph &graph"; - } - const auto &attr_defs = def_.GetAttrDefs(); - for (const auto &attr_def : attr_defs) { - ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; - } - ss << ") {" << std::endl; -} -void ascir::FunctionGenerator::GenInstantiation(std::stringstream &ss) const { - if (def_.IsStartNode()) { - ss << " const auto &op_ptr = std::make_shared(name, graph);" << std::endl; - } else { - ss << " const auto &op_ptr = std::make_shared(name);" << std::endl; - } - ss << " auto &op = *op_ptr;" << std::endl; - ss << " const auto &desc = ge::OpDescUtils::GetOpDescFromOperator(op);" << std::endl; - ss << " desc->SetExtAttr(RELATED_OP, op_ptr);" << std::endl; -} -bool ascir::FunctionGenerator::GenConnectInputs(std::stringstream &ss, const bool has_optional_input) const { - // TODO 这里与GenFunctionDefinition同理,后续删除 - if (def_.IsStartNode()) { - return false; - } - const auto &input_defs = def_.GetInputDefs(); - if (!input_defs.empty()) { - for (const auto &input_def : input_defs) { - if (NeedConnectByInputArgs(has_optional_input, input_def)) { - ss << " op." << input_def.first << " = " << input_def.first << "_in;" << std::endl; - } - } - } - return !input_defs.empty(); -} -bool ascir::FunctionGenerator::GenAttrAssignment(std::stringstream &ss) const { - const auto &attr_defs = def_.GetAttrDefs(); - if (!attr_defs.empty()) { - for (const auto &attr_def : attr_defs) { - // 函数命名大驼峰,所以把属性名第一个字符转换成大写字母 - ss << " op.ir_attr.Set" << CapitalizeFirstLetter(attr_def.name) << "(" << attr_def.name << ");" << std::endl; - } - } - return !attr_defs.empty(); -} - -class StartNodeFuncGenerator : public FunctionGenerator { - public: - explicit StartNodeFuncGenerator(const AscIrDef &def) : FunctionGenerator(def) {} - void Gen(std::stringstream &ss, const bool has_optional_input) const override { - if (!def_.IsStartNode() || def_.GetOutputDefs().size() != 1UL) { - return; - } - FunctionGenerator::Gen(ss, has_optional_input); - } - void GenDefinition(std::stringstream &ss, const bool has_optional_input) const override { - (void) has_optional_input; - // inline ascir::ops::OpType OpType - ss << "inline " - << "AscOpOutput " << ' ' << def_.GetType() << "(const char *name, ge::AscGraph &graph, ge::DataType dt" - << ", const std::vector &axis_ids" - << ", const std::vector &repeats" - << ", const std::vector &strides"; - - const auto &attr_defs = def_.GetAttrDefs(); - for (const auto &attr_def : attr_defs) { - ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; - } - ss << ") {" << std::endl; - } - bool GenOutputsAssignment(std::stringstream &ss) const override { - const auto &output_defs = def_.GetOutputDefs(); - const auto &output_name = output_defs[0].first; - ss << " op." << output_name << ".dtype = dt;" << std::endl; - ss << " *op." << output_name << ".axis = axis_ids;" << std::endl; - ss << " *op." << output_name << ".repeats = repeats;" << std::endl; - ss << " *op." << output_name << ".strides = strides;" << std::endl; - return true; - } -}; - -class StoreNodeFuncGenerator : public FunctionGenerator { - public: - explicit StoreNodeFuncGenerator(const AscIrDef &def) : FunctionGenerator(def) {} - void Gen(std::stringstream &ss, const bool has_optional_input) const override { - (void) has_optional_input; - if (def_.GetType() != "Store") { - return; - } - ss << "inline " - << "void" << ' ' << def_.GetType() << "(const char *name"; - ss << ", const ge::AscOpOutput &" << "ub_in"; - ss << ", ge::AscOpOutput &gm_output"; - const auto &attr_defs = def_.GetAttrDefs(); - for (const auto &attr_def : attr_defs) { - ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; - } - ss << ") {" << std::endl; - ss << " auto store_out = Store(name, ub_in"; - for (const auto &attr_def : attr_defs) { - ss << ", " << attr_def.name; - } - ss << ");" << std::endl; - ss << " auto &gm_producer = const_cast(gm_output.GetOwnerOp());" << std::endl; - ss << " auto &store_op = const_cast(store_out.GetOwnerOp());" << std::endl; - ss << " gm_producer.SetInput(0U, store_op, 0U);" << std::endl; - ss << " AddEdgeForNode(store_op, 0U, gm_producer, 0U);" << std::endl; - ss << " auto *gm_producer_attr = CodeGenUtils::GetOwnerOpAscAttr(gm_producer);" << std::endl; - ss << " gm_producer_attr->sched.exec_order = CodeGenUtils::GenNextExecId(store_op);" << std::endl; - ss << "}" << std::endl; - } -}; - -class ContiguousStartNodeFuncGenerator : FunctionGenerator { - public: - explicit ContiguousStartNodeFuncGenerator(const AscIrDef &def) : FunctionGenerator(def) {} - void Gen(std::stringstream &ss, const bool has_optional_input) const override { - if (!def_.IsStartNode() || def_.GetOutputDefs().size() != 1UL) { - return; - } - FunctionGenerator::Gen(ss, has_optional_input); - } - void GenDefinition(std::stringstream &ss, const bool has_optional_input) const override { - (void) has_optional_input; - ss << "inline " - << "AscOpOutput" << " Contiguous" << def_.GetType() << "(const char *name, ge::AscGraph &graph, ge::DataType dt" - << ", const std::vector &axes"; - const auto &attr_defs = def_.GetAttrDefs(); - for (const auto &attr_def : attr_defs) { - ss << ", const " << attr_def.asc_ir_type << " &" << attr_def.name; - } - ss << ") {" << std::endl; - } - bool GenOutputsAssignment(std::stringstream &ss) const override { - const auto &output_name = def_.GetOutputDefs()[0].first; - ss << " op." << output_name << ".dtype = dt;" << std::endl; - ss << " op." << output_name << ".SetContiguousView(axes);" << std::endl; - return true; - } -}; - -void GetHeaderGuarderFromPath(const char *path, std::stringstream &ss) { - auto name = GetPureFileName(path); - - ss << "ASCIR_OPS_"; - - while (*name != '\0') { - auto c = toupper(*name++); - if (c < 'A' || c > 'Z') { - ss << '_'; - } else { - ss << static_cast(c); - } - } - - ss << '_'; -} -} // namespace - -void GenFunc(const AscIrDef &def, std::stringstream &ss) { - FunctionGenerator(def).Gen(ss, false); - StartNodeFuncGenerator(def).Gen(ss, false); - ContiguousStartNodeFuncGenerator(def).Gen(ss, false); - StoreNodeFuncGenerator(def).Gen(ss, false); - bool has_optional_input = false; - const auto &input_defs = def.GetInputDefs(); - for (const auto &input_def : input_defs) { - if (input_def.second == ge::IrInputType::kIrInputOptional) { - has_optional_input = true; - break; - } - } - if (has_optional_input) { - FunctionGenerator(def).Gen(ss, true); - StartNodeFuncGenerator(def).Gen(ss, true); - ContiguousStartNodeFuncGenerator(def).Gen(ss, true); - StoreNodeFuncGenerator(def).Gen(ss, true); - } -} - -void GenCalcBufFunc(std::stringstream &ss, - const std::map, AscIrDef> &ordered_keys_to_def) { - std::stringstream ss_calc_tmp_buff_map; - std::stringstream ss_calc_tmp_buff; - for (auto &key_and_def : ordered_keys_to_def) { - const auto &calc_func = key_and_def.second.GetCalcTmpBufSizeFunc(); - if (calc_func.func_name.empty()) { - continue; - } - if (calc_func.func_type == CalcTmpBufSizeFuncType::CustomizeType) { - ss_calc_tmp_buff << "extern std::vector> "; - ss_calc_tmp_buff << calc_func.func_name << "(const ge::AscNode &Node);" << std::endl; - } - ss_calc_tmp_buff_map << " {\"" << key_and_def.second.GetType() << "\", &"; - ss_calc_tmp_buff_map << calc_func.func_name << "}," << std::endl; - } - // 没有API注册时不生成CalcBuf函数 - if (ss_calc_tmp_buff_map.str().empty()) { - return; - } - ss << ss_calc_tmp_buff.str(); - ss << "inline std::vector> CalcAscNodeTmpSize(const ge::AscNode &node) {" - << std::endl; - ss << " typedef std::vector> (*calc_func_ptr) (const AscNode &node);" << std::endl; - ss << " static const std::unordered_map node_calc_tmp_buff_map = {" << std::endl; - ss << ss_calc_tmp_buff_map.str(); - ss << " };" << std::endl; - ss << " ge::AscNodeAttr attr = node.attr;" << std::endl; - ss << " if (node_calc_tmp_buff_map.find(attr.type) != node_calc_tmp_buff_map.end()) {" << std::endl; - ss << " return node_calc_tmp_buff_map.at(node.attr.type)(node);" << std::endl; - ss << " }" << std::endl; - ss << " return std::vector>();" << std::endl; - ss << "}" << std::endl; -} - -void GenCommonInferDtypeBaseFunc(std::stringstream &ss, - const std::map, AscIrDef> &ordered_keys_to_def, - const string &extra_str = "") { - std::stringstream func_table; - for (const auto &key_and_def : ordered_keys_to_def) { - const auto &node_type = key_and_def.second.GetType(); - func_table << " {\"" << node_type << "\", "; - func_table << "::ge::ascir_op::" << node_type << "::InferDataType" << extra_str << "}," << std::endl; - } - if (func_table.str().empty()) { - return; - } - ss << "inline ge::Status CommonInferDtype" << extra_str - << "(const std::string &type, const std::vector &input_dtypes,\n" - " std::vector &expect_output_dtypes) {" - << std::endl; - ss << " using func = ge::Status (*)(const std::vector &input_dtypes, \n" - " std::vector &expect_output_dtypes);" - << std::endl; - ss << " static const std::unordered_map func_table = {" << std::endl; - ss << func_table.str(); - ss << " };" << std::endl; - ss << " const auto &iter = func_table.find(type);" << std::endl; - ss << " if (iter != func_table.end()) {" << std::endl; - ss << " return iter->second(input_dtypes, expect_output_dtypes);" << std::endl; - ss << " }" << std::endl; - ss << " GELOGW(\"Node type %s is not supported to infer for now!\", type.c_str());" << std::endl; - ss << " return ge::FAILED;" << std::endl; - ss << "}" << std::endl; -} - -void GenCommonInferDtypeFunc(std::stringstream &ss, - const std::map, AscIrDef> &ordered_keys_to_def) { - GenCommonInferDtypeBaseFunc(ss, ordered_keys_to_def); -} - -void GenCommonInferDtypeWithNoCheckFunc( - std::stringstream &ss, const std::map, AscIrDef> &ordered_keys_to_def) { - GenCommonInferDtypeBaseFunc(ss, ordered_keys_to_def, "WithNoCheck"); -} - -void GenAll(std::stringstream &ss) { - std::stringstream ss_asc_ir; - std::stringstream ss_ge_ir; - - ss << R"(#include "ascendc_ir/utils/cg_calc_tmp_buff_common_funcs.h")" << std::endl << std::endl; - ss << R"(#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h")" << std::endl << std::endl; - ss << R"(#include "ascendc_ir/ascend_reg_ops.h")" << std::endl << std::endl; - ss << R"(#include "utils/cg_utils.h")" << std::endl << std::endl; - ss << R"(#include "graph/type/tensor_type_impl.h")" << std::endl << std::endl; - ss << R"(#include "graph/type/sym_dtype.h")" << std::endl << std::endl; - ss << R"(#include "runtime/dev.h")" << std::endl << std::endl; - ss << "#include " << std::endl; - ss << "#include " << std::endl; - ss << "#include " << std::endl << std::endl; - - std::map, AscIrDef> ordered_keys_to_def; - for (const auto &type_and_def : AscirRegistry::GetInstance().GetAll()) { - ordered_keys_to_def[std::make_pair(type_and_def.second.GetFilePath(), type_and_def.second.GetLine())] = - type_and_def.second; - } - - for (const auto &key_and_def : ordered_keys_to_def) { - ss << "// Defined at " << GetPureFileName(key_and_def.second.GetFilePath().c_str()) << ':' - << key_and_def.second.GetLine() << std::endl; - GenIrComment(key_and_def.second, ss); - ss << "namespace ge {" << std::endl; - GenAscIr(key_and_def.second, ss); - ss << "}" << std::endl << std::endl; // namespace ascir - } - - ss << "namespace ge {" << std::endl; - ss << "namespace ascir {" << std::endl; - ss << "namespace cg {" << std::endl; - for (auto &key_and_def : ordered_keys_to_def) { - const auto &ascir_def = key_and_def.second; - // 增加判断, 如果是动态输出, continue; - if (ascir_def.HasDynamicOutput()) { - continue; - } - - GenFunc(key_and_def.second, ss); - // 如果有node属性配置,重载一个不设置属性的构造函数,把属性变成可选 - if (!ascir_def.GetAttrDefs().empty()) { - ascir_def.MutableAttrDefs().clear(); - GenFunc(ascir_def, ss); - } - } - - ss << "}" << std::endl; // namespace cg - GenCalcBufFunc(ss, ordered_keys_to_def); - GenCommonInferDtypeFunc(ss, ordered_keys_to_def); - GenCommonInferDtypeWithNoCheckFunc(ss, ordered_keys_to_def); - ss << "}" << std::endl; // namespace ascir - ss << "}" << std::endl << std::endl; // namespace ge - AscirRegistry::GetInstance().ClearAll(); -} - -void GenHeaderFileToStream(const char *path, std::stringstream &ss) { - std::stringstream ss_header_guarder; - GetHeaderGuarderFromPath(path, ss_header_guarder); - auto guarder = ss_header_guarder.str(); - - ss << "// Generated from asc-ir definition files, " - "any modification made to this file may be overwritten after compile." - << std::endl; - ss << "// If you want to add self-defined asc-ir, please create a seperated header file." << std::endl; - ss << "#ifndef " << guarder << std::endl; - ss << "#define " << guarder << std::endl << std::endl; - - GenAll(ss); - - ss << "#endif // " << guarder << std::endl; -} - -int GenHeaderFile(const char *path) { - std::stringstream ss; - GenHeaderFileToStream(path, ss); - AscirRegistry::GetInstance().ClearAll(); - std::ofstream fs(path); - if (!fs) { - return -1; - } - fs << ss.str(); - fs.close(); - return 0; -} -} // namespace ascir -} // namespace ge diff --git a/graph/ascendc_ir/generator/generator.cmake b/graph/ascendc_ir/generator/generator.cmake deleted file mode 100644 index 4bb9f570a58d1e4dda3786e4b406368c75165290..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/generator.cmake +++ /dev/null @@ -1,69 +0,0 @@ -# 递归收集目标的所有动态库依赖路径 -function(get_all_dynamic_dirs target result_var) - set(dirs "") - # 获取目标的直接依赖库 - get_target_property(libs ${target} LINK_LIBRARIES) - foreach (lib IN LISTS libs) - # 仅处理 CMake 目标(排除系统库和绝对路径) - if (TARGET ${lib}) - get_target_property(type ${lib} TYPE) - # 处理动态库(SHARED_LIBRARY) - if (type STREQUAL "SHARED_LIBRARY") - # 获取动态库的输出目录 - get_target_property(lib_output_dir ${lib} LIBRARY_OUTPUT_DIRECTORY) - if (NOT lib_output_dir) - set(lib_output_dir $) - endif () - list(APPEND dirs ${lib_output_dir}) - # 递归处理依赖 - get_all_dynamic_dirs(${lib} child_dirs) - list(APPEND dirs ${child_dirs}) - endif () - endif () - endforeach () - list(REMOVE_DUPLICATES dirs) - set(${result_var} ${dirs} PARENT_SCOPE) -endfunction() - -function(ascir_generate depend_so_target bin_dir so_var h_var) - # 1. 收集所有动态库依赖的生成器表达式 - get_all_dynamic_dirs(ascir_ops_header_generator lib_dirs_genex) - - # 2. 添加自定义命令,ascend-toolkit的LD_PATH放在最后, 保证优先使用编译路径下的so - add_custom_command( - OUTPUT ${h_var} - DEPENDS ${depend_so_target} ascir_ops_header_generator - COMMAND ${CMAKE_COMMAND} -E echo "Raw Library Paths: $" - COMMAND bash -c " \ - IFS=: read -ra paths <<< '$'; \ - non_ascend=(); \ - ascend=(); \ - seen=(); \ - for p in \"\${paths[@]}\"; do \ - duplicate=0; \ - for s in \"\${seen[@]}\"; do \ - if [ \"\${p}\" = \"\${s}\" ]; then \ - duplicate=1; \ - break; \ - fi; \ - done; \ - if [ \$duplicate -eq 1 ]; then \ - continue; \ - fi; \ - seen+=(\"\${p}\"); \ - if [[ \"\${p}\" == */latest/* ]]; then \ - ascend+=(\"\${p}\"); \ - else \ - non_ascend+=(\"\${p}\"); \ - fi; \ - done; \ - final_paths=(\"\${non_ascend[@]}\" \"\${ascend[@]}\"); \ - lib_path_str=\$(IFS=:; echo \"\${final_paths[*]}\"); \ - echo \"Adjusted LD_LIBRARY_PATH: \${lib_path_str}:\$LD_LIBRARY_PATH\"; \ - export LD_LIBRARY_PATH=\"\${lib_path_str}:\$LD_LIBRARY_PATH\"; \ - '${bin_dir}/ascir_ops_header_generator' '${so_var}' '${h_var}' \ - " - VERBATIM - COMMENT "Generating header ${h_var} with fresh dependencies" - ) -endfunction() \ No newline at end of file diff --git a/graph/ascendc_ir/generator/generator.h b/graph/ascendc_ir/generator/generator.h deleted file mode 100644 index fc4deb1b9f33fa93414abb6ea8ece124e6f05076..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/generator/generator.h +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AUTOFUSE_GENERATOR_H -#define AUTOFUSE_GENERATOR_H -namespace ge { -namespace ascir { -__attribute__((used)) int GenHeaderFile(const char *path); -} -} -#endif // AUTOFUSE_GENERATOR_H diff --git a/graph/ascendc_ir/utils/asc_graph_utils.cc b/graph/ascendc_ir/utils/asc_graph_utils.cc deleted file mode 100644 index cc31423b37d01045ac69f001ef62ee87f2ac0c13..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/asc_graph_utils.cc +++ /dev/null @@ -1,384 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ascendc_ir/utils/asc_graph_utils.h" -#include -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/ascendc_ir/core/ascendc_ir_impl.h" -#include "proto/ascendc_ir.pb.h" - -namespace ge { -namespace { -graphStatus EstablishAscNodeAndEdges(const ascendc_ir::proto::AscGraphDef &asc_graph_def, AscGraph &out_asc_graph) { - auto &asc_nodes = asc_graph_def.asc_node(); - // 1. Add AscNodes to AscGraph - for (const auto &asc_node : asc_nodes) { - auto &node_attr = asc_node.attr(); - auto &ir_attr = asc_node.ir_def(); - auto &op_name = node_attr.name(); - auto &op_type = ir_attr.type(); - OpDescBuilder op_builder(op_name, op_type); - const auto &inputs_nums = ir_attr.input_nums(); - const auto &inputs_names = ir_attr.input_names(); - const auto &input_types = ir_attr.input_ir_type(); - GE_ASSERT_TRUE(inputs_nums.size() == input_types.size(), - "[Build][Op] for %s failed, inputs_nums[%zu] " - "input_types[%zu] not equal.", - op_name.c_str(), inputs_nums.size(), input_types.size()); - GE_ASSERT_TRUE(inputs_nums.size() == inputs_names.size(), - "[Build][Op] for %s failed, inputs_nums[%zu] " - "inputs_names[%zu] not equal.", - op_name.c_str(), inputs_nums.size(), inputs_names.size()); - for (int32_t ir_id = 0; ir_id < inputs_nums.size(); ir_id++) { - if (input_types[ir_id] == IrInputType::kIrInputDynamic) { - op_builder.AddDynamicInput(ir_attr.input_names(ir_id), inputs_nums[ir_id]); - GELOGD("Add dynamic input[%s] of node[%s] success, input num[%ld].", inputs_names[ir_id].c_str(), - op_name.c_str(), inputs_nums[ir_id]); - } else { - op_builder.AddInput(inputs_names[ir_id]); - GELOGD("Add input[%s] of node[%s] success.", inputs_names[ir_id].c_str(), op_name.c_str()); - } - } - for (auto &ir_output_name : ir_attr.output_names()) { - op_builder.AddOutput(ir_output_name); - } - auto op_desc = op_builder.Build(); - GE_ASSERT_NOTNULL(op_desc, "[Build][Op] for %s failed.", op_name.c_str()); - auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - out_asc_graph.AddNode(op); - GELOGD("Add node[%s:%s] inputs size[%d] output size[%d] success.", op_name.c_str(), op_type.c_str(), - ir_attr.input_names().size(), ir_attr.output_names().size()); - } - const auto &compute_graph = AscGraphUtils::GetComputeGraph(out_asc_graph); - // 2. Add edge between all AscNodes - for (const auto &asc_node : asc_nodes) { - auto &node_attr = asc_node.attr(); - auto &dst_node_name = node_attr.name(); - int32_t dst_in_index = 0; - for (const auto &input : asc_node.input_src()) { - const auto &src_node_name = input.src_node_name(); - const auto &src_out_index = input.src_out_index(); - if (src_node_name.empty()) { - GELOGW("[Get][SrcNodeName] failed of node [%s:%d]", node_attr.name().c_str(), dst_in_index); - continue; - } - const auto &src_node = compute_graph->FindNode(src_node_name); - GE_ASSERT_NOTNULL(src_node, "[Find][SrcNode] %s failed, dst_node[%s].", src_node_name.c_str(), - node_attr.name().c_str()); - const auto &dst_node = compute_graph->FindNode(dst_node_name); - GE_ASSERT_NOTNULL(dst_node, "[Find][DstNode] %s failed, dst_node[%s].", src_node_name.c_str(), - node_attr.name().c_str()); - GE_ASSERT_GRAPH_SUCCESS( - GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_out_index), dst_node->GetInDataAnchor(dst_in_index)), - "[Add][Edge] failed to add edge from node[%s:%d] to node[%s:%d] failed.Possible duplicate link.", - src_node_name.c_str(), src_out_index, dst_node_name.c_str(), dst_in_index); - GELOGD("[Add][Edge] from node[%s:%d] to node[%s:%d].", src_node_name.c_str(), src_out_index, - dst_node_name.c_str(), dst_in_index); - dst_in_index++; - } - } - GELOGD("Deserialize graph[%s] success, graph node size[%zu]", out_asc_graph.GetName().c_str(), - compute_graph->GetDirectNodesSize()); - return GRAPH_SUCCESS; -} -} -ComputeGraphPtr AscGraphUtils::GetComputeGraph(const AscGraph &asc_graph) { - return asc_graph.impl_->GetComputeGraph(); -} - -Status AscGraphUtils::FromComputeGraph(const ge::ComputeGraphPtr &compute_graph, ge::AscGraph &graph) { - GE_ASSERT_NOTNULL(compute_graph); - GE_ASSERT_NOTNULL(graph.impl_); - graph.impl_->compute_graph_ = compute_graph; - return ge::SUCCESS; -} - -graphStatus AscGraphUtils::SerializeToProto(const ge::AscGraph &asc_graph, - ascendc_ir::proto::AscGraphDef &asc_graph_def) { - const auto &ge_graph = AscGraphUtils::GetComputeGraph(asc_graph); - GE_ASSERT_NOTNULL(ge_graph); - asc_graph_def.set_graph_name(asc_graph.GetName()); - // serialize asc graph attr - auto asc_graph_attr_def = asc_graph_def.mutable_asc_graph_attr(); - GE_ASSERT_NOTNULL(asc_graph_attr_def); - auto asc_graph_attr = ge_graph->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(asc_graph_attr, "[GetOrCreate][AscGraphAttr] failed, graph[%s].", asc_graph.GetName().c_str()); - asc_graph_attr->SerializeAttr(*asc_graph_attr_def); - // serialize asc nodes - for (const auto &node : asc_graph.GetAllNodes()) { - GE_ASSERT_NOTNULL(node, "[Get][Node] failed, graph[%s]", asc_graph.GetName().c_str()); - // serialize asc node def - auto node_def = asc_graph_def.add_asc_node(); - GE_ASSERT_NOTNULL(node_def, "[Add][AscNode] proto failed, graph[%s]", asc_graph.GetName().c_str()); - // serialize asc node attr - const auto attr_def = node_def->mutable_attr(); - GE_ASSERT_NOTNULL(attr_def, "[Get][Attr] failed, graph[%s]", asc_graph.GetName().c_str()); - GE_ASSERT_GRAPH_SUCCESS(AscNodeSerializeUtils::SerializeAttrGroupsDef(*node, *attr_def), - "[Serialize][Attr] failed, graph[%s]", asc_graph.GetName().c_str()); - // serialize asc node ir def - auto ir_def = node_def->mutable_ir_def(); - GE_ASSERT_NOTNULL(ir_def, "[Get][IrAttr] failed, graph[%s]", asc_graph.GetName().c_str()); - GE_ASSERT_GRAPH_SUCCESS(AscNodeSerializeUtils::SerializeIrDef(*node, *ir_def), - "[Serialize][IrAttr] failed, graph[%s]", asc_graph.GetName().c_str()); - GELOGD("[AscGraphUtils]Serialize ir attr node[%s:%s] success.", node->GetNamePtr(), ir_def->type().c_str()); - const auto &op_desc = node->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - // serialize asc in tensor - const auto &ir_index_map = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(op_desc); - size_t ir_id = 0UL; - size_t cur_input_id = 0UL; - for (const auto &ir_input : op_desc->GetIrInputs()) { - size_t start = cur_input_id; - size_t end = cur_input_id; - if (ir_input.second == IrInputType::kIrInputDynamic) { - const auto &range_iter = ir_index_map.find(ir_id); - if (range_iter != ir_index_map.cend()) { - start = range_iter->second.first; - end = range_iter->second.second; - } - } else { - end = start + 1; - } - GE_ASSERT_TRUE(start <= end); - GE_ASSERT_TRUE(end <= op_desc->GetAllInputsSize()); - int64_t input_nums = static_cast(end - start); - ir_def->add_input_nums(input_nums); - GELOGD("Add input nums[%d, end:%zu, start:%zu] for node[%s:%s, %d]", input_nums, end, start, - op_desc->GetNamePtr(), op_desc->GetTypePtr(), ir_id); - for (auto id = static_cast(start); id < static_cast(end); id++) { - const auto input_src = node_def->add_input_src(); - const auto &in_anchor = node->GetInDataAnchor(id); - GE_ASSERT_NOTNULL(in_anchor, "[Get][InDataAnchor] failed, graph[%s]", asc_graph.GetName().c_str()); - const auto &peer_out = in_anchor->GetPeerOutAnchor(); - if (peer_out == nullptr) { - GELOGW("[Get][PeerOut] failed of node [%s:%d]", node->GetNamePtr(), in_anchor->GetIdx()); - continue; - } - const auto &src_node = peer_out->GetOwnerNodeBarePtr(); - GE_ASSERT_NOTNULL(src_node, "[Get][SrcNode] failed, graph[%s]", asc_graph.GetName().c_str()); - input_src->set_src_node_name(src_node->GetName()); - input_src->set_src_out_index(peer_out->GetIdx()); - GELOGD("Set src node name[%s:%d] for node[%s:%s, %d]", src_node->GetName().c_str(), peer_out->GetIdx(), - op_desc->GetNamePtr(), op_desc->GetTypePtr(), id); - } - ir_id++; - cur_input_id = end; - } - // serialize asc out tensor - for (const auto &tensor : op_desc->GetAllOutputsDescPtr()) { - const auto tensor_attr = tensor->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(tensor_attr, "[GetOrCreate][AscTensorAttr] failed, graph[%s]", asc_graph.GetName().c_str()); - const auto tensor_def = node_def->add_outputs(); - const auto tensor_attr_def = tensor_def->mutable_attr(); - GE_ASSERT_NOTNULL(tensor_attr_def, "[Get][TensorAttr] failed, graph[%s]", asc_graph.GetName().c_str()); - tensor_attr->SerializeAttr(*tensor_attr_def); - } - GELOGD("Serialize node[%s:%s] success, ir_id[%zu].", node->GetNamePtr(), node->GetTypePtr(), ir_id); - } - return GRAPH_SUCCESS; -} - -graphStatus AscGraphUtils::SerializeToBinary(const AscGraph &asc_graph, std::string &output) { - ascendc_ir::proto::AscGraphDef asc_graph_def; - GE_ASSERT_GRAPH_SUCCESS(SerializeToProto(asc_graph, asc_graph_def), "SerializeToProto failed."); - GE_ASSERT_TRUE(asc_graph_def.SerializeToString(&output)); - return GRAPH_SUCCESS; -} - -graphStatus AscGraphUtils::DeserializeFromBinary(const std::string &to_be_deserialized, AscGraph &out_asc_graph) { - ascendc_ir::proto::AscGraphDef asc_graph_def; - GE_ASSERT_TRUE(asc_graph_def.ParseFromString(to_be_deserialized)); - GE_ASSERT_GRAPH_SUCCESS(DeserializeFromProto(asc_graph_def, out_asc_graph)); - return GRAPH_SUCCESS; -} - -graphStatus AscGraphUtils::SerializeToReadable(const AscGraph &asc_graph, std::string &output) { - ascendc_ir::proto::AscGraphDef asc_graph_def; - GE_ASSERT_GRAPH_SUCCESS(SerializeToProto(asc_graph, asc_graph_def), "SerializeToProto failed."); - GE_ASSERT_TRUE(google::protobuf::TextFormat::PrintToString(asc_graph_def, &output), "SerializeToReadable failed."); - return GRAPH_SUCCESS; -} - -graphStatus AscGraphUtils::DeserializeFromReadable(const std::string &to_be_deserialized, AscGraph &out_asc_graph) { - ascendc_ir::proto::AscGraphDef asc_graph_def; - GE_ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(to_be_deserialized, &asc_graph_def)); - GE_ASSERT_GRAPH_SUCCESS(DeserializeFromProto(asc_graph_def, out_asc_graph)); - return GRAPH_SUCCESS; -} - -graphStatus AscGraphUtils::DeserializeFromProto(const ascendc_ir::proto::AscGraphDef &asc_graph_def, - AscGraph &asc_graph) { - auto &graph_name = asc_graph_def.graph_name(); - // 1. Add AscGraph - asc_graph.impl_->compute_graph_->SetName(graph_name); - GE_ASSERT_GRAPH_SUCCESS(EstablishAscNodeAndEdges(asc_graph_def, asc_graph), - "EstablishAscNodeAndEdges of graph[%s] failed.", graph_name.c_str()); - // deserialize graph attr - const auto &ge_graph = GetComputeGraph(asc_graph); - const auto asc_graph_attr = ge_graph->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(asc_graph_attr, "[GetOrCreate][AscGraphAttr] failed, graph[%s].", graph_name.c_str()); - GE_ASSERT_GRAPH_SUCCESS(asc_graph_attr->DeserializeAttr(asc_graph_def.asc_graph_attr())); - // deserialize node attr, protobuf make sure the order as serializing - const auto &asc_nodes = asc_graph_def.asc_node(); - int64_t node_index = 0L; - for (const auto &asc_node : asc_graph.GetAllNodes()) { - GE_ASSERT_NOTNULL(asc_node, "[Get][Node] failed, graph[%s]", graph_name.c_str()); - // update asc node AscTensors - asc_node->outputs(); - asc_node->inputs(); - GE_ASSERT_TRUE(node_index < asc_nodes.size(), - "[Deserialize][Node] failed, node_index[%ld] should less than nodes size[%zu].", node_index, - asc_nodes.size()); - const auto &asc_node_def = asc_nodes[static_cast(node_index)]; - const auto op_desc = asc_node->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc, "[Get][OpDesc] failed, graph[%s]", graph_name.c_str()); - GE_ASSERT_GRAPH_SUCCESS(AscNodeDeserializeUtils::DeserializeAttrGroupsDef(asc_node_def.attr(), *asc_node)); - GE_ASSERT_GRAPH_SUCCESS(AscNodeDeserializeUtils::DeserializeIrDef(asc_node_def.ir_def(), *asc_node)); - // deserialize output tensor attr - int32_t index = 0; - for (const auto &output_def : asc_node_def.outputs()) { - const auto &output_attr = output_def.attr(); - const auto output_desc = op_desc->MutableOutputDesc(index); - GE_ASSERT_NOTNULL(output_desc, "[Get][OutputDesc] failed of node[%s:%s, %d].", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), index); - const auto output_tensor_attr = output_desc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(output_tensor_attr, "[GetOrCreate][OutputTensorAttr] failed of node[%s:%s, %d].", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), index); - output_tensor_attr->DeserializeAttr(output_attr, output_desc.get()); - index++; - } - // deserialize input tensor attr - index = 0; - for (const auto &input_def : asc_node_def.input_src()) { - const auto &src_node_name = input_def.src_node_name(); - const auto &src_node = asc_graph.FindNode(src_node_name.c_str()); - if (src_node == nullptr) { - GELOGW("[Get][SrcNodeName] %s failed of node [%s:%d]", src_node_name.c_str(), asc_node->GetNamePtr(), index); - continue; - } - const auto src_out_index = input_def.src_out_index(); - const auto &src_op = src_node->GetOpDesc(); - GE_ASSERT_NOTNULL(src_op); - const auto &out_tensor_desc = src_op->MutableOutputDesc(src_out_index); - GE_ASSERT_NOTNULL(out_tensor_desc); - const auto &out_tensor_attr = out_tensor_desc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(out_tensor_attr, "[GetOrCreate][InputTensorAttr] failed of node[%s:%s, %d].", - src_op->GetName().c_str(), src_op->GetType().c_str(), src_out_index); - const auto input_desc = op_desc->MutableInputDesc(index); - GE_ASSERT_NOTNULL(input_desc, "[Get][InputDesc] failed of node[%s:%s, %d].", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), index); - const auto input_tensor_attr = input_desc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(input_tensor_attr, "[GetOrCreate][InputTensorAttr] failed of node[%s:%s, %d].", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), index); - input_tensor_attr->dtype.tensor_desc_ = input_desc.get(); - *input_tensor_attr = *out_tensor_attr; - index++; - } - node_index++; - } - return GRAPH_SUCCESS; -} - -graphStatus AscGraphUtils::ConvertComputeGraphToAscGraph(const ComputeGraphPtr &compute_graph, AscGraph &asc_graph) { - GE_ASSERT_NOTNULL(compute_graph); - // 1. 如果外部没有指定名字, 共享名字 - if (asc_graph.GetName().empty()) { - auto asc_compute_graph = GetComputeGraph(asc_graph); - GE_ASSERT_NOTNULL(asc_compute_graph); - asc_compute_graph->SetName(compute_graph->GetName()); - } - // 2. 转换Node到AscNode, 共享OpDesc - std::unordered_map all_new_nodes; - for (const auto &node : compute_graph->GetDirectNode()) { - GE_ASSERT_NOTNULL(node); - const auto &op_desc = node->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - auto dst_new_node = asc_graph.AddNode(op); - GE_ASSERT_NOTNULL(dst_new_node); - all_new_nodes[dst_new_node->GetName()] = dst_new_node; - } - // 3. 转换连边关系 - for (const auto &src_node : compute_graph->GetDirectNode()) { - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RelinkGraphEdges(src_node, "", all_new_nodes)); - } - // 4. 转换graph上的属性组属性 - AscGraphImpl::DoCopyAscGraphAttrImpl(compute_graph, AscGraphUtils::GetComputeGraph(asc_graph)); - return GRAPH_SUCCESS; -} - -graphStatus AscNodeSerializeUtils::SerializeIrDef(const AscNode &node, ascendc_ir::proto::IrDef &ir_def) { - ir_def.set_type(node.GetType()); - const auto &op_desc = node.GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - for (const auto &ir_input : op_desc->GetIrInputs()) { - ir_def.add_input_names(ir_input.first); - ir_def.add_input_ir_type(ir_input.second); - } - for (const auto &ir_output : op_desc->GetIrOutputs()) { - ir_def.add_output_names(ir_output.first); - ir_def.add_output_ir_type(ir_output.second); - } - GELOGD("Serialize ir def node[%s:%s] success.", node.GetNamePtr(), ir_def.type().c_str()); - return GRAPH_SUCCESS; -} - -graphStatus AscNodeSerializeUtils::SerializeAttrGroupsDef(const AscNode &node, - ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_attr_groups_def) { - const auto op_desc = node.GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - const auto asc_node_attr = op_desc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(asc_node_attr); - return asc_node_attr->SerializeAttr(asc_node_attr_groups_def); -} - -graphStatus AscNodeDeserializeUtils::DeserializeIrDef(const ascendc_ir::proto::IrDef &ir_def, AscNode &node) { - const auto &type = ir_def.type(); - const auto &op_desc = node.GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - op_desc->SetType(type); - GE_ASSERT_EQ(ir_def.input_names_size(), ir_def.input_ir_type_size()); - for (int64_t index = 0; index < ir_def.input_names_size(); ++index) { - op_desc->AppendIrInput(ir_def.input_names(index), static_cast(ir_def.input_ir_type(index))); - } - GE_ASSERT_EQ(ir_def.output_names_size(), ir_def.output_ir_type_size()); - for (int64_t index = 0; index < ir_def.output_names_size(); ++index) { - op_desc->AppendIrOutput(ir_def.output_names(index), static_cast(ir_def.output_ir_type(index))); - } - GELOGD("Deserialize ir def node[%s:%s] success.", node.GetNamePtr(), ir_def.type().c_str()); - return GRAPH_SUCCESS; -} -graphStatus AscNodeDeserializeUtils::DeserializeAttrGroupsDef(const ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_attr_groups_def, - AscNode &node) { - const auto op_desc = node.GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - const auto asc_node_attr = op_desc->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(asc_node_attr); - GE_ASSERT_GRAPH_SUCCESS(asc_node_attr->DeserializeAttr(asc_node_attr_groups_def), - "[Deserialize][Attr] failed of node[%s:%s].", op_desc->GetName().c_str(), - op_desc->GetType().c_str()); - return GRAPH_SUCCESS; -} -graphStatus ExpressionSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - Expression expression; - GE_ASSERT_GRAPH_SUCCESS(av.GetValue(expression)); - def.set_expression(expression.Serialize().get()); - return GRAPH_SUCCESS; -} - -graphStatus ExpressionSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - return av.SetValue(Expression::Deserialize(def.expression().c_str())); -} - -REG_GEIR_SERIALIZER(expression_serializer, - ExpressionSerializer, - GetTypeId(), - proto::AttrDef::kExpression); -} // namespace ge diff --git a/graph/ascendc_ir/utils/asc_tensor_utils.cc b/graph/ascendc_ir/utils/asc_tensor_utils.cc deleted file mode 100644 index 20d2c91655e600f8a85091610be65c0501e10b5a..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/asc_tensor_utils.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ascendc_ir/utils/asc_tensor_utils.h" - -namespace ge { -namespace ascir { -bool AscTensorUtils::IsConstTensor(const AscTensor &t) { - const auto node = t.anchor.GetOwnerNodeBarePtr(); - GE_ASSERT_NOTNULL(node); - return node->GetType() == "Constant" || node->GetType() == "IndexExpr" || node->GetType() == "Scalar"; -} -Node *AscTensorUtils::GetOwner(const AscTensor &t) { - return t.anchor.GetOwnerNodeBarePtr(); -} -int32_t AscTensorUtils::Index(const AscTensor &t) { - return t.anchor.GetIdx(); -} -} // namespace ascir -} // namespace ge diff --git a/graph/ascendc_ir/utils/ascendc_ir_check.cc b/graph/ascendc_ir/utils/ascendc_ir_check.cc deleted file mode 100644 index 4755639fd7d7bb6ade91125bbb539fb5c46d439f..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/ascendc_ir_check.cc +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "graph/ascendc_ir/ascendc_ir_check.h" -#include -namespace ge { -AscIRException::AscIRException(const AscIRException::Info &info) : std::exception(), info_(info) {} - -const AscIRException::Info &AscIRException::GetInfo() const { - return info_; -} -} diff --git a/graph/ascendc_ir/utils/ascendc_ir_dump_utils.cc b/graph/ascendc_ir/utils/ascendc_ir_dump_utils.cc deleted file mode 100644 index 9c2fc0bf35004c6cf18d312020f8bb6b711ca251..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/ascendc_ir_dump_utils.cc +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "inc/graph/ascendc_ir/utils/ascendc_ir_dump_utils.h" -namespace ge { -std::stringstream &DumpAscirGraph::TilingKeyStr(std::stringstream &ss, AscGraph &graph) { - std::string Tilingkey = std::to_string(graph.GetTilingKey()); - ss << "TilingKey: " << Tilingkey << std::endl; - return ss; -} - -std::stringstream &DumpAscirGraph::NameStr(std::stringstream &ss, AscGraph &graph) { - ss << "Graph Name: " << graph.GetName() << std::endl; - return ss; -} - -std::stringstream &DumpAscirGraph::AllAxisStr(std::stringstream &ss, AscGraph &graph) { - ss << "Axis:" << std::endl; - if (graph.GetAllAxis().empty()) { - return ss; - } - static const char *axis_type_str[] = { - "ORIGINAL", - "BLOCK_OUTER", - "BLOCK_INNER", - "TILE_OUTER", - "TILE_INNER", - "MERGED", - "INVALID", - }; - int32_t i = 1; - for (const auto &axis : graph.GetAllAxis()) { - ss << " axis"<< i << ": " << std::endl; - ss << " name: " << axis->name << std::endl; - ss << " id: " << axis->id << std::endl; - if (axis->type >= ge::Axis::kAxisTypeOriginal && axis->type <= ge::Axis::kAxisTypeMerged) { - ss << " type: " << axis_type_str[axis->type] << std::endl; - } - std::string bind_block = axis->bind_block? "true" : "false"; - ss << " bind_block: " << bind_block << std::endl; - ss << " size: " << axis->size.Str().get() << std::endl; - ss << " align: " << axis->align << std::endl; - if (!axis->from.empty()) { - ss << " from: {"; - for (auto from_axis : axis->from) { - ss << from_axis << ", "; - } - ss << "}" << std::endl; - } - if (axis->split_pair_other_id == kIdNone) { - ss << " split_pair_other_id: -1" << std::endl; - } - else { - ss << " split_pair_other_id: " << axis->split_pair_other_id << std::endl; - } - ss << " allow_oversize_axis: " << axis->allow_oversize_axis << std::endl; - ss << " allow_unaligned_tail: " << axis->allow_oversize_axis << std::endl; - i++; - } - return ss; -} - -std::string DumpAscirGraph::ApiTypeToString(ge::ApiType type) { - static const std::map api_type_to_string_map = { - {ge::ApiType::kAPITypeBuffer, "BUFFER"}, - {ge::ApiType::kAPITypeCompute, "COMPUTE"}, - {ge::ApiType::kAPITypeInvalid, "INVALID"}, - }; - const auto it = api_type_to_string_map.find(type); - if (it != api_type_to_string_map.end()) { - return it->second; - } - return "UNDEFINED"; -} - -std::string DumpAscirGraph::ComputUnitToString(ge::ComputeUnit unit) { - static const std::map comput_unit_to_string_map = { - {ge::ComputeUnit::kUnitNone, "NONE"}, - {ge::ComputeUnit::kUnitMTE1, "MTE1"}, - {ge::ComputeUnit::kUnitMTE2, "MTE2"}, - {ge::ComputeUnit::kUnitMTE3, "MTE3"}, - {ge::ComputeUnit::kUnitScalar, "SCALAR"}, - {ge::ComputeUnit::kUnitVector, "VECTOR"}, - {ge::ComputeUnit::kUnitCube, "CUBE"}, - {ge::ComputeUnit::kUnitInvalid, "INVALID"}, - }; - const auto it = comput_unit_to_string_map.find(unit); - if (it != comput_unit_to_string_map.end()) { - return it->second; - } - return "UNDEFINED"; -} - -std::string DumpAscirGraph::ComputeTypeToString(ge::ComputeType type) { - static const std::map comput_type_to_string_map = { - {ge::ComputeType::kComputeLoad, "LOAD"}, - {ge::ComputeType::kComputeStore, "STORE"}, - {ge::ComputeType::kComputeReduceStore, "REDUCE_STORE"}, - {ge::ComputeType::kComputeElewise, "ELEWISE"}, - {ge::ComputeType::kComputeBroadcast, "BROADCAST"}, - {ge::ComputeType::kComputeReduce, "REDUCE"}, - {ge::ComputeType::kComputeTranspose, "TRANPOSE"}, - {ge::ComputeType::kComputeGather, "GATHER"}, - {ge::ComputeType::kComputeInvalid, "INVALID"}, - }; - const auto it = comput_type_to_string_map.find(type); - if (it != comput_type_to_string_map.end()) { - return it->second; - } - return "UNDEFINED"; -} - -std::stringstream &DumpAscirGraph::AscNodeAttrStr(std::stringstream &ss, AscNodeAttr &attr) { - ss << " AscNode: " << std::endl; - ss << " sched: " << std::endl; - ss << " exec_order: " << attr.sched.exec_order << std::endl; - ss << " axis: "; - for (auto axis : attr.sched.axis) { - ss << axis << ", "; - } - ss << std::endl; - ss << " loop_axis: " << attr.sched.loop_axis << std::endl; - ss << std::endl; - ss << " Api: " << std::endl; - ss << " Api type: " << ApiTypeToString(attr.api.type) << std::endl; - ss << " Compute unit: " << ComputUnitToString(attr.api.unit) << std::endl; - ss << " Compute type: " << ComputeTypeToString(attr.api.compute_type) << std::endl; - return ss; -} - -std::stringstream &DumpAscirGraph::AscTensorAttrStr(std::stringstream &ss, AscTensorAttr *attr) { - if (attr == nullptr) { - return ss; - } - ss << " AscTensor: " << std::endl; - ss << " DataType: " << TypeUtils::DataTypeToSerialString(attr->dtype.operator ge::DataType()) << std::endl; - ss << " axis: "; - for (auto axis : attr->axis) { - ss << axis << ", "; - } - ss << std::endl; - ss << " repeats: "; - for (const auto &repeat : attr->repeats) { - ss << repeat.Str().get() << ", "; - } - ss << std::endl; - ss << " strides: "; - for (const auto &stride : attr->strides) { - ss << stride.Str().get() << ", "; - } - ss << std::endl; - ss << " vectorized_axis: "; - for (auto axis : attr->vectorized_axis) { - ss << axis << ", "; - } - ss << std::endl; - ss << " vectorized_strides: "; - for (const auto &stride : attr->vectorized_strides) { - ss << stride.Str().get() << ","; - } - ss << std::endl; - MemAttrStr(ss, attr); - MemQueueAttrStr(ss, attr); - MemBufAttrStr(ss, attr); - MemOptAttrStr(ss, attr); - return ss; -} - -std::string DumpAscirGraph::AllocTypeToString(ge::AllocType type) { - static const std::map alloc_type_to_string_map = { - {ge::AllocType::kAllocTypeGlobal, "GLOBAL"}, - {ge::AllocType::kAllocTypeL1, "L1"}, - {ge::AllocType::kAllocTypeL2, "L2"}, - {ge::AllocType::kAllocTypeBuffer, "BUFFER"}, - {ge::AllocType::kAllocTypeQueue, "QUEUE"} - }; - const auto it = alloc_type_to_string_map.find(type); - if (it != alloc_type_to_string_map.end()) { - return it->second; - } - return "UNDEFINED"; -} - -std::string DumpAscirGraph::PositionToString(ge::Position position) { - static const std::map position_to_string_map = { - {ge::Position::kPositionGM, "GM"}, - {ge::Position::kPositionVecIn, "VECIN"}, - {ge::Position::kPositionVecOut, "VECOUT"} - }; - const auto it = position_to_string_map.find(position); - if (it != position_to_string_map.end()) { - return it->second; - } - return "UNDEFINED"; -} - -std::string DumpAscirGraph::HardwareToString(ge::MemHardware hardware) { - static const std::map hard_ware_to_string_map = { - {ge::MemHardware::kMemHardwareGM, "GM"}, - {ge::MemHardware::kMemHardwareUB, "UB"} - }; - const auto it = hard_ware_to_string_map.find(hardware); - if (it != hard_ware_to_string_map.end()) { - return it->second; - } - return "UNDEFINED"; -} - -std::stringstream &DumpAscirGraph::MemAttrStr(std::stringstream &ss, AscTensorAttr *attr) { - ss << " MemAttr: " << std::endl; - ss << " tensor_id: " << attr->mem.tensor_id << std::endl; - ss << " alloc_type: " << AllocTypeToString(attr->mem.alloc_type) << std::endl; - ss << " position: " << PositionToString(attr->mem.position) << std::endl; - ss << " hardware: " << HardwareToString(attr->mem.hardware) << std::endl; - ss << " buf_ids: "; - for (auto buf_id : attr->mem.buf_ids) { - ss << buf_id << ", "; - } - ss << std::endl; - ss << " name: " << attr->mem.name << std::endl; - return ss; -} - -std::stringstream &DumpAscirGraph::MemQueueAttrStr(std::stringstream &ss, AscTensorAttr *attr) { - ss << " MemQueAttr: " << std::endl; - ss << " id: " << attr->que.id << std::endl; - ss << " depth: " << attr->que.depth << std::endl; - ss << " buf_num: " << attr->que.buf_num << std::endl; - ss << " name: " << attr->que.name << std::endl; - return ss; -} - -std::stringstream &DumpAscirGraph::MemBufAttrStr(std::stringstream &ss, AscTensorAttr *attr) { - ss << " MemBufAttr: " << std::endl; - ss << " id: " << attr->buf.id << std::endl; - ss << " name: " << attr->buf.name << std::endl; - return ss; -} - -std::stringstream &DumpAscirGraph::MemOptAttrStr(std::stringstream &ss, AscTensorAttr *attr) { - ss << " MemOptAttr: " << std::endl; - ss << " reuse_id: " << attr->opt.reuse_id << std::endl; - ss << " ref_tensor: " << attr->opt.ref_tensor << std::endl; - ss << " merge_scope: " << attr->opt.merge_scope << std::endl; - return ss; -} - -std::stringstream &DumpAscirGraph::NodesStr(std::stringstream &ss, ge::AscNodeVisitor &nodes) { - ss << "nodes:" << std::endl; - int32_t i = 1; - for (auto node = nodes.begin(); node != nodes.end(); ++node) { - ss << " node" << i << " info: " << std::endl; - ss << " node name: " << node.operator*()->GetName() << std::endl; - uint32_t input_size = node.operator*()->inputs.Size(); - ss << " inputs: " << std::endl; - for (uint32_t j = 0; j < input_size; j++) { - if ((node.operator*()->GetInDataAnchor(j) != nullptr) && (node.operator*()->GetInDataAnchor(j)->GetPeerOutAnchor() != nullptr)) { - AscTensorAttr &temp = node.operator*()->inputs[j].attr; - AscTensorAttr *tempPtr = &temp; - AscTensorAttrStr(ss, tempPtr); - } - } - ss << " outputs: " << std::endl; - for (auto outputs : node.operator*()->outputs()) { - AscTensorAttrStr(ss, &outputs->attr); - } - ss << " attr: " << std::endl; - AscNodeAttrStr(ss, node.operator*()->attr); - ss << std::endl; - i++; - } - return ss; -} - -std::string DumpAscirGraph::DumpGraph(AscGraph &graph) { - std::stringstream ss; - TilingKeyStr(ss, graph); - NameStr(ss, graph); - AllAxisStr(ss, graph); - AscNodeVisitor all_nodes = graph.GetAllNodes(); - NodesStr(ss, all_nodes); - return ss.str(); -} - -void DumpAscirGraph::WriteOutToFile(const std::string &filename, AscGraph &graph) { - const auto &content = DumpGraph(graph); - std::ofstream outFile(filename); - if (!outFile) { - std::cerr << "Cannot open the file: " << filename << std::endl; - return; - } - outFile << content; - outFile.close(); -} - -} // namespace ge \ No newline at end of file diff --git a/graph/ascendc_ir/utils/axis_utils.cc b/graph/ascendc_ir/utils/axis_utils.cc deleted file mode 100644 index 3cf910119cdfeedd2162a66978ee88647a765c27..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/axis_utils.cc +++ /dev/null @@ -1,383 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "ascendc_ir/ascendc_ir_check.h" -#include "graph/expression/const_values.h" -#include "utils/axis_utils.h" -namespace ge { -namespace { -using AstAxisIdToTransInfo = std::map; -std::vector ToAxisIds(const std::vector &axes) { - std::vector axes_ids; - axes_ids.reserve(axes.size()); - for (const auto &axis : axes) { - axes_ids.emplace_back(axis->id); - } - return axes_ids; -} - -AstAxisIdToTransInfo ToTransInfoMap(const TransInfoRoadOfGraph &trans_info_road_of_graph) { - AstAxisIdToTransInfo res; - for (const auto &trans_info : trans_info_road_of_graph) { - for (const auto &dst_axis : trans_info.dst_axis) { - res[dst_axis->id] = &trans_info; - } - } - return res; -} -const OneTransInfo *GetTransInfo(const AstAxisIdToTransInfo &trans_infos, const AxisId dst_axis_id) { - const auto iter = trans_infos.find(dst_axis_id); - if (iter == trans_infos.cend()) { - return nullptr; - } - return iter->second; -} - -std::vector GetTransInfos(const AstAxisIdToTransInfo &trans_infos, - const std::vector &dst_axis_ids, bool revert = false, - std::unordered_set uniq_set = {}) { - std::vector res; - for (const auto &dst_axis_id : dst_axis_ids) { - const auto info = GetTransInfo(trans_infos, dst_axis_id); - if ((info != nullptr) && (uniq_set.insert(info).second)) { - res.emplace_back(*info); - // if set revert, need find trans info from src_axis - auto got_trans_infos = GetTransInfos( - trans_infos, (revert ? ToAxisIds(info->src_axis) : ToAxisIds(info->dst_axis)), revert, uniq_set); - res.insert(res.cend(), got_trans_infos.cbegin(), got_trans_infos.cend()); - } - } - return res; -} - -bool AreAllSrcReady(const View &tensor_view_to_update, const OneTransInfo &trans_info) { - // check if all src of trans_info(check) is all in axes(source) - return std::all_of(trans_info.src_axis.begin(), trans_info.src_axis.end(), - [&tensor_view_to_update](const AxisPtr &axis) { - auto &axis_ids = tensor_view_to_update.axis_ids; - return std::find(axis_ids.begin(), axis_ids.end(), axis->id) != axis_ids.end(); - }); -} - -std::vector ToRevertTransInfo(const std::vector &trans_infos) { - std::vector revert_trans_infos; - for (const auto &trans_info : trans_infos) { - if (trans_info.trans_type == TransType::kSplit) { - revert_trans_infos.emplace_back(OneTransInfo{TransType::kMerge, trans_info.dst_axis, trans_info.src_axis}); - } else if (trans_info.trans_type == TransType::kMerge) { - revert_trans_infos.emplace_back(OneTransInfo{TransType::kSplit, trans_info.dst_axis, trans_info.src_axis}); - } - } - return revert_trans_infos; -} - -std::pair, std::vector> UpdateReadyTransInfos( - const View &tensor_view_to_update, const std::vector ¬_ready_trans_infos) { - std::vector to_apply_trans_infos; - std::vector not_read_trans_infos_updated; - for (auto iter = not_ready_trans_infos.begin(); iter != not_ready_trans_infos.end();) { - if (AreAllSrcReady(tensor_view_to_update, *iter)) { - to_apply_trans_infos.emplace_back(*iter); - } else { - not_read_trans_infos_updated.emplace_back(*iter); - } - ++iter; - } - return {to_apply_trans_infos, not_read_trans_infos_updated}; -} - -bool CheckAxisValid(const OneTransInfo &one_trans_info, const bool revert) { - const auto trans_type = one_trans_info.trans_type; - bool need_check_dst_merged = - (!revert && trans_type == TransType::kMerge) || (revert && (trans_type == TransType::kSplit)); - const auto merged_axis = revert ? one_trans_info.src_axis.front() : one_trans_info.dst_axis.front(); - if (need_check_dst_merged) { - GE_ASSERT_TRUE(merged_axis->type == Axis::kAxisTypeMerged, "[Check][Axis] failed, axis id[%d], type[%d].", - merged_axis->id, merged_axis->type); - } - bool need_check_dst_split = - (!revert && trans_type == TransType::kSplit) || (revert && (trans_type == TransType::kMerge)); - const auto outer_axis = revert ? one_trans_info.src_axis.front() : one_trans_info.dst_axis.front(); - const auto inner_axis = revert ? one_trans_info.src_axis.back() : one_trans_info.dst_axis.back(); - if (need_check_dst_split) { - GE_ASSERT_TRUE(outer_axis->type == Axis::kAxisTypeBlockOuter || outer_axis->type == Axis::kAxisTypeTileOuter, - "[Check][Axis] failed, axis id[%d], type[%d]", outer_axis->id, outer_axis->type); - GE_ASSERT_TRUE(inner_axis->type == Axis::kAxisTypeBlockInner || inner_axis->type == Axis::kAxisTypeTileInner, - "[Check][Axis] failed, axis id[%d], type[%d]", outer_axis->id, outer_axis->type); - } - return true; -} - -DiffAxesInfo GetDiffAxesInfo(const std::vector &input_api_sched_axes, - const std::vector &my_api_sched_axes) { - DiffAxesInfo diff_axes_info; - diff_axes_info.add_axes = my_api_sched_axes; - for (const auto &input_api_sched_axis : input_api_sched_axes) { - const auto iter = std::find(diff_axes_info.add_axes.cbegin(), diff_axes_info.add_axes.cend(), input_api_sched_axis); - if (iter != diff_axes_info.add_axes.cend()) { - diff_axes_info.add_axes.erase(iter); - } else { - diff_axes_info.del_axes.emplace_back(input_api_sched_axis); - } - } - return diff_axes_info; -} -// `pair.first == false` means apply failed -std::pair ApplyViewTrans(const TransInfoRoadOfGraph &trans_info_road_of_graph, const bool revert, - View &tensor_view_to_update) { - for (const auto &one_trans_info : trans_info_road_of_graph) { - GE_ASSERT_TRUE(CheckAxisValid(one_trans_info, revert)); - switch (one_trans_info.trans_type) { - case TransType::kSplit:GE_ASSERT_TRUE(one_trans_info.src_axis.size() == 1U, "[Check][Axis], size[%zu]", - one_trans_info.src_axis.size()); - GE_ASSERT_TRUE(one_trans_info.dst_axis.size() == 2U, "[Check][Axis], size[%zu]", - one_trans_info.dst_axis.size()); - GE_ASSERT_NOTNULL(one_trans_info.src_axis.front()); - GE_ASSERT_NOTNULL(one_trans_info.dst_axis.front()); - GE_ASSERT_NOTNULL(one_trans_info.dst_axis.back()); - tensor_view_to_update = AxisUtils::SplitView( - tensor_view_to_update, one_trans_info.dst_axis.back()->size, one_trans_info.dst_axis.front()->id, - one_trans_info.dst_axis.back()->id, one_trans_info.src_axis.front()->id); - - break; - case TransType::kMerge: { - GE_ASSERT_TRUE(one_trans_info.src_axis.size() >= 2U, "[Check][Axis], size[%zu]", - one_trans_info.src_axis.size()); - GE_ASSERT_TRUE(one_trans_info.dst_axis.size() == 1U, "[Check][Axis], size[%zu]", - one_trans_info.dst_axis.size()); - GE_ASSERT_NOTNULL(one_trans_info.dst_axis.front()); - std::vector src_axis_ids; - for (const auto &src_axis : one_trans_info.src_axis) { - src_axis_ids.push_back(src_axis->id); - } - tensor_view_to_update = - AxisUtils::MergeView(tensor_view_to_update, one_trans_info.dst_axis.back()->id, src_axis_ids); - - break; - } - case TransType::kValid: - break; - default: - GELOGW("Unsupported trans type %ld", one_trans_info.trans_type); - return {false, tensor_view_to_update}; - } - } - GELOGD("Update view to [%s].", ToString(tensor_view_to_update.axis_ids).c_str()); - return {true, tensor_view_to_update}; -} - -std::pair ApplyReadyTransInfos(const std::vector &my_api_sched_axes, View &tensor_view_to_update, - std::vector ¬_ready_trans, const bool revert) { - std::vector to_apply_trans; - GELOGI("Before apply trans info, view is %s, my api schedule axes is %s, not ready trans size is %zu, revert is %d", - ViewToString(tensor_view_to_update).c_str(), ToString(my_api_sched_axes).c_str(), not_ready_trans.size(), - revert); - std::tie(to_apply_trans, not_ready_trans) = UpdateReadyTransInfos(tensor_view_to_update, not_ready_trans); - std::pair pair0{true, tensor_view_to_update}; - // break loop condition: - // current axes can not find any transform info to apply - while (!to_apply_trans.empty()) { - pair0 = ApplyViewTrans(to_apply_trans, revert, tensor_view_to_update); - if (pair0.first) { - std::tie(tensor_view_to_update) = pair0.second; - } else { - return {false, tensor_view_to_update}; - } - std::tie(to_apply_trans, not_ready_trans) = UpdateReadyTransInfos(tensor_view_to_update, not_ready_trans); - } - // 根据当前API的调度轴顺序调整输出View的轴顺序,保证越外侧的调度轴越靠前 - GELOGI("After apply trans info, view is %s", ViewToString(tensor_view_to_update).c_str()); - return {true, AxisUtils::ReorderView(tensor_view_to_update, my_api_sched_axes)}; -} -} // namespace -View AxisUtils::ReduceView(const View &src_view, int64_t reduce_axis) { - View new_view(src_view); - auto &axis_ids = new_view.axis_ids; - auto &repeats = new_view.repeats; - auto &strides = new_view.strides; - GELOGI("Before reduce, view is %s", ViewToString(src_view).c_str()); - GE_ASSERT_EQ(axis_ids.size(), repeats.size()); - GE_ASSERT_EQ(axis_ids.size(), strides.size()); - const size_t axis_size = axis_ids.size(); - size_t reduce_index_in_axis = 0U; - ge::Expression repeat_size; - - for (size_t index = 0; index < axis_size; ++index) { - if (axis_ids[index] == reduce_axis) { - repeat_size = repeats[index]; - strides[index] = ge::Symbol(0); - reduce_index_in_axis = index; - break; - } - } - // reduce之前的轴的stride应该除去reduce轴的repeat - for (size_t index = 0; index < reduce_index_in_axis; ++index) { - strides[index] = strides[index] / repeat_size; - } - GELOGI("After reduce, view is %s", ViewToString(new_view).c_str()); - return new_view; -} - -std::vector AxisUtils::GetDefaultVectorizedAxis(const std::vector &tensor_axis, int64_t loop_axis) { - auto iter = std::find(tensor_axis.begin(), tensor_axis.end(), loop_axis); - if (iter == tensor_axis.end()) { - return tensor_axis; - } else { - return {std::next(iter), tensor_axis.end()}; - } -} - -View AxisUtils::SplitView(const View &src_view, const ge::Expression &split_size, - const int64_t outter_id, const int64_t inner_id, const int64_t original_id) { - View new_view; - auto &new_axes = new_view.axis_ids; - auto &new_repeat = new_view.repeats; - auto &new_strides = new_view.strides; - const auto &axis_ids = src_view.axis_ids; - const auto &repeats = src_view.repeats; - const auto &strides = src_view.strides; - - GELOGI("Before split, view is %s", ViewToString(src_view).c_str()); - GE_ASSERT_EQ(axis_ids.size(), repeats.size()); - GE_ASSERT_EQ(axis_ids.size(), strides.size()); - for (uint32_t axis_index = 0U; axis_index < axis_ids.size(); axis_index++) { - if (axis_ids[axis_index] != original_id) { - new_axes.push_back(axis_ids[axis_index]); - new_repeat.push_back(repeats[axis_index]); - new_strides.push_back(strides[axis_index]); - } else { - new_axes.push_back(outter_id); - new_axes.push_back(inner_id); - if (repeats[axis_index] == 1) { - // keep stride when repeat=1 - new_repeat.push_back(sym::kSymbolOne); - new_strides.push_back(sym::kSymbolZero); - new_repeat.push_back(sym::kSymbolOne); - new_strides.push_back(strides[axis_index]); - } else { - new_repeat.push_back(repeats[axis_index] / split_size); - new_strides.push_back(strides[axis_index] * split_size); - new_repeat.push_back(split_size); - new_strides.push_back(strides[axis_index]); - } - } - } - - GELOGI("After split, view is %s", ViewToString(new_view).c_str()); - return new_view; -} - -View AxisUtils::MergeView(const View &src_view, const int64_t merged_axis_id, const std::vector &original) { - View new_view; - std::set original_set(original.begin(), original.end()); - std::set merge_axis_set; - auto &new_axis_ids = new_view.axis_ids; - auto &new_repeat = new_view.repeats; - auto &new_strides = new_view.strides; - const auto &axis_ids = src_view.axis_ids; - const auto &repeats = src_view.repeats; - const auto &strides = src_view.strides; - GELOGI("Before merge, view is %s", ViewToString(src_view).c_str()); - GE_ASSERT_EQ(axis_ids.size(), repeats.size()); - GE_ASSERT_EQ(axis_ids.size(), strides.size()); - ge::Expression merge_repeat = sym::kSymbolOne; - for (uint32_t axis_index = 0U; axis_index < axis_ids.size(); axis_index++) { - if (original_set.find(axis_ids[axis_index]) != original_set.end()) { - merge_repeat = merge_repeat * repeats[axis_index]; - merge_axis_set.emplace(axis_ids[axis_index]); - if (merge_axis_set.size() == original_set.size()) { - new_axis_ids.push_back(merged_axis_id); - new_repeat.push_back(merge_repeat); - new_strides.push_back(strides[axis_index]); - } - } else { - new_axis_ids.push_back(axis_ids[axis_index]); - new_repeat.push_back(repeats[axis_index]); - new_strides.push_back(strides[axis_index]); - } - } - // 不支持merge不全的情况 - GE_ASSERT_TRUE(merge_axis_set.size() == original_set.size() || merge_axis_set.empty(), - "tensor has view %s but origin is %s", - ViewToString(src_view).c_str(), - ViewMemberToString(original).c_str()); - GELOGI("After merge, view is %s", ViewToString(new_view).c_str()); - return new_view; -} - -std::pair AxisUtils::UpdateViewIfCrossLoop(const TransInfoRoadOfGraph &trans_info_road_of_graph, - const vector &input_api_sched_axes, - const vector &my_api_sched_axes, - View &&tensor_view_to_update) { - // 计算当前API与输入View对应API的差异化轴 - if (my_api_sched_axes == input_api_sched_axes) { - return {false, tensor_view_to_update}; - } - // 对输入View应用差异化轴的差异化变换,得到输出View - // step1 计算调度轴的映射,查找差异化轴的变换信息 - const auto trans_info_map = ToTransInfoMap(trans_info_road_of_graph); - const auto diff_axes_info = GetDiffAxesInfo(input_api_sched_axes, my_api_sched_axes); - // step2 应用差异化轴的逆变换 - auto revert_not_ready_trans = ToRevertTransInfo(GetTransInfos(trans_info_map, diff_axes_info.del_axes)); - if (!revert_not_ready_trans.empty()) { - ApplyReadyTransInfos(my_api_sched_axes, tensor_view_to_update, revert_not_ready_trans, true); - } - // step3 应用差异化轴的变换 - auto not_ready_trans = GetTransInfos(trans_info_map, diff_axes_info.add_axes); - return ApplyReadyTransInfos(my_api_sched_axes, tensor_view_to_update, not_ready_trans, false); -} - -View AxisUtils::ReorderView(const View &src_view, const std::vector &my_api_sched_axes) { - GELOGI("Before reorder, view is %s, my api sched axes is %s", ViewToString(src_view).c_str(), - ToString(my_api_sched_axes).c_str()); - const auto &src_axes = src_view.axis_ids; - const auto &repeats = src_view.repeats; - const auto &strides = src_view.strides; - GE_ASSERT_EQ(src_axes.size(), repeats.size()); - GE_ASSERT_EQ(src_axes.size(), strides.size()); - using ReorderingView = std::pair>; - std::vector reordering_view; - reordering_view.reserve(src_axes.size()); - for (size_t id = 0UL; id < src_axes.size(); ++id) { - reordering_view.emplace_back(std::make_pair(src_axes[id], std::make_pair(&repeats[id], &strides[id]))); - } - std::sort(reordering_view.begin(), reordering_view.end(), - [&my_api_sched_axes](const ReorderingView &left, const ReorderingView &right) -> bool { - const auto left_iter = std::find(my_api_sched_axes.cbegin(), my_api_sched_axes.cend(), left.first); - const auto right_iter = std::find(my_api_sched_axes.cbegin(), my_api_sched_axes.cend(), right.first); - // the condition for reorder - if ((left_iter == my_api_sched_axes.cend()) && - (right_iter == my_api_sched_axes.cend())) { - return false; - } - if ((left_iter == my_api_sched_axes.cend()) && - (right_iter != my_api_sched_axes.cend())) { - return false; - } - if ((left_iter != my_api_sched_axes.cend()) && - (right_iter == my_api_sched_axes.cend())) { - // for example, my schedule = [a,b], axes = [c,a,d], a is in my schedule, should reorder [c,a] in axes - return true; - } - // for example, my schedule = [a,b], axes = [b,a,d], in my schedule b > a, should reorder [b,a] in axes - return left_iter < right_iter; - }); - std::vector ordered_axes; - std::vector ordered_repeats; - std::vector ordered_strides; - for (const auto &reordering_axis : reordering_view) { - ordered_axes.emplace_back(reordering_axis.first); - ordered_repeats.emplace_back(*reordering_axis.second.first); - ordered_strides.emplace_back(*reordering_axis.second.second); - } - View output_view{ordered_axes, ordered_repeats, ordered_strides}; - GELOGI("After reorder, view is %s", ViewToString(output_view).c_str()); - return output_view; -} -} // namespace ge \ No newline at end of file diff --git a/graph/ascendc_ir/utils/cg_utils.cc b/graph/ascendc_ir/utils/cg_utils.cc deleted file mode 100644 index 4fe4f42f6485551fb770772bb28406d743246079..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/cg_utils.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/cg_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/utils/graph_utils_ex.h" -#include "ascendc_ir/core/ascendc_ir_impl.h" -#include "graph/debug/ge_log.h" -namespace ge { -namespace ascir { -namespace cg { -namespace { -thread_local std::weak_ptr t_context; -int64_t GenNextId(const ge::ComputeGraphPtr &graph, const std::string &key) { - if (graph == nullptr) { - throw std::invalid_argument("Invalid graph"); - } - auto id = graph->GetExtAttr(key); - if (id == nullptr) { - graph->SetExtAttr(key, static_cast(1)); - return 0; - } - return (*id)++; -} - -int64_t GenNextId(const ge::Operator &op, const std::string &key) { - auto node = ge::NodeUtilsEx::GetNodeFromOperator(op); - GE_ASSERT_NOTNULL(node); - return GenNextId(node->GetOwnerComputeGraph(), key); -} -} // namespace - -CgContext *CgContext::GetThreadLocalContext() { - return GetSharedThreadLocalContext().get(); -} -std::shared_ptr CgContext::GetSharedThreadLocalContext() { - return t_context.lock(); -} -void CgContext::SetThreadLocalContext(const std::shared_ptr &context) { - t_context = context; -} -const std::vector &CgContext::GetLoopAxes() const { - return loop_axes_; -} -void CgContext::SetLoopAxes(std::vector axes) { - loop_axes_ = std::move(axes); - loop_axis_ids_cache_.clear(); - loop_axis_ids_cache_.reserve(loop_axes_.size()); - for (const auto &axis : loop_axes_) { - loop_axis_ids_cache_.emplace_back(axis.id); - } -} -void CgContext::PushLoopAxis(const Axis &axis) { - loop_axes_.emplace_back(axis); - loop_axis_ids_cache_.emplace_back(axis.id); -} -void CgContext::PopBackLoopAxis(const Axis &axis) { - if (loop_axis_ids_cache_.empty()) { - GELOGE(FAILED, "Axes stack is empty", ""); - return; - } - auto last_id = *(loop_axis_ids_cache_.rbegin()); - if (last_id != axis.id) { - GELOGE(FAILED, "Pop Axis order unmatch", ""); - return; - } - loop_axis_ids_cache_.pop_back(); - loop_axes_.pop_back(); -} -const std::vector &CgContext::GetLoopAxisIds() const { - return loop_axis_ids_cache_; -} -void CgContext::SetBlockLoopEnd(AxisId id) { - block_loop_end_ = id; -} -AxisId CgContext::GetBlockLoopEnd() const { - return block_loop_end_; -} -void CgContext::SetVectorizedLoopEnd(AxisId id) { - vectorized_loop_end_ = id; -} -AxisId CgContext::GetVectorizedLoopEnd() const { - return vectorized_loop_end_; -} -void CgContext::SetLoopEnd(AxisId id) { - loop_end_ = id; -} -AxisId CgContext::GetLoopEnd() const { - return loop_end_; -} -void CgContext::SetOption(const LoopOption &option) { - option_ = option; -} -const LoopOption &CgContext::GetOption() const { - return option_; -} -LoopGuard::~LoopGuard() { - context_->PopBackLoopAxis(axis_); -} -LoopGuard::LoopGuard(const Axis &axis) { - context_ = CgContext::GetSharedThreadLocalContext(); - if (context_ == nullptr) { - context_ = std::make_shared(); - CgContext::SetThreadLocalContext(context_); - } - - axis_ = axis; - context_->PushLoopAxis(axis_); -} -std::unique_ptr LoopGuard::Create(const Axis &axis, const LoopOption &option) { - auto loop_guard = ComGraphMakeUnique(axis); - loop_guard->context_->SetOption(option); - return loop_guard; -} - -int64_t CodeGenUtils::GenNextExecId(const ge::Operator &op) { - static const std::string kExecIdKey = "cg.ExecId"; - return GenNextId(op, kExecIdKey); -} - -int64_t CodeGenUtils::GenNextContainerId(const ge::Operator &op) { - static const std::string kContainerIdKey = "cg.ContainerId"; - return GenNextId(op, kContainerIdKey); -} - -int64_t CodeGenUtils::GenNextReuseId(const ge::Operator &op) { - static const std::string kReuseIdKey = "cg.ReuseId"; - return GenNextId(op, kReuseIdKey); -} - -int64_t CodeGenUtils::GenNextTensorId(const ge::Operator &op) { - static const std::string kTensorIdKey = "cg.TensorId"; - return GenNextId(op, kTensorIdKey); -} - -int64_t CodeGenUtils::GenNextExecId(const ge::AscGraph &graph) { - // impl_ is always valid - return GenNextExecId(graph.impl_->compute_graph_); -} - -int64_t CodeGenUtils::GenNextExecId(const ge::ComputeGraphPtr &graph) { - static const std::string kExecIdKey = "cg.ExecId"; - return GenNextId(graph, kExecIdKey); -} -} // namespace cg -} // namespace ascir -} diff --git a/graph/ascendc_ir/utils/dtype_transform_utils.cc b/graph/ascendc_ir/utils/dtype_transform_utils.cc deleted file mode 100644 index 1df2c40ff0a71293e4ece77a6f0357096db3d5b2..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/dtype_transform_utils.cc +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "utils/dtype_transform_utils.h" -#include - -ge::DataType DtypeTransformUtils::Prompt(ge::DataType src_type) { - static std::unordered_map prompt_dtype_map = { - {ge::DT_FLOAT16, ge::DT_FLOAT}, - {ge::DT_FLOAT, ge::DT_FLOAT}, - {ge::DT_BF16, ge::DT_FLOAT}, - - }; - const auto &iter = prompt_dtype_map.find(src_type); - return iter == prompt_dtype_map.end() ? ge::DT_UNDEFINED : iter->second; -} - diff --git a/graph/ascendc_ir/utils/mem_utils.cc b/graph/ascendc_ir/utils/mem_utils.cc deleted file mode 100644 index 70708ac97b4756ef0774cdc279d199136589baa1..0000000000000000000000000000000000000000 --- a/graph/ascendc_ir/utils/mem_utils.cc +++ /dev/null @@ -1,29 +0,0 @@ -/** - * 版权所有 (c) 华为技术有限公司 2024 - * - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "graph/utils/mem_utils.h" -namespace ge { -std::atomic MemUtils::gen_container_id_(0L); -std::atomic MemUtils::scope_id_(0L); -TQueConfig::TQueConfig(const int64_t id, const ge::Position pos, const int64_t depth, const int64_t buf_num) - : queue_attr_({id, depth, buf_num, ""}), pos_(pos) {} - -TBufConfig::TBufConfig(const int64_t id, const ge::Position pos) : buf_attr_({id, ""}), pos_(pos) {} - -TQueConfig MemUtils::CreateTQueConfig(const ge::Position pos, const int64_t depth, const int64_t buf_num) { - GE_ASSERT_TRUE(pos == Position::kPositionVecIn || (pos == Position::kPositionVecOut)); - return TQueConfig(gen_container_id_++, pos, depth, buf_num); -} - -TBufConfig MemUtils::CreateTBufConfig(const ge::Position pos) { - GE_ASSERT_TRUE(pos == Position::kPositionVecIn || (pos == Position::kPositionVecOut)); - return TBufConfig(gen_container_id_++, pos); -} -} \ No newline at end of file diff --git a/graph/attr/attr_group_base.cc b/graph/attr/attr_group_base.cc deleted file mode 100644 index 764aabaed25eecc5ab6b3a5b51f35a229bef285a..0000000000000000000000000000000000000000 --- a/graph/attr/attr_group_base.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "attribute_group/attr_group_base.h" - -namespace ge { -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(10); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(11); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(12); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(13); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(14); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(15); -} - -template <> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(16); // 16表示唯一ID -} - -template <> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(17); // 17表示唯一ID -} -} // namespace ge diff --git a/graph/attr/attr_group_serialize.cc b/graph/attr/attr_group_serialize.cc deleted file mode 100644 index c488836cd2ad16e6eb6dfbda0710e65713f91df7..0000000000000000000000000000000000000000 --- a/graph/attr/attr_group_serialize.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/ge_common/debug/ge_log.h" -#include "attribute_group/attr_group_serialize.h" -#include "attribute_group/attr_group_serializer_registry.h" -#include "checker.h" - -namespace ge { -graphStatus AttrGroupSerialize::SerializeAllAttr(proto::AttrGroups &attr_groups, const AttrStore &attr_store) { - GE_ASSERT_GRAPH_SUCCESS(OtherGroupSerialize(attr_groups, attr_store)); - - auto& id_2_ptr = attr_store.GetAttrsGroupPtr(); - for (const auto& ptr : id_2_ptr) { - if (ptr.second != nullptr) { - GE_ASSERT_GRAPH_SUCCESS(ptr.second->Serialize(*attr_groups.add_attr_group_def())); - } - } - return GRAPH_SUCCESS; -} - -graphStatus AttrGroupSerialize::DeserializeAllAttr(const proto::AttrGroups &attr_group, AttrHolder *attr_holder) { - GE_ASSERT_NOTNULL(attr_holder); - auto &attr_store = attr_holder->MutableAttrMap(); - GE_ASSERT_GRAPH_SUCCESS(OtherGroupDeserialize(attr_group, attr_store)); - for (const auto &attr_group_def : attr_group.attr_group_def()) { - auto deserializer = AttrGroupSerializerRegistry::GetInstance() - .GetDeserializer(attr_group_def.attr_group_case()); - if (deserializer.impl == nullptr) { - continue; - } - GE_ASSERT_GRAPH_SUCCESS(deserializer.impl->Deserialize(attr_group_def, attr_holder)); - attr_store.MutableAttrsGroupPtr()[deserializer.id] = std::move(deserializer.impl); - } - return ge::GRAPH_SUCCESS; -} -// todo: other group计划是需要再主线替换掉当前Ge IR上的的map attr = 5字段 -// 这个在分支上暂时不需要切换,等主线切换后再做替换,当前先做属性组的序列化和反序列化 -graphStatus AttrGroupSerialize::OtherGroupSerialize(proto::AttrGroups &attr_groups, const AttrStore &attr_store) { - (void)attr_store; - (void)attr_groups; - return GRAPH_SUCCESS; -} - -graphStatus AttrGroupSerialize::OtherGroupDeserialize(const proto::AttrGroups &attr_groups, AttrStore &attr_store) { - (void)attr_store; - (void)attr_groups; - return GRAPH_SUCCESS; -} -} \ No newline at end of file diff --git a/graph/attr/attr_group_serializer_registry.cc b/graph/attr/attr_group_serializer_registry.cc deleted file mode 100644 index 20a211b4d7a60e91ab22541e4e6f1a715fae6c97..0000000000000000000000000000000000000000 --- a/graph/attr/attr_group_serializer_registry.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2025. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "graph/attribute_group/attr_group_serializer_registry.h" -#include "common/ge_common/debug/ge_log.h" -namespace ge { - -AttrGroupSerializerRegistry &AttrGroupSerializerRegistry::GetInstance() { - static AttrGroupSerializerRegistry instance; - return instance; -} - -void AttrGroupSerializerRegistry::RegisterAttrGroupSerialize(const AttrGroupSerializeBuilder &builder, - const TypeId obj_type, - const proto::AttrGroupDef::AttrGroupCase proto_type) { - const std::lock_guard lck_guard(mutex_); - std::unique_ptr serializer = builder(); - const auto ptr = serializer.get(); - if (ptr == nullptr) { - GELOGE(FAILED, "SerializerBuilder is invalid."); - return; - } - if (serializer_builder_map_.count(obj_type) > 0U) { - GELOGW("Serializer %s for type %s has been registered", - typeid(*ptr).name(), HashedPointer(obj_type).ToString().c_str()); - return; - } - GELOGD("Serializer %s for type %s register successfully", - typeid(*ptr).name(), HashedPointer(obj_type).ToString().c_str()); - serializer_builder_map_[obj_type] = builder; - deserializer_builder_map_[proto_type] = std::make_pair(builder, obj_type); -} - -std::unique_ptr AttrGroupSerializerRegistry::GetSerializer(const TypeId obj_type) { - const auto iter = serializer_builder_map_.find(obj_type); - if (iter == serializer_builder_map_.cend()) { - GELOGW("Serializer for type %s has not been registered", HashedPointer(obj_type).ToString().c_str()); - return nullptr; - } - return iter->second(); -} - -AttrGroupDeserializer AttrGroupSerializerRegistry::GetDeserializer(const proto::AttrGroupDef::AttrGroupCase proto_type) { - const auto iter = - deserializer_builder_map_.find(proto_type); - if (iter == deserializer_builder_map_.cend()) { - GELOGW("Deserializer for type [%d] has not been registered", static_cast(proto_type)); - return AttrGroupDeserializer(nullptr, nullptr); - } - return AttrGroupDeserializer(iter->second.first(), iter->second.second); -} - -AttrGroupSerializerRegister::AttrGroupSerializerRegister(const AttrGroupSerializeBuilder builder, - TypeId const obj_type, - const proto::AttrGroupDef::AttrGroupCase proto_type) noexcept { - if (builder == nullptr) { - GELOGE(FAILED, "SerializerBuilder is nullptr."); - return; - } - AttrGroupSerializerRegistry::GetInstance().RegisterAttrGroupSerialize(builder, obj_type, proto_type); -} -} \ No newline at end of file diff --git a/graph/attr/attr_value.cc b/graph/attr/attr_value.cc deleted file mode 100644 index eff14684b166ac53853f2b571bb6dc95284d0778..0000000000000000000000000000000000000000 --- a/graph/attr/attr_value.cc +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/attr_value.h" -#include "debug/ge_util.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "graph/type_utils.h" -#include "graph/utils/tensor_adapter.h" -#include "common/checker.h" - -#define ATTR_VALUE_SET_GET_IMP(type) \ - graphStatus AttrValue::GetValue(type &val) const { \ - if (impl != nullptr) { \ - return impl->geAttrValue_.GetValue(val); \ - } \ - return GRAPH_FAILED; \ - } - -#define ATTR_VALUE_SET_ATTR_IMP(type) \ - graphStatus AttrValue::SetAttrValue(const type &attr_value) const { \ - GE_ASSERT_NOTNULL(impl); \ - return impl->geAttrValue_.SetValue(attr_value); \ - } - -#define ATTR_VALUE_GET_ATTR_IMP(type) \ - graphStatus AttrValue::GetAttrValue(type &attr_value) const { \ - GE_ASSERT_NOTNULL(impl); \ - return impl->geAttrValue_.GetValue(attr_value); \ - } - -namespace ge { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue::AttrValue() { - impl = ComGraphMakeShared(); -} - -ATTR_VALUE_SET_GET_IMP(AttrValue::STR) -ATTR_VALUE_SET_GET_IMP(AttrValue::INT) -ATTR_VALUE_SET_GET_IMP(AttrValue::FLOAT) - -// 使用宏生成基本类型的 SetAttrValue 和 GetAttrValue 函数 -ATTR_VALUE_SET_ATTR_IMP(int64_t) -ATTR_VALUE_GET_ATTR_IMP(int64_t) -ATTR_VALUE_SET_ATTR_IMP(float32_t) -ATTR_VALUE_GET_ATTR_IMP(float32_t) -ATTR_VALUE_SET_ATTR_IMP(bool) -ATTR_VALUE_GET_ATTR_IMP(bool) -ATTR_VALUE_SET_ATTR_IMP(ge::DataType) -ATTR_VALUE_GET_ATTR_IMP(ge::DataType) - -// 使用宏生成容器类型的 SetAttrValue 和 GetAttrValue 函数 -ATTR_VALUE_SET_ATTR_IMP(std::vector) -ATTR_VALUE_GET_ATTR_IMP(std::vector) -ATTR_VALUE_SET_ATTR_IMP(std::vector) -ATTR_VALUE_GET_ATTR_IMP(std::vector) -ATTR_VALUE_SET_ATTR_IMP(std::vector) -ATTR_VALUE_GET_ATTR_IMP(std::vector) -ATTR_VALUE_SET_ATTR_IMP(std::vector>) -ATTR_VALUE_GET_ATTR_IMP(std::vector>) -ATTR_VALUE_SET_ATTR_IMP(std::vector) -ATTR_VALUE_GET_ATTR_IMP(std::vector) - -graphStatus AttrValue::GetValue(AscendString &val) { - std::string val_get; - const auto status = GetValue(val_get); - if (status != GRAPH_SUCCESS) { - return status; - } - val = AscendString(val_get.c_str()); - return GRAPH_SUCCESS; -} - -// 特殊处理 AscendString 类型 -graphStatus AttrValue::SetAttrValue(const AscendString &attr_value) const { - GE_ASSERT_NOTNULL(impl); - return impl->geAttrValue_.SetValue(std::string(attr_value.GetString())); -} - -graphStatus AttrValue::GetAttrValue(AscendString &attr_value) const { - GE_ASSERT_NOTNULL(impl); - std::string str_value; - GE_ASSERT_GRAPH_SUCCESS(impl->geAttrValue_.GetValue(str_value)); - - attr_value = AscendString(str_value.c_str()); - return GRAPH_SUCCESS; -} - -// 特殊处理 std::vector 类型 -graphStatus AttrValue::SetAttrValue(const std::vector &attr_values) const { - GE_ASSERT_NOTNULL(impl); - - std::vector str_values; - for (const auto &value : attr_values) { - str_values.emplace_back(value.GetString()); - } - return impl->geAttrValue_.SetValue(str_values); -} - -graphStatus AttrValue::GetAttrValue(std::vector &attr_values) const { - GE_ASSERT_NOTNULL(impl); - std::vector str_values; - GE_ASSERT_GRAPH_SUCCESS(impl->geAttrValue_.GetValue(str_values)); - attr_values.clear(); - for (const auto &value : str_values) { - attr_values.emplace_back(value.c_str()); - } - return GRAPH_SUCCESS; -} - -// 特殊处理 Tensor 类型 -graphStatus AttrValue::SetAttrValue(const Tensor &attr_value) const { - GE_ASSERT_NOTNULL(impl); - return impl->geAttrValue_.SetValue(TensorAdapter::AsGeTensor(attr_value)); -} - -graphStatus AttrValue::GetAttrValue(Tensor &attr_value) const { - GE_ASSERT_NOTNULL(impl); - GeTensor ge_tensor; - GE_ASSERT_GRAPH_SUCCESS(impl->geAttrValue_.GetValue(ge_tensor)); - attr_value = TensorAdapter::AsTensor(ge_tensor); - return GRAPH_SUCCESS; -} - -// 特殊处理 std::vector 类型 -graphStatus AttrValue::SetAttrValue(const std::vector &attr_value) const { - GE_ASSERT_NOTNULL(impl); - std::vector ge_tensors; - for (const auto &tensor : attr_value) { - ge_tensors.emplace_back(TensorAdapter::AsGeTensor(tensor)); - } - return impl->geAttrValue_.SetValue(ge_tensors); -} - -graphStatus AttrValue::GetAttrValue(std::vector &attr_value) const { - GE_ASSERT_NOTNULL(impl); - std::vector ge_tensors; - GE_ASSERT_GRAPH_SUCCESS(impl->geAttrValue_.GetValue(ge_tensors)); - attr_value.clear(); - for (const auto &ge_tensor : ge_tensors) { - attr_value.emplace_back(TensorAdapter::AsTensor(ge_tensor)); - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/attr/ge_attr_define.cc b/graph/attr/ge_attr_define.cc deleted file mode 100644 index 6cc001cd8a9d923b72278692a298eca51d5d6f04..0000000000000000000000000000000000000000 --- a/graph/attr/ge_attr_define.cc +++ /dev/null @@ -1,1572 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/debug/ge_attr_define.h" - -namespace ge { -// Public attribute -const std::string ATTR_NAME_OP_FILE_PATH = "_op_file_path"; - -const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape"; - -const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; - -const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; - -const std::string ATTR_NAME_GRAPH_UNKNOWN_FLAG = "_graph_unknown_flag"; - -const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; - -const std::string ATTR_NAME_NAME = "name"; - -const std::string ATTR_NAME_TYPE = "type"; - -const std::string ATTR_NAME_WEIGHT_NAME = "weight_name"; - -const std::string ATTR_NAME_IS_QUANTIZE_FACTOR = "quantize_factor"; - -const std::string ATTR_NAME_ALPHA = "alpha"; - -const std::string ATTR_NAME_BETA = "beta"; - -const std::string ATTR_NAME_PADMODE = "pad_mode"; - -const std::string ATTR_NAME_PADMODES = "padding"; - -const std::string ATTR_NAME_MODE = "mode"; - -const std::string ATTR_NAME_FILTER = "filter"; - -const std::string ATTR_NAME_BIAS = "bias"; - -const std::string ATTR_NAME_BIAS_TERM = "bias_term"; - -const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; - -const std::string ATTR_NAME_PAD = "pad"; - -const std::string ATTR_NAME_PADS = "pad"; - -const std::string ATTR_NAME_PAD_SIZE = "pad size"; - -const std::string ATTR_NAME_PAD_MODE = "pad mode"; - -const std::string ATTR_NAME_SCALE = "scale"; - -const std::string ATTR_NAME_WINDOWS = "windows"; - -const std::string ATTR_NAME_GLOBAL_POOLING = "global_pooling"; - -const std::string ATTR_NAME_CEIL_MODE = "ceil_mode"; - -const std::string ATTR_NAME_RELUMODE = "relu_mode"; - -const std::string ATTR_NAME_STRIDE_SIZE = "stride size"; - -const std::string ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string ATTR_NAME_ALGO = "algo"; - -const std::string ATTR_NAME_FORMAT = "format"; - -const std::string ATTR_NAME_STORAGE_FORMAT = "storage_format"; - -const std::string ATTR_NAME_ORIGIN_FORMAT_IS_SET = "origin_format_is_set"; - -const std::string ATTR_NAME_STORAGE_SHAPE = "storage_shape"; - -const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; - -const std::string ATTR_NAME_LRN_K = "lrn_k"; - -const std::string ATTR_NAME_LRN_NORM_REGION = "lrn_normregion"; - -const std::string ATTR_NAME_LRN_LOCAL_SIZE = "lrn_localsize"; - -const std::string ATTR_NAME_LRN_ALPHA = "lrn_alpha"; - -const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; - -const std::string ATTR_NAME_AXIS = "axis"; -const std::string ATTR_NAME_BROADCAST = "broadcast"; - -const std::string ATTR_NAME_OUTPUT = "output"; -const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; -const std::string ATTR_NAME_TIDX = "t_idx"; - -const std::string ATTR_NAME_TPADDINGS = "t_paddings"; -const std::string ATTR_IMG_H = "img_h"; -const std::string ATTR_IMG_W = "img_w"; -const std::string ATTR_NET_H = "net_h"; -const std::string ATTR_NET_W = "net_w"; - -const std::string ATTR_NAME_TMULTIPLES = "t_multiples"; - -const std::string ATTR_NAME_MULTIPLES = "multiples"; - -const std::string ATTR_NAME_T = "T"; -const std::string ATTR_NAME_N = "N"; - -const std::string ATTR_NAME_TSHAPE = "Tshape"; -const std::string ATTR_NAME_NAN_OPT = "nan_opt"; - -const std::string ATTR_NAME_AIPP = "aipp"; -const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; - -const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs"; -const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs"; - -const std::string ATTR_NAME_INPUT_DIMS = "input_dims"; -const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS = "_dynamic_aipp_input_dims"; -const std::string ATTR_DATA_RELATED_AIPP_MODE = "_data_related_aipp_mode"; -const std::string ATTR_DATA_AIPP_DATA_NAME_MAP = "_data_aipp_data_name_map"; - -const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED = "_graph_has_been_added"; - -const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; -const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; - -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; -const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; -const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; - -const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; -const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; - -const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; -const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; -const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; -const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; -const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; - -const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; -const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; - -const std::string ATTR_NAME_INFERRED_FORMAT = "inferred_format"; -const std::string ATTR_NAME_PRED_PERMUTE_DELETED = "pred_permute_deleted"; -const std::string ATTR_NAME_IGNORE_PRED_FORMAT = "ignore_pred_format"; -const std::string ATTR_NAME_WEIGHTS = "value"; -const std::string ATTR_NAME_WEIGHT_SHA256 = "_value_sha256"; -const std::string ATTR_NAME_IS_REUSE_EXTERNAL_WEIGHT = "_is_reuse_external_weight"; -const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; -const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; -const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; -const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; -const std::string ATTR_NAME_RTS_LABEL_NODE = "_rts_label_node"; -const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; -const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; -const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; -const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; -const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; -const std::string ATTR_NAME_ATOMIC_MEMSET_SIZES = "sizes"; -const std::string ATTR_NAME_ATOMIC_MEMSET_DTYPES = "dtypes"; -const std::string ATTR_NAME_ATOMIC_MEMSET_VALUES_INT = "values_int"; -const std::string ATTR_NAME_ATOMIC_MEMSET_VALUES_FLOAT = "values_float"; -const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims"; -const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size"; -const std::string ATTR_NAME_SEND_EVENT_IDS = "_send_event_ids"; -const std::string ATTR_NAME_RECV_EVENT_IDS = "_recv_event_ids"; -const std::string ATTR_NAME_INIT_VALUE = "_init_value"; - -const std::string ATTR_NAME_ROOT_GRAPH_ID = "_root_graph_id"; -const std::string ATTR_NAME_ROOT_GRAPH_NAME = "_root_graph_name"; -const std::string ATTR_NAME_IS_ROOT_GRAPH = "_is_root_graph"; -// Identify node connecting to input and output -const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data"; -const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput"; - -// Need Map rank id when hccl task init for NPU -const std::string ATTR_NAME_NEED_MAP_RANK_ID = "_need_map_rank_id"; - -// To be deleted -const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; -const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL = "fusion_conv_proposal"; -const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX = "fusion_conv_decodebbox"; -const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_MBOX_LOC_FUSION = "permute_flatten_fusion"; -const std::string SSD_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; -const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; - -// Refinedet -const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; - -const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; -const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; -const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; -const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; - - -// _Arg -const std::string ATTR_NAME_INDEX = "index"; -// _RetVal -const std::string RETVAL_ATTR_NAME_INDEX = "retval_index"; -// Data -const std::string DATA_ATTR_NAME_DATA_TYPE = "data_type"; - -// Send -const std::string SEND_ATTR_EVENT_ID = "event_id"; - -// SendNotify -const std::string SEND_ATTR_NOTIFY_ID = "notify_id"; - -// Recv -const std::string RECV_ATTR_EVENT_ID = "event_id"; - -// RecvNotify -const std::string RECV_ATTR_NOTIFY_ID = "notify_id"; - -// convolution -const std::string ATTR_NAME_COEF = "coef"; - -const std::string ATTR_NAME_STRIDE = "stride"; - -const std::string ATTR_NAME_STRIDES = "stride"; - -const std::string ATTR_NAME_DILATION = "dilation"; - -const std::string ATTR_NAME_DILATIONS = "dilation"; - -const std::string CONV_ATTR_NAME_MODE = "mode"; - -const std::string CONV_ATTR_NAME_ALGO = "algo"; - -const std::string CONV_ATTR_NAME_GROUP = "group"; - -const std::string CONV_ATTR_NAME_PAD_MODE = "pad_mode"; - -const std::string CONV_ATTR_NAME_PAD = "pad"; - -const std::string CONV_ATTR_NAME_STRIDE = "stride"; - -const std::string CONV_ATTR_NAME_DILATION = "dilation"; - -const std::string CONV_ATTR_NAME_NUM_OUTPUT = "num_output"; - -const std::string CONV_ATTR_NAME_KERNEL = "kernel"; - -const std::string CONV_ATTR_NAME_FILTER = "filter"; - -const std::string CONV_ATTR_NAME_BIAS = "bias"; - -const std::string CONV_ATTR_NAME_RELU_FLAG = "relu_flag"; - -const std::string CONV_ATTR_NAME_ADJ = "adj"; - -const std::string CONV_ATTR_NAME_TARGET_SHAPE = "target_shape"; - -const std::string CONV_ATTR_NAME_BEFORE_PAD = "before_pad"; - -const std::string CONV_ATTR_NAME_HAS_BIAS = "has_bias"; - -const std::string NEED_INFER = "isNeedInfer"; - -// Pooling -const std::string POOLING_ATTR_MODE = "mode"; -const std::string POOLING_ATTR_NAN_OPT = "nan_opt"; -const std::string POOLING_ATTR_PAD_MODE = "pad_mode"; -const std::string POOLING_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOLING_ATTR_WINDOW = "window"; -const std::string POOLING_ATTR_PAD = "pad"; -const std::string POOLING_ATTR_STRIDE = "stride"; -const std::string POOLING_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOLING_ATTR_DATA_MODE = "data_mode"; -const std::string POOLING_ATTR_BEFORE_PAD = "before_pad"; -const std::string POOLING_ATTR_NAME_ALGO = "algo"; - -// Eltwise -const std::string ELTWISE_ATTR_MODE = "mode"; -const std::string ELTWISE_ATTR_COEFF = "coeff"; -const std::string ELTWISE_ATTR_WEIGHT = "weight"; -const std::string ELTWISE_ATTR_RELU_FLAG = "relu_flag"; -const std::string ELTWISE_ATTR_ALPHA = "alpha"; -const std::string ELTWISE_ATTR_BETA = "beta"; - -// BatchNorm -const std::string BATCHNORM_ATTR_MODE = "mode"; -const std::string BATCHNORM_ATTR_EPSILON = "epsilon"; -const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS = "use_global_stats"; -const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION = "moving_average_fraction"; -const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; -const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; -const std::string BATCHNORM_ATTR_SCALE = "scale"; -const std::string BATCHNORM_ATTR_BIAS = "bias"; -const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; -const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; -const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; - -// huberloss -const std::string HUBER_LOSS_ATTR_DELTA = "delta"; - -// SSDRealDivTileMul -const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; - -// SSDSumMulRealDivMean -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; -const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; - -// ConcatFive2Four -// ConcatFour2Five -const std::string SSD_BOX_TYPE_NUM = "box_type_num"; -const std::string SSD_CLASS_NUM = "class_num"; -const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; -const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; -const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; -const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; - -// Scale -const std::string SCALE_ATTR_SCALE = "scale"; -const std::string SCALE_ATTR_BIAS = "bias"; - -// FullConnection -const std::string FULL_CONNECTION_ATTR_FILTER = "filter"; -const std::string FULL_CONNECTION_ATTR_BIAS = "bias"; -const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT = "num_output"; -const std::string FULL_CONNECTION_ATTR_RELU_FLAG = "relu_flag"; -const std::string FULL_ATTR_NAME_ALGO = "algo"; - -// SoftmaxOpParams -const std::string SOFTMAX_ATTR_ALGO = "algo"; -const std::string SOFTMAX_ATTR_MODE = "mode"; - -// SparseSoftmaxCrossEntropy -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE = "cross_entropy_mode"; -const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD = "cross_entropy_is_grad"; -// Attr labelSmoothing -const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING = "labelSmoothing"; - -// ApplyMomentum -const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION = "applymomentum_is_graph_fusion"; - -// Activation -const std::string ACTIVATION_ATTR_MODE = "mode"; -const std::string ACTIVATION_ATTR_COEF = "coef"; - -// Concat -const std::string CONCAT_ATTR_NAME_AXIS = "axis"; - -// Const -const std::string CONST_ATTR_NAME_DATA_TRANSTYPE = "data_transtype"; -const std::string CONST_ATTR_NAME_OUTPUT_FORMAT = "output_format"; -const std::string CONST_ATTR_NAME_OUTPUT_TYPE = "output_type"; -const std::string CONST_ATTR_NAME_INPUT = "is_const"; - -// Roipooling -const std::string ROIPOOLING_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIPOOLING_ATTR_NAME_POOLED_W = "pooled_w"; -const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE = "rio_pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE = "pooling_mode"; -const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO = "sampling_ratio"; - -// DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES = "num_classes"; -const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES = "ocr_num_classes"; -const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD = "nms_threshold"; -const std::string DETECTIONOUTPUT_ATTR_TOP_K = "top_k"; -const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD = "confidence_threshold"; -const std::string DETECTIONOUTPUT_ATTR_IMG_H = "img_h"; -const std::string DETECTIONOUTPUT_ATTR_IMG_W = "img_w"; -const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE = "batch_size"; -// Ssd DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_ETA = "eta"; -const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION = "shared_location"; -const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID = "background_label_id"; -const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE = "code_type"; -const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET = "variance_encoded_in_target"; -const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K = "keep_top_k"; -// Refinedet DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE = "objectness_score"; -// yolo DetectionOutput -const std::string DETECTIONOUTPUT_ATTR_ClASSES = "classes"; -const std::string DETECTIONOUTPUT_ATTR_BIASES = "biases"; -const std::string DETECTIONOUTPUT_ATTR_RELATIVE = "relative"; -const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD = "objectness_threshold"; -const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD = "class_threshold"; -const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K = "post_top_k"; -const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY = "iou_threshold_decay"; -const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR = "coor_scale_factor"; -const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION = "yolo_version"; - -// DetectionPostprocess -const std::string POSTPROCESS_ATTR_NAME_CLS_NUM = "cls_num"; -const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH = "conf_thresh"; -const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string POSTPROCESS_ATTR_POST_NMS_TOPN = "post_nms_topn"; -const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT = "bbox_reg_weights"; - -// Spatialtransfrom -const std::string SPTIALTF_ATTR_NAME_OUTPUT_H = "output_h"; -const std::string SPTIALTF_ATTR_NAME_OUTPUT_W = "output_w"; -const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE = "border_value"; -const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM = "affine_transform"; - -// Proposa -const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE = "feat_stride"; -const std::string PROPOSAL_ATTR_NAME_BASE_SIZE = "base_size"; -const std::string PROPOSAL_ATTR_NAME_MIN_SIZE = "min_size"; -const std::string PROPOSAL_ATTR_NAME_RATIO = "ratio"; -const std::string PROPOSAL_ATTR_NAME_SCALE = "scale"; -const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN = "pre_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN = "post_nms_topn"; -const std::string PROPOSAL_ATTR_NAME_NMS_THRESH = "nms_thresh"; -const std::string PROPOSAL_ATTR_NAME_TOP_SIZE = "top_size"; -const std::string PROPOSAL_ATTR_IMG_H = "img_h"; -const std::string PROPOSAL_ATTR_IMG_W = "img_w"; -// Softmax -const std::string SOFTMAX_ATTR_AXIS = "axis"; - -// Permute -const std::string PERMUTE_ATTR_ORDER = "order"; -const std::string PERMUTE_ATTR_PERM = "perm"; - -// SSD Normalize -const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; -const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED = "channel_shared"; -const std::string SSDNORMALIZE_ATTR_EPS = "eps"; - -// Flatten -const std::string FLATTEN_ATTR_AXIS = "axis"; -const std::string FLATTEN_ATTR_END_AXIS = "end_axis"; - -// SsdPRIORBOX -const std::string SSD_PRIOR_BOX_ATTR_FLIP = "flip"; -const std::string SSD_PRIOR_BOX_ATTR_CLIP = "clip"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_H = "img_h"; -const std::string SSD_PRIOR_BOX_ATTR_IMG_W = "img_w"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_H = "step_h"; -const std::string SSD_PRIOR_BOX_ATTR_STEP_W = "step_w"; -const std::string SSD_PRIOR_BOX_ATTR_OFFSET = "offset"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE = "min_size"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE = "max_size"; -const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM = "min_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM = "max_size_num"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO = "aspect_ratio"; -const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; -const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; - -// RefinedetDetectionOutput -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; -const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; - -// PRelu -const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; - -// Psroi pooling -const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string PSROIPOOLING_ATTR_OUTPUT_DIM = "output_dim"; -const std::string PSROIPOOLING_ATTR_GROUP_SIZE = "group_size"; - -// Power -const std::string POWER_ATTR_NAME_POWER = "power"; -const std::string POWER_ATTR_NAME_SCALE = "scale"; -const std::string POWER_ATTR_NAME_SHIFT = "shift"; - -// log -const std::string LOG_ATTR_NAME_SCALE = "scale"; -const std::string LOG_ATTR_NAME_SHIFT = "shift"; -const std::string LOG_ATTR_NAME_BASE = "base"; -// Pack -const std::string PACK_ATTR_NAME_NUM = "N"; - -// Unpack -const std::string UNPACK_ATTR_NAME_NUM = "num"; -const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; -// Gathernd -const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; - -// Argmax -const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; -const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; -const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; -const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; -const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; -const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; -const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; - -// upsample -const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; -const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; - -// Relu -const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; - -// FreeSpaceExtract -const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT = "org_height"; - -// Split -const std::string SPLIT_ATTR_NAME_SLICE_POINT = "slice_point"; -const std::string SPLIT_ATTR_NAME_SIZE_SPLIT = "size_split"; -const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; - -// Tvm -const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; -const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; -const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; -const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; - -// Ffts Tvm -const std::string TVM_ATTR_NAME_THREAD_MAGIC = "_thread_tvm_magic"; -const std::string TVM_ATTR_NAME_THREAD_BLOCKDIM = "_thread_tvm_blockdim"; -const std::string TVM_ATTR_NAME_THREAD_METADATA = "_thread_tvm_metadata"; -const std::string TVM_ATTR_NAME_THREAD_WORKSPACE_TYPE = "_thread_tvm_workspace_type"; -const std::string TVM_ATTR_NAME_THREAD_N_BATCH_SPLIT = "_thread_is_n_batch_split"; - -const std::string ATTR_NAME_THREAD_TBE_KERNEL_BUFFER = "_thread_tbe_kernel_buffer"; -const std::string ATTR_NAME_THREAD_TBE_KERNEL_NAME = "_thread_tbe_kernel_name"; - -// Squeeze -const std::string SQUEEZE_ATTR_AXIS = "axis"; -const std::string SQUEEZE_ATTR_DIMS = "squeeze_dims"; -const std::string SQUEEZE_OP_NAME = "Squeeze"; - -// Stride slice -const std::string STRIDE_SLICE_ATTR_BEGIN_MASK = "begin_mask"; -const std::string STRIDE_SLICE_ATTR_END_MASK = "end_mask"; -const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK = "ellipsis_mask"; -const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK = "new_axis_mask"; -const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK = "shrink_axis_mask"; - -// Slice -const std::string SLICE_ATTR_NAME_BEGINS = "begins"; -const std::string SLICE_ATTR_NAME_SIZES = "sizes"; - -// Roialign -const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; -const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; -const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; -const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; - -// Generate_rpn_proposal -const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK = "post_nms_topk"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE = "rpn_mini_size"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH = "rpn_proposal_nms_thresh"; -const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH = "rpn_proposal_filter_thresh"; -// Decode_bbox -const std::string DECODE_BBOX_ATTR_DECODECLIP = "decodeClip"; - -// Cast -const std::string CAST_ATTR_DSTT = "DstT"; -const std::string CAST_ATTR_SRCT = "SrcT"; -const std::string CAST_ATTR_DST_TYPE = "dst_type"; -const std::string CAST_ATTR_TRUNCATE = "truncate"; - -// Fastrcnnn predications -const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK = "fsr_topk"; -const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD = "fsr_score_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD = "fsr_nms_thres"; -const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES = "fsr_num_classes"; - -// REORG -const std::string REORG_ATTR_STRIDE = "stride"; -const std::string REORG_ATTR_REVERSE = "reverse"; - -// MERGE -const std::string MERGE_DEAD_INDEX = "merge_dead_index"; -const std::string MERGE_PRENODE_FLAG = "merge_prenode_flag"; -const std::string TO_BE_OUTPUT = "to_be_output"; - -// ENTER -const std::string ENTER_ATTR_FRAME_NAME = "frame_name"; -const std::string ENTER_ATTR_CONSTANT_FLAG = "is_constant"; - -// Concatv2 -const std::string CONCAT_V2_ATTR_TIDX = "Tidx"; -const std::string CONCAT_V2_ATTR_N = "N"; -// SUM -const std::string SUM_ATTR_TIDX = "Tidx"; -const std::string SUM_ATTR_AXIS = "axis"; -const std::string SUM_ATTR_KEEP_DIMS = "keep_dims"; - -// ResizeBilinear -const std::string RESIZE_BILINEAR_ATTR_MODE = "mode"; -const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS = "align_corners"; -const std::string RESIZE_BILINEAR_ATTR_HEIGHT = "height"; -const std::string RESIZE_BILINEAR_ATTR_WIDTH = "width"; -const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR = "zoom_factor"; -const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR = "shrink_factor"; -const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN = "pad_begin"; -const std::string RESIZE_BILINEAR_ATTR_PAD_END = "pad_end"; -const std::string RESIZE_BILINEAR_ATTR_ALPHA = "alpha"; -const std::string RESIZE_BILINEAR_ATTR_BETA = "beta"; - -// RetinaNet -const std::string RETINANET_FILTER_BACKGROUND_TRUE = "retina_conv_filter_background"; -const std::string RETINANET_ANCHOR_FUSION = "retina_anchor_fusion"; - -// MatMul -const std::string MATMUL_TRANSPOSE_X = "transposeX"; -const std::string MATMUL_TRANSPOSE_W = "transposeW"; -const std::string MATMUL_HAS_BIAS = "has_bias"; -const std::string MATMUL_ATTR_IS_TRAINING = "matmul_is_training"; - -// Flatten -const std::string FLATTEN_START_AXIS = "start_axis"; -const std::string FLATTEN_END_AXIS = "end_axis"; - -// Reshape -const std::string RESHAPE_ATTR_AXIS = "axis"; -const std::string RESHAPE_ATTR_NUM_AXES = "num_axes"; -const std::string RESHAPE_ATTR_FORMAT = "format"; -const std::string RESHAPE_ATTR_SHAPE = "shape"; -const std::string RESHAPE_ATTR_ALPHA = "alpha"; -const std::string RESHAPE_ATTR_BETA = "beta"; - -// Frameoworkop -const std::string T_IN_DATATYPE = "t_in_datatype"; -const std::string T_OUT_DATATYPE = "t_out_datatype"; -const std::string ATTR_NAME_OUT_N = "out_n"; -const std::string ATTR_NAME_OUT_C = "out_c"; -const std::string ATTR_NAME_OUT_H = "out_h"; -const std::string ATTR_NAME_OUT_W = "out_w"; -const std::string ATTR_PAD_DEPTH_CONV = "pad_depth_conv"; -const std::string ATTR_PAD_CONV = "pad_conv"; - -const std::string ATTR_NAME_BEFORE_PAD = "before_pad"; -const std::string ANN_MEAN_KEEPDIMS = "AnnMeanKeepDims"; -const std::string PAD_ATTR_PADDINGDS = "paddings"; -const std::string PAD_ATTR_CONSTANT_VALUE = "padvalue"; - -// ConvGradFilter -const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape"; -// ConvGradInput -const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; - -// Rnn -const std::string RNN_MODE_STATIC = "rnn_static"; -const std::string MUTI_RNN = "multi_rnn"; -const std::string CNN_RNN = "cnn_rnn"; -const std::string RNN_MODE_ = "rnn_"; - - -const std::string CELL_MODE = "mode"; -const std::string LSTM_CELL = "lstm_cell"; -const std::string GRU_CELL = "gru_cell"; -const std::string RNN_HT = "ht"; -const std::string RNN_XT_HT = "xt_ht"; -const std::string RNN_BATCH_SIZE = "batch_size"; -const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; -const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; -const std::string LSTM_ACTIVATE = "lstm_activate"; -const std::string LSTM_OUT_MAP = "lstm_out_map"; -const std::string LSTM_OUT_MODE = "lstm_out_mode"; -const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; -const std::string LSTM_TIME_MAJOR = "lstm_time_major"; -const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; - -// Upsample -const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; - -// PadV2 -const std::string PADV2_ATTR_NAME_MODE = "mode"; -const std::string PADV2_ATTR_NAME_PADS = "paddings"; -const std::string PADV2_ATTR_NAME_T = "T"; -const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; - -// MirrorPad -const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; -const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; -const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; -const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; -// Filler -const std::string FILLER_TYPE = "filler_type"; -const std::string FILLER_VALUE = "filler_value"; - -// Shufflechannel -const std::string SHUFFLE_CHANNEL_GROUP = "group"; - -// TopKV2 -const std::string TOPKV2_ATTR_K = "k"; - -// Calibaration -const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; -const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; -const std::string PAD_TOP_INDEX = "PAD_TOP_INDEX"; -const std::string PAD_BOTTOM_INDEX = "PAD_BOTTOM_INDEX"; -const std::string PAD_RIGHT_INDEX = "PAD_RIGHT_INDEX"; -const std::string PAD_LEFT_INDEX = "PAD_LEFT_INDEX"; -const std::string QUANTIZE_ALGO_ATTR = "quantize_algo"; -const std::string SCALE_TYPE_ATTR = "scale_type"; - -const std::string QUANTIZE_SCALE_MODE = "quantize_scale_mode"; -const std::string QUANTIZE_SCALE_VALUE = "quantize_scale_value"; -const std::string QUANTIZE_SCALE_OFFSET = "quantize_scale_offset"; -const std::string QUANTIZE_OFFSET_DATA_VALUE = "quantize_offset_data_value"; -const std::string QUANTIZE_OFFSET_DATA_OFFSET = "quantize_offset_data_offset"; -const std::string QUANTIZE_OFFSET_WEIGHT_VALUE = "quantize_offset_weight_value"; -const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET = "quantize_offset_weight_offset"; -const std::string QUANTIZE_OFFSET_PAD_VALUE = "quantize_offset_pad_value"; -const std::string QUANTIZE_OFFSET_PAD_OFFSET = "quantize_offset_pad_offset"; - -const std::string DEQUANTIZE_SCALE_MODE = "dequantize_scale_mode"; -const std::string DEQUANTIZE_SCALE_VALUE = "dequantize_scale_value"; -const std::string DEQUANTIZE_SCALE_OFFSET = "dequantize_scale_offset"; -const std::string DEQUANTIZE_OFFSET_DATA_TYPE = "dequantize_offset_data_value"; -const std::string DEQUANTIZE_OFFSET_DATA_OFFSET = "dequantize_offset_data_offset"; -const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE = "dequantize_offset_weight_value"; -const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET = "dequantize_offset_weight_offset"; -const std::string DEQUANTIZE_OFFSET_PAD_VALUE = "dequantize_offset_pad_value"; -const std::string DEQUANTIZE_OFFSET_PAD_OFFSET = "dequantize_offset_pad_offset"; - -const std::string REQUANTIZE_SCALE_MODE = "requantize_scale_mode"; -const std::string REQUANTIZE_SCALE_VALUE = "requantize_scale_value"; -const std::string REQUANTIZE_SCALE_OFFSET = "requantize_scale_offset"; -const std::string REQUANTIZE_OFFSET_DATA_VALUE = "requantize_offset_data_value"; -const std::string REQUANTIZE_OFFSET_DATA_OFFSET = "requantize_offset_data_offset"; -const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE = "requantize_offset_weight_value"; -const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET = "requantize_offset_weight_offset"; -const std::string REQUANTIZE_OFFSET_PAD_VALUE = "requantize_offset_pad_value"; -const std::string REQUANTIZE_OFFSET_PAD_OFFSET = "requantize_offset_pad_offset"; - -const std::string ATTR_NAME_IS_CONST = "attr_name_is_const"; - -const std::string ATTR_NAME_GROUP = "group"; -const std::string ATTR_NAME_DILATION_SIZE = "dilation_size"; -const std::string ATTR_NAME_EPSILON = "epsilon"; -const std::string ATTR_NAME_POOLING_MODE = "mode"; -const std::string ATTR_NAME_CLASS_NUM = "class_num"; -// model -const std::string ATTR_MODEL_TARGET_TYPE = "target_type"; - -const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; - -const std::string ATTR_MODEL_EVENT_NUM = "event_num"; - -const std::string ATTR_MODEL_NOTIFY_NUM = "notify_num"; - -const std::string ATTR_MODEL_NOTIFY_TYPES = "notify_types"; - -const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; - -const std::string ATTR_MODEL_LABEL_NUM = "label_num"; - -const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; - -const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; - -const std::string ATTR_MODEL_P2P_MEMORY_SIZE = "p2p_memory_size"; - -const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; - -const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; - -const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; - -const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; - -const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR = "task_gen_variable_addr"; - -const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; - -const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; - -const std::string ATTR_MODEL_CORE_TYPE = "core_type"; - -const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; - -const std::string ATTR_MODEL_ATC_CMDLINE = "atc_cmdline"; - -const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; - -const std::string ATTR_MODEL_COMPILER_VERSION = "compiler_version"; - -const std::string ATTR_MODEL_SESSION_SCOPE_MEMORY_SIZE = "session_scope_memory_size"; - -const std::string ATTR_MODEL_SUB_MEMORY_INFO = "sub_memory_info"; - -// Used for om compress -const std::string ATTR_MODEL_OM_COMPRESS_VERSION = "om_compress_version"; - -const std::string ATTR_MODEL_ATTR_NAME_ENUM = "attr_name_enum"; - -const std::string ATTR_MODEL_ATTR_VALUE_ENUM = "attr_value_enum"; - -const std::string ATTR_MODEL_ATTRS_USE_STRING_VALUE = "attrs_use_string_value"; - -// Public attribute -const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; - -const std::string ATTR_NAME_BYTE_SIZE = "op_byte_size"; - -const std::string ATTR_NAME_FUSION_INFERENCE_ID = "fusion_inference_id"; - -const std::string ATTR_NAME_FUSION_OPDEF = "fusion_opdef"; - -const std::string ATTR_NAME_IO_OP = "io_op"; - -const std::string ATTR_NAME_FUSION_SCOPE = "fusion_scope"; - -const std::string ATTR_NAME_OPATTR = "opattr"; - -const std::string ATTR_NAME_SEQLEN_INDEX = "seqlen_index"; - -const std::string ATTR_NAME_X_INDEX = "x_index"; - -const std::string ATTR_NAME_CONT_INDEX = "cont_index"; - -const std::string ATTR_NAME_XSTATIC_INDEX = "xstatic_index"; - -const std::string TARGET_TYPE_MINI = "MINI"; - -const std::string TARGET_TYPE_TINY = "TINY"; - -const std::string TARGET_TYPE_LITE = "LITE"; - -// l2_normalize -const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; -const std::string L2_NORMALIZE_ATTR_EPS = "eps"; - -const std::string POOL_PARAMA_ATTR_WINDOW = "window"; -const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; -const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; -const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; -const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; -const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; - -// HCOM -const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; -const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; - -const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; -const std::string HCOM_ATTR_GROUP = "group"; -const std::string HCOM_ATTR_SR_TAG = "sr_tag"; -const std::string HCOM_ATTR_SRC_RANK = "src_rank"; -const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; -const std::string HCOM_ATTR_FUSION = "fusion"; -const std::string HCOM_ATTR_SHAPE = "shape"; -const std::string HCOM_ATTR_DATA_TYPE = "dtype"; - -// SpaceToDepth/DepthToSpace -const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; - -// SparseSoftmaxCrossEntropyWithLogits -const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; - -// MaxPoolGradWithArgmax -const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; - -// AvgPoolGrad -const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; - -// Pad -const std::string ATTR_PAD_FORMAT = "attr_pad_format"; - -// Varible -const std::string VAR_ATTR_FORMAT = "_var_format"; -const std::string VAR_ATTR_NAME = "var_name"; -const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; -const std::string VAR_ATTR_4D_FORMAT = "4D"; -const std::string VAR_ATTR_5D_FORMAT = "5D"; -const std::string VAR_ATTR_DATA_TYPE = "data_format"; -const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; -const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; -const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; -const std::string VAR_ATTR_SHAPE = "shape"; -const std::string HALF_VAR_NAME_END = "_fp16"; -const std::string VAR_ATTR_INITED = "var_is_inited"; - -const std::string VAR_ATTR_CONTAINER = "container"; -const std::string VAR_ATTR_SHARED_NAME = "shared_name"; -const std::string VAR_ATTR_DTYPE = "dtype"; - -const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; -const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; -const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; -const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; -const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; -const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; - -// Assign -const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; -const std::string ASSIGN_VAR_NAME = "_assign_var_name"; - -// Inplace support -const std::string INPLACE_SUPPORT_INPUT_INDEX = "_inplace_support_input_index"; - -// space2bacth batch2space -const std::string BATCH_SPACE_ATTR_BLOCK = "block"; -const std::string BATCH_SPACE_ATTR_PADDING = "padding"; - -// depth_to_space space_to_depth -const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; - -// FakeQuantWithMinMaxVars -const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; -const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; - -// mobilenet_ssd_conv_fusion -const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; -const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; -const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; - -// lsh project -const std::string LSH_PROJ_TYPE = "lsh_project_type"; - -// log time stamp -const std::string LOG_TIME_STAMP_LOGID = "logid"; -const std::string LOG_TIME_STAMP_NOTIFY = "notify"; - -// ShapeN -const std::string SHAPEN_ATTR_N = "N"; -const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; -const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; -const std::string ATTR_NAME_SPLIT_SHAPEN_ORIGIN_NAME = "_split_shapen_origin_name"; - -// GatherV2 attr def -const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; -const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; -const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; - -// Reshape attr def -const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; -const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; - -// axis attr def -const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; - -const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; - -const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; -const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; - -// For constant folding -const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; - -const std::string ATTR_NAME_IS_INSERTED_BY_CANN = "_is_inserted_by_cann"; - -const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; - -const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; - -const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; - -// For AscendWeightQuant+Enter -const std::string ATTR_NAME_FINAL_CONST_NODE = "_final_const_node"; - -// attr _input_mutable = true means node will modify its input in runtime -const std::string ATTR_NAME_MODIFY_INPUT = "_input_mutable"; - -const std::string ATTR_NAME_REFERENCE = "reference"; - -const std::string ATTR_NAME_NOTASK = "_no_task"; - -const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; - -const std::string ATTR_NAME_GRAPH_OUTPUT_MAX_SIZE = "_graph_output_max_size"; - -const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; - -const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; - -const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; - -const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; - -// Used for mark the active label list stream of activated node -const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; - -// Used for l2cache, true: the memory of all inputs is used for the last time. -const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; - -const std::string ATTR_NAME_DATA_VISIT_DISTANCE = "_data_visit_distance"; - -// Multi batch -const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; -const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; -const std::string ATTR_NAME_BATCH_LABEL = "_batch_label"; -const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch"; - -// Control flow -const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; -const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; -const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; -const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; -const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; -const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; -const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; -const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims"; - -const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; -const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; -const std::string ATTR_NAME_SWITCH_DATA_TYPE = "_switch_data_type"; -const std::string ATTR_NAME_ORIG_NODE_NAME = "_original_node_name"; -const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; -const std::string ATTR_NAME_STREAM_SWITCH_TYPE = "_stream_switch_type"; - -const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; - -// Function Op -const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; - -const std::string ATTR_NAME_NEED_INFER_AGAIN = "_need_infer_again"; - -const std::string ATTR_NAME_MERGE_INPUT_INDEX = "_merge_input_index"; -const std::string ATTR_NAME_CONTROL_FLOW_GROUP = "_control_flow_group"; - -// Used for mark the active node is for loop, type:bool -const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; - -const std::string ATTR_NAME_MEMORY_TYPE_INPUT = "memory_type_input"; - -const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT = "memory_type_output"; - -const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; - -const std::string ATTR_NAME_MEMORY_TYPE_RANGE = "_memory_type_range"; - -const std::string MODEL_ATTR_SESSION_ID = "session_id"; - -// lx fusion -const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; -const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; -const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; -const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; -const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; -const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; -const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; -const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; -const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; -const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; -const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; -const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; -const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; -const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; -const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; -const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; -const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; -const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; -const std::string ATTR_NAME_ENGINE_NAME_FOR_LX = "_lxfusion_engine_name"; -const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX = "_lxfusion_op_kernel_lib_name"; -const std::string ATTR_NAME_NEED_LX_FUSION = "_lx_fusion"; -const std::string ATTR_NAME_OPTIMIZE_GROUP = "_optimize_group"; -const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy"; -const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; -const std::string ATTR_NAME_TBE_KERNEL_NAME_FOR_LOAD = "_tbe_kernel_name_for_load"; -const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; -const std::string ATTR_NAME_DATA_SLICE = "_data_slice"; -const std::string ATTR_NAME_NEED_RECOVER_ATTR = "_need_recover_attr"; -const std::string ATTR_NAME_OFF_SUPERKERNEL_ATTR = "_off_superkernel"; -const std::string ATTR_NAME_SRC_CONST_NAME = "_src_const_name"; -const std::string ATTR_NAME_NEED_KEEP_ORIGIN_INPUT_AND_OUTPUT = "_need_keep_origin_input_and_output"; - -// merge subgraph with output anchor map -const std::string ATTR_NAME_FUSION_ORIGIN_NAME = "_fusion_origin_name"; -const std::string ATTR_NAME_FUSION_ORIGIN_OUTPUT_INDEX = "_fusion_origin_output_index"; - -// read var offset -const std::string ATTR_NAME_INNER_OFFSET = "_inner_offset"; - -// used for memory allocate -const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; -const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; -const std::string ATTR_NAME_WORKSPACE_TYPE_LIST = "_workspace_type"; -const std::string ATTR_NAME_TENSOR_MEM_TYPE = "_tensor_memory_type"; -const std::string ATTR_NAME_SUB_STREAM_ID = "_sub_stream_id"; - -// Op debug attrs -const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; -const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; - -// Atomic addr clean attrs -const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; -const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; -const std::string ATOMIC_ATTR_IS_FUSION_NODE = "is_fusion_node"; -const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO = "sub_node_workspace_info"; -const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET = "sub_node_workspace_offset"; -const std::string ATOMIC_ATTR_IS_ATOMIC_NODE = "is_atomic_node"; -const std::string ATOMIC_ATTR_TVM_MAGIC = "_atomic_tvm_magic"; -const std::string ATOMIC_ATTR_TVM_METADATA = "_atomic_tvm_metadata"; -const std::string ATOMIC_ATTR_TBE_KERNEL_NAME = "_atomic_tbe_kernel_name"; -const std::string EXT_ATTR_ATOMIC_TBE_KERNEL = "_atomic_tbe_kernel"; - -// Source/dst format for Op FormatTransfer -const std::string FORMAT_TRANSFER_SRC_FORMAT = "src_format"; -const std::string FORMAT_TRANSFER_DST_FORMAT = "dst_format"; -const std::string FORMAT_TRANSFER_SRC_SUBFORMAT = "src_subformat"; -const std::string FORMAT_TRANSFER_DST_SUBFORMAT = "dst_subformat"; - -// For compile op by ge call -const std::string ATTR_NEED_COMPILE = "_node_need_compile"; - -const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; - -const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; - -const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; - -const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER = "user_designate_shape_order"; - -// For inserted op -const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; - -// For compress weight -const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; - -// For data dump -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES = "_datadump_original_op_types"; -const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; -const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX = "_datadump_sub_spliter_index"; -const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME = "_datadump_group_op_name"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME = "_datadump_origin_name"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_output_index"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; -const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; -const std::string ATTR_NAME_ORIGIN_OP_ATTRS_IN_FUSION_PROCESS = "_original_op_attrs_in_fusion_process"; -const std::string ATTR_NAME_ORIGIN_OP_ATTRS_MAP = "_original_op_attrs_map"; - -// functional ops attr -const std::string ATTR_NAME_IF_THEN_BRANCH = "then_branch"; -const std::string ATTR_NAME_IF_ELSE_BRANCH = "else_branch"; -const std::string ATTR_NAME_WHILE_COND = "cond"; -const std::string ATTR_NAME_WHILE_BODY = "body"; - -// used for label switch -const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; -const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; -const std::string ATTR_NAME_SUBGRAPH_END_NODE = "_subgraph_end_node"; - -const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; -const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; - -// used for LX tiling -const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; -const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; -const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; -const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; -const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; -const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list"; - -// for unregistered op -const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath"; -const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist"; - -// used for Horovod -const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; -const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; -// used for allreduce tailing optimization -const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; -const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; -// used for recording the number of tasks to be issued for each operator -const std::string ATTR_NAME_HCCL_TASK_NUM = "_hccl_task_num"; -// used for recording task num of RTS nodes such as MemcpyAsync -const std::string ATTR_NAME_NODE_SQE_NUM = "_node_sqe_num"; -// used for parallel group -const std::string ATTR_NAME_PARALLEL_GROUP = "_parallel_group"; - -const std::string ATTR_NAME_IS_SUPPORT_ADDR_REFRESH = "_is_support_addr_refresh"; - -const std::string ATTR_NAME_HCCL_GROUP_ID_LIST = "_hccl_group_id_list"; - -// dynamic shape attr -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; -const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; -const std::string ATTR_DYNAMIC_SHAPE_SINGLE_AICPU = "_single_aicpu_dynamic"; - -// op dynamic input -const std::string ATTR_NAME_DYNAMIC_INPUT_START = "_dynamic_input_index_start"; -const std::string ATTR_NAME_DYNAMIC_INPUT_END = "_dynamic_input_index_end"; - -// atc user def dtype&format -const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; -const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; - -// atc user def dtype&format -const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES = "_user_defined_output_nodes"; - -// for fusion op plugin -const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; - -// graph partition for aicpu -const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name"; -const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name"; - -// aicpu workspace type -const std::string ATTR_NAME_AICPU_WORKSPACE_TYPE = "_aicpu_workspace_type"; - -// input and output memory type -const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; -const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; -const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; -const std::string ATTR_NAME_SPECIAL_OUTPUT_SIZE = "_special_output_size"; -const std::string ATTR_NAME_SPECIAL_INPUT_SIZE = "_special_input_size"; -const std::string ATTR_NAME_INPUT_MEMORY_SCOPE = "_input_memory_scope"; -const std::string ATTR_NAME_OUTPUT_MEMORY_SCOPE = "_output_memory_scope"; -const std::string ATTR_NAME_TENSOR_MEMORY_SCOPE = "_tensor_memory_scope"; - -// stage -const std::string ATTR_STAGE_LEVEL = "_stage_level"; - -// input_output_offset -const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; -const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; -const std::string ATTR_IS_ZERO_COPY_BLOCK = "_is_zero_copy_block"; -const std::string ATTR_NAME_INPUT_OFFSET_LIST_FOR_CONTINUOUS = "_input_offset_list_for_continuous"; -const std::string ATTR_NAME_OUTPUT_OFFSET_LIST_FOR_CONTINUOUS = "_output_offset_list_for_continuous"; - -// mark node cannot be deleted -const std::string ATTR_NAME_CANNOT_BE_DELETED = "_cannot_be_deleted"; - -// The processing mode of INF and NAN during floating-point number calculation. -const std::string ATTR_FP_CEILING_MODE = "_fp_ceiling_mode"; -// count of data from getnext_sink -const std::string ATTR_GETNEXT_SINK_DATA_COUNT = "N"; -const std::string ATTR_GETNEXT_SINK_SHAPE_INFO = "shape_info"; - -// getnext_sink marked on NetOutput -const std::string ATTR_GETNEXT_SINK_DYNMAIC = "getnext_sink_dynamic"; -const std::string ATTR_ALL_GEARS_INFO = "all_gears_info"; - -// Calculate the operator output memory -const std::string ATTR_NAME_MEMORY_SIZE_CALC_TYPE = "_memory_size_calc_type"; -// Indicates which operators keep the precision unchanged -const std::string ATTR_NAME_KEEP_DTYPE = "_keep_dtype"; - -// profiling task mark on fp bp -const std::string ATTR_NAME_INSERT_FP_PROFILILNG_TASK = "_fp_profiling_task"; -const std::string ATTR_NAME_INSERT_BP_PROFILILNG_TASK = "_bp_profiling_task"; -const std::string ATTR_NAME_INSERT_END_PROFILILNG_TASK = "_end_profiling_task"; -const std::string ATTR_NAME_INSERT_PROFILILNG_TASK_LOG_ID = "_profiling_log_id"; -// padding dimension type (FE set and ge get) -const std::string ATTR_NAME_RESHAPE_INFER_TYPE = "_infer_reshape_type"; -const std::string ATTR_NAME_RESHAPE_TYPE_MASK = "_reshape_type_mask"; - -// mark single op scene -const std::string ATTR_SINGLE_OP_SCENE = "_single_op_scene"; - -// for fe judge whether trans/cast op is inserted -const std::string ATTR_NAME_FORMAT_CONTINUOUS = "_format_continuous"; -const std::string ATTR_NAME_REFRESH_CONTINUOUS_FLAG = "_refresh_continuous_flag"; -const std::string ATTR_NAME_FORMAT_AGNOSTIC = "_format_agnostic"; -const std::string ATTR_NAME_FORMAT_AGNOSTIC_EXCEPT_OUTPUT = "_format_agnostic_except_output"; -const std::string ATTR_NAME_FORMAT_AGNOSTIC_EXCEPT_INPUT = "_format_agnostic_except_input"; - -// for ffts/ffts_plus -const std::string ATTR_NAME_FFTS_SUB_GRAPH = "_ffts_sub_graph"; -const std::string ATTR_NAME_THREAD_SCOPE_ID = "_thread_scope_id"; -const std::string ATTR_NAME_THREAD_MODE = "_thread_mode"; -const std::string ATTR_NAME_FFTS_PLUS_SUB_GRAPH = "_ffts_plus_sub_graph"; -const std::string ATTR_NAME_COMPOSITE_ENGINE_NAME = "_composite_engine_name"; -const std::string ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME = "_composite_engine_kernel_lib_name"; -const std::string ATTR_NAME_CUBE_VECTOR_CORE_TYPE = "_cube_vector_core_type"; -const std::string ATTR_NAME_CACHE_PERSIST = "_cache_persist"; -const std::string ATTR_NAME_ALIAS_ENGINE_NAME = "_alias_engine_name"; -const std::string ATTR_NAME_KERNEL_NAMES_PREFIX = "_kernel_names_prefix"; -const std::string ATTR_NAME_FFTS_SUB_TASK_TENSOR_SIZE = "_ffts_sub_task_tensor_size"; -const std::string ATTR_NAME_FFTS_SUB_TASK_TENSOR_OFFSETS = "_ffts_sub_task_tensor_offsets"; -const std::string ATTR_NAME_IS_FFTS_UNSUPPORTED = "_is_ffts_unsupported"; - -// mark fuzz build scene -const std::string ATTR_NAME_FUZZ_BUILD = "_fuzz_build"; -const std::string ATTR_NAME_PLACEMENT = "_mem_type"; -const std::string ATTR_NAME_VALUE = "_value"; -const std::string ATTR_NAME_VALUE_RANGE = "_value_range"; -const std::string ATTR_NAME_BUILD_MODE = "_build_mode"; -const std::string ATTR_NAME_FUZZ_BUILD_RES_ATTRS = "_fuzz_build_res"; -const std::string ATTR_NAME_FUZZ_INPUTS_SUPPORTED_ATTRS = "_inputs_support_info"; -const std::string ATTR_NAME_FUZZ_OUTPUTS_SUPPORTED_ATTRS = "_outputs_support_info"; -const std::string ATTR_NAME_FUZZ_IS_HIGH_PERFORMANCE_ATTRS = "_is_high_performance"; -const std::string ATTR_NAME_IS_ORIGINAL_INPUT = "_is_original_input"; -const std::string ATTR_NAME_IS_OP_GENERALIZED = "_is_op_generalized"; -const std::string ATTR_NAME_IS_DYNAMIC_MODEL = "_is_dynamic_model"; - -// buffer pool allocator -const std::string ATTR_NAME_BUFFER_POOL_ID = "_buffer_pool_id"; -const std::string ATTR_NAME_BUFFER_POOL_SIZE = "_buffer_pool_size"; -const std::string ATTR_NAME_EVENT_MULTIPLEXING = "_event_multiplexing"; -const std::string ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET = "_buffer_pool_node_size_and_offset"; - -// session scope memory -const std::string ATTR_NAME_WORKSPACE_MEMORY_NO_REUSE_SCOPE = "_workspace_memory_no_reuse_scope"; - -// for blocking op -const std::string ATTR_NAME_IS_BLOCKING_OP = "_is_blocking_op"; -const std::string ATTR_NAME_BLOCKING_OP_TIMEOUT = "_blocking_op_timeout"; - -// for op specified engine -const std::string ATTR_NAME_OP_SPECIFIED_ENGINE_NAME = "_specified_engine_name"; -const std::string ATTR_NAME_OP_SPECIFIED_KERNEL_LIB_NAME = "_specified_kernel_lib_name"; - -// for pipeline partition -const std::string ATTR_NAME_PIPELINE_PARTITIONED = "_pipeline_partitioned"; -const std::string ATTR_NAME_OUTPUT_PIPELINE = "_output_pipeline"; - -// for partition -const std::string ATTR_NAME_NO_NEED_PARTITION_AND_MERGE = "_no_need_partition_and_merge"; -const std::string ATTR_NAME_NO_NEED_PARTITION = "_no_need_partition"; -const std::string ATTR_NAME_NO_NEED_MERGE = "_no_need_merge"; -const std::string ATTR_NAME_NO_NEED_DYNAMIC_SHAPE_PARTITION = "_no_need_dynamic_shape_partition"; - -// model deploy scheduler(mds) -const std::string ATTR_NAME_GRADIENT_NODE = "_gradient_node"; -const std::string ATTR_NAME_TRAINABLE_VAR = "_trainable_var"; -const std::string ATTR_NAME_FISSION_FACTOR = "_fission_factor"; -const std::string ATTR_NAME_DEPLOY_INFO = "_deploy_info"; -const std::string ATTR_NAME_CUT_INFO = "_cut_info"; -const std::string ATTR_NAME_DEPLOY_DEVICE_TYPE = "_device_type"; -const std::string ATTR_NAME_DEPLOY_DEVICE_ID = "_device_id"; -const std::string ATTR_NAME_REDUNDANT_DEPLOY_DEVICE_ID = "_redundant_device_id"; -const std::string ATTR_NAME_DEPLOY_GRAPH_INPUTS = "_graph_inputs"; -const std::string ATTR_NAME_DEPLOY_NEED_RETURN_RESULT = "_need_return_result"; -const std::string ATTR_NAME_FORCE_ATTACH_STREAM = "_force_attach_stream"; -const std::string ATTR_NAME_RT_MEMCPY_KIND = "_rt_memcpy_kind"; - -// for qos -const std::string ATTR_NAME_QOS_SERVICE_LABEL = "_qos_service_label"; -const std::string ATTR_NAME_PARALLEL_GROUP_ID = "_parallel_group_id"; - -// for constant folding, mark potential const -const std::string ATTR_NAME_POTENTIAL_CONST = "_is_potential_const"; -const std::string ATTR_NAME_POTENTIAL_WEIGHT = "_potential_weight"; -const std::string ATTR_NAME_POTENTIAL_WEIGHT_INDICES = "_potential_weight_indices"; - -// name of network output tensor -const std::string ATTR_NAME_ORIGIN_OUTPUT_TENSOR_NAME = "_origin_output_tensor_name"; - -// for scope op to record the input and output information of the original graph node -const std::string ATTR_NAME_ORIGIN_GRAPH_NODE_INPUTS = "_origin_graph_node_inputs"; -const std::string ATTR_NAME_ORIGIN_GRAPH_NODE_OUTPUTS = "_origin_graph_node_outputs"; - -// for operator resource list(e.g. queues, channels) -const std::string ATTR_NAME_RESOURCE_LIST = "_resource_list"; - -// for no tiling -const std::string ATTR_NAME_OP_TILING_INLINE_ENGINE = "_op_tiling_inline_engine"; -const std::string ATTR_NAME_OP_EXPORT_SHAPE_ENGINE = "_op_export_shape_engine"; -const std::string ATTR_NAME_OP_MAX_SHAPE = "_op_max_shape"; -const std::string ATTR_NAME_OP_MAX_SIZE = "_op_max_size"; -const std::string ATTR_NAME_TENSOR_MAX_SHAPE = "_tensor_max_shape"; -const std::string ATTR_NAME_OP_NO_TILING = "_op_no_tiling"; -const std::string ATTR_NAME_TENSOR_DESC_MEM_OFFSET = "_tensor_desc_mem_offset"; -const std::string ATTR_NAME_TENSOR_NO_TILING_MEM_TYPE = "_tensor_no_tiling_mem_type"; - -// for soft sync op -const std::string ATTR_NAME_STATIC_TO_DYNAMIC_SOFT_SYNC_OP = "_static_to_dynamic_softsync_op"; -const std::string ATTR_NAME_SGT_CUBE_VECTOR_CORE_TYPE = "_sgt_cube_vector_core_type"; -const std::string ATTR_NAME_MAX_TILING_SIZE = "op_para_size"; - -// for subgraph multi dims -const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INDEX = "_subgraph_multi_dims_index"; -const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INPUT_SHAPE = "_subgraph_multi_dims_input_shape"; -const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INPUT_DIMS = "_subgraph_multi_dims_input_dims"; -const std::string ATTR_NAME_SUBGRAPH_IS_MULTI_DIMS = "_subgraph_is_multi_dims"; -const std::string ATTR_NAME_OP_MULTI_DIMS_INPUT_DIMS = "_op_multi_dims_input_dims"; - -// for support BlockDim -const std::string ATTR_NAME_SUPPORT_BLOCKDIM_FLAG = "_support_blockdim_flag"; -const std::string ATTR_NAME_BLOCKDIM_INDEX = "_blockdim_index"; - -// for support dynamic data flow ops -const std::string ATTR_NAME_DATA_FLOW_HANDLE = "_data_flow_handle"; -const std::string ATTR_NAME_DATA_FLOW_MAX_SIZE = "_data_flow_max_size"; - -// mark node inserted by ge -const std::string ATTR_NAME_IS_INSERTED_BY_GE = "_is_inserted_by_ge"; - -// for cmo feature -const std::string ATTR_NAME_MEMORY_REUSE_INFO = "_mem_reuse_info"; -const std::string ATTR_NAME_OP_READ_WRITE_INDEX = "_op_read_write_index"; -const std::string ATTR_NAME_MEM_RELEASE_FIRST_REUSE_FIRST = "_mem_release_first_reuse_first"; -const std::string ATTR_NAME_NODE_NEED_MULTI_TASK = "_op_need_multi_task"; - -// for support overflow detection -const std::string GLOBALWORKSPACE_SPEC_WORKSPACE = "globalworkspace_spec_workspace"; -const std::string GLOBALWORKSPACE_SPEC_WORKSPACE_BYTES = "globalworkspace_spec_workspace_bytes"; -const std::string GLOBALWORKSPACE_TYPE = "globalworkspace_type"; - -// for value depend -const std::string ATTR_NAME_VALUE_DEPEND = "_is_value_depend"; - -// for process node engine -const std::string ATTR_NAME_PROCESS_NODE_ENGINE_ID = "_process_node_engine_id"; - -// for dynamic graph memory discontiguous -const std::string ATTR_NAME_MEMORY_DISCONTIGUOUS_ALLOCATION = "_memory_discontiguous_allocation"; - -// for flow attribute -const std::string ATTR_NAME_FLOW_ATTR = "_flow_attr"; -const std::string ATTR_NAME_FLOW_ATTR_DEPTH = "_flow_attr_depth"; -const std::string ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY = "_flow_attr_enqueue_policy"; -const std::string ATTR_NAME_FLOW_ATTR_IO_PLACEMENT = "_flow_attr_io_placement"; - -// for aligned attribute -const std::string ATTR_NAME_INPUTS_ALIGN_ATTR = "_inputs_align_attr"; -const std::string ATTR_NAME_INPUTS_ALIGN_INTERVAL = "_inputs_align_interval"; -const std::string ATTR_NAME_INPUTS_ALIGN_OFFSET = "_inputs_align_offset"; - -// for binnary om feature -const std::string ATTR_NAME_OUT_SHAPE_LOCKED = "_out_shape_locked"; -const std::string ATTR_NAME_FORMAT_LOCKED = "_format_locked"; -const std::string ATTR_NAME_SKIP_GEN_TASK = "_skip_gen_task"; -const std::string ATTR_NAME_OM_BINARY_PATH = "_om_binary_path"; - -const std::string ATTR_NAME_NEED_GENTASK_ATOMIC = "need_gentask_atomic"; - -// for pipeline partition attribute -const std::string ATTR_NAME_PARALLEL_SHARDED = "_parallel_sharded"; -const std::string ATTR_NAME_PIPELINE_STAGE = "_pipeline_stage"; -const std::string ATTR_NAME_VIRTUAL_STAGE = "_virtual_stage"; -const std::string ATTR_NAME_LOGIC_DEV_ID = "_logic_device_id"; -const std::string ATTR_NAME_REDUNDANT_LOGIC_DEV_ID = "_redundant_logic_device_id"; -const std::string ATTR_NAME_STAGE_ORDER_ID = "_stage_order_id"; - -// for graph memory optimize -const std::string ATTR_NAME_RECOMPUTE = "_recompute"; -const std::string ATTR_NAME_BACKWARD = "_backward"; -const std::string ATTR_NAME_OPTIMIZER = "_optimizer"; -const std::string ATTR_NAME_GRAPH_SLICE_SCOPE = "_graph_slicing_scope"; -const std::string ATTR_NAME_GRAPH_SLICE_NUM = "_graph_slice_num"; -const std::string ATTR_NAME_IS_FIXED_ADDR_PRIOR = "_is_fixed_addr_prior"; - -// for model eshced priority -const std::string ATTR_NAME_ESCHED_PROCESS_PRIORITY = "_eschedProcessPriority"; -const std::string ATTR_NAME_ESCHED_EVENT_PRIORITY = "_eschedEventPriority"; - -// for graph deployment -const std::string ATTR_NAME_DEVICE_INDEX = "_device_index"; -const std::string ATTR_NAME_MODEL_INDEX = "_model_index"; -const std::string ATTR_NAME_DEVICE_INDEX_TO_LOGIC_DEVICE_ID = "_device_index_to_logic_device_id"; -const std::string ATTR_NAME_NODE_DEPLOYMENT = "_node_deployment"; -const std::string ATTR_NAME_NODE_DEPLOYMENTS = "_node_deployments"; -const std::string ATTR_NAME_TENSOR_DEPLOYMENT = "_tensor_deployment"; -const std::string ATTR_NAME_TENSOR_DEPLOYMENTS = "_tensor_deployments"; -const std::string ATTR_NAME_MODEL_INDEX_TO_LOGIC_DEVICE_ID = "_model_index_to_logic_device_id"; -const std::string ATTR_NAME_ORIGINAL_CONST_NAME = "_original_const_name"; -const std::string ATTR_NAME_DEPLOY_DEVICE_LIST = "_deploy_device_list"; -const std::string ATTR_NAME_RECOMPUTE_MODE = "_recompute_mode"; -const std::string ATTR_NAME_MODEL_EVENTS = "_model_events"; -const std::string ATTR_NAME_HCOM_GROUPS = "_hcom_groups"; -const std::string ATTR_NAME_SHARD_GRAPH_EXT_ATTRS = "_shard_graph_ext_attrs"; -const std::string ATTR_NAME_IS_SHARD_GRAPH_FOR_LOAD = "_is_shard_graph_for_load"; -const std::string ATTR_NAME_GRAPH_MODEL_DEPLOY_MODE = "_graphModelDeployMode"; - -// for tensor parallelism -const std::string ATTR_NAME_TP_RESHARD_ATTR = "_reshard_attr"; - -// for lowering -const std::string ATTR_NAME_GRAPH_FLATTEN_OFFSET = "graph_flatten_offset"; - -// for fileconstant -const std::string ATTR_NAME_FILE_CONSTANT_ID = "file_id"; -const std::string ATTR_NAME_FILE_PATH = "file_path"; -const std::string ATTR_NAME_FILE_CONSTANT_PATH = "_file_constant_path"; -const std::string ATTR_NAME_LOCATION = "location"; -const std::string ATTR_NAME_OFFSET = "offset"; -const std::string ATTR_NAME_LENGTH = "length"; - -// for embedding service -const std::string ATTR_NAME_EXECUTE_TIMES = "_execute_times"; -const std::string ATTR_NAME_MAX_KEY_NUM = "_max_key_num"; -const std::string ATTR_NAME_EMBEDDING_DIM = "_embedding_dim"; -const std::string ATTR_NAME_TAG_ID = "_tag_id"; -const std::string ATTR_NAME_OPTIMIZER_GRAPH_FLAG = "_optimizer_graph_flag"; -const std::string ATTR_NAME_EMBEDDING_GRAPH_FLAG = "_embedding_graph_flag"; -const std::string ATTR_NAME_DATA_TRANSFER_TYPE = "_data_transfer_type"; -const std::string ATTR_NAME_COMM_GROUP_NAMES = "_comm_group_names"; -const std::string ATTR_NAME_USE_COUNTER_FILTER = "_use_counter_filter"; - -const std::string ATTR_MODEL_HOST_ENV_OS = "host_env_os"; -const std::string ATTR_MODEL_HOST_ENV_CPU = "host_env_cpu"; - -// for mc2 stream assign, named_attr类型 -const std::string ATTR_NAME_DISABLE_ATTACHED_RESOURCE = "_disable_attached_resource"; -const std::string ATTR_NAME_ATTACHED_STREAM_INFO = "_attached_stream_info"; -// 下面的是name_attr上设置的属性名称,不要直接设置到op_desc上 -const std::string ATTR_NAME_ATTACHED_STREAM_KEY = "_attached_stream_key"; -const std::string ATTR_NAME_ATTACHED_STREAM_POLICY = "_attached_stream_policy"; - -const std::string ATTR_NAME_ATTACHED_STREAM_ID = "_attached_stream_id"; -const std::string ATTR_NAME_ATTACHED_NOTIFY_INFO = "_attached_notify_info"; -// 下面的是name_attr上设置的属性名称,不要直接设置到op_desc上 -const std::string ATTR_NAME_ATTACHED_NOTIFY_KEY = "_attached_notify_key"; -const std::string ATTR_NAME_ATTACHED_NOTIFY_TYPE = "_attached_notify_type"; -const std::string ATTR_NAME_ATTACHED_NOTIFY_POLICY = "_attached_notify_policy"; -const std::string ATTR_NAME_ATTACHED_NOTIFY_NUM = "_attached_notify_num"; - -// for tiling sink -const std::string ATTR_NAME_ATTACHED_SYNC_RES_INFO = "_attached_sync_res_info"; -// 下面的是name_attr上设置的属性名称,不要直接设置到op_desc上 -const std::string ATTR_NAME_ATTACHED_SYNC_RES_TYPE = "_attached_sync_res_type"; -const std::string ATTR_NAME_ATTACHED_SYNC_RES_KEY = "_attached_sync_res_key"; -const std::string ATTR_NAME_ATTACHED_STREAM_DEPEND_VALUE_LIST = "_attached_stream_depend_value_list"; - -// 由引擎归一打到op_desc上,类型为list_name_attr -const std::string ATTR_NAME_ATTACHED_STREAM_INFO_LIST = "_attached_stream_info_list"; -const std::string ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST = "_attached_sync_res_info_list"; -// 下面的属性是在ATTR_NAME_ATTACHED_STREAM_INFO_LIST跟ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST的属性内部,为name_attr类型,不要直接打在op_desc上 -const std::string ATTR_NAME_ATTACHED_RESOURCE_NAME = "_attached_resource_name"; -const std::string ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY = "_attached_resource_reuse_key"; -const std::string ATTR_NAME_ATTACHED_RESOURCE_DEPEND_VALUE_LIST_INT = "_attached_resource_depend_value_list_int"; -const std::string ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG = "_attached_resource_required_flag"; -const std::string ATTR_NAME_ATTACHED_RESOURCE_ID = "_attached_resource_id"; -const std::string ATTR_NAME_ATTACHED_RESOURCE_IS_VALID = "_attached_resource_is_valid"; -// 仅针对ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST属性内部,标识event/notify类型,默认为event类型 -const std::string ATTR_NAME_ATTACHED_RESOURCE_TYPE = "_attached_resource_type"; - -// skip prune -const std::string ATTR_NAME_SKIP_PRUNE_OPTIMIZE = "_force_skip_prune"; - -// for static shape reuse operator binary -const std::string ATTR_NAME_OP_RUN_INFO = "op_run_info"; -const std::string ATTR_NAME_ATOMIC_OP_RUN_INFO = "atomic_op_run_info"; -const std::string ATTR_NAME_MEMSET_NODE = "memset_node_ptr"; -const std::string ATTR_NAME_KERNEL_BIN_ID = "_kernel_bin_id"; -const std::string ATTR_NAME_TILING_POLICY = "_tiling_policy"; - -// for local reduce -const std::string REDUCE_OP_ATTR_KEEP_DIMS = "keep_dims"; - -const std::string LOGIC_STREAM_ID = "_logic_stream_id"; - -// storage format -const std::string ATTR_NAME_IS_HEAVY_OP = "_is_heavy_op"; -// tiling depend -const std::string ATTR_NAME_DYNAMIC_TILING_DEPEND_OP = "_dynamic_tiling_depend_op"; - -// for infer shape -const std::string ATTR_NAME_PRESET_OUTPUT_SHAPES = "_preset_output_shapes"; - -// session options -const std::string ATTR_NAME_SESSION_OPTIONS = "SessionOptions"; - -// graph options -const std::string ATTR_NAME_GRAPH_OPTIONS = "GraphOptions"; - -// host tensor -const std::string ATTR_NAME_HOST_TENSOR = "_host_tensor"; -// 是否支持分档属性 -const std::string ATTR_NAME_ENABLE_DYNAMIC_BATCH = "_enable_dynamic_batch"; -// 需要分档输入索引 -const std::string ATTR_NAME_DYNAMIC_BATCH_UNKNOWN_DATA_INDEX = "_dynamic_batch_unknown_data_index"; -// 分档中shape是-1的dim_index -const std::string ATTR_NAME_DYNAMIC_BATCH_UNKNOWN_DIM_INDEX = "_dynamic_batch_unknown_dim_index"; - -// for inplace ability -const std::string ATTR_NAME_OUTPUT_INPLACE_ABILITY = "_output_inplace_ability"; - -// for opp_kernel -const std::string ATTR_NAME_BINARY_SOURCE = "_opp_path"; - -const std::string ATTR_NAME_DO_NOT_CONSTANT_FOLDING = "_do_not_constant_folding"; - -// for super kernel -const std::string ATTR_NAME_SUPER_KERNEL_SCOPE = "_super_kernel_scope"; -const std::string ATTR_NAME_SUPER_KERNEL_OPTIONS = "_super_kernel_options"; - -// inference rule for torch or other framework with symbols -const std::string ATTR_NAME_INFER_RULE = "_inference_rule"; -} // namespace ge diff --git a/graph/attr/ge_attr_value.cc b/graph/attr/ge_attr_value.cc deleted file mode 100644 index 74c7bd8ec6e8aa83f31bf77de2a04dafd7cc50b8..0000000000000000000000000000000000000000 --- a/graph/attr/ge_attr_value.cc +++ /dev/null @@ -1,1003 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ge_attr_value.h" - -#include -#include "external/graph/graph.h" -#include "graph/utils/attr_utils.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/model_serialize.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "graph/buffer/buffer_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "graph/utils/tensor_utils.h" -#include "graph/serialization/attr_serializer_registry.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/math_util.h" - -static std::string GetOverflowDescribeOfListint(const std::string &name, const size_t &index, const int64_t &value) { - std::string reason = "When obtaining the attribute of " + name + ", the list_value[" + std::to_string(index) + - "] is " + std::to_string(value) + ", which exceeds the maximum value of integer type and causes value overflow."; - return reason; -} - -static std::string GetOverflowDescribeOfInt(const std::string &name, const int64_t &value) { - std::string reason = "When obtaining the attribute of " + name + ", the value is " + std::to_string(value) + - ", which exceeds the maximum value of integer type and causes value overflow."; - return reason; -} - -namespace ge { -namespace { -const std::map kAttrTypesMap = { - {AnyValue::VT_NONE, "VT_NONE"}, - {AnyValue::VT_STRING, "VT_STRING"}, - {AnyValue::VT_FLOAT, "VT_FLOAT"}, - {AnyValue::VT_BOOL, "VT_BOOL"}, - {AnyValue::VT_INT, "VT_INT"}, - {AnyValue::VT_TENSOR_DESC, "VT_TENSOR_DESC"}, - {AnyValue::VT_TENSOR, "VT_TENSOR"}, - {AnyValue::VT_BYTES, "VT_BYTES"}, - {AnyValue::VT_GRAPH, "VT_GRAPH"}, - {AnyValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, - {AnyValue::VT_LIST_LIST_INT, "VT_LIST_LIST_INT"}, - {AnyValue::VT_DATA_TYPE, "VT_DATA_TYPE"}, - {AnyValue::VT_LIST_STRING, "VT_LIST_STRING"}, - {AnyValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, - {AnyValue::VT_LIST_BOOL, "VT_LIST_BOOL"}, - {AnyValue::VT_LIST_INT, "VT_LIST_INT"}, - {AnyValue::VT_LIST_TENSOR_DESC, "VT_LIST_TENSOR_DESC"}, - {AnyValue::VT_LIST_TENSOR, "VT_LIST_TENSOR"}, - {AnyValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, - {AnyValue::VT_GRAPH, "VT_GRAPH"}, - {AnyValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, - {AnyValue::VT_LIST_DATA_TYPE, "VT_LIST_DATA_TYPE"}, -}; - -const std::map kAttrStrTypesMap = { - {"VT_NONE", AnyValue::VT_NONE}, - {"VT_STRING", AnyValue::VT_STRING}, - {"VT_FLOAT", AnyValue::VT_FLOAT}, - {"VT_BOOL", AnyValue::VT_BOOL}, - {"VT_INT", AnyValue::VT_INT}, - {"VT_TENSOR_DESC", AnyValue::VT_TENSOR_DESC}, - {"VT_TENSOR", AnyValue::VT_TENSOR}, - {"VT_BYTES", AnyValue::VT_BYTES}, - {"VT_GRAPH", AnyValue::VT_GRAPH}, - {"VT_NAMED_ATTRS", AnyValue::VT_NAMED_ATTRS}, - {"VT_LIST_LIST_INT", AnyValue::VT_LIST_LIST_INT}, - {"VT_DATA_TYPE", AnyValue::VT_DATA_TYPE}, - {"VT_LIST_STRING", AnyValue::VT_LIST_STRING}, - {"VT_LIST_FLOAT", AnyValue::VT_LIST_FLOAT}, - {"VT_LIST_BOOL", AnyValue::VT_LIST_BOOL}, - {"VT_LIST_INT", AnyValue::VT_LIST_INT}, - {"VT_LIST_TENSOR_DESC", AnyValue::VT_LIST_TENSOR_DESC}, - {"VT_LIST_TENSOR", AnyValue::VT_LIST_TENSOR}, - {"VT_LIST_BYTES", AnyValue::VT_LIST_BYTES}, - {"VT_GRAPH", AnyValue::VT_GRAPH}, - {"VT_LIST_NAMED_ATTRS", AnyValue::VT_LIST_NAMED_ATTRS}, - {"VT_LIST_DATA_TYPE", AnyValue::VT_LIST_DATA_TYPE}, -}; -} // namespace -void NamedAttrs::SetName(const std::string &name) { - name_ = name; -} - -std::string NamedAttrs::GetName() const { - return name_; -} - -AnyValue NamedAttrs::GetItem(const std::string &key) const { - AnyValue value; - (void) GetAttr(key, value); - return value; -} - -ProtoAttrMap &NamedAttrs::MutableAttrMap() { - return attrs_; -} - -ConstProtoAttrMap &NamedAttrs::GetAttrMap() const { - return attrs_; -} - -bool AttrUtils::HasAttr(ConstAttrHolderAdapter &&obj, const std::string &name) { - if (!obj) { - return false; - } - return obj->HasAttr(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, int32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (!IntegerChecker::Compat(int64_val)) { - const std::string reason = GetOverflowDescribeOfInt(name, int64_val); - REPORT_INNER_ERR_MSG("E18888", "%s", reason.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] %s", reason.c_str()); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, uint32_t &value) { - int64_t int64_val = 0; - if (!AttrUtils::GetInt(std::move(obj), name, int64_val)) { - return false; - } - if (!IntegerChecker::Compat(int64_val)) { - const std::string reason = GetOverflowDescribeOfInt(name, int64_val); - REPORT_INNER_ERR_MSG("E18888", "%s", reason.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] %s", reason.c_str()); - return false; - } - value = static_cast(int64_val); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - return GraphUtils::CloneOpDesc(org_op_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - return GraphUtils::CopyOpDesc(org_op_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetInt(AttrHolderAdapter &&obj, const std::string &name, const int64_t &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, int64_t &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, uint64_t &value) { - if (!obj) { - return false; - } - int64_t int64_val = 0; - const bool ret = GetAttrValue(obj->GetAttrMap(), name, int64_val); - if (ret) { - value = static_cast(int64_val); - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetFloat(AttrHolderAdapter &&obj, const std::string &name, const float32_t &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetFloat(ConstAttrHolderAdapter &&obj, - const std::string &name, float32_t &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListFloat(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListFloat(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetBool(AttrHolderAdapter &&obj, const std::string &name, - const bool &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetBool(ConstAttrHolderAdapter &&obj, - const std::string &name, bool &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListBool(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListBool(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::SetStr(AttrHolderAdapter &&obj, const std::string &name, - const std::string &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetStr(ConstAttrHolderAdapter &&obj, - const std::string &name, std::string &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetStr(ConstAttrHolderAdapter &&obj, - const std::string &name1, - const std::string &name2, std::string &value) { - if (!obj) { - return false; - } - if (!GetAttrValue(obj->GetAttrMap(), name1, value)) { - return GetAttrValue(obj->GetAttrMap(), name2, value); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string *AttrUtils::GetStr(ConstAttrHolderAdapter &&obj, - const std::string &name) { - if (!obj) { - return nullptr; - } - return GetAttrValue(obj->GetAttrMap(), name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListStr(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListStr(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetTensorDesc(AttrHolderAdapter &&obj, const std::string &name, const GeTensorDesc &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetTensorDesc(ConstAttrHolderAdapter &&obj, const std::string &name, GeTensorDesc &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListTensorDesc(AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListTensorDesc(ConstAttrHolderAdapter &&obj, - const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetNamedAttrs(AttrHolderAdapter &&obj, const std::string &name, const NamedAttrs &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetNamedAttrs(ConstAttrHolderAdapter &&obj, const std::string &name, NamedAttrs &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListNamedAttrs(AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListNamedAttrs(ConstAttrHolderAdapter &&obj, - const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetDataType(AttrHolderAdapter &&obj, const std::string &name, const DataType &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetDataType(ConstAttrHolderAdapter &&obj, - const std::string &name, DataType &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListDataType(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListDataType(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListListInt(AttrHolderAdapter &&obj, const std::string &name, - const std::vector> &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListListInt(ConstAttrHolderAdapter &&obj, const std::string &name, - std::vector> &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListListFloat(AttrHolderAdapter &&obj, const std::string &name, - const std::vector> &value) { - if (!obj) { - return false; - } - return SetAttrValue(obj->MutableAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListListFloat(ConstAttrHolderAdapter &&obj, const std::string &name, - std::vector> &value) { - if (!obj) { - return false; - } - return GetAttrValue(obj->GetAttrMap(), name, value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value) { - if (!obj) { - return false; - } - return SetListInt(std::move(obj), name, std::vector(value.begin(), value.end())); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListInt(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - return SetListInt(std::move(obj), name, std::vector(value.begin(), value.end())); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListInt(AttrHolderAdapter &&obj, const std::string &name, std::initializer_list &&value) { - if (!obj) { - return false; - } - return SetListInt(std::move(obj), name, std::vector(value)); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - value.clear(); - std::vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0UL; i < int64_list.size(); ++i) { - if (!IntegerChecker::Compat(int64_list[i])) { - const std::string reason = GetOverflowDescribeOfListint(name, i, int64_list[i]); - REPORT_INNER_ERR_MSG("E18888", "%s", reason.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] %s", reason.c_str()); - return false; - } - } - (void) value.insert(value.cbegin(), int64_list.cbegin(), int64_list.cend()); - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListInt(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - value.clear(); - std::vector int64_list; - if (!GetListInt(std::move(obj), name, int64_list)) { - return false; - } - - for (size_t i = 0UL; i < int64_list.size(); ++i) { - if (!IntegerChecker::Compat(int64_list[i])) { - const std::string reason = GetOverflowDescribeOfListint(name, i, int64_list[i]); - REPORT_INNER_ERR_MSG("E18888", "%s", reason.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] %s", reason.c_str()); - return false; - } - } - (void) value.insert(value.cbegin(), int64_list.cbegin(), int64_list.cend()); - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetTensor(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, const GeTensor &value) { - if (!obj) { - return false; - } - // 当前GeTensor的拷贝赋值、拷贝构造函数均不是深拷贝,因此无法使用默认的方法SetAttr - if (!obj->MutableAttrMap().SetByName(name, GeTensor())) { - return false; - } - const auto tensor = obj->MutableAttrMap().MutableGetByName(name); - if (tensor == nullptr) { - return false; - } - TensorUtils::CopyTensor(value, *tensor); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetShareTensor(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, const GeTensor &value) { - if (!obj) { - return false; - } - // 当前GeTensor的拷贝赋值、拷贝构造函数均不是深拷贝,因此无法使用默认的方法SetAttr - if (!obj->MutableAttrMap().SetByName(name, GeTensor())) { - return false; - } - const auto tensor = obj->MutableAttrMap().MutableGetByName(name); - if (tensor == nullptr) { - return false; - } - TensorUtils::ShareTensor(value, *tensor); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetTensor(AttrHolderAdapter &&obj, const std::string &name, const GeTensorPtr &value) { - if (!obj) { - return false; - } - return SetTensor(std::move(obj), name, *value); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetTensor(AttrHolderAdapter &&obj, const std::string &name, const ConstGeTensorPtr &value) { - if (!obj) { - return false; - } - return SetTensor(std::move(obj), name, *value); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListTensor(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - std::vector tensors(value.size()); - if (!obj->MutableAttrMap().SetByName(name, tensors)) { - return false; - } - const auto attr_tensors = obj->MutableAttrMap().MutableGetByName>(name); - if (attr_tensors == nullptr) { - return false; - } - for (size_t i = 0UL; i < value.size(); ++i) { - TensorUtils::CopyTensor(value[i], (*attr_tensors)[i]); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - std::vector tensors(value.size()); - (void) std::copy(value.begin(), value.end(), tensors.begin()); - return SetListTensor(std::move(obj), name, tensors); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - std::vector tensors(value.size()); - if (!obj->MutableAttrMap().SetByName(name, tensors)) { - return false; - } - const auto attr_tensors = obj->MutableAttrMap().MutableGetByName>(name); - if (attr_tensors == nullptr) { - return false; - } - for (size_t i = 0UL; i < value.size(); ++i) { - TensorUtils::CopyTensor(*(value[i]), (*attr_tensors)[i]); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const std::string &name, - std::initializer_list &&value) { - if (!obj) { - return false; - } - return SetListTensor(std::move(obj), name, std::vector(value)); -} - -// 所有权UT测试,不能把属性上的GeTensor给错误释放了 -// 而且这里的行为与老版本是不一样的,老版本中,即使属性的owner生命周期结束析构了,通过本接口获取的value仍然是可用的 -// 但是新接口中,owner没有转移,owner析构后,value指向的内存就被释放了,这里需要排查 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::MutableTensor(AttrHolderAdapter &&obj, const std::string &name, GeTensorPtr &value) { - if (!obj) { - return false; - } - const auto tensor = obj->MutableAttrMap().MutableGetByName(name); - if (tensor == nullptr) { - return false; - } - value = std::shared_ptr(tensor, [](const GeTensor *const ptr) { (void) ptr; }); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetTensor(ConstAttrHolderAdapter &&obj, const std::string &name, ConstGeTensorPtr &value) { - if (!obj) { - return false; - } - const auto tensor = obj->GetAttrMap().GetByName(name); - if (tensor == nullptr) { - return false; - } - value = std::shared_ptr(tensor, [](const GeTensor *const ptr) { (void) ptr; }); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListTensor(ConstAttrHolderAdapter &&obj, const std::string &name, - std::vector &value) { - if (!obj) { - return false; - } - const auto tensors = obj->GetAttrMap().GetByName>(name); - if (tensors == nullptr) { - return false; - } - value.resize(tensors->size()); - for (size_t i = 0UL; i < tensors->size(); ++i) { - value[i] = std::shared_ptr(&(*tensors)[i], [](const GeTensor *const ptr) { (void) ptr; }); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::MutableListTensor(AttrHolderAdapter &&obj, const std::string &name, std::vector &value) { - if (!obj) { - return false; - } - const auto tensors = obj->MutableAttrMap().MutableGetByName>(name); - if (tensors == nullptr) { - return false; - } - value.resize(tensors->size()); - for (size_t i = 0UL; i < tensors->size(); ++i) { - value[i] = std::shared_ptr(&(*tensors)[i], [](const GeTensor *const ptr) { (void) ptr; }); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetGraph(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, const ComputeGraphPtr &value) { - if (!obj) { - return false; - } - proto::GraphDef *const graph_def = SetAndGetAttrValue(obj->MutableAttrMap(), name, proto::GraphDef()); - if (graph_def == nullptr) { - return false; - } - const ModelSerializeImp imp; - if (!imp.SerializeGraph(value, graph_def)) { - REPORT_INNER_ERR_MSG("E18888", "SerializeGraph failed when add ComputeGraph to attr %s", name.c_str()); - GELOGE(GRAPH_FAILED, "[Serialize][Graph] Failed when add ComputeGraph to attr %s", name.c_str()); - (void) obj->MutableAttrMap().Delete(name); - return false; - } - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListGraph(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - std::vector graphs(value.size()); - if (!obj->MutableAttrMap().SetByName(name, graphs)) { - return false; - } - const auto attr_graphs = obj->MutableAttrMap().MutableGetByName>(name); - if (attr_graphs == nullptr) { - return false; - } - for (size_t i = 0UL; i < value.size(); ++i) { - const ModelSerializeImp imp; - if (!imp.SerializeGraph(value[i], &attr_graphs->at(i))) { - REPORT_INNER_ERR_MSG("E18888", "SerializeGraph failed when add ComputeGraph to attr %s", name.c_str()); - GELOGE(GRAPH_FAILED, "[Serialize][Graph] Failed when add ComputeGraph to attr %s", name.c_str()); - (void) obj->MutableAttrMap().Delete(name); - return false; - } - } - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetGraph(AttrUtils::ConstAttrHolderAdapter &&obj, const std::string &name, ComputeGraphPtr &value) { - if (!obj) { - return false; - } - const auto attr_graph_def = obj->GetAttrMap().GetByName(name); - if (attr_graph_def == nullptr) { - return false; - } - // 这里延续了老代码实现,先拷贝构造一个ComputeGraph,然后做反序列化,感觉直接把attr_graph_def传进去应该就可以了? - // 下一步对这里做整改,直接传入attr_graph_def,避免这一次拷贝 - const auto graph_def = ComGraphMakeShared(*attr_graph_def); - if (graph_def == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create proto::GraphDef failed."); - GELOGE(GRAPH_FAILED, "[Create][GraphDef] proto::GraphDef make shared failed"); - return false; - } - - ModelSerializeImp imp; - imp.SetProtobufOwner(graph_def); - if (!imp.UnserializeGraph(value, *graph_def)) { - REPORT_INNER_ERR_MSG("E18888", "UnserializeGraph failed when get attr ComputeGraph by name %s", name.c_str()); - GELOGE(GRAPH_FAILED, "[Unserialize][Graph] Failed when get attr ComputeGraph by name %s", name.c_str()); - return false; - } - - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListGraph(AttrUtils::ConstAttrHolderAdapter &&obj, const std::string &name, - std::vector &value) { - if (!obj) { - return false; - } - const auto graph_defs = obj->GetAttrMap().GetByName>(name); - if (graph_defs == nullptr) { - return false; - } - - value.resize(graph_defs->size()); - for (size_t i = 0UL; i < graph_defs->size(); ++i) { - std::shared_ptr graph_def; - graph_def = ComGraphMakeShared(graph_defs->at(i)); - if (graph_def == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create proto::GraphDef failed."); - GELOGE(GRAPH_FAILED, "[Create][GraphDef] proto::GraphDef make shared failed"); - graph_def = nullptr; - return false; - } else { - ComputeGraphPtr graph = nullptr; - ModelSerializeImp imp; - imp.SetProtobufOwner(static_cast(graph_def)); - if (!imp.UnserializeGraph(graph, *graph_def)) { - REPORT_INNER_ERR_MSG("E18888", "UnserializeGraph failed."); - GELOGE(GRAPH_FAILED, "[Unserialize][Graph] Failed"); - return false; - } - value[i] = graph; - } - } - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetBytes(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, const Buffer &value) { - if (!obj) { - return false; - } - const auto buffer = SetAndGetAttrValue(obj->MutableAttrMap(), name, Buffer()); - if (buffer == nullptr) { - return false; - } - BufferUtils::CopyFrom(value, *buffer); - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetBytes(ConstAttrHolderAdapter &&obj, const std::string &name, Buffer &value) { - if (!obj) { - return false; - } - const auto buffer = obj->GetAttrMap().GetByName(name); - if (buffer == nullptr) { - return false; - } - BufferUtils::CopyFrom(*buffer, value); - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetListBytes(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, - const std::vector &value) { - if (!obj) { - return false; - } - std::vector buffers(value.size()); - const auto attr_buffers = SetAndGetAttrValue(obj->MutableAttrMap(), name, buffers); - if (attr_buffers == nullptr) { - return false; - } - - for (size_t i = 0UL; i < value.size(); ++i) { - BufferUtils::CopyFrom(value[i], (*attr_buffers)[i]); - } - - return true; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetListBytes(AttrUtils::ConstAttrHolderAdapter &&obj, const std::string &name, - std::vector &value) { - if (!obj) { - return false; - } - const auto buffers = obj->GetAttrMap().GetByName>(name); - if (buffers == nullptr) { - return false; - } - value.resize(buffers->size()); - for (size_t i = 0UL; i < buffers->size(); ++i) { - BufferUtils::CopyFrom(buffers->at(i), value[i]); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetZeroCopyBytes(AttrHolderAdapter &&obj, const std::string &name, Buffer &&buffer) { - if (!obj) { - return false; - } - // Value will be shared - return SetAttrValue(obj->MutableAttrMap(), name, std::move(buffer)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const std::string &name, Buffer &buffer) { - if (!obj) { - return false; - } - // Value will be shared - return GetAttrValue(obj->GetAttrMap(), name, buffer); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::SetZeroCopyListBytes(AttrHolderAdapter &&obj, const std::string &name, - std::vector &list_buffer) { - if (!obj) { - return false; - } - // Value will be shared - return SetAttrValue(obj->MutableAttrMap(), name, list_buffer); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const std::string &name, - std::vector &list_buffer) { - if (!obj) { - return false; - } - // Value will be shared - return GetAttrValue>(obj->GetAttrMap(), name, list_buffer); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::map AttrUtils::GetAllAttrs(ConstAttrHolderAdapter &&obj) { - return GetAllAttrsWithFilter(std::move(obj), nullptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::map AttrUtils::GetAllAttrsWithFilter( - ConstAttrHolderAdapter &&obj, const AttrNameFilter &attr_filter) { - const auto holder = obj.get(); - if (holder == nullptr) { - const std::map empty; - return empty; - } - return holder->GetAllAttrsWithFilter(attr_filter); -} - -std::string AttrUtils::GetAttrsStrAfterRid(ConstAttrHolderAdapter &&obj, - const std::set &un_compute_attrs) { - const std::map attr_map = GetAllAttrs(std::move(obj)); - if (attr_map.empty()) { - return ""; - } - std::map ordered_attrs; - for (auto &attr : attr_map) { - proto::AttrDef attr_def; - auto *const value_serializer = AttrSerializerRegistry::GetInstance().GetSerializer(attr.second.GetValueTypeId()); - if ((value_serializer == nullptr) || (value_serializer->Serialize(attr.second, attr_def) != GRAPH_SUCCESS)) { - ordered_attrs[attr.first] = ""; - continue; - } - - ordered_attrs[attr.first] = attr_def.SerializeAsString(); - } - - std::stringstream str_stream; - for (auto &attr : ordered_attrs) { - if (un_compute_attrs.find(attr.first) != un_compute_attrs.end()) { - continue; - } - str_stream << attr.first << ":" << attr.second << ";"; - } - return str_stream.str(); -} -std::string AttrUtils::GetAllAttrsStr(ConstAttrHolderAdapter &&obj) { - const auto attr_map = GetAllAttrs(std::move(obj)); - if (attr_map.empty()) { - return ""; - } - return GetAllAttrsStr(attr_map); -} - -std::string AttrUtils::GetAllAttrsStr(const std::map &attr_map) { - std::map ordered_attrs; - for (auto &attr : attr_map) { - proto::AttrDef attr_def; - auto *const value_serializer = AttrSerializerRegistry::GetInstance().GetSerializer(attr.second.GetValueTypeId()); - if ((value_serializer == nullptr) || (value_serializer->Serialize(attr.second, attr_def) != GRAPH_SUCCESS)) { - ordered_attrs[attr.first] = ""; - continue; - } - - if (attr_def.has_t()) { - // print tensor desc message as an ordered string. - std::string ordered_tensor_desc; - (void) google::protobuf::TextFormat::PrintToString(attr_def.t().desc(), &ordered_tensor_desc); - ordered_attrs[attr.first] = ordered_tensor_desc + attr_def.t().data(); - } else if (attr_def.has_td()) { - // print tensor desc message as an ordered string. - std::string ordered_attr; - (void) google::protobuf::TextFormat::PrintToString(attr_def.td(), &ordered_attr); - ordered_attrs[attr.first] = ordered_attr; - } else { - ordered_attrs[attr.first] = attr_def.SerializeAsString(); - } - } - - std::stringstream ss; - for (auto &attr : ordered_attrs) { - ss << attr.first << ":" << attr.second << ";"; - } - return ss.str(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool AttrUtils::ClearAllAttrs(AttrHolderAdapter &&obj) { - if (!obj) { - return false; - } - obj->MutableAttrMap().Clear(); - return true; -} - -std::string AttrUtils::ValueTypeToSerialString(const AnyValue::ValueType value_type) { - const auto it = kAttrTypesMap.find(value_type); - if (it != kAttrTypesMap.end()) { - return it->second; - } else { - REPORT_INNER_ERR_MSG("E18888", "value_type not support %d", value_type); - GELOGE(GRAPH_FAILED, "[Check][Param] value_type not support %d", value_type); - return ""; - } -} - -AnyValue::ValueType AttrUtils::SerialStringToValueType(const string &value_type_string) { - const auto it = kAttrStrTypesMap.find(value_type_string); - if (it != kAttrStrTypesMap.end()) { - return it->second; - } else { - REPORT_INNER_ERR_MSG("E18888", "value_type_string not support %s", value_type_string.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] value_type_string not support %s", value_type_string.c_str()); - return AnyValue::VT_NONE; - } -} -} // namespace ge diff --git a/graph/buffer/buffer.cc b/graph/buffer/buffer.cc deleted file mode 100644 index dbbb7b4e39cb3e9e1cf002bf8069e779c1d8dff9..0000000000000000000000000000000000000000 --- a/graph/buffer/buffer.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/buffer.h" -#include "proto/ge_ir.pb.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/buffer/buffer_impl.h" -#include "graph/debug/ge_util.h" -#include "common/util/mem_utils.h" - -namespace ge { -BufferImpl::BufferImpl() { - data_.InitDefault(); - if (data_.GetProtoMsg() != nullptr) { - buffer_ = data_.GetProtoMsg()->mutable_bt(); - } -} - -BufferImpl::BufferImpl(const BufferImpl &other) { - data_ = other.data_; - buffer_ = other.buffer_; -} - -BufferImpl::~BufferImpl() {} - -BufferImpl::BufferImpl(const std::size_t buffer_size, const std::uint8_t default_val) : BufferImpl() { // default - auto const proto_msg = data_.GetProtoMsg(); - if (proto_msg != nullptr) { - try { - proto_msg->set_bt(std::string(buffer_size, static_cast(default_val))); - buffer_ = proto_msg->mutable_bt(); - } catch (std::bad_alloc &) { - REPORT_INNER_ERR_MSG("E18888", "failed to alloc buffer memory, buffer size %zu", buffer_size); - GELOGE(MEMALLOC_FAILED, "[New][Memory] failed to alloc buffer memory, buffer size %zu", buffer_size); - buffer_ = nullptr; - } - } -} - -void BufferImpl::CopyFrom(const std::uint8_t * const data, const std::size_t buffer_size) { - auto const proto_msg = data_.GetProtoMsg(); - if ((proto_msg != nullptr) && (data != nullptr)) { - try { - proto_msg->set_bt(data, buffer_size); - buffer_ = proto_msg->mutable_bt(); - } catch (std::bad_alloc &) { - REPORT_INNER_ERR_MSG("E18888", "Failed to alloc buffer memory, buffer size %zu", buffer_size); - GELOGE(MEMALLOC_FAILED, "[New][Memory] Failed to alloc buffer memory, buffer size %zu", buffer_size); - buffer_ = nullptr; - } - } -} - -BufferImpl::BufferImpl(const std::shared_ptr &proto_owner, - proto::AttrDef * const buffer) - : data_(proto_owner, buffer) { - if (data_.GetProtoMsg() != nullptr) { - buffer_ = data_.GetProtoMsg()->mutable_bt(); - } -} - -BufferImpl::BufferImpl(const std::shared_ptr &proto_owner, std::string * const buffer) - : data_(proto_owner, nullptr) { - buffer_ = buffer; -} - -BufferImpl &BufferImpl::operator=(const BufferImpl &other) { - if (&other != this) { - // Share data - data_ = other.data_; - buffer_ = other.buffer_; - } - return *this; -} - -const std::uint8_t *BufferImpl::GetData() const { - if (buffer_ != nullptr) { - return PtrToPtr(buffer_->data()); - } - return nullptr; -} - -std::uint8_t *BufferImpl::GetData() { - if ((buffer_ != nullptr) && (!buffer_->empty())) { - // Avoid copy on write - (void)(*buffer_)[0UL]; - return PtrToPtr(const_cast(buffer_->data())); - } - return nullptr; -} - -std::size_t BufferImpl::GetSize() const { - if (buffer_ != nullptr) { - return buffer_->size(); - } - return 0UL; -} - -void BufferImpl::ClearBuffer() { - if (buffer_ != nullptr) { - buffer_->clear(); - } -} - -uint8_t BufferImpl::operator[](const size_t index) const { - if ((buffer_ != nullptr) && (index < buffer_->size())) { - return static_cast((*buffer_)[index]); - } - return 0xffU; -} - -Buffer::Buffer() : impl_(MakeShared()) {} - -Buffer::Buffer(const Buffer &other) - : impl_(MakeShared(*(other.impl_))) {} - -Buffer::Buffer(const std::size_t buffer_size, const std::uint8_t default_val) - : impl_(MakeShared(buffer_size, default_val)) {} - -Buffer::~Buffer() {} - -Buffer Buffer::CopyFrom(const std::uint8_t * const data, const std::size_t buffer_size) { - Buffer buffer; - if (buffer.impl_ != nullptr) { - buffer.impl_->CopyFrom(data, buffer_size); - } - return buffer; -} - -Buffer::Buffer(const ProtoMsgOwner &proto_owner, proto::AttrDef * const buffer) - : impl_(MakeShared(proto_owner, buffer)) {} - -Buffer::Buffer(const ProtoMsgOwner &proto_owner, std::string * const buffer) - : impl_(MakeShared(proto_owner, buffer)) {} - -Buffer &Buffer::operator=(const Buffer &other) { - if (&other != this) { - if (impl_ != nullptr) { - *impl_ = *(other.impl_); - } - } - return *this; -} - -const std::uint8_t *Buffer::GetData() const { - return impl_->GetData(); -} - -std::uint8_t *Buffer::GetData() { - return impl_->GetData(); -} - -std::size_t Buffer::GetSize() const { - return impl_->GetSize(); -} - -void Buffer::ClearBuffer() { - impl_->ClearBuffer(); -} - -const std::uint8_t *Buffer::data() const { return GetData(); } - -std::uint8_t *Buffer::data() { return GetData(); } - -std::size_t Buffer::size() const { return GetSize(); } - -void Buffer::clear() { return ClearBuffer(); } - -uint8_t Buffer::operator[](const size_t index) const { - return (*impl_)[index]; -} - -Buffer BufferUtils::CreateShareFrom(const Buffer &other) { - return other; -} - -Buffer BufferUtils::CreateCopyFrom(const Buffer &other) { - return BufferUtils::CreateCopyFrom(other.GetData(), other.GetSize()); -} - -Buffer BufferUtils::CreateCopyFrom(const std::uint8_t * const data, const std::size_t buffer_size) { - return Buffer::CopyFrom(data, buffer_size); -} - -void BufferUtils::ShareFrom(const Buffer &from, Buffer &to) { - to = from; -} - -void BufferUtils::CopyFrom(const Buffer &from, Buffer &to) { - to = BufferUtils::CreateCopyFrom(from); -} -} // namespace ge diff --git a/graph/buffer/buffer_impl.h b/graph/buffer/buffer_impl.h deleted file mode 100644 index 5ec9b82162e24ff13c955845fdf95e2ac6955f06..0000000000000000000000000000000000000000 --- a/graph/buffer/buffer_impl.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_BUFFER_IMPL_H_ -#define GRAPH_BUFFER_IMPL_H_ - -#include -#include "proto/ge_ir.pb.h" -#include "graph/detail/attributes_holder.h" - -namespace ge { -class BufferImpl { - public: - BufferImpl(); - ~BufferImpl(); - BufferImpl(const BufferImpl &other); - BufferImpl(const std::size_t buffer_size, const std::uint8_t default_val); - - void CopyFrom(const std::uint8_t *const data, const std::size_t buffer_size); - BufferImpl(const std::shared_ptr &proto_owner, proto::AttrDef *const buffer); - BufferImpl(const std::shared_ptr &proto_owner, std::string *const buffer); - - BufferImpl &operator=(const BufferImpl &other); - const std::uint8_t *GetData() const; - std::uint8_t *GetData(); - std::size_t GetSize() const; - void ClearBuffer(); - uint8_t operator[](const size_t index) const; - - private: - friend class GeAttrValueImp; - GeIrProtoHelper data_; - std::string *buffer_ = nullptr; -}; -} // namespace ge -#endif // GRAPH_BUFFER_IMPL_H_ diff --git a/graph/buffer/graph_buffer.cc b/graph/buffer/graph_buffer.cc deleted file mode 100644 index eb77d4f514fbe097f55ba8ca2c0319fb795ee83e..0000000000000000000000000000000000000000 --- a/graph/buffer/graph_buffer.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/graph_buffer.h" -#include "graph/buffer.h" -#include "common/util/mem_utils.h" - -namespace ge { -GraphBuffer::GraphBuffer() : buffer_(MakeShared()) {} - -GraphBuffer::~GraphBuffer() {} - -const std::uint8_t *GraphBuffer::GetData() const { - return buffer_->GetData(); -} - -std::size_t GraphBuffer::GetSize() const { - return buffer_->GetSize(); -} -} // namespace ge diff --git a/graph/cache_policy/cache_policy.cc b/graph/cache_policy/cache_policy.cc deleted file mode 100644 index cd800fbe36edcb2d9055e010e3e34bef9a5a17cc..0000000000000000000000000000000000000000 --- a/graph/cache_policy/cache_policy.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/cache_policy.h" -#include "graph/debug/ge_util.h" - -namespace ge { -std::unique_ptr CachePolicy::Create(const MatchPolicyPtr &mp, const AgingPolicyPtr &ap) { - if (mp == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] param match policy must not be null."); - return nullptr; - } - if (ap == nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] param aging policy must not be null."); - return nullptr; - } - auto ccp = ComGraphMakeUnique(); - if (ccp == nullptr) { - GELOGE(GRAPH_FAILED, "Create CachePolicy failed."); - return nullptr; - } - (void)ccp->SetAgingPolicy(ap); - (void)ccp->SetMatchPolicy(mp); - - GELOGI("[CachePolicy] Create CachePolicy success;"); - return ccp; -} - -std::unique_ptr CachePolicy::Create(const MatchPolicyType mp_type, - const AgingPolicyType ap_type, size_t cached_aging_depth) { - const auto mp = PolicyRegister::GetInstance().GetMatchPolicy(mp_type); - GE_ASSERT_NOTNULL(mp); - const auto ap = PolicyRegister::GetInstance().GetAgingPolicy(ap_type); - GE_ASSERT_NOTNULL(ap); - ap->SetCachedAgingDepth(cached_aging_depth); - auto ccp = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(ccp); - (void)ccp->SetAgingPolicy(ap); - (void)ccp->SetMatchPolicy(mp); - GELOGI("[CachePolicy] Create CachePolicy with match_policy: %d, aging_policy: %d success;", - static_cast(mp_type), static_cast(ap_type)); - return ccp; -} - -graphStatus CachePolicy::SetMatchPolicy(const MatchPolicyPtr mp) { - GE_CHECK_NOTNULL(mp); - mp_ = mp; - return GRAPH_SUCCESS; -} - -graphStatus CachePolicy::SetAgingPolicy(const AgingPolicyPtr ap) { - GE_CHECK_NOTNULL(ap); - ap_ = ap; - return GRAPH_SUCCESS; -} - -CacheItemId CachePolicy::AddCache(const CacheDescPtr &cache_desc) { - const CacheHashKey main_hash_key = cache_desc->GetCacheDescHash(); - if (!ap_->IsReadyToAddCache(main_hash_key, cache_desc)) { - GELOGI("Not met the add cache condition with has key:%lu.", main_hash_key); - return KInvalidCacheItemId; - } - const auto cache_item = compile_cache_state_.AddCache(main_hash_key, cache_desc); - if (cache_item == KInvalidCacheItemId) { - GELOGE(GRAPH_FAILED, "[Check][Param] AddCache failed: please check the compile cache description."); - return KInvalidCacheItemId; - } - return cache_item; -} - -CacheItemId CachePolicy::FindCache(const CacheDescPtr &cache_desc) const { - if (mp_ == nullptr) { - GELOGW("match policy is nullptr"); - return KInvalidCacheItemId; - } - return mp_->GetCacheItemId(compile_cache_state_.GetState(), cache_desc); -} - -std::vector CachePolicy::DeleteCache(const DelCacheFunc &func) { - const auto delete_items = compile_cache_state_.DelCache(func); - GELOGI("[CachePolicy] [DeleteCache] Delete %zu CacheInfos.", delete_items.size()); - return delete_items; -} - -std::vector CachePolicy::DeleteCache(const std::vector &delete_item) { - const auto delete_items = compile_cache_state_.DelCache(delete_item); - GELOGI("[CachePolicy] [DeleteCache] Delete %zu CompileCacheInfo", delete_items.size()); - return delete_items; -} - -std::vector CachePolicy::DoAging() { - const auto delete_item = ap_->DoAging(compile_cache_state_); - (void)compile_cache_state_.DelCache(delete_item); - return delete_item; -} -} // namespace ge diff --git a/graph/cache_policy/cache_state.cc b/graph/cache_policy/cache_state.cc deleted file mode 100644 index baa8962167f07605b4a39d33be06f1bfae6b9113..0000000000000000000000000000000000000000 --- a/graph/cache_policy/cache_state.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/cache_state.h" -#include "common/ge_common/debug/ge_log.h" -namespace ge { -CacheItemId CacheState::GetNextCacheItemId() { - const std::lock_guard lock(cache_item_mu_); - if (cache_item_queue_.empty()) { - return cache_item_counter_++; - } else { - const CacheItemId next_item_id = cache_item_queue_.front(); - cache_item_queue_.pop(); - return next_item_id; - } -} - -void CacheState::RecoveryCacheItemId(const std::vector &cache_items) { - const std::lock_guard lock(cache_item_mu_); - for (auto &item_id : cache_items) { - cache_item_queue_.push(item_id); - } -} - -CacheItemId CacheState::AddCache(const CacheHashKey main_hash_key, const CacheDescPtr &cache_desc) { - const std::lock_guard lock(cache_info_queue_mu_); - const auto iter = cache_info_queue.cc_state_.find(main_hash_key); - if (iter == cache_info_queue.cc_state_.end()) { - const CacheItemId next_item_id = GetNextCacheItemId(); - const CacheInfo cache_info = CacheInfo(GetNextTimerCount(), next_item_id, cache_desc); - std::vector info = {cache_info}; - cache_info_queue.Insert(main_hash_key, info); - return next_item_id; - } - auto &cache_infos = iter->second; - for (auto &cache_info : cache_infos) { - if (cache_desc->IsEqual(cache_info.desc_)) { - cache_info.RefreshTimerCount(GetNextTimerCount()); - GELOGW("[AddCache] Same CacheDesc has already been added, whose cache_item is %" PRIu64, cache_info.item_id_); - return cache_info.item_id_; - } - } - // hash collision may happened - const CacheItemId next_item_id = GetNextCacheItemId(); - CacheInfo cache_info = CacheInfo(GetNextTimerCount(), next_item_id, cache_desc); - cache_info_queue.EmplaceBack(main_hash_key, cache_info); - return next_item_id; -} - -std::vector CacheState::DelCache(const DelCacheFunc &func) { - std::vector delete_item; - const std::lock_guard lock(cache_info_queue_mu_); - cache_info_queue.Erase(delete_item, func); - - RecoveryCacheItemId(delete_item); - return delete_item; -} - -std::vector CacheState::DelCache(const std::vector &delete_item) { - const DelCacheFunc lamb = [&delete_item] (const CacheInfo &info) -> bool { - const auto iter = std::find(delete_item.begin(), delete_item.end(), info.GetItemId()); - return iter != delete_item.end(); - }; - return DelCache(lamb); -} - -void CacheInfoQueue::Insert(const CacheHashKey main_hash_key, std::vector &cache_info) { - (void) cc_state_.insert({main_hash_key, std::move(cache_info)}); - ++cache_info_num_; -} -void CacheInfoQueue::EmplaceBack(const CacheHashKey main_hash_key, CacheInfo &cache_info) { - cc_state_[main_hash_key].emplace_back(std::move(cache_info)); - ++cache_info_num_; -} -void CacheInfoQueue::Erase(std::vector &delete_ids, const DelCacheFunc &is_need_delete_func) { - for (auto &item : cc_state_) { - std::vector &cache_vec = item.second; - for (auto iter = cache_vec.begin(); iter != cache_vec.end();) { - if (is_need_delete_func(*iter)) { - delete_ids.emplace_back((*iter).GetItemId()); - iter = cache_vec.erase(iter); - --cache_info_num_; - } else { - iter++; - } - } - } -} -} // namespace ge diff --git a/graph/cache_policy/compile_cache_desc.cc b/graph/cache_policy/compile_cache_desc.cc deleted file mode 100644 index e07a36c0bad56d58af8070daad2c7c092875dc32..0000000000000000000000000000000000000000 --- a/graph/cache_policy/compile_cache_desc.cc +++ /dev/null @@ -1,344 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/compile_cache_desc.h" -#include -#include "common/checker.h" -#include "debug/ge_util.h" - -namespace ge { -constexpr int32_t kAllShape = -2; - -BinaryHolder::BinaryHolder(const BinaryHolder &other) { - if ((other.GetDataPtr() != nullptr) && (other.GetDataLen() != 0UL)) { - data_len_ = other.GetDataLen(); - holder_ = ComGraphMakeUnique(data_len_); - GE_CHECK_NOTNULL_JUST_RETURN(holder_); - const auto mem_ret = memcpy_s(holder_.get(), data_len_, - ge::PtrToPtr(other.GetDataPtr()), data_len_); - if (mem_ret != EOK) { - data_len_ = 0U; - holder_ = nullptr; - GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] memcpy failed."); - } - } -} - -BinaryHolder::BinaryHolder(const uint8_t *const data, const size_t data_len) { - if ((data != nullptr) && (data_len != 0UL)) { - data_len_ = data_len; - holder_ = ComGraphMakeUnique(data_len_); - GE_CHECK_NOTNULL_JUST_RETURN(holder_); - const auto mem_ret = memcpy_s(holder_.get(), data_len_, - ge::PtrToPtr(data), data_len_); - if (mem_ret != EOK) { - data_len_ = 0U; - holder_ = nullptr; - GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] memcpy failed."); - } - } -} - -BinaryHolder &BinaryHolder::operator=(const BinaryHolder &other) { - if ((other.GetDataPtr() != nullptr) && (other.GetDataLen() != 0UL)) { - data_len_ = other.GetDataLen(); - holder_ = ComGraphMakeUnique(data_len_); - if (holder_ == nullptr) { - GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] make unique failed."); - return *this; - } - const auto mem_ret = memcpy_s(holder_.get(), data_len_, - ge::PtrToPtr(other.GetDataPtr()), data_len_); - if (mem_ret != EOK) { - data_len_ = 0U; - holder_ = nullptr; - GELOGE(ge::GRAPH_FAILED, "[BinaryHolder] memcpy failed."); - } - } - return *this; -} - -BinaryHolder::BinaryHolder(BinaryHolder &&other) { - data_len_ = other.data_len_; - holder_ = std::move(other.holder_); - other.data_len_ = 0; -} - -BinaryHolder &BinaryHolder::operator=(BinaryHolder &&other) { - data_len_ = other.data_len_; - holder_ = std::move(other.holder_); - other.data_len_ = 0; - return *this; -} - -std::unique_ptr BinaryHolder::createFrom(std::unique_ptr &&ptr, size_t length) { - auto holder = ComGraphMakeUnique(); - if ((ptr != nullptr) && (holder != nullptr) && (length != 0UL)) { - holder->data_len_ = length; - holder->holder_ = std::move(ptr); - } - return holder; -} - -const uint8_t *BinaryHolder::GetDataPtr() const noexcept { - if (holder_.get() != nullptr) { - return holder_.get(); - } - return nullptr; -} - -const size_t &BinaryHolder::GetDataLen() const noexcept { - return data_len_; -} - -bool BinaryHolder::operator!=(const BinaryHolder &second) const { - if (this->GetDataLen() != second.GetDataLen()) { - return true; - } - const auto this_data = this->GetDataPtr(); - const auto second_data = second.GetDataPtr(); - if (((this_data == nullptr) && (second_data != nullptr)) || - ((this_data != nullptr) && (second_data == nullptr))) { - return true; - } - if ((this_data == nullptr) && (second_data == nullptr)) { - return false; - } - if (memcmp(this_data, second_data, this->GetDataLen()) != 0) { - return true; - } - return false; -} - -Format TensorInfoArgs::GetFormat() const { - return format_; -} - -Format TensorInfoArgs::GetOriginFormat() const { - return origin_format_; -} - -DataType TensorInfoArgs::GetDataType() const { - return data_type_; -} - -void TensorInfoArgs::SetShape(const std::vector &shape) { - shape_.clear(); - for (const auto dim : shape) { - shape_.emplace_back(dim); - } -} - -void TensorInfoArgs::SetShape(const SmallVector &shape) { - shape_.clear(); - shape_ = shape; -} - -void TensorInfoArgs::SetOriginShape(const std::vector &origin_shape) { - origin_shape_.clear(); - for (const auto dim : origin_shape) { - origin_shape_.emplace_back(dim); - } -} - -void TensorInfoArgs::SetOriginShape(const SmallVector &origin_shape) { - origin_shape_.clear(); - origin_shape_ = origin_shape; -} - -void TensorInfoArgs::SetShapeRange(const std::vector> &ranges) { - shape_range_.clear(); - for (const auto &range : ranges) { - shape_range_.emplace_back(range); - } -} - -bool TensorInfoArgs::IsUnknownShape() const { - return std::any_of(shape_.begin(), shape_.end(), [](const int64_t &dim) { - return (dim == UNKNOWN_DIM) || (dim == UNKNOWN_DIM_NUM); - }); -} - -bool TensorInfoArgs::operator!=(const TensorInfoArgs &second) const { - const bool ret = (this->format_ != second.format_) || (this->origin_format_ != second.origin_format_) || - (this->data_type_ != second.data_type_) || (this->shape_ != second.shape_) || - (this->origin_shape_ != second.origin_shape_) || (this->shape_range_ != second.shape_range_); - return ret; -} - -bool TensorInfoArgs::IsTensorInfoMatch(const TensorInfoArgs &other) const { - const bool is_same = (this->format_ == other.format_) && (this->origin_format_ == other.origin_format_) && - (this->data_type_ == other.data_type_); - if (!is_same) { - GELOGD("format or origin format or datatype is not matched"); - return false; - } - return IsShapeInRange(other); -} - -bool TensorInfoArgs::IsShapeInRange(const TensorInfoArgs &other) const { - if ((this->shape_.size() == 1U) && (this->shape_[0] == kAllShape)) { - // -2 is all shape, need to judge first - GELOGD("current shape is -2"); - return true; - } - // check rank - const bool is_same_rank = (this->shape_.size() == other.shape_.size()) && - (this->origin_shape_.size() == other.origin_shape_.size()); - if (!is_same_rank) { - GELOGD("shape or origin shape is not same rank"); - return false; - } - // check shape range when shape is dynamic - if (this->IsUnknownShape()) { - if (this->shape_.size() != this->shape_range_.size()) { - GELOGD("shape size %zu is not match shape rang size %zu", this->shape_.size(), this->shape_range_.size()); - return false; - } - for (size_t i = 0U; i < this->shape_range_.size(); ++i) { - if (this->shape_range_[i].first > other.shape_[i]) { - GELOGD("shape range is not match, first is %" PRId64 ", other is %" PRId64 ", index is %zu", - this->shape_range_[i].first, other.shape_[i], i); - return false; - } - // -1 means infinity great - if (this->shape_range_[i].second == UNKNOWN_DIM) { - GELOGD("shape second is -1, index is %zu", i); - continue; - } - if (this->shape_range_[i].second < other.shape_[i]) { - GELOGD("shape range is not match, second is %" PRId64 ", other is %" PRId64 ", index is %zu", - this->shape_range_[i].second, other.shape_[i], i); - return false; - } - } - } else { - GELOGD("this is exact shape"); - if ((this->shape_ != other.shape_) || (this->origin_shape_ != other.origin_shape_)) { - GELOGD("exact shape or origin shape is not matched"); - return false; - } - } - return true; -} - -size_t CompileCacheDesc::GetTensorInfoSize() { - return tensor_info_args_vec_.size(); -} - -TensorInfoArgs *CompileCacheDesc::MutableTensorInfo(size_t index) { - if (index >= tensor_info_args_vec_.size()) { - return nullptr; - } - return &tensor_info_args_vec_[index]; -} - -void CompileCacheDesc::AddBinary(const BinaryHolder &holder) { - other_desc_.emplace_back(holder); -} - -void CompileCacheDesc::AddBinary(BinaryHolder &&holder) { - other_desc_.emplace_back(holder); -} - -void CompileCacheDesc::SetOpType(const std::string &op_type) { - op_type_ = op_type; - return; -} - -void CompileCacheDesc::AddTensorInfo(const TensorInfoArgs &tensor_info) { - tensor_info_args_vec_.emplace_back(tensor_info); - return; -} - -void CompileCacheDesc::SetScopeId(const std::initializer_list scope_id) { - scope_id_= scope_id; - return; -} - -bool CompileCacheDesc::CheckWithoutTensorInfo(const CompileCacheDesc *first, const CompileCacheDesc *second) const { - if ((first->op_type_ != second->op_type_) || - (first->tensor_info_args_vec_.size() != second->tensor_info_args_vec_.size())) { - GELOGD("op_type_ %s, %s is not match or size %zu %zu is not match", - first->op_type_.c_str(), second->op_type_.c_str(), - first->tensor_info_args_vec_.size(), second->tensor_info_args_vec_.size()); - return false; - } - if (first->scope_id_ != second->scope_id_) { - GELOGD("scope id is not match"); - return false; - } - if (first->other_desc_.size() != second->other_desc_.size()) { - GELOGD("other_desc_ size %zu, %zu is not match ", first->other_desc_.size(), second->other_desc_.size()); - return false; - } - for (size_t i = 0U; i < first->other_desc_.size(); ++i) { - if (first->other_desc_[i].GetDataLen() != second->other_desc_[i].GetDataLen()) { - GELOGD("other_desc_ mem size %zu, %zu is not match ", - first->other_desc_[i].GetDataLen(), second->other_desc_[i].GetDataLen()); - return false; - } - if ((first->other_desc_[i].GetDataPtr() == nullptr) || (second->other_desc_[i].GetDataPtr() == nullptr)) { - return false; - } - const auto cmp_ret = memcmp(first->other_desc_[i].GetDataPtr(), - second->other_desc_[i].GetDataPtr(), second->other_desc_[i].GetDataLen()); - if (cmp_ret != 0) { - GELOGD("mem compare fail"); - return false; - } - } - return true; -} - -bool CompileCacheDesc::IsMatch(const CacheDescPtr &other) const { - const auto *second = dynamic_cast(other.get()); - GE_ASSERT_NOTNULL(second, "dynamic cast failed"); - if (!CheckWithoutTensorInfo(this, second)) { - return false; - } - - for (size_t i = 0U; i < this->tensor_info_args_vec_.size(); ++i) { - const auto &first_args = this->tensor_info_args_vec_[i]; - const auto &second_args = second->tensor_info_args_vec_[i]; - if (!first_args.IsTensorInfoMatch(second_args)) { - GELOGD("shape is not matched"); - return false; - } - } - return true; -} - -bool CompileCacheDesc::IsEqual(const CacheDescPtr &other) const { - const auto *second = dynamic_cast(other.get()); - GE_ASSERT_NOTNULL(second, "dynamic cast failed"); - if (!CheckWithoutTensorInfo(this, second)) { - return false; - } - - for (size_t i = 0U; i < this->tensor_info_args_vec_.size(); ++i) { - const auto &first_args = this->tensor_info_args_vec_[i]; - const auto &second_args = second->tensor_info_args_vec_[i]; - if (first_args != second_args) { - GELOGD("tensor info is not matched"); - return false; - } - } - return true; -} - -CacheHashKey CompileCacheDesc::GetCacheDescHash() const { - CacheHashKey hash_key = 0U; - for (const auto &arg : tensor_info_args_vec_) { - hash_key = HashUtils::MultiHash(hash_key, arg.GetFormat(), arg.GetOriginFormat(), arg.GetDataType()); - } - hash_key = HashUtils::MultiHash(op_type_, hash_key); - return hash_key; -} -} // namespace ge diff --git a/graph/cache_policy/policy_management/aging_policy/aging_policy_lru.cc b/graph/cache_policy/policy_management/aging_policy/aging_policy_lru.cc deleted file mode 100644 index dca2fa283c756f0db5df9a869eb5418b55408e13..0000000000000000000000000000000000000000 --- a/graph/cache_policy/policy_management/aging_policy/aging_policy_lru.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/aging_policy_lru.h" -namespace ge { -std::vector AgingPolicyLru::DoAging(const CacheState &cache_state) const { - const auto &cc_state = cache_state.GetState(); - if (cache_state.GetCurTimerCount() <= delete_interval_) { - GELOGE(ge::PARAM_INVALID, "[Aging][Lru]Delete interval param is invalid. Delete interval is %lu, expect[0, %lu].", - delete_interval_, cache_state.GetCurTimerCount()); - return {}; - } - const uint64_t timer_count_lower_bound = cache_state.GetCurTimerCount() - delete_interval_; - std::vector delete_item; - for (const auto &cache_item : cc_state) { - const std::vector &cache_vec = cache_item.second; - for (auto iter = cache_vec.begin(); iter != cache_vec.end(); iter++) { - if ((*iter).GetTimerCount() < timer_count_lower_bound) { - delete_item.emplace_back((*iter).GetItemId()); - } - } - } - return delete_item; -} -} diff --git a/graph/cache_policy/policy_management/aging_policy/aging_policy_lru_k.cc b/graph/cache_policy/policy_management/aging_policy/aging_policy_lru_k.cc deleted file mode 100644 index eb7bb1ffe3f75315d32bf1a18bff0f897ddfb9a9..0000000000000000000000000000000000000000 --- a/graph/cache_policy/policy_management/aging_policy/aging_policy_lru_k.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/aging_policy_lru_k.h" -namespace ge { -std::vector AgingPolicyLruK::DoAging(const CacheState &cache_state) const { - size_t cur_depth = cache_state.GetCacheInfoNum(); - const auto &cc_state = cache_state.GetState(); - GELOGD("[CACHE][AGING] current depth[%zu] cache queue capacity[%zu].", cur_depth, depth_); - if (cur_depth <= depth_) { - return {}; - } - std::pair delete_item({KInvalidCacheItemId, UINT64_MAX}); - for (const auto &each_cc_state : cc_state) { - for (const auto &cache_info : each_cc_state.second) { - if (cache_info.GetTimerCount() <= delete_item.second) { - delete_item = {cache_info.GetItemId(), cache_info.GetTimerCount()}; - } - } - } - if (delete_item.first == KInvalidCacheItemId) { - return {}; - } - return {delete_item.first}; -} - -bool AgingPolicyLruK::IsCacheDescAppearKTimes(const CacheHashKey hash_key, const CacheDescPtr &cache_desc) { - const std::lock_guard lock(hash_2_cache_descs_and_count_mu_); - if (hash_2_cache_descs_and_count_.count(hash_key) > 0U) { - auto &cache_descs_and_count = hash_2_cache_descs_and_count_[hash_key]; - for (auto &cache_desc_and_count : cache_descs_and_count) { - if (cache_desc->IsEqual(cache_desc_and_count.first)) { - ++cache_desc_and_count.second; - return cache_desc_and_count.second >= k_times_; - } - } - } - hash_2_cache_descs_and_count_[hash_key].emplace_back(std::make_pair(cache_desc, 1U)); - return false; -} -} // namespace ge diff --git a/graph/cache_policy/policy_management/match_policy/match_policy_exact_only.cc b/graph/cache_policy/policy_management/match_policy/match_policy_exact_only.cc deleted file mode 100644 index 51504cbdbdacc64951fb077499a00f6c7ddf1c4b..0000000000000000000000000000000000000000 --- a/graph/cache_policy/policy_management/match_policy/match_policy_exact_only.cc +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/match_policy_exact_only.h" -namespace ge { -CacheItemId MatchPolicyExactOnly::GetCacheItemId(const CCStatType &cc_state, const CacheDescPtr &desc) const { - const CacheHashKey hash_key = desc->GetCacheDescHash(); - const auto &iter = cc_state.find(hash_key); - if (iter == cc_state.end()) { - GELOGD("can not find without shape hash %lu", hash_key); - return KInvalidCacheItemId; - } - const auto &info_vec = iter->second; - const auto cached_info = std::find_if(info_vec.begin(), info_vec.end(), [&desc] (const CacheInfo &cached) { - return (cached.GetCacheDesc()->IsMatch(desc)); - }); - if (cached_info != info_vec.end()) { - return cached_info->GetItemId(); - } else { - return KInvalidCacheItemId; - } -} -} diff --git a/graph/cache_policy/policy_management/match_policy/match_policy_for_exactly_the_same.cc b/graph/cache_policy/policy_management/match_policy/match_policy_for_exactly_the_same.cc deleted file mode 100644 index 1b1687e0f262c8b4eab680ffa32a4e3d65215b41..0000000000000000000000000000000000000000 --- a/graph/cache_policy/policy_management/match_policy/match_policy_for_exactly_the_same.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/match_policy_for_exactly_the_same.h" - -namespace ge { -CacheItemId MatchPolicyForExactlyTheSame::GetCacheItemId(const CCStatType &cc_state, - const CacheDescPtr &cache_desc) const { - const CacheHashKey hash_key = cache_desc->GetCacheDescHash(); - const auto &iter = cc_state.find(hash_key); - if (iter == cc_state.end() || iter->second.empty()) { - GELOGD("[CACHE] hash [%lu] does not exist.", hash_key); - return KInvalidCacheItemId; - } - const auto &info_vec = iter->second; - const auto cached_info = std::find_if(info_vec.begin(), info_vec.end(), [&cache_desc](const CacheInfo &cached) { - return (cache_desc->IsEqual(cached.GetCacheDesc())); - }); - if (cached_info != info_vec.cend()) { - return cached_info->GetItemId(); - } else { - GELOGD("[CACHE] hash [%lu] collision occurred, the same cached desc not found.", hash_key); - return KInvalidCacheItemId; - } -} -} // namespace ge diff --git a/graph/cache_policy/policy_register.cc b/graph/cache_policy/policy_register.cc deleted file mode 100644 index 6d0e9a55b116eb6baf5b3b3c3fa34e5c5254978c..0000000000000000000000000000000000000000 --- a/graph/cache_policy/policy_register.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/cache_policy/policy_register.h" -namespace ge { -PolicyRegister &PolicyRegister::GetInstance() { - static PolicyRegister instance; - return instance; -} - -MatchPolicyRegister::MatchPolicyRegister(const MatchPolicyType match_policy_type, const MatchPolicyCreator &creator) { - PolicyRegister::GetInstance().RegisterMatchPolicy(match_policy_type, creator); -} - -AgingPolicyRegister::AgingPolicyRegister(const AgingPolicyType aging_policy_type, const AgingPolicyCreator &creator) { - PolicyRegister::GetInstance().RegisterAgingPolicy(aging_policy_type, creator); -} -} // namespace ge diff --git a/graph/common/hyper_status.cc b/graph/common/hyper_status.cc deleted file mode 100644 index 1efcd8569d84679f5f8a6108bd3542d2672a8071..0000000000000000000000000000000000000000 --- a/graph/common/hyper_status.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/hyper_status.h" - -#include -#include -#include - -namespace gert { -ge::char_t *CreateMessage(const ge::char_t *format, va_list arg) { - if (format == nullptr) { - return nullptr; - } - - va_list arg_copy; - va_copy(arg_copy, arg); - int len = vsnprintf(nullptr, 0, format, arg_copy); - va_end(arg_copy); - - if (len < 0) { - return nullptr; - } - - auto msg = std::unique_ptr(new (std::nothrow) ge::char_t[len + 1]); - if (msg == nullptr) { - return nullptr; - } - - auto ret = vsnprintf_s(msg.get(), len + 1, len, format, arg); - if (ret < 0) { - return nullptr; - } - - return msg.release(); -} -HyperStatus HyperStatus::Success() { - return {}; -} -HyperStatus::HyperStatus(const HyperStatus &other) : status_{nullptr} { - *this = other; -} -HyperStatus &HyperStatus::operator=(const HyperStatus &other) { - if (this == &other) { - return *this; - } - if (status_ != nullptr) { - delete [] status_; - status_ = nullptr; - } - if (other.status_ == nullptr) { - status_ = nullptr; - } else { - size_t status_len = strlen(other.status_) + 1; - status_ = new (std::nothrow) ge::char_t[status_len]; - if (status_ != nullptr) { - auto ret = strcpy_s(status_, status_len, other.status_); - if (ret != EOK) { - status_[0] = '\0'; - } - } - } - return *this; -} -HyperStatus::HyperStatus(HyperStatus &&other) noexcept { - status_ = other.status_; - other.status_ = nullptr; -} -HyperStatus &HyperStatus::operator=(HyperStatus &&other) noexcept { - if (this != &other) { - delete [] status_; - status_ = other.status_; - other.status_ = nullptr; - } - return *this; -} -HyperStatus HyperStatus::ErrorStatus(const ge::char_t *message, ...) { - HyperStatus status; - va_list arg; - va_start(arg, message); - status.status_ = CreateMessage(message, arg); - va_end(arg); - return status; -} -HyperStatus HyperStatus::ErrorStatus(std::unique_ptr message) { - HyperStatus status; - status.status_ = message.release(); - return status; -} -} diff --git a/graph/common/large_bm.cc b/graph/common/large_bm.cc deleted file mode 100644 index de9b9e9aadc4ab66b2a0dea49fcdd5b92091e2a2..0000000000000000000000000000000000000000 --- a/graph/common/large_bm.cc +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/large_bm.h" -#include "graph/debug/ge_log.h" - -namespace ge { -constexpr size_t kBitsEachValue = 64UL; - -constexpr size_t AlignBitSize(size_t bit_size) { - return bit_size + kBitsEachValue - 1; -} - -constexpr size_t AlignArraySize(size_t bit_size) { - return AlignBitSize(bit_size) >> 6; -} - -void LargeBitmap::ResizeBits(size_t new_size) { - if (new_size < size_) { - return; - } - - size_t new_byte_size = AlignArraySize(new_size); - if (new_byte_size == AlignArraySize(size_)) { - size_ = new_size; - return; - } - - this->bits_.resize(new_byte_size, 0); - for (size_t i = size_; i < AlignBitSize(size_); ++i) { - ClearBit(i); - } - - size_ = new_size; -} - -// Shifting right by 6 bits is equivalent to dividing by 64 -void LargeBitmap::ClearBit(size_t bit_idx) { - bits_[bit_idx >> 6] &= ~(1UL << (bit_idx % kBitsEachValue)); -} - -LargeBitmap::LargeBitmap(const size_t &size) - : size_(size), bits_(AlignArraySize(size), 0UL) {} - -bool LargeBitmap::operator==(const LargeBitmap &another_bm) const { - return bits_ == another_bm.bits_; -} - -bool LargeBitmap::operator!=(const LargeBitmap &another_bm) const { - return bits_ != another_bm.bits_; -} - -void LargeBitmap::SetValues(const uint64_t &value) { - std::fill(bits_.begin(), bits_.end(), value); -} - -void LargeBitmap::SetBit(const size_t &index) { - if (index < size_) { - bits_[index / kBitsEachValue] |= 1UL << (index % kBitsEachValue); - } else { - GE_LOGE("index %zu is not valid. Total size is %zu", index, size_); - return; - } -} - -bool LargeBitmap::GetBit(const size_t &index) const { - if (index < size_) { - return static_cast(bits_[index / kBitsEachValue] & (1UL << (index % kBitsEachValue))); - } else { - GE_LOGE("index %zu is not valid. Total size is %zu", index, size_); - return false; - } -} - -void LargeBitmap::Or(const LargeBitmap &another_bm) { - size_t index = 0UL; - const size_t another_size = another_bm.bits_.size(); - for (auto &bit : bits_) { - if (index >= another_size) { - return; - } - bit |= another_bm.bits_[index]; - ++index; - } -} - -void LargeBitmap::And(const LargeBitmap &another_bm) { - size_t index = 0UL; - const size_t another_size = another_bm.bits_.size(); - for (auto &bit : bits_) { - if (index >= another_size) { - return; - } - bit &= another_bm.bits_[index]; - ++index; - } -} -} diff --git a/graph/context/inference_context.cc b/graph/context/inference_context.cc deleted file mode 100644 index 75a1529ebce9aa7cae9f3c3df308ad050b7fc397..0000000000000000000000000000000000000000 --- a/graph/context/inference_context.cc +++ /dev/null @@ -1,182 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/inference_context.h" -#include "debug/ge_util.h" -#include "debug/ge_log.h" -#include "graph/ge_context.h" -#include "graph/resource_context_mgr.h" - -namespace ge { -class ShapeAndTypeImpl { - public: - ShapeAndTypeImpl() = default; - ~ShapeAndTypeImpl() = default; - - ShapeAndTypeImpl(const Shape &shape, const DataType data_type) : shape_(shape), data_type_(data_type) {} - - private: - Shape shape_; - DataType data_type_ = DT_UNDEFINED; - - friend class ShapeAndType; -}; - -struct InnerInferenceContext { -private: - // For deliver to op in pair, help to support dynamic shape - std::vector marks; - std::vector> input_handle_shapes_and_types; - std::vector> output_handle_shapes_and_types; - // For write op , if reousce changed, push to here - std::set changed_resource_keys; - // For read op, register relied resource - std::set relied_resource_keys; - ResourceContextMgr *resource_context_mgr = nullptr; - - friend class InferenceContext; -}; - -ShapeAndType::ShapeAndType() { shape_and_type_impl_ = ComGraphMakeShared(); } - -ShapeAndType::ShapeAndType(const Shape &shape, DataType data_type) { - shape_and_type_impl_ = ComGraphMakeShared(shape, data_type); -} - -void ShapeAndType::SetShape(const Shape &shape) { - if (shape_and_type_impl_ != nullptr) { - shape_and_type_impl_->shape_ = shape; - } -} - -void ShapeAndType::SetType(DataType data_type) { - if (shape_and_type_impl_ != nullptr) { - shape_and_type_impl_->data_type_ = data_type; - } -} - -Shape ShapeAndType::GetShape() const { - if (shape_and_type_impl_ != nullptr) { - return shape_and_type_impl_->shape_; - } - return Shape(); -} - -DataType ShapeAndType::GetDataType() const { - if (shape_and_type_impl_ != nullptr) { - return shape_and_type_impl_->data_type_; - } - return DT_UNDEFINED; -} - -InferenceContext::InferenceContext(std::unique_ptr &inner_context) { - inner_inference_context_ = std::move(inner_context); -} - -std::unique_ptr InferenceContext::Create(void *resource_context_mgr) { - std::unique_ptr inner_context = ComGraphMakeUnique(); - if (inner_context == nullptr) { - return nullptr; - } - inner_context->resource_context_mgr = PtrToPtr(resource_context_mgr); - - return std::unique_ptr(new (std::nothrow) InferenceContext(inner_context)); -} - -void InferenceContext::SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types) { - inner_inference_context_->input_handle_shapes_and_types.swap(shapes_and_types); -} - -const std::vector> &InferenceContext::GetInputHandleShapesAndTypes() const { - return inner_inference_context_->input_handle_shapes_and_types; -} - -const std::vector> &InferenceContext::GetOutputHandleShapesAndTypes() const { - return inner_inference_context_->output_handle_shapes_and_types; -} - -void InferenceContext::SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types) { - inner_inference_context_->output_handle_shapes_and_types = shapes_and_types; -} - -void InferenceContext::SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types) { - inner_inference_context_->output_handle_shapes_and_types.swap(shapes_and_types); -} - -void InferenceContext::SetMarks(const std::vector &marks) { inner_inference_context_->marks = marks; } - -void InferenceContext::SetMarks(const std::vector &marks) { - std::vector impl_marks; - for (const auto &mark : marks) { - if (mark.GetString() != nullptr) { - impl_marks.emplace_back(mark.GetString()); - } - } - inner_inference_context_->marks = impl_marks; -} - -const std::vector &InferenceContext::GetMarks() const { return inner_inference_context_->marks; } - -void InferenceContext::GetMarks(std::vector &marks) const { - for (auto &str_mark : inner_inference_context_->marks) { - marks.emplace_back(str_mark.c_str()); - } -} - -ResourceContext *InferenceContext::GetResourceContext(const ge::AscendString &key) { - if (inner_inference_context_->resource_context_mgr == nullptr) { - return nullptr; - } - return inner_inference_context_->resource_context_mgr->GetResourceContext(key.GetString()); -} - -graphStatus InferenceContext::SetResourceContext(const ge::AscendString &key, ResourceContext *resource_context) { - if (std::string(key.GetString()).empty()) { - GELOGE(GRAPH_PARAM_INVALID, "Resource key is null, invalid param."); - return GRAPH_PARAM_INVALID; - } - if (inner_inference_context_->resource_context_mgr == nullptr) { - GELOGE(GRAPH_FAILED, "No resource context mgr exist, resource context can not deliver in graph." - "Please check session initialized success or not."); - return GRAPH_FAILED; - } - (void)inner_inference_context_->resource_context_mgr->SetResourceContext(key.GetString(), resource_context); - return GRAPH_SUCCESS; -} - -graphStatus InferenceContext::AddChangedResourceKey(const ge::AscendString &key) { - if (std::string(key.GetString()).empty()) { - GELOGE(GRAPH_PARAM_INVALID, "Resource key is null, invalid param."); - return GRAPH_PARAM_INVALID; - } - (void)inner_inference_context_->changed_resource_keys.insert(key.GetString()); - return GRAPH_SUCCESS; -} - -void InferenceContext::ClearChangedResourceKeys() { - inner_inference_context_->changed_resource_keys.clear(); -} - -const std::set &InferenceContext::GetChangedResourceKeys() const { - return inner_inference_context_->changed_resource_keys; -} - -graphStatus InferenceContext::RegisterReliedOnResourceKey(const ge::AscendString &key) { - if (std::string(key.GetString()).empty()) { - GELOGE(GRAPH_PARAM_INVALID, "Resource key is null, invalid param."); - return GRAPH_PARAM_INVALID; - } - (void)inner_inference_context_->relied_resource_keys.insert(key.GetString()); - return GRAPH_SUCCESS; -} - -const std::set &InferenceContext::GetReliedOnResourceKeys() const { - return inner_inference_context_->relied_resource_keys; -} -} // namespace ge diff --git a/graph/context/resource_context_mgr.cc b/graph/context/resource_context_mgr.cc deleted file mode 100644 index 1a3e510aad7ed64db73b1bb6ffdc85b3261a5466..0000000000000000000000000000000000000000 --- a/graph/context/resource_context_mgr.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/resource_context_mgr.h" - -namespace ge { -ResourceContext *ResourceContextMgr::GetResourceContext(const std::string &resource_key) { - const std::lock_guard lk(ctx_mu_); - const std::map>::const_iterator - iter = resource_keys_to_contexts_.find(resource_key); - if (iter == resource_keys_to_contexts_.cend()) { - return nullptr; - } - return resource_keys_to_contexts_[resource_key].get(); -} - -graphStatus ResourceContextMgr::SetResourceContext(const std::string &resource_key, ResourceContext *const context) { - const std::lock_guard lk(ctx_mu_); - resource_keys_to_contexts_[resource_key] = std::unique_ptr(context); - return GRAPH_SUCCESS; -} - -graphStatus ResourceContextMgr::RegisterNodeReliedOnResource(const std::string &resource_key, NodePtr &node) { - const std::lock_guard lk(ctx_mu_); - (void)resource_keys_to_read_nodes_[resource_key].emplace(node); - return GRAPH_SUCCESS; -} - -OrderedNodeSet &ResourceContextMgr::MutableNodesReliedOnResource(const std::string &resource_key) { - const std::lock_guard lk(ctx_mu_); - return resource_keys_to_read_nodes_[resource_key]; -} - -graphStatus ResourceContextMgr::ClearContext() { - const std::lock_guard lk_resource(ctx_mu_); - resource_keys_to_contexts_.clear(); - resource_keys_to_read_nodes_.clear(); - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/context/runtime_inference_context.cc b/graph/context/runtime_inference_context.cc deleted file mode 100644 index c9e20b2f1150aefdd75e3f0454a304228de4c025..0000000000000000000000000000000000000000 --- a/graph/context/runtime_inference_context.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include - -#include "graph/utils/tensor_adapter.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/runtime_inference_context.h" - -namespace ge { -void RuntimeInferenceContext::Release() { - const std::lock_guard lk(mu_); - ge_tensors_.clear(); -} - -graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int32_t output_id, GeTensorPtr tensor) { - const std::lock_guard lk(mu_); - auto &output_ge_tensors = ge_tensors_[node_id]; - if (static_cast(output_id) >= output_ge_tensors.size()) { - const size_t output_tensor_size = static_cast(output_id) + 1U; - output_ge_tensors.resize(output_tensor_size); - } - - GELOGD("Set tensor for node_id = %" PRId64 ", output_id = %" PRId32, node_id, output_id); - output_ge_tensors[static_cast(output_id)] = std::move(tensor); - - return GRAPH_SUCCESS; -} - -graphStatus RuntimeInferenceContext::GetTensor(const int64_t node_id, int32_t output_id, GeTensorPtr &tensor) const { - if (output_id < 0) { - REPORT_INNER_ERR_MSG("E18888", "Invalid output index: %d", output_id); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Invalid output index: %d", output_id); - return GRAPH_PARAM_INVALID; - } - - const std::lock_guard lk(mu_); - const auto iter = ge_tensors_.find(node_id); - if (iter == ge_tensors_.end()) { - GELOGW("Node not register. Id = %" PRId64, node_id); - return INTERNAL_ERROR; - } - - auto &output_tensors = iter->second; - if (static_cast(output_id) >= output_tensors.size()) { - GELOGW("The %" PRId32 " th output tensor for node id [%" PRId64 "] has not been registered.", output_id, node_id); - return GRAPH_FAILED; - } - - GELOGD("Get ge tensor for node_id = %" PRId64 ", output_id = %" PRId32, node_id, output_id); - tensor = output_tensors[static_cast(output_id)]; - if (tensor == nullptr) { - GELOGW("The %" PRId32 " th output tensor registered for node id [%" PRId64 "] is nullptr.", output_id, node_id); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/detail/attributes_holder.cc b/graph/detail/attributes_holder.cc deleted file mode 100644 index fb28d06aa1050da2a2924510423240eddcd80bf0..0000000000000000000000000000000000000000 --- a/graph/detail/attributes_holder.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/detail/attributes_holder.h" - -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/ge_attr_value.h" -#include "proto/ge_ir.pb.h" - -namespace ge { -void AttrHolder::CopyAttrsFrom(const AttrHolder &holder) { - MutableAttrMap() = holder.GetAttrMap(); -} -void AttrHolder::CopyFrom(const AttrHolder &holder) { - required_attrs_and_type_ = holder.required_attrs_and_type_; - ext_attrs_ = holder.ext_attrs_; -} - -graphStatus AttrHolder::SetAttr(const std::string &name, const AnyValue &value) { - if (value.IsEmpty()) { - REPORT_INNER_ERR_MSG("E18888", "param value is empty, check invalid, key of the attr:%s", name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] value is empty, key of the attr is %s", name.c_str()); - return GRAPH_FAILED; - } - if (!MutableAttrMap().SetAnyValueByName(name, value)) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} -graphStatus AttrHolder::TrySetAttr(const std::string &name, const AnyValue &value) { - if (value.IsEmpty()) { - REPORT_INNER_ERR_MSG("E18888", "param value is empty, check invalid, key of the attr:%s", name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] value is empty, key of the attr is %s", name.c_str()); - return GRAPH_FAILED; - } - if (MutableAttrMap().Exists(name)) { - GELOGW("attr %s already existed, skip update", name.c_str()); - } else { - if (!MutableAttrMap().SetAnyValueByName(name, value)) { - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} -graphStatus AttrHolder::AddRequiredAttr(const std::string &name) { - return AddRequiredAttrWithType(name, ""); -} - -graphStatus AttrHolder::AddRequiredAttrWithType(const std::string &name, const std::string &type) { - if (HasAttr(name)) { - return GRAPH_FAILED; - } - required_attrs_and_type_.emplace(name, type); - return GRAPH_SUCCESS; -} - -graphStatus AttrHolder::GetAttr(const std::string &name, AnyValue &value) const { - const auto av = GetAttrMap().GetAnyValue(name); - if (av == nullptr) { - return GRAPH_FAILED; - } - value = *av; - return GRAPH_SUCCESS; -} - -bool AttrHolder::HasAttr(const std::string &name) const { - if (GetAttrMap().Exists(name)) { - return true; - } - return required_attrs_and_type_.find(name) != required_attrs_and_type_.end(); -} - -graphStatus AttrHolder::DelAttr(const std::string &name) { - return MutableAttrMap().Delete(name) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -const std::map AttrHolder::GetAllAttrs() const { - return GetAttrMap().GetAllAttrs(); -} - -const std::map AttrHolder::GetAllAttrsWithFilter(const AttrNameFilter &attr_filter) const { - return GetAttrMap().GetAllAttrsWithFilter(attr_filter); -} - -const std::set AttrHolder::GetAllAttrNames() const { - return GetAttrMap().GetAllAttrNames(); -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create AttrDef failed."); - GELOGE(GRAPH_FAILED, "[Create][AttrDef] proto::AttrDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create TensorDef failed."); - GELOGE(GRAPH_FAILED, "[Create][TensorDef] proto::TensorDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create TensorDescriptor failed."); - GELOGE(GRAPH_FAILED, "[Create][TensorDescriptor] proto::TensorDescriptor make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create ShapeDef failed."); - GELOGE(GRAPH_FAILED, "[Create][ShapeDef] proto::ShapeDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create NamedAttrs failed."); - GELOGE(GRAPH_FAILED, "[Create][NamedAttrs] proto::NamedAttrs make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create ModelDef failed."); - GELOGE(GRAPH_FAILED, "[Create][ModelDef] proto::ModelDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create OpDef failed."); - GELOGE(GRAPH_FAILED, "[Create][OpDef] proto::OpDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} - -template <> -void GeIrProtoHelper::InitDefault() { - std::shared_ptr proto_owner; - proto_owner = ComGraphMakeShared(); - if (proto_owner == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create GraphDef failed."); - GELOGE(GRAPH_FAILED, "[Create][GraphDef] proto::GraphDef make shared failed"); - return; - } - protoMsg_ = proto_owner.get(); - protoOwner_ = proto_owner; -} -} // namespace ge diff --git a/graph/expression/CMakeLists.txt b/graph/expression/CMakeLists.txt deleted file mode 100644 index ab3b4b719f8855b3b2edaee47b01035361c63574..0000000000000000000000000000000000000000 --- a/graph/expression/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -add_library(aihac_symbolizer SHARED - expression.cc - expression_impl.cc - expr_print_manager.cc - symbol_operator.cc - symbolic_utils.cc - symbol_checker.cc - expr_parser.cc - scanner.cc - attr_group_shape_env_attr.cc - attr_group_symbolic_desc_attr.cc - guard_dfx_context.cc -) -target_compile_options(aihac_symbolizer PRIVATE -DNO_METADEF_ABI_COMPATIABLE -O2 -Werror) - -target_compile_options(aihac_symbolizer PRIVATE - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$,$>:/MTd> - $<$,$>:/MT>) - -target_compile_definitions(aihac_symbolizer PRIVATE - $<$,$>:FMK_SUPPORT_DUMP> - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> -) - -target_include_directories(aihac_symbolizer PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${METADEF_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_link_options(aihac_symbolizer PRIVATE - -Wl,-Bsymbolic -) - -target_link_libraries(aihac_symbolizer - PRIVATE - intf_pub - static_mmpa - -Wl,--no-as-needed - c_sec - slog - json - platform - symengine - error_manager - graph_base - graph - slog - Boost::boost - -Wl,--as-needed - ascend_protobuf_shared_headers - ascend_protobuf - $<$>:-lrt> - -ldl - PUBLIC - metadef_headers -) -set_target_properties(aihac_symbolizer PROPERTIES - CXX_EXTENSIONS NO) diff --git a/graph/expression/attr_group_shape_env_attr.cc b/graph/expression/attr_group_shape_env_attr.cc deleted file mode 100644 index e955f2ff90fe0fe60c5086bbd2ce5c6dbdf011b3..0000000000000000000000000000000000000000 --- a/graph/expression/attr_group_shape_env_attr.cc +++ /dev/null @@ -1,472 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "attribute_group/attr_group_shape_env.h" -#include "graph/debug/ge_util.h" -#include "graph/detail/attributes_holder.h" -#include "proto/ge_ir.pb.h" -#include "graph/expression/expression_impl.h" -#include "graph/utils/type_utils.h" - -namespace ge { -namespace { -static thread_local ShapeEnvAttr *shape_env_context{nullptr}; -static std::map kGeDType2CppDtype = { - {ge::DT_INT32, "int32_t"}, - {ge::DT_INT64, "int64_t"}, - {ge::DT_UINT32, "uint32_t"}, - {ge::DT_UINT64, "uint64_t"}, -}; -} - -thread_local std::string ShapeEnvAttr::guard_dfx_info_ = ""; -ShapeEnvAttr *GetCurShapeEnvContext() { - return shape_env_context; -} - -void SetCurShapeEnvContext(ShapeEnvAttr *shape_env) { - shape_env_context = shape_env; -} - -std::string Source::GetGlobalIndexStr() const { - return "context->GetInputPointer(" + std::to_string(global_index_) + ")"; -} - -graphStatus ShapeEnvAttr::SerializeSymbolInfo(proto::ShapeEnvAttrGroupsDef *shape_env_group) { - GE_ASSERT_NOTNULL(shape_env_group); - shape_env_group->clear_symbol_to_value(); - auto symbol_to_value_def = shape_env_group->mutable_symbol_to_value(); - GE_ASSERT_NOTNULL(symbol_to_value_def); - GELOGI("symbol_to_value size: %zu", symbol_to_value_.size()); - for (const auto &iter : symbol_to_value_) { - GE_ASSERT_TRUE(!iter.first.IsConstExpr(), - "Symbol in symbol_to_value of shape env attr should be a variable, but get: %s", - iter.first.Serialize().get()); - symbol_to_value_def->insert({iter.first.Serialize().get(), iter.second}); - } - auto value_to_symbol_def = shape_env_group->mutable_value_to_symbol(); - GE_ASSERT_NOTNULL(value_to_symbol_def); - for (const auto &iter : value_to_symbol_) { - GE_ASSERT_TRUE(!iter.second.empty()); - proto::SymbolInfoDef symbol_infos; - for (const auto &sym_iter : iter.second) { - GE_ASSERT_TRUE(!sym_iter.IsConstExpr(), - "Symbol in value_to_symbol of shape env attr should be a variable, but get: %s", - sym_iter.Serialize().get()); - symbol_infos.add_symbols(sym_iter.Serialize().get()); - } - value_to_symbol_def->insert({iter.first, symbol_infos}); - } - - auto symbol_to_source_def = shape_env_group->mutable_symbol_to_source(); - GE_ASSERT_NOTNULL(symbol_to_source_def); - // todoo: symbol_to_source_实现序列化 - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::SerializeSymbolCheckInfos(proto::ShapeEnvAttrGroupsDef *shape_env_group) { - GE_ASSERT_NOTNULL(shape_env_group); - auto replacements_def = shape_env_group->mutable_replacements(); - for (const auto &iter : replacements_) { - proto::ReplacementDef rep_def; - rep_def.set_replace_expr(iter.second.replace_expr.Serialize().get()); - rep_def.set_rank(iter.second.rank); - replacements_def->insert({iter.first.Serialize().get(), rep_def}); - } - shape_env_group->clear_symbol_check_infos(); - for (const auto &iter : symbol_check_infos_) { - proto::SymbolCheckInfoDef *symbol_check_info_def = shape_env_group->add_symbol_check_infos(); - symbol_check_info_def->set_expr(iter.expr.Serialize().get()); - symbol_check_info_def->set_file(iter.file); - symbol_check_info_def->set_line(iter.line); - symbol_check_info_def->set_dfx(iter.dfx_info); - } - shape_env_group->clear_symbol_assert_infos(); - for (const auto &iter : symbol_assert_infos_) { - proto::SymbolCheckInfoDef *symbol_assert_info_def = shape_env_group->add_symbol_assert_infos(); - symbol_assert_info_def->set_expr(iter.expr.Serialize().get()); - symbol_assert_info_def->set_file(iter.file); - symbol_assert_info_def->set_line(iter.line); - symbol_assert_info_def->set_dfx(iter.dfx_info); - } - return GRAPH_SUCCESS; -} - -ShapeEnvAttr::ShapeEnvAttr(const ShapeEnvAttr& other) { - shape_env_setting_ = other.shape_env_setting_; - symbol_to_value_ = other.symbol_to_value_; - value_to_symbol_ = other.value_to_symbol_; - symbol_to_source_ = other.symbol_to_source_; - replacements_ = other.replacements_; - symbol_check_infos_ = other.symbol_check_infos_; - symbol_assert_infos_ = other.symbol_assert_infos_; - unique_sym_id_ = other.unique_sym_id_; -} - -ShapeEnvAttr& ShapeEnvAttr::operator=(const ShapeEnvAttr& other) { - if (this != &other) { - shape_env_setting_ = other.shape_env_setting_; - symbol_to_value_ = other.symbol_to_value_; - value_to_symbol_ = other.value_to_symbol_; - symbol_to_source_ = other.symbol_to_source_; - replacements_ = other.replacements_; - symbol_check_infos_ = other.symbol_check_infos_; - symbol_assert_infos_ = other.symbol_assert_infos_; - unique_sym_id_ = other.unique_sym_id_; - } - return *this; -} - -graphStatus ShapeEnvAttr::Serialize(proto::AttrGroupDef &attr_group_def) { - auto shape_env_group = attr_group_def.mutable_shape_env_attr_group(); - GE_ASSERT_SUCCESS(SerializeSymbolInfo(shape_env_group)); - GE_ASSERT_SUCCESS(SerializeSymbolCheckInfos(shape_env_group)); - proto::ShapeEnvSettingDef *shape_env_setting_def = shape_env_group->mutable_shape_setting(); - shape_env_setting_def->set_specialize_zero_one(shape_env_setting_.specialize_zero_one); - shape_env_setting_def->set_dynamic_mode(static_cast(shape_env_setting_.dynamic_mode)); - shape_env_group->set_unique_sym_id(unique_sym_id_); - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::DeserializeSymbolInfo(const proto::ShapeEnvAttrGroupsDef &shape_env_group) { - symbol_to_value_.clear(); - GELOGI("symbol_to_value size: %zu", shape_env_group.symbol_to_value_size()); - for (const auto &iter : shape_env_group.symbol_to_value()) { - Expression sym = Expression::Deserialize(iter.first.c_str()); - GE_ASSERT_TRUE(!sym.IsConstExpr(), - "Symbol in symbol_to_value of shape env attr should be a variable, but get: %s", - iter.first.c_str()); - symbol_to_value_.emplace(std::make_pair(sym, iter.second)); - } - value_to_symbol_.clear(); - for (const auto &iter : shape_env_group.value_to_symbol()) { - std::vector symbol_infos; - for (const auto &sym_iter : iter.second.symbols()) { - Expression sym = Expression::Deserialize(sym_iter.c_str()); - GE_ASSERT_TRUE(!sym.IsConstExpr(), - "Symbol in value_to_symbol of shape env attr should be a variable, but get: %s", - sym_iter.c_str()); - symbol_infos.emplace_back(sym); - } - value_to_symbol_.emplace(std::make_pair(iter.first, symbol_infos)); - } - symbol_to_source_.clear(); - // todoo: symbol_to_source_实现反序列化 - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::DeserializeSymbolCheckInfos(const proto::ShapeEnvAttrGroupsDef &shape_env_group) { - replacements_.clear(); - for (const auto &iter : shape_env_group.replacements()) { - Expression expr = Expression::Deserialize(iter.first.c_str()); - Expression replace_expr = Expression::Deserialize(iter.second.replace_expr().c_str()); - replacements_.emplace(std::make_pair(expr, Replacement(replace_expr, iter.second.rank()))); - } - symbol_check_infos_.clear(); - for (const auto &iter : shape_env_group.symbol_check_infos()) { - Expression expr = Expression::Deserialize(iter.expr().c_str()); - symbol_check_infos_.emplace(SymbolCheckInfo(expr, iter.file(), iter.line(), iter.dfx())); - } - symbol_assert_infos_.clear(); - for (const auto &iter : shape_env_group.symbol_assert_infos()) { - Expression expr = Expression::Deserialize(iter.expr().c_str()); - symbol_assert_infos_.emplace(SymbolCheckInfo(expr, iter.file(), iter.line(), iter.dfx())); - } - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) { - (void) attr_holder; - const auto& shape_env_group = attr_group_def.shape_env_attr_group(); - DeserializeSymbolInfo(shape_env_group); - DeserializeSymbolCheckInfos(shape_env_group); - shape_env_setting_ = - ShapeEnvSetting(shape_env_group.shape_setting().specialize_zero_one(), - static_cast(shape_env_group.shape_setting().dynamic_mode())); - unique_sym_id_ = shape_env_group.unique_sym_id(); - return GRAPH_SUCCESS; -} - -std::unique_ptr ShapeEnvAttr::Clone() { - std::unique_ptr new_attr = ComGraphMakeUnique(*this); - GE_ASSERT_NOTNULL(new_attr); - return new_attr; -} - -bool ShapeEnvAttr::HasSymbolCheckInfo(const ge::Expression &e) const { - auto expr = e.CanonicalizeBoolExpr(); - if (symbol_check_infos_.find(SymbolCheckInfo(expr)) != symbol_check_infos_.end()) { - return true; - } - return false; -} - -bool ShapeEnvAttr::HasSymbolAssertInfo(const ge::Expression &e) const { - auto expr = e.CanonicalizeBoolExpr(); - if (symbol_assert_infos_.find(SymbolCheckInfo(expr)) != symbol_assert_infos_.end()) { - return true; - } - return false; -} - -ge::Expression ShapeEnvAttr::FindReplacements(const ge::Expression &expr) { - auto iter = replacements_.find(expr); - if (iter == replacements_.end()) { - return expr; - } - if (iter->second.has_replace) { - GELOGD("Find replace expr: %s of expr: %s has replace", - iter->second.replace_expr.Str().get(), expr.Str().get()); - return expr; - } - auto replace_expr = iter->second.replace_expr; - GELOGD("Find replace expr: %s of expr: %s", - replace_expr.Str().get(), expr.Str().get()); - if (replace_expr == expr) { - return expr; - } - std::vector> var_replacements; - iter->second.has_replace = true; - for (auto &sym : replace_expr.FreeSymbols()) { - auto replace_sym = FindReplacements(sym); - var_replacements.emplace_back(std::make_pair(sym, replace_sym)); - } - iter->second.has_replace = false; - return replace_expr.Replace(var_replacements); -} - -const std::vector ShapeEnvAttr::GetAllSymbolCheckInfos() const { - std::vector results; - for (const auto &iter : symbol_check_infos_) { - results.emplace_back(iter); - } - return results; -} - -const std::vector ShapeEnvAttr::GetAllSymbolAssertInfos() const { - std::vector results; - for (const auto &iter : symbol_assert_infos_) { - results.emplace_back(iter); - } - return results; -}; - -void ShapeEnvAttr::SimplifySymbolCheckInfo( - std::set &symbol_check_infos) { - std::vector simplify_symbol_check_info; - for (auto &iter : symbol_check_infos) { - const auto simple_expr = iter.expr.Simplify().CanonicalizeBoolExpr(); - if (simple_expr.IsConstExpr()) { - continue; - } - (void)simplify_symbol_check_info.emplace_back(SymbolCheckInfo(simple_expr)); - } - (void)symbol_check_infos.insert(simplify_symbol_check_info.begin(), - simplify_symbol_check_info.end()); -} - -void ShapeEnvAttr::SimplifySymbolCheckInfo() { - GELOGD("Start simplify guard"); - SimplifySymbolCheckInfo(symbol_check_infos_); - SimplifySymbolCheckInfo(symbol_assert_infos_); -} - -ge::Expression ShapeEnvAttr::Simplify(const ge::Expression &expr) { - std::vector> var_replacements; - // 初始化replacements遍历状态 - for (auto &iter : replacements_) { - iter.second.has_replace = false; - } - for (const auto &sym : expr.FreeSymbols()) { - auto replace_expr = FindReplacements(sym); - if ((!replace_expr.IsVariableExpr()) || (replace_expr != sym)) { - var_replacements.emplace_back(std::make_pair(sym, replace_expr)); - } - } - if (!var_replacements.empty()) { - auto result_expr = expr.Replace(var_replacements); - GELOGI("Simplify expr: %s to expr: %s", - expr.Serialize().get(), result_expr.Serialize().get()); - GE_ASSERT_NOTNULL(result_expr.impl_); - return Expression(result_expr.impl_->Simplify()); - } - return Expression(expr.impl_->Simplify()); -} - -ge::Expression ShapeEnvAttr::EvaluateExpr(const ge::Expression &expr) { - std::vector> var_to_val; - auto free_symbols = expr.FreeSymbols(); - for (const auto &free_sym : free_symbols) { - const auto &iter = symbol_to_value_.find(free_sym); - if (iter != symbol_to_value_.end()) { - var_to_val.emplace_back(std::make_pair(free_sym, Symbol(iter->second))); - } - } - return expr.Subs(var_to_val); -} - -TriBool ShapeEnvAttr::HasSymbolInfo(const Expression &expr) const { - Expression e = expr.CanonicalizeBoolExpr(); - if (HasSymbolCheckInfo(e) || HasSymbolAssertInfo(e)) { - return TriBool::kTrue; - } - return TriBool::kUnknown; -} - -void ShapeEnvAttr::AppendInitReplacement(const ge::Expression &expr) { - if (replacements_.find(expr) == replacements_.end()) { - (void)replacements_.emplace(std::make_pair(expr, Replacement(expr, 1))); - } -} - -graphStatus ShapeEnvAttr::FindRootExpr(const ge::Expression &expr, ge::Expression &root_expr) { - const auto &iter = replacements_.find(expr); - GE_ASSERT_TRUE(iter != replacements_.end(), "Can not find replacement of expr: %s", expr.Serialize().get()); - if (iter->second.replace_expr == expr) { - root_expr = expr; - return GRAPH_SUCCESS; - } - GE_ASSERT_SUCCESS(FindRootExpr(iter->second.replace_expr, root_expr)); - return GRAPH_SUCCESS; -} - -std::vector> ShapeEnvAttr::GetAllSym2Src() { - std::vector> result; - for (const auto &iter : symbol_to_source_) { - result.emplace_back(iter.first, iter.second); - } - return result; -} - -bool Replacement::operator<=(const Replacement &other) { - // 并查集的根节点优先级: 常数 > 表达式 > 变量 - if (replace_expr.IsConstExpr()) { - if (other.replace_expr.IsConstExpr()) { - return rank <= other.rank; - } - return false; - } - if (replace_expr.IsVariableExpr()) { - if (other.replace_expr.IsVariableExpr()) { - return rank <= other.rank; - } - return true; - } - if (other.replace_expr.IsConstExpr()) { - return true; - } - if (other.replace_expr.IsVariableExpr()) { - return false; - } - return rank <= other.rank; -} - -graphStatus ShapeEnvAttr::MergeReplacement(const ge::Expression &expr1, - const ge::Expression &expr2) { - ge::Expression father_expr1; - GE_ASSERT_SUCCESS(FindRootExpr(expr1, father_expr1)); - ge::Expression father_expr2; - GE_ASSERT_SUCCESS(FindRootExpr(expr2, father_expr2)); - auto &replacement_1 = replacements_[father_expr1]; - auto &replacement_2 = replacements_[father_expr2]; - if (replacement_1 <= replacement_2) { - replacement_1.replace_expr = father_expr2; - if (replacement_2.rank <= replacement_1.rank) { - replacement_2.rank = replacement_1.rank + 1; - } - } else { - replacement_2.replace_expr = father_expr1; - if (replacement_1.rank <= replacement_2.rank) { - replacement_1.rank = replacement_2.rank + 1; - } - } - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::MergePath() { - for (auto &iter : replacements_) { - ge::Expression root_expr; - GE_ASSERT_SUCCESS(FindRootExpr(iter.first, root_expr)); - iter.second.replace_expr = root_expr; - iter.second.rank = 1; - } - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::AppendReplacement(const ge::Expression &target, const ge::Expression &replacement) { - if (target == replacement) { - return GRAPH_SUCCESS; - } - ge::Expression expr1 = target; - ge::Expression expr2 = replacement; - auto expr = sym::Eq(target, replacement).CanonicalizeBoolExpr(); - vector args = expr.GetArgs(); - if (args.size() == kSizeTwo) { - expr1 = args[0]; - expr2 = args[1]; - GELOGD("expr1 %s->%s, expr2 %s->%s", target.Serialize().get(), expr1.Serialize().get(), - replacement.Serialize().get(), expr2.Serialize().get()); - } - - // 仅支持 符号->常量,符号->表达式,符号->符号 映射 - if (expr1.IsConstExpr()) { - if (!expr2.IsVariableExpr()) { - GELOGW("Unsupport append replacement %s to %s", - expr1.Serialize().get(), expr2.Serialize().get()); - return GRAPH_SUCCESS; - } - } else if (!expr1.IsVariableExpr()) { - if (!expr2.IsVariableExpr()) { - GELOGW("Unsupport append replacement %s to %s", - expr1.Serialize().get(), expr2.Serialize().get()); - return GRAPH_SUCCESS; - } - } - AppendInitReplacement(expr1); - AppendInitReplacement(expr2); - GE_ASSERT_SUCCESS(MergeReplacement(expr1, expr2)); - // 路径压缩 - GE_ASSERT_SUCCESS(MergePath()); - // replace插入后全量化简已有的guard - SimplifySymbolCheckInfo(); - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::AppendSymbolAssertInfo(const ge::Expression &expr, - const std::string &file, const int64_t line) { - GE_ASSERT_TRUE(expr.IsBooleanExpr(), - "Assert expr: %s should be boolean", expr.Serialize().get()); - if (!expr.IsConstExpr()) { - (void)symbol_assert_infos_.emplace(SymbolCheckInfo(expr.CanonicalizeBoolExpr(), file, line, GetGuardDfxContextInfo())); - } - return GRAPH_SUCCESS; -} - -graphStatus ShapeEnvAttr::AppendSymbolCheckInfo(const ge::Expression &expr, - const std::string &file, const int64_t line) { - GE_ASSERT_TRUE(expr.IsBooleanExpr(), - "Check expr: %s should be boolean", expr.Serialize().get()); - if (!expr.IsConstExpr()) { - (void)symbol_check_infos_.emplace(SymbolCheckInfo(expr.CanonicalizeBoolExpr(), file, line, GetGuardDfxContextInfo())); - } - return GRAPH_SUCCESS; -} - -void ShapeEnvAttr::SetGuardDfxContextInfo(const std::string &guard_dfx_info) { - guard_dfx_info_ = guard_dfx_info; -} - -void ShapeEnvAttr::ClearGuardDfxContextInfo() { - guard_dfx_info_.clear(); -} - -std::string ShapeEnvAttr::GetGuardDfxContextInfo() const { - return guard_dfx_info_; -} -} // namespace ge diff --git a/graph/expression/attr_group_symbolic_desc_attr.cc b/graph/expression/attr_group_symbolic_desc_attr.cc deleted file mode 100644 index a5fce0d5e852e7a5205de85f10129523a78eb094..0000000000000000000000000000000000000000 --- a/graph/expression/attr_group_symbolic_desc_attr.cc +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "attribute_group/attr_group_symbolic_desc.h" -#include "common/checker.h" -#include "proto/ge_ir.pb.h" - -namespace ge { - -graphStatus SymbolicDescAttr::Serialize(proto::AttrGroupDef &attr_group_def) { - auto tensor_attr_group = attr_group_def.mutable_tensor_attr_group(); - GE_ASSERT_NOTNULL(tensor_attr_group); - tensor_attr_group->clear_origin_symbol_shape(); - for (const auto &ori_shape : symbolic_tensor.GetOriginSymbolShape().GetDims()) { - tensor_attr_group->add_origin_symbol_shape(ori_shape.Str().get()); - } - tensor_attr_group->clear_symbolic_value(); - if (symbolic_tensor.GetSymbolicValue() != nullptr) { - for (const auto &symbol_value : *symbolic_tensor.GetSymbolicValue()) { - tensor_attr_group->add_symbolic_value(symbol_value.Str().get()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus SymbolicDescAttr::Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) { - (void) attr_holder; - const auto& tensor_attr_group = attr_group_def.tensor_attr_group(); - this->symbolic_tensor.MutableOriginSymbolShape().Clear(); - for (const auto &ori_sh : tensor_attr_group.origin_symbol_shape()) { - this->symbolic_tensor.MutableOriginSymbolShape().AppendDim(ge::Expression::Parse(ori_sh.c_str())); - } - if (!tensor_attr_group.symbolic_value().empty()) { - auto symbol_value_ptr = ComGraphMakeUnique>(); - if (symbol_value_ptr != nullptr) { - for (const auto &symbol_value : tensor_attr_group.symbolic_value()) { - symbol_value_ptr->push_back(ge::Expression::Parse(symbol_value.c_str())); - } - this->symbolic_tensor.SetSymbolicValue(std::move(symbol_value_ptr)); - } - } - return GRAPH_SUCCESS; -} - -std::unique_ptr SymbolicDescAttr::Clone() { - std::unique_ptr attr = ComGraphMakeUnique(*this); - GE_ASSERT_NOTNULL(attr); - return attr; -} -} // namespace ge diff --git a/graph/expression/const_values.h b/graph/expression/const_values.h deleted file mode 100644 index 8b63a81729a61eb7c9b9c21d70f6b6c9430d80c1..0000000000000000000000000000000000000000 --- a/graph/expression/const_values.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_EXPRESSION_CONST_VALUE_H_ -#define GRAPH_EXPRESSION_CONST_VALUE_H_ - -#include -#include -#include "graph/symbolizer/symbolic.h" - -namespace ge { -namespace sym { -const uint32_t kNumOne = 1u; -const uint32_t kNumTwo = 2u; -const size_t kBinaryOpFirstArgIdx = 0u; -const size_t kBinaryOpSecondArgIdx = 1u; -const std::string kNullStr = ""; -const int32_t kMinusOne = -1; -const int32_t kConstOne = 1; -const int32_t kConstTwo = 2; -const size_t kSizeNumTwo = 2u; -const size_t kIndexOne = 1u; -const size_t kIndexTwo = 2u; -const int32_t kBaseTwo = 2; -const int32_t kMinDimLength = 1; -const Symbol kSymbolOne = ge::Symbol(1, "sym_one"); -const Symbol kSymbolZero = ge::Symbol(0, "sym_zero"); - -// options -const std::string kOutputFilePath = "output_file_path"; -const std::string kTilingDataTypeName = "tiling_data_type_name"; -const std::string kGenExtraInfo = "gen_extra_info"; -const std::string kDumpDebugInfo = "dump_debug_info"; -const std::string kGenTilingDataDef = "gen_tiling_data_def"; -const std::string kWithTilingContext = "with_tiling_context"; -const std::string kDefaultFilePath = "./"; -const std::string kDefaultTilingDataTypeName = "TilingData"; -const std::string kIsTrue = "1"; -const std::string kIsFalse = "0"; -const std::string kTilingFuncIdentify = "TilingFunc"; -const std::string kDefaultTilingDataFileName = "tiling_data.h"; -const std::string kDefaultTilingFuncFileName = "tiling_func.cpp"; -} // namespace sym -} // namespace ge - -#endif // GRAPH_EXPRESSION_CONST_VALUE_H_ \ No newline at end of file diff --git a/graph/expression/expr_parser.cc b/graph/expression/expr_parser.cc deleted file mode 100644 index 50333041371c39de456197932d400504ce50fddf..0000000000000000000000000000000000000000 --- a/graph/expression/expr_parser.cc +++ /dev/null @@ -1,327 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "expr_parser.h" -#include "common/checker.h" -#include "symengine/real_double.h" -namespace ge { -ExpressionImplPtr ExprParser::ParserExpression() { - auto ret = ParserAddSubtract(); - GE_ASSERT_NOTNULL(ret); - return ret; -} - -graphStatus ExprParser::Init() { - GE_ASSERT_SUCCESS(scanner_.GetNextToken(currentToken_)); - return ge::GRAPH_SUCCESS; -} -graphStatus ExprParser::Eat(TokenType type) { - GE_ASSERT(currentToken_.type == type); - GE_ASSERT_SUCCESS(scanner_.GetNextToken(currentToken_)); - return ge::GRAPH_SUCCESS; -} - -ExpressionImplPtr ExprParser::ParserFactor() { - switch (currentToken_.type) { - case TokenType::kIdentifier: - return ParserIdentifier(); - case TokenType::kLparen: - return ParserLParen(); - case TokenType::kMax: - return ParserMaxFunction(); - case TokenType::kMin: - return ParserMinFunction(); - case TokenType::kPow: - return ParserPowFunction(); - case TokenType::kMod: - return ParserModFunction(); - case TokenType::kLog: - return ParserLogFunction(); - case TokenType::kCeil: - return ParserCeilFunction(); - case TokenType::kFloor: - return ParserFloorFunction(); - case TokenType::kAbs: - return ParserAbsFunction(); - case TokenType::kRational: - return ParserRationalFunction(); - case TokenType::kNumber: - return ParserNumber(); - case TokenType::kMinus: - return ParserMinus(); - case TokenType::kEq: - return ParserEqual(); - case TokenType::kNe: - return ParserUnequal(); - case TokenType::kLe: - return ParserLessEqual(); - case TokenType::kLt: - return ParserLessThan(); - case TokenType::kTrue: - case TokenType::kFalse: - return ParseConstBoolen(); - case TokenType::kLogicalAnd: - return ParserLogicalAnd(); - case TokenType::kLogicalOr: - return ParserLogicalOr(); - default: - GELOGE(ge::PARAM_INVALID, "Unsupported operator %d when Parser factor.", currentToken_.type); - return nullptr; - } -} - -ExpressionImplPtr ExprParser::ParserAddSubtract() { - auto node = ParserMulDivide(); - GE_ASSERT_NOTNULL(node); - while (currentToken_.type == TokenType::kPlus || currentToken_.type == TokenType::kMinus) { - TokenType op = currentToken_.type; - GE_ASSERT_SUCCESS(Eat(op)); - auto right = ParserMulDivide(); - GE_ASSERT_NOTNULL(right); - switch (op) { - case TokenType::kPlus: - node = Add(node, right); - break; - case TokenType::kMinus: - node = Sub(node, right); - break; - default: - GELOGE(ge::PARAM_INVALID, "unsupported operator %d when parsing add and sub.", currentToken_.type); - return nullptr; - } - } - return node; -} - -ExpressionImplPtr ExprParser::ParserMulDivide() { - auto node = ParserFactor(); - GE_ASSERT_NOTNULL(node); - - while (currentToken_.type == TokenType::kMultiply || currentToken_.type == TokenType::kDivide) { - TokenType op = currentToken_.type; - GE_ASSERT_SUCCESS(Eat(op)); - auto right = ParserFactor(); - GE_ASSERT_NOTNULL(right); - switch (op) { - case TokenType::kMultiply: - node = Mul(node, right); - GE_ASSERT_NOTNULL(node); - break; - case TokenType::kDivide: - node = Div(node, right); - GE_ASSERT_NOTNULL(node); - break; - default: - GELOGE(ge::PARAM_INVALID, "unsupported operator %d when parsing mul and divide.", currentToken_.type); - return nullptr; - } - } - return node; -} - -ExpressionImplPtr ExprParser::ParserMaxFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kMax)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Max(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserMinFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kMin)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Min(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserEqual() { - GE_ASSERT_SUCCESS(Eat(TokenType::kEq)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Eq(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserUnequal() { - GE_ASSERT_SUCCESS(Eat(TokenType::kNe)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Ne(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserLessEqual() { - GE_ASSERT_SUCCESS(Eat(TokenType::kLe)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Le(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserLessThan() { - GE_ASSERT_SUCCESS(Eat(TokenType::kLt)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Lt(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserLogicalAnd() { - GE_ASSERT_SUCCESS(Eat(TokenType::kLogicalAnd)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - std::vector args; - args.push_back(std::move(arg1)); - args.push_back(std::move(arg2)); - return LogicalAnd(args); -} - -ExpressionImplPtr ExprParser::ParserLogicalOr() { - GE_ASSERT_SUCCESS(Eat(TokenType::kLogicalOr)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - std::vector args; - args.push_back(std::move(arg1)); - args.push_back(std::move(arg2)); - return LogicalOr(args); -} - -ExpressionImplPtr ExprParser::ParserPowFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kPow)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Pow(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserModFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kMod)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserAddSubtract(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserAddSubtract(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Mod(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserLogFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kLog)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Log(arg1); -} - -ExpressionImplPtr ExprParser::ParserCeilFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kCeil)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Ceiling(arg1); -} - -ExpressionImplPtr ExprParser::ParserFloorFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kFloor)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Floor(arg1); -} - -ExpressionImplPtr ExprParser::ParserAbsFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kAbs)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Abs(arg1); -} - -ExpressionImplPtr ExprParser::ParserRationalFunction() { - GE_ASSERT_SUCCESS(Eat(TokenType::kRational)); - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto arg1 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kComma)); - auto arg2 = ParserExpression(); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return Rational(arg1, arg2); -} - -ExpressionImplPtr ExprParser::ParserNumber() { - const std::string &numberStr = currentToken_.value; - try { - if (numberStr.find('.') != std::string::npos) { - double value = std::stod(numberStr); - GE_ASSERT_SUCCESS(Eat(TokenType::kNumber)); - return ExpressionImpl::CreateExpressionImpl(value); // 返回浮点数节点 - } else { - int64_t value = std::stoll(numberStr); - GE_ASSERT_SUCCESS(Eat(TokenType::kNumber)); - return ExpressionImpl::CreateExpressionImpl(value); // 返回整数节点 - } - } catch (std::invalid_argument &) { - GELOGW("number str:%s is invalid", numberStr.c_str()); - return nullptr; - } catch (std::out_of_range &) { - GELOGW("number str:%s is out_of_range", numberStr.c_str()); - return nullptr; - } -} - -ExpressionImplPtr ExprParser::ParserIdentifier() { - const std::string name{currentToken_.value}; - GE_ASSERT_SUCCESS(Eat(TokenType::kIdentifier)); - return ExpressionImpl::CreateExpressionImpl(name); -} - -ExpressionImplPtr ExprParser::ParseConstBoolen() { - bool sym_value = currentToken_.value == "True" ? true : false; - GE_ASSERT_SUCCESS(Eat(currentToken_.type)); - return ExpressionImpl::CreateExpressionImpl(sym_value); -} - -ExpressionImplPtr ExprParser::ParserMinus() { - GE_ASSERT_SUCCESS(Eat(TokenType::kMinus)); - ExpressionImplPtr node; - if (currentToken_.type == TokenType::kLparen) { - node = ParserLParen(); - } else { - node = ParserFactor(); - } - GE_ASSERT_NOTNULL(node); - return Neg(node); -} - -ExpressionImplPtr ExprParser::ParserLParen() { - GE_ASSERT_SUCCESS(Eat(TokenType::kLparen)); - auto node = ParserExpression(); - GE_ASSERT_NOTNULL(node); - GE_ASSERT_SUCCESS(Eat(TokenType::kRparen)); - return node; -} -} // namespace ge diff --git a/graph/expression/expr_parser.h b/graph/expression/expr_parser.h deleted file mode 100644 index 1a5666c588acc39d66fe94e0791d4f5869b2d6ad..0000000000000000000000000000000000000000 --- a/graph/expression/expr_parser.h +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_EXPRESSION_EXPR_PARSER_H_ -#define GRAPH_EXPRESSION_EXPR_PARSER_H_ -#include "scanner.h" -#include "graph/debug/ge_util.h" -#include "expression_impl.h" - -#include -#include -#include -namespace ge { - -class ExprParser { -public: - explicit ExprParser(Scanner scanner) : scanner_(scanner) { - Init(); - } - ExpressionImplPtr ParserExpression(); - -private: - graphStatus Init(); - graphStatus Eat(TokenType type); - ExpressionImplPtr ParserFactor(); - ExpressionImplPtr ParserAddSubtract(); - ExpressionImplPtr ParserMulDivide(); - ExpressionImplPtr ParserMaxFunction(); - ExpressionImplPtr ParserMinFunction(); - ExpressionImplPtr ParserPowFunction(); - ExpressionImplPtr ParserModFunction(); - ExpressionImplPtr ParserLogFunction(); - ExpressionImplPtr ParserCeilFunction(); - ExpressionImplPtr ParserFloorFunction(); - ExpressionImplPtr ParserAbsFunction(); - ExpressionImplPtr ParserRationalFunction(); - ExpressionImplPtr ParserNumber(); - ExpressionImplPtr ParserIdentifier(); - ExpressionImplPtr ParserLParen(); - ExpressionImplPtr ParseConstBoolen(); - ExpressionImplPtr ParserMinus(); - ExpressionImplPtr ParserEqual(); - ExpressionImplPtr ParserUnequal(); - ExpressionImplPtr ParserLessEqual(); - ExpressionImplPtr ParserLessThan(); - ExpressionImplPtr ParserLogicalAnd(); - ExpressionImplPtr ParserLogicalOr(); - - Scanner scanner_; - Token currentToken_; -}; - -} -#endif // GRAPH_EXPRESSION_EXPR_PARSER_H_ \ No newline at end of file diff --git a/graph/expression/expr_print_manager.cc b/graph/expression/expr_print_manager.cc deleted file mode 100644 index c537b2c409ccadf486236254bfeac7167be0069f..0000000000000000000000000000000000000000 --- a/graph/expression/expr_print_manager.cc +++ /dev/null @@ -1,274 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "expr_print_manager.h" - -#include -#include -#include -#include -#include -#include - -#include "const_values.h" -#include "common/checker.h" - -namespace ge { -namespace { -const std::string kPrintAdd = " + "; -const std::string kPrintSub = " - "; -const std::string kPrintMul = " * "; -const std::string kPrintDiv = " / "; -const std::string kPrintMod = "Mod"; -const std::string kPrintEq = "ExpectEq"; -const std::string kPrintNe = "ExpectNe"; -const std::string kPrintLe = "ExpectLe"; -const std::string kPrintLt = "ExpectLt"; -const std::string kPrintPow = "Pow"; -const std::string kPrintLog = "Log"; -const std::string kPrintMax = "Max"; -const std::string kPrintMin = "Min"; -const std::string kPrintExp = "Exp"; -const std::string kPrintSqrt = "Sqrt"; -const std::string kPrintCeil = "Ceiling"; -const std::string kPrintFloor = "Floor"; -const std::string kPrintAbs = "Abs"; -const std::string kPrintLogicalAnd = "LogicAnd"; -const std::string kPrintLogicalOr = "LogicOr"; -const std::string kPrintDelim = ", "; -const std::string kPrintBracket_L = "("; -const std::string kPrintBracket_R = ")"; -const size_t kRelationArgsNum = 2UL; -} - -std::string PrintArgs(const std::vector &args, - const std::string &delim, StrType type) { - std::string res; - std::vector args_str; - for (size_t i = 0u; i < args.size(); ++i) { - args_str.emplace_back(ExpressionImpl::SymExprToExpressionImplRef(args[i]).Str(type)); - } - // 保证序列化反序列化后的顺序 - std::sort(args_str.begin(), args_str.end()); - for (size_t i = 0u; i < args_str.size(); ++i) { - if (i > 0u) { - res += delim + args_str[i]; - continue; - } - res = args_str[i]; - } - return res; -} - -std::string DefaultCeilPrinter(const std::vector &args, StrType type) { - return kPrintCeil + kPrintBracket_L + PrintArgs(args, kPrintDelim, type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpCeil, DefaultCeilPrinter); - -std::string DefaultFloorPrinter(const std::vector &args, StrType type) { - return kPrintFloor + kPrintBracket_L + PrintArgs(args, kPrintDelim, type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpFloor, DefaultFloorPrinter); - -std::string DefaultAbsPrinter(const std::vector &args, StrType type) { - return kPrintAbs + kPrintBracket_L + PrintArgs(args, kPrintDelim, type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpAbs, DefaultAbsPrinter); - -std::string DefaultLogicalAndPrinter(const std::vector &args, StrType type) { - return kPrintLogicalAnd + kPrintBracket_L + PrintArgs(args, kPrintDelim, type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpLogicalAnd, DefaultLogicalAndPrinter); - -std::string DefaultLogicalOrPrinter(const std::vector &args, StrType type) { - return kPrintLogicalOr + kPrintBracket_L + PrintArgs(args, kPrintDelim, type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpLogicalOr, DefaultLogicalOrPrinter); - -std::string DefaultAddPrinter(const std::vector &args, StrType type) { - std::vector positive_args; - std::vector negative_args; - for (const auto &arg : args) { - if (SymEngine::is_a(*arg) && - (SymEngine::down_cast(*arg)).get_coef()->is_negative()) { - negative_args.push_back(SymEngine::mul(arg, SymEngine::minus_one)); - continue; - } - positive_args.push_back(arg); - } - std::string res_str = kPrintBracket_L; - if (!positive_args.empty()) { - res_str += PrintArgs(positive_args, kPrintAdd, type); - } - if (!negative_args.empty()) { - res_str += kPrintSub; - res_str += PrintArgs(negative_args, kPrintSub, type); - } - res_str += kPrintBracket_R; - return res_str; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpAdd, DefaultAddPrinter); - -std::string DefaultMulPrinter(const std::vector &args, StrType type) { - // split mul to num and dens - std::vector positive_args; - std::vector negative_args; - for (const auto &arg : args) { - if (SymEngine::is_a(*arg)) { - const auto exp = SymEngine::down_cast(*arg).get_exp(); - if (SymEngine::is_a_Number(*exp) && - SymEngine::down_cast(*exp).is_negative()) { - negative_args.push_back(SymEngine::div(SymEngine::one, arg)); - continue; - } - } - positive_args.push_back(arg); - } - std::string res_str = kPrintBracket_L; - if (!positive_args.empty()) { - res_str += PrintArgs(positive_args, kPrintMul, type); - } else { - res_str += std::to_string(sym::kConstOne); - } - if (!negative_args.empty()) { - res_str += kPrintDiv; - res_str += kPrintBracket_L + PrintArgs(negative_args, kPrintMul, type) + kPrintBracket_R; - } - res_str += kPrintBracket_R; - return res_str; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpMul, DefaultMulPrinter); - -std::string DefaultMaxPrinter(const std::vector &args, StrType type) { - std::string res_str; - if (args.size() >= kSizeTwo) { - res_str = kPrintMax + kPrintBracket_L + - ExpressionImpl::SymExprToExpressionImplRef(args[0]).Str(type) + kPrintDelim + - ExpressionImpl::SymExprToExpressionImplRef(args[1]).Str(type) + kPrintBracket_R; - } - for (size_t i = kSizeTwo; i < args.size(); ++i) { - res_str = kPrintMax + kPrintBracket_L + - res_str + kPrintDelim + ExpressionImpl::SymExprToExpressionImplRef(args[i]).Str(type) + - kPrintBracket_R; - } - return res_str; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpMax, DefaultMaxPrinter); - -std::string DefaultMinPrinter(const std::vector &args, StrType type) { - std::string res_str; - if (args.size() >= kSizeTwo) { - res_str = kPrintMin + kPrintBracket_L - + ExpressionImpl::SymExprToExpressionImplRef(args[0]).Str(type) + kPrintDelim + - ExpressionImpl::SymExprToExpressionImplRef(args[1]).Str(type) + kPrintBracket_R; - } - for (size_t i = kSizeTwo; i < args.size(); ++i) { - res_str = kPrintMin + kPrintBracket_L + - res_str + kPrintDelim + ExpressionImpl::SymExprToExpressionImplRef(args[i]).Str(type) + - kPrintBracket_R; - } - return res_str; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpMin, DefaultMinPrinter); - -std::string PrintIntExpPow(const SymEngineExprPtr &base, const uint32_t exp, StrType type) { - std::string res_str = "("; - for (uint32_t i = 0u; i < exp; ++i) { - if (i > 0u) { - res_str += " * " + ExpressionImpl::SymExprToExpressionImplRef(base).Str(type); - continue; - } - res_str += ExpressionImpl::SymExprToExpressionImplRef(base).Str(type); - } - return res_str + ")"; -} - -std::string GetDefaultPowPrint(const std::vector &base_args, StrType type) { - const size_t base_idx = 0u; - const size_t exp_idx = 1u; - return kPrintPow + "(" + - ExpressionImpl::SymExprToExpressionImplRef(base_args[base_idx]).Str(type) + ", " + - ExpressionImpl::SymExprToExpressionImplRef(base_args[exp_idx]).Str(type) + ")"; -} - - -std::string DefaultPowPrinter(const std::vector &args, StrType type) { - constexpr const size_t pow_args_num = 2UL; - GE_ASSERT_TRUE(args.size() == pow_args_num, - "Symbol operator Pow args num should be 2 but get: %zu", args.size()); - const size_t base_idx = 0u; - const size_t exp_idx = 1u; - if (args[base_idx]->__eq__(*(SymEngine::E))) { - return kPrintExp + "(" + ExpressionImpl::SymExprToExpressionImplRef(args[exp_idx]).Str(type) + ")"; - } - if (args[exp_idx]->__eq__(*SymEngine::rational(sym::kNumOne, sym::kNumTwo))) { - return kPrintSqrt + "(" + ExpressionImpl::SymExprToExpressionImplRef(args[base_idx]).Str(type) + ")"; - } - if (args[exp_idx]->__eq__(*SymEngine::integer(sym::kNumOne))) { - return "(" + ExpressionImpl::SymExprToExpressionImplRef(args[base_idx]).Str(type) + ")"; - } - if (SymEngine::is_a(*(args[exp_idx]))) { - const SymEngine::Integer &exp_arg = SymEngine::down_cast(*(args[exp_idx])); - if (exp_arg.is_positive()) { - return PrintIntExpPow(args[base_idx], exp_arg.as_uint(), type); - } - } - return GetDefaultPowPrint(args, type); -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpPow, DefaultPowPrinter); - -std::string GetDefaultModPrint(const std::vector &base_args, StrType type) { - constexpr const size_t mod_args_num = 2UL; - GE_ASSERT_TRUE(base_args.size() == mod_args_num, - "Symbol operator Mod args num should be 2 but get: %zu", base_args.size()); - const size_t dividend_idx = 0u; - const size_t divisor_idx = 1u; - return kPrintMod + "(" + - ExpressionImpl::SymExprToExpressionImplRef(base_args[dividend_idx]).Str(type) + ", " + - ExpressionImpl::SymExprToExpressionImplRef(base_args[divisor_idx]).Str(type) + ")"; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpMod, GetDefaultModPrint); - -std::string DefaultLogPrinter(const std::vector &args, StrType type) { - return kPrintLog + kPrintBracket_L + PrintArgs(args, kPrintDelim, type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpLog, DefaultLogPrinter); - -std::string DefaultEqualPrinter(const std::vector &args, StrType type) { - GE_ASSERT_TRUE(args.size() == kRelationArgsNum, - "Equal operator args size should be 2, but get %zu", args.size()); - return kPrintEq + kPrintBracket_L + ExpressionImpl::SymExprToExpressionImplRef(args[0]).Str(type) + - kPrintDelim + ExpressionImpl::SymExprToExpressionImplRef(args[1]).Str(type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpEq, DefaultEqualPrinter); - -std::string DefaultUnEqualPrinter(const std::vector &args, StrType type) { - GE_ASSERT_TRUE(args.size() == kRelationArgsNum, - "Unequal operator args size should be 2, but get %zu", args.size()); - return kPrintNe + kPrintBracket_L + ExpressionImpl::SymExprToExpressionImplRef(args[0]).Str(type) + - kPrintDelim + ExpressionImpl::SymExprToExpressionImplRef(args[1]).Str(type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpNe, DefaultUnEqualPrinter); - -std::string DefaultStrictLessThanPrinter(const std::vector &args, StrType type) { - GE_ASSERT_TRUE(args.size() == kRelationArgsNum, - "StrictLessThan operator args size should be 2, but get %zu", args.size()); - return kPrintLt + kPrintBracket_L + ExpressionImpl::SymExprToExpressionImplRef(args[0]).Str(type) + - kPrintDelim + ExpressionImpl::SymExprToExpressionImplRef(args[1]).Str(type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpLt, DefaultStrictLessThanPrinter); - -std::string DefaultLessThanPrinter(const std::vector &args, StrType type) { - GE_ASSERT_TRUE(args.size() == kRelationArgsNum, - "LessThan operator args size should be 2, but get %zu", args.size()); - return kPrintLe + kPrintBracket_L + ExpressionImpl::SymExprToExpressionImplRef(args[0]).Str(type) + - kPrintDelim + ExpressionImpl::SymExprToExpressionImplRef(args[1]).Str(type) + kPrintBracket_R; -} -REGISTER_EXPR_DEFAULT_PRINTER(kOpLe, DefaultLessThanPrinter); -} // namespace ge \ No newline at end of file diff --git a/graph/expression/expr_print_manager.h b/graph/expression/expr_print_manager.h deleted file mode 100644 index f809dc96dd2e4145d727d04102a3945a93429310..0000000000000000000000000000000000000000 --- a/graph/expression/expr_print_manager.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_EXPRESSION_EXPR_PRINT_MANAGER_H_ -#define GRAPH_EXPRESSION_EXPR_PRINT_MANAGER_H_ -#include -#include -#include - -#include "expression_impl.h" - -namespace ge { -using OpPrinter = std::string (*)(const std::vector &args, StrType type); - -class ExprManager { - public: - static ExprManager &GetInstance() { - static ExprManager instance; - return instance; - } - - void RegisterDefaultOpPrinter(const OperationType &opType, const OpPrinter &operation) { - defaultOpPrinters_[opType] = operation; - } - - OpPrinter GetPrinter(const OperationType &type) - { - return GetPrinterHelper(type, defaultOpPrinters_); - } - - private: - ExprManager() = default; - ~ExprManager() = default; - ExprManager(const ExprManager&) = delete; - ExprManager &operator=(const ExprManager&) = delete; - OpPrinter GetPrinterHelper(const OperationType op, const std::map &printerMap) const { - const auto iter = printerMap.find(op); - if (iter == printerMap.end()) { - return nullptr; - } - return iter->second; - } - - std::map defaultOpPrinters_; -}; - -class ExprManagerRegister -{ - public: - ExprManagerRegister(const OperationType op, const OpPrinter &printer) { - ExprManager::GetInstance().RegisterDefaultOpPrinter(op, printer); - } - ~ExprManagerRegister() = default; -}; -} // namespace ge - -#define REGISTER_EXPR_DEFAULT_PRINTER(opType, funcName) \ - ExprManagerRegister register_##opType_default_##funcName(opType, funcName) -#endif // GRAPH_EXPRESSION_EXPR_PRINT_MANAGER_H_ \ No newline at end of file diff --git a/graph/expression/expression.cc b/graph/expression/expression.cc deleted file mode 100644 index 077e500f448a7a91a1baf5465bb7daad8ce70709..0000000000000000000000000000000000000000 --- a/graph/expression/expression.cc +++ /dev/null @@ -1,345 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "graph/symbolizer/symbolic.h" -#include "attribute_group/attr_group_shape_env.h" -#include "expression_impl.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/math_util.h" -#include "common/checker.h" -#include "const_values.h" - -namespace ge { -Expression::~Expression() {} - -Expression::Expression(Expression &&other) noexcept { - impl_ = std::move(other.impl_); -} - -Expression &Expression::operator=(const Expression &other) { - if (&other != this) { - impl_ = ComGraphMakeUnique(); - if ((other.impl_ != nullptr) && (impl_ != nullptr)) { - *impl_ = *other.impl_; - } - } - return *this; -} - -Expression::Expression(const Expression &other) { - // Copy - impl_ = ComGraphMakeUnique(); - if ((other.impl_ != nullptr) && (impl_ != nullptr)) { - *impl_ = *other.impl_; - } -} - -Expression &Expression::operator=(Expression &&other) noexcept { - if (&other != this) { - impl_ = std::move(other.impl_); - } - return *this; -} - -std::unique_ptr Expression::Str(const StrType type) const { - if (impl_ != nullptr) { - auto str = impl_->Str(type); - if (str.empty()) { - return nullptr; - } - auto uni_ptr = ComGraphMakeUnique(str.size() + 1); - IF_NULL_RETURN_NULL(uni_ptr); - // 当src size < dst size时,strncpy_s会在末尾str.size()位置添加'\0' - GE_ASSERT_EOK(strncpy_s(uni_ptr.get(), str.size() + 1, str.c_str(), str.size())); - return uni_ptr; - } - return nullptr; -} - -Expression Expression::Parse(const char_t *str) { - if (str == nullptr) { - GELOGE(FAILED, "Parse expression str is nullptr"); - return Expression(nullptr); - } - return Expression(ExpressionImpl::Parse(str)); -} - -std::unique_ptr Expression::Serialize() const { - return Str(StrType::kStrCpp); -} - -Expression Expression::Deserialize(const ge::char_t *str) { - return Expression(ExpressionImpl::Deserialize(str)); -} - -std::vector Expression::GetArgs() { - std::vector args; - if (impl_ == nullptr) { - return args; - } - - for (ExpressionImplPtr &arg : impl_->GetArgs()) { - args.emplace_back(Expression(std::move(arg))); - } - return args; -} - -ExprType Expression::GetExprType() const { - if (impl_ != nullptr) { - return impl_->GetExprType(); - } - return ExprType::kExprNone; -} - -bool Expression::IsConstExpr() const { - if (impl_!= nullptr) { - return impl_->IsConstExpr(); - } - return false; -} - -bool Expression::IsVariableExpr() const { - if (impl_!= nullptr) { - return impl_->IsVariableExpr(); - } - return false; -} - -bool Expression::IsBooleanExpr() const { - if (impl_!= nullptr) { - return impl_->IsBooleanExpr(); - } - return false; -} - -Expression Expression::Replace(const std::vector> &replace_vars) const { - if (impl_ != nullptr) { - std::map impl_map; - for (auto &item : replace_vars) { - impl_map[item.first.impl_.get()] = item.second.impl_.get(); - } - return Expression(impl_->Replace(impl_map)); - } - return Expression(nullptr); -} - -Expression Expression::Subs(const std::vector> &subs_vars) const { - if (impl_ != nullptr) { - std::map impl_map; - for (auto &item : subs_vars) { - impl_map[item.first.impl_.get()] = item.second.impl_.get(); - } - return Expression(impl_->Subs(impl_map)); - } - return Expression(nullptr); -} - -std::vector Expression::FreeSymbols() const { - if (impl_!= nullptr) { - std::vector ret; - for (auto &free_symbol : impl_->FreeSymbols()) { - ret.emplace_back(Expression(std::move(free_symbol))); - } - return ret; - } - return {}; -} - -graphStatus Expression::GetResult(const std::vector> &vars_value, - double &result) const { - Expression replace_expr = Replace(vars_value); - if ((replace_expr.impl_ != nullptr) && (replace_expr.impl_->GetResult(result))) { - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -bool Expression::IsValid() const { - return impl_ != nullptr; -} - -uint64_t Expression::Hash() const { - if (impl_ != nullptr) { - return impl_->Hash(); - } - return std::numeric_limits::max(); -} - -int64_t Expression::Compare(const Expression &e) const { - if (impl_!= nullptr) { - return impl_->Compare(*e.impl_); - } - return std::numeric_limits::max(); -} - -// 模板函数的定义 -template -typename std::enable_if::value || std::is_floating_point::value, bool>::type -Expression::GetConstValue(T &value) const { - if (!IsConstExpr() || impl_== nullptr) { - return false; - } - return impl_->GetConstValue(value); -} - -// 显式实例化 -template bool Expression::GetConstValue(int32_t &) const; // 实例化 int32 类型 -template bool Expression::GetConstValue(uint32_t &) const; // 实例化 uint32 类型 -template bool Expression::GetConstValue(int64_t &) const; // 实例化 int64 类型 -template bool Expression::GetConstValue(uint64_t &) const; // 实例化 uint64 类型 -template bool Expression::GetConstValue(double &) const; // 实例化 double 类型 -template bool Expression::GetConstValue(float &) const; // 实例化 float 类型 -template bool Expression::GetConstValue(bool &) const; // 实例化 bool 类型 - -Expression Expression::operator+(const Expression &other) const { - return sym::Add(*this, other); -} - -Expression Expression::operator-(const Expression &other) const { - return sym::Sub(*this, other); -} - -Expression Expression::operator*(const Expression &other) const { - return sym::Mul(*this, other); -} - -Expression Expression::operator/(const Expression &other) const { - return sym::Div(*this, other); -} - -Expression Expression::Simplify() const { - if (GetCurShapeEnvContext() != nullptr) { - return GetCurShapeEnvContext()->Simplify(*this); - } - if (impl_ != nullptr) { - return Expression(impl_->Simplify()); - } - return Expression(nullptr); -} - -bool Expression::ContainVar(const Expression &e) const { - if (impl_ != nullptr) { - return impl_->ContainVar(e.impl_.get()); - } - return false; -} - -bool Expression::operator==(const Expression &e) const { - if (impl_ != nullptr && e.impl_ != nullptr) { - return (*impl_ == *e.impl_); - } - return false; -} - -bool Expression::operator!=(const Expression &e) const { - return !(*this == e); -} - -std::ostream &operator<<(std::ostream &os, const Expression &e) { - if (e.impl_ != nullptr) { - os << *e.impl_; - } - return os; -} - -Expression::Expression(ExpressionImplPtr &&e) - : impl_(std::move(e)) {} - -Expression::Expression() { - impl_ = ge::ComGraphMakeUnique(""); -} - -Expression Expression::CanonicalizeBoolExpr() const { - if (impl_ != nullptr) { - return Expression(impl_->CanonicalizeBoolExpr()); - } - return Expression(nullptr); -} - -Symbol::Symbol(ExpressionImplPtr &&e) : Expression(std::move(e)) {} - -Symbol::Symbol(int32_t value, const char_t *name) { - impl_ = ge::ComGraphMakeUnique(value, name); -} - -Symbol::Symbol(int64_t value, const char_t *name) { - impl_ = ge::ComGraphMakeUnique(value, name); -} -Symbol::Symbol(uint32_t value, const char_t *name) { - impl_ = ge::ComGraphMakeUnique(value, name); -} -Symbol::Symbol(uint64_t value, const char_t *name) { - impl_ = ge::ComGraphMakeUnique(value, name); -} -Symbol::Symbol(double value, const char_t *name) { - impl_ = ge::ComGraphMakeUnique(value, name); -} - -Symbol::Symbol(const char_t *name) { - impl_ = ge::ComGraphMakeUnique(name); -} - -std::unique_ptr Symbol::GetName() const { - if (impl_ != nullptr) { - auto str = impl_->GetName(); - if (str.empty()) { - return nullptr; - } - auto uni_ptr = ComGraphMakeUnique(str.size() + 1U); - IF_NULL_RETURN_NULL(uni_ptr); - // 当src size < dst size时,strncpy_s会在末尾str.size()位置添加'\0' - GE_ASSERT_EOK(strncpy_s(uni_ptr.get(), str.size() + 1, str.c_str(), str.size())); - return uni_ptr; - } - return nullptr; -} - -template -typename std::enable_if::value || std::is_floating_point::value, bool>::type -Expression::ComputeHint(T &hint) const { - if (IsConstExpr()) { - return GetConstValue(hint); - } - if (GetCurShapeEnvContext() == nullptr) { - GELOGW("Shape env is nullptr, cannot compute hint, expr: %s", Serialize().get()); - return false; - } - return GetCurShapeEnvContext()->EvaluateExpr(*this).GetConstValue(hint); -} - -template bool Expression::ComputeHint(int32_t &) const; // 实例化 int32 类型 -template bool Expression::ComputeHint(uint32_t &) const; // 实例化 uint32 类型 -template bool Expression::ComputeHint(int64_t &) const; // 实例化 int64 类型 -template bool Expression::ComputeHint(uint64_t &) const; // 实例化 uint64 类型 -template bool Expression::ComputeHint(double &) const; // 实例化 double 类型 -template bool Expression::ComputeHint(float &) const; // 实例化 float 类型 -template bool Expression::ComputeHint(bool &) const; // 实例化 bool 类型 - -namespace sym { -Expression operator+(const Expression &e1, const Expression &e2) { - return Add(e1, e2); -} - -Expression operator-(const Expression &e1, const Expression &e2) { - return Sub(e1, e2); -} - -Expression operator*(const Expression &e1, const Expression &e2) { - return Mul(e1, e2); -} - -Expression operator/(const Expression &e1, const Expression &e2) { - return Div(e1, e2); -} -} // namespace sym -} // namespace ge \ No newline at end of file diff --git a/graph/expression/expression_impl.cc b/graph/expression/expression_impl.cc deleted file mode 100644 index 4fda494281f0cdfb7b8fcdd438f6acd8767b1327..0000000000000000000000000000000000000000 --- a/graph/expression/expression_impl.cc +++ /dev/null @@ -1,827 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "expression_impl.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "expr_print_manager.h" -#include "const_values.h" -#include "expr_parser.h" -#include "common/checker.h" - -namespace ge { -namespace { -constexpr const char_t *kInvalidName = "INVALID_NAME"; - -inline std::string sym_str(const SymEngineExprPtr &expr) { - GE_ASSERT_TRUE(!expr.is_null()); - return expr->__str__(); -} - -bool IsNegative(const SymEngineExprPtr &expr) { - GE_ASSERT_TRUE(!expr.is_null()); - GELOGD("Negative check, %s type is %s", sym_str(expr).c_str(), type_code_name(expr->get_type_code()).c_str()); - if (SymEngine::is_a_Number(*expr.get())) { - const auto *s = SymEngine::down_cast(expr.get()); - return s->is_negative(); - } - - if (SymEngine::is_a(*expr.get())) { - const auto *s = SymEngine::down_cast(expr.get()); - return s->get_coef()->is_negative(); - } - return false; -} - -SymEngine::integer_class ExtractCoeff(const SymEngineExprPtr &expr) { - GE_ASSERT_TRUE(!expr.is_null()); - GELOGD("expr [%s] type is %s", sym_str(expr).c_str(), type_code_name(expr->get_type_code()).c_str()); - if (SymEngine::is_a(*expr)) { - return SymEngine::rcp_static_cast(expr)->as_integer_class(); - } else if (SymEngine::is_a(*expr)) { - for (const auto &arg : expr->get_args()) { - GELOGD("Mul expr [%s] type is %s", sym_str(arg).c_str(), type_code_name(arg->get_type_code()).c_str()); - if (SymEngine::is_a(*arg)) { - return SymEngine::rcp_static_cast(arg)->as_integer_class(); - } - } - } - return SymEngine::integer_class(1); -} - -SymEngine::integer_class ComputeGCD(const std::vector &coeff_vec) { - GE_ASSERT_TRUE(!coeff_vec.empty()); - SymEngine::integer_class gcd = coeff_vec[0]; - for (size_t i = 1; i < coeff_vec.size(); ++i) { - SymEngine::mp_gcd(gcd, gcd, coeff_vec[i]); - } - return gcd; -} - -SymEngine::integer_class ComputeExprGCD(const SymEngineExprPtr &expr) { - GE_ASSERT_TRUE(SymEngine::is_a(*expr)); - std::vector coeff_vec; - for (const auto &arg : expr->get_args()) { - auto coeff = ExtractCoeff(arg); - coeff_vec.push_back(coeff); - GELOGD("expr [%s], coeff = %lu", sym_str(arg).c_str(), SymEngine::mp_get_ui(coeff)); - } - return ComputeGCD(coeff_vec); -} - -SymEngine::map_basic_basic CreateMulDict(const SymEngine::vec_basic &args, - SymEngine::RCP &coeff) { - SymEngine::map_basic_basic dict; - for (const auto &arg : args) { - GELOGD("arg [%s] type %s", sym_str(arg).c_str(), type_code_name(arg->get_type_code()).c_str()); - auto it = dict.find(arg); - if (it != dict.end()) { - dict[arg] = add(it->second, SymEngine::one); - } else { - if (IsNegative(arg)) { - coeff = SymEngine::minus_one; - } - if (SymEngine::is_a(*arg)) { - auto in = SymEngine::rcp_static_cast(arg); - if (in->is_one() || in->is_minus_one()) { - continue; - } - } - dict[arg] = SymEngine::one; - } - } - return dict; -} - -std::vector TrDivision(const SymEngineExprPtr &lhs, const SymEngineExprPtr &rhs) { - std::vector vec; - if (SymEngine::is_a(*lhs) && SymEngine::is_a(*rhs)) { - SymEngineExprPtr ratio = SymEngine::div(lhs, rhs); - if (SymEngine::is_a(*ratio) || SymEngine::eq(*ratio, *SymEngine::one)) { - vec.push_back(ratio); - vec.push_back(SymEngine::one); - return vec; - } - - ratio = SymEngine::div(rhs, lhs); - if (SymEngine::is_a(*ratio) || SymEngine::eq(*ratio, *SymEngine::one)) { - vec.push_back(SymEngine::one); - vec.push_back(ratio); - return vec; - } - } - vec.push_back(lhs); - vec.push_back(rhs); - return vec; -} - -SymEngineExprPtr DivByFactor(const SymEngineExprPtr &x, const SymEngine::integer_class &factor) { - GE_ASSERT_TRUE(!x.is_null()); - GELOGD("DivByFactor [%s], factor = %lu", sym_str(x).c_str(), SymEngine::mp_get_ui(factor)); - if (SymEngine::is_a(*x)) { - SymEngine::integer_class x_val = SymEngine::rcp_static_cast(x)->as_integer_class(); - SymEngine::integer_class result; - SymEngine::mp_divexact(result, x_val, factor); - return SymEngine::integer(result); - } else if (SymEngine::is_a(*x)) { - const auto &args = x->get_args(); - GE_ASSERT_TRUE(!args.empty()); - SymEngine::integer_class total_coeff(1); - SymEngine::vec_basic remaining_args; - for (const auto &arg : args) { - if (SymEngine::is_a(*arg)) { - total_coeff *= SymEngine::rcp_static_cast(arg)->as_integer_class(); - } else { - remaining_args.push_back(arg); - } - } - - SymEngine::integer_class new_coeff; - SymEngine::mp_divexact(new_coeff, total_coeff, factor); - SymEngine::vec_basic new_args; - GELOGD("total_coeff=%lu, factor=%lu, new_coeff=%lu", SymEngine::mp_get_ui(total_coeff), - SymEngine::mp_get_ui(factor), SymEngine::mp_get_ui(new_coeff)); - if (new_coeff != 1) { - new_args.push_back(SymEngine::integer(new_coeff)); - } - new_args.insert(new_args.end(), remaining_args.begin(), remaining_args.end()); - SymEngine::RCP coeff = SymEngine::one; - auto dict = CreateMulDict(new_args, coeff); - return SymEngine::Mul::from_dict(coeff, std::move(dict)); - } else { - GELOGW("unsupported type %s", type_code_name(x->get_type_code()).c_str()); - return x; - } -} -} - -ExpressionImplPtr ExpressionImpl::CreateExpressionImpl(const std::string &name) { - return ge::ComGraphMakeUnique(name); -} - -ExpressionImpl::~ExpressionImpl() {} - -std::string ExpressionImpl::Str(const StrType type) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - if (type == StrType::kStrCpp) { - if (SymEngine::is_a(*sym_expr_)) { - const auto &x = SymEngine::down_cast(*sym_expr_); - auto dens = x.get_den(); - auto nums = x.get_num(); - GE_ASSERT_TRUE(!dens.is_null()); - GE_ASSERT_TRUE(!nums.is_null()); - return "Rational(" + nums->__str__() + " , " + dens->__str__() + ")"; - } - } - if (((GetExprType() == ExprType::kExprOperation) || - (GetExprType() == ExprType::kExprOperationBoolean)) && - (GetOpType() != OperationType::kOpNone)) { - auto printer = ExprManager::GetInstance().GetPrinter(GetOpType()); - GE_ASSERT_NOTNULL(printer); - return printer(sym_expr_->get_args(), type); - } - return sym_expr_->__str__(); -} - -ExpressionImplPtr ExpressionImpl::Parse(const std::string &expr_str) { - Scanner scanner(expr_str); - ge::ExprParser expr_parser(scanner); - auto ret = expr_parser.ParserExpression(); - GE_WARN_ASSERT(ret != nullptr, "Parse expression %s abnormal.", expr_str.c_str()); - return ret; -} - -ExpressionImplPtr ExpressionImpl::Deserialize(const std::string &expr_str) { - auto ret = Parse(expr_str); - GE_WARN_ASSERT(ret != nullptr); - if (ret->Str() == expr_str) { - return ret; - } else { - GELOGW("Parse expression str %s abnormal, result is %s, please check the string is valid.", - expr_str.c_str(), ret->Str().c_str()); - return nullptr; - } -} - -ExpressionImplPtr ExpressionImpl::Replace(const std::map &replace_vars) const { - SymEngine::map_basic_basic sym_replace_vars; - for (const auto &sym_expr_impl_ptr_item : replace_vars) { - GE_ASSERT_NOTNULL(sym_expr_impl_ptr_item.first); - GE_ASSERT_NOTNULL(sym_expr_impl_ptr_item.second); - GE_ASSERT_TRUE(!sym_expr_impl_ptr_item.first->sym_expr_.is_null()); - GE_ASSERT_TRUE(!sym_expr_impl_ptr_item.second->sym_expr_.is_null()); - sym_replace_vars.emplace(sym_expr_impl_ptr_item.first->sym_expr_, sym_expr_impl_ptr_item.second->sym_expr_); - } - GE_ASSERT_TRUE(!sym_expr_.is_null()); - SymEngineExprPtr replaced_expr = sym_expr_->xreplace(sym_replace_vars); - return ExpressionImpl::CreateExpressionImpl(replaced_expr); -} - -ExpressionImplPtr ExpressionImpl::Subs(const std::map &subs_vars) const { - SymEngine::map_basic_basic sym_replace_vars; - for (const auto &sym_expr_impl_ptr_item : subs_vars) { - GE_ASSERT_NOTNULL(sym_expr_impl_ptr_item.first); - GE_ASSERT_NOTNULL(sym_expr_impl_ptr_item.second); - GE_ASSERT_TRUE(!sym_expr_impl_ptr_item.first->sym_expr_.is_null()); - GE_ASSERT_TRUE(!sym_expr_impl_ptr_item.second->sym_expr_.is_null()); - sym_replace_vars.emplace(sym_expr_impl_ptr_item.first->sym_expr_, sym_expr_impl_ptr_item.second->sym_expr_); - } - GE_ASSERT_TRUE(!sym_expr_.is_null()); - SymEngineExprPtr subs_expr = sym_expr_->subs(sym_replace_vars); - return ExpressionImpl::CreateExpressionImpl(subs_expr); -} - -ExpressionImplPtr ExpressionImpl::Simplify() const { - SymEngineExprPtr expanded_expr = SymEngine::expand(sym_expr_); - SymEngineExprPtr simplified_expr = SymEngine::simplify(expanded_expr); - return ExpressionImpl::CreateExpressionImpl(simplified_expr); -} - -bool ExpressionImpl::ContainVar(const ExpressionImpl *e) const { - GE_ASSERT_NOTNULL(e); - GE_ASSERT_TRUE(!e->sym_expr_.is_null()); - if (!(e->sym_expr_->get_args().empty())) { - return false; - } - for (const auto &arg : FreeSymbols()) { - GE_ASSERT_NOTNULL(arg); - GE_ASSERT_NOTNULL(e); - GE_ASSERT_TRUE(!arg->sym_expr_.is_null()); - GE_ASSERT_TRUE(!e->sym_expr_.is_null()); - if (SymEngine::eq(*arg->sym_expr_, *(e->sym_expr_))) { - return true; - } - } - return false; -} - -std::vector ExpressionImpl::FreeSymbols() const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - auto free_symbols = SymEngine::free_symbols(*sym_expr_); - std::vector ret; - for (const auto &sym_arg : free_symbols) { - auto expr = ExpressionImpl::CreateExpressionImpl(sym_arg); - ret.emplace_back(std::move(expr)); - } - return ret; -} - -bool ExpressionImpl::operator==(const ExpressionImpl &e) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - GE_ASSERT_TRUE(!e.sym_expr_.is_null()); - return SymEngine::eq(*sym_expr_, *e.sym_expr_); -} - -ExpressionImplPtr ExpressionImpl::CanonicalizeBoolExpr() { - GE_ASSERT_NOTNULL(sym_expr_.get()); - GELOGI("EXPR(before) [%s]", sym_str(sym_expr_).c_str()); - OperationType type = GetOpType(); - std::unordered_map op_func_map = { - {OperationType::kOpEq, [](const auto &a, const auto &b) { return SymEngine::Eq(a, b); }}, - {OperationType::kOpNe, [](const auto &a, const auto &b) { return SymEngine::Ne(a, b); }}, - {OperationType::kOpLt, [](const auto &a, const auto &b) { return SymEngine::Lt(a, b); }}, - {OperationType::kOpLe, [](const auto &a, const auto &b) { return SymEngine::Le(a, b); }}}; - - if (op_func_map.find(type) == op_func_map.end()) { - GELOGI("EXPR(after) [%s]", sym_str(sym_expr_).c_str()); - return ExpressionImpl::CreateExpressionImpl(sym_expr_); - } - - SymEngine::vec_basic args = sym_expr_.get()->get_args(); - GE_ASSERT_TRUE(args.size() == kSizeTwo); - SymEngineExprPtr rhs = SymEngine::sub(args[1], args[0]); - SymEngineExprPtr lhs = SymEngine::zero; - GELOGI("step1 rhs = [%s], lhs = [%s]", sym_str(rhs).c_str(), sym_str(lhs).c_str()); - - rhs = SymEngine::expand(rhs); - if (SymEngine::is_a(*rhs)) { - SymEngine::integer_class gcd = ComputeExprGCD(rhs); - GELOGD("[%s], gcd = %lu", sym_str(rhs).c_str(), SymEngine::mp_get_ui(gcd)); - if (SymEngine::mp_get_ui(gcd) > 1) { - SymEngine::vec_basic div_gcd_vec; - for (const auto &arg : rhs->get_args()) { - auto expr_div = DivByFactor(arg, gcd); - div_gcd_vec.push_back(expr_div); - GELOGD("DivByFactor [%s] -> [%s]", sym_str(arg).c_str(), sym_str(expr_div).c_str()); - } - rhs = SymEngine::add(div_gcd_vec); - } - } - GELOGI("step2 rhs = [%s], lhs = [%s]", sym_str(rhs).c_str(), sym_str(lhs).c_str()); - - if (SymEngine::is_a(*rhs)) { - SymEngine::vec_basic pos_vec; - SymEngine::vec_basic neg_vec; - for (const auto &arg : rhs->get_args()) { - if (IsNegative(arg)) { - GELOGD("negative push [%s]", sym_str(arg).c_str()); - neg_vec.push_back(SymEngine::sub(SymEngine::zero, arg)); - } else { - pos_vec.push_back(arg); - } - } - rhs = SymEngine::add(pos_vec); - lhs = SymEngine::add(neg_vec); - } - if (IsNegative(rhs) && SymEngine::is_number_and_zero(*lhs)) { - lhs = SymEngine::sub(SymEngine::zero, rhs); - rhs = SymEngine::zero; - } - GELOGI("step3 rhs = [%s], lhs = [%s]", sym_str(rhs).c_str(), sym_str(lhs).c_str()); - - auto vec = TrDivision(lhs, rhs); - GELOGI("step4 rhs = [%s], lhs = [%s]", sym_str(vec[1]).c_str(), sym_str(vec[0]).c_str()); - - auto expr_new = op_func_map[type](vec[0], vec[1]); - GELOGI("EXPR(after) [%s] (canonicalized)", sym_str(expr_new).c_str()); - return ExpressionImpl::CreateExpressionImpl(expr_new); -} - -vector ExpressionImpl::GetArgs() { - vector args; - GE_ASSERT_TRUE(!sym_expr_.is_null()); - for (auto &arg : sym_expr_.get()->get_args()) { - args.emplace_back(ExpressionImpl::CreateExpressionImpl(arg)); - } - return args; -} - -double ExpressionImpl::GetIntegerResult(const SymEngine::Integer &integer_expr) const { - if (integer_expr.is_zero()) { - return 0; - } else if (integer_expr.is_positive()) { - return static_cast(integer_expr.as_uint()); - } - return static_cast(integer_expr.as_int()); -} - -bool ExpressionImpl::GetResult(double &result) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - if (SymEngine::is_a(*sym_expr_)) { - const auto &integer_expr = SymEngine::down_cast(*sym_expr_); - result = GetIntegerResult(integer_expr); - return true; - } - if (SymEngine::is_a(*sym_expr_)) { - const auto &rational_expr = SymEngine::down_cast(*sym_expr_); - result = GetIntegerResult(*(rational_expr.get_num())) / GetIntegerResult(*(rational_expr.get_den())); - return true; - } - if (SymEngine::is_a(*sym_expr_)) { - const auto &real_double_expr = SymEngine::down_cast(*sym_expr_); - result = real_double_expr.as_double(); - return true; - } - return false; -} - -uint64_t ExpressionImpl::Hash() const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - return sym_expr_->hash(); -} - -int64_t ExpressionImpl::Compare(const ExpressionImpl &e) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - GE_ASSERT_TRUE(!e.sym_expr_.is_null()); - return sym_expr_->__cmp__(*e.sym_expr_); -} - -bool ExpressionImpl::IsVariableExpr() const { - return GetExprType() == ExprType::kExprVariable; -} - -bool ExpressionImpl::IsBooleanExpr() const { - return (GetExprType() == ExprType::kExprOperationBoolean) || - (GetExprType() == ExprType::kExprConstantBoolean); -} - -bool ExpressionImpl::GetConstValue(uint32_t &value) const { - uint64_t result = 0UL; - GE_ASSERT_TRUE(GetConstValue(result)); - value = static_cast(result); - return true; -} - -bool ExpressionImpl::GetConstValue(uint64_t &value) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - // 无符号整数类型 - GE_ASSERT_TRUE(SymEngine::is_a(*sym_expr_), - "Cannot get const uint value from a expression: %s not Integer.", Str().c_str()); - const auto &integer_expr = SymEngine::down_cast(*sym_expr_); - value = integer_expr.as_uint(); - return true; -} - -bool ExpressionImpl::GetConstValue(int32_t &value) const { - int64_t result = 0L; - GE_ASSERT_TRUE(GetConstValue(result)); - value = static_cast(result); - return true; -} - -bool ExpressionImpl::GetConstValue(int64_t &value) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - // 整数类型 - GE_ASSERT_TRUE(SymEngine::is_a(*sym_expr_), - "Cannot get const int value from a expression: %s not Integer.", Str().c_str()); - const auto &integer_expr = SymEngine::down_cast(*sym_expr_); - value = integer_expr.as_int(); - return true; -} - -bool ExpressionImpl::GetConstValue(bool &value) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - // bool类型 - GE_ASSERT_TRUE(SymEngine::is_a(*sym_expr_), - "Cannot get const bool value from a expression: %s not BooleanAtom.", Str().c_str()); - const auto &bool_expr = SymEngine::down_cast(*sym_expr_); - value = bool_expr.get_val(); - return true; -} - -bool ExpressionImpl::GetConstValue(float &value) const { - double result = 0L; - GE_ASSERT_TRUE(GetConstValue(result)); - value = static_cast(result); - return true; -} - -bool ExpressionImpl::GetConstValue(double &value) const { - GE_ASSERT_TRUE(!sym_expr_.is_null()); - GE_ASSERT_TRUE((SymEngine::is_a(*sym_expr_)) || - (SymEngine::is_a(*sym_expr_)), - "Cannot get const float value from a expression: %s not RealDouble or Rational.", - Str().c_str()); - if (SymEngine::is_a(*sym_expr_)) { - const auto &real_double_expr = SymEngine::down_cast(*sym_expr_); - value = real_double_expr.as_double(); - } else { - // 分数 - const auto &rational_expr = SymEngine::down_cast(*sym_expr_); - value = GetIntegerResult(*(rational_expr.get_num())) / GetIntegerResult(*(rational_expr.get_den())); - } - return true; -} - -OperationType ExpressionImpl::GetOpType() const { - if (sym_expr_.is_null()) { - GELOGE(ge::PARAM_INVALID, "sym_expr_ is null."); - return OperationType::kOpNone; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpAdd; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpMul; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpMax; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpMin; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpPow; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpMod; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpLog; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpAbs; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpCeil; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpFloor; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpEq; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpNe; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpLe; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpLt; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpLogicalAnd; - } - if (SymEngine::is_a(*sym_expr_)) { - return OperationType::kOpLogicalOr; - } - return OperationType::kOpNone; -} - -std::string ExpressionImpl::GetName() const { - if (IsConstExpr() || GetExprType() == ExprType::kExprVariable) { - if (name_.empty()) { - static std::atomic unique_id(0); - // 此处不应该使用Str()拼接,比如对于1.0,会生成Const_1.0_1,如果codegen采用此名字定义c++变量会编译报错 - name_ = "Const_" + std::to_string(unique_id.fetch_add(1)); - } - return name_; - } else { - return kInvalidName; - } -} - -ExprType ExpressionImpl::GetExprType() const { - if (sym_expr_.is_null()) { - GELOGE(ge::PARAM_INVALID, "sym_expr_ is null."); - return ExprType::kExprNone; - } - if (SymEngine::is_a_Number(*sym_expr_)) { - if (SymEngine::is_a(*sym_expr_)) { - return ExprType::kExprConstantInteger; - } else if (SymEngine::is_a(*sym_expr_)) { - return ExprType::kExprConstantRealDouble; - } else if (SymEngine::is_a(*sym_expr_)) { - return ExprType::kExprConstantRation; - } else { - GELOGE(ge::PARAM_INVALID, "Unsupported type for expression %s", sym_expr_->__str__().c_str()); - return ExprType::kExprNone; - } - } - if (SymEngine::is_a(*sym_expr_)) { - return ExprType::kExprConstantBoolean; - } - if (SymEngine::is_a(*sym_expr_)) { - return ExprType::kExprVariable; - } - if (SymEngine::is_a_Boolean(*sym_expr_)) { - return ExprType::kExprOperationBoolean; - } - return ExprType::kExprOperation; -} - -bool ExpressionImpl::IsConstExpr() const { - return GetExprType() < ExprType::kExprVariable; -} - -ExpressionImpl::ExpressionImpl(int32_t const_value, const std::string &name) - : sym_expr_(SymEngine::integer(const_value)), name_(name) {} - -ExpressionImpl::ExpressionImpl(int64_t const_value, const std::string &name) - : sym_expr_(SymEngine::integer(const_value)), name_(name) {} - -ExpressionImpl::ExpressionImpl(uint32_t const_value, const std::string &name) - : sym_expr_(SymEngine::integer(const_value)), name_(name) {} - -ExpressionImpl::ExpressionImpl(uint64_t const_value, const std::string &name) - : sym_expr_(SymEngine::integer(const_value)), name_(name) {} - -ExpressionImpl::ExpressionImpl(double const_value, const std::string &name) - : sym_expr_(SymEngine::real_double(const_value)), name_(name) {} - -ExpressionImpl::ExpressionImpl(bool const_value, const std::string &name) - : sym_expr_(SymEngine::boolean(const_value)), name_(name) {} - -ExpressionImpl::ExpressionImpl(const std::string &name) : sym_expr_(SymEngine::symbol(name)), name_(name) {} - -ExpressionImpl::ExpressionImpl(const SymEngineExprPtr &sym_expr, const std::string &name) - : sym_expr_(sym_expr), name_(name) {} - -ExpressionImplPtr Add(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::add(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Sub(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::sub(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Mul(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::mul(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Div(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::div(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Max(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::max({a->sym_expr_, b->sym_expr_}); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Min(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::min({a->sym_expr_, b->sym_expr_}); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Abs(const ExpressionImplPtr &a) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::abs(a->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Pow(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::pow(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - - -ExpressionImplPtr Mod(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::mod(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Log(const ExpressionImplPtr &a) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::log(a->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Log(const ExpressionImplPtr &arg, const ExpressionImplPtr &base) { - GE_ASSERT_NOTNULL(arg); - GE_ASSERT_NOTNULL(base); - GE_ASSERT_TRUE(!arg->sym_expr_.is_null()); - GE_ASSERT_TRUE(!base->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::log(arg->sym_expr_, base->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Coeff(const ExpressionImplPtr &b, const ExpressionImplPtr &x, const ExpressionImplPtr &n) { - GE_ASSERT_NOTNULL(b); - GE_ASSERT_NOTNULL(x); - GE_ASSERT_NOTNULL(n); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - GE_ASSERT_TRUE(!x->sym_expr_.is_null()); - GE_ASSERT_TRUE(!n->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::coeff(*b->sym_expr_.get(), *x->sym_expr_.get(), *n->sym_expr_.get()); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Ceiling(const ExpressionImplPtr &a) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::ceiling(a->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Floor(const ExpressionImplPtr &a) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::floor(a->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Rational(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - if (SymEngine::is_a(*a->sym_expr_) && SymEngine::is_a(*b->sym_expr_)) { - const auto &integer_expr_a = SymEngine::down_cast(*a->sym_expr_); - const auto &integer_expr_b = SymEngine::down_cast(*b->sym_expr_); - SymEngineExprPtr sym_expr = SymEngine::Rational::from_two_ints(integer_expr_a, integer_expr_b); - auto impl = ExpressionImpl::CreateExpressionImpl(sym_expr); - return impl; - } else { - std::cerr << "unsupported rational expr" << std::endl; - return nullptr; - } -} - -std::ostream &operator<<(std::ostream &os, const ExpressionImpl &e) { - os << e.Str(); - return os; -} - -ExpressionImplPtr Eq(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::Eq(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Ne(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::Ne(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr Lt(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::Lt(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} -ExpressionImplPtr Le(const ExpressionImplPtr &a, const ExpressionImplPtr &b) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_NOTNULL(b); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - GE_ASSERT_TRUE(!b->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::Le(a->sym_expr_, b->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} -ExpressionImplPtr Not(const ExpressionImplPtr &a) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - if (!SymEngine::is_a_Boolean(*a->sym_expr_)) { - GELOGE(ge::PARAM_INVALID, "Logic operator Not only can handle Boolean expression:%s", - a->Str().c_str()); - return nullptr; - } - SymEngineExprPtr sym_expr = - SymEngine::logical_not(SymEngine::rcp_static_cast(a->sym_expr_)); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} -ExpressionImplPtr Neg(const ExpressionImplPtr &a) { - GE_ASSERT_NOTNULL(a); - GE_ASSERT_TRUE(!a->sym_expr_.is_null()); - SymEngineExprPtr sym_expr = SymEngine::neg(a->sym_expr_); - return ExpressionImpl::CreateExpressionImpl(sym_expr); -} - -ExpressionImplPtr LogicalAnd(const vector &s) { - SymEngine::set_boolean set; - for (const auto &e : s) { - GE_ASSERT_NOTNULL(e); - GE_ASSERT_TRUE(!e->sym_expr_.is_null()); - GE_ASSERT_TRUE(SymEngine::is_a_Boolean(*e->sym_expr_), "Logic operator And only can handle Boolean expression: %s", - e->Str().c_str()); - set.insert(SymEngine::rcp_static_cast(e->sym_expr_)); - } - return ExpressionImpl::CreateExpressionImpl(SymEngine::logical_and(set)); -} - -ExpressionImplPtr LogicalOr(const vector &s) { - SymEngine::set_boolean set; - for (const auto &e : s) { - GE_ASSERT_NOTNULL(e); - GE_ASSERT_TRUE(!e->sym_expr_.is_null()); - GE_ASSERT_TRUE(SymEngine::is_a_Boolean(*e->sym_expr_), "Logic operator Or only can handle Boolean expression: %s", - e->Str().c_str()); - set.insert(SymEngine::rcp_static_cast(e->sym_expr_)); - } - return ExpressionImpl::CreateExpressionImpl(SymEngine::logical_or(set)); -} -} // namespace ge diff --git a/graph/expression/expression_impl.h b/graph/expression/expression_impl.h deleted file mode 100644 index 318d78514cb623b6fb7acc895c7d2c60f669fe02..0000000000000000000000000000000000000000 --- a/graph/expression/expression_impl.h +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_EXPRESSION_EXPRESSION_IMPL_H_ -#define GRAPH_EXPRESSION_EXPRESSION_IMPL_H_ -#include -#include -#include -#include -#include -#include -#include "graph/symbolizer/symbolic.h" -#include "graph/debug/ge_util.h" - -#define IF_NULL_RETURN_NULL(x) \ - if ((x) == nullptr) \ - return nullptr - -namespace ge { -constexpr int32_t kSizeTwo = 2; - -class ExpressionImpl; -using ExpressionImplPtr = std::unique_ptr; -using RelationalFunc = std::function( - const SymEngine::RCP &, const SymEngine::RCP &)>; - -enum OperationType : size_t { - kOpAdd = 0, - kOpMax, - kOpMin, - kOpMul, - kOpAbs, - kOpPow, - kOpMod, - kOpLog, - kOpCeil, - kOpFloor, - kOpEq, - kOpNe, - kOpLt, - kOpLe, - kOpLogicalAnd, - kOpLogicalOr, - kOpNone = std::numeric_limits::max() -}; - -using SymEngineExprPtr = SymEngine::RCP; - -class ExpressionImpl { - public: - // 目前只支持int32_t,int64_t,uint32_t,uint64_t,const string&,const SymEngineExprPtr& - template - static std::unique_ptr CreateExpressionImpl(T value, const std::string &name = "") { - return ge::ComGraphMakeUnique(value, name); - } - ExpressionImpl() = default; - ExpressionImpl(int32_t const_value, const std::string &name); - ExpressionImpl(int64_t const_value, const std::string &name); - ExpressionImpl(uint32_t const_value, const std::string &name); - ExpressionImpl(uint64_t const_value, const std::string &name); - ExpressionImpl(double const_value, const std::string &name); - ExpressionImpl(bool const_value, const std::string &name); - explicit ExpressionImpl(const std::string &name); - ExpressionImpl(const SymEngineExprPtr &sym_expr, const std::string &name); - - static ExpressionImplPtr CreateExpressionImpl(const std::string &name); - ~ExpressionImpl(); - - std::string Str(const StrType type = StrType::kStrCpp) const; - static ExpressionImplPtr Parse(const std::string &expr_str); - static ExpressionImplPtr Deserialize(const std::string &expr_str); - ExprType GetExprType() const; - bool IsConstExpr() const; - bool IsVariableExpr() const; - bool IsBooleanExpr() const; - ExpressionImplPtr Replace(const std::map &replace_vars) const; - ExpressionImplPtr Subs(const std::map &subs_vars) const; - - ExpressionImplPtr Simplify() const; - bool ContainVar(const ExpressionImpl *e) const; - bool operator==(const ExpressionImpl &e) const; - std::vector FreeSymbols() const; - OperationType GetOpType() const; - std::string GetName() const; - bool GetResult(double &result) const; - uint64_t Hash() const; - int64_t Compare(const ExpressionImpl &e) const; - - bool GetConstValue(uint64_t &value) const; - bool GetConstValue(uint32_t &value) const; - bool GetConstValue(int32_t &value) const; - bool GetConstValue(int64_t &value) const; - bool GetConstValue(bool &value) const; - bool GetConstValue(double &value) const; - bool GetConstValue(float &value) const; - - // 该方法不需要new一个ExpressionImpl对象(带来大量的指针校验)就能调用ExpressionImpl的方法 - // ***使用时需注意:1.返回的引用使用时,sym_expr对象必须存在; - // ***使用时需注意:2.ExpressionImpl类只有一个SymEngineExprPtr类型的私有变量 - static const ExpressionImpl &SymExprToExpressionImplRef(const SymEngineExprPtr &sym_expr) { - return *reinterpret_cast(&sym_expr); - } - - /** - * @brief bool表达式进行标准化处理,处理逻辑 - expra为表达式1,exprb为表达式2,OP是比较关系,支持四种表达式Eq、Ne、Lt、Le。Gt和Ge表达式由Lt、Le来进行替换 - 1、原始表达式:expra OP exprb - 2、构造新的参数,右值:exprb - expra , 左值 = 0 - 3、参数处理 - 3.1、最大公约数化简,遍历右值的所有参数,计算最大公约数,如果gcd大于1则所有参数都除最大公约数 - 3.1、区分正数和负数,遍历右值的所有参数,正数的集合相加作为新的右值,负数的集合相加作为新的左值 - 4、通过新的左值和右值构造原始的布尔表达式类型 - */ - ExpressionImplPtr CanonicalizeBoolExpr(); - - vector GetArgs(); - private: - double GetIntegerResult(const SymEngine::Integer &integer_expr) const; - - friend ExpressionImplPtr Add(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Sub(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Mul(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Div(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Max(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Min(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Abs(const ExpressionImplPtr &a); - friend ExpressionImplPtr Pow(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Mod(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Log(const ExpressionImplPtr &a); // 默认以E为底 - friend ExpressionImplPtr Log(const ExpressionImplPtr &arg, const ExpressionImplPtr &base); - friend ExpressionImplPtr Coeff(const ExpressionImplPtr &b, const ExpressionImplPtr &x, const ExpressionImplPtr &n); - friend ExpressionImplPtr Ceiling(const ExpressionImplPtr &a); - friend ExpressionImplPtr Floor(const ExpressionImplPtr &a); - friend ExpressionImplPtr Rational(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Eq(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Ne(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Lt(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Le(const ExpressionImplPtr &a, const ExpressionImplPtr &b); - friend ExpressionImplPtr Not(const ExpressionImplPtr &a); - friend ExpressionImplPtr Neg(const ExpressionImplPtr &a); - friend ExpressionImplPtr LogicalAnd(const vector &s); - friend ExpressionImplPtr LogicalOr(const vector &s); - // friend std::string DefaultPowPrinter(const std::vector &args); - friend class Parser; - - private: - SymEngineExprPtr sym_expr_; // 非空,symengine在内存不够时会抛异常 - mutable std::string name_; -}; - -ExpressionImplPtr Add(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Sub(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Mul(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Div(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Max(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Min(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Abs(const ExpressionImplPtr &a); -ExpressionImplPtr Pow(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Mod(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Log(const ExpressionImplPtr &a); // 默认以E为底 -ExpressionImplPtr Log(const ExpressionImplPtr &arg, const ExpressionImplPtr &base); -ExpressionImplPtr Coeff(const ExpressionImplPtr &b, const ExpressionImplPtr &x, const ExpressionImplPtr &n); -ExpressionImplPtr Ceiling(const ExpressionImplPtr &a); -ExpressionImplPtr Floor(const ExpressionImplPtr &a); -ExpressionImplPtr Rational(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -std::ostream &operator<<(std::ostream &os, const ExpressionImpl &e); -ExpressionImplPtr Eq(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Ne(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Lt(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Le(const ExpressionImplPtr &a, const ExpressionImplPtr &b); -ExpressionImplPtr Not(const ExpressionImplPtr &a); -ExpressionImplPtr Neg(const ExpressionImplPtr &a); -ExpressionImplPtr LogicalAnd(const vector &s); -ExpressionImplPtr LogicalOr(const vector &s); -} // namespace ge - -#endif // GRAPH_EXPRESSION_EXPRESSION_IMPL_H_ \ No newline at end of file diff --git a/graph/expression/guard_dfx_context.cc b/graph/expression/guard_dfx_context.cc deleted file mode 100644 index 4fc6a9ccaa80b15f297a389bb69a81e7f7c09abe..0000000000000000000000000000000000000000 --- a/graph/expression/guard_dfx_context.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/symbolizer/guard_dfx_context.h" -#include "attribute_group/attr_group_shape_env.h" - -namespace ge { -GuardDfxContext::GuardDfxContext(const std::string &guard_dfx_info) { - if (GetCurShapeEnvContext() != nullptr) { - GetCurShapeEnvContext()->SetGuardDfxContextInfo(guard_dfx_info); - } -} -GuardDfxContext::~GuardDfxContext() { - if (GetCurShapeEnvContext() != nullptr) { - GetCurShapeEnvContext()->ClearGuardDfxContextInfo(); - } -} -} diff --git a/graph/expression/scanner.cc b/graph/expression/scanner.cc deleted file mode 100644 index 75d7e619b49e6bc7c546ce2671cab56577b56b80..0000000000000000000000000000000000000000 --- a/graph/expression/scanner.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "scanner.h" -#include -#include -#include "common/checker.h" - -namespace ge { -namespace { -std::unordered_map kTokenMap = { - {"Max", TokenType::kMax}, {"Min", TokenType::kMin}, {"Pow", TokenType::kPow}, {"Log", TokenType::kLog}, - {"Ceiling", TokenType::kCeil}, {"Rational", TokenType::kRational}, {"+", TokenType::kPlus}, - {"-", TokenType::kMinus}, {"*", TokenType::kMultiply}, {"/", TokenType::kDivide}, - {"(", TokenType::kLparen}, {")", TokenType::kRparen}, {",", TokenType::kComma}, - {"Abs", TokenType::kAbs}, {"ExpectEq", TokenType::kEq}, {"ExpectNe", TokenType::kNe}, - {"ExpectLe", TokenType::kLe}, {"ExpectLt", TokenType::kLt}, {"True", TokenType::kTrue}, - {"False", TokenType::kFalse}, {"Floor", TokenType::kFloor}, {"Mod", TokenType::kMod}, - {"LogicAnd", TokenType::kLogicalAnd}, {"LogicOr", TokenType::kLogicalOr}}; -} -Scanner::Scanner(const std::string &input) : input_(input), pos_(0) { - Advance(); -} - -graphStatus Scanner::GetNextToken(Token &token) { - SkipWhitespace(); - - if (currentChar_ == '\0') { - token = {TokenType::kEnd, ""}; - return GRAPH_SUCCESS; - } - - // 识别数字 - if (std::isdigit(currentChar_)) { - token = ReadNumber(); - return GRAPH_SUCCESS; - } - // 识别函数 - if (std::isalpha(currentChar_)) { - std::string identifier = ReadIdentifier(); - if (kTokenMap.find(identifier) != kTokenMap.end()) { - token = {kTokenMap[identifier], identifier}; - } else { - token = {TokenType::kIdentifier, identifier}; - } - return GRAPH_SUCCESS; - } - // 识别操作符 - const auto token_key = std::string(1, currentChar_); - if (kTokenMap.find(token_key) != kTokenMap.end()) { - Advance(); - token = {kTokenMap[token_key], token_key}; - return GRAPH_SUCCESS; - } - GELOGE(ge::PARAM_INVALID, "Unsupported operator: %s", token_key.c_str()); - token = {TokenType::kEnd, ""}; - return GRAPH_FAILED; -} - -void Scanner::Advance(size_t steps) { - while (steps-- > 0) { - if (pos_ < input_.length()) { - currentChar_ = input_[pos_++]; - } else { - currentChar_ = '\0'; - break; - } - } -} - -void Scanner::SkipWhitespace() { - while (std::isspace(currentChar_)) { - Advance(); - } -} - -std::string Scanner::ReadIdentifier() { - std::string result; - while (std::isalnum(currentChar_) || currentChar_ == ':' || currentChar_ == '_') { - result += currentChar_; - Advance(); - } - return result; -} - -Token Scanner::ReadNumber() { - std::string result; - while (std::isdigit(currentChar_) || currentChar_ == '.') { - result += currentChar_; - Advance(); - } - return {TokenType::kNumber, result}; -} -} // namespace ge \ No newline at end of file diff --git a/graph/expression/scanner.h b/graph/expression/scanner.h deleted file mode 100644 index ee1e1f9899ff24c92a7be3891db6df6e4f8ecf0f..0000000000000000000000000000000000000000 --- a/graph/expression/scanner.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef GRAPH_EXPRESSION_SCANNER_H_ -#define GRAPH_EXPRESSION_SCANNER_H_ -#include -#include -#include -#include -#include -#include "graph/types.h" -#include "graph/ge_error_codes.h" - -namespace ge { -enum class TokenType { - kIdentifier, // variable names like a, b, c, s0, s1,... - kNumber, // numeric constants (if needed) - kPlus, // '+' - kMinus, // '-' - kMultiply, // '*' - kDivide, // '/' - kComma, // ',' - kLparen, // '(' - kRparen, // ')' - kMax, // 'std::max' - kMin, // 'std::min' - kPow, // pow - kMod, // mod - kLog, // log - kCeil, // ceil - kFloor, // floor - kAbs, // abs - kRational, // rational - kEq, // EXPECT_EQ - kNe, // EXPECT_NE - kLt, // EXPECT_LT - kLe, // EXPECT_LE - kTrue, // True - kFalse, // False - kLogicalAnd, // LogicalAnd - kLogicalOr, // LogicalOr - kEnd // End of input -}; -struct Token { - TokenType type; - std::string value; // For identifiers and numbers -}; -class Scanner { -public: - explicit Scanner(const std::string &input); - graphStatus GetNextToken(Token &token); - -private: - void Advance(size_t steps = 1); - void SkipWhitespace(); - std::string ReadIdentifier(); - Token ReadNumber(); - - const std::string &input_; - size_t pos_; - char_t currentChar_; -}; -} // namespace ge - -#endif // GRAPH_EXPRESSION_SCANNER_H_ \ No newline at end of file diff --git a/graph/expression/symbol_checker.cc b/graph/expression/symbol_checker.cc deleted file mode 100644 index 2e1a6880e59ac3aef2361d0186d969158a95819a..0000000000000000000000000000000000000000 --- a/graph/expression/symbol_checker.cc +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "attribute_group/attr_group_shape_env.h" -#include "graph/symbolizer/symbolic.h" -#include "common/checker.h" - -namespace ge { -namespace sym { -bool ExpectSymbolEq(const Expression &e0, const Expression &e1, - const char_t *file, const int64_t line) { - GE_ASSERT_NOTNULL(file); - bool res = ExpectSymbolBool(sym::Eq(e0, e1), file, line); - if (res) { - GE_ASSERT_SUCCESS(GetCurShapeEnvContext()->AppendReplacement(e0, e1), - "[%s:%lld] Append symbol equivalence %s to %s failed", - file, line, e0.Serialize().get(), e1.Serialize().get()); - } - return res; -} - -bool ExpectSymbolNe(const Expression &e0, const Expression &e1, - const char_t *file, const int64_t line) { - GE_ASSERT_NOTNULL(file); - bool res = ExpectSymbolBool(sym::Ne(e0, e1), file, line); - if (!res) { - GE_ASSERT_SUCCESS(GetCurShapeEnvContext()->AppendReplacement(e0, e1), - "[%s:%lld] Append symbol equivalence %s to %s failed", - file, line, e0.Serialize().get(), e1.Serialize().get()); - } - return res; -} - -bool ExpectSymbolBool(const Expression &expr, const char_t *file, const int64_t line) { - GE_ASSERT_NOTNULL(file); - GE_ASSERT_TRUE(expr.IsBooleanExpr(), "Only boolean expr can be use to check symbol, expr: %s", - expr.Serialize().get()); - if (expr.IsConstExpr()) { - bool const_value = false; - GE_ASSERT_TRUE(expr.GetConstValue(const_value)); - return const_value; - } - if (GetCurShapeEnvContext() == nullptr) { - GELOGW("Shape env is nullptr, cannot check symbol, expr: %s", expr.Serialize().get()); - return false; - } - bool hint_value = false; - GE_ASSERT_TRUE(expr.GetHint(hint_value)); - if (hint_value) { - GE_ASSERT_SUCCESS(GetCurShapeEnvContext()->AppendSymbolCheckInfo(expr.Simplify(), file, line)); - } else { - GE_ASSERT_SUCCESS(GetCurShapeEnvContext()->AppendSymbolCheckInfo(sym::Not(expr.Simplify()), file, line)); - } - return hint_value; -} - -bool AssertSymbolEq(const Expression &e0, const Expression &e1, - const char_t *file, const int64_t line) { - GE_ASSERT_NOTNULL(file); - GE_ASSERT_TRUE(AssertSymbolBool(ge::sym::Eq(e0, e1), file, line)); - GE_ASSERT_SUCCESS(GetCurShapeEnvContext()->AppendReplacement(e0, e1), - "[%s:%lld] Append symbol equivalence %s to %s failed", - file, line, e0.Serialize().get(), e1.Serialize().get()); - return true; -} - -bool AssertSymbolBool(const Expression &expr, const char_t *file, const int64_t line) { - GE_ASSERT_NOTNULL(file); - GE_ASSERT_TRUE(expr.IsBooleanExpr(), "[%s:%lld] Only boolean expr can be used to assert, expr: %s", - file, line, expr.Serialize().get()); - if (expr.IsConstExpr()) { - bool const_value = false; - GE_ASSERT_TRUE(expr.GetConstValue(const_value)); - GE_ASSERT_TRUE(const_value, "[%s:%lld] Assert %s failed", - file, line, expr.Serialize().get()); - return const_value; - } - if (GetCurShapeEnvContext() == nullptr) { - GELOGW("Shape env is nullptr, cannot check symbol, expr: %s", expr.Serialize().get()); - return false; - } - bool hint_value = false; - GE_ASSERT_TRUE(expr.GetHint(hint_value)); - GE_ASSERT_TRUE(hint_value, "[%s:%lld] Assert %s failed", file, line, expr.Serialize().get()); - GE_ASSERT_SUCCESS(GetCurShapeEnvContext()->AppendSymbolAssertInfo(expr.Simplify(), file, line)); - return true; -} -} -} \ No newline at end of file diff --git a/graph/expression/symbol_operator.cc b/graph/expression/symbol_operator.cc deleted file mode 100644 index b0128539f071cccc5045248201c4d3e4dce6e2ee..0000000000000000000000000000000000000000 --- a/graph/expression/symbol_operator.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "graph/symbolizer/symbolic.h" -#include "attribute_group/attr_group_shape_env.h" -#include "expression_impl.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/math_util.h" -#include "common/checker.h" -#include "const_values.h" - -namespace ge { -namespace sym { -Expression Add(const Expression &a, const Expression &b) { - return Expression(Add(a.impl_, b.impl_)); -} - -Expression Sub(const Expression &a, const Expression &b) { - return Expression(Sub(a.impl_, b.impl_)); -} - -Expression Mul(const Expression &a, const Expression &b) { - return Expression(Mul(a.impl_, b.impl_)); -} - -Expression Div(const Expression &a, const Expression &b) { - return Expression(Div(a.impl_, b.impl_)); -} - -Expression Max(const Expression &a, const Expression &b) { - return Expression(Max(a.impl_, b.impl_)); -} - -Expression Min(const Expression &a, const Expression &b) { - return Expression(Min(a.impl_, b.impl_)); -} - -Expression Abs(const Expression &a) { - return Expression(Abs(a.impl_)); -} - -Expression Pow(const Expression &base, const Expression &exp) { - return Expression(Pow(base.impl_, exp.impl_)); -} - -Expression Mod(const Expression &base, const Expression &exp) { - return Expression(Mod(base.impl_, exp.impl_)); -} - -Expression Log(const Expression &a) { - return Expression(Log(a.impl_)); -} - -Expression Log(const Expression &arg, const Expression &base) { - return Expression(Log(arg.impl_, base.impl_)); -} - -Expression Ceiling(const Expression &a) { - return Expression(Ceiling(a.impl_)); -} - -Expression Floor(const Expression &arg) { - return Expression(Floor(arg.impl_)); -} - -Expression Coeff(const Expression &b, const Expression &x, const Expression &n) { - return Expression(Coeff(b.impl_, x.impl_, n.impl_)); -} - -Expression Rational(int32_t num, int32_t den) { - auto left = ExpressionImpl::CreateExpressionImpl(num); - auto right = ExpressionImpl::CreateExpressionImpl(den); - return Expression(Rational(left, right)); -} - -Expression Align(const Expression &arg, uint32_t alignment) { - if (alignment == 0U) { - GELOGE(FAILED, "Alignment should more than 0"); - return Expression(nullptr); - } - auto align = Symbol(alignment); - return Mul(Ceiling(Div(arg, align)), align); -} - -Expression AlignWithPositiveInteger(const Expression &arg, uint32_t alignment) { - if (alignment == 0U) { - GELOGE(FAILED, "Alignment should more than 0"); - return Expression(nullptr); - } - auto align = Symbol(alignment); - return Mul(Floor(Div(Add(arg, Sub(align, kSymbolOne)), align)), align); -} - -Expression Eq(const Expression &a, const Expression &b) { - return Expression(Eq(a.impl_, b.impl_)); -} - -Expression Ne(const Expression &a, const Expression &b) { - return Expression(Ne(a.impl_, b.impl_)); -} - -Expression Ge(const Expression &a, const Expression &b) { - return Expression(Le(b.impl_, a.impl_)); -} - -Expression Gt(const Expression &a, const Expression &b) { - return Expression(Lt(b.impl_, a.impl_)); -} - -Expression Le(const Expression &a, const Expression &b) { - return Expression(Le(a.impl_, b.impl_)); -} - -Expression Lt(const Expression &a, const Expression &b) { - return Expression(Lt(a.impl_, b.impl_)); -} - -Expression Not(const Expression &a) { - return Expression(Not(a.impl_)); -} - -Expression Neg(const Expression &a) { - return Expression(Neg(a.impl_)); -} - -Expression LogicalAnd(const std::vector &a) { - std::vector impl_vec; - for (auto s : a) { - GE_ASSERT_NOTNULL(s.impl_); - impl_vec.emplace_back(std::move(s.impl_)); - } - return Expression(LogicalAnd(impl_vec)); -} - -Expression LogicalOr(const std::vector &a) { - std::vector impl_vec; - for (auto s : a) { - GE_ASSERT_NOTNULL(s.impl_); - impl_vec.emplace_back(std::move(s.impl_)); - } - return Expression(LogicalOr(impl_vec)); -} -} // namespace sym -} // namespace ge \ No newline at end of file diff --git a/graph/expression/symbolic_utils.cc b/graph/expression/symbolic_utils.cc deleted file mode 100644 index 3fcabe91e16826bbd76f2373558cee6887410440..0000000000000000000000000000000000000000 --- a/graph/expression/symbolic_utils.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/symbolizer/symbolic.h" -#include "graph/symbolizer/symbolic_utils.h" -#include "common/checker.h" -#include "attribute_group/attr_group_shape_env.h" - -namespace ge { -std::string SymbolicUtils::ToString(const Expression &e) { - auto ret = e.Str(StrType::kStrCpp); - return (ret != nullptr) ? ret.get() : "invalid expression"; -} - -TriBool SymbolicUtils::StaticCheckEq(const Expression &e1, const Expression &e2) { - return StaticCheckBool(sym::Eq(e1.Simplify(), e2.Simplify())); -} - -TriBool SymbolicUtils::StaticCheckNe(const Expression &e1, const Expression &e2) { - return StaticCheckBool(sym::Ne(e1.Simplify(), e2.Simplify())); -} - -TriBool SymbolicUtils::StaticCheckLt(const Expression &e1, const Expression &e2) { - return StaticCheckBool(sym::Lt(e1.Simplify(), e2.Simplify())); -} - -TriBool SymbolicUtils::StaticCheckLe(const Expression &e1, const Expression &e2) { - return StaticCheckBool(sym::Le(e1.Simplify(), e2.Simplify())); -} - -TriBool SymbolicUtils::StaticCheckGt(const Expression &e1, const Expression &e2) { - return StaticCheckBool(sym::Gt(e1.Simplify(), e2.Simplify())); -} - -TriBool SymbolicUtils::StaticCheckGe(const Expression &e1, const Expression &e2) { - return StaticCheckBool(sym::Ge(e1.Simplify(), e2.Simplify())); -} - -TriBool SymbolicUtils::StaticCheckBool(const Expression &expr) { - GE_ASSERT_TRUE(expr.IsBooleanExpr(), "Only boolean expr can do static check, expr: %s", - expr.Serialize().get()); - bool value = false; - if (expr.IsConstExpr()) { - GE_ASSERT_TRUE(expr.GetConstValue(value)); - return value ? TriBool::kTrue : TriBool::kFalse; - } - if (GetCurShapeEnvContext() == nullptr) { - GELOGW("Shape env is nullptr, cannot do static check, expr: %s", expr.Serialize().get()); - return TriBool::kUnknown; - } - if (GetCurShapeEnvContext()->HasSymbolInfo(expr) == TriBool::kTrue) { - GELOGI("Find check info of expr: %s, no need simplify guard", SymbolicUtils::ToString(expr).c_str()); - return TriBool::kTrue; - } - const auto simplify_expr = expr.Simplify(); - value = false; - // 化简后判断是否是常量 - if (simplify_expr.IsConstExpr()) { - GE_ASSERT_TRUE(simplify_expr.GetConstValue(value)); - return value ? TriBool::kTrue : TriBool::kFalse; - } - return GetCurShapeEnvContext()->HasSymbolInfo(simplify_expr); -} -} - diff --git a/graph/fast_graph/execute_graph.cc b/graph/fast_graph/execute_graph.cc deleted file mode 100644 index 244f12dd6733be83a30770c18c74a87b2c45082a..0000000000000000000000000000000000000000 --- a/graph/fast_graph/execute_graph.cc +++ /dev/null @@ -1,972 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/fast_graph/execute_graph.h" -#include "graph/debug/ge_log.h" -#include "common/ge_common/ge_types.h" -#include "fast_graph/fast_graph_impl.h" -#include "inc/graph/utils/fast_node_utils.h" - -namespace ge { -namespace { -enum class FastTopoSortingMode { kBFS = 0, kDFS, kRDFS }; -const std::string kMemoryPriority = "MemoryPriority"; -constexpr int32_t kTopoSortingBfs = 0; -constexpr int32_t kTopoSortingDfs = 1; -constexpr int32_t kTopoSortingReverseDfs = 2; - -FastTopoSortingMode GetTopoSortingStrategy() { - std::string topo_sorting_mode_str; - if ((ge::GetContext().GetOption(ge::OPTION_TOPOSORTING_MODE, topo_sorting_mode_str) == GRAPH_SUCCESS) && - (!topo_sorting_mode_str.empty())) { - const int32_t base = 10; - const auto topo_sorting_mode = static_cast(std::strtol(topo_sorting_mode_str.c_str(), nullptr, base)); - if (topo_sorting_mode == kTopoSortingBfs) { - return FastTopoSortingMode::kBFS; - } else if (topo_sorting_mode == kTopoSortingDfs) { - return FastTopoSortingMode::kDFS; - } else if (topo_sorting_mode == kTopoSortingReverseDfs) { - return FastTopoSortingMode::kRDFS; - } else { - GELOGW("OPTION_TOPOSORTING_MODE = %s is invalid", topo_sorting_mode_str.c_str()); - } - } - if (ge::GetContext().GetTrainGraphFlag()) { - GELOGI("train flag is 1, use BFS."); - return FastTopoSortingMode::kBFS; - } - - GELOGI("train flag is 0, use DFS."); - return FastTopoSortingMode::kDFS; -} - -bool IsMemoryPriority() { - std::string memory_optimization_policy; - (void)ge::GetContext().GetOption(MEMORY_OPTIMIZATION_POLICY, memory_optimization_policy); - return (memory_optimization_policy == ge::kMemoryPriority); -} - -void GetOutNodesFromEdge(std::map &map_in_edge_num, FastNode *node, - std::vector &out_nodes) { - const auto iter = map_in_edge_num.find(node); - if (iter != map_in_edge_num.end()) { - --iter->second; - if (iter->second == 0U) { - out_nodes.push_back(node); - } - } -} - -bool InputIsLongLifeTimeNode(const FastNode *node, const ExecuteGraph *execute_graph) { - bool match = false; - auto num = node->GetDataInNum(); - for (size_t i = 0LL; i < num; ++i) { - // the input parameter must be the id of data io - const auto &edge = node->GetInDataEdgeByIndex(i); - if (edge == nullptr) { - continue; - } - - auto &peer_node = edge->src; - if ((peer_node == nullptr) || (peer_node->GetExtendInfo() == nullptr)) { - continue; - } - - const auto type = peer_node->GetType(); - static std::unordered_set kDataSet = {DATA, REFDATA, AIPPDATA, ANN_DATA}; - static const std::unordered_set kConstPlaceHolderOpSet = {CONSTPLACEHOLDER}; - auto graph = peer_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - const bool is_io_data = - (execute_graph == graph) && ((kDataSet.count(type) > 0U) || (kConstPlaceHolderOpSet.count(type) > 0U)); - if ((!FastNodeUtils::GetConstOpType(peer_node)) && (type != VARIABLE) && (type != VARIABLEV2) && (!is_io_data)) { - return false; - } else { - match = true; - } - GELOGD("Node:%s peer:%s type :%s", node->GetName().c_str(), peer_node->GetName().c_str(), - peer_node->GetType().c_str()); - } - - return match; -} - -/// variable const -/// \ / -/// first node -/// | -/// middle node -/// | -/// last node -/// / | -/// node1 node2 -graphStatus GetOutNodeIndex(std::vector &nodes, size_t &index, size_t &out_count, - const ExecuteGraph *execute_graph) { - if (nodes.empty()) { - return GRAPH_FAILED; - } - - // first node's inputs muse be long life time - if ((nodes.size() == 1UL) && (!InputIsLongLifeTimeNode(nodes.front(), execute_graph))) { - return GRAPH_FAILED; - } - - const auto &node = nodes.back(); - auto op_desc = node->GetOpDescBarePtr(); - GE_CHECK_NOTNULL(op_desc); - // middle node must be single input - if ((nodes.size() != 1UL) && (node->GetDataInNum() != 1UL)) { - return GRAPH_FAILED; - } - - int64_t min_index = 0LL; - FastNode *delay_node = nullptr; - for (const auto &out_node : node->GetAllOutNodes()) { - out_count++; - GE_CHECK_NOTNULL(out_node); - auto out_node_desc = out_node->GetOpDescBarePtr(); - GE_CHECK_NOTNULL(out_node_desc); - GELOGD("Node:%s id:%ld peer node:%s id:%ld", node->GetName().c_str(), op_desc->GetId(), - out_node_desc->GetName().c_str(), out_node_desc->GetId()); - if ((min_index == 0LL) || (out_node_desc->GetId() < min_index)) { - min_index = out_node_desc->GetId(); - delay_node = out_node; - } - } - - if (delay_node != nullptr) { - index = static_cast(min_index); - if (index > (static_cast(op_desc->GetId()) + 1UL)) { - GELOGD("Node:%s id:%ld delay to:%s id:%zu", node->GetName().c_str(), op_desc->GetId(), - delay_node->GetName().c_str(), index); - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -void DelayTopoSort(std::vector &nodes, const ExecuteGraph *execute_graph) { - // pair.first: this node can be delay or not - // pair.second: delayed nodes to this node - std::vector>> delay_nodes; - delay_nodes.resize(nodes.size()); - - // set init index - for (size_t i = 0UL; i < delay_nodes.size(); ++i) { - nodes[i]->GetOpDescBarePtr()->SetId(static_cast(i)); - delay_nodes[i].first = true; - delay_nodes[i].second.emplace_back(nodes[i]); - } - - // move delayed node to fit node - size_t delay_node_count = 0UL; - for (size_t i = 0UL; i < delay_nodes.size(); ++i) { - size_t delay_to_index = 0UL; - size_t out_count = 0UL; - if (delay_nodes[i].first && - (GetOutNodeIndex(delay_nodes[i].second, delay_to_index, out_count, execute_graph) == GRAPH_SUCCESS) && - (delay_to_index < delay_nodes.size()) && (delay_to_index > (i + 1UL))) { - delay_nodes[delay_to_index].second.insert(delay_nodes[delay_to_index].second.begin(), - delay_nodes[i].second.begin(), delay_nodes[i].second.end()); - if (out_count > 1UL) { - // last node can not be delay - delay_nodes[delay_to_index].first = false; - } - delay_nodes[i].second.clear(); - delay_node_count++; - } - } - if (delay_node_count > 0UL) { - nodes.clear(); - for (size_t i = 0UL; i < delay_nodes.size(); ++i) { - if (!delay_nodes[i].second.empty()) { - nodes.insert(nodes.end(), delay_nodes[i].second.begin(), delay_nodes[i].second.end()); - } - } - GELOGI("Delay %zu nodes.", delay_node_count); - } -} - -void InitNodeStatus(const ExecuteGraph *compute_graph, std::vector &reverse_dfs_nodes_info) { - reverse_dfs_nodes_info.clear(); - reverse_dfs_nodes_info.resize(compute_graph->GetDirectNodesSize()); - int64_t index = 0; - for (const auto &node : compute_graph->GetDirectNode()) { - reverse_dfs_nodes_info[index].size = 0U; - reverse_dfs_nodes_info[index].status = FastWalkStatus::kNotWalked; - node->GetOpDescBarePtr()->SetId(index); - index++; - } -} -} // namespace - -ExecuteGraph::ExecuteGraph(const std::string &name) { - graph_shared_ = std::make_shared>(name); - graph_shared_->SetOwnerGraph(this); -} - -ExecuteGraph &ExecuteGraph::operator=(ge::ExecuteGraph &exec_graph) { - if (&exec_graph == this) { - return *this; - } - - graph_shared_ = exec_graph.graph_shared_; - names_to_subgraph_ = exec_graph.names_to_subgraph_; - inputs_order_ = exec_graph.inputs_order_; - AttrHolder::SwapBase(exec_graph); - return *this; -} - -ExecuteGraph &ExecuteGraph::CompleteCopy(ge::ExecuteGraph &exec_graph) { - if (&exec_graph == this) { - return *this; - } - - graph_shared_->DeepCopy(*(exec_graph.graph_shared_)); - - const std::map &original_attrs = AttrUtils::GetAllAttrs(exec_graph); - for (auto const &attr_iter : original_attrs) { - if (this->TrySetAttr(attr_iter.first, attr_iter.second) != GRAPH_SUCCESS) { - GELOGW("Set inherit original attr[%s] failed, Please Check.", attr_iter.first.c_str()); - } - } - - inputs_order_.clear(); - for (auto &item : exec_graph.inputs_order_) { - inputs_order_.push_back(item); - } - return *this; -} - -FastNode *ExecuteGraph::AddNode(const OpDescPtr &op) { - return graph_shared_->AddNode(op); -} - -FastNode *ExecuteGraph::AddNode(const OpDescPtr &op, int64_t id) { - return graph_shared_->AddNode(op, id); -} - -void ExecuteGraph::RemoveNodeFromNodesFree(const FastNode *const fast_node) const { - auto quick_node = FastGraphUtils::GetListElementAddr(fast_node); - auto owner = quick_node->owner; - auto mode = quick_node->mode; - if ((owner != nullptr) && (mode == ListMode::kFreeMode)) { - owner->erase(quick_node); - } -} - -FastNode *ExecuteGraph::AddNode(FastNode *fast_node) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return nullptr; - } - - RemoveNodeFromNodesFree(fast_node); - return graph_shared_->AddNode(fast_node); -} - -FastNode *ExecuteGraph::AddNodeFront(const OpDescPtr &op) { - return graph_shared_->AddNodeFront(op); -} - -FastNode *ExecuteGraph::AddNodeFront(FastNode *const fast_node) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return nullptr; - } - - RemoveNodeFromNodesFree(fast_node); - return graph_shared_->AddNodeFront(fast_node); -} - -graphStatus ExecuteGraph::RemoveJustNode(const FastNode *const fast_node) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return GRAPH_FAILED; - } - return graph_shared_->RemoveJustNode(FastGraphUtils::GetListElementAddr(fast_node)); -} - -FastEdge *ExecuteGraph::AddEdge(FastNode *const src, int32_t src_index, FastNode *const dst, int32_t dst_index) { - if ((src == nullptr) || (dst == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return nullptr; - } - - if (!CheckNodeIsInGraph(src) || !CheckNodeIsInGraph(dst)) { - GELOGW("The src %s or dst %s not belong to graph.", src->GetNamePtr(), dst->GetNamePtr()); - } - - return graph_shared_->AddEdge(src, src_index, dst, dst_index); -} - -graphStatus ExecuteGraph::RemoveEdge(const FastEdge *const edge) { - if (edge == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The edge is nullptr."); - GE_LOGE("[Check][Param] The edge is nullptr."); - return GRAPH_FAILED; - } - return graph_shared_->RemoveEdge(FastGraphUtils::GetListElementAddr(edge)); -} - -const FastNode *ExecuteGraph::GetParentNodeBarePtr() const { - return graph_shared_->GetParentNode(); -} - -FastNode *ExecuteGraph::GetParentNodeBarePtr() { - return graph_shared_->GetParentNode(); -} - -void ExecuteGraph::SetParentNode(FastNode *const node) { - graph_shared_->SetParentNode(node); -} - -ExecuteGraph *ExecuteGraph::AddSubGraph(const std::shared_ptr &sub_graph) { - if (sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Try to add a null subgraph"); - GE_LOGE("[Check][Param] Try to add a null subgraph"); - return nullptr; - } - - auto ret = graph_shared_->AddSubGraph(sub_graph.get()); - if (ret == nullptr) { - return nullptr; - } - - names_to_subgraph_[sub_graph->GetName()] = {sub_graph, ret}; - return ret->data; -} - -graphStatus ExecuteGraph::RemoveSubGraph(const ExecuteGraph *const sub_graph) { - if (sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Try to add a null subgraph"); - GE_LOGE("[Check][Param] Try to add a null subgraph"); - return GRAPH_PARAM_INVALID; - } - - return RemoveSubGraph(sub_graph->GetName()); -} - -ExecuteGraph *ExecuteGraph::AddSubGraph(const std::shared_ptr &sub_graph_ptr, const std::string &name) { - if (sub_graph_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Try to add a null subgraph, name %s", name.c_str()); - GE_LOGE("[Check][Param] Try to add a null subgraph, name %s", name.c_str()); - return nullptr; - } - - auto sub_graph = sub_graph_ptr.get(); - const auto parent_graph = sub_graph->GetParentGraphBarePtr(); - if (parent_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Try to add subgraph without parent graph, name %s", name.c_str()); - GE_LOGE("[Get][Graph] Try to add subgraph without parent graph, name %s", name.c_str()); - return nullptr; - } - - const auto parent_node = sub_graph->GetParentNodeBarePtr(); - if ((parent_node == nullptr) || (parent_node->GetExtendInfo() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Try to add a subgraph without parent node, name %s", name.c_str()); - GE_LOGE("[Get][Node] Try to add a subgraph without parent node, name %s", name.c_str()); - return nullptr; - } - - if (parent_node->GetExtendInfo()->GetOwnerGraphBarePtr() != parent_graph) { - REPORT_INNER_ERR_MSG("E18888", - "Try to add a subgraph which parent node's graph is not equal to " - "the subgraph's parent graph, subgraph name %s, parent node name %s", - sub_graph->GetName().c_str(), parent_graph->GetName().c_str()); - GE_LOGE( - "[Check][Param] Try to add a subgraph which parent node's graph is not equal to " - "the subgraph's parent graph, subgraph name %s, parent node name %s", - sub_graph->GetName().c_str(), parent_graph->GetName().c_str()); - return nullptr; - } - - if (name != sub_graph->GetName()) { - GELOGW("[Add][Subgraph] The subgraph name %s is different with input %s", sub_graph->GetName().c_str(), - name.c_str()); - } - - if (names_to_subgraph_.find(sub_graph->GetName()) != names_to_subgraph_.end()) { - REPORT_INNER_ERR_MSG("E18888", "The subgraph %s existed", GetName().c_str()); - GE_LOGE("[Check][Param] The subgraph %s existed", GetName().c_str()); - return nullptr; - } - - auto ret = graph_shared_->AddSubGraph(sub_graph); - if (ret == nullptr) { - return nullptr; - } - names_to_subgraph_[sub_graph->GetName()] = {sub_graph_ptr, ret}; - return ret->data; -} - -graphStatus ExecuteGraph::RemoveSubGraph(const std::string &name) { - auto iter = names_to_subgraph_.find(name); - if (iter != names_to_subgraph_.end()) { - auto quick_graph = iter->second.quick_graph; - graph_shared_->RemoveSubGraph(quick_graph); - names_to_subgraph_.erase(iter); - } - - return GRAPH_SUCCESS; -} - -ExecuteGraph *ExecuteGraph::GetSubGraph(const std::string &name) const { - const ExecuteGraph *exec_graph = graph_shared_->GetParentGraph(); - if (exec_graph == nullptr) { - const auto iter = names_to_subgraph_.find(name); - if (iter == names_to_subgraph_.end()) { - return nullptr; - } - // iter->second.quick_graph is not nullptr - auto quick_graph = iter->second.quick_graph; - return quick_graph->data; - } else { - return exec_graph->GetSubGraph(name); - } -} - -void ExecuteGraph::ClearAllSubGraph() { - names_to_subgraph_.clear(); - return graph_shared_->ClearAllSubGraph(); -} - -std::vector ExecuteGraph::GetDirectNode() const { - return graph_shared_->GetDirectNode(); -} - -size_t ExecuteGraph::GetDirectNodesSize() const { - return graph_shared_->GetDirectNodesSize(); -} - -std::vector ExecuteGraph::GetAllEdges() const { - return graph_shared_->GetAllEdges(); -} - -std::vector ExecuteGraph::GetAllSubgraphs() const { - return graph_shared_->GetAllSubgraphs(); -} - -FastNode *ExecuteGraph::AddInputNode(FastNode *fast_node) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return nullptr; - } - - RemoveNodeFromNodesFree(fast_node); - return graph_shared_->AddInputNode(fast_node); -} - -graphStatus ExecuteGraph::RemoveInputNode(FastNode *const fast_node) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return GRAPH_FAILED; - } - - return graph_shared_->RemoveInputNode(fast_node); -} - -FastNode *ExecuteGraph::AddOutputNodeByIndex(FastNode *const fast_node, int32_t index) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return nullptr; - } - - RemoveNodeFromNodesFree(fast_node); - return graph_shared_->AddOutputNodeByIndex(fast_node, index); -} - -graphStatus ExecuteGraph::RemoveOutputNode(const FastNode *const fast_node) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return GRAPH_FAILED; - } - - return graph_shared_->RemoveOutputNode(fast_node); -} - -const FastNode *ExecuteGraph::FindNode(size_t token) const { - auto quick_node = graph_shared_->FindNode(token); - return ((quick_node == nullptr) ? nullptr : &(quick_node->data)); -} - -graphStatus ExecuteGraph::SortNodes(std::vector &stack, - std::map &map_in_edge_num) const { - // Record the number of non data nodes but no input nodes - std::vector data_nodes_vec; - std::vector no_data_nodes_vec; - for (const auto &node : graph_shared_->GetDirectNodeToModify()) { - // The node is not nullptr. - auto fast_node = &FastGraphUtils::GetNode(node); - GE_IF_BOOL_EXEC(fast_node->GetOpDescBarePtr() == nullptr, continue); - map_in_edge_num[fast_node] = static_cast(fast_node->GetInEdgeSize()); - if (map_in_edge_num[fast_node] == 0U) { - if ((strcmp(fast_node->GetOpDescBarePtr()->GetTypePtr(), DATA) != 0)) { - no_data_nodes_vec.emplace_back(fast_node); - continue; - } - - // Need to insert the data nodes in reverse order - data_nodes_vec.emplace_back(fast_node); - } - } - (void)stack.insert(stack.end(), no_data_nodes_vec.rbegin(), no_data_nodes_vec.rend()); - (void)stack.insert(stack.end(), data_nodes_vec.rbegin(), data_nodes_vec.rend()); - - /// Make sure the inputs order matches with user-designated - /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_) - /// 2. Compare two indices, if not match, swap the positions of two inputs - /// *: Remind: stack is reverse-order - for (size_t i = 0UL; i < stack.size(); ++i) { - // If not found in 'inputs_order_', skip it - const auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); - GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); - const auto inx_i = it_i - inputs_order_.begin(); - for (size_t j = i + 1UL; j < stack.size(); ++j) { - // If not found in 'inputs_order_', skip it - const auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); - GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); - - // Compare index, swap them if it should be - const auto inx_j = it_j - inputs_order_.begin(); - GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); - } - } - - return GRAPH_SUCCESS; -} - -void ExecuteGraph::GetOutNodesFromEdgesToMap(std::map &map_in_edge_num, FastNode *node, - std::map &breadth_node_map) const { - auto iter = map_in_edge_num.find(node); - if (iter != map_in_edge_num.end()) { - --iter->second; - if (iter->second == 0U) { - (void)breadth_node_map.emplace(node->GetName(), node); - } - } -} - -graphStatus ExecuteGraph::CollectBreadthOutNode(const FastNode *const node, - std::map &map_in_edge_num, - std::map &breadth_node_map) const { - auto &edges = node->GetAllOutDataEdgesRef(); - - for (size_t i = 0UL; i < edges.size(); ++i) { - std::for_each(edges[i].begin(), edges[i].end(), [&map_in_edge_num, &breadth_node_map, this](FastEdge *edge) { - if ((edge != nullptr) && (edge->dst_input != kControlEdgeIndex)) { - GetOutNodesFromEdgesToMap(map_in_edge_num, edge->dst, breadth_node_map); - } - }); - } - - auto &control_edges = node->GetAllOutControlEdgesRef(); - if (control_edges.empty()) { - return GRAPH_SUCCESS; - } - std::for_each(control_edges.begin(), control_edges.end(), - [&map_in_edge_num, &breadth_node_map, this](FastEdge *edge) { - if (edge != nullptr) { - GetOutNodesFromEdgesToMap(map_in_edge_num, edge->dst, breadth_node_map); - } - }); - - return GRAPH_SUCCESS; -} - -graphStatus ExecuteGraph::BFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ExecuteGraph *const compute_graph) const { - GELOGD("Runing_Bfs_Sort: %s", GetName().c_str()); - (void)reverse; - const bool is_mem_priority = IsMemoryPriority(); - std::vector reverse_dfs_nodes_info; - if (is_mem_priority) { - InitNodeStatus(compute_graph, reverse_dfs_nodes_info); - } - TopoSortStack topo_sort_stack(&reverse_dfs_nodes_info, is_mem_priority); - std::vector stack_input; - std::map breadth_node_map; - std::map map_in_edge_num; - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); - // Only data nodes here - while ((!stack_input.empty()) || (!topo_sort_stack.Empty())) { - FastNode *node = nullptr; - if (!topo_sort_stack.Empty()) { - node = topo_sort_stack.Pop(); - } else { - node = stack_input.back(); - stack_input.pop_back(); - } - - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDescBarePtr()); - GELOGD("node_vec.push_back %s", node->GetOpDescBarePtr()->GetName().c_str()); - (void)CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); - - for (const auto &name_node : breadth_node_map) { - (void)topo_sort_stack.Push(name_node.second); - } - breadth_node_map.clear(); - } - return GRAPH_SUCCESS; -} - -graphStatus ExecuteGraph::DFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ExecuteGraph *const compute_graph) const { - GELOGD("Runing_Dfs_Sort: %s", GetName().c_str()); - std::vector stack; - std::map map_in_edge_num; - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); - const bool is_mem_priority = IsMemoryPriority(); - std::vector reverse_dfs_nodes_info; - if (is_mem_priority) { - InitNodeStatus(compute_graph, reverse_dfs_nodes_info); - } - TopoSortStack topo_sort_stack(&reverse_dfs_nodes_info, is_mem_priority, true, reverse); - for (const auto &node : stack) { - topo_sort_stack.Push(node); - } - - std::vector out_nodes; - const auto stack_push = [&reverse, &topo_sort_stack](std::vector &tmp_out_nodes) { - if (reverse) { - std::reverse(tmp_out_nodes.begin(), tmp_out_nodes.end()); - } - for (const auto &node : tmp_out_nodes) { - topo_sort_stack.Push(node); - } - tmp_out_nodes.clear(); - }; - // Only data nodes here - while (!topo_sort_stack.Empty()) { - FastNode *node = topo_sort_stack.Pop(); - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDescBarePtr()); - auto &edges = node->GetAllOutDataEdgesRef(); - - for (size_t i = 0UL; i < edges.size(); ++i) { - std::for_each(edges[i].begin(), edges[i].end(), [&map_in_edge_num, &out_nodes](FastEdge *edge) { - if (edge != nullptr) { - GetOutNodesFromEdge(map_in_edge_num, edge->dst, out_nodes); - } - }); - - stack_push(out_nodes); - } - - auto control_edges = node->GetAllOutControlEdgesRef(); - std::for_each(control_edges.begin(), control_edges.end(), [&map_in_edge_num, &out_nodes](FastEdge *edge) { - if (edge != nullptr) { - GetOutNodesFromEdge(map_in_edge_num, edge->dst, out_nodes); - } - }); - stack_push(out_nodes); - } - - return GRAPH_SUCCESS; -} - -void ExecuteGraph::GetInNodes(const FastNode *const current, std::vector &input_nodes) const { - auto &in_data_edges = current->GetAllInDataEdgesRef(); - auto &ref = input_nodes; - for (size_t i = 0UL; i < in_data_edges.size(); i++) { - auto edge = in_data_edges[i]; - if (edge != nullptr) { - ref.push_back(edge->src); - } - } - - auto &in_control_edges = current->GetAllInControlEdgesRef(); - std::for_each(in_control_edges.begin(), in_control_edges.end(), [&ref](FastEdge *edge) { - if (edge != nullptr) { - ref.push_back(edge->src); - } - }); -} - -graphStatus ExecuteGraph::RDFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ExecuteGraph *const compute_graph) const { - (void)reverse; - GELOGD("Runing_Reverse_Dfs_Sort: %s", GetName().c_str()); - std::vector reverse_dfs_nodes_info; - InitNodeStatus(compute_graph, reverse_dfs_nodes_info); - - for (const auto quick_node : graph_shared_->GetDirectNodeToModify()) { - auto node = &FastGraphUtils::GetNode(quick_node); - if (!node->OutNodesIsEmpty()) { - continue; - } - std::vector stack = {node}; - while (!stack.empty()) { - const auto current = stack.back(); - NodeStatus &reverse_dfs_node_info = reverse_dfs_nodes_info[current->GetOpDescBarePtr()->GetId()]; - if (reverse_dfs_node_info.status == FastWalkStatus::kNotWalked) { - reverse_dfs_node_info.status = FastWalkStatus::kWalking; - - std::vector in_all_nodes; - GetInNodes(current, in_all_nodes); - - NodeCmp cmp(&reverse_dfs_nodes_info); - std::set> input_nodes{in_all_nodes.begin(), in_all_nodes.end(), cmp}; - stack.insert(stack.end(), input_nodes.cbegin(), input_nodes.cend()); - continue; - } - stack.pop_back(); - if (reverse_dfs_node_info.status == FastWalkStatus::kWalking) { - reverse_dfs_node_info.status = FastWalkStatus::kWalked; - node_vec.emplace_back(current); - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus ExecuteGraph::TopologicalSortingGraph(const ExecuteGraph *const execute_graph, const bool dfs_reverse) { - using TopoSortingStrategy = std::function &, const bool, - const ExecuteGraph *const compute_graph)>; - static const std::map topo_sorting_strategy{ - {FastTopoSortingMode::kBFS, &ExecuteGraph::BFSTopologicalSorting}, - {FastTopoSortingMode::kDFS, &ExecuteGraph::DFSTopologicalSorting}, - {FastTopoSortingMode::kRDFS, &ExecuteGraph::RDFSTopologicalSorting}}; - - std::vector node_vec; - const auto use_topo_strategy = GetTopoSortingStrategy(); - const auto it = topo_sorting_strategy.find(use_topo_strategy); - if (it == topo_sorting_strategy.end()) { - GELOGE(GRAPH_FAILED, "Can not find topo sorting strategy of %d.", static_cast(use_topo_strategy)); - return GRAPH_FAILED; - } - - if (it->second(this, node_vec, dfs_reverse, execute_graph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - // If they are not equal, there is a closed loop - if (node_vec.size() != GetDirectNodesSize()) { - std::set itered_nodes_set; - for (auto &node : node_vec) { - (void)itered_nodes_set.insert(node); - } - REPORT_INNER_ERR_MSG("E18888", "Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph:%s", - GetDirectNodesSize(), node_vec.size(), GetName().c_str()); - GELOGW("[Check][Param] Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", - GetDirectNodesSize(), node_vec.size()); - for (auto node : graph_shared_->GetDirectNodeToModify()) { - if (itered_nodes_set.count(&FastGraphUtils::GetNode(node)) == 0UL) { - GELOGW("[Check][Param] The node %s does not itered when topological sorting", - FastGraphUtils::GetNode(node).GetName().c_str()); - } - } - return GRAPH_FAILED; - } - - if (IsMemoryPriority() || (use_topo_strategy == FastTopoSortingMode::kRDFS)) { - DelayTopoSort(node_vec, execute_graph); - } - - auto ret = graph_shared_->SetNodesAfterSorting(node_vec); - if (ret != GRAPH_SUCCESS) { - return ret; - } - graph_shared_->SetValidFlag(true); - return GRAPH_SUCCESS; -} - -void ExecuteGraph::GetAllNodesFromOpdesc(std::vector> &subgraphs, const OpDesc &op_desc, - std::deque &candidates) const { - const auto &subgraph_names = op_desc.GetSubgraphInstanceNames(); - auto name_iter = subgraph_names.rbegin(); - while (name_iter != subgraph_names.rend()) { - auto subgraph = GetSubGraph(*name_iter); - if (subgraph != nullptr) { - subgraphs.emplace_back(subgraph->shared_from_this()); - auto subgraph_nodes = subgraph->GetDirectNode(); - (void)candidates.insert(candidates.begin(), subgraph_nodes.begin(), subgraph_nodes.end()); - } - ++name_iter; - } -} - -std::vector ExecuteGraph::AllGraphNodes(std::vector> &subgraphs, - const FastNodeFilter &fast_node_filter) const { - std::vector all_nodes; - std::deque candidates; - - auto &ref = graph_shared_->GetDirectNodeToModify(); - for (auto iter = ref.begin(); iter != ref.end(); ++iter) { - QuickNode *node = *iter; - candidates.push_back(&(node->data)); - } - - while (!candidates.empty()) { - FastNode *node = candidates.front(); - candidates.pop_front(); - - if ((fast_node_filter == nullptr) || fast_node_filter(node)) { - all_nodes.emplace_back(node); - } - const auto op_desc = node->GetOpDescBarePtr(); - if (op_desc != nullptr) { - GetAllNodesFromOpdesc(subgraphs, *op_desc, candidates); - } - } - - return all_nodes; -} - -graphStatus ExecuteGraph::TopologicalSorting() { - auto ret = TopologicalSortingGraph(this, false); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Graph [%s] topological sort failed, saved to file black_box", GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Sort][Graph] Graph [%s] topological sort failed, saved to file black_box", - GetName().c_str()); - return ret; - } - - const auto &src_sub_graphs = graph_shared_->sub_graphs_; - if (src_sub_graphs.empty()) { - return GRAPH_SUCCESS; - } - - // partition sub graph - for (auto sub_graph : src_sub_graphs) { - GE_CHECK_NOTNULL(sub_graph); - GE_CHECK_NOTNULL(FastGraphUtils::GetGraph(sub_graph)); - ret = FastGraphUtils::GetGraph(sub_graph)->TopologicalSortingGraph(FastGraphUtils::GetGraph(sub_graph), false); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Sub graph[%s] topological sort failed, saved to file black_box", - FastGraphUtils::GetGraph(sub_graph)->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Sort][Graph] Sub graph[%s] topological sort failed, saved to file black_box", - FastGraphUtils::GetGraph(sub_graph)->GetName().c_str()); - return ret; - } - } - - std::vector> subgraphs; - auto nodes = AllGraphNodes(subgraphs, nullptr); - int64_t i = 0LL; - for (auto iter = nodes.begin(); iter != nodes.end(); ++iter) { - FastNode *node = *iter; // [node: should not be null] - node->GetOpDescBarePtr()->SetId(i); // [node->GetOpDescBarePtr(): should not be null] - ++i; - } - - if (src_sub_graphs.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original - GELOGW("[TopoSort][CheckNodeSize] Keep original subgraph for graph size %zu not equal %zu.", src_sub_graphs.size(), - subgraphs.size()); - return GRAPH_SUCCESS; - } - - graph_shared_->ClearAllSubGraph(); - names_to_subgraph_.clear(); - std::for_each(subgraphs.begin(), subgraphs.end(), - [this](std::shared_ptr &subgraph) { (void) AddSubGraph(subgraph); }); - return GRAPH_SUCCESS; -} - -void ExecuteGraph::SetName(const std::string &name) { - graph_shared_->SetName(name); -} - -std::string ExecuteGraph::GetName() const { - return graph_shared_->GetName(); -} - -void ExecuteGraph::SetParentGraph(ExecuteGraph *const parent_graph) { - graph_shared_->SetParentGraph(parent_graph); -} - -const ExecuteGraph *ExecuteGraph::GetParentGraphBarePtr(void) const { - return graph_shared_->GetParentGraph(); -} - -ExecuteGraph *ExecuteGraph::GetParentGraphBarePtr(void) { - return graph_shared_->GetParentGraph(); -} - -graphStatus ExecuteGraph::RecycleQuickEdge(const FastEdge *const fast_edge) { - if (fast_edge == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return GRAPH_FAILED; - } - return graph_shared_->RecycleQuickEdge(FastGraphUtils::GetListElementAddr(fast_edge)); -} - -graphStatus ExecuteGraph::RecycleQuickNode(const FastNode *const fast_node) { - if (fast_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr."); - GE_LOGE("[Check][Param] The node is nullptr."); - return GRAPH_FAILED; - } - return graph_shared_->RecycleQuickNode(FastGraphUtils::GetListElementAddr(fast_node)); -} - -std::vector ExecuteGraph::GetAllNodes() const { - std::vector> subgraphs; - return AllGraphNodes(subgraphs, nullptr); -} - -std::vector ExecuteGraph::GetAllNodes(const FastNodeFilter &fast_node_filter) const { - std::vector> subgraphs; - return AllGraphNodes(subgraphs, fast_node_filter); -} - -void ExecuteGraph::SetInputsOrder(const std::vector &inputs_order) { - inputs_order_ = inputs_order; -} - -void ExecuteGraph::ReorderByNodeId() { - graph_shared_->ReorderByNodeId(); -} - -void ExecuteGraph::SetGraphId(size_t graph_id) { - graph_shared_->SetGraphId(graph_id); -} - -size_t ExecuteGraph::GetGraphId() const { - return graph_shared_->GetGraphId(); -} - -ProtoAttrMap &ExecuteGraph::MutableAttrMap() { - return attrs_; -} - -ConstProtoAttrMap &ExecuteGraph::GetAttrMap() const { - return attrs_; -} - -bool ExecuteGraph::CheckNodeIsInGraph(const FastNode *const node) const { - return graph_shared_->CheckNodeIsInGraph(node); -} - -bool ExecuteGraph::CheckEdgeIsInGraph(const FastEdge *const edge) const { - return graph_shared_->CheckEdgeIsInGraph(edge); -} - -graphStatus ExecuteGraph::MoveEdgeToGraph(const FastEdge *const edge) { - if (edge == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The edge is nullptr."); - GE_LOGE("[Check][Param] The edge is nullptr."); - return GRAPH_FAILED; - } - - graph_shared_->MoveEdgeToGraph(edge); - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/fast_graph/fast_graph_impl.h b/graph/fast_graph/fast_graph_impl.h deleted file mode 100644 index 3b9cd85773a24c0ec8a47364bfedb31a5d413632..0000000000000000000000000000000000000000 --- a/graph/fast_graph/fast_graph_impl.h +++ /dev/null @@ -1,997 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef FAST_GRAPH_FAST_GRAPH_IMPL_H -#define FAST_GRAPH_FAST_GRAPH_IMPL_H - -#include -#include "graph/fast_graph/fast_node.h" -#include "quick_list.h" -#include "graph/utils/op_type_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_log.h" -#include "graph/ge_context.h" -#include "graph/utils/ge_ir_utils.h" -#include "common/ge_common/ge_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_attr_define.h" -#include "fast_graph_utils.h" - -namespace ge { -namespace { -constexpr int32_t kInvalidIndex = -1; -} // namespace -template -class FastGraphImpl { - public: - explicit FastGraphImpl(const std::string &name) - : name_(std::move(name)), parent_graph_(nullptr), parent_node_(nullptr) {} - FastGraphImpl() = default; - - ~FastGraphImpl() { - for (auto iter = nodes_.begin(); iter != nodes_.end();) { - auto quick_node = *iter; - iter = nodes_.erase(iter); - if (quick_node == nullptr) { - continue; - } - // for add extra ref count. - // if not the line, quick_node maybe release in SetNodePtr(nullptr). - if (FastGraphUtils::GetNode(quick_node).GetNodeBarePtr() != nullptr) { - auto node_ptr = FastGraphUtils::GetNode(quick_node).GetNodePtr(); - FastGraphUtils::GetNode(quick_node).ClearNodePtr(); - } else { - free_nodes_.push_back(quick_node, ListMode::kFreeMode); - ClearNodeRelateInfo(quick_node); - } - } - - for (auto iter = free_nodes_.begin(); iter != free_nodes_.end();) { - auto quick_node = *iter; - iter = free_nodes_.erase(iter); - if (quick_node != nullptr) { - delete quick_node; - } - } - - for (auto iter = free_edges_.begin(); iter != free_edges_.end();) { - auto edge = *iter; - iter = free_edges_.erase(iter); - if (edge != nullptr) { - delete edge; - } - } - - for (auto iter = sub_graphs_.begin(); iter != sub_graphs_.end();) { - auto item = *iter; - iter = sub_graphs_.erase(iter); - if (item != nullptr) { - delete item; - } - } - - for (auto iter = free_sub_graphs_.begin(); iter != free_sub_graphs_.end();) { - auto item = *iter; - iter = free_sub_graphs_.erase(iter); - if (item != nullptr) { - delete item; - } - } - } - - void SetOwnerGraph(GraphT *graph) { - owner_graph_ = graph; - } - - /** - * The function Set edge owner to the edges_. - */ - void MoveEdgeToGraph(const FastEdge *const edge) { - ListElement> *element = FastGraphUtils::GetListElementAddr(edge); - if ((FastGraphUtils::GetOwner(element) != &edges_) && (FastGraphUtils::GetOwner(element) != nullptr)) { - FastGraphUtils::GetOwner(element)->erase(element); - } - - if (FastGraphUtils::GetOwner(element) == &edges_) { - return; - } - - edges_.push_back(element, ListMode::kWorkMode); - } - - /** - * The function provide the deep copy of graph to other graph. - */ - graphStatus DeepCopy(const FastGraphImpl &graph) { - std::unordered_map origin_node_to_copy_node; - DeepCopyNodes(graph, origin_node_to_copy_node); - DeepCopyEdges(graph, origin_node_to_copy_node); - DeepCopySubGraphs(graph); - DeepCopyInputNodes(graph, origin_node_to_copy_node); - DeepCopyOutputNodes(graph, origin_node_to_copy_node); - - name_ = graph.name_; - graph_id_ = graph.graph_id_; - parent_graph_ = graph.parent_graph_; - parent_node_ = graph.parent_node_; - graph_netoutput_ = graph.graph_netoutput_; - extend_info_ = graph.extend_info_; - return GRAPH_SUCCESS; - } - - graphStatus SetNodesAfterSorting(const std::vector &nodes) { - nodes_.clear(); - for (size_t i = 0UL; i < nodes.size(); i++) { - auto node = nodes[i]; - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr or op_desc should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or op_desc should not be null."); - return PARAM_INVALID; - } - node->GetOpDescBarePtr()->SetId(static_cast(i)); - PushBackToNodeList(FastGraphUtils::GetListElementAddr(node)); - } - return GRAPH_SUCCESS; - } - - /** - * The function is provide: - * 1. set node to nodes_ - * 2. if it is data, set it to input_nodes_ - */ - graphStatus SetNodes(const std::vector &nodes) { - nodes_.clear(); - for (size_t i = 0UL; i < nodes.size(); i++) { - auto node = nodes[i]; - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr or op_desc should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or op_desc should not be null."); - return PARAM_INVALID; - } - RecordNodeAndInputDataToGraph(node); - } - return GRAPH_SUCCESS; - } - - NodeT *AddInputNode(NodeT *const node) { - if (node->GetExtendInfo() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node extend info should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node extend info should not be null."); - return nullptr; - } - - auto input_idx = node->GetExtendInfo()->GetInputIndex(); - auto already_exist_flag = (input_idx >= 0) && (input_idx < static_cast(input_nodes_.size())) && - (node == input_nodes_[input_idx]); - if (!already_exist_flag) { - node->GetExtendInfo()->SetInputIndex(input_nodes_.size()); - input_nodes_.push_back(node); - } - - if (CheckNodeIsInGraph(node)) { - return node; - } - - return AddNode(node); - } - - graphStatus RemoveInputNode(NodeT *const node) { - if ((node == nullptr) || (node->GetExtendInfo() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null, graph:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - auto input_idx = node->GetExtendInfo()->GetInputIndex(); - if ((input_idx >= 0) && (input_idx < static_cast(input_nodes_.size())) && - (node == input_nodes_[input_idx])) { - input_nodes_[input_idx] = nullptr; - node->GetExtendInfo()->SetInputIndex(kInvalidIndex); - return SUCCESS; - } - - GELOGW("[Remove][Node] Failed to remove input node."); - return GRAPH_FAILED; - } - - NodeT *AddOutputNodeByIndex(NodeT *const node, const int32_t index) { - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr) || node->GetExtendInfo() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr or opdesc should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or opdesc should not be null."); - return nullptr; - } - - bool already_have = false; - NodeT *result = node; - // [output_nodes_info_ : should not be null] - auto &output_idxs = node->GetExtendInfo()->GetOutputIndex(); - for (auto item : output_idxs) { - auto flag = (item >= 0) && (item < static_cast(output_nodes_.size())) && - (node == output_nodes_[item].first) && (output_nodes_[item].second == item); - if (flag) { - already_have = true; - result = output_nodes_[item].first; - break; - } - } - - if (!already_have) { - node->GetExtendInfo()->AddOneOutputIndex(output_nodes_.size()); - output_nodes_.emplace_back(std::make_pair(node, index)); - GELOGI("Push back node name:%s, index:%d, into output_nodes_info_.", node->GetName().c_str(), index); - } - - if (!CheckNodeIsInGraph(node)) { - GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "[Add][Node] failed"); - } - - return result; - } - - const std::vector &GetAllInputNodeInfo() const { - return input_nodes_; - } - - const std::vector> &GetAllOutNodeInfo() const { - return output_nodes_; - } - - std::vector GetInputNodes() const { - return input_nodes_; - } - - std::vector> GetAllOutNodes() const { - return output_nodes_; - } - - void SetGraphOutNodesInfo(const std::vector> &out_nodes_info) { - output_nodes_ = out_nodes_info; - } - - void AppendGraphOutNodesInfo(std::vector> &out_nodes_info) { - (void)output_nodes_.insert(output_nodes_.cend(), out_nodes_info.cbegin(), out_nodes_info.cend()); - } - - graphStatus RemoveOutputNode(const NodeT *const node) { - if ((node == nullptr) || (node->GetExtendInfo() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null, graph:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - bool find_node = false; - // [output_nodes_info_ : should not be null] - auto &output_idxs = node->GetExtendInfo()->GetOutputIndex(); - auto iter = output_idxs.begin(); - while (iter != output_idxs.end()) { - auto item = *iter; - auto flag = - (item >= 0) && (item < static_cast(output_nodes_.size())) && (node == output_nodes_[item].first); - if (flag) { - output_nodes_[item] = {nullptr, -1}; - find_node = true; - iter = output_idxs.erase(iter); - } else { - ++iter; - } - } - - GE_IF_BOOL_EXEC(!find_node, return GRAPH_FAILED); - return GRAPH_SUCCESS; - } - - NodeT *AddNode(NodeT *const fast_node) { - if ((fast_node == nullptr) || (fast_node->GetOpDescBarePtr() == nullptr) || - (fast_node->GetExtendInfo() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "the node ptr or op desc ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or op desc ptr should not be null."); - return nullptr; - } - - fast_node->GetExtendInfo()->SetHostNode(extend_info_.is_valid_flag_); - fast_node->GetOpDescBarePtr()->SetId(static_cast(GetDirectNodesSize())); - RecordNodeAndInputDataToGraph(fast_node); - return fast_node; - } - - NodeT *AddNode(const OpDescPtr &op) { - NodeT *node_ptr = CreateOneNode(op, static_cast(GetDirectNodesSize())); - return AddNode(node_ptr); - } - - NodeT *AddNode(const OpDescPtr &op, const int64_t id) { - NodeT *node_ptr = CreateOneNode(op, id); - if ((node_ptr == nullptr) || (node_ptr->GetExtendInfo() == nullptr)) { - return nullptr; - } - node_ptr->GetExtendInfo()->SetHostNode(extend_info_.is_valid_flag_); - RecordNodeAndInputDataToGraph(node_ptr); - return node_ptr; - } - - NodeT *AddNodeFront(NodeT *const fast_node) { - if ((fast_node == nullptr) || (fast_node->GetOpDescBarePtr() == nullptr) || - (fast_node->GetExtendInfo() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr or op desc should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or op desc should not be null."); - return nullptr; - } - - fast_node->GetExtendInfo()->SetHostNode(extend_info_.is_valid_flag_); - fast_node->GetOpDescBarePtr()->SetId(static_cast(GetDirectNodesSize())); - - auto quick_node = FastGraphUtils::GetListElementAddr(fast_node); - if (FastGraphUtils::GetOwner(quick_node) != nullptr) { - FastGraphUtils::GetOwner(quick_node)->erase(quick_node); - } - - auto pos = nodes_.begin(); - if (*pos == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node begin ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node begin ptr should not be null."); - return nullptr; - } - - if ((GetDirectNodesSize() > 0UL) && ((*pos)->data.GetType() == DATA)) { - pos = std::next(nodes_.begin()); - } - - GELOGD("[insert][NodeT] node = %p.", quick_node); - nodes_.insert(pos, quick_node, ListMode::kWorkMode); - fast_node->GetExtendInfo()->SetOwnerGraph(owner_graph_, fast_node); - CheckAndRecordInputNode(fast_node); - return fast_node; - } - - NodeT *AddNodeFront(const OpDescPtr &op) { - NodeT *node_ptr = CreateOneNode(op, static_cast(GetDirectNodesSize())); - return AddNodeFront(node_ptr); - } - - graphStatus RemoveJustNode(ListElement *node_ptr) { - if ((node_ptr == nullptr) || (node_ptr->owner == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - if (FastGraphUtils::GetOwner(node_ptr) != &nodes_) { - if ((FastGraphUtils::GetMode(node_ptr) == ListMode::kFreeMode) && - (FastGraphUtils::GetOwner(node_ptr) != nullptr)) { - return GRAPH_SUCCESS; - } - /* already add to other graph, so it can`t remove from current graph. */ - GELOGW("[Remove][Node] The node is not in the graph, please check the node."); - return GRAPH_NOT_CHANGED; - } - - (void)nodes_.erase(node_ptr); - if (FastGraphUtils::GetNode(node_ptr).GetNodeBarePtr() == nullptr) { - free_nodes_.push_back(node_ptr, ListMode::kFreeMode); - } - - return GRAPH_SUCCESS; - } - - graphStatus RecycleQuickNode(ListElement *const quick_node) { - if (quick_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - if (FastGraphUtils::GetOwner(quick_node) != nullptr) { - FastGraphUtils::GetOwner(quick_node)->erase(quick_node); - } - - free_nodes_.push_back(quick_node, ListMode::kFreeMode); - return GRAPH_SUCCESS; - } - - graphStatus RecycleQuickEdge(ListElement> *const list_edge) { - if (list_edge == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - if (FastGraphUtils::GetOwner(list_edge) != nullptr) { - FastGraphUtils::GetOwner(list_edge)->erase(list_edge); - } - - free_edges_.push_back(list_edge, ListMode::kFreeMode); - return GRAPH_SUCCESS; - } - - graphStatus UpdateNodePos(ListElement *const need_move_node, ListElement *const dst_node, - bool before_insert) { - return nodes_.move(need_move_node, dst_node, before_insert); - } - - FastEdge *AddEdge(NodeT *const src, int32_t src_index, NodeT *const dst, int32_t dst_index) { - if ((src == nullptr) || (dst == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return nullptr; - } - - auto io_index_valid_flag = ((src_index == kControlEdgeIndex) && (dst_index != kControlEdgeIndex)) || - ((src_index != kControlEdgeIndex) && (dst_index == kControlEdgeIndex)); - if (io_index_valid_flag) { - REPORT_INNER_ERR_MSG("E18888", "Failed to check output index [%d] or input index [%d].", src_index, dst_index); - GELOGE(GRAPH_FAILED, "[check][index] Failed to check output index [%d] or input index [%d].", src_index, - dst_index); - return nullptr; - } - - ListElement> *edge = nullptr; - if (free_edges_.empty()) { - edge = new (std::nothrow) ListElement>; - } else { - auto iter = free_edges_.begin(); - edge = *iter; - free_edges_.erase(iter); - } - if (edge == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to malloc memory for edge."); - GELOGE(GRAPH_FAILED, "[malloc][edge] Failed to malloc memory for edge."); - return nullptr; - } - - FastGraphUtils::GetEdgeSrc(edge) = src; - FastGraphUtils::GetEdgeDst(edge) = dst; - FastGraphUtils::GetEdgeSrcOutput(edge) = src_index; - FastGraphUtils::GetEdgeDstInput(edge) = dst_index; - - auto ret = src->RecordEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionOutType); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to record edge in the output node."); - GELOGE(GRAPH_FAILED, "[malloc][edge] Failed to record edge in the output node."); - free_edges_.push_back(edge, ListMode::kFreeMode); - return nullptr; - } - - ret = dst->RecordEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionInType); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to record edge in the input node."); - GELOGE(GRAPH_FAILED, "[malloc][edge] Failed to record edge in the input node."); - src->EraseEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionOutType); - free_edges_.push_back(edge, ListMode::kFreeMode); - return nullptr; - } - - edges_.push_back(edge, ListMode::kWorkMode); - return &FastGraphUtils::GetEdge(edge); - } - - /** - * currently, The function only remove edge belongs to self graph. - */ - graphStatus RemoveEdge(ListElement> *const edge) { - if (edge == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The edge ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The edge ptr should not be null."); - return GRAPH_FAILED; - } - - if (FastGraphUtils::GetOwner(edge) != &edges_) { - if ((FastGraphUtils::GetMode(edge) == ListMode::kFreeMode) && (FastGraphUtils::GetOwner(edge) != nullptr)) { - return GRAPH_SUCCESS; - } - - GELOGW("[Remove][Edge] The edge is not in the graph, please check the edge."); - return GRAPH_NOT_CHANGED; - } - - if (FastGraphUtils::GetEdgeSrc(edge) != nullptr) { - FastGraphUtils::GetEdgeSrc(edge)->EraseEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionOutType); - FastGraphUtils::GetEdgeSrc(edge) = nullptr; - } - - if (FastGraphUtils::GetEdgeDst(edge) != nullptr) { - FastGraphUtils::GetEdgeDst(edge)->EraseEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionInType); - FastGraphUtils::GetEdgeDst(edge) = nullptr; - } - - edges_.erase(edge); - free_edges_.push_back(edge, ListMode::kFreeMode); - return GRAPH_SUCCESS; - } - - ListElement *AddSubGraph(GraphT *const sub_graph) { - ListElement *quick_graph = nullptr; - if (free_sub_graphs_.empty()) { - quick_graph = new (std::nothrow) ListElement; - } else { - auto iter = free_sub_graphs_.begin(); - quick_graph = *iter; - free_sub_graphs_.erase(iter); - } - - if (quick_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to create a subgraph."); - GELOGE(GRAPH_FAILED, "[Create][SubGraph] Failed to create a subgraph."); - return nullptr; - } - - quick_graph->data = sub_graph; - sub_graphs_.push_back(quick_graph, ListMode::kWorkMode); - return quick_graph; - } - - graphStatus RemoveSubGraph(ListElement *const sub_graph) { - if (FastGraphUtils::GetOwner(sub_graph) != &sub_graphs_) { - if ((FastGraphUtils::GetMode(sub_graph) == ListMode::kFreeMode) && - (FastGraphUtils::GetOwner(sub_graph) != nullptr)) { - return GRAPH_SUCCESS; - } - - /* already add to other graph, so it can`t remove from current graph. */ - GELOGW("[Remove][SubGraph] The sub graph is not in the graph, please check the sub graph."); - return GRAPH_NOT_CHANGED; - } - - sub_graphs_.erase(sub_graph); - free_sub_graphs_.push_back(sub_graph, ListMode::kFreeMode); - return GRAPH_SUCCESS; - } - - void ClearAllSubGraph() { - ClearGraphs(); - } - - void SetGraphId(size_t graph_id) { - graph_id_ = graph_id; - } - - size_t GetGraphId(void) const { - return graph_id_; - } - - void SetParentNode(NodeT *const parent_node) { - parent_node_ = parent_node; - } - - NodeT *GetParentNode(void) const { - return parent_node_; - } - - void SetParentGraph(GraphT *const parent_graph) { - parent_graph_ = parent_graph; - } - - const GraphT *GetParentGraph(void) const { - return parent_graph_; - } - - GraphT *GetParentGraph(void) { - return parent_graph_; - } - - const QuickList &GetAllNodeInfo(void) const { - return nodes_; - } - - QuickList &GetAllNodeInfoForModify(void) { - return nodes_; - } - - void SetAllInputNodeInfo(const std::vector &inputs) { - input_nodes_.swap(inputs); - } - - size_t GetAllSubGraphSize(void) const { - return sub_graphs_.size(); - } - - void SetNetOutputNode(NodeT *const netoutput_node) { - graph_netoutput_ = netoutput_node; - } - - void SetName(const std::string &name) { - name_ = name; - } - std::string GetName() const { - return name_; - } - - size_t GetDirectNodesSize() const { - return nodes_.size(); - } - - const QuickList &GetRawDirectNode() const { - return nodes_; - } - - QuickList &GetDirectNodeToModify() const { - return nodes_; - } - - QuickList> &GetRawAllEdges() const { - return edges_; - } - - QuickList &GetRawAllSubgraphs() const { - return sub_graphs_; - } - - std::vector GetDirectNode() const { - return nodes_.CollectAllPtrItemToVector(); - } - - std::vector *> GetAllEdges() const { - return edges_.CollectAllPtrItemToVector(); - } - - std::vector GetAllSubgraphs() const { - return sub_graphs_.CollectAllItemToVector(); - } - - const ListElement *FindNode(size_t token) const { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - if (FastGraphUtils::GetConstNode(node).GetNodeToken() == token) { - return node; - } - } - return nullptr; - } - - bool IsValid() const { - return extend_info_.is_valid_flag_; - } - - void SetValidFlag(bool flag) { - extend_info_.is_valid_flag_ = flag; - } - - void InValid() { - extend_info_.is_valid_flag_ = false; - } - - bool operator==(const FastGraphImpl &r_graph) const { - return (IsEqual(this->name_, r_graph.name_, "graph.name") && - IsEqual(this->graph_id_, r_graph.graph_id_, "graph.graph_id") && - IsEqual(this->GetDirectNodesSize(), r_graph.GetDirectNodesSize(), "graph.nodes.size()") && - IsEqual(this->edges_.size(), r_graph.edges_.size(), "graph.edge.size()") && - IsEqual(this->sub_graphs_.size(), r_graph.sub_graphs_.size(), "graph.sub_graph.size()") && - IsEqual(this->parent_graph_, r_graph.parent_graph_, "graph.parent_graph") && - IsEqual(this->parent_node_, r_graph.parent_node_, "graph.parent_node") && - IsEqual(this->graph_netoutput_, r_graph.graph_netoutput_, "graph.graph_netoutput") && - IsEqual(this->extend_info_.is_valid_flag_, r_graph.extend_info_.is_valid_flag_, "graph.is_valid_flag_")); - } - - void Swap(FastGraphImpl &graph) { - name_.swap(graph.name_); - std::swap(graph_id_, graph.graph_id_); - nodes_.swap(graph.nodes_); - edges_.swap(graph.edges_); - input_nodes_.swap(graph.input_nodes_); - output_nodes_.swap(graph.output_nodes_); - sub_graphs_.swap(graph.sub_graphs_); - - std::swap(parent_graph_, graph.parent_graph_); - std::swap(parent_node_, graph.parent_node_); - std::swap(graph_netoutput_, graph.graph_netoutput_); - std::swap(extend_info_.is_valid_flag_, graph.extend_info_.is_valid_flag_); - } - - void SetSubGraph(const std::vector &sub_graphs) { - ClearGraphs(); - std::for_each(sub_graphs.begin(), sub_graphs.end(), [this](GraphT *graph) { - if (graph != nullptr) { - AddSubGraph(graph); - } - }); - } - - bool CheckNodeIsInGraph(const FastNode *const node) { - ListElement *quick_node = FastGraphUtils::GetListElementAddr(node); - return FastGraphUtils::GetOwner(quick_node) == &nodes_; - } - - bool CheckEdgeIsInGraph(const FastEdge *const edge) { - ListElement> *element = FastGraphUtils::GetListElementAddr(edge); - return FastGraphUtils::GetOwner(element) == &edges_; - } - - graphStatus ClearNode(const std::function &clear_oper) { - auto iter = nodes_.begin(); - while (iter != nodes_.end()) { - QuickNode *quick_node = *iter; - ++iter; - auto ret = clear_oper(quick_node); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - return GRAPH_SUCCESS; - } - - void ReorderByNodeId() { - nodes_.sort([](const ListElement *lhs, const ListElement *rhs) { - return FastGraphUtils::GetConstNode(lhs).GetOpDescBarePtr()->GetId() < - FastGraphUtils::GetConstNode(rhs).GetOpDescBarePtr()->GetId(); - }); - } - - /** - * The Function is used to delete edge which is not in the graph. - * it can used for the following scenarios: - * 1. Remove node a in graph A; - * 2. Add node a in graph B; - * 3. remove edges of all node a in graph A. - */ - void ForceDeleteEdge(FastEdge *const e) { - if (e == nullptr) { - return; - } - auto quick_edge = FastGraphUtils::GetListElementAddr(e); - if (FastGraphUtils::GetEdgeDst(quick_edge) != nullptr) { - FastGraphUtils::GetEdgeDst(quick_edge)->EraseEdge(e, DirectionType::kDirectionInType); - FastGraphUtils::GetEdgeDst(quick_edge) = nullptr; - } - - if (FastGraphUtils::GetEdgeSrc(quick_edge) != nullptr) { - FastGraphUtils::GetEdgeSrc(quick_edge)->EraseEdge(e, DirectionType::kDirectionOutType); - FastGraphUtils::GetEdgeSrc(quick_edge) = nullptr; - } - - if (FastGraphUtils::GetOwner(quick_edge) != nullptr) { - FastGraphUtils::GetOwner(quick_edge)->erase(quick_edge); - free_edges_.push_back(quick_edge, ListMode::kFreeMode); - } else { - // The rt2 move edge of two nodes into same graph. - // it modify the owner of nodes, but not modify the owner of edges. - // it will be result in the inability to delete the edge. - // Therefore, if check the edge is nullptr, it push to the free edges - free_edges_.push_back(quick_edge, ListMode::kFreeMode); - } - } - - private: - graphStatus ClearNodeRelateInfo(ListElement *const node_ptr) { - FastGraphUtils::GetNode(node_ptr).RemoveAllEdge([this](FastEdge *e) { ForceDeleteEdge(e); }); - return GRAPH_SUCCESS; - } - - void CheckAndRecordInputNode(NodeT *const node) { - if ((node == nullptr) || (node->GetExtendInfo() == nullptr)) { - return; - } - - auto input_idx = node->GetExtendInfo()->GetInputIndex(); - auto already_exist_flag = (input_idx >= 0) && (input_idx < static_cast(input_nodes_.size())) && - (node == input_nodes_[input_idx]); - if (OpTypeUtils::IsDataNode(node->GetType()) && (!already_exist_flag)) { - node->GetExtendInfo()->SetInputIndex(input_nodes_.size()); - input_nodes_.push_back(node); - } - } - - /** - * the input paramter (node) can`t be empty. - * it need to check in upper-layer functions. - */ - void PushBackToNodeList(ListElement *const node) { - if (FastGraphUtils::GetOwner(node) != nullptr) { - node->owner->erase(node); - } - GELOGD("[Add][NodeT] node = %p.", node); - nodes_.push_back(node, ListMode::kWorkMode); - } - - void RecordNodeAndInputDataToGraph(NodeT *const node) { - PushBackToNodeList(FastGraphUtils::GetListElementAddr(node)); - node->GetExtendInfo()->SetOwnerGraph(owner_graph_, node); - CheckAndRecordInputNode(node); - } - - void DeepCopyNodes(const FastGraphImpl &other_graph, - std::unordered_map &origin_node_to_copy_node) { - ClearNodes(); - for (auto iter = other_graph.nodes_.begin(); iter != other_graph.nodes_.end(); ++iter) { - auto origin_node = *iter; - if (origin_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node is nullptr in src graph."); - GELOGE(GRAPH_FAILED, "[DeepCopyNodes] The node is nullptr in src graph."); - continue; - } - - OpDescPtr opdesc_ptr = std::make_shared(*(FastGraphUtils::GetConstNode(origin_node).GetOpDescPtr())); - auto copy_node = AddNode(opdesc_ptr); - if (copy_node == nullptr) { - continue; - } - origin_node_to_copy_node.insert(std::make_pair(&FastGraphUtils::GetConstNode(origin_node), copy_node)); - } - } - - void DeepCopyEdges(const FastGraphImpl &other_graph, - const std::unordered_map &origin_node_to_copy_node) { - ClearEdges(); - for (auto origin_edge : other_graph.edges_) { - if (origin_edge == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The edge is nullptr in src graph."); - GELOGE(GRAPH_FAILED, "[DeepCopyEdges] The edge is nullptr in src graph."); - continue; - } - - NodeT *new_src_node = nullptr; - NodeT *new_dst_node = nullptr; - auto map_iter = origin_node_to_copy_node.find(FastGraphUtils::GetConstEdgeSrc(origin_edge)); - if (map_iter != origin_node_to_copy_node.end()) { - new_src_node = map_iter->second; - } else { - // 这里不需要报错,是为了支持部分拷贝的情况 - GELOGI("[DeepCopyEdges] Can not find src node to add edge, skip it."); - continue; - } - - map_iter = origin_node_to_copy_node.find(FastGraphUtils::GetConstEdgeDst(origin_edge)); - if (map_iter != origin_node_to_copy_node.end()) { - new_dst_node = map_iter->second; - } else { - // 这里不需要报错,是为了支持部分拷贝的情况 - GELOGI("[DeepCopyEdges] Can not find dst node to add edge, skip it."); - continue; - } - - auto copy_edge = AddEdge(new_src_node, FastGraphUtils::GetConstEdgeSrcOutput(origin_edge), new_dst_node, - FastGraphUtils::GetConstEdgeDstInput(origin_edge)); - if (copy_edge == nullptr) { - continue; - } - } - } - - void DeepCopySubGraphs(const FastGraphImpl &other_graph) const { - owner_graph_->ClearAllSubGraph(); - for (auto iter = other_graph.sub_graphs_.begin(); iter != other_graph.sub_graphs_.end(); ++iter) { - const ListElement *origin_graph_listnode = *iter; - if ((origin_graph_listnode == nullptr) || (origin_graph_listnode->data == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The sub graph is nullptr in src graph."); - GELOGE(GRAPH_FAILED, "[DeepCopySubGraphs] The sub graph is nullptr in src graph."); - continue; - } - - auto name = origin_graph_listnode->data->GetName(); - std::shared_ptr copy_sub_graph = std::make_shared(name); - if (copy_sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to new sub graph."); - GELOGE(GRAPH_FAILED, "[DeepCopySubGraphs] Failed to new sub graph."); - continue; - } - copy_sub_graph->CompleteCopy(*(origin_graph_listnode->data)); - owner_graph_->AddSubGraph(copy_sub_graph, name); - } - } - - void DeepCopyInputNodes(const FastGraphImpl &other_graph, - const std::unordered_map &origin_node_to_copy_node) { - input_nodes_.clear(); - if (other_graph.input_nodes_.empty()) { - return; - } - for (auto &item : other_graph.input_nodes_) { - auto origin_node = item; - auto old_iter = origin_node_to_copy_node.find(origin_node); - if (old_iter != origin_node_to_copy_node.end()) { - input_nodes_.push_back(old_iter->second); - } - } - } - - void DeepCopyOutputNodes(const FastGraphImpl &other_graph, - const std::unordered_map &origin_node_to_copy_node) { - output_nodes_.clear(); - if (other_graph.output_nodes_.empty()) { - return; - } - for (auto &item : other_graph.output_nodes_) { - auto origin_node = item.first; - - auto old_iter = origin_node_to_copy_node.find(origin_node); - if (old_iter != origin_node_to_copy_node.end()) { - output_nodes_.push_back(std::make_pair(old_iter->second, item.second)); - } - } - } - - void ClearNodes() { - auto iter = nodes_.begin(); - while (iter != nodes_.end()) { - auto quick_node = *iter; - iter = nodes_.erase(iter); - free_nodes_.push_back(quick_node, ListMode::kFreeMode); - } - } - - void ClearEdges() { - auto iter = edges_.begin(); - while (iter != edges_.end()) { - auto quick_edge = *iter; - iter = edges_.erase(iter); - free_edges_.push_back(quick_edge, ListMode::kFreeMode); - } - } - - void ClearGraphs() { - auto iter = sub_graphs_.begin(); - while (iter != sub_graphs_.end()) { - auto quick_graph = *iter; - iter = sub_graphs_.erase(iter); - free_sub_graphs_.push_back(quick_graph, ListMode::kFreeMode); - } - } - - NodeT *CreateOneNode(const OpDescPtr &op, int64_t id) { - if (op == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The OpDesc ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(id); - - ListElement *node_ptr = nullptr; - if (free_nodes_.empty()) { - node_ptr = new (std::nothrow) ListElement(); - } else { - auto iter = free_nodes_.begin(); - node_ptr = *iter; - free_nodes_.erase(iter); - } - - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create node failed."); - GELOGE(GRAPH_FAILED, "[Create][Node] node_ptr is NULL!!!"); - return nullptr; - } - - auto fast_node = &FastGraphUtils::GetNode(node_ptr); - auto ret = fast_node->Init(op); - if (ret != GRAPH_SUCCESS) { - return nullptr; - } - return fast_node; - } - - private: - friend GraphUtils; - friend ExecuteGraph; - friend class ExecuteGraphAdapter; - friend class ExecuteGraphUtils; - - GraphT *owner_graph_ = nullptr; - std::string name_; - size_t graph_id_ = 0UL; - // node - mutable QuickList nodes_; - QuickList free_nodes_; - // edge - mutable QuickList> edges_; - QuickList> free_edges_; - // io - std::vector input_nodes_; - std::vector> output_nodes_; - // subgraph - mutable QuickList sub_graphs_; - QuickList free_sub_graphs_; - - GraphT *parent_graph_ = nullptr; - NodeT *parent_node_ = nullptr; - NodeT *graph_netoutput_ = nullptr; - - GraphExtendInfo extend_info_; -}; - -} // namespace ge -#endif // FAST_GRAPH_NEW_GRAPH_IMPL_H diff --git a/graph/fast_graph/fast_graph_utils.h b/graph/fast_graph/fast_graph_utils.h deleted file mode 100644 index a74f715db329c9208ab8a96613a21799571e0c6e..0000000000000000000000000000000000000000 --- a/graph/fast_graph/fast_graph_utils.h +++ /dev/null @@ -1,237 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef FAST_GRAPH_FAST_GRAPH_UTILS_H -#define FAST_GRAPH_FAST_GRAPH_UTILS_H - -#include -#include "graph/anchor.h" -#include "quick_list.h" -#include "graph/fast_graph/execute_graph.h" -#include "graph/utils/tensor_utils.h" - -namespace ge { -enum class FastWalkStatus { kNotWalked, kWalking, kWalked }; -struct NodeStatus { - size_t size = 0U; - FastWalkStatus status; -}; - -struct GraphExtendInfo { - bool is_valid_flag_ = false; -}; - -using QuickNode = ListElement; -using QuickEdge = ListElement>; -using QuickGraph = ListElement; - -class FastGraphUtils { - public: - static inline Edge &GetEdge(QuickEdge *const quick_edge) { - return quick_edge->data; - } - - static inline FastNode *&GetEdgeSrc(QuickEdge *const quick_edge) { - return quick_edge->data.src; - } - - static inline FastNode *const &GetConstEdgeSrc(const QuickEdge *const quick_edge) { - return quick_edge->data.src; - } - - static inline FastNode *&GetEdgeDst(QuickEdge *const quick_edge) { - return quick_edge->data.dst; - } - - static inline FastNode *const &GetConstEdgeDst(const QuickEdge *const quick_edge) { - return quick_edge->data.dst; - } - - static inline int32_t &GetEdgeSrcOutput(QuickEdge *const quick_edge) { - return quick_edge->data.src_output; - } - - static inline int32_t GetConstEdgeSrcOutput(const QuickEdge *const quick_edge) { - return quick_edge->data.src_output; - } - - static inline int32_t &GetEdgeDstInput(QuickEdge *const quick_edge) { - return quick_edge->data.dst_input; - } - - static inline int32_t GetConstEdgeDstInput(const QuickEdge *const quick_edge) { - return quick_edge->data.dst_input; - } - - static inline int32_t &GetEdgeInEdgeIndex(QuickEdge *const quick_edge) { - return quick_edge->data.in_edge_index; - } - - static inline int32_t &GetEdgeOutEdgeIndex(QuickEdge *const quick_edge) { - return quick_edge->data.out_edge_index; - } - - static inline ExecuteGraph *GetGraph(const ListElement *const quick_graph) { - return quick_graph->data; - } - - static inline ComputeGraph *GetComputeGraph(const ListElement *const compute_graph) { - return compute_graph->data; - } - - static inline FastNode &GetNode(QuickNode *const quick_node) { - return quick_node->data; - } - - static inline const FastNode &GetConstNode(const QuickNode *const quick_node) { - return quick_node->data; - } - - template - static inline ListMode &GetMode(ListElement *const list_element) { - return list_element->mode; - } - - template - static inline QuickList *GetOwner(ListElement *const list_element) { - return list_element->owner; - } - - static inline QuickNode *GetListElementAddr(const FastNode *const fast_node) { - const auto offset = reinterpret_cast(&reinterpret_cast(0)->data); - return reinterpret_cast(reinterpret_cast(fast_node) - offset); - } - - static inline QuickEdge *GetListElementAddr(const FastEdge *const edge) { - return reinterpret_cast(reinterpret_cast(edge) - offsetof(QuickEdge, data)); - } -}; - -template -int64_t GetNodeOutputSize(NODE_T *node, std::vector &reverse_dfs_nodes_info) { - int64_t total_size = 0LL; - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) { - return total_size; - } - - NodeStatus &reverse_dfs_node_info = reverse_dfs_nodes_info[static_cast(node->GetOpDescBarePtr()->GetId())]; - total_size = reverse_dfs_node_info.size; - if (total_size != 0) { - return total_size; - } - for (const auto &out_desc : node->GetOpDescBarePtr()->GetAllOutputsDescPtr()) { - if (out_desc == nullptr) { - continue; - } - int64_t output_size = 0LL; - (void)ge::TensorUtils::CalcTensorMemSize(out_desc->GetShape(), out_desc->GetFormat(), out_desc->GetDataType(), - output_size); - total_size += output_size; - } - if (total_size != 0) { - reverse_dfs_node_info.size = total_size; - } - return total_size; -} - -template -struct NodeCmp { - explicit NodeCmp(std::vector *reverse_dfs_nodes_info) : reverse_dfs_nodes_info_(reverse_dfs_nodes_info) {} - bool operator()(NODE_T *lhs, NODE_T *rhs) const { - const auto lhs_size = GetNodeOutputSize(lhs, *reverse_dfs_nodes_info_); - const auto rhs_size = GetNodeOutputSize(rhs, *reverse_dfs_nodes_info_); - if (lhs_size == rhs_size) { - return strcmp(lhs->GetNamePtr(), rhs->GetNamePtr()) > 0; - } - return lhs_size > rhs_size; - } - std::vector *reverse_dfs_nodes_info_; -}; - -template -struct NodeOutInfo { - NodeOutInfo(NODE_T *node, std::vector *reverse_dfs_nodes_info) - : num_out_data_nodes(node->GetAllOutEdgesSize()), - output_size(GetNodeOutputSize(node, *reverse_dfs_nodes_info)), - node_name(node->GetName()) {} - - bool operator<(const NodeOutInfo &rhs) const { - if (num_out_data_nodes < rhs.num_out_data_nodes) { - return true; - } - if (num_out_data_nodes > rhs.num_out_data_nodes) { - return false; - } - if (output_size < rhs.output_size) { - return true; - } - if (output_size > rhs.output_size) { - return false; - } - return node_name < rhs.node_name; - } - - int64_t num_out_data_nodes; - int64_t output_size; - std::string node_name; -}; - -template -class TopoSortStack { - public: - explicit TopoSortStack(std::vector *reverse_dfs_nodes_info, const bool is_mem_priority = false, - const bool is_dfs = false, const bool is_reverse_dfs = false) - : is_mem_priority_(is_mem_priority), - is_dfs_(is_dfs), - is_reverse_dfs_(is_reverse_dfs), - reverse_dfs_nodes_info_(reverse_dfs_nodes_info) {} - - NODE_T *Pop() { - if (is_mem_priority_ && (!is_reverse_dfs_)) { - const auto &it = mem_priority_stack_.cbegin(); - NODE_T *node = it->second; - (void)mem_priority_stack_.erase(it); - return node; - } - NODE_T *node = normal_stack_.back(); - normal_stack_.pop_back(); - return node; - } - - void Push(NODE_T *node) { - if (is_mem_priority_ && (!is_reverse_dfs_)) { - (void)mem_priority_stack_.emplace(NodeOutInfo(node, reverse_dfs_nodes_info_), node); - return; - } - - if (is_dfs_) { - (void)normal_stack_.insert(normal_stack_.end(), node); - } else { - (void)normal_stack_.insert(normal_stack_.begin(), node); - } - } - - bool Empty() { - if (is_mem_priority_ && (!is_reverse_dfs_)) { - return mem_priority_stack_.empty(); - } - return normal_stack_.empty(); - } - - private: - bool is_mem_priority_; - bool is_dfs_; - bool is_reverse_dfs_; - std::vector *reverse_dfs_nodes_info_; - std::list normal_stack_; - std::map, NODE_T *> mem_priority_stack_; -}; - -} // namespace ge -#endif // FAST_GRAPH_FAST_GRAPH_UTILS_H diff --git a/graph/fast_graph/fast_node.cc b/graph/fast_graph/fast_node.cc deleted file mode 100644 index ae763f5aec253818ad1d665c4f039d6cff635ebc..0000000000000000000000000000000000000000 --- a/graph/fast_graph/fast_node.cc +++ /dev/null @@ -1,876 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/fast_graph/fast_node.h" -#include -#include -#include "common/checker.h" -#include "utils/ge_ir_utils.h" -#include "fast_graph_utils.h" -#include "graph/debug/ge_op_types.h" - -namespace ge { -namespace { -const std::vector kEmpty; -} -FastNode::FastNode() {} - -FastNode::~FastNode() {} - -graphStatus FastNode::Init(const OpDescPtr &op) { - opdesc_ = op; - data_in_num_ = op->GetAllInputsSize(); - data_out_num_ = op->GetOutputsSize(); - node_token_ = reinterpret_cast(this); - - return Reset(); -} - -graphStatus FastNode::Reset() { - if (extend_info_ != nullptr) { - /* The node in cache. it need to init before using it. */ - in_data_edges_.clear(); - in_control_edges_.clear(); - out_data_edges_.clear(); - out_control_edges_.clear(); - out_data_edges_info_.per_edges_num.clear(); - - in_data_edges_count_ = 0UL; - in_control_edge_count_ = 0UL; - out_control_edges_count_ = 0UL; - out_data_edges_info_.total_num = 0UL; - extend_info_->Clear(); - } else { - extend_info_ = ComGraphMakeUnique(); - GE_CHK_BOOL_EXEC(extend_info_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Failed to allocate memory for extend informations"); - return GRAPH_FAILED, "[Check][Param] Failed to allocate memory for extend informations"); - } - - extend_info_->UpdateInputSymbols(data_in_num_); - extend_info_->UpdateOutputSymbols(data_out_num_); - UpdateDataForIoNumChange(); - return GRAPH_SUCCESS; -} - -void FastNode::UpdateDataForIoNumChange() { - if ((out_data_edges_info_.per_edges_num.size() != data_out_num_) || (data_in_num_ != in_data_edges_.size()) || - (data_out_num_ != out_data_edges_.size())) { - out_data_edges_info_.per_edges_num.resize(data_out_num_, 0UL); - in_data_edges_.resize(data_in_num_, nullptr); - out_data_edges_.resize(data_out_num_); - } -} - -OpDescPtr FastNode::GetOpDescPtr() const { - return opdesc_; -} - -OpDesc *FastNode::GetOpDescBarePtr() const { - return opdesc_.get(); -} - -std::string FastNode::GetType() const { - GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return std::string(), "[Check][Param] original OpDesc is nullptr"); - return opdesc_->GetType(); -} - -std::string FastNode::GetName() const { - GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return std::string(), "[Check][Param] original OpDesc is nullptr"); - return opdesc_->GetName(); -} - -const char *FastNode::GetNamePtr() const { - GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return nullptr, "[Check][Param] original OpDesc is nullptr"); - return opdesc_->GetNamePtr(); -} - -const char *FastNode::GetTypePtr() const { - GE_CHK_BOOL_EXEC(opdesc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return nullptr, "[Check][Param] original OpDesc is nullptr"); - return opdesc_->GetTypePtr(); -} - -bool FastNode::operator==(const FastNode &r_node) const { - return (IsEqual(name_, r_node.name_, "node.name") && IsEqual(node_token_, r_node.node_token_, "node.token") && - IsEqual(opdesc_, r_node.opdesc_, "node.opdesc_") && IsEqual(opdesc_, r_node.opdesc_, "node.opdesc_") && - IsEqual(self_ptr_, r_node.self_ptr_, "node.self_ptr_") && - IsEqual(data_in_num_, r_node.data_in_num_, "node.data_in_num_") && - IsEqual(data_out_num_, r_node.data_out_num_, "node.data_out_num_") && - IsEqual(in_data_edges_, r_node.in_data_edges_, "node.in_data_edges_") && - IsEqual(out_data_edges_, r_node.out_data_edges_, "node.out_data_edges_") && - IsEqual(in_control_edges_, r_node.in_control_edges_, "node.in_control_edges_") && - IsEqual(out_control_edges_, r_node.out_control_edges_, "node.out_control_edges_") && - IsEqual(in_data_edges_count_, r_node.in_data_edges_count_, "node.in_data_edges_count_") && - IsEqual(in_control_edge_count_, in_control_edge_count_, "node.in_control_edge_count_") && - IsEqual(*extend_info_, *(r_node.extend_info_), "node.extend_info_") && - IsEqual(out_data_edges_info_.total_num, r_node.out_data_edges_info_.total_num, - "node.out_data_edges_info_.total_num")); -} - -graphStatus FastNode::RecordInControlEdge(FastEdge *const edge) { - edge->in_edge_index = in_control_edges_.size(); - in_control_edges_.push_back(edge); - in_control_edge_count_++; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::RecordOutControlEdge(FastEdge *const edge) { - edge->out_edge_index = out_control_edges_.size(); - out_control_edges_.push_back(edge); - out_control_edges_count_++; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::RecordInDataEdge(FastEdge *const edge, int32_t index) { - if (!CheckDataIndexIsValid(index, DirectionType::kDirectionInType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of in edge.", index, data_in_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of in edge.", index, data_in_num_); - return GRAPH_FAILED; - } - - if (in_data_edges_[index] != nullptr) { - // if index > 0 ,it is data edge. the input edges must be empty. - REPORT_INNER_ERR_MSG("E18888", "Failed to record edge in node [%s] for multiple input.", GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Record][Edge] Failed to record edge in node [%s] for multiple input.", GetName().c_str()); - return GRAPH_FAILED; - } - - in_data_edges_[index] = edge; - in_data_edges_count_++; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::RecordOutDataEdge(FastEdge *const edge, int32_t index) { - if (!CheckDataIndexIsValid(index, DirectionType::kDirectionOutType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_); - return GRAPH_FAILED; - } - - out_data_edges_[index].push_back(edge); - out_data_edges_info_.total_num++; - out_data_edges_info_.per_edges_num[index]++; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::RecordEdge(FastEdge *const edge, DirectionType type) { - if (type == DirectionType::kDirectionInType) { - int32_t index = edge->dst_input; - if (index == kControlEdgeIndex) { - return RecordInControlEdge(edge); - } - edge->in_edge_index = 0; - return RecordInDataEdge(edge, index); - } - - int32_t index = edge->src_output; - if (index == kControlEdgeIndex) { - return RecordOutControlEdge(edge); - } - GE_ASSERT_TRUE(static_cast(index) < out_data_edges_.size()); - edge->out_edge_index = out_data_edges_[index].size(); - return RecordOutDataEdge(edge, index); -} - -graphStatus FastNode::EraseInControlEdge(const FastEdge *const edge) { - GE_ASSERT_TRUE(static_cast(edge->in_edge_index) < in_control_edges_.size()); - GE_ASSERT_TRUE(in_control_edges_[edge->in_edge_index] == edge); - in_control_edges_[edge->in_edge_index] = nullptr; - in_control_edge_count_--; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::EraseOutControlEdge(const FastEdge *const edge) { - GE_ASSERT_TRUE(static_cast(edge->out_edge_index) < out_control_edges_.size()); - GE_ASSERT_TRUE(out_control_edges_[edge->out_edge_index] == edge); - out_control_edges_[edge->out_edge_index] = nullptr; - out_control_edges_count_--; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::EraseInDataEdge(const FastEdge *const edge) { - GE_ASSERT_NOTNULL(edge); - const int32_t index = edge->dst_input; - if (!CheckDataIndexIsValid(index, DirectionType::kDirectionInType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_in_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_in_num_); - return GRAPH_FAILED; - } - GE_ASSERT_TRUE(in_data_edges_[index] == edge); - in_data_edges_[index] = nullptr; - in_data_edges_count_--; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::EraseOutDataEdge(const FastEdge *const edge, int32_t index) { - if (!CheckDataIndexIsValid(index, DirectionType::kDirectionOutType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_); - return GRAPH_FAILED; - } - GE_ASSERT_TRUE(static_cast(edge->out_edge_index) < out_data_edges_[index].size()); - GE_ASSERT_TRUE(out_data_edges_[index][edge->out_edge_index] == edge); - out_data_edges_[index][edge->out_edge_index] = nullptr; - out_data_edges_info_.total_num--; - out_data_edges_info_.per_edges_num[index]--; - return GRAPH_SUCCESS; -} - -graphStatus FastNode::EraseEdge(const FastEdge *const edge, DirectionType type) { - if (type == DirectionType::kDirectionOutType) { - int32_t index = edge->src_output; - if (index == kControlEdgeIndex) { - return EraseOutControlEdge(edge); - } - - return EraseOutDataEdge(edge, index); - } - - int32_t index = edge->dst_input; - if (index == kControlEdgeIndex) { - return EraseInControlEdge(edge); - } - - return EraseInDataEdge(edge); -} - -graphStatus FastNode::CheckAllInputParamter(DirectionType type, int32_t io_idx, int32_t cur_array_index, - int32_t replace_array_index) const { - if (io_idx < -1) { - REPORT_INNER_ERR_MSG("E18888", "The idx[%d] exceed the max capacity of in_edges.", io_idx); - GELOGE(GRAPH_FAILED, "[Check][Param] The idx[%d] exceed the max capacity of in_edges.", io_idx); - return GRAPH_FAILED; - } - - size_t io_size = 0UL; - size_t edge_size = 0UL; - - if (io_idx != kControlEdgeIndex) { - if (type == DirectionType::kDirectionInType) { - io_size = data_in_num_; - } else if (type == DirectionType::kDirectionOutType) { - io_size = data_out_num_; - } - - if (io_size <= static_cast(io_idx)) { - REPORT_INNER_ERR_MSG("E18888", "The idx [%d] exceed the max capacity [%zu] of in_edges.", io_idx, io_size); - GELOGE(GRAPH_FAILED, "[Check][Param] The idx [%d] exceed the max capacity [%zu] of in_edges.", io_idx, io_size); - return GRAPH_FAILED; - } - } - - if (io_idx == kControlEdgeIndex) { - if (type == DirectionType::kDirectionInType) { - edge_size = in_control_edges_.size(); - } else if (type == DirectionType::kDirectionOutType) { - edge_size = out_control_edges_.size(); - } - } else { - if (type == DirectionType::kDirectionInType) { - edge_size = 1; - } else if (type == DirectionType::kDirectionOutType) { - edge_size = out_data_edges_[io_idx].size(); - } - } - - if ((edge_size <= static_cast(replace_array_index)) || (edge_size <= static_cast(cur_array_index))) { - REPORT_INNER_ERR_MSG("E18888", - "The replace index [%d] or current index [%d] exceed the max capacity [%zu] of in_edges.", - replace_array_index, cur_array_index, edge_size); - GELOGE(GRAPH_FAILED, - "[Check][Param] The replace index [%d] or current index [%d] exceed the max capacity [%zu] of in_edges.", - replace_array_index, cur_array_index, edge_size); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus FastNode::MoveEdge(DirectionType type, int32_t io_idx, int32_t cur_array_index, - int32_t replace_array_index) { - auto ret = CheckAllInputParamter(type, io_idx, cur_array_index, replace_array_index); - if (ret != GRAPH_SUCCESS) { - return ret; - } - - if (type == DirectionType::kDirectionInType) { - FastEdge *edge = nullptr; - if (io_idx == kControlEdgeIndex) { - in_control_edges_[replace_array_index] = in_control_edges_[cur_array_index]; - in_control_edges_[cur_array_index] = nullptr; - edge = in_control_edges_[replace_array_index]; - } else { - // don`t need to do something. - } - - if (edge != nullptr) { - edge->in_edge_index = replace_array_index; - } - } else if (type == DirectionType::kDirectionOutType) { - FastEdge *edge = nullptr; - if (io_idx == kControlEdgeIndex) { - out_control_edges_[replace_array_index] = out_control_edges_[cur_array_index]; - out_control_edges_[cur_array_index] = nullptr; - edge = out_control_edges_[replace_array_index]; - } else { - out_data_edges_[io_idx][replace_array_index] = out_data_edges_[io_idx][cur_array_index]; - out_data_edges_[io_idx][cur_array_index] = nullptr; - edge = out_data_edges_[io_idx][replace_array_index]; - } - if (edge != nullptr) { - edge->out_edge_index = replace_array_index; - } - } - - return GRAPH_SUCCESS; -} - -size_t FastNode::GetAllInEdgeSize() const { - return in_control_edge_count_ + in_data_edges_count_; -} - -const std::vector *> &FastNode::GetAllInDataEdgesRef() const { - return in_data_edges_; -} - -std::vector *> &FastNode::MutableAllInDataEdges() { - return in_data_edges_; -} - -const std::vector *> &FastNode::GetAllInControlEdgesRef() const { - return in_control_edges_; -} - -const std::vector *> &FastNode::GetAllOutControlEdgesRef() const { - return out_control_edges_; -} - -const std::vector *>> &FastNode::GetAllOutDataEdgesRef() const { - return out_data_edges_; -} - -std::vector *> FastNode::GetAllInDataEdges() const { - std::vector tmp; - tmp.reserve(in_data_edges_count_); - - std::for_each(in_data_edges_.begin(), in_data_edges_.end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge); - } - }); - return tmp; -} - -std::vector *> FastNode::GetAllInControlEdges() const { - std::vector tmp; - tmp.reserve(in_control_edge_count_); - - std::for_each(in_control_edges_.begin(), in_control_edges_.end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge); - } - }); - return tmp; -} - -std::vector *> FastNode::GetAllOutControlEdges() const { - std::vector tmp; - tmp.reserve(out_control_edges_count_); - - std::for_each(out_control_edges_.begin(), out_control_edges_.end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge); - } - }); - return tmp; -} - -std::vector *> FastNode::GetAllOutDataEdges() const { - std::vector tmp; - tmp.reserve(out_data_edges_info_.total_num); - - for (size_t i = 0UL; i < out_data_edges_.size(); i++) { - std::for_each(out_data_edges_[i].begin(), out_data_edges_[i].end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge); - } - }); - } - return tmp; -} - -inline bool FastNode::CheckDataIndexIsValid(int32_t index, DirectionType type) const { - if (type == DirectionType::kDirectionOutType) { - return ((index >= 0) && (index < static_cast(data_out_num_))); - } else if (type == DirectionType::kDirectionInType) { - return ((index >= 0) && (index < static_cast(data_in_num_))); - } - return false; -} - -bool FastNode::OutNodesIsEmpty() const { - return (out_data_edges_info_.total_num + out_control_edges_count_ == 0); -} - -size_t FastNode::GetAllOutEdgesSize() const { - return out_control_edges_count_ + out_data_edges_info_.total_num; -} - -size_t FastNode::GetAllOutDataEdgesSize() const { - return out_data_edges_info_.total_num; -} - -size_t FastNode::GetAllOutControlEdgesSize() const { - return out_control_edges_count_; -} - -size_t FastNode::GetAllInDataEdgesSize() const { - return in_data_edges_count_; -} - -size_t FastNode::GetAllInControlEdgesSize() const { - return in_control_edge_count_; -} - -std::vector FastNode::GetAllOutNodes() const { - std::vector tmp; - tmp.reserve(out_control_edges_count_ + out_data_edges_info_.total_num); - for (size_t i = 0UL; i < out_data_edges_.size(); i++) { - std::for_each(out_data_edges_[i].begin(), out_data_edges_[i].end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge->dst); - } - }); - } - - std::for_each(out_control_edges_.begin(), out_control_edges_.end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge->dst); - } - }); - - return tmp; -} - -std::vector FastNode::GetAllInNodes() const { - std::vector tmp; - tmp.reserve(in_control_edge_count_ + in_data_edges_count_); - std::for_each(in_data_edges_.begin(), in_data_edges_.end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge->src); - } - }); - - std::for_each(in_control_edges_.begin(), in_control_edges_.end(), [&tmp](FastEdge *edge) { - if (edge != nullptr) { - tmp.push_back(edge->src); - } - }); - - return tmp; -} - -std::vector FastNode::GetInDataNodes() const { - std::vector in_data_nodes; - in_data_nodes.reserve(in_data_edges_count_); - - auto &ref = GetAllInDataEdgesRef(); - std::for_each(ref.begin(), ref.end(), [&in_data_nodes](FastEdge *edge) { - if (edge != nullptr) { - in_data_nodes.push_back(edge->src); - } - }); - - return in_data_nodes; -} - -std::vector FastNode::GetOutDataNodesByIndex(int32_t index) const { - std::vector out_data_nodes; - - if (!CheckDataIndexIsValid(index, DirectionType::kDirectionOutType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", index, data_out_num_); - return out_data_nodes; - } - - out_data_nodes.reserve(out_data_edges_info_.per_edges_num[index]); - - auto &ref = GetOutEdgesRefByIndex(index); - std::for_each(ref.begin(), ref.end(), [&out_data_nodes](FastEdge *edge) { - if (edge != nullptr) { - out_data_nodes.push_back(edge->dst); - } - }); - - return out_data_nodes; -} - -std::vector FastNode::GetOutDataNodes() const { - std::vector out_nodes; - - out_nodes.reserve(out_data_edges_info_.total_num); - for (size_t i = 0UL; i < out_data_edges_.size(); i++) { - std::for_each(out_data_edges_[i].begin(), out_data_edges_[i].end(), [&out_nodes](FastEdge *edge) { - if (edge != nullptr) { - out_nodes.push_back(edge->dst); - } - }); - } - - return out_nodes; -} - -std::vector FastNode::GetOutControlNodes() const { - std::vector out_ctrl_nodes; - - out_ctrl_nodes.reserve(out_control_edges_count_); - for (const auto &edge : out_control_edges_) { - if (edge != nullptr) { - out_ctrl_nodes.push_back(edge->dst); - } - } - return out_ctrl_nodes; -} - -std::vector FastNode::GetInControlNodes() const { - std::vector in_ctrl_nodes; - - in_ctrl_nodes.reserve(in_control_edge_count_); - for (const auto &edge : in_control_edges_) { - if (edge != nullptr) { - in_ctrl_nodes.push_back(edge->src); - } - } - return in_ctrl_nodes; -} - -size_t FastNode::GetNodeToken() const { - return node_token_; -} - -size_t FastNode::GetInEdgesSizeByIndex(int32_t idx) const { - if (idx == kControlEdgeIndex) { - return in_control_edge_count_; - } - - if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionInType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx, data_in_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx, data_in_num_); - return 0UL; - } - - if (in_data_edges_[idx] != nullptr) { - return 1UL; - } - - return 0UL; -} - -size_t FastNode::GetOutEdgesSizeByIndex(int32_t idx) const { - if (idx == kControlEdgeIndex) { - return out_control_edges_count_; - } - - if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionOutType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx + 1, data_in_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx + 1, data_in_num_); - return 0UL; - } - - if (out_data_edges_info_.per_edges_num.size() <= static_cast(idx)) { - return 0UL; - } - - return out_data_edges_info_.per_edges_num[idx]; -} - -Edge *FastNode::GetInDataEdgeByIndex(int32_t idx) const { - if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionInType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of in edge.", idx, data_in_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of in edge.", idx, data_in_num_); - return nullptr; - } - - return in_data_edges_[idx]; -} - -bool FastNode::IsDirectlyControlledByNode(FastNode const *node) const { - for (const auto in_ctrl_edge : in_control_edges_) { - if ((in_ctrl_edge != nullptr) && (in_ctrl_edge->src != nullptr) && (in_ctrl_edge->src == node)) { - return true; - } - } - return false; -} -std::vector *> FastNode::GetOutEdgesByIndex(int32_t idx) const { - if (idx == kControlEdgeIndex) { - return out_control_edges_; - } - - if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionOutType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_); - return std::vector *>{}; - } - - std::vector tmp; - tmp.reserve(out_data_edges_info_.per_edges_num[idx]); - - for (size_t i = 0UL; i < out_data_edges_[idx].size(); i++) { - auto edge = out_data_edges_[idx][i]; - if (edge != nullptr) { - tmp.push_back(edge); - } - } - - return tmp; -} - -const std::vector *> &FastNode::GetOutEdgesRefByIndex(int32_t idx) const { - if (idx == kControlEdgeIndex) { - return out_control_edges_; - } - - if (!CheckDataIndexIsValid(idx, DirectionType::kDirectionOutType)) { - REPORT_INNER_ERR_MSG("E18888", "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_); - GELOGE(GRAPH_FAILED, "The index [%d] exceeds the size [%zu] of out edge.", idx, data_out_num_); - return kEmpty; - } - - return out_data_edges_[idx]; -} - -graphStatus FastNode::ModifySizeByNodeType(const FastEdge *const fast_edge, size_t &in_edge_size) const { - if ((fast_edge != nullptr) && (fast_edge->src != nullptr)) { - auto type = fast_edge->src->GetType(); - if ((strcmp(type.c_str(), NEXTITERATION) == 0) || (strcmp(type.c_str(), REFNEXTITERATION) == 0)) { - GE_IF_BOOL_EXEC(in_edge_size == 0UL, - GELOGE(GRAPH_FAILED, "[Check][Param] If [in_edge_size = 0], the result will be reversed"); - return GRAPH_FAILED); - in_edge_size--; - } - } - - return GRAPH_SUCCESS; -} - -size_t FastNode::GetInEdgeSize() const { - size_t in_edge_size = GetAllInEdgeSize(); - auto &edges = GetAllInDataEdgesRef(); - for (size_t i = 0UL; i < edges.size(); i++) { - auto edge = edges[i]; - if (edge == nullptr) { - continue; - } - auto ret = ModifySizeByNodeType(edge, in_edge_size); - if (ret != GRAPH_SUCCESS) { - return 0; - } - } - return in_edge_size; -} - -void FastNode::RemoveAllEdge(std::function *)> const &remove_edge_func) { - for (size_t i = 0UL; i < in_data_edges_.size(); ++i) { - auto edge = in_data_edges_[i]; - if (edge != nullptr) { - remove_edge_func(edge); - } - } - - for (size_t i = 0UL; i < in_control_edges_.size(); ++i) { - auto edge = in_control_edges_[i]; - if (edge != nullptr) { - remove_edge_func(edge); - } - } - - for (size_t i = 0UL; i < out_data_edges_.size(); ++i) { - for (size_t j = 0UL; j < out_data_edges_[i].size(); ++j) { - auto edge = out_data_edges_[i][j]; - if (edge != nullptr) { - remove_edge_func(edge); - } - } - } - - for (size_t i = 0UL; i < out_control_edges_.size(); ++i) { - auto edge = out_control_edges_[i]; - if (edge != nullptr) { - remove_edge_func(edge); - } - } - - return; -} - -size_t FastNode::GetDataInNum() const { - return data_in_num_; -} - -size_t FastNode::GetDataOutNum() const { - return data_out_num_; -} - -void FastNode::UpdateDataInNum(size_t new_num) { - data_in_num_ = new_num; - UpdateDataForIoNumChange(); - extend_info_->UpdateInputSymbols(data_in_num_); -} - -void FastNode::UpdateDataOutNum(size_t new_num) { - data_out_num_ = new_num; - UpdateDataForIoNumChange(); - extend_info_->UpdateOutputSymbols(data_out_num_); -} - -void FastNode::SetNodePtr(const std::shared_ptr &node) { - self_ptr_ = node; - node_bare_ptr_ = node.get(); -} - -void FastNode::ClearNodePtr() { - self_ptr_ = nullptr; -} - -void FastNode::ClearNodeBarePtr() { - node_bare_ptr_ = nullptr; -} - -std::shared_ptr FastNode::GetNodePtr() const { - if (self_ptr_ != nullptr) { - return self_ptr_; - } - - if (node_bare_ptr_ != nullptr) { - return node_bare_ptr_->shared_from_this(); - } - - return nullptr; -} - -Node *FastNode::GetNodeBarePtr() const { - return node_bare_ptr_; -} - -void FastNode::UpdateOpDesc(const OpDescPtr &new_opdesc) { - if (new_opdesc == nullptr) { - opdesc_.reset(); - return; - } - opdesc_ = new_opdesc; -} - -ExtendInfo *FastNode::GetExtendInfo() const { - return extend_info_.get(); -} - -void ExtendInfo::Clear() { - execute_graph_ = nullptr; - output_index_.clear(); - input_index_ = kControlEdgeIndex; - host_node_ = false; - input_symbols_.clear(); - output_symbols_.clear(); -} - -void ExtendInfo::SetInputIndex(int32_t idx) { - input_index_ = idx; -} - -int32_t ExtendInfo::GetInputIndex() const { - return input_index_; -} - -void ExtendInfo::AddOneOutputIndex(int32_t idx) { - output_index_.push_back(idx); -} - -std::vector &ExtendInfo::GetOutputIndex() { - return output_index_; -} - -ExecuteGraph *ExtendInfo::GetOwnerGraphBarePtr() const { - return execute_graph_; -} - -graphStatus ExtendInfo::SetOwnerGraph(ExecuteGraph *const graph, const FastNode *const fast_node) { - if ((execute_graph_ != nullptr) && (graph != execute_graph_)) { - auto quick_node = FastGraphUtils::GetListElementAddr(fast_node); - auto owner = quick_node->owner; - auto mode = quick_node->mode; - if ((owner != nullptr) && (mode == ListMode::kFreeMode)) { - owner->erase(quick_node); - } - } - execute_graph_ = graph; - return GRAPH_SUCCESS; -} - -bool ExtendInfo::operator==(const ExtendInfo &r_info) const { - return (IsEqual(execute_graph_, r_info.execute_graph_, "node.execute_graph_") && - IsEqual(input_index_, r_info.input_index_, "node.input_index_") && - IsEqual(output_index_, r_info.output_index_, "node.output_index_")); -} - -bool ExtendInfo::GetHostNode() const { - return host_node_; -} - -void ExtendInfo::SetHostNode(const bool is_host) { - host_node_ = is_host; -} - -void ExtendInfo::UpdateInputSymbols(size_t data_in_num) { - input_symbols_.resize(data_in_num, kInvalidSymbol); -} - -void ExtendInfo::UpdateOutputSymbols(size_t data_out_num) { - output_symbols_.resize(data_out_num, kInvalidSymbol); -} - -graphStatus ExtendInfo::SetInputSymbol(size_t idx, uint64_t symbol) { - if (!IsDataIndexValid(idx, input_symbols_)) { - return GRAPH_FAILED; - } - input_symbols_[idx] = symbol; - return GRAPH_SUCCESS; -} - -graphStatus ExtendInfo::SetOutputSymbol(size_t idx, uint64_t symbol) { - if (!IsDataIndexValid(idx, output_symbols_)) { - return GRAPH_FAILED; - } - output_symbols_[idx] = symbol; - return GRAPH_SUCCESS; -} - -uint64_t ExtendInfo::GetInputSymbol(size_t idx) { - if (!IsDataIndexValid(idx, input_symbols_)) { - return kInvalidSymbol; - } - return input_symbols_[idx]; -} - -uint64_t ExtendInfo::GetOutputSymbol(size_t idx) { - if (!IsDataIndexValid(idx, output_symbols_)) { - return kInvalidSymbol; - } - return output_symbols_[idx]; -} - -bool ExtendInfo::IsDataIndexValid(size_t idx, const std::vector &symbols) const { - GE_ASSERT_TRUE(idx < symbols.size(), "The index [%zu] exceeds the size [%zu] of symbols.", idx, symbols.size()); - return true; -} -} // namespace ge diff --git a/graph/fast_graph/quick_list.h b/graph/fast_graph/quick_list.h deleted file mode 100644 index 43ff05a2ae0fc030c38e888fd264342239b3e4ae..0000000000000000000000000000000000000000 --- a/graph/fast_graph/quick_list.h +++ /dev/null @@ -1,446 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -// This header file provides a C-style list, which is a data structure that manages ListElement object pointers. -// Note that it does not manage object ownership. when the list is released, the memory occupied by elements -// is not released. To ensure the normal running of the program, use the following method when destructing the list: -// 1. Erase each element form the list where the upper-layer application uses the list. -// 2. After the erase command is executed, the upper-layer application releases the memory occupied by the -// corresponding element. - -#ifndef D_INC_GRAPH_QUICK_LIST_H -#define D_INC_GRAPH_QUICK_LIST_H -#include -#include -#include -#include -#include -#include -#include -#include "graph/fast_graph/list_element.h" - -namespace ge { -template -struct QuickIterator { - using Self = QuickIterator; - using Element = ListElement; - using difference_type = ptrdiff_t; - using iterator_category = std::bidirectional_iterator_tag; - using value_type = T; - using pointer = T *; - using reference = T &; - - QuickIterator() noexcept : element_() {} - - explicit QuickIterator(Element *const x) noexcept : element_(x) {} - - pointer operator->() const noexcept { - return &(operator*()); - } - - // Must downcast from _List_element_base to _List_element to get to value. - Element *operator*() const noexcept { - return element_; - } - - Self &operator++() noexcept { - element_ = element_->next; - return *this; - } - - Self &operator--() noexcept { - element_ = element_->prev; - return *this; - } - - Self operator++(int) noexcept { - Self tmp = *this; - element_ = element_->next; - return tmp; - } - - Self operator--(int) noexcept { - Self tmp = *this; - element_ = element_->prev; - return tmp; - } - - friend bool operator!=(const Self &x, const Self &y) noexcept { - return x.element_ != y.element_; - } - - friend bool operator==(const Self &x, const Self &y) noexcept { - return x.element_ == y.element_; - } - - Element *element_; -}; - -template -struct ConstQuickIterator { - using Self = ConstQuickIterator; - using Element = const ListElement; - using iterator_category = std::bidirectional_iterator_tag; - using iterator = QuickIterator; - using difference_type = ptrdiff_t; - using value_type = T; - using pointer = const T *; - using reference = const T &; - - ConstQuickIterator() noexcept : element_() {} - - explicit ConstQuickIterator(Element *const x) noexcept : element_(x) {} - - explicit ConstQuickIterator(const iterator &x) noexcept : element_(x.element_) {} - - Element *operator*() const noexcept { - return element_; - } - - pointer operator->() const noexcept { - return &(operator*()); - } - - Self &operator++() noexcept { - element_ = element_->next; - return *this; - } - - Self operator++(int) noexcept { - Self tmp = *this; - element_ = element_->next; - return tmp; - } - - Self &operator--() noexcept { - element_ = element_->prev; - return *this; - } - - Self operator--(int) noexcept { - Self tmp = *this; - element_ = element_->prev; - return tmp; - } - - friend bool operator==(const Self &x, const Self &y) noexcept { - return x.element_ == y.element_; - } - - friend bool operator!=(const Self &x, const Self &y) noexcept { - return x.element_ != y.element_; - } - - Element *element_; -}; - -// in traversal, it is quickly to list -template -class QuickList { - using Element = ListElement; - - public: - using iterator = QuickIterator; - using const_iterator = ConstQuickIterator; - - /** - * Please note that this function does not release memory except head. - * it follow the principle: who applies for who release. - */ - ~QuickList() { - if (head_ != nullptr) { - clear(); - delete head_; - head_ = nullptr; - tail_ = nullptr; - } - } - - QuickList() { - init(); - } - - QuickList(const QuickList &list) = delete; - QuickList &operator=(const QuickList &list) = delete; - - QuickList(QuickList &&list) { - if (this == &list) { - return; - } - clear(); - iterator iter = list.begin(); - while (iter != list.end()) { - auto element = *iter; - iter = erase(iter); - push_back(element); - } - } - - QuickList &operator=(QuickList &&list) { - if (this == &list) { - return *this; - } - - clear(); - iterator iter = list.begin(); - while (iter != list.end()) { - auto element = *iter; - iter = erase(iter); - push_back(element); - } - return *this; - } - - void push_back(Element *const element, ListMode mode) { - tail_->next = element; - - element->prev = tail_; - element->next = head_; - element->owner = this; - element->mode = mode; - tail_ = element; - head_->prev = element; - - total_size_++; - } - - void insert(iterator pos, Element *element, ListMode mode) { - if (pos == begin()) { - Element *src_node = head_->next; - head_->next = element; - element->next = src_node; - element->prev = head_; - src_node->prev = element; - - if (tail_ == head_) { - tail_ = element; - } - } else if (pos == end()) { - element->next = tail_->next; - element->prev = tail_; - tail_->next = element; - tail_ = element; - head_->prev = tail_; - } else { - Element *cur = pos.element_; - Element *prev = cur->prev; - element->prev = prev; - element->next = cur; - - prev->next = element; - cur->prev = element; - } - - element->owner = this; - element->mode = mode; - total_size_++; - } - - int move(Element *const src_pos_value, Element *const dst_pos_value, bool before_flag = true) { - // action for src - Element *cur = src_pos_value; - Element *next = cur->next; - Element *prev = cur->prev; - prev->next = next; - next->prev = prev; - if (tail_ == cur) { - tail_ = prev; - } - - // action for dst - Element *dst = dst_pos_value; - if (before_flag) { - Element *dst_prev = dst->prev; - dst->prev = src_pos_value; - dst_prev->next = src_pos_value; - src_pos_value->prev = dst_prev; - src_pos_value->next = dst; - } else { - Element *dst_next = dst->next; - dst_next->prev = src_pos_value; - dst->next = src_pos_value; - src_pos_value->prev = dst; - src_pos_value->next = dst_next; - if (tail_ == dst) { - tail_ = src_pos_value; - } - } - - return 0; - } - - void push_front(Element *const x, ListMode mode) { - insert(begin(), x, mode); - } - - iterator erase(iterator &pos) { - auto element = pos.element_; - Element *next = element->next; - (void)erase(element); - return iterator(next); - } - - int erase(Element *const x) { - if (x->owner != this) { - return -1; - } - if (x->prev == nullptr) { - return 0; - } - if (tail_ == x) { - tail_ = x->prev; - } - Element *prev = x->prev; - Element *next = x->next; - - x->prev = nullptr; - x->next = nullptr; - x->owner = nullptr; - x->mode = ListMode::kFreeMode; - - prev->next = next; - next->prev = prev; - - total_size_--; - return 0; - } - - size_t size() const { - return total_size_; - } - - bool empty() const { - return head_->next == head_; - } - - void swap(QuickList &list) { - Element *tmp_head = this->head_; - Element *tmp_tail = this->tail_; - size_t tmp_total_size = this->total_size_; - - for (auto iter = list.begin(); iter != list.end(); ++iter) { - auto item = *iter; - item->owner = this; - } - - for (auto iter = begin(); iter != end(); ++iter) { - auto item = *iter; - item->owner = &list; - } - - this->head_ = list.head_; - this->tail_ = list.tail_; - this->head_->owner = this; - this->total_size_ = list.total_size_; - - list.head_ = tmp_head; - list.tail_ = tmp_tail; - list.total_size_ = tmp_total_size; - list.head_->owner = &list; - } - - void clear() { - iterator iter = begin(); - while (iter != end()) { - iter = erase(iter); - } - } - - iterator begin() { - return iterator(head_->next); - } - - iterator end() { - return iterator(head_); - } - - const_iterator begin() const { - return const_iterator(head_->next); - } - - const_iterator end() const { - return const_iterator(head_); - } - - void sort(const std::function *, ListElement *b)> &comp) { - if (total_size_ <= 1) { - return; - } - - std::list carry; - std::list tmp[64]; - std::list *fill = tmp; - std::list *counter = nullptr; - - do { - auto iter = begin(); - auto element = iter.element_; - erase(iter); - carry.insert(carry.begin(), element); - - for (counter = tmp; counter != fill && !counter->empty(); ++counter) { - counter->merge(carry, comp); - carry.swap(*counter); - } - carry.swap(*counter); - if (counter == fill) { - ++fill; - } - } while (!empty()); - - for (counter = tmp + 1; counter != fill; ++counter) { - counter->merge(*(counter - 1), comp); - } - clear(); - std::for_each((*(fill - 1)).begin(), (*(fill - 1)).end(), - [this](Element *element) { push_back(element, ListMode::kWorkMode); }); - } - - // for T is pointer - std::vector CollectAllItemToVector() const { - std::vector tmp; - tmp.reserve(size()); - for (auto iter = begin(); iter != end(); ++iter) { - auto item = *iter; - tmp.push_back(item->data); - } - - return tmp; - } - - // for T is obj - std::vector CollectAllPtrItemToVector() { - std::vector tmp; - tmp.reserve(size()); - for (auto iter = begin(); iter != end(); ++iter) { - auto item = *iter; - tmp.push_back(&(item->data)); - } - - return tmp; - } - - private: - void init() { - Element *mem = new Element; - head_ = mem; - tail_ = mem; - head_->next = head_; - head_->prev = head_; - head_->owner = this; - } - - private: - Element *head_ = nullptr; - Element *tail_ = nullptr; - size_t total_size_ = 0U; -}; -} // namespace ge - -#endif // D_INC_GRAPH_QUICK_LIST_H diff --git a/graph/hcom/hcom_topo_info.cc b/graph/hcom/hcom_topo_info.cc deleted file mode 100644 index c5f971812f85b728992fe332c85ca46375669ab7..0000000000000000000000000000000000000000 --- a/graph/hcom/hcom_topo_info.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/hcom/hcom_topo_info.h" - -#include "graph/debug/ge_log.h" -namespace ge { -Status HcomTopoInfo::SetGroupTopoInfo(const char_t *group, const HcomTopoInfo::TopoInfo &info) { - if (group == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Group key is nullptr,set failed."); - GELOGE(GRAPH_FAILED, "[Check][Param] Group key is nullptr,set failed."); - return GRAPH_FAILED; - } - { - const std::lock_guard lock(mutex_); - rank_info_[group] = info; - } - GELOGI("Add group %s successfully.", group); - return GRAPH_SUCCESS; -} - -Status HcomTopoInfo::GetGroupRankSize(const char_t *group, int64_t &rank_size) { - { - const std::lock_guard lock(mutex_); - const auto &iter_info = rank_info_.find(group); - if (iter_info == rank_info_.end()) { - REPORT_INNER_ERR_MSG("E18888", "Group key [%s] has not been added, get failed.", group); - GELOGE(GRAPH_FAILED, "[Check][Param] group key [%s] has not been added, get failed.", group); - return GRAPH_FAILED; - } - rank_size = iter_info->second.rank_size; - } - return GRAPH_SUCCESS; -} - -Status HcomTopoInfo::SetGroupOrderedStream(const int32_t device_id, const char_t *group, void *stream) { - if (group == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Group is nullptr,set failed."); - GELOGE(GRAPH_FAILED, "[Check][Param] group is nullptr,set failed."); - return GRAPH_FAILED; - } - { - const std::lock_guard lock(mutex_); - device_id_to_group_to_ordered_stream_[device_id][group] = stream; - } - GELOGI("Add device %d group %s stream %p successfully.", device_id, group, stream); - return GRAPH_SUCCESS; -} - -Status HcomTopoInfo::GetGroupOrderedStream(const int32_t device_id, const char_t *group, void *&stream) { - { - const std::lock_guard lock(mutex_); - const auto iter = device_id_to_group_to_ordered_stream_.find(device_id); - if (iter == device_id_to_group_to_ordered_stream_.end()) { - GELOGW("[Check][Param] device[%d] has not been added, get failed.", device_id); - return GRAPH_FAILED; - } - - const auto &iter_inner = iter->second.find(group); - if (iter_inner == iter->second.end()) { - GELOGW("[Check][Param] device[%d] group [%s] has not been added, get failed.", device_id, group); - return GRAPH_FAILED; - } - stream = iter_inner->second; - } - - return GRAPH_SUCCESS; -} - - void HcomTopoInfo::UnsetGroupOrderedStream(const int32_t device_id, const char_t *group) { - const std::lock_guard lock(mutex_); - auto iter = device_id_to_group_to_ordered_stream_.find(device_id); - if (iter != device_id_to_group_to_ordered_stream_.end()) { - (void) iter->second.erase(group); - if (iter->second.empty()) { - (void) device_id_to_group_to_ordered_stream_.erase(iter); - } - } - }; - -HcomTopoInfo::TopoDescs *HcomTopoInfo::GetGroupTopoDesc(const char_t *group) { - const std::lock_guard lock(mutex_); - const auto &iter_info = rank_info_.find(group); - if (iter_info == rank_info_.end()) { - REPORT_INNER_ERR_MSG("E18888", "Group key [%s] has not been added, get failed.", group); - GELOGE(GRAPH_FAILED, "[Check][Param] group key [%s] has not been added, get failed.", group); - return nullptr; - } - return &iter_info->second.topo_level_descs; -} - -Status HcomTopoInfo::GetGroupNotifyHandle(const char_t *group, void *¬ify_handle) { - { - const std::lock_guard lock(mutex_); - const auto &iter_info = rank_info_.find(group); - if (iter_info == rank_info_.end()) { - REPORT_INNER_ERR_MSG("E18888", "Group key [%s] has not been added, get failed.", group); - GELOGE(GRAPH_FAILED, "[Check][Param] group key [%s] has not been added, get failed.", group); - return GRAPH_FAILED; - } - notify_handle = iter_info->second.notify_handle; - } - return GRAPH_SUCCESS; -} - -HcomTopoInfo &HcomTopoInfo::Instance() { - static HcomTopoInfo hcom_topo_info; - return hcom_topo_info; -} - -bool HcomTopoInfo::TryGetGroupTopoInfo(const char_t *group, HcomTopoInfo::TopoInfo &info) { - { - const std::lock_guard lock(mutex_); - const auto &iter_info = rank_info_.find(group); - if (iter_info == rank_info_.end()) { - return false; - } - info = iter_info->second; - } - GELOGI("Get existed info of group %s successfully.", group); - return true; -} - -bool HcomTopoInfo::TopoInfoHasBeenSet(const char_t *group) { - const std::lock_guard lock(mutex_); - return rank_info_.find(group) != rank_info_.end(); -} - -} diff --git a/graph/ir/ir_data_type_symbol_store.cc b/graph/ir/ir_data_type_symbol_store.cc deleted file mode 100644 index a14b62f2c8f6a4c4c405b1191db17b7f17b6eb77..0000000000000000000000000000000000000000 --- a/graph/ir/ir_data_type_symbol_store.cc +++ /dev/null @@ -1,243 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ir/ir_data_type_symbol_store.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/ge_common/string_util.h" -#include "common/util/mem_utils.h" -#include "common/checker.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/op_desc_utils.h" -namespace ge { -namespace { -graphStatus UpdateOpOuputListDtype(const OpDescPtr &op, const size_t &start, const size_t &end, - const std::vector &dtypes) { - GE_ASSERT((end - start) == dtypes.size(), "Size mismatch when update %s output [%zu, %zu) with %zu dtypes", - op->GetName().c_str(), start, end, dtypes.size()); - for (size_t i = start; i < end; i++) { - auto desc = op->MutableOutputDesc(i); - GE_ASSERT_NOTNULL(desc); - GELOGI("Update op %s output %s:%zu with dtype %s", op->GetType().c_str(), op->GetName().c_str(), i, - TypeUtils::DataTypeToSerialString(dtypes[i - start]).c_str()); - desc->SetDataType(dtypes[i - start]); - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateOpOuputDtype(const OpDescPtr &op, const size_t &start, const size_t &end, const DataType &dtype) { - for (size_t i = start; i < end; i++) { - auto desc = op->MutableOutputDesc(i); - GE_ASSERT_NOTNULL(desc); - GELOGI("Update op %s output %s:%zu with dtype %s", op->GetType().c_str(), op->GetName().c_str(), i, - TypeUtils::DataTypeToSerialString(dtype).c_str()); - desc->SetDataType(dtype); - } - return GRAPH_SUCCESS; -} - -const char *ToString(const IrOutputType &type) { - if (type == kIrOutputRequired) { - return "Required"; - } - if (type == kIrOutputDynamic) { - return "Dynamic"; - } - return "Unknown"; -} - -std::string RemoveQuotes(const std::string &str) { - std::string result = str; - if (!result.empty() && result.front() == '\"') { - result.erase(result.begin()); - } - - if (!result.empty() && result.back() == '\"') { - result.pop_back(); - } - return result; -} -} // namespace - -bool IRDataTypeSymbolStore::IsSupportSymbolicInferDtype() const { - return !named_syms_.empty(); -} -graphStatus IRDataTypeSymbolStore::InferDtype(const OpDescPtr &op) const { - GE_ASSERT_NOTNULL(op); - GELOGD("Start infer output dtype for op %s by syms", op->GetName().c_str()); - std::map cached; // 缓存每个Sym的求值结果,避免重复求值 - for (auto &item : named_syms_) { - auto *sym = item.second; - GE_WARN_ASSERT_GRAPH_SUCCESS(sym->Eval(*op, cached[sym]), "Failed eval sym %s of op %s", sym->DebugString().c_str(), - op->GetType().c_str()); - GELOGD("Succeed eval and checking sym %s", sym->DebugString().c_str()); - } - - std::map> ir_output_2_range; - GE_WARN_ASSERT_GRAPH_SUCCESS(ge::GetIrOutputDescRange(op, ir_output_2_range)); - GE_WARN_ASSERT(ir_output_2_range.size() == op->GetIrOutputs().size(), "Failed get output instance info of %s %s", - op->GetName().c_str(), op->GetType().c_str()); - - GE_WARN_ASSERT(output_syms_.size() == op->GetIrOutputs().size(), "Op %s %s has %zu ir outputs, but %zu output syms", - op->GetName().c_str(), op->GetType().c_str(), op->GetIrOutputs().size(), output_syms_.size()); - // 对全部输出表达式进行求值,并更新到op上 - for (size_t i = 0U; i < output_syms_.size(); i++) { - auto *sym = output_syms_[i]; - GE_ASSERT_NOTNULL(sym); - if (sym->IsLegacy()) { - GELOGW("Trying infer legacy output %s(%s) of %s(%s)", sym->Id().c_str(), sym->DebugString().c_str(), - op->GetName().c_str(), op->GetType().c_str()); - return GRAPH_FAILED; - } - - TypeOrTypes type_or_types; - auto cached_iter = cached.find(sym); - if (cached_iter != cached.end()) { - type_or_types = cached_iter->second; - } else { - GE_WARN_ASSERT_GRAPH_SUCCESS(sym->Eval(*op, type_or_types)); - } - - auto &output_range = ir_output_2_range[i]; - size_t start = output_range.first; - size_t end = output_range.first + output_range.second; - - if (type_or_types.IsListType()) { // ListType表示输出为动态输出,并且每个输出的类型可以不同 - GE_WARN_ASSERT(output_name_and_types_[i].second == kIrOutputDynamic, - "Op %s %s output %s bind to list-type sym %s", op->GetType().c_str(), - ToString(output_name_and_types_[i].second), output_name_and_types_[i].first.c_str(), - sym->Id().c_str()); - GE_WARN_ASSERT_GRAPH_SUCCESS(UpdateOpOuputListDtype(op, start, end, type_or_types.UnsafeGetTypes())); - } else { - GE_WARN_ASSERT_GRAPH_SUCCESS(UpdateOpOuputDtype(op, start, end, type_or_types.UnsafeGetType())); - } - } - - return GRAPH_SUCCESS; -} - -// 创建一个命名Sym,用于符号出现早于IR输入的情况 -SymDtype *IRDataTypeSymbolStore::GetOrCreateSymbol(const std::string &origin_sym_id) { - std::string sym_id = RemoveQuotes(origin_sym_id); - for (auto &sym : syms_) { - if (sym->Id() == sym_id) { - return sym.get(); - } - } - auto sym = MakeShared(sym_id); - GE_ASSERT_NOTNULL(sym, "Failed create symbol %s", sym_id.c_str()); - syms_.emplace_back(sym); - return syms_.back().get(); -} - -SymDtype *IRDataTypeSymbolStore::SetInputSymbol(const std::string &ir_input, IrInputType input_type, - const std::string &sym_id) { - auto *sym = GetOrCreateSymbol(sym_id); - GE_ASSERT_NOTNULL(sym); - sym->BindIrInput(ir_input, input_type, num_ir_inputs++); - return sym; -} - -// 调用DATATYPE声明时,绑定Sym的取值范围 -SymDtype *IRDataTypeSymbolStore::DeclareSymbol(const std::string &sym_id, const TensorType &types) { - auto *sym = GetOrCreateSymbol(sym_id); - GE_ASSERT_NOTNULL(sym); - (void) named_syms_.emplace(sym_id, sym); - sym->BindAllowedDtypes(types); - return sym; -} - -SymDtype *IRDataTypeSymbolStore::DeclareSymbol(const std::string &sym_id, const ListTensorType &types) { - auto *sym = GetOrCreateSymbol(sym_id); - GE_ASSERT_NOTNULL(sym); - (void) named_syms_.emplace(sym_id, sym); - sym->BindAllowedDtypes(types); - return sym; -} - -SymDtype *IRDataTypeSymbolStore::DeclareSymbol(const std::string &sym_id, const Promote &types) { - std::vector syms; - for (auto &id : types.Syms()) { - GE_ASSERT(id != sym_id, "Trying promote symbol %s with itself", sym_id.c_str()); - auto *sym = GetOrCreateSymbol(id); - GE_ASSERT_NOTNULL(sym); - syms.emplace_back(sym); - } - - auto *sym = GetOrCreateSymbol(sym_id); - GE_ASSERT_NOTNULL(sym); - (void) named_syms_.emplace(sym_id, sym); - sym->BindExpression(MakeShared(syms)); - return sym; -} - -// 创建输出的Symbol表达式,用于支持类型推导 -SymDtype *IRDataTypeSymbolStore::SetOutputSymbol(const std::string &ir_output, IrOutputType output_type, - const std::string &sym_id) { - auto *sym = GetOrCreateSymbol(sym_id); - GE_ASSERT_NOTNULL(sym); - output_syms_.emplace_back(sym); - output_name_and_types_.emplace_back(ir_output, output_type); - return sym; -} - -graphStatus IRDataTypeSymbolStore::GetPromoteIrInputList(std::vector> &promote_index_list) { - for (const auto &named_sym : named_syms_) { - GE_ASSERT_NOTNULL(named_sym.second); - if (named_sym.second->Type() == ExpressionType::kPromote) { - auto ir_input_indexes = named_sym.second->GetIrInputIndexes(); - promote_index_list.push_back(ir_input_indexes); - } - } - return ge::GRAPH_SUCCESS; -} - -SymDtype *IRDataTypeSymbolStore::DeclareSymbol(const string &sym_id, const OrderedTensorTypeList &types) { - GELOGI("Bind symbol %s with ordered list-dtypes %s", sym_id.c_str(), types.ToString().c_str()); - auto *sym = GetOrCreateSymbol(sym_id); - GE_ASSERT_NOTNULL(sym); - GE_ASSERT_TRUE(named_syms_.emplace(sym_id, sym).second, "Symbol %s has been declared", sym_id.c_str()); - sym->BindAllowedOrderedDtypes(types); - return sym; -} -bool IRDataTypeSymbolStore::IsSupportOrderedSymbolicInferDtype() const { - if (syms_.empty()) { - return false; - } - - size_t base_size = 0; - bool found = false; - - for (const auto &sym_dtype : syms_) { - if (sym_dtype != nullptr && sym_dtype->IsOrderedList()) { - size_t current_size = sym_dtype->GetOrderedTensorTypeList().GetOrderedDtypes().size(); - if (current_size > 0) { - base_size = current_size; - found = true; - break; - } - } - } - - if (!found) { - return false; - } - - return std::all_of(syms_.begin(), syms_.end(), [&base_size](const std::shared_ptr &sym_dtype) { - if (sym_dtype == nullptr) { - return false; - } - if (!sym_dtype->IsOrderedList()) { - return false; - } - size_t current_size = sym_dtype->GetOrderedTensorTypeList().GetOrderedDtypes().size(); - return current_size == base_size; - }); -} - -} // namespace ge diff --git a/graph/ir/ir_data_type_symbol_store.h b/graph/ir/ir_data_type_symbol_store.h deleted file mode 100644 index d5fd68de0393e28dbd397e5b3fb9ba11c71af76f..0000000000000000000000000000000000000000 --- a/graph/ir/ir_data_type_symbol_store.h +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_GRAPH_IR_DATA_TYPE_SYMBOL_STORE_H_ -#define METADEF_CXX_GRAPH_IR_DATA_TYPE_SYMBOL_STORE_H_ - -#include -#include -#include -#include "graph/types.h" -#include "graph/ge_error_codes.h" -#include "graph/type/tensor_type_impl.h" -#include "graph/op_desc.h" -#include "graph/type/sym_dtype.h" - -namespace ge { -/** - * @brief 建立IR上输入输出和data type symbol的映射 - */ -class IRDataTypeSymbolStore { - public: - IRDataTypeSymbolStore() = default; - ~IRDataTypeSymbolStore() = default; - - // 类型推导入口函数,op的输入dtype和属性视作符号实际值的上下文,推导结果原地更新到op上 - graphStatus InferDtype(const OpDescPtr &op) const; - bool IsSupportSymbolicInferDtype() const; - bool IsSupportOrderedSymbolicInferDtype() const;; - - // 创建一个Sym,用于Sym出现早于定义的情况 - SymDtype *GetOrCreateSymbol(const std::string &origin_sym_id); - - // 设置输入对应的Sym - SymDtype *SetInputSymbol(const std::string &ir_input, IrInputType input_type, const std::string &sym_id); - - // 声明一个Sym的取值范围或计算方式,只有具有声明的Sym支持类型推导和范围校验 - SymDtype *DeclareSymbol(const std::string &sym_id, const TensorType &types); - SymDtype *DeclareSymbol(const std::string &sym_id, const ListTensorType &types); - SymDtype *DeclareSymbol(const std::string &sym_id, const Promote &types); - SymDtype *DeclareSymbol(const std::string &sym_id, const OrderedTensorTypeList &types); - - // 设置输出对应的Sym - SymDtype *SetOutputSymbol(const std::string &ir_output, IrOutputType output_type, const std::string &sym_id); - - graphStatus GetPromoteIrInputList(std::vector> &promote_index_list); - - std::list> GetSymbols() const { - return syms_; - } - - std::map GetNamedSymbols() const { - return named_syms_; - } - - std::vector GetOutSymbols() const { - return output_syms_; - } - private: - size_t num_ir_inputs = 0U; - // 全部的Sym,包含T方式声明的Sym,以及Legacy注册的输入输出无效Sym,在DATATYPE声明前,无法确定Sym是否有效 - std::list> syms_; // For copy constructor - // 通过DATATYPE声明的Sym为命名Sym,支持符号共享、计算和范围校验 - std::map named_syms_; - // 每个输出对应一个Sym - std::vector output_syms_; - std::vector> output_name_and_types_; -}; -} // namespace ge -#endif // METADEF_CXX_GRAPH_IR_DATA_TYPE_SYMBOL_STORE_H_ diff --git a/graph/ir/ir_definitions_recover.cc b/graph/ir/ir_definitions_recover.cc deleted file mode 100644 index 7f5428ef232ccf2f34f67ec629b2fca3f2800fe7..0000000000000000000000000000000000000000 --- a/graph/ir/ir_definitions_recover.cc +++ /dev/null @@ -1,295 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ir_definitions_recover.h" -#include -#include -#include -#include "graph/operator_factory.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/debug/ge_op_types.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/checker.h" -#include "graph/utils/op_type_utils.h" -#include "graph/utils/recover_ir_utils.h" -using IrDefinition = ge::RecoverIrUtils::IrDefinition; -namespace { -std::string IrAttrNamesToString(const std::vector &attr_names) { - std::ostringstream oss; - for (const auto &attr : attr_names) { - oss << attr << ", "; - } - return oss.str(); -} - -template -std::string IrDefsToString(const IrDef &ir_defs) { - std::ostringstream oss; - for (const auto &pair : ir_defs) { - oss << "[" << pair.first << ", " << pair.second << "], "; - } - return oss.str(); -} - -template -ge::graphStatus AppendIrDefs(const ge::OpDescPtr &op_desc, const IrDef &ir_ins, const IrDef &ir_defs, - const ge::RecoverIrUtils::IrDefAppender appender) { - // 输入个数和顺序校验针对单算子离线流程当前未实现so进om时版本兼容性校验,实现后校验逻辑可去除 - // 当前运行版本中,算子输入个数减少了(相对于导出模型的版本) - if (ir_defs.size() < ir_ins.size()) { - GELOGE(ge::FAILED, - "In the current running version, the number of operator[%s][%s] inputs has been reduced, " - "ir_def.inputs size[%zu] is less than ir_inputs_in_node size[%zu], ir_def.inputs is [%s], " - "ir_inputs_in_node is [%s]", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), ir_defs.size(), ir_ins.size(), - IrDefsToString(ir_defs).c_str(), IrDefsToString(ir_ins).c_str()); - return ge::FAILED; - } - // 算子输入顺序或者输入类型变化了 - for (size_t i = 0U; i < ir_ins.size(); ++i) { - if (ir_ins[i] != ir_defs[i]) { - GELOGE(ge::FAILED, "In the current running version, the order or type of operator[%s][%s] inputs may " - "have changed, ir_def.inputs[%zu] is [%s, %u], ir_inputs_in_node[%zu] is [%s, %u], " - "ir_def.inputs is [%s], ir_inputs_in_node is [%s]", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), i, ir_defs[i].first.c_str(), ir_defs[i].second, - i, ir_ins[i].first.c_str(), ir_ins[i].second, - IrDefsToString(ir_defs).c_str(), IrDefsToString(ir_ins).c_str()); - return ge::FAILED; - } - } - // 当前运行版本中,算子输入个数在后面增加了,需要添加到node中,或者 ir_inputs_in_node 为空,全部拷贝到node中 - for (size_t i = ir_ins.size(); i < ir_defs.size(); ++i) { - appender(op_desc, ir_defs[i].first, ir_defs[i].second); - GELOGD("Append ir input:%s for node[%s(%s)]", - ir_defs[i].first.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str()); - } - return ge::GRAPH_SUCCESS; -} - -} -namespace ge { -graphStatus RecoverIrUtils::RecoverIrAttrNames(const ge::OpDescPtr &desc, IrDefinition &ir_def) { - const auto &ir_attr_names_in_node = desc->GetIrAttrNames(); - // 输入个数和顺序校验针对单算子离线流程当前未实现so进om时版本兼容性校验,实现后校验逻辑可去除 - // 当前运行版本中,算子属性个数减少了(相对于导出模型的版本) - if (ir_def.attr_names.size() < ir_attr_names_in_node.size()) { - GELOGE(ge::FAILED, - "In the current running version, the number of operator[%s][%s] attributes has been reduced, " - "ir_def.attr_names size[%zu] is less than ir_attr_names_in_node size[%zu], ir_def.attr_names is [%s], " - "ir_attr_names_in_node is [%s]", - desc->GetName().c_str(), desc->GetType().c_str(), ir_def.attr_names.size(), ir_attr_names_in_node.size(), - IrAttrNamesToString(ir_def.attr_names).c_str(), IrAttrNamesToString(ir_attr_names_in_node).c_str()); - return ge::FAILED; - } - // 算子属性顺序变化了 - for (size_t i = 0U; i < ir_attr_names_in_node.size(); ++i) { - if (ir_attr_names_in_node[i] != ir_def.attr_names[i]) { - GELOGE(ge::FAILED, - "In the current running version, the order of operator[%s][%s] attributes may have changed," - "ir_def.attr_names[%zu] is [%s], ir_attr_names_in_node[%zu] is [%s], ir_def.attr_names is [%s], " - "ir_attr_names_in_node is [%s]", - desc->GetName().c_str(), desc->GetType().c_str(), i, ir_def.attr_names[i].c_str(), i, - ir_attr_names_in_node[i].c_str(), IrAttrNamesToString(ir_def.attr_names).c_str(), - IrAttrNamesToString(ir_attr_names_in_node).c_str()); - return ge::FAILED; - } - } - // 当前运行版本中,算子属性在后面增加了,需要拷贝到node中,或者 ir_attr_names_in_node 为空,全部拷贝到node中 - for (size_t i = ir_attr_names_in_node.size(); i < ir_def.attr_names.size(); ++i) { - desc->AppendIrAttrName(ir_def.attr_names[i]); - GELOGD("Append ir attr name:%s for desc[%s(%s)]", ir_def.attr_names[i].c_str(), desc->GetName().c_str(), - desc->GetType().c_str()); - } - return ge::SUCCESS; -} - -void RecoverIrUtils::InitIrDefinitionsIfNeed(const string &op_type, IrDefinition &ir_def) { - if (!ir_def.inited) { - auto op = ge::OperatorFactory::CreateOperator("temp", op_type.c_str()); - op.BreakConnect(); - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - if (op_desc == nullptr) { - GELOGW("Failed to construct operator from type %s", op_type.c_str()); - ir_def.has_ir_definition = false; - ir_def.inited = true; - return; - } - ir_def.attr_names = op_desc->GetIrAttrNames(); - ir_def.inputs = op_desc->GetIrInputs(); - ir_def.outputs = op_desc->GetIrOutputs(); - ir_def.attr_value = ge::AttrUtils::GetAllAttrs(op_desc); - ir_def.has_ir_definition = true; - ir_def.inited = true; - ir_def.op_desc = op_desc; - } -} - -graphStatus RecoverIrUtils::RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, - const string &op_type, - IrDefinition &ir_def) { - if ((desc->GetType() == ge::NETOUTPUT) || ge::OpTypeUtils::IsDataNode(desc->GetType())) { - return ge::GRAPH_SUCCESS; - } - InitIrDefinitionsIfNeed(op_type, ir_def); - - if (!ir_def.has_ir_definition) { - GELOGI("Op type:%s has no registered IR, maybe no need to recover.", op_type.c_str()); - return ge::GRAPH_SUCCESS; - } - - // ir_attr_names - GE_ASSERT_GRAPH_SUCCESS(RecoverIrAttrNames(desc, ir_def), - "%s %s recover ir attr names failed.", - desc->GetNamePtr(), - desc->GetTypePtr()); - // ir input and output - GE_ASSERT_GRAPH_SUCCESS(RecoverIrInputAndOutput(desc, ir_def), - "%s %s recover ir input and output failed.", - desc->GetNamePtr(), - desc->GetTypePtr()); - // sym store - desc->ShareDtypeSymbolsFrom(*ir_def.op_desc); - // attr - const auto node_all_attrs = ge::AttrUtils::GetAllAttrs(desc); - for (const auto &name : ir_def.attr_names) { - if (node_all_attrs.find(name) != node_all_attrs.cend()) { - continue; - } - const std::map::const_iterator iter = ir_def.attr_value.find(name); - if (iter == ir_def.attr_value.cend()) { - GELOGI("node[%s(%s)] missing attr name[%s], and can not find default value for the attr," - " it may be REQUIRED_ATTR.", - desc->GetName().c_str(), op_type.c_str(), name.c_str()); - continue; - } - GELOGD("node[%s(%s)] missing attr name[%s], set default value.", desc->GetName().c_str(), op_type.c_str(), - name.c_str()); - (void) desc->SetAttr(name, iter->second); - } - return ge::GRAPH_SUCCESS; -} - -graphStatus RecoverIrUtils::RecoverIrInputAndOutput(const OpDescPtr &desc, IrDefinition &ir_def) { - // ir_inputs - auto input_appender = [](const ge::OpDescPtr &op_desc, const std::string &ir_name, - const ge::IrInputType ir_type) -> void { op_desc->AppendIrInput(ir_name, ir_type); }; - if (AppendIrDefs(desc, desc->GetIrInputs(), ir_def.inputs, input_appender) != - ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "recover ir inputs failed."); - return ge::GRAPH_FAILED; - } - // ir_outputs - auto output_appender = [](const ge::OpDescPtr &op_desc, const std::string &ir_name, - const ge::IrOutputType ir_type) -> void { op_desc->AppendIrOutput(ir_name, ir_type); }; - if (AppendIrDefs(desc, desc->GetIrOutputs(), ir_def.outputs, output_appender) != - ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "recover ir outputs failed."); - return ge::GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus RecoverNodeIrDefinitions(const ge::NodePtr &node, std::string &op_type, IrDefinition &ir_def) { - return RecoverIrUtils::RecoverOpDescIrDefinition(node->GetOpDesc(), op_type, ir_def); -} -graphStatus RecoverIrUtils::RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const std::string &op_type) { - std::string specified_type = op_type.empty() ? desc->GetType() : op_type; - IrDefinition ir_def; - ir_def.inited = false; - return RecoverIrUtils::RecoverOpDescIrDefinition(desc, specified_type, ir_def); -} - -ge::graphStatus RecoverIrUtils::RecoverIrDefinitions(const ge::ComputeGraphPtr &graph, - const vector &attr_names) { - GELOGD("Start to recover all ir definitions for graph:%s.", graph->GetName().c_str()); - std::map op_type_to_ir_def; - for (const auto &node : graph->GetAllNodes()) { - std::string op_type = ge::NodeUtils::GetNodeType(node); - auto &ir_def = op_type_to_ir_def[op_type]; - if (RecoverNodeIrDefinitions(node, op_type, ir_def) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "[Recover][NodeIrDefinitions] failed, node[%s], type[%s]", - node->GetName().c_str(), node->GetType().c_str()); - return ge::GRAPH_FAILED; - } - for (const auto &attr_name : attr_names) { - ge::ComputeGraphPtr graph_ptr = nullptr; - (void) ge::AttrUtils::GetGraph(node->GetOpDesc(), attr_name, graph_ptr); - if (graph_ptr == nullptr) { - continue; - } - if (RecoverIrDefinitions(graph_ptr) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "[Recover][IrDefinitions] failed, graph[%s]", graph_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - (void) ge::AttrUtils::SetGraph(node->GetOpDesc(), attr_name, graph_ptr); - GELOGD("Success to recover definitions for graph:%s with node:%s and attr:%s.", - graph->GetName().c_str(), node->GetName().c_str(), attr_name.c_str()); - } - } - GELOGD("Success to recover all ir definitions for graph:%s.", graph->GetName().c_str()); - return ge::GRAPH_SUCCESS; -} - -// TODO if all depended is replace, this 2 function will be deleted -ge::graphStatus RecoverIrDefinitions(const ge::ComputeGraphPtr &graph, const vector &attr_names) { - return RecoverIrUtils::RecoverIrDefinitions(graph, attr_names); -} - -ge::graphStatus RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const std::string &op_type) { - return RecoverIrUtils::RecoverOpDescIrDefinition(desc, op_type); -} - -bool CheckIrSpec(const ge::OpDescPtr &desc) { - std::string op_type = desc->GetType(); - IrDefinition ir_def; - ir_def.inited = false; - RecoverIrUtils::InitIrDefinitionsIfNeed(op_type, ir_def); - bool ir_input_include_dynamic = false; - bool ir_output_include_dynamic = false; - for (auto &ir_def_input : ir_def.inputs) { - if ((ir_def_input.second == kIrInputDynamic) || (ir_def_input.second == kIrInputOptional)) { - ir_input_include_dynamic = true; - break; - } - } - for (auto &ir_def_output : ir_def.outputs) { - if (ir_def_output.second == kIrOutputDynamic) { - ir_output_include_dynamic = true; - break; - } - } - size_t input_num = desc->GetInputsSize(); - size_t output_num = desc->GetOutputsSize(); - GELOGD("Node:%s check input num is %d and ir input num is %d, output num is %d and ir output num is %d", - desc->GetName().c_str(), input_num, ir_def.inputs.size(), output_num, ir_def.outputs.size()); - if (((input_num != ir_def.inputs.size()) && !ir_input_include_dynamic) || - ((output_num != ir_def.outputs.size()) && !ir_output_include_dynamic)) { - GELOGI("Node:%s inputs/outputs num has changed, compatibility check fail", desc->GetName().c_str()); - return false; - } - // attr - const auto node_all_attrs = ge::AttrUtils::GetAllAttrs(desc); - for (const auto &name : ir_def.attr_names) { - if (node_all_attrs.find(name) != node_all_attrs.cend()) { - continue; - } - const std::map::const_iterator iter = ir_def.attr_value.find(name); - if (iter == ir_def.attr_value.cend()) { - GELOGI("node[%s(%s)] missing attr name[%s], and can not find default value for the attr," - " it may be REQUIRED_ATTR.", - desc->GetName().c_str(), op_type.c_str(), name.c_str()); - return false; - } - } - return true; -} -} // namespace ge diff --git a/graph/ir/ir_meta.cc b/graph/ir/ir_meta.cc deleted file mode 100644 index 79ada8d3abdc0d588dc00a6cde6fa564718b88fe..0000000000000000000000000000000000000000 --- a/graph/ir/ir_meta.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ir/ir_meta.h" -#include "inc/common/util/trace_manager/trace_manager.h" -#include "graph/utils/ge_ir_utils.h" - -namespace ge { -void IRMetaData::AppendIrAttrName(std::string name) { - ir_attr_names_.emplace_back(std::move(name)); -} -const std::vector &IRMetaData::GetIrAttrNames() const { - return ir_attr_names_; -} -void IRMetaData::AppendIrInput(std::string name, IrInputType input_type) { - ir_inputs_.AppendIrInput(std::move(name), input_type); -} -const std::vector> &IRMetaData::GetIrInputs() const { - return ir_inputs_.ir_inputs; -} -graphStatus IRMetaData::AddRegisterInputName(const std::string &name) { - if (register_unique_name_.insert(name).second) { - register_input_name_.emplace_back(name); - } - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - op_name_, "register_input_name", "", "", name); - return GRAPH_SUCCESS; -} - -vector IRMetaData::GetRegisterInputName() const { - return register_input_name_; -} - -bool IRMetaData::IsOptionalInput(const std::string &name) const { - return optional_input_names_.find(name) != optional_input_names_.end(); -} - -graphStatus IRMetaData::AddRegisterOutputName(const std::string &name) { - if (register_unique_name_.insert(name).second) { - register_output_name_.emplace_back(name); - } - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - op_name_, "register_output_name", "", "", name); - return GRAPH_SUCCESS; -} - -vector IRMetaData::GetRegisterOutputName() const { - return register_output_name_; -} - -void IRMetaData::RegisterSubgraphIrName(const std::string &name, const SubgraphType type) { - subgraph_ir_names_to_type_ordered_.emplace_back(name, type); - subgraph_ir_names_to_type_[name] = type; -} - -const std::map &IRMetaData::GetSubgraphIrNames() const { - return subgraph_ir_names_to_type_; -} -const std::vector> &IRMetaData::GetOrderedSubgraphIrNames() const { - return subgraph_ir_names_to_type_ordered_; -} - -SubgraphType IRMetaData::GetSubgraphTypeByIrName(const std::string &name) const { - const auto iter = subgraph_ir_names_to_type_.find(name); - if (iter == subgraph_ir_names_to_type_.end()) { - return kSubgraphTypeEnd; - } - return iter->second; -} - -IRDataTypeSymbolStore &IRMetaData::MutableIRDataTypeSymbolStore() { - return dtype_symbol_store_; -} - -const IRDataTypeSymbolStore &IRMetaData::GetIRDataTypeSymbolStore() const { - return dtype_symbol_store_; -} - -graphStatus IRMetaData::AddRegisterOptionalInputName(const string &name) { - optional_input_names_.insert(name); - return GRAPH_SUCCESS; -} - -bool IRMetaData::operator==(const IRMetaData &other) const { - return IsEqual(this->optional_input_names_, other.optional_input_names_, - "OpDesc.ir_meta.optional_input_names_"); -} - -std::set IRMetaData::GetOptionalInputName() const { - return optional_input_names_; -} - -IrInputType IRMetaData::GetIrInputType(const string &name) const { - for (const auto &name_2_type : ir_inputs_.ir_inputs) { - if (name == name_2_type.first) { - return name_2_type.second; - } - } - return kIrInputTypeEnd; -} - -void IRMetaData::AppendIrOutput(std::string name, IrOutputType output_type) { - ir_outputs_.AppendIrOutput(std::move(name), output_type); -} - -const std::vector> &IRMetaData::GetIrOutputs() const { - return ir_outputs_.ir_outputs; -} -} // namespace ge diff --git a/graph/ir/ir_meta.h b/graph/ir/ir_meta.h deleted file mode 100644 index c0a99efe147406e4386ff46b23d9a4c9b3d90a9e..0000000000000000000000000000000000000000 --- a/graph/ir/ir_meta.h +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_GRAPH_IR_META_H_ -#define METADEF_CXX_GRAPH_IR_META_H_ - -#include -#include -#include "inc/graph/ascend_limits.h" -#include "inc/graph/small_vector.h" -#include "inc/graph/op_desc.h" -#include "graph/ir/ir_data_type_symbol_store.h" - -namespace ge { -/** - * IR信息 - */ -class IRMetaData { - struct IrInputs { - void AppendIrInput(const std::string &name, IrInputType input_type) { - if (ir_input_names.insert(name).second) { - ir_inputs.emplace_back(name, input_type); - } - } - std::unordered_set ir_input_names; - std::vector> ir_inputs; - }; - struct IrOutputs { - void AppendIrOutput(const std::string &name, IrOutputType output_type) { - if (ir_output_names.insert(name).second) { - ir_outputs.emplace_back(name, output_type); - } - } - std::unordered_set ir_output_names; - std::vector> ir_outputs; - }; - public: - explicit IRMetaData(const std::string &op_name) : op_name_(op_name) {}; - IRMetaData() = default; - void SetOpName(const std::string &op_name) { - op_name_ = op_name; - } - void AppendIrInput(std::string name, IrInputType input_type); - const std::vector> &GetIrInputs() const; - IrInputType GetIrInputType(const std::string &name) const; - - void AppendIrOutput(std::string name, IrOutputType output_type); - const std::vector> &GetIrOutputs() const; - - graphStatus AddRegisterInputName(const std::string &name); - std::vector GetRegisterInputName() const; - - graphStatus AddRegisterOptionalInputName(const std::string &name); - std::set GetOptionalInputName() const; - bool IsOptionalInput(const std::string &name) const; - - graphStatus AddRegisterOutputName(const std::string &name); - std::vector GetRegisterOutputName() const; - - void AppendIrAttrName(std::string name); - const std::vector &GetIrAttrNames() const; - - void RegisterSubgraphIrName(const std::string &name, const SubgraphType type); - const std::map &GetSubgraphIrNames() const; - /** - * @brief Get subgraph names in IR order - * @return subgraph ir names in IR order - */ - const std::vector> &GetOrderedSubgraphIrNames() const; - SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; - - IRDataTypeSymbolStore &MutableIRDataTypeSymbolStore(); - const IRDataTypeSymbolStore &GetIRDataTypeSymbolStore() const; - - bool operator==(const IRMetaData &other) const; - - private: - std::string op_name_; - IrInputs ir_inputs_; - IrOutputs ir_outputs_; - std::vector register_input_name_; // todo need to deprecate - std::set optional_input_names_; // todo need to deprecate - std::vector register_output_name_; - std::vector ir_attr_names_; - // subgraph ir names to type, for a `if` operator: - // then_branch: static - // else_branch: static - // or for a `case` op: - // branches: dynamic - std::map subgraph_ir_names_to_type_; - IRDataTypeSymbolStore dtype_symbol_store_; - std::set register_unique_name_; - std::vector> subgraph_ir_names_to_type_ordered_; -}; - -class OpMetadata { - public: - using SmallIntVector = SmallVector(kDefaultMaxInputNum)>; - OpMetadata() = default; - ~OpMetadata() = default; - OpMetadata(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)), ir_meta_(name) {} - int64_t GetId() const {return id_;} - int64_t GetStreamId() const {return stream_id_;} - const std::vector &GetInputNames() const {return input_names_;} - const std::vector &GetSrcNames() const {return src_names_;} - const std::vector &GetSrcIndexes() const {return src_indexes_;} - const std::vector &GetDstNames() const {return dst_names_;} - const std::vector &GetDstIndexes() const {return dst_indexes_;} - const std::vector &GetInputOffsets() const {return input_offsets_;} - const std::vector &GetOutputOffsets() const {return output_offsets_;} - const std::vector &GetIsInputConsts() const {return is_input_consts_;} - const std::vector &GetSubgraphNames() const {return subgraph_names_;} - void AddSubGraphName(const std::string &name) {subgraph_names_.push_back(name);} - void ClearSubgraphNames() { subgraph_names_.clear(); } - void SetOpName(std::string name) { - name_ = std::move(name); - ir_meta_.SetOpName(name); - } - - private: - friend class OpDescImpl; - std::string name_; - std::string type_; - std::vector inputs_; - bool has_out_attr_{false}; - int64_t id_{0}; - int64_t stream_id_{0}; - std::vector input_names_; - std::vector src_names_; - std::vector src_indexes_; - std::vector dst_names_; - std::vector dst_indexes_; - std::vector input_offsets_; - std::vector output_offsets_; - SmallIntVector workspaces; - SmallIntVector workspace_bytes_list_; - std::vector is_input_consts_; - std::vector subgraph_names_; - IRMetaData ir_meta_; -}; -} // namespace ge -#endif // METADEF_CXX_GRAPH_IR_META_H_ diff --git a/graph/normal_graph/anchor.cc b/graph/normal_graph/anchor.cc deleted file mode 100644 index abcdfc98ff2b17da473748984bf1e0dcc169181a..0000000000000000000000000000000000000000 --- a/graph/normal_graph/anchor.cc +++ /dev/null @@ -1,838 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/anchor.h" -#include -#include -#include -#include "debug/ge_util.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/node.h" -#include "common/util/trace_manager/trace_manager.h" - -namespace { -constexpr size_t kAnchorTypeMaxLen = 1024U; -bool CanAddPeer(const ge::AnchorPtr &anchor) { - if (anchor->IsTypeIdOf() && (anchor->GetPeerAnchorsSize() != 0U)) { - REPORT_INNER_ERR_MSG("E18888", "anchor is type of InDataAnchor, it's peer is not empty."); - GELOGE(ge::GRAPH_FAILED, "[Check][Param] anchor is type of InDataAnchor, it's peer is not empty."); - return false; - } - return true; -} -bool IsSameType(const ge::Anchor::TYPE &lh, const ge::Anchor::TYPE &rh) { - if (lh == rh) { - return true; - } - - return (strncmp(lh, rh, kAnchorTypeMaxLen) == 0); -} -}; - -namespace ge { -class AnchorImpl { - public: - AnchorImpl(const NodePtr &owner_node, const int32_t idx); - ~AnchorImpl() = default; - size_t GetPeerAnchorsSize() const; - Anchor::Vistor GetPeerAnchors(const std::shared_ptr &anchor_ptr) const; - std::vector GetPeerAnchorsPtr() const; - AnchorPtr GetFirstPeerAnchor() const; - NodePtr GetOwnerNode() const; - Node *GetOwnerNodeBarePtr() const; - int32_t GetIdx() const; - void SetIdx(const int32_t index); - - private: - // All peer anchors connected to current anchor - std::vector> peer_anchors_; - // The owner node of anchor - std::weak_ptr owner_node_; - // The bare ptr of owner node, - Node *const owner_node_ptr_; - // The index of current anchor - int32_t idx_; - - friend class Anchor; - friend class OutControlAnchor; - friend class InControlAnchor; - friend class OutDataAnchor; - friend class InDataAnchor; -}; - -AnchorImpl::AnchorImpl(const NodePtr &owner_node, const int32_t idx) - : owner_node_(owner_node), owner_node_ptr_(owner_node_.lock().get()), idx_(idx) {} - -size_t AnchorImpl::GetPeerAnchorsSize() const { - return peer_anchors_.size(); -} - -Anchor::Vistor AnchorImpl::GetPeerAnchors( - const std::shared_ptr &anchor_ptr) const { - std::vector ret; - ret.resize(peer_anchors_.size()); - (void)std::transform(peer_anchors_.begin(), peer_anchors_.end(), ret.begin(), - [] (const std::weak_ptr& anchor) { - return anchor.lock(); - }); - return Anchor::Vistor(anchor_ptr, ret); -} - -std::vector AnchorImpl::GetPeerAnchorsPtr() const { - std::vector ret; - ret.resize(peer_anchors_.size()); - (void) std::transform(peer_anchors_.begin(), peer_anchors_.end(), ret.begin(), - [](const std::weak_ptr &anchor) { return anchor.lock().get(); }); - return ret; -} - -AnchorPtr AnchorImpl::GetFirstPeerAnchor() const { - if (peer_anchors_.empty()) { - return nullptr; - } else { - return Anchor::DynamicAnchorCast(peer_anchors_.begin()->lock()); - } -} - -NodePtr AnchorImpl::GetOwnerNode() const { - return owner_node_.lock(); -} -Node *AnchorImpl::GetOwnerNodeBarePtr() const { - return owner_node_ptr_; -} - -int32_t AnchorImpl::GetIdx() const { return idx_; } - -void AnchorImpl::SetIdx(const int32_t index) { idx_ = index; } - -Anchor::Anchor(const NodePtr &owner_node, const int32_t idx) - : enable_shared_from_this(), impl_(ComGraphMakeShared(owner_node, idx)) {} - -Anchor::~Anchor() = default; - -bool Anchor::IsTypeOf(const TYPE type) const { - return strncmp(Anchor::TypeOf(), type, kAnchorTypeMaxLen) == 0; -} - -bool Anchor::IsTypeIdOf(const TypeId& type) const { - return GetTypeId() == type; -} - -Anchor::TYPE Anchor::GetSelfType() const { - return TypeOf(); -} - -size_t Anchor::GetPeerAnchorsSize() const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return 0UL; - } - return impl_->GetPeerAnchorsSize(); -} - -Anchor::Vistor Anchor::GetPeerAnchors() const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - const std::vector ret; - return Anchor::Vistor(shared_from_this(), ret); - } - return impl_->GetPeerAnchors(shared_from_this()); -} - -std::vector Anchor::GetPeerAnchorsPtr() const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - std::vector ret; - return ret; - } - return impl_->GetPeerAnchorsPtr(); -} - -AnchorPtr Anchor::GetFirstPeerAnchor() const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return nullptr; - } - return impl_->GetFirstPeerAnchor(); -} - -NodePtr Anchor::GetOwnerNode() const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return nullptr; - } - return impl_->GetOwnerNode(); -} - -Node *Anchor::GetOwnerNodeBarePtr() const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return nullptr; - } - return impl_->GetOwnerNodeBarePtr(); -} - -void Anchor::UnlinkAll() noexcept { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return; - } - if (!impl_->peer_anchors_.empty()) { - do { - const auto peer_anchor_ptr = impl_->peer_anchors_.begin()->lock(); - (void)Unlink(peer_anchor_ptr); - } while (!impl_->peer_anchors_.empty()); - } -} - -graphStatus Anchor::Unlink(const AnchorPtr &peer) { - if (peer == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param peer is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] peer anchor is invalid."); - return GRAPH_FAILED; - } - GE_CHK_BOOL_RET_STATUS(impl_ != nullptr, GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr"); - const auto it = std::find_if(impl_->peer_anchors_.begin(), impl_->peer_anchors_.end(), - [peer](const std::weak_ptr &an) { - const auto anchor = an.lock(); - return peer->Equal(anchor); - }); - - if (it == impl_->peer_anchors_.end()) { - GELOGW("[Check][Param] Unlink failed , as this anchor is not connected to peer."); - return GRAPH_FAILED; - } - - const auto it_peer = std::find_if(peer->impl_->peer_anchors_.begin(), peer->impl_->peer_anchors_.end(), - [this](const std::weak_ptr &an) { - const auto anchor = an.lock(); - return Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(it_peer != peer->impl_->peer_anchors_.end(), GRAPH_FAILED, - "[Check][Param] peer(%s, %d) is not connected to this anchor(%s, %d)", - peer->GetOwnerNode()->GetName().c_str(), peer->GetIdx(), - this->GetOwnerNode()->GetName().c_str(), this->GetIdx()); - if ((this->GetOwnerNode() != nullptr) && (peer->GetOwnerNode() != nullptr)) { - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "delete", TraceManager::GetOutGraphName(), - this->GetOwnerNode()->GetName(), "output:" << this->GetIdx(), "", "", - peer->GetOwnerNode()->GetName() << ":input:" << peer->GetIdx()); - } - (void)impl_->peer_anchors_.erase(it); - (void)peer->impl_->peer_anchors_.erase(it_peer); - return GRAPH_SUCCESS; -} - -graphStatus Anchor::Insert(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer) { - GE_CHECK_NOTNULL(old_peer); - GE_CHECK_NOTNULL(first_peer); - GE_CHECK_NOTNULL(second_peer); - GE_CHECK_NOTNULL(impl_); - - if (!IsSameType(old_peer->GetSelfType(), first_peer->GetSelfType())) { - REPORT_INNER_ERR_MSG("E18888", "the type of old_peer[%s] and first_peer[%s] is not the same.", - old_peer->GetSelfType(), first_peer->GetSelfType()); - GELOGE(GRAPH_FAILED, "[Check][Param] the type of old_peer[%s] and first_peer[%s] is not the same.", - old_peer->GetSelfType(), first_peer->GetSelfType()); - return GRAPH_FAILED; - } - - if (!IsSameType(second_peer->GetSelfType(), this->GetSelfType())) { - REPORT_INNER_ERR_MSG("E18888", "the type of second_peer[%s] and current anchor[%s] is not the same.", - second_peer->GetSelfType(), this->GetSelfType()); - GELOGE(GRAPH_FAILED, "[Check][Param] the type of second_peer[%s] and current anchor[%s] is not the same.", - second_peer->GetSelfType(), this->GetSelfType()); - return GRAPH_FAILED; - } - - if ((!CanAddPeer(first_peer)) || (!CanAddPeer(second_peer))) { - REPORT_INNER_ERR_MSG("E18888", "first_peer[%s] or second_peer[%s] check failed", first_peer->GetSelfType(), - second_peer->GetSelfType()); - GELOGE(GRAPH_FAILED, "[Check][Param] first_peer[%s] or second_peer[%s] check failed", - first_peer->GetSelfType(), second_peer->GetSelfType()); - return GRAPH_FAILED; - } - - const auto this_it = std::find_if(impl_->peer_anchors_.begin(), impl_->peer_anchors_.end(), - [old_peer](const std::weak_ptr &an) { - const auto anchor = an.lock(); - return old_peer->Equal(anchor); - }); - - GE_CHK_BOOL_RET_STATUS(this_it != impl_->peer_anchors_.end(), GRAPH_FAILED, - "[Check][Param] this anchor(%s, %d) is not connected to old_peer(%s, %d)", - this->GetOwnerNode()->GetName().c_str(), this->GetIdx(), - old_peer->GetOwnerNode()->GetName().c_str(), old_peer->GetIdx()); - - const auto old_it = std::find_if(old_peer->impl_->peer_anchors_.begin(), old_peer->impl_->peer_anchors_.end(), - [this](const std::weak_ptr &an) { - const auto anchor = an.lock(); - return Equal(anchor); - }); - GE_CHK_BOOL_RET_STATUS(old_it != old_peer->impl_->peer_anchors_.end(), GRAPH_FAILED, - "[Check][Param] old_peer(%s, %d) is not connected to this anchor(%s, %d)", - old_peer->GetOwnerNode()->GetName().c_str(), old_peer->GetIdx(), - this->GetOwnerNode()->GetName().c_str(), this->GetIdx()); - *this_it = first_peer; - first_peer->impl_->peer_anchors_.push_back(shared_from_this()); - *old_it = second_peer; - second_peer->impl_->peer_anchors_.push_back(old_peer); - return GRAPH_SUCCESS; -} - -graphStatus Anchor::ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &new_peer) { - GE_CHECK_NOTNULL(old_peer); - GE_CHECK_NOTNULL(new_peer); - GE_CHECK_NOTNULL(impl_); - if (!IsSameType(old_peer->GetSelfType(), new_peer->GetSelfType())) { - REPORT_INNER_ERR_MSG("E18888", "the type of old_peer[%s] and new_peer[%s] is not the same.", - old_peer->GetSelfType(), new_peer->GetSelfType()); - GELOGE(GRAPH_FAILED, "[Check][Param] the type of old_peer[%s] and new_peer[%s] is not the same.", - old_peer->GetSelfType(), new_peer->GetSelfType()); - return GRAPH_FAILED; - } - - if (!CanAddPeer(new_peer)) { - REPORT_INNER_ERR_MSG("E18888", "new_peer[%s] check failed.", new_peer->GetSelfType()); - GELOGE(GRAPH_FAILED, "[Check][Param] new_peer[%s] check failed.", new_peer->GetSelfType()); - return GRAPH_FAILED; - } - - const auto this_it = std::find_if(this->impl_->peer_anchors_.begin(), this->impl_->peer_anchors_.end(), - [old_peer](const std::weak_ptr &an) { - const auto anchor = an.lock(); - return old_peer->Equal(anchor); - }); - if (this_it == this->impl_->peer_anchors_.end()) { - GELOGE(GRAPH_FAILED, "[Check][Param] this anchor(%s, %d) is not connected to old_peer(%s, %d)", - this->GetOwnerNode()->GetName().c_str(), this->GetIdx(), - old_peer->GetOwnerNode()->GetName().c_str(), old_peer->GetIdx()); - return GRAPH_FAILED; - } - - const auto old_it = std::find_if(old_peer->impl_->peer_anchors_.begin(), old_peer->impl_->peer_anchors_.end(), - [this](const std::weak_ptr &an) { - const auto anchor = an.lock(); - return this->Equal(anchor); - }); - *this_it = new_peer; - (void)old_peer->impl_->peer_anchors_.erase(old_it); - new_peer->impl_->peer_anchors_.push_back(shared_from_this()); - return GRAPH_SUCCESS; -} - -bool Anchor::IsLinkedWith(const AnchorPtr &peer) const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return false; - } - const auto it = std::find_if(impl_->peer_anchors_.begin(), impl_->peer_anchors_.end(), - [peer](const std::weak_ptr &an) { - const auto anchor = an.lock(); - if (peer == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] this old peer anchor is nullptr"); - return false; - } - return peer->Equal(anchor); - }); - return (it != impl_->peer_anchors_.end()); -} - -int32_t Anchor::GetIdx() const { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return 0; - } - return impl_->GetIdx(); -} - -void Anchor::SetIdx(const int32_t index) { - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] impl_ of anchor is nullptr."); - return; - } - impl_->SetIdx(index); -} - -DataAnchor::DataAnchor(const NodePtr &owner_node, const int32_t idx) : Anchor(owner_node, idx) {} - -bool DataAnchor::IsTypeOf(const TYPE type) const { - if (strncmp(Anchor::TypeOf(), type, kAnchorTypeMaxLen) == 0) { - return true; - } - return Anchor::IsTypeOf(type); -} - -Anchor::TYPE DataAnchor::GetSelfType() const { - return Anchor::TypeOf(); -} - -bool DataAnchor::IsTypeIdOf(const TypeId &type) const { - if (GetTypeId() == type) { - return true; - } - return Anchor::IsTypeIdOf(type); -} - -InDataAnchor::InDataAnchor(const NodePtr &owner_node, const int32_t idx) : DataAnchor(owner_node, idx) {} - -OutDataAnchorPtr InDataAnchor::GetPeerOutAnchor() const { - if ((impl_ == nullptr) || impl_->peer_anchors_.empty()) { - return nullptr; - } else { - return Anchor::DynamicAnchorCast(impl_->peer_anchors_.begin()->lock()); - } -} - -graphStatus InDataAnchor::LinkFrom(const OutDataAnchorPtr &src) { - // InDataAnchor must be only linkfrom once - if ((src == nullptr) || (src->impl_ == nullptr) || - (impl_ == nullptr) || (!impl_->peer_anchors_.empty())) { - REPORT_INNER_ERR_MSG("E18888", "src anchor is invalid or the peerAnchors is not empty."); - GELOGE(GRAPH_FAILED, "[Check][Param] src anchor is invalid or the peerAnchors is not empty."); - return GRAPH_FAILED; - } - impl_->peer_anchors_.push_back(src); - src->impl_->peer_anchors_.push_back(shared_from_this()); - // src->impl_->GetOwnerNode() is null: peer->GetOwnerNode() is null: - if ((src->impl_->GetOwnerNodeBarePtr() == nullptr) || (impl_->GetOwnerNodeBarePtr() == nullptr)) { - GELOGW("[Check][Param] src->impl_->GetOwnerNode() or impl_->GetOwnerNode() is null."); - } else { - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - src->impl_->GetOwnerNode()->GetName(), "output:" << src->impl_->GetIdx(), "", "", - impl_->GetOwnerNode()->GetName() << ":input:" << impl_->GetIdx()); - } - return GRAPH_SUCCESS; -} - -bool InDataAnchor::Equal(const AnchorPtr anchor) const { - const auto in_data_anchor = Anchor::DynamicAnchorCast(anchor); - if (in_data_anchor != nullptr) { - if ((GetOwnerNodeBarePtr() == in_data_anchor->GetOwnerNodeBarePtr()) && (GetIdx() == in_data_anchor->GetIdx())) { - return true; - } - } - return false; -} - -bool InDataAnchor::IsTypeOf(const TYPE type) const { - if (strncmp(Anchor::TypeOf(), type, kAnchorTypeMaxLen) == 0) { - return true; - } - return DataAnchor::IsTypeOf(type); -} - -Anchor::TYPE InDataAnchor::GetSelfType() const { - return Anchor::TypeOf(); -} - -bool InDataAnchor::IsTypeIdOf(const TypeId &type) const { - if (GetTypeId() == type) { - return true; - } - return DataAnchor::IsTypeIdOf(type); -} - -OutDataAnchor::OutDataAnchor(const NodePtr &owner_node, const int32_t idx) : DataAnchor(owner_node, idx) {} - -OutDataAnchor::Vistor OutDataAnchor::GetPeerInDataAnchors() const { - std::vector ret; - if (impl_ != nullptr) { - ret.reserve(impl_->peer_anchors_.size()); - for (const auto &anchor : impl_->peer_anchors_) { - const auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr) { - ret.push_back(in_data_anchor); - } - } - } - return OutDataAnchor::Vistor(shared_from_this(), ret); -} - -std::vector OutDataAnchor::GetPeerInDataAnchorsPtr() const { - std::vector ret; - if (impl_ != nullptr) { - ret.reserve(impl_->peer_anchors_.size()); - for (const auto &anchor : impl_->peer_anchors_) { - const auto in_data_anchor = Anchor::DynamicAnchorPtrCast(anchor.lock().get()); - if (in_data_anchor != nullptr) { - ret.push_back(in_data_anchor); - } - } - } - return ret; -} - -uint32_t OutDataAnchor::GetPeerInDataNodesSize() const { - uint32_t out_nums = 0U; - if (impl_ != nullptr) { - for (const auto &anchor : impl_->peer_anchors_) { - const auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if ((in_data_anchor != nullptr) && (in_data_anchor->GetOwnerNodeBarePtr() != nullptr)) { - out_nums++; - } - } - } - return out_nums; -} - -OutDataAnchor::Vistor OutDataAnchor::GetPeerInControlAnchors() const { - std::vector ret; - if (impl_ != nullptr) { - ret.reserve(impl_->peer_anchors_.size()); - for (const auto &anchor : impl_->peer_anchors_) { - const auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_control_anchor != nullptr) { - ret.push_back(in_control_anchor); - } - } - } - return OutDataAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus OutDataAnchor::LinkTo(const InDataAnchorPtr &dest) { - if ((dest == nullptr) || (dest->impl_ == nullptr) || - (!dest->impl_->peer_anchors_.empty())) { - REPORT_INNER_ERR_MSG("E18888", "dest anchor is nullptr or the peerAnchors is not empty."); - GELOGE(GRAPH_FAILED, "[Check][Param] dest anchor is nullptr or the peerAnchors is not empty."); - return GRAPH_FAILED; - } - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "anchor param is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] owner anchor is invalid."); - return GRAPH_FAILED; - } - impl_->peer_anchors_.push_back(dest); - dest->impl_->peer_anchors_.push_back(shared_from_this()); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - impl_->GetOwnerNode()->GetName(), "output:" << impl_->GetIdx(), "", "", - dest->impl_->GetOwnerNode()->GetName() << ":input:" << dest->impl_->GetIdx()); - return GRAPH_SUCCESS; -} - -graphStatus OutDataAnchor::LinkTo(const InControlAnchorPtr &dest) { - if ((dest == nullptr) || (dest->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param dest is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] dest anchor is invalid."); - return GRAPH_FAILED; - } - if (impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "src anchor is invalid."); - return GRAPH_FAILED; - } - impl_->peer_anchors_.push_back(dest); - dest->impl_->peer_anchors_.push_back(shared_from_this()); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - impl_->GetOwnerNode()->GetName(), "output:" << impl_->GetIdx(), "", "", - dest->impl_->GetOwnerNode()->GetName() << ":input:" << dest->impl_->GetIdx()); - return GRAPH_SUCCESS; -} - -graphStatus OutControlAnchor::LinkTo(const InDataAnchorPtr &dest) { - if ((dest == nullptr) || (dest->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param dest is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] dest anchor is invalid."); - return GRAPH_FAILED; - } - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "anchor param is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] owner anchor is invalid."); - return GRAPH_FAILED; - } - impl_->peer_anchors_.push_back(dest); - dest->impl_->peer_anchors_.push_back(shared_from_this()); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - impl_->GetOwnerNode()->GetName(), "output:" << impl_->GetIdx(), "", "", - dest->impl_->GetOwnerNode()->GetName() << ":input:" << dest->impl_->GetIdx()); - return GRAPH_SUCCESS; -} - -bool OutDataAnchor::Equal(const AnchorPtr anchor) const { - CHECK_FALSE_EXEC(anchor != nullptr, return false); - const auto out_data_anchor = Anchor::DynamicAnchorCast(anchor); - if (out_data_anchor != nullptr) { - if ((GetOwnerNodeBarePtr() == out_data_anchor->GetOwnerNodeBarePtr()) && (GetIdx() == out_data_anchor->GetIdx())) { - return true; - } - } - return false; -} - -bool OutDataAnchor::IsTypeOf(const TYPE type) const { - if (strncmp(Anchor::TypeOf(), type, kAnchorTypeMaxLen) == 0) { - return true; - } - return DataAnchor::IsTypeOf(type); -} - -Anchor::TYPE OutDataAnchor::GetSelfType() const { - return Anchor::TypeOf(); -} - -bool OutDataAnchor::IsTypeIdOf(const TypeId &type) const { - if (GetTypeId() == type) { - return true; - } - return DataAnchor::IsTypeIdOf(type); -} - -ControlAnchor::ControlAnchor(const NodePtr &owner_node) : Anchor(owner_node, -1) {} - -ControlAnchor::ControlAnchor(const NodePtr &owner_node, const int32_t idx) : Anchor(owner_node, idx) {} - -bool ControlAnchor::IsTypeOf(const TYPE type) const { - if (strncmp(Anchor::TypeOf(), type, kAnchorTypeMaxLen) == 0) { - return true; - } - return Anchor::IsTypeOf(type); -} - -Anchor::TYPE ControlAnchor::GetSelfType() const { - return Anchor::TypeOf(); -} - -bool ControlAnchor::IsTypeIdOf(const TypeId &type) const { - if (GetTypeId() == type) { - return true; - } - return Anchor::IsTypeIdOf(type); -} - -InControlAnchor::InControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} - -InControlAnchor::InControlAnchor(const NodePtr &owner_node, const int32_t idx) : ControlAnchor(owner_node, idx) {} - -InControlAnchor::Vistor InControlAnchor::GetPeerOutControlAnchors() const { - std::vector ret; - if (impl_ != nullptr) { - ret.reserve(impl_->peer_anchors_.size()); - for (const auto &anchor : impl_->peer_anchors_) { - const auto out_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (out_control_anchor != nullptr) { - ret.push_back(out_control_anchor); - } - } - } - return InControlAnchor::Vistor(shared_from_this(), ret); -} - -std::vector InControlAnchor::GetPeerOutControlAnchorsPtr() const { - std::vector ret; - if (impl_ != nullptr) { - ret.reserve(impl_->peer_anchors_.size()); - for (const auto &anchor : impl_->peer_anchors_) { - const auto out_control_anchor = Anchor::DynamicAnchorPtrCast(anchor.lock().get()); - if (out_control_anchor != nullptr) { - ret.push_back(out_control_anchor); - } - } - } - return ret; -} - -bool InControlAnchor::IsPeerOutAnchorsEmpty() const { - if (impl_ == nullptr) { - return false; - } - return impl_->peer_anchors_.empty(); -} - -InControlAnchor::Vistor InControlAnchor::GetPeerOutDataAnchors() const { - std::vector ret; - if (impl_ != nullptr) { - for (const auto &anchor : impl_->peer_anchors_) { - const auto out_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (out_data_anchor != nullptr) { - ret.push_back(out_data_anchor); - } - } - } - return InControlAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus InControlAnchor::LinkFrom(const OutControlAnchorPtr &src) { - if ((src == nullptr) || (src->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] src anchor is invalid."); - return GRAPH_FAILED; - } - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "anchor param is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] owner anchor is invalid."); - return GRAPH_FAILED; - } - impl_->peer_anchors_.push_back(src); - src->impl_->peer_anchors_.push_back(shared_from_this()); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - src->impl_->GetOwnerNode()->GetName(), "output:" << src->impl_->GetIdx(), "", "", - impl_->GetOwnerNode()->GetName() << ":input:" << impl_->GetIdx()); - return GRAPH_SUCCESS; -} - -bool InControlAnchor::Equal(const AnchorPtr anchor) const { - CHECK_FALSE_EXEC(anchor != nullptr, REPORT_INNER_ERR_MSG("E18888", "param anchor is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] anchor is invalid."); return false); - const auto in_control_anchor = Anchor::DynamicAnchorCast(anchor); - if (in_control_anchor != nullptr) { - if (GetOwnerNodeBarePtr() == in_control_anchor->GetOwnerNodeBarePtr()) { - return true; - } - } - return false; -} - -bool InControlAnchor::IsTypeOf(const TYPE type) const { - if (strncmp(Anchor::TypeOf(), type, kAnchorTypeMaxLen) == 0) { - return true; - } - return ControlAnchor::IsTypeOf(type); -} - -Anchor::TYPE InControlAnchor::GetSelfType() const { - return Anchor::TypeOf(); -} - -bool InControlAnchor::IsTypeIdOf(const TypeId &type) const { - if (GetTypeId() == type) { - return true; - } - return ControlAnchor::IsTypeIdOf(type); -} - -OutControlAnchor::OutControlAnchor(const NodePtr &owner_node) : ControlAnchor(owner_node) {} - -OutControlAnchor::OutControlAnchor(const NodePtr &owner_node, const int32_t idx) : ControlAnchor(owner_node, idx) {} - -OutControlAnchor::Vistor OutControlAnchor::GetPeerInControlAnchors() const { - std::vector ret; - if (impl_ != nullptr) { - ret.reserve(impl_->peer_anchors_.size()); - for (const auto &anchor : impl_->peer_anchors_) { - const auto in_control_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_control_anchor != nullptr) { - ret.push_back(in_control_anchor); - } - } - } - return OutControlAnchor::Vistor(shared_from_this(), ret); -} - -std::vector OutControlAnchor::GetPeerInControlAnchorsPtr() const { - std::vector ret; - if (impl_ != nullptr) { - ret.reserve(impl_->peer_anchors_.size()); - for (const auto &anchor : impl_->peer_anchors_) { - const auto in_control_anchor = Anchor::DynamicAnchorPtrCast(anchor.lock().get()); - if (in_control_anchor != nullptr) { - ret.push_back(in_control_anchor); - } - } - } - return ret; -} - -OutControlAnchor::Vistor OutControlAnchor::GetPeerInDataAnchors() const { - std::vector ret; - if (impl_ != nullptr) { - for (const auto &anchor : impl_->peer_anchors_) { - const auto in_data_anchor = Anchor::DynamicAnchorCast(anchor.lock()); - if (in_data_anchor != nullptr) { - ret.push_back(in_data_anchor); - } - } - } - return OutControlAnchor::Vistor(shared_from_this(), ret); -} - -graphStatus OutControlAnchor::LinkTo(const InControlAnchorPtr &dest) { - if ((dest == nullptr) || (dest->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param dest is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] dest anchor is invalid."); - return GRAPH_FAILED; - } - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "anchor param is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] owner anchor is invalid."); - return GRAPH_FAILED; - } - impl_->peer_anchors_.push_back(dest); - dest->impl_->peer_anchors_.push_back(shared_from_this()); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - impl_->GetOwnerNode()->GetName(), "output:" << impl_->GetIdx(), "", "", - dest->impl_->GetOwnerNode()->GetName() << ":input:" << dest->impl_->GetIdx()); - return GRAPH_SUCCESS; -} - -bool OutControlAnchor::Equal(const AnchorPtr anchor) const { - const auto out_control_anchor = Anchor::DynamicAnchorCast(anchor); - if (out_control_anchor != nullptr) { - if (GetOwnerNodeBarePtr() == out_control_anchor->GetOwnerNodeBarePtr()) { - return true; - } - } - return false; -} - -bool OutControlAnchor::IsTypeOf(const TYPE type) const { - if (strncmp(Anchor::TypeOf(), type, kAnchorTypeMaxLen) == 0) { - return true; - } - return ControlAnchor::IsTypeOf(type); -} - -Anchor::TYPE OutControlAnchor::GetSelfType() const { - return Anchor::TypeOf(); -} - -bool OutControlAnchor::IsTypeIdOf(const TypeId &type) const { - if (GetTypeId() == type) { - return true; - } - return ControlAnchor::IsTypeIdOf(type); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(1); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(2); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(3); -} -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(4); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(5); -} - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(6); -}; - -template<> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { - return reinterpret_cast(7); -} -} // namespace ge diff --git a/graph/normal_graph/compute_graph.cc b/graph/normal_graph/compute_graph.cc deleted file mode 100644 index d9418167bb825cb625a95f2f751145fa787cdaa9..0000000000000000000000000000000000000000 --- a/graph/normal_graph/compute_graph.cc +++ /dev/null @@ -1,2760 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/compute_graph.h" - -#include -#include "graph/ge_context.h" -#include "graph/debug/ge_attr_define.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "common/checker.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "common/ge_common/string_util.h" -#include "common/ge_common/ge_types.h" -#include "graph/utils/tensor_utils.h" -#include "common/util/mem_utils.h" -#include "graph/utils/op_type_utils.h" -#include - -namespace ge { -namespace { -const size_t OUTPUT_PARAM_SIZE = 2UL; -const std::string kMemoryPriority = "MemoryPriority"; - -TopoSortingMode GetTopoSortingStrategy() { - std::string topo_sorting_mode_str; - if ((ge::GetContext().GetOption(ge::OPTION_TOPOSORTING_MODE, topo_sorting_mode_str) == GRAPH_SUCCESS) && - (!topo_sorting_mode_str.empty())) { - const int32_t base = 10; - auto topo_sorting_mode = static_cast(std::strtol(topo_sorting_mode_str.c_str(), nullptr, base)); - if ((topo_sorting_mode >= TopoSortingMode::kBFS) && (topo_sorting_mode < TopoSortingMode::kInvalid)) { - GELOGI("topo_sorting_mode: %s", GetTopoSortingModeStr(topo_sorting_mode)); - return topo_sorting_mode; - } else { - GELOGW("OPTION_TOPOSORTING_MODE = %s is invalid", topo_sorting_mode_str.c_str()); - } - } - - if (ge::GetContext().GetTrainGraphFlag()) { - GELOGI("train flag is 1, use BFS."); - return TopoSortingMode::kBFS; - } - - GELOGI("train flag is 0, use DFS."); - return TopoSortingMode::kDFS; -} - -struct NodeStatus { - size_t size = 0U; - WalkStatus status; -}; - -void InitNodeStatus(const ConstComputeGraphPtr &compute_graph, std::vector &nodes_info) { - nodes_info.clear(); - nodes_info.resize(compute_graph->GetDirectNodesSize()); - int64_t index = 0; - for (const auto &node : compute_graph->GetDirectNode()) { - nodes_info[index].size = 0; - nodes_info[index].status = WalkStatus::kNotWalked; - node->GetOpDesc()->SetId(index); - index++; - } -} - -int64_t GetNodeOutputRealSize(const NodePtr &node, std::vector &nodes_info) { - int64_t total_size = 0; - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - return total_size; - } - NodeStatus &reverse_dfs_node_info = nodes_info[static_cast(node->GetOpDesc()->GetId())]; - total_size = reverse_dfs_node_info.size; - if (total_size != 0) { - return total_size; - } - for (const auto &out_desc : node->GetOpDescBarePtr()->GetAllOutputsDescPtr()) { - if (out_desc == nullptr) { - continue; - } - int64_t output_size = 0; - (void) ge::TensorUtils::CalcTensorMemSize(out_desc->GetShape(), out_desc->GetFormat(), out_desc->GetDataType(), - output_size); - total_size += output_size; - } - if (total_size != 0) { - reverse_dfs_node_info.size = total_size; - } - return total_size; -} - -// 使用节点的输出空间占用大小来排序 -struct NodeCmp { - explicit NodeCmp(std::vector *nodes_info) : nodes_info_(nodes_info) {} - bool operator()(const NodePtr &lhs, const NodePtr &rhs) const { - const auto lhs_size = GetNodeOutputRealSize(lhs, *nodes_info_); - const auto rhs_size = GetNodeOutputRealSize(rhs, *nodes_info_); - if (lhs_size == rhs_size) { - return strcmp(lhs->GetNamePtr(), rhs->GetNamePtr()) > 0; - } - return lhs_size > rhs_size; - } - std::vector *nodes_info_; -}; - -struct NodeOutInfo { - NodeOutInfo(const NodePtr &node, std::vector *nodes_info) - : num_out_data_nodes(node->GetOutDataNodesSize()), - output_size(GetNodeOutputRealSize(node, *nodes_info)), node_name(node->GetName()) {} - - bool operator<(const NodeOutInfo &rhs) const { - if (num_out_data_nodes < rhs.num_out_data_nodes) { - return true; - } - if (num_out_data_nodes > rhs.num_out_data_nodes) { - return false; - } - if (output_size < rhs.output_size) { - return true; - } - if (output_size > rhs.output_size) { - return false; - } - return node_name < rhs.node_name; - } - - int64_t num_out_data_nodes; - int64_t output_size; - std::string node_name; -}; - -bool IsMemoryPriority() { - std::string memory_optimization_policy; - (void) ge::GetContext().GetOption(MEMORY_OPTIMIZATION_POLICY, memory_optimization_policy); - return (memory_optimization_policy == kMemoryPriority); -} - -bool InputIsLongLifeTimeNode(const NodePtr& node, const ConstComputeGraphPtr &graph) { - bool match = false; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - if (in_data_anchor == nullptr) { - continue; - } - const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - const auto &peer_node = peer_out_anchor->GetOwnerNode(); - if (peer_node == nullptr) { - continue; - } - const bool is_io_data = - ((graph.get() == peer_node->GetOwnerComputeGraphBarePtr()) && - (OpTypeUtils::IsDataNode(peer_node->GetType()) || OpTypeUtils::IsConstPlaceHolderNode(peer_node->GetType()))); - std::string op_type; - if ((!NodeUtils::GetConstOpType(peer_node, op_type)) && (!OpTypeUtils::IsVariableNode(peer_node->GetType())) - && (!is_io_data)) { - return false; - } else { - match = true; - } - GELOGD("Node:%s peer:%s type :%s", node->GetName().c_str(), peer_node->GetName().c_str(), - peer_node->GetType().c_str()); - } - return match; -} - -/// variable const -/// \ / -/// first node -/// | -/// middle node -/// | -/// last node -/// / | -/// node1 node2 -graphStatus GetOutNodeIndex(std::vector &nodes, size_t &index, size_t &out_count, - const ConstComputeGraphPtr &graph) { - if (nodes.empty()) { - return GRAPH_FAILED; - } - - // first node's inputs muse be long life time - if ((nodes.size() == 1U) && (!InputIsLongLifeTimeNode(nodes.front(), graph))) { - return GRAPH_FAILED; - } - - const auto &node = nodes.back(); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - // middle node must be single input - if ((nodes.size() != 1U) && (node->GetAllInDataAnchorsSize() != 1U)) { - return GRAPH_FAILED; - } - - int64_t min_index = 0; - Node *delay_node = nullptr; - for (const auto &out_node : node->GetOutAllNodes()) { - out_count++; - GE_CHECK_NOTNULL(out_node); - auto out_node_desc = out_node->GetOpDescBarePtr(); - GE_CHECK_NOTNULL(out_node_desc); - GELOGD("Node:%s id:%ld peer node:%s id:%ld", node->GetName().c_str(), op_desc->GetId(), - out_node_desc->GetName().c_str(), out_node_desc->GetId()); - if ((min_index == 0) || (out_node_desc->GetId() < min_index)) { - min_index = out_node_desc->GetId(); - delay_node = out_node.get(); - } - } - - if (delay_node != nullptr) { - index = static_cast(min_index); - if (index > (static_cast(op_desc->GetId()) + 1U)) { - GELOGD("Node:%s id:%ld delay to:%s id:%zu", node->GetName().c_str(), op_desc->GetId(), - delay_node->GetName().c_str(), index); - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -void DelayTopoSort(std::vector &nodes, const ConstComputeGraphPtr &graph) { - // pair.first: this node can be delay or not - // pair.second: delayed nodes to this node - std::vector>> delay_nodes; - delay_nodes.resize(nodes.size()); - - // set init index - for (size_t i = 0U; i < delay_nodes.size(); ++i) { - nodes[i]->GetOpDescBarePtr()->SetId(static_cast(i)); - delay_nodes[i].first = true; - delay_nodes[i].second.emplace_back(nodes[i]); - } - - // move delayed node to fit node - size_t delay_node_count = 0U; - for (size_t i = 0U; i < delay_nodes.size(); ++i) { - size_t delay_to_index = 0U; - size_t out_count = 0U; - if (delay_nodes[i].first - && (GetOutNodeIndex(delay_nodes[i].second, delay_to_index, out_count, graph) == GRAPH_SUCCESS) - && (delay_to_index < delay_nodes.size()) && (delay_to_index > (i + 1U))) { - delay_nodes[delay_to_index].second.insert(delay_nodes[delay_to_index].second.begin(), - delay_nodes[i].second.begin(), - delay_nodes[i].second.end()); - if (out_count > 1U) { - // last node can not be delay - delay_nodes[delay_to_index].first = false; - } - delay_nodes[i].second.clear(); - delay_node_count++; - } - } - if (delay_node_count > 0U) { - nodes.clear(); - for (size_t i = 0U; i < delay_nodes.size(); ++i) { - if (!delay_nodes[i].second.empty()) { - nodes.insert(nodes.end(), delay_nodes[i].second.begin(), delay_nodes[i].second.end()); - } - } - GELOGI("Delay %zu nodes for %s.", delay_node_count, graph->GetName().c_str()); - } -} - -class TopoSortStack { - public: - explicit TopoSortStack(std::vector *nodes_info, const bool is_mem_priority = false, - const bool is_dfs = false, const bool is_reverse_dfs = false) - : is_mem_priority_(is_mem_priority), is_dfs_(is_dfs), is_reverse_dfs_(is_reverse_dfs), - nodes_info_(nodes_info) {} - - NodePtr Pop() { - if (is_mem_priority_ && (!is_reverse_dfs_)) { - const auto &it = mem_priority_stack_.cbegin(); - const NodePtr node = it->second; - (void) mem_priority_stack_.erase(it); - return node; - } - const NodePtr node = normal_stack_.back(); - normal_stack_.pop_back(); - return node; - } - - void Push(const NodePtr &node) { - if (is_mem_priority_ && (!is_reverse_dfs_)) { - (void) mem_priority_stack_.emplace(NodeOutInfo(node, nodes_info_), node); - return; - } - if (is_dfs_) { - (void) normal_stack_.insert(normal_stack_.end(), node); - } else { - (void) normal_stack_.insert(normal_stack_.begin(), node); - } - } - - bool Empty() { - if (is_mem_priority_ && (!is_reverse_dfs_)) { - return mem_priority_stack_.empty(); - } - return normal_stack_.empty(); - } - - private: - bool is_mem_priority_; - bool is_dfs_; - bool is_reverse_dfs_; - std::vector *nodes_info_; - std::vector normal_stack_; - std::map mem_priority_stack_; -}; - -void AssembleFuseFailReason(const std::vector &nodes, const std::unordered_set attr_set, - const std::string attr_key, std::string &reason_not_support) { - std::stringstream failed_reason; - failed_reason << "Fusion is not supported because there are multiple " << attr_key << " "; - for (const auto &ele : attr_set) { - failed_reason << "[" << ele << "]"; - } - failed_reason << " between nodes "; - for (const auto &node : nodes) { - failed_reason << "[" << node->GetName() << "]"; - } - reason_not_support += failed_reason.str(); - GELOGI("%s.", reason_not_support.c_str()); -} - -std::unordered_set GetAttrStringSet(const std::vector &nodes, const std::string attr_key) { - std::unordered_set attr_set; - for (const auto &node : nodes) { - const auto &op_desc = node->GetOpDesc(); - std::string attr_val; - if (AttrUtils::GetStr(op_desc, attr_key, attr_val)) { - attr_set.emplace(attr_val); - } - } - return attr_set; -} - -std::unordered_set GetUserStreamLabels(const std::vector &nodes) { - return GetAttrStringSet(nodes, public_attr::USER_STREAM_LABEL); -} -/** - * 临时方案: 通过开放public属性对用户开放流编排 - * 正式方案: 后续通过属性组方案实现,该方案将属性进行分组,改图的时候由属性组决定属性处理策略。 - * 预计25年H1落地 - * @param ori_nodes - * @param fusion_ops - * @return - */ -graphStatus InheritUserSteamLabelFromOriginNodes(const std::vector &ori_nodes, - const std::vector &fusion_ops) { - const std::unordered_set origin_stream_labels = GetUserStreamLabels(ori_nodes); - GE_WARN_ASSERT(origin_stream_labels.size() < 2U, - "Inherit user stream label failed, because origin nodes have multiple user stream label."); - if (origin_stream_labels.empty()) { - return GRAPH_SUCCESS; - } - for (const auto &op_desc : fusion_ops) { - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, public_attr::USER_STREAM_LABEL, *origin_stream_labels.begin())); - } - return GRAPH_SUCCESS; -} - -graphStatus InheritSkFromOriginNodes(const std::vector &ori_nodes, - const std::vector &fusion_ops) { - const std::unordered_set scopes = GetAttrStringSet(ori_nodes, ATTR_NAME_SUPER_KERNEL_SCOPE); - const std::unordered_set kernel_options = GetAttrStringSet(ori_nodes, ATTR_NAME_SUPER_KERNEL_OPTIONS); - if (scopes.empty() && kernel_options.empty()) { - return GRAPH_SUCCESS; - } - for (const auto &op_desc : fusion_ops) { - if (!scopes.empty()) { - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, ATTR_NAME_SUPER_KERNEL_SCOPE, *scopes.begin())); - GELOGD("set _super_kernel_scope %s for op %s", scopes.begin()->c_str(), op_desc->GetNamePtr()); - } - if (!kernel_options.empty()) { - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, ATTR_NAME_SUPER_KERNEL_OPTIONS, *kernel_options.begin())); - GELOGD("set _super_kernel_options %s for op %s", kernel_options.begin()->c_str(), op_desc->GetNamePtr()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus InheritCoreNumFromOriginNodes(const std::vector &ori_nodes, const std::vector &fusion_ops) { - std::unordered_set origin_ai_core_nums; - std::unordered_set origin_vector_core_nums; - - for (const auto &node : ori_nodes) { - const auto &op_desc = node->GetOpDesc(); - std::string user_ai_core_num_op; - if (AttrUtils::GetStr(op_desc, public_attr::OP_AI_CORE_NUM, user_ai_core_num_op)) { - origin_ai_core_nums.emplace(user_ai_core_num_op); - } - std::string user_vector_core_num_op; - if (AttrUtils::GetStr(op_desc, public_attr::OP_VECTOR_CORE_NUM, user_vector_core_num_op)) { - origin_vector_core_nums.emplace(user_vector_core_num_op); - } - } - - // 如果所有原始节点都没有设置核数,则不需要继承 - if (origin_ai_core_nums.empty() && origin_vector_core_nums.empty()) { - GELOGI("No need to inherit core num, because origin nodes have no core num."); - return GRAPH_SUCCESS; - } - - GELOGI("Begin to set core num for fusion ops."); - for (const auto &op_desc : fusion_ops) { - if (!origin_ai_core_nums.empty()) { - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, public_attr::OP_AI_CORE_NUM, *origin_ai_core_nums.begin())); - GELOGD("set ai core num %s for op %s", origin_ai_core_nums.begin()->c_str(), op_desc->GetName().c_str()); - } - if (!origin_vector_core_nums.empty()) { - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, public_attr::OP_VECTOR_CORE_NUM, *origin_vector_core_nums.begin())); - GELOGD("set vector core num %s for op %s", origin_vector_core_nums.begin()->c_str(), op_desc->GetName().c_str()); - } - } - return GRAPH_SUCCESS; -} - -} // namespace - -ComputeGraphImpl::ComputeGraphImpl(const std::string &name) - : name_(name), - nodes_(), - input_nodes_(), - sub_graph_(), - is_valid_flag_(false), - need_iteration_(false) { -} - -std::string ComputeGraphImpl::GetName() const { return name_; } - -void ComputeGraphImpl::SetName(const std::string &name) { name_ = name; } - -size_t ComputeGraphImpl::GetAllNodesSize(const ConstComputeGraphPtr &compute_graph) const { - return GetAllNodes(compute_graph).size(); -} - -ComputeGraphImpl::Vistor ComputeGraphImpl::GetAllNodes(const ConstComputeGraphPtr &compute_graph) const { - std::vector> subgraphs; - return AllGraphNodes(subgraphs, compute_graph); -} - -void ComputeGraphImpl::GetAllNodesFromOpdesc(const OpDesc &op_desc, const GraphFilter &graph_filter, - std::deque& candidates, const NodePtr node) const { - const auto &subgraph_names = op_desc.GetSubgraphInstanceNames(); - auto name_iter = subgraph_names.rbegin(); - while (name_iter != subgraph_names.rend()) { - const auto subgraph = GetSubgraph(*name_iter); - if (subgraph != nullptr) { - if ((graph_filter == nullptr) || graph_filter(*node, name_iter->c_str(), subgraph)) { - auto subgraph_nodes = subgraph->GetDirectNode(); - (void) (candidates.insert(candidates.begin(), subgraph_nodes.begin(), subgraph_nodes.end())); - } - } - ++name_iter; - } -} - -ComputeGraphImpl::Vistor ComputeGraphImpl::GetAllNodes(const NodeFilter &node_filter, - const GraphFilter &graph_filter, - const ConstComputeGraphPtr &compute_graph) const { - std::vector all_nodes; - std::deque candidates; - - (void)candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); - while (!candidates.empty()) { - NodePtr node = candidates.front(); - candidates.pop_front(); - - if ((node_filter == nullptr) || node_filter(*node)) { - all_nodes.emplace_back(node); - } - - const auto op_desc = node->GetOpDescBarePtr(); - if (op_desc != nullptr) { - GetAllNodesFromOpdesc(*op_desc, graph_filter, candidates, node); - } - } - - return Vistor(compute_graph, all_nodes); -} - -void inline ComputeGraphImpl::GetAllNodesFromOpdesc(std::vector &subgraphs, const OpDesc &op_desc, - std::deque& candidates) const { - const auto &subgraph_names = op_desc.GetSubgraphInstanceNames(); - auto name_iter = subgraph_names.rbegin(); - while (name_iter != subgraph_names.rend()) { - auto subgraph = GetSubgraph(*name_iter); - if (subgraph != nullptr) { - subgraphs.emplace_back(subgraph); - auto subgraph_nodes = subgraph->GetDirectNode(); - (void) candidates.insert(candidates.begin(), subgraph_nodes.begin(), subgraph_nodes.end()); - } - ++name_iter; - } -} - -void inline ComputeGraphImpl::GetAllNodesPtrFromOpdesc(std::vector &subgraphs, const OpDesc &op_desc, - std::deque& candidates) const { - const auto &subgraph_names = op_desc.GetSubgraphInstanceNames(); - auto name_iter = subgraph_names.rbegin(); - while (name_iter != subgraph_names.rend()) { - auto subgraph = GetSubgraph(*name_iter); - if (subgraph != nullptr) { - subgraphs.emplace_back(subgraph); - auto subgraph_nodes = subgraph->GetDirectNodePtr(); - (void)candidates.insert(candidates.begin(), subgraph_nodes.begin(), subgraph_nodes.end()); - } - ++name_iter; - } -} - -std::vector ComputeGraphImpl::AllGraphNodesPtr(std::vector &subgraphs) const { - std::vector all_nodes; - std::deque candidates; - - for (const auto &node : nodes_) { - (void)candidates.emplace_back(node.get()); - } - while (!candidates.empty()) { - Node *node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - const auto op_desc = node->GetOpDescBarePtr(); - if (op_desc != nullptr) { - GetAllNodesPtrFromOpdesc(subgraphs, *op_desc, candidates); - } - } - - return all_nodes; -} - -ComputeGraphImpl::Vistor ComputeGraphImpl::AllGraphNodes(std::vector &subgraphs, - const ConstComputeGraphPtr &compute_graph) const { - std::vector all_nodes; - std::deque candidates; - - (void)candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); - while (!candidates.empty()) { - NodePtr node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - const auto op_desc = node->GetOpDescBarePtr(); - if (op_desc != nullptr) { - GetAllNodesFromOpdesc(subgraphs, *op_desc, candidates); - } - } - - return Vistor(compute_graph, all_nodes); -} - -ComputeGraphImpl::Vistor ComputeGraphImpl::GetNodes(const bool is_unknown_shape, - const ConstComputeGraphPtr &compute_graph) const { - if (is_unknown_shape) { - return GetDirectNode(compute_graph); - } else { - return GetAllNodes(compute_graph); - } -} - -ComputeGraphImpl::Vistor ComputeGraphImpl::GetNodes(const bool is_unknown_shape, - const NodeFilter &node_filter, - const GraphFilter &graph_filter, - const ConstComputeGraphPtr &compute_graph) const { - return is_unknown_shape ? GetDirectNode(compute_graph) : GetAllNodes(node_filter, graph_filter, compute_graph); -} - -size_t ComputeGraphImpl::GetDirectNodesSize() const { return direct_nodes_size_; } - -ComputeGraphImpl::Vistor ComputeGraphImpl::GetDirectNode(const ConstComputeGraphPtr &compute_graph) const { - return Vistor(compute_graph, nodes_); -} - -std::vector ComputeGraphImpl::GetDirectNodePtr() const { - std::vector direct_nodes; - direct_nodes.reserve(nodes_.size()); - for (const auto &node : nodes_) { - (void)direct_nodes.emplace_back(node.get()); - } - return direct_nodes; -} - -ComputeGraphImpl::Vistor ComputeGraphImpl::GetInputNodes(const ConstComputeGraphPtr &compute_graph) const { - return Vistor(compute_graph, input_nodes_); -} - -ComputeGraphImpl::Vistor ComputeGraphImpl::GetOutputNodes(const ConstComputeGraphPtr &compute_graph) const { - std::vector result; - auto iter = output_nodes_info_.begin(); - while (iter != output_nodes_info_.end()) { - result.push_back(iter->first); - ++iter; - } - return Vistor(compute_graph, result); -} - -NodePtr ComputeGraphImpl::FindNode(const std::string &name) const { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - if (NodeUtils::IsNameEqual(node, name.c_str())) { - return node; - } - } - return nullptr; -} - -NodePtr ComputeGraphImpl::FindFirstNodeMatchType(const std::string &type) const { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - if (NodeUtils::IsTypeEqual(node, type.c_str())) { - return node; - } - } - return nullptr; -} - -bool ComputeGraphImpl::GraphAttrsAreEqual(const ComputeGraphImpl &r_graph) const { - // 整改前实现中,只比较了属性名字,没有比较属性内容,暂时维持这个玩法 - return attrs_.GetAllAttrNames() == r_graph.attrs_.GetAllAttrNames(); -} - -/// Since there may be different input nodes -/// chosen by user in the same graph, special judgment is needed -bool ComputeGraphImpl::VectorInputNodePtrIsEqual(const std::vector &left_nodes, - const std::vector &right_nodes) const { - const auto left_nodes_size = left_nodes.size(); - const auto right_nodes_size = right_nodes.size(); - if (left_nodes_size != right_nodes_size) { - REPORT_INNER_ERR_MSG("E18888", - "Check failed with graph input_nodes_: " - "left inputNodes size %zu is different with right inputNodes size %zu .", - left_nodes_size, right_nodes_size); - GELOGE(GRAPH_FAILED, "[Check][Param] failed with graph input_nodes_: " - "left inputNodes size %zu is different with right inputNodes size %zu .", - left_nodes_size, right_nodes_size); - return false; - } - for (size_t j = 0UL; j < left_nodes_size; j++) { - if ((left_nodes.at(j) == nullptr) || (right_nodes.at(j) == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j); - GELOGE(GRAPH_FAILED, "[Check][Param] left_nodes.at(%zu) or right_nodes.at(%zu) is nullptr", j, j); - return false; - } - const auto &left_input_name = left_nodes.at(j)->GetName(); - const auto &right_input_name = right_nodes.at(j)->GetName(); - if (left_input_name != right_input_name) { - REPORT_INNER_ERR_MSG("E18888", - "Check failed with graph input_nodes_: " - "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.", - left_input_name.c_str(), right_input_name.c_str(), j); - GELOGE(GRAPH_FAILED, "[Check][Param] failed with graph input_nodes_: " - "left inputNode name %s is different with right inputNode name %s at inputNodes index %zu.", - left_input_name.c_str(), right_input_name.c_str(), j); - return false; - } - } - return true; -} - -bool ComputeGraphImpl::GraphMembersAreEqual(const ComputeGraphImpl &r_graph) const { - return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && - IsEqual(this->GetDirectNodesSize(), r_graph.GetDirectNodesSize(), "graph.nodes_.size()") && - VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && - IsEqual(this->name_, r_graph.name_, "graph.name_") && - IsEqual(this->is_valid_flag_, r_graph.is_valid_flag_, "graph.is_valid_flag_") && - IsEqual(this->need_iteration_, r_graph.need_iteration_, "graph.need_iteration_") && - IsEqual(this->params_share_map_, r_graph.params_share_map_, "graph.params_share_map_") && - IsEqual(this->out_nodes_map_, r_graph.out_nodes_map_, "graph.out_nodes_map_") && - IsEqual(this->inputs_order_, r_graph.inputs_order_, "graph.inputs_order_") && - IsEqual(this->output_size_, r_graph.output_size_, "graph.output_size_") && - IsEqual(this->input_size_, r_graph.input_size_, "graph.input_size_") && - IsEqual(this->output_nodes_info_, r_graph.output_nodes_info_, "graph.output_nodes_info_")); -} - -bool ComputeGraphImpl::operator==(const ComputeGraphImpl &r_graph) const { - // Firstly: Graph's members equal - if ((!GraphMembersAreEqual(r_graph)) || (!GraphAttrsAreEqual(r_graph))) { - return false; - } - - // Secondly: Node equal means the link relationship between node and node itself equal - for (const auto &left_node : nodes_) { - if (left_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "left_node is nullptr, graph:%s", this->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] left_node is nullptr"); - return false; - } - const auto &node_name = left_node->GetName(); - // After TopologicalSorting, node order can change, so find node by name - const auto &right_node = r_graph.FindNode(node_name); - GE_IF_BOOL_EXEC(right_node == nullptr, - REPORT_INNER_ERR_MSG("E18888", "left_node:%s not find in r_graph:%s", - node_name.c_str(), r_graph.GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] right_node is NULL!!!"); return false); - if (!((*right_node) == (*left_node))) { - REPORT_INNER_ERR_MSG("E18888", "Compare graph failed, node:%s not equal.", node_name.c_str()); - GELOGE(GRAPH_FAILED, "[Compare][Graph] failed, node:%s not equal.", node_name.c_str()); - return false; - } - } - - // Thirdly: Recursively determine whether the sub graphs are equal - for (size_t i = 0UL; i < this->sub_graph_.size(); i++) { - if (!((*((this->sub_graph_)[i])) == (*((r_graph.sub_graph_)[i])))) { - return false; - } - } - return true; -} - -NodePtr ComputeGraphImpl::AddNodeFront(const NodePtr node) { - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr or op desc should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or op desc should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDescBarePtr()->SetId(static_cast(GetDirectNodesSize())); - if ((GetDirectNodesSize() > 0UL) && ((*(nodes_.begin()))->GetType() == DATA)) { - InsertToNodeList(next(nodes_.begin()), node); - } else { - InsertToNodeList(nodes_.begin(), node); - } - AddInputDataNode(node); - return node; -} - -NodePtr ComputeGraphImpl::AddNodeFront(const OpDescPtr &op, - const ComputeGraphPtr &compute_graph) { - if (op == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The OpDesc ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(static_cast(GetDirectNodesSize())); - const NodePtr node_ptr = std::shared_ptr(new (std::nothrow) Node(op, compute_graph)); - GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "[Create][Node] node_ptr is NULL!!!"); - return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "node %s init failed.", op->GetName().c_str()); - GELOGE(GRAPH_FAILED, "node init fail."); - return nullptr); - return AddNodeFront(node_ptr); -} - -ge::NodePtr ComputeGraphImpl::CreateNodeFromOpDesc(const OpDescPtr &op_desc, - const ComputeGraphPtr &compute_graph, - const int64_t topo_id) { - GE_ASSERT_NOTNULL(op_desc); - op_desc->SetId(topo_id); - const auto node = std::shared_ptr(new (std::nothrow) Node(op_desc, compute_graph)); - GE_ASSERT_NOTNULL(node); - GE_ASSERT_GRAPH_SUCCESS(node->Init()); - node->SetHostNode(is_valid_flag_); - AddInputDataNode(node); - return node; -} - -std::vector ComputeGraphImpl::InsertNodes(const NodePtr &node, - const std::vector &insert_ops, - const ComputeGraphPtr &compute_graph) { - GE_ASSERT_NOTNULL(node); - GE_ASSERT_NOTNULL(node->GetOpDesc()); - const auto topo_id = node->GetOpDesc()->GetId(); - auto iter = std::find(nodes_.begin(), nodes_.end(), node); - GE_ASSERT_TRUE(iter != nodes_.end(), "Cannot find before node: %s in graph: %s", - node->GetName().c_str(), compute_graph->GetName().c_str()); - iter = next(iter); - std::vector insert_nodes; - for (const auto &op_desc : insert_ops) { - auto insert_node = CreateNodeFromOpDesc(op_desc, compute_graph, topo_id); - GE_ASSERT_NOTNULL(insert_node); - InsertToNodeList(iter, insert_node); - insert_nodes.emplace_back(insert_node); - } - return insert_nodes; -} - -NodePtr ComputeGraphImpl::InsertNodeBefore(const NodePtr &node, - const OpDescPtr &insert_op, - const ComputeGraphPtr &compute_graph) { - GE_ASSERT_NOTNULL(node); - GE_ASSERT_NOTNULL(node->GetOpDesc()); - const auto topo_id = node->GetOpDesc()->GetId(); - auto iter = std::find(nodes_.begin(), nodes_.end(), node); - GE_ASSERT_TRUE(iter != nodes_.end(), "Cannot find node: %s in graph: %s", - node->GetNamePtr(), compute_graph->GetName().c_str()); - auto insert_node = CreateNodeFromOpDesc(insert_op, compute_graph, topo_id); - GE_ASSERT_NOTNULL(insert_node); - InsertToNodeList(iter, insert_node); - return insert_node; -} - -NodePtr ComputeGraphImpl::InsertNode(const NodePtr &node, - const OpDescPtr &insert_op, - const ComputeGraphPtr &compute_graph) { - std::vector ops_vec = {insert_op}; - const auto node_vec = InsertNodes(node, ops_vec, compute_graph); - GE_ASSERT_TRUE(!node_vec.empty()); - return node_vec.front(); -} - -bool ComputeGraphImpl::IsSupportFuse(const std::vector &nodes, std::string &reason_not_support) { - const std::unordered_set origin_stream_labels = GetUserStreamLabels(nodes); - if (origin_stream_labels.size() > 1U) { - AssembleFuseFailReason(nodes, origin_stream_labels, public_attr::USER_STREAM_LABEL, reason_not_support); - return false; - } - - const std::unordered_set sk_scopes = GetAttrStringSet(nodes, ATTR_NAME_SUPER_KERNEL_SCOPE); - if (sk_scopes.size() > 1U) { - AssembleFuseFailReason(nodes, sk_scopes, ATTR_NAME_SUPER_KERNEL_SCOPE, reason_not_support); - return false; - } - - const std::unordered_set sk_options = GetAttrStringSet(nodes, ATTR_NAME_SUPER_KERNEL_OPTIONS); - if (sk_options.size() > 1U) { - AssembleFuseFailReason(nodes, sk_options, ATTR_NAME_SUPER_KERNEL_OPTIONS, reason_not_support); - return false; - } - const std::unordered_set aicore_num_options = GetAttrStringSet(nodes, public_attr::OP_AI_CORE_NUM); - if (aicore_num_options.size() > 1U) { - AssembleFuseFailReason(nodes, aicore_num_options, public_attr::OP_AI_CORE_NUM, reason_not_support); - return false; - } - - const std::unordered_set vectorcore_num_options = GetAttrStringSet(nodes, public_attr::OP_VECTOR_CORE_NUM); - if (vectorcore_num_options.size() > 1U) { - AssembleFuseFailReason(nodes, vectorcore_num_options, public_attr::OP_VECTOR_CORE_NUM, reason_not_support); - return false; - } - - return true; -} - -std::vector ComputeGraphImpl::FuseNodeKeepTopo(const std::vector &ori_nodes, - const std::vector &fusion_ops, - const ComputeGraphPtr &compute_graph) { - std::string failed_reason; - GE_WARN_ASSERT(IsSupportFuse(ori_nodes, failed_reason), failed_reason.c_str()); - if (InheritUserSteamLabelFromOriginNodes(ori_nodes, fusion_ops) != GRAPH_SUCCESS) { - GELOGD("Abandoned to fuse nodes because inherit user stream label failed."); - return {}; - } - if (InheritSkFromOriginNodes(ori_nodes, fusion_ops) != GRAPH_SUCCESS) { - GELOGI("Abandoned to fuse nodes because inherit sk attrs failed."); - return {}; - } - if (InheritCoreNumFromOriginNodes(ori_nodes, fusion_ops) != GRAPH_SUCCESS) { - return {}; - } - - auto min_id_node = NodeUtils::GetNodeWithMinimalId(ori_nodes); - GE_ASSERT_NOTNULL(min_id_node); - return InsertNodes(min_id_node, fusion_ops, compute_graph); -} - -NodePtr ComputeGraphImpl::AddNode(const NodePtr node) { - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "the node ptr or op desc ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or op desc ptr should not be null."); - return nullptr; - } - node->SetHostNode(is_valid_flag_); - node->GetOpDescBarePtr()->SetId(static_cast(GetDirectNodesSize())); - PushBackToNodeList(node); - AddInputDataNode(node); - return node; -} - -NodePtr ComputeGraphImpl::AddNode(const OpDescPtr op, const ComputeGraphPtr &compute_graph) { - if (op == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The OpDesc ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(static_cast(GetDirectNodesSize())); - const NodePtr node_ptr = std::shared_ptr(new (std::nothrow) Node(op, compute_graph)); - GE_IF_BOOL_EXEC(node_ptr == nullptr, - REPORT_INNER_ERR_MSG("E18888", "create node failed."); - GELOGE(GRAPH_FAILED, "[Create][Node] node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "node:%s init failed.", op->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Init][Node] %s fail.", op->GetName().c_str()); - return nullptr); - return AddNode(node_ptr); -} - -NodePtr ComputeGraphImpl::AddNode(const OpDescPtr op, const int64_t id, const ComputeGraphPtr &compute_graph) { - if (op == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The OpDesc ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The OpDesc ptr should not be null."); - return nullptr; - } - op->SetId(id); - const NodePtr node = std::shared_ptr(new (std::nothrow) Node(op, compute_graph)); - GE_IF_BOOL_EXEC(node == nullptr, - REPORT_INNER_ERR_MSG("E18888", "create node failed."); - GELOGE(GRAPH_FAILED, "[Create][Node] node_ptr is NULL!!!"); return nullptr); - GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "node init failed."); - GELOGE(GRAPH_FAILED, "[Init][Node] fail."); - return nullptr); - node->SetHostNode(is_valid_flag_); - PushBackToNodeList(node); - AddInputDataNode(node); - return node; -} - -void ComputeGraphImpl::AddInputDataNode(const NodePtr &node) { - if (OpTypeUtils::IsDataNode(node->GetType())) { - if (std::find(input_nodes_.begin(), input_nodes_.end(), node) == input_nodes_.end()) { - input_nodes_.push_back(node); - } - } -} - -NodePtr ComputeGraphImpl::AddInputNode(const NodePtr node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return nullptr; - } - if (std::find(input_nodes_.begin(), input_nodes_.end(), node) == input_nodes_.end()) { - input_nodes_.push_back(node); - } - if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { - GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "[Add][Node] failed"); - } - return node; -} - -NodePtr ComputeGraphImpl::AddOutputNode(const NodePtr node) { - return AddOutputNodeByIndex(node, 0); -} - -NodePtr ComputeGraphImpl::AddOutputNodeByIndex(const NodePtr node, const int32_t index) { - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr or opdesc should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr or opdesc should not be null."); - return nullptr; - } - - bool already_have = false; - NodePtr result = node; - // [output_nodes_info_ : should not be null] - for (const auto &item : output_nodes_info_) { - if ((item.first->GetName() == node->GetName()) && (item.second == index)) { - already_have = true; - result = item.first; - break; - } - } - - if (!already_have) { - output_nodes_info_.emplace_back(std::make_pair(node, index)); - GELOGI("Push back node name:%s, index:%d, into output_nodes_info_.", node->GetName().c_str(), index); - } - - if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { - GE_CHK_BOOL_EXEC(AddNode(node) != nullptr, return nullptr, "[Add][Node] failed"); - } - return result; -} - -graphStatus ComputeGraphImpl::RemoveConstInput(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if ((out_anchor == nullptr) || (out_anchor->GetOwnerNodeBarePtr() == nullptr)) { - continue; - } - if ((out_anchor->GetOwnerNodeBarePtr()->GetType() == CONSTANT) || - (out_anchor->GetOwnerNodeBarePtr()->GetType() == CONSTANTOP)) { - GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED, - "[Remove][Edge] from const op %s failed.", out_anchor->GetOwnerNode()->GetName().c_str()); - if (out_anchor->GetOwnerNode()->GetOutNodes().empty()) { - GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str()); - const auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode()); - if (iter != nodes_.end()) { - EraseFromNodeList(iter); - } - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::RemoveNode(const NodePtr &node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null, graph:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - // delete const op for this node - (void)RemoveConstInput(node); - - // if the node save as input node, delete it - (void)RemoveInputNode(node); - - // if the node save as input node, delete it - (void)RemoveOutputNode(node); - - if (IsolateNode(node) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Isolate][Node] failed, node name: %s, graph:%s.", node->GetName().c_str(), - name_.c_str()); - return GRAPH_FAILED; - } - - const auto iter = find(nodes_.begin(), nodes_.end(), node); - if (iter != nodes_.end()) { - EraseFromNodeList(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -// Used in sub_graph scenes -graphStatus ComputeGraphImpl::RemoveInputNode(const NodePtr &node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null, graph:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - const auto iter = find(input_nodes_.begin(), input_nodes_.end(), node); - if (iter != input_nodes_.end()) { - (void)input_nodes_.erase(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -// Used in sub_graph scenes -graphStatus ComputeGraphImpl::RemoveOutputNode(const NodePtr &node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The node ptr should not be null, graph:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - auto iter = output_nodes_info_.begin(); - bool find_node = false; - // [output_nodes_info_ : should not be null] - while (iter != output_nodes_info_.end()) { - if (node->GetName() == iter->first->GetName()) { - iter = output_nodes_info_.erase(iter); - find_node = true; - } else { - ++iter; - } - } - GE_IF_BOOL_EXEC(!find_node, return GRAPH_FAILED); - return GRAPH_SUCCESS; -} - -std::shared_ptr ComputeGraphImpl::AddSubGraph(const std::shared_ptr &sub_graph) { - if (sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The graph ptr should not be null, graph:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The graph ptr should not be null."); - return nullptr; - } - sub_graph_.push_back(sub_graph); - names_to_subgraph_[sub_graph->GetName()] = sub_graph; - return sub_graph; -} - -graphStatus ComputeGraphImpl::RemoveSubGraph(const std::shared_ptr &sub_graph) { - if (sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The graph ptr should not be null, graph:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The graph ptr should not be null."); - return GRAPH_FAILED; - } - - (void)names_to_subgraph_.erase(sub_graph->GetName()); - const auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); - if (iter != sub_graph_.end()) { - (void)sub_graph_.erase(iter); - } else { - GELOGW("[Remove][Subgraph] find sub_graph failed"); - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::AddSubgraph(const std::string &name, - const std::shared_ptr &subgraph) { - if (subgraph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Try to add a null subgraph, name %s", name.c_str()); - GE_LOGE("[Check][Param] Try to add a null subgraph, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - const auto parent_graph = subgraph->GetParentGraph(); - if (parent_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Try to add subgraph without parent graph, name %s", name.c_str()); - GE_LOGE("[Get][Graph] Try to add subgraph without parent graph, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - const auto parent_node = subgraph->GetParentNode(); - if (parent_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Try to add a subgraph without parent node, name %s", name.c_str()); - GE_LOGE("[Get][Node] Try to add a subgraph without parent node, name %s", name.c_str()); - return GRAPH_PARAM_INVALID; - } - if (parent_node->GetOwnerComputeGraph() != parent_graph) { - REPORT_INNER_ERR_MSG("E18888", - "Try to add a subgraph which parent node's graph is not equal to " - "the subgraph's parent graph, subgraph name %s, parent node name %s", - subgraph->GetName().c_str(), parent_graph->GetName().c_str()); - GE_LOGE("[Check][Param] Try to add a subgraph which parent node's graph is not equal to " - "the subgraph's parent graph, subgraph name %s, parent node name %s", - subgraph->GetName().c_str(), parent_graph->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - if (!this->parent_graph_.expired()) { - GELOGW("[Add][Subgraph] The subgraphs should only be added to the root graph"); - } - if (name != subgraph->GetName()) { - GELOGW("[Add][Subgraph] The subgraph name %s is different with input %s", subgraph->GetName().c_str(), - name.c_str()); - } - if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { - REPORT_INNER_ERR_MSG("E18888", "The subgraph %s existed", name.c_str()); - GE_LOGE("[Check][Param] The subgraph %s existed", name.c_str()); - return GRAPH_PARAM_INVALID; - } - sub_graph_.push_back(subgraph); - names_to_subgraph_[name] = subgraph; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraphImpl::RemoveSubgraph(const std::string &name) { - const std::map::const_iterator iter = names_to_subgraph_.find(name); - if (iter == names_to_subgraph_.cend()) { - return; - } - auto vec_iter = sub_graph_.begin(); - while (vec_iter != sub_graph_.end()) { - if ((*vec_iter) == iter->second) { - (void)sub_graph_.erase(vec_iter); - break; - } - ++vec_iter; - } - (void)names_to_subgraph_.erase(iter); -} - -std::shared_ptr ComputeGraphImpl::GetSubgraph(const std::string &name) const { - const std::shared_ptr parent = parent_graph_.lock(); - if (parent == nullptr) { - const auto iter = names_to_subgraph_.find(name); - return (iter == names_to_subgraph_.end()) ? nullptr : iter->second; - } else { - return parent->GetSubgraph(name); - } -} - -std::vector> ComputeGraphImpl::GetAllSubgraphs() const { - return sub_graph_; -} - -void ComputeGraphImpl::SetAllSubgraphs(const std::vector> &subgraphs) { - sub_graph_ = subgraphs; -} - -shared_ptr ComputeGraphImpl::GetParentGraph() const { - return parent_graph_.lock(); -} - -const ComputeGraph *ComputeGraphImpl::GetParentGraphBarePtr() const { - return parent_graph_bare_ptr_; -} - -void ComputeGraphImpl::SetParentGraph(const std::shared_ptr &parent) { - parent_graph_ = parent; - parent_graph_bare_ptr_ = parent_graph_.lock().get(); -} - -shared_ptr ComputeGraphImpl::GetParentNode() const { - return parent_node_.lock(); -} - -const Node *ComputeGraphImpl::GetParentNodeBarePtr() const { - return parent_node_bare_ptr_; -} - -void ComputeGraphImpl::SetParentNode(const std::shared_ptr &parent) { - parent_node_ = parent; - parent_node_bare_ptr_ = parent_node_.lock().get(); -} - -shared_ptr ComputeGraphImpl::GetOrUpdateNetOutputNode() { - auto graph_netoutput = graph_netoutput_.lock(); - if (graph_netoutput == nullptr || graph_netoutput->GetType() != NETOUTPUT) { - graph_netoutput = FindFirstNodeMatchType(NETOUTPUT); - SetNetOutputNode(graph_netoutput); - } - if (graph_netoutput == nullptr) { - GELOGW("Graph %s has no netoutput node", GetName().c_str()); - } - return graph_netoutput; -} - -void ComputeGraphImpl::SetNetOutputNode(const std::shared_ptr &netoutput_node) { - graph_netoutput_ = netoutput_node; -} - -/// @brief Update input-mapping -/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input -/// @return graphStatus -graphStatus ComputeGraphImpl::UpdateInputMapping(const std::map &input_mapping) { - for (auto &input : nodes_) { - if (input->GetType() == DATA) { - uint32_t cur_index = 0U; - if (!ge::AttrUtils::GetInt(input->GetOpDescBarePtr(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { - continue; - } - const auto iter = input_mapping.find(cur_index); - if (iter == input_mapping.end()) { - continue; - } - if (!ge::AttrUtils::SetInt(input->GetOpDescBarePtr(), ATTR_NAME_PARENT_NODE_INDEX, - static_cast(iter->second))) { - REPORT_INNER_ERR_MSG("E18888", "set attr ATTR_NAME_PARENT_NODE_INDEX failed, op:%s.", - input->GetOpDescBarePtr()->GetName().c_str()); - GE_LOGE("[Call][SetInt] UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed, op:%s.", - input->GetOpDescBarePtr()->GetName().c_str()); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -/// @brief Update output-mapping -/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output -/// @return graphStatus -graphStatus ComputeGraphImpl::UpdateOutputMapping(const std::map &output_mapping) const { - const NodePtr net_output = FindFirstNodeMatchType(NETOUTPUT); - if (net_output == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "UpdateOutputMapping failed: node type %s does not exist in graph.", NETOUTPUT); - GE_LOGE("[Get][NodeType] UpdateOutputMapping failed: node type %s does not exist in graph.", NETOUTPUT); - return GRAPH_FAILED; - } - const auto op_desc = net_output->GetOpDescBarePtr(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "net output's op desc pr should not be null."); - GE_LOGE("[Get][OpDesc] UpdateOutputMapping failed: op_desc is NULL."); - return GRAPH_FAILED; - } - - const size_t num = op_desc->GetAllInputsSize(); - for (size_t i = 0UL; i < num; i++) { - GeTensorDesc tensor = op_desc->GetInputDesc(static_cast(i)); - uint32_t cur_index = 0U; - if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { - continue; - } - const auto iter = output_mapping.find(cur_index); - if (iter == output_mapping.end()) { - continue; - } - if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, static_cast(iter->second))) { - REPORT_INNER_ERR_MSG("E18888", "op %s set %zu input tensor attr ATTR_NAME_PARENT_NODE_INDEX failed.", - op_desc->GetName().c_str(), i); - GE_LOGE("[Set][Int] op %s set %zu input tensor attr ATTR_NAME_PARENT_NODE_INDEX failed.", - op_desc->GetName().c_str(), i); - return GRAPH_FAILED; - } - if (op_desc->UpdateInputDesc(static_cast(i), tensor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "op %s update %zu input_tensor failed.", op_desc->GetName().c_str(), i); - GE_LOGE("[Update][InputDesc] UpdateOutputMapping failed: update %zu input_tensor failed.", i); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::ReorderEventNodes(const ConstComputeGraphPtr &compute_graph) { - std::list &node_list = nodes_; - for (const auto &node : GetDirectNode(compute_graph)) { - if ((strcmp(node->GetTypePtr(), RECV) == 0) || - (strcmp(node->GetTypePtr(), RECV_NOTIFY) == 0)) { - const auto iter = find(node_list.cbegin(), node_list.cend(), node); - if (iter != node_list.cend()) { - (void)node_list.erase(iter); - } - - const auto dst_iter = find(node_list.cbegin(), node_list.cend(), node->GetOutControlNodes().at(0UL)); - (void)node_list.insert(dst_iter, node); - } - if ((strcmp(node->GetTypePtr(), SEND) == 0) || - (strcmp(node->GetTypePtr(), SEND_NOTIFY) == 0)) { - const auto iter = find(node_list.cbegin(), node_list.cend(), node); - if (iter != node_list.cend()) { - (void)node_list.erase(iter); - } - - auto src_iter = find(node_list.cbegin(), node_list.cend(), node->GetInControlNodes().at(0UL)); - (void)node_list.insert(++src_iter, node); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::InsertGraphEvents(const ConstComputeGraphPtr &compute_graph) { - auto status = ReorderEventNodes(compute_graph); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Graph [%s] record event nodes failed, status:%u", name_.c_str(), status); - GELOGE(status, "[Reorder][EventNodes] failed for Graph:%s, status:%u", name_.c_str(), status); - return status; - } - - // Partition subgraph - for (const auto &graph : sub_graph_) { - status = graph->ReorderEventNodes(); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "ReorderEventNodes failed for SubGraph:%s, status:%u", graph->GetName().c_str(), - status); - GELOGE(status, "[Reorder][EventNodes] failed for SubGraph:%s, status:%u", graph->GetName().c_str(), status); - return status; - } - } - - std::vector subgraphs; - const auto nodes = AllGraphNodes(subgraphs, compute_graph); - for (size_t i = 0UL; i < nodes.size(); ++i) { - const NodePtr node = nodes.at(i); // [node: should not be null] - node->GetOpDescBarePtr()->SetId(static_cast(i)); // [node->GetOpDescBarePtr(): should not be null] - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::DFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const { - GELOGI("Runing_Dfs_Sort, reverse: %d, graph: %s", reverse, name_.c_str()); - std::vector stack; - std::map map_in_edge_num; - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num, compute_graph) == GRAPH_SUCCESS, - return GRAPH_FAILED, "sort nodes failed"); - const bool is_mem_priority = IsMemoryPriority(); - std::vector nodes_info; - if (is_mem_priority) { - InitNodeStatus(compute_graph, nodes_info); - } - TopoSortStack topo_sort_stack(&nodes_info, is_mem_priority, true, reverse); - for (const auto &node : stack) { - topo_sort_stack.Push(node); - } - std::vector out_nodes; - const auto stack_push = [&reverse, &topo_sort_stack](std::vector& tmp_out_nodes) { - if (reverse) { - std::reverse(tmp_out_nodes.begin(), tmp_out_nodes.end()); - } - for (const auto &node: tmp_out_nodes) { - topo_sort_stack.Push(node); - } - tmp_out_nodes.clear(); - }; - // Only data nodes here - while (!topo_sort_stack.Empty()) { - const NodePtr node = topo_sort_stack.Pop(); - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDescBarePtr()); - GELOGD("node_vec.push_back %s", node->GetOpDescBarePtr()->GetName().c_str()); - for (const auto &anchor : node->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(anchor); - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - GetOutNodesFromAnchor(peer_in_anchor, map_in_edge_num, out_nodes); - } - stack_push(out_nodes); - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - GetOutNodesFromAnchor(peer_in_anchor, map_in_edge_num, out_nodes); - } - stack_push(out_nodes); - } - GE_IF_BOOL_EXEC(node->GetOutControlAnchor() != nullptr, - for (const AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - GetOutNodesFromAnchor(peer_in_anchor, map_in_edge_num, out_nodes); - } - stack_push(out_nodes);) - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::StableRDFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const { - (void) reverse; - GELOGI("Runing_Stable_Reverse_Dfs_Sort: %s", name_.c_str()); - std::vector nodes_info; - InitNodeStatus(compute_graph, nodes_info); - - for (const auto &node : compute_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - GE_CHECK_NOTNULL(node->GetOpDescBarePtr()); - // id作为索引前面的init初始化保证了一定是有效的, walked的节点不入栈 - if (nodes_info[node->GetOpDesc()->GetId()].status == WalkStatus::kWalked) { - continue; - } - // 按照原有的nodes的topo顺序来入栈 - std::vector stack = {node.get()}; - while (!stack.empty()) { - const auto current = stack.back(); - NodeStatus &reverse_dfs_node_info = nodes_info[current->GetOpDesc()->GetId()]; - if (reverse_dfs_node_info.status == WalkStatus::kNotWalked) { - reverse_dfs_node_info.status = WalkStatus::kWalking; - // 获取输入节点,反向遍历 - const auto in_all_nodes = current->GetInNodesPtr(); - if (in_all_nodes.empty()) { - continue; - } - std::vector in_nodes_has_not_been_walked; - in_nodes_has_not_been_walked.reserve(in_all_nodes.size()); - for (const auto in_node: in_all_nodes) { - if (nodes_info[in_node->GetOpDesc()->GetId()].status == WalkStatus::kNotWalked) { - in_nodes_has_not_been_walked.push_back(in_node); - } - } - - auto cmp = [](const Node *lhs, const Node *rhs) { - // not null - return lhs->GetOpDescBarePtr()->GetId() > rhs->GetOpDescBarePtr()->GetId(); - }; - // 输入节点的排序使用原始的顺序,可以保证原有topo如果满足当前图的遍历关系,最大程度的保留下来 - std::set - input_nodes{in_nodes_has_not_been_walked.begin(), in_nodes_has_not_been_walked.end(), cmp}; - stack.insert(stack.end(), input_nodes.cbegin(), input_nodes.cend()); - } else { - stack.pop_back(); - if (reverse_dfs_node_info.status != WalkStatus::kWalking) { - continue; - } - reverse_dfs_node_info.status = WalkStatus::kWalked; - node_vec.emplace_back(current->shared_from_this()); - GE_CHECK_NOTNULL(current->GetOpDescBarePtr()); - GELOGD("node_vec.push_back %s", current->GetOpDescBarePtr()->GetName().c_str()); - } - } - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::RDFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const { - (void) reverse; - GELOGI("Runing_Reverse_Dfs_Sort: %s", name_.c_str()); - std::vector nodes_info; - InitNodeStatus(compute_graph, nodes_info); - - for (const auto &node : compute_graph->GetDirectNode()) { - if (node->GetOutNodesSize() > 0U) { - continue; - } - - std::vector stack = {node}; - while (!stack.empty()) { - const auto current = stack.back(); - NodeStatus &reverse_dfs_node_info = nodes_info[current->GetOpDesc()->GetId()]; - if (reverse_dfs_node_info.status == WalkStatus::kNotWalked) { - reverse_dfs_node_info.status = WalkStatus::kWalking; - - const auto in_all_nodes = current->GetInAllNodes(); - NodeCmp cmp(&nodes_info); - std::set input_nodes{in_all_nodes.begin(), in_all_nodes.end(), cmp}; - stack.insert(stack.end(), input_nodes.cbegin(), input_nodes.cend()); - continue; - } - stack.pop_back(); - if (reverse_dfs_node_info.status == WalkStatus::kWalking) { - reverse_dfs_node_info.status = WalkStatus::kWalked; - node_vec.emplace_back(current); - GE_CHECK_NOTNULL(current->GetOpDescBarePtr()); - GELOGD("node_vec.push_back %s", current->GetOpDescBarePtr()->GetName().c_str()); - } - } - } - - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::BFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const { - GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); - (void) reverse; - const bool is_mem_priority = IsMemoryPriority(); - std::vector nodes_info; - if (is_mem_priority) { - InitNodeStatus(compute_graph, nodes_info); - } - TopoSortStack topo_sort_stack(&nodes_info, is_mem_priority); - std::vector stack_input; - std::map breadth_node_map; - std::map map_in_edge_num; - // Record the number of non data nodes but no input nodes - GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num, compute_graph) == GRAPH_SUCCESS, - return GRAPH_FAILED, "sort nodes failed"); - - // Only data nodes here - while ((!stack_input.empty()) || (!topo_sort_stack.Empty())) { - NodePtr node = nullptr; - if (!topo_sort_stack.Empty()) { - node = topo_sort_stack.Pop(); - } else { - node = stack_input.back(); - stack_input.pop_back(); - } - - node_vec.push_back(node); - GE_CHECK_NOTNULL(node->GetOpDescBarePtr()); - GELOGD("node_vec.push_back %s", node->GetOpDescBarePtr()->GetName().c_str()); - (void)CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); - - for (const auto &name_node : breadth_node_map) { - (void) topo_sort_stack.Push(name_node.second); - } - breadth_node_map.clear(); - } - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map) const { - for (const auto &anchor : node->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - const auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end()) { - --iter->second; - if (iter->second == 0U) { - (void) breadth_node_map.emplace(peer_in_anchor->GetOwnerNodeBarePtr()->GetName(), - peer_in_anchor->GetOwnerNode()); - } - } - } - - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - const auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end()) { - --iter->second; - if (iter->second == 0U) { - (void) breadth_node_map.emplace(peer_in_anchor->GetOwnerNodeBarePtr()->GetName(), - peer_in_anchor->GetOwnerNode()); - } - } - } - } - if (node->GetOutControlAnchor() != nullptr) { - for (const auto peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchorsPtr()) { - const auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end()) { - --iter->second; - if (iter->second == 0U) { - (void) breadth_node_map.emplace(peer_in_anchor->GetOwnerNodeBarePtr()->GetName(), - peer_in_anchor->GetOwnerNode()); - } - } - } - } - return GRAPH_SUCCESS; -} - -void ComputeGraphImpl::TopologicalSorting(const std::function comp) { - nodes_.sort(std::move(comp)); - int64_t num = 0; - for (const NodePtr &node : nodes_) { - node->GetOpDescBarePtr()->SetId(num++); // node should not be null, node->GetOpDescBarePtr() should not be null] - } -} - -graphStatus ComputeGraphImpl::TopologicalSorting(const ComputeGraphPtr &const_graph_ptr, - const ConstComputeGraphPtr &const_compute_graph) { - auto ret = TopologicalSortingGraph(const_compute_graph); - if (ret != GRAPH_SUCCESS) { - GE_DUMP(const_graph_ptr, "black_box" + name_); - REPORT_INNER_ERR_MSG("E18888", "Graph [%s] topological sort failed, saved to file black_box", name_.c_str()); - GELOGW("[Sort][Graph] Graph [%s] topological sort failed, saved to file black_box", name_.c_str()); - return ret; - } - - if (sub_graph_.empty()) { - return GRAPH_SUCCESS; - } - - // partition sub graph - for (const auto &sub_graph : sub_graph_) { - ret = sub_graph->TopologicalSortingGraph(); - if (ret != GRAPH_SUCCESS) { - GE_DUMP(sub_graph, "black_box" + sub_graph->GetName()); - REPORT_INNER_ERR_MSG("E18888", "Sub graph[%s] topological sort failed, saved to file black_box", - sub_graph->GetName().c_str()); - GELOGW("[Sort][Graph] Sub graph[%s] topological sort failed, saved to file black_box", - sub_graph->GetName().c_str()); - return ret; - } - } - - std::vector> subgraphs; - auto nodes = AllGraphNodes(subgraphs, const_compute_graph); - for (size_t i = 0UL; i < nodes.size(); i++) { - const NodePtr node = nodes.at(i); // [node: should not be null] - node->GetOpDescBarePtr()->SetId(static_cast(i)); // [node->GetOpDescBarePtr(): should not be null] - } - if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original - GELOGW("[TopoSort][CheckNodeSize] Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), - subgraphs.size()); - return GRAPH_SUCCESS; - } - sub_graph_.swap(subgraphs); - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::DoTopologicalSorting(const ConstComputeGraphPtr &compute_graph, - TopoSortingMode sorting_mode, - bool dfs_reverse) { - using TopoSortingStrategy = - std::function &, const bool, const ConstComputeGraphPtr &)>; - static const std::map topo_sorting_strategy{ - {TopoSortingMode::kBFS, &ComputeGraphImpl::BFSTopologicalSorting}, - {TopoSortingMode::kDFS, &ComputeGraphImpl::DFSTopologicalSorting}, - {TopoSortingMode::kRDFS, &ComputeGraphImpl::RDFSTopologicalSorting}, - {TopoSortingMode::kStableRDFS, &ComputeGraphImpl::StableRDFSTopologicalSorting}}; - - std::vector node_vec; - const auto it = topo_sorting_strategy.find(sorting_mode); - if (it == topo_sorting_strategy.end()) { - GELOGE(GRAPH_FAILED, "Can not find topo sorting strategy of %d.", static_cast(sorting_mode)); - return GRAPH_FAILED; - } - if (it->second(this, node_vec, dfs_reverse, compute_graph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - // If they are not equal, there is a closed loop - if (node_vec.size() != GetDirectNodesSize()) { - std::set itered_nodes_set; - for (auto &node : node_vec) { - (void) itered_nodes_set.insert(node.get()); - } - REPORT_INNER_ERR_MSG("E18888", "Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph:%s", - GetDirectNodesSize(), node_vec.size(), name_.c_str()); - GELOGW("[Check][Param] Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", - GetDirectNodesSize(), node_vec.size()); - for (auto &node : nodes_) { - if (itered_nodes_set.count(node.get()) == 0UL) { - GELOGW("[Check][Param] The node %s does not itered when topological sorting", node->GetName().c_str()); - } - } - return GRAPH_FAILED; - } - - ClearNodeList(); - if ((IsMemoryPriority() && (sorting_mode != TopoSortingMode::kStableRDFS)) || - (sorting_mode == TopoSortingMode::kRDFS)) { - DelayTopoSort(node_vec, compute_graph); - } - for (size_t i = 0UL; i < node_vec.size(); i++) { - const NodePtr node = node_vec[i]; // [node: should not be null] - node->GetOpDescBarePtr()->SetId(static_cast(i)); // [node->GetOpDescBarePtr(): should not be null] - PushBackToNodeList(node); - } - - is_valid_flag_ = true; - return GRAPH_SUCCESS; -} - -graphStatus ComputeGraphImpl::TopologicalSortingGraph(const ConstComputeGraphPtr &compute_graph, - const bool dfs_reverse) { - return DoTopologicalSorting(compute_graph, GetTopoSortingStrategy(), dfs_reverse); -} - -graphStatus ComputeGraphImpl::TopologicalSortingGraph(const ConstComputeGraphPtr &compute_graph, - TopoSortingMode topo_sorting_mode) { - return DoTopologicalSorting(compute_graph, topo_sorting_mode, false); -} - -graphStatus ComputeGraphImpl::SortNodes(std::vector &stack, - std::map &map_in_edge_num, - const ConstComputeGraphPtr &compute_graph) const { - // Record the number of non data nodes but no input nodes - uint32_t spec_node_size = 0U; - for (const auto &node : GetDirectNode(compute_graph)) { - GE_IF_BOOL_EXEC(node->GetOpDescBarePtr() == nullptr, continue); - map_in_edge_num[node] = static_cast(GetInEdgeSize(node)); - if (map_in_edge_num[node] == 0U) { - if ((!OpTypeUtils::IsDataNode(node->GetOpDescBarePtr()->GetType())) && - (node->GetOpDescBarePtr()->GetType() != INPUT_TYPE) && (node->GetOpDescBarePtr()->GetType() != RECV) - && (node->GetOpDescBarePtr()->GetType() != SEND)) { - (void)stack.insert(stack.begin(), node); - spec_node_size++; - continue; - } - // Need to insert the data nodes in reverse order - (void)stack.insert(stack.begin() + static_cast(spec_node_size), node); - } - } - - /// Make sure the inputs order matches with user-designated - /// 1. Get the index of two input nodes in the user-inputs-order(inputs_order_) - /// 2. Compare two indices, if not match, swap the positions of two inputs - /// *: Remind: stack is reverse-order - for (size_t i = 0UL; i < stack.size(); ++i) { - // If not found in 'inputs_order_', skip it - const auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); - GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); - const auto inx_i = it_i - inputs_order_.begin(); - for (size_t j = i + 1UL; j < stack.size(); ++j) { - // If not found in 'inputs_order_', skip it - const auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); - GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); - - // Compare index, swap them if it should be - const auto inx_j = it_j - inputs_order_.begin(); - GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); - } - } - - return GRAPH_SUCCESS; -} - -size_t ComputeGraphImpl::GetInEdgeSize(const NodePtr &node) const { - size_t in_edge_size = 0UL; - if (node == nullptr) { - return in_edge_size; - } - for (const auto &anchor : node->GetAllInDataAnchors()) { - in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); - // Break flow control data loop. - const OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); - if ((out_anchor != nullptr) && (out_anchor->GetOwnerNodeBarePtr() != nullptr)) { - const auto out_node = out_anchor->GetOwnerNodeBarePtr(); - if ((out_node->GetType() == NEXTITERATION) || (out_node->GetType() == REFNEXTITERATION)) { - GE_IF_BOOL_EXEC(in_edge_size == 0UL, - GELOGE(GRAPH_FAILED, "[Check][Param] If [in_edge_size = 0], the result will be reversed"); - return in_edge_size); - in_edge_size -= 1UL; - } - } - } - if (node->GetInControlAnchor() != nullptr) { - in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); - } - return in_edge_size; -} - -size_t ComputeGraphImpl::GetOutEdgeSize(const NodePtr &node) const { - size_t out_edge_size = 0UL; - if (node == nullptr) { - return out_edge_size; - } - - // Break flow control data loop. - if ((node->GetType() != NEXTITERATION) && (node->GetType() != REFNEXTITERATION)) { - for (const auto &anchor : node->GetAllOutDataAnchors()) { - if (anchor != nullptr) { - out_edge_size = out_edge_size + anchor->GetPeerAnchorsSize(); - } - } - } - if (node->GetOutControlAnchor() != nullptr) { - if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchorsSize())) { - return 0UL; - } - out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchorsSize(); - } - return out_edge_size; -} - -bool ComputeGraphImpl::IsValid() const { return is_valid_flag_; } - -void ComputeGraphImpl::InValid() { is_valid_flag_ = false; } - -void ComputeGraphImpl::Dump(const ConstComputeGraphPtr &graph) const { - if (!IsLogEnable(GE_MODULE_NAME, DLOG_INFO)) { - return; - } - - GELOGI("graph name = %s.", GetName().c_str()); - for (const auto &node : GetAllNodes(graph)) { - GELOGD("node name = %s.", node->GetName().c_str()); - for (const auto &anchor : node->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC((peer_in_anchor != nullptr) && (peer_in_anchor->GetOwnerNode() != nullptr), - GELOGI("node name = %s, out data node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - for (const auto &peer_in_anchor : anchor->GetPeerInControlAnchors()) { - GE_IF_BOOL_EXEC((peer_in_anchor != nullptr) && (peer_in_anchor->GetOwnerNode() != nullptr), - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - } - const auto out_control_anchor = node->GetOutControlAnchor(); - if (out_control_anchor != nullptr) { - for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { - GE_IF_BOOL_EXEC((peer_in_anchor != nullptr) && (peer_in_anchor->GetOwnerNode() != nullptr), - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC((peer_in_anchor != nullptr) && (peer_in_anchor->GetOwnerNode() != nullptr), - GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str())); - } - } - } -} - -void ComputeGraphImpl::Swap(ComputeGraphImpl &graph) { - origGraph_.swap(graph.origGraph_); - - name_.swap(graph.name_); - std::swap(graph_id_, graph.graph_id_); - attrs_.Swap(graph.attrs_); - nodes_.swap(graph.nodes_); - const auto tmp_size = direct_nodes_size_; - direct_nodes_size_ = graph.direct_nodes_size_; - graph.direct_nodes_size_ = tmp_size; - all_nodes_infos_.swap(graph.all_nodes_infos_); - target_nodes_info_.swap(graph.target_nodes_info_); - - input_nodes_.swap(graph.input_nodes_); - inputs_order_.swap(graph.inputs_order_); - std::swap(input_size_, graph.input_size_); - out_nodes_map_.swap(graph.out_nodes_map_); - std::swap(output_size_, graph.output_size_); - output_nodes_info_.swap(graph.output_nodes_info_); - - sub_graph_.swap(graph.sub_graph_); - names_to_subgraph_.swap(graph.names_to_subgraph_); - parent_graph_.swap(graph.parent_graph_); - parent_graph_bare_ptr_ = parent_graph_.lock().get(); - parent_node_.swap(graph.parent_node_); - parent_node_bare_ptr_ = parent_node_.lock().get(); - graph_netoutput_.swap(graph.graph_netoutput_); - - // the members followed should not in the ComputeGraphImpl class - std::swap(is_valid_flag_, graph.is_valid_flag_); - std::swap(is_summary_graph_, graph.is_summary_graph_); - std::swap(need_iteration_, graph.need_iteration_); - params_share_map_.swap(graph.params_share_map_); - op_name_map_.swap(graph.op_name_map_); - std::swap(session_id_, graph.session_id_); - std::swap(data_format_, graph.data_format_); -} - -void ComputeGraphImpl::SetNodesOwner(const ComputeGraphPtr &compute_graph) { - for (const auto &node : nodes_) { - if (node == nullptr) { - continue; - } - (void)node->SetOwnerComputeGraph(compute_graph); - } -} - -void ComputeGraphImpl::SetTopParentGraph(const ComputeGraphPtr &compute_graph) { - for (const auto &sub_graph : sub_graph_) { - if ((sub_graph == nullptr) || (sub_graph->GetParentGraph() == nullptr) || - (sub_graph->GetParentGraph()->GetParentGraph() != nullptr)) { - continue; - } - (void)sub_graph->SetParentGraph(compute_graph); - } -} - -graphStatus ComputeGraphImpl::IsolateNode(const NodePtr &node) const { - GE_CHECK_NOTNULL(node); - const auto next_nodes = node->GetOutAllNodes(); - // If there is input data side - for (size_t i = 0UL; i < node->GetAllInDataAnchors().size(); i++) { - const auto in_data_anchor = node->GetInDataAnchor(static_cast(i)); - GE_CHECK_NOTNULL(in_data_anchor); - const auto pre_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (pre_out_data_anchor != nullptr) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_data_anchor, in_data_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - in_data_anchor->GetOwnerNode()->GetName().c_str()); - GE_IF_BOOL_EXEC((pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANT) || - (pre_out_data_anchor->GetOwnerNode()->GetType() == CONSTANTOP), - continue); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_data_anchor->GetOwnerNode()->GetName().c_str()); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add edge from %s to %s failed", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Add][Edge] from %s to %s failed", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_data_anchor->GetOwnerNode()->GetName().c_str()); - } - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add edge from %s to %s failed", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Add][Edge] from %s to %s failed", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - } - } - const auto out_ctrl_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(out_ctrl_anchor); - const auto pre_out_ctrl_anchor = pre_out_data_anchor->GetOwnerNodeBarePtr()->GetOutControlAnchor(); - GE_CHECK_NOTNULL(pre_out_ctrl_anchor); - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add edge from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Add][Edge] from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - } - } - } - - // If there is an input control side - const auto in_ctrl_anchor = node->GetInControlAnchor(); - GE_CHECK_NOTNULL(in_ctrl_anchor); - for (const auto &pre_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(pre_out_ctrl_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add edge from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Add][Edge] from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - } - } - const auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(pre_out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add edge from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Add][Edge] from %s to %s failed", - pre_out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - } - } - } - - for (const auto &out_peer_data_anchor : in_ctrl_anchor->GetPeerOutDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_peer_data_anchor, in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_peer_data_anchor->GetOwnerNode()->GetName().c_str(), - in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_peer_data_anchor->GetOwnerNode()->GetName().c_str(), - in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - for (const auto &next_node : next_nodes) { - const auto next_in_control_anchor = next_node->GetInControlAnchor(); - GE_CHK_BOOL_EXEC(GraphUtils::AddEdge(out_peer_data_anchor, next_in_control_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add edge from %s to %s failed", - out_peer_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_control_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Add][Edge] from %s to %s failed", - out_peer_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_control_anchor->GetOwnerNode()->GetName().c_str()); - } - } - - return RemoveExtraOutEdge(node); -} - -graphStatus ComputeGraphImpl::RemoveExtraOutEdge(const NodePtr &node) const { - GE_CHECK_NOTNULL(node); - // Remove redundant output edges - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &next_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_data_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_data_anchor->GetOwnerNode()->GetName().c_str()); - } - - for (const auto &next_in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_data_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_data_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - } - } - const auto out_ctrl_anchor = node->GetOutControlAnchor(); - if (out_ctrl_anchor != nullptr) { - for (const auto &next_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveEdge(out_ctrl_anchor, next_in_ctrl_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "remove edge from %s to %s failed", - out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED, "[Remove][Edge] from %s to %s failed", - out_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - next_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - } - } - return GRAPH_SUCCESS; -} - -ProtoAttrMap &ComputeGraphImpl::MutableAttrMap() { - return attrs_; -} - -ConstProtoAttrMap &ComputeGraphImpl::GetAttrMap() const { - return attrs_; -} - -const std::map &ComputeGraphImpl::GetAllNodesInfo() const { return all_nodes_infos_; } - -void ComputeGraphImpl::SetUserDefOutput(const std::string &output_name) { - if (output_name.empty()) { - return; - } - - const std::vector nodes = StringUtils::Split(output_name, ';'); - for (const std::string &node : nodes) { - std::vector item = StringUtils::Split(node, ':'); - if (item.size() != OUTPUT_PARAM_SIZE) { - REPORT_INNER_ERR_MSG("W18888", "Check output param size failed, output_name:%s", output_name.c_str()); - GELOGW("[Check][Output] Check output param size failed, output_name:%s", output_name.c_str()); - continue; - } - - int32_t index; - try { - index = stoi(StringUtils::Trim(item[1UL])); - } catch (const std::out_of_range &) { - REPORT_INNER_ERR_MSG("W18888", "Catch out_of_range exception, output_name:%s", output_name.c_str()); - GELOGW("[Catch][Exception] Catch out_of_range exception, output_name:%s", output_name.c_str()); - continue; - } catch (const std::invalid_argument &) { - REPORT_INNER_ERR_MSG("W18888", "Catch invalid_argument exception, output_name:%s", output_name.c_str()); - GELOGW("[Catch][Exception] Catch invalid_argument exception, output_name:%s", output_name.c_str()); - continue; - } catch (...) { - REPORT_INNER_ERR_MSG("W18888", "Catch exception, output_name:%s", output_name.c_str()); - GELOGW("[Catch][Exception] Catch exception, output_name:%s", output_name.c_str()); - continue; - } - const auto iter = out_nodes_map_.find(item[0UL]); - if (iter == out_nodes_map_.end()) { - out_nodes_map_[item[0UL]] = std::vector(1UL, index); - } else { - const auto idx_iter = std::find(iter->second.begin(), iter->second.end(), index); - if (idx_iter == iter->second.end()) { - iter->second.push_back(index); - } - } - } -} - -const std::string ComputeGraphImpl::GetOutput() { - static const int32_t resultDefaultSize = 2048; - std::string result; - result.reserve(static_cast(resultDefaultSize)); - auto iter = out_nodes_map_.begin(); - while (iter != out_nodes_map_.end()) { - const auto idxes = iter->second; - for (const auto idx : idxes) { - (void)result.append(iter->first).append(":").append(std::to_string(idx)).append(";"); - } - ++iter; - } - - return result.substr(0UL, result.length() - 1UL); -} - - -void ComputeGraphImpl::EraseFromNodeList(const std::list::iterator &position) { - (void) nodes_.erase(position); - --direct_nodes_size_; -} - -void ComputeGraphImpl::InsertToNodeList(const std::list::iterator &position, const NodePtr &node) { - (void) nodes_.insert(position, node); - ++direct_nodes_size_; -} - -void ComputeGraphImpl::PushBackToNodeList(const NodePtr &node) { - (void) nodes_.push_back(node); - ++direct_nodes_size_; -} - -void ComputeGraphImpl::EmplaceBackToNodeList(const NodePtr &node) { - (void) nodes_.emplace_back(node); - ++direct_nodes_size_; -} - -void ComputeGraphImpl::ClearNodeList() { - (void) nodes_.clear(); - direct_nodes_size_ = 0UL; -} - -void ComputeGraphImpl::ReorderByNodeId() { - std::vector node_vec(nodes_.begin(), nodes_.end()); - std::sort(node_vec.begin(), node_vec.end(), [](const NodePtr &lhs, const NodePtr &rhs) { - return lhs->GetOpDesc()->GetId() < rhs->GetOpDesc()->GetId(); - }); - ClearNodeList(); - for (const auto &node : node_vec) { - PushBackToNodeList(node); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) - : enable_shared_from_this(), - AttrHolder(), - impl_(ComGraphMakeSharedAndThrow(name)) {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const char_t *name) - : ComputeGraph(std::string((name == nullptr) ? "" : name)) {} - -ComputeGraph::~ComputeGraph() {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const ge::ComputeGraph& compute_graph) - : enable_shared_from_this(), - AttrHolder(compute_graph), - impl_(ComGraphMakeSharedAndThrow(*(compute_graph.impl_))) {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(ge::ComputeGraph&& compute_graph) - : enable_shared_from_this(), - AttrHolder(std::move(compute_graph)), - impl_(ComGraphMakeSharedAndThrow(std::move(*(compute_graph.impl_)))) {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string ComputeGraph::GetName() const { return impl_->GetName(); } - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const std::string &name) { - impl_->SetName(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { - return GetAllNodesPtr().size(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { - std::vector> subgraphs; - return AllGraphNodes(subgraphs); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector ComputeGraph::GetAllNodesPtr() const { - std::vector> subgraphs; - return AllGraphNodesPtr(subgraphs); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor -ComputeGraph::GetAllNodes(const NodeFilter &node_filter, const GraphFilter &graph_filter) const { - return impl_->GetAllNodes(node_filter, graph_filter, shared_from_this()); -} - -ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector &subgraphs) const { - return impl_->AllGraphNodes(subgraphs, shared_from_this()); -} - -std::vector ComputeGraph::AllGraphNodesPtr(std::vector &subgraphs) const { - return impl_->AllGraphNodesPtr(subgraphs); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ComputeGraph::Vistor ComputeGraph::GetNodes(const bool is_unknown_shape) const { - return impl_->GetNodes(is_unknown_shape, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor -ComputeGraph::GetNodes(const bool is_unknown_shape, const NodeFilter &node_filter, - const GraphFilter &graph_filter) const { - return impl_->GetNodes(is_unknown_shape, node_filter, graph_filter, shared_from_this()); -} - -size_t ComputeGraph::GetDirectNodesSize() const { - return impl_->GetDirectNodesSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { - return impl_->GetDirectNode(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector ComputeGraph::GetDirectNodePtr() const { - return impl_->GetDirectNodePtr(); -} - -ComputeGraph::Vistor ComputeGraph::GetInputNodes() const { - return impl_->GetInputNodes(shared_from_this()); -} - -ComputeGraph::Vistor ComputeGraph::GetOutputNodes() const { - return impl_->GetOutputNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { - return impl_->FindNode(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -NodePtr ComputeGraph::FindFirstNodeMatchType(const std::string &name) const { - return impl_->FindFirstNodeMatchType(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphAttrsAreEqual( - const ComputeGraph &r_graph) const { - return impl_->GraphAttrsAreEqual(*(r_graph.impl_)); -} - -/// Since there may be different input nodes -/// chosen by user in the same graph, special judgment is needed -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNodePtrIsEqual( - const std::vector &left_nodes, const std::vector &right_nodes) const { - return impl_->VectorInputNodePtrIsEqual(left_nodes, right_nodes); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( - const ComputeGraph &r_graph) const { - return impl_->GraphMembersAreEqual(*(r_graph.impl_)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::operator==( - const ComputeGraph &r_compute_graph) const { - return (*impl_) == (*(r_compute_graph.impl_)); -} - -ComputeGraph& ComputeGraph::operator=(ge::ComputeGraph &compute_graph) { - if (&compute_graph == this) { - return *this; - } - AttrHolder::SwapBase(compute_graph); - *impl_ = *(compute_graph.impl_); - return *this; -} - -NodePtr ComputeGraph::AddNodeFront(const NodePtr node) { - return impl_->AddNodeFront(node); -} - -NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { - return impl_->AddNodeFront(op, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(const NodePtr node) { - return impl_->AddNode(node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector ComputeGraph::InsertNodes( - const NodePtr &node, const std::vector &insert_ops) { - return impl_->InsertNodes(node, insert_ops, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::InsertNode( - const NodePtr &node, const OpDescPtr &insert_op) { - return impl_->InsertNode(node, insert_op, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::InsertNodeBefore( - const NodePtr &node, const OpDescPtr &insert_op) { - return impl_->InsertNodeBefore(node, insert_op, shared_from_this()); -} - -bool ComputeGraph::IsSupportFuse(const std::vector &origin_nodes, std::string &reason_not_support) const { - return impl_->IsSupportFuse(origin_nodes, reason_not_support); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector ComputeGraph::FuseNodeKeepTopo( - const std::vector &ori_nodes, const std::vector &fusion_ops) { - return impl_->FuseNodeKeepTopo(ori_nodes, fusion_ops, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(const OpDescPtr op) { - return impl_->AddNode(op, shared_from_this()); -} - -NodePtr ComputeGraph::AddNode(const OpDescPtr op, const int64_t id) { // for unserialize. - return impl_->AddNode(op, id, shared_from_this()); -} - -NodePtr ComputeGraph::AddInputNode(const NodePtr node) { - return impl_->AddInputNode(node); -} - -NodePtr ComputeGraph::AddOutputNode(const NodePtr node) { - return AddOutputNodeByIndex(node, 0); -} - -NodePtr ComputeGraph::AddOutputNodeByIndex(const NodePtr node, const int32_t index) { - return impl_->AddOutputNodeByIndex(node, index); -} - -graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { - return impl_->RemoveConstInput(node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) { - return impl_->RemoveNode(node); -} - -// Used in sub_graph scenes -graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { - return impl_->RemoveInputNode(node); -} - -graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { - return impl_->RemoveOutputNode(node); -} - -std::shared_ptr ComputeGraph::AddSubGraph(const std::shared_ptr sub_graph) { - return impl_->AddSubGraph(sub_graph); -} - -graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr &sub_graph) { - return impl_->RemoveSubGraph(sub_graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr &subgraph) { - return impl_->AddSubgraph(name, subgraph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::AddSubgraph(const std::shared_ptr &subgraph) { - if (subgraph == nullptr) { - return GRAPH_PARAM_INVALID; - } - return AddSubgraph(subgraph->GetName(), subgraph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { - return impl_->RemoveSubgraph(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( - const std::shared_ptr &subgraph) { - if (subgraph != nullptr) { - RemoveSubgraph(subgraph->GetName()); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetSubgraph( - const std::string &name) const { - return impl_->GetSubgraph(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector> -ComputeGraph::GetAllSubgraphs() const { - return impl_->GetAllSubgraphs(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetAllSubgraphs( - const std::vector> &subgraphs) { - return impl_->SetAllSubgraphs(subgraphs); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::map, std::vector> &ComputeGraph::GetShareParamLayer() const { - return impl_->GetShareParamLayer(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetShareParamLayer( - const std::map, std::vector> params_share_map) { - impl_->SetShareParamLayer(params_share_map); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetInputsOrder( - const std::vector &inputs_order) { - impl_->SetInputsOrder(inputs_order); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetGraphOutNodes( - const std::map> out_nodes_map) { - impl_->SetGraphOutNodes(out_nodes_map); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::AppendGraphOutNodes( - const std::map> out_nodes_map) { - impl_->AppendGraphOutNodes(out_nodes_map); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetParentGraph() { - return impl_->GetParentGraph(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const ComputeGraph *ComputeGraph::GetParentGraphBarePtr() const { - return impl_->GetParentGraphBarePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( - const std::shared_ptr &parent) { - impl_->SetParentGraph(parent); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetParentNode() { - return impl_->GetParentNode(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const Node *ComputeGraph::GetParentNodeBarePtr() const { - return impl_->GetParentNodeBarePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetNetOutputNode( - const std::shared_ptr &netoutput_node) { - return impl_->SetNetOutputNode(netoutput_node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr ComputeGraph::GetOrUpdateNetOutputNode() { - return impl_->GetOrUpdateNetOutputNode(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const std::shared_ptr &parent) { - return impl_->SetParentNode(parent); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::map> &ComputeGraph::GetGraphOutNodes() const { - return impl_->GetGraphOutNodes(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetOrigGraph(const ComputeGraphPtr orig_graph) { - impl_->SetOrigGraph(orig_graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr ComputeGraph::GetOrigGraph(void) { - return impl_->GetOrigGraph(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetOutputSize(const uint32_t size) { - impl_->SetOutputSize(size); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t ComputeGraph::GetOutputSize() const { - return impl_->GetOutputSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetInputSize(const uint32_t size) { - impl_->SetInputSize(size); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t ComputeGraph::GetInputSize() const { - return impl_->GetInputSize(); -} - -// false: known shape true: unknow shape -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GetGraphUnknownFlag() const { - bool is_unknown = false; - (void)AttrUtils::GetBool(this, ATTR_NAME_GRAPH_UNKNOWN_FLAG, is_unknown); - return is_unknown; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetGraphUnknownFlag(const bool flag) { - (void)AttrUtils::SetBool(this, ATTR_NAME_GRAPH_UNKNOWN_FLAG, flag); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetNeedIteration(const bool need_iteration) { - impl_->SetNeedIteration(need_iteration); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GetNeedIteration() const { - return impl_->GetNeedIteration(); -} - -/// -/// @brief Update input-mapping -/// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::UpdateInputMapping(const std::map &input_mapping) { - return impl_->UpdateInputMapping(input_mapping); -} - -/// -/// @brief Update output-mapping -/// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output -/// @return graphStatus -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ComputeGraph::UpdateOutputMapping(const std::map &output_mapping) { - return impl_->UpdateOutputMapping(output_mapping); -} - -graphStatus ComputeGraph::ReorderEventNodes() { - return impl_->ReorderEventNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertGraphEvents() { - return impl_->InsertGraphEvents(shared_from_this()); -} - -graphStatus ComputeGraph::DFSTopologicalSorting(std::vector &node_vec, - std::map &map_in_edge_num, - std::vector &stack, const bool reverse) { - (void) map_in_edge_num; - (void) stack; - return impl_->DFSTopologicalSorting(node_vec, reverse, shared_from_this()); -} - -graphStatus ComputeGraph::BFSTopologicalSorting(std::vector &node_vec, - std::map &map_in_edge_num, - std::deque &stack) { - (void) map_in_edge_num; - (void) stack; - return impl_->BFSTopologicalSorting(node_vec, false, shared_from_this()); -} - -graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map) { - return impl_->CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::TopologicalSorting( - const std::function comp) { - return impl_->TopologicalSorting(comp); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { - return impl_->TopologicalSorting(shared_from_this(), shared_from_this()); -} - -graphStatus ComputeGraph::TopologicalSortingGraph(const bool dfs_reverse) { - return impl_->TopologicalSortingGraph(shared_from_this(), dfs_reverse); -} - -graphStatus ComputeGraph::SortNodes(std::vector &stack, std::map &map_in_edge_num) { - return impl_->SortNodes(stack, map_in_edge_num, shared_from_this()); -} - -size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) const { - return impl_->GetInEdgeSize(node); -} - -size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) const { - return impl_->GetOutEdgeSize(node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { - return impl_->IsValid(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::InValid() { - impl_->InValid(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { - return impl_->Dump(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) { - this->AttrHolder::SwapBase(graph); - impl_->Swap(*(graph.impl_)); - - // Update Node owner. - SetNodesOwner(); - graph.SetNodesOwner(); - - // Update parent graph of 'TOP subgraph'. 'TOP subgraph' refers to the direct subgraph of the root graph. - SetTopParentGraph(); - graph.SetTopParentGraph(); -} - -void ComputeGraph::SetNodesOwner() { - return impl_->SetNodesOwner(shared_from_this()); -} - -void ComputeGraph::SetTopParentGraph() { - return impl_->SetTopParentGraph(shared_from_this()); -} - -void ComputeGraph::EraseFromNodeList(const std::list::iterator position) { - impl_->EraseFromNodeList(position); -} - -void ComputeGraph::ClearNodeList() { - impl_->ClearNodeList(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { - return impl_->IsolateNode(node); -} - -graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) const { - return impl_->RemoveExtraOutEdge(node); -} - -ProtoAttrMap &ComputeGraph::MutableAttrMap() { - return impl_->MutableAttrMap(); -} - -ConstProtoAttrMap &ComputeGraph::GetAttrMap() const { - return impl_->GetAttrMap(); -} - -const std::map &ComputeGraph::GetAllNodesInfo() const { - return impl_->GetAllNodesInfo(); -} - -void ComputeGraph::SetUserDefOutput(const std::string &output_name) { - impl_->SetUserDefOutput(output_name); -} - -const std::string ComputeGraph::GetOutput() { - return impl_->GetOutput(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetGraphOpName( - const std::map &op_name_map) { - impl_->SetGraphOpName(op_name_map); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::map &ComputeGraph::GetGraphOpName() const { - return impl_->GetGraphOpName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetAllNodesInfo( - const std::map &nodes) { - impl_->SetAllNodesInfo(nodes); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetGraphOutNodesInfo( - std::vector> &out_nodes_info) { - impl_->SetGraphOutNodesInfo(out_nodes_info); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::AppendGraphOutNodesInfo( - std::vector> &out_nodes_info) { - impl_->AppendGraphOutNodesInfo(out_nodes_info); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::vector> &ComputeGraph::GetGraphOutNodesInfo() const { - return impl_->GetGraphOutNodesInfo(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetGraphTargetNodesInfo( - const std::vector &target_nodes_info) { - impl_->SetGraphTargetNodesInfo(target_nodes_info); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::vector &ComputeGraph::GetGraphTargetNodesInfo() const { - return impl_->GetGraphTargetNodesInfo(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetSessionID(const uint64_t session_id) { - impl_->SetSessionID(session_id); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint64_t ComputeGraph::GetSessionID() const { - return impl_->GetSessionID(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetGraphID(const uint32_t graph_id) { - impl_->SetGraphID(graph_id); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t ComputeGraph::GetGraphID() const { - return impl_->GetGraphID(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SaveDataFormat(const ge::Format data_format) { - impl_->SaveDataFormat(data_format); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ge::Format ComputeGraph::GetDataFormat() const { - return impl_->GetDataFormat(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsSummaryGraph() const { - return impl_->IsSummaryGraph(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetSummaryFlag(const bool is_summary_graph) { - impl_->SetSummaryFlag(is_summary_graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::ReorderByNodeId() { - impl_->ReorderByNodeId(); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ComputeGraph::TopologicalSorting(TopoSortingMode topo_sorting_mode) { - return impl_->TopologicalSortingGraph(shared_from_this(), topo_sorting_mode); -} -} // namespace ge diff --git a/graph/normal_graph/compute_graph_impl.h b/graph/normal_graph/compute_graph_impl.h deleted file mode 100644 index c7c82f89f32ed21be2625c7e914ef68ef002e5c0..0000000000000000000000000000000000000000 --- a/graph/normal_graph/compute_graph_impl.h +++ /dev/null @@ -1,314 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_COMPUTE_GRAPH_IMPL_H_ -#define GRAPH_COMPUTE_GRAPH_IMPL_H_ - -#include "graph/compute_graph.h" - -namespace ge { -inline const ge::char_t *GetTopoSortingModeStr(const TopoSortingMode &mode) { - static const ge::char_t *topo_sorting_mode_strs[static_cast(TopoSortingMode::kInvalid) + 1U] - = {"BFS", "DFS", "RDFS", "StableRDFS", "Invalid"}; - if ((mode >= TopoSortingMode::kInvalid) || (mode < TopoSortingMode::kBFS)) { - return topo_sorting_mode_strs[static_cast(TopoSortingMode::kInvalid)]; - } - return topo_sorting_mode_strs[static_cast(mode)]; -} - -enum class WalkStatus { - kNotWalked, - kWalking, - kWalked -}; -class ComputeGraphImpl { - public: - using ConstComputeGraphPtr = std::shared_ptr; - template - using Vistor = RangeVistor>; - - explicit ComputeGraphImpl(const std::string &name); - - ~ComputeGraphImpl() = default; - - std::string GetName() const; - void SetName(const std::string &name); - - size_t GetAllNodesSize(const ConstComputeGraphPtr &compute_graph) const; - Vistor GetAllNodes(const ConstComputeGraphPtr &compute_graph) const; - Vistor GetAllNodes(const NodeFilter &node_filter, - const GraphFilter &graph_filter, - const ConstComputeGraphPtr &compute_graph) const; - Vistor AllGraphNodes(std::vector &subgraphs, - const ConstComputeGraphPtr &compute_graph) const; - std::vector AllGraphNodesPtr(std::vector &subgraphs) const; - Vistor GetNodes(const bool is_unknown_shape, - const ConstComputeGraphPtr &compute_graph) const; - Vistor GetNodes(const bool is_unknown_shape, - const NodeFilter &node_filter, - const GraphFilter &graph_filter, - const ConstComputeGraphPtr &compute_graph) const; - size_t GetDirectNodesSize() const; - Vistor GetDirectNode(const ConstComputeGraphPtr &compute_graph) const; - std::vector GetDirectNodePtr() const; - Vistor GetInputNodes(const ConstComputeGraphPtr &compute_graph) const; - Vistor GetOutputNodes(const ConstComputeGraphPtr &compute_graph) const; - NodePtr FindNode(const std::string &name) const; - NodePtr FindFirstNodeMatchType(const std::string &type) const; - - bool GraphAttrsAreEqual(const ComputeGraphImpl &r_graph) const; - bool VectorInputNodePtrIsEqual(const std::vector &left_nodes, const std::vector &right_nodes) const; - bool GraphMembersAreEqual(const ComputeGraphImpl &r_graph) const; - - bool operator==(const ComputeGraphImpl &r_graph) const; - - NodePtr AddNodeFront(const NodePtr node); - NodePtr AddNodeFront(const OpDescPtr &op, const ComputeGraphPtr &compute_graph); - NodePtr AddNode(const NodePtr node); - NodePtr AddNode(const OpDescPtr op, const ComputeGraphPtr &compute_graph); - NodePtr AddNode(const OpDescPtr op, const int64_t id, const ComputeGraphPtr &compute_graph); - - std::vector InsertNodes(const NodePtr &node, - const std::vector &insert_ops, - const ComputeGraphPtr &compute_graph); - - NodePtr InsertNode(const NodePtr &node, - const OpDescPtr &insert_op, - const ComputeGraphPtr &compute_graph); - - NodePtr InsertNodeBefore(const NodePtr &node, - const OpDescPtr &insert_op, - const ComputeGraphPtr &compute_graph); - - static bool IsSupportFuse(const std::vector &nodes, std::string &reason_not_support) ; - std::vector FuseNodeKeepTopo(const std::vector &ori_nodes, - const std::vector &fusion_ops, - const ComputeGraphPtr &compute_graph); - - NodePtr AddInputNode(const NodePtr node); - NodePtr AddOutputNode(const NodePtr node); - NodePtr AddOutputNodeByIndex(const NodePtr node, const int32_t index); - - graphStatus RemoveConstInput(const NodePtr &node); - graphStatus RemoveNode(const NodePtr &node); - graphStatus RemoveInputNode(const NodePtr &node); - graphStatus RemoveOutputNode(const NodePtr &node); - - std::shared_ptr AddSubGraph(const std::shared_ptr &sub_graph); - graphStatus RemoveSubGraph(const std::shared_ptr &sub_graph); - graphStatus AddSubgraph(const std::string &name, const std::shared_ptr &subgraph); - void RemoveSubgraph(const std::string &name); - - std::shared_ptr GetSubgraph(const std::string &name) const; - std::vector> GetAllSubgraphs() const; - void SetAllSubgraphs(const std::vector> &subgraphs); - - std::shared_ptr GetParentGraph() const; - const ComputeGraph *GetParentGraphBarePtr() const; - void SetParentGraph(const std::shared_ptr &parent); - std::shared_ptr GetParentNode() const; - const Node *GetParentNodeBarePtr() const; - void SetParentNode(const std::shared_ptr &parent); - std::shared_ptr GetOrUpdateNetOutputNode(); - void SetNetOutputNode(const std::shared_ptr &netoutput_node); - const std::map> &GetGraphOutNodes() const { return out_nodes_map_; } - - void SetOrigGraph(const ComputeGraphPtr &orig_graph) { origGraph_ = orig_graph; } - ComputeGraphPtr GetOrigGraph(void) { return origGraph_; } - void SetOutputSize(const uint32_t size) { output_size_ = size; } - uint32_t GetOutputSize() const { return output_size_; } - void SetInputSize(const uint32_t size) { input_size_ = size; } - uint32_t GetInputSize() const { return input_size_; } - - void SetNeedIteration(const bool need_iteration) { need_iteration_ = need_iteration; } - bool GetNeedIteration() const { return need_iteration_; } - - const std::map, std::vector> &GetShareParamLayer() const { - return params_share_map_; - } - void SetShareParamLayer(const std::map, std::vector> ¶ms_share_map) { - params_share_map_ = params_share_map; - } - - void SetInputsOrder(const std::vector &inputs_order) { inputs_order_ = inputs_order; } - void SetGraphOutNodes(const std::map> &out_nodes_map) { - out_nodes_map_ = out_nodes_map; - } - void AppendGraphOutNodes(const std::map> out_nodes_map) { - for (auto &item : out_nodes_map) { - (void)out_nodes_map_.emplace(item.first, item.second); - } - } - - void SetGraphOpName(const std::map &op_name_map) { op_name_map_ = op_name_map; } - const std::map &GetGraphOpName() const { return op_name_map_; } - void SetAllNodesInfo(const std::map &nodes) { all_nodes_infos_ = nodes; } - - void SetGraphOutNodesInfo(const std::vector> &out_nodes_info) { - output_nodes_info_ = out_nodes_info; - } - - void AppendGraphOutNodesInfo(std::vector> &out_nodes_info) { - (void)output_nodes_info_.insert(output_nodes_info_.cend(), out_nodes_info.cbegin(), out_nodes_info.cend()); - } - - const std::vector> &GetGraphOutNodesInfo() const { return output_nodes_info_; } - - void SetGraphTargetNodesInfo(const std::vector &target_nodes_info) { - target_nodes_info_ = target_nodes_info; - } - const std::vector &GetGraphTargetNodesInfo() const { return target_nodes_info_; } - - void SetSessionID(const uint64_t session_id) { session_id_ = session_id; } - uint64_t GetSessionID() const { return session_id_; } - - void SetGraphID(const uint32_t graph_id) { graph_id_ = graph_id; } - uint32_t GetGraphID() const { return graph_id_; } - - void SaveDataFormat(const ge::Format data_format) { data_format_ = data_format; } - ge::Format GetDataFormat() const { return data_format_; } - bool IsSummaryGraph() const { return is_summary_graph_; } - void SetSummaryFlag(const bool is_summary_graph) { is_summary_graph_ = is_summary_graph; } - - graphStatus UpdateInputMapping(const std::map &input_mapping); - graphStatus UpdateOutputMapping(const std::map &output_mapping) const; - graphStatus ReorderEventNodes(const ConstComputeGraphPtr &compute_graph); - graphStatus InsertGraphEvents(const ConstComputeGraphPtr &compute_graph); - - graphStatus DFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const; - graphStatus BFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const; - /** - * 从模型输出节点开始反向DFS遍历 - * @param node_vec - * @param reverse - * @param compute_graph - * @return - */ - graphStatus RDFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const; - /** - * 基于调用此接口之前的原始topo顺序,仅对拓扑错误的节点做部分调整,部分调整的算法RDFS - * @param node_vec - * @param reverse - * @param compute_graph - * @return - */ - graphStatus StableRDFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ConstComputeGraphPtr &compute_graph) const; - graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, - std::map &breadth_node_map) const; - void TopologicalSorting(const std::function comp); - graphStatus TopologicalSorting(const ComputeGraphPtr &const_graph_ptr, - const ConstComputeGraphPtr &const_compute_graph); - graphStatus TopologicalSortingGraph(const ConstComputeGraphPtr &compute_graph, - const bool dfs_reverse = false); - graphStatus TopologicalSortingGraph(const ConstComputeGraphPtr &compute_graph, - TopoSortingMode topo_sorting_mode); - graphStatus SortNodes(std::vector &stack, std::map &map_in_edge_num, - const ConstComputeGraphPtr &compute_graph) const; - - size_t GetInEdgeSize(const NodePtr &node) const; - size_t GetOutEdgeSize(const NodePtr &node) const; - - bool IsValid() const; - void InValid(); - void Dump(const ConstComputeGraphPtr &graph) const; - void Swap(ComputeGraphImpl &graph); - - void SetNodesOwner(const ComputeGraphPtr &compute_graph); - void SetTopParentGraph(const ComputeGraphPtr &compute_graph); - graphStatus IsolateNode(const NodePtr &node) const; - graphStatus RemoveExtraOutEdge(const NodePtr &node) const; - - ProtoAttrMap &MutableAttrMap(); - ConstProtoAttrMap &GetAttrMap() const; - - const std::map &GetAllNodesInfo() const; - void SetUserDefOutput(const std::string &output_name); - const std::string GetOutput(); - - void EraseFromNodeList(const std::list::iterator &position); - void InsertToNodeList(const std::list::iterator &position, const NodePtr &node); - - void PushBackToNodeList(const NodePtr &node); - - void EmplaceBackToNodeList(const NodePtr &node); - void ClearNodeList(); - void ReorderByNodeId(); - - private: - void inline AddInputDataNode(const NodePtr &node); - ge::NodePtr CreateNodeFromOpDesc(const OpDescPtr &op_desc, - const ComputeGraphPtr &compute_graph, - const int64_t topo_id); - void inline GetAllNodesFromOpdesc(const OpDesc &op_desc, const GraphFilter &graph_filter, - std::deque& candidates, const NodePtr node) const; - void inline GetAllNodesFromOpdesc(std::vector &subgraphs, const OpDesc &op_desc, - std::deque& candidates) const; - void inline GetAllNodesPtrFromOpdesc(std::vector &subgraphs, const OpDesc &op_desc, - std::deque& candidates) const; - - template - void inline GetOutNodesFromAnchor(const AnchorPtr &peer_in_anchor, std::map &map_in_edge_num, - std::vector &out_nodes) const { - const auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); - if (iter != map_in_edge_num.end()) { - --iter->second; - if (iter->second == 0U) { - out_nodes.push_back(peer_in_anchor->GetOwnerNode()); - } - } - } - graphStatus DoTopologicalSorting(const ConstComputeGraphPtr &compute_graph, - TopoSortingMode sorting_mode, - bool dfs_reverse); - private: - friend class ModelSerializeImp; - friend class GraphUtils; - friend class ExecuteGraphAdapter; - std::string name_; - std::list nodes_; - uint32_t graph_id_ = 0U; - AttrStore attrs_; - size_t direct_nodes_size_ = 0UL; - std::map all_nodes_infos_; - std::vector target_nodes_info_; - - std::vector input_nodes_; - std::vector inputs_order_; - uint32_t input_size_ = 1U; - std::map> out_nodes_map_; - uint32_t output_size_ = 1U; - std::vector> output_nodes_info_; - - std::vector> sub_graph_; - std::map> names_to_subgraph_; - std::weak_ptr parent_graph_; - std::weak_ptr parent_node_; - - // the members followed should not in the ComputeGraph class - bool is_valid_flag_; - bool is_summary_graph_ = false; - // Indicates whether it is need iteration - bool need_iteration_ = false; - std::map, std::vector> params_share_map_; - // TaskIdx -> op_name Map - std::map op_name_map_; - uint64_t session_id_ = 0UL; - ge::Format data_format_ = ge::FORMAT_ND; - // Graph Before BFE - ComputeGraphPtr origGraph_; - std::weak_ptr graph_netoutput_; - Node *parent_node_bare_ptr_ = nullptr; - ComputeGraph *parent_graph_bare_ptr_ = nullptr; -}; -} // namespace ge -#endif // GRAPH_COMPUTE_GRAPH_IMPL_H_ diff --git a/graph/normal_graph/ge_tensor.cc b/graph/normal_graph/ge_tensor.cc deleted file mode 100644 index aa6c9c4216f509684cd811752927179ee346e053..0000000000000000000000000000000000000000 --- a/graph/normal_graph/ge_tensor.cc +++ /dev/null @@ -1,1894 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ge_tensor.h" - -#include -#include -#include -#include "graph/debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "graph/ge_attr_value.h" -#include "graph/model_serialize.h" -#include "graph/small_vector.h" -#include "graph/detail/model_serialize_imp.h" -#include "proto/ge_ir.pb.h" -#include "graph/utils/ge_ir_utils.h" -#include "common/util/mem_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/math_util.h" -#include "inc/common/checker.h" -#include "graph/attribute_group/attr_group_serialize.h" - -namespace ge { -namespace { -static const size_t PAIR_ELEMENT_SIZE = 2UL; -static const size_t PAIR_ELEMENT_KEY = 0UL; -static const size_t PAIR_ELEMENT_VALUE = 1UL; -const char_t *const kKeyDataTypeSelfDefined = "__tensor_desc_data_type__"; -const std::map kDataTypeMap = { - {DT_UNDEFINED, proto::DT_UNDEFINED}, - {DT_FLOAT, proto::DT_FLOAT}, - {DT_FLOAT16, proto::DT_FLOAT16}, - {DT_INT8, proto::DT_INT8}, - {DT_UINT8, proto::DT_UINT8}, - {DT_INT16, proto::DT_INT16}, - {DT_UINT16, proto::DT_UINT16}, - {DT_INT32, proto::DT_INT32}, - {DT_INT64, proto::DT_INT64}, - {DT_UINT32, proto::DT_UINT32}, - {DT_UINT64, proto::DT_UINT64}, - {DT_BOOL, proto::DT_BOOL}, - {DT_DOUBLE, proto::DT_DOUBLE}, - {DT_DUAL, proto::DT_DUAL}, - {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, - {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, - {DT_COMPLEX32, proto::DT_COMPLEX32}, - {DT_COMPLEX64, proto::DT_COMPLEX64}, - {DT_COMPLEX128, proto::DT_COMPLEX128}, - {DT_QINT8, proto::DT_QINT8}, - {DT_QINT16, proto::DT_QINT16}, - {DT_QINT32, proto::DT_QINT32}, - {DT_QUINT8, proto::DT_QUINT8}, - {DT_QUINT16, proto::DT_QUINT16}, - {DT_RESOURCE, proto::DT_RESOURCE}, - {DT_STRING_REF, proto::DT_STRING_REF}, - {DT_STRING, proto::DT_STRING}, - {DT_VARIANT, proto::DT_VARIANT}, - {DT_BF16, proto::DT_BF16}, - {DT_INT4, proto::DT_INT4}, - {DT_UINT1, proto::DT_UINT1}, - {DT_INT2, proto::DT_INT2}, - {DT_UINT2, proto::DT_UINT2}, - {DT_HIFLOAT8, proto::DT_HIFLOAT8}, - {DT_FLOAT8_E5M2, proto::DT_FLOAT8_E5M2}, - {DT_FLOAT8_E4M3FN, proto::DT_FLOAT8_E4M3FN}, - {DT_FLOAT8_E8M0, proto::DT_FLOAT8_E8M0}, - {DT_FLOAT6_E3M2, proto::DT_FLOAT6_E3M2}, - {DT_FLOAT6_E2M3, proto::DT_FLOAT6_E2M3}, - {DT_FLOAT4_E2M1, proto::DT_FLOAT4_E2M1}, - {DT_FLOAT4_E1M2, proto::DT_FLOAT4_E1M2}, -}; - -const std::map kDataTypeSelfDefinedMap = { - {DT_DUAL, 13}, {DT_DUAL_SUB_INT8, 14}, {DT_DUAL_SUB_UINT8, 15}, {DT_COMPLEX64, 16}, {DT_COMPLEX128, 17}, - {DT_QINT8, 18}, {DT_QINT16, 19}, {DT_QINT32, 20}, {DT_QUINT8, 21}, {DT_QUINT16, 22}, - {DT_COMPLEX32, 33}, -}; - -const std::map kDeviceToStrMap = { - {NPU, "NPU"}, {CPU, "CPU"}, -}; - -const std::map kStrToDeviceMap = { - {"NPU", NPU}, {"CPU", CPU} -}; - -const std::string TENSOR_UTILS_SIZE = "size"; -const std::string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; -const std::string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; -const std::string TENSOR_UTILS_OUTPUT_TENSOR = "output_tensor"; -const std::string TENSOR_UTILS_DEVICE_TYPE = "device_type"; -const std::string TENSOR_UTILS_INPUT_TENSOR = "input_tensor"; -const std::string TENSOR_UTILS_REAL_DIM_CNT = "real_dim_cnt"; -const std::string TENSOR_UTILS_REUSE_INPUT_INDEX = "reuse_input_index"; -const std::string TENSOR_UTILS_DATA_OFFSET = "data_offset"; -const std::string TENSOR_UTILS_CMPS_SIZE = "cmps_size"; -const std::string TENSOR_UTILS_CMPS_TAB = "cmps_tab"; -const std::string TENSOR_UTILS_CMPS_TAB_OFFSET = "cmps_tab_offset"; -const std::string TENSOR_UTILS_CMPSINFO = "cmps_info"; -const std::string TENSOR_UTILS_ALLOFFSET_QUANTIZE_INFO = "alloffset_quantize_info"; -const std::string TENSOR_UTILS_RC = "rc"; -const std::string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; -const std::string TENSOR_UTILS_ORIGIN_SHAPE_INITIALIZED = "origin_shape_initialized"; -const std::string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; -const std::string TENSOR_UTILS_ORIGIN_FORMAT_INT = "origin_format_for_int"; -const std::string TENSOR_UTILS_FORMAT = "format_for_int"; -const std::string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; -const std::string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; -const std::string TENSOR_UTILS_ORIGIN_SHAPE_RANGE = "origin_shape_range"; -const std::string TENSOR_UTILS_VALUE_RANGE = "value_range"; -const std::string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; -const std::string TENSOR_UTILS_PLACEMENT = "placement"; -} - -void GeTensorSerializeUtils::GeShapeAsProto(const GeShape &shape, proto::ShapeDef *proto) { - if (proto != nullptr) { - proto->clear_dim(); - for (auto dim : shape.GetDims()) { - proto->add_dim(dim); - } - } -} -void GeTensorSerializeUtils::GeTensorDescAsProto(const GeTensorDescImpl &desc, proto::TensorDescriptor *proto) { - if (proto != nullptr) { - // serialize extension tensor meta data - proto->set_size(desc.ext_meta_.GetSize()); - proto->set_weight_size(desc.ext_meta_.GetWeightSize()); - proto->set_reuse_input(desc.ext_meta_.GetReuseInput()); - proto->set_output_tensor(desc.ext_meta_.GetOutputTensor()); - if (kDeviceToStrMap.find(desc.ext_meta_.GetDeviceType()) != kDeviceToStrMap.end()) { - proto->set_device_type(kDeviceToStrMap.at(desc.ext_meta_.GetDeviceType())); - } - proto->set_input_tensor(desc.ext_meta_.GetInputTensor()); - proto->set_real_dim_cnt(static_cast(desc.ext_meta_.GetRealDimCnt())); - proto->set_reuse_input_index(static_cast(desc.ext_meta_.GetReuseInputIndex())); - proto->set_data_offset(desc.ext_meta_.GetDataOffset()); - proto->set_cmps_size(desc.ext_meta_.GetCmpsSize()); - proto->set_cmps_tab(desc.ext_meta_.GetCmpsTab()); - proto->set_cmps_tab_offset(desc.ext_meta_.GetCmpsTabOffset()); - - // serialize attributes - if (!ModelSerializeImp::SerializeAllAttrsFromAnyMap(desc.attrs_.GetAllAttrs(), proto->mutable_attr())) { - GELOGE(GRAPH_FAILED, "GeTensorDesc attr serialize failed."); - return; - } - - if (!desc.attrs_.GetAttrsGroupPtr().empty()) { - // serialize Attr Group - if (AttrGroupSerialize::SerializeAllAttr(*(proto->mutable_attr_groups()), desc.attrs_) != ge::SUCCESS) { - GELOGE(GRAPH_FAILED, "GeTensorDesc attr serialize failed."); - return; - } - } - - // serialize member object - (*proto->mutable_attr())[TENSOR_UTILS_ORIGIN_FORMAT].set_s( - TypeUtils::FormatToSerialString(desc.GetOriginFormat())); - - (*proto->mutable_attr())[TENSOR_UTILS_ORIGIN_FORMAT_INT].set_i(desc.GetOriginFormat()); - - (*proto->mutable_attr())[TENSOR_UTILS_FORMAT].set_i(desc.GetFormat()); - - if (desc.GetOriginDataType() != DT_UNDEFINED) { - (*proto->mutable_attr())[TENSOR_UTILS_ORIGIN_DATA_TYPE].set_s( - TypeUtils::DataTypeToSerialString(desc.GetOriginDataType())); - } - - const bool is_origin_shape_init = desc.ext_meta_.IsOriginShapeInited(); - (*proto->mutable_attr())[TENSOR_UTILS_ORIGIN_SHAPE_INITIALIZED].set_b(is_origin_shape_init); - if (is_origin_shape_init) { - auto const origin_shape_proto_list = (*proto->mutable_attr())[TENSOR_UTILS_ORIGIN_SHAPE].mutable_list(); - origin_shape_proto_list->clear_i(); - for (auto const dim : desc.OriginShapeReference().GetDims()) { - origin_shape_proto_list->add_i(dim); - } - origin_shape_proto_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_INT); - } - - auto const iter = kDataTypeMap.find(desc.GetDataType()); - if (iter != kDataTypeMap.end()) { - proto->set_dtype(iter->second); - } else { // maybe custom data type - proto->set_dtype(kDataTypeMap.at(DT_UNDEFINED)); - } - proto->set_layout(TypeUtils::FormatToSerialString(desc.GetFormat())); - GeTensorSerializeUtils::GeShapeAsProto(desc.ShapeReference(), proto->mutable_shape()); - } -} -void GeTensorSerializeUtils::GeTensorDescAsProto(const GeTensorDesc &desc, proto::TensorDescriptor *proto) { - GeTensorSerializeUtils::GeTensorDescAsProto(*desc.impl_, proto); -} -void GeTensorSerializeUtils::GeTensorAsProto(const GeTensorImpl &tensor, proto::TensorDef *proto) { - if (tensor.tensor_def_.protoOwner_ != nullptr) { - if (tensor.tensor_def_.protoMsg_ != nullptr) { - *proto = *tensor.tensor_def_.protoMsg_; - GeTensorDescAsProto(tensor.desc_, proto->mutable_desc()); - } - } else { - if ((tensor.tensor_data_.impl_ != nullptr) && (tensor.tensor_data_.impl_->tensor_descriptor_ != nullptr)) { - GeTensorDescAsProto(*tensor.tensor_data_.impl_->tensor_descriptor_, proto->mutable_desc()); - } - proto->set_data(tensor.tensor_data_.data(), tensor.tensor_data_.size()); - } -} -void GeTensorSerializeUtils::GeTensorAsProto(const GeTensor &tensor, proto::TensorDef *proto) { - GeTensorSerializeUtils::GeTensorAsProto(*tensor.impl_, proto); -} - -void GeTensorSerializeUtils::AssembleGeShapeFromProto(const proto::ShapeDef *proto, GeShape &shape) { - if (proto != nullptr) { - shape = GeShape(nullptr, const_cast(proto)); - } -} -void GeTensorSerializeUtils::AssembleGeTensorDescFromProto( - const proto::TensorDescriptor *const proto, GeTensorDesc &desc) { - if (proto != nullptr) { - desc = GeTensorDesc(const_cast(proto)); - } -} -void GeTensorSerializeUtils::AssembleGeTensorFromProto(const proto::TensorDef *proto, GeTensor &tensor) { - if (proto != nullptr) { - tensor = GeTensor(nullptr, const_cast(proto)); - } -} - -void GeTensorSerializeUtils::NormalizeGeTensorDescProto(proto::TensorDescriptor *proto) { - if (proto == nullptr) { - return; - } - auto &attr_map = *(proto->mutable_attr()); - auto iter = attr_map.find(TENSOR_UTILS_SIZE); - if (iter != attr_map.end()) { - proto->set_size(iter->second.i()); - } - iter = attr_map.find(TENSOR_UTILS_WEIGHT_SIZE); - if (attr_map.end() != iter) { - proto->set_weight_size(iter->second.i()); - } - iter = attr_map.find(TENSOR_UTILS_REUSE_INPUT); - if (attr_map.end() != iter) { - proto->set_reuse_input(iter->second.b()); - } - iter = attr_map.find(TENSOR_UTILS_OUTPUT_TENSOR); - if (attr_map.end() != iter) { - proto->set_output_tensor(iter->second.b()); - } - iter = attr_map.find(TENSOR_UTILS_DEVICE_TYPE); - if (attr_map.end() != iter) { - proto->set_device_type(iter->second.s()); - } - iter = attr_map.find(TENSOR_UTILS_INPUT_TENSOR); - if (attr_map.end() != iter) { - proto->set_input_tensor(iter->second.b()); - } - iter = attr_map.find(TENSOR_UTILS_REAL_DIM_CNT); - if (attr_map.end() != iter) { - proto->set_real_dim_cnt(iter->second.i()); - } - iter = attr_map.find(TENSOR_UTILS_REUSE_INPUT_INDEX); - if (attr_map.end() != iter) { - proto->set_reuse_input_index(iter->second.i()); - } - iter = attr_map.find(TENSOR_UTILS_DATA_OFFSET); - if (attr_map.end() != iter) { - proto->set_data_offset(iter->second.i()); - } - iter = attr_map.find(TENSOR_UTILS_CMPS_SIZE); - if (attr_map.end() != iter) { - proto->set_cmps_size(iter->second.i()); - } - iter = attr_map.find(TENSOR_UTILS_CMPS_TAB); - if (attr_map.end() != iter) { - proto->set_cmps_tab(iter->second.s()); - } - iter = attr_map.find(TENSOR_UTILS_CMPS_TAB_OFFSET); - if (attr_map.end() != iter) { - proto->set_cmps_tab_offset(iter->second.i()); - } -} - -void GeTensorSerializeUtils::GetShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape) { - if (proto == nullptr) { - return; - } - shape.SetDimNum(static_cast(proto->shape().dim_size())); - size_t i = 0U; - for (auto const dim : proto->shape().dim()) { - (void)shape.SetDim(i++, dim); - } -} - -void GeTensorSerializeUtils::GetOriginShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape) { - if (proto == nullptr) { - return; - } - auto &attrs = proto->attr(); - auto const iter = attrs.find(TENSOR_UTILS_ORIGIN_SHAPE); - if (iter != attrs.end()) { - shape.SetDimNum(static_cast(iter->second.list().i_size())); - size_t i = 0U; - for (auto const dim : iter->second.list().i()) { - (void)shape.SetDim(i++, dim); - } - } -} - -void GeTensorSerializeUtils::GetDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype) { - if (proto == nullptr) { - return; - } - dtype = DT_UNDEFINED; - auto &attrs = proto->attr(); - auto const iter = attrs.find(kKeyDataTypeSelfDefined); - if (iter == attrs.end()) { - auto const proto_dtype = proto->dtype(); - auto const founded = std::find_if( - kDataTypeMap.begin(), kDataTypeMap.end(), - [proto_dtype](const std::pair &item) { - return item.second == proto_dtype; - }); - if (founded != kDataTypeMap.end()) { - dtype = founded->first; - return; - } - } else { // Custom defined data type set - const int64_t data_type_proto = iter->second.i(); - auto const founded = std::find_if(kDataTypeSelfDefinedMap.begin(), kDataTypeSelfDefinedMap.end(), - [data_type_proto](const std::pair &item) { - return item.second == data_type_proto; - }); - if (founded != kDataTypeSelfDefinedMap.end()) { - dtype = founded->first; - return; - } - } -} - -void GeTensorSerializeUtils::GetOriginDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype) { - if (proto == nullptr) { - return; - } - auto &attrs = proto->attr(); - auto const iter = attrs.find(TENSOR_UTILS_ORIGIN_DATA_TYPE); - if (iter != attrs.end()) { - dtype = TypeUtils::SerialStringToDataType(iter->second.s()); - } -} - -void GeTensorSerializeUtils::GetFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format) { - if (proto == nullptr) { - return; - } - - auto &attrs = proto->attr(); - auto const attr_iter = attrs.find(TENSOR_UTILS_FORMAT); - if (attr_iter != attrs.end()) { - format = static_cast(attr_iter->second.i()); - return; - } - - format = TypeUtils::SerialStringToFormat(proto->layout()); -} - -void GeTensorSerializeUtils::GetOriginFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format) { - if (proto == nullptr) { - return; - } - auto &attrs = proto->attr(); - - auto const attr_iter = attrs.find(TENSOR_UTILS_ORIGIN_FORMAT_INT); - if (attr_iter != attrs.end()) { - format = static_cast(attr_iter->second.i()); - return; - } - - auto const iter = attrs.find(TENSOR_UTILS_ORIGIN_FORMAT); - if (iter != attrs.end()) { - format = TypeUtils::SerialStringToFormat(iter->second.s()); - } -} - -class GeShapeImpl { - using DimsType = SmallVector; - public: - GeShapeImpl() = default; - ~GeShapeImpl() = default; - explicit GeShapeImpl(const std::vector &dims); - explicit GeShapeImpl(proto::ShapeDef *const proto_msg); - - void SetDimNum(const size_t dim_num); - void AppendDim(const int64_t dim_size); - bool IsUnknownDimNum() const; - void SetIsUnknownDimNum(); - size_t GetDimNum() const; - int64_t GetDim(const size_t idx) const; - graphStatus SetDim(const size_t idx, const int64_t value); - std::vector ShapeImplGetDims() const; - const DimsType &ShapeImplGetMutableDims() const; - std::string ShapeImplToString() const; - int64_t GetShapeSize() const; - bool IsUnknownShape() const; - bool IsScalar() const; - bool IsEmptyTensor() const; - - bool operator==(const GeShapeImpl &other) const; - -private: - DimsType dims_; - friend class GeTensorDesc; -}; - -// Default -GeShapeImpl::GeShapeImpl(const std::vector &dims) { - dims_.resize(dims.size()); - (void)std::copy(dims.begin(), dims.end(), dims_.begin()); -} - -void GeShapeImpl::SetDimNum(const size_t dim_num) { - dims_.resize(dim_num, UNKNOWN_DIM); -} - -void GeShapeImpl::AppendDim(const int64_t dim_size) { - dims_.push_back(dim_size); -} - -bool GeShapeImpl::IsUnknownDimNum() const { - return (dims_.size() == 1UL) && (dims_[0UL] == UNKNOWN_DIM_NUM); -} - -void GeShapeImpl::SetIsUnknownDimNum() { - dims_.resize(1UL, UNKNOWN_DIM_NUM); - dims_[0UL] = UNKNOWN_DIM_NUM; -} - -size_t GeShapeImpl::GetDimNum() const { - if (IsUnknownDimNum()) { - GELOGI("Dim num is unknown, return 0U."); - return 0UL; - } - return dims_.size(); -} - -int64_t GeShapeImpl::GetDim(const size_t idx) const { - if (idx < dims_.size()) { - return dims_[idx]; - } else { - return 0; - } -} - -graphStatus GeShapeImpl::SetDim(const size_t idx, const int64_t value) { - if (idx < dims_.size()) { - dims_[idx] = value; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -std::vector GeShapeImpl::ShapeImplGetDims() const { - std::vector dims; - dims.resize(dims_.size()); - (void)std::copy(dims_.begin(), dims_.end(), dims.begin()); - - return dims; -} - -const SmallVector &GeShapeImpl::ShapeImplGetMutableDims() const { - return dims_; -} - -std::string GeShapeImpl::ShapeImplToString() const { - if (dims_.empty()) { - return ""; - } - - std::stringstream ss; - ss << dims_[0UL]; - for (size_t i = 1UL; i < dims_.size(); i++) { - ss << "," << dims_[i]; - } - return ss.str(); -} - -int64_t GeShapeImpl::GetShapeSize() const { - if (dims_.empty()) { - return 0; - } - int64_t shape_size = 1; - for (auto const dim : dims_) { - if ((dim == UNKNOWN_DIM) || (dim == UNKNOWN_DIM_NUM) || (dim < 0)) { - return -1; - } else if (dim == 0) { - return 0; - } else { - if (shape_size > (INT64_MAX / dim)) { - return -1; - } - shape_size *= dim; - } - } - return shape_size; -} - -bool GeShapeImpl::IsUnknownShape() const { - return std::any_of(dims_.begin(), dims_.end(), [](const int64_t &dim) { - return (dim == UNKNOWN_DIM) || (dim == UNKNOWN_DIM_NUM) || (dim < 0); - }); -} - -bool GeShapeImpl::IsScalar() const { - return dims_.empty(); -} - -bool GeShapeImpl::IsEmptyTensor() const { - for (const auto &dim : dims_) { - if (dim == 0) { - return true; - } - } - return false; -} - -GeShapeImpl::GeShapeImpl(proto::ShapeDef *const proto_msg) { - if (proto_msg != nullptr) { - const auto &dims = *proto_msg->mutable_dim(); - - dims_.resize(static_cast(dims.size())); - (void)std::copy(dims.begin(), dims.end(), dims_.begin()); - } -} - -bool GeShapeImpl::operator==(const GeShapeImpl &other) const { - return this->ShapeImplGetDims() == other.ShapeImplGetDims(); -} - -GeShape::GeShape() : impl_(MakeShared()) {} -GeShape::GeShape(std::vector s) : impl_(MakeShared(std::move(s))) {} -GeShape::GeShape(const GeShape &other) : impl_(MakeShared(*(other.impl_))) {} -GeShape::GeShape(GeShape &&other) : impl_(MakeShared(std::move(*(other.impl_)))) {} -GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *const proto_msg) - : impl_(MakeShared(proto_msg)) { - (void)proto_owner; -} -GeShape::~GeShape() = default; - -size_t GeShape::GetDimNum() const { - return impl_->GetDimNum(); -} - -void GeShape::SetDimNum(const size_t dim_num) { - impl_->SetDimNum(dim_num); -} - -void GeShape::AppendDim(const int64_t dim_size) { - impl_->AppendDim(dim_size); -} - -bool GeShape::IsUnknownDimNum() const { - return impl_->IsUnknownDimNum(); -} - -void GeShape::SetIsUnknownDimNum() { - impl_->SetIsUnknownDimNum(); -} - -int64_t GeShape::GetDim(const size_t idx) const { - return impl_->GetDim(idx); -} - -graphStatus GeShape::SetDim(const size_t idx, const int64_t value) { - return impl_->SetDim(idx, value); -} - -std::vector GeShape::GetDims() const { - return impl_->ShapeImplGetDims(); -} - -const SmallVector &GeShape::GetMutableDims() const { - return impl_->ShapeImplGetMutableDims(); -} - -std::string GeShape::ToString() const { - return impl_->ShapeImplToString(); -} - -int64_t GeShape::GetShapeSize() const { - return impl_->GetShapeSize(); -} - -bool GeShape::IsUnknownShape() const { - return impl_->IsUnknownShape(); -} - -bool GeShape::IsScalar() const { - return impl_->IsScalar(); -} - -bool GeShape::IsEmptyTensor() const { - return impl_->IsEmptyTensor(); -} - -GeShape &GeShape::operator=(const GeShape &other) { - if (&other != this) { - *impl_ = *(other.impl_); - } - return *this; -} - -GeShape &GeShape::operator=(GeShape &&other) { - if (&other != this) { - impl_ = other.impl_; - } - return *this; -} - -bool GeShape::operator==(const GeShape &other) const { - return (*impl_) == (*(other.impl_)); -} - -GeTensorDescImpl::GeTensorDescImpl(const GeShape &shape, const Format format, const DataType dt) : GeTensorDescImpl() { - SetFormat(format); - SetDataType(dt); - shape_ = shape; -} - -GeTensorDescImpl::GeTensorDescImpl(proto::TensorDescriptor *const proto_msg) - : GeTensorDescImpl() { - if (proto_msg == nullptr) { - GELOGE(INTERNAL_ERROR, "Try assemble ge tensor desc from nullptr proto"); - return; - } - // normalize the input TensorDescriptor,A metadata information maybe stored in different fields of TensorDescriptor, - // This function needs to prioritize and determine the final metadata information used. - // After standardization, the direct member field on TensorDescriptor is always valid - GeTensorSerializeUtils::NormalizeGeTensorDescProto(proto_msg); - - // store high frequency attributes to member field - GeTensorSerializeUtils::GetOriginFormatFromDescProto(proto_msg, origin_format_); - GeTensorSerializeUtils::GetOriginDtypeFromDescProto(proto_msg, origin_dtype_); - GeTensorSerializeUtils::GetOriginShapeFromDescProto(proto_msg, origin_shape_); - - GeTensorSerializeUtils::GetFormatFromDescProto(proto_msg, format_); - GeTensorSerializeUtils::GetDtypeFromDescProto(proto_msg, dtype_); - GeTensorSerializeUtils::GetShapeFromDescProto(proto_msg, shape_); - - // get extension tensor desc metadata - ext_meta_.SetSize(proto_msg->size()); - ext_meta_.SetWeightSize(proto_msg->weight_size()); - ext_meta_.SetReuseInput(proto_msg->reuse_input()); - ext_meta_.SetOutputTensor(proto_msg->output_tensor()); - if (kStrToDeviceMap.find(proto_msg->device_type()) != kStrToDeviceMap.end()) { - ext_meta_.SetDeviceType(kStrToDeviceMap.at(proto_msg->device_type())); - } - ext_meta_.SetInputTensor(proto_msg->input_tensor()); - if (IntegerChecker::Compat(proto_msg->real_dim_cnt())) { - ext_meta_.SetRealDimCnt(static_cast(proto_msg->real_dim_cnt())); - } - if (IntegerChecker::Compat(proto_msg->reuse_input_index())) { - ext_meta_.SetReuseInputIndex(static_cast(proto_msg->reuse_input_index())); - } - ext_meta_.SetDataOffset(proto_msg->data_offset()); - ext_meta_.SetCmpsSize(proto_msg->cmps_size()); - ext_meta_.SetCmpsTab(proto_msg->cmps_tab()); - ext_meta_.SetCmpsTabOffset(proto_msg->cmps_tab_offset()); - - auto &attr_map = *(proto_msg->mutable_attr()); - const auto iter = attr_map.find(TENSOR_UTILS_ORIGIN_SHAPE_INITIALIZED); - if (iter != attr_map.end()) { - ext_meta_.SetOriginShapeInited(iter->second.b()); - } - - // note that we deserialize attributes in implement of GeTensor constructor -} - -void GeTensorDescImpl::SetDataType(const DataType dtype) { - dtype_ = dtype; -} - -void GeTensorDescImpl::SetOriginDataType(const DataType dtype) { - origin_dtype_ = dtype; -} - -DataType GeTensorDescImpl::GetOriginDataType() const { - return origin_dtype_; -} - -void GeTensorDescImpl::SetFormat(const Format format) { - format_ = format; -} - -void GeTensorDescImpl::SetOriginFormat(const Format format) { - origin_format_ = format; -} - -Format GeTensorDescImpl::GetOriginFormat() const { - return origin_format_; -} - -GeShape &GeTensorDescImpl::ShapeReference() const { - return shape_; -} - -GeShape &GeTensorDescImpl::OriginShapeReference() const { - return origin_shape_; -} - -bool GeTensorDescImpl::GeTensorDescAttrsAreEqual(const GeTensorDescImpl &other) const { - // The definition of attribute equality remains unchanged - return ((shape_ == other.shape_) && - (dtype_ == other.dtype_) && - (format_ == other.format_) && - (ext_meta_ == other.ext_meta_)); -} - -bool GeTensorDescImpl::operator==(const GeTensorDescImpl &other) const { - // The definition of attribute equality remains unchanged - return (origin_shape_ == other.origin_shape_) && (origin_format_ == other.origin_format_) && - (origin_dtype_ == other.origin_dtype_) && (GeTensorDescAttrsAreEqual(other)); -} - -ProtoAttrMap &GeTensorDescImpl::MutableAttrMap() { - return attrs_; -} - -ConstProtoAttrMap &GeTensorDescImpl::GetAttrMap() const { - return attrs_; -} - -void GeTensorDescImpl::SetShape(GeShape &shape) const { - ShapeReference() = std::move(shape); -} - -Format GeTensorDescImpl::GetFormat() const { - return format_; -} - -void GeTensorDescImpl::SetName(const std::string &name) { - ext_meta_.SetName(name); -} - -const std::string GeTensorDescImpl::GetName() const { - return ext_meta_.GetName(); -} - -DataType GeTensorDescImpl::GetDataType() const { - return dtype_; -} - -std::string GeTensorDescImpl::ExtMeta::GetDeviceTypeStr() const { - auto const iter = kDeviceToStrMap.find(device_type); - if (iter != kDeviceToStrMap.end()) { - return iter->second; - } - const static std::string kDefaultTypeString{"NPU"}; - return kDefaultTypeString; -} - -GeTensorDesc::GeTensorDesc() : AttrHolder(), - impl_(ComGraphMakeSharedAndThrow()) {} - -// Default -GeTensorDesc::GeTensorDesc(const GeShape &shape, const Format format, const DataType dt) : AttrHolder(), - impl_(ComGraphMakeSharedAndThrow(shape, format, dt)) {} - -// Default -GeTensorDesc::GeTensorDesc(const GeTensorDesc &desc) : AttrHolder(desc), - impl_(ComGraphMakeSharedAndThrow(*(desc.impl_))) {} - -// Default -GeTensorDesc::GeTensorDesc(GeTensorDesc &&desc) : AttrHolder(desc), impl_(desc.impl_) {} - -GeTensorDesc::~GeTensorDesc() = default; - -GeTensorDesc::GeTensorDesc(proto::TensorDescriptor *const proto_msg) - : AttrHolder(), impl_(ComGraphMakeSharedAndThrow(proto_msg)) { - if (proto_msg != nullptr) { - if (!ModelSerializeImp::DeserializeAllAttrsToAttrHolder(proto_msg->attr(), this)) { - GELOGW("GeTensorDesc attr deserialize failed."); - } - if (AttrGroupSerialize::DeserializeAllAttr(proto_msg->attr_groups(), this) != ge::SUCCESS) { - GELOGW("GeTensorDesc attr group deserialize failed."); - } - } -} - -bool GeTensorDesc::GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const { - return impl_->GeTensorDescAttrsAreEqual(*(r_ge_tensor_desc.impl_)); -} - -bool GeTensorDesc::operator==(const GeTensorDesc &r_ge_tensor_desc) const { - return (*impl_) == (*(r_ge_tensor_desc.impl_)); -} - -GeShape &GeTensorDesc::ShapeReference() const { - return impl_->ShapeReference(); -} - -GeShape &GeTensorDesc::OriginShapeReference() const { - return impl_->OriginShapeReference(); -} - -ProtoAttrMap &GeTensorDesc::MutableAttrMap() { - return impl_->MutableAttrMap(); -} - -ConstProtoAttrMap &GeTensorDesc::GetAttrMap() const { - return impl_->GetAttrMap(); -} - -void GeTensorDesc::Update(const GeShape &shape, const Format format, const DataType dt) { - ShapeReference() = shape; - SetFormat(format); - SetDataType(dt); -} -const GeShape &GeTensorDesc::GetShape() const { return ShapeReference(); } - -GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } - -void GeTensorDesc::SetShape(const GeShape &shape) { ShapeReference() = shape; } - -void GeTensorDesc::SetShape(GeShape &&shape) { ShapeReference() = std::move(shape); } - -// set shape with -2, it stand for unknown shape -void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } - -// for unknown shape -graphStatus GeTensorDesc::SetValueRange(const std::vector> &range) { - std::vector> value_range; - for (const auto &ele : range) { - value_range.emplace_back(std::vector({ele.first, ele.second})); - } - auto const ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_VALUE_RANGE, value_range); - return ret ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus GeTensorDesc::GetValueRange(std::vector> &range) const { - std::vector> value_range; - (void) AttrUtils::GetListListInt(this, TENSOR_UTILS_VALUE_RANGE, value_range); - - for (const auto &ele : value_range) { - // here must be only two elemenet because pair - if (ele.size() != PAIR_ELEMENT_SIZE) { - REPORT_INNER_ERR_MSG("E18888", "value_range must contain only 2 value but really is %zu", ele.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] value_range must contain only 2 value but really is %zu", ele.size()); - return GRAPH_FAILED; - } - range.emplace_back(std::make_pair(ele[PAIR_ELEMENT_KEY], ele[PAIR_ELEMENT_VALUE])); - } - - return GRAPH_SUCCESS; -} - -graphStatus GeTensorDesc::SetShapeRange(const std::vector> &range) { - std::vector> shape_range; - for (const auto &ele : range) { - shape_range.emplace_back(std::vector({ele.first, ele.second})); - } - auto const ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); - return ret ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus GeTensorDesc::SetOriginShapeRange(const std::vector> &range) { - std::vector> origin_shape_range; - for (const auto &ele : range) { - origin_shape_range.emplace_back(std::vector({ele.first, ele.second})); - } - auto const ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_ORIGIN_SHAPE_RANGE, origin_shape_range); - return ret ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus GeTensorDesc::GetShapeRange(std::vector> &range) const { - std::vector> shape_range; - (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); - - for (const auto &ele : shape_range) { - // here must be only two elemenet because pair - if (ele.size() != PAIR_ELEMENT_SIZE) { - REPORT_INNER_ERR_MSG("E18888", "shape_range must contain only 2 value but really is %zu", ele.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] shape_range must contain only 2 value but really is %zu", ele.size()); - return GRAPH_FAILED; - } - std::pair pair({ele[PAIR_ELEMENT_KEY], ele[PAIR_ELEMENT_VALUE]}); - range.emplace_back(pair); - } - - return GRAPH_SUCCESS; -} - -graphStatus GeTensorDesc::GetOriginShapeRange(std::vector> &range) const { - std::vector> origin_shape_range; - (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_ORIGIN_SHAPE_RANGE, origin_shape_range); - - for (const auto &ele : origin_shape_range) { - // here must be only two elemenet because pair - if (ele.size() != PAIR_ELEMENT_SIZE) { - REPORT_INNER_ERR_MSG("E18888", "origin_shape_range must contain only 2 value but really is %zu", ele.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] origin_shape_range must contain only 2 value but really is %zu", ele.size()); - return GRAPH_FAILED; - } - std::pair pair({ele[PAIR_ELEMENT_KEY], ele[PAIR_ELEMENT_VALUE]}); - range.emplace_back(pair); - } - - return GRAPH_SUCCESS; -} - -const GeShape &GeTensorDesc::GetOriginShape() const { - return impl_->OriginShapeReference(); -} - -GeShape &GeTensorDesc::MutableOriginShape() const { - return OriginShapeReference(); -} - -void GeTensorDesc::SetOriginShape(const GeShape &origin_shape) { - impl_->OriginShapeReference() = origin_shape; - impl_->SetOriginShapeInited(true); -} - -bool GeTensorDesc::IsOriginShapeInitialized() const { - return impl_->IsOriginShapeInited(); -} - -Format GeTensorDesc::GetFormat() const { - return impl_->GetFormat(); -} - -void GeTensorDesc::SetFormat(const Format format) { - return impl_->SetFormat(format); -} - -void GeTensorDesc::SetName(const std::string &name) { - return impl_->SetName(name); -} - -const std::string GeTensorDesc::GetName() const { - return impl_->GetName(); -} - -Format GeTensorDesc::GetOriginFormat() const { - return impl_->GetOriginFormat(); -} - -void GeTensorDesc::SetOriginFormat(const Format origin_format) { - impl_->SetOriginFormat(origin_format); -} - -void GeTensorDesc::SetDataType(const DataType data_type) { - return impl_->SetDataType(data_type); -} - -DataType GeTensorDesc::GetDataType() const { - return impl_->GetDataType(); -} - -void GeTensorDesc::SetOriginDataType(const DataType origin_data_type) { - impl_->SetOriginDataType(origin_data_type); -} - -DataType GeTensorDesc::GetOriginDataType() const { - return impl_->GetOriginDataType(); -} - -std::vector GeTensorDesc::GetRefPortIndex() const { - std::vector ref_port_index; - (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); - return ref_port_index; -} - -void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { - (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); -} - -Placement GeTensorDesc::GetPlacement() const { - int64_t placement = 0; - (void)AttrUtils::GetInt(this, TENSOR_UTILS_PLACEMENT, placement); - return static_cast(placement); -} - -void GeTensorDesc::SetPlacement(const Placement placement) { - (void)AttrUtils::SetInt(this, TENSOR_UTILS_PLACEMENT, static_cast(placement)); -} - -graphStatus GeTensorDesc::IsValid() const { - if ((this->GetDataType() != DT_UNDEFINED) || (this->GetFormat() != FORMAT_RESERVED)) { - return GRAPH_SUCCESS; - } - return GRAPH_PARAM_INVALID; -} - -GeTensorDesc GeTensorDesc::Clone() const { return *this; } - -GeTensorDesc &GeTensorDesc::operator=(const GeTensorDesc &desc) { - if (&desc != this) { - AttrHolder::CopyFrom(desc); - *impl_ = *(desc.impl_); - } - return *this; -} - -GeTensorDesc &GeTensorDesc::operator=(GeTensorDesc &&desc) { - if (&desc != this) { - AttrHolder::CopyFrom(desc); - impl_ = desc.impl_; - } - return *this; -} - -const std::string GeTensorDesc::GetExpandDimsRule() const { - return impl_->GetExpandDimsRule(); -} -void GeTensorDesc::SetExpandDimsRule(const std::string &expand_dims_rule) { - impl_->SetExpandDimsRule(expand_dims_rule); -} - -uint32_t TensorDataImpl::invalid_data_ = 0x3A2D2900U; - -TensorDataImpl::TensorDataImpl(const TensorDataImpl &other) { - // Share data - tensor_descriptor_ = other.tensor_descriptor_; - aligned_ptr_ = other.aligned_ptr_; - length_ = other.length_; -} - -TensorDataImpl &TensorDataImpl::operator=(const TensorDataImpl &other) { - if (&other != this) { - // Share data - tensor_descriptor_ = other.tensor_descriptor_; - aligned_ptr_ = other.aligned_ptr_; - length_ = other.length_; - } - return *this; -} - -graphStatus TensorDataImpl::SetData(const uint8_t *const data, const size_t size) { - if (size == 0UL) { - GELOGD("size is 0"); - clear(); - return GRAPH_SUCCESS; - } - if (data == nullptr) { - GELOGD("data addr is empty"); - return GRAPH_SUCCESS; - } - - if (MallocAlignedPtr(size) == nullptr) { - GELOGE(MEMALLOC_FAILED, "[Malloc][Memory] failed, size=%zu", size); - return GRAPH_FAILED; - } - - size_t remain_size = size; - auto dst_addr = PtrToValue(aligned_ptr_->MutableGet()); - auto src_addr = PtrToValue(data); - while (remain_size > SECUREC_MEM_MAX_LEN) { - if (memcpy_s(ValueToPtr(dst_addr), SECUREC_MEM_MAX_LEN, - ValueToPtr(src_addr), SECUREC_MEM_MAX_LEN) != EOK) { - REPORT_INNER_ERR_MSG("E18888", "memcpy failed, size = %lu", SECUREC_MEM_MAX_LEN); - GELOGE(INTERNAL_ERROR, "[Memcpy][Data] failed, size = %lu", SECUREC_MEM_MAX_LEN); - return GRAPH_FAILED; - } - remain_size -= SECUREC_MEM_MAX_LEN; - dst_addr += SECUREC_MEM_MAX_LEN; - src_addr += SECUREC_MEM_MAX_LEN; - } - if (memcpy_s(ValueToPtr(dst_addr), remain_size, - ValueToPtr(src_addr), remain_size) != EOK) { - REPORT_INNER_ERR_MSG("E18888", "memcpy failed, size=%zu", remain_size); - GELOGE(INTERNAL_ERROR, "[Memcpy][Data] failed, size=%zu", remain_size); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -void TensorDataImpl::SetData(std::shared_ptr aligned_ptr, const size_t size) { - aligned_ptr_ = std::move(aligned_ptr); - length_ = size; -} - -graphStatus TensorDataImpl::SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - if (size == 0UL) { - GELOGW("[Set][Data] Input size is 0"); - clear(); - return GRAPH_SUCCESS; - } - if (data == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "data is nullptr"); - GELOGE(GRAPH_FAILED, "[Check][Param] data is nullptr"); - return GRAPH_FAILED; - } - length_ = size; - aligned_ptr_ = AlignedPtr::BuildFromData(data, delete_fuc); - return GRAPH_SUCCESS; -} - -graphStatus TensorDataImpl::ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - if (size == 0UL) { - GELOGW("[Reset][Data] Input size is 0"); - clear(); - return GRAPH_SUCCESS; - } - length_ = size; - if (aligned_ptr_ != nullptr) { - aligned_ptr_->Reset(data, delete_fuc); - } else { - aligned_ptr_ = AlignedPtr::BuildFromData(data, delete_fuc); - } - return GRAPH_SUCCESS; -} - -const uint8_t *TensorDataImpl::MallocAlignedPtr(const size_t size) { - if (size == 0UL) { - GELOGW("[Check][Param] Input data size is 0"); - clear(); - return PtrToPtr(&invalid_data_); - } - if (length_ != size) { - aligned_ptr_.reset(); - } - length_ = size; - if (aligned_ptr_ == nullptr) { - aligned_ptr_ = MakeShared(length_); - if (aligned_ptr_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create AlignedPtr failed."); - GELOGE(INTERNAL_ERROR, "[Create][AlignedPtr] failed."); - return nullptr; - } - } - - return aligned_ptr_->Get(); -} - -size_t TensorDataImpl::GetSize() const { return length_; } - -const uint8_t *TensorDataImpl::GetData() const { - if (length_ == 0UL) { - return PtrToPtr(&invalid_data_); - } - if (aligned_ptr_ == nullptr) { - return nullptr; - } - return aligned_ptr_->Get(); -} - -uint8_t *TensorDataImpl::GetData() { - if (length_ == 0UL) { - return PtrToPtr(&invalid_data_); - } - if (aligned_ptr_ == nullptr) { - return nullptr; - } - return aligned_ptr_->MutableGet(); -} - -bool TensorDataImpl::IsTensorDataValid() const { - return !((length_ == 0UL) && (GetData() == PtrToPtr(&invalid_data_))); -} - -void TensorDataImpl::clear() { - aligned_ptr_.reset(); - length_ = 0UL; -} - -uint8_t TensorDataImpl::operator[](const size_t index) const { - const uint8_t *const value_ptr = PtrAdd(aligned_ptr_->MutableGet(), length_, index); - if (value_ptr != nullptr) { - return *value_ptr; - } - return static_cast(0xffU); -} - -TensorData::TensorData() - : impl_(MakeShared()) {} - -TensorData::TensorData(const TensorData &other) - : impl_(MakeShared(*(other.impl_))) {} - -TensorData::~TensorData() = default; - -TensorData &TensorData::operator=(const TensorData &other) { - if (&other != this) { - *impl_ = *(other.impl_); - } - return *this; -} - -graphStatus TensorData::SetData(std::vector &&data) { return SetData(data.data(), data.size()); } -graphStatus TensorData::SetData(const std::vector &data) { return SetData(data.data(), data.size()); } -graphStatus TensorData::SetData(const Buffer &data) { return SetData(data.data(), data.size()); } -graphStatus TensorData::SetData(const TensorData &data) { return SetData(data.data(), data.size()); } - -graphStatus TensorData::SetData(const uint8_t *const data, const size_t size) { - return impl_->SetData(data, size); -} - -graphStatus TensorData::SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - return impl_->SetData(data, size, delete_fuc); -} - -void TensorData::SetData(std::shared_ptr aligned_ptr, const size_t size) { - impl_->SetData(aligned_ptr, size); -} - -graphStatus TensorData::ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - return impl_->ResetData(data, size, delete_fuc); -} - -const uint8_t *TensorData::MallocAlignedPtr(const size_t size) { - return impl_->MallocAlignedPtr(size); -} - -size_t TensorData::GetSize() const { - return impl_->GetSize(); -} - -const uint8_t *TensorData::GetData() const { - return impl_->GetData(); -} - -uint8_t *TensorData::GetData() { - return impl_->GetData(); -} - -bool TensorData::IsTensorDataValid() const { - return impl_->IsTensorDataValid(); -} - -const std::uint8_t *TensorData::data() const { return GetData(); } -std::uint8_t *TensorData::data() { return GetData(); } -std::size_t TensorData::size() const { return GetSize(); } -void TensorData::clear() { - impl_->clear(); -} - -uint8_t TensorData::operator[](const size_t index) const { - return (*impl_)[index]; -} - -const std::shared_ptr &TensorData::GetAlignedPtr() { - return impl_->GetAlignedPtr(); -} - -GeTensorImpl::GeTensorImpl() : tensor_def_(nullptr, nullptr), desc_(), tensor_data_() { - if (desc_.impl_ != nullptr) { - if (tensor_data_.impl_ != nullptr) { - tensor_data_.impl_->tensor_descriptor_ = desc_.impl_; - } - } -} - -GeTensorImpl::GeTensorImpl(const GeTensorDesc &tensor_desc) : GeTensorImpl() { - DescReference() = tensor_desc; -} - -GeTensorImpl::GeTensorImpl(const GeTensorDesc &tensor_desc, const std::vector &data) : GeTensorImpl() { - DescReference() = tensor_desc; - if (tensor_data_.SetData(data) != GRAPH_SUCCESS) { - GELOGW("[Set][Data] Set data failed"); - } -} - -GeTensorImpl::GeTensorImpl(const GeTensorDesc &tensor_desc, const uint8_t *const data, const size_t size) - : GeTensorImpl() { - DescReference() = tensor_desc; - if (tensor_data_.SetData(data, size) != GRAPH_SUCCESS) { - GELOGW("[Set][Data] Set data failed"); - } -} - -GeTensorImpl::GeTensorImpl(GeTensorDesc &&tensor_desc, std::vector &&data) : GeTensorImpl() { - DescReference() = std::move(tensor_desc); - if (tensor_data_.SetData(data) != GRAPH_SUCCESS) { - GELOGW("[Set][Data] Set data failed"); - } -} - -GeTensorImpl::GeTensorImpl(const GeTensorDesc &tensor_desc, const Buffer &data) : GeTensorImpl() { - DescReference() = tensor_desc; - if (data.size() == 0UL) { - GELOGI("GetSize res is 0."); - } - if (data.data() == nullptr) { - GELOGI("data addr is null."); - } - if (tensor_data_.SetData(data) != GRAPH_SUCCESS) { - GELOGW("[Set][Data] Set data failed"); - } -} - -GeTensorImpl::GeTensorImpl(const GeTensorDesc &tensor_desc, std::shared_ptr aligned_ptr, const size_t size) - : GeTensorImpl() { - DescReference() = tensor_desc; - tensor_data_.SetData(std::move(aligned_ptr), size); -} - -GeTensorImpl::GeTensorImpl(const GeTensorDesc &tensor_desc, const size_t size) : GeTensorImpl() { - DescReference() = tensor_desc; - if (tensor_data_.MallocAlignedPtr(size) == nullptr) { - GELOGW("[Malloc][Memory] Malloc memory failed, size=%zu", size); - } -} - -GeTensorImpl::GeTensorImpl(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg) - : tensor_def_(proto_owner, proto_msg) { - // 这里后续改为反序列化接口调用,从proto恢复GeTensorDesc - desc_ = GeTensorDesc((proto_msg == nullptr) ? nullptr : proto_msg->mutable_desc()); - tensor_data_ = TensorData(); - if ((tensor_data_.impl_ != nullptr) && (desc_.impl_ != nullptr)) { - // 之前没有把TensorData上的proto变为GeTensorDesc,因为TensorData创建后不会修改,多个TensorData通过GeIrProto共享 - // 但是!原本的语义是TensorData上的proto::TensorDescriptor与Tensor上的GeTensorDesc是共享的,当GeTensorDesc改造完 - // 这种共享的能力就消失了,这会导致在GeTensor创建后,对GeTensorDesc的修改无法反应到TensorData上,看起来只能将TensorData - // 上的proto::TensorDescriptor修改为GeTensorDescImpl,并且需要与GeTensor的GeTensorDesc共享 - tensor_data_.impl_->tensor_descriptor_ = desc_.impl_; - } - - if (proto_msg != nullptr) { - if (proto_owner != nullptr) { - BuildAlignerPtrWithProtoData(); - } else { - (void)tensor_data_.SetData(PtrToPtr(proto_msg->data().data()), - proto_msg->data().size()); - } - } -} - -GeTensorDesc &GeTensorImpl::DescReference() const { - return desc_; -} - -void GeTensorImpl::BuildAlignerPtrWithProtoData() { - auto const proto_msg = tensor_def_.GetProtoMsg(); - if ((proto_msg == nullptr) || - (PtrToPtr(proto_msg->data().data()) == tensor_data_.data())) { - return; - } - if (tensor_data_.impl_ == nullptr) { - return; - } - - tensor_data_.impl_->length_ = proto_msg->data().size(); - tensor_data_.impl_->aligned_ptr_.reset(); - tensor_data_.impl_->aligned_ptr_ = AlignedPtr::BuildFromAllocFunc( - [&proto_msg](std::unique_ptr &ptr) { - ptr.reset(const_cast(PtrToPtr( - proto_msg->data().data()))); - }, - [](const uint8_t *const ptr) { - (void)ptr; - }); -} - -graphStatus GeTensorImpl::SetData(std::vector &&data) { - if (tensor_def_.GetProtoOwner() != nullptr) { - auto const proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data.data(), data.size()); - BuildAlignerPtrWithProtoData(); - return GRAPH_SUCCESS; - } - return tensor_data_.SetData(data); -} - -graphStatus GeTensorImpl::SetData(const std::vector &data) { - if (tensor_def_.GetProtoOwner() != nullptr) { - auto const proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data.data(), data.size()); - BuildAlignerPtrWithProtoData(); - return GRAPH_SUCCESS; - } - return tensor_data_.SetData(data); -} - -graphStatus GeTensorImpl::SetData(const uint8_t *const data, const size_t size) { - if (size > 0UL) { - GE_CHECK_NOTNULL(data); - } - if (tensor_def_.GetProtoOwner() != nullptr) { - auto const proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - proto_msg->set_data(data, size); - BuildAlignerPtrWithProtoData(); - return GRAPH_SUCCESS; - } - return tensor_data_.SetData(data, size); -} - -graphStatus GeTensorImpl::SetData(const Buffer &data) { - if (tensor_def_.GetProtoOwner() != nullptr) { - auto const proto_msg = tensor_def_.GetProtoMsg(); - GE_CHECK_NOTNULL(proto_msg); - if (data.size() == 0UL) { - GELOGI("GetSize res is 0."); - } - if (data.data() == nullptr) { - GELOGI("data addr is null."); - } - proto_msg->set_data(data.data(), data.size()); - BuildAlignerPtrWithProtoData(); - return GRAPH_SUCCESS; - } - return tensor_data_.SetData(data); -} - -graphStatus GeTensorImpl::SetData(const TensorData &data) { - return SetData(data.data(), data.size()); -} - -graphStatus GeTensorImpl::SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - return tensor_data_.SetData(data, size, delete_fuc); -} - -graphStatus GeTensorImpl::ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - return tensor_data_.ResetData(data, size, delete_fuc); -} - -void GeTensorImpl::ClearData() { - if (tensor_def_.GetProtoOwner() != nullptr) { - auto const proto_msg = tensor_def_.GetProtoMsg(); - if (proto_msg != nullptr) { - proto_msg->clear_data(); - } - } - tensor_data_.clear(); -} - -bool GeTensorImpl::IsTensorDataValid() const { - return tensor_data_.IsTensorDataValid(); -} - -void GeTensorImpl::Clone(GeTensorImpl &tensor) const { - if ((tensor.desc_.impl_ != nullptr) && (desc_.impl_ != nullptr)) { - *(tensor.desc_.impl_) = *(desc_.impl_); - } - if ((tensor.tensor_data_.impl_ != nullptr) && (tensor.desc_.impl_ != nullptr)) { - tensor.tensor_data_.impl_->tensor_descriptor_ = tensor.desc_.impl_; - } - (void)tensor.SetData(GetData()); -} - -std::shared_ptr GeTensorImpl::GetAlignedPtr() const { - if (tensor_data_.impl_ != nullptr) { - return tensor_data_.impl_->GetAlignedPtr(); - } - return nullptr; -} - -GeTensorImpl::GeTensorImpl(const GeTensorImpl &other) : GeTensorImpl() { - *this = other; -} - -GeTensorImpl &GeTensorImpl::operator=(const GeTensorImpl &other) { - if (&other != this) { - if (other.tensor_def_.GetProtoOwner() != nullptr) { - // Old scene, share tensor_def, tensor_desc, tensor_data with `other` - tensor_def_ = other.tensor_def_; - // 这里修改了 - desc_ = other.desc_; - if ((tensor_data_.impl_ != nullptr) && (desc_.impl_ != nullptr)) { - tensor_data_.impl_->tensor_descriptor_ = desc_.impl_; - } - BuildAlignerPtrWithProtoData(); - } else { - // share tensor_data, do not share tensor_desc, tensor_def is null - desc_ = other.desc_; - tensor_data_ = other.tensor_data_; - if ((tensor_data_.impl_ != nullptr) && (desc_.impl_ != nullptr)) { - tensor_data_.impl_->tensor_descriptor_ = desc_.impl_; - } - } - } - return *this; -} - -GeTensor::GeTensor() : impl_(MakeShared()) {} - -GeTensor::GeTensor(GeTensor &&other) noexcept : impl_(std::move(other.impl_)) {} - -GeTensor::GeTensor(GeTensorImplPtr impl) : impl_(std::move(impl)) {} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc) - : impl_(MakeShared(tensor_desc)) {} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const std::vector &data) - : impl_(MakeShared(tensor_desc, data)) {} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *const data, const size_t size) - : impl_(MakeShared(tensor_desc, data, size)) {} - -GeTensor::GeTensor(GeTensorDesc &&tensor_desc, std::vector &&data) - : impl_(MakeShared(std::move(tensor_desc), std::move(data))) {} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data) - : impl_(MakeShared(tensor_desc, data)) {} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, std::shared_ptr aligned_ptr, const size_t size) - : impl_(MakeShared(tensor_desc, aligned_ptr, size)) {} - -GeTensor::GeTensor(const GeTensorDesc &tensor_desc, const size_t size) - : impl_(MakeShared(tensor_desc, size)) {} - -GeTensor::GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg) - : impl_(MakeShared(proto_owner, proto_msg)) {} - -GeTensor::~GeTensor() = default; - -void GeTensor::BuildAlignerPtrWithProtoData() { - impl_->BuildAlignerPtrWithProtoData(); -} - -const GeTensorDesc &GeTensor::GetTensorDesc() const { return DescReference(); } - -GeTensorDesc &GeTensor::MutableTensorDesc() { return DescReference(); } - -GeTensorDesc &GeTensor::DescReference() const { - return impl_->DescReference(); -} - -void GeTensor::SetTensorDesc(const GeTensorDesc &tensor_desc) { DescReference() = tensor_desc; } - -graphStatus GeTensor::SetData(std::vector &&data) { - return impl_->SetData(data); -} - -graphStatus GeTensor::SetData(const std::vector &data) { - return impl_->SetData(data); -} - -graphStatus GeTensor::SetData(const uint8_t *const data, const size_t size) { - return impl_->SetData(data, size); -} - -graphStatus GeTensor::SetData(const Buffer &data) { - return impl_->SetData(data); -} - -graphStatus GeTensor::SetData(const TensorData &data) { - return SetData(data.data(), data.size()); -} - -graphStatus GeTensor::SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - return impl_->SetData(data, size, delete_fuc); -} - -graphStatus GeTensor::ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc) { - return impl_->ResetData(data, size, delete_fuc); -} - -void GeTensor::ClearData() { - impl_->ClearData(); -} - -GeTensor GeTensor::Clone() const { - const GeTensor tensor; - impl_->Clone(*(tensor.impl_)); - return tensor; -} - -GeTensor::GeTensor(const GeTensor &other) - : impl_(MakeShared(*(other.impl_))) {} - -GeTensor &GeTensor::operator=(const GeTensor &other) { - if (&other != this) { - *impl_ = *(other.impl_); - } - return *this; -} - -GeTensor &GeTensor::operator=(GeTensor &&other) { - if (&other != this) { - impl_ = other.impl_; - } - return *this; -} - -std::shared_ptr GeTensor::GetAlignedPtr() { - return impl_->GetAlignedPtr(); -} - -const TensorData &GeTensor::GetData() const { - return impl_->GetData(); -} -TensorData &GeTensor::MutableData() { - return impl_->MutableData(); -} - -bool GeTensor::IsTensorDataValid() const { - return impl_->IsTensorDataValid(); -} - -// zero copy SetData -void GeTensor::SetData(std::shared_ptr aligned_ptr, const size_t size) { - impl_->SetData(std::move(aligned_ptr), size); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, - int64_t &size) { - if (tensor_desc.impl_ != nullptr) { - size = tensor_desc.impl_->ext_meta_.GetSize(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize( - GeTensorDesc &tensor_desc, const int64_t size) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetSize(size); - } -} - -int64_t TensorUtils::GetWeightSize(const GeTensorDesc &tensor_desc) { - if (tensor_desc.impl_ != nullptr) { - return tensor_desc.impl_->ext_meta_.GetWeightSize(); - } - return 0; -} - -int64_t TensorUtils::GetWeightSize(const GeTensor &tensor) { - return GetWeightSize(tensor.GetTensorDesc()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t TensorUtils::GetWeightSize(const ConstGeTensorPtr &tensor_ptr) { - if (tensor_ptr == nullptr) { - return 0; - } - return GetWeightSize(*tensor_ptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint8_t *TensorUtils::GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, - const uint8_t *const base) { - if (tensor_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param tensor_ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] tensor_ptr is null."); - return nullptr; - } - return GetWeightAddr(*tensor_ptr, base); -} - -uint8_t *TensorUtils::GetWeightAddr(const GeTensor &tensor, const uint8_t *const base) { - if (base == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param base is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] base is null."); - return nullptr; - } - int64_t weight_data_offset = 0; - if (GetDataOffset(tensor.GetTensorDesc(), weight_data_offset) != GRAPH_SUCCESS) { - return nullptr; - } - - if (weight_data_offset == 0) { - // The weight of offset 0 is still in const op, still get from ATTR_NAME_WEIGHTS. - return const_cast(tensor.GetData().data()); - } - return PtrToPtr(ValueToPtr(PtrToValue(base) + - static_cast(weight_data_offset))); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetWeightSize(GeTensorDesc &tensor_desc, - const int64_t size) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetWeightSize(size); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetReuseInput(const GeTensorDesc &tensor_desc, - bool &flag) { - if (tensor_desc.impl_ != nullptr) { - flag = tensor_desc.impl_->ext_meta_.GetReuseInput(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInput( - GeTensorDesc &tensor_desc, const bool flag) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetReuseInput(flag); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetOutputTensor(const GeTensorDesc &tensor_desc, - bool &flag) { - if (tensor_desc.impl_ != nullptr) { - flag = tensor_desc.impl_->ext_meta_.GetOutputTensor(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetOutputTensor( - GeTensorDesc &tensor_desc, const bool flag) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetOutputTensor(flag); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDeviceType(const GeTensorDesc &tensor_desc, - DeviceType &type) { - if (tensor_desc.impl_ != nullptr) { - type = tensor_desc.impl_->ext_meta_.GetDeviceType(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDeviceType(GeTensorDesc &tensor_desc, - const DeviceType type) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetDeviceType(type); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetInputTensor(const GeTensorDesc &tensor_desc, - bool &flag) { - if (tensor_desc.impl_ != nullptr) { - flag = tensor_desc.impl_->ext_meta_.GetInputTensor(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetInputTensor( - GeTensorDesc &tensor_desc, const bool flag) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetInputTensor(flag); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRealDimCnt(const GeTensorDesc &tensor_desc, - uint32_t &cnt) { - if (tensor_desc.impl_ != nullptr) { - cnt = tensor_desc.impl_->ext_meta_.GetRealDimCnt(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRealDimCnt(GeTensorDesc &tensor_desc, - const uint32_t cnt) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetRealDimCnt(cnt); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx) { - if (tensor_desc.impl_ != nullptr) { - idx = tensor_desc.impl_->ext_meta_.GetReuseInputIndex(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetReuseInputIndex(GeTensorDesc &tensor_desc, - const uint32_t idx) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetReuseInputIndex(idx); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetDataOffset(const GeTensorDesc &tensor_desc, - int64_t &offset) { - if (tensor_desc.impl_ != nullptr) { - offset = tensor_desc.impl_->ext_meta_.GetDataOffset(); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetDataOffset(GeTensorDesc &tensor_desc, - const int64_t offset) { - if (tensor_desc.impl_ != nullptr) { - tensor_desc.impl_->ext_meta_.SetDataOffset(offset); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetRC(const GeTensorDesc &tensor_desc, - uint32_t &rc) { - return AttrUtils::GetInt(&tensor_desc, TENSOR_UTILS_RC, rc) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetRC(GeTensorDesc &tensor_desc, const uint32_t rc) { - (void)AttrUtils::SetInt(&tensor_desc, TENSOR_UTILS_RC, static_cast(rc)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool TensorUtils::IsOriginShapeInited(const GeTensorDesc &tensor_desc) { - return tensor_desc.impl_->IsOriginShapeInited(); -} - -GeTensor TensorUtils::CreateShareTensor(const GeTensor &other) { - GeTensor tensor; - ShareTensor(other, tensor); - return tensor; -} - -GeTensor TensorUtils::CreateShareTensor(const GeTensorDesc &tensor_desc, - std::shared_ptr aligned_ptr, - const size_t size) { - const GeTensor tensor(tensor_desc); - if (tensor.impl_ != nullptr) { - ShareAlignedPtr(std::move(aligned_ptr), size, tensor.impl_->tensor_data_); - } - return tensor; -} - -void TensorUtils::ShareTensor(const GeTensor &from, GeTensor &to) { - if (&from == &to) { - return; - } - if ((from.impl_ != nullptr) && (to.impl_ != nullptr)) { - if (from.impl_->tensor_def_.GetProtoOwner() != nullptr) { - // 这种场景下看原来的逻辑,已经没有什么是不是共享的了,所以直接改成了impl共享,幸好impl是shared ptr - // 但是之前似乎有个啥逻辑。是假定可以把shared ptr当成unique用的,得风暴下,记不得了 - to.impl_ = from.impl_; - } else { - // share tensor_data, do not share tensor_desc, tensor_def is null - to.impl_->desc_ = from.impl_->desc_; - to.impl_->tensor_data_ = from.impl_->tensor_data_; - to.impl_->tensor_data_.impl_->tensor_descriptor_ = to.impl_->desc_.impl_; - } - } -} -void TensorUtils::ShareTensorData(const TensorData &from, TensorData &to) { - if (&from == &to) { - return; - } - // Share data - if ((from.impl_ != nullptr) && (to.impl_ != nullptr)) { - to.impl_->tensor_descriptor_ = from.impl_->tensor_descriptor_; - to.impl_->aligned_ptr_ = from.impl_->aligned_ptr_; - to.impl_->length_ = from.impl_->length_; - } -} -TensorData TensorUtils::CreateShareTensorData(const TensorData &other) { - TensorData td; - ShareTensorData(other, td); - return td; -} -void TensorUtils::ShareAlignedPtr(std::shared_ptr ptr, const size_t size, TensorData &to) { - if (to.impl_ != nullptr) { - to.impl_->aligned_ptr_ = std::move(ptr); - to.impl_->length_ = size; - } -} -void TensorUtils::ShareAlignedPtr(std::shared_ptr ptr, const size_t size, GeTensor &to) { - if (to.impl_ != nullptr) { - ShareAlignedPtr(std::move(ptr), size, to.impl_->tensor_data_); - } -} -// UT -void TensorUtils::CopyTensor(const GeTensor &from, GeTensor &to) { - if (&from == &to) { - return; - } - if ((from.impl_ == nullptr) || (to.impl_ == nullptr)) { - return; - } - if (from.impl_->tensor_def_.GetProtoOwner() != nullptr) { - to.impl_->tensor_def_.CopyValueFrom(from.impl_->tensor_def_); - to.impl_->desc_.impl_ = GeTensorDesc(to.impl_->tensor_def_.GetProtoMsg()->mutable_desc()).impl_; - to.impl_->desc_.impl_->attrs_ = from.impl_->desc_.impl_->attrs_; - to.impl_->tensor_data_.impl_->tensor_descriptor_ = to.impl_->desc_.impl_; - to.BuildAlignerPtrWithProtoData(); - } else { - // tensor_def is null, copy tensor_data, tensor_desc - to.impl_->desc_ = from.impl_->desc_; - (void)to.impl_->tensor_data_.SetData(from.impl_->tensor_data_); - to.impl_->tensor_data_.impl_->tensor_descriptor_ = to.impl_->desc_.impl_; - } -} - -bool TensorUtils::IsShapeEqual(const GeShape &src, const GeShape &dst) { - GE_ASSERT_TRUE(src.GetDimNum() == dst.GetDimNum(), "src(%s) dims num is not equal to dst(%s) dim num.", - src.ToString().c_str(), dst.ToString().c_str()); - const auto src_dims = src.GetDims(); - const auto dst_dims = dst.GetDims(); - for (size_t i = 0UL; i < src_dims.size(); ++i) { - if ((src_dims[i] == -1) || (dst_dims[i] == -1)) { - GELOGW("src dim %d is %d, dst dim %d is %d, there has unknown shape.", i, src_dims[i], i, dst_dims[i]); - continue; - } - GE_ASSERT_TRUE(src_dims[i] == dst_dims[i], "src(%s) dim(%zu) = %ld is not equal to dst(%s) dim(%zu) = %ld.", - src.ToString().c_str(), i, src_dims[i], dst.ToString().c_str(), i, dst_dims[i]); - } - return true; -} -} // namespace ge diff --git a/graph/normal_graph/ge_tensor_impl.h b/graph/normal_graph/ge_tensor_impl.h deleted file mode 100644 index 0828cc49c31a1360ed595b2703dad9ab27aa9540..0000000000000000000000000000000000000000 --- a/graph/normal_graph/ge_tensor_impl.h +++ /dev/null @@ -1,343 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_GE_TENSOR_IMPL_H_ -#define GRAPH_GE_TENSOR_IMPL_H_ - - -#include -#include -#include -#include "graph/ge_tensor.h" - -namespace ge { -class GeTensorDescImpl { - public: - GeTensorDescImpl() = default; - GeTensorDescImpl(const GeShape &shape, const Format format, const DataType dt); - explicit GeTensorDescImpl(proto::TensorDescriptor *const proto_msg); - ~GeTensorDescImpl() = default; - - GeShape &ShapeReference() const; - GeShape &OriginShapeReference() const; - - bool GeTensorDescAttrsAreEqual(const GeTensorDescImpl &other) const; - bool operator==(const GeTensorDescImpl &other) const; - - ProtoAttrMap &MutableAttrMap(); - ConstProtoAttrMap &GetAttrMap() const; - void SetShape(GeShape &shape) const; - - void SetDataType(const DataType dtype); - DataType GetDataType() const; - void SetFormat(const Format format); - Format GetFormat() const; - void SetOriginFormat(const Format format); - Format GetOriginFormat() const; - void SetOriginDataType(const DataType dtype); - DataType GetOriginDataType() const; - void SetName(const std::string &name); - const std::string GetName() const; - bool IsOriginShapeInited() const { - return ext_meta_.IsOriginShapeInited(); - } - void SetOriginShapeInited(const bool origin_shape_inited) { - ext_meta_.SetOriginShapeInited(origin_shape_inited); - } - - void SetExpandDimsRule(const std::string &expand_dims_rule) { - ext_meta_.SetExpandDimsRule(expand_dims_rule); - } - std::string GetExpandDimsRule() const { - return ext_meta_.GetExpandDimsRule(); - } - - private: - friend class GeTensorImpl; - friend class TensorUtils; - friend class GeAttrValueImp; - friend class ModelSerializeImp; - friend class GeTensorSerializeUtils; - friend class OnnxUtils; - - class ExtMeta { - public: - bool operator==(const ExtMeta& other) const { - return (name == other.name) && (device_type == other.device_type) && (size == other.size) && - (weight_size == other.weight_size) && (cmps_tab_offset == other.cmps_tab_offset) && - (reuse_input_index == other.reuse_input_index) && (cmps_tab == other.cmps_tab) && - (data_offset == other.data_offset) && (cmps_size == other.cmps_size) && - (real_dim_cnt == other.real_dim_cnt) && - (other.reuse_input ? reuse_input : !reuse_input) && - (other.input_tensor ? input_tensor : !input_tensor) && - (other.output_tensor ? output_tensor : !output_tensor) && - (other.origin_shape_inited_ ? origin_shape_inited_ : !origin_shape_inited_); - } - // for name - std::string GetName() const { - return name; - } - - void SetName(const std::string &v) { - name = v; - } - - // for device_type - DeviceType GetDeviceType() const { - return device_type; - } - - std::string GetDeviceTypeStr() const; - - void SetDeviceType(const DeviceType v) { - device_type = v; - } - - // for size - int64_t GetSize() const { - return size; - } - - void SetSize(const int64_t v) { - size = v; - } - - // for weight_size - int64_t GetWeightSize() const { - return weight_size; - } - - void SetWeightSize(const int64_t v) { - weight_size = v; - } - - // for data_offset - int64_t GetDataOffset() const { - return data_offset; - } - - void SetDataOffset(const int64_t v) { - data_offset = v; - } - - // for real_dim_cnt - uint32_t GetRealDimCnt() const { - return real_dim_cnt; - } - - void SetRealDimCnt(const uint32_t v) { - real_dim_cnt = v; - } - - // for input_tensor - bool GetInputTensor() const { - return input_tensor; - } - - void SetInputTensor(const bool v) { - input_tensor = v; - } - - // for reuse_input - bool GetReuseInput() const { - return reuse_input; - } - - void SetReuseInput(const bool v) { - reuse_input = v; - } - - // for reuse_input_index - uint32_t GetReuseInputIndex() const { - return reuse_input_index; - } - - void SetReuseInputIndex(const uint32_t v) { - reuse_input_index = v; - } - - // for output_tensor - bool GetOutputTensor() const { - return output_tensor; - } - - void SetOutputTensor(const bool v) { - output_tensor = v; - } - - // for cmps_size - int64_t GetCmpsSize() const { - return cmps_size; - } - - void SetCmpsSize(const int64_t v) { - cmps_size = v; - } - - // for cmps_tab - std::string GetCmpsTab() const { - return cmps_tab; - } - - void SetCmpsTab(const std::string &v) { - cmps_tab = v; - } - - // for cmps_tab_offset - int64_t GetCmpsTabOffset() const { - return cmps_tab_offset; - } - - void SetCmpsTabOffset(const int64_t v) { - cmps_tab_offset = v; - } - - bool IsOriginShapeInited() const { - return origin_shape_inited_; - } - - void SetOriginShapeInited(const bool origin_shape_inited) { - origin_shape_inited_ = origin_shape_inited; - } - - void SetExpandDimsRule(const std::string &expand_dims_rule) { - expand_dims_rule_ = expand_dims_rule; - } - std::string GetExpandDimsRule() const { - return expand_dims_rule_; - } - - private: - int64_t size{0}; - int64_t data_offset{0}; - int64_t cmps_tab_offset{0}; - int64_t cmps_size{0}; - int64_t weight_size{0}; - - uint32_t real_dim_cnt{0U}; - uint32_t reuse_input_index{0U}; - - DeviceType device_type{NPU}; - bool input_tensor{false}; - bool reuse_input{false}; - bool output_tensor{false}; - bool origin_shape_inited_{false}; - - std::string cmps_tab; - std::string name; - - std::string expand_dims_rule_; - }; - - mutable GeShape shape_; - Format format_{FORMAT_ND}; - DataType dtype_{DT_FLOAT}; - - mutable GeShape origin_shape_; - Format origin_format_{FORMAT_ND}; - DataType origin_dtype_{DT_UNDEFINED}; - - ExtMeta ext_meta_; - AttrStore attrs_; -}; - -class TensorDataImpl { - public: - TensorDataImpl() = default; - - TensorDataImpl(const TensorDataImpl &other); - - ~TensorDataImpl() = default; - - TensorDataImpl &operator=(const TensorDataImpl &other); - - graphStatus SetData(const uint8_t * const data, const size_t size); - graphStatus SetData(uint8_t * const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); - void SetData(std::shared_ptr aligned_ptr, const size_t size); - - graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); - - const uint8_t *MallocAlignedPtr(const size_t size); - - size_t GetSize() const; - const uint8_t *GetData() const; - uint8_t *GetData(); - bool IsTensorDataValid() const; - - void clear(); - - uint8_t operator[](const size_t index) const; - - const std::shared_ptr &GetAlignedPtr() const { return aligned_ptr_; } - - private: - friend class GeTensorImpl; - friend class TensorUtils; - friend class GeAttrValueImp; - friend class ModelSerializeImp; - friend class GeTensorSerializeUtils; - // TensorDatat shared with a GeTensorDesc by holding the impl of GeTensorDesc - std::shared_ptr tensor_descriptor_; - std::shared_ptr aligned_ptr_ = nullptr; - size_t length_ = 0UL; - // functions data() & mutable_data() return address of invalid_data_ when length_ is 0 - // defined for coding convenience - static uint32_t invalid_data_; -}; - -class GeTensorImpl { - public: - GeTensorImpl(); - explicit GeTensorImpl(const GeTensorDesc &tensor_desc); - GeTensorImpl(const GeTensorDesc &tensor_desc, const std::vector &data); - GeTensorImpl(const GeTensorDesc &tensor_desc, const uint8_t * const data, const size_t size); - GeTensorImpl(GeTensorDesc &&tensor_desc, std::vector &&data); - GeTensorImpl(const GeTensorDesc &tensor_desc, const Buffer &data); - GeTensorImpl(const GeTensorDesc &tensor_desc, std::shared_ptr aligned_ptr, const size_t size); - GeTensorImpl(const GeTensorDesc &tensor_desc, const size_t size); - GeTensorImpl(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg); - GeTensorImpl(const GeTensorImpl &other); - - ~GeTensorImpl() = default; - - GeTensorImpl &operator=(const GeTensorImpl &other); - - GeTensorDesc &DescReference() const; - void BuildAlignerPtrWithProtoData(); - graphStatus SetData(std::vector &&data); - graphStatus SetData(const std::vector &data); - graphStatus SetData(const uint8_t * const data, size_t const size); - graphStatus SetData(const Buffer &data); - graphStatus SetData(const TensorData &data); - graphStatus SetData(uint8_t * const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); - void ClearData(); - void Clone(GeTensorImpl &tensor) const; - - std::shared_ptr GetAlignedPtr() const; - const TensorData &GetData() const { return tensor_data_; } - TensorData &MutableData() { return tensor_data_; } - bool IsTensorDataValid() const; - // zero copy SetData - void SetData(std::shared_ptr aligned_ptr, const size_t size) { - tensor_data_.SetData(std::move(aligned_ptr), size); - } - graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); - - private: - friend class TensorUtils; - friend class GeAttrValueImp; - friend class ModelSerializeImp; - friend class GeTensorSerializeUtils; - GeIrProtoHelper tensor_def_; - // Reference from tensor_data_, do not direct use - mutable GeTensorDesc desc_; - TensorData tensor_data_; -}; -} // namespace ge -#endif // GRAPH_GE_TENSOR_IMPL_H_ diff --git a/graph/normal_graph/gnode.cc b/graph/normal_graph/gnode.cc deleted file mode 100644 index 0196c69d14155327f202a60e72014fc7955a5402..0000000000000000000000000000000000000000 --- a/graph/normal_graph/gnode.cc +++ /dev/null @@ -1,1130 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/gnode.h" - -#include "debug/ge_util.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "graph/anchor.h" -#include "graph/utils/node_adapter.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_op_types.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "common/util/mem_utils.h" -#include "common/checker.h" - -#define NODE_ATTR_GET_IMP(ArgType) \ - graphStatus GNode::GetAttr(const AscendString &name, ArgType &attr_value) const { \ - const char_t *const ascend_name = name.GetString(); \ - if (std::string(ascend_name).empty()) { \ - REPORT_INNER_ERR_MSG("E18888", "ascend std::string error."); \ - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] GetAttr: ascend std::string error."); \ - return GRAPH_PARAM_INVALID; \ - } \ - \ - if (impl_ == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] GetAttr: node impl is nullptr."); \ - return GRAPH_FAILED; \ - } \ - \ - const std::shared_ptr node_ptr_share = impl_->node_ptr_.lock(); \ - if (node_ptr_share == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] GetAttr: the node shared ptr is not valid."); \ - return GRAPH_FAILED; \ - } \ - const Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr_share); \ - if (op.GetAttr(ascend_name, attr_value) != GRAPH_SUCCESS) { \ - GELOGW("[Get][Attr] of node[%s] failed.", node_ptr_share->GetName().c_str()); \ - return GRAPH_FAILED; \ - } \ - \ - return GRAPH_SUCCESS; \ - } - -#define NODE_ATTR_SET_IMP(ArgType) \ - graphStatus GNode::SetAttr(const AscendString &name, ArgType &attr_value) const { \ - const char_t *const ascend_name = name.GetString(); \ - if (std::string(ascend_name).empty()) { \ - REPORT_INNER_ERR_MSG("E18888", "ascend std::string error."); \ - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] SetAttr: ascend std::string error."); \ - return GRAPH_PARAM_INVALID; \ - } \ - \ - if (impl_ == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: node impl is nullptr."); \ - return GRAPH_FAILED; \ - } \ - \ - const std::shared_ptr node_ptr_share = impl_->node_ptr_.lock(); \ - if (node_ptr_share == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: the node shared ptr is not valid."); \ - return GRAPH_FAILED; \ - } \ - \ - Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr_share); \ - (void)op.SetAttr(ascend_name, attr_value); \ - return GRAPH_SUCCESS; \ - } - -namespace ge { -class NodeImpl { - public: - NodeImpl() = default; - ~NodeImpl() = default; - - NodeImpl(NodeImpl &) = delete; - NodeImpl &operator=(const NodeImpl &) = delete; - -private: - friend class NodeAdapter; - friend class GNode; - std::weak_ptr node_ptr_; -}; - -NodePtr NodeAdapter::GNode2Node(const ge::GNode &node) { - if (node.impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param graph_node.impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GNode2Node: gnode impl is nullptr."); - return nullptr; - } - - return node.impl_->node_ptr_.lock(); -} - -GNode NodeAdapter::Node2GNode(const ge::NodePtr &node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Node2GNode: node is nullptr"); - return GNode(); - } - - const GNode graph_node; - if (graph_node.impl_ == nullptr) { - GELOGW("[Check][Param] Gnode impl is nullptr, node:%s", node->GetName().c_str()); - return graph_node; - } - graph_node.impl_->node_ptr_ = node; - - return graph_node; -} - -GNodePtr NodeAdapter::Node2GNodePtr(const ge::NodePtr &node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Node2GNodePtr: node is nullptr"); - return nullptr; - } - - const GNodePtr gnode = MakeShared(); - if (gnode == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create GNode failed."); - GELOGE(GRAPH_FAILED, "[Create][GNode] Node2GNodePtr: gnode is nullptr, node[%s].", node->GetName().c_str()); - return nullptr; - } - - if (gnode->impl_ == nullptr) { - GELOGW("[Check][Param] Gnode impl is nullptr, node:%s", node->GetName().c_str()); - return nullptr; - } - gnode->impl_->node_ptr_ = node; - - return gnode; -} - -GNode::GNode() { impl_ = ComGraphMakeShared(); } - - -graphStatus GNode::GetType(AscendString &type) const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetType: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetType: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - const std::string node_type = node_ptr->GetType(); - const AscendString ascend_type(node_type.c_str()); - type = ascend_type; - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetName(AscendString &name) const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetName: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetName: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - const std::string node_name = node_ptr->GetName(); - const AscendString ascend_name(node_name.c_str()); - name = ascend_name; - return GRAPH_SUCCESS; -} - -std::pair GNode::GetInDataNodesAndPortIndexs(const int32_t index) const { - const std::pair gnode_idx = {nullptr, 0xFF}; - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] Gnode: node impl is nullptr."); - return gnode_idx; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node ptr is not valid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Gnode: the node shared ptr is not valid."); - return gnode_idx; - } - - const auto in_anchor = node_ptr->GetInDataAnchor(index); - if (in_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get in data node of index[%d] from node[%s], " - "the anchor does not exist", index, node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Anchor] Failed to get in data node of index[%d] from node[%s], " - "the anchor does not exist", index, node_ptr->GetName().c_str()); - return gnode_idx; - } - - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get in data node of index[%d] from node [%s], " - "the data input does not exist", index, node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Anchor] Failed to get in data node of index[%d] from node [%s], " - "the data input does not exist", index, node_ptr->GetName().c_str()); - return gnode_idx; - } - - const NodePtr peer_node_ptr = out_anchor->GetOwnerNode(); - const GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); - if (gnode == nullptr) { - GELOGE(GRAPH_FAILED, "[Get][GNode] Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); - return gnode_idx; - } - - return {gnode, out_anchor->GetIdx()}; -} - -std::vector GNode::GetInControlNodes() const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Gnode: node impl is nullptr."); - return {}; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Gnode: node ptr is not valid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Gnode: the node shared ptr is not valid."); - return {}; - } - - std::vector gnodes; - const auto in_control_nodes = node_ptr->GetInControlNodes(); - for (auto &in_control_node : in_control_nodes) { - GNodePtr gnode = NodeAdapter::Node2GNodePtr(in_control_node); - if (gnode == nullptr) { - GELOGE(GRAPH_FAILED, "[Get][GNode] In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); - return {}; - } - gnodes.emplace_back(gnode); - } - - return gnodes; -} - -std::vector> GNode::GetOutDataNodesAndPortIndexs(const int32_t index) const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Gnode: node impl is nullptr."); - return {}; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Gnode: the node shared ptr is not valid."); - return {}; - } - - const auto out_anchor = node_ptr->GetOutDataAnchor(index); - if (out_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get out data node of index %d from node %s, " - "the anchor does not exist", index, node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Anchor] Failed to get out data node of index %d from node %s, " - "the anchor does not exist", index, node_ptr->GetName().c_str()); - return {}; - } - - std::vector> gnode_index; - const auto in_data_anchors = out_anchor->GetPeerInDataAnchors(); - for (auto &in_data_anchor : in_data_anchors) { - if (in_data_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str()); - return {}; - } - const NodePtr peer_node_ptr = in_data_anchor->GetOwnerNode(); - const GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); - if (gnode == nullptr) { - GELOGE(GRAPH_FAILED, "[Get][GNode] Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); - return {}; - } - gnode_index.emplace_back(std::pair(gnode, in_data_anchor->GetIdx())); - } - - return gnode_index; -} - -std::vector GNode::GetOutControlNodes() const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutControlNodes: node impl is nullptr."); - return {}; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutControlNodes: the node shared ptr is not valid."); - return {}; - } - - std::vector gnodes; - const auto out_control_nodes = node_ptr->GetOutControlNodes(); - for (auto &out_control_node : out_control_nodes) { - GNodePtr gnode = NodeAdapter::Node2GNodePtr(out_control_node); - if (gnode == nullptr) { - GELOGE(GRAPH_FAILED, "[Get][GNode] In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); - return {}; - } - gnodes.emplace_back(gnode); - } - - return gnodes; -} - -graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputConstData: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputConstData: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index); - GE_CHECK_NOTNULL(input_data_node); - const std::string op_type = input_data_node->GetType(); - if ((op_type == CONSTANT) || (op_type == CONSTANTOP)) { - const Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node); - if (const_op.GetAttr(ATTR_NAME_WEIGHTS.c_str(), data) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Input data node[%s] of node[%s] get data failed.", - input_data_node->GetName().c_str(), node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Attr] Input data node[%s] of node[%s] get data failed.", - input_data_node->GetName().c_str(), node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - return SUCCESS; - } else if (op_type == DATA) { - auto parent_node = NodeUtils::GetParentInput(input_data_node); - while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { - parent_node = NodeUtils::GetParentInput(parent_node); - } - if ((parent_node != nullptr) && - ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { - const Operator const_op = OpDescUtils::CreateOperatorFromNode(parent_node); - if (const_op.GetAttr(ATTR_NAME_WEIGHTS.c_str(), data) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Input data node[%s] of node[%s] get data failed.", - parent_node->GetName().c_str(), node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Attr] Input data node[%s] of node[%s] get data failed.", - parent_node->GetName().c_str(), node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - } else { - // empty - } - REPORT_INNER_ERR_MSG("E18888", "Node[%s] has no const input.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "[Check][Param] Node[%s] has no const input.", node_ptr->GetName().c_str()); - return GRAPH_NODE_WITHOUT_CONST_INPUT; -} - -graphStatus GNode::GetInputIndexByName(const AscendString &name, int32_t &index) { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "ascend string error."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] GetInputIndexByName: ascend string error."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputIndexByName: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputIndexByName: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const auto op_desc = node_ptr->GetOpDescBarePtr(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "get Op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const std::string node_name = ascend_name; - index = op_desc->GetInputIndexByName(node_name); - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetOutputIndexByName(const AscendString &name, int32_t &index) { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "ascend string error."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] GetOutputIndexByName: ascend string error."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutputIndexByName: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutputIndexByName: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const auto op_desc = node_ptr->GetOpDescBarePtr(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const std::string node_name = ascend_name; - index = op_desc->GetOutputIndexByName(node_name); - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetDynamicInputIndexesByName(const AscendString &name, std::vector &indexes) { - const std::string node_name = name.GetString(); - if (node_name.empty()) { - REPORT_INNER_ERR_MSG("E18888", "input name is empty."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] input name is empty."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr"); - GELOGE(GRAPH_FAILED, "[Check][Param] the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const auto op_desc = node_ptr->GetOpDescBarePtr(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - if (op_desc->GetDynamicInputIndexesByName(node_name, indexes) != GRAPH_SUCCESS || indexes.empty()) { - const std::string error_message = indexes.empty() ? - "the dynamic input indexes is empty, input name is " + node_name + "." : - "get dynamic input index by name failed, input name is " + node_name + "."; - REPORT_INNER_ERR_MSG("E18888", "%s", error_message.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] %s", error_message.c_str()); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetDynamicOutputIndexesByName(const AscendString &name, std::vector &indexes) { - const std::string node_name = name.GetString(); - if (node_name.empty()) { - REPORT_INNER_ERR_MSG("E18888", "output name is empty."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] output name is empty."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] the node shared ptr is nullptr."); - return GRAPH_FAILED; - } - - const auto op_desc = node_ptr->GetOpDescBarePtr(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - if (op_desc->GetDynamicOutputIndexesByName(node_name, indexes) != GRAPH_SUCCESS || indexes.empty()) { - const std::string error_message = indexes.empty() ? - "the dynamic output indexes is empty, output name is " + node_name + "." : - "get dynamic output index by name failed, output name is " + node_name + "."; - REPORT_INNER_ERR_MSG("E18888", "%s", error_message.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] %s", error_message.c_str()); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -size_t GNode::GetInputsSize() const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputsSize: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputsSize: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const OpDescPtr op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - return op_desc->GetInputsSize(); -} - -size_t GNode::GetOutputsSize() const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutputsSize: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutputsSize: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const OpDescPtr op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - return op_desc->GetOutputsSize(); -} - -graphStatus GNode::GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const { - if (index < 0) { - REPORT_INNER_ERR_MSG("E18888", "index:%d can not be less than zero, check invalid.", index); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] GetInputDesc: index[%d] cannot be less than zero.", index); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputDesc: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetInputDesc: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const OpDescPtr op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetInputDescPtr(static_cast(index)); - if (ge_tensor_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][TensorDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); - - return GRAPH_SUCCESS; -} - -graphStatus GNode::UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc) { - if (index < 0) { - REPORT_INNER_ERR_MSG("E18888", "index:%d cannot be less than zero.", index); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] UpdateInputDesc: index[%d] cannot be less than zero.", index); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] UpdateInputDesc: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] UpdateInputDesc: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const OpDescPtr op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); - if (op_desc->UpdateInputDesc(static_cast(index), ge_tensor_desc) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Update][InputDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const { - if (index < 0) { - REPORT_INNER_ERR_MSG("E18888", "index:%d cannot be less than zero.", index); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] GetOutputDesc: index[%d] cannot be less than zero.", index); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutputDesc: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetOutputDesc: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const OpDescPtr op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetOutputDescPtr(static_cast(index)); - if (ge_tensor_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][TensorDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); - - return GRAPH_SUCCESS; -} - -graphStatus GNode::UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc) { - if (index < 0) { - REPORT_INNER_ERR_MSG("E18888", "index:%d cannot be less than zero.", index); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Gnode: index[%d] cannot be less than zero.", index); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] UpdateOutputDesc: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] UpdateOutputDesc: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const OpDescPtr op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); - if (op_desc->UpdateOutputDesc(static_cast(index), ge_tensor_desc) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Update][InputDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -NODE_ATTR_GET_IMP(int64_t) -NODE_ATTR_GET_IMP(int32_t) -NODE_ATTR_GET_IMP(uint32_t) -NODE_ATTR_GET_IMP(float32_t) -NODE_ATTR_GET_IMP(bool) -NODE_ATTR_GET_IMP(Tensor) -NODE_ATTR_GET_IMP(std::vector) -NODE_ATTR_GET_IMP(std::vector) -NODE_ATTR_GET_IMP(std::vector) -NODE_ATTR_GET_IMP(std::vector) -NODE_ATTR_GET_IMP(std::vector) -NODE_ATTR_GET_IMP(std::vector) -NODE_ATTR_GET_IMP(OpBytes) -NODE_ATTR_GET_IMP(std::vector>) -NODE_ATTR_GET_IMP(std::vector) -NODE_ATTR_GET_IMP(ge::DataType) -NODE_ATTR_GET_IMP(AttrValue) - -NODE_ATTR_SET_IMP(int64_t) -NODE_ATTR_SET_IMP(int32_t) -NODE_ATTR_SET_IMP(uint32_t) -NODE_ATTR_SET_IMP(float32_t) -NODE_ATTR_SET_IMP(bool) -NODE_ATTR_SET_IMP(Tensor) -NODE_ATTR_SET_IMP(std::vector) -NODE_ATTR_SET_IMP(std::vector) -NODE_ATTR_SET_IMP(std::vector) -NODE_ATTR_SET_IMP(std::vector) -NODE_ATTR_SET_IMP(std::vector) -NODE_ATTR_SET_IMP(std::vector) -NODE_ATTR_SET_IMP(OpBytes) -NODE_ATTR_SET_IMP(std::vector>) -NODE_ATTR_SET_IMP(std::vector) -NODE_ATTR_SET_IMP(ge::DataType) - -graphStatus GNode::SetAttr(const AscendString &name, AttrValue &attr_value) const { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "ascend string error."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] SetAttr: ascend string error."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: the shared ptr is not valid."); - return GRAPH_FAILED; - } - - const std::string node_name = ascend_name; - Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); - (void)op.SetAttr(node_name.c_str(), std::move(attr_value)); - return GRAPH_SUCCESS; -} - -graphStatus GNode::SetAttr(const AscendString &name, AscendString &attr_value) const { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "name ascend string error"); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] SetAttr: name ascend string error."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: the shared ptr is not valid."); - return GRAPH_FAILED; - } - const std::string node_name = ascend_name; - const std::string node_attr_value = attr_value.GetString(); - Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); - (void)op.SetAttr(node_name, node_attr_value); - - return GRAPH_SUCCESS; -} - -graphStatus GNode::SetAttr(const AscendString &name, std::vector &attr_values) const { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "name ascend string error."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] SetAttr: name ascend string error."); - return GRAPH_PARAM_INVALID; - } - - for (auto &attr_val : attr_values) { - const char_t *const ascend_attr_value = attr_val.GetString(); - if (std::string(ascend_attr_value).empty()) { - REPORT_INNER_ERR_MSG("E18888", "param attr values is invalid"); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] SetAttr: attr val error."); - return GRAPH_PARAM_INVALID; - } - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetAttr: the shared ptr is not valid."); - return GRAPH_FAILED; - } - std::vector node_attr_vals; - for (const auto &attr_val : attr_values) { - if (attr_val.GetString() != nullptr) { - std::string node_attr_val = attr_val.GetString(); - node_attr_vals.emplace_back(node_attr_val); - } - } - const std::string node_name = ascend_name; - Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); - (void)op.SetAttr(node_name, node_attr_vals); - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetAttr(const AscendString &name, AscendString &attr_value) const { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "name ascend string error."); - GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetAttr: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetAttr: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const std::string node_name = ascend_name; - const Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); - std::string op_name; - if (op.GetAttr(node_name, op_name) != GRAPH_SUCCESS) { - GELOGW("[Check][Param] Get attr of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const AscendString attr_value_get(op_name.c_str()); - attr_value = attr_value_get; - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetAttr(const AscendString &name, std::vector &attr_values) const { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "name ascend string error."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] GetAttr: name ascend string error."); - return GRAPH_PARAM_INVALID; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetAttr: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetAttr: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const std::string node_name = ascend_name; - const Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); - std::vector attr_names; - if (op.GetAttr(node_name, attr_names) != GRAPH_SUCCESS) { - GELOGW("[Get][Attr] of node[%s] failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - for (auto &attr_name : attr_names) { - const AscendString ascend_attr_name(attr_name.c_str()); - attr_values.push_back(ascend_attr_name); - } - - return GRAPH_SUCCESS; -} - -bool GNode::HasAttr(const AscendString &name) { - const char_t *const ascend_name = name.GetString(); - if (std::string(ascend_name).empty()) { - REPORT_INNER_ERR_MSG("E18888", "ascend string error."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] HasAttr: ascend string error."); - return false; - } - - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] HasAttr: node impl is nullptr."); - return false; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] HasAttr: the node shared ptr is not valid."); - return false; - } - - const OpDescPtr op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] of node[%s] failed.", node_ptr->GetName().c_str()); - return false; - } - const std::string attr_name = ascend_name; - if (!op_desc->HasAttr(attr_name)) { - GELOGW("[Call][HasAttr] Node[%s] has no attr name[%s]", node_ptr->GetName().c_str(), attr_name.c_str()); - return false; - } - - return true; -} - -graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr &graph) const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetSubgraph: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetSubgraph: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index); - if (compute_graph_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "get subgraph[%u] failed from node[%s].", index, node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][SubGraph] subgraph[%u] from node[%s] is nullptr.", - index, node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - graph = GraphUtilsEx::CreateGraphPtrFromComputeGraph(compute_graph_ptr); - if (graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create compute graph failed from %s.", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Create][Graph] failed from %s.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus GNode::GetALLSubgraphs(std::vector &graph_list) const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetALLSubgraphs: node impl is nullptr."); - return GRAPH_FAILED; - } - - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "the node shared ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetALLSubgraphs: the node shared ptr is not valid."); - return GRAPH_FAILED; - } - - const auto root_graph = GraphUtils::FindRootGraph(node_ptr->GetOwnerComputeGraph()); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to find root graph from node %s ", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][RootGraph] Failed to find root graph from node %s ", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - const std::vector sub_graphs = root_graph->GetAllSubgraphs(); - if (sub_graphs.empty()) { - REPORT_INNER_ERR_MSG("E18888", "get all subgraphs failed from node[%s].", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][ALLSubGraphs] failed from node[%s].", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - for (auto &sub_graph : sub_graphs) { - if (sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "get subgraph failed from node[%s].", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][SubGraph] failed from node[%s].", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - GraphPtr graph = GraphUtilsEx::CreateGraphPtrFromComputeGraph(sub_graph); - if (graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create compute graph failed from node[%s].", node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Create][ComputeGraph] failed from node[%s].", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - graph_list.emplace_back(graph); - } - - if (graph_list.empty()) { - GELOGW("[Get][Subgraph] Node %s has no subgraph", node_ptr->GetName().c_str()); - } - - return GRAPH_SUCCESS; -} - -graphStatus GNode::SetSubgraph(const AscendString &subgraph_ir_name, const Graph &subgraph) { - GE_ASSERT_NOTNULL(impl_); - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - GE_ASSERT_NOTNULL(node_ptr, "Gnode has invalid node source, try to call `AddNodeByOp` firstly"); - - auto sub_compute_graph = GraphUtilsEx::GetComputeGraph(subgraph); - GE_ASSERT_NOTNULL(sub_compute_graph); - return NodeUtils::AddSubgraph(node_ptr, subgraph_ir_name.GetString(), sub_compute_graph); -} -graphStatus GNode::SetSubgraphs(const AscendString &subgraph_ir_name, const std::vector &subgraphs) { - GE_ASSERT_NOTNULL(impl_); - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - GE_ASSERT_NOTNULL(node_ptr, "Gnode has invalid node source, try to call `AddNodeByOp` firstly"); - - std::vector sub_compute_graphs{}; - for (const auto &subgraph : subgraphs) { - auto sub_compute_graph = GraphUtilsEx::GetComputeGraph(subgraph); - GE_ASSERT_NOTNULL(sub_compute_graph); - sub_compute_graphs.emplace_back(sub_compute_graph); - } - - return NodeUtils::AddSubgraphs(node_ptr, subgraph_ir_name.GetString(), sub_compute_graphs); -} - -namespace { -graphStatus GetAttrValue(const std::shared_ptr &node_ptr, const AscendString &name, - uint32_t index, AttrValue &attr_value, bool is_input) { - auto tensor_desc = is_input ? node_ptr->GetOpDesc()->MutableInputDesc(index) - : node_ptr->GetOpDesc()->MutableOutputDesc(index); - GE_ASSERT_NOTNULL(tensor_desc, "index: %u is invalid for node: %s %s", - index, node_ptr->GetNamePtr(), node_ptr->GetTypePtr()); - GE_WARN_ASSERT_GRAPH_SUCCESS(tensor_desc->GetAttr(name.GetString(), attr_value.impl->MutableAnyValue()), - "Attr: %s has not been set for node: %s %s", - name.GetString(), - node_ptr->GetNamePtr(), - node_ptr->GetTypePtr()); - return GRAPH_SUCCESS; -} - -graphStatus SetAttrValue(const std::shared_ptr &node_ptr, const AscendString &name, - uint32_t index, const AttrValue &attr_value, bool is_input) { - auto tensor_desc = is_input ? node_ptr->GetOpDesc()->MutableInputDesc(index) - : node_ptr->GetOpDesc()->MutableOutputDesc(index); - GE_ASSERT_NOTNULL(tensor_desc, "index: %u is invalid for node: %s %s", - index, node_ptr->GetNamePtr(), node_ptr->GetTypePtr()); - return tensor_desc->SetAttr(name.GetString(), attr_value.impl->MutableAnyValue()); -} -} - -// 实现新的AttrValue接口 -graphStatus GNode::GetOutputAttr(const AscendString &name, uint32_t output_index, AttrValue &attr_value) const { - GE_ASSERT_NOTNULL(impl_); - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - GE_ASSERT_NOTNULL(node_ptr, "Gnode has invalid node source, try to call `AddNodeByOp` firstly"); - return GetAttrValue(node_ptr, name, output_index, attr_value, false); -} - -graphStatus GNode::SetOutputAttr(const AscendString &name, uint32_t output_index, const AttrValue &attr_value) { - GE_ASSERT_NOTNULL(impl_); - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - GE_ASSERT_NOTNULL(node_ptr, "Gnode has invalid node source, try to call `AddNodeByOp` firstly"); - return SetAttrValue(node_ptr, name, output_index, attr_value, false); -} - -graphStatus GNode::GetInputAttr(const AscendString &name, uint32_t input_index, AttrValue &attr_value) const { - GE_ASSERT_NOTNULL(impl_); - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - GE_ASSERT_NOTNULL(node_ptr, "Gnode has invalid node source, try to call `AddNodeByOp` firstly"); - return GetAttrValue(node_ptr, name, input_index, attr_value, true); -} - -graphStatus GNode::SetInputAttr(const AscendString &name, uint32_t input_index, const AttrValue &attr_value) { - GE_ASSERT_NOTNULL(impl_); - const std::shared_ptr node_ptr = impl_->node_ptr_.lock(); - GE_ASSERT_NOTNULL(node_ptr, "Gnode has invalid node source, try to call `AddNodeByOp` firstly"); - return SetAttrValue(node_ptr, name, input_index, attr_value, true); -} -} // namespace ge diff --git a/graph/normal_graph/graph.cc b/graph/normal_graph/graph.cc deleted file mode 100644 index 30dc015ddfe372f6fa250173f42e81f4afc46c30..0000000000000000000000000000000000000000 --- a/graph/normal_graph/graph.cc +++ /dev/null @@ -1,1190 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/graph.h" -#include "external/graph/graph_buffer.h" -#include -#include "debug/ge_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_op_types.h" -#include "graph/model.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/node_adapter.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "common/checker.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "graph/type_utils.h" -#include "utils/tensor_adapter.h" -#include "proto/onnx/ge_onnx.pb.h" -#include "utils/ge_ir_utils.h" - -namespace ge { -class GraphImpl { - public: - friend class GraphUtils; - friend class GraphUtilsEx; - GraphImpl(const GraphImpl &) = delete; - GraphImpl &operator=(const GraphImpl &) = delete; - - explicit GraphImpl(const std::string &name) : name_(name) {} - - ~GraphImpl() { - if (IsValid()) { - if (compute_graph_ != nullptr) { - GraphUtilsEx::BreakConnect(compute_graph_->GetAllNodesInfo()); - } - } - for (const auto &it : op_list_) { - const Operator op = it.second; - op.BreakConnect(); - } - } - - graphStatus SetInputs(const std::vector &inputs) { - GE_ASSERT(!IsValid(), - "Inner graph has been inited, maybe you call `SetInputs` again or call `AddNodeByOp or SetValid` before `SetInputs`"); - compute_graph_ = GraphUtilsEx::CreateGraphFromOperator(name_, inputs); - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "[Build][Graph] failed."); - GE_CHK_BOOL_RET_STATUS(inputs.size() != 0U, GRAPH_FAILED, "[Check][Param] set input NULL."); - compute_graph_->SetInputSize(static_cast(inputs.size())); - return GRAPH_SUCCESS; - } - - graphStatus SetOutputs(const std::vector &outputs) { - if (compute_graph_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "compute graph is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] set ComputeGraph failed."); - return GRAPH_FAILED; - } - if (outputs.empty()) { - GELOGI("Set outputs size is 0."); - return GRAPH_SUCCESS; - } - - // Construct special output node - std::vector>> output_indexs; - for (size_t i = 0U; i < outputs.size(); ++i) { - output_indexs.emplace_back(outputs[i], std::vector{}); - } - - const graphStatus ret = SetOutputs(output_indexs); - return ret; - } - - graphStatus SetOutputs(const std::vector>> &output_indexs) { - if (compute_graph_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "compute graph is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] set ComputeGraph failed."); - return GRAPH_FAILED; - } - if (output_indexs.empty()) { - GELOGW("[SetOutputs][CheckParam] Set outputs size is 0."); - return GRAPH_SUCCESS; - } - - // Construct special output node - std::vector> output_nodes; - for (const auto &item : output_indexs) { - const Operator &output = item.first; - const std::vector &indexs = item.second; - AscendString out_name; - (void) output.GetName(out_name); - ge::NodePtr node = compute_graph_->FindNode(out_name.GetString()); - if (node == nullptr) { - GELOGW("[SetOutputs][Check] User designated out_node %s does not exist in graph, skip it", - out_name.GetString()); - continue; - } - - const ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); - if (tmp_op_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc in node must not be null."); - continue; - } - const size_t out_size = tmp_op_ptr->GetOutputsSize(); - if (indexs.empty()) { - for (size_t i = 0U; i < out_size; ++i) { - output_name_ += std::string(out_name.GetString()) + ":" + std::to_string(i) + ";"; - output_nodes.emplace_back(node, i); - } - } else { - for (size_t i = 0U; i < indexs.size(); ++i) { - if (indexs[i] >= out_size) { - GELOGW("[SetOutputs][Check] User designated out_node %s has no output %zu, output_size=%zu, skip it", - out_name.GetString(), indexs[i], out_size); - } else { - output_name_ += std::string(out_name.GetString()) + ":" + std::to_string(i) + ";"; - output_nodes.emplace_back(node, indexs[i]); - } - } - } - } - - // Del last ";" - if (!output_name_.empty()) { - output_name_ = output_name_.substr(0U, output_name_.length() - 1U); - } - compute_graph_->SetUserDefOutput(output_name_); - compute_graph_->SetOutputSize(static_cast(output_indexs.size())); - compute_graph_->SetGraphOutNodesInfo(output_nodes); - return GRAPH_SUCCESS; - } - - graphStatus SetOutputs(const std::vector> &outputs) { - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "[Check][Param] set ComputeGraph faild."); - if (outputs.empty()) { - GELOGI("set outputs size is 0."); - return GRAPH_SUCCESS; - } - - // Construct specified output - std::vector> output_nodes; - for (const auto &item : outputs) { - AscendString out_name; - (void) item.first.GetName(out_name); - ge::NodePtr node = compute_graph_->FindNode(out_name.GetString()); - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "designated out_node (%s) does not exist in graph:%s, this out_node ignored!", - out_name.GetString(), compute_graph_->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Warning, user designated out_node (%s) does not exist in graph:%s, " - "this out_node ignored!", out_name.GetString(), compute_graph_->GetName().c_str()); - return GRAPH_FAILED; - } - const ge::OpDescPtr tmp_op_ptr = node->GetOpDesc(); - if (tmp_op_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc_ptr in node must not be null."); - continue; - } - const size_t out_size = tmp_op_ptr->GetOutputsSize(); - - if (item.second.empty()) { - for (size_t i = 0U; i < out_size; ++i) { - output_name_ += std::string(out_name.GetString()) + ":" + std::to_string(i) + ";"; - output_nodes.emplace_back(node, i); - } - } else { - int32_t index = tmp_op_ptr->GetOutputIndexByName(item.second); - if (index < 0) { - REPORT_INNER_ERR_MSG("E18888", - "user designated out_node (%s):(%s) does not exist in graph:%s, " - "this out_node ignored!", - out_name.GetString(), item.second.c_str(), compute_graph_->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Warning, user designated out_node (%s):(%s) does not exist in graph:%s, " - "this out_node ignored!", out_name.GetString(), item.second.c_str(), - compute_graph_->GetName().c_str()); - return GRAPH_FAILED; - } - output_name_ += std::string(out_name.GetString()) + ":" + std::to_string(index) + ";"; - output_nodes.emplace_back(node, index); - } - } - // Del last ";" - if (!output_name_.empty()) { - output_name_ = output_name_.substr(0U, output_name_.length() - 1U); - } - compute_graph_->SetOutputSize(static_cast(outputs.size())); - compute_graph_->SetGraphOutNodesInfo(output_nodes); - GELOGI("********************SetOutputs Success***********************"); - GE_IF_BOOL_EXEC(!output_name_.empty(), GELOGI(" NetOutputs: (%s)", output_name_.c_str())); - - return GRAPH_SUCCESS; - } - - graphStatus SetTargets(const std::vector &targets) { - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, GRAPH_FAILED, "[Check][Param] set ComputeGraph faild."); - if (targets.empty()) { - GELOGI("set targets size is 0."); - return GRAPH_SUCCESS; - } - - std::vector target_nodes; - for (const auto &item : targets) { - AscendString name; - (void) item.GetName(name); - const ge::NodePtr node = compute_graph_->FindNode(name.GetString()); - if (node == nullptr) { - GELOGW("[SetTargets][Check] User designated target_node %s does not exist in graph, skip it", name.GetString()); - continue; - } - target_nodes.push_back(node); - } - compute_graph_->SetGraphTargetNodesInfo(target_nodes); - return GRAPH_SUCCESS; - } - bool IsValid() const { return (compute_graph_ != nullptr); } - - graphStatus AddOp(const ge::Operator &op) { - AscendString name; - (void) op.GetName(name); - const auto ret = op_list_.emplace(std::pair(name.GetString(), op)); - GE_CHK_BOOL_RET_STATUS(ret.second, GRAPH_FAILED, "[Check][Param] the op have added before, op name:%s.", - name.GetString()); - return GRAPH_SUCCESS; - } - - graphStatus GetAllOpName(std::vector &op_name) const { - for (const auto &it : op_list_) { - AscendString name; - it.second.GetName(name); - op_name.emplace_back(name.GetString()); - } - return GRAPH_SUCCESS; - } - - graphStatus FindOpByName(const std::string &name, ge::Operator &op) const { - const auto it = op_list_.find(name); - GE_CHK_BOOL_EXEC(it != op_list_.end(), - REPORT_INNER_ERR_MSG("E18888", "there is no op: %s.", name.c_str()); - return GRAPH_FAILED, "[Find][Op] there is no op: %s.", name.c_str()); - op = it->second; - return GRAPH_SUCCESS; - } - - graphStatus FindOpByType(const std::string &type, std::vector &ops) const { - for (auto &op : op_list_) { - AscendString op_type; - (void) op.second.GetOpType(op_type); - if (op_type.GetString() == type) { - ops.push_back(op.second); - continue; - } - if (op_type == ge::FRAMEWORKOP) { - (void) op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), op_type); - if (op_type.GetString() == type) { - ops.push_back(op.second); - } - } - } - return GRAPH_SUCCESS; - } - - void SetNeedIteration(bool need_iteration) { - if (compute_graph_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Set need iteration failed, as compute graph is null."); - GELOGE(GRAPH_FAILED, "[Check][Param] Set need iteration failed, as compute graph is null."); - return; - } - compute_graph_->SetNeedIteration(need_iteration); - } - - const std::string &GetName() const { - return name_; - } - - ComputeGraphPtr GetComputeGraph() const { - return compute_graph_; - } - - graphStatus InitComputeGraph() { - GE_WARN_ASSERT(compute_graph_ == nullptr, "No need to init again"); - compute_graph_ = ge::ComGraphMakeSharedAndThrow(name_); - return GRAPH_SUCCESS; - } - - graphStatus RemoveEdge(const NodePtr &src_node_ptr, const int32_t src_port_index, - const NodePtr &dst_node_ptr, const int32_t dst_port_index) { - GE_CHECK_NOTNULL(src_node_ptr); - GE_CHECK_NOTNULL(dst_node_ptr); - - graphStatus res = GRAPH_FAILED; - if ((src_port_index == -1) && (dst_port_index == -1)) { - if (src_node_ptr->GetOutControlAnchor() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src node:%s out control anchor is null.", src_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Anchor] src node:%s out control anchor is null.", src_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - res = GraphUtils::RemoveEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); - if (res != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "remove control edge between [%s] and [%s]failed.", - src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Remove][ControlEdge] between [%s] and [%s]failed.", - src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - - if (src_node_ptr->GetOutDataAnchor(src_port_index) == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src node[%s] out data anchor[%d] is null.", src_node_ptr->GetName().c_str(), - src_port_index); - GELOGE(GRAPH_FAILED, "[Get][Anchor] src node[%s] out data anchor[%d] is null.", - src_node_ptr->GetName().c_str(), src_port_index); - return GRAPH_FAILED; - } - - if ((src_port_index != -1) && (dst_port_index == -1)) { - res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInControlAnchor()); - if (res != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "remove data-control edge between [%s] and [%s]failed.", - src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Remove][Edge] between [%s] and [%s]failed.", - src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - - res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), - dst_node_ptr->GetInDataAnchor(dst_port_index)); - if (res != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "remove data edge between [%s] and [%s] failed.", src_node_ptr->GetName().c_str(), - dst_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Remove][Edge] between [%s] and [%s] failed.", - src_node_ptr->GetName().c_str(), dst_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; - } - - private: - std::string name_; - std::string output_name_; - std::map op_list_; - ComputeGraphPtr compute_graph_{nullptr}; -}; - -Graph::Graph(const std::string &name) { - impl_ = ComGraphMakeShared(name); - if (impl_ == nullptr) { - GELOGW("[Check][Impl] Make graph impl failed"); - } -} - -Graph::Graph(const char_t *name) { - if (name != nullptr) { - std::string graph_name = name; - impl_ = ComGraphMakeShared(graph_name); - if (impl_ == nullptr) { - GELOGW("[Check][Impl] Make graph impl failed"); - } - } else { - GELOGW("[Check][Param] Input graph name is nullptr."); - } -} - -graphStatus Graph::AddOp(const ge::Operator &op) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return GRAPH_FAILED, "[Check][Param] AddOp failed: graph can not be used, impl is nullptr."); - return impl_->AddOp(op); -} - -graphStatus Graph::GetAllOpName(std::vector &op_name) const { - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return GRAPH_FAILED, "[Check][Param] GetAllOpName failed: graph can not be used, impl is nullptr."); - return impl_->GetAllOpName(op_name); -} - -graphStatus Graph::GetAllOpName(std::vector &names) const { - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return GRAPH_FAILED, "[Check][Param] GetAllOpName failed: graph can not be used, impl is nullptr."); - std::vector op_names; - if (impl_->GetAllOpName(op_names) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Get][AllOpName] failed."); - return GRAPH_FAILED; - } - - for (auto &op_name : op_names) { - names.emplace_back(op_name.c_str()); - } - - return GRAPH_SUCCESS; -} - -graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { - const Operator op_find_op_def("NULL"); - op = op_find_op_def; - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return GRAPH_FAILED, "[Check][Param] FindOpByName failed: graph can not be used, impl is nullptr."); - return impl_->FindOpByName(name, op); -} - -graphStatus Graph::FindOpByName(const char_t *name, Operator &op) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] FindOpByName: name is nullptr."); - return GRAPH_FAILED; - } - const Operator op_find_op_def("NULL"); - op = op_find_op_def; - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return GRAPH_FAILED, "[Check][Param] FindOpByName failed: graph can not be used, impl is nullptr."); - const std::string op_name = name; - return impl_->FindOpByName(op_name, op); -} - -graphStatus Graph::FindOpByType(const std::string &type, std::vector &ops) const { - GE_CHECK_NOTNULL(impl_); - return impl_->FindOpByType(type, ops); -} - -graphStatus Graph::FindOpByType(const char_t *type, std::vector &ops) const { - if (type == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param type is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] FindOpByType: type is nullptr."); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(impl_); - const std::string op_type = type; - return impl_->FindOpByType(op_type, ops); -} - -Graph &Graph::SetInputs(const std::vector &inputs) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return *this, "[Check][Param] SetInputs failed: graph can not be used, impl is nullptr."); - GE_CHK_BOOL_EXEC(!inputs.empty(), REPORT_INNER_ERR_MSG("E18888", "input operator size can not be 0"); - return *this, "[Check][Param] SetInputs failed: input operator size can not be 0, graph: %s", - impl_->GetName().c_str()); - (void)impl_->SetInputs(inputs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector &outputs) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetOutputs failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetOutputs(outputs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector>> &output_indexs) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetOutputs failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetOutputs(output_indexs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector> &outputs) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return *this, "[Check][Param] SetOutputs failed: graph can not be used, impl is nullptr."); - (void)impl_->SetOutputs(outputs); - return *this; -} - -Graph &Graph::SetOutputs(const std::vector> &outputs) { - GE_CHK_BOOL_EXEC(impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - return *this, "[Check][Param] SetOutputs failed: graph can not be used, impl is nullptr."); - std::vector> graph_outputs; - for (auto &item : outputs) { - const char_t * const name = item.second.GetString(); - if (name != nullptr) { - graph_outputs.emplace_back((std::pair(item.first, name))); - } else { - GELOGW("[SetOutputs][CheckParam] Input output_op_name is nullptr."); - } - } - - (void)impl_->SetOutputs(graph_outputs); - return *this; -} - -Graph &Graph::SetTargets(const std::vector &targets) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] SetTargets failed: graph can not be used, impl is nullptr."); - return *this; - } - (void)impl_->SetTargets(targets); - return *this; -} - -bool Graph::IsValid() const { - if (impl_ == nullptr) { - return false; - } - return impl_->IsValid(); -} - -void Graph::SetNeedIteration(bool need_iteration) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] Set need iteration failed, as impl is null."); - return; - } - impl_->SetNeedIteration(need_iteration); -} - -std::vector Graph::GetAllNodes() const { - std::vector graph_nodes; - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetAllNodes: graph can not be used, impl is nullptr."); - return graph_nodes; - } - - const ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); - if (compute_graph_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "impl compute graph is nullptr."); - GELOGE(GRAPH_FAILED, "[Get][Graph] GetAllNodes: compute graph ptr is nullptr, graph %s", impl_->GetName().c_str()); - return graph_nodes; - } - - for (auto &node : compute_graph_ptr->GetAllNodes()) { - GNode gnode = NodeAdapter::Node2GNode(node); - graph_nodes.emplace_back(gnode); - } - - return graph_nodes; -} - -std::vector Graph::GetDirectNode() const { - std::vector graph_nodes; - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] GetDirectNode: graph can not be used, impl is nullptr."); - return graph_nodes; - } - const ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); - if (compute_graph_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "impl compute graph is nullptr."); - GELOGE(GRAPH_FAILED, "[Get][Graph] GetDirectNode: compute graph ptr is nullptr, graph %s", - impl_->GetName().c_str()); - return graph_nodes; - } - - for (auto &node : compute_graph_ptr->GetDirectNode()) { - GNode gnode = NodeAdapter::Node2GNode(node); - graph_nodes.emplace_back(gnode); - } - - return graph_nodes; -} - -graphStatus Graph::RemoveNode(GNode &node) { - return RemoveNode(node, false); -} - -graphStatus Graph::RemoveNode(GNode &node, bool contain_subgraph) { - GE_CHECK_NOTNULL(impl_); - - const NodePtr node_ptr = NodeAdapter::GNode2Node(node); - GE_CHECK_NOTNULL(node_ptr); - - const ComputeGraphPtr owner_compute_graph = node_ptr->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_compute_graph); - - ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); - GE_CHECK_NOTNULL(compute_graph_ptr); - - if (contain_subgraph) { - if (!GraphUtils::IsNodeInGraphRecursively(compute_graph_ptr, *node_ptr)) { - REPORT_INNER_ERR_MSG("E18888", "node[%s] is not in the graph[%s] or not in subgraph.", - node_ptr->GetName().c_str(), compute_graph_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] is not in the graph[%s].", - node_ptr->GetName().c_str(), compute_graph_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - compute_graph_ptr = owner_compute_graph; - } else { - if (compute_graph_ptr != owner_compute_graph) { - REPORT_INNER_ERR_MSG("E18888", "node[%s] is not in the graph[%s].", node_ptr->GetName().c_str(), - compute_graph_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] is not in the graph[%s].", - node_ptr->GetName().c_str(), compute_graph_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - } - - ge::NodeUtils::UnlinkAll(*node_ptr); - if (GraphUtils::RemoveNodeWithoutRelink(compute_graph_ptr, node_ptr) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "graph:%s remove node:%s failed", compute_graph_ptr->GetName().c_str(), - node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Remove][Node] %s from graph:%s failed.", - node_ptr->GetName().c_str(), compute_graph_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - (void)node_ptr->ClearOwnerGraph(nullptr); - return GRAPH_SUCCESS; -} - -graphStatus Graph::RemoveEdge(GNode &src_node, const int32_t src_port_index, - GNode &dst_node, const int32_t dst_port_index) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] graph can not be used, impl is nullptr."); - return GRAPH_FAILED; - } - - if ((src_port_index == -1) && (dst_port_index != -1)) { - REPORT_INNER_ERR_MSG("E18888", "src_port_index == -1 and dst_port_index != -1, check invalid ."); - GELOGE(GRAPH_FAILED, "[Check][Param] src control anchor link to dst data anchor does not exist."); - return GRAPH_FAILED; - } - - const NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); - if (src_node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src gnode to node failed."); - GELOGE(GRAPH_FAILED, "[Get][Node] src gnode to node failed."); - return GRAPH_FAILED; - } - - const NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); - if (dst_node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "dst gnode to node failed."); - GELOGE(GRAPH_FAILED, "[Get][Node] dst gnode to node failed."); - return GRAPH_FAILED; - } - - if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src node:%s compute graph is nullptr.", src_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] src node:%s compute graph is nullptr.", src_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "dst node:%s compute graph is nullptr", dst_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] dst node:%s compute graph is nullptr.", dst_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - if (impl_->RemoveEdge(src_node_ptr, src_port_index, dst_node_ptr, dst_port_index) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "remove edge between %s(%d) and %s(%d) failed.", src_node_ptr->GetName().c_str(), - src_port_index, dst_node_ptr->GetName().c_str(), dst_port_index); - GELOGE(GRAPH_FAILED, "[Remove][Edge] between %s(%d) and %s(%d) failed.", - src_node_ptr->GetName().c_str(), src_port_index, dst_node_ptr->GetName().c_str(), dst_port_index); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -GNode Graph::AddNodeByOp(const Operator &op) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] graph can not be used, impl is nullptr."); - return GNode(); - } - - const std::shared_ptr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - if (op_desc == nullptr) { - AscendString name; - (void) op.GetName(name); - REPORT_INNER_ERR_MSG("E18888", "get op desc from op:%s failed", name.GetString()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] from op[%s] failed.", name.GetString()); - return GNode(); - } - - GE_ASSERT_GRAPH_SUCCESS(SetValid()); - const auto compute_graph_ptr = impl_->GetComputeGraph(); - GE_ASSERT_NOTNULL(compute_graph_ptr); - const NodePtr node_ptr = compute_graph_ptr->AddNode(op_desc); - const GNode gnode = NodeAdapter::Node2GNode(node_ptr); - - return gnode; -} - -graphStatus Graph::AddDataEdge(GNode &src_node, const int32_t src_port_index, - GNode &dst_node, const int32_t dst_port_index) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] graph can not be used, impl is nullptr."); - return GRAPH_FAILED; - } - - const NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); - if (src_node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src gnode to node failed."); - GELOGE(GRAPH_FAILED, "[Get][Node] src gnode to node failed."); - return GRAPH_FAILED; - } - - const NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); - if (dst_node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "dst gnode to node failed."); - GELOGE(GRAPH_FAILED, "[Get][Node] dst gnode to node failed."); - return GRAPH_FAILED; - } - - if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src node[%s] owner compute graph is nullptr.", src_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] src node[%s] owner compute graph is nullptr.", src_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "dst node[%s] owner compute graph is nullptr.", dst_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] dst node[%s] owner compute graph is nullptr.", dst_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutDataAnchor(src_port_index), - dst_node_ptr->GetInDataAnchor(dst_port_index)); - if (res != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "add data edge from %s(%d) to %s(%d) failed.", src_node_ptr->GetName().c_str(), - src_port_index, dst_node_ptr->GetName().c_str(), dst_port_index); - GELOGE(GRAPH_FAILED, "[Add][DataEdge] from %s(%d) to %s(%d) failed.", src_node_ptr->GetName().c_str(), - src_port_index, dst_node_ptr->GetName().c_str(), dst_port_index); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus Graph::AddControlEdge(GNode &src_node, GNode &dst_node) { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "graph can not be used, impl is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] graph can not be used, impl is nullptr."); - return GRAPH_FAILED; - } - - const NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); - if (src_node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src gnode to node failed."); - GELOGE(GRAPH_FAILED, "[Get][Node] src gnode to node failed."); - return GRAPH_FAILED; - } - - const NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); - if (dst_node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "dst gnode to node failed."); - GELOGE(GRAPH_FAILED, "[Get][Node] dst gnode to node failed."); - return GRAPH_FAILED; - } - - if (src_node_ptr->GetOwnerComputeGraph() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src node[%s] owner compute graph is nullptr.", src_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] src node[%s] owner compute graph is nullptr.", src_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - if (dst_node_ptr->GetOwnerComputeGraph() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "dst node[%s] owner compute graph is nullptr.", dst_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] dst node[%s] owner compute graph is nullptr.", dst_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); - if (res != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "add control edge from %s to %s failed.", src_node_ptr->GetName().c_str(), - dst_node_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][ControlEdge] from %s to %s failed.", src_node_ptr->GetName().c_str(), - dst_node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - return SUCCESS; -} - -GraphPtr Graph::ConstructFromInputs(const std::vector &inputs, const AscendString &name) { - const char_t *const ascend_name = name.GetString(); - if (ascend_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "ascend string error"); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] ascend string error."); - return nullptr; - } - - if (inputs.empty()) { - REPORT_INNER_ERR_MSG("E18888", "inputs size can not be 0."); - GELOGE(GRAPH_FAILED, "[Check][Param] inputs size can not be 0, graph: %s", ascend_name); - return nullptr; - } - - const std::string graph_name = ascend_name; - const ComputeGraphPtr compute_graph = GraphUtilsEx::CreateGraphFromOperator(graph_name, inputs); - if (compute_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create compute graph from op failed, name:%s", graph_name.c_str()); - GELOGE(GRAPH_FAILED, "[Create][ComputeGraph] failed, name:%s.", graph_name.c_str()); - return nullptr; - } - - compute_graph->SetInputSize(static_cast(inputs.size())); - const GraphPtr graph_ptr = GraphUtilsEx::CreateGraphPtrFromComputeGraph(compute_graph); - if (graph_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create graph from compute graph:%s failed.", compute_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Create][Graph] from compute graph:%s failed.", compute_graph->GetName().c_str()); - return nullptr; - } - - return graph_ptr; -} - -graphStatus Graph::SaveToFile(const std::string &file_name) const { - Model model = Model(); - model.SetGraph(GraphUtilsEx::GetComputeGraph(*this)); - return model.SaveToFile(file_name); -} - -graphStatus Graph::SaveToFile(const char_t *file_name) const { - if (file_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "file name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] file name is nullptr."); - return GRAPH_FAILED; - } - - Model model = Model(); - model.SetGraph(GraphUtilsEx::GetComputeGraph(*this)); - const std::string name = file_name; - return model.SaveToFile(name); -} - -graphStatus Graph::LoadFromFile(const std::string &file_name) { - Model model = Model(); - const graphStatus ret = model.LoadFromFile(file_name); - if (ret != GRAPH_SUCCESS) { - return ret; - } - *this = GraphUtilsEx::CreateGraphFromComputeGraph(model.GetGraph()); - return GRAPH_SUCCESS; -} - -graphStatus Graph::LoadFromFile(const char_t *file_name) { - if (file_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param file name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] file name is nullptr."); - return GRAPH_FAILED; - } - - Model model = Model(); - const std::string file = file_name; - const graphStatus ret = model.LoadFromFile(file); - if (ret != GRAPH_SUCCESS) { - return ret; - } - *this = GraphUtilsEx::CreateGraphFromComputeGraph(model.GetGraph()); - return GRAPH_SUCCESS; -} - -graphStatus Graph::LoadFromSerializedModelArray(const void *serialized_model, size_t size) { - GE_ASSERT_NOTNULL(serialized_model, "param serialized_model is nullptr"); - GE_ASSERT(size != 0U, "param size is 0"); - Model model; - GE_ASSERT_GRAPH_SUCCESS(Model::Load(static_cast(serialized_model), size, model), - "Failed to load model from serialized model def."); - GE_ASSERT_NOTNULL(model.GetGraph(), "Failed to get root graph from model."); - *this = GraphUtilsEx::CreateGraphFromComputeGraph(model.GetGraph()); - return GRAPH_SUCCESS; -} - -graphStatus Graph::SaveToMem(GraphBuffer &graph_buffer) const -{ - Model model = Model(); - model.SetGraph(GraphUtilsEx::GetComputeGraph(*this)); - GE_ASSERT_GRAPH_SUCCESS(model.Save(*(graph_buffer.buffer_)), "Failed to save graph to memory."); - return GRAPH_SUCCESS; -} - -graphStatus Graph::LoadFromMem(const GraphBuffer &graph_buffer) -{ - Model model = Model(); - GE_ASSERT_GRAPH_SUCCESS(Model::Load(graph_buffer.GetData(), graph_buffer.GetSize(), model), - "Failed to load graph from memory."); - - *this = GraphUtilsEx::CreateGraphFromComputeGraph(model.GetGraph()); - return GRAPH_SUCCESS; -} - -graphStatus Graph::LoadFromMem(const uint8_t *data, const size_t len) -{ - Model model = Model(); - GE_ASSERT_GRAPH_SUCCESS(Model::Load(data, len, model), "Failed to load graph from memory."); - - *this = GraphUtilsEx::CreateGraphFromComputeGraph(model.GetGraph()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::string &Graph::GetName() const { - return impl_->GetName(); -} - -graphStatus Graph::GetName(AscendString &name) const { - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] impl is nullptr."); - return GRAPH_FAILED; - } - const std::string graph_name = impl_->GetName(); - name = AscendString(graph_name.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus Graph::CopyFrom(const Graph &src_graph) { - const auto res = GraphUtilsEx::CopyGraph(src_graph, *this); - if (res != GRAPH_SUCCESS) { - AscendString name; - (void)src_graph.GetName(name); - REPORT_INNER_ERR_MSG("E18888", "copy graph from %s failed.", name.GetString()); - GELOGE(GRAPH_FAILED, "[Copy][Graph] from %s failed.", name.GetString()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -// 添加AttrValue版本的SetAttr和GetAttr方法 -graphStatus Graph::SetAttr(const AscendString &name, const AttrValue &attr_value) { - GE_ASSERT_NOTNULL(impl_); - const auto compute_graph = impl_->GetComputeGraph(); - GE_ASSERT_NOTNULL(compute_graph, "Inner source is not ready, call `SetInputs` or `SetValid` at first"); - return compute_graph->SetAttr(name.GetString(), attr_value.impl->MutableAnyValue()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Graph::GetAttr(const AscendString &name, - AttrValue &attr_value) const { - GE_ASSERT_NOTNULL(impl_); - const auto compute_graph = impl_->GetComputeGraph(); - GE_ASSERT_NOTNULL(compute_graph, "Inner source is not ready, call `SetInputs` or `SetValid` at first"); - return compute_graph->GetAttr(name.GetString(), attr_value.impl->MutableAnyValue()); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Graph::SetValid() { - GE_ASSERT_NOTNULL(impl_); - if (IsValid()) { - return GRAPH_SUCCESS; - } - GELOGD("Inner graph is invalid, start to init at first"); - return impl_->InitComputeGraph(); -} - -namespace { -graphStatus ConvertComputeGraphToProto(Graph::DumpFormat format, const ComputeGraphPtr &compute_graph, - google::protobuf::Message *const proto) { - ge::Model model("", ""); - model.SetGraph(compute_graph); - switch (format) { - case Graph::DumpFormat::kOnnx: { - model.SetName(compute_graph->GetName()); - auto *model_proto = reinterpret_cast(proto); - GE_ASSERT_NOTNULL(model_proto); - GE_ASSERT_TRUE(OnnxUtils::ConvertGeModelToModelProto(model, *model_proto, DumpLevel::DUMP_WITH_OUT_DESC), - "[Convert][GeModel] DumpGEGraphToOnnx failed."); - break; - } - case Graph::DumpFormat::kTxt: { - auto *model_proto = reinterpret_cast(proto); - GE_ASSERT_NOTNULL(model_proto); - GE_ASSERT_GRAPH_SUCCESS(model.Save(*model_proto, true)); - break; - } - } - return SUCCESS; -} -} - -graphStatus Graph::Dump(Graph::DumpFormat format, std::ostream &o_stream) const { - GE_ASSERT_TRUE(IsValid()); - const auto compute_graph = impl_->GetComputeGraph(); - GE_ASSERT_NOTNULL(compute_graph); - - onnx::ModelProto onnx_model_proto; - ge::proto::ModelDef txt_model_proto; - google::protobuf::Message *proto = nullptr; - switch (format) { - case DumpFormat::kOnnx: { - proto = &onnx_model_proto; - break; - } - case DumpFormat::kTxt: { - proto = &txt_model_proto; - break; - } - } - GE_ASSERT_NOTNULL(proto); - GE_ASSERT_GRAPH_SUCCESS(ConvertComputeGraphToProto(format, compute_graph, proto)); - GraphUtils::WriteProtoToOStream(*proto, o_stream); - return GRAPH_SUCCESS; -} - -graphStatus Graph::DumpToFile(Graph::DumpFormat format, const AscendString &suffix) const { -#ifdef FMK_SUPPORT_DUMP - GE_ASSERT_TRUE(IsValid()); - const auto compute_graph = impl_->GetComputeGraph(); - GE_ASSERT_NOTNULL(compute_graph); - - onnx::ModelProto onnx_model_proto; - ge::proto::ModelDef txt_model_proto; - google::protobuf::Message *proto = nullptr; - std::string file_absolut_name; - switch (format) { - case DumpFormat::kOnnx: { - proto = &onnx_model_proto; - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::GenDumpOnnxFileName(compute_graph, suffix.GetString(), file_absolut_name)); - break; - } - case DumpFormat::kTxt: { - proto = &txt_model_proto; - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::GenDumpTxtFileName(compute_graph, suffix.GetString(), "", file_absolut_name)); - break; - } - } - GE_ASSERT_NOTNULL(proto); - GE_ASSERT_GRAPH_SUCCESS(ConvertComputeGraphToProto(format, compute_graph, proto)); - GraphUtils::WriteProtoToTextFile(*proto, file_absolut_name.c_str()); - return SUCCESS; -#else - (void)format; - (void)suffix; - GELOGW("[Graph][DumpToFile] Need to define FMK_SUPPORT_DUMP for dump graph."); - return FAILED; -#endif -} - -GNodePtr Graph::FindNodeByName(const AscendString &node_name) const { - if (!IsValid()) { - REPORT_INNER_ERROR("E18888", "current graph is invalid."); - GELOGE(GRAPH_FAILED, "[Find][Node] impl is invalid."); - return nullptr; - } - auto node = impl_->GetComputeGraph()->FindNode(node_name.GetString()); - GE_ASSERT_NOTNULL(node, "Node name: %s was not found in the current graph:%s.", node_name.GetString(), impl_->GetName().c_str()); - return NodeAdapter::Node2GNodePtr(node); -} - -ConstGraphPtr Graph::GetParentGraph() const { - if (!IsValid()) { - REPORT_INNER_ERROR("E18888", "current graph is invalid."); - GELOGE(GRAPH_FAILED, "[Get][ParentGraph] current impl is invalid."); - return nullptr; - } - auto parent_compute_graph_ptr = impl_->GetComputeGraph()->GetParentGraph(); - GE_ASSERT_NOTNULL(parent_compute_graph_ptr); - return GraphUtilsEx::CreateGraphPtrFromComputeGraph(parent_compute_graph_ptr); -} - -GNodePtr Graph::GetParentNode() const { - if (!IsValid()) { - REPORT_INNER_ERROR("E18888", "current graph is invalid."); - GELOGE(GRAPH_FAILED, "[Get][ParentNode] current impl is invalid."); - return nullptr; - } - auto parent_graph_node_ptr = impl_->GetComputeGraph()->GetParentNode(); - GE_ASSERT_NOTNULL(parent_graph_node_ptr); - return NodeAdapter::Node2GNodePtr(parent_graph_node_ptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::CopyGraphImpl(const Graph &src_graph, Graph &dst_graph, - const std::map &node_old_2_new, - const std::map &op_desc_old_2_new) { - GE_CHECK_NOTNULL(dst_graph.impl_); - GE_CHECK_NOTNULL(src_graph.impl_); - - std::map &dst_op_list = dst_graph.impl_->op_list_; - const std::map &src_op_list = src_graph.impl_->op_list_; - auto &dst_compute_graph = dst_graph.impl_->compute_graph_; - - dst_graph.impl_->output_name_ = src_graph.impl_->output_name_; - - auto ret = OpDescUtils::CopyOperators(dst_compute_graph, - node_old_2_new, op_desc_old_2_new, - src_op_list, dst_op_list); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "copy operators to graph:%s failed.", dst_compute_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Copy][Operators] to graph:%s failed.", dst_compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - - ret = OpDescUtils::CopyOperatorLinks(src_op_list, dst_op_list); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "copy operator links failed, ret:%u.", ret); - GELOGE(GRAPH_FAILED, "[Copy][OperatorLinks] failed, ret:%u.", ret); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtilsEx::GetComputeGraph(const ge::Graph &graph) { - if (!graph.IsValid()) { - return nullptr; - } - return graph.impl_->compute_graph_; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtilsEx::CreateGraphFromOperatorWithStableTopo( - Graph &graph, - const std::vector &ops) { - AscendString graph_name; - GE_ASSERT_SUCCESS(graph.GetName(graph_name)); - GE_ASSERT_TRUE(graph.impl_->compute_graph_ == nullptr, "Compute graph of graph: %s has been created", - graph_name.GetString()); - graph.impl_->compute_graph_ = - GraphUtilsEx::CreateComputeGraphFromOperatorWithStableTopo(graph_name.GetString(), ops); - GE_ASSERT_NOTNULL(graph.impl_->compute_graph_); - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph -GraphUtilsEx::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { - if (compute_graph == nullptr) { - return Graph(""); - } - - const auto name = compute_graph->GetName(); - const auto graph = Graph(name.c_str()); - if (graph.impl_ == nullptr) { - return graph; - } - graph.impl_->compute_graph_ = compute_graph; - return graph; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::unique_ptr GraphUtilsEx::CreateGraphUniquePtrFromComputeGraph(const ComputeGraphPtr &compute_graph) { - GE_ASSERT_NOTNULL(compute_graph); - auto name = compute_graph->GetName(); - auto graph = ComGraphMakeUnique(name.c_str()); - GE_ASSERT_NOTNULL(graph); - GE_ASSERT_NOTNULL(graph->impl_); - graph->impl_->compute_graph_ = compute_graph; - return graph; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GraphPtr -GraphUtilsEx::CreateGraphPtrFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { - return CreateGraphUniquePtrFromComputeGraph(compute_graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtilsEx::RecoverGraphOperators(const Graph &graph) { - GE_CHECK_NOTNULL(graph.impl_); - GE_CHECK_NOTNULL(graph.impl_->compute_graph_); - - graph.impl_->op_list_.clear(); - for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { - graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); - } - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtilsEx::CopyGraphImpl(const Graph &src_graph, Graph &dst_graph, - const std::map &node_old_2_new, - const std::map &op_desc_old_2_new) { - GE_CHECK_NOTNULL(dst_graph.impl_); - GE_CHECK_NOTNULL(src_graph.impl_); - - std::map &dst_op_list = dst_graph.impl_->op_list_; - const std::map &src_op_list = src_graph.impl_->op_list_; - auto &dst_compute_graph = dst_graph.impl_->compute_graph_; - - dst_graph.impl_->output_name_ = src_graph.impl_->output_name_; - - auto ret = OpDescUtils::CopyOperators(dst_compute_graph, - node_old_2_new, op_desc_old_2_new, - src_op_list, dst_op_list); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "copy operators to graph:%s failed.", dst_compute_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Copy][Operators] to graph:%s failed.", dst_compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - - ret = OpDescUtils::CopyOperatorLinks(src_op_list, dst_op_list); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "copy operator links failed, ret:%u.", ret); - GELOGE(GRAPH_FAILED, "[Copy][OperatorLinks] failed, ret:%u.", ret); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/normal_graph/model.cc b/graph/normal_graph/model.cc deleted file mode 100644 index e1ef5770624ff78a9a623072cb8e934bb117093c..0000000000000000000000000000000000000000 --- a/graph/normal_graph/model.cc +++ /dev/null @@ -1,249 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/model.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "graph/debug/ge_attr_define.h" -#include "debug/ge_util.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/model_serialize.h" -#include "graph/utils/file_utils.h" -#include "mmpa/mmpa_api.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/ge_ir_utils.h" -#include "common/checker.h" -#include "proto/ge_ir.pb.h" - -namespace { -using google::protobuf::io::FileInputStream; -using google::protobuf::io::FileOutputStream; -using google::protobuf::io::ZeroCopyInputStream; - -const int32_t DEFAULT_VERSION = 1; -const int32_t ACCESS_PERMISSION_BITS = 256; // 0400; -static ge::ModelSerialize SERIALIZE; -} // namespace - -namespace ge { -static char_t *GetStrError() { - constexpr size_t kMaxErrLen = 128U; - char_t err_buf[kMaxErrLen + 1U] = {}; - const auto str_error = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], kMaxErrLen); - return str_error; -} - -void Model::Init() { - (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_P2P_MEMORY_SIZE, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); - (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); - (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); - version_ = 0U; -} - -Model::Model() :AttrHolder() { - Init(); -} - -Model::Model(const std::string &name, const std::string &custom_version) - : AttrHolder(), name_(name), version_(static_cast(DEFAULT_VERSION)), platform_version_(custom_version) { - Init(); -} - -Model::Model(const char_t *name, const char_t *custom_version) - : Model(std::string(name == nullptr ? "" : name), - std::string(custom_version == nullptr ? "" : custom_version)) {} - -std::string Model::GetName() const { return name_; } - -void Model::SetName(const std::string &name) { name_ = name; } - -uint32_t Model::GetVersion() const { return version_; } - -std::string Model::GetPlatformVersion() const { return platform_version_; } - -void Model::SetGraph(const ComputeGraphPtr &graph) { graph_ = graph; } - -const ComputeGraphPtr Model::GetGraph() const { return graph_; } - -graphStatus Model::Save(Buffer &buffer, const bool is_dump) const { - buffer = SERIALIZE.SerializeModel(*this, is_dump); - return (buffer.GetSize() > 0U) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::SaveWithoutSeparate(Buffer &buffer, - const bool is_dump) const { - std::string path; - buffer = SERIALIZE.SerializeModel(*this, path, false, is_dump); - return (buffer.GetSize() > 0U) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::Save(Buffer &buffer, const std::string &path, const bool is_dump) const { - buffer = SERIALIZE.SerializeModel(*this, path, true, is_dump); - return (buffer.GetSize() > 0U) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::SaveSeparateModel(Buffer &buffer, const std::string &path, const bool is_dump) const { - buffer = SERIALIZE.SerializeSeparateModel(*this, path, is_dump); - return (buffer.GetSize() > 0U) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::Save(proto::ModelDef &model_def, const bool is_dump) const { - return SERIALIZE.SerializeModel(*this, is_dump, model_def); -} - -void Model::SetAttr(const ProtoAttrMap &attrs) { attrs_ = attrs; } - -graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) { - return SERIALIZE.UnserializeModel(data, len, model) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::LoadWithMultiThread(const uint8_t *data, size_t len, Model &model) { - return SERIALIZE.UnserializeModel(data, len, model, true) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::Load(ge::proto::ModelDef &model_def, const std::string &path) { - return SERIALIZE.UnserializeModel(model_def, *this, path) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::Load(ge::proto::ModelDef &model_def) { - return SERIALIZE.UnserializeModel(model_def, *this) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Model::SaveToFile(const std::string &file_name, const bool force_separate) const { - Buffer buffer; - std::string dir_path; - std::string file; - SplitFilePath(file_name, dir_path, file); - if (!dir_path.empty()) { - GE_ASSERT_TRUE((CreateDir(dir_path) == 0), - "Create direct failed, path: %s.", file_name.c_str()); - } else { - GE_ASSERT_SUCCESS(GetAscendWorkPath(dir_path)); - if (dir_path.empty()) { - dir_path = "./"; - } - } - std::string real_path = RealPath(dir_path.c_str()); - GE_ASSERT_TRUE(!real_path.empty(), "Path: %s is empty", file_name.c_str()); - real_path = real_path + "/" + file; - - graphStatus ret = GRAPH_SUCCESS; - if (!force_separate) { - ret = (*this).Save(buffer, real_path); - } else { - ret = (*this).SaveSeparateModel(buffer, real_path); - } - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "[Save][Data] to file:%s fail.", file_name.c_str()); - GELOGE(ret, "[Save][Data] to file:%s fail.", file_name.c_str()); - return ret; - } - // Write file - if (buffer.GetData() != nullptr) { - ge::proto::ModelDef ge_proto; - const std::string str(PtrToPtr(buffer.GetData()), buffer.GetSize()); - if (!ge_proto.ParseFromString(str)) { - return GRAPH_FAILED; - } - const int32_t fd = - mmOpen2(&real_path[0], static_cast(static_cast(M_WRONLY) | static_cast(M_CREAT) | - static_cast(O_TRUNC)), static_cast(ACCESS_PERMISSION_BITS)); - if (fd < 0) { - REPORT_INNER_ERR_MSG("E18888", "open file:%s failed, error:%s ", &real_path[0], GetStrError()); - GELOGE(GRAPH_FAILED, "[Open][File] %s failed, error:%s ", &real_path[0], GetStrError()); - return GRAPH_FAILED; - } - const bool result = ge_proto.SerializeToFileDescriptor(fd); - if (!result) { - REPORT_INNER_ERR_MSG("E18888", "SerializeToFileDescriptor failed, file:%s.", &real_path[0]); - GELOGE(GRAPH_FAILED, "[Call][SerializeToFileDescriptor] failed, file:%s.", &real_path[0]); - if (mmClose(fd) != 0) { - REPORT_INNER_ERR_MSG("E18888", "close file:%s fail, error:%s.", &real_path[0], GetStrError()); - GELOGE(GRAPH_FAILED, "[Close][File] %s fail, error:%s.", &real_path[0], GetStrError()); - return GRAPH_FAILED; - } - return GRAPH_FAILED; - } - if (mmClose(fd) != 0) { - REPORT_INNER_ERR_MSG("E18888", "close file:%s fail, error:%s.", &real_path[0], GetStrError()); - GELOGE(GRAPH_FAILED, "[Close][File] %s fail, error:%s.", &real_path[0], GetStrError()); - return GRAPH_FAILED; - } - if (!result) { - REPORT_INNER_ERR_MSG("E18888", "SerializeToFileDescriptor failed, file:%s.", &real_path[0]); - GELOGE(GRAPH_FAILED, "[Call][SerializeToFileDescriptor] failed, file:%s.", &real_path[0]); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -bool Model::IsValid() const { return graph_ != nullptr; } - -graphStatus Model::LoadFromFile(const std::string &file_name) { - char_t real_path[MMPA_MAX_PATH] = {}; - if (strnlen(file_name.c_str(), sizeof(real_path)) >= sizeof(real_path)) { - return GRAPH_FAILED; - } - const INT32 result = mmRealPath(file_name.c_str(), &real_path[0], MMPA_MAX_PATH); - if (result != EN_OK) { - REPORT_INNER_ERR_MSG("E18888", "get realpath failed for %s, error:%s.", file_name.c_str(), GetStrError()); - GELOGE(GRAPH_FAILED, "[Get][RealPath] failed for %s, error:%s.", file_name.c_str(), GetStrError()); - return GRAPH_FAILED; - } - const int32_t fd = mmOpen(&real_path[0], M_RDONLY); - if (fd < 0) { - REPORT_INNER_ERR_MSG("E18888", "open file:%s failed, error:%s", &real_path[0], GetStrError()); - GELOGE(GRAPH_FAILED, "[Open][File] %s failed, error:%s", &real_path[0], GetStrError()); - return GRAPH_FAILED; - } - - ge::proto::ModelDef model_def; - const bool ret = model_def.ParseFromFileDescriptor(fd); - if (!ret) { - REPORT_INNER_ERR_MSG("E18888", "ParseFromFileDescriptor failed, file:%s.", &real_path[0]); - GELOGE(GRAPH_FAILED, "[Call][ParseFromFileDescriptor] failed, file:%s.", &real_path[0]); - if (mmClose(fd) != 0) { - REPORT_INNER_ERR_MSG("E18888", "close file:%s fail, error:%s.", &real_path[0], GetStrError()); - GELOGE(GRAPH_FAILED, "[Close][File] %s fail. error:%s", &real_path[0], GetStrError()); - return GRAPH_FAILED; - } - return GRAPH_FAILED; - } - if (mmClose(fd) != 0) { - REPORT_INNER_ERR_MSG("E18888", "close file:%s fail, error:%s.", &real_path[0], GetStrError()); - GELOGE(GRAPH_FAILED, "[Close][File] %s fail. error:%s", &real_path[0], GetStrError()); - return GRAPH_FAILED; - } - if (!ret) { - REPORT_INNER_ERR_MSG("E18888", "ParseFromFileDescriptor failed, file:%s.", &real_path[0]); - GELOGE(GRAPH_FAILED, "[Call][ParseFromFileDescriptor] failed, file:%s.", &real_path[0]); - return GRAPH_FAILED; - } - std::string path(real_path); - return Load(model_def, file_name); -} - -ProtoAttrMap &Model::MutableAttrMap() { return attrs_; } - -ConstProtoAttrMap &Model::GetAttrMap() const { - return attrs_; -} -} // namespace ge diff --git a/graph/normal_graph/node.cc b/graph/normal_graph/node.cc deleted file mode 100644 index 90249ce7b77ef037da6fce74fc7cd4b3d10cc379..0000000000000000000000000000000000000000 --- a/graph/normal_graph/node.cc +++ /dev/null @@ -1,1196 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/node.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "external/graph/operator_factory.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/operator_factory_impl.h" -#include "graph/shape_refiner.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "common/util/trace_manager/trace_manager.h" - -namespace ge { -Node::NodeImpl::NodeImpl(const OpDescPtr &op, const ComputeGraphPtr &owner_graph) - : op_(op), - owner_graph_(owner_graph), - owner_graph_ptr_(owner_graph.get()), - in_data_anchors_(), - out_data_anchors_(), - in_control_anchor_(nullptr), - out_control_anchor_(nullptr), - has_init_(false), - host_node_(false), - anchor_status_updated_(false) {} - -Node::NodeImpl::~NodeImpl() { - for (const auto &in_data_anchor : in_data_anchors_) { - if (in_data_anchor != nullptr) { - in_data_anchor->UnlinkAll(); - } - } - for (const auto &out_data_anchor : out_data_anchors_) { - if (out_data_anchor != nullptr) { - out_data_anchor->UnlinkAll(); - } - } - if (in_control_anchor_ != nullptr) { - in_control_anchor_->UnlinkAll(); - } - if (out_control_anchor_ != nullptr) { - out_control_anchor_->UnlinkAll(); - } -} - -graphStatus Node::NodeImpl::Init(const NodePtr &node) { - if (has_init_) { - return GRAPH_SUCCESS; - } - GE_CHK_BOOL_EXEC(op_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return GRAPH_FAILED, "[Check][Param] original OpDesc is nullptr"); - size_t size = op_->GetAllInputsSize(); - in_data_anchors_.reserve(size); - for (size_t i = 0UL; i < size; i++) { - const std::shared_ptr anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Current in_data_anchor is null, malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][InDataAnchor] Current in_data_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - } - size = op_->GetOutputsSize(); - out_data_anchors_.reserve(size); - for (size_t i = 0UL; i < size; i++) { - const std::shared_ptr anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Current out_data_anchor is null, malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][OutDataAnchor] Current out_data_anchor is null, malloc shared_ptr failed."); - return GRAPH_FAILED; - } - out_data_anchors_.push_back(anchor); - } - in_control_anchor_ = ComGraphMakeShared(node, -1); - out_control_anchor_ = ComGraphMakeShared(node, -1); - if ((in_control_anchor_ == nullptr) || (out_control_anchor_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Current in_control_anchor or out_control_anchor is null, malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][ControlAnchor] Current in_control_anchor or out_control_anchor is null, " - "malloc shared_ptr failed."); - return GRAPH_FAILED; - } - has_init_ = true; - return GRAPH_SUCCESS; -} - -std::string Node::NodeImpl::GetName() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return std::string(), "[Check][Param] original OpDesc is nullptr"); - return op_->GetName(); -} - -std::string Node::NodeImpl::GetType() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return std::string(), "[Check][Param] original OpDesc is nullptr"); - return op_->GetType(); -} - -const char *Node::NodeImpl::GetNamePtr() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return nullptr, "[Check][Param] original OpDesc is nullptr"); - return op_->GetNamePtr(); -} - -const char *Node::NodeImpl::GetTypePtr() const { - GE_CHK_BOOL_EXEC(op_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return nullptr, "[Check][Param] original OpDesc is nullptr"); - return op_->GetTypePtr(); -} - -bool Node::NodeImpl::NodeMembersAreEqual(const NodeImpl &r_node) const { - return ((((this->op_ != nullptr) && (r_node.op_ != nullptr) && (IsEqual(*(this->op_), *(r_node.op_), "node.op_"))) || - ((this->op_ == nullptr) && (r_node.op_ == nullptr))) && - IsEqual(this->has_init_, r_node.has_init_, "node.has_init_") && - IsEqual(this->anchor_status_updated_, r_node.anchor_status_updated_, "node.anchor_status_updated_") && - IsEqual(this->send_event_id_list_, r_node.send_event_id_list_, "node.send_event_id_list_") && - IsEqual(this->recv_event_id_list_, r_node.recv_event_id_list_, "node.recv_event_id_list_")); -} - -bool Node::NodeImpl::NodeAnchorIsEqual(const AnchorPtr &left_anchor, - const AnchorPtr &right_anchor, - const size_t i) const { - GE_IF_BOOL_EXEC(left_anchor == nullptr, REPORT_INNER_ERR_MSG("E18888", "left_anchor is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] left_anchor is null."); return false); - GE_IF_BOOL_EXEC(right_anchor == nullptr, REPORT_INNER_ERR_MSG("E18888", "right_anchor is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] right_anchor is null."); return false); - const auto anchor_peer_size = left_anchor->GetPeerAnchorsSize(); - const auto right_anchor_peer_size = right_anchor->GetPeerAnchorsSize(); - // Firstly, verify anchor's peer anchors size equal or not - if (anchor_peer_size != right_anchor_peer_size) { - REPORT_INNER_ERR_MSG("E18888", - "Size of anchor's peer anchors verify failed, node name: %s " - "anchor_peer_size [%zu] is different form [%zu] at index [%zu].", - this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i); - GELOGE(GRAPH_FAILED, "[Check][Param] Size of anchor's peer anchors verify failed, node name: %s " - "anchor_peer_size [%zu] is different form [%zu] at index [%zu].", - this->GetName().c_str(), anchor_peer_size, right_anchor_peer_size, i); - return false; - } - // Secondly, verify anchor's peer anchor owner node equal or not - for (size_t j = 0UL; j < anchor_peer_size; j++) { - const auto peer_node = left_anchor->GetPeerAnchorsPtr().at(j)->GetOwnerNodeBarePtr(); - const auto r_peer_node = right_anchor->GetPeerAnchorsPtr().at(j)->GetOwnerNodeBarePtr(); - if ((peer_node == nullptr) || (r_peer_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu].", - this->GetName().c_str(), i, j); - GELOGE(GRAPH_FAILED, "[Get][OwnerNode] anchor's peer node is null, node name: %s index[%zu] " - "peer node index[%zu].", this->GetName().c_str(), i, j); - return false; - } - // Determine the connection relationship by linking the node's name - if (peer_node->GetName() != r_peer_node->GetName()) { - REPORT_INNER_ERR_MSG("E18888", - "anchor's peer node name verify failed, node name: %s index[%zu]" - "peer node name %s is different from %s at index [%zu].", - this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j); - GELOGE(GRAPH_FAILED, "[Check][Param] anchor's peer node name verify failed, node name: %s index[%zu]" - "peer node name %s is different from %s at index [%zu].", - this->GetName().c_str(), i, peer_node->GetName().c_str(), r_peer_node->GetName().c_str(), j); - return false; - } - } - return true; -} - -graphStatus Node::NodeImpl::AddLinkFrom(const NodePtr &input_node, const NodePtr &owner_node) { - // This function is deprecated, please use other two overloaded functions - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1UL) { - REPORT_INNER_ERR_MSG("E18888", "node:%s out_anchor size is:%zu, only support 1", input_node->GetName().c_str(), - out_anchors.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - GE_CHK_BOOL_EXEC(op_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return GRAPH_FAILED, "[Check][Param] original OpDesc is nullptr"); - const auto op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - if (op_->AddInputDesc(op_desc->GetOutputDesc(0U)) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "add input desc failed, op:%s.", op_->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][InputDesc] failed."); - return GRAPH_FAILED; - } - const std::shared_ptr anchor = ComGraphMakeShared(owner_node, in_data_anchors_.size()); - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); - GELOGE(GRAPH_FAILED, "[Create][InDataAnchor] out_anchor size is:%zu, malloc shared_ptr failed.", - out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void) out_anchors.at(0U)->LinkTo(in_data_anchors_.back()); - - return GRAPH_SUCCESS; -} - -graphStatus Node::NodeImpl::AddLinkFrom(const uint32_t &index, - const NodePtr &input_node, - const NodePtr &owner_node) { - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1UL) { - REPORT_INNER_ERR_MSG("E18888", "node:%s out_anchor size is:%zu, only support 1", input_node->GetName().c_str(), - out_anchors.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - GE_CHECK_NOTNULL(op_); - const auto op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - if (op_->AddInputDesc(index, op_desc->GetOutputDesc(0U)) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "add input desc failed, index:%u.", index); - GELOGE(GRAPH_FAILED, "[Add][InputDesc] failed."); - return GRAPH_FAILED; - } - - if (index < GetAllInDataAnchors(owner_node).size()) { - (void) out_anchors.at(0U)->LinkTo(in_data_anchors_[static_cast(index)]); - } else { - const std::shared_ptr - anchor = ComGraphMakeShared(owner_node, in_data_anchors_.size()); - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "out_anchor size is:%zu, malloc shared_ptr failed.", out_anchors.size()); - GELOGE(GRAPH_FAILED, "[Create][InDataAnchor] out_anchor size is:%zu, malloc shared_ptr failed.", - out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void) out_anchors.at(0U)->LinkTo(in_data_anchors_.back()); - } - - return GRAPH_SUCCESS; -} - -graphStatus Node::NodeImpl::AddLinkFromForParse(const NodePtr &input_node, const NodePtr &owner_node) { - // This function is used for ParseWeights. - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1UL) { - REPORT_INNER_ERR_MSG("E18888", "node:%s out_anchor size is:%zu, only support 1", input_node->GetName().c_str(), - out_anchors.size()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - const std::shared_ptr anchor = ComGraphMakeShared(owner_node, in_data_anchors_.size()); - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "out_anchor size is:%zu, make anchor failed", out_anchors.size()); - GELOGE(GRAPH_FAILED, "[Create][InDataAnchor] out_anchor size is:%zu, make anchor failed", out_anchors.size()); - return GRAPH_FAILED; - } - in_data_anchors_.push_back(anchor); - (void)out_anchors.at(0U)->LinkTo(in_data_anchors_.back()); - - return GRAPH_SUCCESS; -} - -graphStatus Node::NodeImpl::AddLinkFrom(const std::string &name, const NodePtr &input_node, const NodePtr &owner_node) { - GE_CHECK_NOTNULL(input_node); - // Input_node ---> this - auto out_anchors = input_node->GetAllOutDataAnchors(); - if (out_anchors.size() != 1UL) { - REPORT_INNER_ERR_MSG("E18888", "node:%s out_anchor size is:%zu, only support 1", input_node->GetName().c_str(), - out_anchors.size()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] out_anchor size is:%zu, only support 1", out_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - GE_CHECK_NOTNULL(op_); - const auto input_op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(input_op_desc); - const auto index = op_->GetInputIndexByName(name); - if (index != -1) { - if (index >= static_cast(in_data_anchors_.size())) { - REPORT_INNER_ERR_MSG("E18888", - "op %s get input name %s 's index %d is illegal as which >= indataanchors size:%zu.", - op_->GetName().c_str(), name.c_str(), index, in_data_anchors_.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] op %s get input name %s 's index %d is illegal.", - op_->GetName().c_str(), name.c_str(), index); - return GRAPH_FAILED; - } - (void) out_anchors.at(0U)->LinkTo(in_data_anchors_[static_cast(index)]); - } else { - const std::shared_ptr - anchor = ComGraphMakeShared(owner_node, in_data_anchors_.size()); - GE_CHECK_NOTNULL(anchor); - in_data_anchors_.push_back(anchor); - (void) out_anchors.at(0U)->LinkTo(in_data_anchors_.back()); - } - if (op_->AddInputDesc(name, input_op_desc->GetOutputDesc(0U)) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "add input desc failed, name:%s, op:%s", name.c_str(), op_->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][InputDesc] failed."); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -ComputeGraphPtr Node::NodeImpl::GetOwnerComputeGraph() const { - return owner_graph_.lock(); -} - -ComputeGraph *Node::NodeImpl::GetOwnerComputeGraphBarePtr() const { - return owner_graph_ptr_; -} - -graphStatus Node::NodeImpl::SetOwnerComputeGraph(const ComputeGraphPtr &graph) { - if (graph == nullptr) { - return GRAPH_PARAM_INVALID; - } - owner_graph_ = graph; - owner_graph_ptr_ = graph.get(); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "owner_graph", "", "", graph->GetName()); - return GRAPH_SUCCESS; -} - -graphStatus Node::NodeImpl::ClearOwnerGraph(const ComputeGraphPtr &graph) { - owner_graph_ = graph; - owner_graph_ptr_ = graph.get(); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "delete", TraceManager::GetOutGraphName(), - this->GetName(), "owner_graph", "", "", ((graph == nullptr) ? std::string("") : graph->GetName())); - return GRAPH_SUCCESS; -} - -Node::Vistor Node::NodeImpl::GetAllInDataAnchors(const ConstNodePtr &node_ptr) const { - return Node::Vistor(node_ptr, in_data_anchors_); -} - -Node::Vistor Node::NodeImpl::GetAllOutDataAnchors(const ConstNodePtr &node_ptr) const { - return Node::Vistor(node_ptr, out_data_anchors_); -} - -uint32_t Node::NodeImpl::GetAllInDataAnchorsSize() const { - return static_cast(in_data_anchors_.size()); -} - -uint32_t Node::NodeImpl::GetAllOutDataAnchorsSize() const { - return static_cast(out_data_anchors_.size()); -} - -Node::Vistor Node::NodeImpl::GetAllInAnchors(const ConstNodePtr &owner_node) const { - std::vector vec; - // Push back in_data_anchors_ - for (const auto &in_anchor_iter : Node::Vistor(owner_node, in_data_anchors_)) { - const auto in_anchor = Anchor::DynamicAnchorCast(in_anchor_iter); - if (in_anchor != nullptr) { - vec.push_back(in_anchor); - } - } - // Push back in_control_anchor_ - if ((!in_control_anchor_->GetPeerOutControlAnchorsPtr().empty()) || - (!in_control_anchor_->GetPeerOutDataAnchors().empty())) { - const auto in_anchor = Anchor::DynamicAnchorCast(in_control_anchor_); - if (in_anchor != nullptr) { - vec.push_back(in_anchor); - } - } - return Node::Vistor(owner_node, vec); -} - -std::vector Node::NodeImpl::GetAllInAnchorsPtr() const { - std::vector vec; - vec.reserve(in_data_anchors_.size() + 1U); // in_data_anchors_ + in_control_anchor_ - // Push back in_data_anchors_ - for (const auto &in_anchor : in_data_anchors_) { - if (in_anchor != nullptr) { - vec.emplace_back(in_anchor.get()); - } - } - // Push back in_control_anchor_ - vec.emplace_back(in_control_anchor_.get()); - return vec; -} - -Node::Vistor Node::NodeImpl::GetAllOutAnchors(const ConstNodePtr &owner_node) const { - std::vector vec; - // Push back out_data_anchors_ - for (const auto &out_anchor_iter : Node::Vistor(owner_node, out_data_anchors_)) { - const auto out_anchor = Anchor::DynamicAnchorCast(out_anchor_iter); - if (out_anchor != nullptr) { - vec.push_back(out_anchor); - } - } - // Push back out_control_anchor_ - if ((!out_control_anchor_->GetPeerInControlAnchorsPtr().empty()) || - (!out_control_anchor_->GetPeerInDataAnchors().empty())) { - const auto out_anchor = Anchor::DynamicAnchorCast(out_control_anchor_); - if (out_anchor != nullptr) { - vec.push_back(out_anchor); - } - } - return Node::Vistor(owner_node, vec); -} - -std::vector Node::NodeImpl::GetAllOutAnchorsPtr() const { - std::vector vec; - vec.reserve(out_data_anchors_.size() + 1U); // out_data_anchors_ + out_control_anchor_ - // Push back out_data_anchors_ - for (const auto &out_anchor : out_data_anchors_) { - if (out_anchor != nullptr) { - vec.emplace_back(out_anchor.get()); - } - } - // Push back out_control_anchor_ - vec.emplace_back(out_control_anchor_.get()); - return vec; -} - -InDataAnchorPtr Node::NodeImpl::GetInDataAnchor(const int32_t idx) const { - if ((idx < 0) || (idx >= static_cast(in_data_anchors_.size()))) { - GELOGW("[Check][Param] Op %s doesn't have data input %d, type = %s", GetName().c_str(), idx, GetType().c_str()); - return nullptr; - } else { - return in_data_anchors_[static_cast(idx)]; - } -} - -AnchorPtr Node::NodeImpl::GetInAnchor(const int32_t idx) const { - // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ - if ((idx < -1) || (idx >= static_cast(in_data_anchors_.size()))) { - GELOGW("[Check][Param] Op %s doesn't have input %d, type = %s", GetName().c_str(), idx, GetType().c_str()); - return nullptr; - } else { - // Return control anchor - if (idx == -1) { - return Anchor::DynamicAnchorCast(in_control_anchor_); - } - // Return data anchor - return in_data_anchors_[static_cast(idx)]; - } -} - -AnchorPtr Node::NodeImpl::GetOutAnchor(const int32_t idx) const { - // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ - if ((idx < -1) || (idx >= static_cast(out_data_anchors_.size()))) { - REPORT_INNER_ERR_MSG("E18888", "Op:%s(%s) doesn't have index:%d's anchorname", GetName().c_str(), GetType().c_str(), - idx); - GELOGE(GRAPH_FAILED, "[Check][Param] Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", - GetName().c_str(), idx, GetType().c_str()); - return nullptr; - } else { - // Return control anchor - if (idx == -1) { - return Anchor::DynamicAnchorCast(out_control_anchor_); - } - // Return data anchor - return out_data_anchors_[static_cast(idx)]; - } -} - -OutDataAnchorPtr Node::NodeImpl::GetOutDataAnchor(const int32_t idx) const { - if ((idx < 0) || (idx >= static_cast(out_data_anchors_.size()))) { - REPORT_INNER_ERR_MSG("E18888", "Op:%s(%s) doesn't have index:%d's anchorname", GetName().c_str(), GetType().c_str(), - idx); - GELOGE(GRAPH_FAILED, "[Check][Param] Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", - GetName().c_str(), idx, GetType().c_str()); - return nullptr; - } else { - return out_data_anchors_[static_cast(idx)]; - } -} - -InControlAnchorPtr Node::NodeImpl::GetInControlAnchor() const { - return in_control_anchor_; -} - -OutControlAnchorPtr Node::NodeImpl::GetOutControlAnchor() const { - return out_control_anchor_; -} - -Node::Vistor Node::NodeImpl::GetInNodes(const ge::ConstNodePtr &owner_node) const { - std::vector vec; - vec.reserve(GetInNodesSize()); - for (const auto &in_anchor : in_data_anchors_) { - if (in_anchor == nullptr) { - continue; - } - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - const auto &node = out_anchor->GetOwnerNode(); - vec.push_back(node); - } - if (in_control_anchor_ != nullptr) { - for (const auto out_control_anchor : in_control_anchor_->GetPeerOutControlAnchorsPtr()) { - const auto &node = out_control_anchor->GetOwnerNode(); - vec.push_back(node); - } - } - return Node::Vistor(owner_node, vec); -} - -std::vector Node::NodeImpl::GetInNodesPtr() const { - std::vector in_nodes; - in_nodes.reserve(GetInNodesSize()); - for (const auto &in_anchor : in_data_anchors_) { - if (in_anchor == nullptr) { - continue; - } - const auto &out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - in_nodes.push_back(out_anchor->GetOwnerNodeBarePtr()); - } - if (in_control_anchor_ != nullptr) { - for (const auto out_control_anchor : in_control_anchor_->GetPeerOutControlAnchorsPtr()) { - in_nodes.push_back(out_control_anchor->GetOwnerNodeBarePtr()); - } - } - return in_nodes; -} - -bool Node::NodeImpl::IsAllInNodesSeen(const std::unordered_set &nodes_seen) const { - for (const auto &in_anchor : in_data_anchors_) { - GE_CHK_BOOL_EXEC((in_anchor != nullptr), - continue, "[Check][Param] in_data_anchor is nullptr, node:%s", GetName().c_str()); - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - const auto node = out_anchor->GetOwnerNodeBarePtr(); - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "peer node is null, node name: %s input index[%d] peer node output index[%d].", - GetName().c_str(), in_anchor->GetIdx(), out_anchor->GetIdx()); - GELOGE(GRAPH_FAILED, - "[Get][OwnerNode] peer node is null, node name: %s input index[%d] peer node output index[%d].", - GetName().c_str(), in_anchor->GetIdx(), out_anchor->GetIdx()); - return false; - } - if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { - continue; - } - if (nodes_seen.count(node) == 0U) { - return false; - } - } - - if (in_control_anchor_ != nullptr) { - if (in_control_anchor_->IsPeerOutAnchorsEmpty()) { - return true; - } - const auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); - for (const auto &out_control_anchor : peer_out_control_anchors) { - const auto node = out_control_anchor->GetOwnerNodeBarePtr(); - if ((node->GetType() == NEXTITERATION) || (node->GetType() == REFNEXTITERATION)) { - continue; - } - if (nodes_seen.count(node) == 0U) { - return false; - } - } - } - - return true; -} - -Node::Vistor Node::NodeImpl::GetInDataNodes(const ge::ConstNodePtr &owner_node) const { - const auto &vec = NodeUtils::GetInDataNodes(*owner_node, nullptr); - return Node::Vistor(owner_node, vec); -} - -Node::Vistor Node::NodeImpl::GetInControlNodes(const ge::ConstNodePtr &owner_node) const { - const auto &vec = NodeUtils::GetInControlNodes(*owner_node, nullptr); - return Node::Vistor(owner_node, vec); -} - -Node::Vistor Node::NodeImpl::GetOutNodes(const ge::ConstNodePtr &owner_node) const { - std::vector vec; - vec.reserve(GetOutNodesSize()); - for (const auto &out_anchor : out_data_anchors_) { - if (out_anchor == nullptr) { - continue; - } - for (const auto peer_in_anchor : out_anchor->GetPeerInDataAnchorsPtr()) { - const auto &node = peer_in_anchor->GetOwnerNode(); - vec.push_back(node); - } - } - if (out_control_anchor_ != nullptr) { - for (const auto in_control_anchor : out_control_anchor_->GetPeerInControlAnchorsPtr()) { - const auto &node = in_control_anchor->GetOwnerNode(); - vec.push_back(node); - } - } - return Node::Vistor(owner_node, vec); -} - -std::vector Node::NodeImpl::GetOutNodesPtr() const { - std::vector vec; - vec.reserve(GetOutNodesSize()); - for (const auto &out_anchor : out_data_anchors_) { - if (out_anchor == nullptr) { - continue; - } - for (const auto peer_in_anchor : out_anchor->GetPeerInDataAnchorsPtr()) { - vec.push_back(peer_in_anchor->GetOwnerNodeBarePtr()); - } - } - if (out_control_anchor_ != nullptr) { - for (const auto in_control_anchor : out_control_anchor_->GetPeerInControlAnchorsPtr()) { - vec.push_back(in_control_anchor->GetOwnerNodeBarePtr()); - } - } - return vec; -} - -Node::Vistor Node::NodeImpl::GetInAllNodes(const ge::ConstNodePtr &owner_node) const { - return GetInNodes(owner_node); -} - -Node::Vistor Node::NodeImpl::GetOutDataNodes(const ConstNodePtr &owner_node) const { - const auto &vec = NodeUtils::GetOutDataNodes(*owner_node, nullptr); - return Node::Vistor(owner_node, vec); -} - -std::vector Node::NodeImpl::GetOutDataNodesPtr() const { - std::vector vec; - for (const auto &out_anchor : out_data_anchors_) { - if (out_anchor != nullptr) { - for (const auto in_anchor : out_anchor->GetPeerInDataAnchorsPtr()) { - if (in_anchor != nullptr) { - const auto node = in_anchor->GetOwnerNodeBarePtr(); - vec.emplace_back(node); - } - } - } - } - return vec; -} - -uint32_t Node::NodeImpl::GetOutDataNodesSize() const { - uint32_t out_nums = 0U; - for (const auto &out_anchor : out_data_anchors_) { - GE_CHK_BOOL_EXEC((out_anchor != nullptr), continue, "[Check][Param] out data anchor is nullptr, node:%s", - GetName().c_str()); - out_nums += out_anchor->GetPeerInDataNodesSize(); - } - return out_nums; -} - -uint32_t Node::NodeImpl::GetOutControlNodesSize() const { - uint32_t out_nums = 0U; - if (out_control_anchor_ != nullptr) { - out_nums += out_control_anchor_->GetPeerAnchorsSize(); - } - return out_nums; -} - -uint32_t Node::NodeImpl::GetOutNodesSize() const { - return GetOutDataNodesSize() + GetOutControlNodesSize(); -} - -Node::Vistor Node::NodeImpl::GetOutControlNodes(const ge::ConstNodePtr &owner_node) const { - const auto &vec = NodeUtils::GetOutControlNodes(*owner_node, nullptr); - return Node::Vistor(owner_node, vec); -} - -Node::Vistor Node::NodeImpl::GetOutAllNodes(const ge::ConstNodePtr &owner_node) const { - return GetOutNodes(owner_node); -} - -OpDescPtr Node::NodeImpl::GetOpDesc() const { - return op_; -} - -OpDesc *Node::NodeImpl::GetOpDescBarePtr() const { - return op_.get(); -} - -graphStatus Node::NodeImpl::UpdateOpDesc(const OpDescPtr &op_desc) { - GE_CHK_BOOL_EXEC(op_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "original OpDesc is nullptr"); - return GRAPH_FAILED, "[Check][Param] original OpDesc is nullptr"); - GE_CHK_BOOL_EXEC(op_desc != nullptr, REPORT_INNER_ERR_MSG("E18888", "param op_desc is nullptr, check invalid."); - return GRAPH_PARAM_INVALID, "[Check][Param] Param OpDesc is nullptr"); - GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), - REPORT_INNER_ERR_MSG("E18888", - "inputs count(%zu) of param op_desc not equal to " - "inputs count(%zu) of original opdesc:%s, check invalid", - op_desc->GetInputsSize(), op_->GetInputsSize(), op_->GetName().c_str()); - return GRAPH_PARAM_INVALID, - "[Check][Param] Inputs count expected to be same, original OpDesc %zu, Param OpDesc %zu", - op_->GetInputsSize(), op_desc->GetInputsSize()); - - GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), - REPORT_INNER_ERR_MSG("E18888", - "outputs count(%zu) of param op_desc not equal to " - "outputs count(%zu) of original opdesc:%s, check invalid", - op_desc->GetOutputsSize(), op_->GetOutputsSize(), op_->GetName().c_str()); - return GRAPH_PARAM_INVALID, - "[Check][Param] Outputs count expected to be same, original OpDesc %zu, Param OpDesc %zu", - op_->GetOutputsSize(), op_desc->GetOutputsSize()); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "op_desc", "", "", op_desc->GetName()); - op_ = op_desc; - return GRAPH_SUCCESS; -} - -Node::Vistor> Node::NodeImpl::GetInDataNodesAndAnchors( - const ConstNodePtr &owner_node) const { - std::vector> vec; - for (const auto &p : in_data_anchors_) { - if (p == nullptr) { - GELOGW("[Check][Param] In data anchor is nullptr, node=%s, type=%s", GetType().c_str(), GetName().c_str()); - continue; - } - auto anchor_ptr = p->GetPeerOutAnchor(); - if (anchor_ptr == nullptr) { - continue; - } - auto node = anchor_ptr->GetOwnerNode(); - if (node == nullptr) { - GELOGW("[Check][Param] Src node is nullptr, node=%s, type=%s", GetType().c_str(), GetName().c_str()); - continue; - } - vec.emplace_back(node, anchor_ptr); - } - return Node::Vistor>(owner_node, vec); -} - -Node::Vistor> Node::NodeImpl::GetOutDataNodesAndAnchors( - const ConstNodePtr &owner_node) const { - std::vector> vec; - for (const auto &p : out_data_anchors_) { - if (p == nullptr) { - GELOGW("[Check][Param] Out data anchor is nullptr, node=%s, type=%s", GetType().c_str(), GetName().c_str()); - continue; - } - for (const auto &in_anchor : p->GetPeerInDataAnchors()) { - if (in_anchor == nullptr) { - GELOGW("[Check][Param] Dst in data anchor is nullptr, node=%s, type=%s", GetType().c_str(), GetName().c_str()); - continue; - } - auto node = in_anchor->GetOwnerNode(); - if (node == nullptr) { - GELOGW("[Check][Param] Dst node is nullptr, node=%s, type=%s", GetType().c_str(), GetName().c_str()); - continue; - } - vec.emplace_back(node, in_anchor); - } - } - return Node::Vistor>(owner_node, vec); -} - -size_t Node::NodeImpl::GetInDataNodesSize() const { - size_t size = 0U; - for (const auto &in_anchor : in_data_anchors_) { - if (in_anchor == nullptr) { - continue; - } - size += in_anchor->GetPeerAnchorsSize(); - } - return size; -} - -size_t Node::NodeImpl::GetInControlNodesSize() const { - size_t size = 0U; - if (in_control_anchor_ != nullptr) { - size = in_control_anchor_->GetPeerAnchorsSize(); - } - return size; -} - -size_t Node::NodeImpl::GetInNodesSize() const { - return GetInDataNodesSize() + GetInControlNodesSize(); -} - -std::vector Node::NodeImpl::GetAllInDataAnchorsPtr() const { - std::vector in_data_anchors; - in_data_anchors.reserve(in_data_anchors_.size()); - for (const auto &in_data_anchor : in_data_anchors_) { - in_data_anchors.emplace_back(in_data_anchor.get()); - } - return in_data_anchors; -} - -std::vector Node::NodeImpl::GetAllOutDataAnchorsPtr() const { - std::vector out_data_anchors; - out_data_anchors.reserve(out_data_anchors_.size()); - for (const auto &out_data_anchor : out_data_anchors_) { - out_data_anchors.emplace_back(out_data_anchor.get()); - } - return out_data_anchors; -} - -Node::Node() : enable_shared_from_this(), impl_(ComGraphMakeSharedAndThrow()) {} - -Node::Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph) - : enable_shared_from_this(), impl_(ComGraphMakeSharedAndThrow(op, owner_graph)) {} - -Node::~Node() {} - -graphStatus Node::Init() { - return impl_->Init(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetName() const { - return impl_->GetName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const char *Node::GetNamePtr() const { - return impl_->GetNamePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string Node::GetType() const { - return impl_->GetType(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const char *Node::GetTypePtr() const { - return impl_->GetTypePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeMembersAreEqual(const Node &r_node) const { - return impl_->NodeMembersAreEqual(*(r_node.impl_)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(const AnchorPtr &left_anchor, - const AnchorPtr &right_anchor, - const size_t i) const { - return impl_->NodeAnchorIsEqual(left_anchor, right_anchor, i); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeInConnectsAreEqual(const Node &r_node) const { - // 1.Verify all in data and control anchors size - const auto in_data_anchor_size = this->GetAllInDataAnchors().size(); - const auto r_in_data_anchor_size = r_node.GetAllInDataAnchors().size(); - if (in_data_anchor_size != r_in_data_anchor_size) { - REPORT_INNER_ERR_MSG("E18888", - "param node in data anchors count:%zu not equal to " - "this in data anchors count:%zu, verify failed, node name: %s.", - r_in_data_anchor_size, in_data_anchor_size, this->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Size of node's in data anchors verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - const auto l_in_anchors = this->GetAllInAnchors(); - const auto r_in_anchors = r_node.GetAllInAnchors(); - // Data anchors size equal, all anchors size not equal, means control anchor size not equal - const auto in_control_anchor_size = l_in_anchors.size() - in_data_anchor_size; - const auto r_in_control_anchor_size = r_in_anchors.size() - r_in_data_anchor_size; - if (in_control_anchor_size != r_in_control_anchor_size) { - REPORT_INNER_ERR_MSG("E18888", - "param node in control anchors count:%zu not equal to " - "this in control anchors count:%zu, verify failed, node name: %s.", - r_in_control_anchor_size, in_control_anchor_size, this->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Size of node's in control anchors verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - // 2.Verify all in data and control anchors connect info - for (size_t i = 0UL; i < this->GetAllInAnchors().size(); i++) { - // Verify data anchors - if (i < in_data_anchor_size) { - const auto &in_anchor = l_in_anchors.at(i); - const auto &r_in_anchor = r_in_anchors.at(i); - if (!(NodeAnchorIsEqual(in_anchor, r_in_anchor, i))) { - GELOGE(GRAPH_FAILED, "[Call][NodeAnchorIsEqual] Node's in data control anchor verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - } else { - // Verify control anchors - const auto &in_control_anchor = l_in_anchors.at(i); - const auto &r_in_control_anchor = r_in_anchors.at(i); - if (!(NodeAnchorIsEqual(in_control_anchor, r_in_control_anchor, i - in_data_anchor_size))) { - GELOGE(GRAPH_FAILED, "[Call][NodeAnchorIsEqual] Node's in control anchor verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeOutConnectsAreEqual(const Node &r_node) const { - // 1.Verify all out data and control anchors size - const auto l_out_data_anchors = this->GetAllOutDataAnchors(); - const auto r_out_data_anchors = r_node.GetAllOutDataAnchors(); - const auto out_data_anchor_size = l_out_data_anchors.size(); - const auto r_out_data_anchor_size = r_out_data_anchors.size(); - if (out_data_anchor_size != r_out_data_anchor_size) { - REPORT_INNER_ERR_MSG("E18888", - "param node out data anchors count:%zu not equal to " - "this out data anchors count:%zu, verify failed, node name: %s.", - r_out_data_anchor_size, out_data_anchor_size, this->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Size of node's out data anchors verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - const auto l_out_anchors = this->GetAllOutAnchors(); - const auto r_out_anchors = r_node.GetAllOutAnchors(); - // Data anchors size equal, all anchors size not equal, means control anchor size not equal - const auto out_control_anchor_size = l_out_anchors.size() - out_data_anchor_size; - const auto r_out_control_anchor_size = r_out_anchors.size() - r_out_data_anchor_size; - if (out_control_anchor_size != r_out_control_anchor_size) { - REPORT_INNER_ERR_MSG("E18888", - "param node out control anchors count:%zu not equal to " - "this out control anchors count:%zu, verify failed, node name: %s.", - r_out_control_anchor_size, out_control_anchor_size, this->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Size of node's out control anchors verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - - // 2.Verify all out data and control anchors connect info - for (size_t i = 0UL; i < this->GetAllOutAnchors().size(); i++) { - // Verify data anchors - if (i < out_data_anchor_size) { - const auto &out_anchor = l_out_data_anchors.at(i); - const auto &r_out_anchor = r_out_data_anchors.at(i); - if (!(NodeAnchorIsEqual(out_anchor, r_out_anchor, i))) { - GELOGE(GRAPH_FAILED, "[Call][NodeAnchorIsEqual] Node's out data control anchor verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - } else { - // Verify control anchors - const auto &out_control_anchor = l_out_anchors.at(i); - const auto &r_out_control_anchor = r_out_anchors.at(i); - if (!(NodeAnchorIsEqual(out_control_anchor, r_out_control_anchor, i - out_data_anchor_size))) { - GELOGE(GRAPH_FAILED, "[Call][NodeAnchorIsEqual] Node's out control anchor verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::operator==(const Node &r_node) const { - return (NodeMembersAreEqual(r_node) && NodeInConnectsAreEqual(r_node) && NodeOutConnectsAreEqual(r_node)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const NodePtr &input_node) { - return impl_->AddLinkFrom(input_node, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFrom(const uint32_t &index, - const NodePtr input_node) { - return impl_->AddLinkFrom(index, input_node, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::AddLinkFromForParse(const NodePtr &input_node) { - return impl_->AddLinkFromForParse(input_node, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus Node::AddLinkFrom(const std::string &name, const NodePtr input_node) { - return impl_->AddLinkFrom(name, input_node, shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr Node::GetOwnerComputeGraph() const { - return impl_->GetOwnerComputeGraph(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph *Node::GetOwnerComputeGraphBarePtr() const { - return impl_->GetOwnerComputeGraphBarePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::SetOwnerComputeGraph(const ComputeGraphPtr &graph) { - return impl_->SetOwnerComputeGraph(graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::ClearOwnerGraph(const ComputeGraphPtr &graph) { - return impl_->ClearOwnerGraph(graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInDataAnchors() const { - return impl_->GetAllInDataAnchors(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutDataAnchors() const { - return impl_->GetAllOutDataAnchors(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllInDataAnchorsSize() const { - return impl_->GetAllInDataAnchorsSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetAllOutDataAnchorsSize() const { - return impl_->GetAllOutDataAnchorsSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllInAnchors() const { - return impl_->GetAllInAnchors(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector Node::GetAllInAnchorsPtr() const { - return impl_->GetAllInAnchorsPtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetAllOutAnchors() const { - return impl_->GetAllOutAnchors(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector Node::GetAllOutAnchorsPtr() const { - return impl_->GetAllOutAnchorsPtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(const int32_t idx) const { - return impl_->GetInDataAnchor(idx); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(const int32_t idx) const { - return impl_->GetInAnchor(idx); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(const int32_t idx) const { - return impl_->GetOutAnchor(idx); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(const int32_t idx) const { - return impl_->GetOutDataAnchor(idx); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchorPtr Node::GetInControlAnchor() const { - return impl_->GetInControlAnchor(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchorPtr Node::GetOutControlAnchor() const { - return impl_->GetOutControlAnchor(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInNodes() const { - return impl_->GetInNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::IsAllInNodesSeen( - const std::unordered_set &nodes_seen) const { - return impl_->IsAllInNodesSeen(nodes_seen); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInDataNodes() const { - return impl_->GetInDataNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInControlNodes() const { - return impl_->GetInControlNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutNodes() const { - return impl_->GetOutNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetInAllNodes() const { - return impl_->GetInAllNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutDataNodes() const { - return impl_->GetOutDataNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector Node::GetOutDataNodesPtr() const { - return impl_->GetOutDataNodesPtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutDataNodesSize() const { - return impl_->GetOutDataNodesSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutControlNodesSize() const { - return impl_->GetOutControlNodesSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t Node::GetOutNodesSize() const { - return impl_->GetOutNodesSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutControlNodes() const { - return impl_->GetOutControlNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::GetOutAllNodes() const { - return impl_->GetOutAllNodes(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr Node::GetOpDesc() const { - return impl_->GetOpDesc(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc *Node::GetOpDescBarePtr() const { - return impl_->GetOpDescBarePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(const OpDescPtr &op_desc) { - return impl_->UpdateOpDesc(op_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> - Node::GetInDataNodesAndAnchors() const { - return impl_->GetInDataNodesAndAnchors(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor> -Node::GetOutDataNodesAndAnchors() const { - return impl_->GetOutDataNodesAndAnchors(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::AddSendEventId(const uint32_t event_id) { - impl_->AddSendEventId(event_id); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::AddRecvEventId(const uint32_t event_id) { - impl_->AddRecvEventId(event_id); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &Node::GetSendEventIdList() const { - return impl_->GetSendEventIdList(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &Node::GetRecvEventIdList() const { - return impl_->GetRecvEventIdList(); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::GetFusionInputFlowList( - kFusionDataFlowVec_t &fusion_input_list) { - impl_->GetFusionInputFlowList(fusion_input_list); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::GetFusionOutputFlowList( - kFusionDataFlowVec_t &fusion_output_list) { - impl_->GetFusionOutputFlowList(fusion_output_list); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::SetFusionInputFlowList( - const kFusionDataFlowVec_t &fusion_input_list) { - impl_->SetFusionInputFlowList(fusion_input_list); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::SetFusionOutputFlowList( - const kFusionDataFlowVec_t &fusion_output_list) { - impl_->SetFusionOutputFlowList(fusion_output_list); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::GetHostNode() const { - return impl_->GetHostNode(); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::SetHostNode(const bool is_host) { - impl_->SetHostNode(is_host); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void Node::SetOrigNode(const NodePtr &orignode) { - impl_->SetOrigNode(orignode); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr Node::GetOrigNode() { - return impl_->GetOrigNode(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t Node::GetInDataNodesSize() const { - return impl_->GetInDataNodesSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t Node::GetInControlNodesSize() const { - return impl_->GetInControlNodesSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t Node::GetInNodesSize() const { - return impl_->GetInNodesSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector Node::GetInNodesPtr() const { - return impl_->GetInNodesPtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector Node::GetOutNodesPtr() const { - return impl_->GetOutNodesPtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector Node::GetAllInDataAnchorsPtr() const { - return impl_->GetAllInDataAnchorsPtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector Node::GetAllOutDataAnchorsPtr() const { - return impl_->GetAllOutDataAnchorsPtr(); -} - -} // namespace ge diff --git a/graph/normal_graph/node_impl.h b/graph/normal_graph/node_impl.h deleted file mode 100644 index 407a2bcbeafd0710c6bd6d6dc7fec9d816eaa604..0000000000000000000000000000000000000000 --- a/graph/normal_graph/node_impl.h +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_NODE_IMPL_H_ -#define GRAPH_NODE_IMPL_H_ - -#include "graph/node.h" - -namespace ge { -class Node::NodeImpl { - public: - NodeImpl() = default; - NodeImpl(const OpDescPtr &op, const ComputeGraphPtr &owner_graph); - ~NodeImpl(); - graphStatus Init(const NodePtr &node); - std::string GetName() const; - const char *GetNamePtr() const; - std::string GetType() const; - const char *GetTypePtr() const; - bool NodeMembersAreEqual(const NodeImpl &r_node) const; - bool NodeAnchorIsEqual(const AnchorPtr &left_anchor, - const AnchorPtr &right_anchor, - const size_t i) const; - - graphStatus AddLinkFrom(const NodePtr &input_node, const NodePtr &owner_node); - graphStatus AddLinkFrom(const uint32_t &index, - const NodePtr &input_node, - const NodePtr &owner_node); - graphStatus AddLinkFromForParse(const NodePtr &input_node, const NodePtr &owner_node); - graphStatus AddLinkFrom(const std::string &name, const NodePtr &input_node, const NodePtr &owner_node); - // Get the node belong to which compute graph - // Normally, return value is not null - ComputeGraphPtr GetOwnerComputeGraph() const; - ComputeGraph *GetOwnerComputeGraphBarePtr() const; - graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); - graphStatus ClearOwnerGraph(const ComputeGraphPtr &graph); - - Node::Vistor GetAllInDataAnchors(const ConstNodePtr &node_ptr) const; - std::vector GetAllInDataAnchorsPtr() const; - Node::Vistor GetAllOutDataAnchors(const ConstNodePtr &node_ptr) const; - std::vector GetAllOutDataAnchorsPtr() const; - uint32_t GetAllInDataAnchorsSize() const; - uint32_t GetAllOutDataAnchorsSize() const; - Node::Vistor GetAllInAnchors(const ConstNodePtr &owner_node) const; - std::vector GetAllInAnchorsPtr() const; - Node::Vistor GetAllOutAnchors(const ConstNodePtr &owner_node) const; - std::vector GetAllOutAnchorsPtr() const; - InDataAnchorPtr GetInDataAnchor(const int32_t idx) const; - AnchorPtr GetInAnchor(const int32_t idx) const; - AnchorPtr GetOutAnchor(const int32_t idx) const; - OutDataAnchorPtr GetOutDataAnchor(const int32_t idx) const; - InControlAnchorPtr GetInControlAnchor() const; - OutControlAnchorPtr GetOutControlAnchor() const; - - Node::Vistor GetInAllNodes(const ConstNodePtr &owner_node) const; - Node::Vistor GetInNodes(const ConstNodePtr &owner_node) const; - std::vector GetInNodesPtr() const; - bool IsAllInNodesSeen(const std::unordered_set &nodes_seen) const; - Node::Vistor GetInDataNodes(const ConstNodePtr &owner_node) const; - Node::Vistor GetInControlNodes(const ConstNodePtr &owner_node) const; - Node::Vistor GetOutDataNodes(const ConstNodePtr &owner_node) const; - std::vector GetOutDataNodesPtr() const; - uint32_t GetOutDataNodesSize() const; - uint32_t GetOutControlNodesSize() const; - uint32_t GetOutNodesSize() const; - size_t GetInDataNodesSize() const; - size_t GetInControlNodesSize() const; - size_t GetInNodesSize() const; - Node::Vistor GetOutControlNodes(const ConstNodePtr &owner_node) const; - Node::Vistor GetOutNodes(const ConstNodePtr &owner_node) const; - Node::Vistor GetOutAllNodes(const ConstNodePtr &owner_node) const; - std::vector GetOutNodesPtr() const; - - OpDescPtr GetOpDesc() const; - OpDesc *GetOpDescBarePtr() const; - graphStatus UpdateOpDesc(const OpDescPtr &op_desc); - Node::Vistor> - GetInDataNodesAndAnchors(const ConstNodePtr &owner_node) const; - Node::Vistor> - GetOutDataNodesAndAnchors(const ConstNodePtr &owner_node) const; - - void AddSendEventId(const uint32_t event_id) { send_event_id_list_.push_back(event_id); } - void AddRecvEventId(const uint32_t event_id) { recv_event_id_list_.push_back(event_id); } - - const std::vector &GetSendEventIdList() const { return send_event_id_list_; } - const std::vector &GetRecvEventIdList() const { return recv_event_id_list_; } - - void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) const { - fusion_input_list = fusion_input_dataflow_list_; - } - void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) const { - fusion_output_list = fusion_output_dataflow_list_; - } - void SetFusionInputFlowList(const kFusionDataFlowVec_t &fusion_input_list) { - fusion_input_dataflow_list_ = fusion_input_list; - } - void SetFusionOutputFlowList(const kFusionDataFlowVec_t &fusion_output_list) { - fusion_output_dataflow_list_ = fusion_output_list; - } - - bool GetHostNode() const { return host_node_; } - void SetHostNode(const bool is_host) { host_node_ = is_host; } - - void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } - NodePtr GetOrigNode() { return orig_node_; } - - private: - friend class NodeUtils; - friend class TuningUtils; - friend class OnnxUtils; - OpDescPtr op_; - std::weak_ptr owner_graph_; - ComputeGraph *owner_graph_ptr_ = nullptr; - std::vector in_data_anchors_; - std::vector out_data_anchors_; - InControlAnchorPtr in_control_anchor_; - OutControlAnchorPtr out_control_anchor_; - bool has_init_{false}; - bool host_node_{false}; - bool anchor_status_updated_{false}; - std::vector send_event_id_list_; - std::vector recv_event_id_list_; - - kFusionDataFlowVec_t fusion_input_dataflow_list_; - kFusionDataFlowVec_t fusion_output_dataflow_list_; - NodePtr orig_node_{nullptr}; -}; -} // namespace ge -#endif // GRAPH_BUFFER_IMPL_H_ diff --git a/graph/normal_graph/op_desc.cc b/graph/normal_graph/op_desc.cc deleted file mode 100644 index 5effbcd9bd9a21810e6d9ecd3e19d0a0644118f0..0000000000000000000000000000000000000000 --- a/graph/normal_graph/op_desc.cc +++ /dev/null @@ -1,2238 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/op_desc.h" - -#include "base/err_msg.h" -#include "graph/common_error_codes.h" -#include "graph/operator_factory_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/ge_context.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/transformer_utils.h" -#include "graph/utils/node_utils.h" -#include "common/util/mem_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "register/op_tiling/op_tiling_constants.h" -#include "common/util/trace_manager/trace_manager.h" -#include "common/checker.h" -#include "graph/utils/op_desc_utils.h" - -namespace { -using std::make_pair; -using std::shared_ptr; - -void AddDynamicNameIndex(const std::map &dynamic_names_indexes, size_t insert_index, - std::map &name_indexes) { - // Update index in input_name_idx - for (auto &name_index : name_indexes) { - if (name_index.second >= (insert_index)) { - name_index.second += dynamic_names_indexes.size(); - } - } - name_indexes.insert(dynamic_names_indexes.cbegin(), dynamic_names_indexes.cend()); -} -} - -namespace ge { -TensorType::TensorType(DataType dt) { - tensor_type_impl_ = ComGraphMakeShared(); - if (tensor_type_impl_ != nullptr) { - tensor_type_impl_->GetMutableDateTypeSet().emplace(dt); - } -} - -TensorType::TensorType(const std::initializer_list &initial_types) { - tensor_type_impl_ = ComGraphMakeShared(); - if (tensor_type_impl_ != nullptr) { - tensor_type_impl_->GetMutableDateTypeSet() = initial_types; - } -} - -static const GeTensorDesc& InvalidGeTensorDesc() { - const static GeTensorDesc kGlobalInvalidGeTensorDesc; - return kGlobalInvalidGeTensorDesc; -} -const std::string ATTR_NAME_ID = "id"; - -const std::string ATTR_NAME_STREAM_ID = "stream_id"; - -const std::string ATTR_NAME_INPUT_NAME = "input_name"; - -const std::string ATTR_NAME_SRC_NAME = "src_name"; - -const std::string ATTR_NAME_SRC_INDEX = "src_index"; - -const std::string ATTR_NAME_INPUT = "input"; - -const std::string ATTR_NAME_INPUT_DESC = "input_desc"; - -const std::string ATTR_NAME_OUTPUT_DESC = "output_desc"; - -const std::string ATTR_NAME_DST_NAME = "dst_name"; - -const std::string ATTR_NAME_DST_INDEX = "dst_index"; - -const std::string ATTR_NAME_WORKSPACE = "workspace"; - -const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; - -const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; - -const std::string ATTR_NAME_OP_KERNEL_LIB_NAME = "_ge_attr_op_kernel_lib_name"; - -OpDescImpl::OpDescImpl() { - meta_data_.has_out_attr_ = true; -} - -OpDescImpl::OpDescImpl(const std::string &name, const std::string &type) : meta_data_(name, type) { - meta_data_.has_out_attr_ = true; -} - -OpDescImpl::OpDescImpl(const OpDescImpl &op_desc_impl) { - subgraph_instance_names_ = op_desc_impl.subgraph_instance_names_; - subgraph_names_to_index_ = op_desc_impl.subgraph_names_to_index_; - for (const auto &input_desc : op_desc_impl.inputs_desc_) { - inputs_desc_.emplace_back(MakeShared(*input_desc)); - } - input_name_idx_ = op_desc_impl.input_name_idx_; - for (const auto &output_desc : op_desc_impl.outputs_desc_) { - outputs_desc_.emplace_back(MakeShared(*output_desc)); - } - output_name_idx_ = op_desc_impl.output_name_idx_; - infer_func_ = op_desc_impl.infer_func_; - infer_format_func_ = op_desc_impl.infer_format_func_; - infer_value_range_func_ = op_desc_impl.infer_value_range_func_; - verifier_func_ = op_desc_impl.verifier_func_; - infer_data_slice_func_ = op_desc_impl.infer_data_slice_func_; - op_kernel_lib_name_ = op_desc_impl.op_kernel_lib_name_; - engine_name_ = op_desc_impl.engine_name_; - meta_data_ = op_desc_impl.meta_data_; - attrs_ = op_desc_impl.attrs_; - tiling_func_info_ = op_desc_impl.tiling_func_info_; - atomic_tiling_func_info_ = op_desc_impl.atomic_tiling_func_info_; -} - -OpDescImpl &OpDescImpl::operator=(const OpDescImpl &op_desc_impl) & { - if (this == &op_desc_impl) { - return *this; - } - subgraph_instance_names_ = op_desc_impl.subgraph_instance_names_; - subgraph_names_to_index_ = op_desc_impl.subgraph_names_to_index_; - for (const auto &input_desc : op_desc_impl.inputs_desc_) { - inputs_desc_.emplace_back(MakeShared(*input_desc)); - } - input_name_idx_ = op_desc_impl.input_name_idx_; - for (const auto &output_desc : op_desc_impl.outputs_desc_) { - outputs_desc_.emplace_back(MakeShared(*output_desc)); - } - output_name_idx_ = op_desc_impl.output_name_idx_; - infer_func_ = op_desc_impl.infer_func_; - infer_format_func_ = op_desc_impl.infer_format_func_; - infer_value_range_func_ = op_desc_impl.infer_value_range_func_; - verifier_func_ = op_desc_impl.verifier_func_; - infer_data_slice_func_ = op_desc_impl.infer_data_slice_func_; - op_kernel_lib_name_ = op_desc_impl.op_kernel_lib_name_; - engine_name_ = op_desc_impl.engine_name_; - meta_data_ = op_desc_impl.meta_data_; - attrs_ = op_desc_impl.attrs_; - tiling_func_info_ = op_desc_impl.tiling_func_info_; - atomic_tiling_func_info_ = op_desc_impl.atomic_tiling_func_info_; - return *this; -} - -OpDescImpl::OpDescImpl(const ge::proto::OpDef &op_def) : meta_data_(op_def.name(), op_def.type()) { - // proto deserialize meta_data - DeSerializeOpDefToMetaData(op_def); -} - -void OpDescImpl::DeSerializeOpDefToMetaData(const proto::OpDef &op_def) { - meta_data_.has_out_attr_ = op_def.has_out_attr(); - meta_data_.id_ = op_def.id(); - meta_data_.stream_id_ = op_def.stream_id(); - meta_data_.inputs_.clear(); - (void)meta_data_.inputs_.insert(meta_data_.inputs_.cend(), op_def.input().cbegin(), op_def.input().cend()); - meta_data_.input_names_.clear(); - (void)meta_data_.input_names_.insert(meta_data_.input_names_.cend(), - op_def.input_name().cbegin(), op_def.input_name().cend()); - meta_data_.src_names_.clear(); - (void)meta_data_.src_names_.insert(meta_data_.src_names_.cend(), - op_def.src_name().cbegin(), op_def.src_name().cend()); - meta_data_.src_indexes_.clear(); - (void)meta_data_.src_indexes_.insert(meta_data_.src_indexes_.cend(), - op_def.src_index().cbegin(), op_def.src_index().cend()); - meta_data_.dst_names_.clear(); - (void)meta_data_.dst_names_.insert(meta_data_.dst_names_.cend(), - op_def.dst_name().cbegin(), op_def.dst_name().cend()); - meta_data_.dst_indexes_.clear(); - (void)meta_data_.dst_indexes_.insert(meta_data_.dst_indexes_.cend(), - op_def.dst_index().cbegin(), op_def.dst_index().cend()); - meta_data_.input_offsets_.clear(); - (void)meta_data_.input_offsets_.insert(meta_data_.input_offsets_.cend(), - op_def.input_i().cbegin(), op_def.input_i().cend()); - meta_data_.output_offsets_.clear(); - (void)meta_data_.output_offsets_.insert(meta_data_.output_offsets_.cend(), - op_def.output_i().cbegin(), op_def.output_i().cend()); - meta_data_.workspaces.clear(); - (void)meta_data_.workspaces.insert(meta_data_.workspaces.cend(), - op_def.workspace().cbegin(), op_def.workspace().cend()); - meta_data_.workspace_bytes_list_.clear(); - (void)meta_data_.workspace_bytes_list_.insert(meta_data_.workspace_bytes_list_.cend(), - op_def.workspace_bytes().cbegin(), op_def.workspace_bytes().cend()); - meta_data_.is_input_consts_.clear(); - (void)meta_data_.is_input_consts_.insert(meta_data_.is_input_consts_.cend(), - op_def.is_input_const().cbegin(), op_def.is_input_const().cend()); - meta_data_.subgraph_names_.clear(); - (void)meta_data_.subgraph_names_.insert(meta_data_.subgraph_names_.cend(), - op_def.subgraph_name().cbegin(), op_def.subgraph_name().cend()); -} - -void OpDescImpl::SerializeMetaDataToOpDef(proto::OpDef * const op_def) { - op_def->set_name(meta_data_.name_); - op_def->set_type(meta_data_.type_); - op_def->set_has_out_attr(meta_data_.has_out_attr_); - op_def->set_id(meta_data_.id_); - op_def->set_stream_id(meta_data_.stream_id_); - op_def->clear_input(); - for (const auto &input : meta_data_.inputs_) {op_def->add_input(input);} - op_def->clear_input_name(); - for (const auto &input_name : meta_data_.input_names_) {op_def->add_input_name(input_name);} - op_def->clear_src_name(); - for (const auto &src_name : meta_data_.src_names_) {op_def->add_src_name(src_name);} - op_def->clear_src_index(); - for (const auto src_idx : meta_data_.src_indexes_) {op_def->add_src_index(src_idx);} - op_def->clear_dst_name(); - for (const auto &dst_name : meta_data_.dst_names_) {op_def->add_dst_name(dst_name);} - op_def->clear_dst_index(); - for (const auto dst_idx : meta_data_.dst_indexes_) {op_def->add_dst_index(dst_idx);} - op_def->clear_input_i(); - for (const auto input_i : meta_data_.input_offsets_) {op_def->add_input_i(input_i);} - op_def->clear_output_i(); - for (const auto output_i : meta_data_.output_offsets_) {op_def->add_output_i(output_i);} - op_def->clear_workspace(); - for (const auto workspace : meta_data_.workspaces) {op_def->add_workspace(workspace);} - op_def->clear_workspace_bytes(); - for (const auto workspace_bytes : meta_data_.workspace_bytes_list_) { - op_def->add_workspace_bytes(workspace_bytes); - } - op_def->clear_is_input_const(); - for (const auto is_input_const : meta_data_.is_input_consts_) { - op_def->add_is_input_const(is_input_const); - } - op_def->clear_subgraph_name(); - for (const auto &subgraph_name : meta_data_.subgraph_names_) { - op_def->add_subgraph_name(subgraph_name); - } -} - -string OpDescImpl::GetName() const { - return meta_data_.name_; -} - -const char *OpDescImpl::GetNamePtr() const { - return meta_data_.name_.c_str(); -} - -void OpDescImpl::SetName(const std::string &name) { - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "name", "", "", name); - meta_data_.SetOpName(name); -} - -const char *OpDescImpl::GetTypePtr() const { - return meta_data_.type_.c_str(); -} - -std::string OpDescImpl::GetType() const { - return meta_data_.type_; -} - -void OpDescImpl::SetType(const std::string &type) { - if (meta_data_.type_ == type) { - return; - } - meta_data_.type_ = type; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "type", "", "", type); -} - -void OpDescImpl::SetIrRelated(const OpDescImpl *r_op_desc) { - if (r_op_desc != nullptr) { - this->meta_data_.ir_meta_ = r_op_desc->meta_data_.ir_meta_; - } else { - this->meta_data_.ir_meta_ = IRMetaData(""); - } -} - -graphStatus OpDescImpl::AddInputDesc(const ge::GeTensorDesc &input_desc) { - const int32_t index = static_cast(inputs_desc_.size()); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "input_desc", "", "", index); - return AddInputDesc("__input" + std::to_string(index), input_desc); -} - -graphStatus OpDescImpl::AddInputDesc(const uint32_t index, const ge::GeTensorDesc &input_desc) { - graphStatus ret = GRAPH_SUCCESS; - if (index < inputs_desc_.size()) { - // InputsDesc[index] is exist, then update it - ret = UpdateInputDesc(index, input_desc); - } else { - // InputDesc[index] does not exist, then add it - ret = AddInputDesc(input_desc); - } - return ret; -} - -graphStatus OpDescImpl::AddInputDesc(const std::string &name, const ge::GeTensorDesc &input_desc) { - if (input_name_idx_.find(name) != input_name_idx_.end()) { - GELOGI("input %s is exist, update it", name.c_str()); - const graphStatus ret = UpdateInputDesc(name, input_desc); - return ret; - } else { - int32_t index = static_cast(inputs_desc_.size()); - const std::shared_ptr in_desc = ComGraphMakeShared(input_desc); - if (in_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddInputDesc failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] AddInputDesc failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - inputs_desc_.push_back(in_desc); - (void)input_name_idx_.insert(make_pair(name, index)); - (void)meta_data_.ir_meta_.AddRegisterInputName(name); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "input_desc:" << index, "", "", "input_name:" << name); - return GRAPH_SUCCESS; - } -} - -graphStatus OpDescImpl::AddInputDescMiddle(const std::string &name, const uint32_t num, const size_t index) { - std::map dynamic_names_indexes; - for (uint32_t i = 0U; i < num; i++) { - std::string input_name = name + std::to_string(i); - GE_CHK_BOOL_EXEC((input_name_idx_.find(input_name) == input_name_idx_.end()), - REPORT_INNER_ERR_MSG("E18888", "Add input tensor_desc is existed. name[%s]", input_name.c_str()); - GELOGE(ge::FAILED, "[Check][Param] Add input tensor_desc is existed. name[%s]", - input_name.c_str()); - return GRAPH_FAILED); - - const std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddInputDescMiddle failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] AddInputDescMiddle failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - if (index > inputs_desc_.size()) { - REPORT_INNER_ERR_MSG("E18888", "AddInputDescMiddle failed, as param index(%zu) " - "is bigger than inputs size(%zu).", index, inputs_desc_.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] AddInputDescMiddle failed, as param index(%zu) " - "is bigger than inputs size(%zu).", index, inputs_desc_.size()); - return GRAPH_FAILED; - } - - auto pos = inputs_desc_.begin(); - std::advance(pos, index + i); - (void)inputs_desc_.insert(pos, in_desc); - - dynamic_names_indexes.insert(make_pair(input_name, i + index)); - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "input_desc:" << (i + index), "", "", "input_name:" << input_name); - } - AddDynamicNameIndex(dynamic_names_indexes, index, input_name_idx_); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::AddOutputDescMiddle(const std::string &name, const uint32_t num, const size_t index) { - std::map dynamic_names_indexes; - for (uint32_t i = 0U; i < num; i++) { - std::string output_name = name + std::to_string(i); - GE_CHK_BOOL_EXEC((output_name_idx_.find(output_name) == output_name_idx_.end()), - REPORT_INNER_ERR_MSG("E18888", "Add output tensor_desc is existed. name[%s]", output_name.c_str()); - return GRAPH_FAILED, - "[Check][Param] Add output tensor_desc is existed. name[%s]", output_name.c_str()); - - const std::shared_ptr out_desc = ComGraphMakeShared(GeTensorDesc()); - if (out_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddOutputDescMiddle failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] AddOutputDescMiddle failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - if (index > outputs_desc_.size()) { - REPORT_INNER_ERR_MSG("E18888", "AddOutputDescMiddle failed, as param index(%zu) " - "is bigger than outputs size(%zu).", index, outputs_desc_.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] AddOutputDescMiddle failed, as param index(%zu) " - "is bigger than outputs size(%zu).", index, outputs_desc_.size()); - return GRAPH_FAILED; - } - - auto pos = outputs_desc_.begin(); - std::advance(pos, index + i); - (void) outputs_desc_.insert(pos, out_desc); - (void) dynamic_names_indexes.insert(make_pair(output_name, i + index)); - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "output_desc:" << (i + index), "", "", output_name); - } - AddDynamicNameIndex(dynamic_names_indexes, index, output_name_idx_); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::AddInputDescForward(const std::string &name, const uint32_t num) { - std::map dynamic_input_name_indexes; - for (uint32_t i = 0U; i < num; i++) { - std::string input_name = name + std::to_string(i); - GE_CHK_BOOL_EXEC((input_name_idx_.find(input_name) == input_name_idx_.end()), - REPORT_INNER_ERR_MSG("E18888", "Add input tensor_desc is existed. name[%s]", input_name.c_str()); - return GRAPH_FAILED, - "[Check][Param] Add input tensor_desc is existed. name[%s]", input_name.c_str()); - - const std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddInputDescForward failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] AddInputDescForward failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - (void)inputs_desc_.insert(inputs_desc_.cbegin(), in_desc); - - dynamic_input_name_indexes.insert(make_pair(input_name, i)); - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "input_desc:0", "", "", "input_name:" << input_name); - } - AddDynamicNameIndex(dynamic_input_name_indexes, 0U, input_name_idx_); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::AddOutputDescForward(const std::string &name, const uint32_t num) { - std::map output_name_indexes; - for (uint32_t i = 0U; i < num; i++) { - const auto index = i; - std::string output_name = name + std::to_string(index); - GE_CHK_BOOL_EXEC((output_name_idx_.find(output_name) == output_name_idx_.end()), - REPORT_INNER_ERR_MSG("E18888", "Add output tensor_desc is existed. name[%s]", output_name.c_str()); - return GRAPH_FAILED, - "[Check][Param] Add output tensor_desc is existed. name[%s]", output_name.c_str()); - - const std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); - if (in_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddOutputDescForward failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] AddOutputDescForward failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - (void)outputs_desc_.insert(outputs_desc_.cbegin(), in_desc); - output_name_indexes.insert(make_pair(output_name, index)); - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "output_desc:0", "", "", "output_name:" << output_name); - } - AddDynamicNameIndex(output_name_indexes, 0U, output_name_idx_); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::AddOptionalInputDesc(const std::string &name, - const ge::GeTensorDesc &input_desc) { - if (OpDescImpl::AddInputDesc(name, input_desc) == GRAPH_FAILED) { - return GRAPH_FAILED; - } - (void)meta_data_.ir_meta_.AddRegisterOptionalInputName(name); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::UpdateInputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_Desc) { - if (index >= inputs_desc_.size()) { - GELOGW("[UpdateInput][Check] Input index is invalid, index=%u, input_size=%zu", index, inputs_desc_.size()); - return GRAPH_FAILED; - } - - inputs_desc_[static_cast(index)] = ComGraphMakeShared(tensor_Desc); - if (inputs_desc_[static_cast(index)] == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "UpdateInputDesc failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] UpdateInputDesc failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "input_desc:" << index, "", "", - tensor_Desc.GetName()); - return GRAPH_SUCCESS; -} - -bool OpDescImpl::OpDescMembersAreEqual(const OpDescImpl &r_op_desc) const { - return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && - IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && - IsEqual(this->meta_data_.ir_meta_, r_op_desc.meta_data_.ir_meta_, "OpDesc.ir_mata_") && - IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && - IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); -} - -bool OpDescImpl::OpDescAttrsAreEqual(const OpDescImpl &r_op_desc) const { - // 看起来当前的本判等函数没有考虑属性,补一下UT确认一下 - const auto &r_data = r_op_desc.meta_data_; - return (IsEqual(meta_data_.name_, r_data.name_, "meta_data_.name_") && - IsEqual(meta_data_.type_, r_data.type_, "meta_data_.type_") && - IsEqual(meta_data_.inputs_, r_data.inputs_, "meta_data_.inputs_") && - IsEqual(meta_data_.has_out_attr_, r_data.has_out_attr_, "meta_data_.has_out_attr_") && - IsEqual(meta_data_.stream_id_, r_data.stream_id_, "meta_data_.stream_id_") && - IsEqual(meta_data_.input_names_, r_data.input_names_, "meta_data_.input_names_") && - IsEqual(meta_data_.src_names_, r_data.src_names_, "meta_data_.src_names_") && - IsEqual(meta_data_.dst_names_, r_data.dst_names_, "meta_data_.dst_names_") && - IsEqual(meta_data_.src_indexes_, r_data.src_indexes_, "meta_data_.src_indexes_") && - IsEqual(meta_data_.dst_indexes_, r_data.dst_indexes_, "meta_data_.dst_indexes_") && - IsEqual(meta_data_.input_offsets_, r_data.input_offsets_, "meta_data_.input_offsets_") && - IsEqual(meta_data_.output_offsets_, r_data.output_offsets_, "meta_data_.output_offsets_") && - IsEqual(meta_data_.workspaces, r_data.workspaces, "meta_data_.workspaces") && - IsEqual(meta_data_.workspace_bytes_list_, r_data.workspace_bytes_list_, - "meta_data_.workspace_bytes_list_") && - IsEqual(meta_data_.is_input_consts_, r_data.is_input_consts_, "meta_data_.is_input_consts_")); -} - -bool OpDescImpl::OpDescGenTensorDescsAreEqual(const OpDescImpl &r_op_desc) -const { - // 1.Verify inputs and outputs desc size - const auto inputs_desc_size = this->inputs_desc_.size(); - const auto r_inputs_desc_size = r_op_desc.inputs_desc_.size(); - if (inputs_desc_size != r_inputs_desc_size) { - REPORT_INNER_ERR_MSG("E18888", "param r_op_desc inputs count(%zu) not equal to %s inputs count(%zu), " - "verify failed.", r_inputs_desc_size, this->GetName().c_str(), inputs_desc_size); - GELOGE(GRAPH_FAILED, "[Check][Param] Size of OpDesc's inputs desc verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - const auto outputs_desc_size = this->outputs_desc_.size(); - const auto r_outputs_desc_size = r_op_desc.outputs_desc_.size(); - if (outputs_desc_size != r_outputs_desc_size) { - REPORT_INNER_ERR_MSG("E18888", "param r_op_desc outputs count(%zu) not equal to %s outputs count(%zu), " - "verify failed.", r_inputs_desc_size, this->GetName().c_str(), inputs_desc_size); - GELOGE(GRAPH_FAILED, "[Check][Param] Size of OpDesc's outputs desc verify failed, node name: %s.", - this->GetName().c_str()); - return false; - } - // 2.Verify all inputs desc equal - for (uint32_t i = 0U; i < inputs_desc_size; i++) { - const auto &in_ge_tensor_desc = this->GetInputDesc(i); - const auto &r_in_ge_tensor_desc = r_op_desc.GetInputDesc(i); - // Determine the connection relationship by GeTensorDesc - if (!(in_ge_tensor_desc == r_in_ge_tensor_desc)) { - REPORT_INNER_ERR_MSG("E18888", "r_op_desc inputdesc(index:%u) not equal to %s inputdesc(index:%u), " - "verify failed.", i, this->GetName().c_str(), i); - GELOGE(GRAPH_FAILED, "[Check][Param] Link info of OpDesc's inputs desc verify failed, OpDesc name: %s.", - this->GetName().c_str()); - return false; - } - } - // 3.Verify all outputs desc equal - for (uint32_t i = 0U; i < outputs_desc_size; i++) { - const auto &out_ge_tensor_desc = this->GetOutputDesc(i); - const auto &r_out_ge_tensor_desc = r_op_desc.GetOutputDesc(i); - if (!(out_ge_tensor_desc == r_out_ge_tensor_desc)) { - REPORT_INNER_ERR_MSG("E18888", "r_op_desc outputdesc(index:%u) not equal to %s outputdesc(index:%u), " - "verify failed.", i, this->GetName().c_str(), i); - GELOGE(GRAPH_FAILED, "[Check][Param] Link info of OpDesc's outputs desc verify failed, OpDesc name: %s.", - this->GetName().c_str()); - return false; - } - } - return true; -} - -graphStatus OpDescImpl::UpdateInputDesc(const std::string &name, const ge::GeTensorDesc &tensor_Desc) { - const auto it = input_name_idx_.find(name); - if (it == input_name_idx_.end()) { - GELOGW("[UpdateInput][Check] Can not find input desc named %s", name.c_str()); - return GRAPH_FAILED; - } - if (it->second >= inputs_desc_.size()) { - REPORT_INNER_ERR_MSG("E18888", "%u is out of range(0, %zu), check invalid", it->second, inputs_desc_.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] [%u] more than size:%zu of inputs_desc_", it->second, inputs_desc_.size()); - return GRAPH_FAILED; - } - - inputs_desc_[static_cast(it->second)] = ComGraphMakeShared(tensor_Desc); - if (inputs_desc_[static_cast(it->second)] == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "UpdateInputDesc failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] UpdateInputDesc failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "input_desc:" << it->second, "", "", tensor_Desc.GetName()); - return GRAPH_SUCCESS; -} - -bool OpDescImpl::InputIsSet(const std::string &name) const { - const auto it = input_name_idx_.find(name); - if (it != input_name_idx_.end()) { - GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), - REPORT_INNER_ERR_MSG("E18888", "input name(%s) id(%u) is out of range(0, %zu), check invalid", - name.c_str(), it->second, inputs_desc_.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] it->second is invalid."); - return false); - const auto tensor_desc = inputs_desc_[static_cast(it->second)]; - GE_IF_BOOL_EXEC(tensor_desc == nullptr, - REPORT_INNER_ERR_MSG("E18888", "tensor_desc(index:%u) is null.", it->second); - GELOGE(GRAPH_FAILED, "[Check][Param] tensor_desc(index:%u) is null.", it->second); return false); - const auto dims = tensor_desc->GetShape().GetDims(); - if (dims.size() > 0U) { - return true; - } - } - return false; -} - -const GeTensorDesc &OpDescImpl::GetInputDesc(const uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG(index < inputs_desc_.size(), InvalidGeTensorDesc()); - return *(inputs_desc_[static_cast(index)].get()); -} - -const GeTensorDesc &OpDescImpl::GetInputDesc(const std::string &name) const { - const auto it = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), InvalidGeTensorDesc()); - GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), InvalidGeTensorDesc()); - return *(inputs_desc_[static_cast(it->second)].get()); -} - -GeTensorDescPtr OpDescImpl::MutableInputDesc(const uint32_t index) const { - if (index >= inputs_desc_.size()) { - GELOGW("[Get][InputDesc] Failed to get input desc [%u]", index); - return nullptr; - } - if (inputs_desc_[static_cast(index)] == nullptr) { - return nullptr; - } - if (inputs_desc_[static_cast(index)]->IsValid() != GRAPH_SUCCESS) { - GELOGD("[Get][InputDesc] Input desc is invalid [%u]", index); - return nullptr; - } - return inputs_desc_[static_cast(index)]; -} - -GeTensorDescPtr OpDescImpl::MutableInputDesc(const std::string &name) const { - auto input_name_idx = GetAllInputName(); - const std::map::const_iterator it = input_name_idx.find(name); - if (it == input_name_idx.cend()) { - GELOGW("[Get][InputDesc] Failed to get [%s] input desc", name.c_str()); - return nullptr; - } - return MutableInputDesc(it->second); -} - -OpDesc::Vistor OpDescImpl::GetAllInputNames(const ConstOpDescPtr &op_desc) const { - std::vector names; - if (input_name_idx_.empty()) { - return OpDesc::Vistor(op_desc, names); - } - for (const std::pair input : input_name_idx_) { - names.push_back(input.first); - } - return OpDesc::Vistor(op_desc, names); -} - -void OpDescImpl::SetOpKernelLibName(const std::string &name) { - op_kernel_lib_name_ = name; -} - -std::string OpDescImpl::GetOpKernelLibName() const { - if (!op_kernel_lib_name_.empty()) { - return op_kernel_lib_name_; - } - return ""; -} - -void OpDescImpl::SetOpEngineName(const std::string &name) { - engine_name_ = name; -} - -std::string OpDescImpl::GetOpEngineName() const { return engine_name_; } - -OpDesc::Vistor OpDescImpl::GetAllInputsDesc(const ConstOpDescPtr &op_desc) const { - std::vector temp{}; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - temp.push_back(*it); - } else { - GELOGW("[Get][InputDesc] This input_desc is invalid, it won't be return"); - continue; - } - } - return OpDesc::Vistor(op_desc, temp); -} - -OpDesc::Vistor OpDescImpl::GetAllInputsDescPtr(const ConstOpDescPtr &op_desc) const { - std::vector temp{}; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - temp.push_back(it); - } else { - GELOGD("[Get][InputDesc] This input_desc is invalid, it won't be return"); - continue; - } - } - return OpDesc::Vistor(op_desc, temp); -} - -size_t OpDescImpl::GetInputsSize() const { - // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register. - size_t size = 0U; - for (const auto &it : inputs_desc_) { - if (it->IsValid() == GRAPH_SUCCESS) { - size++; - } - } - return size; -} - -size_t OpDescImpl::GetAllInputsSize() const { return inputs_desc_.size(); } - -graphStatus OpDescImpl::AddOutputDesc(const ge::GeTensorDesc &output_desc) { - const int32_t index = static_cast(outputs_desc_.size()); - return AddOutputDesc("__output" + std::to_string(index), output_desc); -} - -graphStatus OpDescImpl::AddOutputDesc(const std::string &name, const ge::GeTensorDesc &output_desc) { - GE_CHK_BOOL_EXEC((output_name_idx_.find(name) == output_name_idx_.end()), - REPORT_INNER_ERR_MSG("E18888", "Add output tensor_Desc is existed. name[%s]", name.c_str()); - return GRAPH_FAILED, - "[Check][Param] Add output tensor_Desc is existed. name[%s]", name.c_str()); - const int32_t index = static_cast(outputs_desc_.size()); - - const std::shared_ptr tensor = ComGraphMakeShared(output_desc); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddOutputDesc failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] AddOutputDesc failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - outputs_desc_.push_back(tensor); - (void)output_name_idx_.insert(make_pair(name, index)); - (void)meta_data_.ir_meta_.AddRegisterOutputName(name); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "output_desc:" << index, "", "", "output_name:" << name); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::UpdateOutputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_Desc) { - GE_CHK_BOOL_EXEC((index < outputs_desc_.size()), - REPORT_INNER_ERR_MSG("E18888", "param index(%u) is out of range(0, %zu), check invalid", - index, outputs_desc_.size()); - return GRAPH_FAILED, - "[Check][Param] The index is invalid. index[%u]", index); - outputs_desc_[static_cast(index)] = ComGraphMakeShared(tensor_Desc); - if (outputs_desc_[static_cast(index)] == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "UpdateOutputDesc failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] UpdateOutputDesc failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "output_desc:" << index, "", "", tensor_Desc.GetName()); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::UpdateOutputDesc(const std::string &name, const ge::GeTensorDesc &tensor_Desc) { - const auto it = output_name_idx_.find(name); - if (it == output_name_idx_.end()) { - GELOGW("[Update][OutputDesc] Can not find the output desc named %s", name.c_str()); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(it->second >= outputs_desc_.size(), - REPORT_INNER_ERR_MSG("E18888", "output name(%s) idx(%u) is out of range(0, %zu), check invalid", - name.c_str(), it->second, outputs_desc_.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] it->second is invalid."); - return GRAPH_FAILED); - outputs_desc_[static_cast(it->second)] = ComGraphMakeShared(tensor_Desc); - if (outputs_desc_[static_cast(it->second)] == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "UpdateOutputDesc failed, as malloc shared_ptr failed."); - GELOGE(GRAPH_FAILED, "[Create][GeTensorDesc] UpdateOutputDesc failed, as malloc shared_ptr failed."); - return GRAPH_FAILED; - } - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "output_desc:" << it->second, "", "", tensor_Desc.GetName()); - return GRAPH_SUCCESS; -} - -const GeTensorDesc &OpDescImpl::GetOutputDesc(const uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG(static_cast(index) < outputs_desc_.size(), InvalidGeTensorDesc()); - return *(outputs_desc_[static_cast(index)].get()); -} - -const GeTensorDesc &OpDescImpl::GetOutputDesc(const std::string &name) const { - const auto it = output_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), InvalidGeTensorDesc()); - GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < outputs_desc_.size(), InvalidGeTensorDesc()); - return *(outputs_desc_[static_cast(it->second)].get()); -} - -GeTensorDescPtr OpDescImpl::MutableOutputDesc(const uint32_t index) const { - if (index < outputs_desc_.size()) { - return outputs_desc_[static_cast(index)]; - } - GELOGW("[Get][OutputDesc] Failed to get output desc [%u], output number [%zu]", index, outputs_desc_.size()); - return nullptr; -} - -GeTensorDescPtr OpDescImpl::MutableOutputDesc(const std::string &name) const { - const auto it = output_name_idx_.find(name); - if (it == output_name_idx_.end()) { - GELOGW("[Update][OutputDesc] Can not find the output desc named %s", name.c_str()); - return nullptr; - } - return MutableOutputDesc(it->second); -} - -uint32_t OpDescImpl::GetAllOutputsDescSize() const { - return static_cast(outputs_desc_.size()); -} - -OpDesc::Vistor OpDescImpl::GetAllOutputsDesc(const ConstOpDescPtr &op_desc) const { - std::vector temp{}; - for (const auto &it : outputs_desc_) { - temp.push_back(*it); - } - return OpDesc::Vistor(op_desc, temp); -} - -OpDesc::Vistor OpDescImpl::GetAllOutputsDescPtr(const ConstOpDescPtr &op_desc) const { - return OpDesc::Vistor(op_desc, outputs_desc_); -} - -size_t OpDescImpl::GetOutputsSize() const { return outputs_desc_.size(); } - -ConstGeTensorDescPtr OpDescImpl::GetOutputDescPtr(const uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(outputs_desc_.size()), nullptr); - return outputs_desc_[static_cast(index)]; -} - -ConstGeTensorDescPtr OpDescImpl::GetInputDescPtr(const uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(inputs_desc_.size()), nullptr); - if (inputs_desc_[static_cast(index)] == nullptr) { - return nullptr; - } - if (inputs_desc_[static_cast(index)]->IsValid() != GRAPH_SUCCESS) { - GELOGW("[Get][InputDesc] Input desc %u is invalid", index); - return nullptr; - } else { - return inputs_desc_[static_cast(index)]; - } -} - -ConstGeTensorDescPtr OpDescImpl::GetInputDescPtrDfault(const uint32_t index) const { - GE_CHK_BOOL_RET_STATUS_NOLOG((index) < static_cast(inputs_desc_.size()), nullptr); - return inputs_desc_[static_cast(index)]; -} - -ConstGeTensorDescPtr OpDescImpl::GetInputDescPtr(const std::string &name) const { - const auto it = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr()); - return inputs_desc_[static_cast(it->second)]; -} - -graphStatus OpDescImpl::AddDynamicInputDesc(const std::string &name, const uint32_t num, const bool is_push_back) { - if (is_push_back) { - for (uint32_t i = 0U; i < num; i++) { - if (AddInputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } else { - if (AddInputDescForward(name, num) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - if (meta_data_.ir_meta_.AddRegisterInputName(name) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::AddDynamicInputDescByIndex(const std::string &name, const uint32_t num, const size_t index) { - if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::AddDynamicOutputDesc(const std::string &name, const uint32_t num, const bool is_push_back) { - if (is_push_back) { - for (uint32_t i = 0U; i < num; i++) { - if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } else { - if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - - if (meta_data_.ir_meta_.AddRegisterOutputName(name) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -bool OpDescImpl::IsOptionalInput(const uint32_t index) const { - return meta_data_.ir_meta_.IsOptionalInput(GetInputNameByIndex(index)); -} - -std::map OpDescImpl::GetAllInputName() const { return input_name_idx_; } - -std::map OpDescImpl::GetAllOutputName() { return output_name_idx_; } - -std::map& OpDescImpl::MutableAllInputName() { return input_name_idx_; } - -std::map& OpDescImpl::MutableAllOutputName() { return output_name_idx_; } - -bool OpDescImpl::UpdateInputName(std::map input_name_idx) { - const auto &ir_inputs = GetIRMeta().GetIrInputs(); - size_t last_dyn = ir_inputs.size(); - for (size_t i = 0UL; i < ir_inputs.size(); ++i) { - if (ir_inputs[i].second == kIrInputDynamic) { - last_dyn = i; - } - } - - if (last_dyn + 1UL < ir_inputs.size()) { - GELOGW( - "[Update][InputName] Dynamic input in middle is unsupported to update from factory, last_dyn=%zu ir_size=%zu.", - last_dyn, ir_inputs.size()); - return false; - } - - // Use inputDesc_.size() to contain the InValid OptionInput.GetInputsSize() will remove default OptionInput name. - const auto input_map_size = inputs_desc_.size(); - const auto factory_map_size = input_name_idx.size(); - // It indicates that some inputs have no optional name. - // The redundant optional name of factory needs to be deleted and then assigned - if (input_map_size < factory_map_size) { - GELOGI("org_input_name_num=%zu, factory_input_name_num=%zu", input_map_size, factory_map_size); - for (auto it = input_name_idx.begin(); it != input_name_idx.end();) { - if (it->second >= input_map_size) { - it = input_name_idx.erase(it); - } else { - ++it; - } - } - if (input_name_idx.size() == input_map_size) { - GELOGI("UpdateInputName"); - input_name_idx_ = input_name_idx; - } else { - GELOGW("[Update][InputName] After update, org_input_name_num=%zu, factory_input_name_num=%zu", input_map_size, - input_name_idx.size()); - return false; - } - } else if (input_map_size == factory_map_size) { - input_name_idx_ = input_name_idx; - } else { - GELOGW("[Update][InputName] factory_input_name_num can not be less than org_input_name_num, exactly " - "org_input_name_num=%zu, factory_input_name_num=%zu", input_map_size, factory_map_size); - return false; - } - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "input_name_idx", "", "", ""); - return true; -} - -bool OpDescImpl::UpdateOutputName(std::map output_name_idx) { - const size_t output_map_size = GetAllOutputsDescSize(); - const size_t factory_map_size = output_name_idx.size(); - if (output_map_size < factory_map_size) { - GELOGI("org_output_name_num=%zu, factory_output_name_num=%zu", output_map_size, factory_map_size); - for (auto it = output_name_idx.begin(); it != output_name_idx.end();) { - if (it->second >= output_map_size) { - it = output_name_idx.erase(it); - } else { - ++it; - } - } - if (output_name_idx.size() == output_map_size) { - GELOGI("UpdateOutputName"); - output_name_idx_ = output_name_idx; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "output_name_idx", "", "", ""); - return true; - } - } else if (output_map_size == factory_map_size) { - output_name_idx_ = output_name_idx; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "output_name_idx", "", "", ""); - return true; - } else { - GELOGW("[Update][OutputName] factory_output_name_num can not be less than org_output_name_num, exactly " - "org_output_name_num=%zu, factory_output_name_num=%zu", output_map_size, output_name_idx.size()); - return false; - } - GELOGW("[Update][OutputName] After update, org_output_name_num=%zu, factory_output_name_num=%zu", output_map_size, - factory_map_size); - return false; -} - -void *OpDescImpl::GetTilingFuncInfo() const { - return tiling_func_info_; -} - -void OpDescImpl::SetTilingFuncInfo(void *tiling_func_info) { - tiling_func_info_ = tiling_func_info; -} - -void *OpDescImpl::GetAtomicTilingFuncInfo() const { - return atomic_tiling_func_info_; -} - -void OpDescImpl::SetAtomicTilingFuncInfo(void *atomic_tiling_func_info) { - atomic_tiling_func_info_ = atomic_tiling_func_info; -} - -std::function OpDescImpl::GetInferFunc() const { return infer_func_; } - -std::function OpDescImpl::GetVerifyFunc() const { return verifier_func_; } - -std::function OpDescImpl::GetInferFormatFunc() const { return infer_format_func_; } - -std::function OpDescImpl::GetInferValueRangeFunc() const { return infer_value_range_func_; } - -std::function OpDescImpl::GetInferDataSliceFunc() const { return infer_data_slice_func_; } - -void OpDescImpl::AddInferFunc(const std::function &func) { infer_func_ = func; } - -void OpDescImpl::AddInferFormatFunc(const std::function &func) { infer_format_func_ = func; } - -void OpDescImpl::AddVerifierFunc(const std::function &func) { verifier_func_ = func; } - -void OpDescImpl::AddInferValueRangeFunc(const std::function &func) { - infer_value_range_func_ = func; -} - -void OpDescImpl::AddInferDataSliceFunc(const std::function &func) { - infer_data_slice_func_ = func; -} - -graphStatus OpDescImpl::DefaultInferFormat(const ConstOpDescPtr &op_desc) const { - ge::Format first_none_nd_format = FORMAT_ND; - const auto input_descs = GetAllInputsDescPtr(op_desc); - const auto output_descs = GetAllOutputsDescPtr(op_desc); - // Overall input and output,get the first non-nd format - for (const auto &input_desc : input_descs) { - const Format origin_format = input_desc->GetOriginFormat(); - if (origin_format != FORMAT_ND) { - first_none_nd_format = origin_format; - break; - } - } - for (const auto &output_desc : output_descs) { - const Format origin_format = output_desc->GetOriginFormat(); - if (origin_format != FORMAT_ND) { - first_none_nd_format = origin_format; - break; - } - } - // Refresh all input output format - GELOGD("Default infer format.node[%s], first none nod format is:%d", GetName().c_str(), first_none_nd_format); - - for (const auto &input_desc : input_descs) { - const Format origin_format = input_desc->GetOriginFormat(); - GELOGD("Default infer format[in].node[%s].origin format is:%d", GetName().c_str(), origin_format); - if (origin_format == FORMAT_ND) { - input_desc->SetOriginFormat(first_none_nd_format); - input_desc->SetFormat(first_none_nd_format); - } - } - for (const auto &output_desc : output_descs) { - const Format origin_format = output_desc->GetOriginFormat(); - GELOGD("Default infer format[out].node[%s].origin format is:%d", GetName().c_str(), origin_format); - if (origin_format == FORMAT_ND) { - output_desc->SetOriginFormat(first_none_nd_format); - output_desc->SetFormat(first_none_nd_format); - } - } - return GRAPH_SUCCESS; -} - -std::string OpDescImpl::GetInputNameByIndex(const uint32_t index) const { - auto it = input_name_idx_.begin(); - for (; it != input_name_idx_.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); - return it->first; -} - -int32_t OpDescImpl::GetInputIndexByName(const std::string &name) const { - const auto it_find = input_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); - return static_cast(it_find->second); -} - -graphStatus OpDescImpl::GetDynamicInputIndexesByName(const std::string &name, std::vector &indexes) const { - uint32_t previous_value = UINT32_MAX; // 初始化为一个不可能的值 - - for (size_t i = 0U; ; i++) { - const auto it_find = input_name_idx_.find(name + std::to_string(i)); - if (it_find != input_name_idx_.end()) { - if (previous_value != UINT32_MAX && it_find->second != previous_value + 1U) { - // 如果当前的 value 不是前一个 value 加一,返回 GRAPH_FAILED - GELOGE(GRAPH_FAILED, "The dynamic input index is not continuous."); - return GRAPH_FAILED; - } - indexes.emplace_back(it_find->second); - previous_value = it_find->second; - } else { - break; - } - } - return GRAPH_SUCCESS; -} - -std::string OpDescImpl::GetValidInputNameByIndex(const uint32_t index) const { - std::map valid_input_name_idx {}; - uint32_t j = 0U; - for (size_t i = 0U; i < GetAllInputsSize(); i++) { - if (MutableInputDesc(static_cast(i)) != nullptr) { - const auto valid_name = GetInputNameByIndex(static_cast(i)); - GE_CHK_BOOL_RET_STATUS_NOLOG(!valid_name.empty(), ""); - (void)valid_input_name_idx.insert({valid_name, j}); - j++; - } - } - auto it = valid_input_name_idx.begin(); - for (; it != valid_input_name_idx.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != valid_input_name_idx.end(), ""); - return it->first; -} - -std::string OpDescImpl::GetOutputNameByIndex(const uint32_t index) const { - auto it = output_name_idx_.begin(); - for (; it != output_name_idx_.end(); ++it) { - if (it->second == index) { - break; - } - } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != output_name_idx_.end(), ""); - return it->first; -} - -int32_t OpDescImpl::GetOutputIndexByName(const std::string &name) const { - const auto it_find = output_name_idx_.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != output_name_idx_.end(), -1); - return static_cast(it_find->second); -} - -graphStatus OpDescImpl::GetDynamicOutputIndexesByName(const std::string &name, std::vector &indexes) const { - uint32_t previous_value = UINT32_MAX; - - for (size_t i = 0U; ; i++) { - const auto it_find = output_name_idx_.find(name + std::to_string(i)); - if (it_find != output_name_idx_.end()) { - if (previous_value != UINT32_MAX && it_find->second != previous_value + 1U) { - GELOGE(GRAPH_FAILED, "The dynamic output index is not continuous."); - return GRAPH_FAILED; - } - indexes.emplace_back(it_find->second); - previous_value = it_find->second; - } else { - break; - } - } - return GRAPH_SUCCESS; -} - -ProtoAttrMap &OpDescImpl::MutableAttrMap() { - return attrs_; -} - -ConstProtoAttrMap &OpDescImpl::GetAttrMap() const { - return attrs_; -} - -void OpDescImpl::SetId(const int64_t id) { - meta_data_.id_ = id; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "id", "", "", id); -} - -int64_t OpDescImpl::GetId() const { - return meta_data_.id_; -} - -void OpDescImpl::SetStreamId(const int64_t stream_id) { - meta_data_.stream_id_ = stream_id; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "stream_id", "", "", stream_id); -} - -int64_t OpDescImpl::GetStreamId() const { - return meta_data_.stream_id_; -} - -void OpDescImpl::SetInputName(const vector &input_name) { - meta_data_.input_names_ = input_name; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "input_name", "", "", ""); -} - -vector OpDescImpl::GetInputName() const { - return meta_data_.input_names_; -} - -void OpDescImpl::SetSrcName(const vector &src_name) { - meta_data_.src_names_ = src_name; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "src_name", "", "", ""); -} - -vector OpDescImpl::GetSrcName() const { - return meta_data_.src_names_; -} - -void OpDescImpl::SetSrcIndex(const vector &src_index) { - meta_data_.src_indexes_ = src_index; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "src_index", "", "", ""); -} - -vector OpDescImpl::GetSrcIndex() const { - return meta_data_.src_indexes_; -} - -void OpDescImpl::SetInputOffset(const vector &input) { - meta_data_.input_offsets_ = input; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "input_offset", "", "", ""); -} - -vector OpDescImpl::GetInputOffset() const { - return meta_data_.input_offsets_; -} - -void OpDescImpl::SetOutputOffset(const vector &output) { - meta_data_.output_offsets_ = output; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "out_offset", "", "", ""); -} - -vector OpDescImpl::GetOutputOffset() const { - return meta_data_.output_offsets_; -} - -void OpDescImpl::SetDstName(const vector &dst_name) { - meta_data_.dst_names_ = dst_name; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "dst_name", "", "", ""); -} - -vector OpDescImpl::GetDstName() const { - return meta_data_.dst_names_; -} - -void OpDescImpl::SetDstIndex(const vector &dst_index) { - meta_data_.dst_indexes_ = dst_index; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "dst_index", "", "", ""); -} - -void OpDescImpl::SetWorkspace(const vector &workspace) { - meta_data_.workspaces.assign(workspace.cbegin(), workspace.cend()); -} - -vector OpDescImpl::GetWorkspace() const { - vector res(meta_data_.workspaces.size()); - for (size_t i = 0UL; i < meta_data_.workspaces.size(); ++i) { - res[i] = meta_data_.workspaces[i]; - } - return res; -} - -void OpDescImpl::SetWorkspaceBytes(const vector &workspace_bytes) { - meta_data_.workspace_bytes_list_.assign(workspace_bytes.cbegin(), workspace_bytes.cend()); -} - -vector OpDescImpl::GetWorkspaceBytes() const { - vector res(meta_data_.workspace_bytes_list_.size()); - for (size_t i = 0UL; i < meta_data_.workspace_bytes_list_.size(); ++i) { - res[i] = meta_data_.workspace_bytes_list_[i]; - } - return res; -} - -void OpDescImpl::SetIsInputConst(const vector &is_input_const) { - meta_data_.is_input_consts_ = is_input_const; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - this->GetName(), "is_input_const", "", "", ""); -} - -vector OpDescImpl::GetIsInputConst() const { - return meta_data_.is_input_consts_; -} - -std::string OpDescImpl::GetSubgraphInstanceName(const size_t index) const { - if (index >= subgraph_instance_names_.size()) { - return ""; - } - return subgraph_instance_names_.at(index); -} - -const std::vector &OpDescImpl::GetSubgraphInstanceNames() const { - return subgraph_instance_names_; -} - -void OpDescImpl::RemoveSubgraphInstanceName(const std::string &name) { - for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { - if ((*iter) == name) { - *iter = ""; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "delete", TraceManager::GetOutGraphName(), - this->GetName(), "subgraph_instance_name", "", "", name); - return; - } - } -} - -graphStatus OpDescImpl::AddSubgraphName(const std::string &name) { - GELOGI("Add subgraph name is %s", name.c_str()); - const std::map::const_iterator iter = subgraph_names_to_index_.find(name); - if (iter != subgraph_names_to_index_.cend()) { - GELOGW("[Add][Subgraph] Subgraph name %s exists, index %u", name.c_str(), iter->second); - return GRAPH_FAILED; - } - const auto size = subgraph_names_to_index_.size(); - subgraph_names_to_index_[name] = size; - subgraph_instance_names_.resize(size + 1U); - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), - this->GetName(), "subgraph_name", "", "", name); - return GRAPH_SUCCESS; -} - -const std::map &OpDescImpl::GetSubgraphNameIndexes() const { - return subgraph_names_to_index_; -} - -graphStatus OpDescImpl::SetSubgraphInstanceName(const size_t index, const std::string &name) { - GELOGI("Add sub graph instance name is %s, index is %zu", name.c_str(), index); - if (index >= subgraph_instance_names_.size()) { - REPORT_INNER_ERR_MSG("E18888", "Index %zu exceeds the max instance count %zu", index, - subgraph_instance_names_.size()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Index %zu exceeds the max instance count %zu", index, - subgraph_instance_names_.size()); - return GRAPH_PARAM_INVALID; - } - subgraph_instance_names_[index] = name; - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "add", TraceManager::GetOutGraphName(), this->GetName(), - "subgraph_instance_index:" << index, "", "", "name:" << name); - return GRAPH_SUCCESS; -} - -graphStatus OpDescImpl::GetSubgraphNameByInstanceName(const std::string &instance_name, - std::string &subgraph_name) const { - for (size_t idx = 0U; idx < subgraph_instance_names_.size(); ++idx) { - if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. - continue; - } - - for (const auto &name_to_index : subgraph_names_to_index_) { - if (name_to_index.second != idx) { // find subgraph name. - continue; - } - - subgraph_name = name_to_index.first; - return GRAPH_SUCCESS; - } - } - - return GRAPH_PARAM_INVALID; -} - -IRMetaData &OpDescImpl::MutableIRMeta() { - return meta_data_.ir_meta_; -} -const IRMetaData &OpDescImpl::GetIRMeta() const { - return meta_data_.ir_meta_; -} - -bool OpDescImpl::IsSupportSymbolicInferDataType() const { - return GetIRMeta().GetIRDataTypeSymbolStore().IsSupportSymbolicInferDtype(); -} -graphStatus OpDescImpl::SymbolicInferDataType(const OpDescPtr &op_desc) const { - return GetIRMeta().GetIRDataTypeSymbolStore().InferDtype(op_desc); -} - -size_t OpDescImpl::GetIrInputsSize() const { - return meta_data_.ir_meta_.GetIrInputs().size(); -} - -std::map OpDescImpl::GetAllOutputIndexToName() { - std::map idx2name; - for (auto &item : output_name_idx_) { - GELOGD("[%s:%s] get output name %s with idx %zu", GetNamePtr(), GetTypePtr(), item.first.c_str(), item.second); - idx2name.emplace(item.second, item.first); - } - GE_ASSERT_EQ(idx2name.size(), output_name_idx_.size()); - return idx2name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() - : enable_shared_from_this(), AttrHolder(), impl_(ComGraphMakeSharedAndThrow()) { -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::~OpDesc() {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const std::string &name, const std::string &type) - : enable_shared_from_this(), AttrHolder(), impl_(ComGraphMakeSharedAndThrow(name, type)) {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const OpDesc &op_desc) - : enable_shared_from_this(), AttrHolder(op_desc), - impl_(ComGraphMakeSharedAndThrow(*(op_desc.impl_))) {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(OpDesc &&op_desc) - : enable_shared_from_this(), AttrHolder(std::move(op_desc)), - impl_(ComGraphMakeSharedAndThrow(std::move(*(op_desc.impl_)))) {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc(const ge::proto::OpDef &op_def) - : enable_shared_from_this(), AttrHolder(), impl_(ComGraphMakeSharedAndThrow(op_def)) {} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetName() const { - return impl_->GetName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const char *OpDesc::GetNamePtr() const { - return impl_->GetNamePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetNamePtr(const char_t *name) { - std::string name_str = name == nullptr ? "" : std::string(name); - return impl_->SetName(name_str); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetName(const std::string &name) { - return impl_->SetName(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetType() const { - return impl_->GetType(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const char *OpDesc::GetTypePtr() const { - return impl_->GetTypePtr(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetType(const std::string &type) { - return impl_->SetType(type); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIrRelated(const OpDescPtr &op_desc) { - impl_->SetIrRelated(op_desc != nullptr ? op_desc->impl_.get() : nullptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddInputDesc(const ge::GeTensorDesc &input_desc) { - return impl_->AddInputDesc(input_desc); -} - -graphStatus OpDesc::AddInputDesc(const uint32_t index, const ge::GeTensorDesc &input_desc) { - return impl_->AddInputDesc(index, input_desc); -} - -graphStatus OpDesc::AddInputDesc(const std::string &name, const ge::GeTensorDesc &input_desc) { - return impl_->AddInputDesc(name, input_desc); -} - -graphStatus OpDesc::AddInputDescMiddle(const std::string &name, const uint32_t num, const size_t index) { - return impl_->AddInputDescMiddle(name, num, index); -} - -graphStatus OpDesc::AddOutputDescMiddle(const std::string &name, const uint32_t num, const size_t index) { - return impl_->AddOutputDescMiddle(name, num, index); -} - -graphStatus OpDesc::AddOutputDescForward(const std::string &name, const uint32_t num) { - return impl_->AddOutputDescForward(name, num); -} - -graphStatus OpDesc::AddOptionalInputDesc(const std::string &name, const ge::GeTensorDesc &input_desc) { - return impl_->AddOptionalInputDesc(name, input_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::UpdateInputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_desc) { - return impl_->UpdateInputDesc(index, tensor_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { - return impl_->OpDescMembersAreEqual(*(r_op_desc.impl_)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { - return impl_->OpDescAttrsAreEqual(*(r_op_desc.impl_)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) - const { - return impl_->OpDescGenTensorDescsAreEqual(*(r_op_desc.impl_)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpDesc &r_op_desc) const { - return (OpDescAttrsAreEqual(r_op_desc) && OpDescMembersAreEqual(r_op_desc) && - OpDescGenTensorDescsAreEqual(r_op_desc)); -} - -graphStatus OpDesc::UpdateInputDesc(const std::string &name, const ge::GeTensorDesc &tensor_desc) { - return impl_->UpdateInputDesc(name, tensor_desc); -} - -bool OpDesc::InputIsSet(const std::string &name) const { - return impl_->InputIsSet(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const GeTensorDesc &OpDesc::GetInputDesc(const uint32_t index) const { - return impl_->GetInputDesc(index); -} - -const GeTensorDesc &OpDesc::GetInputDesc(const std::string &name) const { - return impl_->GetInputDesc(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(const uint32_t index) const { - return impl_->MutableInputDesc(index); -} - -GeTensorDescPtr OpDesc::MutableInputDesc(const std::string &name) const { - return impl_->MutableInputDesc(name); -} - -bool OpDesc::IsOptionalInput(const uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } - -GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { - return impl_->GetAllInputNames(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpKernelLibName(const std::string &name) { - impl_->SetOpKernelLibName(name); - const auto ret = AttrUtils::SetStr(this, ATTR_NAME_OP_KERNEL_LIB_NAME, name); - if (!ret) { - REPORT_INNER_ERR_MSG("E18888", "set %s to op failed.", ATTR_NAME_OP_KERNEL_LIB_NAME.c_str()); - GELOGE(GRAPH_FAILED, "[Set][Str] %s to op failed.", ATTR_NAME_OP_KERNEL_LIB_NAME.c_str()); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpKernelLibName() const { - std::string op_kernel_lib_name = impl_->GetOpKernelLibName(); - if (op_kernel_lib_name.empty()) { - (void)AttrUtils::GetStr(this, ATTR_NAME_OP_KERNEL_LIB_NAME, - op_kernel_lib_name); - } - return op_kernel_lib_name; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(const std::string &name) { - impl_->SetOpEngineName(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { - return impl_->GetOpEngineName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OppImplVersion OpDesc::GetOppImplVersion() const { - int64_t opp_impl_version; - if (!ge::AttrUtils::GetInt(this, ge::ATTR_NAME_BINARY_SOURCE, opp_impl_version)) { - return OppImplVersion::kOpp; - } - return static_cast(opp_impl_version); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { - return impl_->GetAllInputsDesc(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDescPtr() const { - return impl_->GetAllInputsDescPtr(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() const { - // Just return valid inputs size.InValid desc is set in default OPTION_INPUT register. - return impl_->GetInputsSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { - return impl_->GetAllInputsSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { - return impl_->AddOutputDesc(output_desc); -} - -graphStatus OpDesc::AddOutputDesc(const std::string &name, const ge::GeTensorDesc &output_desc) { - return impl_->AddOutputDesc(name, output_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDesc::UpdateOutputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_desc) { - return impl_->UpdateOutputDesc(index, tensor_desc); -} - -graphStatus OpDesc::UpdateOutputDesc(const std::string &name, const ge::GeTensorDesc &tensor_desc) { - return impl_->UpdateOutputDesc(name, tensor_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const GeTensorDesc &OpDesc::GetOutputDesc(const uint32_t index) const { - return impl_->GetOutputDesc(index); -} - -const GeTensorDesc &OpDesc::GetOutputDesc(const std::string &name) const { - return impl_->GetOutputDesc(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const uint32_t index) const { - return impl_->MutableOutputDesc(index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -GeTensorDescPtr OpDesc::MutableOutputDesc(const std::string &name) const { - return impl_->MutableOutputDesc(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { - return impl_->GetAllOutputsDescSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDesc() const { - return impl_->GetAllOutputsDesc(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllOutputsDescPtr() const { - return impl_->GetAllOutputsDescPtr(shared_from_this()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetOutputsSize() const { - return impl_->GetOutputsSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ConstGeTensorDescPtr OpDesc::GetOutputDescPtr(const uint32_t index) const { - return impl_->GetOutputDescPtr(index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const uint32_t index) const { - return impl_->GetInputDescPtr(index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr -OpDesc::GetInputDescPtrDfault(const uint32_t index) const { - return impl_->GetInputDescPtrDfault(index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const std::string &name) const { - return impl_->GetInputDescPtr(name); -} - -graphStatus OpDesc::AddRegisterInputName(const std::string &name) { - return impl_->MutableIRMeta().AddRegisterInputName(name); -} - -vector OpDesc::GetRegisterInputName() const { - return impl_->MutableIRMeta().GetRegisterInputName(); -} - -graphStatus OpDesc::AddDynamicInputDesc(const std::string &name, const uint32_t num, const bool is_push_back) { - return impl_->AddDynamicInputDesc(name, num, is_push_back); -} - -graphStatus OpDesc::AddDynamicInputDescByIndex(const std::string &name, const uint32_t num, const size_t index) { - if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDesc::AddRegisterOutputName(const std::string &name) { - return impl_->MutableIRMeta().AddRegisterOutputName(name); -} - -vector OpDesc::GetRegisterOutputName() const { - return impl_->MutableIRMeta().GetRegisterOutputName(); -} - -graphStatus OpDesc::AddDynamicOutputDesc(const std::string &name, const uint32_t num, const bool is_push_back) { - if (is_push_back) { - for (uint32_t i = 0U; i < num; i++) { - if (AddOutputDesc(name + std::to_string(i), GeTensorDesc()) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } else { - if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - - if (AddRegisterOutputName(name) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -bool OpDesc::IsOptionalInput(const std::string &name) const { - return impl_->GetIRMeta().IsOptionalInput(name); -} - -std::map OpDesc::GetAllInputName() const { - return impl_->GetAllInputName(); -} - -std::map OpDesc::GetAllOutputName() { - return impl_->GetAllOutputName(); -} - -std::map& OpDesc::MutableAllInputName() { - return impl_->MutableAllInputName(); -} - -std::map& OpDesc::MutableAllOutputName() { - return impl_->MutableAllOutputName(); -} - -bool OpDesc::UpdateInputName(const std::map input_name_idx) { - return impl_->UpdateInputName(input_name_idx); -} - -bool OpDesc::UpdateOutputName(const std::map output_name_idx) { - return impl_->UpdateOutputName(output_name_idx); -} - -std::function OpDesc::GetInferFunc() const { - return impl_->GetInferFunc(); -} - -void *OpDesc::GetTilingFuncInfo() const { - return impl_->GetTilingFuncInfo(); -} - -void OpDesc::SetTilingFuncInfo(void *tiling_func_info) { - impl_->SetTilingFuncInfo(tiling_func_info); -} - -void *OpDesc::GetAtomicTilingFuncInfo() const { - return impl_->GetAtomicTilingFuncInfo(); -} - -void OpDesc::SetAtomicTilingFuncInfo(void *atomic_tiling_func_info) { - impl_->SetAtomicTilingFuncInfo(atomic_tiling_func_info); -} - -std::function OpDesc::GetVerifyFunc() const { - return impl_->GetVerifyFunc(); -} - -std::function OpDesc::GetInferFormatFunc() const { - return impl_->GetInferFormatFunc(); -} - -std::function OpDesc::GetInferDataSliceFunc() const { - return impl_->GetInferDataSliceFunc(); -} - -std::function OpDesc::GetInferValueRangeFunc() const { - return impl_->GetInferValueRangeFunc(); -} - -void OpDesc::AddInferFunc(const std::function &func) { - impl_->AddInferFunc(func); -} - -void OpDesc::AddInferFormatFunc(const std::function &func) { - impl_->AddInferFormatFunc(func); -} - -void OpDesc::AddVerifierFunc(const std::function &func) { - impl_->AddVerifierFunc(func); -} - -void OpDesc::AddInferValueRangeFunc(const std::function &func) { - impl_->AddInferValueRangeFunc(func); -} - -void OpDesc::AddInferDataSliceFunc(const std::function &func) { - impl_->AddInferDataSliceFunc(func); -} - -graphStatus OpDesc::DefaultInferFormat() { - return impl_->DefaultInferFormat(shared_from_this()); -} - -graphStatus OpDesc::CommonVerify() const { - for (const std::string &iname : GetAllInputNames()) { - // Checking shape of all inputs - GE_CHK_BOOL_RET_STATUS_NOLOG(GetInputDescPtr(iname) != nullptr, GRAPH_FAILED); - const std::vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); - if (ishape == DUMMY_SHAPE) { - continue; - } - for (const int64_t dim : ishape) { - if (dim < -2) { // -2 is all shape - REPORT_PREDEFINED_ERR_MSG("E19014", std::vector({"opname", "value", "reason"}), - std::vector({GetName().c_str(), ("input " + iname + " shape").c_str(), - "contains negative or zero dimension"})); - GELOGE(FAILED, "Op[%s]'s input %s shape contains negative or zero dimension", GetName().c_str(), iname.c_str()); - return GRAPH_FAILED; - } - } - } - // Check all attributes defined - const auto &all_attributes = GetAllAttrs(); - for (const auto &name : GetAllAttrNames()) { - if (all_attributes.find(name) == all_attributes.end()) { - REPORT_PREDEFINED_ERR_MSG( - "E19014", std::vector({"opname", "value", "reason"}), - std::vector({GetName().c_str(), ("attribute " + name).c_str(), "is empty"})); - GELOGE(FAILED, "operator attribute %s is empty.", name.c_str()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetInputNameByIndex(const uint32_t index) const { - return impl_->GetInputNameByIndex(index); -} - -int32_t OpDesc::GetInputIndexByName(const std::string &name) const { - return impl_->GetInputIndexByName(name); -} - -graphStatus OpDesc::GetDynamicInputIndexesByName(const std::string &name, std::vector &indexes) const { - return impl_->GetDynamicInputIndexesByName(name, indexes); -} - -std::string OpDesc::GetValidInputNameByIndex(const uint32_t index) const { - return impl_->GetValidInputNameByIndex(index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOutputNameByIndex(const uint32_t index) const { - return impl_->GetOutputNameByIndex(index); -} - -int32_t OpDesc::GetOutputIndexByName(const std::string &name) const { - return impl_->GetOutputIndexByName(name); -} - -graphStatus OpDesc::GetDynamicOutputIndexesByName(const std::string &name, std::vector &indexes) const { - return impl_->GetDynamicOutputIndexesByName(name, indexes); -} - -ProtoAttrMap &OpDesc::MutableAttrMap() { - return impl_->MutableAttrMap(); -} - -ConstProtoAttrMap &OpDesc::GetAttrMap() const { - return impl_->GetAttrMap(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetId(const int64_t id) { - impl_->SetId(id); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetId() const { - return impl_->GetId(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetStreamId(const int64_t stream_id) { - impl_->SetStreamId(stream_id); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetStreamId() const { - return impl_->GetStreamId(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetAttachedStreamId(const int64_t stream_id) { - const auto ret = AttrUtils::SetInt(this, ATTR_NAME_ATTACHED_STREAM_ID, stream_id); - if (!ret) { - GELOGW("[Set][Attr] %s to op failed.", ATTR_NAME_ATTACHED_STREAM_ID.c_str()); - } - - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), - ATTR_NAME_ATTACHED_STREAM_ID, "", "", stream_id); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int64_t OpDesc::GetAttachedStreamId() const { - int64_t attached_stream_id = -1; // default invalid value - (void) AttrUtils::GetInt(this, ATTR_NAME_ATTACHED_STREAM_ID, attached_stream_id); - return attached_stream_id; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetAttachedStreamIds( - const std::vector &stream_ids) { - std::vector attached_stream_infos; - if (!AttrUtils::GetListNamedAttrs(this, ATTR_NAME_ATTACHED_STREAM_INFO_LIST, attached_stream_infos)) { - // 兼容老的方案,后续全部切换到ATTR_NAME_ATTACHED_STREAM_INFO_LIST属性设置上,不再提供Set/GetAttachedStreamId接口 - GELOGW("[Get][Attr] %s to op failed.", ATTR_NAME_ATTACHED_STREAM_INFO_LIST.c_str()); - if (stream_ids.size() == 1U) { - (void) SetAttachedStreamId(stream_ids[0U]); - } else { - GELOGW("SetAttachedStreamId only support stream size 1 but got %zu", stream_ids.size()); - } - } - if (stream_ids.size() != attached_stream_infos.size()) { - GELOGW("stream_ids size %zu is not equal to attr %s size %zu", stream_ids.size(), - ATTR_NAME_ATTACHED_STREAM_INFO_LIST.c_str(), attached_stream_infos.size()); - return; - } - for (size_t i = 0U; i < stream_ids.size(); ++i) { - if (!ge::AttrUtils::SetInt(attached_stream_infos[i], ATTR_NAME_ATTACHED_RESOURCE_ID, stream_ids[i])) { - GELOGW("[Set][Attr] %s failed, index = %zu.", ATTR_NAME_ATTACHED_STREAM_ID.c_str(), i); - } - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), - ATTR_NAME_ATTACHED_RESOURCE_ID, "", "", stream_ids[i]); - } - if (!AttrUtils::SetListNamedAttrs(this, ATTR_NAME_ATTACHED_STREAM_INFO_LIST, attached_stream_infos)) { - GELOGW("[Set][Attr] %s failed, op %s", ATTR_NAME_ATTACHED_STREAM_INFO_LIST.c_str(), GetName().c_str()); - return; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetAttachedStreamIds() const { - std::vector attached_stream_infos; - if (!AttrUtils::GetListNamedAttrs(this, ATTR_NAME_ATTACHED_STREAM_INFO_LIST, attached_stream_infos)) { - // 兼容老的方案 - GELOGD("[Get][Attr] %s to op %s failed.", ATTR_NAME_ATTACHED_STREAM_INFO_LIST.c_str(), GetNamePtr()); - if (GetAttachedStreamId() != -1) { - return {GetAttachedStreamId()}; - } - return {}; - } - std::vector stream_ids; - for (const auto &attached_stream_info : attached_stream_infos) { - int64_t stream_id = -1; // 默认值-1 - if (!ge::AttrUtils::GetInt(attached_stream_info, ATTR_NAME_ATTACHED_RESOURCE_ID, stream_id)) { - GELOGW("[Get][Attr] %s from op %s's attr map %s failed", - ATTR_NAME_ATTACHED_RESOURCE_ID.c_str(), - GetNamePtr(), - ATTR_NAME_ATTACHED_STREAM_INFO_LIST.c_str()); - } - stream_ids.emplace_back(stream_id); - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), this->GetName(), - ATTR_NAME_ATTACHED_RESOURCE_ID, "", "", stream_ids.back()); - } - return stream_ids; -} -// 需要兼容已有的GetAttachedStreamId接口,对于新的接口满足: -// 存在某个attached_streams_id不等于-1即可 -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::HasValidAttachedStreamId() const { - const auto &attached_streams_ids = GetAttachedStreamIds(); - return !attached_streams_ids.empty() && - std::any_of(attached_streams_ids.begin(), attached_streams_ids.end(), - [](const int64_t attached_stream_id) { return attached_stream_id != -1; }); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputName(const std::vector &input_name) { - impl_->SetInputName(input_name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetInputName() const { - return impl_->GetInputName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcName(const std::vector &src_name) { - impl_->SetSrcName(src_name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetSrcName() const { - return impl_->GetSrcName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetSrcIndex(const std::vector &src_index) { - impl_->SetSrcIndex(src_index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetSrcIndex() const { - return impl_->GetSrcIndex(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetInputOffset(const std::vector &input) { - impl_->SetInputOffset(input); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetInputOffset() const { - return impl_->GetInputOffset(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOutputOffset(const std::vector &output) { - impl_->SetOutputOffset(output); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetOutputOffset() const { - return impl_->GetOutputOffset(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstName(const std::vector &dst_name) { - impl_->SetDstName(dst_name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetDstName() const { - return impl_->GetDstName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -void OpDesc::SetOpInferDepends(const std::vector &depend_names) { - const auto ret = AttrUtils::SetListStr(this, optiling::ATTR_NAME_OP_INFER_DEPENDS, depend_names); - if (!ret) { - GELOGE(GRAPH_FAILED, "[Set][Attr] op_infer_depends fail."); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetOpInferDepends() const { - std::vector depend_names; - (void)AttrUtils::GetListStr(this, optiling::ATTR_NAME_OP_INFER_DEPENDS, depend_names); - return depend_names; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const std::vector &dst_index) { - impl_->SetDstIndex(dst_index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetWorkspace(const std::vector &workspace) { - impl_->SetWorkspace(workspace); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetWorkspace() const { - return impl_->GetWorkspace(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -void OpDesc::SetWorkspaceBytes(const std::vector &workspace_bytes) { - impl_->SetWorkspaceBytes(workspace_bytes); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetWorkspaceBytes() const { - return impl_->GetWorkspaceBytes(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetIsInputConst(const std::vector &is_input_const) { - impl_->SetIsInputConst(is_input_const); - // If comes from ME,which is_input_const exist as attrs, outside no need to check GE_TRAIN flag - const auto ret = AttrUtils::SetListBool(this, ATTR_NAME_IS_INPUT_CONST, is_input_const); - if (!ret) { - GELOGE(GRAPH_FAILED, "[Set][Attr] is_input_const fail."); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDesc::GetIsInputConst() const { - return impl_->GetIsInputConst(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(const uint32_t index) const { - return impl_->GetSubgraphInstanceName(static_cast(index)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &OpDesc::GetSubgraphInstanceNames() - const { - return impl_->GetSubgraphInstanceNames(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { - impl_->RemoveSubgraphInstanceName(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { - return impl_->AddSubgraphName(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map &OpDesc::GetSubgraphNameIndexes() - const { - return impl_->GetSubgraphNameIndexes(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus OpDesc::SetSubgraphInstanceName(const uint32_t index, const std::string &name) { - return impl_->SetSubgraphInstanceName(static_cast(index), name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -void OpDesc::RegisterSubgraphIrName(const std::string &name, const SubgraphType type) { - impl_->MutableIRMeta().RegisterSubgraphIrName(name, type); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::map &OpDesc::GetSubgraphIrNames() const { - return impl_->GetIRMeta().GetSubgraphIrNames(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -const std::vector> &OpDesc::GetOrderedSubgraphIrNames() const { - return impl_->GetIRMeta().GetOrderedSubgraphIrNames(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -SubgraphType OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { - return impl_->GetIRMeta().GetSubgraphTypeByIrName(name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { - return impl_->GetSubgraphNameByInstanceName(instance_name, subgraph_name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AppendIrAttrName(const std::string &name) { - return impl_->MutableIRMeta().AppendIrAttrName(name); -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector &OpDesc::GetIrAttrNames() const { - return impl_->GetIRMeta().GetIrAttrNames(); -} -void OpDesc::AppendIrInput(std::string name, IrInputType input_type) { - impl_->MutableIRMeta().AppendIrInput(std::move(name), input_type); -} -const std::vector> &OpDesc::GetIrInputs() const { - return impl_->GetIRMeta().GetIrInputs(); -} - -void OpDesc::SetInputDtypeSymbol(const std::string &ir_input, IrInputType type, const std::string &sym_id) { - impl_->MutableIRMeta().MutableIRDataTypeSymbolStore().SetInputSymbol(ir_input, type, sym_id); -} -void OpDesc::SetOutputDtypeSymbol(const std::string &ir_output, IrOutputType type, const std::string &sym_id) { - impl_->MutableIRMeta().MutableIRDataTypeSymbolStore().SetOutputSymbol(ir_output, type, sym_id); -} -void OpDesc::DeclareDtypeSymbol(const std::string &sym_id, const TensorType &type) { - impl_->MutableIRMeta().MutableIRDataTypeSymbolStore().DeclareSymbol(sym_id, type); -} -void OpDesc::DeclareDtypeSymbol(const std::string &sym_id, const ListTensorType &type) { - impl_->MutableIRMeta().MutableIRDataTypeSymbolStore().DeclareSymbol(sym_id, type); -} -void OpDesc::DeclareDtypeSymbol(const std::string &sym_id, const Promote &type) { - impl_->MutableIRMeta().MutableIRDataTypeSymbolStore().DeclareSymbol(sym_id, type); -} -void OpDesc::ShareDtypeSymbolsFrom(const OpDesc &src) { - impl_->MutableIRMeta().MutableIRDataTypeSymbolStore() = src.impl_->GetIRMeta().GetIRDataTypeSymbolStore(); -} -bool OpDesc::IsSupportSymbolicInferDataType() const { - return impl_->IsSupportSymbolicInferDataType(); -} -graphStatus OpDesc::SymbolicInferDataType() { - return impl_->SymbolicInferDataType(shared_from_this()); -} - -void OpDesc::AppendIrOutput(std::string name, IrOutputType output_type) { - impl_->MutableIRMeta().AppendIrOutput(name, output_type); -} -const std::vector> &OpDesc::GetIrOutputs() const { - return impl_->GetIRMeta().GetIrOutputs(); -} -size_t OpDesc::GetIrInputsSize() const { - return impl_->GetIrInputsSize(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddInput(const std::string &name) { - inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} -graphStatus OpDesc::GetPromoteIrInputList(std::vector> &promote_index_list) { - return impl_->MutableIRMeta().MutableIRDataTypeSymbolStore().GetPromoteIrInputList(promote_index_list); -} - -std::map OpDesc::GetAllOutputIndexToName() { - return impl_->GetAllOutputIndexToName(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddInput(const std::string &name, const GeTensorDesc &tensor) { - inputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name, - const uint32_t num) { - for (uint32_t i = 0U; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddDynamicInput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor) { - for (uint32_t i = 0U; i < num; i++) { - inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name) { - outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddOutput(const std::string &name, const GeTensorDesc &tensor) { - outputs_.emplace_back(std::make_pair(name, tensor)); - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name, - const uint32_t num) { - for (uint32_t i = 0U; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); - } - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -OpDescBuilder& OpDescBuilder::AddDynamicOutput(const std::string &name, const uint32_t num, - const GeTensorDesc &tensor) { - for (uint32_t i = 0U; i < num; i++) { - outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); - } - return *this; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { - const OpDescPtr op_desc = MakeShared(name_, type_); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create opdesc failed, name:%s, type:%s.", name_.c_str(), type_.c_str()); - GELOGE(GRAPH_FAILED, "[Create][OpDesc] failed, name:%s, type:%s.", name_.c_str(), type_.c_str()); - return nullptr; - } - for (auto &input : inputs_) { - if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "AddInputDesc failed, op:%s.", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Add][InputDesc] failed, op:%s.", name_.c_str()); - return nullptr; - } - } - for (auto &output : outputs_) { - if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "AddOutputDesc failed, op:%s", name_.c_str()); - GELOGE(GRAPH_FAILED, "[Add][OutputDesc] failed, op:%s.", name_.c_str()); - return nullptr; - } - } - return op_desc; -} -} // namespace ge diff --git a/graph/normal_graph/op_desc_impl.h b/graph/normal_graph/op_desc_impl.h deleted file mode 100644 index 94ca0aedb83dd31d7b98d31fc90e0b71c01661d8..0000000000000000000000000000000000000000 --- a/graph/normal_graph/op_desc_impl.h +++ /dev/null @@ -1,239 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_OP_DESC_IMPL_H_ -#define GRAPH_OP_DESC_IMPL_H_ - -#include -#include -#include -#include -#include "graph/types.h" -#include "graph/op_desc.h" -#include "graph/small_vector.h" -#include "graph/ascend_limits.h" -#include "graph/type/tensor_type_impl.h" -#include "graph/ir/ir_meta.h" - -namespace ge { -enum class DataTypeInferStrategy { - kInferFromAttr, - kInferFromInput, - kInferFromOutput, - kInvalidStrategy -}; - -class OpDescImpl { - public: - OpDescImpl(); - OpDescImpl(const std::string &name, const std::string &type); - OpDescImpl(const OpDescImpl &op_desc_impl); - OpDescImpl& operator=(const OpDescImpl &op_desc_impl) &; - explicit OpDescImpl(const ge::proto::OpDef &op_def); - - ~OpDescImpl() = default; - - const char *GetNamePtr() const; - std::string GetName() const; - void SetName(const std::string &name); - - const char *GetTypePtr() const; - std::string GetType() const; - void SetType(const std::string &type); - - void SetIrRelated(const OpDescImpl *r_op_desc); - - graphStatus AddInputDesc(const ge::GeTensorDesc &input_desc); - graphStatus AddInputDesc(const uint32_t index, const ge::GeTensorDesc &input_desc); - graphStatus AddInputDesc(const std::string &name, const ge::GeTensorDesc &input_desc); - graphStatus AddInputDescMiddle(const std::string &name, const uint32_t num, const size_t index); - graphStatus AddOutputDescMiddle(const std::string &name, const uint32_t num, const size_t index); - graphStatus AddInputDescForward(const std::string &name, const uint32_t num); - graphStatus AddOutputDescForward(const std::string &name, const uint32_t num); - graphStatus AddOptionalInputDesc(const std::string &name, const ge::GeTensorDesc &input_desc); - - graphStatus UpdateInputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_Desc); - graphStatus UpdateInputDesc(const std::string &name, const ge::GeTensorDesc &tensor_Desc); - - bool OpDescMembersAreEqual(const OpDescImpl &r_op_desc) const; - bool OpDescAttrsAreEqual(const OpDescImpl &r_op_desc) const; - bool OpDescGenTensorDescsAreEqual(const OpDescImpl &r_op_desc) const; - - bool InputIsSet(const std::string &name) const; - - const GeTensorDesc &GetInputDesc(const uint32_t index) const; - const GeTensorDesc &GetInputDesc(const std::string &name) const; - GeTensorDescPtr MutableInputDesc(const uint32_t index) const; - GeTensorDescPtr MutableInputDesc(const std::string &name) const; - OpDesc::Vistor GetAllInputNames(const ConstOpDescPtr &op_desc) const; - - void SetOpKernelLibName(const std::string &name); - std::string GetOpKernelLibName() const; - void SetOpEngineName(const std::string &name); - std::string GetOpEngineName() const; - - OpDesc::Vistor GetAllInputsDesc(const ConstOpDescPtr &op_desc) const; - OpDesc::Vistor GetAllInputsDescPtr(const ConstOpDescPtr &op_desc) const; - - size_t GetInputsSize() const; - size_t GetIrInputsSize() const; - size_t GetAllInputsSize() const; - - graphStatus AddOutputDesc(const ge::GeTensorDesc &output_desc); - graphStatus AddOutputDesc(const std::string &name, const ge::GeTensorDesc &output_desc); - graphStatus UpdateOutputDesc(const uint32_t index, const ge::GeTensorDesc &tensor_Desc); - graphStatus UpdateOutputDesc(const std::string &name, const ge::GeTensorDesc &tensor_Desc); - const GeTensorDesc &GetOutputDesc(const uint32_t index) const; - const GeTensorDesc &GetOutputDesc(const std::string &name) const; - GeTensorDescPtr MutableOutputDesc(const uint32_t index) const; - GeTensorDescPtr MutableOutputDesc(const std::string &name) const; - - uint32_t GetAllOutputsDescSize() const; - OpDesc::Vistor GetAllOutputsDesc(const ConstOpDescPtr &op_desc) const; - OpDesc::Vistor GetAllOutputsDescPtr(const ConstOpDescPtr &op_desc) const; - ConstGeTensorDescPtr GetOutputDescPtr(const uint32_t index) const; - size_t GetOutputsSize() const; - - ConstGeTensorDescPtr GetInputDescPtr(const uint32_t index) const; - ConstGeTensorDescPtr GetInputDescPtrDfault(const uint32_t index) const; - ConstGeTensorDescPtr GetInputDescPtr(const std::string &name) const; - - graphStatus AddDynamicInputDesc(const std::string &name, const uint32_t num, const bool is_push_back); - graphStatus AddDynamicInputDescByIndex(const std::string &name, const uint32_t num, const size_t index); - - graphStatus AddDynamicOutputDesc(const std::string &name, const uint32_t num, const bool is_push_back); - bool IsOptionalInput(const uint32_t index) const; - std::map GetAllInputName() const; - std::map GetAllOutputName(); - std::map GetAllOutputIndexToName(); - std::map& MutableAllInputName(); - std::map& MutableAllOutputName(); - bool UpdateInputName(std::map input_name_idx); - bool UpdateOutputName(std::map output_name_idx); - - std::function GetInferFunc() const; - std::function GetVerifyFunc() const; - std::function GetInferFormatFunc() const; - std::function GetInferValueRangeFunc() const; - std::function GetInferDataSliceFunc() const; - - void AddInferFunc(const std::function &func); - void AddInferFormatFunc(const std::function &func); - void AddInferValueRangeFunc(const std::function &func); - void AddVerifierFunc(const std::function &func); - void AddInferDataSliceFunc(const std::function &func); - - bool IsSupportSymbolicInferDataType() const; - graphStatus SymbolicInferDataType(const OpDescPtr &op_desc) const; - graphStatus DefaultInferFormat(const ConstOpDescPtr &op_desc) const; - - std::string GetInputNameByIndex(const uint32_t index) const; - int32_t GetInputIndexByName(const std::string &name) const; - graphStatus GetDynamicInputIndexesByName(const std::string &name, std::vector &indexes) const; - std::string GetValidInputNameByIndex(const uint32_t index) const; - - std::string GetOutputNameByIndex(const uint32_t index) const; - int32_t GetOutputIndexByName(const std::string &name) const; - graphStatus GetDynamicOutputIndexesByName(const std::string &name, std::vector &indexes) const; - - ProtoAttrMap &MutableAttrMap(); - ConstProtoAttrMap &GetAttrMap() const; - - IRMetaData &MutableIRMeta(); - const IRMetaData &GetIRMeta() const; - - void SetId(const int64_t id); - int64_t GetId() const; - - void SetStreamId(const int64_t stream_id); - int64_t GetStreamId() const; - - void SetInputName(const std::vector &input_name); - std::vector GetInputName() const; - - void SetSrcName(const std::vector &src_name); - std::vector GetSrcName() const; - - void SetSrcIndex(const std::vector &src_index); - std::vector GetSrcIndex() const; - - void SetInputOffset(const std::vector &input); - std::vector GetInputOffset() const; - - void SetOutputOffset(const std::vector &output); - std::vector GetOutputOffset() const; - - void SetDstName(const std::vector &dst_name); - std::vector GetDstName() const; - - void SetDstIndex(const std::vector &dst_index); - - void SetWorkspace(const std::vector &workspace); - std::vector GetWorkspace() const; - - void SetWorkspaceBytes(const std::vector &workspace_bytes); - std::vector GetWorkspaceBytes() const; - - void SetIsInputConst(const std::vector &is_input_const); - std::vector GetIsInputConst() const; - - std::string GetSubgraphInstanceName(const size_t index) const; - const std::vector &GetSubgraphInstanceNames() const; - void RemoveSubgraphInstanceName(const std::string &name); - graphStatus AddSubgraphName(const std::string &name); - const std::map &GetSubgraphNameIndexes() const; - graphStatus SetSubgraphInstanceName(const size_t index, const std::string &name); - - graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; - - void *GetTilingFuncInfo() const; - void SetTilingFuncInfo(void *tiling_func_info); - void *GetAtomicTilingFuncInfo() const; - void SetAtomicTilingFuncInfo(void *atomic_tiling_func_info); - - private: - void DeSerializeOpDefToMetaData(const proto::OpDef &op_def); - void SerializeMetaDataToOpDef(proto::OpDef * const op_def); - - friend class AttrUtils; - friend class OpDescUtils; - friend class ModelSerializeImp; - friend class OnnxUtils; - friend class GraphUtils; - friend class NodeUtils; - friend class FastNodeUtils; - friend class ExecuteGraphUtils; - std::vector subgraph_instance_names_; - - // subgraph names to index, for a `if` operator: - // then_branch: 0 - // else_branch: 1 - // or for a `case` node: - // branches0: 0 - // branches1: 1 - // branches2: 2 - std::map subgraph_names_to_index_; - std::vector inputs_desc_{}; - std::map input_name_idx_{}; - std::vector outputs_desc_{}; - std::map output_name_idx_{}; - std::function infer_func_ = nullptr; - std::function infer_format_func_ = nullptr; - std::function infer_value_range_func_ = nullptr; - std::function verifier_func_ = nullptr; - std::function infer_data_slice_func_ = nullptr; - std::string op_kernel_lib_name_; - std::string engine_name_; - OpMetadata meta_data_; - AttrStore attrs_; - void *tiling_func_info_ = nullptr; - void *atomic_tiling_func_info_ = nullptr; -}; -} // namespace ge -#endif // GRAPH_OP_DESC_IMPL_H_ diff --git a/graph/normal_graph/op_io.h b/graph/normal_graph/op_io.h deleted file mode 100644 index 5abf5d844f9c4853f033460377799abe4e800c58..0000000000000000000000000000000000000000 --- a/graph/normal_graph/op_io.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_OP_IO_H -#define METADEF_CXX_OP_IO_H -namespace ge { - -class OpIO { - public: - OpIO(const std::string &name, const int32_t index, const OperatorImplPtr &owner) - : name_(name), index_(index), owner_(owner) {} - - ~OpIO() = default; - - std::string GetName() const { return name_; } - - int32_t GetIndex() const { return index_; } - - OperatorImplPtr GetOwner() const { return owner_; } - - private: - std::string name_; - int32_t index_; - std::shared_ptr owner_; -}; -} -#endif // METADEF_CXX_OP_IO_H diff --git a/graph/normal_graph/operator.cc b/graph/normal_graph/operator.cc deleted file mode 100644 index 1656fa0b055a5957cef19fdc4e7490acdd9e1cbb..0000000000000000000000000000000000000000 --- a/graph/normal_graph/operator.cc +++ /dev/null @@ -1,3274 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/operator.h" - -#include -#include -#include -#include -#include -#include "external/graph/operator_factory.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "debug/ge_util.h" -#include "external/graph/attr_value.h" -#include "graph/compute_graph.h" -#include "graph/ge_context.h" -#include "graph/runtime_inference_context.h" -#include "graph/utils/node_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/constant_utils.h" -#include "common/checker.h" -#include "graph/type/tensor_type_impl.h" -#include "graph/normal_graph/op_io.h" -#include "graph/normal_graph/operator_impl.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "graph/utils/op_type_utils.h" - -#define OP_ATTR_SET_IMP(ArgType, AttrUtilsFun) \ - Operator &ge::Operator::SetAttr(const std::string &name, ArgType attr_value) { \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); \ - return *this; \ - } \ - if (!ge::AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name.c_str()); \ - } \ - return *this; \ - } \ - Operator &ge::Operator::SetAttr(const char_t *name, ArgType attr_value) { \ - if (name == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator attr name is nullptr."); \ - return *this; \ - } \ - const std::string op_name = name; \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", op_name.c_str()); \ - return *this; \ - } \ - if (!ge::AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { \ - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); \ - } \ - return *this; \ - } - -#define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ - graphStatus ge::Operator::GetAttr(const std::string &name, ArgType attr_value) const { \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); \ - return GRAPH_FAILED; \ - } \ - if (!ge::AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("[Get][Attr] Get attr name %s unsuccessful", name.c_str()); \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } \ - graphStatus ge::Operator::GetAttr(const char_t *name, ArgType attr_value) const { \ - if (name == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator attr name is nullptr."); \ - return GRAPH_FAILED; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return GRAPH_FAILED; \ - } \ - const std::string op_name = name; \ - if (!ge::AttrUtils::Get##AttrUtilsFun(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { \ - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } - -#define OP_ATTR_REG_IMP(ArgType, AttrUtilsFun) \ - void ge::Operator::AttrRegister(const char_t *name, ArgType attr_value) { \ - GE_CHECK_NOTNULL_JUST_RETURN(name); \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return; \ - } \ - if (!ge::AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ - GELOGW("[Register][Attr] Reg attr name %s unsuccessful", name); \ - } \ - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); \ - } - -#define EDGE_ATTR_SET_BY_IDX_IMP(ArgType, AttrUtilsFun) \ - Operator &Operator::SetInputAttr(const int32_t index, const char_t *name, ArgType attr_value) { \ - if (name == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator attr name is nullptr."); \ - return *this; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return *this; \ - } \ - const std::string op_name = name; \ - auto tensor = operator_impl_->MutableInputDesc(index); \ - if (!ge::AttrUtils::Set##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Set][Attr] Set attr name %s to op %s of index[%d] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), index); \ - } \ - return *this; \ - } \ - Operator &Operator::SetOutputAttr(const int32_t index, const char_t *name, ArgType attr_value) { \ - if (name == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator attr name is nullptr."); \ - return *this; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return *this; \ - } \ - const std::string op_name = name; \ - auto tensor = operator_impl_->MutableOutputDesc(index); \ - if (!ge::AttrUtils::Set##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Set][Attr] Set attr name %s to op %s of index[%d] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), index); \ - } \ - return *this; \ - } -#define EDGE_ATTR_SET_BY_NAME_IMP(ArgType, AttrUtilsFun) \ - Operator &Operator::SetInputAttr(const char_t *dst_name, const char_t *name, ArgType attr_value) { \ - if ((dst_name == nullptr) || (name == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "dst_name or attr name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator input name or attr name is nullptr."); \ - return *this; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return *this; \ - } \ - const std::string op_name = name; \ - const std::string dst_names = dst_name; \ - auto tensor = operator_impl_->MutableInputDesc(dst_names); \ - if (!ge::AttrUtils::Set##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Set][Attr] Set attr name %s to op %s of input_name[%s] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), dst_name); \ - } \ - return *this; \ - } \ - Operator &Operator::SetOutputAttr(const char_t *dst_name, const char_t *name, ArgType attr_value) { \ - if ((dst_name == nullptr) || (name == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "dst_name or attr name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator output name or attr name is nullptr."); \ - return *this; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return *this; \ - } \ - const std::string op_name = name; \ - const std::string dst_names = dst_name; \ - auto tensor = operator_impl_->MutableOutputDesc(dst_names); \ - if (!ge::AttrUtils::Set##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Set][Attr] Set attr name %s to op %s of output_name[%s] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), dst_name); \ - } \ - return *this; \ - } - -#define EDGE_ATTR_GET_BY_IDX_IMP(ArgType, AttrUtilsFun) \ - graphStatus Operator::GetInputAttr(const int32_t index, const char_t *name, ArgType attr_value) const { \ - if (name == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator attr name is nullptr."); \ - return GRAPH_FAILED; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return GRAPH_FAILED; \ - } \ - const std::string op_name = name; \ - const auto &tensor = operator_impl_->GetInputDesc(index); \ - if (!ge::AttrUtils::Get##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Get][Attr] Get attr name %s to op %s of index[%d] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), index); \ - } \ - return GRAPH_SUCCESS; \ - } \ - graphStatus Operator::GetOutputAttr(const int32_t index, const char_t *name, ArgType attr_value) const { \ - if (name == nullptr) { \ - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator attr name is nullptr."); \ - return GRAPH_FAILED; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return GRAPH_FAILED; \ - } \ - const std::string op_name = name; \ - const auto &tensor = operator_impl_->GetOutputDesc(index); \ - if (!ge::AttrUtils::Get##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Get][Attr] Get attr name %s to op %s of index[%d] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), index); \ - } \ - return GRAPH_SUCCESS; \ - } -#define EDGE_ATTR_GET_BY_NAME_IMP(ArgType, AttrUtilsFun) \ - graphStatus Operator::GetInputAttr(const char_t *dst_name, const char_t *name, ArgType attr_value) const { \ - if ((dst_name == nullptr) || (name == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "dst_name or attr name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator input name or attr name is nullptr."); \ - return GRAPH_FAILED; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return GRAPH_FAILED; \ - } \ - const std::string op_name = name; \ - const std::string dst_names = dst_name; \ - const auto &tensor = operator_impl_->GetInputDesc(dst_names); \ - if (!ge::AttrUtils::Get##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Get][Attr] Get attr name %s to op %s of input_name[%s] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), dst_name); \ - } \ - return GRAPH_SUCCESS; \ - } \ - graphStatus Operator::GetOutputAttr(const char_t *dst_name, const char_t *name, ArgType attr_value) const { \ - if ((dst_name == nullptr) || (name == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "dst_name or attr name is nullptr, check invalid."); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator output name or attr name is nullptr."); \ - return GRAPH_FAILED; \ - } \ - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { \ - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); \ - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); \ - return GRAPH_FAILED; \ - } \ - const std::string op_name = name; \ - const std::string dst_names = dst_name; \ - const auto &tensor = operator_impl_->GetOutputDesc(dst_names); \ - if (!ge::AttrUtils::Get##AttrUtilsFun(tensor, op_name, attr_value)) { \ - GELOGW("[Get][Attr] Get attr name %s to op %s of output_name[%s] unsuccessful", name, \ - operator_impl_->GetOpDescImpl()->GetName().c_str(), dst_name); \ - } \ - return GRAPH_SUCCESS; \ - } - -// 此宏因兼容问题需要保留,只在operator类内使用 -#define GE_RETURN_IF_NULL(v, ...) \ - do { \ - if ((v) == nullptr) { \ - auto msg = CreateErrorMsg(__VA_ARGS__); \ - if (msg.empty()) { \ - REPORT_INNER_ERR_MSG("E19999", "Assert %s not null failed", #v); \ - GELOGE(ge::FAILED, "Assert %s not null failed", #v); \ - } else { \ - REPORT_INNER_ERR_MSG("E19999", "%s", msg.data()); \ - GELOGE(ge::FAILED, "%s", msg.data()); \ - } \ - return; \ - } \ - } while (false) -#define EDGE_ATTR_SET_IMP(ArgType, AttrUtilsFunc) \ - EDGE_ATTR_SET_BY_IDX_IMP(ArgType, AttrUtilsFunc) \ - EDGE_ATTR_SET_BY_NAME_IMP(ArgType, AttrUtilsFunc) -#define EDGE_ATTR_GET_IMP(ArgType, AttrUtilsFunc) \ - EDGE_ATTR_GET_BY_IDX_IMP(ArgType, AttrUtilsFunc) \ - EDGE_ATTR_GET_BY_NAME_IMP(ArgType, AttrUtilsFunc) -namespace ge { -namespace { -graphStatus SetTensorAttr(const GeTensorDescPtr &tensor, const char_t *name, - const std::vector &attr_value) { - if ((tensor == nullptr) || (name == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator input parameters name or tensor is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name or tensor is nullptr, check invalid."); - return GRAPH_PARAM_INVALID; - } - std::vector op_attr_values; - for (const auto &value : attr_value) { - op_attr_values.emplace_back(value.GetString()); - } - if (!AttrUtils::SetListStr(tensor, name, op_attr_values)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name); - } - return GRAPH_SUCCESS; -} - -graphStatus GetTensorAttr(const GeTensorDesc &tensor, const char_t *name, std::vector &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator input parameters name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr, check invalid."); - return GRAPH_PARAM_INVALID; - } - std::vector op_attr_value; - GE_ASSERT_TRUE(AttrUtils::GetListStr(tensor, name, op_attr_value), - "[Get][Attr] GetListStr name %s on tensor failed.", name); - for (const auto &value : op_attr_value) { - attr_value.emplace_back(AscendString(value.c_str())); - } - return GRAPH_SUCCESS; -} -} -const int32_t kMaxDepth = 20; -OperatorKeeper &OperatorKeeper::GetInstance() { - static OperatorKeeper instance; - return instance; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromNode(ge::ConstNodePtr node_ptr) { - const ge::OperatorImplPtr operator_impl_ptr = ComGraphMakeShared(node_ptr); - if (operator_impl_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "OperatorImpl make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] OperatorImpl make shared failed"); - return Operator("default"); - } - return operator_impl_ptr->ToOperator(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::CopyOperators(const ComputeGraphPtr &dst_compute_graph, - const std::map &node_old_2_new, - const std::map &op_desc_old_2_new, - const std::map &src_op_list, - std::map &dst_op_list) { - GE_CHECK_NOTNULL(dst_compute_graph); - - std::map all_node_info; - - for (const auto &itr : src_op_list) { - auto name = itr.first; - const ge::Operator &src_op = itr.second; - GE_CHECK_NOTNULL(src_op.operator_impl_); - const OperatorImplPtr scr_op_impl_ptr = src_op.operator_impl_; - GE_CHECK_NOTNULL(scr_op_impl_ptr->op_desc_); - ge::Operator dst_op; - OpDescPtr dst_op_desc = nullptr; - if (scr_op_impl_ptr->node_ == nullptr) { - // cannot find op_desc in compute graph, need creat new op_desc - // otherwise use existing op_desc - const auto it = op_desc_old_2_new.find(scr_op_impl_ptr->op_desc_); - if (it != op_desc_old_2_new.end()) { - dst_op_desc = it->second; - } else { - dst_op_desc = OpDescUtils::CopyOpDesc(scr_op_impl_ptr->op_desc_); - if (dst_op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "CopyOpDesc from %s failed", scr_op_impl_ptr->op_desc_->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Copy][OpDesc] from %s failed", scr_op_impl_ptr->op_desc_->GetName().c_str()); - return GRAPH_FAILED; - } - dst_op_desc->CopyAttrsFrom(*scr_op_impl_ptr->op_desc_); - dst_op_desc->SetName(scr_op_impl_ptr->op_desc_->GetName()); - } - dst_op = CreateOperatorFromOpDesc(dst_op_desc); - } else { - const auto original_op_desc = scr_op_impl_ptr->node_->GetOpDesc(); - if (scr_op_impl_ptr->op_desc_ != original_op_desc) { - REPORT_INNER_ERR_MSG("E18888", "node and op_desc of operator are not equal."); - GELOGE(GRAPH_FAILED, "[Check][Param] node and op_desc of operator are not equal."); - return GRAPH_FAILED; - } - NodePtr dst_node = nullptr; - // cannot find node in compute graph, need creat new node - // otherwise use existing node and op_desc - const auto it = node_old_2_new.find(scr_op_impl_ptr->node_); - if (it != node_old_2_new.end()) { - dst_node = it->second; - } else { - dst_op_desc = OpDescUtils::CopyOpDesc(scr_op_impl_ptr->op_desc_); - if (dst_op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "CopyOpDesc from %s failed", scr_op_impl_ptr->op_desc_->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Copy][OpDesc] from %s failed", scr_op_impl_ptr->op_desc_->GetName().c_str()); - return GRAPH_FAILED; - } - dst_op_desc->CopyAttrsFrom(*scr_op_impl_ptr->op_desc_); - dst_op_desc->SetName(scr_op_impl_ptr->op_desc_->GetName()); - dst_node = NodeUtils::CreatNodeWithoutGraph(dst_op_desc); - GE_CHECK_NOTNULL(dst_node); - // to do link egdes - } - dst_op = CreateOperatorFromNode(dst_node); - (void)(all_node_info.emplace(dst_op.GetOperatorImplPtr(), dst_node)); - } - dst_op.operator_impl_->subgraph_names_to_builders_ = src_op.operator_impl_->subgraph_names_to_builders_; - (void)(dst_op_list.emplace(name, dst_op)); - } - - dst_compute_graph->SetAllNodesInfo(all_node_info); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::CopyOperatorLinks(const std::map &src_op_list, - std::map &dst_op_list) { - for (const auto &it : src_op_list) { - auto &src_op = it.second; - const auto op_name = it.first; - auto &dst_op = dst_op_list[op_name]; - const OperatorImplPtr src_impl_ptr = src_op.GetOperatorImplPtr(); - GE_CHECK_NOTNULL(src_impl_ptr); - for (const auto &itr : src_impl_ptr->input_link_) { - const std::string dst_name = itr.first; - const OpIO &op_io = itr.second; - const OperatorImplPtr input_impl_ptr = op_io.GetOwner(); - GE_CHECK_NOTNULL(input_impl_ptr); - const auto iter = dst_op_list.find(input_impl_ptr->GetName()); - if (iter == dst_op_list.end()) { - REPORT_INNER_ERR_MSG("E18888", "Find dst operator:%s failed", input_impl_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Find dst operator:%s failed", input_impl_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - auto &input_op = iter->second; - (void)(dst_op.SetInput(dst_name.c_str(), input_op)); - } - - for (const auto &itr : src_impl_ptr->control_input_link_) { - const OperatorImplPtr input_ctrl_impl_ptr = itr.lock(); - GE_CHECK_NOTNULL(input_ctrl_impl_ptr); - const auto iter = dst_op_list.find(input_ctrl_impl_ptr->GetName()); - if (iter == dst_op_list.end()) { - REPORT_INNER_ERR_MSG("E18888", "Find dst ctrl operator:%s failed", input_ctrl_impl_ptr->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Find dst ctrl operator:%s failed", input_ctrl_impl_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - auto &ctrl_input_op = iter->second; - (void)(dst_op.AddControlInput(ctrl_input_op)); - } - } - return GRAPH_SUCCESS; -} - -Operator::Operator(const std::string &type) { - static std::atomic index = {0U}; - std::string name = type + "_" + std::to_string(index++); - operator_impl_ = ComGraphMakeShared(name, type); - if (operator_impl_ == nullptr) { - GELOGW("[Check][Param] Make OperatorImpl unsuccessful"); - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); -} - -Operator::Operator(const char_t *type) { - if (type != nullptr) { - std::string op_type = type; - static std::atomic index = {0U}; - std::string name = op_type + "_" + std::to_string(index++); - operator_impl_ = ComGraphMakeShared(name, op_type); - if (operator_impl_ == nullptr) { - GELOGW("[Check][Param] Make OperatorImpl unsuccessful"); - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); - } else { - GELOGW("[Check][Param] Operator type is nullptr"); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator OpDescUtils::CreateOperatorFromOpDesc(OpDescPtr op_desc) { - std::shared_ptr operator_impl_ptr; - operator_impl_ptr = ComGraphMakeShared(op_desc); - if (operator_impl_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "OperatorImpl make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] OperatorImpl make shared failed"); - return Operator("default"); - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_ptr); - return operator_impl_ptr->ToOperator(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::GetOpDescFromOperator(const Operator &oprt) { - return OperatorImpl::GetOpDesc(oprt); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstNodePtr NodeUtilsEx::GetNodeFromOperator(const Operator &op) { - return op.GetNode(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtilsEx::SetNodeToOperator(Operator &op, - const ConstNodePtr &node) { - return op.operator_impl_->SetNode(node); -} - -GE_FUNC_HOST_VISIBILITY Operator::Operator(const std::string &name, const std::string &type) { - operator_impl_ = ComGraphMakeShared(name, type); - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "OperatorImpl make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] OperatorImpl make shared failed"); - return; - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); -} - -GE_FUNC_HOST_VISIBILITY Operator::Operator(const AscendString &name, const AscendString &type) { - if ((name.GetString() != nullptr) && (type.GetString() != nullptr)) { - std::string op_name = name.GetString(); - std::string op_type = type.GetString(); - operator_impl_ = ComGraphMakeShared(op_name, op_type); - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "OperatorImpl make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] OperatorImpl make shared failed"); - return; - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); - } else { - GELOGW("[Check][Param] Operator input parameter is nullptr"); - } -} - -GE_FUNC_HOST_VISIBILITY Operator::Operator(const char_t *name, const char_t *type) { - if ((name != nullptr) && (type != nullptr)) { - std::string op_name = name; - std::string op_type = type; - operator_impl_ = ComGraphMakeShared(op_name, op_type); - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "OperatorImpl make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] OperatorImpl make shared failed"); - return; - } - OperatorKeeper::GetInstance().CheckInOperator(operator_impl_); - } else { - GELOGW("[Check][Param] Operator input parameter is nullptr"); - } -} - -Operator::Operator(ge::OperatorImplPtr &&op_impl) { operator_impl_ = std::move(op_impl); } - -bool Operator::IsEmpty() const { - if (operator_impl_ == nullptr) { - return true; - } - return false; -} - -std::string Operator::GetName() const { - if (operator_impl_ != nullptr) { - return operator_impl_->GetName(); - } - return ""; -} - -graphStatus Operator::GetName(AscendString &name) const { - if (operator_impl_ != nullptr) { - const std::string op_name = operator_impl_->GetName(); - name = op_name.c_str(); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt) { - // Describe the connection relationship between operators, no create action - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator impl is nullptr, check invalid."); - return *this, "[Check][Param] operator impl is nullptr."); - operator_impl_->SetInputImpl(dst_name, src_oprt); - return *this; -} - -GE_FUNC_HOST_VISIBILITY Operator &Operator::SetInput(const char_t *dst_name, const ge::Operator &src_oprt) { - GE_CHK_BOOL_EXEC(dst_name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param dst name is nullptr, check invalid"); - return *this, "[Check][Param] Operator dst name is nullptr."); - // Describe the connection relationship between operators, no create action - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - return *this, "[Check][Param] Operator impl is nullptr."); - const std::string dst_op_name = dst_name; - operator_impl_->SetInputImpl(dst_op_name, src_oprt); - return *this; -} - -Operator &Operator::SetInput(const std::string &dst_name, const ge::OutHandler &out_handler) { - return SetInput(dst_name.c_str(), out_handler); -} - -Operator &Operator::SetInput(const char_t *dst_name, const ge::OutHandler &out_handler) { - if (dst_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator dst_name is nullptr."); - return *this; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - return *this, "[Check][Param] operator impl is nullptr."); - operator_impl_->SetInputImpl(dst_name, out_handler); - return *this; -} - -Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, const std::string &name) { - return SetInput(dst_name.c_str(), src_oprt, name.c_str()); -} - -Operator &Operator::SetInput(const char_t *dst_name, const ge::Operator &src_oprt, const char_t *name) { - GE_CHK_BOOL_EXEC(dst_name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param dst_name is nullptr, check invalid."); - return *this, "[Check][Param] Dst name is nullptr."); - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - return *this, "[Check][Param] Name is nullptr."); - const auto out_handler = src_oprt.GetOutput(name); - GE_CHK_BOOL_EXEC(out_handler != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOutput by name:%s failed, out_handler is nullptr.", name); - return *this, "[Get][Output] by name:%s failed, out_handler is nullptr.", name); - return SetInput(dst_name, out_handler); -} - -Operator &Operator::SetInput(const std::string &dst_name, const ge::Operator &src_oprt, uint32_t index) { - return SetInput(dst_name.c_str(), src_oprt, index); -} - -Operator &Operator::SetInput(const char_t *dst_name, const ge::Operator &src_oprt, uint32_t index) { - GE_CHK_BOOL_EXEC(dst_name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param dst_name is nullptr, check invalid"); - return *this, "[Check][Param] Dst name is nullptr."); - const auto out_handler = src_oprt.GetOutput(index); - GE_CHK_BOOL_EXEC(out_handler != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOutput by index:%u failed, out_handler is nullptr.", index); - return *this, "[Get][Output] by index:%u failed, out_handler is nullptr.", index); - return SetInput(dst_name, out_handler); -} - -Operator &Operator::SetInput(uint32_t dst_index, const Operator &src_oprt, uint32_t src_index) { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - const char_t *invalid_obj_name = ((operator_impl_ == nullptr) ? "operator" : "op desc"); - REPORT_INNER_ERR_MSG("E18888", "%s impl is nullptr, check invalid.", invalid_obj_name); - GELOGE(ge::FAILED, "[Check][Param] %s impl is nullptr.", invalid_obj_name); - return *this; - } - std::string dst_name = operator_impl_->GetOpDescImpl()->GetInputNameByIndex(dst_index); - if (dst_name.empty()) { - REPORT_INNER_ERR_MSG("E18888", "Set by dst_index:%u failed, dst_index is invalid.", dst_index); - GELOGE(ge::FAILED, "[GetInputNameByIndex] by index:%u failed, dst_index is invalid.", dst_index); - return *this; - } - return SetInput(dst_name.c_str(), src_oprt, src_index); -} - -Operator &Operator::AddControlInput(const Operator &src_oprt) { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator impl is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - return *this; - } - operator_impl_->AddControlInputImp(src_oprt); - return *this; -} - -graphStatus Operator::GetInputConstData(const std::string &dst_name, Tensor &data) const { - GE_CHECK_NOTNULL(operator_impl_); - const graphStatus ret = operator_impl_->GetInputConstData(dst_name, data); - if (ret != GRAPH_SUCCESS) { - GELOGW("[Get][ConstInput] %s get input const data unsuccessful", dst_name.c_str()); - return ret; - } - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetInputConstData(const char_t *dst_name, Tensor &data) const { - GE_CHECK_NOTNULL(dst_name); - GE_CHECK_NOTNULL(operator_impl_); - const std::string op_dst_name = dst_name; - const graphStatus ret = operator_impl_->GetInputConstData(op_dst_name, data); - if (ret != GRAPH_SUCCESS) { - GELOGW("[Get][ConstInput] %s get input const data unsuccessful", op_dst_name.c_str()); - return ret; - } - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetInputConstDataOut(const std::string &dst_name, Tensor &data) const { - return GetInputConstDataOut(dst_name.c_str(), data); -} - -graphStatus Operator::GetInputConstDataOut(const char_t *dst_name, Tensor &data) const { - GE_CHECK_NOTNULL(operator_impl_); - if (dst_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator dst_name is nullptr."); - return GRAPH_FAILED; - } - if (operator_impl_->GetInputConstDataOut(dst_name, data) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "%s get input const data out failed", dst_name); - GELOGE(GRAPH_FAILED, "[Get][Tensor] %s get input const data out failed", dst_name); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -std::shared_ptr Operator::GetNode() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return nullptr, "[Check][Param] operator impl is nullptr."); - return operator_impl_->GetNode(); -} - -TensorDesc Operator::GetInputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); -} - -TensorDesc Operator::GetInputDescByName(const char_t *name) const { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name)); -} - -void Operator::SetInferenceContext(const InferenceContextPtr &inference_context) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - operator_impl_->SetInferenceContext(inference_context); -} - -InferenceContextPtr Operator::GetInferenceContext() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return nullptr, "[Check][Param] operator impl is nullptr."); - return operator_impl_->GetInferenceContext(); -} - -TensorDesc Operator::GetInputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); -} - -graphStatus Operator::TryGetInputDesc(const std::string &name, TensorDesc &tensor_desc) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - const auto check = operator_impl_->InputIsSet(name); - if (check) { - tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name)); - } - return check ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Operator::TryGetInputDesc(const char_t *name, TensorDesc &tensor_desc) const { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - const auto check = operator_impl_->InputIsSet(op_name); - if (check) { - tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name)); - } - return check ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus Operator::UpdateInputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - return operator_impl_->UpdateInputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::UpdateInputDesc(const char_t *name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - return operator_impl_->UpdateInputDesc(op_name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -OutHandler Operator::GetOutput(const std::string &name) const { - return GetOutput(name.c_str()); -} - -OutHandler Operator::GetOutput(const char_t *name) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return nullptr; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return nullptr, "[Check][Param] operator impl is nullptr."); - return operator_impl_->GetOutput(name); -} - -OutHandler Operator::GetOutput(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return nullptr, "[Check][Param] operator impl is nullptr."); - return operator_impl_->GetOutput(index); -} - -TensorDesc Operator::GetOutputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name)); -} - -TensorDesc Operator::GetOutputDescByName(const char_t *name) const { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(op_name)); -} - -TensorDesc Operator::GetOutputDesc(uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(index)); -} - -graphStatus Operator::UpdateOutputDesc(const std::string &name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - return operator_impl_->UpdateOutputDesc(name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::UpdateOutputDesc(const char_t *name, const ge::TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - return operator_impl_->UpdateOutputDesc(op_name, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -TensorDesc Operator::GetDynamicInputDesc(const std::string &name, uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(name + std::to_string(index))); -} - -TensorDesc Operator::GetDynamicInputDesc(const char_t *name, uint32_t index) const { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(op_name + std::to_string(index))); -} - -graphStatus Operator::UpdateDynamicInputDesc(const std::string &name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - return operator_impl_->UpdateInputDesc(name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::UpdateDynamicInputDesc(const char_t *name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - const std::string op_name = name; - return operator_impl_->UpdateInputDesc(op_name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -TensorDesc Operator::GetDynamicOutputDesc(const std::string &name, uint32_t index) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] operator impl is nullptr."); - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(name + std::to_string(index))); -} - -TensorDesc Operator::GetDynamicOutputDesc(const char_t *name, uint32_t index) const { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return TensorDesc(), "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetOutputDesc(op_name + std::to_string(index))); -} - -graphStatus Operator::UpdateDynamicOutputDesc(const std::string &name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - return operator_impl_->UpdateOutputDesc(name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::UpdateDynamicOutputDesc(const char_t *name, uint32_t index, const TensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - return operator_impl_->UpdateOutputDesc(op_name + std::to_string(index), - TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::InferShapeAndType() { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return GRAPH_FAILED, "[Get][OpDescImpl] is nullptr."); - - return OpDescUtilsEx::CallInferFunc(operator_impl_->GetOpDescImpl(), *this); -} - -graphStatus Operator::VerifyAllAttr(bool disable_common_verifier) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return GRAPH_FAILED, "[Get][OpDescImpl] is nullptr."); - - if ((!disable_common_verifier) && (static_cast(Operator::VerifyAll()) == GRAPH_FAILED)) { - return GRAPH_FAILED; - } else { - return OpDescUtilsEx::OpVerify(operator_impl_->GetOpDescImpl()); - } -} - -GE_FUNC_HOST_VISIBILITY size_t Operator::GetInputsSize() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return 0UL, "[Check][Param] OperatorImpl_ is nullptr"); - return static_cast(operator_impl_->GetInputsSize()); -} - -GE_FUNC_HOST_VISIBILITY size_t Operator::GetOutputsSize() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return 0UL, "[Check][Param] OperatorImpl_ is nullptr"); - return static_cast(operator_impl_->GetOutputsSize()); -} - -// According to op get the attrs name and type -namespace { -const std::map kIrAttrTypesMap = { - {"Int", "VT_INT"}, - {"Float", "VT_FLOAT"}, - {"String", "VT_STRING"}, - {"Bool", "VT_BOOL"}, - {"Tensor", "VT_TENSOR"}, - {"NamedAttrs", "VT_NAMED_ATTRS"}, - {"ListInt", "VT_LIST_INT"}, - {"ListFloat", "VT_LIST_FLOAT"}, - {"ListString", "VT_LIST_STRING"}, - {"ListBool", "VT_LIST_BOOL"}, - {"ListTensor", "VT_LIST_TENSOR"}, - {"Bytes", "VT_BYTES"}, - {"ListListInt", "VT_LIST_LIST_INT"}, - {"ListNamedAttrs", "VT_LIST_NAMED_ATTRS"}, - {"Type", "VT_DATA_TYPE"}, - {"ListType", "VT_LIST_DATA_TYPE"}, -}; -} // namespace -const std::map Operator::GetAllAttrNamesAndTypes() const { - std::map attr_types; - - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return attr_types, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return attr_types, "[Get][OpDescImpl] is nullptr."); - const std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - for (const auto &iter : attr_map) { - const std::string name = iter.first; - const AnyValue::ValueType type = iter.second.GetValueType(); - attr_types[name] = AttrUtils::ValueTypeToSerialString(type); - } - - return attr_types; -} - -graphStatus Operator::GetAllAttrNamesAndTypes(std::map &attr_name_types) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return GRAPH_FAILED, "[Get][OpDescImpl] is nullptr."); - const std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - - for (const auto &iter : attr_map) { - const std::string name = iter.first; - const AnyValue::ValueType type = iter.second.GetValueType(); - const AscendString temp(name.c_str()); - attr_name_types[temp] = AscendString(AttrUtils::ValueTypeToSerialString(type).c_str()); - } - - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetAllIrAttrNamesAndTypes(std::map &attr_name_types) const { - GE_ASSERT_NOTNULL(operator_impl_); - GE_ASSERT_NOTNULL(operator_impl_->GetOpDescImpl()); - const std::map attr_map = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - const auto &ir_attrs_names = operator_impl_->GetOpDescImpl()->GetIrAttrNames(); - std::unordered_set ir_attrs_name_set(ir_attrs_names.begin(), ir_attrs_names.end()); - for (const auto &iter : attr_map) { - const std::string name = iter.first; - const AnyValue::ValueType type = iter.second.GetValueType(); - // save ir normal attr - if (ir_attrs_name_set.find(name) != ir_attrs_name_set.end()) { - const AscendString temp(name.c_str()); - attr_name_types[temp] = AscendString(AttrUtils::ValueTypeToSerialString(type).c_str()); - } - } - const auto &required_attrs = operator_impl_->GetOpDescImpl()->GetRequiredAttrWithType(); - // save ir required attr - for (const auto &iter : required_attrs) { - const AscendString name(iter.first.c_str()); - const AscendString type(iter.second.c_str()); - attr_name_types[name] = type; - } - return GRAPH_SUCCESS; -} - -void Operator::InputRegister(const std::string &name) { - InputRegister(name.c_str()); -} - -void Operator::InputRegister(const char_t *name) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); - operator_impl_->GetOpDescImpl()->AppendIrInput(name, kIrInputRequired); -} - -void Operator::InputRegister(const char_t *name, const char_t *datatype_symbol) { - GE_RETURN_IF_NULL(name, "[Check][Param] Operator name is nullptr."); - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - InputRegister(name); - operator_impl_->GetOpDescImpl()->SetInputDtypeSymbol(name, kIrInputRequired, datatype_symbol); -} - -void Operator::OptionalInputRegister(const std::string &name) { - OptionalInputRegister(name.c_str()); -} - -void Operator::OptionalInputRegister(const char_t *name) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, - GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); - operator_impl_->GetOpDescImpl()->AppendIrInput(name, kIrInputOptional); -} - -void Operator::OptionalInputRegister(const char_t *name, const char_t *datatype_symbol) { - GE_RETURN_IF_NULL(name, "[Check][Param] Operator name is nullptr."); - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - OptionalInputRegister(name); - operator_impl_->GetOpDescImpl()->SetInputDtypeSymbol(name, kIrInputOptional, datatype_symbol); -} - -void Operator::InferFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); -} - -void Operator::InferFormatFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); -} - -void Operator::VerifierFuncRegister(const std::function &func) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); -} - -void Operator::OutputRegister(const std::string &name) { - OutputRegister(name.c_str()); -} - -void Operator::OutputRegister(const char_t *name) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - // [No need to verify return value] - (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); - operator_impl_->GetOpDescImpl()->AppendIrOutput(name, kIrOutputRequired); -} - -void Operator::OutputRegister(const char_t *name, const char_t *datatype_symbol) { - GE_RETURN_IF_NULL(name, "[Check][Param] Operator name is nullptr."); - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - OutputRegister(name); - operator_impl_->GetOpDescImpl()->SetOutputDtypeSymbol(name, kIrOutputRequired, datatype_symbol); -} - -void Operator::DynamicInputRegister(const std::string &name, const uint32_t num, bool is_push_back) { - DynamicInputRegister(name.c_str(), num, is_push_back); -} - -void Operator::DynamicInputRegister(const char_t *name, const uint32_t num, bool is_push_back) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(std::string(name)), - static_cast(num)), - REPORT_INNER_ERR_MSG("E18888", "set attr %s to op:%s failed.", name, - operator_impl_->GetOpDescImpl()->GetName().c_str()); - return, "[Set][Int] %s to op:%s failed", name, operator_impl_->GetOpDescImpl()->GetName().c_str()); - (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); - if (num == 0U) { - operator_impl_->GetOpDescImpl()->AppendIrInput(name, kIrInputDynamic); - } -} - -void Operator::DynamicInputRegisterByIndex(const std::string &name, const uint32_t num, size_t index) { - DynamicInputRegisterByIndex(name.c_str(), num, index); -} - -void Operator::DynamicInputRegisterByIndex(const char_t *name, const uint32_t num, size_t index) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - (void)(operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index)); -} - -int32_t Operator::GetDynamicInputNum(const std::string &name) const { - return GetDynamicInputNum(name.c_str()); -} - -int32_t Operator::GetDynamicInputNum(const char_t *name) const { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return 0, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return 0, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return 0, "[Get][OpDescImpl] is nullptr."); - const std::string op_name = name; - int32_t num = 0; - GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(op_name), num), - REPORT_INNER_ERR_MSG("E18888", "get attr %s failed, op:%s.", op_name.c_str(), - operator_impl_->GetOpDescImpl()->GetName().c_str()); - return num, "[Get][Int] %s failed", op_name.c_str()); - return num; -} - -void Operator::DynamicOutputRegister(const std::string &name, const uint32_t num, bool is_push_back) { - DynamicOutputRegister(name.c_str(), num, is_push_back); -} - -void Operator::DynamicOutputRegister(const char_t *name, const uint32_t num, bool is_push_back) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(std::string(name)), - static_cast(num)), - REPORT_INNER_ERR_MSG("E18888", "set attr %s to op:%s failed.", name, - operator_impl_->GetOpDescImpl()->GetName().c_str()); - return, "[Set][Int] %s to op:%s failed", name, - operator_impl_->GetOpDescImpl()->GetName().c_str()); - (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); - if (num == 0U) { - operator_impl_->GetOpDescImpl()->AppendIrOutput(name, kIrOutputDynamic); - } -} - -int32_t Operator::GetDynamicOutputNum(const std::string &name) const { - return GetDynamicOutputNum(name.c_str()); -} - -int32_t Operator::GetDynamicOutputNum(const char_t *name) const { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - return 0, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return 0, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return 0, "[Get][OpDescImpl] is nullptr."); - const std::string op_name = name; - int32_t num = 0; - GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(op_name), num), - REPORT_INNER_ERR_MSG("E18888", "get attr %s failed, op:%s.", op_name.c_str(), - operator_impl_->GetOpDescImpl()->GetName().c_str()); - return num, "[Get][Init] %s failed", op_name.c_str()); - return num; -} - -void Operator::RequiredAttrRegister(const std::string &name) { - RequiredAttrRegister(name.c_str()); -} - -void Operator::RequiredAttrRegister(const char_t *name) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return, "[Get][OpDescImpl] is nullptr."); - (void)(operator_impl_->GetOpDescImpl()->AddRequiredAttr(name)); - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::RequiredAttrWithTypeRegister(const char_t *name, const char_t *type) { - if ((name == nullptr) || (operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] param is nullptr."); - return; - } - std::string ir_attr_type; - const auto iter = kIrAttrTypesMap.find(type); - if (iter != kIrAttrTypesMap.end()) { - ir_attr_type = iter->second; - } else { - GELOGW("[Check][Param] unsupport attr type %s from ir.", type); - } - (void) (operator_impl_->GetOpDescImpl()->AddRequiredAttrWithType(name, ir_attr_type)); - GELOGD("Save required attr name %s with type %s to op %s %s", name, ir_attr_type.c_str(), - operator_impl_->GetOpDescImpl()->GetNamePtr(), operator_impl_->GetOpDescImpl()->GetTypePtr()); - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::DataTypeRegister(const char_t *datatype_symbol, const TensorType &type_range) { - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - operator_impl_->GetOpDescImpl()->DeclareDtypeSymbol(datatype_symbol, type_range); -} - -void Operator::DataTypeRegister(const char_t *datatype_symbol, const ListTensorType &list_type_range) { - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - operator_impl_->GetOpDescImpl()->DeclareDtypeSymbol(datatype_symbol, list_type_range); -} - -void Operator::DataTypeRegister(const char_t *datatype_symbol, const Promote &promote_rule) { - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - operator_impl_->GetOpDescImpl()->DeclareDtypeSymbol(datatype_symbol, promote_rule); -} - -graphStatus Operator::VerifyAll() { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOpDescImpl failed, as return nullptr."); - return GRAPH_FAILED, "[Get][OpDescImpl] is nullptr."); - - // Check all inputs defined - for (const std::string &iname : operator_impl_->GetOpDescImpl()->GetAllInputNames()) { - GE_CHK_BOOL_RET_STATUS(operator_impl_->GetOpDescImpl()->IsOptionalInput(iname) || operator_impl_->InputIsSet(iname), - GRAPH_FAILED, "[Check][Param] operator input %s is not linked.", iname.c_str()); - const std::vector& ishape = operator_impl_->GetOpDescImpl()->GetInputDesc(iname).GetShape().GetDims(); - for (const int64_t &dim : ishape) { - GE_CHK_BOOL_RET_STATUS(dim > 0, GRAPH_FAILED, - "[Check][Param] operator input %s shape contains negative or zero dimension, " - "node:%s, index:%d.", - iname.c_str(), operator_impl_->GetOpDescImpl()->GetName().c_str(), - operator_impl_->GetOpDescImpl()->GetInputIndexByName(iname)); - } - } - // Check all attributes defined - const auto all_attributes = operator_impl_->GetOpDescImpl()->GetAllAttrs(); - for (const auto &name : operator_impl_->GetOpDescImpl()->GetAllAttrNames()) { - GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, - "[Check][Param] operator attribute %s is empty.", name.c_str()); - } - - return GRAPH_SUCCESS; -} - -std::string Operator::GetOpType() const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return "Data", "[Check][Param] operator impl is nullptr."); - return OperatorImpl::GetOpDesc(*this)->GetType(); -} - -graphStatus Operator::GetOpType(AscendString &type) const { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] Operator impl is nullptr."); - const std::string op_type = OperatorImpl::GetOpDesc(*this)->GetType(); - type = op_type.c_str(); - return GRAPH_SUCCESS; -} - -Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt) { - return SetInput(dst_name.c_str(), dst_index, src_oprt); -} - -Operator &Operator::SetInput(const char_t *dst_name, uint32_t dst_index, const ge::Operator &src_oprt) { - if (dst_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator dst_name is nullptr."); - return *this; - } - std::string dynamic_dst_name = dst_name + std::to_string(dst_index); - return SetInput(dynamic_dst_name.c_str(), src_oprt); -} - -Operator &Operator::SetInput(const std::string &dst_name, uint32_t dst_index, const ge::Operator &src_oprt, - const std::string &name) { - return SetInput(dst_name.c_str(), dst_index, src_oprt, name.c_str()); -} - -Operator &Operator::SetInput(const char_t *dst_name, uint32_t dst_index, const ge::Operator &src_oprt, - const char_t *name) { - if (dst_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator dst_name is nullptr."); - return *this; - } - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return *this; - } - std::string dynamic_dst_name = dst_name + std::to_string(dst_index); - return SetInput(dynamic_dst_name.c_str(), src_oprt, name); -} - -OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } - -void Operator::BreakConnect() const { - if (operator_impl_ == nullptr) { - GELOGW("[Check][Param] operator impl is nullptr"); - return; - } - operator_impl_->ClearInputLinks(); - operator_impl_->ClearOutputLinks(); - OperatorKeeper::GetInstance().CheckOutOperator(operator_impl_); -} - -OP_ATTR_SET_IMP(int64_t, Int) -OP_ATTR_SET_IMP(int32_t, Int) -OP_ATTR_SET_IMP(uint32_t, Int) -OP_ATTR_GET_IMP(int64_t &, Int) -OP_ATTR_GET_IMP(int32_t &, Int) -OP_ATTR_GET_IMP(uint32_t &, Int) -OP_ATTR_SET_IMP(const std::vector &, ListInt) -OP_ATTR_SET_IMP(const std::vector &, ListInt) -OP_ATTR_SET_IMP(const std::vector &, ListInt) -OP_ATTR_SET_IMP(std::initializer_list &&, ListInt) -OP_ATTR_GET_IMP(std::vector &, ListInt) -OP_ATTR_GET_IMP(std::vector &, ListInt) -OP_ATTR_GET_IMP(std::vector &, ListInt) -OP_ATTR_GET_IMP(std::vector> &, ListListInt) -OP_ATTR_SET_IMP(const std::vector> &, ListListInt) - -OP_ATTR_SET_IMP(float32_t, Float) -OP_ATTR_GET_IMP(float32_t &, Float) -OP_ATTR_SET_IMP(const std::vector &, ListFloat) -OP_ATTR_GET_IMP(std::vector &, ListFloat) - -OP_ATTR_SET_IMP(bool, Bool) -OP_ATTR_GET_IMP(bool &, Bool) -OP_ATTR_SET_IMP(const std::vector &, ListBool) -OP_ATTR_GET_IMP(std::vector &, ListBool) - -OP_ATTR_SET_IMP(const ge::NamedAttrs &, NamedAttrs) -OP_ATTR_GET_IMP(ge::NamedAttrs &, NamedAttrs) -OP_ATTR_SET_IMP(const std::vector &, ListNamedAttrs) -OP_ATTR_GET_IMP(std::vector &, ListNamedAttrs) - -OP_ATTR_REG_IMP(int64_t, Int) -OP_ATTR_REG_IMP(const std::vector &, ListInt) -OP_ATTR_REG_IMP(float32_t, Float) -OP_ATTR_REG_IMP(const std::vector &, ListFloat) -OP_ATTR_REG_IMP(bool, Bool) -OP_ATTR_REG_IMP(const std::vector &, ListBool) -OP_ATTR_REG_IMP(const std::vector> &, ListListInt) -OP_ATTR_REG_IMP(const ge::NamedAttrs &, NamedAttrs) -OP_ATTR_REG_IMP(const std::vector &, ListNamedAttrs) - -EDGE_ATTR_SET_IMP(int64_t, Int) -EDGE_ATTR_GET_IMP(int64_t &, Int) -EDGE_ATTR_SET_IMP(int32_t, Int) -EDGE_ATTR_GET_IMP(int32_t &, Int) -EDGE_ATTR_SET_IMP(uint32_t, Int) -EDGE_ATTR_GET_IMP(uint32_t &, Int) -EDGE_ATTR_SET_IMP(bool, Bool) -EDGE_ATTR_GET_IMP(bool &, Bool) -EDGE_ATTR_SET_IMP(float32_t, Float) -EDGE_ATTR_GET_IMP(float32_t &, Float) -EDGE_ATTR_SET_IMP(const std::vector &, ListInt) -EDGE_ATTR_GET_IMP(std::vector &, ListInt) -EDGE_ATTR_SET_IMP(const std::vector &, ListInt) -EDGE_ATTR_GET_IMP(std::vector &, ListInt) -EDGE_ATTR_SET_IMP(const std::vector &, ListInt) -EDGE_ATTR_GET_IMP(std::vector &, ListInt) -EDGE_ATTR_SET_IMP(const std::vector &, ListBool) -EDGE_ATTR_GET_IMP(std::vector &, ListBool) -EDGE_ATTR_SET_IMP(const std::vector &, ListFloat) -EDGE_ATTR_GET_IMP(std::vector &, ListFloat) - -void Operator::AttrRegister(const std::string &name, const std::string &attr_value) { - AttrRegister(name.c_str(), AscendString(attr_value.c_str())); -} - -void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) { - std::vector attr_values; - (void)std::transform(attr_value.begin(), attr_value.end(), attr_values.begin(), - [](const std::string &val) { - return AscendString(val.c_str()); - }); - AttrRegister(name.c_str(), attr_values); -} - -void Operator::AttrRegister(const std::string &name, int64_t attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, const vector &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, float32_t attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, const vector &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const char_t *name, const char_t *attr_value) { - const AscendString op_attr_value = attr_value; - return AttrRegister(name, op_attr_value); -} - -void Operator::AttrRegister(const std::string &name, bool attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, const vector &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, const vector> &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, const NamedAttrs &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, const vector &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const std::string &name, const AscendString &attr_value) { - if (attr_value.GetString() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Attr %s register param is invalid.", name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Attr %s register param is invalid.", name.c_str()); - return; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name.c_str()); - return; - } - const std::string str_attr_value = attr_value.GetString(); - if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), name, str_attr_value)) { - GELOGW("[Register][Attr] Reg attr name %s unsuccessful", name.c_str()); - } -} - -void Operator::AttrRegister(const char_t *name, const AscendString &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - const std::string op_name = name; - if (attr_value.GetString() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Attr %s register param is invalid.", name); - GELOGE(GRAPH_FAILED, "[Check][Param] Attr %s register param is invalid.", name); - return; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", op_name.c_str()); - return; - } - const std::string str_attr_value = attr_value.GetString(); - if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), op_name, str_attr_value)) { - GELOGW("[Register][Attr] Reg attr name %s unsuccessful", op_name.c_str()); - } - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const char_t *name, const std::vector &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - std::vector str_attr_values; - for (auto &val : attr_value) { - if (val.GetString() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Attr %s register value is invalid.", name); - GELOGE(GRAPH_FAILED, "Attr %s register value is invalid.", name); - return; - } - str_attr_values.emplace_back(val.GetString()); - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return; - } - if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), name, str_attr_values)) { - GELOGW("[Register][Attr] Reg attr name %s unsuccessful", name); - } - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::AttrRegister(const char_t *name, const AttrValue &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return; - } - (void) SetAttr(name, attr_value); - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -Operator &Operator::SetAttr(const std::string &name, const std::string &attr_value) { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const std::string &name, std::string &attr_value) const { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const std::string &name, const std::vector &attr_value) { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const std::string &name, std::vector &attr_value) const { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetListStr(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const char_t *name, const char_t *attr_value) { - if ((name == nullptr) || (attr_value == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr or attr_value is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return *this; - } - - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value; - if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetInputAttr(const int32_t index, const char_t *name, const char_t *attr_value) { - if ((name == nullptr) || (attr_value == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value; - auto tensor = operator_impl_->MutableInputDesc(index); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set input[%d] attr name %s unsuccessful", index, op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetInputAttr(const char_t *dst_name, const char_t *name, const char_t *attr_value) { - if ((dst_name == nullptr) || (name == nullptr) || (attr_value == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value; - auto tensor = operator_impl_->MutableInputDesc(dst_name); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get input[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get index[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set input[%s] attr name %s unsuccessful", dst_name, op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetOutputAttr(const int32_t index, const char_t *name, const char_t *attr_value) { - if ((name == nullptr) || (attr_value == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value; - auto tensor = operator_impl_->MutableOutputDesc(index); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set output[%d] attr name %s unsuccessful", index, op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetOutputAttr(const char_t *dst_name, const char_t *name, const char_t *attr_value) { - if ((dst_name == nullptr) || (name == nullptr) || (attr_value == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator output parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value; - auto tensor = operator_impl_->MutableOutputDesc(dst_name); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get output[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get output[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set output[%s] attr name %s unsuccessful", dst_name, op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetInputAttr(const int32_t index, const char_t *name, const AscendString &attr_value) { - if ((name == nullptr) || (attr_value.GetString() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value.GetString(); - auto tensor = operator_impl_->MutableInputDesc(index); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set input[%d] attr name %s unsuccessful", index, op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetInputAttr(const char_t *dst_name, const char_t *name, const AscendString &attr_value) { - if ((dst_name == nullptr) || (name == nullptr) || (attr_value.GetString() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value.GetString(); - auto tensor = operator_impl_->MutableInputDesc(dst_name); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get input[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get index[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set input[%s] attr name %s unsuccessful", dst_name, op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value) { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, check invalid"); - return *this; - } - if (SetTensorAttr(operator_impl_->MutableInputDesc(index), name, attr_value) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Set input attr failed, input index[%d], please check.", index); - GELOGE(GRAPH_FAILED, "[Set][InputAttr] failed, input index[%d], please check.", index); - return *this; - } - return *this; -} - -Operator &Operator::SetOutputAttr(const int32_t index, const char_t *name, - const std::vector &attr_value) { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, check invalid"); - return *this; - } - if (SetTensorAttr(operator_impl_->MutableOutputDesc(index), name, attr_value) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Set output attr failed, output index[%d], please check.", index); - GELOGE(GRAPH_FAILED, "[Set][OutputAttr] failed, output index[%d], please check.", index); - return *this; - } - return *this; -} - -Operator &Operator::SetInputAttr(const char_t *dst_name, const char_t *name, - const std::vector &attr_value) { - if ((operator_impl_ == nullptr) || (dst_name == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl or dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl or dst_name is nullptr, check invalid"); - return *this; - } - if (SetTensorAttr(operator_impl_->MutableInputDesc(dst_name), name, attr_value) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Set input attr failed, input dst_name[%s], please check.", dst_name); - GELOGE(GRAPH_FAILED, "[Set][InputAttr] failed, input dst_name[%s], please check.", dst_name); - return *this; - } - return *this; -} - -Operator &Operator::SetOutputAttr(const char_t *dst_name, const char_t *name, - const std::vector &attr_value) { - if ((operator_impl_ == nullptr) || (dst_name == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl or dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl or dst_name is nullptr, check invalid"); - return *this; - } - if (SetTensorAttr(operator_impl_->MutableOutputDesc(dst_name), name, attr_value) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Set output attr failed, output dst_name[%s], please check.", dst_name); - GELOGE(GRAPH_FAILED, "[Set][OutputAttr] failed, output dst_name[%s], please check.", dst_name); - return *this; - } - return *this; -} - -graphStatus Operator::GetInputAttr(const char_t *dst_name, const char_t *name, - std::vector &attr_value) const { - if ((operator_impl_ == nullptr) || (dst_name == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl or dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl or dst_name is nullptr, check invalid"); - return GRAPH_FAILED; - }; - GE_ASSERT_GRAPH_SUCCESS(GetTensorAttr(operator_impl_->GetInputDesc(dst_name), name, attr_value), - "[Get][InputAttr] failed, dst_name[%s].", dst_name); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetOutputAttr(const char_t *dst_name, const char_t *name, - std::vector &attr_value) const { - if ((operator_impl_ == nullptr) || (dst_name == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl or dst_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl or dst_name is nullptr, check invalid"); - return GRAPH_FAILED; - } - GE_ASSERT_GRAPH_SUCCESS(GetTensorAttr(operator_impl_->GetOutputDesc(dst_name), name, attr_value), - "[Get][OutputAttr] failed, dst_name[%s].", dst_name); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetInputAttr(const int32_t index, const char_t *name, - std::vector &attr_value) const { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, check invalid"); - return GRAPH_FAILED; - } - GE_ASSERT_GRAPH_SUCCESS(GetTensorAttr(operator_impl_->GetInputDesc(index), name, attr_value), - "[Get][InputAttr] failed, index[%d].", index); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetOutputAttr(const int32_t index, const char_t *name, - std::vector &attr_value) const { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator impl is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, check invalid"); - return GRAPH_FAILED; - } - GE_ASSERT_GRAPH_SUCCESS(GetTensorAttr(operator_impl_->GetOutputDesc(index), name, attr_value), - "[Get][OutputAttr] failed, index[%d].", index); - return GRAPH_SUCCESS; -} - -Operator &Operator::SetOutputAttr(const int32_t index, const char_t *name, const AscendString &attr_value) { - if ((name == nullptr) || (attr_value.GetString() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator output parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value.GetString(); - auto tensor = operator_impl_->MutableOutputDesc(index); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get index[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set output[%d] attr name %s unsuccessful", index, op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetOutputAttr(const char_t *dst_name, const char_t *name, const AscendString &attr_value) { - if ((dst_name == nullptr) || (name == nullptr) || (attr_value.GetString() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator output parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value.GetString(); - auto tensor = operator_impl_->MutableOutputDesc(dst_name); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get output[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get index[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - return *this; - } - if (!AttrUtils::SetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set output[%s] attr name %s unsuccessful", dst_name, op_name.c_str()); - } - return *this; -} - -graphStatus Operator::GetOutputAttr(const char_t *dst_name, const char_t *name, AscendString &attr_value) const { - if ((dst_name == nullptr) || (name == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator name parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator output parameters is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - std::string op_attr_value; - auto tensor = operator_impl_->MutableOutputDesc(dst_name); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get output[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get output[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Get][Attr] Get output[%s] attr name %s unsuccessful", dst_name, op_name.c_str()); - return GRAPH_FAILED; - } - attr_value = AscendString(op_attr_value.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetInputAttr(const char_t *dst_name, const char_t *name, AscendString &attr_value) const { - if ((dst_name == nullptr) || (name == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator name parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - std::string op_attr_value; - auto tensor = operator_impl_->MutableInputDesc(dst_name); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get input[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get input[%s] of op[%s] failed, check invalid", dst_name, op_name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Get][Attr] Get input[%s] attr name %s unsuccessful", dst_name, op_name.c_str()); - return GRAPH_FAILED; - } - attr_value = AscendString(op_attr_value.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetInputAttr(const int32_t index, const char_t *name, AscendString &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator name parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - std::string op_attr_value; - auto tensor = operator_impl_->MutableInputDesc(index); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get input[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get input[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Get][Attr] Get input[%d] attr name %s unsuccessful", index, op_name.c_str()); - return GRAPH_FAILED; - } - attr_value = AscendString(op_attr_value.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetOutputAttr(const int32_t index, const char_t *name, AscendString &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator output parameters is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - std::string op_attr_value; - auto tensor = operator_impl_->MutableOutputDesc(index); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get output[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Get output[%d] of op[%s] failed, check invalid", index, op_name.c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetStr(tensor, op_name, op_attr_value)) { - GELOGW("[Get][Attr] Get output[%d] attr name %s unsuccessful", index, op_name.c_str()); - return GRAPH_FAILED; - } - attr_value = AscendString(op_attr_value.c_str()); - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const char_t *name, const AscendString &attr_value) { - if ((name == nullptr) || (attr_value.GetString() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Operator input parameters is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const std::string op_attr_value = attr_value.GetString(); - if (!AttrUtils::SetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const char_t *name, AscendString &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator input parameters name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator input parameters is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - std::string op_attr_value; - if (!AttrUtils::GetStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_value)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); - return GRAPH_FAILED; - } - attr_value = AscendString(op_attr_value.c_str()); - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const char_t *name, const std::vector &attr_values) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator input parameters name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return *this; - } - std::vector op_attr_values; - for (auto &attr_value : attr_values) { - if (attr_value.GetString() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator ascend std::string name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator ascend std::string name is nullptr."); - return *this; - } - op_attr_values.emplace_back(attr_value.GetString()); - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - if (!AttrUtils::SetListStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_values)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_values) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator input parameters name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - std::vector op_attr_values; - if (!AttrUtils::GetListStr(operator_impl_->GetOpDescImpl(), op_name, op_attr_values)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); - return GRAPH_FAILED; - } - for (auto &op_attr_value : op_attr_values) { - attr_values.emplace_back(AscendString(op_attr_value.c_str())); - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const std::string &name, const Tensor &attr_value) { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - const GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); - if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name.c_str()); - } - return *this; -} - -Operator &Operator::SetAttr(const char_t *name, const Tensor &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - const GeTensor tensor = TensorAdapter::AsGeTensor(attr_value); - if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), op_name, tensor)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -Operator &Operator::SetAttr(const std::string &name, const std::vector &attr_value) { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - std::vector val_list; - for (const auto &item : attr_value) { - const auto tensor = TensorAdapter::AsGeTensor(item); - val_list.push_back(tensor); - } - if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name.c_str()); - } - return *this; -} - -Operator &Operator::SetAttr(const char_t *name, const std::vector &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - std::vector val_list; - for (const auto &item : attr_value) { - const auto tensor = TensorAdapter::AsGeTensor(item); - val_list.push_back(tensor); - } - if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), op_name, val_list)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const std::string &name, Tensor &attr_value) const { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - ConstGeTensorPtr tensor; - if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", name.c_str()); - return GRAPH_FAILED; - } - attr_value = TensorAdapter::GeTensor2Tensor(tensor); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetAttr(const char_t *name, Tensor &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - ConstGeTensorPtr tensor; - if (!AttrUtils::GetTensor(operator_impl_->GetOpDescImpl(), op_name, tensor)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); - return GRAPH_FAILED; - } - attr_value = TensorAdapter::GeTensor2Tensor(tensor); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetAttr(const std::string &name, std::vector &attr_value) const { - attr_value.clear(); - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - std::vector val_list; - if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", name.c_str()); - return GRAPH_FAILED; - } - for (auto &tensor : val_list) { - attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); - } - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return GRAPH_FAILED; - } - attr_value.clear(); - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - std::vector val_list; - if (!AttrUtils::GetListTensor(operator_impl_->GetOpDescImpl(), op_name, val_list)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); - return GRAPH_FAILED; - } - for (auto &tensor : val_list) { - attr_value.push_back(TensorAdapter::GeTensor2Tensor(tensor)); - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const std::string &name, const OpBytes &attr_value) { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); - return *this; - } - if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, - Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name.c_str()); - } - return *this; -} - -Operator &Operator::SetAttr(const char_t *name, const OpBytes &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), op_name, - Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const std::string &name, OpBytes &attr_value) const { - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", name.c_str()); - return GRAPH_FAILED; - } - Buffer buffer; - if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, buffer)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", name.c_str()); - return GRAPH_FAILED; - } - attr_value.clear(); - if (buffer.data() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "buffer data is null, op:%s", operator_impl_->GetOpDescImpl()->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] buffer data is null."); - return GRAPH_FAILED; - } - attr_value.assign(buffer.data(), buffer.data() + buffer.size()); - return GRAPH_SUCCESS; -} - -graphStatus Operator::GetAttr(const char_t *name, OpBytes &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - Buffer buffer; - if (!AttrUtils::GetZeroCopyBytes(operator_impl_->GetOpDescImpl(), op_name, buffer)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); - return GRAPH_FAILED; - } - attr_value.clear(); - if (buffer.data() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "buffer data is null, op:%s", operator_impl_->GetOpDescImpl()->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] buffer data is null."); - return GRAPH_FAILED; - } - attr_value.assign(buffer.data(), buffer.data() + buffer.size()); - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const std::string &name, ge::AttrValue &&attr_value) { - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - return *this, "[Check][Param] Operator impl is nullptr."); - (void)operator_impl_->SetAttr(name, std::move(attr_value.impl->geAttrValue_)); - return *this; -} - -Operator &Operator::SetAttr(const char_t *name, ge::AttrValue &&attr_value) { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - return *this, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - return *this, "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - (void)operator_impl_->SetAttr(op_name, std::move(attr_value.impl->geAttrValue_)); - return *this; -} - -Operator &Operator::SetAttr(const char_t *name, const ge::AttrValue &attr_value) { - GE_CHK_BOOL_EXEC(name != nullptr, REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - return *this, "[Check][Param] Operator name is nullptr."); - GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - return *this, "[Check][Param] Operator impl is nullptr."); - const std::string op_name = name; - (void)operator_impl_->SetAttr(op_name, attr_value.impl->geAttrValue_); - return *this; -} - -graphStatus Operator::GetAttr(const std::string &name, ge::AttrValue &attr_value) const { - return GetAttr(name.c_str(), attr_value); -} - -graphStatus Operator::GetAttr(const char_t *name, ge::AttrValue &attr_value) const { - GE_CHECK_NOTNULL(name); - GE_CHECK_NOTNULL(operator_impl_); - GE_CHECK_NOTNULL(attr_value.impl); - return operator_impl_->GetAttr(name, attr_value.impl->geAttrValue_); -} - -Operator &Operator::SetAttr(const std::string &name, const std::vector &attr_value) { - return SetAttr(name.c_str(), attr_value); -} - -Operator &Operator::SetAttr(const char_t *name, const std::vector &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (!operator_impl_->GetOpDescImpl())) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const std::string &name, std::vector &attr_value) const { - return GetAttr(name.c_str(), attr_value); -} - -graphStatus Operator::GetAttr(const char_t *name, std::vector &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return GRAPH_FAILED; - } - attr_value.clear(); - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - if (!AttrUtils::GetListDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -Operator &Operator::SetAttr(const std::string &name, const ge::DataType &attr_value) { - return SetAttr(name.c_str(), attr_value); -} - -Operator &Operator::SetAttr(const char_t *name, const ge::DataType &attr_value) { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return *this; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return *this; - } - const std::string op_name = name; - if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", op_name.c_str()); - } - return *this; -} - -graphStatus Operator::GetAttr(const std::string &name, ge::DataType &attr_value) const { - return GetAttr(name.c_str(), attr_value); -} - -graphStatus Operator::GetAttr(const char_t *name, ge::DataType &attr_value) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return GRAPH_FAILED; - } - if ((operator_impl_ == nullptr) || (operator_impl_->GetOpDescImpl() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr or opdesc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr, name %s.", name); - return GRAPH_FAILED; - } - const std::string op_name = name; - if (!AttrUtils::GetDataType(operator_impl_->GetOpDescImpl(), op_name, attr_value)) { - GELOGW("[Get][Attr] Get attr name %s unsuccessful", op_name.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const char_t *name, const std::vector &attr_value) { - GE_CHECK_NOTNULL_JUST_RETURN(name); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_->GetOpDescImpl()); - if (!AttrUtils::SetListDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name); - } - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::AttrRegister(const std::string &name, const ge::DataType &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const char_t *name, const ge::DataType &attr_value) { - GE_CHECK_NOTNULL_JUST_RETURN(name); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_->GetOpDescImpl()); - if (!AttrUtils::SetDataType(operator_impl_->GetOpDescImpl(), name, attr_value)) { - GELOGW("[Set][Attr] Set attr name %s unsuccessful", name); - } - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::AttrRegister(const std::string &name, const Tensor &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const char_t *name, const Tensor &attr_value) { - GE_CHECK_NOTNULL_JUST_RETURN(name); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_->GetOpDescImpl()); - const GeTensor &tensor = TensorAdapter::AsGeTensor(attr_value); - if (!AttrUtils::SetTensor(operator_impl_->GetOpDescImpl(), name, tensor)) { - GELOGW("[Register][Attr] Reg attr name %s unsuccessful", name); - } - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::AttrRegister(const std::string &name, const std::vector &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const char_t *name, const vector &attr_value) { - GE_CHECK_NOTNULL_JUST_RETURN(name); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_->GetOpDescImpl()); - std::vector val_list; - for (const auto &item : attr_value) { - val_list.push_back(TensorAdapter::AsGeTensor(item)); - } - if (!AttrUtils::SetListTensor(operator_impl_->GetOpDescImpl(), name, val_list)) { - GELOGW("[Register][Attr] Reg attr name %s unsuccessful", name); - } - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::AttrRegister(const std::string &name, const OpBytes &attr_value) { - return AttrRegister(name.c_str(), attr_value); -} - -void Operator::AttrRegister(const char_t *name, const OpBytes &attr_value) { - GE_CHECK_NOTNULL_JUST_RETURN(name); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_); - GE_CHECK_NOTNULL_JUST_RETURN(operator_impl_->GetOpDescImpl()); - if (!AttrUtils::SetZeroCopyBytes(operator_impl_->GetOpDescImpl(), name, - Buffer::CopyFrom(attr_value.data(), attr_value.size()))) { - GELOGW("[Register][Attr] Reg attr name %s unsuccessful", name); - } - operator_impl_->GetOpDescImpl()->AppendIrAttrName(name); -} - -void Operator::SubgraphRegister(const std::string &ir_name, bool dynamic) { - return SubgraphRegister(ir_name.c_str(), dynamic); -} - -void Operator::SubgraphRegister(const char_t *ir_name, bool dynamic) { - if (ir_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", ir_name); - return; - } - operator_impl_->SubgraphRegister(ir_name, dynamic ? static_cast(kDynamic) : static_cast(kStatic)); -} - -void Operator::SubgraphCountRegister(const std::string &ir_name, uint32_t count) { - return SubgraphCountRegister(ir_name.c_str(), count); -} - -void Operator::SubgraphCountRegister(const char_t *ir_name, uint32_t count) { - if (ir_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return; - } - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", ir_name); - return; - } - operator_impl_->SubgraphCountRegister(ir_name, count); -} - -void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { - return SetSubgraphBuilder(ir_name.c_str(), index, builder); -} - -void Operator::SetSubgraphBuilder(const char_t *ir_name, uint32_t index, const SubgraphBuilder &builder) { - if (ir_name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param ir_name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator ir_name is nullptr."); - return; - } - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr, name %s.", ir_name); - return; - } - operator_impl_->SetSubgraphBuilder(ir_name, index, builder); -} - -std::vector Operator::GetSubgraphNames() const { - return operator_impl_->GetSubgraphNames(); -} - -graphStatus Operator::GetSubgraphNames(std::vector &names) const { - const std::vector subgraph_names = operator_impl_->GetSubgraphNames(); - for (auto &subgraph_name : subgraph_names) { - names.emplace_back(subgraph_name.c_str()); - } - return GRAPH_SUCCESS; -} - -SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const std::string &name, uint32_t index) const { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] operator impl is nullptr."); - return nullptr; - } - return operator_impl_->GetSubgraphBuilder(name, index); -} - -SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const char_t *name, uint32_t index) const { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator impl is nullptr."); - return nullptr; - } - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param ir_name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return nullptr; - } - const std::string op_ir_name = name; - return operator_impl_->GetSubgraphBuilder(op_ir_name, index); -} - -SubgraphBuilder Operator::GetSubgraphBuilder(const std::string &name) const { - return GetDynamicSubgraphBuilder(name.c_str(), 0U); -} - -SubgraphBuilder Operator::GetSubgraphBuilder(const char_t *name) const { - return GetDynamicSubgraphBuilder(name, 0U); -} - -Graph Operator::GetSubgraphImpl(const std::string &name) const { - return GetSubgraphImpl(name.c_str()); -} - -Graph Operator::GetSubgraphImpl(const char_t *name) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return Graph(""); - } - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GE_LOGE("[Check][Param] Failed to get subgraph %s, the operator impl is null", name); - return Graph(""); - } - const auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get subgraph %s, because the op_desc is nullptr.", name); - GE_LOGE("[Get][OpDesc] Failed to get subgraph %s, the op_desc is null", name); - return Graph(""); - } - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - const auto iter = subgraph_names_to_index.find(name); - if (iter == subgraph_names_to_index.end()) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get subgraph %s, the name may be invalid", name); - GE_LOGE("[Check][Param] Failed to get subgraph %s, the name may be invalid", name); - return Graph(""); - } - const auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); - if (subgraph_instance_name.empty()) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get subgraph %s index %u, the subgraph may not be added", name, - iter->second); - GE_LOGE("[Get][Subgraph] %s index %u failed, because the subgraph may not be added", name, iter->second); - return Graph(""); - } - - const auto node = operator_impl_->GetNode(); - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get subgraph %s, because the node is null", name); - GE_LOGE("[Get][Node] Failed to get subgraph %s, because the node is null", name); - return Graph(""); - } - const auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get subgraph %s, because can not find the root graph,node:%s", name, - node->GetName().c_str()); - GE_LOGE("[Get][Subgraph] subgraph %s failed, because can not find the root graph", name); - return Graph(""); - } - const auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); - if (subgraph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to get subgraph %s index %u, because can not find the instance %s " - "from the root graph", - name, iter->second, subgraph_instance_name.c_str()); - GE_LOGE("[Get][Subgraph] %s index %u failed, because can not find the instance %s from the root graph", name, - iter->second, subgraph_instance_name.c_str()); - return Graph(""); - } - return GraphUtilsEx::CreateGraphFromComputeGraph(subgraph); -} - -Graph Operator::GetSubgraph(const std::string &name) const { - return GetSubgraphImpl(name.c_str()); -} - -Graph Operator::GetSubgraph(const char_t *name) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Get subgraph failed, name is nullptr."); - return Graph(""); - } - return GetSubgraphImpl(name); -} - -Graph Operator::GetDynamicSubgraph(const std::string &name, uint32_t index) const { - return GetSubgraph((name + std::to_string(index)).c_str()); -} - -Graph Operator::GetDynamicSubgraph(const char_t *name, uint32_t index) const { - if (name == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param name is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator name is nullptr."); - return Graph(""); - } - const std::string op_name = name + std::to_string(index); - return GetSubgraph(op_name.c_str()); -} - -size_t Operator::GetSubgraphNamesCount() const { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GE_LOGE("[Check][Param] Failed to get subgraph names count, the operator impl is null"); - return 0UL; - } - return operator_impl_->GetSubgraphNamesCount(); -} - -graphStatus Operator::UpdateInputDesc(const uint32_t index, const TensorDesc &tensor_desc) { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to update input desc, the operator impl is null"); - return GRAPH_FAILED; - } - return operator_impl_->UpdateInputDesc(index, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -graphStatus Operator::UpdateOutputDesc(const uint32_t index, const TensorDesc &tensor_desc) { - if (operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "operator_impl_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to update output desc, the operator impl is null"); - return GRAPH_FAILED; - } - return operator_impl_->UpdateOutputDesc(index, TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); -} - -class GraphBuilderImpl { -public: - explicit GraphBuilderImpl(const std::string &name) : graph_(ComGraphMakeShared(name)) { - if (graph_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "ComputeGraph make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] ComputeGraph make shared failed"); - return; - } - } - - ~GraphBuilderImpl() {} - - ComputeGraphPtr BuildGraph(const std::vector &inputs) { - std::vector vec_inputs; - for (auto &it : inputs) { - const auto src_op_impl = it.operator_impl_; - GE_CHK_BOOL_EXEC(src_op_impl != nullptr, REPORT_INNER_ERR_MSG("E18888", "src_op_impl is nullptr, check invalid."); - return nullptr, "[Check][Param] Operator Impl is null."); - GE_CHK_BOOL_EXEC(src_op_impl->op_desc_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "impl's opdesc is nullptr, check invalid."); - return nullptr, "[Check][Param] Operator impl's opdesc is null."); - - const std::string type = src_op_impl->op_desc_->GetType(); - const auto node_op = ge::OperatorFactory::CreateOperator("node_op", type.c_str()); - const auto tensor_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - node_op.BreakConnect(); - - GE_CHK_BOOL_EXEC(tensor_desc != nullptr, continue, "[Get][Opdesc] tensor_desc is null."); - if (((tensor_desc->GetInputsSize() == 0UL) && (tensor_desc->GetOutputsSize() > 0UL)) || - OpTypeUtils::IsDataNode(type) || (type == VARIABLE) || (type == INITDATA) || (type == GETNEXT)) { - vec_inputs.push_back(it.operator_impl_); - } else { - GELOGW("[BuildGraph][CheckInput] Input operator should be Data, Variable operator or operator that has output " - "but no input."); - } - } - GE_CHK_BOOL_EXEC(!vec_inputs.empty(), - REPORT_INNER_ERR_MSG("E18888", - "User Input do not include operator such as " - "Data, Variable operator or operator that has output but no input."); - return nullptr, - "[Check][Param] User Input do not include operator such as " - "Data, Variable operator or operator that has output but no input."); - auto ret = WalkAllOperators(vec_inputs); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "[Call][WalkAllOperators] failed, ret:%d.", ret); - - ret = AddEdge(); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "[Add][Edge] failed, ret:%d.", ret); - - return graph_; - } - - ComputeGraphPtr BuildGraphWithStableSort(const std::vector &ops) { - uint32_t input_size = 0UL; - for (auto &op : ops) { - AscendString op_type; - GE_ASSERT_SUCCESS(op.GetOpType(op_type)); - if (OpTypeUtils::IsDataNode(op_type.GetString())) { - input_size++; - } - GE_ASSERT_SUCCESS(GenerateNodeFromOperator(op)); - } - GE_ASSERT_SUCCESS(AddEdge()); - const int32_t start_recursion_depth = 0; - GE_ASSERT_SUCCESS(MoveSubgraphToRoot(start_recursion_depth, graph_)); - graph_->SetInputSize(input_size); - return graph_; - } - - const std::map &GetAllNodesInfo() const { return all_nodes_info_; } - -private: - graphStatus GenerateNodeFromOperator(const Operator &op) { - GE_ASSERT_NOTNULL(graph_); - const auto op_impl = op.operator_impl_; - GE_ASSERT_NOTNULL(op_impl); - auto node_ptr = graph_->AddNode(op_impl->op_desc_); - GE_ASSERT_NOTNULL(node_ptr); - (void)all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); - (void)stable_all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); - GE_ASSERT_SUCCESS(WalkAllSubgraphs(node_ptr, op_impl)); - return GRAPH_SUCCESS; - } - - graphStatus WalkAllOperators(const std::vector &vec_ops) { - GE_CHK_BOOL_EXEC(graph_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "graph_ is nullptr, check invalid."); - return GRAPH_FAILED, "[Check][Param] graph_ is null."); - std::queue> que; - que.push(vec_ops); - while (!que.empty()) { - const auto vec_tem = que.front(); - que.pop(); - for (const auto &op_impl : vec_tem) { - GE_CHK_BOOL_EXEC(op_impl != nullptr, - REPORT_INNER_ERR_MSG("E18888", "op_impl is nullptr, check invalid."); - return GRAPH_FAILED, "[Check][Param] Operator Impl is null."); - if (all_nodes_info_.find(op_impl) != all_nodes_info_.cend()) { - GELOGI("This node %s has created.", op_impl->GetName().c_str()); - continue; - } - auto node_ptr = graph_->AddNode(op_impl->op_desc_); - GE_CHK_BOOL_EXEC(node_ptr != nullptr, - REPORT_INNER_ERR_MSG("E18888", "add node failed."); - return GRAPH_FAILED, "[Add][Node] failed."); - (void)(all_nodes_info_.insert(std::make_pair(op_impl, node_ptr))); - (void)(stable_all_nodes_info_.insert(std::make_pair(op_impl, node_ptr))); - - auto &out_links = op_impl->output_links_; - std::vector vec_op_forward{}; - for (const auto &out_link : out_links) { - for (const auto &op_forward : out_link.second) { - vec_op_forward.push_back(op_forward.GetOwner()); - } - } - - auto &out_control_links = op_impl->control_output_link_; - for (const auto &out_link : out_control_links) { - vec_op_forward.push_back(out_link.lock()); - } - que.push(vec_op_forward); - - auto &in_links = op_impl->input_link_; - std::vector vec_op_back_forward{}; - for (const auto &in_link : in_links) { - vec_op_back_forward.push_back(in_link.second.GetOwner()); - } - - auto &in_control_links = op_impl->control_input_link_; - for (const auto &in_link : in_control_links) { - vec_op_back_forward.push_back(in_link.lock()); - } - que.push(vec_op_back_forward); - - if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } - return MoveSubgraphToRoot(0, graph_); - } - - graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { - const std::string name = node->GetName(); - for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { - const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); - if (builder == nullptr) { - GELOGW("[Check][Param] Node %s has no builder", name.c_str()); - continue; - } - - const Graph graph = builder(); // Build subgraph from user define builder. - const ComputeGraphPtr &subgraph = GraphUtilsEx::GetComputeGraph(graph); - GE_CHK_BOOL_EXEC(subgraph != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Node: %s, Build graph failed.", name.c_str()); - return GRAPH_FAILED, "[Get][Graph] Node: %s, Build graph failed.", name.c_str()); - - subgraph->SetParentNode(node); - subgraph->SetParentGraph(graph_); - if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); - GELOGE(GRAPH_FAILED, "[Set][SubGraph] %s index %u failed", subgraph->GetName().c_str(), name_idx.second); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; - } - - graphStatus MoveSubgraphToRoot(const int32_t recursion_depth, const ComputeGraphPtr &graph) const { - if (recursion_depth > kMaxDepth) { - REPORT_INNER_ERR_MSG("E18888", "param recursion_depth:%d is bigger than kMaxRecursiveDepth:%d", recursion_depth, - kMaxDepth); - GELOGE(GRAPH_FAILED, "[Check][Param] DecodeGraph: recursion depth is too large, abort"); - return GRAPH_FAILED; - } - const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "failed to find root graph of %s", graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Find][RootGraph] failed for graph:%s.", graph->GetName().c_str()); - return GRAPH_FAILED; - } - - if (root_graph == graph) { - const auto subgraphs = graph->GetAllSubgraphs(); - - for (auto &subgraph : subgraphs) { - if (MoveSubgraphToRoot(recursion_depth + 1, subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } else { - const auto subgraphs = graph->GetAllSubgraphs(); - for (auto &subgraph : subgraphs) { - if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - graph->RemoveSubgraph(subgraph->GetName()); - if (MoveSubgraphToRoot(recursion_depth + 1, subgraph) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; - } - - graphStatus AddEdge() { - for (const auto &node_info : stable_all_nodes_info_) { - const auto src_op_impl_ptr = node_info.first; - const auto src_node_ptr = node_info.second; - - GE_IF_BOOL_EXEC((src_op_impl_ptr == nullptr) || (src_node_ptr == nullptr), continue); - const auto out_links = src_op_impl_ptr->output_links_; - GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Src operator impl's op_desc is nullptr, check invalid."); - return GRAPH_FAILED, "[Check][Param] Src operator impl's op_desc is null."); - auto &op_desc = src_op_impl_ptr->op_desc_; - GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - for (const auto &out : out_links) { - const auto src_idx = op_desc->GetOutputIndexByName(out.first); - GE_CHK_BOOL_EXEC(src_idx >= 0, - REPORT_INNER_ERR_MSG("E18888", "Find output index by name:%s in op:%s failed", - out.first.c_str(), op_desc->GetName().c_str()); - return GRAPH_FAILED, "[Get][Index] Find output index by name:%s failed", out.first.c_str()); - - const auto src_anchor = src_node_ptr->GetOutDataAnchor(src_idx); - GE_CHK_BOOL_EXEC(src_anchor != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetOutDataAnchor failed, index:%d, op:%s.", src_idx, - op_desc->GetName().c_str()); - return GRAPH_FAILED, "[Get][OutDataAnchor] failed, index:%d.", src_idx); - - for (const auto &dst_opio : out.second) { - const std::map::const_iterator dst_node_info = - all_nodes_info_.find(dst_opio.GetOwner()); - GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.cend(), - REPORT_INNER_ERR_MSG("E18888", "Find Dst node failed, op:%s.", op_desc->GetName().c_str()); - return GRAPH_FAILED, "[Check][Param] Find Dst node failed."); - - GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); - - const auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); - GE_CHK_BOOL_EXEC(dst_anchor != nullptr, - REPORT_INNER_ERR_MSG("E18888", "GetInDataAnchor failed, index:%d, op:%s", - dst_opio.GetIndex(), op_desc->GetName().c_str()); - return GRAPH_FAILED, "GetInDataAnchor failed, index:%d", dst_opio.GetIndex()); - - const auto ret = GraphUtils::AddEdge(src_anchor, dst_anchor); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add edge from node[%s][%d] to node[%s][%d] failed.", - src_node_ptr->GetName().c_str(), src_anchor->GetIdx(), - dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx()); - return GRAPH_FAILED, "[Add][Edge] from node[%s][%d] to node[%s][%d] failed.", - src_node_ptr->GetName().c_str(), src_anchor->GetIdx(), - dst_node_info->second->GetName().c_str(), dst_anchor->GetIdx()); - } - } - const auto out_control_anchor = src_node_ptr->GetOutControlAnchor(); - for (const auto &control_out : src_op_impl_ptr->control_output_link_) { - const std::map::const_iterator dst_node_info = - all_nodes_info_.find(control_out.lock()); - if (dst_node_info == all_nodes_info_.cend()) { - REPORT_INNER_ERR_MSG("E18888", "Find Dst node failed."); - GELOGE(GRAPH_FAILED, "[Check][Param] Find Dst node failed."); - return GRAPH_FAILED; - } - GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); - const auto in_control_anchor = dst_node_info->second->GetInControlAnchor(); - const auto ret = GraphUtils::AddEdge(out_control_anchor, in_control_anchor); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "add edge failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(), - dst_node_info->second->GetType().c_str()); - GELOGE(ret, "[Add][Edge] failed. srcNode %s:%s, dstNode %s:%s", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), dst_node_info->second->GetName().c_str(), - dst_node_info->second->GetType().c_str()); - return ret; - } - } - } - return GRAPH_SUCCESS; - } - struct OperatorCompareKey { - bool operator()(const OperatorImplPtr &n0, const OperatorImplPtr &n1) const { - if ((n0 == nullptr) || (n1 == nullptr)) { - return false; - } - return (n0->GetName() < n1->GetName()); - } - }; - ComputeGraphPtr graph_ = nullptr; - std::map all_nodes_info_{}; - // stable_all_nodes_info_中的键值对和all_nodes_info_中的一样,stable_all_nodes_info_ - // 用于保证每次遍历到的键值对的顺序是固定的 - std::map stable_all_nodes_info_{}; -}; - -void Operator::DynamicInputRegister(const char_t *name, - const uint32_t num, - const char_t *datatype_symbol, - bool is_push_back) { - GE_RETURN_IF_NULL(name, "[Check][Param] Operator name is nullptr."); - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - DynamicInputRegister(name, num, is_push_back); - operator_impl_->GetOpDescImpl()->SetInputDtypeSymbol(name, kIrInputDynamic, datatype_symbol); -} - -void Operator::DynamicOutputRegister(const char_t *name, - const uint32_t num, - const char_t *datatype_symbol, - bool is_push_back) { - GE_RETURN_IF_NULL(name, "[Check][Param] Operator name is nullptr."); - GE_RETURN_IF_NULL(datatype_symbol, "[Check][Param] Operator datatype_symbol is nullptr."); - GE_RETURN_IF_NULL(operator_impl_, "[Check][Param] Operator impl is nullptr."); - GE_RETURN_IF_NULL(operator_impl_->GetOpDescImpl(), "[Get][OpDescImpl] is nullptr."); - DynamicOutputRegister(name, num, is_push_back); - operator_impl_->GetOpDescImpl()->SetOutputDtypeSymbol(name, kIrOutputDynamic, datatype_symbol); -} -graphStatus Operator::SetSubgraphInstanceName(const uint32_t index, const char_t *name) { - GE_ASSERT_NOTNULL(operator_impl_); - GE_ASSERT_NOTNULL(operator_impl_->GetOpDescImpl()); - return operator_impl_->GetOpDescImpl()->SetSubgraphInstanceName(index, name); -} - -static inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { - for (const auto &graph : compute_graph->GetAllSubgraphs()) { - std::set node_names; - for (auto const &node : graph->GetDirectNode()) { - const auto result = node_names.insert(node->GetName()); - if (!result.second) { - REPORT_INNER_ERR_MSG("E18888", "[Check][Param] graph %s has same name node %s", graph->GetName().c_str(), - node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] graph %s has same name node %s", graph->GetName().c_str(), - node->GetName().c_str()); - return true; - } - } - } - - std::set node_names; - for (auto const &node : compute_graph->GetDirectNode()) { - const auto result = node_names.insert(node->GetName()); - if (!result.second) { - REPORT_INNER_ERR_MSG("E18888", "[Check][Param] graph %s has same name node %s", compute_graph->GetName().c_str(), - node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] graph %s has same name node %s", compute_graph->GetName().c_str(), - node->GetName().c_str()); - return true; - } - } - return false; -} - -ComputeGraphPtr GraphUtilsEx::CreateGraphFromOperator(const std::string &name, - const std::vector &inputs) { - auto graph_builder_impl = GraphBuilderImpl(name); - ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); - GE_CHK_BOOL_EXEC(compute_graph != nullptr, - REPORT_INNER_ERR_MSG("E18888", "BuildGraph failed, as return nullptr."); - return compute_graph, "[Build][Graph] Computer graph is nullptr"); - compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); - if (HasSameNameNode(compute_graph)) { - GELOGW("[CreateGraph][Check] Nodes with same name exist in one compute graph is not allowed, graph_name: %s", - name.c_str()); - compute_graph = nullptr; - } - - return compute_graph; -} - -ComputeGraphPtr GraphUtilsEx::CreateComputeGraphFromOperatorWithStableTopo(const std::string &name, - const std::vector &ops) { - auto graph_builder_impl = GraphBuilderImpl(name); - ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraphWithStableSort(ops); - GE_ASSERT_NOTNULL(compute_graph); - compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); - GE_ASSERT_TRUE(!HasSameNameNode(compute_graph), - "[CreateGraph][Check] Nodes with same name exist in one compute graph is not allowed, graph_name: %s", - name.c_str()); - return compute_graph; -} - -void GraphUtilsEx::BreakConnect(const std::map &all_nodes_infos) { - for (const auto &it : all_nodes_infos) { - const OperatorImplPtr op_impl = it.first; - if (op_impl == nullptr) { - GELOGW("[BreakConnect][Check] Operator impl is null"); - continue; - } - op_impl->ClearOutputLinks(); - op_impl->ClearInputLinks(); - OperatorKeeper::GetInstance().CheckOutOperator(op_impl); - } -} -} // namespace ge diff --git a/graph/normal_graph/operator_factory.cc b/graph/normal_graph/operator_factory.cc deleted file mode 100644 index 2b6abdb3e77419a7a4c0fab54603d755b4639523..0000000000000000000000000000000000000000 --- a/graph/normal_graph/operator_factory.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/operator_factory_impl.h" -#include "debug/ge_log.h" - -namespace ge { -Operator OperatorFactory::CreateOperator(const std::string &operator_name, const std::string &operator_type) { - return OperatorFactoryImpl::CreateOperator(operator_name, operator_type); -} - -Operator OperatorFactory::CreateOperator(const char_t *const operator_name, const char_t *const operator_type) { - if ((operator_name == nullptr) || (operator_type == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Create Operator input parameter is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Create Operator input parameter is nullptr."); - return Operator(); - } - const std::string op_name = operator_name; - const std::string op_type = operator_type; - return OperatorFactoryImpl::CreateOperator(op_name, op_type); -} - -graphStatus OperatorFactory::GetOpsTypeList(std::vector &all_ops) { - return OperatorFactoryImpl::GetOpsTypeList(all_ops); -} - -graphStatus OperatorFactory::GetOpsTypeList(std::vector &all_ops) { - std::vector all_op_types; - if (OperatorFactoryImpl::GetOpsTypeList(all_op_types) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get ops type list failed."); - GELOGE(GRAPH_FAILED, "[Get][OpsTypeList] failed."); - return GRAPH_FAILED; - } - for (auto &op_type : all_op_types) { - all_ops.emplace_back(op_type.c_str()); - } - return GRAPH_SUCCESS; -} - -bool OperatorFactory::IsExistOp(const std::string &operator_type) { - return OperatorFactoryImpl::IsExistOp(operator_type); -} - -bool OperatorFactory::IsExistOp(const char_t *const operator_type) { - if (operator_type == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Operator type is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Operator type is nullptr."); - return false; - } - const std::string op_type = operator_type; - return OperatorFactoryImpl::IsExistOp(op_type); -} - -OperatorCreatorRegister::OperatorCreatorRegister(const std::string &operator_type, OpCreator const &op_creator) { - (void)OperatorFactoryImpl::RegisterOperatorCreator(operator_type, op_creator); -} - -OperatorCreatorRegister::OperatorCreatorRegister(const char_t *const operator_type, OpCreatorV2 const &op_creator) { - std::string op_type; - if (operator_type != nullptr) { - op_type = operator_type; - } - (void)OperatorFactoryImpl::RegisterOperatorCreator(op_type, op_creator); -} - -InferShapeFuncRegister::InferShapeFuncRegister(const std::string &operator_type, - const InferShapeFunc &infer_shape_func) { - (void)OperatorFactoryImpl::RegisterInferShapeFunc(operator_type, infer_shape_func); -} - -InferShapeFuncRegister::InferShapeFuncRegister(const char_t *const operator_type, - const InferShapeFunc &infer_shape_func) { - std::string op_type; - if (operator_type != nullptr) { - op_type = operator_type; - } - (void)OperatorFactoryImpl::RegisterInferShapeFunc(op_type, infer_shape_func); -} - -InferFormatFuncRegister::InferFormatFuncRegister(const std::string &operator_type, - const InferFormatFunc &infer_format_func) { - (void)OperatorFactoryImpl::RegisterInferFormatFunc(operator_type, infer_format_func); -} - -InferFormatFuncRegister::InferFormatFuncRegister(const char_t *const operator_type, - const InferFormatFunc &infer_format_func) { - std::string op_type; - if (operator_type != nullptr) { - op_type = operator_type; - } - (void)OperatorFactoryImpl::RegisterInferFormatFunc(op_type, infer_format_func); -} - -InferValueRangeFuncRegister::InferValueRangeFuncRegister(const char_t *const operator_type, - const WHEN_CALL when_call, - const InferValueRangeFunc &infer_value_range_func) { - std::string op_type; - if (operator_type != nullptr) { - op_type = operator_type; - } - (void)OperatorFactoryImpl::RegisterInferValueRangeFunc(op_type, when_call, false, infer_value_range_func); -} - -InferValueRangeFuncRegister::InferValueRangeFuncRegister(const char_t *const operator_type) { - std::string op_type; - if (operator_type != nullptr) { - op_type = operator_type; - } - (void)OperatorFactoryImpl::RegisterInferValueRangeFunc(op_type); -} - -VerifyFuncRegister::VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func) { - (void)OperatorFactoryImpl::RegisterVerifyFunc(operator_type, verify_func); -} - -VerifyFuncRegister::VerifyFuncRegister(const char_t *const operator_type, const VerifyFunc &verify_func) { - std::string op_type; - if (operator_type != nullptr) { - op_type = operator_type; - } - (void)OperatorFactoryImpl::RegisterVerifyFunc(op_type, verify_func); -} -} // namespace ge diff --git a/graph/normal_graph/operator_factory_impl.cc b/graph/normal_graph/operator_factory_impl.cc deleted file mode 100644 index 4818300acd5d10a816241f4ef38ded022793dc43..0000000000000000000000000000000000000000 --- a/graph/normal_graph/operator_factory_impl.cc +++ /dev/null @@ -1,419 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/operator_factory_impl.h" - -#include - -#include "debug/ge_log.h" -#include "common/util/mem_utils.h" - -namespace ge { -namespace { - std::atomic is_register_overridable(false); -} -std::shared_ptr> OperatorFactoryImpl::operator_creators_; -std::shared_ptr> OperatorFactoryImpl::operator_creators_v2_; -std::shared_ptr> OperatorFactoryImpl::operator_infershape_funcs_; -std::shared_ptr> OperatorFactoryImpl::operator_inferformat_funcs_; -std::shared_ptr> OperatorFactoryImpl::operator_verify_funcs_; -std::shared_ptr> OperatorFactoryImpl::operator_infer_data_slice_funcs_; -std::shared_ptr> OperatorFactoryImpl::operator_infer_value_range_paras_; -std::shared_ptr> OperatorFactoryImpl::operator_infer_axis_slice_funcs_; -std::shared_ptr> OperatorFactoryImpl::operator_infer_axis_type_info_funcs_; -InferShapeV2Func OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; -InferDataTypeFunc OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; -InferShapeRangeFunc OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; -InferFormatV2Func OperatorFactoryImpl::operator_infer_format_v2_func_ = nullptr; -IsInferFormatV2RegisteredFunc OperatorFactoryImpl::is_infer_format_v2_registered_func_ = nullptr; -Operator OperatorFactoryImpl::CreateOperator(const std::string &operator_name, const std::string &operator_type) { - if (operator_creators_v2_ != nullptr) { - const std::map::const_iterator - it_v2 = operator_creators_v2_->find(operator_type); - if (it_v2 != operator_creators_v2_->cend()) { - return it_v2->second(operator_name.c_str()); - } else { - GELOGW("[Create][Operator] No op_proto of [%s] registered by AscendString.", operator_type.c_str()); - } - } - if (operator_creators_ == nullptr) { - return Operator(); - } - const std::map::const_iterator it = operator_creators_->find(operator_type); - if (it == operator_creators_->cend()) { - GELOGW("[Create][Operator] No op_proto of [%s] registered by string.", operator_type.c_str()); - return Operator(); - } - return it->second(operator_name); -} - -graphStatus OperatorFactoryImpl::GetOpsTypeList(std::vector &all_ops) { - all_ops.clear(); - if (operator_creators_v2_ != nullptr) { - all_ops.resize(operator_creators_v2_->size()); - (void)std::transform( - operator_creators_v2_->begin(), operator_creators_v2_->end(), all_ops.begin(), - [](const std::pair &operator_creator_v2) { return operator_creator_v2.first; }); - return GRAPH_SUCCESS; - } else { - GELOGW("[Get][OpsTypeList] Ops not registered by AscendString."); - } - - if (operator_creators_ != nullptr) { - all_ops.resize(operator_creators_->size()); - (void)std::transform( - operator_creators_->begin(), operator_creators_->end(), all_ops.begin(), - [](const std::pair &operator_creator) { return operator_creator.first; }); - } else { - REPORT_INNER_ERR_MSG("E18888", "no operator creators found"); - GELOGE(GRAPH_FAILED, "[Check][Param] no operator creators found"); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -bool OperatorFactoryImpl::IsExistOp(const std::string &operator_type) { - if (operator_creators_v2_ != nullptr) { - const std::map::const_iterator it_v2 = operator_creators_v2_->find(operator_type); - if (it_v2 != operator_creators_v2_->cend()) { - return true; - } - } - - if (operator_creators_ == nullptr) { - return false; - } - const std::map::const_iterator it = operator_creators_->find(operator_type); - if (it == operator_creators_->cend()) { - return false; - } - return true; -} - -InferShapeFunc OperatorFactoryImpl::GetInferShapeFunc(const std::string &operator_type) { - if (operator_infershape_funcs_ == nullptr) { - return nullptr; - } - const std::map::const_iterator - it = operator_infershape_funcs_->find(operator_type); - if (it == operator_infershape_funcs_->cend()) { - return nullptr; - } - return it->second; -} - -InferShapeV2Func OperatorFactoryImpl::GetInferShapeV2Func() { - return operator_infer_shape_v2_func_; -} - -InferDataTypeFunc OperatorFactoryImpl::GetInferDataTypeFunc() { - return operator_infer_datatype_func_; -} - -InferShapeRangeFunc OperatorFactoryImpl::GetInferShapeRangeFunc() { - return operator_infer_shape_range_func_; -} - -InferFormatFunc OperatorFactoryImpl::GetInferFormatFunc(const std::string &operator_type) { - if (operator_inferformat_funcs_ == nullptr) { - GELOGI("operator_inferformat_funcs_ is null"); - return nullptr; - } - const std::map::const_iterator - it = operator_inferformat_funcs_->find(operator_type); - if (it == operator_inferformat_funcs_->cend()) { - return nullptr; - } - return it->second; -} - -InferValueRangePara OperatorFactoryImpl::GetInferValueRangePara(const std::string &operator_type) { - const InferValueRangePara ret_para; - if (operator_infer_value_range_paras_ == nullptr) { - GELOGI("operator_infervalue_paras_ is null, operator infer value registration is none"); - return ret_para; - } - const std::map::const_iterator - it = operator_infer_value_range_paras_->find(operator_type); - if (it == operator_infer_value_range_paras_->end()) { - GELOGD("optype[%s] has not registered infer value func", operator_type.c_str()); - return ret_para; - } - return it->second; -} - -VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) { - if (operator_verify_funcs_ == nullptr) { - return nullptr; - } - const std::map::const_iterator - it = operator_verify_funcs_->find(operator_type); - if (it == operator_verify_funcs_->cend()) { - return nullptr; - } - return it->second; -} - -InferDataSliceFunc OperatorFactoryImpl::GetInferDataSliceFunc(const std::string &operator_type) { - if (operator_infer_data_slice_funcs_ == nullptr) { - return nullptr; - } - const std::map::const_iterator - it = operator_infer_data_slice_funcs_->find(operator_type); - if (it == operator_infer_data_slice_funcs_->cend()) { - return nullptr; - } - return it->second; -} - -void OperatorFactoryImpl::SetRegisterOverridable(const bool &is_overridable) { - is_register_overridable.store(is_overridable); -} - -graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const std::string &operator_type, - OpCreator const &op_creator) { - if (operator_creators_ == nullptr) { - operator_creators_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_creators_); - } - const std::map::const_iterator it = operator_creators_->find(operator_type); - if (it != operator_creators_->cend()) { - return GRAPH_FAILED; - } - (void)operator_creators_->emplace(operator_type, op_creator); - GELOGD("Register operator creator for %s.", operator_type.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const std::string &operator_type, - OpCreatorV2 const &op_creator) { - if (operator_creators_v2_ == nullptr) { - operator_creators_v2_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_creators_v2_); - } - auto it = operator_creators_v2_->find(operator_type); - if (it != operator_creators_v2_->cend()) { - if (is_register_overridable.load()) { - GELOGD("Override creator v2 for %s.", operator_type.c_str()); - it->second = op_creator; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; - } - (void)operator_creators_v2_->emplace(operator_type, op_creator); - GELOGD("Register creator v2 for %s.", operator_type.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterInferShapeFunc(const std::string &operator_type, - InferShapeFunc const infer_shape_func) { - if (operator_infershape_funcs_ == nullptr) { - GELOGI("operator_infershape_funcs_ init"); - operator_infershape_funcs_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_infershape_funcs_); - } - const std::map::const_iterator - it = operator_infershape_funcs_->find(operator_type); - if (it != operator_infershape_funcs_->cend()) { - GELOGW("op [%s] has registered infer func", operator_type.c_str()); - return GRAPH_FAILED; - } - GELOGD("Register infer func for type: %s.", operator_type.c_str()); - (void)operator_infershape_funcs_->emplace(operator_type, infer_shape_func); - return GRAPH_SUCCESS; -} - -void OperatorFactoryImpl::RegisterInferShapeV2Func(InferShapeV2Func const infer_shape_func) { - if (operator_infer_shape_v2_func_ == nullptr) { - GELOGI("operator infer shape v2 funcs init"); - operator_infer_shape_v2_func_ = infer_shape_func; - } -} - -void OperatorFactoryImpl::RegisterInferDataTypeFunc(InferDataTypeFunc const infer_data_type_func) { - if (operator_infer_datatype_func_ == nullptr) { - GELOGI("operator infer data type funcs init"); - operator_infer_datatype_func_ = infer_data_type_func; - } -} - -void OperatorFactoryImpl::RegisterInferShapeRangeFunc(InferShapeRangeFunc const infer_shape_range_func) { - if (operator_infer_shape_range_func_ == nullptr) { - GELOGI("operator infer shape range funcs init"); - operator_infer_shape_range_func_ = infer_shape_range_func; - } -} - -graphStatus OperatorFactoryImpl::RegisterInferFormatFunc(const std::string &operator_type, - InferFormatFunc const infer_format_func) { - if (operator_inferformat_funcs_ == nullptr) { - GELOGI("operator_inferformat_funcs_ init"); - operator_inferformat_funcs_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_inferformat_funcs_); - } - const std::map::const_iterator - it = operator_inferformat_funcs_->find(operator_type); - if (it != operator_inferformat_funcs_->cend()) { - return GRAPH_FAILED; - } - (void)operator_inferformat_funcs_->emplace(operator_type, infer_format_func); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func) { - if (operator_verify_funcs_ == nullptr) { - GELOGI("operator_verify_funcs_ init"); - operator_verify_funcs_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_verify_funcs_); - } - const std::map::const_iterator it = operator_verify_funcs_->find(operator_type); - if (it != operator_verify_funcs_->cend()) { - return GRAPH_FAILED; - } - (void)operator_verify_funcs_->emplace(operator_type, verify_func); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterInferDataSliceFunc(const std::string &operator_type, - InferDataSliceFunc const infer_data_slice_func) { - if (operator_infer_data_slice_funcs_ == nullptr) { - GELOGI("operator_infer_data_slice_funcs_ init"); - operator_infer_data_slice_funcs_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_infer_data_slice_funcs_); - } - const std::map::const_iterator - it = operator_infer_data_slice_funcs_->find(operator_type); - if (it != operator_infer_data_slice_funcs_->cend()) { - return GRAPH_FAILED; - } - (void)operator_infer_data_slice_funcs_->emplace(operator_type, infer_data_slice_func); - return GRAPH_SUCCESS; -} - -graphStatus OperatorFactoryImpl::RegisterInferValueRangeFunc(const std::string &operator_type) { - return RegisterInferValueRangeFunc(operator_type, INPUT_HAS_VALUE_RANGE, - true, nullptr); -} - -graphStatus OperatorFactoryImpl::RegisterInferValueRangeFunc(const std::string &operator_type, - const WHEN_CALL when_call, - const bool use_cpu_kernel, - const InferValueRangeFunc &infer_value_range_func) { - if (operator_infer_value_range_paras_ == nullptr) { - GELOGI("operator_infervalue_paras_ init"); - operator_infer_value_range_paras_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_infer_value_range_paras_); - } - const std::map::const_iterator - it = operator_infer_value_range_paras_->find(operator_type); - if (it != operator_infer_value_range_paras_->cend()) { - GELOGW("optype[%s] has registered infervalue func", operator_type.c_str()); - return GRAPH_FAILED; - } - InferValueRangePara tmp_para(when_call, use_cpu_kernel, infer_value_range_func); - (void)operator_infer_value_range_paras_->emplace(operator_type, tmp_para); - - GELOGD("Optype[%s] infervalue func registered successfully, when_call = %d, use_cpu_kernel = %d", - operator_type.c_str(), static_cast(when_call), static_cast(use_cpu_kernel)); - return GRAPH_SUCCESS; -} - -InferAxisSliceFunc OperatorFactoryImpl::GetInferAxisSliceFunc(const std::string &operator_type) { - if (operator_infer_axis_slice_funcs_ == nullptr) { - return nullptr; - } - const std::map::const_iterator - it = operator_infer_axis_slice_funcs_->find(operator_type); - if (it == operator_infer_axis_slice_funcs_->cend()) { - return nullptr; - } - return it->second; -} - -graphStatus OperatorFactoryImpl::RegisterInferAxisSliceFunc(const std::string &operator_type, - const InferAxisSliceFunc &infer_axis_slice_func) { - if (operator_infer_axis_slice_funcs_ == nullptr) { - GELOGI("axis slice derivation funcs init"); - operator_infer_axis_slice_funcs_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_infer_axis_slice_funcs_); - } - const std::map::const_iterator - it = operator_infer_axis_slice_funcs_->find(operator_type); - if (it != operator_infer_axis_slice_funcs_->cend()) { - return GRAPH_FAILED; - } - (void)operator_infer_axis_slice_funcs_->emplace(operator_type, infer_axis_slice_func); - return GRAPH_SUCCESS; -} - -InferAxisTypeInfoFunc OperatorFactoryImpl::GetInferAxisTypeInfoFunc(const std::string &operator_type) { - if (operator_infer_axis_type_info_funcs_ == nullptr) { - return nullptr; - } - const std::map::const_iterator - it = operator_infer_axis_type_info_funcs_->find(operator_type); - if (it == operator_infer_axis_type_info_funcs_->cend()) { - return nullptr; - } - return it->second; -} - -graphStatus OperatorFactoryImpl::RegisterInferAxisTypeInfoFunc(const std::string &operator_type, - const InferAxisTypeInfoFunc &infer_axis_type_info_func) { - if (operator_infer_axis_type_info_funcs_ == nullptr) { - GELOGI("axis type info derivation funcs init"); - operator_infer_axis_type_info_funcs_ = MakeShared>(); - GE_CHECK_NOTNULL(operator_infer_axis_type_info_funcs_); - } - const std::map::const_iterator - it = operator_infer_axis_type_info_funcs_->find(operator_type); - if (it != operator_infer_axis_type_info_funcs_->cend()) { - GELOGW("optype[%s] has registered axis type info func", operator_type.c_str()); - return GRAPH_FAILED; - } - (void)operator_infer_axis_type_info_funcs_->emplace(operator_type, infer_axis_type_info_func); - return GRAPH_SUCCESS; -} - -void OperatorFactoryImpl::RegisterInferFormatV2Func(InferFormatV2Func const infer_format_func) { - if (operator_infer_format_v2_func_ == nullptr) { - GELOGI("operator infer format v2 funcs init"); - operator_infer_format_v2_func_ = infer_format_func; - } -} - -InferFormatV2Func OperatorFactoryImpl::GetInferFormatV2Func() { - return operator_infer_format_v2_func_; -} - -void OperatorFactoryImpl::RegisterIsInferFormatV2RegisteredFunc( - IsInferFormatV2RegisteredFunc const is_infer_format_v2_registered_func) { - if (is_infer_format_v2_registered_func_ == nullptr) { - GELOGI("operator is_infer_format_v2_registered funcs init"); - is_infer_format_v2_registered_func_ = is_infer_format_v2_registered_func; - } -} - -IsInferFormatV2RegisteredFunc OperatorFactoryImpl::GetIsInferFormatV2RegisteredFunc() { - return is_infer_format_v2_registered_func_; -} - -void OperatorFactoryImpl::ReleaseRegInfo() { - ge::OperatorFactoryImpl::operator_infer_axis_type_info_funcs_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_axis_slice_funcs_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_value_range_paras_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_data_slice_funcs_ = nullptr; - ge::OperatorFactoryImpl::operator_verify_funcs_ = nullptr; - ge::OperatorFactoryImpl::operator_inferformat_funcs_ = nullptr; - ge::OperatorFactoryImpl::operator_infershape_funcs_ = nullptr; - ge::OperatorFactoryImpl::operator_creators_v2_ = nullptr; - ge::OperatorFactoryImpl::operator_creators_ = nullptr; - GELOGI("Release ops proto reg info success."); -} -} // namespace ge diff --git a/graph/normal_graph/operator_impl.cc b/graph/normal_graph/operator_impl.cc deleted file mode 100644 index 775a4478170ab9dbb340090e6e2b64635ae40738..0000000000000000000000000000000000000000 --- a/graph/normal_graph/operator_impl.cc +++ /dev/null @@ -1,595 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/normal_graph/operator_impl.h" - -#include "graph/normal_graph/op_io.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "debug/ge_op_types.h" -#include "graph/compute_graph.h" -#include "graph/ge_context.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/utils/constant_utils.h" -#include "graph/utils/node_utils.h" -#include "common/checker.h" - -namespace ge { -OperatorImpl::OperatorImpl(const std::string &name, const std::string &type) - : enable_shared_from_this(), op_desc_(ComGraphMakeShared(name, type)) { - if (op_desc_ == nullptr) { - GELOGW("[Check][Param] Make op_desc failed"); - } -} - -OperatorImpl::OperatorImpl(const OpDescPtr &op_desc) : enable_shared_from_this(), op_desc_(op_desc) {} - -OperatorImpl::OperatorImpl(const ConstNodePtr node) : enable_shared_from_this(), node_(node) { - if ((node_ != nullptr) && (node_->GetOpDesc() != nullptr)) { - op_desc_ = node_->GetOpDesc(); - } -} - -OperatorImpl::~OperatorImpl() {} - -void OperatorImpl::SetInputImpl(const std::string &dst_name, const ge::Operator &src_oprt) { - if (src_oprt.GetOutputsSize() != 1U) { - if ((src_oprt.operator_impl_ == nullptr) || (src_oprt.operator_impl_->op_desc_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "The source op is nullptr, check invalid."); - return; - } - GELOGE(ge::FAILED, "[Check][Param] The source operator[%s] must be single output operator", - src_oprt.operator_impl_->op_desc_->GetName().c_str()); - REPORT_INNER_ERR_MSG("E18888", "The source operator[%s] must be single output operator", - src_oprt.operator_impl_->op_desc_->GetName().c_str()); - return; - } - - const auto out_handler = src_oprt.GetOutput(0U); - if (out_handler == nullptr) { - return; - } - - return SetInputImpl(dst_name, out_handler); -} - -void OperatorImpl::SetInputImpl(const std::string &dst_name, const ge::OutHandler &out_handler) { - GE_CHK_BOOL_EXEC(out_handler != nullptr, REPORT_INNER_ERR_MSG("E18888", "param out_handler is nullptr, check invalid."); - return, "[Check][Param] SetInputImpl faild, as out_handler is nullptr."); - GE_CHK_BOOL_EXEC(!dst_name.empty(), REPORT_INNER_ERR_MSG("E18888", "param dst_name is empty, check invalid."); - return, "[Check][Param] dst name is empty"); - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr."); - return, "[Check][Param] op_desc_ is nullptr."); - (void)input_link_.insert(std::make_pair(dst_name, *out_handler)); - - const std::string src_name = out_handler->GetName(); - const int32_t dst_index = op_desc_->GetInputIndexByName(dst_name); - GE_CHK_BOOL_EXEC(dst_index >= 0, - REPORT_INNER_ERR_MSG("E18888", "Find input index by name failed. name[%s], op name:%s", - dst_name.c_str(), op_desc_->GetName().c_str()); - return, "[Get][InputIndex] Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), - op_desc_->GetName().c_str()); - const auto out_op_impl = out_handler->GetOwner(); - GE_CHK_BOOL_EXEC((out_op_impl != nullptr) && (out_op_impl->GetOpDescImpl() != nullptr), - REPORT_INNER_ERR_MSG("E18888", "out_handler invalid. name[%s]", dst_name.c_str()); - return, "[Get][Impl] out_handler invalid. name[%s]", dst_name.c_str()); - bool is_const = false; - if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { - is_const = true; - } - auto is_input_const = op_desc_->GetIsInputConst(); - for (int32_t i = static_cast(is_input_const.size()); i <= dst_index; ++i) { - is_input_const.push_back(false); - } - is_input_const[static_cast(dst_index)] = is_const; - op_desc_->SetIsInputConst(is_input_const); - - const OpIO in_handler(dst_name, dst_index, shared_from_this()); - GE_CHK_BOOL_EXEC(out_op_impl != nullptr, - REPORT_INNER_ERR_MSG("E18888", "out_handler invalid. name[%s]", dst_name.c_str()); - return, "[Get][Impl] of out_handler failed."); - - out_op_impl->UpdateLinkMapImpl(src_name, in_handler); - auto src_output_desc = out_op_impl->GetOutputDesc(src_name); - const auto dst_input_desc = op_desc_->GetInputDesc(dst_name); - if (dst_input_desc.GetFormat() == FORMAT_RESERVED) { - src_output_desc.SetFormat(FORMAT_ND); - src_output_desc.SetOriginFormat(FORMAT_ND); - } else { - src_output_desc.SetFormat(dst_input_desc.GetFormat()); - src_output_desc.SetOriginFormat(dst_input_desc.GetOriginFormat()); - } - // clear src tensor attr - for (const auto &attr : src_output_desc.GetAllAttrs()) { - (void) src_output_desc.DelAttr(attr.first); - } - // add dst tensor attr - for (const auto &attr : dst_input_desc.GetAllAttrs()) { - (void) src_output_desc.SetAttr(attr.first, attr.second); - } - - GE_CHK_BOOL_EXEC(op_desc_->UpdateInputDesc(dst_name, src_output_desc) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "UpdateInputDesc failed, dst name is %s, src name is %s", - dst_name.c_str(), src_name.c_str()); - return, "[Update][InputDesc] failed, dst name is %s, src name is %s", dst_name.c_str(), - src_name.c_str()); // fix for linking opdesc -} - -void OperatorImpl::AddControlInputImp(const ge::Operator &src_oprt) { - if (src_oprt.operator_impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Src operator impl is nullptr, check invalid"); - GELOGE(FAILED, "[Check][Param] Src operator impl is nullptr"); - return; - } - for (auto &input : control_input_link_) { - if (input.lock() == src_oprt.operator_impl_) { - return; - } - } - control_input_link_.push_back(src_oprt.operator_impl_); - src_oprt.operator_impl_->control_output_link_.push_back(shared_from_this()); -} - -graphStatus OperatorImpl::GetInputImpl(const std::string &dst_name, ge::OpIO &out_handler) const { - const auto out = input_link_.find(dst_name); - if (out == input_link_.end()) { - return GRAPH_FAILED; - } - out_handler = out->second; - return GRAPH_SUCCESS; -} - -graphStatus OperatorImpl::GetInputImpl(const uint32_t idx, ge::OpIO &out_handler) const { - GE_CHECK_NOTNULL(op_desc_); - const std::string dst_name = op_desc_->GetInputNameByIndex(idx); - return GetInputImpl(dst_name, out_handler); -} - -namespace { -graphStatus GetFromInputDesc(const OpDescPtr &op_desc, const int32_t index, ConstGeTensorPtr &ge_tensor) { - // if tensor has host mem, init data by ATTR_NAME_VALUE first - const auto tensor = op_desc->MutableInputDesc(static_cast(index)); - GeTensorPtr tensor_value = nullptr; - if (AttrUtils::MutableTensor(tensor, ATTR_NAME_VALUE, tensor_value)) { - GELOGD("Get ATTR_NAME_VALUE from %d input of %s, Tensor addr is %p, tensor value data type is %d.", index, - op_desc->GetName().c_str(), tensor.get(), tensor_value->GetTensorDesc().GetDataType()); - ge_tensor = tensor_value; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} -} // namespace - - -graphStatus OperatorImpl::GetFromPeerNode(NodePtr &peer_node, - const OutDataAnchorPtr &out_data_anchor, - ConstGeTensorPtr &ge_tensor) const { - auto peer_node_2_out_anchor = std::make_pair(peer_node, out_data_anchor); - if ((peer_node->GetType() == ENTER) || (peer_node->GetType() == REFENTER)) { - const auto enter_in_data_anchor = peer_node->GetInDataAnchor(0); - GE_CHECK_NOTNULL(enter_in_data_anchor); - const auto enter_peer_out_data_anchor = enter_in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(enter_peer_out_data_anchor); - peer_node = enter_peer_out_data_anchor->GetOwnerNode(); - peer_node_2_out_anchor.first = peer_node; - peer_node_2_out_anchor.second = enter_peer_out_data_anchor; - } - const auto peer_op_desc = peer_node->GetOpDesc(); - GE_CHECK_NOTNULL(peer_op_desc); - const auto peer_op_type = peer_op_desc->GetType(); - if (ConstantUtils::IsConstant(peer_op_desc)) { - return ConstantUtils::GetWeight(peer_op_desc, static_cast(peer_node_2_out_anchor.second->GetIdx()), - ge_tensor) ? GRAPH_SUCCESS : GRAPH_FAILED; - } - if (peer_op_type == FILECONSTANT) { - return ConstantUtils::GetWeightFromFile(peer_op_desc, ge_tensor) ? GRAPH_SUCCESS : GRAPH_FAILED; - } - // Place holder operator, try to get the weight from `parentNode`; - // `parentNode` is the real node of the placeholder node in engine partition graph - if (peer_op_type == PLACEHOLDER) { - if ((NodeUtils::TryGetWeightByPlaceHolderNode(peer_node, ge_tensor) != GRAPH_SUCCESS) || (ge_tensor == nullptr)) { - return GRAPH_FAILED; - } else { - return GRAPH_SUCCESS; - } - } - - if (peer_op_type == DATA) { - if ((NodeUtils::TryGetWeightByDataNode(peer_node, ge_tensor) != GRAPH_SUCCESS) || (ge_tensor == nullptr)) { - return GRAPH_FAILED; - } else { - return GRAPH_SUCCESS; - } - } - return GRAPH_FAILED; -} - -graphStatus OperatorImpl::GetInputConstData(const std::string &dst_name, Tensor &data) { - GE_CHECK_NOTNULL(op_desc_); - const auto index = op_desc_->GetInputIndexByName(dst_name); - ConstGeTensorPtr ge_tensor = nullptr; - if (GetInputConstData(static_cast(index), ge_tensor) == GRAPH_SUCCESS) { - data = TensorAdapter::GeTensor2Tensor(ge_tensor); - return GRAPH_SUCCESS; - } - - return GRAPH_FAILED; -} - -graphStatus OperatorImpl::GetInputConstData(const uint32_t idx, ConstGeTensorPtr &ge_tensor) const { - if (ge_tensor != nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "ge_tensor already has value"); - return GRAPH_PARAM_INVALID; - } - const auto node = GetNode(); - if (node == nullptr) { - // for out graph - return GetInputConstDataOut(idx, ge_tensor); - } - // from runtime context - if (get_const_input_runtime_ != nullptr) { - GeTensorPtr tensor_value = nullptr; - GE_CHK_GRAPH_STATUS_RET(get_const_input_runtime_(node, idx, tensor_value), - "Fail to get %d const input of %s from context.", idx, node->GetName().c_str()); - ge_tensor = tensor_value; - return GRAPH_SUCCESS; - } - - const auto in_data_anchor = node->GetInDataAnchor(static_cast(idx)); - GE_CHECK_NOTNULL(in_data_anchor); - const auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (out_data_anchor == nullptr) { - GELOGW("[Check][op: %s][Param:out_data_anchor] is null, idx : %u.", GetName().c_str(), idx); - return ge::PARAM_INVALID; - } - auto peer_node = out_data_anchor->GetOwnerNode(); - if (runtime_context_ != nullptr) { - // deprecated, will delete when air support - GeTensorPtr tensor_value = nullptr; - if (runtime_context_->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), tensor_value) == - GRAPH_SUCCESS) { - ge_tensor = tensor_value; - return GRAPH_SUCCESS; - } - } - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - // from input desc - if (GetFromInputDesc(op_desc, static_cast(idx), ge_tensor) == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - // from peer node - return GetFromPeerNode(peer_node, out_data_anchor, ge_tensor); -} - -graphStatus OperatorImpl::GetInputConstDataOut(const uint32_t idx, ConstGeTensorPtr &ge_tensor) const { - ge::OpIO out_handle("", 0, nullptr); - if (GetInputImpl(idx, out_handle) != GRAPH_SUCCESS) { - GELOGW("[Get][InputImpl] failed, op name: %s, input index: %u", GetName().c_str(), idx); - return GRAPH_FAILED; - } - if ((out_handle.GetOwner() != nullptr) && (out_handle.GetOwner()->GetOpDescImpl() != nullptr)) { - const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); - const auto op_desc = out_handle.GetOwner()->GetOpDescImpl(); - if ((op_desc_impl_type == CONSTANTOP) || (op_desc_impl_type == CONSTANT)) { - if (AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor)) { - return GRAPH_SUCCESS; - } - } - if (op_desc_impl_type == FILECONSTANT) { - if (ConstantUtils::GetWeightFromFile(op_desc, ge_tensor)) { - return GRAPH_SUCCESS; - } - } - } - return GRAPH_FAILED; -} - -graphStatus OperatorImpl::GetInputConstDataOut(const std::string &dst_name, Tensor &data) const { - ge::OpIO out_handle("", 0, nullptr); - if (GetInputImpl(dst_name, out_handle) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "%s get input impl failed", dst_name.c_str()); - GELOGE(FAILED, "[Get][InputImpl] failed, dst_name:%s", dst_name.c_str()); - return GRAPH_FAILED; - } - if ((out_handle.GetOwner() != nullptr) && (out_handle.GetOwner()->GetOpDescImpl() != nullptr)) { - const Operator const_op(out_handle.GetOwner()); - const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); - if ((op_desc_impl_type == CONSTANTOP) || (op_desc_impl_type == CONSTANT)) { - return const_op.GetAttr(ATTR_NAME_WEIGHTS.c_str(), data); - } - if (op_desc_impl_type == FILECONSTANT) { - const auto op_desc = out_handle.GetOwner()->GetOpDescImpl(); - ConstGeTensorPtr ge_tensor = nullptr; - if (ConstantUtils::GetWeightFromFile(op_desc, ge_tensor)) { - data = TensorAdapter::GeTensor2Tensor(ge_tensor); - return GRAPH_SUCCESS; - } - } - } - return GRAPH_FAILED; -} - -bool OperatorImpl::InputIsSet(const std::string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return false, "[Check][Param] op_desc_ is nullptr."); - return op_desc_->InputIsSet(name); -} - -std::string OperatorImpl::GetName() const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return std::string(), "[Check][Param] op_desc_ is nullptr."); - return op_desc_->GetName(); -} - -GeTensorDesc OperatorImpl::GetInputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return GeTensorDesc(), "[Check][Param] op_desc_ is nullptr."); - return op_desc_->GetInputDesc(name); -} - -GeTensorDesc OperatorImpl::GetInputDesc(const uint32_t index) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return GeTensorDesc(), "[Check][Param] op_desc_ is nullptr."); - return op_desc_->GetInputDesc(index); -} - -GeTensorDescPtr OperatorImpl::MutableInputDesc(const std::string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return nullptr, "[Check][Param] op_desc_ is nullptr."); - return op_desc_->MutableInputDesc(name); -} - -GeTensorDescPtr OperatorImpl::MutableInputDesc(const uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return nullptr, "[Check][Param] op_desc_ is nullptr."); - return op_desc_->MutableInputDesc(index); -} - -graphStatus OperatorImpl::UpdateInputDesc(const std::string &name, const GeTensorDesc &tensor_desc) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return GRAPH_FAILED, "[Check][Param] op_desc_ is nullptr."); - - return op_desc_->UpdateInputDesc(name, tensor_desc); -} - -OutHandler OperatorImpl::GetOutput(const std::string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return nullptr, "[Check][Param] op_desc_ is nullptr."); - - int32_t src_index = op_desc_->GetOutputIndexByName(name); - GE_CHK_BOOL_EXEC(src_index >= 0, - REPORT_INNER_ERR_MSG("E18888", "Find src index by name failed. name[%s]", name.c_str()); - return nullptr, "[Get][OutputIndex] Find src index by name failed. name[%s]", name.c_str()); - const shared_ptr output_ptr = ComGraphMakeShared(name, src_index, shared_from_this()); - if (output_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "OpIO make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] OpIO make shared failed"); - return nullptr; - } - return output_ptr; -} - -OutHandler OperatorImpl::GetOutput(uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return nullptr, "[Check][Param] op_desc_ is nullptr."); - std::string name = op_desc_->GetOutputNameByIndex(index); - if (name.empty()) { - REPORT_INNER_ERR_MSG("E18888", "Find src name by index failed. index[%u]", index); - GELOGE(GRAPH_FAILED, "[Get][OutputName] Find src name by index failed. index[%u]", index); - return nullptr; - } - const shared_ptr output_ptr = ComGraphMakeShared(name, index, shared_from_this()); - if (output_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "OpIO make shared failed"); - GELOGE(GRAPH_FAILED, "[Call][ComGraphMakeShared] OpIO make shared failed"); - return nullptr; - } - return output_ptr; -} - -GeTensorDesc OperatorImpl::GetOutputDesc(const std::string &name) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return GeTensorDesc(), "[Check][Param] op_desc_ is nullptr."); - - return op_desc_->GetOutputDesc(name); -} - -GeTensorDesc OperatorImpl::GetOutputDesc(const uint32_t index) const { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return GeTensorDesc(), "[Check][Param] op_desc_ is nullptr."); - - return op_desc_->GetOutputDesc(index); -} - -GeTensorDescPtr OperatorImpl::MutableOutputDesc(const std::string &name) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return nullptr, "[Check][Param] op_desc_ is nullptr."); - return op_desc_->MutableOutputDesc(name); -} - -GeTensorDescPtr OperatorImpl::MutableOutputDesc(const uint32_t index) { - GE_CHK_BOOL_EXEC(op_desc_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - return nullptr, "[Check][Param] op_desc_ is nullptr."); - return op_desc_->MutableOutputDesc(index); -} - -graphStatus OperatorImpl::UpdateOutputDesc(const std::string &name, const GeTensorDesc &tensor_desc) { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "[Check][Param] op_desc is nullptr."); - const auto res = op_desc_->UpdateOutputDesc(name, tensor_desc); - if (res == GRAPH_SUCCESS) { - // normalize ge tensor desc - auto normalized_tensor_desc = tensor_desc; - TensorAdapter::NormalizeGeTensorDesc(normalized_tensor_desc); - for (const auto &ol : output_links_[name]) { - if (ol.GetOwner() == nullptr) { - GELOGW("[Update][Check] %s get owner is nullptr", ol.GetName().c_str()); - continue; - } - GE_CHK_BOOL_RET_STATUS(ol.GetOwner()->UpdateInputDesc(ol.GetName(), normalized_tensor_desc) == GRAPH_SUCCESS, - GRAPH_FAILED, "[Update][InputDesc] Could not update next operator's input %s.", - ol.GetName().c_str()); - } - } - return res; -} - -size_t OperatorImpl::GetInputsSize() const { - GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0UL); - return op_desc_->GetInputsSize(); -} - -size_t OperatorImpl::GetOutputsSize() const { - GE_IF_BOOL_EXEC(op_desc_ == nullptr, return 0U); - return op_desc_->GetOutputsSize(); -} - -graphStatus OperatorImpl::SetAttr(const std::string &name, AnyValue &&attr_value) { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "[Check][Param] op_desc is nullptr."); - return op_desc_->SetAttr(name, attr_value); -} - -graphStatus OperatorImpl::SetAttr(const std::string &name, const AnyValue &attr_value) { - GE_ASSERT_NOTNULL(op_desc_, "[Check][Param] inner source is invalid."); - return op_desc_->SetAttr(name, attr_value); -} - -graphStatus OperatorImpl::GetAttr(const std::string &name, AnyValue &attr_value) const { - GE_CHK_BOOL_RET_STATUS(op_desc_ != nullptr, GRAPH_FAILED, "[Check][Param] op_desc is nullptr."); - return op_desc_->GetAttr(name, attr_value); -} - -OpDescPtr OperatorImpl::GetOpDescImpl() const { - return op_desc_; -} - -void OperatorImpl::UpdateLinkMapImpl(const std::string &src_name, const OpIO &op_dst) { - const auto it_find = output_links_.find(src_name); - if (it_find == output_links_.end()) { - std::vector dsts{op_dst}; - (void)output_links_.insert(std::make_pair(src_name, dsts)); - } else { - it_find->second.push_back(op_dst); - } -} - -Operator OperatorImpl::ToOperator() { - return Operator(shared_from_this()); -} - -OpDescPtr OperatorImpl::GetOpDesc(const Operator &oprt) { - GE_IF_BOOL_EXEC(oprt.operator_impl_ == nullptr, return nullptr); - return oprt.operator_impl_->op_desc_; -} - -void OperatorImpl::ClearOutputLinks() noexcept { - output_links_.clear(); -} - -void OperatorImpl::ClearInputLinks() noexcept { - input_link_.clear(); -} - -ge::ConstNodePtr OperatorImpl::GetNode() const { - return node_; -} - -graphStatus OperatorImpl::SetNode(const ConstNodePtr &node) { - GE_IF_BOOL_EXEC(node_ != nullptr, return GRAPH_FAILED); - node_ = node; - return GRAPH_SUCCESS; -} - -void OperatorImpl::SetInferenceContext(const InferenceContextPtr &inference_context) { - inference_context_ = inference_context; -} - -InferenceContextPtr OperatorImpl::GetInferenceContext() const { - return inference_context_; -} - -void OperatorImpl::SubgraphRegister(const std::string &ir_name, const bool dynamic) { - op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); -} - -void OperatorImpl::SubgraphCountRegister(const std::string &ir_name, const uint32_t count) { - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { - (void)op_desc_->AddSubgraphName(ir_name); - subgraph_names_to_builders_[ir_name] = nullptr; - } else { - for (uint32_t i = 0U; i < count; ++i) { - const std::string key_name = NodeUtils::GenDynamicSubgraphName(ir_name, i); - (void)op_desc_->AddSubgraphName(key_name); - subgraph_names_to_builders_[key_name] = nullptr; - } - } -} - -void OperatorImpl::SetSubgraphBuilder(const std::string &ir_name, const uint32_t index, - const SubgraphBuilder &builder) { - std::string key_name = ir_name; - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { - key_name += std::to_string(index); - } - - const auto it = subgraph_names_to_builders_.find(key_name); - if (it == subgraph_names_to_builders_.end()) { - REPORT_INNER_ERR_MSG("E18888", "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); - GELOGE(PARAM_INVALID, "[Check][Param] Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), - index); - return; - } - it->second = builder; -} - -SubgraphBuilder OperatorImpl::GetSubgraphBuilder(const std::string &ir_name, const uint32_t index) const { - std::string key_name = ir_name; - if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { - key_name += std::to_string(index); - } - - return GetSubgraphBuilder(key_name); -} - -SubgraphBuilder OperatorImpl::GetSubgraphBuilder(const std::string &name) const { - const auto iter = subgraph_names_to_builders_.find(name); - if (iter == subgraph_names_to_builders_.end()) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get subgraph builder for name %s", name.c_str()); - GELOGE(PARAM_INVALID, "[Check][Param] Failed to get subgraph builder for name %s", name.c_str()); - return nullptr; - } - - return iter->second; -} - -std::vector OperatorImpl::GetSubgraphNames() const { - auto &ir_names = op_desc_->GetSubgraphIrNames(); - std::vector names(ir_names.size()); - (void)std::transform(ir_names.begin(), ir_names.end(), names.begin(), - [](const std::pair &name_to_type) { - return name_to_type.first; - }); - return names; -} - -size_t OperatorImpl::GetSubgraphNamesCount() const { - return op_desc_->GetSubgraphIrNames().size(); -} - -graphStatus OperatorImpl::UpdateInputDesc(const uint32_t index, const GeTensorDesc &tensor_desc) { - GE_CHECK_NOTNULL(op_desc_); - return op_desc_->UpdateInputDesc(index, tensor_desc); -} - -graphStatus OperatorImpl::UpdateOutputDesc(const uint32_t index, const GeTensorDesc &tensor_desc) { - GE_CHECK_NOTNULL(op_desc_); - return op_desc_->UpdateOutputDesc(index, tensor_desc); -} -} // namespace ge diff --git a/graph/normal_graph/operator_impl.h b/graph/normal_graph/operator_impl.h deleted file mode 100644 index ecedad3ef3f01e7abd833c9107f491cdd855fea0..0000000000000000000000000000000000000000 --- a/graph/normal_graph/operator_impl.h +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_OPERATOR_IMPL_H -#define METADEF_CXX_OPERATOR_IMPL_H -#include -#include -#include "graph/op_desc.h" -#include "graph/node.h" -#include "graph/operator.h" -#include "graph/inference_context.h" -#include "graph/runtime_inference_context.h" -#include "graph/normal_graph/op_io.h" -namespace ge { -class OperatorImpl : public std::enable_shared_from_this { - public: - using GetConstInputOnRuntimeFun = - std::function; - explicit OperatorImpl(const std::string &name, const std::string &type); - explicit OperatorImpl(const OpDescPtr &op_desc); - explicit OperatorImpl(const ConstNodePtr node); - ~OperatorImpl(); - - void SetInputImpl(const std::string &dst_name, const Operator &src_oprt); - void SetInputImpl(const std::string &dst_name, const OutHandler &out_handler); - void AddControlInputImp(const Operator &src_oprt); - graphStatus GetInputImpl(const std::string &dst_name, ge::OpIO &out_handler) const; - graphStatus GetInputImpl(const uint32_t idx, ge::OpIO &out_handler) const; - graphStatus GetInputConstData(const std::string &dst_name, Tensor &data); - graphStatus GetInputConstData(const uint32_t idx, ConstGeTensorPtr &ge_tensor) const; - graphStatus GetInputConstDataOut(const std::string &dst_name, Tensor &data) const; - graphStatus GetInputConstDataOut(const uint32_t idx, ConstGeTensorPtr &ge_tensor) const; - bool InputIsSet(const std::string &name); - std::string GetName() const; - GeTensorDesc GetInputDesc(const std::string &name) const; - GeTensorDesc GetInputDesc(const uint32_t index) const; - GeTensorDescPtr MutableInputDesc(const std::string &name); - GeTensorDescPtr MutableInputDesc(const uint32_t index); - graphStatus UpdateInputDesc(const std::string &name, const GeTensorDesc &tensor_desc); - OutHandler GetOutput(const std::string &name); - OutHandler GetOutput(uint32_t index); - GeTensorDesc GetOutputDesc(const std::string &name) const; - GeTensorDesc GetOutputDesc(const uint32_t index) const; - GeTensorDescPtr MutableOutputDesc(const std::string &name); - GeTensorDescPtr MutableOutputDesc(const uint32_t index); - graphStatus UpdateOutputDesc(const std::string &name, const GeTensorDesc &tensor_desc); - size_t GetInputsSize() const; - size_t GetOutputsSize() const; - graphStatus SetAttr(const std::string &name, AnyValue &&attr_value); - graphStatus SetAttr(const std::string &name, const AnyValue &attr_value); - graphStatus GetAttr(const std::string &name, AnyValue &attr_value) const; - OpDescPtr GetOpDescImpl() const; - void UpdateLinkMapImpl(const std::string &src_name, const OpIO &op_dst); - Operator ToOperator(); - void ClearOutputLinks() noexcept; - void ClearInputLinks() noexcept; - ge::ConstNodePtr GetNode() const; - graphStatus SetNode(const ConstNodePtr &node) ; - void SetInferenceContext(const InferenceContextPtr &inference_context); - InferenceContextPtr GetInferenceContext() const; - void SubgraphRegister(const std::string &ir_name, const bool dynamic); - void SubgraphCountRegister(const std::string &ir_name, const uint32_t count); - void SetSubgraphBuilder(const std::string &ir_name, const uint32_t index, const SubgraphBuilder &builder); - SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, const uint32_t index) const; - SubgraphBuilder GetSubgraphBuilder(const std::string &name) const; - std::vector GetSubgraphNames() const; - size_t GetSubgraphNamesCount() const; - - static OpDescPtr GetOpDesc(const Operator &oprt); - graphStatus UpdateInputDesc(const uint32_t index, const GeTensorDesc &tensor_desc); - graphStatus UpdateOutputDesc(const uint32_t index, const GeTensorDesc &tensor_desc); - - private: - graphStatus GetFromPeerNode(NodePtr &peer_node, const OutDataAnchorPtr &out_data_anchor, - ConstGeTensorPtr &ge_tensor) const; - - private: - OpDescPtr op_desc_ = nullptr; - ge::ConstNodePtr node_{nullptr}; - ge::InferenceContextPtr inference_context_; - std::map> output_links_{}; - std::map input_link_{}; - std::vector> control_input_link_{}; - std::vector> control_output_link_{}; - std::map subgraph_names_to_builders_; - RuntimeInferenceContext *runtime_context_{nullptr}; // depracated, will delete when air support - GetConstInputOnRuntimeFun get_const_input_runtime_ = nullptr; - - private: - friend class GraphBuilderImpl; - friend class MultiThreadGraphBuilder; - friend class OpDescUtils; -}; -// Used to manage OperatorImpl instances created by ge api. -class OperatorKeeper { - public: - static OperatorKeeper &GetInstance(); - void CheckInOperator(const OperatorImplPtr &op_impl) { - if (op_impl) { - const std::lock_guard lock(mutex_); - (void)(operators_.insert(op_impl)); - } - } - void CheckOutOperator(const OperatorImplPtr &op_impl) { - if (op_impl) { - const std::lock_guard lock(mutex_); - (void)(operators_.erase(op_impl)); - } - } - - void ClearInvalidOp() { - const std::lock_guard lock(mutex_); - for (auto iter = operators_.begin(); iter != operators_.end();) { - auto op = iter->lock(); - if (op == nullptr) { - iter = operators_.erase(iter); - } else { - ++iter; - } - } - } - - private: - OperatorKeeper() = default; - ~OperatorKeeper() { - for (const auto &iter : operators_) { - if (!iter.expired()) { - iter.lock()->ClearInputLinks(); - } - if (!iter.expired()) { - iter.lock()->ClearOutputLinks(); - } - } - // Manually clean up for `Operator` destructor may access `operators_` - auto operators = std::move(operators_); - operators.clear(); - } - std::set, std::owner_less>> operators_; - std::mutex mutex_; -}; -} // namespace ge - -#endif // METADEF_CXX_OPERATOR_IMPL_H diff --git a/graph/normal_graph/tensor.cc b/graph/normal_graph/tensor.cc deleted file mode 100644 index 6d1ac779ba23b63a1cc3ca07f8b94422377d1a70..0000000000000000000000000000000000000000 --- a/graph/normal_graph/tensor.cc +++ /dev/null @@ -1,1103 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/tensor.h" - -#include -#include "debug/ge_util.h" -#include "graph/ge_tensor.h" -#include "graph/debug/ge_attr_define.h" -#include "securec.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" -#include "common/checker.h" -#include "expand_dimension.h" -#include "transfer_shape_according_to_format.h" -#include "attribute_group/attr_group_base.h" - -namespace { -const int64_t UNKNOWN_DIM_SIZE = -1; -} // namespace - -namespace ge { -// If not overflow return true -static bool Int64MulNotOverflow(const int64_t a, const int64_t b) { - if (a > 0) { - if (b > 0) { - if (a > (INT64_MAX / b)) { - return false; - } - } else { - if (b < (INT64_MIN / a)) { - return false; - } - } - } else { - if (b > 0) { - if (a < (INT64_MIN / b)) { - return false; - } - } else { - if ((a != 0) && (b < (INT64_MAX / a))) { - return false; - } - } - } - return true; -} - -class TensorDescValue { - public: - TensorDescValue() = default; - ~TensorDescValue() = default; - TensorDescValue(const TensorDescValue &other) { - if ((other.const_data_len_ == 0U) || (other.const_data_buffer_ == nullptr)) { - return; - } - if (!TensorDescValue::CloneValue(this->const_data_buffer_, other.const_data_buffer_, other.const_data_len_)) { - return; - } - this->const_data_len_ = other.const_data_len_; - return; - } - TensorDescValue &operator=(const TensorDescValue &other) { - if ((&other == this) || (other.const_data_len_ == 0U) || (other.const_data_buffer_ == nullptr)) { - return *this; - } - if (!TensorDescValue::CloneValue(this->const_data_buffer_, other.const_data_buffer_, other.const_data_len_)) { - return *this; - } - this->const_data_len_ = other.const_data_len_; - return *this; - } - - private: - std::unique_ptr const_data_buffer_ = nullptr; - size_t const_data_len_ = 0U; - - static bool CloneValue(std::unique_ptr &dst, const std::unique_ptr &src, - const std::size_t len) { - dst = ComGraphMakeUnique(len); - if (dst == nullptr) { - return false; - } - size_t remain_size = len; - auto dst_addr = PtrToValue(static_cast(dst.get())); - auto src_addr = PtrToValue(static_cast(src.get())); - while (remain_size > SECUREC_MEM_MAX_LEN) { - if (memcpy_s(ValueToPtr(dst_addr), SECUREC_MEM_MAX_LEN, - ValueToPtr(src_addr), SECUREC_MEM_MAX_LEN) != EOK) { - return false; - } - remain_size -= SECUREC_MEM_MAX_LEN; - dst_addr += SECUREC_MEM_MAX_LEN; - src_addr += SECUREC_MEM_MAX_LEN; - } - if ((remain_size != 0U) && (memcpy_s(ValueToPtr(dst_addr), remain_size, - ValueToPtr(src_addr), remain_size) != EOK)) { - return false; - } - return true; - } - friend class TensorDesc; -}; - -class TensorDescImpl { - public: - TensorDescImpl() = default; - ~TensorDescImpl() = default; - TensorDescImpl(const Shape &shape, const Format format, const DataType dt) - : shape_(shape), format_(format), data_type_(dt) {} - private: - Shape shape_; - std::vector> range_; - Format format_ = FORMAT_ND; - Format origin_format_ = FORMAT_ND; - bool origin_format_is_set_ = false; - DataType data_type_ = DT_FLOAT; - Shape origin_shape_; - bool origin_shape_is_set_ = false; - int64_t size_ = 0; - int64_t real_dim_cnt_ = 0; - std::string name_; - Placement placement_ = kPlacementHost; - TensorDescValue tensor_desc_value_; - std::string expand_dims_rule_; - bool reuse_input_ = false; - uint32_t reuse_input_index_ = 0U; - - friend class TensorDesc; - friend class TensorAdapter; -}; - -class TensorImpl { - public: - TensorImpl() = default; - ~TensorImpl() = default; - - explicit TensorImpl(const TensorDesc &tensor_desc) : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)) {} - TensorImpl(const TensorDesc &tensor_desc, const std::vector &data) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data) {} - TensorImpl(const TensorDesc &tensor_desc, const uint8_t * const data, const size_t size) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), data, size) {} - TensorImpl(TensorDesc &&tensor_desc, std::vector &&data) - : ge_tensor(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc), std::move(data)) {} - - graphStatus SetData(const std::string &data) { - if (!data.empty()) { - /// Extra 16 bytes store string head - /// Extra 1 byte store '\0' - const size_t total_size = data.size() + sizeof(StringHead) + 1U; - const std::unique_ptr buff = ComGraphMakeUnique(total_size); - if (buff == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "allocate string raw data buff failed, size:%zu", total_size); - GELOGE(GRAPH_FAILED, "[New][Buffer] allocate string raw data buff failed"); - return GRAPH_FAILED; - } - StringHead *const string_head = PtrToPtr(buff.get()); - // Front 8 bytes store pointer of string - char_t *const raw_data = PtrToPtr( - ValueToPtr(PtrToValue(PtrToPtr(buff.get())) + sizeof(*string_head))); - string_head->addr = static_cast(sizeof(StringHead)); - string_head->len = static_cast(data.size()); - const int32_t memcpy_ret = memcpy_s(raw_data, total_size - sizeof(StringHead), data.c_str(), data.size() + 1U); - if (memcpy_ret != EOK) { - REPORT_INNER_ERR_MSG("E18888", "memcpy data failed, ret:%d, size:%zu.", memcpy_ret, data.size() + 1U); - GELOGE(GRAPH_FAILED, "[Copy][Data] failed, ret:%d", memcpy_ret); - return GRAPH_FAILED; - } - (void)ge_tensor.SetData(PtrToPtr(buff.get()), total_size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; - } - - graphStatus SetData(const std::vector &data) { - if (data.empty()) { - REPORT_INNER_ERR_MSG("E18888", "there is no data, please check the input variable"); - GELOGE(GRAPH_FAILED, "[Check][Param] there is no data, please check the input variable"); - return GRAPH_FAILED; - } - size_t total_size = 0U; - total_size = std::accumulate(data.begin(), data.end(), total_size, [](size_t total, const std::string& str) { - /// Extra 16 bytes store string head - /// Extra 1 byte store '\0' - total += str.size() + sizeof(StringHead) + 1U; - return total; - }); - - const std::unique_ptr buff = ComGraphMakeUnique(total_size); - if (buff == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "allocate string raw data buff failed, size:%zu", total_size); - GELOGE(GRAPH_FAILED, "[New][Buffer] allocate string raw data buff failed"); - return GRAPH_FAILED; - } - // Front some bytes store head of each string - StringHead * const string_head = PtrToPtr(buff.get()); - uint64_t raw_data = PtrToValue(static_cast(buff.get())) + (data.size() * sizeof(*string_head)); - uint64_t ptr_size = data.size() * sizeof(StringHead); - for (size_t i = 0U; i < data.size(); ++i) { - PtrAdd(string_head, data.size(), i)->addr = static_cast(ptr_size); - PtrAdd(string_head, data.size(), i)->len = static_cast(data[i].size()); - if (total_size < ptr_size) { - REPORT_INNER_ERR_MSG("E18888", "Subtraction invalid, total_size:%zu, ptr_size:%" PRIu64, total_size, ptr_size); - GELOGE(GRAPH_FAILED, "[Check][Param] Subtraction invalid, total_size: %zu, ptr_size: %" PRIu64, - total_size, ptr_size); - return GRAPH_FAILED; - } - const int32_t memcpy_ret = memcpy_s(ValueToPtr(raw_data), total_size - ptr_size, - data[i].c_str(), data[i].size() + 1U); - GE_CHK_BOOL_RET_STATUS(memcpy_ret == EOK, GRAPH_FAILED, "copy data failed"); - raw_data += (data[i].size() + 1U); - ptr_size += (data[i].size() + 1U); - } - - (void)ge_tensor.SetData(PtrToPtr(buff.get()), total_size); - return GRAPH_SUCCESS; - } - - private: - GeTensor ge_tensor; - friend class Tensor; - friend class TensorAdapter; -}; - -class ShapeImpl { - public: - ShapeImpl() = default; - ~ShapeImpl() = default; - explicit ShapeImpl(const std::vector &dims) { - bool is_unknown_dim_num = false; - for (const auto &dim : dims) { - if (dim == UNKNOWN_DIM_NUM) { - is_unknown_dim_num = true; - break; - } - } - dims_ = is_unknown_dim_num ? std::vector({UNKNOWN_DIM_NUM}) : dims; - } - - private: - std::vector dims_; - friend class Shape; -}; - -Shape::Shape() { impl_ = ComGraphMakeShared(); } - -Shape::Shape(const std::vector &dims) { impl_ = ComGraphMakeShared(dims); } - -size_t Shape::GetDimNum() const { - if (impl_ != nullptr) { - const bool is_dim_unknown = std::any_of(std::begin(impl_->dims_), std::end(impl_->dims_), - [](const int64_t i) { return i == UNKNOWN_DIM_NUM; }); - if (is_dim_unknown) { - GELOGI("Dim num is unknown, return 0U."); - return 0U; - } - return impl_->dims_.size(); - } - return 0U; -} - -int64_t Shape::GetDim(size_t idx) const { - if (impl_ != nullptr) { - if (idx >= impl_->dims_.size()) { - return 0; - } - return impl_->dims_[idx]; - } - return 0; -} - -graphStatus Shape::SetDim(size_t idx, int64_t value) { - if (impl_ != nullptr) { - if (idx >= impl_->dims_.size()) { - return GRAPH_FAILED; - } - impl_->dims_[idx] = value; - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -std::vector Shape::GetDims() const { - const std::vector dims; - if (impl_ != nullptr) { - return impl_->dims_; - } - return dims; -} - -int64_t Shape::GetShapeSize() const { - if (impl_ != nullptr) { - if (impl_->dims_.empty()) { - return 0; - } - int64_t size = 1; - for (const auto i : impl_->dims_) { - if ((i == UNKNOWN_DIM_NUM) || (i == UNKNOWN_DIM)) { - return UNKNOWN_DIM_SIZE; - } - - if (!Int64MulNotOverflow(size, i)) { - REPORT_INNER_ERR_MSG("E18888", "mul overflow: %" PRId64 ", %" PRId64, size, i); - GELOGE(GRAPH_FAILED, "[Check][Overflow] mul overflow: %" PRId64 ", %" PRId64, size, i); - size = 0; - return size; - } - size *= i; - } - return size; - } - return 0; -} - -TensorDesc::TensorDesc() { - impl = ComGraphMakeSharedAndThrow(); -} - -TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { - impl = ComGraphMakeSharedAndThrow(shape, format, dt); - SetRealDimCnt(static_cast(shape.GetDimNum())); -} - -TensorDesc::TensorDesc(const TensorDesc &desc) { - // Copy - impl = ComGraphMakeShared(); - if ((desc.impl != nullptr) && (impl != nullptr)) { - *impl = *desc.impl; - } -} - -TensorDesc::TensorDesc(TensorDesc &&desc) { - // Move - impl = std::move(desc.impl); -} - -TensorDesc &TensorDesc::operator=(const TensorDesc &desc) { - // Copy - if (&desc != this) { - impl = ComGraphMakeShared(); - if ((desc.impl != nullptr) && (impl != nullptr)) { - *impl = *desc.impl; - } - } - return *this; -} - -TensorDesc &TensorDesc::operator=(TensorDesc &&desc) { - if (&desc != this) { - impl = std::move(desc.impl); - } - return *this; -} - -void TensorDesc::Update(const Shape &shape, Format format, DataType dt) { - if (impl != nullptr) { - impl->shape_ = shape; - impl->format_ = format; - impl->data_type_ = dt; - } -} - -Shape TensorDesc::GetShape() const { - if (impl != nullptr) { - return impl->shape_; - } - return Shape(); -} - -void TensorDesc::SetShape(const Shape &shape) { - if (impl != nullptr) { - impl->shape_ = shape; - } -} - -// set shape with -2, it stand for unknown shape -graphStatus TensorDesc::SetUnknownDimNumShape() { - if (impl != nullptr) { - impl->shape_ = Shape({UNKNOWN_DIM_NUM}); - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "Set unknown shape failed, because no impl class!"); - GELOGE(GRAPH_FAILED, "[Set][UnknownDimNumShape] failed, because no impl class!"); - return GRAPH_FAILED; -} - -// for unknown shape -graphStatus TensorDesc::SetShapeRange(const std::vector> &range) { - if (impl != nullptr) { - impl->range_ = range; - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "SetShapeRange failed! impl is nullptr!"); - GELOGE(GRAPH_FAILED, "[Set][ShapeRange] failed! impl is nullptr!"); - return GRAPH_FAILED; -} -graphStatus TensorDesc::GetShapeRange(std::vector> &range) const { - if (impl != nullptr) { - range = impl->range_; - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "impl is nullptr! check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] impl is nullptr! check invalid"); - return GRAPH_FAILED; -} - -Shape TensorDesc::GetOriginShape() const { - if (impl != nullptr) { - return impl->origin_shape_; - } - return Shape(); -} - -void TensorDesc::SetOriginShape(const Shape &origin_shape) { - if (impl != nullptr) { - impl->origin_shape_ = origin_shape; - impl->origin_shape_is_set_ = true; - } -} - -Format TensorDesc::GetFormat() const { - if (impl != nullptr) { - return impl->format_; - } - return FORMAT_RESERVED; -} - -void TensorDesc::SetFormat(Format format) { - if (impl != nullptr) { - impl->format_ = format; - } -} - -Format TensorDesc::GetOriginFormat() const { - if (impl != nullptr) { - return impl->origin_format_; - } - return FORMAT_RESERVED; -} - -void TensorDesc::SetOriginFormat(Format origin_format) { - if (impl != nullptr) { - impl->origin_format_ = origin_format; - impl->origin_format_is_set_ = true; - } -} - -DataType TensorDesc::GetDataType() const { - if (impl != nullptr) { - return impl->data_type_; - } - return DT_UNDEFINED; -} - -void TensorDesc::SetDataType(DataType dt) { - if (impl != nullptr) { - impl->data_type_ = dt; - } -} - -void TensorDesc::SetSize(int64_t size) { - if (impl != nullptr) { - impl->size_ = size; - } -} - -int64_t TensorDesc::GetSize() const { - if (impl != nullptr) { - return impl->size_; - } - return 0; -} - -void TensorDesc::SetRealDimCnt(const int64_t real_dim_cnt) { - if (impl != nullptr) { - impl->real_dim_cnt_ = real_dim_cnt; - } -} - -int64_t TensorDesc::GetRealDimCnt() const { - if (impl != nullptr) { - return impl->real_dim_cnt_; - } - return 0; -} - -std::string TensorDesc::GetName() const { - if (impl != nullptr) { - return impl->name_; - } - return ""; -} - -void TensorDesc::SetName(const std::string &name) { - if (impl != nullptr) { - impl->name_ = name; - } -} - -graphStatus TensorDesc::GetName(AscendString &name) { - if (impl != nullptr) { - name = AscendString(impl->name_.c_str()); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus TensorDesc::GetName(AscendString &name) const { - if (impl != nullptr) { - name = AscendString(impl->name_.c_str()); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -void TensorDesc::SetName(const char_t *name) { - if ((impl != nullptr) && (name != nullptr)) { - impl->name_ = name; - } -} - -void TensorDesc::SetPlacement(Placement placement) { - if (impl != nullptr) { - impl->placement_ = placement; - } -} - -Placement TensorDesc::GetPlacement() const { - if (impl != nullptr) { - return impl->placement_; - } - return kPlacementHost; -} - -void TensorDesc::SetConstData(std::unique_ptr const_data_buffer, const size_t &const_data_len) { - if (impl != nullptr) { - impl->tensor_desc_value_.const_data_buffer_ = std::move(const_data_buffer); - impl->tensor_desc_value_.const_data_len_ = const_data_len; - } - return; -} - -bool TensorDesc::GetConstData(uint8_t **const_data_buffer, size_t &const_data_len) const { - if (impl != nullptr) { - *const_data_buffer = impl->tensor_desc_value_.const_data_buffer_.get(); - const_data_len = impl->tensor_desc_value_.const_data_len_; - return true; - } - return false; -} - -void TensorDesc::SetExpandDimsRule(const AscendString &expand_dims_rule) { - if (impl != nullptr) { - impl->expand_dims_rule_ = expand_dims_rule.GetString(); - } -} - -graphStatus TensorDesc::GetExpandDimsRule(AscendString &expand_dims_rule) const { - if (impl != nullptr) { - expand_dims_rule = AscendString(impl->expand_dims_rule_.c_str()); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -void TensorDesc::SetReuseInputIndex(const uint32_t idx) { - if (impl != nullptr) { - impl->reuse_input_ = true; - impl->reuse_input_index_ = idx; - } -} - -Tensor::Tensor() { impl = ComGraphMakeSharedAndThrow(); } - -Tensor::Tensor(const TensorDesc &tensor_desc) { - impl = ComGraphMakeSharedAndThrow(tensor_desc); -} - -static void CheckTensorParam(const uint64_t shape_size, const DataType data_type, const size_t data_size) { - uint32_t type_length; - const bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("[Create][Tensor] Datatype %d not found.", data_type); - } - - if (ret && ((shape_size != 0U) || (data_size != type_length))) { - if ((type_length != 0U) && ((UINT64_MAX / type_length) < shape_size)) { - GELOGW("[Create][Tensor] Calculate size failed, as mul overflow: %" PRIu64 " * %" PRIu32, - shape_size, type_length); - } else { - if ((shape_size * type_length) != data_size) { - GELOGW("[Create][Tensor] Tensor length not equal: shape_byte_size=%" PRIu64 ", dt_type=%s, data_size=%zu.", - shape_size * type_length, TypeUtils::DataTypeToSerialString(data_type).c_str(), data_size); - } - } - } -} - -Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector &data) { - CheckTensorParam(static_cast(tensor_desc.GetShape().GetShapeSize()), - tensor_desc.GetDataType(), data.size()); - impl = ComGraphMakeShared(tensor_desc, data); -} - -Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { - CheckTensorParam(static_cast(tensor_desc.GetShape().GetShapeSize()), - tensor_desc.GetDataType(), size); - impl = ComGraphMakeShared(tensor_desc, data, size); -} - -Tensor::Tensor(TensorDesc &&tensor_desc, std::vector &&data) { - CheckTensorParam(static_cast(tensor_desc.GetShape().GetShapeSize()), - tensor_desc.GetDataType(), data.size()); - impl = ComGraphMakeShared(std::move(tensor_desc), std::move(data)); -} - -TensorDesc Tensor::GetTensorDesc() const { - if (impl != nullptr) { - return TensorAdapter::GeTensorDesc2TensorDesc(impl->ge_tensor.MutableTensorDesc()); - } - return TensorDesc(); -} - -graphStatus Tensor::SetTensorDesc(const TensorDesc &tensor_desc) { - if (impl != nullptr) { - impl->ge_tensor.SetTensorDesc(TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc)); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -const uint8_t *Tensor::GetData() const { - if (impl != nullptr) { - return impl->ge_tensor.GetData().data(); - } - return nullptr; -} - -uint8_t *Tensor::GetData() { - if (impl != nullptr) { - return impl->ge_tensor.MutableData().data(); - } - return nullptr; -} - -size_t Tensor::GetSize() const { - if (impl != nullptr) { - return impl->ge_tensor.GetData().size(); - } - return 0U; -} - -std::unique_ptr Tensor::ResetData() { - if (impl != nullptr) { - auto aligned_ptr = impl->ge_tensor.GetAlignedPtr(); - if (aligned_ptr != nullptr) { - return aligned_ptr->Reset(); - } - } - return nullptr; -} - -graphStatus Tensor::SetData(std::vector &&data) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::vector &data) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const uint8_t *data, size_t size) { - if (impl != nullptr) { - (void)impl->ge_tensor.SetData(data, size); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::string &data) { - if (impl != nullptr) { - if (impl->SetData(data) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Set][Data] %s failed.", data.c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::vector &data) { - if (impl != nullptr) { - if (impl->SetData(data) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][SetData] Tensor set vector data failed."); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const char_t *data) { - if ((impl != nullptr) && (data != nullptr)) { - const std::string tensor_data = data; - if (impl->SetData(tensor_data) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][SetData] Tensor set data(%s) failed.", data); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(const std::vector &datas) { - if (impl != nullptr) { - std::vector tensor_data; - for (auto &data : datas) { - if (data.GetString() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Data is nullptr. check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Data is nullptr."); - return GRAPH_FAILED; - } - tensor_data.emplace_back(data.GetString()); - } - if (impl->SetData(tensor_data) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][SetData] Tensor set vector data failed."); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::SetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func) { - if (impl != nullptr) { - if (impl->ge_tensor.SetData(data, size, deleter_func) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][SetData] Tensor set data with deleter function failed"); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -graphStatus Tensor::IsValid() { - const uint64_t shape_size = static_cast(GetTensorDesc().GetShape().GetShapeSize()); - const DataType data_type = GetTensorDesc().GetDataType(); - uint32_t type_length; - const bool ret = TypeUtils::GetDataTypeLength(data_type, type_length); - if (!ret) { - GELOGW("[Check][Tensor] Datatype %d not found.", data_type); - return GRAPH_SUCCESS; - } - - const size_t data_size = GetSize(); - if (data_type == DT_STRING) { - return GRAPH_SUCCESS; - } - - if ((shape_size != 0U) || (data_size != type_length)) { - if ((type_length != 0U) && ((UINT64_MAX / type_length) < shape_size)) { - GELOGW("[Check][Tensor] Calculate size failed, as mul overflow: %" PRIu64 " * %" PRIu32, shape_size, type_length); - } else { - if ((shape_size * type_length) != data_size) { - GELOGW("[Check][Tensor] Tensor length not equal: shape_byte_size=%" PRIu64 ", dt_type=%s, data_size=%zu.", - shape_size * type_length, TypeUtils::DataTypeToSerialString(data_type).c_str(), data_size); - return GRAPH_FAILED; - } - } - } - - return GRAPH_SUCCESS; -} - -graphStatus Tensor::SetOriginShapeDimNum(const size_t dim_num) { - if (impl != nullptr) { - impl->ge_tensor.MutableTensorDesc().MutableOriginShape().SetDimNum(dim_num); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -size_t Tensor::GetOriginShapeDimNum() const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetOriginShape().GetDimNum(); - } - return 0U; -} - -graphStatus Tensor::SetOriginShapeDim(const size_t idx, const int64_t dim_value) { - if (impl != nullptr) { - return impl->ge_tensor.MutableTensorDesc().MutableOriginShape().SetDim(idx, dim_value); - } - return ge::GRAPH_FAILED; -} - -int64_t Tensor::GetOriginShapeDim(const size_t idx) const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetOriginShape().GetDim(idx); - } - return 0; -} - -graphStatus Tensor::SetOriginFormat(const ge::Format &format) { - if (impl != nullptr) { - impl->ge_tensor.MutableTensorDesc().SetOriginFormat(format); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -ge::Format Tensor::GetOriginFormat() const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetOriginFormat(); - } - return ge::FORMAT_RESERVED; -} - -graphStatus Tensor::SetShapeDimNum(const size_t dim_num) { - if (impl != nullptr) { - impl->ge_tensor.MutableTensorDesc().MutableShape().SetDimNum(dim_num); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -size_t Tensor::GetShapeDimNum() const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetShape().GetDimNum(); - } - return 0U; -} - -graphStatus Tensor::SetShapeDim(const size_t idx, const int64_t dim_value) { - if (impl != nullptr) { - return impl->ge_tensor.MutableTensorDesc().MutableShape().SetDim(idx, dim_value); - } - return ge::GRAPH_FAILED; -} - -int64_t Tensor::GetShapeDim(const size_t idx) const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetShape().GetDim(idx); - } - return 0; -} - -graphStatus Tensor::SetFormat(const ge::Format &format) { - if (impl != nullptr) { - impl->ge_tensor.MutableTensorDesc().SetFormat(format); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -ge::Format Tensor::GetFormat() const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetFormat(); - } - return ge::FORMAT_RESERVED; -} - -graphStatus Tensor::SetDataType(const ge::DataType &dtype) { - if (impl != nullptr) { - impl->ge_tensor.MutableTensorDesc().SetDataType(dtype); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -ge::DataType Tensor::GetDataType() const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetDataType(); - } - return ge::DT_UNDEFINED; -} - -graphStatus Tensor::SetPlacement(const ge::Placement &placement) { - if (impl != nullptr) { - impl->ge_tensor.MutableTensorDesc().SetPlacement(placement); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -ge::Placement Tensor::GetPlacement() const { - if (impl != nullptr) { - return impl->ge_tensor.GetTensorDesc().GetPlacement(); - } - return ge::Placement::kPlacementEnd; -} - -graphStatus Tensor::SetExpandDimsRule(const AscendString &expand_dims_rule) { - if (impl != nullptr) { - impl->ge_tensor.MutableTensorDesc().SetExpandDimsRule(expand_dims_rule.GetString()); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -graphStatus Tensor::GetExpandDimsRule(AscendString &expand_dims_rule) const { - if (impl != nullptr) { - expand_dims_rule = AscendString(impl->ge_tensor.GetTensorDesc().GetExpandDimsRule().c_str()); - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; -} - -graphStatus Tensor::ResetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func) { - if (impl != nullptr) { - if (impl->ge_tensor.ResetData(data, size, deleter_func) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][SetData] Tensor set data with deleter function failed"); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -Tensor Tensor::Clone() const { - const Tensor tensor; - if ((impl != nullptr) && (tensor.impl != nullptr)) { - tensor.impl->ge_tensor = impl->ge_tensor.Clone(); - } - return tensor; -} - -GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc) { - GeTensorDesc ge_tensor_desc(GeShape(tensor_desc.GetShape().GetDims()), tensor_desc.GetFormat(), - tensor_desc.GetDataType()); - if (tensor_desc.impl->origin_format_is_set_) { - (void)AttrUtils::SetBool(ge_tensor_desc, ATTR_NAME_ORIGIN_FORMAT_IS_SET, true); - } - if (tensor_desc.impl->origin_shape_is_set_) { - ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); - } - ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); - ge_tensor_desc.SetExpandDimsRule(tensor_desc.impl->expand_dims_rule_); - TensorUtils::SetReuseInput(ge_tensor_desc, tensor_desc.impl->reuse_input_); - TensorUtils::SetReuseInputIndex(ge_tensor_desc, tensor_desc.impl->reuse_input_index_); - - AscendString name(""); - (void) tensor_desc.GetName(name); - ge_tensor_desc.SetName(name.GetString()); - ge_tensor_desc.SetPlacement(tensor_desc.GetPlacement()); - std::vector> shape_range; - auto status = tensor_desc.GetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get shape range failed! ret:%u", status); - GELOGE(GRAPH_FAILED, "[Get][ShapeRange] failed! ret:%u", status); - return ge_tensor_desc; - } - status = ge_tensor_desc.SetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Set shape range failed! ret:%u", status); - GELOGE(GRAPH_FAILED, "[Set][ShapeRange] failed! ret:%u", status); - return ge_tensor_desc; - } - const auto size = tensor_desc.GetSize(); - TensorUtils::SetSize(ge_tensor_desc, size); - - const auto real_dim_cnt = static_cast(tensor_desc.GetRealDimCnt()); - TensorUtils::SetRealDimCnt(ge_tensor_desc, real_dim_cnt); - return ge_tensor_desc; -} - -TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc) { - TensorDesc tensor_desc(Shape(ge_tensor_desc.GetShape().GetDims()), ge_tensor_desc.GetFormat(), - ge_tensor_desc.GetDataType()); - if (TensorUtils::IsOriginShapeInited(ge_tensor_desc)) { - tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); - } - tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); - tensor_desc.SetName(ge_tensor_desc.GetName().c_str()); - tensor_desc.SetPlacement(ge_tensor_desc.GetPlacement()); - std::vector> shape_range; - auto status = ge_tensor_desc.GetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get shape range failed! ret:%u", status); - GELOGE(GRAPH_FAILED, "[Get][ShapeRange] failed! ret:%u", status); - return tensor_desc; - } - status = tensor_desc.SetShapeRange(shape_range); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Set shape range failed! ret:%u", status); - GELOGE(GRAPH_FAILED, "[Set][ShapeRange] failed! ret:%u", status); - return tensor_desc; - } - int64_t size = 0; - (void)TensorUtils::GetSize(ge_tensor_desc, size); - tensor_desc.SetSize(size); - - uint32_t real_dim_cnt = 0U; - (void)TensorUtils::GetRealDimCnt(ge_tensor_desc, real_dim_cnt); - tensor_desc.SetRealDimCnt(static_cast(real_dim_cnt)); - - tensor_desc.SetExpandDimsRule(AscendString(ge_tensor_desc.GetExpandDimsRule().c_str())); - return tensor_desc; -} - -Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { - const Tensor tensor; - if ((ge_tensor != nullptr) && (tensor.impl != nullptr)) { - tensor.impl->ge_tensor = ge_tensor->Clone(); - } - return tensor; -} - -ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); - } - return ge_tensor; -} - -GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { - GeTensorPtr ge_tensor; - if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); - } - return ge_tensor; -} - -const GeTensor TensorAdapter::AsGeTensor(const Tensor &tensor) { - if (tensor.impl != nullptr) { - return tensor.impl->ge_tensor; - } - return GeTensor(); -} - -const GeTensor* TensorAdapter::AsBareGeTensorPtr(const Tensor &tensor) { - if (tensor.impl != nullptr) { - return &(tensor.impl->ge_tensor); - } - return nullptr; -} - -GeTensor TensorAdapter::AsGeTensorShared(const Tensor &tensor) { - if (tensor.impl != nullptr) { - // Construct new rvalue ge tensor to avoid call copy constructor - return GeTensor(tensor.impl->ge_tensor.impl_); - } - return {}; -} - -GeTensor TensorAdapter::NormalizeGeTensor(const GeTensor &tensor) { - auto normalized_tensor = tensor; - auto &desc = normalized_tensor.MutableTensorDesc(); - NormalizeGeTensorDesc(desc); - return normalized_tensor; -} - -void TensorAdapter::NormalizeGeTensorDesc(GeTensorDesc &desc) { - bool origin_format_is_set = false; - if (AttrUtils::GetBool(desc, ATTR_NAME_ORIGIN_FORMAT_IS_SET, origin_format_is_set) && origin_format_is_set && - TensorUtils::IsOriginShapeInited(desc)) { - (void) AttrUtils::SetInt(desc, ATTR_NAME_STORAGE_FORMAT, static_cast(desc.GetFormat())); - (void) AttrUtils::SetListInt(desc, ATTR_NAME_STORAGE_SHAPE, desc.GetShape().GetDims()); - desc.SetFormat(desc.GetOriginFormat()); - desc.SetShape(desc.GetOriginShape()); - (void) AttrUtils::SetBool(desc, ATTR_NAME_ORIGIN_FORMAT_IS_SET, false); - } -} - -GeTensor TensorAdapter::AsGeTensor(Tensor &tensor) { - if (tensor.impl != nullptr) { - return tensor.impl->ge_tensor; - } - return GeTensor(); -} - -const Tensor TensorAdapter::AsTensor(const GeTensor &ge_tensor) { - const Tensor tensor; - if (tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor; - } - return tensor; -} - -Tensor TensorAdapter::AsTensor(GeTensor &ge_tensor) { - const Tensor tensor; - if (tensor.impl != nullptr) { - tensor.impl->ge_tensor = ge_tensor; - } - return tensor; -} -} // namespace ge diff --git a/graph/opsproto/opsproto_manager.cc b/graph/opsproto/opsproto_manager.cc deleted file mode 100644 index bca6aa503583d222cb95108a1b29fadeb7f97ed8..0000000000000000000000000000000000000000 --- a/graph/opsproto/opsproto_manager.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/opsproto_manager.h" -#include -#include -#include "common/ge_common/debug/ge_log.h" -#include "graph/debug/ge_log.h" -#include "graph/types.h" -#include "graph/def_types.h" -#include "graph/operator_factory_impl.h" -#include "mmpa/mmpa_api.h" -#include "common/plugin/plugin_manager.h" -#include "graph/operator_factory_impl.h" - -namespace ge { -OpsProtoManager *OpsProtoManager::Instance() { - static OpsProtoManager instance; - return &instance; -} - -bool OpsProtoManager::Initialize(const std::map &options) { - const std::lock_guard lock(mutex_); - - if (is_init_) { - GELOGI("OpsProtoManager is already initialized."); - return true; - } - - const std::map::const_iterator iter = options.find("ge.opsProtoLibPath"); - if (iter == options.end()) { - GELOGW("[Initialize][CheckOption] Option \"ge.opsProtoLibPath\" not set"); - return false; - } - - pluginPath_ = iter->second; - LoadOpsProtoPluginSo(pluginPath_); - - is_init_ = true; - - return true; -} - -void OpsProtoManager::Finalize() { - const std::lock_guard lock(mutex_); - - if (!is_init_) { - GELOGI("OpsProtoManager is not initialized."); - return; - } - - for (const auto handle : handles_) { - if (handle != nullptr) { - if (mmDlclose(handle) != 0) { - const char_t *error = mmDlerror(); - error = (error == nullptr) ? "" : error; - GELOGW("[Finalize][CloseHandle] close handle unsuccessfully, reason:%s", error); - continue; - } - GELOGI("close opsprotomanager handler success"); - } else { - GELOGW("[Finalize][CheckHandle] handler is null"); - } - } - - is_init_ = false; -} - -OpsProtoManager::~OpsProtoManager() { - OperatorFactoryImpl::ReleaseRegInfo(); - Finalize(); -} - -static std::vector SplitStr(const std::string &str, const char_t delim) { - std::vector elems; - if (str.empty()) { - elems.emplace_back(""); - return elems; - } - - std::stringstream str_stream(str); - std::string item; - - while (getline(str_stream, item, delim)) { - elems.push_back(item); - } - - const auto str_size = str.size(); - if ((str_size > 0UL) && (str[str_size - 1UL] == delim)) { - elems.emplace_back(""); - } - - return elems; -} - -void GetOpsProtoSoFileList(const std::string &path, std::vector &file_list) { - // Support multi lib directory with ":" as delimiter - const std::vector v_path = SplitStr(path, ':'); - - std::string os_type; - std::string cpu_type; - PluginManager::GetCurEnvPackageOsAndCpuType(os_type, cpu_type); - - for (size_t i = 0UL; i < v_path.size(); ++i) { - const std::string new_path = v_path[i] + "lib/" + os_type + "/" + cpu_type + "/"; - char_t resolved_path[MMPA_MAX_PATH] = {}; - const INT32 result = mmRealPath(new_path.c_str(), &(resolved_path[0U]), MMPA_MAX_PATH); - if (result == EN_OK) { - std::vector file_list_unfiltered; - PluginManager::GetFileListWithSuffix(new_path, ".so", file_list_unfiltered); - std::for_each(file_list_unfiltered.begin(), file_list_unfiltered.end(), [&file_list](const std::string &file) { - if (!PluginManager::IsEndWith(file, "rt2.0.so") && !PluginManager::IsEndWith(file, "rt.so")) { - file_list.emplace_back(file); - } - }); - } else { - GELOGW("[FindSo][Check] Get path with os&cpu type [%s] unsuccessfully, reason:%s", new_path.c_str(), strerror(errno)); - PluginManager::GetFileListWithSuffix(v_path[i], ".so", file_list); - } - } -} - -void OpsProtoManager::LoadOpsProtoPluginSo(const std::string &path) { - if (path.empty()) { - REPORT_INNER_ERR_MSG("E18888", "filePath is empty. please check your text file."); - GELOGE(GRAPH_FAILED, "[Check][Param] filePath is empty. please check your text file."); - return; - } - std::vector file_list; - - // If there is .so file in the lib path - GetOpsProtoSoFileList(path, file_list); - - // Not found any .so file in the lib path - if (file_list.empty()) { - GELOGW("[LoadSo][Check] OpsProtoManager can not find any plugin file in pluginPath: %s \n", path.c_str()); - return; - } - // Warning message - GELOGW("[LoadSo][Check] Shared library will not be checked. Please make sure that the source of shared library is " - "trusted."); - - // Load .so file - for (const auto &elem : file_list) { - OperatorFactoryImpl::SetRegisterOverridable(true); - void *const handle = mmDlopen(elem.c_str(), static_cast(static_cast(MMPA_RTLD_NOW) | - static_cast(MMPA_RTLD_GLOBAL))); - OperatorFactoryImpl::SetRegisterOverridable(false); - if (handle == nullptr) { - const char_t *error = mmDlerror(); - error = (error == nullptr) ? "" : error; - GELOGW("[LoadSo][Open] OpsProtoManager dlopen unsuccessfully, plugin name:%s. Message(%s).", elem.c_str(), error); - continue; - } else { - // Close dl when the program exist, not close here - GELOGI("OpsProtoManager plugin load %s successfully.", elem.c_str()); - handles_.push_back(handle); - } - } -} -} // namespace ge diff --git a/graph/option/ge_context.cc b/graph/option/ge_context.cc deleted file mode 100644 index 965cf71db535bc206a80f7a5261d5fa6085addf9..0000000000000000000000000000000000000000 --- a/graph/option/ge_context.cc +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ge_context.h" -#include -#include "graph/ge_global_options.h" -#include "graph/ge_local_context.h" -#include "graph/types.h" -#include "common/ge_common/debug/ge_log.h" -#include "utils/extern_math_util.h" -#include "external/ge_common/ge_api_types.h" - -namespace ge { -namespace { -const int32_t kDecimal = 10; -const char_t *kHostExecPlacement = "HOST"; -const char_t *kEnabled = "1"; - -template -ge::Status GetOptionValue(const std::string &option_name, T &var) { - std::string option; - if (ge::GetContext().GetOption(option_name, option) != GRAPH_SUCCESS) { - return ge::FAILED; - } - - int64_t value = 0; - try { - value = static_cast(std::stoi(option.c_str())); - } catch (std::invalid_argument &) { - GELOGW("[Init] Transform option %s %s to int failed, as catching invalid_argument exception", option_name.c_str(), - option.c_str()); - return ge::FAILED; - } catch (std::out_of_range &) { - GELOGW("[Init] Transform option %s %s to int failed, as catching out_of_range exception", option_name.c_str(), - option.c_str()); - return ge::FAILED; - } - if (!IntegerChecker::Compat(value)) { - GELOGW("[Init] Transform option %s %s to int failed, value is invalid_argument", option_name.c_str(), - option.c_str()); - return ge::FAILED; - } - var = value; - return ge::SUCCESS; -} -} // namespace - -GEContext &GetContext() { - static GEContext ge_context {}; - return ge_context; -} - -thread_local uint64_t GEContext::session_id_ = 0UL; -thread_local uint64_t GEContext::context_id_ = 0UL; - -graphStatus GEContext::GetOption(const std::string &key, std::string &option) { - return GetThreadLocalContext().GetOption(key, option); -} - -const std::string &GEContext::GetReadableName(const std::string &key) { - return GetThreadLocalContext().GetReadableName(key); -} - -bool GEContext::IsOverflowDetectionOpen() const { - std::string enable_overflow_detection; - if (GetThreadLocalContext().GetOption("ge.exec.overflow", enable_overflow_detection) != GRAPH_SUCCESS) { - return false; - } - GELOGD("Option ge.exec.overflow is %s.", enable_overflow_detection.c_str()); - return (enable_overflow_detection == kEnabled); -} - -bool GEContext::IsGraphLevelSat() const { - std::string graph_level_sat; - if (GetThreadLocalContext().GetOption("ge.graphLevelSat", graph_level_sat) != GRAPH_SUCCESS) { - return false; - } - GELOGD("Option ge.graphLevelSat is %s.", graph_level_sat.c_str()); - return (graph_level_sat == kEnabled); -} - -bool GEContext::GetHostExecFlag() const { - std::string exec_placement; - if (GetThreadLocalContext().GetOption("ge.exec.placement", exec_placement) != GRAPH_SUCCESS) { - return false; - } - GELOGD("Option ge.exec.placement is %s.", exec_placement.c_str()); - return exec_placement == kHostExecPlacement; -} - -bool GEContext::GetTrainGraphFlag() const { - std::string run_mode; - if ((GetThreadLocalContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == ge::GRAPH_SUCCESS) && - (!run_mode.empty())) { - if (static_cast(std::strtol(run_mode.c_str(), nullptr, kDecimal)) >= ge::TRAIN) { - return true; - } - } - return false; -} - -uint64_t GEContext::GetInputFusionSize() const { - const uint64_t default_fusion_size = 128 * 1024U; // 128KB - const uint64_t max_fusion_size = 32 * 1024 * 1024U; // 32MB - - std::string fusion_size; - if (GetThreadLocalContext().GetOption(OPTION_EXEC_INPUT_FUSION_SIZE, fusion_size) != GRAPH_SUCCESS) { - return default_fusion_size; - } - - long value = std::strtol(fusion_size.c_str(), nullptr, kDecimal); - if (value < 0) { - GELOGI("%s is %s which is less than 0, return 0", OPTION_EXEC_INPUT_FUSION_SIZE, fusion_size.c_str()); - return 0U; - } - - uint64_t result = static_cast(value); - if (result > max_fusion_size) { - GELOGW("option [%s] is %s which is bigger than max(%" PRIu64 "), return max", OPTION_EXEC_INPUT_FUSION_SIZE, - fusion_size.c_str(), max_fusion_size); - return max_fusion_size; - } - return result; -} - -std::mutex &GetGlobalOptionsMutex() { - static std::mutex global_options_mutex; - return global_options_mutex; -} - -std::map &GetMutableGlobalOptions() { - static std::map context_global_options{}; - return context_global_options; -} - -void GEContext::Init() { - (void) GetOptionValue("ge.exec.sessionId", session_id_); - (void) GetOptionValue("ge.exec.deviceId", device_id_); - - int32_t stream_sync_timeout = -1; - (void) GetOptionValue("stream_sync_timeout", stream_sync_timeout); - SetStreamSyncTimeout(stream_sync_timeout); - - int32_t event_sync_timeout = -1; - (void) GetOptionValue("event_sync_timeout", event_sync_timeout); - SetEventSyncTimeout(event_sync_timeout); -} - -uint64_t GEContext::SessionId() const { return session_id_; } - -uint32_t GEContext::DeviceId() const { - uint32_t device_id = 0U; - // session device id has priority - auto status = GetOptionValue("ge.session_device_id", device_id); - return (status == ge::SUCCESS) ? device_id : device_id_; -} - -int32_t GEContext::StreamSyncTimeout() const { return GetThreadLocalContext().StreamSyncTimeout(); } - -int32_t GEContext::EventSyncTimeout() const { return GetThreadLocalContext().EventSyncTimeout(); } - -void GEContext::SetSessionId(const uint64_t session_id) { session_id_ = session_id; } - -void GEContext::SetContextId(const uint64_t context_id) { context_id_ = context_id; } - -void GEContext::SetCtxDeviceId(const uint32_t device_id) { device_id_ = device_id; } - -void GEContext::SetStreamSyncTimeout(const int32_t timeout) { GetThreadLocalContext().SetStreamSyncTimeout(timeout); } - -void GEContext::SetEventSyncTimeout(const int32_t timeout) { GetThreadLocalContext().SetEventSyncTimeout(timeout); } - -graphStatus GEContext::SetOptionNameMap(const std::string &option_name_map_json) { - return GetThreadLocalContext().SetOptionNameMap(option_name_map_json); -} - -OptimizationOption &GEContext::GetOo() const { - return GetThreadLocalContext().GetOo(); -} -} // namespace ge diff --git a/graph/option/ge_local_context.cc b/graph/option/ge_local_context.cc deleted file mode 100644 index bb5f16e146f06712a5550da508bc01a6328c49f5..0000000000000000000000000000000000000000 --- a/graph/option/ge_local_context.cc +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ge_local_context.h" -#include "nlohmann/json.hpp" -#include "common/ge_common/debug/ge_log.h" -#include - -namespace ge { -using Json = nlohmann::json; - -namespace { -int32_t GetTimeoutValue(const std::string &timeout_option) { - std::string timeout_str = "-1"; - (void) GetThreadLocalContext().GetOption(timeout_option, timeout_str); - int timeout_value; - try { - timeout_value = std::stoi(timeout_str); - } catch (...) { - timeout_value = -1; - GELOGW("option %s's value %s is invalid", timeout_option.c_str(), timeout_str.c_str()); - } - return timeout_value; -} -} // namespace - -GEThreadLocalContext &GetThreadLocalContext() { - static thread_local GEThreadLocalContext thread_context; - return thread_context; -} - -graphStatus GEThreadLocalContext::GetOption(const std::string &key, std::string &option) { - if (optimization_option_.GetValue(key, option) == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - const std::map::const_iterator graph_iter = graph_options_.find(key); - if (graph_iter != graph_options_.end()) { - option = graph_iter->second; - return GRAPH_SUCCESS; - } - const std::map::const_iterator session_iter = session_options_.find(key); - if (session_iter != session_options_.end()) { - option = session_iter->second; - return GRAPH_SUCCESS; - } - const std::map::const_iterator global_iter = global_options_.find(key); - if (global_iter != global_options_.end()) { - option = global_iter->second; - return GRAPH_SUCCESS; - } - return GRAPH_PARAM_INVALID; -} - -void GEThreadLocalContext::SetGlobalOption(std::map options_map) { - global_options_.clear(); - global_options_ = std::move(options_map); - - SetStreamSyncTimeout(GetTimeoutValue("stream_sync_timeout")); - SetEventSyncTimeout(GetTimeoutValue("event_sync_timeout")); - - std::string option_name_map = ""; - if (option_name_map_.empty() && - (GetThreadLocalContext().GetOption(ge::OPTION_NAME_MAP, option_name_map) == GRAPH_SUCCESS)) { - (void) SetOptionNameMap(option_name_map); - } -} - -void GEThreadLocalContext::SetSessionOption(std::map options_map) { - session_options_.clear(); - session_options_ = std::move(options_map); -} - -void GEThreadLocalContext::SetGraphOption(std::map options_map) { - graph_options_.clear(); - graph_options_ = std::move(options_map); -} - -graphStatus GEThreadLocalContext::SetOptionNameMap(const std::string &option_name_map_json) { - if (!option_name_map_.empty()) { - GELOGD("option name map has set, don't need reset"); - return ge::GRAPH_SUCCESS; - } - Json option_json; - try { - option_json = Json::parse(option_name_map_json); - } catch (nlohmann::json::parse_error &) { - GELOGE(ge::GRAPH_FAILED, "Parse JsonStr to Json failed, JsonStr: %s", option_name_map_json.c_str()); - return ge::GRAPH_FAILED; - } - for (auto iter : option_json.items()) { - if (iter.key().empty()) { - GELOGE(ge::GRAPH_FAILED, "Check option_name_map failed, key is null"); - return ge::GRAPH_FAILED; - } - if (static_cast(iter.value()).empty()) { - GELOGE(ge::GRAPH_FAILED, "Check option_name_map failed, value is null"); - return ge::GRAPH_FAILED; - } - option_name_map_.insert({iter.key(), static_cast(iter.value())}); - } - return ge::GRAPH_SUCCESS; -} - -const std::string &GEThreadLocalContext::GetReadableName(const std::string &key) { - auto iter = option_name_map_.find(key); - if (iter != option_name_map_.end()) { - GELOGD("Option %s's readable name is show name: %s", key.c_str(), iter->second.c_str()); - return iter->second; - } - GELOGD("Option %s's readable name is GE IR option: %s", key.c_str(), key.c_str()); - return key; -} - -std::map GEThreadLocalContext::GetAllGraphOptions() const { - return graph_options_; -} -std::map GEThreadLocalContext::GetAllSessionOptions() const { - return session_options_; -} -std::map GEThreadLocalContext::GetAllGlobalOptions() const { - return global_options_; -} - -std::map GEThreadLocalContext::GetAllOptions() const { - std::map options_all; - options_all.insert(graph_options_.cbegin(), graph_options_.cend()); - options_all.insert(session_options_.cbegin(), session_options_.cend()); - options_all.insert(global_options_.cbegin(), global_options_.cend()); - return options_all; -} - -void GEThreadLocalContext::SetStreamSyncTimeout(const int32_t timeout) { - stream_sync_timeout_ = timeout; -} - -void GEThreadLocalContext::SetEventSyncTimeout(const int32_t timeout) { - event_sync_timeout_ = timeout; -} - -int32_t GEThreadLocalContext::StreamSyncTimeout() const { - return stream_sync_timeout_; -} - -int32_t GEThreadLocalContext::EventSyncTimeout() const { - return event_sync_timeout_; -} - -OptimizationOption &GEThreadLocalContext::GetOo() { - return optimization_option_; -} -} // namespace ge diff --git a/graph/option/optimization_option.cc b/graph/option/optimization_option.cc deleted file mode 100644 index 79e075833f43b866c4ff92f7f504918154110a31..0000000000000000000000000000000000000000 --- a/graph/option/optimization_option.cc +++ /dev/null @@ -1,202 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/option/optimization_option.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/ge_common/string_util.h" -#include "common/checker.h" -#include "graph/ge_local_context.h" - -namespace ge { -namespace { -const std::unordered_map kOptValToLevels{ - {"O1", OoLevel::kO1}, - {"O3", OoLevel::kO3}, -}; - -void ReportParamInvalid(const std::string &opt_name, const std::string &opt_value, const std::string &reason) { - REPORT_PREDEFINED_ERR_MSG("E10001", std::vector({"parameter", "value", "reason"}), - std::vector({opt_name.c_str(), opt_value.c_str(), reason.c_str()})); - GELOGE(GRAPH_PARAM_INVALID, "[Oo][Check] the value [%s] of option [%s] is invalid. %s", opt_value.c_str(), - opt_name.c_str(), reason.c_str()); -} -} // namespace - -graphStatus OptimizationOption::Initialize(const std::map &ge_options, - const std::unordered_map ®istered_options) { - return Initialize(ge_options, registered_options, std::unordered_set{}); -} - -graphStatus OptimizationOption::Initialize(const std::map &ge_options, - const std::unordered_map ®istered_options, - const std::unordered_set &forbidden_option_set) { - working_oo_level_ = OoLevel::kEnd; - working_opt_names_to_value_.clear(); - // 1. Initialize OoLevel if possible - if (InitWorkingOolevel(ge_options) != GRAPH_SUCCESS) { - return GRAPH_PARAM_INVALID; - } - // 2. Expand optimization template - for (const auto &opt_info : registered_options) { - if (OoInfoUtils::IsBitSet(opt_info.second.levels, static_cast(working_oo_level_))) { - const auto value_str = OoInfoUtils::GetDefaultValue(opt_info.second, working_oo_level_); - (void) working_opt_names_to_value_.emplace(opt_info.first, value_str); - } - } - - // 解析ge.optimizationSwitch的值,更新working_opt_names_to_value_ - UpdatePassSwitchByOption(ge_options, forbidden_option_set); - - // 3. Verify user-configured optimization options - for (const auto &ge_opt : ge_options) { - const auto iter = registered_options.find(ge_opt.first); - if (iter == registered_options.cend()) { - continue; - } - if (IsOptionValueValid(ge_opt.first, ge_opt.second, iter->second.checker) != GRAPH_SUCCESS) { - return GRAPH_PARAM_INVALID; - } - working_opt_names_to_value_[ge_opt.first] = ge_opt.second; - } - - PrintAllWorkingOo(); - GELOGI("Init optimization option success"); - return GRAPH_SUCCESS; -} - -graphStatus OptimizationOption::GetValue(const std::string &opt_name, std::string &opt_value) const { - const auto iter = working_opt_names_to_value_.find(opt_name); - if (iter == working_opt_names_to_value_.cend()) { - return GRAPH_FAILED; - } - opt_value = iter->second; - return GRAPH_SUCCESS; -} - -graphStatus OptimizationOption::IsOoLevelValid(const std::string &oo_level) { - const auto &oo_level_iter = kOptValToLevels.find(oo_level); - if (oo_level_iter == kOptValToLevels.end()) { - ReportParamInvalid(GetThreadLocalContext().GetReadableName(OO_LEVEL), oo_level, - "The optimization option level is unsupported."); - return GRAPH_PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -graphStatus OptimizationOption::IsOptionValueValid(const std::string &opt_name, const std::string &opt_value, - OoInfo::ValueChecker checker) { - if (checker == nullptr) { - return GRAPH_SUCCESS; - } - if (!checker(opt_value)) { - ReportParamInvalid(GetThreadLocalContext().GetReadableName(opt_name), opt_value, - "Invalid optimization option value."); - return PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -graphStatus OptimizationOption::InitWorkingOolevel(const std::map &ge_options) { - const auto opt_iter = ge_options.find(OO_LEVEL); - if (opt_iter == ge_options.end()) { - // default oo_level is O3 if ge_option is not set - working_oo_level_ = OoLevel::kO3; - } else { - if (IsOoLevelValid(opt_iter->second) != GRAPH_SUCCESS) { - return GRAPH_PARAM_INVALID; - } - working_oo_level_ = kOptValToLevels.at(opt_iter->second); - } - GELOGI("[Oo][Print]working_oo_level is %u.", working_oo_level_); - return GRAPH_SUCCESS; -} - -bool OptimizationOption::IsPassConfiguredWithOptimizationSwitch(const std::string &pass_name) const { - /** - * pass开关若是通过绑定oo level写入到working_opt_names_to_value_中,键是passname,值是true/false/空 - * 若是通过optimization_switch配置写入,键是passname,值是on/off - */ - const auto iter = working_opt_names_to_value_.find(pass_name); - return (iter != working_opt_names_to_value_.end() && (iter->second == "on" || iter->second == "off")); -} - -graphStatus OptimizationOption::SetPassSwitch(const std::string &pass_switch_str, - const std::unordered_set &forbidden_option_set, - bool force_update) { - GELOGI("Begin to set pass switch with option optimization_switch [%s]", pass_switch_str.c_str()); - std::stringstream ss(pass_switch_str); - std::string token; - - // 拆分每一对 pass:switch - while (std::getline(ss, token, ';')) { - size_t pos = token.find(':'); - // 1. 格式错误处理,记录Warning日志 - if (pos == std::string::npos) { - GELOGW("[Oo][SetPassSwitch] Invalid token format: %s", token.c_str()); - continue; - } - - // 2. 校验冒号前后内容是否为空,冒号后面只能是on/off - std::string pass_name = token.substr(0, pos); - std::string pass_switch = token.substr(pos + 1); - if (pass_name.empty() || (pass_switch != "on" && pass_switch != "off")) { - GELOGW("[Oo][SetPassSwitch] Invalid key or value in token: %s", token.c_str()); - continue; - } - - // 3. 黑名单检查:不能配置ge option,记录Warning日志 - if (forbidden_option_set.find(pass_name) != forbidden_option_set.end()) { - GELOGW("[Oo][SetPassSwitch] [%s] is one of ge option names, can not configured here", pass_name.c_str()); - continue; - } - - // 4. 如果不是强制更新,则跳过已经通过optimization_switch配置的pass - if (!force_update && IsPassConfiguredWithOptimizationSwitch(pass_name)) { - GELOGW("[Oo][SetPassSwitch] [%s] is already configured, skip it", pass_name.c_str()); - continue; - } - - working_opt_names_to_value_[pass_name] = pass_switch; - GELOGD("[Oo][SetPassSwitch]the switch of pass [%s] is [%s]", pass_name.c_str(), pass_switch.c_str()); - } - - return GRAPH_SUCCESS; -} - -graphStatus OptimizationOption::UpdatePassSwitchByOption(const std::map &ge_options, - const std::unordered_set &forbidden_option_set) { - const auto iter = ge_options.find(ge::OPTIMIZATION_SWITCH); - if (iter == ge_options.end()) { - GELOGI("No need to init optimization switch"); - return GRAPH_SUCCESS; - } - // ge.optimizationSwitch的配置为最高优先级,强制更新 - return SetPassSwitch(iter->second, forbidden_option_set, true); -} - -void OptimizationOption::PrintAllWorkingOo() { - for (const auto &iter : working_opt_names_to_value_) { - GELOGD("[Oo][Print]the value [%s] of option [%s] is set successfully", iter.second.c_str(), iter.first.c_str()); - } -} - -graphStatus OptimizationOption::RefreshPassSwitch(const std::string &fusion_config_str) { - std::string optimization_switch; - if (GetThreadLocalContext().GetOption(OPTIMIZATION_SWITCH, optimization_switch) == GRAPH_SUCCESS && - (optimization_switch != "forbidden_close_pass:on" && optimization_switch != "forbidden_close_pass:off")) { - // 1. 以optimization_switch的配置优先,重复配置的option,使用optimization_switch的配置;FE带过来的fusion_config_str中只有pass开关,不做黑名单校验 - // 2. tfa/atc场景默认写入optimization_switch为optimization_switch == "forbidden_close_pass:on/off",此时以fusion_config_str中的配置为准 - // 3. forbidden_close_pass对外不可见,因此torch入口传过来的optimization_switch不会为"forbidden_close_pass:on/off" - return SetPassSwitch(fusion_config_str, std::unordered_set{}, false); - } else { - // fusion_cfg的配置优先级高于O3;FE带过来的fusion_config_str中只有pass开关,不做黑名单校验 - return SetPassSwitch(fusion_config_str, std::unordered_set{}, true); - } -} -} // namespace ge diff --git a/graph/option/optimization_option_info.cc b/graph/option/optimization_option_info.cc deleted file mode 100644 index a5b2f2d84c12a47d287125088c05d5061a52fb14..0000000000000000000000000000000000000000 --- a/graph/option/optimization_option_info.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/option/optimization_option_info.h" -#include -#include "ge_common/debug/ge_log.h" - -namespace { -const std::map kOoLevelStr = {{ge::OoLevel::kO1, "O1"}, {ge::OoLevel::kO3, "O3"}}; -} // namespace -namespace ge { -bool OoInfoUtils::IsBitSet(const uint64_t bits, const uint32_t pos) { - if (pos < sizeof(uint32_t)) { - return ((bits & (1UL << pos)) != 0UL); - } - return false; -} - -uint64_t OoInfoUtils::GenOptLevelBits(const std::vector &levels) { - uint64_t level_bits = 0; - for (const auto level : levels) { - if (level <= OoLevel::kO3) { - /** O3 > O2 > O1 > O0, 当前认为四个级别之间是子集包含关系 (如果有变, 此处需要修改) - * 例如, O1级别的选项属于 O1/O2/O3, 对应的三个比特位被置为 1 - */ - for (auto i = static_cast(level); i <= static_cast(OoLevel::kO3); ++i) { - level_bits |= (1 << static_cast(i)); - } - } else if (level < OoLevel::kEnd) { - // 超过 kO3 的优化模板都属于某个独立功能集,不一定是子集包含关系 - level_bits |= (1 << static_cast(static_cast(level))); - } - } - return level_bits; -} - -uint64_t OoInfoUtils::GenOptVisibilityBits(const std::vector &entries) { - uint64_t vis_bits = 0; - for (const auto entry : entries) { - if (entry < OoEntryPoint::kEnd) { - vis_bits |= (1 << static_cast(static_cast(entry))); - } - } - return vis_bits; -} - -std::string OoInfoUtils::GenOoLevelStr(const uint64_t opt_level) { - std::string level_str; - for (const auto &level : kOoLevelStr) { - if (OoInfoUtils::IsBitSet(opt_level, static_cast(level.first))) { - level_str.append(level.second); - level_str.push_back('/'); - } - } - if (level_str.back() == '/') { - level_str.pop_back(); - } - return level_str; -} - -std::string OoInfoUtils::GetDefaultValue(const ge::OoInfo &info, ge::OoLevel target_level) { - if (info.default_values.count(target_level) > 0UL) { - return info.default_values.at(target_level); - } - return {}; -} - -bool OoInfoUtils::IsSwitchOptValueValid(const std::string &opt_value) { - if (opt_value.empty() || (opt_value == "true") || (opt_value == "false")) { - return true; - } - GELOGE(ge::GRAPH_PARAM_INVALID, "Valid switch option value: \"true\" or \"false\" or null"); - return false; -} -} // namespace ge \ No newline at end of file diff --git a/graph/parallelism/comm_task_builder.h b/graph/parallelism/comm_task_builder.h deleted file mode 100644 index ac1088424f67fc590b2ea6a2b6c318341bc95ab7..0000000000000000000000000000000000000000 --- a/graph/parallelism/comm_task_builder.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_PARALLELISM_COMM_TASK_BUILDER_H_ -#define METADEF_GRAPH_PARALLELISM_COMM_TASK_BUILDER_H_ - -#include "graph/parallelism/tensor_parallel_attrs.h" -#include "nlohmann/json.hpp" - -namespace ge { -namespace tp { -class CommTaskBuilder { - public: - static CommTaskBuilder &GetInstance() { - static CommTaskBuilder instance; - return instance; - } - - void BuildCommTask(const nlohmann::json &j, CommTask &comm_task); - Status ConvertToJson(const CommTask &comm_task, nlohmann::json &j); - - private: - CommTaskBuilder(); - ~CommTaskBuilder() = default; - - void InitCommTaskBuilders(); - void InitJsonConverters(); - template - static Status ConvertToJson(const T *reshard_task, nlohmann::json &j); - - std::map> builders_; - std::map> json_converters_; -}; -} // namespace tp -} // namespace ge - -#endif // METADEF_GRAPH_PARALLELISM_COMM_TASK_BUILDER_H_ diff --git a/graph/parallelism/tensor_parallel_attrs.cc b/graph/parallelism/tensor_parallel_attrs.cc deleted file mode 100644 index c5ea39aa6b917e899690726e86fb748c91ce865c..0000000000000000000000000000000000000000 --- a/graph/parallelism/tensor_parallel_attrs.cc +++ /dev/null @@ -1,890 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "parallelism/tensor_parallel_attrs.h" -#include "common/ge_common/util.h" -#include "graph/debug/ge_util.h" -#include "nlohmann/json.hpp" -#include "parallelism/comm_task_builder.h" - -#define USED_BY_JSON __attribute__((unused)) static - -namespace ge { -namespace tp { -namespace { -using Json = nlohmann::json; - -constexpr size_t kValidDimSliceItemNum = 2U; -constexpr size_t kIndexStepId = 0U; -constexpr size_t kIndexOutputIndex = 1U; - -Status StringToJson(const std::string &json_str, Json &json) { - std::stringstream ss; - ss << json_str; - try { - ss >> json; - } catch (const nlohmann::json::exception &e) { - GELOGE(PARAM_INVALID, "Failed to init json object, err = %s, json_str = %s", e.what(), json_str.c_str()); - return PARAM_INVALID; - } - return SUCCESS; -} - -template -Status ParseFromJson(const std::string &type, const std::string &json_str, T &value) { - Json json; - GE_CHK_STATUS_RET_NOLOG(StringToJson(json_str, json)); - try { - value = json.get(); - } catch (const nlohmann::json::exception &e) { - GELOGE(PARAM_INVALID, - "Failed to parse json object, type = %s, err = %s, json_str = %s", - type.c_str(), - e.what(), - json_str.c_str()); - return PARAM_INVALID; - } - return SUCCESS; -} - -template -std::shared_ptr CreateReshardTaskInfo(const Json &j) { - return ComGraphMakeShared(j.get()); -} - -template -std::string ToJsonString(const T &obj) { - try { - const Json j = obj; - return j.dump(); - } catch (const nlohmann::json::exception &e) { - GELOGE(FAILED, "Failed to dump object, err = %s", e.what()); - return ""; - } -} - -template -void GetValue(const Json &j, const std::string &key, T &value) { - value = j.at(key).template get(); -} -template -void TryGetValue(const Json &j, const std::string &key, T &value) { - if (j.contains(key)) { - value = j.at(key).template get(); - } -} -} // namespace - -void CommTaskBuilder::BuildCommTask(const Json &j, CommTask &comm_task) { - auto task_type = j.at("task_type").get(); - const decltype(builders_)::const_iterator it = builders_.find(task_type); - if (it == builders_.cend()) { - GELOGE(PARAM_INVALID, "unsupported op type %s", comm_task.task_type.c_str()); - return; - } - it->second(j, comm_task); // exception caught by caller - comm_task.task_type = std::move(task_type); -} - -Status CommTaskBuilder::ConvertToJson(const CommTask &comm_task, nlohmann::json &j) { - const decltype(json_converters_)::const_iterator it = json_converters_.find(comm_task.task_type); - GE_CHK_BOOL_RET_STATUS(it != json_converters_.cend(), - PARAM_INVALID, - "unsupported op type %s", - comm_task.task_type.c_str()); - return it->second(comm_task, j); // exception caught by caller -} - -std::string DeviceIndex::DebugString() const { - return engine_type + ToString(indices); -} - -USED_BY_JSON void to_json(Json &j, const DimSlice &dim_slice) { - j = std::vector{dim_slice.begin, dim_slice.end}; -} - -USED_BY_JSON void from_json(const Json &j, DimSlice &dim_slice) { - const auto &range = j.get>(); - if (range.size() == kValidDimSliceItemNum) { - dim_slice.begin = range.front(); - dim_slice.end = range.back(); - } else { - dim_slice.begin = -1; - dim_slice.end = -1; - GELOGE(PARAM_INVALID, "invalid DimSlice: %s", j.dump().c_str()); - } -} - -USED_BY_JSON void to_json(Json &j, const DeviceIndex &device_index) { - j = Json(); - j["engine_type"] = device_index.engine_type; - j["index"] = device_index.indices; -} - -USED_BY_JSON void from_json(const Json &j, DeviceIndex &device_index) { - GetValue(j, "engine_type", device_index.engine_type); - GetValue(j, "index", device_index.indices); -} - -USED_BY_JSON void to_json(Json &j, const ModelIndex &model_index) { - j = Json(); - j["device_index"] = model_index.device_index; - j["virtual_stage_id"] = model_index.virtual_stage_id; - j["stage_id"] = model_index.stage_id; -} - -USED_BY_JSON void from_json(const Json &j, ModelIndex &model_index) { - GetValue(j, "device_index", model_index.device_index); - GetValue(j, "virtual_stage_id", model_index.virtual_stage_id); - GetValue(j, "stage_id", model_index.stage_id); -} - -USED_BY_JSON void to_json(Json &j, const PipelineConfig &pipeline_config) { - j = Json(); - j["micro_batch"] = pipeline_config.micro_batch; - j["stage_id"] = pipeline_config.stage_id; - j["virtual_stage_id"] = pipeline_config.virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, PipelineConfig &pipeline_config) { - GetValue(j, "micro_batch", pipeline_config.micro_batch); - GetValue(j, "stage_id", pipeline_config.stage_id); - GetValue(j, "virtual_stage_id", pipeline_config.virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const TensorSliceDeployment &tensor_slice_deployment) { - j = Json(); - j["device_indices_each_slice"] = tensor_slice_deployment.device_indices_each_slice; - j["axis_slices"] = tensor_slice_deployment.axis_slices; -} - -USED_BY_JSON void from_json(const Json &j, TensorSliceDeployment &tensor_slice_deployment) { - GetValue(j, "device_indices_each_slice", tensor_slice_deployment.device_indices_each_slice); - GetValue(j, "axis_slices", tensor_slice_deployment.axis_slices); -} - -USED_BY_JSON void to_json(Json &j, const TensorDeployment &tensor_deployment) { - j = Json(); - j["shard_deployment"] = tensor_deployment.shard_deployment; - if (!tensor_deployment.verbose.empty()) { - j["verbose"] = tensor_deployment.verbose; - } -} - -USED_BY_JSON void from_json(const Json &j, TensorDeployment &tensor_deployment) { - GetValue(j, "shard_deployment", tensor_deployment.shard_deployment); - TryGetValue(j, "verbose", tensor_deployment.verbose); -} - -USED_BY_JSON void to_json(Json &j, const TensorDeployments &tensor_deployments) { - j = Json(); - j["deployments"] = tensor_deployments.deployments; -} - -USED_BY_JSON void from_json(const Json &j, NodeDeployments &node_deployments) { - GetValue(j, "deployments", node_deployments.deployments); -} - -USED_BY_JSON void from_json(const Json &j, TensorDeployments &tensor_deployments) { - GetValue(j, "deployments", tensor_deployments.deployments); -} - -USED_BY_JSON void to_json(Json &j, const NodeDeployment &node_deployment) { - j = Json(); - j["devices"] = node_deployment.devices; - j["pipeline_config"] = node_deployment.pipeline_config; -} - -USED_BY_JSON void from_json(const Json &j, NodeDeployment &node_deployment) { - GetValue(j, "devices", node_deployment.devices); - TryGetValue(j, "pipeline_config", node_deployment.pipeline_config); -} - -USED_BY_JSON void to_json(Json &j, const NodeDeployments &node_deployments) { - j = Json(); - j["deployments"] = node_deployments.deployments; -} - - -USED_BY_JSON void to_json(Json &j, const CommPair &comm_pair) { - j = Json(); - j["src_device_index"] = comm_pair.src_device_index; - j["dst_device_index"] = comm_pair.dst_device_index; - j["src_virtual_stage_id"] = comm_pair.src_virtual_stage_id; - j["dst_virtual_stage_id"] = comm_pair.dst_virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, CommPair &comm_pair) { - GetValue(j, "src_device_index", comm_pair.src_device_index); - GetValue(j, "dst_device_index", comm_pair.dst_device_index); - TryGetValue(j, "src_virtual_stage_id", comm_pair.src_virtual_stage_id); - TryGetValue(j, "dst_virtual_stage_id", comm_pair.dst_virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const FlowAttr &comm_group) { - j = Json(); - j["depth"] = comm_group.depth; - j["enqueue_policy"] = comm_group.enqueue_policy; -} - -USED_BY_JSON void from_json(const Json &j, FlowAttr &comm_group) { - GetValue(j, "depth", comm_group.depth); - GetValue(j, "enqueue_policy", comm_group.enqueue_policy); -} - -USED_BY_JSON void to_json(Json &j, const CommGroup &comm_group) { - j = comm_group.device_indices; -} - -USED_BY_JSON void from_json(const Json &j, CommGroup &comm_group) { - comm_group.device_indices = j.get>(); -} - -USED_BY_JSON void to_json(Json &j, const SendRecvReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSendReceive; - j["comm_pairs"] = task_info.comm_pairs; - j["parallel_group"] = task_info.parallel_group; - j["comm_type"] = task_info.comm_type; - j["flow_attr"] = task_info.flow_attr; -} - -USED_BY_JSON void from_json(const Json &j, SendRecvReshardTask &task_info) { - GetValue(j, "comm_pairs", task_info.comm_pairs); - TryGetValue(j, "comm_type", task_info.comm_type); - TryGetValue(j, "parallel_group", task_info.parallel_group); - TryGetValue(j, "flow_attr", task_info.flow_attr); -} - -USED_BY_JSON void to_json(Json &j, const AllGatherReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllGather; - j["axis"] = task_info.axis; - j["comm_groups"] = task_info.comm_groups; - j["parallel_group"] = task_info.parallel_group; - j["output_allocator"] = task_info.output_allocator; -} - -USED_BY_JSON void from_json(const Json &j, AllGatherReshardTask &all_gather_task_info) { - GetValue(j, "comm_groups", all_gather_task_info.comm_groups); - GetValue(j, "axis", all_gather_task_info.axis); - TryGetValue(j, "parallel_group", all_gather_task_info.parallel_group); - TryGetValue(j, "output_allocator", all_gather_task_info.output_allocator); -} - -USED_BY_JSON void to_json(Json &j, const AllToAllReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllToAll; - j["comm_groups"] = task_info.comm_groups; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, AllToAllReshardTask &all_to_all_task_info) { - GetValue(j, "comm_groups", all_to_all_task_info.comm_groups); - TryGetValue(j, "parallel_group", all_to_all_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const AllReduceReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllReduce; - j["comm_groups"] = task_info.comm_groups; - j["reduction"] = task_info.reduction; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, AllReduceReshardTask &all_reduce_task_info) { - GetValue(j, "reduction", all_reduce_task_info.reduction); - GetValue(j, "comm_groups", all_reduce_task_info.comm_groups); - TryGetValue(j, "parallel_group", all_reduce_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const AllReduceMeanReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomAllReduceMean; - j["comm_groups"] = task_info.comm_groups; - j["axis"] = task_info.axis; - j["value"] = task_info.value; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, AllReduceMeanReshardTask &task_info) { - GetValue(j, "comm_groups", task_info.comm_groups); - GetValue(j, "axis", task_info.axis); - GetValue(j, "value", task_info.value); - TryGetValue(j, "parallel_group", task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const ReduceScatterReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomReduceScatter; - j["comm_groups"] = task_info.comm_groups; - j["reduction"] = task_info.reduction; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, ReduceScatterReshardTask &reduce_scatter_task_info) { - GetValue(j, "reduction", reduce_scatter_task_info.reduction); - GetValue(j, "comm_groups", reduce_scatter_task_info.comm_groups); - TryGetValue(j, "parallel_group", reduce_scatter_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const BroadcastReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeHcomBroadcast; - j["comm_groups"] = task_info.comm_groups; - j["roots"] = task_info.root_device_indices; - j["parallel_group"] = task_info.parallel_group; -} - -USED_BY_JSON void from_json(const Json &j, BroadcastReshardTask &broadcast_task_info) { - GetValue(j, "roots", broadcast_task_info.root_device_indices); - GetValue(j, "comm_groups", broadcast_task_info.comm_groups); - TryGetValue(j, "parallel_group", broadcast_task_info.parallel_group); -} - -USED_BY_JSON void to_json(Json &j, const SliceReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSlice; - j["axes"] = task_info.axes; - j["offsets"] = task_info.offsets; - j["size"] = task_info.sizes; - j["device_index"] = task_info.device_index; -} - -USED_BY_JSON void from_json(const Json &j, SliceReshardTask &task_info) { - TryGetValue(j, "axes", task_info.axes); - GetValue(j, "offsets", task_info.offsets); - GetValue(j, "size", task_info.sizes); - TryGetValue(j, "device_index", task_info.device_index); -} - -USED_BY_JSON void to_json(Json &j, const SliceByAxisReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSliceByAxis; - j["axis_to_slice_deployments"] = task_info.axis_to_slice_deployments; -} - -USED_BY_JSON void from_json(const Json &j, SliceByAxisReshardTask &task_info) { - GetValue(j, "axis_to_slice_deployments", task_info.axis_to_slice_deployments); -} - -USED_BY_JSON void to_json(Json &j, const ConcatReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeConcat; - j["concat_dim"] = task_info.concat_dim; -} - -USED_BY_JSON void from_json(const Json &j, ConcatReshardTask &task_info) { - GetValue(j, "concat_dim", task_info.concat_dim); -} - -USED_BY_JSON void to_json(Json &j, const UniqueConcatReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeUniqueConcat; - j["unique_id"] = task_info.unique_id; - j["concat_dim"] = task_info.concat_dim; - j["src_device_indices"] = task_info.src_device_indices; - j["dst_device_index"] = task_info.dst_device_index; -} - -USED_BY_JSON void from_json(const Json &j, UniqueConcatReshardTask &task_info) { - TryGetValue(j, "unique_id", task_info.unique_id); - GetValue(j, "concat_dim", task_info.concat_dim); - GetValue(j, "src_device_indices", task_info.src_device_indices); - GetValue(j, "dst_device_index", task_info.dst_device_index); -} - -USED_BY_JSON void to_json(Json &j, const SplitReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeSplit; - j["num_split"] = task_info.num_split; - j["split_dim"] = task_info.split_dim; -} - -USED_BY_JSON void from_json(const Json &j, SplitReshardTask &task_info) { - GetValue(j, "num_split", task_info.num_split); - GetValue(j, "split_dim", task_info.split_dim); -} - -USED_BY_JSON void to_json(Json &j, const TransposeReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeTranspose; - j["perm"] = task_info.perm; -} - -USED_BY_JSON void from_json(const Json &j, TransposeReshardTask &task_info) { - GetValue(j, "perm", task_info.perm); -} - -USED_BY_JSON void to_json(Json &j, const ReshapeReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeReshape; - j["shape"] = task_info.shape; -} - -USED_BY_JSON void from_json(const Json &j, ReshapeReshardTask &task_info) { - GetValue(j, "shape", task_info.shape); -} - -USED_BY_JSON void to_json(Json &j, const ModifyValueReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeModifyValue; - j["op_type"] = task_info.op_type; - j["value"] = task_info.value; -} - -USED_BY_JSON void from_json(const Json &j, ModifyValueReshardTask &task_info) { - GetValue(j, "op_type", task_info.op_type); - GetValue(j, "value", task_info.value); -} - -USED_BY_JSON void to_json(Json &j, const CastReshardTask &task_info) { - j = Json(); - j["task_type"] = kCommTaskTypeCast; - j["dst_type"] = static_cast(task_info.dst_type); -} - -USED_BY_JSON void from_json(const Json &j, CastReshardTask &task_info) { - int32_t dst_type = -1; - GetValue(j, "dst_type", dst_type); - task_info.dst_type = static_cast(dst_type); -} - -USED_BY_JSON void to_json(Json &j, const CommTask &comm_task) { - GE_CHK_STATUS(CommTaskBuilder::GetInstance().ConvertToJson(comm_task, j)); -} - -USED_BY_JSON void from_json(const Json &j, CommTask &comm_task) { - CommTaskBuilder::GetInstance().BuildCommTask(j, comm_task); -} - -USED_BY_JSON void to_json(Json &j, const CommStepInput &step_input) { - j = std::vector{step_input.step_id, step_input.output_index}; -} - -USED_BY_JSON void from_json(const Json &j, CommStepInput &step_input) { - const auto step_id_and_out_index = j.get>(); - const size_t num_items = step_id_and_out_index.size(); - if (num_items > kIndexStepId) { - step_input.step_id = step_id_and_out_index[kIndexStepId]; - } - if (num_items > kIndexOutputIndex) { - step_input.output_index = step_id_and_out_index[kIndexOutputIndex]; - } -} - -USED_BY_JSON void to_json(Json &j, const CommStep &comm_step) { - j = Json(); - j["id"] = comm_step.id; - if (!comm_step.inputs.empty()) { - j["input_ids"] = comm_step.inputs; - } - j["comm_task"] = comm_step.comm_task; -} - -USED_BY_JSON void from_json(const Json &j, CommStep &comm_step) { - comm_step.id = j.at("id").get(); - if (j.contains("input_ids")) { - comm_step.inputs = j.at("input_ids").get>(); - } - comm_step.comm_task = j.at("comm_task").get(); -} - -USED_BY_JSON void to_json(Json &j, const PeerInput &peer_input) { - j = Json(); - j["step_id"] = peer_input.step_id; - j["node_name"] = peer_input.node_name; - j["input_index"] = peer_input.input_index; - j["stage_id"] = peer_input.stage_id; - j["virtual_stage_id"] = peer_input.virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, PeerInput &peer_input) { - GetValue(j, "step_id", peer_input.step_id); - GetValue(j, "node_name", peer_input.node_name); - GetValue(j, "input_index", peer_input.input_index); - TryGetValue(j, "stage_id", peer_input.stage_id); - TryGetValue(j, "virtual_stage_id", peer_input.virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const OutputReshardRes &reshard_res) { - j = Json(); - j["comm_steps"] = reshard_res.comm_steps; - j["peer_inputs"] = reshard_res.peer_inputs; - j["device_list"] = reshard_res.device_indices; - j["stage_id"] = reshard_res.stage_id; - j["virtual_stage_id"] = reshard_res.virtual_stage_id; -} - -USED_BY_JSON void from_json(const Json &j, OutputReshardRes &reshard_res) { - GetValue(j, "comm_steps", reshard_res.comm_steps); - GetValue(j, "peer_inputs", reshard_res.peer_inputs); - GetValue(j, "device_list", reshard_res.device_indices); - TryGetValue(j, "stage_id", reshard_res.stage_id); - TryGetValue(j, "virtual_stage_id", reshard_res.virtual_stage_id); -} - -USED_BY_JSON void to_json(Json &j, const ReshardAttr &reshard_attr) { - j = reshard_attr.reshard_infos; -} - -USED_BY_JSON void to_json(Json &j, const ShardGraphExtAttrs &shard_graph_ext_attrs) { - j = Json(); - j["dev_index_to_logic_dev_id"] = shard_graph_ext_attrs.dev_index_to_logic_dev_id; - j["graph_name_to_endpoints"] = shard_graph_ext_attrs.graph_name_to_endpoints; - j["group_name_to_dev_ids"] = shard_graph_ext_attrs.group_name_to_dev_ids; -} - -USED_BY_JSON void from_json(const Json &j, ShardGraphExtAttrs &shard_graph_ext_attrs) { - shard_graph_ext_attrs.dev_index_to_logic_dev_id = - j.at("dev_index_to_logic_dev_id").get>>(); - shard_graph_ext_attrs.graph_name_to_endpoints = - j.at("graph_name_to_endpoints").get>>>(); - shard_graph_ext_attrs.group_name_to_dev_ids = - j.at("group_name_to_dev_ids").get>>(); -} - -USED_BY_JSON void from_json(const Json &j, ReshardAttr &reshard_attr) { - reshard_attr.reshard_infos = j.get>>(); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, ShardGraphExtAttrs &shard_graph_ext_attrs) { - return ParseFromJson("ShardGraphExtAttrs", json_str, shard_graph_ext_attrs); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, DeviceIndex &device_index) { - return ParseFromJson("DeviceIndex", json_str, device_index); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, ModelIndex &model_index) { - return ParseFromJson("ModelIndex", json_str, model_index); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, PipelineConfig &pipeline_config) { - return ParseFromJson("PipelineConfig", json_str, pipeline_config); -} - - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - ReshardAttr &reshard_attr) { - return ParseFromJson("ReshardRes", json_str, reshard_attr); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - TensorDeployment &tensor_deployment) { - return ParseFromJson("TensorDeployment", json_str, tensor_deployment); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - TensorDeployments &tensor_deployments) { - return ParseFromJson("TensorDeployments", json_str, tensor_deployments); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - NodeDeployments &node_deployments) { - return ParseFromJson("NodeDeployments", json_str, node_deployments); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, CommTask &comm_task) { - return ParseFromJson("CommTask", json_str, comm_task); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, CommStep &comm_step) { - return ParseFromJson("CommStep", json_str, comm_step); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, - OutputReshardRes &output_reshard_res) { - return ParseFromJson("TensorReshardInfo", json_str, output_reshard_res); -} - -Status TensorParallelAttrs::FromJson(const std::string &json_str, NodeDeployment &node_deployment) { - return ParseFromJson("NodeDeployment", json_str, node_deployment); -} - -std::string TensorParallelAttrs::ToJson(const ShardGraphExtAttrs &shard_graph_ext_attrs) { - return ToJsonString(shard_graph_ext_attrs); -} - -std::string TensorParallelAttrs::ToJson(const DeviceIndex &device_index) { - return ToJsonString(device_index); -} - -std::string TensorParallelAttrs::ToJson(const ModelIndex &model_index) { - return ToJsonString(model_index); -} - -std::string TensorParallelAttrs::ToJson(const PipelineConfig &pipeline_config) { - return ToJsonString(pipeline_config); -} - -std::string TensorParallelAttrs::ToJson(const NodeDeployment &node_deployment) { - return ToJsonString(node_deployment); -} - -std::string TensorParallelAttrs::ToJson(const TensorDeployment &tensor_deployment) { - return ToJsonString(tensor_deployment); -} - -std::string TensorParallelAttrs::ToJson(const ReshardAttr &reshard_attr) { - return ToJsonString(reshard_attr); -} - -std::string TensorParallelAttrs::ToJson(const TensorDeployments &tensor_deployments) { - return ToJsonString(tensor_deployments); -} - -std::string TensorParallelAttrs::ToJson(const NodeDeployments &node_deployments) { - return ToJsonString(node_deployments); -} - -CommTaskBuilder::CommTaskBuilder() { - InitCommTaskBuilders(); - InitJsonConverters(); -} - -void CommTaskBuilder::InitCommTaskBuilders() { - builders_[kCommTaskTypeSlice] = [](const Json &j, CommTask &comm_task) { - comm_task.slice_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeSliceByAxis] = [](const Json &j, CommTask &comm_task) { - comm_task.slice_by_axis_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeSplit] = [](const Json &j, CommTask &comm_task) { - comm_task.split_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeConcat] = [](const Json &j, CommTask &comm_task) { - comm_task.concat_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeUniqueConcat] = [](const Json &j, CommTask &comm_task) { - comm_task.unique_concat_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeTranspose] = [](const Json &j, CommTask &comm_task) { - comm_task.transpose_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllGather] = [](const Json &j, CommTask &comm_task) { - comm_task.all_gather_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllReduce] = [](const Json &j, CommTask &comm_task) { - comm_task.all_reduce_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllReduceMean] = [](const Json &j, CommTask &comm_task) { - comm_task.all_reduce_mean_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomReduceScatter] = [](const Json &j, CommTask &comm_task) { - comm_task.reduce_scatter_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomBroadcast] = [](const Json &j, CommTask &comm_task) { - comm_task.broadcast_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeHcomAllToAll] = [](const Json &j, CommTask &comm_task) { - comm_task.all_to_all_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeSendReceive] = [](const Json &j, CommTask &comm_task) { - comm_task.send_recv_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeModifyValue] = [](const Json &j, CommTask &comm_task) { - comm_task.modify_value_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeReshape] = [](const Json &j, CommTask &comm_task) { - comm_task.reshape_reshard_task = CreateReshardTaskInfo(j); - }; - builders_[kCommTaskTypeCast] = [](const Json &j, CommTask &comm_task) { - comm_task.cast_reshard_task = CreateReshardTaskInfo(j); - }; -} - -template -Status CommTaskBuilder::ConvertToJson(const T *reshard_task, nlohmann::json &j) { - GE_CHECK_NOTNULL(reshard_task); - j = *reshard_task; - return SUCCESS; -} - -void CommTaskBuilder::InitJsonConverters() { - json_converters_[kCommTaskTypeSlice] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.slice_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeSliceByAxis] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.slice_by_axis_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeSplit] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.split_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeConcat] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.concat_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeUniqueConcat] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.unique_concat_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeTranspose] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.transpose_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllGather] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_gather_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllReduce] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_reduce_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllReduceMean] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_reduce_mean_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomReduceScatter] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.reduce_scatter_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomBroadcast] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.broadcast_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeHcomAllToAll] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.all_to_all_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeSendReceive] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.send_recv_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeModifyValue] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.modify_value_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeReshape] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.reshape_reshard_task.get(), j); - }; - json_converters_[kCommTaskTypeCast] = [](const CommTask &comm_task, nlohmann::json &j) { - return ConvertToJson(comm_task.cast_reshard_task.get(), j); - }; -} - -bool operator==(const DeviceIndex &lhs, const DeviceIndex &rhs) { - return lhs.engine_type == rhs.engine_type && - lhs.indices == rhs.indices; -} - -bool operator!=(const DeviceIndex &lhs, const DeviceIndex &rhs) { - return !(rhs == lhs); -} - -bool operator<(const DeviceIndex &lhs, const DeviceIndex &rhs) { - if (lhs.engine_type < rhs.engine_type) { - return true; - } - if (rhs.engine_type < lhs.engine_type) { - return false; - } - return lhs.indices < rhs.indices; -} - -bool operator==(const ModelIndex &lhs, const ModelIndex &rhs) { - return (lhs.device_index == rhs.device_index) && (lhs.virtual_stage_id == rhs.virtual_stage_id); -} - -bool operator!=(const ModelIndex &lhs, const ModelIndex &rhs) { - return !(rhs == lhs); -} - -bool operator<(const ModelIndex &lhs, const ModelIndex &rhs) { - if (lhs.virtual_stage_id < rhs.virtual_stage_id) { - return true; - } - if (rhs.virtual_stage_id < lhs.virtual_stage_id) { - return false; - } - return lhs.device_index < rhs.device_index; -} - -bool operator==(const CommStepInput &lhs, const CommStepInput &rhs) { - return (lhs.step_id == rhs.step_id) && (lhs.output_index == rhs.output_index); -} - -bool operator<(const CommStepInput &lhs, const CommStepInput &rhs) { - if (lhs.step_id < rhs.step_id) { - return true; - } - if (rhs.step_id < lhs.step_id) { - return false; - } - return lhs.output_index < rhs.output_index; -} - -bool operator==(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs) { - return (lhs.inserted_node_id == rhs.inserted_node_id) && (lhs.output_index == rhs.output_index); -} -bool operator<(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs) { - if (lhs.inserted_node_id < rhs.inserted_node_id) { - return true; - } - if (rhs.inserted_node_id < lhs.inserted_node_id) { - return false; - } - return lhs.output_index < rhs.output_index; -} - -bool operator==(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs) { - return (lhs.node_name == rhs.node_name) && (lhs.sliced_id == rhs.sliced_id); -} - -bool operator<(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs) { - if (lhs.node_name < rhs.node_name) { - return true; - } - if (rhs.node_name < lhs.node_name) { - return false; - } - return lhs.sliced_id < rhs.sliced_id; -} - -bool operator==(const DstNodeInfo &lhs, const DstNodeInfo &rhs) { - return (lhs.orig_node_info == rhs.orig_node_info) && (lhs.input_indexes == rhs.input_indexes); -} - -bool operator<(const DstNodeInfo &lhs, const DstNodeInfo &rhs) { - if (lhs.orig_node_info < rhs.orig_node_info) { - return true; - } - if (rhs.orig_node_info < lhs.orig_node_info) { - return false; - } - return lhs.InputIndexesToString() < rhs.InputIndexesToString(); -} - -bool operator==(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs) { - if ((lhs.input_info.inserted_node_id >= 0) && (rhs.input_info.inserted_node_id >= 0)) { - return (lhs.input_info == rhs.input_info); - } - if ((lhs.input_info.inserted_node_id < 0) && (rhs.input_info.inserted_node_id < 0)) { - return (lhs.input_info == rhs.input_info) && (lhs.orig_node_info == rhs.orig_node_info); - } - return false; -} -bool operator<(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs) { - if (lhs.input_info < rhs.input_info) { - return true; - } - if (rhs.input_info < lhs.input_info) { - return false; - } - return lhs.orig_node_info < rhs.orig_node_info; -} - -bool operator==(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs) { - return (lhs.input_info == rhs.input_info) && (lhs.node_info == rhs.node_info); -} - -bool operator<(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs) { - if (lhs.input_info < rhs.input_info) { - return true; - } - if (rhs.input_info < lhs.input_info) { - return false; - } - return lhs.node_info < rhs.node_info; -} - -std::string ModelIndex::DebugString() const { - return device_index.DebugString() + "[S" + std::to_string(stage_id) + ", V" + std::to_string(virtual_stage_id) + "]"; -} -} // namespace tp -} // namespace ge diff --git a/graph/refiner/format_refiner.cc b/graph/refiner/format_refiner.cc deleted file mode 100644 index 81685320073f0b8c6406181b7a3d88ea33e770be..0000000000000000000000000000000000000000 --- a/graph/refiner/format_refiner.cc +++ /dev/null @@ -1,524 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/refiner/format_refiner.h" - -#include -#include -#include -#include - -#include "graph/ref_relation.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/type_utils_inner.h" -#include "graph/types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/utils/op_type_utils.h" - -namespace ge { -namespace { -const size_t kDimSizeOf4D = 4UL; -const std::set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; -const char_t *const kIsGraphInferred = "_is_graph_inferred"; -thread_local RefRelations reflection_builder; - -static graphStatus ReflectionProcess(const std::unordered_set &reflection, - std::deque &nodes, const ge::Format to_be_set_format) { - for (const auto &reflection_cell : reflection) { - const auto &reflection_node = reflection_cell.node; - const auto in_out_idx = reflection_cell.in_out_idx; - GE_CHECK_NOTNULL(reflection_node); - if (reflection_cell.in_out == ge::NODE_IN) { - auto desc = reflection_node->GetOpDesc()->MutableInputDesc(static_cast(in_out_idx)); - GE_CHECK_NOTNULL(desc); - desc->SetOriginFormat(to_be_set_format); - desc->SetFormat(to_be_set_format); - } else { - auto desc = reflection_node->GetOpDesc()->MutableOutputDesc(static_cast(in_out_idx)); - GE_CHECK_NOTNULL(desc); - desc->SetOriginFormat(to_be_set_format); - desc->SetFormat(to_be_set_format); - } - nodes.push_back(reflection_cell.node); - } - - return GRAPH_SUCCESS; -} - -static graphStatus BiasAddFormatFixProcess(const ge::NodePtr &graph_node_ptr) { - // 5 meas dim num - if ((graph_node_ptr->GetType() != "BiasAdd") && (graph_node_ptr->GetType() != "BiasAddGrad")) { - return GRAPH_SUCCESS; - } - const std::unordered_map kTfFormatFix = { - {"NHWC", FORMAT_NDHWC}, - {"NCHW", FORMAT_NCDHW} - }; - for (size_t i = 0UL; i < graph_node_ptr->GetOpDesc()->GetInputsSize(); i++) { - const auto in_desc = graph_node_ptr->GetOpDesc()->MutableInputDesc(static_cast(i)); - GE_CHECK_NOTNULL(in_desc); - const auto dim_num = in_desc->MutableShape().GetDimNum(); - if (dim_num == 5UL) { // 5 means dim num - const auto org_format = in_desc->GetOriginFormat(); - const auto key = TypeUtils::FormatToSerialString(org_format); - const auto fixed_format = (kTfFormatFix.count(key) == 0UL) ? org_format : kTfFormatFix.at(key); - in_desc->SetOriginFormat(fixed_format); - in_desc->SetFormat(fixed_format); - GELOGD("Fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", - i, graph_node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(org_format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } else if (dim_num < 4UL) { - in_desc->SetOriginFormat(FORMAT_ND); - in_desc->SetFormat(FORMAT_ND); - GELOGD("Fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", - i, graph_node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(FORMAT_ND).c_str(), - TypeUtils::FormatToSerialString(FORMAT_ND).c_str()); - } else { - // do nothing - } - } - for (size_t i = 0UL; i < graph_node_ptr->GetOpDesc()->GetOutputsSize(); i++) { - const auto out_desc = graph_node_ptr->GetOpDesc()->MutableOutputDesc(static_cast(i)); - GE_CHECK_NOTNULL(out_desc); - if (out_desc->MutableShape().GetDimNum() != 5UL) { // 5 means dim num - continue; - } - const auto org_format = out_desc->GetOriginFormat(); - const auto key = TypeUtils::FormatToSerialString(org_format); - const auto fixed_format = (kTfFormatFix.count(key) == 0UL) ? org_format : kTfFormatFix.at(key); - out_desc->SetOriginFormat(fixed_format); - out_desc->SetFormat(fixed_format); - GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", - i, graph_node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(org_format).c_str(), - TypeUtils::FormatToSerialString(fixed_format).c_str()); - } - return GRAPH_SUCCESS; -} - - -static bool JudgeNodeIsAllNd(const OpDescPtr &one_op_desc, const ge::NodePtr &one_node_ptr, - std::vector &anchor_data_nodes) { - // consider special node save process - // Pre-save data node (only main graph data) and default infer fail - if (OpTypeUtils::IsDataNode(one_node_ptr->GetType())) { - anchor_data_nodes.push_back(one_node_ptr); - } - - // get all input desc format - const auto input_size = static_cast(one_op_desc->GetAllInputsSize()); - for (uint32_t i = 0U; i < input_size; i++) { - // Operator pre-set format but not origin format - const auto &input_desc = one_op_desc->MutableInputDesc(i); - GE_IF_BOOL_EXEC(input_desc == nullptr, continue); - const auto input_format = input_desc->GetFormat(); - if ((input_format != FORMAT_ND) && (input_format != FORMAT_RESERVED)) { - return false; - } - } - // Get all output desc format - const auto output_size = static_cast(one_op_desc->GetOutputsSize()); - for (uint32_t i = 0U; i < output_size; i++) { - const auto &output_desc = one_op_desc->MutableOutputDesc(i); - GE_IF_BOOL_EXEC(output_desc == nullptr, continue); - const auto output_format = output_desc->GetFormat(); - if ((output_format != FORMAT_ND) && (output_format != FORMAT_RESERVED)) { - return false; - } - } - return true; -} - -static graphStatus AnchorsInferProcess(std::deque &nodes, const OutDataAnchorPtr &out_data_anchor, - const Format to_be_set_format) { - for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); - - const auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); - GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); - const auto peer_in_data_opdesc = peer_in_data_node->GetOpDesc(); - GE_IF_BOOL_EXEC(peer_in_data_opdesc == nullptr, continue); - - // Check format whether have been set - const int32_t idx = peer_in_data_anchor->GetIdx(); - // do peer_out_node name and index as key to lookup reflections - const ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); - std::unordered_set reflection; - auto ret_status = reflection_builder.LookUpRefRelations(key, reflection); - if (ret_status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "LookUpRefRelations failed! Node is [%s], the %d input edge", - (peer_in_data_node->GetName()).c_str(), idx); - GELOGE(GRAPH_FAILED, "[Call][LookUpRefRelations] failed! Node is [%s], the %d input edge", - (peer_in_data_node->GetName()).c_str(), idx); - return GRAPH_FAILED; - } - - bool format_locked = false; - (void)AttrUtils::GetBool(peer_in_data_opdesc, ATTR_NAME_FORMAT_LOCKED, format_locked); - GELOGD("Get format locked flag:%u (shape can not be changed while value is equal to 1) from peer in node:%s.", - static_cast(format_locked), peer_in_data_node->GetName().c_str()); - - auto ge_tensor_desc = peer_in_data_opdesc->MutableInputDesc(static_cast(idx)); - if (ge_tensor_desc == nullptr) { - continue; - } - if ((ge_tensor_desc->GetOriginFormat() == FORMAT_ND) && (!format_locked)) { - const auto dim_num = ge_tensor_desc->GetShape().GetDimNum(); - GE_IF_BOOL_EXEC(dim_num == 0UL, - GELOGI("node name:%s idx:%d in is scalar. stop forward infer!", peer_in_data_node->GetName().c_str(), idx); - continue); - - /// Check whether node to change dims () - /// Because some node will calculate with 5D, C dim maybe multi meaning - const auto peer_in_data_node_type = peer_in_data_node->GetType(); - const auto iter1 = kChangeDimNodes.find(peer_in_data_node_type); - // 4 means dims num - if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4UL)) { - GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); - continue; - } - - if (reflection.empty()) { - ge_tensor_desc->SetOriginFormat(to_be_set_format); - ge_tensor_desc->SetFormat(to_be_set_format); - - /// Because netoutput node added before infer format ,so netoutput is end condition - /// must set netoutput format , because saved result depend on format - GE_IF_BOOL_EXEC(peer_in_data_node_type == NETOUTPUT, continue); - - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); - ret_status = NodeUtilsEx::InferOriginFormat(peer_in_data_node); - GE_IF_BOOL_EXEC(ret_status != GRAPH_SUCCESS, - GELOGE(GRAPH_FAILED, "[Infer][Format] failed, node:%s", (peer_in_data_node->GetName()).c_str()); - return GRAPH_FAILED); - nodes.push_back(peer_in_data_node); - } else { - const auto ret = ReflectionProcess(reflection, nodes, to_be_set_format); - GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "[Reflect][Node] failed! status:%d", ret); - return GRAPH_FAILED); - } - } - } - return GRAPH_SUCCESS; -} -} // namespace - -graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &com_graph, const OpDescPtr &op_desc) { - if ((op_desc->GetType() == CONSTANTOP) && (!IsGraphInferred(com_graph))) { - ConstGeTensorPtr tensor_value; - if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { - REPORT_INNER_ERR_MSG("E18888", "GetTensor failed, node name:%s.", op_desc->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Tensor] failed, node name:%s.", op_desc->GetName().c_str()); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(tensor_value); - (void)op_desc->UpdateOutputDesc(0U, tensor_value->GetTensorDesc()); - } - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &com_graph, - std::vector &anchor_points, - std::vector &anchor_data_nodes) { - anchor_points.clear(); - // Get all anchor point nodes and switch nodes - for (auto &one_node_ptr : com_graph->GetAllNodes()) { - if (one_node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node ptr in graph(%s) should not be null", com_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node ptr in graph(%s) should not be null", com_graph->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &one_op_desc = one_node_ptr->GetOpDesc(); - if (one_op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node's opdesc is nullptr,graph:%s", com_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node's opdesc is nullptr,graph:%s", com_graph->GetName().c_str()); - return GRAPH_FAILED; - } - graphStatus ret_status = RefreshConstantOutProcess(com_graph, one_op_desc); - if (ret_status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][RefreshConstantOutProcess] failed! graph:%s, op:%s", - com_graph->GetName().c_str(), one_op_desc->GetName().c_str()); - return GRAPH_FAILED; - } - - // check anchor point valid - if (JudgeNodeIsAllNd(one_op_desc, one_node_ptr, anchor_data_nodes)) { - continue; - } - // special process for biasAdd op - // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg - // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism - // so here do special process - ret_status = BiasAddFormatFixProcess(one_node_ptr); - if (ret_status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][BiasAddFormatFixProcess] failed! node:%s, graph:%s", - one_node_ptr->GetName().c_str(), com_graph->GetName().c_str()); - return GRAPH_FAILED; - } - - GELOGD("Node[%s] is anchor point!", one_node_ptr->GetName().c_str()); - anchor_points.push_back(one_node_ptr); - } - GELOGI("anchor_points number is %zu", anchor_points.size()); - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::AnchorProcess(const ge::NodePtr &anchor_node) { - std::deque nodes; - nodes.push_back(anchor_node); - while (!nodes.empty()) { - const ge::NodePtr one_node = nodes.front(); - nodes.pop_front(); - GE_CHECK_NOTNULL(one_node); - GE_CHECK_NOTNULL(one_node->GetOpDesc()); - graphStatus ret_status = BackInferProcess(nodes, one_node); - if ((ret_status != GRAPH_SUCCESS) && (one_node != nullptr)) { - GELOGE(ret_status, "[Back][InferProcess] failed! status:%d, node name [%s]", - ret_status, one_node->GetName().c_str()); - return ret_status; - } - ret_status = ForwardInferProcess(nodes, one_node); - if ((ret_status != GRAPH_SUCCESS) && (one_node != nullptr)) { - GELOGE(ret_status, "[Forward][InferProcess] failed! status:%d, node name [%s]", - ret_status, one_node->GetName().c_str()); - return ret_status; - } - } - return GRAPH_SUCCESS; -} -graphStatus FormatRefiner::BackInferProcess(std::deque &nodes, const ge::NodePtr &node) { - GELOGD("Enter back infer format for Node [%s]", node->GetName().c_str()); - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - const auto in_data_anchor_idx = in_anchor->GetIdx(); - GELOGD("Node [%s]:%d [B]", node->GetName().c_str(), in_data_anchor_idx); - const auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); - GE_IF_BOOL_EXEC(input_desc == nullptr, continue); - const auto to_be_set_format = input_desc->GetOriginFormat(); - GE_IF_BOOL_EXEC(to_be_set_format == FORMAT_ND, GELOGD("Node [%s] format is ND.[B]", node->GetName().c_str()); - continue); - const auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC (peer_out_data_anchor == nullptr, continue); - const auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - const int32_t idx = peer_out_data_anchor->GetIdx(); - // do peer_out_node name and index as key to lookup reflections - const ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); - std::unordered_set reflection; - auto status = reflection_builder.LookUpRefRelations(key, reflection); - GE_IF_BOOL_EXEC(status != GRAPH_SUCCESS, - GELOGE(GRAPH_FAILED, "[Call][LookUpRefRelations] failed! Node is [%s], the %d out edge", - (peer_out_data_node->GetName()).c_str(), idx); return GRAPH_FAILED); - - // Check format whether have been set - // op_desc of node should not be null - auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast(idx)); - - bool format_locked = false; - (void)AttrUtils::GetBool(peer_out_data_node->GetOpDesc(), ATTR_NAME_FORMAT_LOCKED, format_locked); - GELOGD("Get format locked flag:%u (shape is locked if value is equal to 1) from peer out node:%s.", - static_cast(format_locked), peer_out_data_node->GetName().c_str()); - - if ((ge_tensor_desc->GetOriginFormat() == FORMAT_ND) && (!format_locked)) { - const auto dim_num = ge_tensor_desc->GetShape().GetDimNum(); - GE_IF_BOOL_EXEC(dim_num == 0UL, GELOGD("node name:%s idx:%d out is scalar. stop back infer!", - peer_out_data_node->GetName().c_str(), idx); continue); - - /// Check whether node to change dims () - /// Because some node will calculate with 5D, C dim maybe multi meaning - const auto peer_out_data_node_type = peer_out_data_node->GetType(); - const auto iter1 = kChangeDimNodes.find(peer_out_data_node_type); - // 4 means dims num - if ((iter1 != kChangeDimNodes.end()) && (dim_num < 4UL)) { - GELOGD("Node[%s] is change dim node and shape is smaller than 4. do not modify format", - (peer_out_data_node->GetName()).c_str()); - continue; - } - - if (reflection.empty()) { - ge_tensor_desc->SetOriginFormat(to_be_set_format); - ge_tensor_desc->SetFormat(to_be_set_format); - - // Call operator infer format api (forward) to get out format - GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); - status = NodeUtilsEx::InferOriginFormat(peer_out_data_node); - GE_IF_BOOL_EXEC(status != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "[Infer][Format] failed, Node:%s", - (peer_out_data_node->GetName()).c_str()); return GRAPH_FAILED); - nodes.push_back(peer_out_data_node); - } else { - const auto ret = ReflectionProcess(reflection, nodes, to_be_set_format); - GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "[Reflect][Node] failed! status:%d", ret); - return GRAPH_FAILED); - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::ForwardInferProcess(std::deque &nodes, const ge::NodePtr &node) { - GELOGD("Enter forward infer format for Node [%s]", node->GetName().c_str()); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); - const auto out_data_anchor_idx = out_data_anchor->GetIdx(); - GELOGD("Node [%s]:%d [F]", node->GetName().c_str(), out_data_anchor_idx); - if (node->GetOpDesc()->MutableOutputDesc(static_cast(out_data_anchor_idx)) == nullptr) { - continue; - } - const auto to_be_set_format = - node->GetOpDesc()->MutableOutputDesc(static_cast(out_data_anchor_idx))->GetOriginFormat(); - if (to_be_set_format == FORMAT_ND) { - GELOGD("Node [%s] format is ND.[F]", node->GetName().c_str()); - continue; - } - const auto ret = AnchorsInferProcess(nodes, out_data_anchor, to_be_set_format); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - return GRAPH_SUCCESS; -} - -void FormatRefiner::RefreshOriginFormatOfAnchor(const std::vector &anchor_points) { - for (const auto &node : anchor_points) { - for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { - // single op support private format set, its origin format should not be override - const auto ori_format = input_desc->GetOriginFormat(); - const auto format = input_desc->GetFormat(); - if (TypeUtilsInner::IsInternalFormat(format)) { - continue; - } - if ((input_desc != nullptr) && ((ori_format == FORMAT_ND) || (ori_format == FORMAT_RESERVED))) { - input_desc->SetOriginFormat(input_desc->GetFormat()); - } - } - for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { - const auto ori_format = output_desc->GetOriginFormat(); - const auto format = output_desc->GetFormat(); - if (TypeUtilsInner::IsInternalFormat(format)) { - continue; - } - if ((output_desc != nullptr) && ((ori_format == FORMAT_ND) || (ori_format == FORMAT_RESERVED))) { - output_desc->SetOriginFormat(output_desc->GetFormat()); - } - } - } -} - -graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, - const std::vector &anchor_data_nodes, - const ge::Format data_format) { - if (!(IsGraphInferred(graph) && (!TypeUtilsInner::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { - GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", - static_cast(IsGraphInferred(graph)), TypeUtils::FormatToSerialString(data_format).c_str()); - return GRAPH_SUCCESS; - } - GELOGD("Enter DataNodeFormatProcess"); - std::vector uninfered_data_nodes; - // Check and renew data nodes format - for (const auto &data_node : anchor_data_nodes) { - GE_CHECK_NOTNULL(data_node); - const auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - const auto input_desc = op_desc->MutableInputDesc(0U); - const auto output_desc = op_desc->MutableOutputDesc(0U); - GE_CHECK_NOTNULL(input_desc); - GE_CHECK_NOTNULL(output_desc); - - const auto curr_format = output_desc->GetOriginFormat(); - if (curr_format != FORMAT_ND) { - // Data format has been infered , continue - continue; - } - // keep data format be ND because lacking of defination when input shape num is smaller than 4 - if (input_desc->MutableShape().GetDimNum() < kDimSizeOf4D) { - continue; - } - // Set format for un-infered data node - input_desc->SetOriginFormat(data_format); - input_desc->SetFormat(data_format); - output_desc->SetOriginFormat(data_format); - output_desc->SetFormat(data_format); - uninfered_data_nodes.push_back(data_node); - } - // Reinfer format from uninfered data nodes - for (const auto &node : uninfered_data_nodes) { - if (node == nullptr) { - continue; - } - GELOGD("data node [%s] start infer format process", node->GetName().c_str()); - const auto status = AnchorProcess(node); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][AnchorProcess] failed, status:%d, node:%s", status, node->GetName().c_str()); - return GRAPH_FAILED; - } - } - GELOGD("DataNodeFormatProcess success"); - return GRAPH_SUCCESS; -} - -graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) { - GELOGI("Enter InferOrigineFormat process!"); - - // True: infered false:no-infered - std::vector anchor_points; - std::vector anchor_data_nodes; - - if (graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param graph is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] input graph is nullptr"); - return GRAPH_FAILED; - } - // build reflection relations of boundary - (void)reflection_builder.Clear(); - auto status = reflection_builder.BuildRefRelations(*graph); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][BuildRefRelations] failed, graph:%s", graph->GetName().c_str()); - return GRAPH_FAILED; - } - // User set global net format - status = GetAnchorPoints(graph, anchor_points, anchor_data_nodes); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild! graph:%s", graph->GetName().c_str()); - return GRAPH_FAILED; - } - // Refresh origin format of anchor point - RefreshOriginFormatOfAnchor(anchor_points); - // Infer format process - for (const auto &anchor_node : anchor_points) { - if (anchor_node == nullptr) { - continue; - } - status = AnchorProcess(anchor_node); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Call][AnchorProcess] failed, node:%s", anchor_node->GetName().c_str()); - return GRAPH_FAILED; - } - } - /// According to discuss with sys-enginer, data node default format is ND.Its format - /// should be set by infered.But if some data-node can not be got by infer, set context's - /// format for these data nodes. - /// Notice: ignore 5D formats - const auto data_format = graph->GetDataFormat(); - status = DataNodeFormatProcess(graph, anchor_data_nodes, data_format); - - (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); - - return status; -} - -bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { - bool is_graph_inferred = false; - return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); -} -} // namespace ge diff --git a/graph/refiner/format_refiner.h b/graph/refiner/format_refiner.h deleted file mode 100644 index df0c5c86364a89a8c31b468fc4821e323d39b939..0000000000000000000000000000000000000000 --- a/graph/refiner/format_refiner.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_GRAPH_FORMAT_REFINER_H_ -#define COMMON_GRAPH_FORMAT_REFINER_H_ - -#if defined(_MSC_VER) -#ifdef FUNC_VISIBILITY -#define METADEF_FUNC_VISIBILITY _declspec(dllexport) -#else -#define METADEF_FUNC_VISIBILITY -#endif -#else -#ifdef FUNC_VISIBILITY -#define METADEF_FUNC_VISIBILITY -#else -#define METADEF_FUNC_VISIBILITY __attribute__((visibility("hidden"))) -#endif -#endif - -#include -#include -#include -#include -#include "graph/compute_graph.h" -#include "graph/types.h" -#include "graph/ge_error_codes.h" - -namespace ge { -// ShapeRefiner performs shape inference for compute graphs -class METADEF_FUNC_VISIBILITY FormatRefiner { - public: - static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); - - private: - static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &com_graph, const OpDescPtr &op_desc); - static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &com_graph, std::vector &anchor_points, - std::vector &anchor_data_nodes); - static graphStatus AnchorProcess(const ge::NodePtr &anchor_node); - static void RefreshOriginFormatOfAnchor(const std::vector &anchor_points); - static graphStatus BackInferProcess(std::deque &nodes, const ge::NodePtr &node); - static graphStatus ForwardInferProcess(std::deque &nodes, const ge::NodePtr &node); - static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, - const std::vector &anchor_data_nodes, - const ge::Format data_format); - static bool IsGraphInferred(const ComputeGraphPtr &graph); -}; -} // namespace ge -#endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/graph/refiner/ref_relation.cc b/graph/refiner/ref_relation.cc deleted file mode 100644 index 4420a11f1e813cb5269151de1b9742560924946d..0000000000000000000000000000000000000000 --- a/graph/refiner/ref_relation.cc +++ /dev/null @@ -1,483 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ref_relation.h" - -#include -#include -#include - -#include "common/util/mem_utils.h" -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "graph/utils/graph_utils.h" -#include "graph/def_types.h" - -namespace ge { -namespace { - const char_t *kRefIdx = "_parent_node_index"; - const char_t *kWhile = "While"; - const char_t *kIf = "If"; - const char_t *kCase = "Case"; - const char_t *kStatelessWhile = "StatelessWhile"; - std::set function_op = {kWhile, kIf, kCase}; -} - -/* Impl */ -class RefRelations::Impl { -public: - graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result) { - const auto iter = look_up_table_.find(key.hash_key); - if (iter != look_up_table_.end()) { - for (auto &c : iter->second) { - (void)result.insert(c); - } - return GRAPH_SUCCESS; - } - GELOGD("[RefRelations][Check] can not find any relations! key value of dest relation is %s", key.hash_key.c_str()); - return GRAPH_SUCCESS; - }; - graphStatus BuildRefRelations(ge::ComputeGraph &graph); - graphStatus Clear() { - GELOGD("Start clear boundary reflections between main graph and sub graph!"); - look_up_table_.clear(); - values_.clear(); - return GRAPH_SUCCESS; - }; -private: - friend class RefRelations; - graphStatus BuildLookUpTables(); - graphStatus BuildRefRelationsForBranch( - const NodePtr &root_node, - const std::vector> &classed_data_nodes, - const std::vector>> &classed_netoutput_nodes, - std::vector> &node_refs) const; - graphStatus BuildRefRelationsForWhile( - const NodePtr &root_node, - const std::vector> &classed_data_nodes, - const std::vector>> &classed_netoutput_nodes, - std::vector> &node_refs) const; - graphStatus BuildRelationsWithFuncNodeType( - const NodePtr &root_node, - const std::vector> &classed_data_nodes, - const std::vector>> &classed_netoutput_nodes, - std::vector> &node_refs) const; - void GetDataAndNetoutputOfSubGraph( - const ge::ComputeGraph &root_graph, - std::vector &graph_data_nodes, - std::vector &netoutput_nodes, - const std::vector &sub_graph_names, - const std::string &node_type) const; - - graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) const; - graphStatus ProcessSubgraphDataNodes(std::vector &graph_data_nodes, - std::vector> &classed_data_nodes) const; - graphStatus ProcessSubgraphNetoutput( - const std::vector &netoutput_nodes, - std::vector>> &classed_netoutput_nodes) const; - void BuildRelationsForVariables(const ge::ComputeGraph &root_graph); - - std::unordered_map> look_up_table_; - std::vector>> values_; -}; - -// Node Level -graphStatus RefRelations::Impl::BuildRefRelationsForBranch( - const NodePtr &root_node, - const std::vector> &classed_data_nodes, - const std::vector>> &classed_netoutput_nodes, - std::vector> &node_refs) const { - GELOGD("Enter BuildRefRelationsForBranch!"); - - size_t ref_i = 0UL; - for (const auto &ref_i_data_nodes : classed_data_nodes) { - std::vector in_ref_i_all_refs; - RefCell cell_root(root_node->GetName(), root_node, NODE_IN, static_cast(ref_i)); - in_ref_i_all_refs.emplace_back(cell_root); - for (const auto &data : ref_i_data_nodes) { - RefCell cell_in(data->GetName(), data, NODE_IN, 0); - RefCell cell_out(data->GetName(), data, NODE_OUT, 0); - in_ref_i_all_refs.emplace_back(cell_in); - in_ref_i_all_refs.emplace_back(cell_out); - } - node_refs.emplace_back(in_ref_i_all_refs); - ref_i++; - } - - size_t ref_o = 0UL; - for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { - std::vector out_ref_i_all_refs; - RefCell cell_root(root_node->GetName(), root_node, NODE_OUT, static_cast(ref_o)); - out_ref_i_all_refs.emplace_back(cell_root); - for (const auto &ele : ref_o_net_nodes) { - RefCell cell_netoutput_in((ele.first)->GetName(), ele.first, NODE_IN, static_cast(ele.second)); - out_ref_i_all_refs.emplace_back(cell_netoutput_in); - } - node_refs.emplace_back(out_ref_i_all_refs); - ref_o++; - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildLookUpTables() { - GELOGD("start to build look up table!"); - for (size_t i = 0UL; i < values_.size(); i++) { - std::vector> &val = values_[i]; - for (const auto &ele : val) { - for (const auto &ref_cell : ele) { - look_up_table_.emplace(ref_cell.hash_key, ele); - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::BuildRefRelationsForWhile( - const NodePtr &root_node, - const std::vector> &classed_data_nodes, - const std::vector>> &classed_netoutput_nodes, - std::vector> &node_refs) const { - GELOGD("Enter BuildRefRelations for while op!"); - // data_nodes has been sorted - // for while, input num must be same as output num - const auto input_num = root_node->GetAllInDataAnchorsSize(); - NodePtr netoutput = nullptr; - - size_t ref_i = 0UL; - while (ref_i < input_num) { - auto &ref_i_data_nodes = classed_data_nodes[ref_i]; - auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; - - std::vector ref_i_all_refs; - RefCell cell_root_i(root_node->GetName(), root_node, NODE_IN, static_cast(ref_i)); - RefCell cell_root_o(root_node->GetName(), root_node, NODE_OUT, static_cast(ref_i)); - ref_i_all_refs.emplace_back(cell_root_i); - ref_i_all_refs.emplace_back(cell_root_o); - for (const auto &data : ref_i_data_nodes) { - RefCell cell_in(data->GetName(), data, NODE_IN, 0); - RefCell cell_out(data->GetName(), data, NODE_OUT, 0); - ref_i_all_refs.emplace_back(cell_in); - ref_i_all_refs.emplace_back(cell_out); - } - - for (const auto &ele : ref_i_net_nodes) { - RefCell cell_netoutput_in((ele.first)->GetName(), ele.first, NODE_IN, static_cast(ele.second)); - ref_i_all_refs.emplace_back(cell_netoutput_in); - netoutput = ele.first; - } - node_refs.emplace_back(ref_i_all_refs); - ref_i++; - } - /* There exist scene like the follows, it means data0 data1 netoutput 0'th - * and 1'th tensor should be the same addr. - * Data0 Data1 - * \/ - * /\ - * netoutput - */ - if (netoutput == nullptr) { - return GRAPH_SUCCESS; - } - for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { - const auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - continue; - } - const auto peer_out_data_node = peer_out_data_anchor->GetOwnerNodeBarePtr(); - if ((peer_out_data_node == nullptr) || (peer_out_data_node->GetOpDesc() == nullptr)) { - GELOGW("[RefRelations][Check] Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", - netoutput->GetName().c_str()); - continue; - } - if (peer_out_data_node->GetType() != DATA) { - continue; - } - const auto in_data_anchor_idx = in_anchor->GetIdx(); - const auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); - int32_t ref_d = 0; - int32_t ref_n = 0; - (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIdx, ref_d); - (void)AttrUtils::GetInt(net_in_desc, kRefIdx, ref_n); - const size_t ref_desc = static_cast(ref_d); - const size_t ref_in = static_cast(ref_n); - - const size_t idx1 = node_refs[ref_in].size(); // 注意,不要删除idx1、idx2,存在ref_desc=ref_in的情况 - for (size_t i = 0U; i < idx1; ++i) { - node_refs[ref_desc].emplace_back(node_refs[ref_in][i]); - } - const size_t idx2 = node_refs[ref_desc].size(); - for (size_t i = 0U; i < idx2; ++i) { - node_refs[ref_in].emplace_back(node_refs[ref_desc][i]); - } - } - - return GRAPH_SUCCESS; -} -// build ref relations according to diff func op type -graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( - const NodePtr &root_node, - const std::vector> &classed_data_nodes, - const std::vector>> &classed_netoutput_nodes, - std::vector> &node_refs) const { - // data_nodes has been sorted - const auto &node_type = root_node->GetType(); - - auto status = GRAPH_SUCCESS; - if ((node_type != kWhile) && (node_type != kStatelessWhile)) { - status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); - } else { - status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); - } - return status; -} - -void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, - std::vector &graph_data_nodes, - std::vector &netoutput_nodes, - const std::vector &sub_graph_names, - const std::string &node_type) const { - int32_t sub_graph_idx = 0; - for (const auto &name : sub_graph_names) { - const auto &sub_graph = root_graph.GetSubgraph(name); - if (sub_graph == nullptr) { - GELOGW("[RefRelations][Check] Can not find sub graph %s, root graph: %s.", name.c_str(), - root_graph.GetName().c_str()); - continue; - } - for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { - const auto &sub_graph_node_type = sub_graph_node->GetType(); - if (sub_graph_node_type == DATA) { - graph_data_nodes.emplace_back(sub_graph_node); - } - if (sub_graph_node_type == NETOUTPUT) { - // if while, the first subgraph must be cond subgraph. - // There is no meaning for refs ,so continue - if (((node_type == kWhile) || (node_type == kStatelessWhile)) && (sub_graph_idx == 0)) { - continue; - } - netoutput_nodes.emplace_back(sub_graph_node); - } - } - sub_graph_idx++; - } -} - -graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) const { - const auto parent_graph_ptr = graph.GetParentGraph(); - if (parent_graph_ptr == nullptr) { - root_graph = graph; - return GRAPH_SUCCESS; - } - const auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); - if (root_graph_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Get null root graph, graph:%s", parent_graph_ptr->GetName().c_str()); - GE_LOGE("[Find][Graph] Get null root graph"); - return GRAPH_PARAM_INVALID; - } - root_graph = *root_graph_ptr; - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(std::vector &graph_data_nodes, - std::vector> &classed_data_nodes) const { - GELOGD("start to process subgraph data nodes!"); - int32_t max_ref_idx = 0; - for (const auto &e : graph_data_nodes) { - int32_t i; - bool is_exist = true; - is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIdx, i); - if (!is_exist) { - REPORT_INNER_ERR_MSG("E18888", "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIdx); - GELOGE(GRAPH_FAILED, "[Get][Int] Invalid SubGraph NetOutput node[%s].no attr %s", - e->GetName().c_str(), kRefIdx); - return GRAPH_FAILED; - } - max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; - } - classed_data_nodes.resize(static_cast(max_ref_idx) + 1UL); - while (!graph_data_nodes.empty()) { - auto data = graph_data_nodes.back(); - graph_data_nodes.pop_back(); - int32_t ref_idx = 0; - (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIdx, ref_idx); - if (ref_idx >= static_cast(classed_data_nodes.size())) { - return GRAPH_FAILED; - } - classed_data_nodes[static_cast(ref_idx)].emplace_back(data); - } - return GRAPH_SUCCESS; -} - -graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( - const std::vector &netoutput_nodes, - std::vector>> &classed_netoutput_nodes) const { - GELOGD("[RefRelations]Start to process subgraph netoutput!"); - // calc netoutput max_ref_idx - int32_t max_ref_idx = 0; - for (const auto &sub_netoutput_node : netoutput_nodes) { - const auto op_desc = sub_netoutput_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { - const auto in_desc = op_desc->MutableInputDesc(static_cast(in_data_anchor->GetIdx())); - if (in_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Invalid NetOutput node [%s] idx [%d], no tensor on it", - sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); - GELOGE(GRAPH_FAILED, "[Get][Tensor] Invalid NetOutput node [%s] idx [%d], no tensor on it", - sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); - return GRAPH_FAILED; - } - int32_t ref_o; - if (AttrUtils::GetInt(in_desc, kRefIdx, ref_o)) { - max_ref_idx = (ref_o > max_ref_idx) ? ref_o : max_ref_idx; - } else { - REPORT_INNER_ERR_MSG("E18888", "Invalid NetOutput node [%s] idx [%d], no attr[_parent_node_index] on it", - sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); - GELOGE(GRAPH_FAILED, "[Get][Int] Invalid NetOutput node [%s] idx [%d], no attr[_parent_node_index] on it", - sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); - return GRAPH_FAILED; - } - } - } - classed_netoutput_nodes.resize(static_cast(max_ref_idx) + 1UL); - // re-sort according ref idx - for (const auto &sub_netoutput_node : netoutput_nodes) { - const auto op_desc = sub_netoutput_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { - const auto in_desc = op_desc->MutableInputDesc(static_cast(in_data_anchor->GetIdx())); - int32_t ref_o; - if (AttrUtils::GetInt(in_desc, kRefIdx, ref_o)) { - if (ref_o >= static_cast(classed_netoutput_nodes.size())) { - return GRAPH_FAILED; - } - classed_netoutput_nodes[static_cast(ref_o)].emplace_back(std::pair( - {sub_netoutput_node, static_cast(in_data_anchor->GetIdx())} - )); - } - } - } - return GRAPH_SUCCESS; -} - -void RefRelations::Impl::BuildRelationsForVariables(const ge::ComputeGraph &root_graph) { - if (root_graph.GetAllSubgraphs().empty()) { - return; - } - - std::map> variables; - for (const auto &node : root_graph.GetAllNodes()) { - if (node->GetType() == VARIABLE) { - variables[node->GetName()].emplace_back(node); - } - } - - for (const auto &it : variables) { - const auto &instances = it.second; - if (instances.size() <= 1UL) { - continue; - } - - GELOGD("Variable [%s] has %zu instances", it.first.c_str(), instances.size()); - std::vector variable_all_refs; - for (const auto &variable : instances) { - RefCell variable_ref(it.first, variable, NODE_OUT, 0); - variable_all_refs.emplace_back(std::move(variable_ref)); - } - - std::vector> refs {variable_all_refs}; - values_.emplace_back(std::move(refs)); - } -} - -graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { - GELOGD("Start to build ref relations!"); - /* First Step: Get root graph */ - ge::ComputeGraph &root_graph = graph; - auto status = GetRootGraph(graph, root_graph); - if (status != GRAPH_SUCCESS) { - return status; - } - - for (const auto &node : graph.GetAllNodes()) { - const auto &node_type = node->GetType(); - const auto &op_desc = node->GetOpDesc(); - const auto &sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - continue; - } - std::vector graph_data_nodes; - std::vector netoutput_nodes; - // Get data and netoutput of sub_graph - GetDataAndNetoutputOfSubGraph(root_graph, graph_data_nodes, netoutput_nodes, sub_graph_names, node_type); - std::vector> classed_data_nodes; // resize according to ref_idx - std::vector>> classed_netoutput_nodes; // resize according to ref_idx - status = ProcessSubgraphDataNodes(graph_data_nodes, classed_data_nodes); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Process][SubgraphDataNodes] failed! ret:%d", status); - return status; - } - - // for netoutput - // check netoutput - // here main graph output number must be the same as every sub_graph netoutput node - // key: netoutput node_ptr , - status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); - if (status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Process][SubgraphNetoutput] failed! ret:%d", status); - return status; - } - - std::vector> node_refs; - status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "[Build][Relations] WithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); - return status; - } - if (!node_refs.empty()) { - values_.push_back(node_refs); - } - } - - BuildRelationsForVariables(root_graph); - /* Seconde Step: generate map */ - status = BuildLookUpTables(); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "[Build][LookUpTables] failed! ret:%d", status); - return status; - } - return GRAPH_SUCCESS; -} - -/* Ref Relations Interface */ -RefRelations::RefRelations() { - impl_ = MakeShared(); - if (impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "new impl failed."); - GELOGE(GRAPH_FAILED, "[New][Impl] MakeShared failed!"); - return; - } -} - -graphStatus RefRelations::LookUpRefRelations(const RefCell &key, std::unordered_set &result) { - GE_CHECK_NOTNULL(impl_); - return impl_->LookUpRefRelations(key, result); -} - -graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &graph) { - GE_CHECK_NOTNULL(impl_); - return impl_->BuildRefRelations(graph); -} - -graphStatus RefRelations::Clear() { - GE_CHECK_NOTNULL(impl_); - return impl_->Clear(); -} -} diff --git a/graph/refiner/shape_refiner.cc b/graph/refiner/shape_refiner.cc deleted file mode 100644 index ea572b774fab544f61fa00b6eda570692bcbb567..0000000000000000000000000000000000000000 --- a/graph/refiner/shape_refiner.cc +++ /dev/null @@ -1,1009 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/shape_refiner.h" - -#include -#include -#include -#include -#include -#include -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" - -#include "debug/ge_log.h" -#include "debug/ge_op_types.h" -#include "external/graph/operator_factory.h" -#include "graph/operator_factory_impl.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "graph/utils/node_utils_ex.h" - -namespace ge { -namespace { -const char_t *const kPreOpInputShapeRange = "_pre_op_in_range"; - -const static std::set kDummyContextOpTypes{ "Enter", "Switch", "RefSwitch", "StackPush", "StackPop" }; -const static std::map kGeLocalOpMapping { - { "StreamMerge", "Merge" }, { "MemcpyAsync", "Identity" } -}; - -bool IsOpWithSubgraph(const NodePtr &node) { - const auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - return false; - } - const auto subgraph_name = op_desc->GetSubgraphInstanceNames(); - return !subgraph_name.empty(); -} - -graphStatus UpdateOutputForMultiBatch(const ConstNodePtr &node, - std::vector> &ref_out_tensors) { - // check sub_graph shape. Get max for update. - for (size_t i = 0UL; i < ref_out_tensors.size(); ++i) { - if (ref_out_tensors[i].empty()) { - continue; - } - - int64_t max_size = 0; - size_t max_shape_index = 0UL; - auto &ref_out_tensor = ref_out_tensors[i].at(0U); - for (size_t j = 0UL; j < ref_out_tensors[i].size(); ++j) { - auto &tensor = ref_out_tensors[i].at(j); - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - REPORT_INNER_ERR_MSG("E18888", "node[%s] does not support diff dtype among all ref output", - node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype among all ref output", - node->GetName().c_str()); - return GRAPH_FAILED; - } - - const auto shape = tensor.MutableShape(); - int64_t size = 1; - for (const auto dim : shape.GetDims()) { - if ((dim != 0) && ((std::numeric_limits::max() / dim) < size)) { - REPORT_INNER_ERR_MSG("E18888", "The shape:%s size overflow, node:%s", shape.ToString().c_str(), - node->GetName().c_str()); - GELOGE(PARAM_INVALID, "[Check][Overflow] The shape size overflow"); - return PARAM_INVALID; - } - size *= dim; - } - - if (size > max_size) { - max_size = size; - max_shape_index = j; - } - } - - (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(i), ref_out_tensors[i].at(max_shape_index)); - } - - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class branch op process"); - if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { - return UpdateOutputForMultiBatch(node, ref_out_tensors); - } - - // check sub_graph shape.If not same ,do unknown shape process - for (size_t i = 0UL; i < ref_out_tensors.size(); i++) { - if (ref_out_tensors[i].empty()) { - continue; - } - auto &ref_out_tensor = ref_out_tensors[i].at(0U); - ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); - for (auto &tensor : ref_out_tensors[i]) { - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - REPORT_INNER_ERR_MSG("E18888", "node[%s] does not support diff dtype among all ref output, shape:%s", - node->GetName().c_str(), ref_out_tensor_shape.ToString().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype output", node->GetName().c_str()); - return GRAPH_FAILED; - } - const auto shape = tensor.MutableShape(); - if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { - GELOGD("node is %s, i : %zu, shape size: %" PRId64 ", ref_out_tensor_shape size: %" PRId64, - node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - ref_out_tensor_shape = GeShape(UNKNOWN_RANK); - break; - } - for (size_t j = 0UL; j < ref_out_tensor_shape.GetDims().size(); j++) { - if (ref_out_tensor_shape.GetDim(j) != shape.GetDim(j)) { - GELOGD("node is %s, i : %zu, j: %zu ,shape size: %" PRId64 ", ref_out_tensor_shape size: %" PRId64, - node->GetName().c_str(), i, j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); - (void) ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); - } - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(i), ref_out_tensor); - } - return GRAPH_SUCCESS; -} - -void SetShapeRangeForWhile(GeShape &data_shape, const GeShape &out_shape, bool &need_infer_again, - std::vector> &data_shape_range) { - for (size_t j = 0U; j < data_shape.GetDimNum(); ++j) { - if (data_shape.GetDim(j) != out_shape.GetDim(j)) { - if (data_shape.GetDim(j) != UNKNOWN_DIM) { - // if input data is fix shape, output is different, need_infer_again - need_infer_again = true; - } - (void) data_shape.SetDim(j, UNKNOWN_DIM); - } - // set shape rang of while, if dim is unknown ,set shape range as {0,-1} - if (data_shape.GetDim(j) == UNKNOWN_DIM) { - data_shape_range.emplace_back(std::make_pair(SHAPE_RANGE_LOWER_LIMIT, UNKNOWN_DIM)); - } else { - data_shape_range.emplace_back(std::make_pair(data_shape.GetDim(j), data_shape.GetDim(j))); - } - } -} - -graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, - std::vector> &ref_data_tensors, - std::vector> &ref_out_tensors) { - GELOGD("Enter update parent node shape for class while op process"); - if (ref_data_tensors.size() != ref_out_tensors.size()) { - REPORT_INNER_ERR_MSG("E18888", "op:%s(%s) input number[%zu] and output number[%zu] is not same!", - node->GetName().c_str(), node->GetType().c_str(), ref_data_tensors.size(), - ref_out_tensors.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] while op [%s] input number[%zu] and output number[%zu] is not same!", - node->GetName().c_str(), ref_data_tensors.size(), ref_out_tensors.size()); - return GRAPH_FAILED; - } - for (size_t i = 0U; i < ref_data_tensors.size(); i++) { - if (ref_out_tensors[i].size() != 1U) { - REPORT_INNER_ERR_MSG("E18888", "while op, every output should only find one output tensor in all graph!"); - GELOGE(GRAPH_FAILED, "[Check][Param] while op, every output should only find one output tensor in all graph!"); - return GRAPH_FAILED; - } - } - bool need_infer_again = false; - // check input and output - for (size_t i = 0UL; i < ref_out_tensors.size(); i++) { - auto ref_out_tensor = ref_out_tensors[i].at(0U); - const auto out_shape = ref_out_tensor.MutableShape(); - // ref_i's data and output tensor shape should be same - for (auto &tensor : ref_data_tensors[i]) { - // if the input tensor shares multiple references, the ranges should ensure consistency - std::vector> data_shape_range; - if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { - REPORT_INNER_ERR_MSG("E18888", "node[%s] does not support diff dtype or format among all ref output", - node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] does not support diff dtype or format output.", - node->GetName().c_str()); - return GRAPH_FAILED; - } - auto data_shape = tensor.MutableShape(); - // input is dynamic, here use dim_num - if (data_shape.GetDims() != out_shape.GetDims()) { - GELOGI("After infer, While %s %zu output shape [%s] is not match with input shape [%s].Need infer again.", - node->GetName().c_str(), i, out_shape.ToString().c_str(), data_shape.ToString().c_str()); - if (data_shape.GetDimNum() != out_shape.GetDimNum()) { - ref_out_tensor.SetUnknownDimNumShape(); - } else { - SetShapeRangeForWhile(data_shape, out_shape, need_infer_again, data_shape_range); - ref_out_tensor.SetShape(data_shape); - (void)ref_out_tensor.SetShapeRange(data_shape_range); - } - } - } - (void)node->GetOpDesc()->UpdateOutputDesc(static_cast(i), ref_out_tensor); - } - (void)AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, need_infer_again); - return GRAPH_SUCCESS; -} - -graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { - // if infer again, update output of while into subgraph data node - const auto op_desc = node->GetOpDesc(); - const auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } - - const auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - for (const auto &name : sub_graph_names) { - const auto sub_graph = root_graph->GetSubgraph(name); - if (sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); - GE_LOGE("[Get][Graph] can not find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - for (const auto &node_sub : sub_graph->GetDirectNode()) { - if (node_sub->GetType() != DATA) { - continue; - } - int32_t ref_i; - const auto data_opdesc = node_sub->GetOpDesc(); - if (data_opdesc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), - node->GetName().c_str()); - GE_LOGE("[Get][OpDesc] Invalid data node on the sub graph %s parent node %s, no OpDesc", - name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - REPORT_INNER_ERR_MSG("E18888", "Invalid data node on the sub graph %s parent node %s, no ref-index attribute", - name.c_str(), node->GetName().c_str()); - GE_LOGE("[Get][Int] Invalid data node on the sub graph %s parent node %s, no ref-index attribute", - name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { - continue; - } - auto input_desc = op_desc->MutableInputDesc(static_cast(ref_i)); - if (input_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "The ref index(%d) on the data %s on the sub graph %s " - "parent node %s are incompatible, inputs num %u", - ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), - node->GetAllInDataAnchorsSize()); - GE_LOGE("[Call][MutableInputDesc] The ref index(%d) on the data %s on the sub graph %s " - "parent node %s are incompatible, inputs num %u", ref_i, node_sub->GetName().c_str(), - name.c_str(), node->GetName().c_str(), node->GetAllInDataAnchorsSize()); - return GRAPH_FAILED; - } - GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), - node->GetName().c_str()); - - // if need infer again, refresh subgraph input with output - bool is_infer_again = false; - (void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, is_infer_again); - if (is_infer_again) { - input_desc = op_desc->MutableOutputDesc(static_cast(ref_i)); - GE_CHECK_NOTNULL(input_desc, - "The ref index(%d) on the data %s on the subgraph %s " - "parent node %s are incompatible, outputs num %u.", - ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), - node->GetAllOutDataAnchorsSize()); - GELOGD("Update input desc of data %s on the sub graph %s of node %s,output idx: %d from [%s] to [%s]", - node_sub->GetName().c_str(), - name.c_str(), - node->GetName().c_str(), - ref_i, - data_opdesc->GetInputDescPtr(0U)->GetShape().ToString().c_str(), - input_desc->GetShape().ToString().c_str()); - } - - auto ret = data_opdesc->UpdateInputDesc(0U, *input_desc); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to update input desc of data %s on the sub graph %s parent node %s", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - GE_LOGE("[Update][InputDesc] of data %s on the sub graph %s parent node %s failed", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - return ret; - } - ret = data_opdesc->UpdateOutputDesc(0U, *input_desc); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to update output desc of data %s on the sub graph %s parent node %s", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - GE_LOGE("[Update][OutputDesc] of data %s on the sub graph %s parent node %s failed", - node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); - return ret; - } - } - } - return GRAPH_SUCCESS; -} - -graphStatus FindSubgraphDataAndNetoutput(const std::shared_ptr &sub_graph, - NodePtr &netoutput, const ConstNodePtr &node, - std::vector> &ref_data_tensors) { - auto sub_nodes = sub_graph->GetDirectNode(); - for (size_t i = sub_nodes.size(); i > 0UL; --i) { - const auto sub_node = sub_nodes.at(i - 1UL); - if (sub_node->GetType() == NETOUTPUT) { - netoutput = sub_node; - } - if (sub_node->GetType() == DATA) { - if (sub_node->GetOpDesc() == nullptr) { - return GRAPH_FAILED; - } - - int32_t ref_i; - if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - REPORT_INNER_ERR_MSG("E18888", "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Int] subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); - return GRAPH_FAILED; - } - if ((ref_i < 0) || (static_cast(ref_i) >= node->GetAllInDataAnchorsSize())) { - REPORT_INNER_ERR_MSG("E18888", "data node[%s]'s ref index[%d] is not in range [0, %u)!", - sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); - GELOGE(GRAPH_FAILED, "[Check][Param] data node[%s]'s ref index[%d] is not in range [0, %u)!", - sub_node->GetName().c_str(), ref_i, node->GetAllInDataAnchorsSize()); - return GRAPH_FAILED; - } - ref_data_tensors[static_cast(ref_i)].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0U)); - } - } - return GRAPH_SUCCESS; -} - -graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { - const auto op_desc = node->GetOpDesc(); - const auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } - - std::vector> ref_data_tensors(static_cast(node->GetAllInDataAnchorsSize())); - std::vector> ref_out_tensors(static_cast(node->GetAllOutDataAnchorsSize())); - const auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - - for (const auto &name : sub_graph_names) { - const auto sub_graph = root_graph->GetSubgraph(name); - if (sub_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Can not find the subgraph %s for node %s", name.c_str(), node->GetName().c_str()); - GE_LOGE("[Get][Subgraph] Can not find the subgraph %s for node %s", name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - NodePtr netoutput = nullptr; - const auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); - if (ret != GRAPH_SUCCESS) { - return ret; - } - if (netoutput == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "No NetOutput node on sub graph %s, parent node %s", name.c_str(), - node->GetName().c_str()); - GE_LOGE("[Check][Param] No NetOutput node on sub graph %s, parent node %s", - name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - const auto netoutput_opdesc = netoutput->GetOpDesc(); - if (netoutput_opdesc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", - name.c_str(), node->GetName().c_str()); - GE_LOGE("[Get][OpDesc] Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", - name.c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { - const auto edge_desc = netoutput_opdesc->MutableInputDesc(static_cast(edge_anchor->GetIdx())); - if (edge_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Invalid NetOutput node on sub graph %s, parent node %s, " - "can not find input tensor %d", - name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); - GE_LOGE("[Get][Tensor] Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", - name.c_str(), node->GetName().c_str(), edge_anchor->GetIdx()); - return GRAPH_FAILED; - } - GELOGI("Netoutput in anchor index is %d, input tensor dim is %zu", - edge_anchor->GetIdx(), edge_desc->GetShape().GetDimNum()); - int32_t ref_i; - if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { - // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. - continue; - } - GELOGI("Parent node index of edge desc is %d", ref_i); - if ((ref_i < 0) || (static_cast(ref_i) >= node->GetAllOutDataAnchorsSize())) { - return GRAPH_FAILED; - } - ref_out_tensors[static_cast(ref_i)].emplace_back(*edge_desc); - } - } - - if (node->GetType() == WHILE) { - return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); - } - return UpdateParentNodeForBranch(node, ref_out_tensors); -} - -std::string Serial(const std::vector &dims) { - std::string serial_string; - serial_string += "["; - for (const int64_t dim : dims) { - serial_string += std::to_string(dim) + " "; - } - serial_string += "]"; - return serial_string; -} - -void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { - desc_str += "["; - std::vector> shape_range; - (void)desc->GetShapeRange(shape_range); - for (const auto &pair : shape_range) { - desc_str += "{"; - desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); - desc_str += "},"; - } - desc_str += "]"; - shape_range.clear(); - (void)desc->GetOriginShapeRange(shape_range); - for (const auto &pair : shape_range) { - desc_str += ",{"; - desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); - desc_str += "},"; - } -} - -graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { - GE_CHECK_NOTNULL(node_ptr); - GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); - - for (const auto &in_anchor : node_ptr->GetAllInDataAnchors()) { - const auto in_idx = in_anchor->GetIdx(); - const auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_data_anchor == nullptr) { - continue; - } - const auto peer_out_data_node = peer_out_data_anchor->GetOwnerNodeBarePtr(); - if ((peer_out_data_node == nullptr) || (peer_out_data_node->GetOpDesc() == nullptr)) { - continue; - } - const int32_t peer_out_idx = peer_out_data_anchor->GetIdx(); - const auto peer_out_desc = peer_out_data_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_out_idx)); - - // check shape and dtype continuity. do not stop process - const auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(static_cast(in_idx)); - if (in_desc == nullptr) { - continue; - } - const auto in_shape = in_desc->MutableShape().GetDims(); - const auto in_dtype = in_desc->GetDataType(); - const auto peer_out_shape = peer_out_desc->MutableShape().GetDims(); - const auto peer_out_dtype = peer_out_desc->GetDataType(); - if (peer_out_dtype != in_dtype) { - GELOGW("[Update][InputDesc] current node [%s] [%d]\'th in_dtype is [%s].peer output node [%s] [%d]\'th " - "output_dtype is [%s]. The two dtype should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), in_idx, TypeUtils::DataTypeToSerialString(in_dtype).c_str(), - peer_out_data_node->GetName().c_str(), peer_out_idx, - TypeUtils::DataTypeToSerialString(peer_out_dtype).c_str()); - } else if ((!in_shape.empty()) && (in_shape != peer_out_shape)) { - const std::string in_shape_str = Serial(in_shape); - const std::string peer_out_shape_str = Serial(peer_out_shape); - GELOGW("[Update][InputDesc] current node [%s] [%d]\'th in_shape is [%s].peer output node [%s] [%d]\'th " - "output_shape is [%s]. The two shape should be same! Please check graph and fix it", - node_ptr->GetName().c_str(), in_idx, in_shape_str.c_str(), - peer_out_data_node->GetName().c_str(), peer_out_idx, peer_out_shape_str.c_str()); - } else { - // do nothing - } - // refresh current node input desc - in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); - in_desc->SetShape(peer_out_desc->MutableShape()); - in_desc->SetDataType(peer_out_desc->GetDataType()); - in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); - if (peer_out_desc->MutableShape().GetDims() != UNKNOWN_RANK) { - std::vector> shape_range; - (void)peer_out_desc->GetShapeRange(shape_range); - (void)in_desc->SetShapeRange(shape_range); - } - std::vector pre_op_in_range; - if (ge::AttrUtils::GetListInt(*peer_out_desc, kPreOpInputShapeRange, pre_op_in_range)) { - (void)ge::AttrUtils::SetListInt(*in_desc, kPreOpInputShapeRange, pre_op_in_range); - } - ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast(peer_out_desc->MutableShape().GetDims().size())); - } - return GRAPH_SUCCESS; -} -} // namespace -void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { - if (!IsLogEnable(GE, DLOG_DEBUG)) { - return; - } - const ge::OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL_JUST_RETURN(op_desc); - std::stringstream ss; - ss << "{"; - int32_t in_idx = 0; - int32_t out_idx = 0; - for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { - if (input_desc == nullptr) { - in_idx++; - continue; - } - if (in_idx > 0) { - ss << " "; - } - ss << "input_" << in_idx << " " << "tensor: ["; - ss << "(shape:[" << input_desc->MutableShape().ToString() << "]),"; - ss << "(format:" << TypeUtils::FormatToSerialString(input_desc->GetFormat()) << "),"; - ss << "(dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) << "),"; - ss << "(origin_shape:" << input_desc->GetOriginShape().ToString() << "),"; - ss << "(origin_format:" << TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) << "),"; - ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) << "),"; - std::string range_str; - SerialShapeRange(input_desc, range_str); - ss << "(shape_range:" << range_str << ")]"; - in_idx++; - } - for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { - if (output_desc == nullptr) { - out_idx++; - continue; - } - ss << " "; - ss << "output_" << out_idx << " " << "tensor: ["; - ss << "(shape:[" << output_desc->MutableShape().ToString() << "]),"; - ss << "(format:" << TypeUtils::FormatToSerialString(output_desc->GetFormat()) << "),"; - ss << "(dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) << "),"; - ss << "(origin_shape:" << output_desc->GetOriginShape().ToString() << "),"; - ss << "(origin_format:" << TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) << "),"; - ss << "(origin_dtype:" << TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) << "),"; - std::string range_str; - SerialShapeRange(output_desc, range_str); - ss << "(shape_range:" << range_str << ")]"; - out_idx++; - } - ss << "}"; - GELOGD("Shape dump [%s], Node name[%s], type[%s]. %s", phase.c_str(), node->GetName().c_str(), - node->GetType().c_str(), ss.str().c_str()); -} - -namespace { -thread_local std::unordered_map context_map; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -void ShapeRefiner::ClearContextMap() { - context_map.clear(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -void ShapeRefiner::PushToContextMap(const NodePtr &node, const InferenceContextPtr &inference_context) { - (void)context_map.emplace(node, inference_context); -} - -static void GetRealOutNode(const OutDataAnchorPtr &peer_out_data_anchor, - std::stack> &node_to_indx_stack, - std::map &out_nodes) { - auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); - if (IsOpWithSubgraph(peer_out_data_node)) { - node_to_indx_stack.push(std::make_pair(peer_out_data_node, peer_out_data_anchor->GetIdx())); - } else if ((peer_out_data_node->GetType() == DATA) - && peer_out_data_node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { - NodeToOutAnchor node_to_out_anchor = NodeUtils::GetParentInputAndAnchorCrossSubgraph(peer_out_data_node); - if ((node_to_out_anchor.first == nullptr) || (node_to_out_anchor.second == nullptr)) { - GELOGW("Get parent input node or anchor is nullptr."); - return; - } - if (IsOpWithSubgraph(node_to_out_anchor.first)) { - node_to_indx_stack.push(std::make_pair(node_to_out_anchor.first, node_to_out_anchor.second->GetIdx())); - } else { - (void)out_nodes.emplace(node_to_out_anchor.first, node_to_out_anchor.second->GetIdx()); - } - GELOGI("Ori peer node is:[%s][%s], change to real peer node:[%s][%s]", - peer_out_data_node->GetName().c_str(), peer_out_data_node->GetType().c_str(), - node_to_out_anchor.first->GetName().c_str(), node_to_out_anchor.first->GetType().c_str()); - } else { - (void)out_nodes.emplace(peer_out_data_node, peer_out_data_anchor->GetIdx()); - GELOGI("Peer node: %s, out index: %d.", peer_out_data_node->GetName().c_str(), peer_out_data_anchor->GetIdx()); - } - return; -} - -static Status GetOutNodesByParentNodeOutIndex(const NodePtr &parent_node, const int32_t out_idx, - std::map &out_nodes) { - out_nodes.clear(); - if (!IsOpWithSubgraph(parent_node)) { - return SUCCESS; - } - std::stack> node_to_indx_stack; - node_to_indx_stack.push(std::make_pair(parent_node, out_idx)); - while (!node_to_indx_stack.empty()) { - std::pair node_to_idx = node_to_indx_stack.top(); - node_to_indx_stack.pop(); - GELOGD("Node: %s, out index: %d.", node_to_idx.first->GetName().c_str(), node_to_idx.second); - const auto subgraph_output_nodes = NodeUtils::GetSubgraphOutputNodes(*(node_to_idx.first)); - for (const auto &netoutput : subgraph_output_nodes) { - GE_CHECK_NOTNULL(netoutput); - const auto output_desc = netoutput->GetOpDesc(); - GE_CHECK_NOTNULL(output_desc); - for (const auto &in_data_anchor : netoutput->GetAllInDataAnchors()) { - GE_CHECK_NOTNULL(in_data_anchor); - const auto in_desc = output_desc->MutableInputDesc(static_cast(in_data_anchor->GetIdx())); - GE_CHECK_NOTNULL(in_desc); - int32_t ref = 0; - if (AttrUtils::GetInt(in_desc, ATTR_NAME_PARENT_NODE_INDEX, ref) && (ref == node_to_idx.second)) { - const auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_data_anchor); - GetRealOutNode(peer_out_data_anchor, node_to_indx_stack, out_nodes); - } - } - } - } - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::GetRealInNodesAndIndex(NodePtr &input_node, int32_t &output_idx, - std::map &nodes_idx) { - auto op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - while ((input_node->GetType() == DATA) && (op_desc->HasAttr(ATTR_NAME_PARENT_NODE_INDEX))) { - int32_t ref_i = 0; - (void)AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i); - const auto owner_graph = input_node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - const auto parent_node = owner_graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - const auto in_data_anchor = parent_node->GetInDataAnchor(ref_i); - GE_CHECK_NOTNULL(in_data_anchor); - const auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_data_anchor); - output_idx = peer_out_data_anchor->GetIdx(); - input_node = peer_out_data_anchor->GetOwnerNode(); - op_desc = input_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GELOGD("In node[%s], type[%s], ref[%d].", input_node->GetName().c_str(), input_node->GetType().c_str(), ref_i); - } - - if (IsOpWithSubgraph(input_node)) { - if (GetOutNodesByParentNodeOutIndex(input_node, output_idx, nodes_idx) != SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get outnodes of %s by parent node out index failed.", input_node->GetName().c_str()); - GELOGE(FAILED, "[Get][Outnodes] of %s by parent node out index failed.", input_node->GetName().c_str()); - return FAILED; - } - GELOGI("Out node num: %zu.", nodes_idx.size()); - } - if (nodes_idx.empty()) { - (void)nodes_idx.emplace(input_node, output_idx); - } - return SUCCESS; -} - - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::CreateInferenceContext(const NodePtr &node, InferenceContextPtr &inference_context) { - return CreateInferenceContext(node, nullptr, inference_context); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::CreateInferenceContext(const NodePtr &node, ResourceContextMgr *const resource_context_mgr, - InferenceContextPtr &inference_context) { - GE_CHECK_NOTNULL(node); - inference_context = std::shared_ptr(InferenceContext::Create(resource_context_mgr)); - GE_CHECK_NOTNULL(inference_context); - const auto all_in_data_anchors = node->GetAllInDataAnchors(); - std::vector> input_shapes_and_types(all_in_data_anchors.size()); - std::vector marks; - - bool has_input_shapes_and_types = false; - for (const auto &in_anchor : all_in_data_anchors) { - GE_CHECK_NOTNULL(in_anchor); - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - - auto input_node = out_anchor->GetOwnerNode(); - auto out_idx = out_anchor->GetIdx(); - std::map input_nodes_2_out_idx; - if (GetRealInNodesAndIndex(input_node, out_idx, input_nodes_2_out_idx) != SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to get real in nodes and index, node:%s", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][InNodesAndIndex] of node[%s] failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - - const auto input_idx = in_anchor->GetIdx(); - for (const auto &node_idx : input_nodes_2_out_idx) { - const auto in_node = node_idx.first; - GELOGD("Input node[%s], type[%s], context_map size[%zu].", in_node->GetName().c_str(), in_node->GetType().c_str(), - context_map.size()); - const auto iter = context_map.find(in_node); - if (iter != context_map.end()) { - const auto &src_context = iter->second; - GE_CHECK_NOTNULL(src_context); - std::vector src_marks; - src_context->GetMarks(src_marks); - GELOGD("node:%s get %zu marks from node:%s", - node->GetName().c_str(), src_marks.size(), in_node->GetName().c_str()); - for (const auto& mark : src_marks) { - if (marks.empty()) { - marks.emplace_back(mark); - } - } - const auto output_idx = node_idx.second; - const auto output_shape_and_type = src_context->GetOutputHandleShapesAndTypes(); - if (output_idx < static_cast(output_shape_and_type.size())) { - GELOGI("Add shape and type from %s:%d to %s:%d", in_node->GetName().c_str(), output_idx, - node->GetName().c_str(), input_idx); - input_shapes_and_types[static_cast(input_idx)] = - output_shape_and_type[static_cast(output_idx)]; - has_input_shapes_and_types = true; - } else { - GELOGI("[%s] Output out of range. index = %d, size = %zu", node->GetName().c_str(), output_idx, - output_shape_and_type.size()); - } - } - } - } - - if (has_input_shapes_and_types) { - inference_context->SetInputHandleShapesAndTypes(std::move(input_shapes_and_types)); - } - GELOGD("Node: %s, marks size: %zu.", node->GetName().c_str(), marks.size()); - inference_context->SetMarks(marks); - - return SUCCESS; -} - -graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { - return InferShapeAndType(node, op, true); -} - -graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, const bool before_subgraph) { - const auto op_desc = node->GetOpDesc(); - const auto &op_type = op_desc->GetType(); - - graphStatus ret; - if (before_subgraph) { - ret = UpdateSubGraphDataNodes(node); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - // Get infer func and execute - ret = OpDescUtilsEx::CallInferFunc(op_desc, op); - if (ret == GRAPH_PARAM_INVALID) { - // Op ir no infer func, try to get infer func from operator factory - const auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType().c_str()); - if (node_op.IsEmpty()) { - GELOGW("[InferShape][Check] Get op from OperatorFactory failed, type: %s", op_type.c_str()); - return ret; - } - - GELOGD("get op from OperatorFactory success. opType: %s", op_type.c_str()); - const auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - node_op.BreakConnect(); - if (temp_op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "GetOpDescFromOperator failed, return nullptr."); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null"); - return GRAPH_FAILED; - } - if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { - GELOGW("[InferShape][UpdateInputName] Update input name failed"); - for (const auto &out_desc : op_desc->GetAllOutputsDescPtr()) { - if ((out_desc != nullptr) && out_desc->GetShape().GetDims().empty()) { - break; - } - return GRAPH_SUCCESS; - } - } - if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { - GELOGW("[InferShape][UpdateOutputName] Update output name failed"); - } - op_desc->AddInferFunc(temp_op_desc->GetInferFunc()); - ret = OpDescUtilsEx::CallInferFunc(op_desc, op); - GELOGI("op CallInferFunc second. ret: %u", ret); - } - if (ret != GRAPH_SUCCESS) { - return ret; - } - - if (!before_subgraph) { - return UpdateParentNodeOutTensor(node); - } - return GRAPH_SUCCESS; -} - -graphStatus ShapeRefiner::DoInferShapeAndTypeForRunning(const ConstNodePtr &node, Operator &op, - const bool before_subgraph) { - const auto op_desc = node->GetOpDesc(); - const auto origin_type = NodeUtils::GetNodeType(*node); - - graphStatus ret; - if (before_subgraph) { - ret = UpdateSubGraphDataNodes(node); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - - // Create InferenceContext to avoid null pointer access. - if (kDummyContextOpTypes.count(origin_type) > 0U) { - GELOGD("Set InferenceContext for node [%s]", op_desc->GetName().c_str()); - op.SetInferenceContext(std::shared_ptr(InferenceContext::Create())); - } - - // Get infer func and execute - ret = OpDescUtilsEx::CallInferFunc(op_desc, op); - if (ret == GRAPH_PARAM_INVALID) { - GELOGD("NodeUtils::GetNodeType return value is: [%s]", origin_type.c_str()); - const auto it = kGeLocalOpMapping.find(origin_type); - const auto infer_func = - OperatorFactoryImpl::GetInferShapeFunc((it == kGeLocalOpMapping.end()) ? origin_type : it->second); - if (infer_func == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to Get InferFunc. Reason: ASCEND_OPP_PATH is not set or it's invalid;" - " Or the infer func of %s is not registered.", - origin_type.c_str()); - GELOGE(GRAPH_FAILED, "[Get][InferFunc] failed. type is %s", origin_type.c_str()); - return GRAPH_FAILED; - } - op_desc->AddInferFunc(infer_func); - ret = OpDescUtilsEx::CallInferFunc(op_desc, op); - GELOGI("op CallInferFunc second. ret: %u", ret); - } - if (ret != GRAPH_SUCCESS) { - return ret; - } - - if (!before_subgraph) { - return UpdateParentNodeOutTensor(node); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { - return InferShapeAndType(node, true); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::InferShapeAndTypeForRunning(const NodePtr &node, Operator &op, const bool before_subgraph) { - GE_CHECK_NOTNULL(node); - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - std::vector temp_dtype; - for (auto &tensor_desc: op_desc->GetAllOutputsDescPtr()) { - temp_dtype.emplace_back(tensor_desc->GetDataType()); - } - PrintInOutTensorShape(node, "before_infershape when running"); - - const graphStatus status = DoInferShapeAndTypeForRunning(node, op, before_subgraph); - if ((status == GRAPH_PARAM_INVALID) || (status == GRAPH_SUCCESS)) { - // ensure the dtype is not changed after infershape in running - const auto after_opdesc = node->GetOpDesc(); - GE_CHECK_NOTNULL(after_opdesc); - auto all_output_tensor = after_opdesc->GetAllOutputsDescPtr(); - for (size_t i = 0UL; i < all_output_tensor.size(); ++i) { - auto &output_tensor = all_output_tensor.at(i); - if (output_tensor->GetDataType() != temp_dtype[i]) { - GELOGD("Op %s output %zu need reset dtype,original dtype is %s, new dtype is %s", - node->GetName().c_str(), i, - TypeUtils::DataTypeToSerialString(output_tensor->GetDataType()).c_str(), - TypeUtils::DataTypeToSerialString(temp_dtype[i]).c_str()); - output_tensor->SetDataType(temp_dtype[i]); - } - } - PrintInOutTensorShape(node, "after_infershape when running"); - return GRAPH_SUCCESS; - } else { - REPORT_INNER_ERR_MSG("EZ8888", "%s(%s) call infer function failed.", node->GetName().c_str(), - node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s(%s).", - node->GetName().c_str(), node->GetType().c_str()); - return GRAPH_FAILED; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::UpdateInputOutputDesc(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - auto const output_tensor = op_desc->MutableOutputDesc(static_cast(out_anchor->GetIdx())); - GE_IF_BOOL_EXEC(output_tensor == nullptr, continue); - GE_IF_BOOL_EXEC(output_tensor->MutableShape().GetDims().empty(), - output_tensor->SetOriginShape(output_tensor->GetShape())); - - ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetOriginShape().GetDims() - .size())); - output_tensor->SetOriginDataType(output_tensor->GetDataType()); - // set output origin shape range - std::vector> range; - (void)output_tensor->GetShapeRange(range); - (void)output_tensor->SetOriginShapeRange(range); - GELOGD("node name is %s, origin shape is %" PRId64 ", origin format is %s, origin data type is %s", - node->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); - } - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - const auto input_tensor = op_desc->MutableInputDesc(static_cast(in_anchor->GetIdx())); - GE_IF_BOOL_EXEC(input_tensor == nullptr, continue); - - // set input origin shape range - std::vector> range; - (void)input_tensor->GetShapeRange(range); - (void)input_tensor->SetOriginShapeRange(range); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::PostProcessAfterInfershape(const NodePtr &node, const Operator &op, - const bool is_unknown_graph) { - GE_CHECK_NOTNULL(node); - if (is_unknown_graph) { - PrintInOutTensorShape(node, "after_infershape when running"); - return GRAPH_SUCCESS; - } - - if (UpdateInputOutputDesc(node) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Update input and output desc of %s failed.", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Update][TensorDesc] Update input and output desc of %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - - if (!is_unknown_graph) { - auto ctx_after_infer = op.GetInferenceContext(); - if (ctx_after_infer != nullptr) { - std::vector marks; - ctx_after_infer->GetMarks(marks); - GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), marks.size()); - if ((!ctx_after_infer->GetOutputHandleShapesAndTypes().empty()) || (!marks.empty())) { - GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), - marks.size()); - (void)context_map.emplace(node, ctx_after_infer); - } - } - } - PrintInOutTensorShape(node, "after_infershape"); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, const bool before_subgraph) { - GE_CHECK_NOTNULL(node); - const bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - // some op can not infershape twice such as aipp - const bool need_update_input = (!is_unknown_graph) && (!op_desc->HasAttr("has_infered_verified")); - if (need_update_input) { - const auto status = UpdateOpInputDesc(node); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "update op input_desc failed! ret:%u, node:%s", status, node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Update][OpInputDesc] failed! ret:%u", status); - return status; - } - } - - if (NodeUtilsEx::Verify(node) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("EZ8888", "Verifying %s(%s) failed.", node->GetName().c_str(), node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Call][Verify] Verifying %s(%s) failed.", node->GetName().c_str(), node->GetType().c_str()); - return GRAPH_FAILED; - } - PrintInOutTensorShape(node, "before_infershape"); - Operator op = OpDescUtils::CreateOperatorFromNode(node); // do not need runtime context - - if (!is_unknown_graph) { - InferenceContextPtr inference_context; - if (CreateInferenceContext(node, inference_context) != SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "CreateInferenceContext of %s failed.", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Create][Context] CreateInferenceContext of %s failed.", node->GetName().c_str()); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(inference_context); - std::vector marks; - inference_context->GetMarks(marks); - GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), marks.size()); - op.SetInferenceContext(inference_context); - } - - const graphStatus status = InferShapeAndType(node, op, before_subgraph); - const bool check_status_valid = (status == GRAPH_PARAM_INVALID) || (status == GRAPH_SUCCESS); - if (!check_status_valid) { - REPORT_INNER_ERR_MSG("EZ8888", "%s(%s) call infer function failed.", node->GetName().c_str(), - node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node:%s(%s).", - node->GetName().c_str(), node->GetType().c_str()); - return GRAPH_FAILED; - } - - return PostProcessAfterInfershape(node, op, is_unknown_graph); -} -} // namespace ge diff --git a/graph/serialization/attr_serializer.cc b/graph/serialization/attr_serializer.cc deleted file mode 100644 index 5dc85f32b72610436ac026dd804681527fac5361..0000000000000000000000000000000000000000 --- a/graph/serialization/attr_serializer.cc +++ /dev/null @@ -1,13 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "attr_serializer.h" - -namespace ge { -} diff --git a/graph/serialization/attr_serializer.h b/graph/serialization/attr_serializer.h deleted file mode 100644 index 0343b5413c8ad14f63180a3c83bc5ddca01ccc95..0000000000000000000000000000000000000000 --- a/graph/serialization/attr_serializer.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_ATTR_SERIALIZER_H -#define METADEF_CXX_ATTR_SERIALIZER_H - -#include "proto/ge_ir.pb.h" - -#include "graph/any_value.h" - -namespace ge { -/** - * 所有的serializer都应该是无状态的、可并发调用的,全局仅构造一份,后续多线程并发调用 - */ -class GeIrAttrSerializer { - public: - virtual graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) = 0; - virtual graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) = 0; - virtual ~GeIrAttrSerializer() = default; - GeIrAttrSerializer() = default; - GeIrAttrSerializer(const GeIrAttrSerializer &) = delete; - GeIrAttrSerializer &operator=(const GeIrAttrSerializer &) = delete; - GeIrAttrSerializer(GeIrAttrSerializer &&) = delete; - GeIrAttrSerializer &operator=(GeIrAttrSerializer &&) = delete; -}; -} // namespace ge - -#endif // METADEF_CXX_ATTR_SERIALIZER_H diff --git a/graph/serialization/attr_serializer_registry.cc b/graph/serialization/attr_serializer_registry.cc deleted file mode 100644 index bb9b830440ddd7d32397141229b0f6eaadfc4328..0000000000000000000000000000000000000000 --- a/graph/serialization/attr_serializer_registry.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "attr_serializer_registry.h" - -#include "graph/serialization/bool_serializer.h" -#include "graph/serialization/buffer_serializer.h" -#include "graph/serialization/data_type_serializer.h" -#include "graph/serialization/float_serializer.h" -#include "graph/serialization/graph_serializer.h" -#include "graph/serialization/int_serializer.h" -#include "graph/serialization/list_list_float_serializer.h" -#include "graph/serialization/list_list_int_serializer.h" -#include "graph/serialization/list_value_serializer.h" -#include "graph/serialization/named_attrs_serializer.h" -#include "graph/serialization/string_serializer.h" -#include "graph/serialization/tensor_desc_serializer.h" -#include "graph/serialization/tensor_serializer.h" - -#include "common/ge_common/debug/ge_log.h" -#include "graph/debug/ge_log.h" - -namespace ge { -REG_GEIR_SERIALIZER(attr_bool, BoolSerializer, GetTypeId(), proto::AttrDef::kB); -REG_GEIR_SERIALIZER(attr_buffer, BufferSerializer, GetTypeId(), proto::AttrDef::kBt); -REG_GEIR_SERIALIZER(attr_data_type, DataTypeSerializer, GetTypeId(), proto::AttrDef::kDt); -REG_GEIR_SERIALIZER(attr_float, FloatSerializer, GetTypeId(), proto::AttrDef::kF); -REG_GEIR_SERIALIZER(attr_graph, GraphSerializer, GetTypeId(), proto::AttrDef::kG); -REG_GEIR_SERIALIZER(attr_int, IntSerializer, GetTypeId(), proto::AttrDef::kI); -REG_GEIR_SERIALIZER(attr_list, ListListFloatSerializer, - GetTypeId>>(), proto::AttrDef::kListListFloat); -REG_GEIR_SERIALIZER(attr_list_list_int, ListListIntSerializer, - GetTypeId>>(), proto::AttrDef::kListListInt); -REG_GEIR_SERIALIZER(attr_list_int, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_str, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_float, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_bool, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_tensor_desc, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_tensor, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_buffer, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_graph_def, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_named_attrs, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_list_data_type, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(attr_named_attrs, NamedAttrsSerializer, GetTypeId(), proto::AttrDef::kFunc); -REG_GEIR_SERIALIZER(attr_str, StringSerializer, GetTypeId(), proto::AttrDef::kS); -REG_GEIR_SERIALIZER(attr_tensor_desc, TensorDescSerializer, GetTypeId(), proto::AttrDef::kTd); -REG_GEIR_SERIALIZER(attr_tensor, TensorSerializer, GetTypeId(), proto::AttrDef::kT); - -AttrSerializerRegistry &AttrSerializerRegistry::GetInstance() { - static AttrSerializerRegistry instance; - return instance; -} - -void AttrSerializerRegistry::RegisterGeIrAttrSerializer(const GeIrAttrSerializerBuilder& builder, - const TypeId obj_type, - const proto::AttrDef::ValueCase proto_type) { - const std::lock_guard lck_guard(mutex_); - if (serializer_map_.count(obj_type) > 0U) { - return; - } - std::unique_ptr serializer = builder(); - serializer_map_[obj_type] = serializer.get(); - deserializer_map_[proto_type] = serializer.get(); - serializer_holder_.push_back(std::move(serializer)); -} - -GeIrAttrSerializer *AttrSerializerRegistry::GetSerializer(const TypeId obj_type) { - const std::map::const_iterator iter = serializer_map_.find(obj_type); - if (iter == serializer_map_.cend()) { - // print type - REPORT_INNER_ERR_MSG("E18888", "Serializer for type has not been registered"); - GELOGE(FAILED, "Serializer for type has not been registered"); - return nullptr; - } - return iter->second; -} - -GeIrAttrSerializer *AttrSerializerRegistry::GetDeserializer(const proto::AttrDef::ValueCase proto_type) { - const std::map::const_iterator iter = - deserializer_map_.find(proto_type); - if (iter == deserializer_map_.cend()) { - REPORT_INNER_ERR_MSG("E18888", "Deserializer for type [%d] has not been registered", - static_cast(proto_type)); - GELOGE(FAILED, "Deserializer for type [%d] has not been registered", static_cast(proto_type)); - return nullptr; - } - return iter->second; -} - -AttrSerializerRegistrar::AttrSerializerRegistrar(const GeIrAttrSerializerBuilder builder, - const TypeId obj_type, - const proto::AttrDef::ValueCase proto_type) noexcept { - if (builder == nullptr) { - GELOGE(FAILED, "SerializerBuilder is nullptr."); - return; - } - AttrSerializerRegistry::GetInstance().RegisterGeIrAttrSerializer(builder, obj_type, proto_type); -} -} diff --git a/graph/serialization/attr_serializer_registry.h b/graph/serialization/attr_serializer_registry.h deleted file mode 100644 index 30d94f99cff7452cb2eff7ccec68d88bac7b9635..0000000000000000000000000000000000000000 --- a/graph/serialization/attr_serializer_registry.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_ATTR_SERIALIZER_REGISTRY_H -#define METADEF_CXX_ATTR_SERIALIZER_REGISTRY_H -#include -#include -#include -#include - -#include "graph/type_utils.h" -#include "attr_serializer.h" - -#define REG_GEIR_SERIALIZER(serializer_name, cls, obj_type, bin_type) \ - REG_GEIR_SERIALIZER_BUILDER_UNIQ_HELPER(serializer_name, __COUNTER__, cls, obj_type, bin_type) - -#define REG_GEIR_SERIALIZER_BUILDER_UNIQ_HELPER(name, ctr, cls, obj_type, bin_type) \ - REG_GEIR_SERIALIZER_BUILDER_UNIQ(name, ctr, cls, obj_type, bin_type) - -#define REG_GEIR_SERIALIZER_BUILDER_UNIQ(name, ctr, cls, obj_type, bin_type) \ - static ::ge::AttrSerializerRegistrar register_serialize_##name##ctr \ - __attribute__((unused)) = \ - ::ge::AttrSerializerRegistrar([]()->std::unique_ptr{ \ - return std::unique_ptr(new(std::nothrow)cls()); \ - }, obj_type, bin_type) - -namespace ge { -using GeIrAttrSerializerBuilder = std::function()>; -class AttrSerializerRegistry { - public: - AttrSerializerRegistry(const AttrSerializerRegistry &) = delete; - AttrSerializerRegistry(AttrSerializerRegistry &&) = delete; - AttrSerializerRegistry &operator=(const AttrSerializerRegistry &) = delete; - AttrSerializerRegistry &operator=(AttrSerializerRegistry &&) = delete; - - ~AttrSerializerRegistry() = default; - - static AttrSerializerRegistry &GetInstance(); - /** - * 注册一个GE IR的序列化、反序列化handler - * @param builder 调用该builder时,返回一个handler的实例 - * @param obj_type 内存中的数据类型,可以通过`GetTypeId`函数获得 - * @param proto_type protobuf数据类型枚举值 - */ - void RegisterGeIrAttrSerializer(const GeIrAttrSerializerBuilder &builder, - const TypeId obj_type, - const proto::AttrDef::ValueCase proto_type); - - GeIrAttrSerializer *GetSerializer(const TypeId obj_type); - GeIrAttrSerializer *GetDeserializer(const proto::AttrDef::ValueCase proto_type); - - private: - AttrSerializerRegistry() = default; - - std::mutex mutex_; - std::vector> serializer_holder_; - std::map serializer_map_; - std::map deserializer_map_; -}; - -class AttrSerializerRegistrar { - public: - AttrSerializerRegistrar(const GeIrAttrSerializerBuilder builder, - const TypeId obj_type, - const proto::AttrDef::ValueCase proto_type) noexcept; - ~AttrSerializerRegistrar() = default; -}; -} // namespace ge - -#endif // METADEF_CXX_ATTR_SERIALIZER_REGISTRY_H diff --git a/graph/serialization/bool_serializer.cc b/graph/serialization/bool_serializer.cc deleted file mode 100644 index c295408a8cd865b8ebfc0aa85fa28c9c7b7df3c3..0000000000000000000000000000000000000000 --- a/graph/serialization/bool_serializer.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "bool_serializer.h" -#include -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_log.h" - -namespace ge { -graphStatus BoolSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - bool val; - const graphStatus ret = av.GetValue(val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get bool attr."); - return GRAPH_FAILED; - } - def.set_b(val); - return GRAPH_SUCCESS; -} - -graphStatus BoolSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - return av.SetValue(def.b()); -} - -REG_GEIR_SERIALIZER(bool_serializer, BoolSerializer, GetTypeId(), proto::AttrDef::kB); -} // namespace ge diff --git a/graph/serialization/bool_serializer.h b/graph/serialization/bool_serializer.h deleted file mode 100644 index 794d6535f9a27e556ffdbf7d2ce738fa15134771..0000000000000000000000000000000000000000 --- a/graph/serialization/bool_serializer.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_BOOL_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_BOOL_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -namespace ge { -class BoolSerializer : public GeIrAttrSerializer { - public: - BoolSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_BOOL_SERIALIZER_H_ diff --git a/graph/serialization/buffer_serializer.cc b/graph/serialization/buffer_serializer.cc deleted file mode 100644 index 6d4beced7cbcbdc970ae028d344244a3050d91aa..0000000000000000000000000000000000000000 --- a/graph/serialization/buffer_serializer.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "buffer_serializer.h" -#include -#include "proto/ge_ir.pb.h" -#include "graph/buffer.h" -#include "graph/debug/ge_log.h" - -namespace ge { -graphStatus BufferSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - Buffer val; - const graphStatus ret = av.GetValue(val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get buffer attr."); - return GRAPH_FAILED; - } - if ((val.data()!= nullptr) && (val.size() > 0U)) { - def.set_bt(val.GetData(), val.GetSize()); - } - return GRAPH_SUCCESS; -} - -graphStatus BufferSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - Buffer buffer = Buffer::CopyFrom(reinterpret_cast(def.bt().data()), def.bt().size()); - return av.SetValue(std::move(buffer)); -} - -REG_GEIR_SERIALIZER(buffer_serializer, BufferSerializer, GetTypeId(), proto::AttrDef::kBt); -} // namespace ge diff --git a/graph/serialization/buffer_serializer.h b/graph/serialization/buffer_serializer.h deleted file mode 100644 index 1930da28bf8d07b57d31d479bb26ba99d4847f8e..0000000000000000000000000000000000000000 --- a/graph/serialization/buffer_serializer.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_BUFFER_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_BUFFER_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -namespace ge { -class BufferSerializer : public GeIrAttrSerializer { - public: - BufferSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_BUFFER_SERIALIZER_H_ diff --git a/graph/serialization/data_type_serializer.cc b/graph/serialization/data_type_serializer.cc deleted file mode 100644 index 08d141efa79faab56092686217258c6b0619fb4f..0000000000000000000000000000000000000000 --- a/graph/serialization/data_type_serializer.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "data_type_serializer.h" -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_log.h" -#include "graph/types.h" - -namespace ge { -graphStatus DataTypeSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - ge::DataType value; - const graphStatus ret = av.GetValue(value); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get datatype attr."); - return GRAPH_FAILED; - } - def.set_dt(static_cast(value)); - return GRAPH_SUCCESS; -} - -graphStatus DataTypeSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - return av.SetValue(static_cast(def.dt())); -} - -REG_GEIR_SERIALIZER(data_type_serializer, DataTypeSerializer, GetTypeId(), proto::AttrDef::kDt); -} // namespace ge diff --git a/graph/serialization/data_type_serializer.h b/graph/serialization/data_type_serializer.h deleted file mode 100644 index a39721961a65658b442a83b060c55bf767dafdfb..0000000000000000000000000000000000000000 --- a/graph/serialization/data_type_serializer.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_DATA_TYPE_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_DATA_TYPE_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" - -namespace ge { -class DataTypeSerializer : public GeIrAttrSerializer { - public: - DataTypeSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_DATA_TYPE_SERIALIZER_H_ diff --git a/graph/serialization/float_serializer.cc b/graph/serialization/float_serializer.cc deleted file mode 100644 index aafc4edec527ecaa5d4827e0887fb6760fb3eec0..0000000000000000000000000000000000000000 --- a/graph/serialization/float_serializer.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "float_serializer.h" -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_log.h" -#include "graph/types.h" - -namespace ge { -graphStatus FloatSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - float32_t val; - const graphStatus ret = av.GetValue(val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get float attr."); - return GRAPH_FAILED; - } - def.set_f(val); - return GRAPH_SUCCESS; -} - -graphStatus FloatSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - return av.SetValue(def.f()); -} - -REG_GEIR_SERIALIZER(float_serializer, FloatSerializer, GetTypeId(), proto::AttrDef::kF); -} // namespace ge diff --git a/graph/serialization/float_serializer.h b/graph/serialization/float_serializer.h deleted file mode 100644 index e1cd1143dacf4f2e36cdd9518337a955651e431f..0000000000000000000000000000000000000000 --- a/graph/serialization/float_serializer.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_FLOAT_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_FLOAT_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -namespace ge { -class FloatSerializer : public GeIrAttrSerializer { - public: - FloatSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_FLOAT_SERIALIZER_H_ diff --git a/graph/serialization/graph_serializer.cc b/graph/serialization/graph_serializer.cc deleted file mode 100644 index b72db69ae1190ee6e1f4db2a4863b23678937cb4..0000000000000000000000000000000000000000 --- a/graph/serialization/graph_serializer.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph_serializer.h" -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_log.h" -#include "graph/detail/model_serialize_imp.h" - -namespace ge { -graphStatus GraphSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - const auto graph = def.mutable_g(); - GE_CHECK_NOTNULL(graph); - - if (av.GetValue(*graph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Serialize graph failed"); - return GRAPH_FAILED; - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - return av.SetValue(def.g()); -} - -REG_GEIR_SERIALIZER(graph_serializer, GraphSerializer, GetTypeId(), proto::AttrDef::kG); -} // namespace ge diff --git a/graph/serialization/graph_serializer.h b/graph/serialization/graph_serializer.h deleted file mode 100644 index 8aa92467dc181aa9404bb0aa0b2a6cc3df583d33..0000000000000000000000000000000000000000 --- a/graph/serialization/graph_serializer.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_GRAPH_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_GRAPH_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -#include "graph/compute_graph.h" -#include "proto/ge_ir.pb.h" -namespace ge { -class GraphSerializer : public GeIrAttrSerializer { - public: - GraphSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_GRAPH_SERIALIZER_H_ diff --git a/graph/serialization/int_serializer.cc b/graph/serialization/int_serializer.cc deleted file mode 100644 index 58ef10106483567d63d75b4f8024d6b7f6511c92..0000000000000000000000000000000000000000 --- a/graph/serialization/int_serializer.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "int_serializer.h" -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_log.h" - -namespace ge { -graphStatus IntSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - int64_t value; - const graphStatus ret = av.GetValue(value); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get int attr."); - return GRAPH_FAILED; - } - def.set_i(value); - return GRAPH_SUCCESS; -} - -graphStatus IntSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - return av.SetValue(def.i()); -} - -REG_GEIR_SERIALIZER(int_serializer, IntSerializer, GetTypeId(), proto::AttrDef::kI); -} // namespace ge diff --git a/graph/serialization/int_serializer.h b/graph/serialization/int_serializer.h deleted file mode 100644 index 126f64e0d0c77be6215c3737f55151acdfcf313e..0000000000000000000000000000000000000000 --- a/graph/serialization/int_serializer.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_INT_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_INT_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -namespace ge { -class IntSerializer : public GeIrAttrSerializer { - public: - IntSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_INT64_SERIALIZER_H_ diff --git a/graph/serialization/list_list_float_serializer.cc b/graph/serialization/list_list_float_serializer.cc deleted file mode 100644 index c65a384d2187fcbd1a11ac9adcf783bfcd927b39..0000000000000000000000000000000000000000 --- a/graph/serialization/list_list_float_serializer.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "list_list_float_serializer.h" -#include -#include "graph/debug/ge_util.h" -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_log.h" - -namespace ge { -graphStatus ListListFloatSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - std::vector> list_list_value; - const graphStatus ret = av.GetValue(list_list_value); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_list_float attr."); - return GRAPH_FAILED; - } - const auto mutable_list_list = def.mutable_list_list_float(); - GE_CHECK_NOTNULL(mutable_list_list); - mutable_list_list->clear_list_list_f(); - for (const auto &list_value : list_list_value) { - const auto list_f = mutable_list_list->add_list_list_f(); - GE_CHECK_NOTNULL(list_f); - for (const auto val : list_value) { - list_f->add_list_f(val); - } - } - return GRAPH_SUCCESS; -} -graphStatus ListListFloatSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - std::vector> values; - for (auto idx = 0; idx < def.list_list_float().list_list_f_size(); ++idx) { - std::vector vec; - for (auto i = 0; i < def.list_list_float().list_list_f(idx).list_f_size(); ++i) { - vec.push_back(def.list_list_float().list_list_f(idx).list_f(i)); - } - values.push_back(vec); - } - - return av.SetValue(std::move(values)); -} - -REG_GEIR_SERIALIZER(list_list_float_serializer, ListListFloatSerializer, - GetTypeId>>(), proto::AttrDef::kListListFloat); -} // namespace ge diff --git a/graph/serialization/list_list_float_serializer.h b/graph/serialization/list_list_float_serializer.h deleted file mode 100644 index 5f4b59906b8c898e1b8851393fd38a109f793473..0000000000000000000000000000000000000000 --- a/graph/serialization/list_list_float_serializer.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_LIST_LIST_FLOAT_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_LIST_LIST_FLOAT_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -namespace ge { -class ListListFloatSerializer : public GeIrAttrSerializer { - public: - ListListFloatSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; - -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_LIST_LIST_FLOAT_SERIALIZER_H_ diff --git a/graph/serialization/list_list_int_serializer.cc b/graph/serialization/list_list_int_serializer.cc deleted file mode 100644 index a175ba2a75e1262671ab8e5dd141acaa3e642423..0000000000000000000000000000000000000000 --- a/graph/serialization/list_list_int_serializer.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "list_list_int_serializer.h" -#include -#include "graph/debug/ge_util.h" -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_log.h" - -namespace ge { -graphStatus ListListIntSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - std::vector> list_list_value; - const graphStatus ret = av.GetValue(list_list_value); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_list_int attr."); - return GRAPH_FAILED; - } - const auto mutable_list_list = def.mutable_list_list_int(); - GE_CHECK_NOTNULL(mutable_list_list); - mutable_list_list->clear_list_list_i(); - for (const auto &list_value : list_list_value) { - const auto list_i = mutable_list_list->add_list_list_i(); - GE_CHECK_NOTNULL(list_i); - for (const int64_t val : list_value) { - list_i->add_list_i(val); - } - } - return GRAPH_SUCCESS; -} -graphStatus ListListIntSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - std::vector> values; - for (auto idx = 0; idx < def.list_list_int().list_list_i_size(); ++idx) { - std::vector vec; - for (auto i = 0; i < def.list_list_int().list_list_i(idx).list_i_size(); ++i) { - vec.push_back(def.list_list_int().list_list_i(idx).list_i(i)); - } - values.push_back(vec); - } - - return av.SetValue(std::move(values)); -} - -REG_GEIR_SERIALIZER(list_list_int_serializer, ListListIntSerializer, - GetTypeId>>(), proto::AttrDef::kListListInt); -} // namespace ge diff --git a/graph/serialization/list_list_int_serializer.h b/graph/serialization/list_list_int_serializer.h deleted file mode 100644 index 6bae6eafc7f65c76ae9161b7b0e5352299d66f03..0000000000000000000000000000000000000000 --- a/graph/serialization/list_list_int_serializer.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_LIST_LIST_INT_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_LIST_LIST_INT_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -namespace ge { -class ListListIntSerializer : public GeIrAttrSerializer { - public: - ListListIntSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; - -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_LIST_LIST_INT_SERIALIZER_H_ diff --git a/graph/serialization/list_value_serializer.cc b/graph/serialization/list_value_serializer.cc deleted file mode 100644 index da0c07466e1bf12e481d549f0c75875e33a62967..0000000000000000000000000000000000000000 --- a/graph/serialization/list_value_serializer.cc +++ /dev/null @@ -1,383 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "list_value_serializer.h" -#include -#include -#include - -#include "graph/debug/ge_log.h" -#include "graph/types.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_util.h" -#include "tensor_desc_serializer.h" -#include "tensor_serializer.h" -#include "named_attrs_serializer.h" -#include "graph_serializer.h" -#include "graph/ge_tensor.h" -#include "graph/def_types.h" - -namespace ge { -using ComputeGraphPtr = std::shared_ptr; -using GeTensorPtr = std::shared_ptr; -using ListValue = proto::AttrDef::ListValue; -using std::placeholders::_1; -using std::placeholders::_2; - -graphStatus ListValueSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - const static std::map> - type_serializer_map = { - {AnyValue::VT_LIST_INT, std::bind(&ListValueSerializer::SerializeListInt, _1, _2)}, - {AnyValue::VT_LIST_FLOAT, std::bind(&ListValueSerializer::SerializeListFloat, _1, _2)}, - {AnyValue::VT_LIST_BOOL, std::bind(&ListValueSerializer::SerializeListBool, _1, _2)}, - {AnyValue::VT_LIST_BYTES, std::bind(&ListValueSerializer::SerializeListBuffer, _1, _2)}, - {AnyValue::VT_LIST_DATA_TYPE, std::bind(&ListValueSerializer::SerializeListDataType, _1, _2)}, - {AnyValue::VT_LIST_STRING, std::bind(&ListValueSerializer::SerializeListString, _1, _2)}, - {AnyValue::VT_LIST_NAMED_ATTRS, std::bind(&ListValueSerializer::SerializeListNamedAttrs, _1, _2)}, - {AnyValue::VT_LIST_TENSOR_DESC, std::bind(&ListValueSerializer::SerializeListGeTensorDesc, _1, _2)}, - {AnyValue::VT_LIST_TENSOR, std::bind(&ListValueSerializer::SerializeListGeTensor, _1, _2)}, - {AnyValue::VT_LIST_GRAPH, std::bind(&ListValueSerializer::SerializeListGraphDef, _1, _2)}, - }; - - const auto iter = type_serializer_map.find(av.GetValueType()); - if (iter == type_serializer_map.end()) { - GELOGE(GRAPH_FAILED, "Value type [%d] not support.", static_cast(av.GetValueType())); - return GRAPH_FAILED; - } - return iter->second(av, def); -} -graphStatus ListValueSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - const static std::map> - type_deserializer_map = { - {ListValue::VT_LIST_INT, std::bind(&ListValueSerializer::DeserializeListInt, _1, _2)}, - {ListValue::VT_LIST_FLOAT, std::bind(&ListValueSerializer::DeserializeListFloat, _1, _2)}, - {ListValue::VT_LIST_STRING, std::bind(&ListValueSerializer::DeserializeListString, _1, _2)}, - {ListValue::VT_LIST_BYTES, std::bind(&ListValueSerializer::DeserializeListBuffer, _1, _2)}, - {ListValue::VT_LIST_BOOL, std::bind(&ListValueSerializer::DeserializeListBool, _1, _2)}, - {ListValue::VT_LIST_DATA_TYPE, std::bind(&ListValueSerializer::DeserializeListDataType, _1, _2)}, - {ListValue::VT_LIST_NAMED_ATTRS, std::bind(&ListValueSerializer::DeserializeListNamedAttrs, _1, _2)}, - {ListValue::VT_LIST_TENSOR_DESC, std::bind(&ListValueSerializer::DeserializeListGeTensorDesc, _1, _2)}, - {ListValue::VT_LIST_TENSOR, std::bind(&ListValueSerializer::DeserializeListGeTensor, _1, _2)}, - {ListValue::VT_LIST_GRAPH, std::bind(&ListValueSerializer::DeserializeListGraphDef, _1, _2)}, - }; - - const auto iter = type_deserializer_map.find(def.list().val_type()); - if (iter == type_deserializer_map.end()) { - GELOGE(GRAPH_FAILED, "Value type [%d] not support.", static_cast(def.list().val_type())); - return GRAPH_FAILED; - } - return iter->second(def, av); -} - -graphStatus ListValueSerializer::SerializeListInt(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_int attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_i(); - for (const auto val : list_val) { - mutable_list->add_i(val); - } - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_INT); - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::SerializeListString(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_string attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_s(); - for (const auto &val : list_val) { - mutable_list->add_s(val); - } - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_STRING); - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::SerializeListFloat(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_float attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_f(); - for (const auto val : list_val) { - mutable_list->add_f(val); - } - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_FLOAT); - - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::SerializeListBool(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_bool attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_b(); - for (const auto val : list_val) { - mutable_list->add_b(val); - } - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_BOOL); - - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::SerializeListGeTensorDesc(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_tensor_desc attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_td(); - for (const auto &val : list_val) { - const auto attr_proto = mutable_list->add_td(); - GE_CHECK_NOTNULL(attr_proto); - GeTensorSerializeUtils::GeTensorDescAsProto(val, attr_proto); - } - - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_TENSOR_DESC); - - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::SerializeListGeTensor(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_tensor attr_value."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_t(); - for (const auto &val : list_val) { - const auto attr_proto = mutable_list->add_t(); - GE_CHECK_NOTNULL(attr_proto); - GeTensorSerializeUtils::GeTensorAsProto(val, attr_proto); - } - - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_TENSOR); - - return GRAPH_SUCCESS; -} -graphStatus ListValueSerializer::SerializeListBuffer(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_buffer attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_bt(); - for (auto val : list_val) { - if ((val.GetData() != nullptr) && (val.size() > 0U)) { - mutable_list->add_bt(val.GetData(), val.GetSize()); - } - } - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_BYTES); - - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::SerializeListGraphDef(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_graph attr_value."); - return GRAPH_FAILED; - } - - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_g(); - for (const auto &graph : list_val) { - const auto mutable_graph = mutable_list->add_g(); - GE_CHECK_NOTNULL(mutable_graph); - *mutable_graph = graph; - } - - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_GRAPH); - - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::SerializeListNamedAttrs(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_named_attr attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_na(); - const auto attr_serializer = AttrSerializerRegistry::GetInstance().GetSerializer(GetTypeId()); - const auto named_attr_serializer = dynamic_cast(attr_serializer); - GE_CHECK_NOTNULL(named_attr_serializer); - - for (const auto &val : list_val) { - const auto attr_proto = mutable_list->add_na(); - GE_CHECK_NOTNULL(attr_proto); - if (named_attr_serializer->Serialize(val, attr_proto) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "NamedAttr [%s] serialize failed.", val.GetName().c_str()); - return GRAPH_FAILED; - } - } - - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_NAMED_ATTRS); - - return GRAPH_SUCCESS; -} -graphStatus ListValueSerializer::SerializeListDataType(const AnyValue &av, proto::AttrDef &def) { - std::vector list_val; - const graphStatus ret = av.GetValue(list_val); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get list_datatype attr."); - return GRAPH_FAILED; - } - const auto mutable_list = def.mutable_list(); - GE_CHECK_NOTNULL(mutable_list); - mutable_list->clear_dt(); - for (const auto val : list_val) { - mutable_list->add_dt(static_cast(val)); - } - mutable_list->set_val_type(proto::AttrDef::ListValue::VT_LIST_DATA_TYPE); - - return GRAPH_SUCCESS; -} - -graphStatus ListValueSerializer::DeserializeListInt(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().i_size())); - for (auto idx = 0; idx < def.list().i_size(); ++idx) { - values[static_cast(idx)] = def.list().i(idx); - } - return av.SetValue(std::move(values)); -} - -graphStatus ListValueSerializer::DeserializeListString(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().s_size())); - for (auto idx = 0; idx < def.list().s_size(); ++idx) { - values[static_cast(idx)] = def.list().s(idx); - } - return av.SetValue(std::move(values)); -} - -graphStatus ListValueSerializer::DeserializeListFloat(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().f_size())); - for (auto idx = 0; idx < def.list().f_size(); ++idx) { - values[static_cast(idx)] = def.list().f(idx); - } - - return av.SetValue(std::move(values)); -} - -graphStatus ListValueSerializer::DeserializeListBool(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().b_size())); - for (auto idx = 0; idx < def.list().b_size(); ++idx) { - values[static_cast(idx)] = def.list().b(idx); - } - return av.SetValue(std::move(values)); -} - -graphStatus ListValueSerializer::DeserializeListGeTensorDesc(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().td_size())); - for (auto idx = 0; idx < def.list().td_size(); ++idx) { - GeTensorSerializeUtils::AssembleGeTensorDescFromProto(&def.list().td(idx), values[static_cast(idx)]); - } - - return av.SetValue(std::move(values)); -} - -graphStatus ListValueSerializer::DeserializeListGeTensor(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().t_size())); - for (auto idx = 0; idx < def.list().t_size(); ++idx) { - GeTensorSerializeUtils::AssembleGeTensorFromProto(&def.list().t(idx), values[static_cast(idx)]); - } - - return av.SetValue(std::move(values)); -} - -graphStatus ListValueSerializer::DeserializeListBuffer(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().bt_size())); - for (auto idx = 0; idx < def.list().bt_size(); ++idx) { - values[static_cast(idx)] = - Buffer::CopyFrom(PtrToPtr(def.list().bt(idx).data()), def.list().bt(idx).size()); - } - - return av.SetValue(std::move(values)); -} -graphStatus ListValueSerializer::DeserializeListGraphDef(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().g_size())); - for (auto idx = 0; idx < def.list().g_size(); ++idx) { - values[static_cast(idx)] = def.list().g(idx); - } - return av.SetValue(std::move(values)); -} - -graphStatus ListValueSerializer::DeserializeListNamedAttrs(const proto::AttrDef &def, AnyValue &av) { - const auto attr_deserializer = AttrSerializerRegistry::GetInstance(). - GetDeserializer(proto::AttrDef::ValueCase::kFunc); - const auto named_attr_deserializer = dynamic_cast(attr_deserializer); - GE_CHECK_NOTNULL(named_attr_deserializer); - - std::vector values(static_cast(def.list().na_size())); - for (auto idx = 0; idx < def.list().na_size(); ++idx) { - if (named_attr_deserializer->Deserialize(def.list().na(idx), values[static_cast(idx)]) - != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "NamedAttr [%s] deserialize failed.", def.list().na(idx).name().c_str()); - return GRAPH_FAILED; - } - } - - return av.SetValue(std::move(values)); -} -graphStatus ListValueSerializer::DeserializeListDataType(const proto::AttrDef &def, AnyValue &av) { - std::vector values(static_cast(def.list().dt_size())); - for (auto idx = 0; idx < def.list().dt_size(); ++idx) { - values[static_cast(idx)] = static_cast(def.list().dt(idx)); - } - - return av.SetValue(std::move(values)); -} - -REG_GEIR_SERIALIZER(list_int, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_str, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_float, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_bool, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_tensor_desc, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_tensor, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_buffer, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_graph_def, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_named_attr, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); -REG_GEIR_SERIALIZER(list_data_type, ListValueSerializer, GetTypeId>(), proto::AttrDef::kList); -} // namespace ge diff --git a/graph/serialization/list_value_serializer.h b/graph/serialization/list_value_serializer.h deleted file mode 100644 index ef93362ba3afedb280af016efdde055706367bf7..0000000000000000000000000000000000000000 --- a/graph/serialization/list_value_serializer.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_LIST_VALUE_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_LIST_VALUE_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -#include "proto/ge_ir.pb.h" -#include "graph/ge_attr_value.h" - -namespace ge { -class ListValueSerializer : public GeIrAttrSerializer { - public: - ListValueSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; - - private: - static graphStatus SerializeListInt(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListString(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListFloat(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListBool(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListGeTensorDesc(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListGeTensor(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListBuffer(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListGraphDef(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListNamedAttrs(const AnyValue &av, proto::AttrDef &def); - static graphStatus SerializeListDataType(const AnyValue &av, proto::AttrDef &def); - - static graphStatus DeserializeListInt(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListString(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListFloat(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListBool(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListGeTensorDesc(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListGeTensor(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListBuffer(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListGraphDef(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListNamedAttrs(const proto::AttrDef &def, AnyValue &av); - static graphStatus DeserializeListDataType(const proto::AttrDef &def, AnyValue &av); -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_LIST_VALUE_SERIALIZER_H_ diff --git a/graph/serialization/model_serialize.cc b/graph/serialization/model_serialize.cc deleted file mode 100644 index 44dec7bf1979c67b29eb62e69d5739e993120aa2..0000000000000000000000000000000000000000 --- a/graph/serialization/model_serialize.cc +++ /dev/null @@ -1,1242 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/model_serialize.h" -#include -#include -#include -#include -#include -#include -#include - -#include "mmpa/mmpa_api.h" -#include "graph/debug/ge_attr_define.h" -#include "proto/ge_ir.pb.h" -#include "debug/ge_log.h" -#include "debug/ge_util.h" -#include "graph/utils/file_utils.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/ge_tensor.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/serialization/attr_serializer_registry.h" -#include "graph/utils/graph_utils.h" -#include "debug/ge_op_types.h" -#include "common/util/mem_utils.h" -#include "common/checker.h" -#include "graph/attribute_group/attr_group_serialize.h" - -namespace { -const std::string kTmpWeight = "air_weight/"; -const std::string kSrcOutPeerIndex = "_src_out_peer_index_for_ge_txt_load"; // only exist in dump file -constexpr int64_t kInvalidIndex = -1; -constexpr int32_t kDecimal = 10; -constexpr int32_t kMaxThreadNum = 16; - -ge::Status CreateExternalWeightPath(const std::string &model_path, const std::string &model_name, - const std::string &op_tag, std::string &weight_real_path, - std::string &weight_relative_path) { - static std::mutex dir_mutex; // mutex for create dir - // regulate file path - std::string dir_path; - std::string file_name; - ge::SplitFilePath(model_path, dir_path, file_name); - if (!dir_path.empty()) { - dir_path += "/"; - } - std::string regulated_model_name = ge::GetRegulatedName(model_name); - if (!regulated_model_name.empty()) { - regulated_model_name += "/"; - } - std::string regulated_op_tag = ge::GetRegulatedName(op_tag); - // get weight path - weight_relative_path = kTmpWeight + regulated_model_name + regulated_op_tag + "_file"; - weight_real_path = dir_path + weight_relative_path; - ge::SplitFilePath(weight_real_path, dir_path, file_name); - // create weight dir - const bool weight_dir_exist = (mmAccess(dir_path.c_str()) == EN_OK); - if ((!dir_path.empty()) && (!weight_dir_exist)) { - const std::lock_guard lock(dir_mutex); - GE_ASSERT_TRUE((ge::CreateDir(dir_path) == EOK), "Create direct failed, path: %s.", dir_path.c_str()); - } - return ge::SUCCESS; -} - -ge::Buffer AllocBufferByModelDef(const ge::proto::ModelDef &model_def) { -#if !defined(__ANDROID__) && !defined(ANDROID) - ge::Buffer buffer(model_def.ByteSizeLong()); -#else - Buffer buffer(model_def.ByteSize()); -#endif - GE_ASSERT_TRUE(buffer.GetSize() != 0UL, "get size failed"); - GE_ASSERT_TRUE((buffer.GetData() != nullptr), "get size failed"); - return buffer; -} -} - -namespace ge { -bool ModelSerializeImp::ParseNodeIndex(const std::string &node_index, std::string &node_name, int32_t &index) const { - const auto sep = node_index.rfind(":"); - if (sep == std::string::npos) { - GELOGD("[Parse][CheckParam] Separator \":\" is not found in node_index."); - return false; - } - node_name = node_index.substr(0UL, sep); - const auto index_str = node_index.substr(sep + 1UL); - index = static_cast(std::strtol(index_str.c_str(), nullptr, kDecimal)); - return true; -} - -int64_t ModelSerializeImp::GenDataInputInfo(const OutDataAnchorPtr &src_anchor, - const InDataAnchorPtr &dst_anchor) const { - const auto peer_in_data_anchors = src_anchor->GetPeerInDataAnchors(); - for (size_t i = 0U; i < peer_in_data_anchors.size(); ++i) { - if (peer_in_data_anchors.at(i) == dst_anchor) { - return static_cast(i); - } - } - return kInvalidIndex; -} - -int64_t ModelSerializeImp::GenCtrlInputInfo(const OutControlAnchorPtr &src_anchor, - const InControlAnchorPtr &dst_anchor) const { - const auto peer_in_ctrl_anchors = src_anchor->GetPeerInControlAnchors(); - for (size_t i = 0U; i < peer_in_ctrl_anchors.size(); ++i) { - if (peer_in_ctrl_anchors.at(i) == dst_anchor) { - return static_cast(i); - } - } - return kInvalidIndex; -} - -bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *const op_def_proto, - const bool is_dump_graph) const { - GE_CHK_BOOL_EXEC(node != nullptr, REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - return false, "[Check][Param] node is null."); - GE_CHK_BOOL_EXEC(op_def_proto != nullptr, REPORT_INNER_ERR_MSG("E18888", "param op_def_proto is null, check invalid."); - return false, "[Check][Param] op_def_proto is null."); - - op_def_proto->clear_input(); - proto::AttrDef src_out_peer_index; - // Inputs - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - if (in_data_anchor != nullptr) { - const auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if ((peer_out_anchor != nullptr) && peer_out_anchor->GetOwnerNodeBarePtr()) { - op_def_proto->add_input(peer_out_anchor->GetOwnerNodeBarePtr()->GetName() + ":" + - std::to_string(peer_out_anchor->GetIdx())); - src_out_peer_index.mutable_list()->add_i(GenDataInputInfo(peer_out_anchor, in_data_anchor)); - } else { - op_def_proto->add_input(""); - src_out_peer_index.mutable_list()->add_i(kInvalidIndex); - } - } - } - // Control edge - const auto in_control_anchor = node->GetInControlAnchor(); - if (in_control_anchor != nullptr) { - const auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors(); - for (const auto &peer_out_anchor : peer_out_anchors) { - if ((peer_out_anchor != nullptr) && peer_out_anchor->GetOwnerNodeBarePtr()) { - op_def_proto->add_input(peer_out_anchor->GetOwnerNodeBarePtr()->GetName() + ":-1"); - src_out_peer_index.mutable_list()->add_i(GenCtrlInputInfo(peer_out_anchor, in_control_anchor)); - } - } - } - if (is_dump_graph) { - GELOGD("Save src out peer index for %s.", node->GetName().c_str()); - auto const op_desc_attr = op_def_proto->mutable_attr(); - (void)op_desc_attr->insert({kSrcOutPeerIndex, src_out_peer_index}); - } - - return true; -} - -void ModelSerializeImp::FixOpDefSubgraphInstanceName(const ConstOpDescPtr &op_desc) const { - op_desc->impl_->meta_data_.ClearSubgraphNames(); - for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { - op_desc->impl_->meta_data_.AddSubGraphName(name); - } -} - -bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *const op_def_proto, - const bool not_dump_all) const { - GE_CHK_BOOL_EXEC(op_desc != nullptr, REPORT_INNER_ERR_MSG("E18888", "param op_desc is nullptr. check invalid."); - return false, "[Check][Param] op_desc is null."); - GE_CHK_BOOL_EXEC(op_def_proto != nullptr, REPORT_INNER_ERR_MSG("E18888", "param op_def_proto is null, check invalid."); - return false, "[Check][Param] op_def_proto is null."); - GE_CHK_BOOL_EXEC(op_desc->impl_ != nullptr, - REPORT_INNER_ERR_MSG("E18888", "param op_desc impl is null, check invalid."); - return false, "[Check][Param] op_desc impl is null."); - - FixOpDefSubgraphInstanceName(op_desc); - op_desc->impl_->SerializeMetaDataToOpDef(op_def_proto); - // Delete unnecessary attr - op_def_proto->clear_input_desc(); - op_def_proto->clear_output_desc(); - // Input descs - if (op_desc->GetAllInputsSize() > 0UL) { - const auto size = static_cast(op_desc->GetAllInputsSize()); - for (uint32_t i = 0U; i < size; i++) { - const auto tensor_desc = op_desc->GetInputDescPtrDfault(i); - if ((tensor_desc != nullptr) && (tensor_desc->impl_ != nullptr)) { - GeTensorSerializeUtils::GeTensorDescAsProto(*tensor_desc, op_def_proto->add_input_desc()); - } - } - } - // Output descs - if (op_desc->GetOutputsSize() > 0UL) { - const auto size = static_cast(op_desc->GetOutputsSize()); - for (uint32_t i = 0U; i < size; i++) { - const auto tensor_desc = op_desc->GetOutputDescPtr(i); - if ((tensor_desc != nullptr) && (tensor_desc->impl_ != nullptr)) { - GeTensorSerializeUtils::GeTensorDescAsProto(*tensor_desc, op_def_proto->add_output_desc()); - } - } - } - - op_def_proto->set_id(op_desc->GetId()); - OpDescToAttrDef(op_desc, op_def_proto, not_dump_all); - - return true; -} - -void ModelSerializeImp::OpDescIrDefToAttrDef(const ConstOpDescPtr &op_desc, - google::protobuf::Map *op_desc_attr) const { - if (!op_desc->impl_->GetIRMeta().GetIrAttrNames().empty()) { - proto::AttrDef ir_attr_names; - for (const auto &item : op_desc->impl_->GetIRMeta().GetIrAttrNames()) { - ir_attr_names.mutable_list()->add_s(item); - } - (*op_desc_attr)["_ir_attr_names"] = ir_attr_names; - } - if (!op_desc->impl_->GetIRMeta().GetIrInputs().empty()) { - proto::AttrDef key; - proto::AttrDef value; - for (const auto &input : op_desc->impl_->GetIRMeta().GetIrInputs()) { - key.mutable_list()->add_s(input.first); - value.mutable_list()->add_i(static_cast(input.second)); - } - (*op_desc_attr)["_ir_inputs_key"] = key; - (*op_desc_attr)["_ir_inputs_value"] = value; - } - if (!op_desc->impl_->GetIRMeta().GetIrOutputs().empty()) { - proto::AttrDef key; - proto::AttrDef value; - for (const auto &output : op_desc->impl_->GetIRMeta().GetIrOutputs()) { - key.mutable_list()->add_s(output.first); - value.mutable_list()->add_i(static_cast(output.second)); - } - (*op_desc_attr)["_ir_outputs_key"] = key; - (*op_desc_attr)["_ir_outputs_value"] = value; - } -} - -void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *const op_def_proto, - const bool not_dump_all) const { - auto const op_desc_attr = op_def_proto->mutable_attr(); - if ((op_desc == nullptr) || (op_desc->impl_ == nullptr)) { - GELOGE(FAILED, "[Check][Param] op desc or impl is nullptr."); - return; - } - if (!op_desc->impl_->input_name_idx_.empty()) { - proto::AttrDef key_in; - proto::AttrDef value_in; - for (auto &item : op_desc->impl_->input_name_idx_) { - key_in.mutable_list()->add_s(item.first); - value_in.mutable_list()->add_i(static_cast(item.second)); - } - (void) op_desc_attr->insert({"_input_name_key", key_in}); - (void) op_desc_attr->insert({"_input_name_value", value_in}); - } - if (!op_desc->impl_->output_name_idx_.empty()) { - proto::AttrDef key_out; - proto::AttrDef value_out; - for (auto &item : op_desc->impl_->output_name_idx_) { - key_out.mutable_list()->add_s(item.first); - value_out.mutable_list()->add_i(static_cast(item.second)); - } - (void) op_desc_attr->insert({"_output_name_key", key_out}); - (void) op_desc_attr->insert({"_output_name_value", value_out}); - } - if (!op_desc->impl_->GetIRMeta().GetOptionalInputName().empty()) { - proto::AttrDef opt_input; - for (auto &item : op_desc->impl_->GetIRMeta().GetOptionalInputName()) { - opt_input.mutable_list()->add_s(item); - } - (*op_desc_attr)["_opt_input"] = opt_input; - } - OpDescIrDefToAttrDef(op_desc, op_desc_attr); - - if (!SerializeAllAttrsFromAnyMap(op_desc->GetAllAttrs(), op_desc_attr)) { - GELOGE(GRAPH_FAILED, "OpDesc [%s] attr serialize failed.", op_desc->GetName().c_str()); - return; - } - - if (!op_desc->GetAttrMap().GetAttrsGroupPtr().empty() && - AttrGroupSerialize::SerializeAllAttr(*(op_def_proto->mutable_attr_groups()), op_desc->GetAttrMap()) != ge::SUCCESS) { - GELOGE(GRAPH_FAILED, "OpDesc attr group serialize failed."); - return; - } - - if (not_dump_all) { - (void) op_desc_attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); - (void) op_desc_attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); - (void) op_desc_attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); - GE_IF_BOOL_EXEC(((op_def_proto->type() == CONSTANT) || (op_def_proto->type() == CONSTANTOP)), - (void) op_desc_attr->erase(ATTR_NAME_WEIGHTS)); - } -} - -bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *const op_def_proto, - const bool not_dump_all) const { - return SerializeNode(node, false, op_def_proto, not_dump_all); -} - -bool ModelSerializeImp::SerializeNode(const NodePtr &node, const bool is_dump_graph, proto::OpDef *const op_def_proto, - const bool not_dump_all) const { - if ((node == nullptr) || (op_def_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node or op_def_proto is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] param node or op_def_proto is nullptr, check invalid."); - return false; - } - if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, not_dump_all)) { - GELOGE(GRAPH_FAILED, "[Serialize][OpDesc] failed, node:%s", node->GetName().c_str()); - return false; - } - return SerializeEdge(node, op_def_proto, is_dump_graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, - proto::GraphDef *const graph_proto, const bool not_dump_all) const { - return SerializeGraph(graph, false, graph_proto, not_dump_all); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, - const bool is_dump_graph, proto::GraphDef *const graph_proto, const bool not_dump_all) const { - if ((graph == nullptr) || (graph_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param graph or graph_proto is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] param graph or graph_proto is nullptr, check invalid."); - return false; - } - graph_proto->set_name(graph->GetName()); - // Inputs - for (const auto &input : graph->GetInputNodes()) { - if (input != nullptr) { - graph_proto->add_input(input->GetName() + ":0"); - } - } - // Outputs - for (const auto &output : graph->GetGraphOutNodesInfo()) { - if (output.first != nullptr) { - graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second)); - GELOGI("Add output to graph proto, node name:%s, index:%d", output.first->GetName().c_str(), output.second); - } - } - // ComputeGraph中的属性序列化 - if (!SerializeAllAttrsFromAnyMap(graph->GetAllAttrs(), graph_proto->mutable_attr())) { - GELOGE(GRAPH_FAILED, "ComputeGraph [%s] serialize attr failed.", graph->GetName().c_str()); - return false; - } - - if (!graph->GetAttrMap().GetAttrsGroupPtr().empty() && - AttrGroupSerialize::SerializeAllAttr(*(graph_proto->mutable_attr_groups()), graph->GetAttrMap()) != ge::SUCCESS) { - GELOGE(GRAPH_FAILED, "Graph attr group serialize failed."); - return false; - } - - for (const auto &node : graph->GetDirectNode()) { - if (!SerializeNode(node, is_dump_graph, graph_proto->add_op(), not_dump_all)) { - return false; - } - } - return true; -} - -bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *const model_proto, - const bool not_dump_all) const { - return SerializeModel(model, false, model_proto, not_dump_all); -} - -bool ModelSerializeImp::SerializeModel(const Model &model, const bool is_dump_graph, proto::ModelDef *const model_proto, - const bool not_dump_all) const { - if (model_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param model_proto is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] param model_proto is nullptr, check Invalid"); - return false; - } - model_proto->set_name(model.GetName()); - model_proto->set_custom_version(model.GetPlatformVersion()); - model_proto->set_version(model.GetVersion()); - - // Model属性序列化 - if (!SerializeAllAttrsFromAnyMap(model.GetAllAttrs(), model_proto->mutable_attr())) { - GELOGE(GRAPH_FAILED, "Model [%s] serialize attr failed.", model.GetName().c_str()); - return false; - } - // Model属性组序列化 - if (!model.GetAttrMap().GetAttrsGroupPtr().empty() && - AttrGroupSerialize::SerializeAllAttr(*(model_proto->mutable_attr_groups()), model.GetAttrMap()) != ge::SUCCESS) { - GELOGE(GRAPH_FAILED, "Model attr group serialize failed."); - return false; - } - - const auto compute_graph = model.graph_; - if (compute_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "get compute graph from graph failed as graph is invalid."); - GELOGE(GRAPH_FAILED, "[Get][ComputeGraph] return nullptr"); - return false; - } - if (!SerializeGraph(compute_graph, is_dump_graph, model_proto->add_graph(), not_dump_all)) { - GELOGE(GRAPH_FAILED, "[Serialize][Graph] failed"); - return false; - } - - const auto root_graph = GraphUtils::FindRootGraph(compute_graph); - GE_RT_FALSE_CHECK_NOTNULL(root_graph); - std::vector> subgraphs; - if (compute_graph == root_graph) { - subgraphs = compute_graph->GetAllSubgraphs(); - } else { - GELOGD("[Serialize][Subgraph] compute_graph[%s] is not root graph[%s], get all subgraphs recursively", - compute_graph->GetName().c_str(), root_graph->GetName().c_str()); - if (ge::GraphUtils::GetSubgraphsRecursively(compute_graph, subgraphs) != SUCCESS) { - GELOGE(GRAPH_FAILED, "[Serialize][Subgraph] failed"); - return false; - } - } - for (const auto &subgraph : subgraphs) { - if (!SerializeGraph(subgraph, is_dump_graph, model_proto->add_graph(), not_dump_all)) { - GELOGE(GRAPH_FAILED, "[Serialize][Subgraph] failed"); - return false; - } - } - - return true; -} - -void ModelSerializeImp::AttrDefToOpDescIn(OpDescPtr &op_desc, std::vector &key_in, - std::vector &value_in) const { - if ((op_desc == nullptr) || (op_desc->impl_ == nullptr)) { - GELOGE(FAILED, "[Serialize][Opdesc] op desc or impl is nullptr."); - return; - } - if (!key_in.empty()) { - if (key_in.size() != value_in.size()) { - GELOGW("[ParseAttrDef][CheckParam] Input key and value vector size is different. key_size=%zu, value_size=%zu.", - key_in.size(), value_in.size()); - } else { - for (size_t i = 0UL; i < key_in.size(); ++i) { - (void) op_desc->impl_->input_name_idx_.insert(std::pair(key_in.at(i), value_in.at(i))); - } - } - } -} - -void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_out, - std::vector &value_out, - const std::vector &opt_input) const { - if ((op_desc == nullptr) || (op_desc->impl_ == nullptr)) { - GELOGE(FAILED, "[Serialize][Opdesc] op desc or impl is nullptr."); - return; - } - if (!key_out.empty()) { - if (key_out.size() != value_out.size()) { - GELOGW("[ParseAttrDef][CheckParam] Output key and value vector size is different. key_size=%zu, value_size=%zu.", - key_out.size(), value_out.size()); - } else { - for (size_t i = 0UL; i < key_out.size(); ++i) { - (void)op_desc->impl_->output_name_idx_.insert(std::pair(key_out.at(i), value_out.at(i))); - } - } - } - if (!opt_input.empty()) { - for (const auto &i : opt_input) { - (void) op_desc->impl_->MutableIRMeta().AddRegisterOptionalInputName(i); - } - } -} - -void ModelSerializeImp::AttrDefToOpDescIrDef(OpDescPtr &op_desc, proto::OpDef &op_def_proto) const { - if (op_def_proto.attr().count("_ir_attr_names") > 0UL) { - const auto &name_list = op_def_proto.attr().at("_ir_attr_names").list(); - for (const auto &item_s : name_list.s()) { - op_desc->impl_->MutableIRMeta().AppendIrAttrName(item_s); - } - (void) op_def_proto.mutable_attr()->erase("_ir_attr_names"); - } - - std::vector keys; - if (op_def_proto.attr().count("_ir_inputs_key") > 0UL) { - const auto &key_list = op_def_proto.attr().at("_ir_inputs_key").list(); - for (const auto &key : key_list.s()) { - keys.emplace_back(key); - } - (void) op_def_proto.mutable_attr()->erase("_ir_inputs_key"); - } - std::vector values; - if (op_def_proto.attr().count("_ir_inputs_value") > 0UL) { - const auto &value_list = op_def_proto.attr().at("_ir_inputs_value").list(); - for (const auto &value : value_list.i()) { - if (value >= kIrInputTypeEnd) { - GELOGW("[ParseAttrDef][CheckParam] ir inputs value[%" PRId64 "] is invalid, valid range is [%d-%d)", - value, kIrInputRequired, kIrInputTypeEnd); - return; - } - values.emplace_back(static_cast(value)); - } - (void) op_def_proto.mutable_attr()->erase("_ir_inputs_value"); - } - if (keys.size() != values.size()) { - GELOGW("[ParseAttrDef][CheckParam] ir inputs key and value vector size is different. key_size=%zu, value_size=%zu.", - keys.size(), values.size()); - return; - } - for (size_t i = 0U; i < keys.size(); ++i) { - op_desc->impl_->MutableIRMeta().AppendIrInput(std::move(keys[i]), values[i]); - } - - std::vector out_keys; - if (op_def_proto.attr().count("_ir_outputs_key") > 0UL) { - const auto &key_list = op_def_proto.attr().at("_ir_outputs_key").list(); - out_keys.reserve(key_list.s_size()); - for (const auto &key : key_list.s()) { - out_keys.emplace_back(key); - } - (void) op_def_proto.mutable_attr()->erase("_ir_outputs_key"); - } - std::vector out_types; - if (op_def_proto.attr().count("_ir_outputs_value") > 0UL) { - const auto &val_list = op_def_proto.attr().at("_ir_outputs_value").list(); - out_types.reserve(val_list.i_size()); - for (const auto &val : val_list.i()) { - if (val < kIrOutputTypeEnd) { - out_types.emplace_back(static_cast(val)); - } - } - (void) op_def_proto.mutable_attr()->erase("_ir_outputs_value"); - } - if (out_keys.size() == out_types.size()) { - for (size_t i = 0UL; i < out_keys.size(); ++i) { - op_desc->impl_->MutableIRMeta().AppendIrOutput(std::move(out_keys[i]), out_types[i]); - } - } -} - -bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) const { - std::vector opt_input; - std::vector key_in; - std::vector value_in; - std::vector key_out; - std::vector value_out; - - ExtractMetaDataAttrIn(op_def_proto, opt_input, key_in, value_in); - ExtractMetaDataAttr(op_def_proto, key_out, value_out); - - op_desc = ComGraphMakeShared(op_def_proto); - GE_CHK_BOOL_EXEC(op_desc != nullptr, REPORT_INNER_ERR_MSG("E18888", "create OpDesc failed."); - return false, "[Create][OpDesc] op_desc is nullptr."); - GE_CHK_BOOL_EXEC(op_desc->impl_ != nullptr, REPORT_INNER_ERR_MSG("E18888", "create OpDesc impl failed."); - return false, "[Create][OpDesc] op_desc impl is nullptr."); - // Input tensor - for (auto &input_desc : *op_def_proto.mutable_input_desc()) { - const std::shared_ptr temp_value = ComGraphMakeShared(&input_desc); - GE_CHK_BOOL_EXEC(temp_value != nullptr, REPORT_INNER_ERR_MSG("E18888", "create GeTensorDesc failed."); - return false, "[Create][GeTensorDesc] temp_value is nullptr."); - op_desc->impl_->inputs_desc_.push_back(temp_value); - } - // Output tensor - for (auto &output_desc : *op_def_proto.mutable_output_desc()) { - const std::shared_ptr temp_value = ComGraphMakeShared(&output_desc); - GE_CHK_BOOL_EXEC(temp_value != nullptr, REPORT_INNER_ERR_MSG("E18888", "create GeTensorDesc failed."); - return false, "[Create][GeTensorDesc] temp_value is nullptr."); - op_desc->impl_->outputs_desc_.push_back(temp_value); - } - - op_desc->SetId(op_def_proto.id()); - uint32_t graph_index = 0U; - for (const std::string &name : op_def_proto.subgraph_name()) { - if (!name.empty()) { - (void) op_desc->AddSubgraphName(name); - (void) op_desc->SetSubgraphInstanceName(graph_index++, name); - } - } - - // insert name index by key and value - AttrDefToOpDescIn(op_desc, key_in, value_in); - AttrDefToOpDesc(op_desc, key_out, value_out, opt_input); - AttrDefToOpDescIrDef(op_desc, op_def_proto); - if (!DeserializeAllAttrsToAttrHolder(op_def_proto.attr(), op_desc.get())) { - GELOGE(GRAPH_FAILED, "Opdesc [%s] attr deserialize failed", op_def_proto.name().c_str()); - return false; - } - GE_ASSERT_GRAPH_SUCCESS(AttrGroupSerialize::DeserializeAllAttr(op_def_proto.attr_groups(), op_desc.get())); - return true; -} - -void ModelSerializeImp::ExtractMetaDataAttrIn(proto::OpDef &op_def_proto, std::vector &opt_input, - std::vector &key_in, std::vector &value_in) const { - if (op_def_proto.attr().count("_opt_input") > 0UL) { - const auto &name_list = op_def_proto.attr().at("_opt_input").list(); - for (const auto &item_s : name_list.s()) { - opt_input.push_back(item_s); - } - (void) op_def_proto.mutable_attr()->erase("_opt_input"); - } - if (op_def_proto.attr().count("_input_name_key") > 0UL) { - const auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); - for (const auto &item_s : output_name_key_list.s()) { - key_in.push_back(item_s); - } - (void) op_def_proto.mutable_attr()->erase("_input_name_key"); - } - if (op_def_proto.attr().count("_input_name_value") > 0UL) { - const auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); - for (const auto &item_i : input_name_value_list.i()) { - value_in.push_back(static_cast(item_i)); - } - (void) op_def_proto.mutable_attr()->erase("_input_name_value"); - } -} - -void ModelSerializeImp::ExtractMetaDataAttr(proto::OpDef &op_def_proto, std::vector &key_out, - std::vector &value_out) const { - if (op_def_proto.attr().count("_output_name_key") > 0UL) { - const auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); - for (const auto &item_s : output_name_key_list.s()) { - key_out.push_back(item_s); - } - (void) op_def_proto.mutable_attr()->erase("_output_name_key"); - } - if (op_def_proto.attr().count("_output_name_value") > 0UL) { - const auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); - for (const auto &item_i : output_name_value_list.i()) { - value_out.push_back(static_cast(item_i)); - } - (void) op_def_proto.mutable_attr()->erase("_output_name_value"); - } -} - -bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto) { - GE_RT_FALSE_CHECK_NOTNULL(graph); - OpDescPtr op_desc = nullptr; - if (!UnserializeOpDesc(op_desc, op_def_proto)) { - GELOGE(ge::INTERNAL_ERROR, "[Unserialize][OpDesc] error."); - return false; - } - - const NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); - GE_CHK_BOOL_EXEC(node != nullptr, - REPORT_INNER_ERR_MSG("E18888", "add node to graph:%s failed", graph->GetName().c_str()); - return false, "[Add][Node] to graph:%s failed.", graph->GetName().c_str()); - - std::vector src_out_peer_index; - if (op_def_proto.attr().count(kSrcOutPeerIndex) > 0UL) { - const auto &src_out_peer_index_list = op_def_proto.attr().at(kSrcOutPeerIndex).list(); - for (const auto &item_i : src_out_peer_index_list.i()) { - src_out_peer_index.push_back(static_cast(item_i)); - } - (void)op_def_proto.mutable_attr()->erase(kSrcOutPeerIndex); - } - - // Inputs - int32_t dst_index = 0; - int32_t cur_index = 0; - const size_t input_size = op_def_proto.input().size(); - for (const auto &input : op_def_proto.input()) { - std::string node_name; - int32_t index = 0; - if (ParseNodeIndex(input, node_name, index)) { - int32_t peer_index = static_cast(kInvalidIndex); - if (src_out_peer_index.size() == input_size) { - peer_index = src_out_peer_index[cur_index]; - } - node_input_node_names_.push_back( - NodeNameNodeReq{node_name, index, peer_index, node, dst_index, op_def_proto.name()}); - } - if (index >= 0) { - dst_index++; - } - ++cur_index; - } - node_map_[op_def_proto.name()] = node; - return true; -} - -void ModelSerializeImp::SaveEdgeInfo(const AnchorPtr &src_anchor, const AnchorPtr &dst_anchor, - const int64_t src_out_peer_index, const int64_t cur_index, - std::unordered_map &edges) const { - // old version would be -1 - if (src_out_peer_index >= 0) { - edges[src_anchor].emplace(dst_anchor, src_out_peer_index); - } else { - edges[src_anchor].emplace(dst_anchor, cur_index); - } -} - -bool ModelSerializeImp::LinkEdges(const std::unordered_map &edges) const { - for (const auto &edge : edges) { - for (const auto &out_anchor_index : edge.second) { - GE_ASSERT_SUCCESS(GraphUtils::AddEdge(edge.first, out_anchor_index.first)); - } - } - return true; -} - -bool ModelSerializeImp::HandleNodeNameRef() { - // Edges - std::unordered_map edges; - int64_t cur_index = 0; - for (auto &item : node_input_node_names_) { - ++cur_index; - const auto src_node_it = node_map_.find(item.src_node_name); - GE_ASSERT_TRUE(src_node_it != node_map_.end()); - GE_IF_BOOL_EXEC((src_node_it->second == nullptr) || (item.dst_node == nullptr), continue); - if (item.src_out_index >= 0) { - const auto src_anchor = src_node_it->second->GetOutDataAnchor(item.src_out_index); - const auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index); - GE_ASSERT_NOTNULL(src_anchor); - GE_ASSERT_NOTNULL(dst_anchor); - SaveEdgeInfo(src_anchor, dst_anchor, item.src_out_peer_index, cur_index, edges); - } else { - // Control edge - const auto src_anchor = src_node_it->second->GetOutControlAnchor(); - const auto dst_anchor = item.dst_node->GetInControlAnchor(); - if ((src_anchor != nullptr) && (dst_anchor != nullptr)) { - SaveEdgeInfo(src_anchor, dst_anchor, item.src_out_peer_index, cur_index, edges); - } - } - } - GE_ASSERT_TRUE(LinkEdges(edges)); - // Graph input - for (auto &item : graph_input_node_names_) { - const std::map::const_iterator node_it = node_map_.find(item.node_name); - GE_ASSERT_TRUE(node_it != node_map_.cend()); - GE_IF_BOOL_EXEC(item.graph == nullptr, continue); - GE_ASSERT_NOTNULL(item.graph->AddInputNode(node_it->second)); - } - // Graph output - for (auto &item : graph_output_node_names_) { - const std::map::const_iterator node_it = node_map_.find(item.node_name); - GE_ASSERT_TRUE(node_it != node_map_.cend()); - - GE_IF_BOOL_EXEC(item.graph == nullptr, continue); - const auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index); - GELOGI("node name:%s, item.index:%d", node_it->second->GetName().c_str(), item.index); - GE_ASSERT_NOTNULL(ret); - } - node_input_node_names_.clear(); - graph_input_node_names_.clear(); - graph_output_node_names_.clear(); - node_map_.clear(); - return true; -} - -bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, - std::map &subgraphs) const { - std::queue all_graphs; - all_graphs.emplace(compute_graph); - while (!all_graphs.empty()) { - const ComputeGraphPtr graph = all_graphs.front(); - all_graphs.pop(); - - for (const NodePtr &node : graph->GetDirectNode()) { - const OpDescPtr op_desc = node->GetOpDesc(); - for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { - if (name.empty()) { - continue; - } - const auto it = subgraphs.find(name); - if (it == subgraphs.end()) { - REPORT_INNER_ERR_MSG("E18888", "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), - name.c_str(), subgraphs.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] Node:%s, Subgraph:%s not found, num:%zu.", - op_desc->GetName().c_str(), name.c_str(), subgraphs.size()); - return false; - } - - ComputeGraphPtr &subgraph = it->second; - GE_ASSERT_NOTNULL(subgraph); - subgraph->SetParentGraph(graph); - subgraph->SetParentNode(node); - (void)compute_graph->AddSubgraph(subgraph->GetName(), subgraph); - all_graphs.emplace(subgraph); - } - } - } - - return true; -} - -Status ModelSerializeImp::ParallelUnserializeGraph( - std::map &graphs, - ::google::protobuf::RepeatedPtrField &graphs_proto) { - if (graphs_proto.empty()) { - GELOGW("Graph proto is empty"); - return SUCCESS; - } - // 当图个数小于16时,只需要拉起跟子图个数相同的线程数(这里需要拉起的线程数需要除去主线程) - const int32_t thread_num = std::min(graphs_proto.size() - 1, kMaxThreadNum); - GELOGI("Start to unserialize graph with multi thread, thread num[%d], graph num[%d]", - thread_num, graphs_proto.size()); - // 初始化子图表 - for (int32_t idx = 0; idx < graphs_proto.size(); ++idx) { - graphs.emplace(std::make_pair(graphs_proto[idx].name(), nullptr)); - } - std::vector threads; - std::atomic ret{ge::SUCCESS}; - std::atomic doing_num{0}; - auto path = air_path_; - auto func = [&graphs_proto, &path, &ret, &graphs, &doing_num] () { - int32_t cur_num = doing_num.fetch_add(1); - while ((cur_num < graphs_proto.size()) && (ret == ge::SUCCESS)) { - GELOGD("Unserialize graph, id: %ld, graph_name: %s", - cur_num, graphs_proto[cur_num].name().c_str()); - ge::ModelSerializeImp impl; - impl.SetAirModelPath(path); - if (!impl.UnserializeGraph(graphs[graphs_proto[cur_num].name()], graphs_proto[cur_num])) { - GELOGE(ge::FAILED, "Unserialize graph: %ld failed, graph_name: %s", - cur_num, graphs_proto[cur_num].name().c_str()); - ret = ge::PARAM_INVALID; - return; - } - cur_num = doing_num.fetch_add(1); - } - }; - for (int32_t i = 0; i < thread_num; i++) { - threads.emplace_back(std::thread([i, &func]() { - auto thread_name = "ge_dserigrh_" + std::to_string(i); - (void)pthread_setname_np(pthread_self(), thread_name.c_str()); - func(); - })); - } - // 当前线程也利用起来 - func(); - for (auto &t : threads) { - if (t.joinable()) { - t.join(); - } - } - GE_ASSERT_SUCCESS(ret, "Parallel unserialize graph failed."); - return SUCCESS; -} - -Status ModelSerializeImp::UnserializeGraph( - std::map &graphs, - ::google::protobuf::RepeatedPtrField &graphs_proto) { - if (graphs_proto.empty()) { - GELOGW("Graph proto is empty"); - return SUCCESS; - } - GELOGI("Start to unserialize graph, graph num[%d]", graphs_proto.size()); - // 初始化子图表 - for (int32_t idx = 0; idx < graphs_proto.size(); ++idx) { - GELOGD("Unserialize graph, id: %ld, graph_name: %s", - idx, graphs_proto[idx].name().c_str()); - ge::ModelSerializeImp impl; - impl.SetAirModelPath(air_path_); - GE_ASSERT_TRUE(impl.UnserializeGraph(graphs[graphs_proto[idx].name()], graphs_proto[idx]), - "Unserialize graph: %ld failed, graph_name: %s", idx, graphs_proto[idx].name().c_str()); - } - return SUCCESS; -} - -bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto, - const bool is_enable_multi_thread) { - model.name_ = model_proto.name(); - model.version_ = model_proto.version(); - model.platform_version_ = model_proto.custom_version(); - // Model属性反序列化 - if (!DeserializeAllAttrsToAttrHolder(model_proto.attr(), &model)) { - GELOGE(GRAPH_FAILED, "Model [%s] deserialize attr failed.", model.GetName().c_str()); - return false; - } - // Model属性组反序列化 - GE_ASSERT_GRAPH_SUCCESS(AttrGroupSerialize::DeserializeAllAttr(model_proto.attr_groups(), &model)); - auto &graphs_proto = *model_proto.mutable_graph(); - std::map graphs; - if (is_enable_multi_thread) { - GE_ASSERT_SUCCESS(ParallelUnserializeGraph(graphs, graphs_proto)); - } else { - GE_ASSERT_SUCCESS(UnserializeGraph(graphs, graphs_proto)); - } - - if (!graphs_proto.empty()) { - // 从图集合中找到根图 - const auto it = graphs.find(graphs_proto[0].name()); - GE_ASSERT_TRUE(it != graphs.end(), "Can not find graph: %s in graph map", - graphs_proto[0].name().c_str()); - model.graph_ = it->second; - // 存在子图的情况下需要构造图直接的关系 - if (graphs.size() > 1) { - GE_ASSERT_TRUE(RebuildOwnership(model.graph_, graphs), - "[Rebuild][GraphOwnerShip] failed"); - } - } - return true; -} - -bool ModelSerializeImp::UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto) { - graph = ComGraphMakeShared(graph_proto.name()); - if ((graph == nullptr) || (graph->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "create ComputeGraph failed."); - GELOGE(GRAPH_FAILED, "[Create][ComputeGraph] ComputeGraph make shared failed"); - return false; - } - - // Inputs - for (const auto &input : graph_proto.input()) { - std::string node_name; - int32_t index; - if (ParseNodeIndex(input, node_name, index)) { - graph_input_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); - } - } - // Outputs - for (const auto &output : graph_proto.output()) { - std::string node_name; - int32_t index; - if (ParseNodeIndex(output, node_name, index)) { - graph_output_node_names_.push_back(NodeNameGraphReq{node_name, index, graph}); - } - } - // ComputeGraph 属性反序列化 - if (!DeserializeAllAttrsToAttrHolder(graph_proto.attr(), graph.get())) { - GELOGE(GRAPH_FAILED, "ComputeGraph [%s] deserialize attr failed.", graph->GetName().c_str()); - return false; - } - // ComputeGraph 属性组反序列化 - GE_ASSERT_GRAPH_SUCCESS(AttrGroupSerialize::DeserializeAllAttr(graph_proto.attr_groups(), graph.get())); - for (auto &op_def_proto : *graph_proto.mutable_op()) { - // 还原const的weight到算子proto上 - if ((op_def_proto.type() == CONSTANT) || (op_def_proto.type() == CONSTANTOP)) { - if (!SetWeightForModel(op_def_proto)) { - GELOGE(GRAPH_FAILED, "[Unserialize][Model] Set const weight for node: %s failed", op_def_proto.name().c_str()); - return false; - } - } - // 反序列化算子 - if (!UnserializeNode(graph, op_def_proto)) { - GELOGE(GRAPH_FAILED, "[Unserialize][Node] failed"); - return false; - } - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::UnserializeGraph(ComputeGraphPtr &graph, - proto::GraphDef &graph_proto) { - if (!UnserializeGraphWithoutEdge(graph, graph_proto)) { - GELOGE(GRAPH_FAILED, "[Deserialize][Graph] Deserialize graph without edges failed, graph_name: %s", - graph_proto.name().c_str()); - return false; - } - if (!HandleNodeNameRef()) { - GELOGE(GRAPH_FAILED, "[Call][HandleNodeNameRef] Link Anchor or set graph input or output fail"); - return false; - } - return true; -} - -static bool ReadProtoFromBinaryFile(const uint8_t *const data, const size_t len, - google::protobuf::Message *const proto) { - GE_CHK_BOOL_EXEC(data != nullptr, REPORT_INNER_ERR_MSG("E18888", "param data is nullptr, check invalid."); - return false, "[Check][Param] data is null."); - GE_CHK_BOOL_EXEC(proto != nullptr, REPORT_INNER_ERR_MSG("E18888", "param proto is nullptr, check invalid."); - return false, "[Check][Param] proto is null."); - - google::protobuf::io::CodedInputStream coded_stream(data, static_cast(len)); - // 2048M -1 - coded_stream.SetTotalBytesLimit(INT32_MAX); - if (!proto->ParseFromCodedStream(&coded_stream)) { - REPORT_INNER_ERR_MSG("E18888", "Read proto from BinaryFile failed, len %zu", len); - GELOGE(GRAPH_FAILED, "[Read][Proto] from BinaryFile failed, len %zu", len); - return false; - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeAllAttrsFromAnyMap( - const std::map &attr_map, - google::protobuf::Map *const mutable_attr) { - if (mutable_attr == nullptr) { - GELOGE(GRAPH_FAILED, "mutable_attr is nullptr."); - return false; - } - - for (const auto &attr : attr_map) { - const AnyValue attr_value = attr.second; - const auto value_serializer = AttrSerializerRegistry::GetInstance().GetSerializer(attr_value.GetValueTypeId()); - if (value_serializer == nullptr) { - GELOGE(GRAPH_FAILED, "Get serialized failed,name:[%s] value type:%u.", - attr.first.c_str(), attr_value.GetValueType()); - return false; - } - proto::AttrDef attr_def; - if (value_serializer->Serialize(attr_value, attr_def) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Attr serialized failed, name:[%s].", attr.first.c_str()); - return false; - } - (*mutable_attr)[attr.first] = attr_def; - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::DeserializeAllAttrsToAttrHolder( - const google::protobuf::Map &proto_attr_map, AttrHolder *const attr_holder) { - if (attr_holder == nullptr) { - return false; - } - for (const auto &iter : proto_attr_map) { - // skip not set attribute - if ((iter.second.value_case() == proto::AttrDef::VALUE_NOT_SET) || - ((iter.second.value_case() == proto::AttrDef::kList) && - (iter.second.list().val_type() == ge::proto::AttrDef::ListValue::VT_LIST_NONE))) { - continue; - } - - const auto deserializer = - AttrSerializerRegistry::GetInstance().GetDeserializer(iter.second.value_case()); - if (deserializer == nullptr) { - GELOGE(GRAPH_FAILED, "Get deserialize failed, attr type:[%d].", static_cast(iter.second.value_case())); - return false; - } - AnyValue attr_value; - if (deserializer->Deserialize(iter.second, attr_value) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Attr deserialized failed, name:[%s].", iter.first.c_str()); - return false; - } - - if (attr_holder->SetAttr(iter.first, attr_value) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Set attr [%s] failed.", iter.first.c_str()); - return false; - } - } - return true; -} - -bool ModelSerializeImp::SeparateModelDef(Buffer &buffer, const std::string &path, proto::ModelDef &model_def) const { - for (auto &graph_def : *model_def.mutable_graph()) { - for (auto &op_def : *graph_def.mutable_op()) { - if ((op_def.type() != CONSTANT) && (op_def.type() != CONSTANTOP)) { - continue; - } - auto attr_map = op_def.mutable_attr(); - auto iter = attr_map->find(ATTR_NAME_WEIGHTS); - GE_ASSERT_TRUE(iter != attr_map->end(), "Find attr [%s] of op[%s] failed.", ATTR_NAME_WEIGHTS.c_str(), - op_def.name().c_str()); - auto tensor_def = iter->second.mutable_t(); - if (tensor_def->data().empty()) { - GELOGW("Weight attr of node: %s is empty", op_def.name().c_str()); - continue; - } - auto reuse_iter = attr_map->find(ATTR_NAME_IS_REUSE_EXTERNAL_WEIGHT); - if ((reuse_iter != attr_map->end()) && (reuse_iter->second.b())) { - GELOGD("op:%s of model:%s need reuse external weight, do not need dump weight.", op_def.name().c_str(), - model_def.name().c_str()); - tensor_def->set_data(""); - continue; - } - std::string relative_path; - std::string weight_real_path; - std::string op_tag = op_def.type() + "_" + graph_def.name() + "_" + op_def.name(); - GE_ASSERT_SUCCESS(CreateExternalWeightPath(path, model_def.name(), op_tag, weight_real_path, relative_path), - "[Create][ExternalWeightPath] failed, path:%s, model_name:%s, op_tag:%s.", path.c_str(), - model_def.name().c_str(), op_tag.c_str()); - GELOGD("Create external weight path:%s, model_name:%s, op_tag:%s, weight real path:%s, relative_path:%s", - path.c_str(), model_def.name().c_str(), op_tag.c_str(), weight_real_path.c_str(), relative_path.c_str()); - const char *const data = tensor_def->data().c_str(); - const auto op_name = op_def.name(); - const int64_t length = static_cast(tensor_def->data().length()); - GE_ASSERT_GRAPH_SUCCESS(SaveBinToFile(data, length, weight_real_path), - "Write data of attr [%s] of op[%s] to path[%s] failed.", ATTR_NAME_WEIGHTS.c_str(), - op_name.c_str(), weight_real_path.c_str()); - tensor_def->set_data(""); - // set file attr and length attr - proto::AttrDef file_attr; - file_attr.set_s(relative_path); - attr_map->insert({ATTR_NAME_LOCATION, file_attr}); - proto::AttrDef length_attr; - length_attr.set_i(length); - attr_map->insert({ATTR_NAME_LENGTH, length_attr}); - } - } - buffer = AllocBufferByModelDef(model_def); - return SerializeToBuffer(model_def, buffer); -} - -bool ModelSerializeImp::SerializeToBuffer(const proto::ModelDef &model_def, Buffer &buffer) const { - google::protobuf::io::ArrayOutputStream array_stream(buffer.GetData(), static_cast(buffer.GetSize())); - google::protobuf::io::CodedOutputStream output_stream(&array_stream); - output_stream.SetSerializationDeterministic(true); - return model_def.SerializeToCodedStream(&output_stream); -} - -Buffer ModelSerialize::SerializeModel(const Model &model, const bool not_dump_all) const { - std::string path; - return SerializeModel(model, path, true, not_dump_all); -} - -Buffer ModelSerialize::SerializeSeparateModel(const Model &model, const std::string &path, - const bool not_dump_all) const { - proto::ModelDef model_def; - ModelSerializeImp model_imp; - if (!model_imp.SerializeModel(model, true, &model_def, not_dump_all)) { - return Buffer(); - } - auto buffer = AllocBufferByModelDef(model_def); - if (!model_imp.SeparateModelDef(buffer, path, model_def)) { - return Buffer(); - } - return buffer; -} - -Buffer ModelSerialize::SerializeModel(const Model &model, const std::string &path, - const bool is_need_separate, const bool not_dump_all) const { - proto::ModelDef model_def; - ModelSerializeImp model_imp; - if (!model_imp.SerializeModel(model, &model_def, not_dump_all)) { - return Buffer(); - } - auto buffer = AllocBufferByModelDef(model_def); - // try serialize to buffer - if (model_imp.SerializeToBuffer(model_def, buffer)) { - return buffer; - } - // if is_need_separate is not enable, return failed - if (!is_need_separate) { - GELOGE(GRAPH_FAILED, "[Serialize][Model] Model is larger than 2G, " - "but can not separate in this scenario, you can use external_weight instead"); - return Buffer(); - } - GELOGW("[Serialize][Model] Model could larger than 2G, need separate"); - if (!model_imp.SeparateModelDef(buffer, path, model_def)) { - GELOGW("[Serialize][Model] Serialize to binary failed"); - return Buffer(); - } - return buffer; -} - -Status ModelSerialize::SerializeModel(const Model &model, const bool not_dump_all, proto::ModelDef &model_def) const { - ModelSerializeImp model_imp; - if (!model_imp.SerializeModel(model, true, &model_def, not_dump_all)) { - return FAILED; - } - return SUCCESS; -} - -bool ModelSerializeImp::LoadWeightFromFile(const std::string &file_path, - const int64_t &length, - std::string &weight) const { - if (length <= 0L) { - GELOGE(GRAPH_FAILED, "Value length is less than 0."); - return false; - } - auto bin_data = std::unique_ptr(new(std::nothrow) char_t[length]); - if (bin_data == nullptr) { - GELOGE(FAILED, "[Allocate][Mem]Allocate mem failed"); - return false; - } - std::string air_directory; - std::string air_filename; - SplitFilePath(air_path_, air_directory, air_filename); - std::string weight_path; - if (!air_directory.empty()) { - weight_path = air_directory + "/" + file_path; - } else { - weight_path = file_path; - } - size_t data_len = static_cast(length); - if (GetBinFromFile(weight_path, static_cast(bin_data.get()), data_len) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Get bin from file failed."); - return false; - } - weight = std::string(bin_data.get(), length); - return true; -} - -bool ModelSerializeImp::SetWeightForModel(proto::OpDef &op_def) const { - auto attr_map = op_def.mutable_attr(); - auto iter = attr_map->find(ATTR_NAME_LOCATION); - if (iter == attr_map->end()) { - return true; - } - const std::string file_path = iter->second.s(); - iter = attr_map->find(ATTR_NAME_LENGTH); - if (iter == attr_map->end()) { - return true; - } - const int64_t length = iter->second.i(); - std::string weight; - if (!LoadWeightFromFile(file_path, length, weight)) { - GELOGE(GRAPH_FAILED, "Load weight from path %s failed.", file_path.c_str()); - return false; - } - iter = attr_map->find(ATTR_NAME_WEIGHTS); - GE_ASSERT_TRUE(iter != attr_map->end(), "find attr [%s] of op[%s] failed.", ATTR_NAME_WEIGHTS.c_str(), - op_def.name().c_str()); - attr_map->erase(ATTR_NAME_LOCATION); - attr_map->erase(ATTR_NAME_LENGTH); - iter->second.mutable_t()->set_data(weight); - return true; -} - -bool ModelSerialize::UnserializeModel(const uint8_t *const data, const size_t len, - Model &model, const bool is_enable_multi_thread) const { - if (data == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param data is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] data is nullptr"); - return false; - } - - std::shared_ptr model_proto_ptr; - model_proto_ptr = ComGraphMakeShared(); - if (model_proto_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create ModelDef failed."); - GELOGE(GRAPH_FAILED, "[Create][ModelDef] proto::ModelDef make shared failed"); - return false; - } - - auto &model_proto = *model_proto_ptr; - if (!ReadProtoFromBinaryFile(data, len, &model_proto)) { - GELOGE(GRAPH_FAILED, "[Read][Proto] from binaryfile failed."); - return false; - } - ModelSerializeImp model_imp; - model_imp.SetProtobufOwner(model_proto_ptr); - if (!model_imp.UnserializeModel(model, model_proto, is_enable_multi_thread)) { - GELOGE(GRAPH_FAILED, "[Unserialize][Model] failed"); - return false; - } - return model.IsValid(); -} - -bool ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def, Model &model) const { - std::string path; - return UnserializeModel(model_def, model, path); -} - -bool ModelSerialize::UnserializeModel(ge::proto::ModelDef &model_def, Model &model, const std::string &path) const { - const std::shared_ptr model_def_ptr = ComGraphMakeShared(model_def); - GE_CHK_BOOL_EXEC(model_def_ptr != nullptr, REPORT_INNER_ERR_MSG("E18888", "create ModelDef failed."); - return false, "[Create][ModelDef] mode_def make shared failed"); - - ModelSerializeImp model_imp; - model_imp.SetAirModelPath(path); - model_imp.SetProtobufOwner(model_def_ptr); - if (!model_imp.UnserializeModel(model, *model_def_ptr)) { - GELOGE(GRAPH_FAILED, "[Unserialize][Model] fail"); - return false; - } - return model.IsValid(); -} -} // namespace ge diff --git a/graph/serialization/named_attrs_serializer.cc b/graph/serialization/named_attrs_serializer.cc deleted file mode 100644 index 34652b0c7213555d2a7cc90da7f11bc52d923cda..0000000000000000000000000000000000000000 --- a/graph/serialization/named_attrs_serializer.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "named_attrs_serializer.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/attr_utils.h" - -namespace ge { -graphStatus NamedAttrsSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - ge::NamedAttrs named_attrs; - const graphStatus ret = av.GetValue(named_attrs); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get named attrs."); - return GRAPH_FAILED; - } - auto func = def.mutable_func(); - - return Serialize(named_attrs, func); -} - -graphStatus NamedAttrsSerializer::Serialize(const ge::NamedAttrs &named_attr, proto::NamedAttrs* proto_attr) const { - GE_CHECK_NOTNULL(proto_attr); - proto_attr->set_name(named_attr.GetName().c_str()); - const auto mutable_attr = proto_attr->mutable_attr(); - GE_CHECK_NOTNULL(mutable_attr); - - const auto attrs = AttrUtils::GetAllAttrs(named_attr); - for (const auto &attr : attrs) { - const AnyValue attr_value = attr.second; - const auto serializer = AttrSerializerRegistry::GetInstance().GetSerializer(attr_value.GetValueTypeId()); - GE_CHECK_NOTNULL(serializer); - proto::AttrDef attr_def; - if (serializer->Serialize(attr_value, attr_def) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Attr serialized failed, name:[%s].", attr.first.c_str()); - return FAILED; - } - (*mutable_attr)[attr.first] = attr_def; - } - return GRAPH_SUCCESS; -} - -graphStatus NamedAttrsSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - ge::NamedAttrs value; - if (Deserialize(def.func(), value) != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - return av.SetValue(std::move(value)); -} - -graphStatus NamedAttrsSerializer::Deserialize(const proto::NamedAttrs &proto_attr, ge::NamedAttrs &named_attrs) const { - named_attrs.SetName(proto_attr.name()); - const auto proto_attr_map = proto_attr.attr(); - for (const auto &sub_proto_attr : proto_attr_map) { - const auto deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(sub_proto_attr.second.value_case()); - GE_CHECK_NOTNULL(deserializer); - AnyValue attr_value; - if (deserializer->Deserialize(sub_proto_attr.second, attr_value) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Attr deserialized failed, name:[%s].", sub_proto_attr.first.c_str()); - return FAILED; - } - if (named_attrs.SetAttr(sub_proto_attr.first, attr_value) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "NamedAttrs [%s] set attr [%s] failed.", - named_attrs.GetName().c_str(), sub_proto_attr.first.c_str()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -REG_GEIR_SERIALIZER(named_attr_serializer, NamedAttrsSerializer, GetTypeId(), proto::AttrDef::kFunc); -} // namespace ge diff --git a/graph/serialization/named_attrs_serializer.h b/graph/serialization/named_attrs_serializer.h deleted file mode 100644 index 31e93087336bf305a2f6a970f7dac877f0d767f5..0000000000000000000000000000000000000000 --- a/graph/serialization/named_attrs_serializer.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_NAMED_ATTRS_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_NAMED_ATTRS_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -#include "proto/ge_ir.pb.h" -#include "graph/ge_attr_value.h" - -namespace ge { -class NamedAttrsSerializer : public GeIrAttrSerializer { - public: - NamedAttrsSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; - - graphStatus Serialize(const ge::NamedAttrs &named_attr, proto::NamedAttrs *proto_attr) const; - graphStatus Deserialize(const proto::NamedAttrs &proto_attr, ge::NamedAttrs &named_attrs) const; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_NAMED_ATTRS_SERIALIZER_H_ diff --git a/graph/serialization/string_serializer.cc b/graph/serialization/string_serializer.cc deleted file mode 100644 index 49d48f4e0f96eae09e5761315b39d171c3c41b53..0000000000000000000000000000000000000000 --- a/graph/serialization/string_serializer.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "string_serializer.h" -#include -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_log.h" - -namespace ge { -graphStatus StringSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - std::string value; - const graphStatus ret = av.GetValue(value); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get string attr."); - return GRAPH_FAILED; - } - def.set_s(value); - return GRAPH_SUCCESS; -} - -graphStatus StringSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - return av.SetValue(def.s()); -} - -REG_GEIR_SERIALIZER(str_serializer, StringSerializer, GetTypeId(), proto::AttrDef::kS); -} // namespace ge diff --git a/graph/serialization/string_serializer.h b/graph/serialization/string_serializer.h deleted file mode 100644 index b92976f6218752ea2640c1d20085d6d5e0df21bb..0000000000000000000000000000000000000000 --- a/graph/serialization/string_serializer.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_STRING_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_STRING_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -namespace ge { -class StringSerializer : public GeIrAttrSerializer { - public: - StringSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_STRING_SERIALIZER_H_ diff --git a/graph/serialization/tensor_desc_serializer.cc b/graph/serialization/tensor_desc_serializer.cc deleted file mode 100644 index c2510caf666d3735b451cde5740399f675f5e569..0000000000000000000000000000000000000000 --- a/graph/serialization/tensor_desc_serializer.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "tensor_desc_serializer.h" - -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_util.h" -#include "graph/ge_tensor.h" - -namespace ge { -graphStatus TensorDescSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - GeTensorDesc tensor_desc; - const graphStatus ret = av.GetValue(tensor_desc); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get tensor_desc attr."); - return GRAPH_FAILED; - } - GeTensorSerializeUtils::GeTensorDescAsProto(tensor_desc, def.mutable_td()); - return GRAPH_SUCCESS; -} - -graphStatus TensorDescSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - GeTensorDesc tensor_desc; - const proto::TensorDescriptor &descriptor = def.td(); - GeTensorSerializeUtils::AssembleGeTensorDescFromProto(&descriptor, tensor_desc); - return av.SetValue(std::move(tensor_desc)); -} - -REG_GEIR_SERIALIZER(tensor_desc_serialzier, TensorDescSerializer, GetTypeId(), proto::AttrDef::kTd); -} // namespace ge diff --git a/graph/serialization/tensor_desc_serializer.h b/graph/serialization/tensor_desc_serializer.h deleted file mode 100644 index e9ccaf8f62294c26147669c46137fa4b6aadfe79..0000000000000000000000000000000000000000 --- a/graph/serialization/tensor_desc_serializer.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_GE_TENSOR_DESC_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_GE_TENSOR_DESC_SERIALIZER_H_ - -#include "attr_serializer.h" -#include "attr_serializer_registry.h" -#include "proto/ge_ir.pb.h" -#include "graph/ge_tensor.h" - -namespace ge { -class TensorDescSerializer : public GeIrAttrSerializer { - public: - TensorDescSerializer() noexcept = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_GE_TENSOR_DESC_SERIALIZER_H_ diff --git a/graph/serialization/tensor_serializer.cc b/graph/serialization/tensor_serializer.cc deleted file mode 100644 index 71c6b54ed894565d8a671bf0c4cff4ba0dc50ab6..0000000000000000000000000000000000000000 --- a/graph/serialization/tensor_serializer.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "tensor_serializer.h" -#include "proto/ge_ir.pb.h" -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_log.h" -#include "tensor_desc_serializer.h" -#include "graph/ge_tensor.h" - -namespace ge { -graphStatus TensorSerializer::Serialize(const AnyValue &av, proto::AttrDef &def) { - GeTensor ge_tensor; - const graphStatus ret = av.GetValue(ge_tensor); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "Failed to get tensor attr."); - return GRAPH_FAILED; - } - - GeTensorSerializeUtils::GeTensorAsProto(ge_tensor, def.mutable_t()); - return GRAPH_SUCCESS; -} - -graphStatus TensorSerializer::Deserialize(const proto::AttrDef &def, AnyValue &av) { - GeTensor ge_tensor; - GeTensorSerializeUtils::AssembleGeTensorFromProto(&def.t(), ge_tensor); - return av.SetValue(std::move(ge_tensor)); -} - -REG_GEIR_SERIALIZER(tesnor_serializer, TensorSerializer, GetTypeId(), proto::AttrDef::kT); -} // namespace ge diff --git a/graph/serialization/tensor_serializer.h b/graph/serialization/tensor_serializer.h deleted file mode 100644 index ade26f39a1fc6682a40e9fe2eec240342cc6a8e6..0000000000000000000000000000000000000000 --- a/graph/serialization/tensor_serializer.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_GE_TENSOR_SERIALIZER_H_ -#define METADEF_GRAPH_SERIALIZATION_GE_TENSOR_SERIALIZER_H_ - -#include "attr_serializer_registry.h" -#include "graph/ge_tensor.h" -#include "attr_serializer.h" - -namespace ge { -class TensorSerializer : public GeIrAttrSerializer { - public: - TensorSerializer() noexcept = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_GE_TENSOR_SERIALIZER_H_ diff --git a/graph/serialization/utils/serialization_util.cc b/graph/serialization/utils/serialization_util.cc deleted file mode 100644 index d0970ec5a5d91eeb0d1da8ec82177839456a551f..0000000000000000000000000000000000000000 --- a/graph/serialization/utils/serialization_util.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "serialization_util.h" -#include - -namespace ge { -const std::map kDataTypeMap = { - {DT_UNDEFINED, proto::DT_UNDEFINED}, - {DT_FLOAT, proto::DT_FLOAT}, - {DT_FLOAT16, proto::DT_FLOAT16}, - {DT_INT8, proto::DT_INT8}, - {DT_UINT8, proto::DT_UINT8}, - {DT_INT16, proto::DT_INT16}, - {DT_UINT16, proto::DT_UINT16}, - {DT_INT32, proto::DT_INT32}, - {DT_INT64, proto::DT_INT64}, - {DT_UINT32, proto::DT_UINT32}, - {DT_UINT64, proto::DT_UINT64}, - {DT_BOOL, proto::DT_BOOL}, - {DT_DOUBLE, proto::DT_DOUBLE}, - {DT_DUAL, proto::DT_DUAL}, - {DT_DUAL_SUB_INT8, proto::DT_DUAL_SUB_INT8}, - {DT_DUAL_SUB_UINT8, proto::DT_DUAL_SUB_UINT8}, - {DT_COMPLEX32, proto::DT_COMPLEX32}, - {DT_COMPLEX64, proto::DT_COMPLEX64}, - {DT_COMPLEX128, proto::DT_COMPLEX128}, - {DT_QINT8, proto::DT_QINT8}, - {DT_QINT16, proto::DT_QINT16}, - {DT_QINT32, proto::DT_QINT32}, - {DT_QUINT8, proto::DT_QUINT8}, - {DT_QUINT16, proto::DT_QUINT16}, - {DT_RESOURCE, proto::DT_RESOURCE}, - {DT_STRING_REF, proto::DT_STRING_REF}, - {DT_STRING, proto::DT_STRING}, - {DT_VARIANT, proto::DT_VARIANT}, - {DT_BF16, proto::DT_BF16}, - {DT_INT4, proto::DT_INT4}, - {DT_UINT1, proto::DT_UINT1}, - {DT_INT2, proto::DT_INT2}, - {DT_UINT2, proto::DT_UINT2}, - {DT_HIFLOAT8, proto::DT_HIFLOAT8}, - {DT_FLOAT8_E5M2, proto::DT_FLOAT8_E5M2}, - {DT_FLOAT8_E4M3FN, proto::DT_FLOAT8_E4M3FN}, - {DT_FLOAT8_E8M0, proto::DT_FLOAT8_E8M0}, - {DT_FLOAT6_E3M2, proto::DT_FLOAT6_E3M2}, - {DT_FLOAT6_E2M3, proto::DT_FLOAT6_E2M3}, - {DT_FLOAT4_E2M1, proto::DT_FLOAT4_E2M1}, - {DT_FLOAT4_E1M2, proto::DT_FLOAT4_E1M2} -}; - -void SerializationUtil::GeDataTypeToProto(const ge::DataType ge_type, proto::DataType &proto_type) { - auto iter = kDataTypeMap.find(ge_type); - if (iter != kDataTypeMap.end()) { - proto_type = iter->second; - return; - } - proto_type = proto::DT_UNDEFINED; -} - -void SerializationUtil::ProtoDataTypeToGe(const proto::DataType proto_type, ge::DataType &ge_type) { - for (auto iter : kDataTypeMap) { - if (iter.second == proto_type) { - ge_type = iter.first; - return; - } - } - ge_type = DT_UNDEFINED; -} -} // namespace ge diff --git a/graph/serialization/utils/serialization_util.h b/graph/serialization/utils/serialization_util.h deleted file mode 100644 index 9c19d71b1d9454b5d92158758fc0138fdca09180..0000000000000000000000000000000000000000 --- a/graph/serialization/utils/serialization_util.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_GRAPH_SERIALIZATION_UTILS_SERIALIZATION_UTIL_H_ -#define METADEF_GRAPH_SERIALIZATION_UTILS_SERIALIZATION_UTIL_H_ - -#include "proto/ge_ir.pb.h" -#include "graph/types.h" - -namespace ge { -class SerializationUtil { - public: - static void GeDataTypeToProto(const ge::DataType ge_type, proto::DataType &proto_type); - static void ProtoDataTypeToGe(const proto::DataType proto_type, ge::DataType &ge_type); - private: - SerializationUtil() = delete; -}; -} // namespace ge - -#endif // METADEF_GRAPH_SERIALIZATION_UTILS_SERIALIZATION_UTIL_H_ diff --git a/graph/stub/Makefile b/graph/stub/Makefile deleted file mode 100644 index 72610287c4475b6cb99f8540dab68429a4b1fcc8..0000000000000000000000000000000000000000 --- a/graph/stub/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -inc_path := $(shell pwd)/metadef/inc/external/ -out_path := $(shell pwd)/out/graph/lib64/stub/ -stub_path := $(shell pwd)/metadef/graph/stub/ - -mkdir_stub := $(shell mkdir -p $(out_path)) -graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/graph/type/ascend_string.cc b/graph/type/ascend_string.cc deleted file mode 100644 index 90f68cd5bb9940d32b04d0eef84af34a48b13315..0000000000000000000000000000000000000000 --- a/graph/type/ascend_string.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/ascend_string.h" -#include "debug/ge_log.h" -#include "common/util/mem_utils.h" -#include "base/type/ascend_string_impl.h" - -namespace ge { -AscendString::AscendString(const char_t *const name) { - AscendStringImpl::Construct(*this, name); -} - -AscendString::AscendString(const char_t *const name, size_t length) { - AscendStringImpl::Construct(*this, name, length); -} - -const char_t *AscendString::GetString() const { - return AscendStringImpl::GetString(*this); -} - -size_t AscendString::GetLength() const { - return AscendStringImpl::GetLength(*this); -} - -size_t AscendString::Hash() const { - return AscendStringImpl::Hash(*this); -} - -bool AscendString::operator<(const AscendString &d) const { - return AscendStringImpl::Lt(*this, d); -} - -bool AscendString::operator>(const AscendString &d) const { - return AscendStringImpl::Gt(*this, d); -} - -bool AscendString::operator==(const AscendString &d) const { - return AscendStringImpl::Eq(*this, d); -} - -bool AscendString::operator<=(const AscendString &d) const { - return AscendStringImpl::Le(*this, d); -} - -bool AscendString::operator>=(const AscendString &d) const { - return AscendStringImpl::Ge(*this, d); -} - -bool AscendString::operator!=(const AscendString &d) const { - return AscendStringImpl::Ne(*this, d); -} - -bool AscendString::operator==(const char_t *const d) const { - return AscendStringImpl::Eq(*this, d); -} - -bool AscendString::operator!=(const char_t *const d) const { - return AscendStringImpl::Ne(*this, d); -} - -size_t AscendString::Find(const AscendString &ascend_string) const { - return AscendStringImpl::Find(*this, ascend_string); -} -} // namespace ge diff --git a/graph/type/axis_type_info.cc b/graph/type/axis_type_info.cc deleted file mode 100644 index 7c76a86a99dcc09917176f0eab343e826cdaeaf1..0000000000000000000000000000000000000000 --- a/graph/type/axis_type_info.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "inc/graph/axis_type_info.h" - -namespace ge { -void AxisTypeInfo::AddInputCutInfo(CutInfo &input_cut_info) { - relate_inputs_.emplace_back(input_cut_info); -} - -void AxisTypeInfo::AddOutputCutInfo(CutInfo &output_cut_info) { - relate_outputs_.emplace_back(output_cut_info); -} - -graphStatus AxisTypeInfo::GetInputCutInfo(const size_t index, CutInfo &input_cut_info) const { - return DoGetCutInfo(relate_inputs_, index, input_cut_info); -} - -graphStatus AxisTypeInfo::GetOutputCutInfo(const size_t index, CutInfo &output_cut_info) const { - return DoGetCutInfo(relate_outputs_, index, output_cut_info); -} - -void AxisTypeInfo::AddInputValueCutInfo(const CutInfo &cut_info) { - relate_input_values_.emplace_back(cut_info); -} - -void AxisTypeInfo::AddOutputValueCutInfo(const CutInfo &cut_info) { - relate_output_values_.emplace_back(cut_info); -} - -graphStatus AxisTypeInfo::GetInputValueCutInfo(const size_t index, CutInfo &cut_info) const { - return DoGetCutInfo(relate_input_values_, index, cut_info); -} - -graphStatus AxisTypeInfo::GetOutputValueCutInfo(const size_t index, CutInfo &cut_info) const { - return DoGetCutInfo(relate_output_values_, index, cut_info); -} - -graphStatus AxisTypeInfo::DoGetCutInfo(const std::vector &cut_infos, - const size_t index, - CutInfo &cut_info) { - if (cut_infos.size() <= index) { - return GRAPH_FAILED; - } - cut_info = cut_infos[index]; - return GRAPH_SUCCESS; -} -} diff --git a/graph/type/sym_dtype.cc b/graph/type/sym_dtype.cc deleted file mode 100644 index c0e4dcaf24b744cf23d483ac89eb17bf071072d6..0000000000000000000000000000000000000000 --- a/graph/type/sym_dtype.cc +++ /dev/null @@ -1,723 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/type/sym_dtype.h" -#include "common/checker.h" -#include "graph/utils/attr_utils.h" -#include "graph/type/tensor_type_impl.h" -#include "graph/types.h" -#include "graph/utils/type_utils.h" -#include "op_common/data_type_utils.h" - -namespace ge { - -namespace { -graphStatus GetDtypeFromAttr(const OpDesc &op, const std::string &attr, DataType &dtype) { - GELOGI("Trying get dtype from attr %s of op %s", attr.c_str(), op.GetName().c_str()); - if (AttrUtils::GetDataType(op, attr, dtype)) { - return GRAPH_SUCCESS; - } - int32_t numeric_dtype = -1; - if (AttrUtils::GetInt(op, attr, numeric_dtype)) { - GE_WARN_ASSERT(numeric_dtype >= 0 && numeric_dtype < DT_MAX, "Invalid numeric dtype %d for sym %s of op %s", - numeric_dtype, attr.c_str(), op.GetName().c_str()); - dtype = static_cast(numeric_dtype); - return GRAPH_SUCCESS; - } - GELOGW("Op %s has no attr named %s", op.GetName().c_str(), attr.c_str()); - return GRAPH_FAILED; -} - -graphStatus GetListDtypeFromAttr(const OpDesc &op, const std::string &attr, std::vector &dtypes) { - GELOGI("Trying get list-dtype from attr %s of op %s", attr.c_str(), op.GetName().c_str()); - if (AttrUtils::GetListDataType(op, attr, dtypes)) { - return GRAPH_SUCCESS; - } - std::vector numeric_dtypes; - if (AttrUtils::GetListInt(op, attr, numeric_dtypes)) { - for (auto &numeric_dtype : numeric_dtypes) { - GE_WARN_ASSERT(numeric_dtype >= 0 && numeric_dtype < DT_MAX, "Invalid numeric dtype %d for sym %s of op %s", - numeric_dtype, attr.c_str(), op.GetName().c_str()); - dtypes.push_back(static_cast(numeric_dtype)); - } - return GRAPH_SUCCESS; - } - GELOGW("Op %s has no attr named %s", op.GetName().c_str(), attr.c_str()); - return GRAPH_FAILED; -} - -std::string ToString(const TensorType &types) { - std::string s = "["; - for (auto &dtype : types.tensor_type_impl_->GetMutableDateTypeSet()) { - s += TypeUtils::DataTypeToSerialString(dtype); - s += ","; - } - s += "]"; - return s; -} - -const char *ToString(const IrInputType &type) { - if (type == kIrInputRequired) { - return "Required"; - } - if (type == kIrInputOptional) { - return "Optional"; - } - if (type == kIrInputDynamic) { - return "Dynamic"; - } - return "Unknown"; -} - -graphStatus PromoteDtype(const DataType &left, const DataType &right, DataType &promoted_dtype) { - GE_WARN_ASSERT(left >= 0 && left < DT_MAX, "Invalid left dtype %d", left); - GE_WARN_ASSERT(right >= 0 && right < DT_MAX, "Invalid right dtype %d", right); - - promoted_dtype = opcommon::PromoteType(left, right); - GELOGD("Promoted dtype %s from %s and %s", TypeUtils::DataTypeToSerialString(promoted_dtype).c_str(), - TypeUtils::DataTypeToSerialString(left).c_str(), TypeUtils::DataTypeToSerialString(right).c_str()); - return GRAPH_SUCCESS; -} - -graphStatus PromoteDtype(const TypeOrTypes &left, const TypeOrTypes &right, TypeOrTypes &promoted_dtype) { - GE_WARN_ASSERT(left.IsListType() == right.IsListType(), "Trying promote %s with %s", left.DebugString().c_str(), - right.DebugString().c_str()); - if (left.IsListType()) { - std::vector left_dtypes; - std::vector right_dtypes; - - GE_WARN_ASSERT_GRAPH_SUCCESS(left.GetTypes(left_dtypes)); - GE_WARN_ASSERT_GRAPH_SUCCESS(right.GetTypes(right_dtypes)); - - GE_WARN_ASSERT(left_dtypes.size() == right_dtypes.size(), "Trying promote %s with %s", left.DebugString().c_str(), - right.DebugString().c_str()); - - std::vector data_types; - data_types.resize(left_dtypes.size()); - for (size_t i = 0U; i < left_dtypes.size(); i++) { - GE_WARN_ASSERT_GRAPH_SUCCESS(PromoteDtype(left_dtypes[i], right_dtypes[i], data_types[i])); - } - promoted_dtype.SetTypes(data_types); - return GRAPH_SUCCESS; - } - - DataType left_dtype; - DataType right_dtype; - GE_WARN_ASSERT_GRAPH_SUCCESS(left.GetType(left_dtype)); - GE_WARN_ASSERT_GRAPH_SUCCESS(right.GetType(right_dtype)); - - DataType dtype; - GE_WARN_ASSERT_GRAPH_SUCCESS(PromoteDtype(left_dtype, right_dtype, dtype)); - promoted_dtype.SetType(dtype); - return GRAPH_SUCCESS; -} -} // namespace - -graphStatus TypeOrTypes::GetType(DataType &type) const { - if (!initialized_ || is_list_ || types_.empty()) { - return GRAPH_FAILED; - } - type = types_[0]; - return GRAPH_SUCCESS; -} - -graphStatus TypeOrTypes::GetTypes(std::vector &types) const { - if (!initialized_ || !is_list_) { - return GRAPH_FAILED; - } - types = types_; - return GRAPH_SUCCESS; -} - -const DataType &TypeOrTypes::UnsafeGetType() const { - if (!initialized_ || is_list_ || (types_.size() != 1)) { - const static DataType kUndefined = DT_UNDEFINED; - return kUndefined; - } - return types_[0]; -} - -const std::vector &TypeOrTypes::UnsafeGetTypes() const { - if (!initialized_ || !is_list_) { - const static std::vector kUndefined{}; - return kUndefined; - } - return types_; -} - -void TypeOrTypes::SetType(const DataType &type) { - initialized_ = true; - is_list_ = false; - types_.clear(); - types_.emplace_back(type); -} - -void TypeOrTypes::SetTypes(const std::vector &types) { - initialized_ = true; - is_list_ = true; - types_ = types; -} - -std::string TypeOrTypes::DebugString() const { - if (!initialized_) { - return "Uninitialized"; - } - std::string ret = is_list_ ? "List[" : ""; - for (auto &type : types_) { - ret += TypeUtils::DataTypeToSerialString(type) + ","; - } - if (is_list_) { - ret += "]"; - } - return ret; -} - -// 不使用DATATYPE指定sym的取值范围时,sym的取值范围为所有数据类型 -SymDtype::SymDtype(const std::string &id) - : id_(id), - is_legacy_(true), - is_list_(false), - tensor_type_({}), - is_ordered_list_(false), - ordered_tensor_type_list_({}) {} - -const std::string &SymDtype::Id() const { - return id_; -} - -bool SymDtype::IsLegacy() const { - return is_legacy_; -} - -void SymDtype::BindIrInput(const std::string &ir_input, const IrInputType &input_type, size_t input_index) { - ir_inputs_.emplace_back(ir_input, input_type, input_index); -} - -void SymDtype::BindAllowedDtypes(const TensorType &types) { - is_legacy_ = false; - is_list_ = false; - tensor_type_ = types; -} - -void SymDtype::BindAllowedDtypes(const ListTensorType &types) { - is_legacy_ = false; - is_list_ = true; - tensor_type_ = types.tensor_type; -} - -void SymDtype::BindExpression(const std::shared_ptr &expression) { - is_legacy_ = false; - expression_ = expression; -} - -bool SymDtype::IsListType() const { - if (expression_ != nullptr) { - return expression_->IsListType(); - } - return is_list_; -} - -const std::string &SymDtype::DebugString() const { - std::string ret = id_ + ":"; - ret += (is_list_ ? "List" : "Oneof"); - ret += ToString(tensor_type_); - return id_; -} - -std::vector SymDtype::GetIrInputIndexes() const { - if (expression_ == nullptr) { - std::vector ir_input_indexes; - for (const auto &ir_input : ir_inputs_) { - ir_input_indexes.push_back(ir_input.index); - } - return ir_input_indexes; - } else { - return expression_->GetIrInputIndexes(); - } -} - -ExpressionType SymDtype::Type() const { - if (expression_ == nullptr) { - return ExpressionType::kSingle; - } else { - return expression_->Type(); - } -} - -graphStatus SymDtype::Eval(const OpDesc &op, TypeOrTypes &type_or_types) const { - GE_WARN_ASSERT(!is_legacy_, "Trying eval legacy sym dtype %s", id_.c_str()); - if (expression_ != nullptr) { - GELOGI("Eval sym dtype from expression of op %s", id_.c_str(), op.GetType().c_str()); - return expression_->Eval(op, type_or_types); - } - - if (IsListType()) { - std::vector dtypes; - GE_WARN_ASSERT_GRAPH_SUCCESS(Eval(op, dtypes)); - type_or_types.SetTypes(dtypes); - return GRAPH_SUCCESS; - } - - DataType single_dtype; - GE_WARN_ASSERT_GRAPH_SUCCESS(Eval(op, single_dtype)); - type_or_types.SetType(single_dtype); - return GRAPH_SUCCESS; -} - -std::string SymDtype::SymBackend::DebugString() const { - return std::string(ToString(type)) + "[" + std::to_string(index) + "] " + name; -} - -graphStatus SymDtype::Eval(const OpDesc &op, DataType &dtype) const { - if (ir_inputs_.empty()) { - GELOGI("Trying eval sym dtype from attr %s of op %s", id_.c_str(), op.GetType().c_str()); - if (AttrUtils::HasAttr(op, id_)) { - GE_WARN_ASSERT_GRAPH_SUCCESS(GetDtypeFromAttr(op, id_, dtype)); - GE_WARN_ASSERT(tensor_type_.tensor_type_impl_->IsDataTypeInRange(dtype)); - return GRAPH_SUCCESS; - } - GE_WARN_ASSERT(tensor_type_.tensor_type_impl_->GetMutableDateTypeSet().size() == 1, - "Op %s has no attr %s and sym %s allowed dtypes range is not one", op.GetType().c_str(), - id_.c_str()); - dtype = *tensor_type_.tensor_type_impl_->GetMutableDateTypeSet().begin(); - return GRAPH_SUCCESS; - } - - std::map> ir_input_2_range; - GE_WARN_ASSERT_GRAPH_SUCCESS(GetIrInputRawDescRange(const_cast(&op)->shared_from_this(), ir_input_2_range)); - GE_WARN_ASSERT(ir_input_2_range.size() == op.GetIrInputsSize(), "Failed get input instance info of %s %s", - op.GetName().c_str(), op.GetType().c_str()); - - std::set infered_dtypes; - for (auto &backend : ir_inputs_) { - auto &input_range = ir_input_2_range[backend.index]; - size_t start = input_range.first; - size_t end = input_range.first + input_range.second; - GELOGD("Sym %s of %s backend %s mapping to input desc[%zu:%zu)", id_.c_str(), op.GetName().c_str(), - backend.DebugString().c_str(), start, end); - - for (size_t i = start; i < end; i++) { - auto desc = op.MutableInputDesc(i); - GE_ASSERT_NOTNULL(desc); - GELOGI("Get dtype %s from %s input %s:%zu of op %s", - TypeUtils::DataTypeToSerialString(desc->GetDataType()).c_str(), ToString(backend.type), - backend.name.c_str(), i - start, op.GetName().c_str()); - infered_dtypes.insert(desc->GetDataType()); - } - } - - GE_WARN_ASSERT(infered_dtypes.size() == 1, "Infer dtype failed for op %s as %zu types infered", op.GetName().c_str(), - infered_dtypes.size()); - dtype = *infered_dtypes.begin(); - if (!tensor_type_.tensor_type_impl_->IsDataTypeInRange(dtype)) { - REPORT_INNER_ERR_MSG("EZ9999", "Sym %s of op %s %s infered dtype %s not in range %s", id_.c_str(), - op.GetName().c_str(), op.GetType().c_str(), TypeUtils::DataTypeToSerialString(dtype).c_str(), - ToString(tensor_type_).c_str()); - GELOGW("Sym %s infered dtype %s not in range %s", - id_.c_str(), TypeUtils::DataTypeToSerialString(dtype).c_str(), ToString(tensor_type_).c_str()); - return PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -graphStatus SymDtype::Eval(const OpDesc &op, std::vector &dtypes) const { - if (ir_inputs_.empty()) { - GELOGI("Eval sym list-dtype from attr %s of op %s", id_.c_str(), op.GetType().c_str()); - GE_WARN_ASSERT_GRAPH_SUCCESS(GetListDtypeFromAttr(op, id_, dtypes)); - for (auto &dtype : dtypes) { - GE_WARN_ASSERT(tensor_type_.tensor_type_impl_->IsDataTypeInRange(dtype), - "Sym %s infered one of list-dtype %s not in range %s", id_.c_str(), - TypeUtils::DataTypeToSerialString(dtype).c_str(), ToString(tensor_type_).c_str()); - } - return GRAPH_SUCCESS; - } - - std::map> ir_input_2_range; - GE_WARN_ASSERT_GRAPH_SUCCESS(GetIrInputRawDescRange(const_cast(&op)->shared_from_this(), ir_input_2_range)); - GE_WARN_ASSERT(ir_input_2_range.size() == op.GetIrInputsSize(), "Failed get input instance info of %s %s", - op.GetName().c_str(), op.GetType().c_str()); - - for (auto &backend : ir_inputs_) { - GE_WARN_ASSERT(backend.type == kIrInputDynamic, "List-type sym %s can not bind to %s input %s", id_.c_str(), - ToString(backend.type), backend.name.c_str()); - auto &input_range = ir_input_2_range[backend.index]; - size_t start = input_range.first; - size_t end = input_range.first + input_range.second; - GELOGD("Sym %s of %s backend %s mapping to input desc[%zu:%zu)", id_.c_str(), op.GetName().c_str(), - backend.DebugString().c_str(), start, end); - - std::vector input_dtypes; - for (size_t i = start; i < end; i++) { - auto desc = op.MutableInputDesc(i); - GE_ASSERT_NOTNULL(desc); - GELOGI("Get dtype %s from dynamic input %s:%zu of op %s", - TypeUtils::DataTypeToSerialString(desc->GetDataType()).c_str(), backend.name.c_str(), i - start, - op.GetName().c_str()); - input_dtypes.push_back(desc->GetDataType()); - } - - if (dtypes.empty()) { - dtypes = input_dtypes; - } else { - GE_WARN_ASSERT(input_dtypes.size() == dtypes.size(), "Infer dtype size mismatch %zu vs. %zu", input_dtypes.size(), - dtypes.size()); - for (size_t i = 0U; i < input_dtypes.size(); i++) { - GE_WARN_ASSERT(input_dtypes[i] == dtypes[i], "Sym list-dtype mismatch"); - } - } - } - - for (auto &dtype : dtypes) { - GE_WARN_ASSERT(tensor_type_.tensor_type_impl_->IsDataTypeInRange(dtype), - "Sym %s infered list-dtype %s not in range %s", id_.c_str(), - TypeUtils::DataTypeToSerialString(dtype).c_str(), ToString(tensor_type_).c_str()); - } - - return GRAPH_SUCCESS; -} - -void SymDtype::BindAllowedOrderedDtypes(const OrderedTensorTypeList &types) { - is_legacy_ = false; - is_list_ = true; - is_ordered_list_ = true; - ordered_tensor_type_list_ = types; -} - -PromotionSymDtypeExpression::PromotionSymDtypeExpression(const std::vector &syms) : syms_(syms) {} - -graphStatus PromotionSymDtypeExpression::Eval(const OpDesc &op, TypeOrTypes &type_or_types) const { - GE_WARN_ASSERT(syms_.size() > 1U, "Trying eval promotion sym with %zu syms", syms_.size()); - - GE_WARN_ASSERT_GRAPH_SUCCESS(syms_[0]->Eval(op, type_or_types)); - GELOGI("Promoting start with %s from sym %s", type_or_types.DebugString().c_str(), syms_[0]->DebugString().c_str()); - - TypeOrTypes next; - for (size_t i = 1U; i < syms_.size(); i++) { - GE_WARN_ASSERT_GRAPH_SUCCESS(syms_[i]->Eval(op, next)); - GELOGI("Promoting %s with %s from sym %s", type_or_types.DebugString().c_str(), next.DebugString().c_str(), - syms_[i]->DebugString().c_str()); - GE_WARN_ASSERT_GRAPH_SUCCESS(PromoteDtype(type_or_types, next, type_or_types)); - } - - return GRAPH_SUCCESS; -} - -ExpressionType PromotionSymDtypeExpression::Type() const { - return ExpressionType::kPromote; -} - -std::vector PromotionSymDtypeExpression::GetIrInputIndexes() const { - std::vector ir_input_indexes; - for (const auto &sym : syms_) { - const auto sym_ir_input_indexs = sym->GetIrInputIndexes(); - for (const auto sym_ir_input : sym_ir_input_indexs) { - ir_input_indexes.push_back(sym_ir_input); - } - } - return ir_input_indexes; -} - -namespace { -class DescEnv { - public: - DescEnv(const OpDescPtr &op, bool for_input) : op_(op), for_input_(for_input) {} - ~DescEnv() = default; - - bool IsDescValid(uint32_t index) const { - return for_input_ ? (op_->MutableInputDesc(index) != nullptr) : (op_->MutableOutputDesc(index) != nullptr); - } - - size_t NumDescs() const { - return for_input_ ? op_->GetAllInputsSize() : op_->GetOutputsSize(); - } - - std::string DebugString() const { - std::string str = "Env for "; - str += op_->GetName() + " "; - str += op_->GetType() + " "; - str += for_input_ ? "input" : "output"; - return str; - } - - private: - const OpDescPtr &op_; - bool for_input_; -}; -class IrIOSpec { - public: - IrIOSpec(const std::string &name, const IrInputType &type) { - name_ = name; - is_input_ = true; - if (type == kIrInputDynamic) { - is_dynamic_ = true; - } else if (type == kIrInputOptional) { - is_optional_ = true; - } else if (type == kIrInputRequired) { - is_required_ = true; - } else { - is_valid_ = false; - } - } - - IrIOSpec(const std::string &name, const IrOutputType &type) { - name_ = name; - if (type == kIrOutputDynamic) { - is_dynamic_ = true; - } else if (type == kIrOutputRequired) { - is_required_ = true; - } else { - is_valid_ = false; - } - } - ~IrIOSpec() = default; - - const std::string &GetName() const { - return name_; - } - std::string DebugString() const { - std::string str = (is_dynamic_ ? "Dynamic " : is_optional_ ? "Optional " : is_required_ ? "Required " : "Invalid "); - str += is_input_ ? "input " : "output "; - str += name_; - return str; - } - bool IsValid() const { - return is_valid_; - } - bool IsDynamic() const { - return is_dynamic_; - } - bool IsOptional() const { - return is_optional_; - } - bool IsRequired() const { - return is_required_; - } - - private: - std::string name_; - bool is_input_ = false; - bool is_valid_ = true; - bool is_dynamic_ = false; - bool is_optional_ = false; - bool is_required_ = false; -}; - -// 对于空的Dynamic输入和未传值的Optional输入,计算其起始index以展示更为友好 -size_t GetIrDescStartIndex(std::map> &ir_2_range, size_t ir_index) { - if (ir_index == 0U) { - return 0U; - } - - auto iter = ir_2_range.find(ir_index - 1U); - if (iter == ir_2_range.end()) { - return 0U; - } - return iter->second.first + iter->second.second; -} - -graphStatus MappingDynamicIrDesc(const std::vector &ir_specs, const DescEnv &desc_env, - const std::map &name2idx, - std::map> &ir_2_range) { - GELOGD("Start mapping dynamic ir desc for %s", desc_env.DebugString().c_str()); - for (size_t ir_io_idx = 0U; ir_io_idx < ir_specs.size(); ir_io_idx++) { - const auto &ir_spec = ir_specs[ir_io_idx]; - GE_WARN_ASSERT(ir_spec.IsValid(), "Invalid ir spec %s", ir_spec.DebugString().c_str()); - if (!ir_spec.IsDynamic()) { // 优先处理Dynamic类型的IR输入 - continue; - } - std::set indexes; // Dynamic类型的IR输入对应的多个index - size_t num_instances = 0U; - for (; num_instances < name2idx.size(); num_instances++) { - auto iter = name2idx.find(ir_spec.GetName() + std::to_string(num_instances)); - if (iter == name2idx.end()) { - break; - } - indexes.insert(iter->second); - } - // 校验Dynamic类型的IR IO对应的多个index连续 - GE_WARN_ASSERT((indexes.size() <= 1U) || (*indexes.rbegin() - *indexes.begin() == (indexes.size() - 1U))); - if (indexes.empty()) { - GELOGD("Dynamic ir spec %s has no instance", ir_spec.DebugString().c_str()); - ir_2_range.emplace(ir_io_idx, std::make_pair(GetIrDescStartIndex(ir_2_range, ir_io_idx), 0U)); - } else { - ir_2_range.emplace(ir_io_idx, std::make_pair(*indexes.begin(), indexes.size())); - GELOGD("Mapping %s to desc[%zu, %zu)", ir_spec.DebugString().c_str(), *indexes.begin(), - *indexes.begin() + indexes.size()); - } - } - return GRAPH_SUCCESS; -} - -void UpdateRawDescInstanceShifts(std::vector &desc_instance_shifts, size_t elim_index) { - if (elim_index >= desc_instance_shifts.size()) { - return; - } - auto iter = desc_instance_shifts.begin() + elim_index + 1U; - for (; iter != desc_instance_shifts.end(); iter++) { - (*iter)++; - } -} - -graphStatus MappingNonDynamicIrDesc(const std::vector &ir_specs, const DescEnv &desc_env, - const std::vector> &name2index_left, - const bool &require_raw_index, - std::map> &ir_2_range) { - GELOGD("Start mapping non-dynamic ir desc for %s", desc_env.DebugString().c_str()); - std::vector desc_instance_shifts; - desc_instance_shifts.resize(desc_env.NumDescs(), 0U); - - auto iter = name2index_left.begin(); - for (size_t ir_io_idx = 0U; ir_io_idx < ir_specs.size(); ir_io_idx++) { - const auto &ir_spec = ir_specs[ir_io_idx]; - if (ir_spec.IsDynamic()) { // 已经处理过Dynamic类型的IR输入 - continue; - } - - if (iter == name2index_left.end()) { // 只允许Optional的IR输入没有对应的desc,对应Optional在IR最后且没有Desc信息 - if (!ir_spec.IsOptional()) { - GELOGW("No desc left for %s", ir_spec.DebugString().c_str()); - return GRAPH_SUCCESS; - } - ir_2_range.emplace(ir_io_idx, std::make_pair(GetIrDescStartIndex(ir_2_range, ir_io_idx), 0U)); - continue; - } - - auto &name = iter->first; - auto &index = iter->second; - - if (ir_spec.GetName() != name) { // 如果当前名字和IR不一致,需要确保不是乱序,即没有与IR名字对应的Desc存在 - for (auto &name2index : name2index_left) { - GE_WARN_ASSERT(ir_spec.GetName() != name2index.first, "Found desc for %s index %u, while current name is %s", - ir_spec.DebugString().c_str(), name2index.second, name.c_str()); - } - } - - if (!ir_spec.IsOptional()) { // 非可选,则认为是自行构造的非标IR - iter++; - ir_2_range.emplace(ir_io_idx, std::make_pair(index, 1U)); - GELOGD("Mapping %s to desc %zu named %s", ir_spec.DebugString().c_str(), index, name.c_str()); - continue; - } - - if (name != ir_spec.GetName()) { // 对应Optional不在尾部且未传入 - GELOGD("Ir spec %s has no instance as desc[%u] named %s", ir_spec.DebugString().c_str(), index, name.c_str()); - ir_2_range.emplace(ir_io_idx, std::make_pair(index, 0U)); - continue; - } - - iter++; - if (desc_env.IsDescValid(index)) { // 对应Optional传入有效值 - GELOGD("Mapping %s desc[%zu]", ir_spec.DebugString().c_str(), index); - ir_2_range.emplace(ir_io_idx, std::make_pair(index, 1U)); - } else { // Optional传入无效值,对实例index进行调整(实例index只会保存非nullptr的输入) - GELOGD("Skip mapping %s to invalid desc[%zu]", ir_spec.DebugString().c_str(), index); - ir_2_range.emplace(ir_io_idx, std::make_pair(index, 0U)); - UpdateRawDescInstanceShifts(desc_instance_shifts, index); - } - } - - if (!require_raw_index) { - for (auto &item : ir_2_range) { - auto &start = item.second.first; - auto &num = item.second.second; - size_t shift = (start >= desc_instance_shifts.size() ? 0U : desc_instance_shifts[start]); - start = (start > shift) ? (start - shift) : 0U; - GELOGD("Re-mapping %s to desc[%zu, %zu) shift(-%zu)", ir_specs[item.first].DebugString().c_str(), start, - start + num, shift); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus GetIrDescRange(const std::vector &ir_specs, const std::map &name2idx, - const DescEnv &desc_env, const bool &require_raw_index, - std::map> &ir_2_range) { - GELOGD("Start get desc range for %s", desc_env.DebugString().c_str()); - for (auto &ir_spec : ir_specs) { - GELOGD(" Spec %s", ir_spec.DebugString().c_str()); - } - - std::map idx2name; - for (auto &item : name2idx) { - GELOGD(" Desc name %s index %d", item.first.c_str(), item.second); - idx2name.emplace(item.second, item.first); - } - GE_WARN_ASSERT(idx2name.size() == name2idx.size(), "Found %zu names, while %zu indexes", idx2name.size(), - name2idx.size()); - if (!idx2name.empty()) { - GE_WARN_ASSERT(idx2name.rbegin()->first == idx2name.size() - 1U); // 拦截index不连续 - } - - // 首先确定Dynamic类型的IR IO对应的index范围,对于IR构图场景,用户会通过create_dynmaic_xx接口创建多个输入Desc, - // 但是Desc在所有desc中的位置,是受调用时的参数决定的,默认情况下,都向尾部追加,会出现先定义的IR输入或输出对应的desc,在后定义的之后 - GE_WARN_ASSERT_GRAPH_SUCCESS(MappingDynamicIrDesc(ir_specs, desc_env, name2idx, ir_2_range)); - - std::vector index_consumed; // index对应的desc是否已经决定对应关系 - index_consumed.resize(name2idx.size(), false); - for (auto &item : ir_2_range) { - auto &range = item.second; - for (size_t i = range.first; i < range.first + range.second; i++) { - index_consumed[i] = true; - } - } - - std::vector> name2index_left; - for (size_t i = 0U; i < index_consumed.size(); i++) { - if (!index_consumed[i]) { // 未被使用的index顺序排列 - name2index_left.emplace_back(idx2name[i], static_cast(i)); - } - } - - // 确定非Dynamic类型的IR IO对应的index范围 - GE_WARN_ASSERT_GRAPH_SUCCESS( - MappingNonDynamicIrDesc(ir_specs, desc_env, name2index_left, require_raw_index, ir_2_range)); - - // 不校验所有的index都决定了对应的IR输入(存在算子追加非IR输入的场景,CCB裁决框架适配支持) - - return GRAPH_SUCCESS; -} - -graphStatus GetIrInputDescRange(const OpDescPtr &op, const bool &require_raw_index, - std::map> &ir_input_2_range) { - GE_ASSERT_NOTNULL(op); - std::vector ir_specs; - for (auto &item : op->GetIrInputs()) { - ir_specs.emplace_back(item.first, item.second); - } - const std::map &name2idx = op->GetAllInputName(); - DescEnv desc_env(op, true); - - return GetIrDescRange(ir_specs, name2idx, desc_env, require_raw_index, ir_input_2_range); -} -} // namespace - -// 获取输入IR对应的实例Desc的index范围,实例Desc中会去除未传值的Optional输入Desc -graphStatus GetIrInputInstanceDescRange(const OpDescPtr &op, - std::map> &ir_input_2_range) { - return GetIrInputDescRange(op, false, ir_input_2_range); -} - -// 获取输入IR对应的全部Desc的index范围,包含未传值的Optional输入Desc -graphStatus GetIrInputRawDescRange(const OpDescPtr &op, std::map> &ir_input_2_range) { - return GetIrInputDescRange(op, true, ir_input_2_range); -} - -graphStatus GetIrOutputDescRange(const OpDescPtr &op, std::map> &ir_output_2_range) { - GE_ASSERT_NOTNULL(op); - std::vector ir_specs; - for (auto &item : op->GetIrOutputs()) { - ir_specs.emplace_back(item.first, item.second); - } - const std::map &name2idx = op->GetAllOutputName(); - DescEnv desc_env(op, false); - - return GetIrDescRange(ir_specs, name2idx, desc_env, true, ir_output_2_range); -} -} // namespace ge diff --git a/graph/type/sym_dtype.h b/graph/type/sym_dtype.h deleted file mode 100644 index 22775bc2597313943f12539cf489afa47ac64bb6..0000000000000000000000000000000000000000 --- a/graph/type/sym_dtype.h +++ /dev/null @@ -1,180 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_GRAPH_SYM_DTYPE_H_ -#define METADEF_CXX_GRAPH_SYM_DTYPE_H_ - -#include "graph/op_desc.h" -#include "graph/utils/type_utils.h" - -namespace ge { -class OrderedTensorTypeList { - public: - explicit OrderedTensorTypeList(const std::initializer_list &initial_types) : initial_types_(initial_types) { - } - std::vector GetOrderedDtypes() const { - return initial_types_; - } - bool IsDataTypeInRange(ge::DataType data_type) const { - return std::find(initial_types_.begin(), initial_types_.end(), data_type) != initial_types_.end(); - } - std::vector GetDtypeIndexs(ge::DataType data_type) const { - std::vector indices; - size_t idx = 0; - for (auto it = initial_types_.begin(); it != initial_types_.end(); ++it, ++idx) { - if (*it == data_type) { - indices.push_back(idx); - } - } - return indices; - } - std::string ToString() const { - std::string s = "["; - for (const auto &dtype : initial_types_) { - s += TypeUtils::DataTypeToSerialString(dtype); - s += ","; - } - s += "]"; - return s; - } - private: - std::vector initial_types_; -}; - -class TypeOrTypes { - public: - TypeOrTypes() : initialized_(false), is_list_(false) {} - ~TypeOrTypes() = default; - - bool IsListType() const { - return is_list_; - } - - graphStatus GetType(DataType &type) const; - graphStatus GetTypes(std::vector &types) const; - - const DataType &UnsafeGetType() const; - const std::vector &UnsafeGetTypes() const; - - void SetType(const DataType &type); - void SetTypes(const std::vector &types); - - std::string DebugString() const; - - private: - bool initialized_; - bool is_list_; - std::vector types_; -}; - -enum class ExpressionType { - kSingle, - kPromote -}; - -// 用于支持符号推导的数据类型表达式,它表达一个由Sym组成的表达式 -class SymDtypeExpression { - public: - // 对Sym表达式进行基于op上下文的实际值计算 - virtual graphStatus Eval(const OpDesc &op, TypeOrTypes &type_or_types) const = 0; - virtual bool IsListType() const = 0; - virtual ExpressionType Type() const = 0; - virtual std::vector GetIrInputIndexes() const = 0; - virtual ~SymDtypeExpression() = default; -}; - -// Sym类型,每个输入或输出对应一个Sym,多个输入或输出可以对应同一个Sym -class SymDtype : public SymDtypeExpression { - public: - explicit SymDtype(const std::string &id); - ~SymDtype() override = default; - - const std::string &Id() const; // Sym的标识,与DATATYPE中声明时的标识一致 - bool IsListType() const override; // 返回Sym是否对为ListType类型 - ExpressionType Type() const override; - - graphStatus Eval(const OpDesc &op, TypeOrTypes &type_or_types) const override; - - // 绑定sym对应的IR输入名称,以及IR输入类型。如果未绑定任何输入,在推导时会尝试从属性中获取 - void BindIrInput(const std::string &ir_input, const IrInputType &input_type, size_t input_index); - // 设置Sym的取值范围或计算方式(DATATYPE声明时调用) - void BindAllowedDtypes(const TensorType &types); - void BindAllowedDtypes(const ListTensorType &types); - void BindAllowedOrderedDtypes(const OrderedTensorTypeList &types); - void BindExpression(const std::shared_ptr &expression); - - bool IsLegacy() const; // 是否为Legacy的Sym,未通过DATATYPE声明的Sym为Legacy的Sym - bool IsOrderedList() const { - return is_ordered_list_; - } - - const std::string &DebugString() const; - - std::vector GetIrInputIndexes() const override; - - TensorType GetTensorType() const { - return tensor_type_; - } - - OrderedTensorTypeList GetOrderedTensorTypeList() const { - return ordered_tensor_type_list_; - } - protected: - graphStatus Eval(const OpDesc &op, DataType &dtype) const; - graphStatus Eval(const OpDesc &op, std::vector &dtypes) const; - - std::string id_; - bool is_legacy_; // 是否为Legacy方式的IR,对于Legacy方式的IR,不支持类型推导及类型校验 - - bool is_list_; // Sym是否为ListType类型 - TensorType tensor_type_; // Sym的取值范围 - bool is_ordered_list_; - OrderedTensorTypeList ordered_tensor_type_list_; - - struct SymBackend { - SymBackend(const std::string &input_name, const IrInputType &input_type, size_t input_index) - : type(input_type), index(input_index), name(input_name) {} - IrInputType type; - size_t index; - std::string name; - std::string DebugString() const; - }; - - std::vector ir_inputs_; // Sym的对应的输入实体,与expression_互斥 - std::shared_ptr expression_; // Sym的计算表达式,与ir_inputs_互斥 -}; - -// 表达类型提升的Sym表达式 -class PromotionSymDtypeExpression : public SymDtypeExpression { - public: - // 表达类型提升的Sym计算表达,入参syms中的sym进行两两提升,对ListType类型的sym,会继续对sym间对应位置的Dtype进行提升 - explicit PromotionSymDtypeExpression(const std::vector &syms); - - graphStatus Eval(const OpDesc &op, TypeOrTypes &type_or_types) const override; - bool IsListType() const override { - return std::all_of(syms_.begin(), syms_.end(), - [](const SymDtype *sym) { return (sym != nullptr) && sym->IsListType(); }); - } - ExpressionType Type() const override; - std::vector GetIrInputIndexes() const override; - - private: - std::vector syms_; -}; - -// 获取输入IR对应的实例Desc的index范围,实例Desc中会去除未传值的Optional输入Desc -graphStatus GetIrInputInstanceDescRange(const OpDescPtr &op, - std::map> &ir_input_2_range); - -// 获取输入IR对应的全部Desc的index范围,包含未传值的Optional输入Desc -graphStatus GetIrInputRawDescRange(const OpDescPtr &op, std::map> &ir_input_2_range); - -graphStatus GetIrOutputDescRange(const OpDescPtr &op, std::map> &ir_output_2_range); -} // namespace ge -#endif // METADEF_CXX_GRAPH_SYM_DTYPE_H_ diff --git a/graph/type/tensor_type_impl.h b/graph/type/tensor_type_impl.h deleted file mode 100644 index b172d162290d57ebff0f18fab84ad00d5796bfef..0000000000000000000000000000000000000000 --- a/graph/type/tensor_type_impl.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_TENSOR_TYPE_IMPL_H -#define METADEF_CXX_TENSOR_TYPE_IMPL_H -#include -#include "graph/types.h" -namespace ge { -class TensorTypeImpl { - public: - TensorTypeImpl() = default; - ~TensorTypeImpl() = default; - - std::set &GetMutableDateTypeSet() { - return dt_set_; - } - bool IsDataTypeInRange(const DataType data_type) const { - return (dt_set_.count(data_type) > 0); - } - private: - std::set dt_set_; -}; -} // namespace ge - -#endif // METADEF_CXX_TENSOR_TYPE_IMPL_H diff --git a/graph/type/types.cc b/graph/type/types.cc deleted file mode 100644 index 99fa65f20bed278e2274a7bb3b233cfb5075854f..0000000000000000000000000000000000000000 --- a/graph/type/types.cc +++ /dev/null @@ -1,170 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/graph/types.h" -#include -#include -#include "common/ge_common/debug/ge_log.h" -#include "graph/ge_error_codes.h" -#include "graph/utils/type_utils.h" - -namespace ge { -const char_t *GetFormatName(Format format) { - static const char_t *names[FORMAT_END] = { - "NCHW", - "NHWC", - "ND", - "NC1HWC0", - "FRACTAL_Z", - "NC1C0HWPAD", // 5 - "NHWC1C0", - "FSR_NCHW", - "FRACTAL_DECONV", - "C1HWNC0", - "FRACTAL_DECONV_TRANSPOSE", // 10 - "FRACTAL_DECONV_SP_STRIDE_TRANS", - "NC1HWC0_C04", - "FRACTAL_Z_C04", - "CHWN", - "DECONV_SP_STRIDE8_TRANS", // 15 - "HWCN", - "NC1KHKWHWC0", - "BN_WEIGHT", - "FILTER_HWCK", - "LOOKUP_LOOKUPS", // 20 - "LOOKUP_KEYS", - "LOOKUP_VALUE", - "LOOKUP_OUTPUT", - "LOOKUP_HITS", - "C1HWNCoC0", // 25 - "MD", - "NDHWC", - "UNKNOWN", // FORMAT_FRACTAL_ZZ - "FRACTAL_NZ", - "NCDHW", // 30 - "DHWCN", - "NDC1HWC0", - "FRACTAL_Z_3D", - "CN", - "NC", // 35 - "DHWNC", - "FRACTAL_Z_3D_TRANSPOSE", - "FRACTAL_ZN_LSTM", - "FRACTAL_Z_G", - "UNKNOWN", // 40, FORMAT_RESERVED - "UNKNOWN", // FORMAT_ALL - "UNKNOWN", // FORMAT_NULL - "ND_RNN_BIAS", - "FRACTAL_ZN_RNN", - "NYUV", // 45 - "NYUV_A", - "NCL", - "FRACTAL_Z_WINO", - "C1HWC0", - "FRACTAL_NZ_C0_16", - "FRACTAL_NZ_C0_32", - }; - if (format >= FORMAT_END) { - return "UNKNOWN"; - } - return names[format]; -} - -static int64_t CeilDiv(const int64_t n1, const int64_t n2) { - if (n1 == 0) { - return 0; - } - return (n2 != 0) ? (((n1 - 1) / n2) + 1) : 0; -} - -static Status CheckInt64MulOverflow(const int64_t a, const int64_t b) { - if (a > 0) { - if (b > 0) { - if (a > (INT64_MAX / b)) { - return FAILED; - } - } else { - if (b < (INT64_MIN / a)) { - return FAILED; - } - } - } else { - if (b > 0) { - if (a < (INT64_MIN / b)) { - return FAILED; - } - } else { - if ((a != 0) && (b < (INT64_MAX / a))) { - return FAILED; - } - } - } - return SUCCESS; -} - -int64_t GetSizeInBytes(int64_t element_count, DataType data_type) { - if (element_count < 0) { - GELOGW("[Check][param]GetSizeInBytes failed, element_count:%" PRId64 " less than 0.", element_count); - return -1; - } - uint32_t type_size = 0U; - if (!TypeUtils::GetDataTypeLength(data_type, type_size)) { - GELOGW("[Check][DataType]GetSizeInBytes failed, data_type:%d not support.", data_type); - return -1; - } else if (type_size > kDataTypeSizeBitOffset) { - const auto bit_size = type_size - kDataTypeSizeBitOffset; - if (CheckInt64MulOverflow(element_count, static_cast(bit_size)) == FAILED) { - GELOGW("[Check][overflow]GetSizeInBytes failed, when multiplying %" PRId64 " and %d.", element_count, bit_size); - return -1; - } - return CeilDiv(element_count * bit_size, kBitNumOfOneByte); - } else { - if (CheckInt64MulOverflow(element_count, static_cast(type_size)) == FAILED) { - GELOGW("[Check][overflow]GetSizeInBytes failed, when multiplying %" PRId64 " and %" PRId32 ".", - element_count, type_size); - return -1; - } - return element_count * type_size; - } -} - -std::vector Promote::Syms() const { - std::vector result; - if (data_ == nullptr) { - return result; - } - auto &syms = *static_cast *>(data_.get()); - result.reserve(syms.size()); - for (const auto &sym : syms) { - result.push_back(sym.c_str()); - } - return result; -} - -Promote::Promote(const std::initializer_list &syms) { - data_ = std::shared_ptr(new (std::nothrow) std::vector(), - [](void *ptr) { delete static_cast *>(ptr); }); - if (data_ != nullptr) { - for (const auto &sym : syms) { - static_cast *>(data_.get())->emplace_back((sym == nullptr) ? "" : sym); - } - } -} - -Promote::Promote(Promote &&other) noexcept { - data_ = std::move(other.data_); -} - -Promote &Promote::operator=(Promote &&other) noexcept { - if (this != &other) { - data_ = std::move(other.data_); - } - return *this; -} -} // namespace ge diff --git a/graph/utils/anchor_utils.cc b/graph/utils/anchor_utils.cc deleted file mode 100644 index c7c2b1b5cbd611aee4da7c1b7678e048962b8b67..0000000000000000000000000000000000000000 --- a/graph/utils/anchor_utils.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/anchor_utils.h" -#include -#include "graph/debug/ge_util.h" -#include "common/ge_common/debug/ge_log.h" - -namespace ge { -// Get anchor status -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorStatus AnchorUtils::GetStatus(const DataAnchorPtr &data_anchor) { - if (data_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param data_anchor is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] The input data anchor is invalid."); - return ANCHOR_RESERVED; - } - return data_anchor->status_; -} - -// Set anchor status -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus AnchorUtils::SetStatus(const DataAnchorPtr &data_anchor, - const AnchorStatus anchor_status) { - if ((data_anchor == nullptr) || (anchor_status == ANCHOR_RESERVED)) { - REPORT_INNER_ERR_MSG("E18888", "The input data anchor or input data format is invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] The input data anchor or input data format is invalid."); - return GRAPH_FAILED; - } - data_anchor->status_ = anchor_status; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY int32_t AnchorUtils::GetIdx(const AnchorPtr &anchor) { - // Check if it can add edge between DataAnchor - const auto data_anchor = Anchor::DynamicAnchorCast(anchor); - if (data_anchor != nullptr) { - return data_anchor->GetIdx(); - } - // Check if it can add edge between ControlAnchor - const auto ctrl_anchor = Anchor::DynamicAnchorCast(anchor); - if (ctrl_anchor != nullptr) { - return ctrl_anchor->GetIdx(); - } - return -1; -} -} // namespace ge diff --git a/graph/utils/args_format_desc.cc b/graph/utils/args_format_desc.cc deleted file mode 100644 index 8f53283857210313473b1c5cd6af329f6fccb5f2..0000000000000000000000000000000000000000 --- a/graph/utils/args_format_desc.cc +++ /dev/null @@ -1,903 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/args_format_desc.h" -#include -#include -#include -#include -#include "common/checker.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/op_desc.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/anchor.h" -#include "graph/compute_graph.h" -#include "graph/debug/ge_attr_define.h" - - -namespace ge { -constexpr size_t kMaxDimNum = 25UL; -constexpr size_t kMaxWorkspaceNum = 16UL; -constexpr int32_t kDecimalCarry = 10; -constexpr int32_t kAsciiZero = 48; -constexpr int32_t kDigitFormatCnt = 1; -constexpr int32_t kAmbiguousIrIdx = -1; - -using ParseFunc = - std::function &)>; -using GetArgsSize = std::function; -using SerializeFunc = std::function; - -struct PatternHandler { - ParseFunc parse; - GetArgsSize getArgsSize; - SerializeFunc serialize; - AddrType type; -}; - -graphStatus FindSkSubNode(const OpDescPtr &sk_op, const int32_t id, NodePtr &sub_node) { - GE_ASSERT_NOTNULL(sk_op); - ComputeGraphPtr sub_graph = nullptr; - sub_graph = sk_op->TryGetExtAttr("_sk_sub_graph", sub_graph); - GE_ASSERT_NOTNULL(sub_graph); - for (const auto &node : sub_graph->GetDirectNode()) { - GE_ASSERT_NOTNULL(node); - if (node->GetOpDesc()->GetId() == static_cast(id)) { - sub_node = node; - GELOGI("find %d sub node %s from sk node %s", id, node->GetNamePtr(), sk_op->GetNamePtr()); - return GRAPH_SUCCESS; - } - } - GELOGE(GRAPH_FAILED, "can not find %d sub node from sk node %s", id, sk_op->GetNamePtr()); - return GRAPH_FAILED; -} - -static std::set kNeedEasyParserTypes{ - AddrType::INPUT, AddrType::OUTPUT, AddrType::INPUT_DESC, AddrType::OUTPUT_DESC, - AddrType::INPUT_INSTANCE, AddrType::OUTPUT_INSTANCE, AddrType::WORKSPACE}; - -static graphStatus DefaultCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - (void) op_desc; - (void) arg_desc; - size += sizeof(uintptr_t); - return GRAPH_SUCCESS; -} - -static graphStatus DefaultParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &args_desc) { - (void) op_desc; - (void) pattern_str; - args_desc.push_back({type, kAmbiguousIrIdx, false, {0}}); - return GRAPH_SUCCESS; -} - -static graphStatus PlaceholderParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &args_desc) { - (void) op_desc; - auto width = ArgsFormatWidth::BIT64; - if (pattern_str == ".32b") { - width = ArgsFormatWidth::BIT32; - } else if (!pattern_str.empty()) { - GELOGE(PARAM_INVALID, "Args format [%s] matched failed, it may be unsupported.", pattern_str.c_str()); - return GRAPH_FAILED; - } - args_desc.push_back({type, static_cast(width), false, {0}}); - return GRAPH_SUCCESS; -} - -static void PlaceholderSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - ss << pattern; - if (arg_desc.ir_idx == static_cast(ArgsFormatWidth::BIT32)) { - ss << ".32b"; - } -} - -static void DefaultSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - (void) arg_desc; - ss << pattern; -} - -static void FftsTilingSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - ss << pattern; - if (arg_desc.ir_idx == 0) { - ss << ".non_tail"; - } else { - ss << ".tail"; - } -} - -static void ArrayLikeSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - ss << pattern; - if (arg_desc.ir_idx >= 0) { - ss << std::to_string(arg_desc.ir_idx); - if (!arg_desc.folded) { - ss << '*'; - } - } else { - ss << '*'; - } -} - -static graphStatus WorkspaceCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - (void) op_desc; - if (arg_desc.ir_idx == kAmbiguousIrIdx) { - size += sizeof(uintptr_t) * kMaxWorkspaceNum; - } else { - size += sizeof(uintptr_t); - } - return GRAPH_SUCCESS; -} - -static graphStatus EventAddrParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &args_desc) { - (void) op_desc; - int32_t ir_idx = 0; - const std::string prefix = "event_addr"; - for (size_t i = prefix.size(); i < pattern_str.size(); ++i) { - if (isdigit(pattern_str[i])) { - ir_idx = ir_idx * kDecimalCarry + static_cast(pattern_str[i]) - kAsciiZero; - } - } - args_desc.push_back({type, ir_idx, false, {0}}); - return GRAPH_SUCCESS; -} - -static void EventAddrSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - ss << pattern << arg_desc.ir_idx << "*"; -} - -static graphStatus WorkspaceParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &args_desc) { - (void) op_desc; - if (pattern_str == "ws*") { - args_desc.push_back({type, kAmbiguousIrIdx, false, {0}}); - return GRAPH_SUCCESS; - } - int32_t ir_idx = 0; - for (size_t i = 2UL; i < pattern_str.size(); ++i) { - if (isdigit(pattern_str[i])) { - ir_idx = ir_idx * kDecimalCarry + static_cast(pattern_str[i]) - kAsciiZero; - } - } - args_desc.push_back({type, ir_idx, false, {0}}); - return GRAPH_SUCCESS; -} - -static graphStatus InputCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - const auto &ir_inputs = op_desc->GetIrInputs(); - size_t count = 0UL; - if (arg_desc.ir_idx >= 0) { - // 非通配符场景 - GE_ASSERT((static_cast(arg_desc.ir_idx) < ir_inputs.size()), "ir_index [%d] is out of range", - arg_desc.ir_idx); - if (ir_inputs[arg_desc.ir_idx].second == IrInputType::kIrInputDynamic) { - if (arg_desc.folded) { - ++count; // pointer to addr - } - int32_t dyn_num = 0; - for (auto &iter : op_desc->GetAllInputName()) { - if (iter.first == ir_inputs[arg_desc.ir_idx].first + std::to_string(dyn_num)) { - ++dyn_num; - ++count; // real input_addr - } - } - } else { - ++count; - } - } else { - // 通配符场景,非动态输入默认展开, 动态输入按照i0形式折叠 - for (const auto &ir_input : ir_inputs) { - ++count; - if (ir_input.second == IrInputType::kIrInputDynamic) { - int32_t dyn_num = 0; - for (auto &iter : op_desc->GetAllInputName()) { - if (iter.first == ir_input.first + std::to_string(dyn_num)) { - ++count; // real input addr - ++dyn_num; - } - } - } - } - } - size += count * sizeof(uintptr_t); - - return GRAPH_SUCCESS; -} - -static graphStatus InputInstanceParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - GE_ASSERT_NOTNULL(op_desc); - const size_t valid_input_size = op_desc->GetInputsSize(); - if (pattern_str == "i_instance*") { - // 为了方便加载时使用,通配符场景解析后默认展开成多个 - for (size_t i = 0UL; i < valid_input_size; ++i) { - arg_descs.push_back({type, static_cast(i), false, {0}}); - } - } else { - int32_t ir_idx{0}; - GE_ASSERT_TRUE(sscanf_s(pattern_str.c_str(), "i_instance%d", &ir_idx) == kDigitFormatCnt, - "Arg format [%s] is invalid", pattern_str.c_str()); - GE_ASSERT(static_cast(ir_idx) < valid_input_size, "ir index [%d] is invalid.", ir_idx); - arg_descs.push_back({type, ir_idx, false, {0}}); - } - return SUCCESS; -} - -static graphStatus InputInstanceCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - size_t count = 1UL; - if (arg_desc.ir_idx < 0) { - count = op_desc->GetInputsSize(); - } - if (arg_desc.folded) { - count *= 2UL; - } - size += count * sizeof(uintptr_t); - - return GRAPH_SUCCESS; -} - -static graphStatus OutputInstanceCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - size_t count = 1UL; - if (arg_desc.ir_idx < 0) { - count = op_desc->GetOutputsSize(); - } - if (arg_desc.folded) { - count *= 2UL; - } - size += count * sizeof(uintptr_t); - return GRAPH_SUCCESS; -} - -static graphStatus OutputInstanceParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - GE_ASSERT_NOTNULL(op_desc); - const size_t valid_output_size = op_desc->GetOutputsSize(); - if (pattern_str == "o_instance*") { - // 为了方便加载时使用,通配符场景解析后默认展开成多个 - for (size_t i = 0UL; i < valid_output_size; ++i) { - arg_descs.push_back({type, static_cast(i), false, {0}}); - } - } else { - int32_t ir_idx{0}; - GE_ASSERT_TRUE(sscanf_s(pattern_str.c_str(), "o_instance%d", &ir_idx) == kDigitFormatCnt, - "Arg format [%s] is invalid", pattern_str.c_str()); - GE_ASSERT(static_cast(ir_idx) < valid_output_size, "ir index [%d] is invalid.", ir_idx); - arg_descs.push_back({type, ir_idx, false, {0}}); - } - return SUCCESS; -} - -static graphStatus InputParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - GE_ASSERT_NOTNULL(op_desc); - const auto &ir_inputs = op_desc->GetIrInputs(); - if (pattern_str == "i*") { - // 为了方便加载时使用,通配符场景解析后默认展开成多个 - for (size_t i = 0UL; i < ir_inputs.size(); ++i) { - bool folded{false}; - if (ir_inputs[i].second == IrInputType::kIrInputDynamic) { - folded = true; - } - arg_descs.push_back({type, static_cast(i), folded, {0}}); - } - } else { - int32_t ir_idx{0}; - bool has_idx{false}; - for (size_t i = 1UL; i < pattern_str.size(); ++i) { - if (isdigit(pattern_str[i])) { - ir_idx = ir_idx * kDecimalCarry + static_cast(pattern_str[i]) - kAsciiZero; - has_idx = true; - } - } - GE_ASSERT(has_idx, "Arg format [%s] is invalid", pattern_str.c_str()); - GE_ASSERT(static_cast(ir_idx) < ir_inputs.size(), "ir index [%d] is invalid.", ir_idx); - - bool folded{false}; - if (ir_inputs[static_cast(ir_idx)].second == IrInputType::kIrInputDynamic && - pattern_str[pattern_str.length() - 1UL] != '*') { - folded = true; - } - arg_descs.push_back({type, ir_idx, folded, {0}}); - } - - return GRAPH_SUCCESS; -} - -static graphStatus OutputCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - const auto &ir_outputs = op_desc->GetIrOutputs(); - size_t count = 0UL; - if (arg_desc.ir_idx >= 0) { - // 非通配符场景 - GE_ASSERT((static_cast(arg_desc.ir_idx) < ir_outputs.size()), "ir_index [%d] is out of range", - arg_desc.ir_idx); - if (ir_outputs[arg_desc.ir_idx].second == IrOutputType::kIrOutputDynamic) { - if (arg_desc.folded) { - count++; // pointer to addr - } - int32_t dyn_num = 0; - for (auto &iter : op_desc->GetAllOutputName()) { - if (iter.first == ir_outputs[arg_desc.ir_idx].first + std::to_string(dyn_num)) { - ++count; // real input_addr - ++dyn_num; - } - } - } else { - count++; - } - } else { - // 通配符场景,非动态输入默认展开, 动态输入按照i0形式折叠 - for (const auto &ir_output : ir_outputs) { - count++; - if (ir_output.second == IrOutputType::kIrOutputDynamic) { - int32_t dyn_num = 0; - for (auto &iter : op_desc->GetAllOutputName()) { - if (iter.first == ir_output.first + std::to_string(dyn_num)) { - ++count; // real input addr - ++dyn_num; - } - } - } - } - } - size += count * sizeof(uintptr_t); - - return GRAPH_SUCCESS; -} - -static graphStatus OutputParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - GE_ASSERT_NOTNULL(op_desc); - const auto &ir_outputs = op_desc->GetIrOutputs(); - if (pattern_str == "o*") { - // 为了方便加载时使用,通配符场景解析后默认展开成多个 - for (size_t i = 0UL; i < ir_outputs.size(); ++i) { - bool folded{false}; - if (ir_outputs[i].second == IrOutputType::kIrOutputDynamic) { - folded = true; - } - arg_descs.push_back({type, static_cast(i), folded, {0}}); - } - } else { - int32_t ir_idx{0}; - bool has_idx{false}; - for (size_t i = 1UL; i < pattern_str.size(); ++i) { - if (isdigit(pattern_str[i])) { - ir_idx = ir_idx * kDecimalCarry + static_cast(pattern_str[i]) - kAsciiZero; - has_idx = true; - } - } - GE_ASSERT(has_idx, "Op[%s] arg format [%s] is invalid", op_desc->GetNamePtr(), pattern_str.c_str()); - GE_ASSERT(static_cast(ir_idx) < ir_outputs.size(), "Op[%s] ir index [%d] is invalid.", - op_desc->GetNamePtr(), ir_idx); - bool folded{false}; - if (ir_outputs[static_cast(ir_idx)].second == IrOutputType::kIrOutputDynamic && - pattern_str[pattern_str.length() - 1UL] != '*') { - folded = true; - } - arg_descs.push_back({type, ir_idx, folded, {0}}); - } - return GRAPH_SUCCESS; -} - -static graphStatus InputDescCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - const auto &ir_inputs = op_desc->GetIrInputs(); - GE_ASSERT((arg_desc.ir_idx >= 0 && static_cast(arg_desc.ir_idx) < ir_inputs.size()), - "ir_index is out of range"); - auto ir_name = ir_inputs[static_cast(arg_desc.ir_idx)].first; - if (arg_desc.folded) { - size += sizeof(uintptr_t); // pointer to desc - } - size += sizeof(uintptr_t); // offset to addr - size_t dyn_num = 0UL; - for (auto &iter : op_desc->GetAllInputName()) { - if (iter.first == ir_name + std::to_string(dyn_num)) { - const auto &input_desc = op_desc->GetInputDesc(iter.second); - size += sizeof(uintptr_t) * 2UL; // dims_info + addr - if (input_desc.GetShape().IsUnknownDimNum()) { - size += sizeof(uintptr_t) * kMaxDimNum; - } else if (input_desc.GetShape().IsScalar()) { - size += sizeof(uintptr_t); - } else { - size += sizeof(uintptr_t) * input_desc.GetShape().GetDimNum(); - } - ++dyn_num; - } - } - return GRAPH_SUCCESS; -} - -static graphStatus OutputDescCalcSize(const OpDescPtr &op_desc, const ArgDesc &arg_desc, size_t &size) { - const auto &ir_outputs = op_desc->GetIrOutputs(); - GE_ASSERT((arg_desc.ir_idx >= 0 && static_cast(arg_desc.ir_idx) < ir_outputs.size()), - "ir_index [%d] is out of range", arg_desc.ir_idx); - auto ir_name = ir_outputs[static_cast(arg_desc.ir_idx)].first; - - if (arg_desc.folded) { - size += sizeof(uintptr_t); // pointer to desc - } - size += sizeof(uintptr_t); // offset to addr - size_t dyn_num = 0UL; - for (auto &iter : op_desc->GetAllOutputName()) { - if (iter.first == ir_name + std::to_string(dyn_num)) { - const auto &output_desc = op_desc->GetOutputDesc(iter.second); - size += sizeof(uintptr_t) * 2UL; // dims_info + addr - if (output_desc.GetShape().IsUnknownDimNum()) { - size += sizeof(uintptr_t) * kMaxDimNum; - } else if (output_desc.GetShape().IsScalar()) { - size += sizeof(uintptr_t); - } else { - size += sizeof(uintptr_t) * output_desc.GetShape().GetDimNum(); - } - ++dyn_num; - } - } - return GRAPH_SUCCESS; -} - -static graphStatus IODescParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - (void) op_desc; - bool folded{true}; - if (pattern_str[pattern_str.length() - 1] == '*') { - folded = false; - } - int32_t ir_idx{0}; - bool has_idx{false}; - for (size_t i = 6UL; i < pattern_str.size(); ++i) { // start after i_desc/o_desc - if (isdigit(pattern_str[i])) { - ir_idx = ir_idx * kDecimalCarry + static_cast(pattern_str[i]) - kAsciiZero; - has_idx = true; - } - } - GE_ASSERT(has_idx, "Dynamic intput/output should have a concrete ir idx."); - arg_descs.push_back({type, ir_idx, folded, {0}}); - return GRAPH_SUCCESS; -} - -static graphStatus HiddenInputParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - (void) op_desc; - ArgDesc arg = {type, kAmbiguousIrIdx, false, {0}}; - if (sscanf_s(pattern_str.c_str(), "hi.hcom%d*", &arg.ir_idx) == kDigitFormatCnt) { - *reinterpret_cast(arg.reserved) = static_cast(HiddenInputsType::HCOM); - arg_descs.emplace_back(arg); - return GRAPH_SUCCESS; - } - if (sscanf_s(pattern_str.c_str(), "hi.tilefwk%d*", &arg.ir_idx) == kDigitFormatCnt) { - *reinterpret_cast(arg.reserved) = static_cast(HiddenInputsType::TILEFWK); - arg_descs.emplace_back(arg); - return GRAPH_SUCCESS; - } - if (sscanf_s(pattern_str.c_str(), "hi.hcclsk%d*", &arg.ir_idx) == kDigitFormatCnt) { - *reinterpret_cast(arg.reserved) = static_cast(HiddenInputsType::HCCLSUPERKERNEL); - arg_descs.emplace_back(arg); - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Hidden input type [%s] is unsupported.", pattern_str.c_str()); - return GRAPH_FAILED; -} - -static void HiddenInputSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - if (*reinterpret_cast(arg_desc.reserved) == static_cast(HiddenInputsType::HCOM)) { - ss << pattern << ".hcom" << arg_desc.ir_idx << "*"; - } - if (*reinterpret_cast(arg_desc.reserved) == static_cast(HiddenInputsType::TILEFWK)) { - ss << pattern << ".tilefwk" << arg_desc.ir_idx << "*"; - } - if (*reinterpret_cast(arg_desc.reserved) == - static_cast(HiddenInputsType::HCCLSUPERKERNEL)) { - ss << pattern << ".hcclsk" << arg_desc.ir_idx << "*"; - } - return; -} - -static graphStatus TilingContextParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - (void) op_desc; - static const std::map pattern_to_subtype{ - {"tiling_context", TilingContextSubType::TILING_CONTEXT}, - {"tiling_context.tiling_data", TilingContextSubType::TILING_DATA}, - {"tiling_context.tiling_key", TilingContextSubType::TILING_KEY}, - {"tiling_context.block_dim", TilingContextSubType::BLOCK_DIM}, - }; - const auto iter = pattern_to_subtype.find(pattern_str); - GE_ASSERT_TRUE(iter != pattern_to_subtype.end(), "pattern [%s] is unsupported.", pattern_str.c_str()); - arg_descs.push_back({type, static_cast(iter->second), false, {0}}); - return GRAPH_SUCCESS; -} - -static void TilingContextSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - ss << pattern; - const TilingContextSubType sub_type = static_cast(arg_desc.ir_idx); - switch (sub_type) { - case TilingContextSubType::TILING_DATA: - ss << ".tiling_data"; - break; - case TilingContextSubType::TILING_KEY: - ss << ".tiling_key"; - break; - case TilingContextSubType::BLOCK_DIM: - ss << ".block_dim"; - break; - default: - break; - } -} - -static graphStatus CustomValueParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - (void) op_desc; - auto width = ArgsFormatWidth::BIT64; - uint64_t payload; - if (sscanf_s(pattern_str.c_str(), "#.32b%lu", &payload) == kDigitFormatCnt) { - width = ArgsFormatWidth::BIT32; - } else if (sscanf_s(pattern_str.c_str(), "#%lu", &payload) != kDigitFormatCnt) { - GELOGE(GRAPH_FAILED, "Unsupported custom value format: [%s]", pattern_str.c_str()); - return GRAPH_FAILED; - } - ArgDesc arg = {type, static_cast(width), false, {0}}; - *reinterpret_cast(arg.reserved) = payload; - arg_descs.emplace_back(arg); - return GRAPH_SUCCESS; -} - -static void CustomValueSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &arg_desc) { - ss << pattern; - if (arg_desc.ir_idx == static_cast(ArgsFormatWidth::BIT32)) { - ss << ".32b"; - } - ss << *reinterpret_cast(arg_desc.reserved); -} - -static graphStatus VariableWidthCalcSize(const OpDescPtr &, const ArgDesc &arg_desc, size_t &size) { - GE_ASSERT(arg_desc.addr_type == AddrType::PLACEHOLDER || arg_desc.addr_type == AddrType::CUSTOM_VALUE); - auto width = static_cast(arg_desc.ir_idx); - switch (width) { - case ArgsFormatWidth::BIT64: - size += sizeof(uint64_t); - break; - case ArgsFormatWidth::BIT32: - size += sizeof(uint32_t); - break; - default: - GELOGE(PARAM_INVALID, "Encountering undefined ArgsFormatWidth: %d", static_cast(width)); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -struct PatternCmp { - bool operator()(const std::string &lhs, const std::string &rhs) const { - if (lhs.size() != rhs.size()) { - return lhs.size() > rhs.size(); - } else { - return rhs.compare(lhs) > 0; - } - }; -}; - -static const std::map kSkPatternToHandler = { - {"i", {InputParser, InputCalcSize, ArrayLikeSerializer, AddrType::INPUT}}, - {"o", {OutputParser, OutputCalcSize, ArrayLikeSerializer, AddrType::OUTPUT}}, - {"ws", {WorkspaceParser, WorkspaceCalcSize, ArrayLikeSerializer, AddrType::WORKSPACE}}, - {"t", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::TILING}}, - {"i_desc", {IODescParser, InputDescCalcSize, ArrayLikeSerializer, AddrType::INPUT_DESC}}, - {"o_desc", {IODescParser, OutputDescCalcSize, ArrayLikeSerializer, AddrType::OUTPUT_DESC}}, - {"ffts_addr", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::FFTS_ADDR}}, - {"overflow_addr", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::OVERFLOW_ADDR}}, - {"t_ffts", {DefaultParser, DefaultCalcSize, FftsTilingSerializer, AddrType::TILING_FFTS}}, - {"hi", {HiddenInputParser, DefaultCalcSize, HiddenInputSerializer, AddrType::HIDDEN_INPUT}}, - {"*op_type", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::OP_TYPE}}, - {"tiling_context", {TilingContextParser, DefaultCalcSize, TilingContextSerializer, AddrType::TILING_CONTEXT}}, - {"", {PlaceholderParser, VariableWidthCalcSize, PlaceholderSerializer, AddrType::PLACEHOLDER}}, - {"#", {CustomValueParser, VariableWidthCalcSize, CustomValueSerializer, AddrType::CUSTOM_VALUE}}, - {"i_instance", {InputInstanceParser, InputInstanceCalcSize, ArrayLikeSerializer, ge::AddrType::INPUT_INSTANCE}}, - {"o_instance", {OutputInstanceParser, OutputInstanceCalcSize, ArrayLikeSerializer, ge::AddrType::OUTPUT_INSTANCE}}, - {"event_addr", {EventAddrParser, DefaultCalcSize, EventAddrSerializer, ge::AddrType::EVENT_ADDR}}, -}; - -static graphStatus ConvertArgDescNormal2Sk(const ArgDesc &normal_arg_desc, int32_t op_id, ArgDesc &sk_arg_desc) { - GE_ASSERT_TRUE(normal_arg_desc.addr_type != AddrType::CUSTOM_VALUE); - SkArgDescV2 sk_arg_desc_tmp{}; - sk_arg_desc_tmp.addr_type = AddrType::SUPER_KERNEL_SUB_NODE; - sk_arg_desc_tmp.ir_idx = op_id; - if (normal_arg_desc.addr_type != AddrType::HIDDEN_INPUT) { - sk_arg_desc_tmp.reserved = normal_arg_desc.folded; - } else { - sk_arg_desc_tmp.reserved = *reinterpret_cast(normal_arg_desc.reserved); - } - sk_arg_desc_tmp.sub_addr_type = normal_arg_desc.addr_type; - sk_arg_desc_tmp.sub_idx = normal_arg_desc.ir_idx; - sk_arg_desc = *reinterpret_cast(&sk_arg_desc_tmp); - return GRAPH_SUCCESS; -} - -static graphStatus ConvertArgDescSk2Normal(const ArgDesc &sk_arg_desc, ArgDesc &arg_desc, int32_t &sub_op_id) { - if (sk_arg_desc.addr_type != AddrType::SUPER_KERNEL_SUB_NODE) { - arg_desc = sk_arg_desc; - sub_op_id = INT32_MAX; - return GRAPH_SUCCESS; - } - ArgDesc tmp_arg_desc{}; - const SkArgDesc *sk_arg_desc_tmp = reinterpret_cast(&sk_arg_desc); - sub_op_id = sk_arg_desc_tmp->ir_idx; - if (sk_arg_desc_tmp->sub_addr_type != AddrType::HIDDEN_INPUT) { - tmp_arg_desc.addr_type = sk_arg_desc_tmp->sub_addr_type; - tmp_arg_desc.ir_idx = sk_arg_desc_tmp->sub_idx; - tmp_arg_desc.folded = sk_arg_desc_tmp->folded; - } else { - const SkArgDescV2 *sk_arg_desc_v2_tmp = reinterpret_cast(&sk_arg_desc); - tmp_arg_desc.addr_type = sk_arg_desc_v2_tmp->sub_addr_type; - tmp_arg_desc.ir_idx = sk_arg_desc_v2_tmp->sub_idx; - tmp_arg_desc.folded = false; - *reinterpret_cast(tmp_arg_desc.reserved) = sk_arg_desc_v2_tmp->reserved; - } - arg_desc = tmp_arg_desc; - return GRAPH_SUCCESS; -} - -static graphStatus SknParser(const OpDescPtr &op_desc, const std::string &pattern_str, const AddrType type, - std::vector &arg_descs) { - GELOGD("get pattern %s, type %d", pattern_str.c_str(), type); - const std::string skn_str = "skn"; - GE_ASSERT_TRUE(pattern_str.substr(0, skn_str.length()) == skn_str); - int32_t sub_idx{0}; - bool has_idx{false}; - size_t i = skn_str.length(); - for (; i < pattern_str.size(); ++i) { // start after skn - if (isdigit(pattern_str[i])) { - sub_idx = sub_idx * kDecimalCarry + static_cast(pattern_str[i]) - kAsciiZero; - has_idx = true; - } else { - break; - } - } - GE_ASSERT(has_idx, "skn should have a concrete sub idx."); - NodePtr sub_node; - GE_ASSERT_SUCCESS(FindSkSubNode(op_desc, sub_idx, sub_node)); - std::vector sub_arg_descs; - std::string sub_pattern_str = pattern_str.substr(i); - GELOGD("get sub_pattern_str %s, type %d", sub_pattern_str.c_str(), type); - for (const auto &iter : kSkPatternToHandler) { - if (strncmp(sub_pattern_str.c_str(), iter.first.c_str(), iter.first.length()) == 0) { - GE_ASSERT_SUCCESS(iter.second.parse(sub_node->GetOpDesc(), sub_pattern_str, iter.second.type, sub_arg_descs)); - break; - } - } - GE_ASSERT_TRUE(sub_arg_descs.size() == 1); - ArgDesc tmp_sk_desc{}; - GE_ASSERT_GRAPH_SUCCESS(ConvertArgDescNormal2Sk(sub_arg_descs[0], sub_idx, tmp_sk_desc)); - GELOGD("get sub_pattern_str %s, sub_type %d, sub id %d", - sub_pattern_str.c_str(), sub_arg_descs[0].addr_type, sub_arg_descs[0].ir_idx); - - arg_descs.emplace_back(tmp_sk_desc); - return GRAPH_SUCCESS; -} - -static Status SknSerializer(std::stringstream &ss, const std::string &pattern, const ArgDesc &sk_arg_desc) { - ArgDesc tmp_arg_desc{}; - int32_t sub_op_id = 0; - GE_ASSERT_GRAPH_SUCCESS(ConvertArgDescSk2Normal(sk_arg_desc, tmp_arg_desc, sub_op_id)); - ss << pattern << sub_op_id; - bool founded = false; - for (const auto &iter : kSkPatternToHandler) { - if (iter.second.type == tmp_arg_desc.addr_type) { - iter.second.serialize(ss, iter.first, tmp_arg_desc); - founded = true; - break; - } - } - GE_ASSERT_TRUE(founded, "find %d no serialize func", tmp_arg_desc.addr_type); - return GRAPH_SUCCESS; -} - -static graphStatus SknCalcSize(const OpDescPtr &op_desc, const ArgDesc &sk_arg_desc, size_t &size) { - ArgDesc tmp_arg_desc{}; - int32_t sub_op_id = 0; - GE_ASSERT_GRAPH_SUCCESS(ConvertArgDescSk2Normal(sk_arg_desc, tmp_arg_desc, sub_op_id)); - NodePtr sub_node; - GE_ASSERT_SUCCESS(FindSkSubNode(op_desc, sub_op_id, sub_node)); - - bool founded = false; - for (const auto &iter : kSkPatternToHandler) { - if (iter.second.type == tmp_arg_desc.addr_type) { - GE_ASSERT_SUCCESS(iter.second.getArgsSize(sub_node->GetOpDesc(), tmp_arg_desc, size)); - founded = true; - break; - } - } - GE_ASSERT_TRUE(founded, "find %d no serialize func", tmp_arg_desc.addr_type); - return GRAPH_SUCCESS; -} - -static const std::map kPatternToHandler = { - {"i", {InputParser, InputCalcSize, ArrayLikeSerializer, AddrType::INPUT}}, - {"o", {OutputParser, OutputCalcSize, ArrayLikeSerializer, AddrType::OUTPUT}}, - {"ws", {WorkspaceParser, WorkspaceCalcSize, ArrayLikeSerializer, AddrType::WORKSPACE}}, - {"t", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::TILING}}, - {"i_desc", {IODescParser, InputDescCalcSize, ArrayLikeSerializer, AddrType::INPUT_DESC}}, - {"o_desc", {IODescParser, OutputDescCalcSize, ArrayLikeSerializer, AddrType::OUTPUT_DESC}}, - {"ffts_addr", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::FFTS_ADDR}}, - {"overflow_addr", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::OVERFLOW_ADDR}}, - {"t_ffts", {DefaultParser, DefaultCalcSize, FftsTilingSerializer, AddrType::TILING_FFTS}}, - {"hi", {HiddenInputParser, DefaultCalcSize, HiddenInputSerializer, AddrType::HIDDEN_INPUT}}, - {"*op_type", {DefaultParser, DefaultCalcSize, DefaultSerializer, AddrType::OP_TYPE}}, - {"tiling_context", {TilingContextParser, DefaultCalcSize, TilingContextSerializer, AddrType::TILING_CONTEXT}}, - {"", {PlaceholderParser, VariableWidthCalcSize, PlaceholderSerializer, AddrType::PLACEHOLDER}}, - {"#", {CustomValueParser, VariableWidthCalcSize, CustomValueSerializer, AddrType::CUSTOM_VALUE}}, - {"i_instance", {InputInstanceParser, InputInstanceCalcSize, ArrayLikeSerializer, ge::AddrType::INPUT_INSTANCE}}, - {"o_instance", {OutputInstanceParser, OutputInstanceCalcSize, ArrayLikeSerializer, ge::AddrType::OUTPUT_INSTANCE}}, - {"event_addr", {EventAddrParser, DefaultCalcSize, EventAddrSerializer, ge::AddrType::EVENT_ADDR}}, - {"skn", {SknParser, SknCalcSize, SknSerializer, ge::AddrType::SUPER_KERNEL_SUB_NODE}}, - }; - -void ArgsFormatDesc::Append(AddrType type, int32_t ir_idx, bool folded) { - int32_t idx = (type == AddrType::HIDDEN_INPUT ? 0 : ir_idx); - arg_descs_.push_back({type, idx, folded, {0}}); -} - -void ArgsFormatDesc::AppendTilingContext(TilingContextSubType sub_type) { - arg_descs_.push_back({AddrType::TILING_CONTEXT, static_cast(sub_type), false, {0}}); -} - -void ArgsFormatDesc::AppendPlaceholder(ArgsFormatWidth width) { - arg_descs_.push_back({AddrType::PLACEHOLDER, static_cast(width), false, {0}}); -} - -void ArgsFormatDesc::AppendCustomValue(uint64_t value, ArgsFormatWidth width) { - ArgDesc arg = {AddrType::CUSTOM_VALUE, static_cast(width), false, {0}}; - *reinterpret_cast(arg.reserved) = value; - arg_descs_.push_back(arg); -} - -std::string ArgsFormatDesc::ToString() const { - return Serialize(arg_descs_); -} - -graphStatus ArgsFormatDesc::GetArgsSize(const OpDescPtr &op_desc, size_t &args_size) const { - GE_ASSERT_NOTNULL(op_desc); - size_t total_size{0UL}; - for (const auto &arg_desc : arg_descs_) { - for (const auto &iter : kPatternToHandler) { - if (iter.second.type == arg_desc.addr_type) { - GE_ASSERT_SUCCESS(iter.second.getArgsSize(op_desc, arg_desc, total_size)); - } - } - } - args_size = total_size; - return GRAPH_SUCCESS; -} - -graphStatus ArgsFormatDesc::GetArgSize(const OpDescPtr &op_desc, const ArgDesc arg_desc, size_t &arg_size) { - GE_ASSERT_NOTNULL(op_desc); - for (const auto &iter : kPatternToHandler) { - if (iter.second.type == arg_desc.addr_type) { - GE_ASSERT_SUCCESS(iter.second.getArgsSize(op_desc, arg_desc, arg_size)); - return GRAPH_SUCCESS; - } - } - GELOGE(GRAPH_PARAM_INVALID, "arg_desc type [%d] is unsupported.", static_cast(arg_desc.addr_type)); - return GRAPH_PARAM_INVALID; -} - -static graphStatus EasyParser(const std::string &pattern_str, const AddrType type, const std::string &prefix_str, - std::vector &arg_descs) { - if (prefix_str + "*" == pattern_str) { - // i*, o*,ws* - arg_descs.push_back({type, kAmbiguousIrIdx, false, {0U}}); - return GRAPH_SUCCESS; - } - // 处理:i0, o0,ws0, i0*, o0*等场景 - int32_t ir_idx{kAmbiguousIrIdx}; - std::string scan_str = prefix_str + "%d"; - GE_ASSERT_TRUE(sscanf_s(pattern_str.c_str(), scan_str.c_str(), &ir_idx) == kDigitFormatCnt, - "Args format {%s} is invalid.", pattern_str.c_str()); - const bool folded = pattern_str[pattern_str.length() - 1UL] != '*'; - arg_descs.push_back({type, ir_idx, folded, {0U}}); - return GRAPH_SUCCESS; -} - -static graphStatus SingleParser(const std::string &pattern_str, const OpDescPtr &op_desc, bool easy_mode, - std::vector &arg_descs, bool &parsed) { - for (const auto &iter : kPatternToHandler) { - if (strncmp(pattern_str.c_str(), iter.first.c_str(), iter.first.length()) == 0) { - if (easy_mode && kNeedEasyParserTypes.count(iter.second.type) > 0UL) { - GE_ASSERT_SUCCESS(EasyParser(pattern_str, iter.second.type, iter.first, arg_descs)); - } else { - GE_ASSERT_SUCCESS(iter.second.parse(op_desc, pattern_str, iter.second.type, arg_descs)); - } - parsed = true; - break; - } - } - return GRAPH_SUCCESS; -} - -// Compatible with the old offline model. -graphStatus ArgsFormatDesc::Parse(const OpDescPtr &op_desc, const std::string &str, std::vector &arg_descs) { - return ArgsFormatDesc::Parse(op_desc, str, arg_descs, false); -} - -graphStatus ArgsFormatDesc::Parse(const OpDescPtr &op_desc, const std::string &str, std::vector &arg_descs, - const bool easy_mode) { - arg_descs.clear(); - size_t start_idx = 0UL; - while (start_idx < str.size()) { - GE_ASSERT(str[start_idx] == '{', "SyntaxError: argsformat should be surrounded by '{','}'"); - size_t end_idx = start_idx + 1UL; - bool parsed{false}; - while (end_idx < str.size()) { - if (str[end_idx] == '}') { - std::string pattern_str = str.substr(start_idx + 1, end_idx - start_idx - 1); - GE_ASSERT_SUCCESS(SingleParser(pattern_str, op_desc, easy_mode, arg_descs, parsed), - "args format [%s] parse failed.", pattern_str.c_str()); - start_idx = end_idx + 1UL; - break; - } - ++end_idx; - } - GE_ASSERT(parsed, "SyntaxError: argsformat should be surrounded by '{','}'"); - } - return GRAPH_SUCCESS; -} - -std::string ArgsFormatDesc::Serialize(const std::vector &arg_descs) { - std::stringstream ss; - for (const auto &arg_desc : arg_descs) { - for (const auto &iter : kPatternToHandler) { - if (iter.second.type == arg_desc.addr_type) { - ss << '{'; - iter.second.serialize(ss, iter.first, arg_desc); - ss << '}'; - } - } - } - return ss.str(); -} - -void ArgsFormatDesc::Clear() { - arg_descs_.clear(); -} - -graphStatus ArgsFormatDesc::ConvertArgDescSkToNormal(const ArgDesc &sk_arg_desc, - ArgDesc &arg_desc, int32_t &sub_op_id) { - return ConvertArgDescSk2Normal(sk_arg_desc, arg_desc, sub_op_id); -} - -graphStatus ArgsFormatDesc::ConvertToSuperKernelArgFormat(const NodePtr &sk_node, - const NodePtr &sub_node, const std::string &sub_node_arg_format, std::string &sk_arg_format) { - GE_ASSERT_NOTNULL(sk_node); - GE_ASSERT_NOTNULL(sub_node); - auto sk_opdesc = sk_node->GetOpDesc(); - GE_ASSERT_NOTNULL(sk_opdesc); - auto sub_op = sub_node->GetOpDesc(); - GE_ASSERT_NOTNULL(sub_op); - GELOGI("current sub_op %s arg format %s, sk %s arg format %s", - sub_node->GetNamePtr(), sub_node_arg_format.c_str(), sk_node->GetNamePtr(), sk_arg_format.c_str()); - - std::vector cur_op_arg_descs; - ArgsFormatDesc::Parse(sub_op, sub_node_arg_format, cur_op_arg_descs, false); - std::vector append_sk_arg_descs; - for (auto &arg_desc : cur_op_arg_descs) { - ArgDesc tmp_arg_desc{}; - GE_ASSERT_GRAPH_SUCCESS(ConvertArgDescNormal2Sk(arg_desc, sub_op->GetId(), tmp_arg_desc)); - append_sk_arg_descs.emplace_back(tmp_arg_desc); - } - sk_arg_format += ArgsFormatDesc::Serialize(append_sk_arg_descs); - - const size_t max_log_string_len = 800U; - size_t index = 0U; - while (index < sk_arg_format.length()) { - GELOGI("%s", sk_arg_format.substr(index, max_log_string_len).c_str()); - index += max_log_string_len; - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/utils/args_format_desc_utils.cc b/graph/utils/args_format_desc_utils.cc deleted file mode 100644 index 68820e1b34d1a0ac2015ed03960a8eae6f2cb0ac..0000000000000000000000000000000000000000 --- a/graph/utils/args_format_desc_utils.cc +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "graph/utils/args_format_desc_utils.h" -#include -#include "common/checker.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/args_format_desc.h" - -namespace ge { - -void ArgsFormatDescUtils::Append(std::vector &arg_descs, ge::AddrType type, int32_t ir_idx, bool folded) { - arg_descs.push_back({type, ir_idx, folded, {0}}); -} - -void ArgsFormatDescUtils::AppendTilingContext(std::vector &arg_descs, ge::TilingContextSubType sub_type) { - arg_descs.push_back({AddrType::TILING_CONTEXT, static_cast(sub_type), false, {0}}); -} - -graphStatus ArgsFormatDescUtils::InsertHiddenInputs(std::vector &arg_descs, int32_t insert_pos, - HiddenInputsType hidden_type, size_t input_cnt) { - if (insert_pos < 0) { - insert_pos = arg_descs.size(); - } - GE_ASSERT_TRUE(static_cast(insert_pos) <= arg_descs.size()); - ArgDesc arg_desc = {AddrType::HIDDEN_INPUT, -1, false, {0}}; - for (size_t i = 0; i < input_cnt; ++i, ++insert_pos) { - arg_desc.ir_idx = static_cast(i); - *reinterpret_cast(arg_desc.reserved) = static_cast(hidden_type); - arg_descs.insert(arg_descs.begin() + insert_pos, arg_desc); - } - return GRAPH_SUCCESS; -} - -graphStatus ArgsFormatDescUtils::InsertCustomValue(std::vector &arg_descs, int32_t insert_pos, - uint64_t custom_value) { - ArgDesc arg = {AddrType::CUSTOM_VALUE, static_cast(ArgsFormatWidth::BIT64), false, {0}}; - *reinterpret_cast(arg.reserved) = custom_value; - if (insert_pos < 0) { - arg_descs.emplace_back(arg); - } else { - GE_ASSERT_TRUE(static_cast(insert_pos) <= arg_descs.size()); - arg_descs.insert(arg_descs.begin() + insert_pos, arg); - } - return GRAPH_SUCCESS; -} - -graphStatus ArgsFormatDescUtils::Parse(const std::string &str, std::vector &arg_descs) { - return ArgsFormatDesc::Parse(nullptr, str, arg_descs, true); -} - -std::string ArgsFormatDescUtils::Serialize(const std::vector &arg_descs) { - return ArgsFormatDesc::Serialize(arg_descs); -} - -std::string ArgsFormatDescUtils::ToString(const std::vector &arg_descs) { - return ArgsFormatDesc::Serialize(arg_descs); -} -} // namespace ge \ No newline at end of file diff --git a/graph/utils/connection_matrix.cc b/graph/utils/connection_matrix.cc deleted file mode 100644 index de8bbcecfbcfd4da1efc1d581abab154a6fdf166..0000000000000000000000000000000000000000 --- a/graph/utils/connection_matrix.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/connection_matrix.h" -#include "connection_matrix_impl.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" - -namespace ge { -ConnectionMatrix::ConnectionMatrix(const ComputeGraphPtr &graph) - : impl_(ComGraphMakeUnique(graph)) {} - -bool ConnectionMatrix::IsConnected(const NodePtr &a, const NodePtr &b) const { - if (impl_ == nullptr) { - return false; - } - return impl_->IsConnected(a, b); -} - -void ConnectionMatrix::SetConnectivity(const Node::Vistor &inputs, const NodePtr &node) { - if (impl_ == nullptr) { - return; - } - impl_->SetConnectivity(inputs, node); -} - -graphStatus ConnectionMatrix::Generate(const ComputeGraphPtr &graph) { - GE_CHECK_NOTNULL(impl_); - return impl_->Generate(graph); -} - -void ConnectionMatrix::Update(const ComputeGraphPtr &graph, const std::vector &fusion_nodes) { - if (impl_ == nullptr) { - return; - } - impl_->Update(graph, fusion_nodes); -} - -void ConnectionMatrix::ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name) { - if (impl_ == nullptr) { - return; - } - impl_->ExpandAndUpdate(fusion_nodes, node_name); -} - -ConnectionMatrixImpl::ConnectionMatrixImpl(const ComputeGraphPtr &graph) : graph_(graph) { - const auto direct_nodes = graph->GetDirectNode(); - size_ = direct_nodes.size(); - bit_maps_.reserve(size_); - uint64_t index_loop = 0; - for (const auto &node : direct_nodes) { - name_to_index_[node->GetName()] = index_loop; - bit_maps_.emplace_back(size_); - index_loop++; - } - used_ = size_; -}; - -ConnectionMatrixImpl::~ConnectionMatrixImpl() { - bit_maps_.clear(); - name_to_index_.clear(); -} - -uint64_t ConnectionMatrixImpl::AddNode(const std::string &op_name) { - if (used_ + 1 >= size_) { - size_t new_size = size_ + expand_step_; - for (auto &m: bit_maps_) { - m.ResizeBits(new_size); - } - - ge::LargeBitmap new_bit_vector(new_size); - bit_maps_.resize(new_size, new_bit_vector); - for (size_t i = used_; i < new_size; ++i) { - bit_maps_[i].SetValues(0); - } - - size_ = new_size; - } - - uint64_t new_index = used_; - ++used_; - name_to_index_[op_name] = new_index; - return new_index; -} - -void ConnectionMatrixImpl::ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name) { - uint64_t new_index = AddNode(node_name); - ge::LargeBitmap &new_bit_vector = GetBitMap(new_index); - - // update - new_bit_vector.SetBit(new_index); - std::vector fusion_indexs(fusion_nodes.size(), 0); - for (size_t i = 0U; i < fusion_nodes.size(); ++i) { - auto index = GetIndex(fusion_nodes[i]); - new_bit_vector.Or(GetBitMap(index)); - fusion_indexs[i] = index; - } - - for (size_t i = 0; i < used_; ++i) { - ge::LargeBitmap &node_map = bit_maps_[i]; - for (size_t j = 0; j < fusion_nodes.size(); ++j) { - if (node_map.GetBit(fusion_indexs[j])) { - node_map.Or(new_bit_vector); - break; - } - } - } -} - -graphStatus ConnectionMatrixImpl::Generate(const ComputeGraphPtr &graph) { - if (graph_ == nullptr) { - graph_ = graph; - } - for (auto &node : graph->GetDirectNode()) { - const auto inputs = node->GetInAllNodes(); - SetConnectivity(inputs, node); - } - return GRAPH_SUCCESS; -} - -void ConnectionMatrixImpl::Update(const ComputeGraphPtr &graph, const vector &fusion_nodes) { - if (graph_ == nullptr) { - return; - } - if (graph != graph_) { - GELOGW("Input graph %s is not the same one %s when contribute connection matrix.", graph->GetName().c_str(), - graph_->GetName().c_str()); - return; - } - LargeBitmap new_bit_vector(graph->GetDirectNode().size()); - new_bit_vector.SetValues(0U); - for (size_t i = 0U; i < fusion_nodes.size(); i++) { - new_bit_vector.Or(GetBitMap(fusion_nodes[i])); - } - for (auto &node : graph->GetDirectNode()) { - bool is_connected_to_fusion = false; - for (size_t i = 0U; i < fusion_nodes.size(); i++) { - if (GetBitMap(node).GetBit(static_cast(GetIndex(fusion_nodes[i])))) { - is_connected_to_fusion = true; - break; - } - } - if (is_connected_to_fusion) { - GetBitMap(node).Or(new_bit_vector); - } - } -} - -void ConnectionMatrixImpl::SetConnectivity(const Node::Vistor &inputs, const NodePtr &node) { - LargeBitmap &bitmap = GetBitMap(node); - if (std::find(inputs.begin(), inputs.end(), node) == inputs.end()) { - bitmap.SetValues(0U); - } - - bitmap.SetBit(static_cast(GetIndex(node))); - for (const NodePtr &input : inputs) { - if (input != node) { - bitmap.Or(GetBitMap(input)); - } - } -} - -uint64_t ConnectionMatrixImpl::GetIndex(const std::string &op_name) const { - const auto iter = name_to_index_.find(op_name); - if (iter != name_to_index_.end()) { - return iter->second; - } else { - GELOGW("node %s is not found in name_to_index_", op_name.c_str()); - return 0; - } -} - -uint64_t ConnectionMatrixImpl::GetIndex(const NodePtr &node) const { - return GetIndex(node->GetName()); -} - -bool ConnectionMatrixImpl::IsConnected(const NodePtr &a, const NodePtr &b) const { - return GetBitMap(b).GetBit(static_cast(GetIndex(a))); -} - -const LargeBitmap &ConnectionMatrixImpl::GetBitMap(const NodePtr &node) const { - return bit_maps_[static_cast(GetIndex(node))]; -} - -LargeBitmap &ConnectionMatrixImpl::GetBitMap(const NodePtr &node) { - return bit_maps_[static_cast(GetIndex(node))]; -} - -LargeBitmap &ConnectionMatrixImpl::GetBitMap(uint64_t index) { - return bit_maps_[index]; -} -} // namespace ge diff --git a/graph/utils/connection_matrix_impl.h b/graph/utils/connection_matrix_impl.h deleted file mode 100644 index 679e7311b4d28b30e7223220057cae1a9e147b24..0000000000000000000000000000000000000000 --- a/graph/utils/connection_matrix_impl.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CONNECTION_MATRIX_IMPL_H_ -#define GRAPH_CONNECTION_MATRIX_IMPL_H_ - -#include "graph/debug/ge_attr_define.h" -#include "graph/node.h" -#include "graph/graph.h" -#include "graph/compute_graph.h" -#include "common/large_bm.h" - -namespace ge { -class ConnectionMatrixImpl { -public: - explicit ConnectionMatrixImpl(const ComputeGraphPtr &graph); - - ~ConnectionMatrixImpl(); - - bool IsConnected(const NodePtr &a, const NodePtr &b) const; - - // inputs are all input nodes of parameter node. - // if there is a path between A->B, then B will own A's - // connectivity. The reason is --- - // If some node can reach A, than it can also reach B. - void SetConnectivity(const Node::Vistor &inputs, const NodePtr &node); - - /* Computes the connectivity between two nodes in the - * computation. The returned ConnectivityMatrix is constructed such that - * ConnectivityMatrix::IsConnected(a, b) returns true iff there exists a - * directed path (from producer to consumer) from 'a' to 'b'. Both data - * connection and control connection are considered for connectivity. - * A node is connected to itself. */ - graphStatus Generate(const ComputeGraphPtr &graph); - - // update reachablity map for fused nodes. - void Update(const ComputeGraphPtr &graph, const std::vector &fusion_nodes); - - uint64_t AddNode(const string &op_name); - - void ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name); - -private: - ConnectionMatrixImpl() = delete; - uint64_t GetIndex(const NodePtr &node) const; - - uint64_t GetIndex(const std::string &op_name) const; - - const LargeBitmap &GetBitMap(const NodePtr &node) const; - - LargeBitmap &GetBitMap(const NodePtr &node); - - LargeBitmap &GetBitMap(uint64_t index); - - size_t size_ = 0; - size_t used_ = 0; - size_t expand_step_ = 64; - - std::vector bit_maps_; - - std::unordered_map name_to_index_; - - ComputeGraphPtr graph_; -}; -} -#endif // GRAPH_CONNECTION_MATRIX_H_ diff --git a/graph/utils/constant_utils.cc b/graph/utils/constant_utils.cc deleted file mode 100644 index d8d5bb7dc407f5e66c56aafb8b26d862f48455c0..0000000000000000000000000000000000000000 --- a/graph/utils/constant_utils.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/constant_utils.h" -#include "common/checker.h" -#include "debug/ge_util.h" -#include "graph/utils/file_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/tensor_adapter.h" - -namespace ge { -bool ConstantUtils::IsConstant(const NodePtr &node) { - return IsConstant(node->GetOpDesc()); -} - -bool ConstantUtils::IsConstant(const OpDescPtr &op_desc) { - if ((op_desc->GetType() == CONSTANT) || (op_desc->GetType() == CONSTANTOP)) { - return true; - } - return IsPotentialConst(op_desc); -} - -bool ConstantUtils::IsPotentialConst(const OpDescPtr &op_desc) { - bool is_potential_const = false; - const auto has_attr = AttrUtils::GetBool(op_desc, ATTR_NAME_POTENTIAL_CONST, is_potential_const); - return (has_attr && is_potential_const); -} - -bool ConstantUtils::IsRealConst(const OpDescPtr &op_desc) { - return ((op_desc->GetType() == CONSTANT) || (op_desc->GetType() == CONSTANTOP)); -} - -bool ConstantUtils::GetWeight(const OpDescPtr &op_desc, const uint32_t index, ConstGeTensorPtr &weight) { - if (AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, weight)) { - return true; - } - if (!IsPotentialConst(op_desc)) { - return false; - } - - std::vector weight_indices; - std::vector weights; - if (!GetPotentialWeight(op_desc, weight_indices, weights)) { - return false; - } - for (size_t i = 0U; i < weight_indices.size(); ++i) { - if (weight_indices[i] == index) { - weight = weights[i]; - return true; - } - } - return false; -} - -bool ConstantUtils::MutableWeight(const OpDescPtr &op_desc, const uint32_t index, GeTensorPtr &weight) { - if (AttrUtils::MutableTensor(op_desc, ATTR_NAME_WEIGHTS, weight)) { - return true; - } - if (!IsPotentialConst(op_desc)) { - return false; - } - std::vector weight_indices; - std::vector weights; - if (!MutablePotentialWeight(op_desc, weight_indices, weights)) { - return false; - } - - for (size_t i = 0U; i < weight_indices.size(); ++i) { - if (weight_indices[i] == index) { - weight = weights[i]; - return true; - } - } - return false; -} -bool ConstantUtils::SetWeight(const OpDescPtr &op_desc, const uint32_t index, const GeTensorPtr weight) { - if (IsRealConst(op_desc) && - AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight)) { - return true; - } - if (!IsPotentialConst(op_desc)) { - return false; - } - std::vector weight_indices; - std::vector weights; - if (!MutablePotentialWeight(op_desc, weight_indices, weights)) { - return false; - } - - for (size_t i = 0U; i < weight_indices.size(); ++i) { - if (weight_indices[i] == index) { - weights[i] = weight; - return AttrUtils::SetListTensor(op_desc, ATTR_NAME_POTENTIAL_WEIGHT, weights); - } - } - return false; -} -bool ConstantUtils::GetPotentialWeight(const OpDescPtr &op_desc, - std::vector &weight_indices, - std::vector &weights) { - // check potential const attrs - if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_POTENTIAL_WEIGHT_INDICES, weight_indices)) { - GELOGW("Missing ATTR_NAME_POTENTIAL_WEIGHT_INDICES attr on potential const %s.", op_desc->GetName().c_str()); - return false; - } - if (!AttrUtils::GetListTensor(op_desc, ATTR_NAME_POTENTIAL_WEIGHT, weights)) { - GELOGW("Missing ATTR_NAME_POTENTIAL_WEIGHT attr on potential const %s.", op_desc->GetName().c_str()); - return false; - } - if (weight_indices.size() != weights.size()) { - GELOGW("Weight indices not match with weight size on potential const %s.", op_desc->GetName().c_str()); - return false; - } - return true; -} - -bool ConstantUtils::MutablePotentialWeight(const OpDescPtr &op_desc, std::vector &weight_indices, - std::vector &weights) { - // check potential const attrs - if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_POTENTIAL_WEIGHT_INDICES, weight_indices)) { - GELOGW("Missing ATTR_NAME_POTENTIAL_WEIGHT_INDICES attr on potential const %s.", op_desc->GetName().c_str()); - return false; - } - if (!AttrUtils::MutableListTensor(op_desc, ATTR_NAME_POTENTIAL_WEIGHT, weights)) { - GELOGW("Missing ATTR_NAME_POTENTIAL_WEIGHT attr on potential const %s.", op_desc->GetName().c_str()); - return false; - } - if (weight_indices.size() != weights.size()) { - GELOGW("Weight indices not match with weight size on potential const %s.", op_desc->GetName().c_str()); - return false; - } - return true; -} -bool ConstantUtils::MarkPotentialConst(const OpDescPtr &op_desc, - const std::vector indices, - const std::vector weights) { - if (indices.size() != weights.size()) { - return false; - } - return (AttrUtils::SetBool(op_desc, ATTR_NAME_POTENTIAL_CONST, true) && - AttrUtils::SetListInt(op_desc, ATTR_NAME_POTENTIAL_WEIGHT_INDICES, indices) && - AttrUtils::SetListTensor(op_desc, ATTR_NAME_POTENTIAL_WEIGHT, weights)); -} -bool ConstantUtils::UnMarkPotentialConst(const OpDescPtr &op_desc) { - if (op_desc->HasAttr(ATTR_NAME_POTENTIAL_CONST) && - op_desc->HasAttr(ATTR_NAME_POTENTIAL_WEIGHT_INDICES) && - op_desc->HasAttr(ATTR_NAME_POTENTIAL_WEIGHT)) { - (void)op_desc->DelAttr(ATTR_NAME_POTENTIAL_CONST); - (void)op_desc->DelAttr(ATTR_NAME_POTENTIAL_WEIGHT_INDICES); - (void)op_desc->DelAttr(ATTR_NAME_POTENTIAL_WEIGHT); - return true; - } - return false; -} - -bool ConstantUtils::GetWeightFromFile(const OpDescPtr &op_desc, ConstGeTensorPtr &weight) { - if (op_desc->GetType() != FILECONSTANT) { - return false; - } - auto output_desc = op_desc->MutableOutputDesc(0U); - GE_ASSERT_NOTNULL(output_desc); - DataType out_type = ge::DT_UNDEFINED; - (void)AttrUtils::GetDataType(op_desc, "dtype", out_type); - output_desc->SetDataType(out_type); - int64_t weight_size = 0; - GE_ASSERT_SUCCESS(TensorUtils::GetTensorSizeInBytes(*output_desc, weight_size), "Failed to get weight size"); - std::string file_path; - (void)AttrUtils::GetStr(op_desc, ATTR_NAME_LOCATION, file_path); - int64_t attr_offset = 0; - (void)AttrUtils::GetInt(op_desc, ATTR_NAME_OFFSET, attr_offset); - const auto offset = static_cast(attr_offset); - int64_t attr_length = 0; - (void)AttrUtils::GetInt(op_desc, ATTR_NAME_LENGTH, attr_length); - const auto length = static_cast(attr_length); - if (file_path.empty()) { - (void)AttrUtils::GetStr(op_desc, ATTR_NAME_FILE_PATH, file_path); - if (file_path.empty()) { - GELOGW("Failed to get file constant weight path, node:%s", op_desc->GetName().c_str()); - return false; - } - } - const size_t file_length = (length == 0U ? static_cast(weight_size) : length); - const auto &bin_buff = GetBinFromFile(file_path, offset, file_length); - GE_ASSERT_NOTNULL(bin_buff); - const auto tensor = - ComGraphMakeShared(*output_desc, reinterpret_cast(bin_buff.get()), file_length); - GE_ASSERT_NOTNULL(tensor); - weight = tensor; - return true; -} -} // namespace ge diff --git a/graph/utils/cycle_detector.cc b/graph/utils/cycle_detector.cc deleted file mode 100644 index 10e7d0b8db0213ee6725843f2a60a1ac39180d27..0000000000000000000000000000000000000000 --- a/graph/utils/cycle_detector.cc +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/cycle_detector.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" - -namespace ge { -namespace { -void PrintAllNodes(const std::vector &scope_nodes) { - for (const auto &node : scope_nodes) { - if (node == nullptr) { - GELOGD("type: null, name: null"); - } else { - GELOGD("type: %s, name: %s", node->GetType().c_str(), node->GetName().c_str()); - } - } -} - -bool CheckEachPeerOut(const NodePtr &node, const std::unordered_set &scope_nodes_set, - const std::vector &scope_nodes, - const std::unique_ptr &connectivity) { - for (const auto &peer_out : node->GetOutAllNodes()) { - if (scope_nodes_set.count(peer_out) > 0) { - continue; - } - for (const auto &node_temp : scope_nodes) { - if ((node_temp == nullptr) || (node_temp == node)) { - continue; - } - GELOGD("Check %s and %s.", peer_out->GetName().c_str(), node_temp->GetName().c_str()); - - if (connectivity->IsConnected(peer_out, node_temp)) { - GELOGD("There is a path between %s and %s after fusing:", peer_out->GetName().c_str(), - node_temp->GetName().c_str()); - PrintAllNodes(scope_nodes); - return true; - } - } - } - return false; -} - -bool DetectOneScope(const std::vector &scope_nodes, - const std::unique_ptr &connectivity) { - /* Create a set for accelerating the searching. */ - const std::unordered_set scope_nodes_set(scope_nodes.begin(), scope_nodes.end()); - - for (const auto &node : scope_nodes) { - if (node == nullptr) { - continue; - } - if (CheckEachPeerOut(node, scope_nodes_set, scope_nodes, connectivity)) { - return true; - } - } - return false; -} -} // namespace -graphStatus CycleDetector::Init(const ComputeGraphPtr &graph) { - if (connectivity_ == nullptr) { - connectivity_ = ComGraphMakeUnique(graph); - if (connectivity_ == nullptr) { - GELOGW("Make shared failed"); - return FAILED; - } - - const Status ret = connectivity_->Generate(graph); - if (ret != SUCCESS) { - GE_LOGE("Cannot generate connection matrix for graph %s.", graph->GetName().c_str()); - return FAILED; - } - } - return SUCCESS; -} - -bool CycleDetector::HasDetectedCycle(const std::vector> &fusion_nodes) { - for (const auto &scope_nodes : fusion_nodes) { - if (DetectOneScope(scope_nodes, connectivity_)) { - return true; - } - } - return false; -} - -void CycleDetector::Update(const ComputeGraphPtr &graph, const std::vector &fusion_nodes) { - if (connectivity_ == nullptr) { - GELOGW("Connectivity is empty, please call HasDetectedCycle first."); - return; - } - connectivity_->Update(graph, fusion_nodes); -} - -void CycleDetector::ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name) { - if (connectivity_ == nullptr) { - GELOGW("Connectivity is empty, please generate first."); - return; - } - connectivity_->ExpandAndUpdate(fusion_nodes, node_name); -} -} // namespace ge diff --git a/graph/utils/dumper/ge_graph_dumper.cc b/graph/utils/dumper/ge_graph_dumper.cc deleted file mode 100644 index 6a5da6cf057782b991cdee3af2b2c5e891549ded..0000000000000000000000000000000000000000 --- a/graph/utils/dumper/ge_graph_dumper.cc +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "ge_graph_dumper.h" - -namespace ge { -namespace { -struct DefaultDumper : public GeGraphDumper { - void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) override { - (void)graph; - (void)suffix; - } -}; -DefaultDumper default_dumper; -GeGraphDumper *register_checker = &default_dumper; -} - -GeGraphDumper &GraphDumperRegistry::GetDumper() { - return *register_checker; -} -void GraphDumperRegistry::Register(GeGraphDumper &dumper) { - register_checker = &dumper; -} -void GraphDumperRegistry::Unregister() { - register_checker = &default_dumper; -} -} // namespace ge diff --git a/graph/utils/dumper/ge_graph_dumper.h b/graph/utils/dumper/ge_graph_dumper.h deleted file mode 100644 index 8145fd4019be4230734c8677a16b7812f3aa6064..0000000000000000000000000000000000000000 --- a/graph/utils/dumper/ge_graph_dumper.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_UTILS_DUMPER_GE_GRAPH_DUMPER_H_ -#define GRAPH_UTILS_DUMPER_GE_GRAPH_DUMPER_H_ - -#include "graph/compute_graph.h" - -namespace ge { -struct GeGraphDumper { - GeGraphDumper() = default; - GeGraphDumper(const GeGraphDumper &) = delete; - GeGraphDumper &operator=(const GeGraphDumper &) = delete; - GeGraphDumper(GeGraphDumper &&) = delete; - GeGraphDumper &operator=(GeGraphDumper &&) = delete; - virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) = 0; - virtual ~GeGraphDumper() = default; -}; - -struct GraphDumperRegistry { - static GeGraphDumper &GetDumper(); - static void Register(GeGraphDumper &dumper); - static void Unregister(); -}; - -} // namespace ge - -#endif diff --git a/graph/utils/enum_attr_utils.cc b/graph/utils/enum_attr_utils.cc deleted file mode 100644 index a4cd43d08ae2413c83ee9c33d2d7e5dc6945d3f7..0000000000000000000000000000000000000000 --- a/graph/utils/enum_attr_utils.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/enum_attr_utils.h" -#include - -namespace ge { -void EnumAttrUtils::GetEnumAttrName(vector &enum_attr_names, const string &attr_name, string &enum_attr_name, - bool &is_new_attr) { - uint32_t position; - const auto iter = std::find(enum_attr_names.begin(), enum_attr_names.end(), attr_name); - if (iter != enum_attr_names.end()) { - is_new_attr = false; - position = static_cast(std::distance(enum_attr_names.begin(), iter)); - } else { - is_new_attr = true; - position = static_cast(enum_attr_names.size()); - enum_attr_names.emplace_back(attr_name); - } - Encode(position, enum_attr_name); -} - -void EnumAttrUtils::GetEnumAttrValue(vector &enum_attr_values, const string &attr_value, - int64_t &enum_attr_value) { - const auto iter = std::find(enum_attr_values.begin(), enum_attr_values.end(), attr_value); - if (iter != enum_attr_values.end()) { - enum_attr_value = static_cast(std::distance(enum_attr_values.begin(), iter)); - } else { - enum_attr_value = static_cast(enum_attr_values.size()); - enum_attr_values.emplace_back(attr_value); - } -} - -void EnumAttrUtils::GetEnumAttrValues(vector &enum_attr_values, const vector &attr_values, - vector &enum_values) { - int64_t enum_attr_value; - for (const auto &attr_value : attr_values) { - GetEnumAttrValue(enum_attr_values, attr_value, enum_attr_value); - enum_values.emplace_back(enum_attr_value); - } -} - -graphStatus EnumAttrUtils::GetAttrName(const vector &enum_attr_names, const vector name_use_string_values, - const string &enum_attr_name, string &attr_name, bool &is_value_string) { - if (enum_attr_name.empty()) { - GELOGE(GRAPH_FAILED, "enum_attr_name is empty."); - return GRAPH_FAILED; - } - static std::string prefix_value(kAppendNum, prefix); - // 判断enum_attr_name字符串是否Enum化,Enum字符串以'\0'字符开始 - if (enum_attr_name.rfind(prefix_value, 0U) == 0U) { - size_t position = 0U; - Decode(enum_attr_name, position); - if (position < enum_attr_names.size() && position < name_use_string_values.size()) { - attr_name = enum_attr_names[position]; - is_value_string = name_use_string_values[position]; - return GRAPH_SUCCESS; - } else { - GELOGE(GRAPH_FAILED, - "position[%zu] is not less than enum_attr_names size[%zu] or name_use_string_values size[%zu].", - position, enum_attr_names.size(), name_use_string_values.size()); - return GRAPH_FAILED; - } - } else { - attr_name = enum_attr_name; - is_value_string = false; - return GRAPH_SUCCESS; - } - return GRAPH_SUCCESS; -} - -graphStatus EnumAttrUtils::GetAttrValue(const vector &enum_attr_values, const int64_t enum_attr_value, - string &attr_value) { - if (static_cast(enum_attr_value) < enum_attr_values.size()) { - attr_value = enum_attr_values[enum_attr_value]; - return GRAPH_SUCCESS; - } else { - GELOGE(GRAPH_FAILED, "enum_attr_value[%lld] is not less than enum_attr_values size[%zu].", - enum_attr_value, enum_attr_values.size()); - return GRAPH_FAILED; - } -} - -graphStatus EnumAttrUtils::GetAttrValues(const vector &enum_attr_values, const vector &enum_values, - vector &attr_values) { - string attr_value; - for (const auto enum_attr_value : enum_values) { - if (GetAttrValue(enum_attr_values, enum_attr_value, attr_value) == GRAPH_SUCCESS) { - attr_values.emplace_back(attr_value); - } else { - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -// 属性名称定义为string类型,此处编码用Assci编码, Assci码的第一位为结束符,不用;可使用127位 -// 1位字符范围:[0, 126]; 两位字符范围:[0, 127^2 - 1]; N位字符的范围:[0, 127^N - 1] -void EnumAttrUtils::Encode(const uint32_t src, string &dst) { - // 按照上述字符范围获取源数据的位数 - uint32_t src_num = static_cast(log(src) / log(kMaxValueOfEachDigit)) + 1U; - - // 每个ENUM化字符串编码的前缀为'\0', 用于区分哪些字符串未做ENUM化 - dst.append(kAppendNum, prefix); - char_t data; - for (uint32_t i = 0U; i < src_num; i++) { - // 获取每一位的值,取位数后会加1,防止编码中出现'\0'字符 - data = static_cast((src / static_cast(pow(kMaxValueOfEachDigit, i))) % kMaxValueOfEachDigit); - dst.append(kAppendNum, data + 1); - } -} - -// 将127位编码的src转换成实际的数字 -void EnumAttrUtils::Decode(const string &src, size_t &dst) { - // 解码从第2位开始,第一位是标志符'\0' - for (size_t i = 1U; i < src.size(); i++) { - dst += static_cast(src[i] - 1) * static_cast(pow(kMaxValueOfEachDigit, (i - 1U))); - } -} - -} // namespace ge - diff --git a/graph/utils/execute_graph_adapter.cc b/graph/utils/execute_graph_adapter.cc deleted file mode 100644 index 4bda800d01e7924f63e68529eada3ea6a3d1c5c5..0000000000000000000000000000000000000000 --- a/graph/utils/execute_graph_adapter.cc +++ /dev/null @@ -1,198 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/execute_graph_adapter.h" -#include "common/checker.h" -#include "fast_graph/fast_graph_impl.h" -#include "graph/compute_graph.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/execute_graph_utils.h" -#include "graph/utils/graph_utils.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -namespace { -constexpr int32_t kHybridSubgraphRecursion = 32; -} // namespace - -ComputeGraphPtr ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph) { - GE_ASSERT_NOTNULL(src_graph); - const auto dst_graph = ComGraphMakeShared(src_graph->GetName()); - GE_ASSERT_NOTNULL(dst_graph); - const int32_t depth = 0; - GE_ASSERT_GRAPH_SUCCESS(ConvertExecuteGraphToComputeGraph(src_graph, dst_graph, depth), - "Convert execute graph:%s to compute graph failed.", src_graph->GetName().c_str()); - return dst_graph; -} - -graphStatus ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph, - const ComputeGraphPtr &dst_graph, - const int32_t depth) { - GE_ASSERT_TRUE(depth <= kHybridSubgraphRecursion, "param depth:%d larger than %d(allow max subgraphs).", depth, - kHybridSubgraphRecursion); - std::unordered_map all_new_nodes; - GE_ASSERT_GRAPH_SUCCESS(CopyOpAndSubgraph(src_graph, dst_graph, all_new_nodes, depth), - "Copy op and subgraph from %s to %s failed.", src_graph->GetName().c_str(), - dst_graph->GetName().c_str()); - - for (const auto &n : src_graph->graph_shared_->nodes_) { - GE_ASSERT_NOTNULL(n); - GE_ASSERT_GRAPH_SUCCESS(RelinkGraphEdges(&FastGraphUtils::GetNode(n), all_new_nodes), - "Relink edge for node %s failed.", FastGraphUtils::GetNode(n).GetNamePtr()); - } - - std::vector new_subgraphs; - const auto &old_subgraphs = src_graph->GetAllSubgraphs(); - for (const auto &sub_graph : old_subgraphs) { - const auto new_subgraph = dst_graph->GetSubgraph(sub_graph->GetName()); - GE_CHECK_NOTNULL(new_subgraph); - new_subgraphs.emplace_back(new_subgraph); - } - dst_graph->SetAllSubgraphs(new_subgraphs); - - GE_ASSERT_GRAPH_SUCCESS(CopyMembers(src_graph, dst_graph, all_new_nodes)); - - // inherit all attr from old graph to new graph - InheritOriginalAttr(src_graph, dst_graph); - return GRAPH_SUCCESS; -} - -graphStatus ExecuteGraphAdapter::CopyOpAndSubgraph(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, - std::unordered_map &all_new_nodes, - const int32_t depth) { - const auto src_root_graph = ExecuteGraphUtils::FindRootGraph(src_graph); - GE_ASSERT_NOTNULL(src_root_graph); - const auto dst_root_graph = GraphUtils::FindRootGraph(dst_graph); - GE_ASSERT_NOTNULL(dst_root_graph); - for (const auto &src_node : src_graph->graph_shared_->nodes_) { - GE_ASSERT_NOTNULL(src_node); - // 复用原图的OpDesc对象,原图不能释放 - const auto &op_desc = FastGraphUtils::GetNode(src_node).GetOpDescPtr(); - GE_ASSERT_NOTNULL(op_desc); - const auto &dst_node = dst_graph->AddNode(op_desc, op_desc->GetId()); - GE_ASSERT_NOTNULL(dst_node, "Add node:%s for dst graph failed.", op_desc->GetName().c_str()); - all_new_nodes[&FastGraphUtils::GetNode(src_node)] = dst_node.get(); - - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - const auto subgraph_num = subgraph_names.size(); - for (size_t subgrah_idx = 0U; subgrah_idx < subgraph_num; ++subgrah_idx) { - const auto &subgraph_name = subgraph_names[subgraph_num - 1U - subgrah_idx]; - const auto src_subgraph = src_root_graph->GetSubGraph(subgraph_name); - if ((src_subgraph == nullptr) && subgraph_name.empty()) { - continue; - } - GE_ASSERT_NOTNULL(src_subgraph); - const auto dst_subgraph = ComGraphMakeShared(src_subgraph->GetName()); - GE_ASSERT_NOTNULL(dst_subgraph); - dst_subgraph->SetParentGraph(dst_root_graph); - GE_ASSERT_GRAPH_SUCCESS(ConvertExecuteGraphToComputeGraph(src_subgraph, dst_subgraph, depth + 1), - "Copy subgraph from %s to %s failed.", src_subgraph->GetName().c_str(), - dst_subgraph->GetName().c_str()); - (void) dst_root_graph->AddSubGraph(dst_subgraph); - dst_subgraph->SetParentNode(dst_node); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ExecuteGraphAdapter::RelinkGraphEdges(FastNode *old_node, - const std::unordered_map &all_new_nodes) { - const auto &iter = all_new_nodes.find(old_node); - GE_ASSERT_TRUE(iter != all_new_nodes.end(), "all_new_nodes not contain %s", old_node->GetNamePtr()); - const auto &new_node = iter->second; - GE_ASSERT_NOTNULL(new_node); - const auto &old_out_edges = old_node->GetAllOutDataEdgesRef(); - for (size_t out_i = 0; out_i < old_out_edges.size(); ++out_i) { - for (const auto old_edge : old_out_edges[out_i]) { - if (old_edge == nullptr) { - continue; - } - const auto old_dst_node = old_edge->dst; - GE_ASSERT_NOTNULL(old_dst_node); - const auto dst_index = old_edge->dst_input; - - const auto &dst_iter = all_new_nodes.find(old_dst_node); - if (dst_iter != all_new_nodes.end()) { - const auto &new_dst_node = dst_iter->second; - GE_ASSERT_NOTNULL(new_dst_node); - GE_ASSERT_GRAPH_SUCCESS( - GraphUtils::AddEdge(new_node->GetOutDataAnchor(out_i), new_dst_node->GetInDataAnchor(dst_index)), - "Add edge %s:%d -> %s:%d failed.", new_node->GetName().c_str(), out_i, new_dst_node->GetName().c_str(), - dst_index); - } - } - } - - for (const auto old_control_out_edge : old_node->GetAllOutControlEdgesRef()) { - if (old_control_out_edge == nullptr) { - continue; - } - const auto old_dst_node = old_control_out_edge->dst; - GE_ASSERT_NOTNULL(old_dst_node); - - auto dst_iter = all_new_nodes.find(old_dst_node); - if (dst_iter != all_new_nodes.end()) { - const auto &new_dst_node = dst_iter->second; - GE_ASSERT_NOTNULL(new_dst_node); - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), new_dst_node->GetInControlAnchor()), - "Add control edge %s -> %s failed.", new_node->GetName().c_str(), - new_dst_node->GetName().c_str()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ExecuteGraphAdapter::CopyMembers(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, - const std::unordered_map &all_new_nodes) { - GE_ASSERT_NOTNULL(src_graph); - GE_ASSERT_NOTNULL(dst_graph); - GE_ASSERT_NOTNULL(src_graph->graph_shared_); - GE_ASSERT_NOTNULL(dst_graph->impl_); - - // copy info of output nodes from old graph to new graph. - const auto &out_nodes_info = src_graph->graph_shared_->GetAllOutNodeInfo(); - std::vector> new_out_nodes_info; - // ExecuteGraph未开放OutNodeInfo的接口,只是预埋了实现,以下流程暂时业务流程走不到 - for (const auto &info : out_nodes_info) { - GE_ASSERT_NOTNULL(info.first); - const auto it = all_new_nodes.find(info.first); - if (it != all_new_nodes.end()) { - new_out_nodes_info.emplace_back(it->second, info.second); - } - } - dst_graph->SetGraphOutNodesInfo(new_out_nodes_info); - - // copy info of input nodes from old graph to new graph. - const auto &input_nodes = src_graph->graph_shared_->GetAllInputNodeInfo(); - for (const auto &node : input_nodes) { - GE_ASSERT_NOTNULL(node); - const auto &it = all_new_nodes.find(node); - if (it != all_new_nodes.end()) { - (void) dst_graph->AddInputNode(it->second->shared_from_this()); - } - } - - // ExecuteGraph没有target信息 & other members - // graph属性序列化 - dst_graph->impl_->attrs_ = src_graph->attrs_; - return GRAPH_SUCCESS; -} - -void ExecuteGraphAdapter::InheritOriginalAttr(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph) { - const auto &original_attrs = AttrUtils::GetAllAttrs(src_graph); - for (const auto &attr_iter : original_attrs) { - if (dst_graph->TrySetAttr(attr_iter.first, attr_iter.second) != GRAPH_SUCCESS) { - GELOGW("Set inherit original attr[%s] failed, Please Check.", attr_iter.first.c_str()); - } - } - // copy ExtAttr to dst_graph - dst_graph->CopyFrom(*src_graph); -} -} // namespace ge diff --git a/graph/utils/execute_graph_utils.cc b/graph/utils/execute_graph_utils.cc deleted file mode 100644 index bcc5c672311919317a666b47bb5074962c4085a1..0000000000000000000000000000000000000000 --- a/graph/utils/execute_graph_utils.cc +++ /dev/null @@ -1,693 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/execute_graph_utils.h" - -#include "common/checker.h" -#include "graph/fast_graph/fast_graph_impl.h" -#include "graph/fast_graph/fast_graph_utils.h" -#include "graph/utils/fast_node_utils.h" -#include "graph/utils/op_type_utils.h" - -namespace ge { -namespace { -using InFastNodesToOut = std::map, FastNodeCompareKey>; -InFastNodesToOut GetFullConnectIONodes(const FastNode *fast_node) { - GE_ASSERT_NOTNULL(fast_node); - InFastNodesToOut in_nodes_to_out; - const auto &in_nodes = fast_node->GetAllInNodes(); - const auto &out_nodes = fast_node->GetAllOutNodes(); - for (const auto &in_node : in_nodes) { - (void) in_nodes_to_out.emplace(in_node, out_nodes); - } - return in_nodes_to_out; -} - -graphStatus ReplaceOutDataEdges(FastNode *new_node, const FastNode *old_node, const std::vector &outputs_map, - ExecuteGraph *graph) { - const auto &new_outs = new_node->GetAllOutDataEdgesRef(); - const auto new_out_size = new_outs.size(); - GE_ASSERT_TRUE(new_out_size >= outputs_map.size(), - "Failed to replace out data edge, the actual size %zu is less than the mapping size %zu", new_out_size, - outputs_map.size()); - const auto &old_outs = old_node->GetAllOutDataEdgesRef(); - for (size_t i = 0U; i < new_out_size; ++i) { - if (i >= outputs_map.size()) { - return GRAPH_SUCCESS; - } - const auto old_index = outputs_map[i]; - if ((old_index < 0) || (static_cast(old_index) >= old_outs.size())) { - continue; - } - - for (const auto old_edge : old_outs[old_index]) { - if (old_edge == nullptr) { - continue; - } - const auto dst_node = old_edge->dst; - GE_ASSERT_NOTNULL(dst_node); - const auto dst_input = old_edge->dst_input; - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_edge), "Remove edge %s:%d->%s:%d failed", old_node->GetNamePtr(), - old_index, dst_node->GetNamePtr(), dst_input); - GE_ASSERT_NOTNULL(graph->AddEdge(new_node, i, dst_node, dst_input), "Add edge %s:%d->%s:%d failed", - new_node->GetNamePtr(), i, dst_node->GetNamePtr(), dst_input); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ReplaceInDataEdges(FastNode *new_node, const FastNode *old_node, const std::vector &inputs_map, - ExecuteGraph *graph) { - const auto &new_ins = new_node->GetAllInDataEdgesRef(); - const auto new_in_size = new_ins.size(); - GE_ASSERT_TRUE(new_in_size >= inputs_map.size(), - "Failed to replace in data edge, the actual size %zu is less than the mapping size %zu", new_in_size, - inputs_map.size()); - const auto &old_ins = old_node->GetAllInDataEdgesRef(); - for (size_t i = 0U; i < new_in_size; ++i) { - if (i >= inputs_map.size()) { - return GRAPH_SUCCESS; - } - const auto old_index = inputs_map[i]; - if ((old_index < 0) || (static_cast(old_index) >= old_ins.size())) { - continue; - } - - const auto old_edge = old_ins[old_index]; - if (old_edge == nullptr) { - continue; - } - const auto src_node = old_edge->src; - GE_ASSERT_NOTNULL(src_node); - const auto src_output = old_edge->src_output; - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_edge), "Remove edge %s:%d->%s:%d failed", src_node->GetNamePtr(), - src_output, old_node->GetNamePtr(), old_index); - GE_ASSERT_NOTNULL(graph->AddEdge(src_node, src_output, new_node, i), "Add edge %s:%d->%s:%d failed", - src_node->GetNamePtr(), src_output, new_node->GetNamePtr(), i); - } - return GRAPH_SUCCESS; -} - -graphStatus ReplaceControlEdges(FastNode *new_node, const FastNode *old_node, ExecuteGraph *graph) { - const auto &new_control_in_edges = new_node->GetAllInControlEdges(); - const auto exist_control_in_edges = - std::unordered_set *>(new_control_in_edges.begin(), new_control_in_edges.end()); - for (const auto old_control_in_edge : old_node->GetAllInControlEdgesRef()) { - if ((old_control_in_edge == nullptr) || (exist_control_in_edges.count(old_control_in_edge) > 0U)) { - continue; - } - const auto src_node = old_control_in_edge->src; - GE_ASSERT_NOTNULL(src_node); - GE_ASSERT_NOTNULL(graph->AddEdge(src_node, kControlEdgeIndex, new_node, kControlEdgeIndex), - "Add control edge %s->%s failed", src_node->GetNamePtr(), new_node->GetNamePtr()); - } - - const auto &new_control_out_edges = new_node->GetAllOutControlEdges(); - const auto exist_control_out_edges = - std::unordered_set *>(new_control_out_edges.begin(), new_control_out_edges.end()); - for (const auto old_control_out_edge : old_node->GetAllOutControlEdgesRef()) { - if ((old_control_out_edge == nullptr) || (exist_control_out_edges.count(old_control_out_edge) > 0U)) { - continue; - } - const auto dst_node = old_control_out_edge->dst; - GE_ASSERT_NOTNULL(dst_node); - GE_ASSERT_NOTNULL(graph->AddEdge(new_node, kControlEdgeIndex, dst_node, kControlEdgeIndex), - "Add control edge %s->%s failed", new_node->GetNamePtr(), dst_node->GetNamePtr()); - } - return GRAPH_SUCCESS; -} - -graphStatus RelinkDataIO(ExecuteGraph *graph, FastNode *node, const std::vector &io_map, - InFastNodesToOut &in_nodes_to_out) { - const size_t out_data_endpoint_size = node->GetDataOutNum(); - const size_t in_data_endpoint_size = node->GetDataInNum(); - GE_ASSERT_TRUE(out_data_endpoint_size >= io_map.size(), - "The io_map specified for node %s type %s is larger %zu than the actual size %zu", - node->GetName().c_str(), node->GetType().c_str(), io_map.size(), out_data_endpoint_size); - const auto &all_in_data_edges = node->GetAllInDataEdgesRef(); - for (size_t i = 0U; i < out_data_endpoint_size; ++i) { - int32_t in_index = (i < io_map.size()) ? io_map[i] : -1; - if (in_index < 0) { - for (const auto old_out_edge : node->GetOutEdgesRefByIndex(i)) { - if (old_out_edge != nullptr) { - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_out_edge), "Remove out data edge for %s:%d failed.", - node->GetNamePtr(), i); - } - } - } else { - GE_ASSERT_TRUE(in_index < static_cast(in_data_endpoint_size), - "Failed to relink for node %s type %s, invalid index %d specified for input(%zu)", - node->GetName().c_str(), node->GetType().c_str(), in_index, in_data_endpoint_size); - const auto old_in_edge = all_in_data_edges[in_index]; - if (old_in_edge == nullptr) { - continue; - } - const auto src_node = old_in_edge->src; - GE_ASSERT_NOTNULL(src_node); - const auto src_output = old_in_edge->src_output; - GE_ASSERT_GRAPH_SUCCESS( - graph->RemoveEdge(old_in_edge), - "Failed relink node %s type %s, failed to unlink the data link from %s(%d) to it at input-index %d", - node->GetName().c_str(), node->GetType().c_str(), src_node->GetName().c_str(), src_output, in_index); - - for (const auto old_out_edge : node->GetOutEdgesRefByIndex(i)) { - if (old_out_edge == nullptr) { - continue; - } - const auto dst_node = old_out_edge->dst; - GE_ASSERT_NOTNULL(dst_node); - const auto dst_input = old_out_edge->dst_input; - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_out_edge), "Remove data edge %s:%d->%s:%d failed.", - node->GetNamePtr(), i, dst_node->GetNamePtr(), dst_input); - GE_ASSERT_NOTNULL(graph->AddEdge(src_node, src_output, dst_node, dst_input), - "Add data edge %s:%d->%s:%d failed.", src_node->GetNamePtr(), src_output, - dst_node->GetNamePtr(), dst_input); - in_nodes_to_out[src_node].emplace_back(dst_node); - } - } - } - - for (const auto in_data_edge : all_in_data_edges) { - if (in_data_edge != nullptr) { - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(in_data_edge), "Remove in data edge for node:%s failed.", - node->GetNamePtr()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus RelinkControlNodeIfNeed(ExecuteGraph *graph, const InFastNodesToOut &in_nodes_to_out, - InFastNodesToOut &connected_data_in_to_out) { - for (const auto &in_node_to_out : in_nodes_to_out) { - const auto in_node = in_node_to_out.first; - GE_ASSERT_NOTNULL(in_node); - const auto &connected_data_out = connected_data_in_to_out[in_node]; - const auto &out_control_nodes = in_node->GetOutControlNodes(); - const auto out_control_nodes_set = - std::unordered_set(out_control_nodes.begin(), out_control_nodes.end()); - for (const auto out_node : in_node_to_out.second) { - GE_ASSERT_NOTNULL(out_node); - if (std::find(connected_data_out.begin(), connected_data_out.end(), out_node) == connected_data_out.end()) { - if (out_control_nodes_set.count(out_node) > 0) { - continue; - } - GE_ASSERT_NOTNULL(graph->AddEdge(in_node, kControlEdgeIndex, out_node, kControlEdgeIndex), - "Add control edge %s->%s failed.", in_node->GetNamePtr(), out_node->GetNamePtr()); - } - } - } - return GRAPH_SUCCESS; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ExecuteGraph *ExecuteGraphUtils::FindRootGraph(ExecuteGraph *exe_graph) { - ExecuteGraph *result = nullptr; - while (exe_graph != nullptr) { - result = exe_graph; - exe_graph = result->GetParentGraphBarePtr(); - } - return result; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY FastNode *ExecuteGraphUtils::FindNodeFromAllNodes( - ExecuteGraph *exe_graph, const char_t *const name) { - GE_ASSERT_NOTNULL(exe_graph); - GE_ASSERT_NOTNULL(exe_graph->graph_shared_); - GE_ASSERT_NOTNULL(name); - const auto root_graph = FindRootGraph(exe_graph); - GE_ASSERT_NOTNULL(root_graph); - - const auto insert_func = [](const ExecuteGraph *const exe_graph, std::deque &candidates) -> void { - auto iter = exe_graph->graph_shared_->nodes_.end(); - while (iter != exe_graph->graph_shared_->nodes_.begin()) { - --iter; - (void) candidates.insert(candidates.begin(), &FastGraphUtils::GetNode(iter.element_)); - } - }; - std::deque candidates; - insert_func(exe_graph, candidates); - while (!candidates.empty()) { - const auto fast_node = candidates.front(); - candidates.pop_front(); - if (fast_node == nullptr) { - continue; - } - if (strcmp(fast_node->GetNamePtr(), name) == 0) { - return fast_node; - } - const auto op_desc = fast_node->GetOpDescBarePtr(); - if (op_desc != nullptr) { - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - auto name_iter = subgraph_names.rbegin(); - while (name_iter != subgraph_names.rend()) { - const auto subgraph = root_graph->GetSubGraph(*name_iter); - if ((subgraph != nullptr) && (subgraph->graph_shared_ != nullptr)) { - insert_func(subgraph, candidates); - } - ++name_iter; - } - } - } - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector ExecuteGraphUtils::FindNodesByTypeFromAllNodes( - ExecuteGraph *exe_graph, const char_t *const type) { - GE_ASSERT_NOTNULL(exe_graph); - GE_ASSERT_NOTNULL(type); - const auto root_graph = FindRootGraph(exe_graph); - GE_ASSERT_NOTNULL(root_graph); - - std::vector nodes; - for (const auto node : root_graph->GetAllNodes()) { - if ((node != nullptr) && (strcmp(node->GetTypePtr(), type) == 0)) { - nodes.emplace_back(node); - } - } - return nodes; -} - -FastNode *ExecuteGraphUtils::FindFirstNodeMatchType(ExecuteGraph *exe_graph, const char_t *const type) { - GE_ASSERT_NOTNULL(exe_graph); - GE_ASSERT_NOTNULL(exe_graph->graph_shared_); - GE_ASSERT_NOTNULL(type); - for (const auto &node : exe_graph->graph_shared_->nodes_) { - if ((node != nullptr) && (strcmp(FastGraphUtils::GetNode(node).GetTypePtr(), type) == 0)) { - return &FastGraphUtils::GetNode(node); - } - } - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::InsertNodeAfter( - const EdgeSrcEndpoint &src, const std::vector &dsts, FastNode *insert_node, - const uint32_t input_index, const uint32_t output_index) { - GE_ASSERT_NOTNULL(insert_node); - const auto src_node = src.node; - GE_ASSERT_NOTNULL(src_node); - const auto src_extend_info = src_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(src_extend_info, "The extend info of src node:% is null", src_node->GetNamePtr()); - const auto graph = src_extend_info->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph, "The own graph of src node:% is null", src_node->GetNamePtr()); - GE_ASSERT_NOTNULL(insert_node->GetExtendInfo(), "The extend info of insert node:% is null", - insert_node->GetNamePtr()); - GE_ASSERT_TRUE(graph == insert_node->GetExtendInfo()->GetOwnerGraphBarePtr(), - "rc:%s and insert_node:%s does not exist in the same graph.", src_node->GetNamePtr(), - insert_node->GetNamePtr()); - - const auto src_index = src.index; - GE_ASSERT_NOTNULL(graph->AddEdge(src_node, src_index, insert_node, input_index), "[Add][Edge] %s:%d->%s:%d failed.", - src_node->GetNamePtr(), src_index, insert_node->GetNamePtr(), input_index); - for (const auto &dst : dsts) { - const auto dst_node = dst.node; - GE_ASSERT_NOTNULL(dst_node); - const auto dst_index = dst.index; - const auto dst_extend_info = dst_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(dst_extend_info, "The extend info of src node:% is null", dst_node->GetNamePtr()); - GE_ASSERT_TRUE(graph == dst_extend_info->GetOwnerGraphBarePtr(), - "[Check][Param] dst:%s and insert_node:%s does not exist in the same graph.", dst_node->GetNamePtr(), - insert_node->GetNamePtr()); - - const auto old_edge = dst_node->GetInDataEdgeByIndex(dst_index); - if (old_edge != nullptr) { - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_edge), "Remove input edge %s:%d failed.", dst_node->GetNamePtr(), - dst_index); - } - GE_ASSERT_NOTNULL(graph->AddEdge(insert_node, output_index, dst_node, dst_index), "Add edge %s:%d->%s:%d failed.", - insert_node->GetNamePtr(), output_index, dst_node->GetNamePtr(), dst_index); - for (const auto &old_ctrl_edge : src_node->GetAllOutControlEdgesRef()) { - if (old_ctrl_edge == nullptr) { - continue; - } - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_ctrl_edge), "Remove out control edge for %s failed.", - src_node->GetNamePtr()); - GE_ASSERT_NOTNULL(graph->AddEdge(insert_node, kControlEdgeIndex, dst_node, kControlEdgeIndex), - "Add control edge %s->%s failed.", insert_node->GetNamePtr(), dst_node->GetNamePtr()); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::InsertNodeBefore( - const EdgeDstEndpoint &dst, FastNode *insert_node, const uint32_t input_index, const uint32_t output_index) { - GE_ASSERT_NOTNULL(insert_node); - const auto dst_node = dst.node; - GE_ASSERT_NOTNULL(dst_node); - const auto dst_extend_info = dst_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(dst_extend_info, "The extend info of src node:% is null", dst_node->GetNamePtr()); - const auto graph = dst_extend_info->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph, "The own graph of src node:% is null", dst_node->GetNamePtr()); - GE_ASSERT_NOTNULL(insert_node->GetExtendInfo(), "The extend info of insert node:% is null", - insert_node->GetNamePtr()); - GE_ASSERT_TRUE(graph == insert_node->GetExtendInfo()->GetOwnerGraphBarePtr(), - "[Check][Param] src:%s and insert_node:%s does not exist in the same graph.", dst_node->GetNamePtr(), - insert_node->GetNamePtr()); - - const auto dst_index = dst.index; - const auto old_edge = dst_node->GetInDataEdgeByIndex(dst_index); - GE_ASSERT_NOTNULL(old_edge, "The input edge %s:%d is nullptr.", dst_node->GetNamePtr(), dst_index); - const auto src_node = old_edge->src; - GE_ASSERT_NOTNULL(src_node, "The src of %s:%d is nullptr.", dst_node->GetNamePtr(), dst_index); - const auto src_index = old_edge->src_output; - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_edge), "Remove edge %s:%d->%s:%d failed.", src_node->GetNamePtr(), - src_index, dst_node->GetNamePtr(), dst_index); - GE_ASSERT_NOTNULL(graph->AddEdge(src_node, src_index, insert_node, input_index), "Add edge %s:%d->%s:%d failed.", - src_node->GetNamePtr(), src_index, insert_node->GetNamePtr(), input_index); - GE_ASSERT_NOTNULL(graph->AddEdge(insert_node, output_index, dst_node, dst_index), "Add edge %s:%d->%s:%d failed.", - insert_node->GetNamePtr(), output_index, dst_node->GetNamePtr(), dst_index); - - for (const auto old_ctrl_edge : dst_node->GetAllInControlEdgesRef()) { - if (old_ctrl_edge == nullptr) { - continue; - } - const auto src_ctrl_node = old_ctrl_edge->src; - GE_ASSERT_NOTNULL(src_ctrl_node); - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(old_ctrl_edge), "Remove ctrl edge %s->%s failed.", - src_ctrl_node->GetNamePtr(), dst_node->GetNamePtr()); - GE_ASSERT_NOTNULL(graph->AddEdge(src_ctrl_node, kControlEdgeIndex, insert_node, kControlEdgeIndex), - "Add ctrl edge %s->%s failed.", src_ctrl_node->GetNamePtr(), insert_node->GetNamePtr()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::CopyInCtrlEdges(const FastNode *src_node, - FastNode *dst_node) { - GE_ASSERT_NOTNULL(src_node); - GE_ASSERT_NOTNULL(dst_node); - - const auto &src_ctrl_in_nodes = src_node->GetInControlNodes(); - if (src_ctrl_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - std::unordered_set exist_in_ctrl_nodes_set; - const auto &exist_in_ctrl_nodes = dst_node->GetInControlNodes(); - if (!exist_in_ctrl_nodes.empty()) { - exist_in_ctrl_nodes_set.insert(exist_in_ctrl_nodes.begin(), exist_in_ctrl_nodes.end()); - } - - const auto src_extend_info = src_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(src_extend_info, "The extend info of src node:% is null", src_node->GetNamePtr()); - const auto graph = src_extend_info->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph, "The graph of src node:% is null", src_node->GetNamePtr()); - for (const auto in_node : src_ctrl_in_nodes) { - if ((in_node != nullptr) && (exist_in_ctrl_nodes_set.count(in_node) == 0U)) { - GE_ASSERT_NOTNULL(graph->AddEdge(in_node, kControlEdgeIndex, dst_node, kControlEdgeIndex), - "Add ctrl edge %s->%s failed.", in_node->GetNamePtr(), dst_node->GetNamePtr()); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::MoveInCtrlEdges(const FastNode *src_node, - FastNode *dst_node) { - GE_ASSERT_NOTNULL(src_node); - GE_ASSERT_NOTNULL(dst_node); - GE_ASSERT_GRAPH_SUCCESS(CopyInCtrlEdges(src_node, dst_node), "Copy in ctrl edges failed, src_node:%s, dst_node:%s", - src_node->GetNamePtr(), dst_node->GetNamePtr()); - - const auto src_extend_info = src_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(src_extend_info, "The extend info of src node:% is null", src_node->GetNamePtr()); - const auto graph = src_extend_info->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph, "The graph of src node:% is null", src_node->GetNamePtr()); - for (const auto src_in_ctrl_edge : src_node->GetAllInControlEdgesRef()) { - if (src_in_ctrl_edge != nullptr) { - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(src_in_ctrl_edge), "Remove in ctrl edge for %s failed.", - src_node->GetNamePtr()); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::CopyOutCtrlEdges(const FastNode *src_node, - FastNode *dst_node) { - GE_ASSERT_NOTNULL(src_node); - GE_ASSERT_NOTNULL(dst_node); - - const auto &out_ctrl_nodes = src_node->GetOutControlNodes(); - if (out_ctrl_nodes.empty()) { - return GRAPH_SUCCESS; - } - - std::unordered_set exist_out_ctrl_nodes_set; - const auto &exist_out_ctrl_nodes = dst_node->GetOutControlNodes(); - if (!exist_out_ctrl_nodes.empty()) { - exist_out_ctrl_nodes_set.insert(exist_out_ctrl_nodes.begin(), exist_out_ctrl_nodes.end()); - } - - const auto src_extend_info = src_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(src_extend_info, "The extend info of src node:% is null", src_node->GetNamePtr()); - const auto graph = src_extend_info->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph, "The graph of src node:% is null", src_node->GetNamePtr()); - for (const auto out_node : out_ctrl_nodes) { - if ((out_node != nullptr) && (exist_out_ctrl_nodes_set.count(out_node) == 0U)) { - GE_ASSERT_NOTNULL(graph->AddEdge(dst_node, kControlEdgeIndex, out_node, kControlEdgeIndex), - "Add ctrl edge %s->%s failed.", dst_node->GetNamePtr(), out_node->GetNamePtr()); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::MoveOutCtrlEdges(const FastNode *src_node, - FastNode *dst_node) { - GE_ASSERT_NOTNULL(src_node); - GE_ASSERT_NOTNULL(dst_node); - GE_ASSERT_GRAPH_SUCCESS(CopyOutCtrlEdges(src_node, dst_node), "Copy out ctrl edges failed, src_node:%s, dst_node:%s", - src_node->GetNamePtr(), dst_node->GetNamePtr()); - - const auto src_extend_info = src_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(src_extend_info, "The extend info of src node:% is null", src_node->GetNamePtr()); - const auto graph = src_extend_info->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph, "The graph of src node:% is null", src_node->GetNamePtr()); - for (const auto src_out_ctrl_edge : src_node->GetAllOutControlEdgesRef()) { - if (src_out_ctrl_edge != nullptr) { - GE_ASSERT_GRAPH_SUCCESS(graph->RemoveEdge(src_out_ctrl_edge), "Remove out ctrl edge for %s failed.", - src_node->GetNamePtr()); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::MoveNodeToGraph(FastNode *node, - ExecuteGraph *dst_graph) { - GE_ASSERT_GRAPH_SUCCESS(IsolateNode(node, {})); - GE_ASSERT_NOTNULL(node->GetExtendInfo(), "EntendInfo of node %s is null.", node->GetNamePtr()); - GE_ASSERT_GRAPH_SUCCESS(RemoveNodeWithoutRelink(node->GetExtendInfo()->GetOwnerGraphBarePtr(), node)); - GE_ASSERT_NOTNULL(dst_graph->AddNode(node)); - GE_ASSERT_GRAPH_SUCCESS(node->GetExtendInfo()->SetOwnerGraph(dst_graph, node)); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::ReplaceNodeDataEdges( - FastNode *new_node, FastNode *old_node, const std::initializer_list inputs_map, - const std::initializer_list outputs_map, ExecuteGraph *graph) { - return ReplaceNodeDataEdges(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map), - graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ExecuteGraphUtils::ReplaceNodeDataEdges(FastNode *new_node, FastNode *old_node, const std::vector &inputs_map, - const std::vector &outputs_map, ExecuteGraph *graph) { - GE_ASSERT_NOTNULL(new_node); - GE_ASSERT_NOTNULL(old_node); - if (graph == nullptr) { - GE_ASSERT_NOTNULL(new_node->GetExtendInfo()); - graph = new_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph); - GE_ASSERT_NOTNULL(old_node->GetExtendInfo()); - GE_ASSERT_TRUE(graph == old_node->GetExtendInfo()->GetOwnerGraphBarePtr()); - } - - GE_ASSERT_GRAPH_SUCCESS(ReplaceOutDataEdges(new_node, old_node, outputs_map, graph), - "Failed when replace node outputs from old node %s type %s to new node %s type %s", - old_node->GetNamePtr(), old_node->GetTypePtr(), new_node->GetNamePtr(), - new_node->GetTypePtr()); - GE_ASSERT_GRAPH_SUCCESS(ReplaceInDataEdges(new_node, old_node, inputs_map, graph), - "Failed when replace node inputs from old node %s type %s to new node %s type %s", - old_node->GetNamePtr(), old_node->GetTypePtr(), new_node->GetNamePtr(), - new_node->GetTypePtr()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ExecuteGraphUtils::ReplaceNodeEdges( - FastNode *new_node, FastNode *old_node, const std::initializer_list inputs_map, - const std::initializer_list outputs_map) { - return ReplaceNodeEdges(new_node, old_node, std::vector(inputs_map), std::vector(outputs_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ExecuteGraphUtils::ReplaceNodeEdges(FastNode *new_node, FastNode *old_node, const std::vector &inputs_map, - const std::vector &outputs_map) { - GE_ASSERT_NOTNULL(new_node); - GE_ASSERT_NOTNULL(old_node); - GE_ASSERT_NOTNULL(new_node->GetExtendInfo()); - const auto graph = new_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph); - GE_ASSERT_NOTNULL(old_node->GetExtendInfo()); - GE_ASSERT_TRUE(graph == old_node->GetExtendInfo()->GetOwnerGraphBarePtr()); - GE_ASSERT_GRAPH_SUCCESS(ReplaceNodeDataEdges(new_node, old_node, inputs_map, outputs_map, graph), - "Replace data edgs from %s to %s failed.", old_node->GetNamePtr(), new_node->GetNamePtr()); - GE_ASSERT_GRAPH_SUCCESS(ReplaceControlEdges(new_node, old_node, graph), "Replace control edgs from %s to %s failed.", - old_node->GetNamePtr(), new_node->GetNamePtr()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ExecuteGraphUtils::IsolateNode(FastNode *node, const std::initializer_list &io_map) { - return IsolateNode(node, std::vector(io_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ExecuteGraphUtils::IsolateNode(FastNode *node, const std::vector &io_map) { - GE_ASSERT_NOTNULL(node); - const auto &in_nodes_to_out = GetFullConnectIONodes(node); - GE_ASSERT_NOTNULL(node->GetExtendInfo()); - const auto graph = node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph); - InFastNodesToOut data_in_to_out; - GE_ASSERT_GRAPH_SUCCESS(RelinkDataIO(graph, node, io_map, data_in_to_out), "Relink data io failed for node %s", - node->GetNamePtr()); - GE_ASSERT_GRAPH_SUCCESS(RelinkControlNodeIfNeed(graph, in_nodes_to_out, data_in_to_out), - "Relink control io failed for node %s", node->GetNamePtr()); - FastNodeUtils::UnlinkAll(node); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ExecuteGraphUtils::ReplaceEdgeSrc(FastEdge *old_edge, const EdgeSrcEndpoint &new_src) { - GE_ASSERT_NOTNULL(old_edge); - const auto dst_node = old_edge->dst; - GE_ASSERT_NOTNULL(dst_node); - GE_ASSERT_NOTNULL(dst_node->GetExtendInfo()); - const auto graph = dst_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph, "Failed to replace edge source, node %s has null root graph", dst_node->GetNamePtr()); - if ((graph->RemoveEdge(old_edge) == GRAPH_SUCCESS) && - (graph->AddEdge(new_src.node, new_src.index, dst_node, old_edge->dst_input) != nullptr)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "Replace edge failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ExecuteGraphUtils::RemoveSubgraphRecursively(ExecuteGraph *execute_graph, FastNode *remove_node) { - GE_ASSERT_NOTNULL(execute_graph); - GE_ASSERT_NOTNULL(remove_node); - GE_ASSERT_NOTNULL(remove_node->GetOpDescBarePtr()); - if (remove_node->GetOpDescBarePtr()->GetSubgraphInstanceNames().empty()) { - GELOGD("Node %s has no subgraph.", remove_node->GetName().c_str()); - return GRAPH_SUCCESS; - } - - const auto remove_extend_info = remove_node->GetExtendInfo(); - GE_ASSERT_NOTNULL(remove_extend_info); - if (remove_extend_info->GetOwnerGraphBarePtr() == nullptr) { - GELOGW("Node %s has a owner graph with null value.", remove_node->GetNamePtr()); - return GRAPH_SUCCESS; - } - - if ((remove_extend_info->GetOwnerGraphBarePtr() != execute_graph) && - !execute_graph->CheckNodeIsInGraph(remove_node)) { - GELOGW("Can not find node %s in graph %s.", remove_node->GetName().c_str(), execute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - // find all subgraphs connecting to remove_node - const auto root_graph = FindRootGraph(execute_graph); - std::vector subgraphs_to_remove; - std::deque nodes_to_visit; - nodes_to_visit.push_back(remove_node); - const auto insert_func = [](const ExecuteGraph *const exe_graph, std::deque &candidates) -> void { - auto iter = exe_graph->graph_shared_->nodes_.end(); - while (iter != exe_graph->graph_shared_->nodes_.begin()) { - --iter; - (void) candidates.insert(candidates.begin(), &FastGraphUtils::GetNode(iter.element_)); - } - }; - while (!nodes_to_visit.empty()) { - const auto curr_node = nodes_to_visit.front(); - nodes_to_visit.pop_front(); - const OpDesc *op_desc = curr_node->GetOpDescBarePtr(); - if (op_desc != nullptr) { - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (const auto &name : subgraph_names) { - const auto subgraph = root_graph->GetSubGraph(name); - if ((subgraph != nullptr) && (subgraph->graph_shared_ != nullptr)) { - subgraphs_to_remove.emplace_back(subgraph); - insert_func(subgraph, nodes_to_visit); - } - } - } - } - - // remove all subgraphs - for (const auto &remove_graph : subgraphs_to_remove) { - GE_ASSERT_GRAPH_SUCCESS(root_graph->RemoveSubGraph(remove_graph), - "[Remove][SubGraph] failed, sub graph name is %s, execute graph is %s.", - remove_node->GetNamePtr(), execute_graph->GetName().c_str()); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -ExecuteGraphUtils::RemoveNodeWithoutRelink(ExecuteGraph *execute_graph, FastNode *node) { - GE_ASSERT_NOTNULL(execute_graph); - GE_ASSERT_NOTNULL(node, "param node is nullptr, check invalid."); - // If the node save as input node, delete it - (void) execute_graph->RemoveInputNode(node); - - // If the node save as output node, delete it - (void) execute_graph->RemoveOutputNode(node); - - // If the node has sub-graphs, delete them - GE_ASSERT_GRAPH_SUCCESS(RemoveSubgraphRecursively(execute_graph, node), "Remove subgraph of node %s failed.", - node->GetNamePtr()); - if (execute_graph->CheckNodeIsInGraph(node)) { - GE_ASSERT_GRAPH_SUCCESS(execute_graph->RemoveJustNode(node), "Remove node %s failed.", node->GetNamePtr()); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::unordered_map -ExecuteGraphUtils::GetNodeMapFromAllNodes(ExecuteGraph *exe_graph) { - GE_ASSERT_NOTNULL(exe_graph); - GE_ASSERT_NOTNULL(exe_graph->graph_shared_); - const auto root_graph = FindRootGraph(exe_graph); - GE_ASSERT_NOTNULL(root_graph); - - const auto insert_func = [](const ExecuteGraph *const exe_graph, std::deque &candidates) -> void { - auto iter = exe_graph->graph_shared_->nodes_.end(); - while (iter != exe_graph->graph_shared_->nodes_.begin()) { - --iter; - (void) candidates.insert(candidates.begin(), &FastGraphUtils::GetNode(iter.element_)); - } - }; - std::deque candidates; - insert_func(exe_graph, candidates); - std::unordered_map node_name_to_nodes; - while (!candidates.empty()) { - const auto fast_node = candidates.front(); - candidates.pop_front(); - if ((fast_node == nullptr) || (fast_node->GetOpDescBarePtr() == nullptr)) { - continue; - } - node_name_to_nodes.emplace(fast_node->GetName(), fast_node); - const auto &subgraph_names = fast_node->GetOpDescBarePtr()->GetSubgraphInstanceNames(); - auto name_iter = subgraph_names.rbegin(); - while (name_iter != subgraph_names.rend()) { - const auto subgraph = root_graph->GetSubGraph(*name_iter); - if ((subgraph != nullptr) && (subgraph->graph_shared_ != nullptr)) { - insert_func(subgraph, candidates); - } - ++name_iter; - } - } - return node_name_to_nodes; -} -} // namespace ge diff --git a/graph/utils/fast_node_utils.cc b/graph/utils/fast_node_utils.cc deleted file mode 100644 index 82af50de3e851e554d49b7161b9c39e47fea97d5..0000000000000000000000000000000000000000 --- a/graph/utils/fast_node_utils.cc +++ /dev/null @@ -1,233 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "inc/graph/utils/fast_node_utils.h" - -#include "inc/common/checker.h" -#include "inc/graph/compiler_options.h" -#include "inc/graph/ge_tensor.h" -#include "inc/graph/utils/execute_graph_utils.h" -#include "inc/graph/debug/ge_attr_define.h" -#include "graph/fast_graph/fast_graph_utils.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_log.h" - -namespace ge { -FastNode *FastNodeUtils::GetInDataNodeByIndex(const FastNode *const node, const int32_t index) { - GE_ASSERT_NOTNULL(node); - const auto in_data_edge = node->GetInDataEdgeByIndex(index); - GE_ASSERT_NOTNULL(in_data_edge); - return in_data_edge->src; -} - -FastNode *FastNodeUtils::GetParentInput(const FastNode *const node) { - GE_ASSERT_NOTNULL(node); - uint32_t parent_index = 0U; - if (!AttrUtils::GetInt(node->GetOpDescPtr(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - return nullptr; - } - - // Subgraph Data Node, check for constant input. - GE_ASSERT_NOTNULL(node->GetExtendInfo(), "EntendInfo of node %s is null.", node->GetNamePtr()); - const auto graph = node->GetExtendInfo()->GetOwnerGraphBarePtr(); - GE_ASSERT_NOTNULL(graph); - - const auto parent_node = graph->GetParentNodeBarePtr(); - if (parent_node == nullptr) { - GELOGW("Node {%s %s} has attr %s but has no parent node.", node->GetNamePtr(), node->GetTypePtr(), - ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return nullptr; - } - - const auto edge = parent_node->GetInDataEdgeByIndex(static_cast(parent_index)); - GE_ASSERT_NOTNULL(edge); - const auto src_node = edge->src; - GE_ASSERT_NOTNULL(src_node); - - if (src_node->GetType() == DATA) { - GE_ASSERT_NOTNULL(src_node->GetOpDescBarePtr()); - if (src_node->GetOpDescBarePtr()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { - return GetParentInput(src_node); - } - } - return src_node; -} - -bool FastNodeUtils::GetConstOpType(const FastNode *const node) { - if (node == nullptr) { - return false; - } - - const auto &node_type = node->GetType(); - if ((node_type == CONSTANT) || (node_type == CONSTANTOP) || (node_type == FILECONSTANT)) { - return true; - } - - if (node_type != DATA) { - return false; // not subgraph input node - } - - const auto parent_node = GetParentInput(node); - return GetConstOpType(parent_node); -} - -graphStatus FastNodeUtils::AppendSubgraphToNode(FastNode *const node, const std::string &subgraph_name, - const ExecuteGraphPtr &subgraph) { - GE_ASSERT_NOTNULL(node); - GE_ASSERT_NOTNULL(subgraph); - auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - - GE_ASSERT_GRAPH_SUCCESS(op_desc->AddSubgraphName(subgraph_name)); - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - const auto &iter = subgraph_names_to_index.find(subgraph_name); - GE_ASSERT_TRUE(iter != subgraph_names_to_index.cend()); - - return MountSubgraphToNode(node, iter->second, subgraph); -} - -ExecuteGraph *FastNodeUtils::GetSubgraphFromNode(const FastNode *const node, const uint32_t index) { - GE_ASSERT_NOTNULL(node); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - - GE_ASSERT_NOTNULL(node->GetExtendInfo(), "EntendInfo of node %s is null.", node->GetNamePtr()); - const auto root_graph = ExecuteGraphUtils::FindRootGraph(node->GetExtendInfo()->GetOwnerGraphBarePtr()); - GE_ASSERT_NOTNULL(root_graph); - return root_graph->GetSubGraph(op_desc->GetSubgraphInstanceName(index)); -} - -graphStatus FastNodeUtils::MountSubgraphToNode(FastNode *const node, const uint32_t index, const ExecuteGraphPtr &subgraph) { - GE_ASSERT_NOTNULL(node); - GE_ASSERT_NOTNULL(subgraph, "[Check][Param] Failed to set subgraph to node %s index %u, null subgraph", - node->GetNamePtr(), index); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - - GE_ASSERT_NOTNULL(node->GetExtendInfo(), "EntendInfo of node %s is null.", node->GetNamePtr()); - const auto root_graph = ExecuteGraphUtils::FindRootGraph(node->GetExtendInfo()->GetOwnerGraphBarePtr()); - GE_ASSERT_NOTNULL(root_graph, "[Get][Graph] Failed to add subgraph to node %s, null root graph", node->GetNamePtr()); - - const auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); - GE_CHK_GRAPH_STATUS_RET(ret, "[Set][Name] Failed to set subgraph to node %s index %u", node->GetNamePtr(), index); - - subgraph->SetParentNode(node); - GE_ASSERT_NOTNULL(node->GetExtendInfo(), "EntendInfo of node %s is null.", node->GetNamePtr()); - subgraph->SetParentGraph(node->GetExtendInfo()->GetOwnerGraphBarePtr()); - - return (root_graph->AddSubGraph(const_cast(subgraph)) != nullptr) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus FastNodeUtils::AppendInputEdgeInfo(FastNode *const node, - const uint32_t num) { - GE_ASSERT_NOTNULL(node); - - const GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - for (size_t i = op_desc->GetAllInputsSize(); i < num; ++i) { - GE_CHK_GRAPH_STATUS_RET(op_desc->AddInputDesc(data_desc), "[Add][InputDesc] failed, op:%s", op_desc->GetNamePtr()); - } - node->UpdateDataInNum(num); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus FastNodeUtils::AppendOutputEdgeInfo(FastNode *const node, - const uint32_t num) { - GE_ASSERT_NOTNULL(node); - - const GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const auto op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { - GE_CHK_GRAPH_STATUS_RET(op_desc->AddOutputDesc(data_desc), "[Add][OutputDesc] failed, op:%s", - op_desc->GetNamePtr()); - } - node->UpdateDataOutNum(num); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool FastNodeUtils::ClearInputDesc(const OpDesc *const op_desc, - const uint32_t index) { - GE_ASSERT_NOTNULL(op_desc); - GE_ASSERT_NOTNULL(op_desc->impl_); - GE_ASSERT_TRUE((index < op_desc->impl_->inputs_desc_.size()), - "[Check][Param] index %u is invalid, out of range(0, %zu).", index, - op_desc->impl_->inputs_desc_.size()); - - const auto iter = op_desc->impl_->inputs_desc_.begin() + static_cast(index); - (void)op_desc->impl_->inputs_desc_.erase(iter); - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus FastNodeUtils::RemoveInputEdgeInfo(FastNode *const node, - const uint32_t num) { - GE_ASSERT_NOTNULL(node); - - const auto &op_desc = node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - uint32_t input_size = op_desc->GetAllInputsSize(); - while (input_size > num) { - if (!FastNodeUtils::ClearInputDesc(op_desc, input_size - 1)) { - return GRAPH_FAILED; - } - --input_size; - } - - const auto &input_names = op_desc->GetAllInputName(); - (void)op_desc->UpdateInputName(input_names); - auto is_input_consts = op_desc->GetIsInputConst(); - is_input_consts.resize(static_cast(num)); - op_desc->SetIsInputConst(is_input_consts); - - node->UpdateDataInNum(num); - - return GRAPH_SUCCESS; -} - -void FastNodeUtils::UnlinkAll(FastNode *const node) { - if (node == nullptr || node->GetExtendInfo() == nullptr) { - GELOGW("param node is null or node's extend info is null."); - return; - } - const auto owner_graph = node->GetExtendInfo()->GetOwnerGraphBarePtr(); - const auto remove_edge_func = [&owner_graph](FastEdge *e) { - if (e->src != nullptr) { - e->src->EraseEdge(e, DirectionType::kDirectionOutType); - e->src = nullptr; - } - if (e->dst != nullptr) { - e->dst->EraseEdge(e, DirectionType::kDirectionInType); - e->dst = nullptr; - } - if (FastGraphUtils::GetListElementAddr(e)->owner != nullptr) { - FastGraphUtils::GetListElementAddr(e)->owner->erase(FastGraphUtils::GetListElementAddr(e)); - } - auto ret = owner_graph->RecycleQuickEdge(e); - if ((ret != GRAPH_SUCCESS) && (e != nullptr)) { - delete e; - } - }; - node->RemoveAllEdge(remove_edge_func); -} - -EdgeDstEndpoint FastNodeUtils::GetDstEndpoint(const FastEdge *const edge) { - GE_ASSERT_NOTNULL(edge); - return {edge->dst, edge->dst_input}; -} - -EdgeSrcEndpoint FastNodeUtils::GetSrcEndpoint(const FastEdge *const edge) { - GE_ASSERT_NOTNULL(edge); - return {edge->src, edge->src_output}; -} -} // namespace ge diff --git a/graph/utils/ffts_graph_utils.cc b/graph/utils/ffts_graph_utils.cc deleted file mode 100644 index 3a7753e6c564ce3ad52e685b921223f87f5f5d55..0000000000000000000000000000000000000000 --- a/graph/utils/ffts_graph_utils.cc +++ /dev/null @@ -1,602 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/ffts_graph_utils.h" - -#include - -#include "graph/debug/ge_util.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_op_types.h" - -namespace { -static uint32_t g_fftsPlusSubgraphNum = 0U; -const uint32_t kMaxiumRecursionDepth = 10U; -} - -namespace ge { -graphStatus FftsGraphUtils::GraphPartition(ComputeGraph &graph, const std::set &unsupported_nodes) { - if (unsupported_nodes.empty()) { - GELOGI("Graph:%s, no node is unsupported, skip clipping", graph.GetName().c_str()); - return SUCCESS; - } - - const auto &ffts_plus_graph = GetFftsPlusGraph(graph); - GE_CHECK_NOTNULL(ffts_plus_graph); - std::unordered_set nodes_need_clip; - std::unordered_set graphs_need_split; - GE_CHK_STATUS_RET(CollectClipNodesAndGraphs(ffts_plus_graph, unsupported_nodes, nodes_need_clip, graphs_need_split), - "[Collect][NeedClip] nodes and subgraphs in graph %s failed", ffts_plus_graph->GetName().c_str()); - if (nodes_need_clip.empty() && graphs_need_split.empty()) { - GELOGI("Graph:%s, no node/subgraph need to be clipped, skip", ffts_plus_graph->GetName().c_str()); - return SUCCESS; - } - const auto &parent_node = ffts_plus_graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - // op_desc of node should not be null - (void)parent_node->GetOpDesc()->DelAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); - - (void)graphs_need_split.emplace(ffts_plus_graph); - for (const auto &subgraph : graphs_need_split) { - if (IsGraphNeedSplit(subgraph, nodes_need_clip)) { - std::vector>> split_nodes; - GE_CHK_STATUS_RET(SplitNodesWithCheck(subgraph, nodes_need_clip, split_nodes), - "[Split][Nodes] failed, graph:%s", subgraph->GetName().c_str()); - GE_CHK_STATUS_RET(SplitSubgraph(subgraph, split_nodes), - "[Split][Subgraph] %s failed", subgraph->GetName().c_str()); - } else { - GE_CHK_STATUS_RET(BuildFftsPlusSubgraphWithAllNodes(subgraph), - "[Build][FftsPlusSubgraph] failed, graph:%s", subgraph->GetName().c_str()); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus FftsGraphUtils::CollectClipNodesAndGraphs(const ComputeGraphPtr &graph, - const std::set &unsupported_nodes, - std::unordered_set &nodes_need_clip, - std::unordered_set &graphs_need_split) { - for (const auto &node : graph->GetAllNodes()) { - if (unsupported_nodes.count(node) == 0U) { - continue; - } - - (void)nodes_need_clip.emplace(node); - ComputeGraphPtr cur_graph = node->GetOwnerComputeGraph(); - while (cur_graph != graph) { - const auto &parent_node = cur_graph->GetParentNode(); - if (parent_node == nullptr) { - break; - } - (void)nodes_need_clip.emplace(parent_node); - std::vector subgraphs; - GE_CHK_STATUS_RET(NodeUtils::GetDirectSubgraphs(parent_node, subgraphs), "[Get][Subgraphs] failed for node %s", - parent_node->GetName().c_str()); - for (const auto &subgraph : subgraphs) { - (void)graphs_need_split.emplace(subgraph); - } - cur_graph = cur_graph->GetParentGraph(); - } - } - - return GRAPH_SUCCESS; -} - -bool FftsGraphUtils::IsGraphNeedSplit(const ComputeGraphPtr &graph, - const std::unordered_set &nodes_need_clip) { - for (const auto &node : graph->GetDirectNode()) { - if (nodes_need_clip.count(node) > 0U) { - return true; - } - } - return false; -} - -graphStatus FftsGraphUtils::SplitNodesWithCheck(const ComputeGraphPtr &graph, - const std::unordered_set &nodes_need_clip, - std::vector>> &split_nodes) { - // collect src nodes - std::set cur_nodes; - std::set next_nodes; - for (const auto &node : graph->GetDirectNode()) { - if (node->GetInAllNodes().empty()) { - if (nodes_need_clip.count(node) == 0U) { - (void)cur_nodes.insert(node); - } else { - (void)next_nodes.insert(node); - } - } - } - // non-calc nodes should remain in ori-graph - std::set calc_nodes; - CollectCalcNodeInSubgraph(graph, calc_nodes); - // split nodes - bool support_flag = false; - std::set visited_nodes; - while (!(cur_nodes.empty() && next_nodes.empty())) { - const auto &is_cur_stage = [support_flag, nodes_need_clip](const NodePtr &node_ptr) -> bool { - return (support_flag == (nodes_need_clip.count(node_ptr) == 0U)); - }; - SplitNodes(calc_nodes, is_cur_stage, visited_nodes, cur_nodes, next_nodes); - std::set cur_split_nodes; - for (const auto &cur_node : cur_nodes) { - if (calc_nodes.count(cur_node) > 0U) { - (void)cur_split_nodes.insert(cur_node); - } - } - if (!cur_split_nodes.empty()) { - split_nodes.emplace_back(support_flag, cur_split_nodes); - } - support_flag = !support_flag; - cur_nodes.clear(); - std::swap(cur_nodes, next_nodes); - } - - return GRAPH_SUCCESS; -} - -void FftsGraphUtils::SplitNodes(const std::set &calc_nodes, - const std::function &is_cur_stage, - std::set &visited_nodes, - std::set &cur_nodes, - std::set &next_nodes) { - visited_nodes.insert(cur_nodes.cbegin(), cur_nodes.cend()); - std::queue nodes; - for (const auto &node : cur_nodes) { - nodes.push(node); - } - while (!nodes.empty()) { - const auto &node = nodes.front(); - nodes.pop(); - if (calc_nodes.find(node) != calc_nodes.end()) { - (void)cur_nodes.insert(node); - } else { - // op_desc of node should not be null - (void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID); - } - (void)visited_nodes.insert(node); - for (const auto &out_node : node->GetOutAllNodes()) { - const auto &in_nodes = out_node->GetInAllNodes(); - const bool all_in_node_seen = !std::any_of(in_nodes.begin(), in_nodes.end(), - [visited_nodes](const NodePtr &node_ptr) { - return visited_nodes.count(node_ptr) == 0U; - }); - if (!all_in_node_seen) { - continue; - } - if (is_cur_stage(out_node)) { - (void)nodes.push(out_node); - } else { - (void)next_nodes.insert(out_node); - } - } - } -} - -graphStatus FftsGraphUtils::SplitSubgraph(const ComputeGraphPtr &subgraph, - const std::vector>> &split_nodes) { - for (const auto &item : split_nodes) { - if ((item.first) && (!item.second.empty())) { - const auto &subgraph_name = "FFTS_Plus_Subgraph_" + std::to_string(g_fftsPlusSubgraphNum++); - const auto &new_subgraph = GraphUtils::BuildSubgraphWithNodes(subgraph, item.second, subgraph_name); - if (new_subgraph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Build subgraph %s failed", subgraph_name.c_str()); - GELOGE(GRAPH_FAILED, "[Build][Subgraph] %s failed", subgraph_name.c_str()); - return GRAPH_FAILED; - } - GE_CHK_STATUS_RET(SetAttrForFftsPlusSubgraph(new_subgraph), "[Set][Attr] failed for ffts+ subgraph"); - } else { - for (const auto &node : item.second) { - // op_desc of node should not be null - (void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID); - } - } - } - - return GRAPH_SUCCESS; -} - -graphStatus FftsGraphUtils::BuildFftsPlusSubgraphWithAllNodes(const ComputeGraphPtr &subgraph) { - GE_CHECK_NOTNULL(subgraph); - std::set calc_nodes; - CollectCalcNodeInSubgraph(subgraph, calc_nodes); - const auto &subgraph_name = "FFTS_Plus_Subgraph_" + std::to_string(g_fftsPlusSubgraphNum++); - const auto &new_subgraph = GraphUtils::BuildSubgraphWithNodes(subgraph, calc_nodes, subgraph_name); - if (new_subgraph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Build subgraph %s failed", subgraph_name.c_str()); - GELOGE(GRAPH_FAILED, "[Build][Subgraph] %s failed", subgraph_name.c_str()); - return GRAPH_FAILED; - } - GE_CHK_STATUS_RET(SetAttrForFftsPlusSubgraph(new_subgraph), "[Set][Attr] failed for ffts+ subgraph"); - - return GRAPH_SUCCESS; -} - -void FftsGraphUtils::CollectCalcNodeInSubgraph(const ComputeGraphPtr &subgraph, std::set &calc_nodes) { - std::set edge_nodes; - const std::set ctrl_goto_types = { LABELSET, LABELGOTOEX, LABELSWITCHBYINDEX }; - // collect end nodes - CollectEndNodeInSubgraph(subgraph, ctrl_goto_types, edge_nodes); - // collect start nodes - std::queue start_nodes; - for (const auto &node : subgraph->GetDirectNode()) { - if ((node->GetType() == DATA) || - ((node->GetInAllNodes().empty()) && (ctrl_goto_types.count(node->GetType()) > 0U))) { - start_nodes.push(node); - } - } - while (!start_nodes.empty()) { - const auto &cur_node = start_nodes.front(); - start_nodes.pop(); - (void)edge_nodes.insert(cur_node); - for (const auto &out_node : cur_node->GetOutAllNodes()) { - if (ctrl_goto_types.count(out_node->GetType()) > 0U) { - start_nodes.push(out_node); - } - } - } - - for (const auto &node : subgraph->GetDirectNode()) { - if (edge_nodes.count(node) == 0U) { - (void)calc_nodes.insert(node); - } - } -} - -void FftsGraphUtils::CollectEndNodeInSubgraph(const ComputeGraphPtr &subgraph, - const std::set &ctrl_goto_types, - std::set &edge_nodes) { - const auto &net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); - if (net_output_node == nullptr) { - return; - } - std::set out_nodes; - for (const auto &in_node : net_output_node->GetInAllNodes()) { - for (const auto &out_node : in_node->GetOutAllNodes()) { - (void)out_nodes.insert(out_node); - } - } - std::queue end_nodes; - end_nodes.push(net_output_node); - for (const auto &out_node : out_nodes) { - if (ctrl_goto_types.count(out_node->GetType()) > 0U) { - end_nodes.push(out_node); - } - } - while (!end_nodes.empty()) { - const auto &cur_node = end_nodes.front(); - end_nodes.pop(); - (void)edge_nodes.insert(cur_node); - for (const auto &out_node : cur_node->GetOutAllNodes()) { - end_nodes.push(out_node); - } - } -} - -ComputeGraphPtr FftsGraphUtils::GetFftsPlusGraph(ComputeGraph &graph) { - const auto &parent_node = graph.GetParentNode(); - if (parent_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "parent node of graph %s is null", graph.GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] parent node of graph %s is null", graph.GetName().c_str()); - return nullptr; - } - std::vector subgraphs; - if (NodeUtils::GetDirectSubgraphs(parent_node, subgraphs) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get subgraph failed, node:%s", parent_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Subgraph] failed, node:%s", parent_node->GetName().c_str()); - return nullptr; - } - if (subgraphs.size() != 1U) { - REPORT_INNER_ERR_MSG("E18888", "Number of subgraphs in parent_node:%s is %zu, graph:%s", - parent_node->GetName().c_str(), subgraphs.size(), graph.GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Number of subgraphs in parent_node:%s is %zu, graph:%s", - parent_node->GetName().c_str(), subgraphs.size(), graph.GetName().c_str()); - return nullptr; - } - return subgraphs[0U]; -} - -graphStatus FftsGraphUtils::SetAttrForFftsPlusSubgraph(const ComputeGraphPtr &subgraph) { - const auto &parent_node = subgraph->GetParentNode(); - if (parent_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Parent node of subgraph %s is null", subgraph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Parent node of subgraph %s is null", subgraph->GetName().c_str()); - return GRAPH_FAILED; - } - (void)AttrUtils::SetStr(parent_node->GetOpDesc(), ATTR_NAME_FFTS_PLUS_SUB_GRAPH, subgraph->GetName().c_str()); - for (const auto &node : subgraph->GetAllNodes()) { - // depend on SGT api, need modify - (void)AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_THREAD_SCOPE_ID, 0); - } - return GRAPH_SUCCESS; -} - -graphStatus FftsGraphUtils::GraphPartition(ComputeGraph &graph, - const CalcFunc &calc_func, - const std::vector &upper_limit) { - if ((calc_func == nullptr) || upper_limit.empty()) { - GELOGI("Graph:%s, calculate function or upper_limit is empty, skip graph partition", - graph.GetName().c_str()); - return SUCCESS; - } - - const auto &ffts_plus_graph = GetFftsPlusGraph(graph); - GE_CHECK_NOTNULL(ffts_plus_graph); - // calculate value per node / graph - // value of func_node equal to the sum of all node_value in subgraphs - std::map> node_value; - std::map> graph_value; - GE_CHK_STATUS_RET(Calculate(ffts_plus_graph, calc_func, node_value, graph_value), - "[Calculate][Value] failed for graph %s", ffts_plus_graph->GetName().c_str()); - if (!IsValueValid(ffts_plus_graph, upper_limit, node_value, graph_value)) { - REPORT_INNER_ERR_MSG("E18888", "Check value invalid"); - GELOGE(GRAPH_FAILED, "[Check][Value] invalid"); - return GRAPH_FAILED; - } - - // input graph not exceed the limit - if ((graph_value.count(ffts_plus_graph) > 0U) && (graph_value[ffts_plus_graph] <= upper_limit)) { - GELOGI("Graph %s not exceed limit, skip graph partition", ffts_plus_graph->GetName().c_str()); - return SUCCESS; - } - const auto &parent_node = ffts_plus_graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - // op_desc of node should not be null - (void)parent_node->GetOpDesc()->DelAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); - - GE_CHK_STATUS_RET(PartitionGraphWithLimit(ffts_plus_graph, node_value, graph_value, upper_limit), - "[Partition][Graph] failed, graph:%s", ffts_plus_graph->GetName().c_str()); - - // only non-Ffts+ subgraph of PARTITIONEDCALL need to be unfolded - const auto &filter = [](const ComputeGraphPtr &graph_ptr) { - const auto &parent = graph_ptr->GetParentNode(); - if ((parent == nullptr) || (parent->GetOpDesc() == nullptr)) { - return false; - } - // op_desc of node should not be null - if ((parent->GetType() != PARTITIONEDCALL) || - (parent->GetOpDesc()->GetSubgraphInstanceNames().size() != 1U)) { - return false; - } - return !parent->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); - }; - GE_CHK_STATUS_RET(GraphUtils::UnfoldSubgraph(ffts_plus_graph, filter), "[Unfold][Subgraph] failed, graph:%s", - ffts_plus_graph->GetName().c_str()); - - return GRAPH_SUCCESS; -} - -graphStatus FftsGraphUtils::Calculate(const ComputeGraphPtr &graph, - const CalcFunc &calc_func, - std::map> &node_value, - std::map> &graph_value, - const uint32_t recursive_depth) { - if (recursive_depth >= kMaxiumRecursionDepth) { - REPORT_INNER_ERR_MSG("E18888", "param depth:%u >= %u(allow max subgraphs)", recursive_depth, kMaxiumRecursionDepth); - GELOGE(GRAPH_FAILED, "[Check][Param]exist too much subgraphs:%u > %u(allow max subgraphs)", - recursive_depth, kMaxiumRecursionDepth); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(graph); - std::vector cur_graph_value; - for (const auto &node : graph->GetDirectNode()) { - std::vector cur_node_value; - if (node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { - cur_node_value = calc_func(node); - } else { - cur_node_value = Calculate(node, calc_func, node_value, graph_value, recursive_depth); - if (cur_node_value.empty()) { - REPORT_INNER_ERR_MSG("E18888", "Calculate value for func node %s failed", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Calculate][Value] for func node %s failed", node->GetName().c_str()); - return GRAPH_FAILED; - } - } - node_value[node] = cur_node_value; - if (cur_graph_value.empty()) { - cur_graph_value = cur_node_value; - } else if (cur_graph_value.size() != cur_node_value.size()) { - REPORT_INNER_ERR_MSG("E18888", - "Value size not match, value size of graph %s is %zu, " - "value size of node %s is %zu", - graph->GetName().c_str(), cur_graph_value.size(), node->GetName().c_str(), - cur_node_value.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] Value size not match, value size of graph %s is %zu, " - "value size of node %s is %zu", graph->GetName().c_str(), cur_graph_value.size(), - node->GetName().c_str(), cur_node_value.size()); - return GRAPH_FAILED; - } else { - (void) std::transform(cur_graph_value.begin(), cur_graph_value.end(), cur_node_value.begin(), - cur_graph_value.begin(), std::plus()); - } - } - graph_value[graph] = cur_graph_value; - return SUCCESS; -} - -std::vector FftsGraphUtils::Calculate(const NodePtr &node, const CalcFunc &calc_func, - std::map> &node_value, - std::map> &graph_value, - const uint32_t recursive_depth) { - std::vector subgraphs; - if (NodeUtils::GetDirectSubgraphs(node, subgraphs) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get subgraphs failed"); - GELOGE(GRAPH_FAILED, "[Get][Subgraphs] failed"); - return {}; - } - std::vector cur_node_value; - for (const auto &subgraph : subgraphs) { - if (graph_value.count(subgraph) == 0U) { - if (Calculate(subgraph, calc_func, node_value, graph_value, recursive_depth + 1U) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Calculate value failed, graph %s, parent_node:%s", subgraph->GetName().c_str(), - node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Calculate][Value] failed, graph %s, parent_node:%s", - subgraph->GetName().c_str(), node->GetName().c_str()); - return {}; - } - } - if (graph_value.find(subgraph) == graph_value.end()) { - REPORT_INNER_ERR_MSG("E18888", "Find value failed for graph %s", subgraph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Find][Value] failed for graph %s", subgraph->GetName().c_str()); - return {}; - } - const auto &subgraph_value = graph_value[subgraph]; - if (cur_node_value.empty()) { - cur_node_value = subgraph_value; - } else if (cur_node_value.size() != subgraph_value.size()) { - REPORT_INNER_ERR_MSG("E18888", - "Value size not match, value size of node %s is %zu, value size of subgraph %s " - "is %zu", - node->GetName().c_str(), cur_node_value.size(), subgraph->GetName().c_str(), - subgraph_value.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] Value size not match, value size of node %s is %zu, " - "value size of subgraph %s is %zu", node->GetName().c_str(), cur_node_value.size(), - subgraph->GetName().c_str(), subgraph_value.size()); - return {}; - } else { - (void) std::transform(cur_node_value.begin(), cur_node_value.end(), - subgraph_value.begin(), cur_node_value.begin(), std::plus()); - } - } - - return cur_node_value; -} - -bool FftsGraphUtils::IsValueValid(const ComputeGraphPtr &graph, const std::vector &upper_limit, - const std::map> &node_value, - const std::map> &graph_value) { - std::vector subgraphs; - if (GraphUtils::GetSubgraphsRecursively(graph, subgraphs) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get subgraphs failed, graph:%s", graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Subgraphs] failed, graph:%s", graph->GetName().c_str()); - return false; - } - for (const auto &subgraph : subgraphs) { - if (graph_value.count(subgraph) == 0U) { - REPORT_INNER_ERR_MSG("E18888", "Find graph value failed, graph:%s", subgraph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Find graph value failed, graph:%s", subgraph->GetName().c_str()); - return false; - } - std::set calc_nodes; - CollectCalcNodeInSubgraph(subgraph, calc_nodes); - for (const auto &node : calc_nodes) { - if (node_value.count(node) == 0U) { - REPORT_INNER_ERR_MSG("E18888", "Find node value failed, node:%s", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Find node value failed, node:%s", node->GetName().c_str()); - return false; - } - } - } - - const auto is_node_value_match = [upper_limit](const std::pair> &pair_item) { - return pair_item.second.size() != upper_limit.size(); - }; - if (std::find_if(node_value.begin(), node_value.end(), is_node_value_match) != node_value.end()) { - REPORT_INNER_ERR_MSG("E18888", "Node value size not match"); - GELOGE(GRAPH_FAILED, "[Check][Param] Node value size not match"); - return false; - } - - const auto is_graph_value_match = [upper_limit](const std::pair> &pair_item) { - return pair_item.second.size() != upper_limit.size(); - }; - if (std::find_if(graph_value.begin(), graph_value.end(), is_graph_value_match) != graph_value.end()) { - REPORT_INNER_ERR_MSG("E18888", "Graph value size not match"); - GELOGE(GRAPH_FAILED, "[Check][Param] Graph value size not match"); - return false; - } - - return true; -} - -graphStatus FftsGraphUtils::PartitionGraphWithLimit(const ComputeGraphPtr &graph, - std::map> &node_value, - std::map> &graph_value, - const std::vector &upper_limit, - const uint32_t recursive_depth) { - if (recursive_depth >= kMaxiumRecursionDepth) { - REPORT_INNER_ERR_MSG("E18888", "param depth:%u >= %u(allow max subgraphs)", recursive_depth, kMaxiumRecursionDepth); - GELOGE(GRAPH_FAILED, "[Check][Param]exist too much subgraphs:%u > %u(allow max subgraphs)", - recursive_depth, kMaxiumRecursionDepth); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(graph); - std::set calc_nodes; - CollectCalcNodeInSubgraph(graph, calc_nodes); - uint32_t split_level = 0U; - std::map> split_nodes; - std::vector exceed_single_node; - std::vector cur_value; - for (const auto &node : graph->GetDirectNode()) { - if (calc_nodes.count(node) == 0U) { - // op_desc of node should not be null - (void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID); - continue; - } - std::vector cur_node_value = node_value[node]; - if (cur_value.empty()) { - cur_value = cur_node_value; - } else { - (void)std::transform(cur_value.begin(), cur_value.end(), cur_node_value.begin(), cur_value.begin(), - std::plus()); - } - if (cur_value <= upper_limit) { - (void)split_nodes[split_level].emplace(node); - } else { - ++split_level; - if (cur_node_value > upper_limit) { - (void)exceed_single_node.emplace_back(node); - cur_value.clear(); - } else { - (void)split_nodes[split_level].emplace(node); - cur_value = cur_node_value; - } - } - } - - for (const auto &item : split_nodes) { - const auto &subgraph_name = "FFTS_Plus_Subgraph_" + std::to_string(g_fftsPlusSubgraphNum++); - const auto &subgraph = GraphUtils::BuildSubgraphWithNodes(graph, item.second, subgraph_name); - if (subgraph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Build subgraph %s failed", subgraph_name.c_str()); - GELOGE(GRAPH_FAILED, "[Build][Subgraph] %s failed", subgraph_name.c_str()); - return GRAPH_FAILED; - } - GE_CHK_STATUS_RET(SetAttrForFftsPlusSubgraph(subgraph), "[Set][Attr] failed for ffts+ subgraph"); - } - - return SplitFuncNode(exceed_single_node, node_value, graph_value, upper_limit, recursive_depth); -} - -graphStatus FftsGraphUtils::SplitFuncNode(const std::vector exceed_single_node, - std::map> &node_value, - std::map> &graph_value, - const std::vector &upper_limit, - const uint32_t recursive_depth) { - for (const auto &node : exceed_single_node) { - // op_desc of node should not be null - (void)node->GetOpDesc()->DelAttr(ATTR_NAME_THREAD_SCOPE_ID); - std::vector subgraphs; - GE_CHK_STATUS_RET(NodeUtils::GetDirectSubgraphs(node, subgraphs), "[Get][Subgraphs] of node %s failed", - node->GetName().c_str()); - for (const auto &subgraph : subgraphs) { - if (graph_value[subgraph] <= upper_limit) { - GE_CHK_STATUS_RET(BuildFftsPlusSubgraphWithAllNodes(subgraph), "[Build][FftsPlusSubgraph] failed, graph:%s ", - subgraph->GetName().c_str()); - } else { - GE_CHK_STATUS_RET(PartitionGraphWithLimit(subgraph, node_value, graph_value, upper_limit, recursive_depth + 1U), - "[Partition][Subgraph] failed, graph:%s ", subgraph->GetName().c_str()); - } - } - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/utils/ge_dump_graph_whitelist.h b/graph/utils/ge_dump_graph_whitelist.h deleted file mode 100644 index 57c1b18a82ac9b2fb9b9186df1aa8f1aac22141d..0000000000000000000000000000000000000000 --- a/graph/utils/ge_dump_graph_whitelist.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_GRAPH_UTILS_GE_DUMP_GRAPH_WHILELIST_H -#define COMMON_GRAPH_UTILS_GE_DUMP_GRAPH_WHILELIST_H - -namespace ge { -// kGeDumpWhitelistFullName + kGeDumpWhitelistKeyName size need small than 100 -const std::set kGeDumpWhitelistFullName = { - "PreRunBegin", // 用户原始图 - "PreRunAfterNormalizeGraph", // 图标准化出口图 - "AfterInfershape", // infershape出口图 - "PreRunAfterPrepare", // 图准备阶段之后的图 - "PreRunAfterOptimizeOriginalGraph", // 原图优化之后的图 - "PreRunAfterOptimizeAfterStage1", // 各算子信息库优化处理之后的图 - "PreRunAfterOptimizeSubgraph", // 子图优化之后的图 - "PreRunAfterOptimizeGraphBeforeBuild", // 模型编译入口图 - "Build", // 模型编译出口图 - "ComputeGraphBeforeLowering", // lowering前的计算图 - "ExeGraphBeforeOptimize", // lowering后,执行图优化前的执行图 - "ExecuteGraphAfterSplit" // 动态shape最终的执行图 -}; -const std::set kGeDumpWhitelistKeyName = { - "AutoFuseBeforeOptimize", // 自动融合优化之前的图 - "AutoFuseAfterOptimize", // 自动融合优化之后的图 - "RunCustomPass" // 用户自定义pass优化之后的图 -}; -} // namespace ge - -#endif // COMMON_GRAPH_UTILS__GE_DUMP_GRAPH_WHILELIST_H diff --git a/graph/utils/ge_ir_utils.cc b/graph/utils/ge_ir_utils.cc deleted file mode 100644 index b9a101e68af615c8fb84c9281088b42beac24d4f..0000000000000000000000000000000000000000 --- a/graph/utils/ge_ir_utils.cc +++ /dev/null @@ -1,1339 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/ge_ir_utils.h" -#include -#include "common/ge_common/debug/ge_log.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "mmpa/mmpa_api.h" -#include "attribute_group/attr_group_serialize.h" - -namespace { -const ge::char_t *const kControlAnchorIndex = ":-1"; -const ge::char_t *const kNodeTypeForSubgraph = "subgraph"; -const int8_t kMaxRecursiveDepth = 10; -const int32_t kDecimalBase = 10; -const uint64_t kInputPrefixLength = 5U; -const uint64_t kOutputPrefixLength = 6U; -// 因为存量的用例会校验原来的dump图的字段,所以我们设置一个触发tensordesc结构化的下限值 -const size_t kTensorDescStructuredThresholds = 20U; -} // namespace - -namespace ge { -// Part 1: from IR convert to ONNX Protobuf -namespace { -const std::map kGeDataTypeToOnnxMap = { - {DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64}, - {DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32}, - {DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8}, - {DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16}, - {DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16}, - {DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL}, - {DT_FLOAT8_E5M2, onnx::TensorProto_DataType_FLOAT8E5M2}, - {DT_FLOAT8_E4M3FN, onnx::TensorProto_DataType_FLOAT8E4M3FN}, -}; -} -DumpLevel OnnxUtils::dump_level_ = DumpLevel::DUMP_LEVEL_END; -const OnnxUtils::TensordescAttrHandlers OnnxUtils::ext_meta_attr_handlers_ = { - {"size", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return ext_meta.GetSize(); }}, - {"weight_size", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return ext_meta.GetWeightSize(); }}, - {"reuse_input", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return static_cast(ext_meta.GetReuseInput()); }}, - {"output_tensor", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return static_cast(ext_meta.GetOutputTensor()); }}, - {"device_type", onnx::AttributeProto_AttributeType_STRING, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return ext_meta.GetDeviceTypeStr(); }}, - {"input_tensor", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return static_cast(ext_meta.GetInputTensor()); }}, - {"real_dim_cnt", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return static_cast(ext_meta.GetRealDimCnt()); }}, - {"data_offset", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return ext_meta.GetDataOffset(); }}, - {"cmps_size", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return ext_meta.GetCmpsSize(); }}, - {"cmps_tab", onnx::AttributeProto_AttributeType_STRING, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return ext_meta.GetCmpsTab(); }}, - {"cmps_tab_offset", onnx::AttributeProto_AttributeType_INT, - [](const GeTensorDescImpl::ExtMeta &ext_meta) { return ext_meta.GetCmpsTabOffset(); }}, -}; - -const OnnxUtils::TensordescAttrHandlers OnnxUtils::normal_member_attr_handlers_ = { - {"dtype", onnx::AttributeProto_AttributeType_STRING, [](const ConstGeTensorDescPtr &desc) { - return ge::TypeUtils::DataTypeToSerialString(desc->GetDataType()); - }}, - {"origin_dtype", onnx::AttributeProto_AttributeType_STRING, [](const ConstGeTensorDescPtr &desc) { - return ge::TypeUtils::DataTypeToSerialString(desc->GetOriginDataType()); - }}, - {"shape", onnx::AttributeProto_AttributeType_INTS, [](const ConstGeTensorDescPtr &desc) { - return desc->GetShape().GetDims(); - }}, - {"origin_shape", onnx::AttributeProto_AttributeType_INTS, [](const ConstGeTensorDescPtr &desc) { - return desc->GetOriginShape().GetDims(); - }}, - {"layout", onnx::AttributeProto_AttributeType_STRING, [](const ConstGeTensorDescPtr &desc) { - return ge::TypeUtils::FormatToSerialString(desc->GetFormat()); - }}, - {"origin_layout", onnx::AttributeProto_AttributeType_STRING, [](const ConstGeTensorDescPtr &desc) { - return ge::TypeUtils::FormatToSerialString(desc->GetOriginFormat()); - }}, -}; -struct AttrNameComp { - inline bool operator()(const onnx::AttributeProto &lsh, const onnx::AttributeProto &rsh) const { - return lsh.name() < rsh.name(); - } -}; - -onnx::TensorProto_DataType OnnxUtils::EncodeDataType(const DataType data_type) { - const auto it = kGeDataTypeToOnnxMap.find(data_type); - if (it != kGeDataTypeToOnnxMap.end()) { - return it->second; - } else { - GELOGW("[Encode][DataType] Datatype %u not support", data_type); - return onnx::TensorProto_DataType_UNDEFINED; - } -} - -void OnnxUtils::AddAttrProtoFromAttribute(const std::pair &string_attr_value, - onnx::NodeProto *const node_proto) { - if (node_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_proto is nullptr, check invalid"); - GELOGE(FAILED, "[Check][Param] Node proto is nullptr."); - return; - } - const auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "add attr to node proto return nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] attr is nullptr."); - return; - } - const auto attr_name = string_attr_value.first; - attr->set_name(attr_name); - const auto attr_value = string_attr_value.second; - const auto value_type = attr_value.GetValueType(); - switch (value_type) { - case GeAttrValue::VT_FLOAT: { - float32_t data_f = 0.0F; - (void)attr_value.GetValue(data_f); - attr->set_f(data_f); - attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); - break; - } - case GeAttrValue::VT_LIST_FLOAT: { - std::vector data_fs = {}; - (void)attr_value.GetValue(data_fs); - attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for (auto &v : data_fs) { - attr->add_floats(v); - } - break; - } - case GeAttrValue::VT_INT: { - int64_t data_i = 0; - (void)attr_value.GetValue(data_i); - attr->set_type(onnx::AttributeProto_AttributeType_INT); - attr->set_i(data_i); - break; - } - case GeAttrValue::VT_LIST_INT: { - std::vector data_is = {}; - (void)attr_value.GetValue(data_is); - attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for (auto &v : data_is) { - attr->add_ints(v); - } - break; - } - case GeAttrValue::VT_STRING: { - std::string data_s; - (void)attr_value.GetValue(data_s); - attr->set_type(onnx::AttributeProto_AttributeType_STRING); - attr->set_s(data_s); - break; - } - case GeAttrValue::VT_LIST_STRING: { - std::vector data_ss = {}; - (void)attr_value.GetValue(data_ss); - attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for (auto &v : data_ss) { - attr->add_strings(v); - } - break; - } - default: - GELOGW("[Add][Attr] ValueType %u is not supported", value_type); - break; - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *const node_proto, const onnx::AttributeProto_AttributeType type, - const std::string &name, const void *const data) { - if (node_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_proto is nullptr."); - GELOGE(FAILED, "[Check][Param] Node_proto is nullptr."); - return; - } - const auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "add attr to node proto return nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] attr is nullptr."); - return; - } - attr->set_name(name); - switch (type) { - case onnx::AttributeProto_AttributeType_FLOAT: - attr->set_f((*(static_cast(data)))); - attr->set_type(onnx::AttributeProto_AttributeType_FLOAT); - break; - - case onnx::AttributeProto_AttributeType_FLOATS: - attr->set_type(onnx::AttributeProto_AttributeType_FLOATS); - for (auto &v : (*(static_cast *>(data)))) { - attr->add_floats(v); - } - break; - - case onnx::AttributeProto_AttributeType_INT: - attr->set_type(onnx::AttributeProto_AttributeType_INT); - attr->set_i((*(static_cast(data)))); - break; - - case onnx::AttributeProto_AttributeType_INTS: - attr->set_type(onnx::AttributeProto_AttributeType_INTS); - for (auto &v : *(static_cast *>(data))) { - attr->add_ints(v); - } - break; - - case onnx::AttributeProto_AttributeType_STRING: - attr->set_type(onnx::AttributeProto_AttributeType_STRING); - attr->set_s((*(static_cast(data)))); - break; - - case onnx::AttributeProto_AttributeType_STRINGS: - attr->set_type(onnx::AttributeProto_AttributeType_STRINGS); - for (auto &v : *(static_cast *>(data))) { - attr->add_strings(v); - } - break; - - default: - GELOGW("[Add][Attr] AttributeType %u is not supported", type); - break; - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *const node_proto, const onnx::AttributeProto_AttributeType type, - const std::string &name, - const ::google::protobuf::RepeatedField<::google::protobuf::int64> data) { - if (node_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_proto is nullptr."); - GELOGE(FAILED, "[Check][Param] Node_proto is nullptr."); - return; - } - if (!data.empty()) { - const auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "add attr to node proto return nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_ints(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *const node_proto, const onnx::AttributeProto_AttributeType type, - const std::string &name, const ::google::protobuf::RepeatedField data) { - if (node_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_proto is nullptr."); - GELOGE(FAILED, "[Check][Param] Node proto is nullptr."); - return; - } - if (!data.empty()) { - const auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "add attr to node proto return nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_ints(static_cast(v)); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *const node_proto, const onnx::AttributeProto_AttributeType type, - const std::string &name, const ::google::protobuf::RepeatedField data) { - if (node_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_proto is nullptr."); - GELOGE(FAILED, "[Check][Param] Node_proto is nullptr."); - return; - } - if (!data.empty()) { - const auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "add attr to node proto return nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_floats(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAttrProto(onnx::NodeProto *const node_proto, const onnx::AttributeProto_AttributeType type, - const std::string &name, const ::google::protobuf::RepeatedPtrField<::std::string> data) { - if (node_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_proto is nullptr."); - GELOGE(FAILED, "[Check][Param] Node proto is nullptr."); - return; - } - if (!data.empty()) { - const auto attr = node_proto->add_attribute(); - if (attr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "add attr to node proto return nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] attr is nullptr."); - return; - } - attr->set_name(name); - for (auto &v : data) { - attr->add_strings(v); - } - attr->set_type(type); - } -} - -void OnnxUtils::AddAllAttrToJson(const ConstGeTensorDescPtr &tensor_desc, nlohmann::json &tensor_json) { - const std::map attr_maps = tensor_desc->GetAllAttrs(); - google::protobuf::Map tensor_desc_map; - (void) ModelSerializeImp::SerializeAllAttrsFromAnyMap(attr_maps, &tensor_desc_map); - for (const auto &pair : tensor_desc_map) { - AddJson(pair.first, tensor_json, pair.second.DebugString()); - } -} - -void OnnxUtils::AddAllAttrToProto(onnx::NodeProto *const node_proto, const ConstGeTensorDescPtr &tensor_desc, - const char_t *const prefix, const uint32_t idx) { - const std::map attr_maps = tensor_desc->GetAllAttrs(); - google::protobuf::Map tensor_desc_map; - (void) ModelSerializeImp::SerializeAllAttrsFromAnyMap(attr_maps, &tensor_desc_map); - const std::string suffix = ":" + std::to_string(idx); - AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, prefix, suffix); -} - -void OnnxUtils::AddAllAttrGroupToJson(const ConstGeTensorDescPtr &tensor_desc, nlohmann::json &tensor_json) { - const auto attr_store = tensor_desc->GetAttrMap(); - ge::proto::AttrGroups attr_groups; - (void) AttrGroupSerialize::SerializeAllAttr(attr_groups, attr_store); - if (attr_groups.attr_group_def_size() > 0) { - AddJson("attr_groups", tensor_json, attr_groups.DebugString()); - } -} - -void OnnxUtils::AddAllAttrGroupToProto(onnx::NodeProto *const node_proto, - const ConstGeTensorDescPtr &tensor_desc, - const char_t *const prefix, - const uint32_t idx) { - const auto attr_store = tensor_desc->GetAttrMap(); - ge::proto::AttrGroups attr_groups; - (void) AttrGroupSerialize::SerializeAllAttr(attr_groups, attr_store); - if (attr_groups.attr_group_def_size() > 0) { - const std::string attr_groups_readable = attr_groups.DebugString(); - AddAttrProto(node_proto, - onnx::AttributeProto_AttributeType_STRING, - std::string(prefix) + "groups:" + std::to_string(idx), - &attr_groups_readable); - } -} -void OnnxUtils::AddShapeFormatAndDtypeToJson(const ge::ConstGeTensorDescPtr &desc, nlohmann::json &tensor_json) { - for (const auto &item :normal_member_attr_handlers_) { - switch (item.attr_type) { - case onnx::AttributeProto_AttributeType_INTS: - AddJson(item.name, tensor_json, item.member_ints_getter(desc)); - break; - case onnx::AttributeProto_AttributeType_STRING: - AddJson(item.name, tensor_json, item.member_str_getter(desc)); - break; - default:GELOGW("Unsupported ext meta type %ld", static_cast(item.attr_type)); - } - } -} - -void OnnxUtils::AddShapeFormatAndDtypeToProto(const ge::ConstGeTensorDescPtr &desc, - const std::string &prefix, - const uint32_t idx, - onnx::NodeProto *const node_proto) { - for (const auto &item :normal_member_attr_handlers_) { - const std::string attr_name = prefix + item.name + ":" + std::to_string(idx); - switch (item.attr_type) { - case onnx::AttributeProto_AttributeType_INTS: { - const std::vector value = item.member_ints_getter(desc); - AddAttrProto(node_proto, item.attr_type, attr_name, &value); - break; - }; - case onnx::AttributeProto_AttributeType_STRING: { - const std::string value = item.member_str_getter(desc); - AddAttrProto(node_proto, item.attr_type, attr_name, &value); - break; - }; - default:GELOGW("Unsupported ext meta type %ld", static_cast(item.attr_type)); - } - } -} - -void OnnxUtils::AddExtMetaToJson(const GeTensorDescImpl::ExtMeta &tensor_descriptor, nlohmann::json &tensor_json) { - for (const auto &item : ext_meta_attr_handlers_) { - switch (item.attr_type) { - case onnx::AttributeProto_AttributeType_INT: - AddJson(item.name, - tensor_json, - item.ext_meta_int_getter(tensor_descriptor)); - break; - case onnx::AttributeProto_AttributeType_STRING: - AddJson(item.name, - tensor_json, - item.ext_meta_str_getter(tensor_descriptor)); - break; - default:GELOGW("Unsupported ext meta type %ld", static_cast(item.attr_type)); - } - } -} - -void OnnxUtils::AddExtMetaToProto(const GeTensorDescImpl::ExtMeta &tensor_descriptor, - const std::string &prefix, - uint32_t index, - onnx::NodeProto *node_proto) { - for (const auto &item : ext_meta_attr_handlers_) { - const std::string attr_name = prefix + item.name + ":" + std::to_string(index); - switch (item.attr_type) { - case onnx::AttributeProto_AttributeType_INT: { - const int64_t value = item.ext_meta_int_getter(tensor_descriptor); - AddAttrProto(node_proto, item.attr_type, attr_name, &value); - break; - } - case onnx::AttributeProto_AttributeType_STRING: { - const std::string value = item.ext_meta_str_getter(tensor_descriptor); - AddAttrProto(node_proto, item.attr_type, attr_name, &value); - break; - } - default:GELOGW("Unsupported ext meta type %ld", static_cast(item.attr_type)); - } - } -} - -template -void OnnxUtils::ProcessTensorDescImpl(const OpDescPtr &op_desc, - const string &desc_type, - DescGetter desc_getter, - onnx::NodeProto *node_proto) { - const auto size = desc_type == "input" ? op_desc->GetAllInputsSize() : op_desc->GetOutputsSize(); - const std::string nums_name = desc_type + "_desc_nums"; - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, nums_name, &size); - const bool use_json = size >= kTensorDescStructuredThresholds; - const std::string prefix = desc_type + "_desc_"; - for (uint32_t i = 0U; i < size; ++i) { - const auto desc = desc_getter(op_desc, i); - if (desc == nullptr || desc->impl_ == nullptr) { - GELOGW("%s desc of %s with index:%u is nullptr", desc_type.c_str(), op_desc->GetNamePtr(), i); - continue; - } - const auto &tensor_descriptor = desc->impl_->ext_meta_; - if (use_json) { - nlohmann::json tensor_json; - AddShapeFormatAndDtypeToJson(desc, tensor_json); - AddExtMetaToJson(tensor_descriptor, tensor_json); - AddAllAttrToJson(desc, tensor_json); - AddAllAttrGroupToJson(desc, tensor_json); - const std::string json_value = tensor_json.dump(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + std::to_string(i), &json_value); - } else { - AddShapeFormatAndDtypeToProto(desc, prefix, i, node_proto); - AddExtMetaToProto(tensor_descriptor, prefix, i, node_proto); - const std::string attr_name_prefix = prefix + "attr_"; - AddAllAttrToProto(node_proto, desc, attr_name_prefix.c_str(), i); - AddAllAttrGroupToProto(node_proto, desc, attr_name_prefix.c_str(), i); - } - } -} - -void OnnxUtils::AddAttrProtoForOpInDesc(onnx::NodeProto *const node_proto, const OpDescPtr &op_desc) { - return ProcessTensorDescImpl(op_desc, "input", - [](const OpDescPtr &op, uint32_t i) { - return op->GetInputDescPtrDfault(i); - }, node_proto); -} - -void OnnxUtils::AddAttrProtoForOpOutDesc(onnx::NodeProto *const node_proto, const OpDescPtr &op_desc) { - return ProcessTensorDescImpl(op_desc, "output", - [](const OpDescPtr &op, uint32_t i) { - return op->GetOutputDescPtr(i); - }, node_proto); -} - -void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *const node_proto, const OpDescPtr &op_desc) { - if ((node_proto == nullptr) || (op_desc == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node_proto or op_desc is nullptr"); - GELOGE(GRAPH_FAILED, "[Check][Param] node_proto or op_desc is nullptr"); - return; - } - AddAttrProtoForOpInDesc(node_proto, op_desc); - AddAttrProtoForOpOutDesc(node_proto, op_desc); -} - -void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( - const ::google::protobuf::Map &attr_map, onnx::NodeProto *const node_proto, - const std::string& prefix, const std::string& suffix) { - for (const auto &item : attr_map) { - const auto attr_name = item.first; - const auto attr_def = item.second; - const auto attr_type = attr_def.value_case(); - if (attr_type == ge::proto::AttrDef::kT) { - const auto &tensor_def = attr_def.t(); - const auto &tensor_desc = tensor_def.desc(); - const auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - prefix + attr_name + "_desc_dtype" + suffix, &data_type); - const auto dims = tensor_desc.shape().dim(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, - prefix + attr_name + "_desc_shape" + suffix, dims); - const auto layout = tensor_desc.layout(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - prefix + attr_name + "_desc_layout" + suffix, &layout); - const auto device_type = tensor_desc.device_type(); - AddAttrProto(node_proto, ge::onnx::AttributeProto_AttributeType_STRING, - prefix + attr_name + "_desc_device_type" + suffix, &device_type); - if (dump_level_ == DumpLevel::DUMP_ALL) { - const auto data = tensor_def.data(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - prefix + attr_name + "_data" + suffix, &data); - } - } - if (attr_type == ge::proto::AttrDef::kS) { - if (dump_level_ == DumpLevel::DUMP_ALL) { - const auto str_value = attr_def.s(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); - } - } - if (attr_type == ge::proto::AttrDef::kI) { - const auto int_value = attr_def.i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); - } - if (attr_type == ge::proto::AttrDef::kF) { - const auto float_value = attr_def.f(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); - } - if (attr_type == ge::proto::AttrDef::kB) { - const auto int_value = static_cast(attr_def.b()); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); - } - if (attr_type == ge::proto::AttrDef::kList) { - AddListAttrProto(attr_name, attr_def, prefix, suffix, node_proto); - } - if (attr_type == ge::proto::AttrDef::kListListInt) { - const auto &list_value = attr_def.list_list_int(); - const auto &list_ints = list_value.list_list_i(); - int64_t list_index = 0; - for (const auto &one_ints : list_ints) { - const auto &ints = one_ints.list_i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, - prefix + attr_name + suffix + "_" + std::to_string(list_index++), ints); - } - } - } -} - -void OnnxUtils::AddListAttrProto(const std::string &attr_name, - const ::ge::proto::AttrDef &attr_def, const std::string &prefix, - const std::string &suffix, onnx::NodeProto *node_proto) { - const auto &list_value = attr_def.list(); - const auto &list_value_type = list_value.val_type(); - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { - if (dump_level_ == DumpLevel::DUMP_ALL) { - const auto &strings = list_value.s(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); - } - } - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { - const auto &floats = list_value.f(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); - } - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { - const auto &ints = list_value.i(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); - } - if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { - const auto &bools = list_value.b(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); - } -} - -void OnnxUtils::AddCommonAttrGroupIntoProto(const OpDescPtr &op_desc, onnx::NodeProto *const node_proto) { - const auto attr_store = op_desc->GetAttrMap(); - ge::proto::AttrGroups attr_groups; - (void) AttrGroupSerialize::SerializeAllAttr(attr_groups, attr_store); - if (attr_groups.attr_group_def_size() > 0) { - const std::string attr_groups_readable = attr_groups.DebugString(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, - "attr_groups", &attr_groups_readable); - } -} - -void OnnxUtils::AddCommonAttrIntoProto(onnx::NodeProto *const node_proto, const OpDescPtr &op_desc) { - const auto meta_data = op_desc->impl_->meta_data_; - const auto id = meta_data.GetId(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id); - const auto stream_id = meta_data.GetStreamId(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id); - const auto &input_name = meta_data.GetInputNames(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", &input_name); - const auto &src_name = meta_data.GetSrcNames(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", &src_name); - const auto src_index = meta_data.GetSrcIndexes(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", &src_index); - const auto &dst_name = meta_data.GetDstNames(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", &dst_name); - const auto dst_index = meta_data.GetDstIndexes(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", &dst_index); - const auto input_i = meta_data.GetInputOffsets(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", &input_i); - const auto output_i = meta_data.GetOutputOffsets(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", &output_i); - const auto workspace = op_desc->impl_->GetWorkspace(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", &workspace); - const auto workspace_bytes = op_desc->impl_->GetWorkspaceBytes(); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", &workspace_bytes); - const auto &is_input_const = meta_data.GetIsInputConsts(); - vector int_const(is_input_const.size()); - for (size_t idx = 0UL; idx < is_input_const.size(); ++idx) { - int_const[idx] = static_cast(is_input_const[idx]); - } - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", &int_const); - google::protobuf::Map op_def_attr_map; - const std::map attr_maps = op_desc->GetAllAttrs(); - (void)ModelSerializeImp::SerializeAllAttrsFromAnyMap(attr_maps, &op_def_attr_map); - AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); - AddCommonAttrGroupIntoProto(op_desc, node_proto); -} - -void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *const node_proto) { - if ((node == nullptr) || (node->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr."); - GELOGE(GRAPH_FAILED, "[Check][Param] node is nullptr"); - return; - } - // 1.Attributes added from node's methods - const auto send_list = node->impl_->send_event_id_list_; - if (!send_list.empty()) { - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list); - } - const auto recv_list = node->impl_->recv_event_id_list_; - if (!recv_list.empty()) { - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list); - } - const auto op_desc = node->impl_->op_; - if ((op_desc != nullptr) && (op_desc->impl_ != nullptr)) { - // for input_name_idx_ in opdesc - const auto input_name_2_indexs = op_desc->GetAllInputName(); - ::google::protobuf::RepeatedPtrField<::std::string> input_names; - ::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes; - for (const auto &input_name_2_index: input_name_2_indexs) { - std::string input_name = input_name_2_index.first; - input_names.Add(std::move(input_name)); - input_indexes.Add(static_cast(input_name_2_index.second)); - } - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names); - AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes); - // 2.Attributes added from node's op_(message OpDef) - // Input and out describes - AddAttrProtoForOpInAndOutDesc(node_proto, op_desc); - // Others - AddCommonAttrIntoProto(node_proto, op_desc); - } else { - REPORT_INNER_ERR_MSG("E18888", "Opdesc is nullptr, node:%s", node->GetName().c_str()); - GELOGE(FAILED, "[Check][Param] Opdesc is nullptr"); - return; - } -} - -bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *const node_proto) { - if ((node == nullptr) || (node->impl_ == nullptr) || (node_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node or node_proto is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] EncodeOpDesc: Input Para Node Invalid"); - return false; - } - // Encode ge::Node members to AttributeProto - AddAttrProtoFromNodeMembers(node, node_proto); - - // Sort node attributes by name. - std::sort(node_proto->mutable_attribute()->begin(), node_proto->mutable_attribute()->end(), AttrNameComp()); - return true; -} - -void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *const node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node or node_proto is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] EncodeNodeLinkForNetronVisual: Input Para Node Invalid"); - return; - } - const auto &node_name = node->GetName(); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) { - node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx())); - } - } - const auto out_control_anchor = node->GetOutControlAnchor(); - if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) { - node_proto->add_output(node_name + kControlAnchorIndex); - } -} - -bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *const node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node or node_proto is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] EncodeNodeLink: Input Para Node Invalid"); - return false; - } - node_proto->clear_input(); - // 1. Add input by in data edge - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - const auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNodeBarePtr() != nullptr)) { - node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + - std::to_string(peer_out_anchor->GetIdx())); - } else { - // Add "" input - node_proto->add_input(""); - } - } - - // 2. Add input by in control edge - const auto in_control_anchor = node->GetInControlAnchor(); - if (in_control_anchor != nullptr) { - const auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors(); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor->GetOwnerNodeBarePtr() != nullptr) { - node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); - } - } - } else { - REPORT_INNER_ERR_MSG("E18888", "In control anchor of node(%s) is nullptr", node->GetName().c_str()); - GELOGE(FAILED, "[Check][Param] In control anchor of node(%s) is nullptr", node->GetName().c_str()); - return false; - } - - // 3. Add output for Netron visual support - EncodeNodeLinkForNetronVisual(node, node_proto); - return true; -} - -bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *const node_proto) { - if ((node == nullptr) || (node_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node or node_proto is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] EncodeNode: Input Para Node Invalid"); - return false; - } - // 1. Encode name and type - node_proto->set_name(node->GetName()); - /// Netron believes that some operators, such as the activation operator of softplus, only have one input, - /// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix - /// is added to correctly display the link relation at the expense of some color features - node_proto->set_op_type("ge:" + node->GetType()); - - if (dump_level_ != DumpLevel::DUMP_WITH_OUT_DESC) { - // 2.for attr - if (!EncodeNodeDesc(node, node_proto)) { - GELOGE(GRAPH_FAILED, "[Encode][NodeDesc] failed, node:%s", node->GetName().c_str()); - return false; - } - } - // 3.for link info - return EncodeNodeLink(node, node_proto); -} - -void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *const tensor_type) { - if ((node == nullptr) || (tensor_type == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node or tensor type is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid"); - return; - } - const auto &op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGW("[Encode][Tensor] op_desc is empty, name %s, type %s", node->GetName().c_str(), node->GetType().c_str()); - return; - } - for (size_t i = 0U; i < op_desc->GetOutputsSize(); ++i) { - const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(static_cast(i)); - if (ge_tensor == nullptr) { - GELOGW("[Encode][Tensor] Output desc %zu of node %s is nullptr", i, node->GetName().c_str()); - continue; - } - const auto ge_data_type = ge_tensor->GetDataType(); - const auto onnx_data_type = EncodeDataType(ge_data_type); - tensor_type->set_elem_type(onnx_data_type); - onnx::TensorShapeProto *const shape = tensor_type->mutable_shape(); - if (shape == nullptr) { - GELOGW("[Encode][Tensor] Shape is nullptr"); - continue; - } - for (const auto d : ge_tensor->GetShape().GetDims()) { - const auto dim = shape->add_dim(); - dim->set_dim_value(d); - } - } -} - -void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *const value_info_proto) { - if ((node == nullptr) || (value_info_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node or value info proto is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] EncodeValueInfo: Input Param Node or value_info_proto Invalid"); - return; - } - value_info_proto->set_name(node->GetName()); - onnx::TypeProto *const t = value_info_proto->mutable_type(); - onnx::TypeProto_Tensor *const tensor_type = t->mutable_tensor_type(); - EncodeTypeProtoTensorType(node, tensor_type); -} - -bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *const graph_proto) { - if ((graph == nullptr) || (graph_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param graph or graph proto is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] EncodeGraph: Input para Invalid"); - return false; - } - graph_proto->set_name(graph->GetName()); - // 1. Add graph inputs - for (const auto &input : graph->GetInputNodes()) { - const auto value_info_proto = graph_proto->add_input(); - EncodeValueInfo(input, value_info_proto); - } - // 2. Add graph outputs - for (const auto &output : graph->GetOutputNodes()) { - const auto value_info_proto = graph_proto->add_output(); - EncodeValueInfo(output, value_info_proto); - } - // 3. Add nodes - for (const auto &node : graph->GetDirectNode()) { - if (!EncodeNode(node, graph_proto->add_node())) { - GELOGW("[Encode][Graph] Encode node %s failed", node->GetName().c_str()); - continue; - } - } - return true; -} - -bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) { - const char_t *dump_ge_graph = nullptr; - MM_SYS_GET_ENV(MM_ENV_DUMP_GE_GRAPH, dump_ge_graph); - auto dump_level = (dump_ge_graph != nullptr) ? - static_cast(std::strtol(dump_ge_graph, nullptr, kDecimalBase)) : DumpLevel::NO_DUMP; - return ConvertGeModelToModelProto(model, model_proto, dump_level); -} - -// Part 2: from ONNX Protobuf convert to IR -static std::map onnxDataTypeToGeMap = { - {onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64}, - {onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32}, - {onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8}, - {onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16}, - {onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16}, - {onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL}, - {onnx::TensorProto_DataType_FLOAT8E5M2, DT_FLOAT8_E5M2}, - {onnx::TensorProto_DataType_FLOAT8E4M3FN, DT_FLOAT8_E4M3FN}, -}; - -bool OnnxUtils::ParseNameAndIndex(const std::string &node_name_index, std::string &node_name, int32_t &idx) { - const auto sep = node_name_index.rfind(':'); - if (sep == std::string::npos) { - return false; - } - node_name = node_name_index.substr(0U, sep); - const auto index_str = node_name_index.substr(sep + 1U); - idx = static_cast(std::strtol(index_str.c_str(), nullptr, kDecimalBase)); - return true; -} - -bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, const NodePtr &node_ptr) { - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_ptr is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] DecodeNodeLinkImp: node_ptr is nullptr"); - return false; - } - // Data edge - if (item.GetSrcOutIndex() >= 0) { - const auto src_anchor = node_ptr->GetOutDataAnchor(item.GetSrcOutIndex()); - const auto dst_anchor = item.GetDstNode()->GetInDataAnchor(item.GetDstInIndex()); - if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Get DataAnchor failed, %s:%d, %s:%d ", item.GetSrcNodeName().c_str(), - item.GetSrcOutIndex(), item.GetDstNodeName().c_str(), item.GetDstInIndex()); - GELOGE(GRAPH_FAILED, "[Get][DataAnchor] failed, %s:%d, %s:%d ", item.GetSrcNodeName().c_str(), - item.GetSrcOutIndex(), item.GetDstNodeName().c_str(), item.GetDstInIndex()); - return false; - } - if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "src anchor link to dst anchor failed."); - GELOGE(GRAPH_FAILED, "[Invoke][LinkTo] Data Anchor: src anchor link to dst anchor failed"); - return false; - } - // Control edge - } else { - const auto src_anchor = node_ptr->GetOutControlAnchor(); - const auto dst_anchor = item.GetDstNode()->GetInControlAnchor(); - if ((src_anchor == nullptr) || (dst_anchor == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Get ControlAnchor failed, %s:%d, %s:%d ", item.GetSrcNodeName().c_str(), - item.GetSrcOutIndex(), item.GetDstNodeName().c_str(), item.GetDstInIndex()); - GELOGE(GRAPH_FAILED, "[Get][ControlAnchor] failed, %s:%d, %s:%d ", item.GetSrcNodeName().c_str(), - item.GetSrcOutIndex(), item.GetDstNodeName().c_str(), item.GetDstInIndex()); - return false; - } - if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "src anchor(%s) link to dst anchor(%s) failed.", - src_anchor->GetOwnerNode()->GetName().c_str(), - dst_anchor->GetOwnerNode()->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][LinkTo] Control Anchor: src anchor(%s) link to dst anchor(%s) failed", - src_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetOwnerNode()->GetName().c_str()); - return false; - } - } - return true; -} - -bool OnnxUtils::DecodeNodeLink(const std::vector &node_proto_vector, - const std::map &node_map) { - for (const auto &node_proto : node_proto_vector) { - const auto &node_name = node_proto.name(); - const auto dst_node = node_map.find(node_name); - if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "destination node: %s find failed or is nullptr", node_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] destination node: %s find failed or is nullptr", node_name.c_str()); - return false; - } - int32_t dst_index = 0; - for (const auto &input : node_proto.input()) { - std::string input_node_name; - int32_t index = 0; - if (ParseNameAndIndex(input, input_node_name, index)) { - const auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()}; - const auto src_node = node_map.find(input_node_name); - if (src_node == node_map.end()) { - REPORT_INNER_ERR_MSG("E18888", "find src node: %s failed", input_node_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] find src node: %s failed", input_node_name.c_str()); - return false; - } - const auto node_ptr = src_node->second; - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "src node: %s is nullptr", input_node_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] src node: %s is nullptr", input_node_name.c_str()); - return false; - } - if (!DecodeNodeLinkImp(item, node_ptr)) { - GELOGE(GRAPH_FAILED, "[Invoke][DecodeNodeLinkImp] failed, node: %s", input_node_name.c_str()); - return false; - } - } - if (index >= 0) { - dst_index++; - } - } - } - return true; -} - -void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &strings) { - if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_STRINGS) { - REPORT_INNER_ERR_MSG("E18888", "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - for (int32_t i = 0; i < attr_proto.strings_size(); i++) { - strings.push_back(attr_proto.strings(i)); - } -} - -void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::string &value) { - if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_STRING) { - REPORT_INNER_ERR_MSG("E18888", "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - value = attr_proto.s(); -} - -void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &ints) { - if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_INTS) { - REPORT_INNER_ERR_MSG("E18888", "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - for (int32_t i = 0; i < attr_proto.ints_size(); i++) { - ints.push_back(attr_proto.ints(i)); - } -} - -void OnnxUtils::DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, int64_t &value) { - if (attr_proto.type() != ge::onnx::AttributeProto_AttributeType_INT) { - REPORT_INNER_ERR_MSG("E18888", "Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Attribute %s call wrong decode attribute function", attr_proto.name().c_str()); - return; - } - value = attr_proto.i(); -} - -void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_desc, - const int32_t idx, - const OpDescPtr &op_desc) { - const auto tensor_desc = op_desc->MutableInputDesc(static_cast(idx)); - if ((tensor_desc == nullptr) || (tensor_desc->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "MutableInputDesc index:%d return nullptr, op:%s, attr:%s", idx, - op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][MutableInputDesc] index:%d return nullptr, op name %s, attr name %s", - idx, op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); - return; - } - if (attr_name_for_input_desc == "input_desc_dtype") { - const auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); - tensor_desc->SetDataType(data_type); - } else if (attr_name_for_input_desc == "input_desc_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - const GeShape ge_shape(ints); - tensor_desc->SetShape(ge_shape); - } else if (attr_name_for_input_desc == "input_desc_layout") { - const auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - tensor_desc->SetFormat(data_format); - } else if (attr_name_for_input_desc == "input_desc_origin_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - const GeShape ge_shape(ints); - tensor_desc->SetOriginShape(ge_shape); - } else if (attr_name_for_input_desc == "input_desc_origin_layout") { - const auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - tensor_desc->SetOriginFormat(data_format); - } else if (attr_name_for_input_desc == "input_desc_size") { - int64_t input_size = 0; - DecodeAttribute(attr_proto, input_size); - tensor_desc->impl_->ext_meta_.SetSize(input_size); - } else if (attr_name_for_input_desc == "input_desc_data_offset") { - int64_t offset = 0; - DecodeAttribute(attr_proto, offset); - tensor_desc->impl_->ext_meta_.SetDataOffset(offset); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_output_desc, - const int32_t index, const OpDescPtr &op_desc) { - const auto tensor_desc = op_desc->MutableOutputDesc(static_cast(index)); - if ((tensor_desc == nullptr) || (tensor_desc->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "MutableOutputDesc index:%d return nullptr, op:%s, attr:%s", index, - op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][MutableOutputDesc] index:%d return nullptr, op name %s, attr name %s", - index, op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); - return; - } - if (attr_name_for_output_desc == "output_desc_dtype") { - const auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); - tensor_desc->SetDataType(data_type); - } else if (attr_name_for_output_desc == "output_desc_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - const GeShape ge_shape(ints); - tensor_desc->SetShape(ge_shape); - } else if (attr_name_for_output_desc == "output_desc_layout") { - const auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - tensor_desc->SetFormat(data_format); - } else if (attr_name_for_output_desc == "output_desc_origin_shape") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - const GeShape ge_shape(ints); - tensor_desc->SetOriginShape(ge_shape); - } else if (attr_name_for_output_desc == "output_desc_origin_layout") { - const auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); - tensor_desc->SetOriginFormat(data_format); - } else if (attr_name_for_output_desc == "output_desc_size") { - int64_t output_size = 0; - DecodeAttribute(attr_proto, output_size); - tensor_desc->impl_->ext_meta_.SetSize(output_size); - } else if (attr_name_for_output_desc == "output_desc_data_offset") { - int64_t offset = 0; - DecodeAttribute(attr_proto, offset); - tensor_desc->impl_->ext_meta_.SetDataOffset(offset); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_output_desc, - const int32_t idx, - const OpDescPtr &op_desc) { - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param op_desc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] op_desc is nullptr"); - return; - } - if (attr_name_for_input_output_desc.substr(0U, kInputPrefixLength) == "input") { - DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, idx, op_desc); - } else if (attr_name_for_input_output_desc.substr(0U, kOutputPrefixLength) == "output") { - DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, idx, op_desc); - } else { - return; - } -} - -void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { - if ((op_desc == nullptr) || (op_desc->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param op_desc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] DecodeNodeAttributeForOpDesc: op_desc is nullptr"); - return; - } - const auto &attr_name = attr_proto.name(); - std::string attr_name_for_input_output_desc; - int32_t index = 0; - if (!ParseNameAndIndex(attr_name, attr_name_for_input_output_desc, index)) { - if (attr_name == "id") { - op_desc->SetId(attr_proto.i()); - } else if (attr_name == "stream_id") { - op_desc->SetStreamId(attr_proto.i()); - } else if (attr_name == "src_name") { - std::vector strings; - DecodeAttribute(attr_proto, strings); - op_desc->SetSrcName(strings); - } else if (attr_name == "dst_name") { - std::vector strings; - DecodeAttribute(attr_proto, strings); - op_desc->SetDstName(strings); - } else if (attr_name == "src_index") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetSrcIndex(ints); - } else if (attr_name == "dst_index") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetDstIndex(ints); - } else if (attr_name == "fusion_scope") { - int64_t val = 0; - DecodeAttribute(attr_proto, val); - AnyValue av; - (void)av.SetValue(val); - (void)op_desc->SetAttr(attr_proto.name(), av); - } else if (attr_name == "input_i") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetInputOffset(ints); - } else if (attr_name == "output_i") { - std::vector ints; - DecodeAttribute(attr_proto, ints); - op_desc->SetOutputOffset(ints); - } else { - return; - } - // Update input and output desc - } else { - DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); - } -} - -bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *const node_proto, OpDescPtr &op_desc) { - if ((op_desc == nullptr) || (node_proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param op_desc or node_proto is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Op_desc is nullptr or node_proto is nullptr"); - return false; - } - // 1. Decode node_proto name and type - op_desc->SetName(node_proto->name()); - const auto &node_type_with_ge_prefix = node_proto->op_type(); - const auto sep = node_type_with_ge_prefix.find(':'); - if (sep == std::string::npos) { - return false; - } - const auto node_type = node_type_with_ge_prefix.substr(sep + 1U); - op_desc->SetType(node_type); - // 2. Add empty input and output desc - for (const auto &attr : node_proto->attribute()) { - if (attr.name() == "input_desc_nums") { - const auto size_in = attr.i(); - for (int64_t i = 0; i < size_in; i++) { - const GeTensorDesc ge_tensor_desc; - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); - } - } - if (attr.name() == "output_desc_nums") { - const auto size_out = attr.i(); - for (int64_t i = 0; i < size_out; i++) { - const GeTensorDesc ge_tensor_desc; - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); - } - } - } - // 3.Decode node_proto attributes - for (decltype(node_proto->attribute_size()) i = 0; i < node_proto->attribute_size(); i++) { - DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc); - } - return true; -} - -bool OnnxUtils::AddInputAndOutputNodesForGraph(const onnx::GraphProto &graph_proto, - ComputeGraphPtr &graph, - const std::map &node_map) { - // Add inputs nodes for graph - for (const auto &input : graph_proto.input()) { - const auto &input_node_name = input.name(); - const auto input_node_item = node_map.find(input_node_name); - if (input_node_item == node_map.end()) { - REPORT_INNER_ERR_MSG("E18888", "cannot find graph's input node %s in node_", input_node_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] cannot find graph's input node %s in node_", input_node_name.c_str()); - return false; - } - const auto ret = graph->AddInputNode(input_node_item->second); - GE_CHK_BOOL_EXEC(ret != nullptr, continue, - "[Add][InputNode] %s failed, graph:%s", input_node_name.c_str(), graph->GetName().c_str()); - } - // Add outputs nodes for graph - for (const auto &output : graph_proto.output()) { - const auto &output_name = output.name(); - const auto output_node_item = node_map.find(output_name); - if (output_node_item == node_map.end()) { - REPORT_INNER_ERR_MSG("E18888", "cannot find graph's output node %s in node_", output_name.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] cannot find graph's output node %s in node_", output_name.c_str()); - return false; - } - const auto ret = graph->AddOutputNode(output_node_item->second); - if (ret == nullptr) { - GELOGW("[Decode][Graph] Add output node %s failed", output_name.c_str()); - continue; - } - } - return true; -} - -bool OnnxUtils::DecodeGraph(const int32_t recursion_depth, - const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) { - if (recursion_depth > kMaxRecursiveDepth) { - REPORT_INNER_ERR_MSG("E18888", "param recursion_depth:%d is bigger than kMaxRecursiveDepth:%d", recursion_depth, - kMaxRecursiveDepth); - GELOGE(GRAPH_FAILED, "[Check][Param] DecodeGraph: recursion depth is too large, abort"); - return false; - } - - graph = ComGraphMakeShared(graph_proto.name()); - GE_CHK_BOOL_EXEC(graph != nullptr, - REPORT_INNER_ERR_MSG("E18888", "create ComputeGraph failed."); - return false, "[Create][ComputeGraph]ComputeGraph make shared failed"); - /// 1. Decode all nodes first, node should include input - /// and output nodes and nodes which represent sub graphs - std::map node_map; - std::vector node_proto_vector; - for (const auto &node_proto : graph_proto.node()) { - // a. nodes represent sub graphs - if (node_proto.op_type() == kNodeTypeForSubgraph) { - ComputeGraphPtr compute_graph; - // in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH - const auto &node_attr = node_proto.attribute(0); - if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) && - (DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph))) { - (void)graph->AddSubGraph(compute_graph); - } else { - REPORT_INNER_ERR_MSG("E18888", "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(), - node_attr.type()); - GELOGE(GRAPH_FAILED, "[Check][Param] Decode sub graph %s failed with node type:%d", node_proto.name().c_str(), - node_attr.type()); - return false; - } - // b. direct nodes in graph - } else { - node_proto_vector.push_back(node_proto); - OpDescPtr op_desc = ComGraphMakeShared(); - // b.1 For node desc - if (!DecodeNodeDesc(&node_proto, op_desc)) { - GELOGE(GRAPH_FAILED, "[Decode][NodeDesc] %s failed ", node_proto.name().c_str()); - return false; - } - auto node = graph->AddNode(op_desc); - (void)node_map.insert(std::make_pair(node_proto.name(), node)); - } - } - /// We get all nodes in graph here - /// b.2 For node link - if (!DecodeNodeLink(node_proto_vector, node_map)) { - GELOGE(GRAPH_FAILED, "[Decode][NodeLink] failed"); - return false; - } - - return AddInputAndOutputNodesForGraph(graph_proto, graph, node_map); -} - -bool OnnxUtils::ConvertGeModelToModelProto(const Model &model, onnx::ModelProto &model_proto, DumpLevel dump_level) { - dump_level_ = dump_level; - GELOGD("DumpGEGraphToOnnx with dump_ge_graph_level %" PRId32 ".", dump_level_); - model_proto.set_model_version(static_cast(model.GetVersion())); - model_proto.set_ir_version(onnx::IR_VERSION); - model_proto.set_producer_name(model.GetName()); - const auto compute_graph = model.graph_; - if (compute_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "GetComputeGraph for model return nullptr."); - GELOGE(GRAPH_FAILED, "[Invoke][GetComputeGraph] return nullptr"); - return false; - } - const auto graph_proto = model_proto.mutable_graph(); - if (graph_proto == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "mutable_graph return nullptr, graph:%s", compute_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][MutableGraph] return nullptr, graph:%s", compute_graph->GetName().c_str()); - return false; - } - if (!EncodeGraph(compute_graph, graph_proto)) { - GELOGE(GRAPH_FAILED, "[Invoke][EncodeGraph] fail, graph:%s", compute_graph->GetName().c_str()); - return false; - } - graph_proto->clear_input(); - - // For subgraphs: a subgraph is represented by a node - for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { - if (sub_compute_graph == nullptr) { - GELOGW("[Convert][GeModel] Graph %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str()); - continue; - } - const auto node_proto = graph_proto->add_node(); - if (node_proto == nullptr) { - GELOGW("[Convert][GeModel] Add node failed"); - continue; - } - node_proto->set_name(sub_compute_graph->GetName()); - node_proto->set_op_type(kNodeTypeForSubgraph); - const auto attr = node_proto->add_attribute(); - attr->set_name("graph"); - attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); - const auto sub_graph_proto = attr->mutable_g(); - if (sub_graph_proto == nullptr) { - GELOGW("[Convert][GeModel] Sub graph proto is nullptr"); - continue; - } - if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { - GELOGW("[Convert][GeModel] Encode sub graph %s failed", sub_compute_graph->GetName().c_str()); - continue; - } - } - return true; -} -} // namespace ge diff --git a/graph/utils/ge_ir_utils.h b/graph/utils/ge_ir_utils.h deleted file mode 100644 index 062132c3088953940d37ca2ae92f61a09993b42f..0000000000000000000000000000000000000000 --- a/graph/utils/ge_ir_utils.h +++ /dev/null @@ -1,262 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ -#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "normal_graph/ge_tensor_impl.h" - -#include -#include -#include -#include -#include -#include -#include "nlohmann/json.hpp" - -#include "proto/ge_ir.pb.h" -#include "proto/onnx/ge_onnx.pb.h" - -namespace ge { -/// -/// @ingroup ge_ir_utils -/// @brief check, if not equal, log with tag -/// @param [in] const left_value, right_value reference, log_info_tag -/// @return bool -/// -template -bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) { - if ((l_value == r_value)) { - return true; - } else { - GELOGD("Check not equal with %s", log_info_tag.c_str()); - return false; - } -} - -class OnnxUtils { - public: - static bool ConvertGeModelToModelProto(const ge::Model &model, ge::onnx::ModelProto &model_proto); - - static bool ConvertGeModelToModelProto(const ge::Model &model, ge::onnx::ModelProto &model_proto, DumpLevel dump_level); - private: - // Part 1: from IR convert to ONNX Protobuf - static void AddAttrProto(ge::onnx::NodeProto *const node_proto, const ge::onnx::AttributeProto_AttributeType type, - const std::string &name, const void *const data); - - static void AddAttrProto(ge::onnx::NodeProto *const node_proto, const ge::onnx::AttributeProto_AttributeType type, - const std::string &name, - const ::google::protobuf::RepeatedField<::google::protobuf::int64> data); - - static void AddAttrProto(ge::onnx::NodeProto *const node_proto, - const ge::onnx::AttributeProto_AttributeType type, - const std::string &name, const ::google::protobuf::RepeatedField data); - - static void AddAttrProto(ge::onnx::NodeProto *const node_proto, const ge::onnx::AttributeProto_AttributeType type, - const std::string &name, const ::google::protobuf::RepeatedField data); - - static void AddAttrProto(ge::onnx::NodeProto *const node_proto, - const ge::onnx::AttributeProto_AttributeType type, - const std::string &name, const ::google::protobuf::RepeatedPtrField<::std::string> data); - - static void AddListAttrProto(const std::string &attr_name, const ::ge::proto::AttrDef &attr_def, - const std::string &prefix, const std::string &suffix, onnx::NodeProto *node_proto); - - static void AddAttrProtoFromNodeMembers(const NodePtr &node, ge::onnx::NodeProto *const node_proto); - - static void AddAttrProtoFromAttribute(const std::pair &string_attr_value, - ge::onnx::NodeProto *const node_proto); - - static void AddAttrProtoForOpInDesc(onnx::NodeProto *const node_proto, const OpDescPtr &op_desc); - - static void AddAttrProtoForOpOutDesc(onnx::NodeProto *const node_proto, const OpDescPtr &op_desc); - - static void AddAttrProtoForOpInAndOutDesc(ge::onnx::NodeProto *const node_proto, const OpDescPtr &op_desc); - - static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map &attr_map, - ge::onnx::NodeProto *const node_proto, - const std::string &prefix = "", - const std::string &suffix = ""); - - static ge::onnx::TensorProto_DataType EncodeDataType(const ge::DataType data_type); - - static void EncodeNodeLinkForNetronVisual(const NodePtr &node, ge::onnx::NodeProto *const node_proto); - - static bool EncodeNodeLink(const NodePtr &node, ge::onnx::NodeProto *const node_proto); - - static bool EncodeNodeDesc(const NodePtr &node, ge::onnx::NodeProto *const node_proto); - - static bool EncodeNode(const NodePtr &node, ge::onnx::NodeProto *const node_proto); - - static void EncodeTypeProtoTensorType(const NodePtr &node, ge::onnx::TypeProto_Tensor *const tensor_type); - - static void EncodeValueInfo(const NodePtr &node, ge::onnx::ValueInfoProto *const value_info_proto); - - static bool EncodeGraph(const ConstComputeGraphPtr &graph, ge::onnx::GraphProto *const graph_proto); - - /// Part 2: from ONNX Protobuf convert to IR - /// Describes node's link relationships - class NodeLinkInfo { - public: - NodeLinkInfo() = default; - ~NodeLinkInfo() = default; - NodeLinkInfo(std::string src_name, - int32_t src_out_index, - NodePtr dst_node, - int32_t dst_in_index, - std::string dst_name) : - src_node_name_(std::move(src_name)), - src_out_index_(src_out_index), - dst_node_(std::move(dst_node)), - dst_in_index_(dst_in_index), - dst_node_name_(std::move(dst_name)) {} - - std::string GetSrcNodeName() const { return src_node_name_; }; - int32_t GetSrcOutIndex() const { return src_out_index_; }; - NodePtr GetDstNode() const { return dst_node_; }; - int32_t GetDstInIndex() const { return dst_in_index_; }; - std::string GetDstNodeName() const { return dst_node_name_; }; - - private: - std::string src_node_name_; - int32_t src_out_index_; - NodePtr dst_node_; - int32_t dst_in_index_; - std::string dst_node_name_; - }; - struct TensorDescToOnnxAttrHandler { - std::string name; - onnx::AttributeProto_AttributeType attr_type; - using FuncCase0 = int64_t(*)(const GeTensorDescImpl::ExtMeta &); - using FuncCase1 = std::string(*)(const GeTensorDescImpl::ExtMeta &); - using FuncCase2 = std::vector(*)(const ConstGeTensorDescPtr &); - using FuncCase3 = std::string(*)(const ConstGeTensorDescPtr &); - union { - FuncCase0 ext_meta_int_getter{nullptr}; - FuncCase1 ext_meta_str_getter; - FuncCase2 member_ints_getter; - FuncCase3 member_str_getter; - }; - TensorDescToOnnxAttrHandler(std::string s, - onnx::AttributeProto_AttributeType t, - FuncCase3 func) : name(std::move(s)), attr_type(t), member_str_getter(func) {}; - TensorDescToOnnxAttrHandler(std::string s, - onnx::AttributeProto_AttributeType t, - FuncCase2 func) : name(std::move(s)), attr_type(t), member_ints_getter(func) {}; - TensorDescToOnnxAttrHandler(std::string s, - onnx::AttributeProto_AttributeType t, - FuncCase1 func) : name(std::move(s)), attr_type(t), ext_meta_str_getter(func) {}; - TensorDescToOnnxAttrHandler(std::string s, - onnx::AttributeProto_AttributeType t, - FuncCase0 func) : name(std::move(s)), attr_type(t), ext_meta_int_getter(func) {}; - }; - using TensordescAttrHandlers = std::vector; - // Parse node name and index - static bool ParseNameAndIndex(const std::string &node_name_index, std::string &node_name, int32_t &idx); - - static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &strings); - - static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::vector &ints); - - static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, int64_t &value); - - static void DecodeAttribute(const ge::onnx::AttributeProto &attr_proto, std::string &value); - - static void DecodeNodeAttributeForOpOutDesc(const ge::onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_output_desc, - const int32_t index, const OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpInDesc(const ge::onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_desc, - const int32_t idx, - const OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpInAndOutDesc(const ge::onnx::AttributeProto &attr_proto, - const std::string &attr_name_for_input_output_desc, - const int32_t idx, - const OpDescPtr &op_desc); - - static void DecodeNodeAttributeForOpDesc(const ge::onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); - - static bool DecodeNodeLinkImp(const NodeLinkInfo &item, const NodePtr &node_ptr); - - static bool DecodeNodeLink(const std::vector &node_proto_vector, - const std::map &node_map); - - static bool DecodeNodeDesc(const ge::onnx::NodeProto *const node_proto, OpDescPtr &op_desc); - - static bool DecodeGraph(const int32_t recursion_depth, - const ge::onnx::GraphProto &graph_proto, ComputeGraphPtr &graph); - - static void AddShapeFormatAndDtypeToJson(const ge::ConstGeTensorDescPtr &desc, nlohmann::json &tensor_json); - - static void AddShapeFormatAndDtypeToProto(const ge::ConstGeTensorDescPtr &desc, - const std::string &prefix, - const uint32_t idx, - onnx::NodeProto *const node_proto); - - static void AddAllAttrToJson(const ConstGeTensorDescPtr &tensor_desc, nlohmann::json &tensor_json); - - static void AddAllAttrToProto(onnx::NodeProto *const node_proto, const ConstGeTensorDescPtr &tensor_desc, - const char_t *const prefix, const uint32_t idx); - - static void AddAllAttrGroupToJson(const ConstGeTensorDescPtr &tensor_desc, nlohmann::json &tensor_json); - - static void AddAllAttrGroupToProto(onnx::NodeProto *const node_proto, const ConstGeTensorDescPtr &tensor_desc, - const char_t *const prefix, const uint32_t idx); - - static void AddCommonAttrIntoProto(onnx::NodeProto *const node_proto, const OpDescPtr &op_desc); - static void AddCommonAttrGroupIntoProto(const OpDescPtr &op_desc, onnx::NodeProto *const node_proto); - - static bool AddInputAndOutputNodesForGraph(const onnx::GraphProto &graph_proto, - ComputeGraphPtr &graph, - const std::map &node_map); - template - static void ProcessTensorDescImpl(const OpDescPtr &op_desc, - const string &desc_type, - DescGetter desc_getter, - onnx::NodeProto *node_proto); - static void AddExtMetaToJson(const GeTensorDescImpl::ExtMeta &tensor_descriptor, nlohmann::json &tensor_json); - static void AddExtMetaToProto(const GeTensorDescImpl::ExtMeta &tensor_descriptor, - const std::string &prefix, - uint32_t index, - onnx::NodeProto *node_proto); - template - static void AddJson(const std::string &name, nlohmann::json &json_holder, const T &json_obj) { - try { - json_holder[name] = json_obj; - } - catch (const std::exception &e) { - GELOGW("Failed to init json object, err = %s, name = %s", e.what(), name.c_str()); - return; - } - } - static DumpLevel dump_level_; - static const TensordescAttrHandlers ext_meta_attr_handlers_; - static const TensordescAttrHandlers normal_member_attr_handlers_; -}; -} // namespace ge - -#endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ diff --git a/graph/utils/graph_thread_pool.cc b/graph/utils/graph_thread_pool.cc deleted file mode 100644 index a20eb8c4994f45927d106d3a6303f3655341da78..0000000000000000000000000000000000000000 --- a/graph/utils/graph_thread_pool.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/graph_thread_pool.h" - -#include -#include -#include - -#include "register/register_types.h" -#include "graph/ge_context.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GraphThreadPool::GraphThreadPool(const uint32_t size) - : is_stoped_(false) { - idle_thrd_num_ = (size < 1U) ? 1U : size; - - for (uint32_t i = 0U; i < idle_thrd_num_; ++i) { - pool_.emplace_back(&ThreadFunc, this); - } -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GraphThreadPool::~GraphThreadPool() { - is_stoped_.store(true); - { - const std::unique_lock lock{m_lock_}; - cond_var_.notify_all(); - } - - for (std::thread &thd : pool_) { - if (thd.joinable()) { - try { - thd.join(); - } catch (...) { - GELOGW("exception"); - } - } - } -} - -void GraphThreadPool::ThreadFunc(GraphThreadPool *const thread_pool) { - if (thread_pool == nullptr) { - return; - } - GELOGI("Thread started success"); - while (!thread_pool->is_stoped_) { - std::function task; - { - std::unique_lock lock{thread_pool->m_lock_}; - thread_pool->cond_var_.wait( - lock, [thread_pool]() -> bool { return thread_pool->is_stoped_.load() || (!thread_pool->tasks_.empty()); }); - if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { - return; - } - task = std::move(thread_pool->tasks_.front()); - thread_pool->tasks_.pop(); - } - --thread_pool->idle_thrd_num_; - task(); - ++thread_pool->idle_thrd_num_; - } -} -} // namespace ge diff --git a/graph/utils/graph_utils.cc b/graph/utils/graph_utils.cc deleted file mode 100644 index 1f46b7a4c20e1e5a9fe0bdb9b7e08e5462a9b5c0..0000000000000000000000000000000000000000 --- a/graph/utils/graph_utils.cc +++ /dev/null @@ -1,5119 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/graph_utils.h" - -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -#include "graph/ge_context.h" -#include "graph/debug/ge_util.h" -#include "graph/ge_local_context.h" -#include "proto/ge_ir.pb.h" -#include "graph/utils/file_utils.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/dumper/ge_graph_dumper.h" -#include "graph/debug/ge_op_types.h" -#include "external/ge_common/ge_api_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "mmpa/mmpa_api.h" -#include "common/checker.h" -#include "graph/utils/op_type_utils.h" -#include "graph/utils/constant_utils.h" -#include "external/utils/extern_math_util.h" -#include "ge_dump_graph_whitelist.h" - -namespace ge { -enum class DumpGraphLevel { - kDumpLevel1 = 1, - kDumpLevel2, - kDumpLevel3, - kDumpLevel4, - kDumpLevelOther, -}; - -namespace { -const int32_t kBaseOfIntegerValue = 10; -#ifdef FMK_SUPPORT_DUMP -const int32_t kDumpGraphIndexWidth = 8; -#endif - -const char_t *const kDumpStrBuild = "Build"; -const char_t *const kDumpStrPreRunBegin = "PreRunBegin"; -const char_t *const kDumpStrPartition = "partition"; -const char_t *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; -const char_t *const kDumpStrSubgraphFunc = "sub_graph"; -const char_t *const kDumpStrAicpu = "Aicpu"; -const char_t *const kOriginName4Recover = "_origin_name_4_recover"; -const char_t *const kOriginType4Recover = "_origin_type_4_recover"; -const char_t *const kLocation4Recover = "_location_4_recover"; -const char_t *const kLength4Recover = "_length_4_recover"; -const size_t kNameMax = 255U; -const int32_t kCopyGraphMaxRecursionDepth = 10; -const int32_t kNameWidth = 5; -const uint32_t kSubgraphIndexOfPartitionedCall = 0U; -const std::set kMergeInputSkipTypes{ STREAMACTIVE, STREAMSWITCH, CONSTANT, CONSTANTOP }; -constexpr int32_t kInvalidStream = -1; -constexpr size_t kNoOpOptimizeThreshold = 1000UL; -const std::string kSuperKernelScope = "_super_kernel_scope"; -const std::string kSuperKernelOptions = "_super_kernel_options"; -const std::vector kNecessaryStrAttrWhitelist = { - public_attr::USER_STREAM_LABEL, public_attr::OP_AI_CORE_NUM, public_attr::OP_VECTOR_CORE_NUM, kSuperKernelScope, - kSuperKernelOptions}; - -Status InheritAttr(const OpDescPtr &node_op_desc, const OpDescPtr &insert_op_desc) { - GE_ASSERT_NOTNULL(node_op_desc); - for (const auto &attr : kNecessaryStrAttrWhitelist) { - std::string attr_val; - if (AttrUtils::GetStr(node_op_desc, attr, attr_val)) { - GE_ASSERT_NOTNULL(insert_op_desc); - GE_ASSERT_TRUE(AttrUtils::SetStr(insert_op_desc, attr, attr_val)); - } - } - return SUCCESS; -} - -graphStatus ReLinkInputDataEdge(const NodePtr &input_node, - const NodePtr &target_node) { - GE_ASSERT_TRUE(input_node->GetType() == DATA, "Input node: %s should be Data", - input_node->GetNamePtr()); - int32_t index = -1; - (void)AttrUtils::GetInt(input_node->GetOpDesc(), ATTR_NAME_INDEX, index); - GE_ASSERT_TRUE(index >= 0, - "Attr index[%d] of node: %s is invalid", index, input_node->GetNamePtr()); - GE_ASSERT_TRUE(index < static_cast(target_node->GetAllInDataAnchorsSize()), - "Attr index[%d] of node: %s cannot larger than input num: %u of target node: %s", - index, input_node->GetNamePtr(), target_node->GetAllInDataAnchorsSize(), target_node->GetNamePtr()); - GELOGD("Begin to handle subgraph input node:%s with index:%d.", input_node->GetName().c_str(), index); - // get node's in data anchor and peer out anchor - auto node_in_anchor = target_node->GetInDataAnchor(index); - GE_ASSERT_NOTNULL(node_in_anchor); - auto src_out_anchor = node_in_anchor->GetPeerOutAnchor(); - GE_ASSERT_NOTNULL(src_out_anchor); - auto data_out_anchor = input_node->GetOutDataAnchor(0); - GE_ASSERT_NOTNULL(data_out_anchor); - GE_ASSERT_NOTNULL(src_out_anchor->GetOwnerNode()); - GE_ASSERT_SUCCESS(GraphUtils::RemoveEdge(src_out_anchor, node_in_anchor), - "Remove edge from %s to %s failed.", src_out_anchor->GetOwnerNode()->GetNamePtr(), - target_node->GetNamePtr()); - auto node_in_control_anchor = target_node->GetInControlAnchor(); - GE_ASSERT_NOTNULL(node_in_control_anchor); - for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) { - GE_ASSERT_NOTNULL(peer_in_anchor->GetOwnerNode()); - GE_ASSERT_SUCCESS(GraphUtils::ReplaceEdgeSrc(data_out_anchor, peer_in_anchor, src_out_anchor), - "Replace src: %s from dst: %s to src: %s failed", input_node->GetNamePtr(), - peer_in_anchor->GetOwnerNode()->GetNamePtr(), src_out_anchor->GetOwnerNode()->GetNamePtr()); - // add control edge - for (const auto &out_anchor : node_in_control_anchor->GetPeerOutControlAnchors()) { - const auto peer_in_anchor_node = peer_in_anchor->GetOwnerNode(); - GE_ASSERT_NOTNULL(peer_in_anchor_node); - GE_ASSERT_NOTNULL(out_anchor->GetOwnerNode()); - GE_ASSERT_SUCCESS(GraphUtils::AddEdge(out_anchor, peer_in_anchor_node->GetInControlAnchor()), - "Add control edge from %s to %s failed.", out_anchor->GetOwnerNode()->GetNamePtr(), - peer_in_anchor_node->GetNamePtr()); - } - } - return SUCCESS; -} - -graphStatus RelinkOutputNodeEdge(const NodePtr &out_node, - const int32_t out_index, const NodePtr &target_node, const size_t target_index) { - // 处理输出算子的连边关系 - GE_ASSERT_TRUE(target_index < static_cast(target_node->GetAllOutDataAnchorsSize()), - "Attr index[%d] of node: %s cannot larger than input num: %u of target node: %s", - target_index, out_node->GetNamePtr(), target_node->GetAllOutDataAnchorsSize(), target_node->GetNamePtr()); - auto node_out_anchor = target_node->GetOutDataAnchor(target_index); - GE_ASSERT_NOTNULL(node_out_anchor, "Get index: %zu of node: %s failed", - target_index, target_node->GetNamePtr()); - GE_ASSERT_NOTNULL(out_node); - GELOGD("Begin to handle subgraph output node:%s output:%d with index:%d of node: %s.", - out_node->GetNamePtr(), out_index, target_index, target_node->GetNamePtr()); - auto src_out_anchor = out_node->GetOutDataAnchor(out_index); - GE_CHECK_NOTNULL(src_out_anchor); - for (const auto &dst_in_anchor : node_out_anchor->GetPeerInDataAnchors()) { - GE_ASSERT_NOTNULL(dst_in_anchor->GetOwnerNode()); - GE_ASSERT_SUCCESS(GraphUtils::ReplaceEdgeSrc(node_out_anchor, dst_in_anchor, src_out_anchor), - "Replace src: %s from dst: %s to src: %s failed", target_node->GetNamePtr(), - dst_in_anchor->GetOwnerNode()->GetNamePtr(), out_node->GetNamePtr()); - } - auto node_out_control_anchor = target_node->GetOutControlAnchor(); - GE_ASSERT_NOTNULL(node_out_control_anchor); - for (const auto &peer_in_control_anchor : node_out_control_anchor->GetPeerInControlAnchors()) { - GE_ASSERT_NOTNULL(peer_in_control_anchor->GetOwnerNode()); - GE_ASSERT_SUCCESS(GraphUtils::AddEdge(out_node->GetOutControlAnchor(), peer_in_control_anchor), - "Add control edge from %s to %s failed.", out_node->GetNamePtr(), - peer_in_control_anchor->GetOwnerNode()->GetNamePtr()); - } - return SUCCESS; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::GetIndependentCompileGraphs(const ComputeGraphPtr &compute_graph, - std::vector &independent_compile_subgraphs) { - bool is_pipeline_partitioned = false; - (void)AttrUtils::GetBool(*compute_graph, ATTR_NAME_PIPELINE_PARTITIONED, is_pipeline_partitioned); - if (is_pipeline_partitioned) { - for (const auto &node : compute_graph->GetDirectNode()) { - if (node->GetType() == PARTITIONEDCALL) { - auto sub_graph = NodeUtils::GetSubgraph(*node, kSubgraphIndexOfPartitionedCall); - GE_CHECK_NOTNULL(sub_graph); - independent_compile_subgraphs.emplace_back(sub_graph); - } - } - return GRAPH_SUCCESS; - } - independent_compile_subgraphs.emplace_back(compute_graph); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "addedge failed because param src is nullptr or run LinkTo failed."); - GELOGE(GRAPH_FAILED, "[Add][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const AnchorPtr &src, - const AnchorPtr &dst) { - const OutDataAnchorPtr src_data = Anchor::DynamicAnchorCast(src); - const InDataAnchorPtr dst_data = Anchor::DynamicAnchorCast(dst); - const OutControlAnchorPtr src_control = Anchor::DynamicAnchorCast(src); - const InControlAnchorPtr dst_control = Anchor::DynamicAnchorCast(dst); - if ((src_data != nullptr) && (dst_data != nullptr) && (src_data->LinkTo(dst_data) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_data != nullptr) && (dst_control != nullptr) && (src_data->LinkTo(dst_control) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_control != nullptr) && (dst_control != nullptr) && (src_control->LinkTo(dst_control) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - if ((src_control != nullptr) && (dst_data != nullptr) && (src_control->LinkTo(dst_data) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "AddEdge failed because src or dst is nullptr or run LinkTo failed."); - GELOGE(GRAPH_FAILED, "[Add][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutControlAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "AddEdge failed because src is nullptr or run LinkTo failed."); - GELOGE(GRAPH_FAILED, "[Add][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AddEdge(const OutDataAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->LinkTo(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "AddEdge failed because src is nullptr or run LinkTo failed."); - GELOGE(GRAPH_FAILED, "[Add][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, - const InDataAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "RemoveEdge failed because src is nullptr or run Unlink failed."); - GELOGE(GRAPH_FAILED, "[Remove][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const AnchorPtr &src, - const AnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "RemoveEdge failed because src is nullptr or run Unlink failed."); - GELOGE(GRAPH_FAILED, "[Remove][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutControlAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "RemoveEdge failed because src is nullptr or run Unlink failed."); - GELOGE(GRAPH_FAILED, "[Remove][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveEdge(const OutDataAnchorPtr &src, - const InControlAnchorPtr &dst) { - if ((src != nullptr) && (src->Unlink(dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "[Remove][Edge] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::ReplaceEdgeSrc(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const OutDataAnchorPtr &new_src) { - if ((RemoveEdge(src, dst) == GRAPH_SUCCESS) && (AddEdge(new_src, dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "[Replace][EdgeSrc] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::ReplaceEdgeSrc(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, - const OutControlAnchorPtr &new_src) { - if ((RemoveEdge(src, dst) == GRAPH_SUCCESS) && (AddEdge(new_src, dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "[Replace][EdgeSrc] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, - const InDataAnchorPtr &new_dst) { - if ((RemoveEdge(src, dst) == GRAPH_SUCCESS) && (AddEdge(src, new_dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "[Replace][EdgeDst] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, - const InControlAnchorPtr &new_dst) { - if ((RemoveEdge(src, dst) == GRAPH_SUCCESS) && (AddEdge(src, new_dst) == GRAPH_SUCCESS)) { - return GRAPH_SUCCESS; - } - GELOGE(GRAPH_FAILED, "[Replace][EdgeDst] Failed."); - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeBetweenDataAnchors( - const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node) { - GE_CHECK_NOTNULL(src); - GE_CHECK_NOTNULL(dst); - GE_CHECK_NOTNULL(new_node); - - const InDataAnchorPtr node_in_anchor = new_node->GetInDataAnchor(0); - GE_CHK_BOOL_RET_STATUS(node_in_anchor != nullptr, GRAPH_FAILED, - "[Invoke][GetInDataAnchor] this node has not inDataAnchor"); - const OutDataAnchorPtr node_out_anchor = new_node->GetOutDataAnchor(0); - GE_CHK_BOOL_RET_STATUS(node_out_anchor != nullptr, GRAPH_FAILED, - "[Invoke][GetOutDataAnchor] this node has not outDataAnchor"); - GE_CHK_STATUS_RET(src->Insert(dst, node_in_anchor, node_out_anchor), "[Replace][Peer] Failed"); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(remove_node); - GE_CHECK_NOTNULL(remove_node->GetOpDesc()); - if (remove_node->GetOwnerComputeGraph() == nullptr) { - GELOGW("Node %s has not been setted owner graph.", remove_node->GetName().c_str()); - return GRAPH_SUCCESS; - } - if ((remove_node->GetOwnerComputeGraph() != compute_graph) && - (std::find(compute_graph->impl_->nodes_.begin(), compute_graph->impl_->nodes_.end(), remove_node) == - compute_graph->impl_->nodes_.end())) { - GELOGW("Can not find node %s in graph %s.", remove_node->GetName().c_str(), compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - if (remove_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { - GELOGD("Node %s has no subgraph.", remove_node->GetName().c_str()); - return GRAPH_SUCCESS; - } - // Find all subgraph of this node - const auto &root_graph = GraphUtils::FindRootGraph(compute_graph); - std::vector subgraphs; - std::vector all_nodes; - std::deque candidates; - NodePtr remove_node_new = remove_node; - candidates.emplace_back(remove_node_new); - while (!candidates.empty()) { - const NodePtr node = candidates.front(); - all_nodes.emplace_back(node); - candidates.pop_front(); - - const OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { - auto subgraph = root_graph->GetSubgraph(*name_iter); - if ((subgraph != nullptr) && (subgraph->impl_ != nullptr)) { - subgraphs.emplace_back(subgraph); - (void)candidates.insert(candidates.begin(), subgraph->impl_->nodes_.begin(), subgraph->impl_->nodes_.end()); - } - } - } - // Remove all subgraph - for (const auto &remove_graph : subgraphs) { - if (root_graph->RemoveSubGraph(remove_graph) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "RemoveSubGraph failed, sub graph name is %s, compute graph is %s.", - remove_node->GetName().c_str(), compute_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Remove][SubGraph] failed, sub graph name is %s, compute graph is %s.", - remove_node->GetName().c_str(), compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveNodesByTypeWithoutRelink(const ComputeGraphPtr &compute_graph, const std::string &node_type) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(compute_graph->impl_); - GELOGI("Start remove %s from graph %s.", node_type.c_str(), compute_graph->GetName().c_str()); - for (auto iter = compute_graph->impl_->input_nodes_.begin(); - iter != compute_graph->impl_->input_nodes_.end();) { - if ((*iter)->GetType() == node_type) { - iter = compute_graph->impl_->input_nodes_.erase(iter); - } else { - iter++; - } - } - - for (auto iter = compute_graph->impl_->output_nodes_info_.begin(); - iter != compute_graph->impl_->output_nodes_info_.end();) { - if (iter->first->GetType() == node_type) { - iter = compute_graph->impl_->output_nodes_info_.erase(iter); - } else { - iter++; - } - } - - for (auto iter = compute_graph->impl_->nodes_.begin(); - iter != compute_graph->impl_->nodes_.end();) { - if ((*iter)->GetType() == node_type) { - if ((node_type != PLACEHOLDER) && (node_type != END)) { - const auto ret = RemoveSubgraphRecursively(compute_graph, (*iter)); - if (ret != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - } - iter = compute_graph->impl_->nodes_.erase(iter); - compute_graph->impl_->direct_nodes_size_--; - } else { - iter++; - } - } - GELOGI("End remove %s from graph %s.", node_type.c_str(), compute_graph->GetName().c_str()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(compute_graph->impl_); - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should not be null."); - return GRAPH_FAILED; - } - - // If the node save as input node, delete it - (void)compute_graph->RemoveInputNode(node); - - // If the node save as output node, delete it - (void)compute_graph->RemoveOutputNode(node); - - // If the node has sub-graphs, delete them - const auto ret = RemoveSubgraphRecursively(compute_graph, node); - if (ret != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - const auto iter = find(compute_graph->impl_->nodes_.begin(), compute_graph->impl_->nodes_.end(), node); - if (iter != compute_graph->impl_->nodes_.end()) { - compute_graph->EraseFromNodeList(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RemoveNodesWithoutRelink(const ComputeGraphPtr &compute_graph, const std::unordered_set &nodes) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(compute_graph->impl_); - for (auto iter = compute_graph->impl_->input_nodes_.begin(); iter != compute_graph->impl_->input_nodes_.end();) { - if (nodes.count(*iter) > 0U) { - iter = compute_graph->impl_->input_nodes_.erase(iter); - } else { - iter++; - } - } - - for (auto iter = compute_graph->impl_->output_nodes_info_.begin(); - iter != compute_graph->impl_->output_nodes_info_.end();) { - if (nodes.count((*iter).first) > 0U) { - iter = compute_graph->impl_->output_nodes_info_.erase(iter); - } else { - iter++; - } - } - size_t success_removed_nodes_size = 0U; - for (auto iter = compute_graph->impl_->nodes_.begin(); iter != compute_graph->impl_->nodes_.end();) { - if (nodes.count(*iter) > 0U) { - const auto ret = RemoveSubgraphRecursively(compute_graph, (*iter)); - if (ret != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - GELOGD("Remove %s from graph %s.", (*iter)->GetName().c_str(), compute_graph->GetName().c_str()); - iter = compute_graph->impl_->nodes_.erase(iter); - compute_graph->impl_->direct_nodes_size_--; - success_removed_nodes_size++; - } else { - iter++; - } - } - const auto to_be_remove_nodes_size = nodes.size(); - if (success_removed_nodes_size != to_be_remove_nodes_size) { - GELOGW("Successfully remove %zu nodes but there are %zu nodes to be delete", success_removed_nodes_size, - to_be_remove_nodes_size); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, - const std::vector &dsts, const OpDescPtr &insert_op, - const uint32_t input_index, const uint32_t output_index) { - GE_ASSERT_NOTNULL(src); - const NodePtr src_node = src->GetOwnerNode(); - GE_ASSERT_NOTNULL(src_node); - auto compute_graph = src_node->GetOwnerComputeGraphBarePtr(); - GE_ASSERT_NOTNULL(compute_graph); - auto insert_node = compute_graph->InsertNode(src_node, insert_op); - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::InsertNodeAfter(src, dsts, - insert_node, input_index, output_index)); - return insert_node; -} - -/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst -/// @param [in] src -/// @param [in] dsts -/// @param [in] insert_node -/// @param [in] input_index -/// @param [in] output_index -/// @return graphStatus -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, - const std::vector &dsts, const NodePtr &insert_node, - const uint32_t input_index, const uint32_t output_index) { - GE_CHECK_NOTNULL(src); - GE_CHECK_NOTNULL(insert_node); - - const auto src_node = src->GetOwnerNodeBarePtr(); - GE_CHECK_NOTNULL(src_node); - if (src_node->GetOwnerComputeGraph() != insert_node->GetOwnerComputeGraph()) { - REPORT_INNER_ERR_MSG("E18888", "src:%s and insert_node:%s does not exist in the same graph.", - src_node->GetName().c_str(), insert_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] src:%s and insert_node:%s does not exist in the same graph.", - src_node->GetName().c_str(), insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - - if (AddEdge(src, insert_node->GetInDataAnchor(static_cast(input_index))) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "AddEdge %s->%s failed.", src_node->GetName().c_str(), insert_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][Edge] %s->%s failed.", src_node->GetName().c_str(), insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - - const OutControlAnchorPtr src_out_ctrl_anchor = src_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(src_out_ctrl_anchor); - - bool ctrl_edge_flag = true; - const std::string type = NodeUtils::GetNodeType(src->GetOwnerNode()); - if ((type == SWITCH) || (type == REFSWITCH) || (type == SWITCHN)) { - ctrl_edge_flag = false; - } - - for (auto &dst : dsts) { - GE_CHECK_NOTNULL(dst); - const auto dst_node = dst->GetOwnerNodeBarePtr(); - GELOGI("Insert node %s between %s->%s.", - insert_node->GetName().c_str(), src_node->GetName().c_str(), dst_node->GetName().c_str()); - if (src_node->GetOwnerComputeGraph() != dst_node->GetOwnerComputeGraph()) { - REPORT_INNER_ERR_MSG("E18888", "src:%s and dst:%s does not exist in the same graph.", src_node->GetName().c_str(), - dst_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] src:%s and dst:%s does not exist in the same graph.", - src_node->GetName().c_str(), dst_node->GetName().c_str()); - return GRAPH_FAILED; - } - - (void)RemoveEdge(src, dst); - if (AddEdge(insert_node->GetOutDataAnchor(static_cast(output_index)), dst) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), - dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Replace][Edge] from %s->%s to %s->%s failed.", src_node->GetName().c_str(), - dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); - return GRAPH_FAILED; - } - - if (!ctrl_edge_flag) { continue; } - for (const InControlAnchorPtr& peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { - if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || - (AddEdge(insert_node->GetOutControlAnchor(), peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { - REPORT_INNER_ERR_MSG("E18888", "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), - peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), insert_node->GetName().c_str(), - peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Replace][Edge] from %s->%s to %s->%s failed.", - src_node->GetName().c_str(), peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), - insert_node->GetName().c_str(), peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); - return GRAPH_FAILED; - } - } - } - GE_ASSERT_SUCCESS(InheritAttr(src_node->GetOpDesc(), insert_node->GetOpDesc())); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr GraphUtils::InsertNodeBefore(const InDataAnchorPtr &dst, - const OpDescPtr &insert_op, - const uint32_t input_index, - const uint32_t output_index) { - GE_ASSERT_NOTNULL(dst); - const auto src_node_out_anchor = dst->GetPeerOutAnchor(); - GE_ASSERT_NOTNULL(src_node_out_anchor); - const auto src_node = src_node_out_anchor->GetOwnerNode(); - GE_ASSERT_NOTNULL(src_node); - auto compute_graph = src_node->GetOwnerComputeGraphBarePtr(); - GE_ASSERT_NOTNULL(compute_graph); - auto insert_node = compute_graph->InsertNode(src_node, insert_op); - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::InsertNodeBefore(dst, insert_node, - input_index, output_index)); - return insert_node; -} - -graphStatus GraphUtils::InsertNodeBefore(const InDataAnchorPtr &dst, - const NodePtr &insert_node, - const uint32_t input_index, - const uint32_t output_index) { - GE_CHECK_NOTNULL(dst); - GE_CHECK_NOTNULL(insert_node); - const auto dst_node = dst->GetOwnerNodeBarePtr(); - GE_CHECK_NOTNULL(dst_node); - if (dst_node->GetOwnerComputeGraph() != insert_node->GetOwnerComputeGraph()) { - GELOGE(GRAPH_FAILED, "[INSERT][NODE] dst:%s and insert_node:%s does not exist in the same graph.", - dst_node->GetName().c_str(), insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - - const auto src_node_out_anchor = dst->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_node_out_anchor); - const auto src_node = src_node_out_anchor->GetOwnerNodeBarePtr(); - GE_CHECK_NOTNULL(src_node); - // insert node - if ((RemoveEdge(src_node_out_anchor, dst) != GRAPH_SUCCESS) || - (AddEdge(src_node_out_anchor, - insert_node->GetInDataAnchor(static_cast(input_index))) != GRAPH_SUCCESS) || - (AddEdge(insert_node->GetOutDataAnchor(static_cast(output_index)), dst) != GRAPH_SUCCESS)) { - GELOGE(GRAPH_FAILED, "[INSERT][NODE] %s between %s->%s failed", - insert_node->GetName().c_str(), - src_node->GetName().c_str(), - dst_node->GetName().c_str()); - return GRAPH_FAILED; - } - GELOGI("[INSERT][NODE] %s between %s->%s", - insert_node->GetName().c_str(), - src_node->GetName().c_str(), - dst_node->GetName().c_str()); - - // update control edges - const auto in_ctrl_anchor = dst_node->GetInControlAnchor(); - GE_CHECK_NOTNULL(in_ctrl_anchor); - const auto insert_node_in_ctrl_anchor = insert_node->GetInControlAnchor(); - for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) { - GE_CHECK_NOTNULL(peer_out_ctrl_anchor); - const auto peer_node = peer_out_ctrl_anchor->GetOwnerNode(); - if (NodeUtils::IsLikeAtomicClean(peer_node)) { - continue; - } - if ((RemoveEdge(peer_out_ctrl_anchor, in_ctrl_anchor) != GRAPH_SUCCESS) || - (AddEdge(peer_out_ctrl_anchor, insert_node_in_ctrl_anchor) != GRAPH_SUCCESS)) { - GELOGE(GRAPH_FAILED, "[INSERT][NODE] replace control edge from %s->%s to %s->%s failed.", - (peer_node != nullptr) ? peer_node->GetName().c_str() : "NULL", - dst_node->GetName().c_str(), - (peer_node != nullptr) ? peer_node->GetName().c_str() : "NULL", - insert_node->GetName().c_str()); - return GRAPH_FAILED; - } - } - GE_ASSERT_SUCCESS(InheritAttr(dst_node->GetOpDesc(), insert_node->GetOpDesc())); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode(ComputeGraph &compute_graph, - const NodePtr &node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] The node ptr should be not null."); - return GRAPH_FAILED; - } - if (compute_graph.impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The compute graph impl should be not null, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] The compute graph impl should be not null."); - return GRAPH_FAILED; - } - const auto iter = find(compute_graph.impl_->nodes_.begin(), compute_graph.impl_->nodes_.end(), node); - if (iter != compute_graph.impl_->nodes_.end()) { - compute_graph.EraseFromNodeList(iter); - return GRAPH_SUCCESS; - } - return GRAPH_FAILED; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RemoveJustNode( - const ComputeGraphPtr compute_graph, const NodePtr &node) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(node); - const graphStatus ret = ((RemoveJustNode(*compute_graph, node) == GRAPH_SUCCESS) ? GRAPH_SUCCESS : GRAPH_FAILED); - return ret; -} - -void GraphUtils::RecordOriginalNames(const std::vector original_nodes, const ge::NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - return, "[Check][Param] node is null."); - std::vector original_names; - std::vector original_types; - for (const auto &node_tmp : original_nodes) { - std::vector names_tmp; - std::vector types_tmp; - const ge::OpDescPtr opdesc_tmp = node_tmp->GetOpDesc(); - if (opdesc_tmp == nullptr) { - GELOGE(GRAPH_FAILED, "[Check][Param] Node %s get opdesc is nullptr", node_tmp->GetName().c_str()); - continue; - } - auto ret = ge::AttrUtils::GetListStr(opdesc_tmp, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names_tmp); - ge::AttrUtils::GetListStr(opdesc_tmp, ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES, types_tmp); - if (!ret) { - GELOGW("[Get][Attr] Get attr _datadump_original_op_names failed"); - continue; - } - if (names_tmp.size() != 0UL) { - (void)original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); - } else { - original_names.push_back(opdesc_tmp->GetName()); - } - if (types_tmp.size() != 0UL) { - (void)original_types.insert(original_types.end(), types_tmp.begin(), types_tmp.end()); - } else { - original_types.push_back(opdesc_tmp->GetType()); - } - } - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), - REPORT_INNER_ERR_MSG("E18888", "Set original_op_names to node:%s fail.", node->GetName().c_str()); - return, "[Invoke][SetListStr] Set original_op_names to node:%s fail.", node->GetName().c_str()); - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES, original_types), - REPORT_INNER_ERR_MSG("E18888", "Set original_op_types to node:%s fail.", node->GetName().c_str()); - return, "[Invoke][SetListStr] Set original_op_types to node:%s fail.", node->GetName().c_str()); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::RecordOriginalNames(std::vector names_tmp, - const ge::NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - return, "[Check][Param] node is null."); - std::vector original_names; - if (names_tmp.size() != 0UL) { - (void)original_names.insert(original_names.end(), names_tmp.begin(), names_tmp.end()); - } else { - const std::string tmp; - original_names.push_back(tmp); - } - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names), - REPORT_INNER_ERR_MSG("E18888", "Set original_op_names to node %s fail.", node->GetName().c_str()); - return, "[Invoke][SetListStr] Set original_op_names to node %s fail.", node->GetName().c_str()); -} - -namespace { -#ifdef FMK_SUPPORT_DUMP -void GetDumpGraphPrefix(std::stringstream& stream_file_name) { - static std::string path_prefix; - if (path_prefix.empty()) { - const char_t *npu_collect_path = nullptr; - MM_SYS_GET_ENV(MM_ENV_NPU_COLLECT_PATH, npu_collect_path); - if (npu_collect_path != nullptr) { - const std::string base_path_str(npu_collect_path); - stream_file_name << base_path_str << "/extra-info/graph/" << mmGetPid() << "_" << GetContext().DeviceId() << "/"; - } else { - const char_t *dump_graph_path = nullptr; - MM_SYS_GET_ENV(MM_ENV_DUMP_GRAPH_PATH, dump_graph_path); - if (dump_graph_path != nullptr) { - const std::string dump_graph_path_str(dump_graph_path); - stream_file_name << (dump_graph_path_str.empty() ? "" : dump_graph_path_str + "/"); - stream_file_name << "pid_" << mmGetPid() << "_deviceid_" << GetContext().DeviceId() << "/"; - } else { - stream_file_name << "./"; - std::string ascend_work_path; - (void)GetAscendWorkPath(ascend_work_path); - if (!ascend_work_path.empty()) { - stream_file_name.str(""); - stream_file_name << (ascend_work_path + "/"); - } - } - } - path_prefix = stream_file_name.str(); - } else { - stream_file_name << path_prefix; - } -} - -bool SetOptions2GraphInner(const std::map& option, - const std::string& attr_name, const ge::ComputeGraphPtr &graph) { - // set graph options - ge::NamedAttrs attr; - attr.SetName(attr_name); - for (auto itr_graph = option.begin(); itr_graph != option.end(); itr_graph++) { - auto const ret = attr.SetAttr(itr_graph->first, GeAttrValue::CreateFrom(itr_graph->second)); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "set [%s:] [%s]=[%s] to graph fail.", - attr.GetName().c_str(), itr_graph->first.c_str(), itr_graph->second.c_str()); - return false; - } - } - auto ret = ge::AttrUtils::SetNamedAttrs(graph, attr_name, attr); - if (!ret) { - GELOGE(GRAPH_FAILED, "set [%s] to graph fail.", attr_name.c_str()); - return false; - } - return true; -} -bool SetOptions2Graph(const int64_t dump_level, const ge::ComputeGraphPtr &graph) { - if (graph == nullptr) { - GELOGE(GRAPH_FAILED, "graph is nullptr"); - return false; - } - if (dump_level == static_cast(ge::DumpLevel::DUMP_ALL) - || dump_level == static_cast(ge::DumpLevel::DUMP_WITH_OUT_DATA)) { - GEThreadLocalContext &context = GetThreadLocalContext(); - const std::map& tmp_graph_options = context.GetAllGraphOptions(); - const std::map& tmp_session_options = context.GetAllSessionOptions(); - const std::map& tmp_global_options = context.GetAllGlobalOptions(); - if (!SetOptions2GraphInner(tmp_graph_options, "GraphOptions", graph)) { - return false; - } - if (!SetOptions2GraphInner(tmp_global_options, "GlobalOptions", graph)) { - return false; - } - if (!SetOptions2GraphInner(tmp_session_options, "SessionOptions", graph)) { - return false; - } - } - return true; -} -graphStatus GetDumpRealPath(const int64_t file_index, const std::string &suffix, - const std::string &user_graph_name, std::string &real_path_name) { - std::string relative_path; - if (user_graph_name.empty()) { - std::stringstream stream_file_name; - { - static std::mutex mutex; - const std::lock_guard lock(mutex); - GetDumpGraphPrefix(stream_file_name); - if (mmAccess2(stream_file_name.str().c_str(), M_F_OK) != EN_OK) { - if (CreateDir(stream_file_name.str()) != 0) { - GELOGW("[DumpGraph][CreateDir] Create dump graph dir failed, path:%s", stream_file_name.str().c_str()); - stream_file_name.str(""); - stream_file_name << "./"; - } - } - } - - stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; - stream_file_name << "_" << GetSanitizedName(suffix) << ".txt"; - relative_path = stream_file_name.str(); - } else { - const auto sep = user_graph_name.rfind(MMPA_PATH_SEPARATOR_STR); - if (sep == std::string::npos) { - (void)relative_path.append("./"); - (void)relative_path.append(user_graph_name); - } else { - const std::string file_name = user_graph_name.substr(sep + 1UL, user_graph_name.length()); - std::string path_dir = user_graph_name.substr(0UL, sep + 1UL); - if ((file_name.length() == 0UL) || (path_dir.length() == 0UL)) { - GELOGW("[Invalid]path or name invalid.user_graph_name:%s", user_graph_name.c_str()); - return GRAPH_PARAM_INVALID; - } - - if ((mmAccess2(path_dir.c_str(), M_F_OK) != EN_OK) && (CreateDir(path_dir) != 0)) { - GELOGW("[DumpGraph][CreateDir] Create dump graph dir failed, path:%s", path_dir.c_str()); - path_dir = "./"; - } - (void)relative_path.append(path_dir); - (void)relative_path.append(file_name); - } - } - - char_t real_path[MMPA_MAX_PATH] = {}; - auto const ret = mmRealPath(relative_path.c_str(), &(real_path[0]), MMPA_MAX_PATH); - if (ret != EN_OK) { - GELOGD("[Get][RealPath]file does not exist, it will be create. ret:%d", ret); - } - - real_path_name = real_path; - GELOGD("Get dump graph real_path_name:%s", real_path_name.c_str()); - return GRAPH_SUCCESS; -} - -bool NoNeedDumpGraph(int64_t &dump_content_level) { - const char_t *dump_ge_graph = nullptr; - MM_SYS_GET_ENV(MM_ENV_DUMP_GE_GRAPH, dump_ge_graph); - if (dump_ge_graph == nullptr) { - return true; - } - dump_content_level = (dump_ge_graph[0U] != '\0') - ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) - : static_cast(ge::DumpLevel::NO_DUMP); - if ((dump_content_level == static_cast(DumpLevel::NO_DUMP)) || - (dump_content_level >= static_cast(DumpLevel::DUMP_LEVEL_END))) { - GELOGD("Skip dump with DUMP_GE_GRAPH value:%" PRId64 ".", dump_content_level); - return true; - } - return false; -} -#endif - -inline graphStatus CheckDumpGraphNum(const int64_t file_index) { - thread_local int64_t max_dump_file_num = 0; - if (max_dump_file_num == 0) { - std::string opt = "0"; - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); - max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if ((max_dump_file_num != 0) && (file_index > max_dump_file_num)) { - GELOGW("[DumpGraph][Check] dump_graph_num exceeds max_dump_file_num, dump_graph_num=%" PRId64 - ", max_dump_file_num=%" PRId64, - file_index, max_dump_file_num); - return GRAPH_PARAM_INVALID; - } - return GRAPH_SUCCESS; -} -bool IsDumpGraphExcludeSubGraphOnLevel1(const std::string &suffix) { - if ((suffix.find(kDumpStrPartition) != std::string::npos) || - (suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || - (suffix.find(kDumpStrAicpu) != std::string::npos) || - (suffix.find(kDumpStrSubgraphFunc) != std::string::npos)) { - return false; - } - return true; -} - -bool IsDumpGraphWithinWhitelistOnLevel2(const std::string &suffix) { - // 如果是子图则过滤掉 - if (!IsDumpGraphExcludeSubGraphOnLevel1(suffix)) { - return false; - } - for (const auto &full_name : kGeDumpWhitelistFullName) { - if (suffix.compare(full_name) == 0) { - return true; - } - } - for (const auto &key_name : kGeDumpWhitelistKeyName) { - if (suffix.find(key_name) != std::string::npos) { - return true; - } - } - return false; -} - -bool IsStrNotNum(const std::string &val) { // avoid negative number '-' - for (const auto &ele : val) { - if (!isdigit(ele)) { - return true; - } - } - return false; -} - -bool IsDumpGraphByKeyName(const std::string &env_val, const std::string &suffix) { - const auto &key_names = StringUtils::Split(env_val, '|'); - for (const auto &name : key_names) { - if (suffix.find(name) != std::string::npos) { - return true; - } - } - return false; -} -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::NoNeedDumpGraphBySuffix(const std::string &suffix) { - const char_t *dump_level = nullptr; - MM_SYS_GET_ENV(MM_ENV_DUMP_GRAPH_LEVEL, dump_level); - if (dump_level == nullptr) { - return !IsDumpGraphWithinWhitelistOnLevel2(suffix); - } - - if ((suffix.empty() || (!IsStrNotNum(suffix)))) { - GELOGW("suffix %s is empty or is number, no need dump", suffix.c_str()); - return true; - } - - const std::string env_val(dump_level); - if (IsStrNotNum(env_val)) { - return !IsDumpGraphByKeyName(env_val, suffix); - } - - const int64_t dump_graph_level = std::strtol(dump_level, nullptr, kBaseOfIntegerValue); - if (dump_graph_level == static_cast(DumpGraphLevel::kDumpLevel1)) { - return !IsDumpGraphExcludeSubGraphOnLevel1(suffix); - } - - if (dump_graph_level == static_cast(DumpGraphLevel::kDumpLevel2)) { - return !IsDumpGraphWithinWhitelistOnLevel2(suffix); - } - - if (dump_graph_level == static_cast(DumpGraphLevel::kDumpLevel3)) { - return (suffix.compare(kDumpStrBuild) != 0); - } - - if (dump_graph_level == static_cast(DumpGraphLevel::kDumpLevel4)) { - return (suffix.compare(kDumpStrPreRunBegin) != 0); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(const ge::ComputeGraphPtr &graph, - const std::string &suffix, - const bool is_always_dump, - const std::string &user_graph_name) { -#ifdef FMK_SUPPORT_DUMP - GraphDumperRegistry::GetDumper().Dump(graph, suffix); - // dump the graph according to different graph level - int64_t dump_level{0}; - const bool not_dump = (NoNeedDumpGraph(dump_level) || GraphUtils::NoNeedDumpGraphBySuffix(suffix)) - && (!is_always_dump); - if (not_dump) { - return; - } - - // file name - std::string real_path; - if (GenDumpTxtFileName(graph, suffix, user_graph_name, real_path) != GRAPH_SUCCESS) { - return; - } - - // Create model - ge::Model model("", ""); - if (!SetOptions2Graph(dump_level, graph) && (!is_always_dump)) { - return; - } - model.SetGraph(graph); - ge::proto::ModelDef ge_proto; - bool is_dump_graph_structure_only = (dump_level != static_cast(ge::DumpLevel::DUMP_ALL)) && (!is_always_dump); - if (model.Save(ge_proto, is_dump_graph_structure_only) != SUCCESS) { - return; - } - GraphUtils::WriteProtoToTextFile(ge_proto, real_path.c_str()); -#else - (void)graph; - (void)suffix; - (void)is_always_dump; - (void)user_graph_name; - GELOGW("[DumpGraph][Check] Need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::DumpGEGraphByPath(const ge::ComputeGraphPtr &graph, const std::string &file_path, - const int64_t dump_level) { - return DumpGEGraphByPath(graph, file_path, static_cast(dump_level)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::DumpGEGraphByPath(const ge::ComputeGraphPtr &graph, const std::string &file_path, - const ge::DumpLevel dump_level) { - const auto sep = file_path.rfind(MMPA_PATH_SEPARATOR_STR); - if (sep == std::string::npos) { - REPORT_PREDEFINED_ERR_MSG("E19026", std::vector({"pathname", "reason"}), - std::vector({file_path.c_str(), "Separator is not found in file_path."})); - GELOGE(GRAPH_FAILED, "[CheckParam] Separator is not found in file_path.file_path:%s", file_path.c_str()); - return GRAPH_FAILED; - } - const std::string file_name = file_path.substr(sep + 1UL, file_path.length()); - const std::string path_dir = file_path.substr(0UL, sep + 1UL); - if ((file_name.length() == 0UL) || (path_dir.length() == 0UL)) { - REPORT_PREDEFINED_ERR_MSG("E19026", std::vector({"pathname", "reason"}), - std::vector({file_path.c_str(), "Path or filename is not set."})); - GELOGE(GRAPH_FAILED, "[Invalid]path or name invalid.file_path:%s", file_path.c_str()); - return GRAPH_FAILED; - } - - // Create Model - ge::Model model("", ""); - model.SetGraph(graph); - - // SerializeModel to ModelDef - ge::proto::ModelDef ge_proto; - if (model.Save(ge_proto, dump_level != ge::DumpLevel::DUMP_ALL) != SUCCESS) { - return GRAPH_FAILED; - } - // Write file - char_t real_path[MMPA_MAX_PATH] = {}; - if (mmRealPath(path_dir.c_str(), &(real_path[0U]), MMPA_MAX_PATH) != EN_OK) { - REPORT_PREDEFINED_ERR_MSG("E19026", std::vector({"pathname", "reason"}), - std::vector({path_dir.c_str(), "Directory does not exist."})); - GELOGE(GRAPH_FAILED, "[Get][RealPath]Directory %s does not exist.", path_dir.c_str()); - return GRAPH_FAILED; - } - const std::string path = real_path; - const std::string real_path_name = path + std::string(MMPA_PATH_SEPARATOR_STR) + file_name; - GraphUtils::WriteProtoToTextFile(ge_proto, real_path_name.c_str()); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGrph(const ge::ComputeGraphPtr &graph, - const std::string &path, - const std::string &suffix) { - // file name - static std::atomic atomic_file_index(0); - const auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump om txt: %" PRId64, file_index); - if (CheckDumpGraphNum(file_index) != GRAPH_SUCCESS) { return; } - - std::stringstream stream_file_name; - stream_file_name << path.c_str() << "/ge_proto_" << std::setw(kNameWidth) << std::setfill('0') - << file_index; - stream_file_name << "_graph_" << graph->GetGraphID() << "_" << GetSanitizedName(suffix) << ".txt"; - const std::string proto_file = stream_file_name.str(); - (void)DumpGEGraphByPath(graph, proto_file, ge::DumpLevel::NO_DUMP); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char_t *const file, - ge::ComputeGraph &compute_graph) { - ge::proto::ModelDef model_def; - // Get ModelDef object from file generated by DumpGEGraph() - if (!ReadProtoFromTextFile(file, &model_def)) { - REPORT_PREDEFINED_ERR_MSG( - "E19003", std::vector({"file", "errmsg"}), - std::vector({((file == nullptr) ? "" : file), "Read proto from file failed"})); - GELOGE(GRAPH_FAILED, "[Get][ModelDef] failed from file:%s", (file == nullptr) ? "" : file); - return false; - } - ge::Model model; - // Get Model object from ModelDef by deserialize ModelDef - if (model.Load(model_def) == GRAPH_SUCCESS) { - GE_CHK_BOOL_EXEC(model.GetGraph() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Get computer graph is nullptr, model file:%s.", file); - return false, "[Get][ComputerGraph] is nullptr"); - compute_graph = *model.GetGraph(); - return true; - } else { - REPORT_PREDEFINED_ERR_MSG("E19003", std::vector({"file", "errmsg"}), - std::vector({file, "Get Model failed from ModelDef"})); - GELOGE(GRAPH_FAILED, "[Get][Model] failed from ModelDef:%s", file); - return false; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char_t *const file, - ge::ComputeGraphPtr &compute_graph) { - ge::proto::ModelDef model_def; - // Get ModelDef object from file generated by DumpGEGraph() - if (!ReadProtoFromTextFile(file, &model_def)) { - REPORT_PREDEFINED_ERR_MSG( - "E19003", std::vector({"file", "errmsg"}), - std::vector({((file == nullptr) ? "" : file), "Read proto from file failed"})); - GELOGE(GRAPH_FAILED, "[Get][ModelDef] failed from file:%s", (file == nullptr) ? "" : file); - return false; - } - ge::Model model; - // Get Model object from ModelDef by deserialize ModelDef - if (model.Load(model_def) == GRAPH_SUCCESS) { - GE_CHK_BOOL_EXEC(model.GetGraph() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Get computer graph is nullptr, model file:%s.", file); - return false, "[Get][ComputerGraph] is nullptr"); - compute_graph = model.GetGraph(); - for (const auto &node : compute_graph->GetDirectNode()) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "ModeDef %s has nullptr node.", file); - GELOGE(GRAPH_FAILED, "[Get][Node]Nullptr node in graph:%s, model:%s", compute_graph->GetName().c_str(), file); - return false; - } - GELOGI("Node %s set owner graph", node->GetName().c_str()); - if (node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "SetOwnerComputeGraph failed for node:%s", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][SetGraph]Node %s set owner graph failed", node->GetName().c_str()); - return false; - } - } - GE_ASSERT_GRAPH_SUCCESS(ConvertFileConstToConst(compute_graph)); - return true; - } else { - REPORT_PREDEFINED_ERR_MSG("E19003", std::vector({"file", "errmsg"}), - std::vector({file, "Get Model failed from ModelDef"})); - GELOGE(GRAPH_FAILED, "[Get][Model] failed from ModelDef:%s", file); - return false; - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ConvertFileConstToConst(const ComputeGraphPtr &graph) { - GE_CHECK_NOTNULL(graph); - std::vector file_consts; - for (const auto &node : graph->GetDirectNode()) { - if (node->GetType() != FILECONSTANT) { - continue; - } - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - std::string file_path; - if (!AttrUtils::GetStr(op_desc, kLocation4Recover, file_path)) { - continue; - } - GE_ASSERT_TRUE(!file_path.empty()); - GE_ASSERT_GRAPH_SUCCESS(op_desc->DelAttr(kLocation4Recover)); - int64_t attr_length = 0; - GE_ASSERT_TRUE(AttrUtils::GetInt(op_desc, kLength4Recover, attr_length)); - GE_ASSERT_TRUE(attr_length > 0); - GE_ASSERT_GRAPH_SUCCESS(op_desc->DelAttr(kLength4Recover)); - size_t file_length = static_cast(attr_length); - - const auto bin_buff = ComGraphMakeUnique(file_length); - GE_CHECK_NOTNULL(bin_buff); - GE_ASSERT_GRAPH_SUCCESS(GetBinFromFile(file_path, bin_buff.get(), file_length)); - - const GeTensorPtr &weight = ComGraphMakeShared( - op_desc->GetOutputDesc(0U), PtrToPtr(bin_buff.get()), file_length); - GE_CHECK_NOTNULL(weight); - - std::string origin_type; - if (AttrUtils::GetStr(op_desc, kOriginType4Recover, origin_type) && (kConstOpTypes.count(origin_type) > 0U)) { - GE_ASSERT_SUCCESS(RecoverConstByWeightFile(op_desc, weight)); - continue; - } - - const auto const_op = OpDescUtils::CreateConstOp(weight); - GE_CHECK_NOTNULL(const_op); - const_op->SetName(op_desc->GetName() + "_" + CONSTANT); - const_op->SetId(op_desc->GetId()); - const auto const_node = graph->AddNode(const_op); - GE_CHECK_NOTNULL(const_node); - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::ReplaceNodeAnchors(const_node, node, {}, {0})); - NodeUtils::UnlinkAll(*node); - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RemoveJustNode(graph, node)); - GELOGD("Convert node: %s from file constant to const by %s success.", node->GetName().c_str(), file_path.c_str()); - } - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::RecoverConstByWeightFile(const OpDescPtr &op_desc, const GeTensorPtr &weight) { - GE_CHECK_NOTNULL(op_desc); - std::string op_name; - GE_ASSERT_TRUE(AttrUtils::GetStr(op_desc, kOriginName4Recover, op_name)); - op_desc->SetName(op_name); - GE_ASSERT_GRAPH_SUCCESS(op_desc->DelAttr(kOriginName4Recover)); - std::string op_type; - GE_ASSERT_TRUE(AttrUtils::GetStr(op_desc, kOriginType4Recover, op_type)); - op_desc->SetType(op_type); - GE_ASSERT_GRAPH_SUCCESS(op_desc->DelAttr(kOriginType4Recover)); - GE_ASSERT_TRUE(AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight)); - GELOGD("Recover const node: %s, type: %s.", op_name.c_str(), op_type.c_str()); - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::WriteProtoToOStream( - const ascend_private::protobuf::Message &proto, std::ostream &o_stream) { - auto output = ComGraphMakeUnique(&o_stream); - if (output == nullptr) { - REPORT_CALL_ERROR("E18888", "create OstreamOutputStream failed."); - GELOGE(GRAPH_FAILED, "[Create][OstreamOutputStream] Output is nullptr"); - return GRAPH_FAILED; - } - const bool ret = google::protobuf::TextFormat::Print(proto, output.get()); - if (!ret) { - REPORT_CALL_ERROR("E18888", "write ostream failed."); - GELOGE(GRAPH_FAILED, "[Invoke][Print] Fail to write the ostream"); - return GRAPH_FAILED; - } - return SUCCESS; -} - -// Printing protocol messages in text format is useful for debugging and human editing of messages. -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( - const google::protobuf::Message &proto, const char_t *const real_path) { -#ifdef FMK_SUPPORT_DUMP - const MODE FILE_AUTHORITY = 384U; // 0600U in octal - const int32_t fd = mmOpen2(real_path, - static_cast( - static_cast(M_WRONLY) | static_cast(M_CREAT) | static_cast(O_TRUNC)), - FILE_AUTHORITY); - if (fd < 0) { - REPORT_INNER_ERR_MSG("E18888", "open file:%s failed, errormessage:%s", real_path, strerror(errno)); - GELOGE(GRAPH_FAILED, "[Open][File] failed for %s, reason:%s", real_path, strerror(errno)); - return; - } - - auto output = ComGraphMakeUnique(fd); - if (output == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create FileOutputStream failed."); - GELOGE(GRAPH_FAILED, "[Create][FileOutputStream] Output is nullptr"); - if (mmClose(fd) != 0) { - REPORT_INNER_ERR_MSG("E18888", "close FileOutputStream failed, reason:%s.", strerror(errno)); - GELOGE(GRAPH_FAILED, "[Close][FileOutputStream] failed, reason:%s", strerror(errno)); - } - return; - } - const bool ret = google::protobuf::TextFormat::Print(proto, output.get()); - if (!ret) { - REPORT_INNER_ERR_MSG("E18888", "write file:%s failed.", real_path); - GELOGE(GRAPH_FAILED, "[Invoke][Print] Fail to write the file: %s", real_path); - GE_CHK_BOOL_EXEC(mmClose(fd) == 0, - REPORT_INNER_ERR_MSG("E18888", "close FileOutputStream failed, reason:%s.", strerror(errno)); - return, "[Close][FileOutputStream] failed, reason:%s", strerror(errno)); - return; - } - output.reset(); - GE_CHK_BOOL_EXEC(mmClose(fd) == 0, - REPORT_INNER_ERR_MSG("E18888", "close FileOutputStream failed, reason:%s.", strerror(errno)); - return, "[Close][FileOutputStream] failed, reason:%s.", strerror(errno)); - - FILE *const file = fopen(real_path, "rb"); - if (file == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "open file:%s failed, errormessage:%s", real_path, strerror(errno)); - GELOGE(GRAPH_FAILED, "[Invoke][FOpen] fail to open the file: %s, %s", real_path, strerror(errno)); - return; - } - if (fseek(file, 0L, SEEK_END) == 0) { - const int64_t fileSize = ftell(file); - thread_local int64_t max_dump_file_size = 0; - if (max_dump_file_size == 0) { - std::string opt = "0"; - // Can not check return value - (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); - max_dump_file_size = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); - } - if ((max_dump_file_size != 0) && (fileSize != -1) && (fileSize > max_dump_file_size)) { - GELOGW("[WriteProto][Check] dump_graph_num exceeds max_dump_file_num, dump_graph_num=%" PRId64 "," - " max_dump_file_num=%" PRId64, fileSize, max_dump_file_size); - GE_IF_BOOL_EXEC(remove(real_path) != 0, GELOGW("[WriteProto][RemovePath] Remove path %s failed", real_path)); - GE_CHK_BOOL_EXEC(fclose(file) == 0, - REPORT_INNER_ERR_MSG("E18888", "close file:%s failed, error:%s", real_path, strerror(errno)); - return, "[FClose][File] %s failed error:%s", real_path, strerror(errno)); - return; - } - } - GE_CHK_BOOL_EXEC(fclose(file) == 0, - REPORT_INNER_ERR_MSG("E18888", "close file:%s failed error:%s", real_path, strerror(errno)); - return, "[FClose][File] %s failed error:%s", real_path, strerror(errno)); -#else - (void)proto; - (void)real_path; - GELOGW("[Write][Proto] Need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::ReadProtoFromTextFile( - const char_t *const file, google::protobuf::Message *const proto) { - if ((file == nullptr) || (proto == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param file or proto is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] incorrect parameter. file path or message is invalid"); - return false; - } - std::ifstream fs(file, std::ifstream::in); - if (!fs.is_open()) { - REPORT_INNER_ERR_MSG("E18888", "open file:%s failed.", file); - GELOGE(GRAPH_FAILED, "[Invoke][OpenFile]proto file '%s' open fail.", file); - return false; - } - google::protobuf::io::IstreamInputStream input(&fs); - const bool ret = google::protobuf::TextFormat::Parse(&input, proto); - if (!ret) { - REPORT_INNER_ERR_MSG("E18888", "parse proto from text ret fail, please check your text file '%s'.", file); - GELOGE(GRAPH_FAILED, "[Parse][Proto] from text ret fail, please check your text file '%s'.", file); - } - fs.close(); - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnnx( - const ge::ComputeGraph &compute_graph, const std::string &suffix, bool is_always_dump) { -#ifdef FMK_SUPPORT_DUMP - // dump the graph according to different graph level - int64_t dump_content_level{0}; - if (!is_always_dump && (NoNeedDumpGraph(dump_content_level) || GraphUtils::NoNeedDumpGraphBySuffix(suffix))) { - return; - } - return DumpGEGraphToOnnxByContentLevel(compute_graph, suffix, static_cast(dump_content_level)); -#else - (void)compute_graph; - (void)suffix; - (void)is_always_dump; - GELOGW("[DumpGraph][Check] Need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnnxByContentLevel( - const ge::ComputeGraph &compute_graph, const std::string &suffix, DumpLevel content_level) { -#ifdef FMK_SUPPORT_DUMP - // 1.Get ge::onnx::ModelProto from ge::Model - ge::Model model("GE", ""); - const auto compute_graph_ptr = ComGraphMakeShared(compute_graph); - model.SetGraph(compute_graph_ptr); - onnx::ModelProto model_proto; - if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto, content_level)) { - GELOGE(GRAPH_FAILED, "[Convert][GeModel] DumpGEGraphToOnnx failed."); - return; - } - - // 2.Set file name - std::string real_path; - if (GenDumpOnnxFileName(compute_graph_ptr, suffix, real_path) != SUCCESS) { - return; - } - - // 3. Serialize to file in current path - GraphUtils::WriteProtoToTextFile(model_proto, real_path.c_str()); -#else - (void)compute_graph; - (void)suffix; - (void)content_level; - GELOGW("[DumpGraph][Check] Need to define FMK_SUPPORT_DUMP for dump graph."); -#endif -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, - const std::string &suffix) { - return DumpGEGraphToOnnx(compute_graph, suffix, false); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, - const std::string &path, - const std::string &suffix) { - // 1.Get ge::onnx::ModelProto from ge::Model - ge::Model model("GE", ""); - const auto compute_graph_ptr = ComGraphMakeShared(compute_graph); - model.SetGraph(compute_graph_ptr); - onnx::ModelProto model_proto; - if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) { - GELOGE(GRAPH_FAILED, "[Convert][GeModel] DumpGEGraphToOnnx failed."); - return; - } - - // 2.Set file name - static std::atomic atomic_file_index(0); - const auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump ge onnx file: %" PRId64, file_index); - if (CheckDumpGraphNum(file_index) != GRAPH_SUCCESS) { return; } - - std::stringstream stream_file_name; - stream_file_name << "ge_onnx_" << std::setw(kNameWidth) << std::setfill('0') << file_index; - stream_file_name << "_graph_" << compute_graph.GetGraphID(); - stream_file_name << "_" << GetSanitizedName(suffix) << ".pbtxt"; - std::string proto_file = stream_file_name.str(); - if ((proto_file.length()) >= kNameMax) { - proto_file = proto_file.substr(0U, kNameMax - 7U); - proto_file = proto_file + ".pbtxt"; - GELOGW("[Check][Param] File name is too longer!, file:%s", proto_file.c_str()); - } - const std::string full_proto_file = path + "/" + proto_file; - const auto real_path = ComGraphMakeUnique(static_cast(MMPA_MAX_PATH)); - if (real_path == nullptr) { - GELOGE(GRAPH_FAILED, "[New][RealPath] failed."); - return; - } - /// Returning nullptr means 3 case as follows: - /// a.path is PATH_MAX chars or more - /// b.the file does not exist - /// c.the path has no permissions - /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() - if (mmRealPath(full_proto_file.c_str(), real_path.get(), MMPA_MAX_PATH) != EN_OK) { - // Case a has been checked above - GELOGI("File %s does not exist, it will be created, realpath info[%s].", full_proto_file.c_str(), strerror(errno)); - } - - // 3. Serialize to file in current path - GraphUtils::WriteProtoToTextFile(model_proto, real_path.get()); -} - -namespace { -using InNodesToOut = std::map, NodeCompareKey>; -using OutNodesToIn = InNodesToOut; - -inline std::string GetNodeNameByAnchor(const Anchor *const anchor) { - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param anchor is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Anchor is nullptr"); - return "Null"; - } - const auto node = anchor->GetOwnerNodeBarePtr(); - return (node == nullptr) ? "Null" : node->GetName(); -} - -graphStatus ReplaceOutDataAnchor(const OutDataAnchorPtr &new_anchor, const OutDataAnchorPtr &old_anchor, - InNodesToOut *const in_nodes_to_out = nullptr, OutNodesToIn *const out_nodes_to_in = nullptr) { - if ((new_anchor == nullptr) || (old_anchor == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param new_anchor or old_anchor is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] new_anchor or old_anchor is nullptr"); - return GRAPH_PARAM_INVALID; - } - const auto new_node = new_anchor->GetOwnerNode(); - for (const auto &peer_in_anchor : old_anchor->GetPeerInDataAnchors()) { - auto ret = peer_in_anchor->Unlink(old_anchor); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to unlink old anchor link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(old_anchor.get()).c_str(), old_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - GELOGE(GRAPH_FAILED, "[Remove][Link] Failed to unlink old anchor link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(old_anchor.get()).c_str(), old_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - ret = peer_in_anchor->LinkFrom(new_anchor); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "[Create][Link] Failed to relink new anchors from %s(%d) to %s(%d)", - GetNodeNameByAnchor(new_anchor.get()).c_str(), new_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - GELOGE(GRAPH_FAILED, "[Create][Link] Failed to relink new anchors from %s(%d) to %s(%d)", - GetNodeNameByAnchor(new_anchor.get()).c_str(), new_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - - if (in_nodes_to_out != nullptr) { - (void)(*in_nodes_to_out)[new_node].emplace_back(peer_in_anchor->GetOwnerNode()); - } - if (out_nodes_to_in != nullptr) { - (void)(*out_nodes_to_in)[peer_in_anchor->GetOwnerNode()].emplace_back(new_node); - } - } - return GRAPH_SUCCESS; -} - -graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, InNodesToOut &in_nodes_to_out, - OutNodesToIn &out_nodes_to_in) { - GE_CHECK_NOTNULL(node); - auto in_data_anchors = node->GetAllInDataAnchors(); - auto out_data_anchors = node->GetAllOutDataAnchors(); - const size_t out_data_anchors_size = out_data_anchors.size(); - if (out_data_anchors_size < io_map.size()) { - REPORT_INNER_ERR_MSG("E18888", "param io_map size:%zu > the actual size:%zu, node:%s type:%s", io_map.size(), - out_data_anchors.size(), node->GetName().c_str(), node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] The io_map specified for node %s type %s is larger %zu than " - "the actual size %zu", node->GetName().c_str(), node->GetType().c_str(), - io_map.size(), out_data_anchors.size()); - return GRAPH_PARAM_INVALID; - } - - for (size_t i = 0U; i < out_data_anchors_size; ++i) { - const auto out_data_anchor = out_data_anchors.at(i); - if (out_data_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to relink for node %s type %s, the out data anchor " - "at index %zu is null", - node->GetName().c_str(), node->GetType().c_str(), i); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to relink for node %s type %s, the out data anchor " - "at index %zu is null", node->GetName().c_str(), node->GetType().c_str(), i); - return GRAPH_FAILED; - } - - int32_t in_index = -1; - if (i < io_map.size()) { - in_index = io_map.at(i); - } - if (in_index < 0) { - out_data_anchor->UnlinkAll(); - } else { - if (in_index >= static_cast(in_data_anchors.size())) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to relink for node %s type %s, invalid index %d specified for input(%zu)", - node->GetName().c_str(), node->GetType().c_str(), in_index, in_data_anchors.size()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Failed to relink for node %s type %s, invalid index %d " - "specified for input(%zu)", node->GetName().c_str(), node->GetType().c_str(), - in_index, in_data_anchors.size()); - return GRAPH_PARAM_INVALID; - } - const auto in_anchor = in_data_anchors.at(static_cast(in_index)); - if (in_anchor == nullptr) { - GELOGW("[Relink][Check] %d\'th in_data_anchor of node %s type %s is null, ignore it.", in_index, - node->GetName().c_str(), node->GetType().c_str()); - continue; - } - const auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - const auto ret = ReplaceOutDataAnchor(peer_out_anchor, out_data_anchor, &in_nodes_to_out, &out_nodes_to_in); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Replace][OutDataAnchor] Failed to relink node %s type %s for relinking data anchors", - node->GetName().c_str(), node->GetType().c_str()); - return GRAPH_FAILED; - } - } - } - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - in_anchor->UnlinkAll(); - } - return GRAPH_SUCCESS; -} - -InNodesToOut GetFullConnectIONodes(const NodePtr &node, std::set &in_nodes_set, - std::set &out_nodes_set) { - InNodesToOut in_nodes_to_out; - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Node is nullptr"); - return in_nodes_to_out; - } - const auto in_nodes_list = node->GetInNodes(); - const auto out_nodes_list = node->GetOutNodes(); - auto out_nodes = std::vector(out_nodes_list.begin(), out_nodes_list.end()); - in_nodes_set = std::set(in_nodes_list.begin(), in_nodes_list.end()); - out_nodes_set = std::set(out_nodes_list.begin(), out_nodes_list.end()); - - for (const auto &in_node : in_nodes_list) { - (void)in_nodes_to_out.emplace(in_node, out_nodes); - } - return in_nodes_to_out; -} - -graphStatus RelinkControlNodeIfNeed(const NodePtr &node, const InNodesToOut &in_nodes_to_out, - InNodesToOut &connected_data_in_to_out) { - GE_CHECK_NOTNULL(node); - for (const auto &in_node_to_out : in_nodes_to_out) { - auto &in_node = in_node_to_out.first; - GE_CHECK_NOTNULL(in_node); - auto &connected_data_out = connected_data_in_to_out[in_node]; - for (const auto &out_node : in_node_to_out.second) { - GE_CHECK_NOTNULL(out_node); - if (std::find(connected_data_out.begin(), connected_data_out.end(), out_node) == connected_data_out.end()) { - GE_CHECK_NOTNULL(in_node->GetOutControlAnchor()); - if (in_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor())) { - continue; - } - // Some pass, such as SameTransdataBreadFusionPass will generate a ring, so add a - // ring breaking operation here, and notice, this is an operation which will be - // delete later, so do not use this interface to break a ring - if (in_node == out_node) { - GELOGW("[Relink][CtrlNode] There is a cycle between %s to %s when isolating node %s type %s", - in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str(), - node->GetType().c_str()); - continue; - } - const auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), out_node->GetInControlAnchor()); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Add ControlEdge from %s to %s failed, when isolating node %s type %s", - in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str(), - node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Add][ControlEdge] from %s to %s failed, when isolating node %s type %s", - in_node->GetName().c_str(), out_node->GetName().c_str(), node->GetName().c_str(), - node->GetType().c_str()); - return GRAPH_FAILED; - } - } - } - } - return GRAPH_SUCCESS; -} -template -graphStatus ReplaceOutDataAnchors(const OutDataAnchorVisitor &new_outs, - const OutDataAnchorVisitor &old_outs, - const std::vector &outputs_map) { - const auto new_out_size = new_outs.size(); - if (new_out_size < outputs_map.size()) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to replace out data anchors, the actual size %zu is less than " - "the mapping size %zu", - new_out_size, outputs_map.size()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Failed to replace out data anchors, the actual size %zu is less than " - "the mapping size %zu", new_out_size, outputs_map.size()); - return GRAPH_PARAM_INVALID; - } - for (size_t i = 0U; i < new_out_size; ++i) { - auto &new_out_anchor = new_outs.at(i); - if (new_out_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to replace out data anchors, " - "the out data anchor on new node is null, index %zu", - i); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to replace out data anchors, " - "the out data anchor on new node is null, index %zu", i); - return GRAPH_FAILED; - } - if (i >= outputs_map.size()) { - continue; - } - const auto old_index = outputs_map.at(i); - if ((old_index < 0) || (static_cast(old_index) >= old_outs.size())) { - continue; - } - - const OutDataAnchorPtr &old_out_anchor = old_outs.at(static_cast(old_index)); - if (old_out_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to replace out data anchors, " - "the out data anchor on old node is null, index %d", - old_index); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to replace out data anchors, " - "the out data anchor on old node is null, index %d", old_index); - return GRAPH_FAILED; - } - const auto ret = ReplaceOutDataAnchor(new_out_anchor, old_out_anchor); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - - return GRAPH_SUCCESS; -} -template -graphStatus DoReplaceInDataAnchors(const InDataAnchorVisitor &new_ins, - const InDataAnchorVisitor &old_ins, - const std::vector &inputs_map, bool need_keep_origin = false) { - const auto new_in_size = new_ins.size(); - if (new_in_size < inputs_map.size()) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to replace in data anchors, " - "the actual size %zu is less than the mapping size %zu", - new_in_size, inputs_map.size()); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to replace in data anchors, " - "the actual size %zu is less than the mapping size %zu", new_in_size, inputs_map.size()); - return GRAPH_PARAM_INVALID; - } - - for (size_t i = 0U; i < new_in_size; ++i) { - auto &new_in_anchor = new_ins.at(i); - if (new_in_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to replace in data anchors, " - "the out data anchor on new node is null, index %zu", - i); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to replace in data anchors, " - "the out data anchor on new node is null, index %zu", i); - return GRAPH_FAILED; - } - if (i >= inputs_map.size()) { - continue; - } - const auto old_index = inputs_map.at(i); - if ((old_index < 0) || (static_cast(old_index) >= old_ins.size())) { - continue; - } - const InDataAnchorPtr &old_in_anchor = old_ins.at(static_cast(old_index)); - if (old_in_anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", - "Failed to replace in data anchors, " - "the out data anchor on old node is null, index %d", - old_index); - GELOGE(GRAPH_FAILED, "[Check][Param] Failed to replace in data anchors, " - "the out data anchor on old node is null, index %d", old_index); - return GRAPH_FAILED; - } - - const auto peer_out_anchor = old_in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - auto ret = GRAPH_SUCCESS; - if (!need_keep_origin) { - ret = peer_out_anchor->Unlink(old_in_anchor); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to unlink old anchors, unlink from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - GELOGE(GRAPH_FAILED, "[Remove][Link] Failed to unlink old anchors, unlink from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - } - ret = peer_out_anchor->LinkTo(new_in_anchor); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to link new anchors, link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - GELOGE(GRAPH_FAILED, "[Create][Link]Failed to link new anchors, link from %s(%d) to %s(%d)", - GetNodeNameByAnchor(peer_out_anchor.get()).c_str(), peer_out_anchor->GetIdx(), - GetNodeNameByAnchor(old_in_anchor.get()).c_str(), old_in_anchor->GetIdx()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_node) { - GE_CHECK_NOTNULL(new_node); - GE_CHECK_NOTNULL(new_node->GetInControlAnchor()); - GE_CHECK_NOTNULL(old_node); - GE_CHECK_NOTNULL(old_node->GetInControlAnchor()); - const auto peer_out_anchors = old_node->GetInControlAnchor()->GetPeerAnchors(); - const auto new_in_control_anchor = new_node->GetInControlAnchor(); - const auto exists_out_anchors = new_in_control_anchor->GetPeerAnchors(); - const auto exists_out_anchors_set = std::set(exists_out_anchors.begin(), exists_out_anchors.end()); - for (const auto &peer_out_anchor : peer_out_anchors) { - if (peer_out_anchor == nullptr) { - continue; - } - if (exists_out_anchors_set.count(peer_out_anchor) > 0U) { - continue; - } - const auto ret = GraphUtils::AddEdge(peer_out_anchor, new_in_control_anchor); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Add edge from %s to %s failed, ret:%d", - peer_out_anchor->GetOwnerNode()->GetName().c_str(), - new_in_control_anchor->GetOwnerNode()->GetName().c_str(), ret); - GELOGE(GRAPH_FAILED, "[Add][Edge] from %s to %s failed, ret:%d", - peer_out_anchor->GetOwnerNode()->GetName().c_str(), - new_in_control_anchor->GetOwnerNode()->GetName().c_str(), ret); - return GRAPH_FAILED; - } - } - const auto old_out_control_anchor = old_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(old_out_control_anchor); - const auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); - const auto new_out_control_anchor = new_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(new_out_control_anchor); - auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); - const auto exists_in_anchors_set = std::set(exists_in_anchors.begin(), exists_in_anchors.end()); - for (const auto &peer_in_anchor : peer_in_anchors) { - if (peer_in_anchor == nullptr) { - continue; - } - if (exists_in_anchors_set.count(peer_in_anchor) > 0U) { - continue; - } - const auto ret = GraphUtils::AddEdge(new_out_control_anchor, peer_in_anchor); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "AddEdge from %s to %s failed, ret:%d", - new_out_control_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), ret); - GELOGE(GRAPH_FAILED, "[Add][Edge] from %s to %s failed, ret:%d", - new_out_control_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), ret); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -// check refdata in subgraph is ref from inner data -graphStatus CheckIsRefFromInnerData(const OutDataAnchorPtr &out_data_anchor, NodePtr &inner_data, - bool &is_ref_from_innerdata) { - is_ref_from_innerdata = false; - const auto owner_node = out_data_anchor->GetOwnerNode(); - if (owner_node->GetType() != REFDATA) { - return GRAPH_SUCCESS; - } - GE_ASSERT_NOTNULL(owner_node->GetOwnerComputeGraph()); - if (owner_node->GetOwnerComputeGraph()->GetParentNode() == nullptr) { - return GRAPH_SUCCESS; - } - - NodePtr peer_in_ctrl_inner_data = nullptr; - for (const auto &peer_out_ctrl : owner_node->GetInControlAnchor()->GetPeerOutControlAnchors()) { - const auto peer_in_ctrl_node = peer_out_ctrl->GetOwnerNode(); - GE_ASSERT_NOTNULL(peer_in_ctrl_node); - if (OpTypeUtils::IsSubgraphInnerData(peer_in_ctrl_node->GetOpDesc())) { - peer_in_ctrl_inner_data = peer_in_ctrl_node; - break; - } - } - GE_ASSERT_NOTNULL(peer_in_ctrl_inner_data, - "Invalid graph. Refdata[%s] in subgraph[%s] should has one control edge from inner data.", - owner_node->GetNamePtr(), owner_node->GetOwnerComputeGraph()->GetName().c_str()); - inner_data = peer_in_ctrl_inner_data; - is_ref_from_innerdata = true; - return GRAPH_SUCCESS; -} - -graphStatus CheckIsRefFromRefData(const OutDataAnchorPtr &out_data_anchor, NodePtr &refed_node, - bool &is_ref_from_refdata) { - is_ref_from_refdata = false; - const auto owner_node = out_data_anchor->GetOwnerNode(); - const auto out_desc = owner_node->GetOpDesc()->GetOutputDescPtr(static_cast(out_data_anchor->GetIdx())); - GE_ASSERT_NOTNULL(out_desc); - std::string ref_var_src_var_name; - bool has_ref_attr = ge::AttrUtils::GetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - if (!has_ref_attr) { - return GRAPH_SUCCESS; - } - // find src ref_data_node - const auto &ower_graph = owner_node->GetOwnerComputeGraph(); - GE_ASSERT_NOTNULL(ower_graph); - const auto ref_data_node = ower_graph->FindNode(ref_var_src_var_name); - if (ref_data_node == nullptr) { - GELOGW("Can not find refdata named %s. Please check ref relation on graph.", ref_var_src_var_name.c_str()); - return GRAPH_SUCCESS; - } - if (ref_data_node->GetType() != REFDATA) { - return GRAPH_SUCCESS; - } - refed_node = ref_data_node; - is_ref_from_refdata = true; - return GRAPH_SUCCESS; -} - -bool IsNeedOptimizeWithNoOp(const size_t in_size, const size_t out_size) { - return ((in_size * out_size) > kNoOpOptimizeThreshold) && ((in_size * out_size) > (in_size + out_size)); -} - -graphStatus RelinkControlNodeWithNoOpOptimize(const NodePtr &node, const std::set &in_nodes, - const std::set &out_nodes, - InNodesToOut &connected_data_in_to_out, - InNodesToOut &connected_data_out_to_in) { - GE_CHECK_NOTNULL(node); - const auto in_node_size = in_nodes.size(); - const auto out_node_size = out_nodes.size(); - GELOGD("Relink control node with NoOp optimize for [%s][%s], as in_node_size is %zu, out_node_size is %zu", - node->GetNamePtr(), node->GetTypePtr(), in_node_size, out_node_size); - std::vector noop_in_nodes{}; - std::vector noop_out_nodes{}; - for (const auto &in_node : in_nodes) { - GE_ASSERT_NOTNULL(in_node); - const auto &iter = connected_data_in_to_out.find(in_node); - if ((iter != connected_data_in_to_out.end()) && (iter->second.size() == out_node_size)) { - GELOGD("The node %s will not add out ctrl to NoOp, as already link to all out nodes", in_node->GetNamePtr()); - continue; - } - noop_in_nodes.emplace_back(in_node); - GELOGD("The node %s will add ctrl to NoOp", in_node->GetNamePtr()); - } - for (const auto &out_node : out_nodes) { - GE_ASSERT_NOTNULL(out_node); - const auto &iter = connected_data_out_to_in.find(out_node); - if ((iter != connected_data_out_to_in.end()) && (iter->second.size() == in_node_size)) { - GELOGD("The node %s will not add in ctrl from NoOp, as already link to all in nodes", out_node->GetNamePtr()); - continue; - } - noop_out_nodes.emplace_back(out_node); - GELOGD("The node %s will add in ctrl from NoOp", out_node->GetNamePtr()); - } - - if (noop_in_nodes.empty() || noop_out_nodes.empty()) { - return GRAPH_SUCCESS; - } - - const auto &graph = node->GetOwnerComputeGraph(); - GE_ASSERT_NOTNULL(graph); - const auto &noop = graph->AddNode(OpDescBuilder(node->GetName() + "_noop", NOOP).Build()); - GE_ASSERT_NOTNULL(noop); - for (const auto &in_node : noop_in_nodes) { - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(in_node->GetOutControlAnchor(), noop->GetInControlAnchor()), - "Add ControlEdge from %s to %s failed, when isolating node %s type %s", - in_node->GetNamePtr(), noop->GetNamePtr(), node->GetNamePtr(), node->GetTypePtr()); - } - - for (const auto &out_node : noop_out_nodes) { - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(noop->GetOutControlAnchor(), out_node->GetInControlAnchor()), - "Add ControlEdge from %s to %s failed, when isolating node %s type %s", noop->GetNamePtr(), - out_node->GetNamePtr(), node->GetNamePtr(), node->GetTypePtr()); - } - return GRAPH_SUCCESS; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNode(const NodePtr &node, - const std::vector &io_map) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Failed to isolate node(null)"); - return GRAPH_PARAM_INVALID; - } - - /// We must get full connections info before re-link data io, because the data - /// edges may be unlinked when relink data io - std::set in_nodes{}; - std::set out_nodes{}; - const auto in_nodes_to_out = GetFullConnectIONodes(node, in_nodes, out_nodes); - - InNodesToOut data_in_to_out; - OutNodesToIn data_out_to_in; - auto ret = RelinkDataIO(node, io_map, data_in_to_out, data_out_to_in); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Relink][DataIO] failed, node %s type %s", node->GetName().c_str(), node->GetType().c_str()); - return ret; - } - - if (IsNeedOptimizeWithNoOp(in_nodes.size(), out_nodes.size())) { - ret = RelinkControlNodeWithNoOpOptimize(node, in_nodes, out_nodes, data_in_to_out, data_out_to_in); - } else { - ret = RelinkControlNodeIfNeed(node, in_nodes_to_out, data_in_to_out); - } - if (ret != GRAPH_SUCCESS) { - return ret; - } - NodeUtils::UnlinkAll(*node); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::IsolateNode(const NodePtr &node, const std::initializer_list &io_map) { - return IsolateNode(node, std::vector(io_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::IsolateNodeOneIO(const NodePtr &node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] incorrect parameter. node is invalid"); - return GRAPH_PARAM_INVALID; - } - if (node->GetAllInDataAnchorsSize() != 1U) { - return GRAPH_PARAM_INVALID; - } - if (node->GetAllOutDataAnchorsSize() != 1U) { - return GRAPH_PARAM_INVALID; - } - return IsolateNode(node, {0}); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, - const std::vector &inputs_map, - const std::vector &outputs_map) { - if ((new_node == nullptr) || (old_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param new_node or old_node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - auto ret = ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - if (ret != GRAPH_SUCCESS) { - // The error log was printed in `ReplaceNodeDataAnchors` - return GRAPH_FAILED; - } - ret = ReplaceControlAnchors(new_node, old_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Replace][ControlAnchors] failed when replace node from old node %s type %s " - "to new node %s type %s", old_node->GetName().c_str(), old_node->GetType().c_str(), - new_node->GetName().c_str(), new_node->GetType().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodeAnchors( - const NodePtr &new_node, const NodePtr &old_node, const std::initializer_list inputs_map, - const std::initializer_list outputs_map) { - return ReplaceNodeAnchors(new_node, old_node, - std::vector(inputs_map), std::vector(outputs_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - const std::initializer_list inputs_map, - const std::initializer_list outputs_map) { - return ReplaceNodeDataAnchors(new_node, old_node, - std::vector(inputs_map), std::vector(outputs_map)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, - const std::vector &inputs_map, - const std::vector &outputs_map) { - if ((new_node == nullptr) || (old_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param new_node or old_node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - - auto ret = ReplaceOutDataAnchors(new_node->GetAllOutDataAnchors(), old_node->GetAllOutDataAnchors(), outputs_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Replace][OutDataAnchors] failed when replace node from old node %s type %s " - "to new node %s type %s", old_node->GetName().c_str(), old_node->GetType().c_str(), - new_node->GetName().c_str(), new_node->GetType().c_str()); - return GRAPH_FAILED; - } - ret = DoReplaceInDataAnchors(new_node->GetAllInDataAnchors(), old_node->GetAllInDataAnchors(), inputs_map); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Replace][InDataAnchors] failed when replace node from old node %s type %s " - "to new node %s type %s", old_node->GetName().c_str(), old_node->GetType().c_str(), - new_node->GetName().c_str(), new_node->GetType().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -// 检查并插入owner图是否唯一 -bool CheckAndInsertGraph(const NodePtr &node, - std::set &owner_graph, - ComputeGraph **graph_has_inserted_to_set) { - auto graph = node->GetOwnerComputeGraphBarePtr(); - if (graph != nullptr) { - if (owner_graph.empty()) { - GE_ASSERT_TRUE(owner_graph.insert(graph).second); - *graph_has_inserted_to_set = graph; - } else { - GE_ASSERT_EQ(owner_graph.size(), 1U); - *graph_has_inserted_to_set = *(owner_graph.begin()); - if (owner_graph.find(graph) == owner_graph.end()) { - GELOGE(GRAPH_FAILED, - "Node %s has diff owner graph %s with before nodes's graph %s", - node->GetNamePtr(), - graph->GetName().c_str(), - (*graph_has_inserted_to_set)->GetName().c_str()); - return false; - } - } - } - return true; -} - -graphStatus ExtractAndCheckInDataAnchorsByOrder(const std::vector &nodes, - std::vector &in_data_anchors, - std::set &owner_graph) { - in_data_anchors.clear(); - static ComputeGraph *graph_has_inserted_to_set = nullptr; - for (const auto &node: nodes) { - GE_ASSERT_NOTNULL(node); - if (!CheckAndInsertGraph(node, owner_graph, &graph_has_inserted_to_set)) { - return GRAPH_FAILED; - } - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - in_data_anchors.push_back(in_data_anchor); - } - } - return GRAPH_SUCCESS; -} - -graphStatus ExtractAndCheckOutDataAnchorsByOrder(const std::vector &nodes, - std::vector &out_data_anchors, - std::set &owner_graph) { - out_data_anchors.clear(); - static ComputeGraph *graph_has_inserted_to_set = nullptr; - for (const auto &node: nodes) { - GE_ASSERT_NOTNULL(node); - if (!CheckAndInsertGraph(node, owner_graph, &graph_has_inserted_to_set)) { - return GRAPH_FAILED; - } - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - out_data_anchors.push_back(out_data_anchor); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InheritExecutionOrder(const std::vector &new_nodes, - const std::vector &old_nodes, - const ComputeGraphPtr &graph, - bool need_convert_data_edges_2_ctrl_edges) { - GE_ASSERT_NOTNULL(graph); - GE_ASSERT_TRUE(!old_nodes.empty()); - GE_ASSERT_TRUE(!new_nodes.empty()); - const auto &first_new_old = new_nodes.front(); - const auto &last_new_old = new_nodes.back(); - // 防止name重名 - const auto &noop_in = graph->AddNode(OpDescBuilder("noop_in_" + first_new_old->GetName(), NOOP).Build()); - const auto &noop_out = graph->AddNode(OpDescBuilder("noop_out_" + last_new_old->GetName(), NOOP).Build()); - // 注意old_nodes内部的控制关系不需要带到noop上 - NodeFilter node_filter = - [&old_nodes](const Node &node) { - return std::find(old_nodes.begin(), old_nodes.end(), node.shared_from_this()) == old_nodes.end(); - }; - for (const auto &old_node : old_nodes) { - GE_ASSERT_GRAPH_SUCCESS(CopyInCtrlEdges(old_node, noop_in, node_filter)); - GE_ASSERT_GRAPH_SUCCESS(CopyOutCtrlEdges(old_node, noop_out, node_filter)); - if (need_convert_data_edges_2_ctrl_edges) { - GE_ASSERT_GRAPH_SUCCESS(ConvertInDataEdgesToInCtrlEdges(old_node, noop_in, node_filter)); - GE_ASSERT_GRAPH_SUCCESS(ConvertOutDataEdgesToOutCtrlEdges(old_node, noop_out, node_filter)); - } - } - if (noop_in->GetInControlNodesSize() > 0U) { - for (const auto &new_node : new_nodes) { - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(noop_in->GetOutControlAnchor(), new_node->GetInControlAnchor())); - } - } else { - GE_ASSERT_GRAPH_SUCCESS(RemoveJustNode(graph, noop_in)); - } - if (noop_out->GetOutControlNodesSize() > 0U) { - for (const auto &new_node : new_nodes) { - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), noop_out->GetInControlAnchor())); - } - } else { - GE_ASSERT_GRAPH_SUCCESS(RemoveJustNode(graph, noop_out)); - } - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ReplaceNodesDataAnchors(const std::vector &new_nodes, - const std::vector &old_nodes, - const std::vector &inputs_map, - const std::vector &outputs_map) { - GE_ASSERT_GRAPH_SUCCESS(ReplaceNodesInDataAnchors(new_nodes, old_nodes, inputs_map)); - GE_ASSERT_GRAPH_SUCCESS(ReplaceNodesOutDataAnchors(new_nodes, old_nodes, outputs_map)); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::ReplaceNodesInDataAnchors(const std::vector &new_nodes, - const std::vector &old_nodes, - const std::vector &inputs_map) { - GE_ASSERT_TRUE(!old_nodes.empty()); - GE_ASSERT_TRUE(!new_nodes.empty()); - std::vector old_nodes_in_data_anchors; - std::set owner_graph; - GE_ASSERT_GRAPH_SUCCESS(ExtractAndCheckInDataAnchorsByOrder(old_nodes, - old_nodes_in_data_anchors, - owner_graph)); - std::vector new_nodes_in_data_anchors; - GE_ASSERT_GRAPH_SUCCESS(ExtractAndCheckInDataAnchorsByOrder(new_nodes, - new_nodes_in_data_anchors, - owner_graph)); - GE_ASSERT_EQ(owner_graph.size(), 1U); - return DoReplaceInDataAnchors(new_nodes_in_data_anchors, old_nodes_in_data_anchors, inputs_map); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::ReplaceNodesOutDataAnchors(const std::vector &new_nodes, - const std::vector &old_nodes, - const std::vector &outputs_map) { - GE_ASSERT_TRUE(!old_nodes.empty()); - GE_ASSERT_TRUE(!new_nodes.empty()); - std::vector old_nodes_out_data_anchors; - std::set owner_graph; - GE_ASSERT_GRAPH_SUCCESS(ExtractAndCheckOutDataAnchorsByOrder(old_nodes, - old_nodes_out_data_anchors, - owner_graph)); - std::vector new_nodes_out_data_anchors; - GE_ASSERT_GRAPH_SUCCESS(ExtractAndCheckOutDataAnchorsByOrder(new_nodes, - new_nodes_out_data_anchors, - owner_graph)); - GE_ASSERT_EQ(owner_graph.size(), 1U); - return ReplaceOutDataAnchors(new_nodes_out_data_anchors, old_nodes_out_data_anchors, outputs_map); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::CopyNodesInDataAnchors(const std::vector &new_nodes, - const std::vector &old_nodes, - const std::vector &inputs_map) { - GE_ASSERT_TRUE(!old_nodes.empty()); - GE_ASSERT_TRUE(!new_nodes.empty()); - std::vector old_nodes_in_data_anchors; - std::set owner_graph; - GE_ASSERT_GRAPH_SUCCESS(ExtractAndCheckInDataAnchorsByOrder(old_nodes, - old_nodes_in_data_anchors, - owner_graph)); - std::vector new_nodes_in_data_anchors; - GE_ASSERT_GRAPH_SUCCESS(ExtractAndCheckInDataAnchorsByOrder(new_nodes, - new_nodes_in_data_anchors, - owner_graph)); - GE_ASSERT_EQ(owner_graph.size(), 1U); - return DoReplaceInDataAnchors(new_nodes_in_data_anchors, old_nodes_in_data_anchors, inputs_map, true); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInCtrlEdges(const NodePtr &src_node, - const NodePtr &dst_node) { - return CopyInCtrlEdges(src_node, dst_node, nullptr); -} - -graphStatus AddCtrlEdges(const vector> &src_nodes, - const vector> &dst_nodes) { - for (const auto &dst_node : dst_nodes) { - GE_ASSERT_NOTNULL(dst_node); - std::unordered_set exist_in_ctrl_nodes_set; - auto exist_in_ctrl_nodes = dst_node->GetInControlNodes(); - if (!exist_in_ctrl_nodes.empty()) { - exist_in_ctrl_nodes_set.insert(exist_in_ctrl_nodes.begin(), exist_in_ctrl_nodes.end()); - } - - const auto dst_ctrl = dst_node->GetInControlAnchor(); - for (const auto &in_node : src_nodes) { - GE_ASSERT_NOTNULL(in_node); - if (exist_in_ctrl_nodes_set.count(in_node) > 0U) { - continue; - } - const auto ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_ctrl); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Add ControlEdge from %s to %s failed", in_node->GetName().c_str(), - dst_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][ControlEdge] from %s to %s failed", - in_node->GetName().c_str(), dst_node->GetName().c_str()); - return ret; - } - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInCtrlEdges(const NodePtr &src_node, - const NodePtr &dst_node, - const NodeFilter &node_filter) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src_node or dst_node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - const auto src_ctrl_in_nodes = NodeUtils::GetInControlNodes(*src_node, node_filter); - if (src_ctrl_in_nodes.empty()) { - return GRAPH_SUCCESS; - } - return AddCtrlEdges(src_ctrl_in_nodes, {dst_node}); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveInCtrlEdges(const NodePtr &src_node, - const NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src_node or dst_node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Parameter is nullptr"); - return GRAPH_FAILED; - } - const auto ret = CopyInCtrlEdges(src_node, dst_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][InCtrlEdges] failed, ret:%d, src_node:%s, dst_node:%s", - ret, src_node->GetName().c_str(), dst_node->GetName().c_str()); - return ret; - } - GE_CHECK_NOTNULL(src_node->GetInControlAnchor()); - src_node->GetInControlAnchor()->UnlinkAll(); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOutCtrlEdges(const NodePtr &src_node, - const NodePtr &dst_node) { - return CopyOutCtrlEdges(src_node, dst_node, nullptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOutCtrlEdges(const NodePtr &src_node, - const NodePtr &dst_node, - const NodeFilter &node_filter) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src_node or dst_node is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Parameter is nullptr"); - return GRAPH_FAILED; - } - const auto &out_ctrl_nodes = NodeUtils::GetOutControlNodes(*src_node, node_filter); - if (out_ctrl_nodes.empty()) { - return GRAPH_SUCCESS; - } - return AddCtrlEdges({dst_node}, out_ctrl_nodes); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCtrlEdges(NodePtr &src_node, - NodePtr &dst_node) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src_node or dst_node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Parameter is nullptr"); - return GRAPH_FAILED; - } - const auto ret = CopyOutCtrlEdges(src_node, dst_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][OutCtrlEdges] failed, ret:%d", ret); - return ret; - } - GE_CHECK_NOTNULL(src_node->GetOutControlAnchor()); - src_node->GetOutControlAnchor()->UnlinkAll(); - return GRAPH_SUCCESS; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ConvertInDataEdgesToInCtrlEdges(const NodePtr &src_node, - const NodePtr &dst_node, - const NodeFilter &node_filter) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src_node or dst_node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - const auto src_in_data_nodes = NodeUtils::GetInDataNodes(*src_node, node_filter); - if (src_in_data_nodes.empty()) { - return GRAPH_SUCCESS; - } - return AddCtrlEdges(src_in_data_nodes, {dst_node}); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ConvertOutDataEdgesToOutCtrlEdges(const NodePtr &src_node, - const NodePtr &dst_node, - const NodeFilter &node_filter) { - if ((src_node == nullptr) || (dst_node == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src_node or dst_node is nullptr, check invalid"); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Parameter is nullptr"); - return GRAPH_PARAM_INVALID; - } - const auto &src_out_data_nodes = NodeUtils::GetOutDataNodes(*src_node, node_filter); - if (src_out_data_nodes.empty()) { - return GRAPH_SUCCESS; - } - return AddCtrlEdges({dst_node}, src_out_data_nodes); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveNodesToGraphAfterTargetNode( - const ComputeGraphPtr &target_graph, - const NodePtr &target_node, const ComputeGraphPtr &src_graph) { - GE_ASSERT_NOTNULL(target_node); - GE_ASSERT_NOTNULL(target_graph); - GE_ASSERT_NOTNULL(src_graph); - // 将expandGraph上的算子移动到targetGraph - auto target_graph_impl = target_graph->impl_; - GE_ASSERT_NOTNULL(target_graph_impl); - auto target_iter = - std::find(target_graph_impl->nodes_.begin(), target_graph_impl->nodes_.end(), target_node); - GE_ASSERT_TRUE(target_iter != target_graph_impl->nodes_.end(), - "Target node: %s should in target graph: %s", target_node->GetNamePtr(), target_graph->GetName().c_str()); - target_iter = next(target_iter); - for (const auto &node : src_graph->GetDirectNode()) { - GE_ASSERT_NOTNULL(node); - // 输入输出算子不挪 - if ((node->GetType() == DATA) || (node->GetType() == NETOUTPUT)) { - continue; - } - target_graph_impl->InsertToNodeList(target_iter, node); - node->SetHostNode(target_graph_impl->is_valid_flag_); - GE_ASSERT_SUCCESS(node->SetOwnerComputeGraph(target_graph), - "SetOwnerComputeGraph:%s failed for node:%s", target_graph->GetName().c_str(), - node->GetNamePtr()); - auto op_desc = node->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - const auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); - for (size_t i = 0UL; i < sub_graph_names.size(); i++) { - auto sub_graph = src_graph->GetSubgraph(sub_graph_names[i]); - GE_ASSERT_NOTNULL(sub_graph); - sub_graph->SetParentGraph(target_graph); - } - } - // 将expandGraph中的剩余子图继承到根图 - const auto root_graph = GraphUtils::FindRootGraph(target_graph); - for (const auto &subgraph : src_graph->GetAllSubgraphs()) { - root_graph->AddSubGraph(subgraph); - } - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::ExpandNodeWithGraph( - const NodePtr &target_node, const ComputeGraphPtr &expand_graph) { - GE_ASSERT_NOTNULL(target_node); - auto target_graph = target_node->GetOwnerComputeGraph(); - GE_ASSERT_NOTNULL(target_graph); - // 搬移节点 - GE_ASSERT_SUCCESS(GraphUtils::MoveNodesToGraphAfterTargetNode(target_graph, target_node, expand_graph)); - // 处理输入连边关系 - for (const auto &input_node : expand_graph->GetInputNodes()) { - GE_ASSERT_SUCCESS(ReLinkInputDataEdge(input_node, target_node)); - NodeUtils::UnlinkAll(*input_node); - } - // 处理输出连边关系 - // 如果expand图存在netoutput,则先断开netoutput - const auto net_output_node = expand_graph->FindFirstNodeMatchType(NETOUTPUT); - if (net_output_node != nullptr) { - NodeUtils::UnlinkAll(*net_output_node); - } - const auto out_nodes_info = expand_graph->GetGraphOutNodesInfo(); - for (size_t index = 0UL; index < out_nodes_info.size(); index++) { - GE_ASSERT_SUCCESS(RelinkOutputNodeEdge(out_nodes_info[index].first, - out_nodes_info[index].second, target_node, index)); - } - // 处理输出信息的映射关系 - const auto target_graph_out_node_info = target_graph->GetGraphOutNodesInfo(); - const auto sub_graph_out_node_info = expand_graph->GetGraphOutNodesInfo(); - std::vector> new_output_info; - for (const auto &out_node_info : target_graph_out_node_info) { - if (out_node_info.first == target_node) { - GE_ASSERT_TRUE(static_cast(out_node_info.second) < sub_graph_out_node_info.size()); - (void)new_output_info.emplace_back(sub_graph_out_node_info[out_node_info.second]); - } else { - (void)new_output_info.emplace_back(out_node_info); - } - } - // 删除原算子 - NodeUtils::UnlinkAll(*target_node); - GE_ASSERT_SUCCESS(RemoveNodeWithoutRelink(target_graph, target_node)); - target_graph->SetGraphOutNodesInfo(new_output_info); - return SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, - const NodePtr &node) { - if (graph->AddInputNode(node) == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddInputNode %s(%s) failed, graph:%s", node->GetName().c_str(), - node->GetType().c_str(), graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][InputNode] %s(%s) failed, graph:%s", node->GetName().c_str(), - node->GetType().c_str(), graph->GetName().c_str()); - return GRAPH_FAILED; - } - graph->SetInputSize(graph->GetInputSize() + 1U); - if (graph->impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "Graph impl is nullptr."); - return GRAPH_FAILED; - } - graph->impl_->inputs_order_.emplace_back(node->GetName()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ComputeGraphPtr GraphUtils::FindRootGraph(ComputeGraphPtr graph) { - ComputeGraphPtr result = nullptr; - while (graph != nullptr) { - result = std::move(graph); - graph = result->GetParentGraph(); - } - return result; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyComputeGraph( - const ComputeGraphPtr &src_compute_graph, const NodeFilter &node_filter, const GraphFilter &graph_filter, - const AttrFilter &attr_filter, ComputeGraphPtr &dst_compute_graph) { - GE_CHECK_NOTNULL(src_compute_graph); - if (src_compute_graph->GetParentGraph() != nullptr) { - GELOGE(GRAPH_FAILED, - "[Check][RootGraph] Only support copy root graph, current graph name:%s, " - "parent graph name:%s.", - src_compute_graph->GetName().c_str(), src_compute_graph->GetParentGraph()->GetName().c_str()); - return GRAPH_FAILED; - } - - const int32_t depth = 0; - std::map old_2_new_node; - std::map old_2_new_op_desc; - const graphStatus ret = CopyComputeGraph(src_compute_graph, node_filter, graph_filter, attr_filter, dst_compute_graph, - old_2_new_node, old_2_new_op_desc, depth); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][ComputeGraphPtr] failed, ret:%d.", ret); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr GraphUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - GE_CHECK_NOTNULL_EXEC(org_op_desc, return nullptr); - const auto op_def = ComGraphMakeShared(); - GE_CHECK_NOTNULL_EXEC(op_def, return nullptr); - - ModelSerializeImp imp; - (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); - - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), - REPORT_INNER_ERR_MSG("E18888", "UnserializeOpDesc failed"); - return op_desc, "[Call][UnserializeOpDesc] op_desc unserialize failed"); - - GE_CHECK_NOTNULL_EXEC(op_desc->impl_, return nullptr); - op_desc->ext_attrs_ = org_op_desc->ext_attrs_; - - // This function may be called by some passes of fusion engine, in this condition, do not need these attribute - if (!op_desc->impl_->input_name_idx_.empty()) { - op_desc->impl_->input_name_idx_.clear(); - } - if (!op_desc->impl_->output_name_idx_.empty()) { - op_desc->impl_->output_name_idx_.clear(); - } - op_desc->impl_->MutableIRMeta() = IRMetaData(op_desc->GetName()); - return op_desc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr GraphUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc, - const AttrFilter &attr_filter) { - GE_ASSERT_NOTNULL(org_op_desc); - GE_ASSERT_NOTNULL(org_op_desc->impl_); - const auto op_def = ComGraphMakeShared(); - GE_ASSERT_NOTNULL(op_def); - - ModelSerializeImp imp; - (void) imp.SerializeOpDesc(org_op_desc, op_def.get()); - imp.SetProtobufOwner(op_def); - OpDescPtr op_desc = nullptr; - GE_ASSERT_TRUE(imp.UnserializeOpDesc(op_desc, *op_def)); - // weight's data call `Clone` for deep copy if needed - if (ConstantUtils::IsConstant(op_desc) && ((attr_filter == nullptr) || attr_filter(*op_desc, ATTR_NAME_WEIGHTS))) { - ConstGeTensorPtr weight = nullptr; - if (AttrUtils::GetTensor(org_op_desc, ATTR_NAME_WEIGHTS, weight)) { - const GeTensor copy_weight = weight->Clone(); - GE_ASSERT_TRUE(AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, copy_weight)); - GELOGD("Clone ATTR_NAME_WEIGHTS for node:%s success.", op_desc->GetName().c_str()); - } - } - // remove attr by attr_filter - for (const auto &attr_name : op_desc->GetAllAttrNames()) { - if ((attr_filter != nullptr) && (!attr_filter(*op_desc, attr_name))) { - GE_ASSERT_GRAPH_SUCCESS(op_desc->DelAttr(attr_name)); - } - } - GE_ASSERT_NOTNULL(op_desc->impl_); - op_desc->ext_attrs_ = org_op_desc->ext_attrs_; - op_desc->impl_->input_name_idx_.insert(org_op_desc->impl_->input_name_idx_.cbegin(), - org_op_desc->impl_->input_name_idx_.cend()); - op_desc->impl_->MutableIRMeta() = org_op_desc->impl_->GetIRMeta(); - op_desc->impl_->output_name_idx_.insert(org_op_desc->impl_->output_name_idx_.cbegin(), - org_op_desc->impl_->output_name_idx_.cend()); - - op_desc->impl_->infer_func_ = org_op_desc->impl_->infer_func_; - op_desc->impl_->infer_format_func_ = org_op_desc->impl_->infer_format_func_; - op_desc->impl_->verifier_func_ = org_op_desc->impl_->verifier_func_; - - return op_desc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr GraphUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - return CopyOpDesc(org_op_desc, nullptr); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::CopyComputeGraph(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph) { - return CopyComputeGraph(src_compute_graph, nullptr, nullptr, nullptr, dst_compute_graph); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::CopyComputeGraph(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph, - std::map &node_old_2_new, - std::map &op_desc_old_2_new, const int32_t depth) { - return CopyComputeGraph(src_compute_graph, nullptr, nullptr, nullptr, dst_compute_graph, node_old_2_new, - op_desc_old_2_new, depth); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyOpAndSubgraph( - const ComputeGraphPtr &src_compute_graph, const NodeFilter &node_filter, const GraphFilter &graph_filter, - const AttrFilter &attr_filter, ComputeGraphPtr &dst_compute_graph, std::map &node_old_2_new, - std::map &op_desc_old_2_new, std::unordered_map &all_new_nodes, - const int32_t depth) { - GE_CHECK_NOTNULL(src_compute_graph); - GE_CHECK_NOTNULL(dst_compute_graph); - const auto dst_root_compute_graph = FindRootGraph(dst_compute_graph); - GE_CHECK_NOTNULL(dst_root_compute_graph); - const auto src_root_compute_graph = FindRootGraph(src_compute_graph); - GE_CHECK_NOTNULL(src_root_compute_graph); - for (const auto &n : src_compute_graph->GetDirectNode()) { - if ((node_filter != nullptr) && (!node_filter(*n))) { - continue; - } - const auto &op_desc = GraphUtils::CopyOpDesc(n->GetOpDesc(), attr_filter); - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(op_desc->impl_); - op_desc->SetName(n->GetName()); - op_desc->impl_->MutableIRMeta() = n->GetOpDesc()->impl_->GetIRMeta(); - op_desc->impl_->subgraph_names_to_index_ = n->GetOpDesc()->impl_->subgraph_names_to_index_; - op_desc->impl_->subgraph_instance_names_ = n->GetOpDesc()->impl_->subgraph_instance_names_; - - const NodePtr node = dst_compute_graph->AddNode(op_desc, n->GetOpDesc()->GetId()); - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "AddNode %s to graph:%s failed", op_desc->GetName().c_str(), - dst_compute_graph->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][Node][%s] to graph:%s failed", - op_desc->GetName().c_str(), dst_compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - all_new_nodes[node->GetName()] = node; - node_old_2_new[n] = node; - op_desc_old_2_new[n->GetOpDesc()] = op_desc; - - // copy subgraph from old graph to new graph - const auto &subgraph_names = n->GetOpDesc()->GetSubgraphInstanceNames(); - const size_t subgraph_num = subgraph_names.size(); - for (size_t subgraph_idx = 0U; subgraph_idx < subgraph_num; ++subgraph_idx) { - const auto &subgraph_name = subgraph_names[subgraph_num - 1U - subgraph_idx]; - const auto src_subgraph = src_root_compute_graph->GetSubgraph(subgraph_name); - if ((src_subgraph == nullptr) && subgraph_name.empty()) { - GELOGD("node=%s subgraph is empty, subgraph_idx=%zu, subgraph_num=%zu.", n->GetName().c_str(), subgraph_idx, - subgraph_num); - continue; - } - GE_CHECK_NOTNULL(src_subgraph, ", get subgraph[%s] failed, node=%s.", subgraph_name.c_str(), - n->GetName().c_str()); - if ((graph_filter != nullptr) && - (!graph_filter(*src_subgraph->GetParentNode(), src_subgraph->GetName().c_str(), src_subgraph))) { - op_desc->RemoveSubgraphInstanceName(subgraph_name); - continue; - } - ComputeGraphPtr dst_subgraph = ComGraphMakeShared(src_subgraph->GetName()); - GE_CHECK_NOTNULL(dst_subgraph); - dst_subgraph->SetParentGraph(dst_compute_graph); - std::map sub_node_old_2_new; - std::map sub_op_desc_old_2_new; - const graphStatus ret = CopyComputeGraph(src_subgraph, node_filter, graph_filter, attr_filter, dst_subgraph, - sub_node_old_2_new, sub_op_desc_old_2_new, depth + 1); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][SubGraph] %s of parent node:%s failed.", - src_subgraph->GetName().c_str(), node->GetName().c_str()); - return GRAPH_FAILED; - } - (void)dst_root_compute_graph->AddSubGraph(dst_subgraph); - dst_subgraph->SetParentNode(node); - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyComputeGraph( - const ComputeGraphPtr &src_compute_graph, const NodeFilter &node_filter, const GraphFilter &graph_filter, - const AttrFilter &attr_filter, ComputeGraphPtr &dst_compute_graph, std::map &node_old_2_new, - std::map &op_desc_old_2_new, const int32_t depth) { - GE_CHECK_NOTNULL(dst_compute_graph); - GE_CHECK_NOTNULL(src_compute_graph); - - if (depth >= kCopyGraphMaxRecursionDepth) { - REPORT_INNER_ERR_MSG("E18888", "param depth:%d >= %d(allow max subgraphs)", depth, kCopyGraphMaxRecursionDepth); - GELOGE(GRAPH_FAILED, "[Check][Param]exist too much subgraphs:%d > %d(allow max subgraphs)", depth, - kCopyGraphMaxRecursionDepth); - return GRAPH_FAILED; - } - // copy op and subgraph from old graph to new graph - std::unordered_map all_new_nodes; - graphStatus ret = CopyOpAndSubgraph(src_compute_graph, node_filter, graph_filter, attr_filter, dst_compute_graph, - node_old_2_new, op_desc_old_2_new, all_new_nodes, depth); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][OpAndSubGraph] failed."); - return GRAPH_FAILED; - } - - for (const auto &n : src_compute_graph->GetDirectNode()) { - if ((node_filter != nullptr) && (!node_filter(*n))) { - continue; - } - if (RelinkGraphEdges(n, "", all_new_nodes) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Relink][Edges] failed."); - return GRAPH_FAILED; - } - } - // To keep subgraph consistent with the source graph - std::vector new_subgraphs; - const auto old_subgraphs = src_compute_graph->GetAllSubgraphs(); - for (const auto &sub_graph : old_subgraphs) { - if ((graph_filter != nullptr) && - (!graph_filter(*sub_graph->GetParentNode(), sub_graph->GetName().c_str(), sub_graph))) { - continue; - } - const auto new_subgraph = dst_compute_graph->GetSubgraph(sub_graph->GetName()); - GE_CHECK_NOTNULL(new_subgraph); - GELOGD("Copy new subgraph:%s.", sub_graph->GetName().c_str()); - new_subgraphs.push_back(new_subgraph); - } - dst_compute_graph->SetAllSubgraphs(new_subgraphs); - - // copy members from old graph to new graph - ret = CopyMembers(src_compute_graph, dst_compute_graph, all_new_nodes); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][Members] failed, ret:%d.", ret); - return GRAPH_FAILED; - } - - // inherit all attr from old graph to new graph - InheritOriginalAttr(src_compute_graph, dst_compute_graph); - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::CopyMembers(const ComputeGraphPtr &src_compute_graph, - ComputeGraphPtr &dst_compute_graph, - const std::unordered_map &all_new_nodes) { - if ((src_compute_graph == nullptr) || (src_compute_graph->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Param src_compute_graph is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Src compute graph is nullptr."); - return GRAPH_FAILED; - } - if ((dst_compute_graph == nullptr) || (dst_compute_graph->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Param dst_compute_graph is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Dst compute graph is nullptr."); - return GRAPH_FAILED; - } - // copy info of output nodes from old graph to new graph. - const std::vector> &out_nodes_info = src_compute_graph->GetGraphOutNodesInfo(); - std::vector> new_out_nodes_info; - for (const auto &info : out_nodes_info) { - const auto it = all_new_nodes.find(info.first->GetName()); - if (it == all_new_nodes.end()) { - GELOGW("[Check][Param] Find output node:%s failed.", info.first->GetName().c_str()); - continue; - } - new_out_nodes_info.emplace_back(it->second, info.second); - } - dst_compute_graph->SetGraphOutNodesInfo(new_out_nodes_info); - - // copy info of input nodes from old graph to new graph. - const ComputeGraph::Vistor &input_nodes = src_compute_graph->GetInputNodes(); - for (const auto &node : input_nodes) { - const auto it = all_new_nodes.find(node->GetName()); - if (it == all_new_nodes.end()) { - GELOGW("[Check][Param] Find input node:%s failed.", node->GetName().c_str()); - continue; - } - (void)dst_compute_graph->AddInputNode(it->second); - } - - // copy target info nodes from old graph to new graph. - const std::vector &src_traget_nodes_info = src_compute_graph->GetGraphTargetNodesInfo(); - std::vector dst_traget_nodes_info; - for (const auto &node : src_traget_nodes_info) { - const auto it = all_new_nodes.find(node->GetName()); - if (it == all_new_nodes.end()) { - GELOGW("[Check][Param] Find target info node:%s failed.", node->GetName().c_str()); - continue; - } - dst_traget_nodes_info.emplace_back(it->second); - } - dst_compute_graph->SetGraphTargetNodesInfo(dst_traget_nodes_info); - - // graph属性序列化 - dst_compute_graph->impl_->attrs_ = src_compute_graph->impl_->attrs_; - - // copy other members from old graph to new graph. - dst_compute_graph->impl_->data_format_ = src_compute_graph->impl_->data_format_; - dst_compute_graph->impl_->need_iteration_ = src_compute_graph->impl_->need_iteration_; - dst_compute_graph->impl_->is_summary_graph_ = src_compute_graph->impl_->is_summary_graph_; - dst_compute_graph->impl_->is_valid_flag_ = src_compute_graph->impl_->is_valid_flag_; - dst_compute_graph->impl_->input_size_ = src_compute_graph->impl_->input_size_; - dst_compute_graph->impl_->output_size_ = src_compute_graph->impl_->output_size_; - dst_compute_graph->impl_->inputs_order_ = src_compute_graph->impl_->inputs_order_; - dst_compute_graph->impl_->op_name_map_ = src_compute_graph->impl_->op_name_map_; - dst_compute_graph->impl_->out_nodes_map_ = src_compute_graph->impl_->out_nodes_map_; - dst_compute_graph->impl_->params_share_map_ = src_compute_graph->impl_->params_share_map_; - dst_compute_graph->impl_->graph_id_ = src_compute_graph->impl_->graph_id_; - return GRAPH_SUCCESS; -} - -/// Make a copy of ComputeGraph. -/// @param graph: original graph. -/// @param suffix: node name suffix of new graph. -/// @param output_nodes: output nodes of new graph. -/// @return ComputeGraphPtr -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ComputeGraphPtr GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &suffix, - std::vector &input_nodes, std::vector &output_nodes) { - GE_CHK_BOOL_EXEC(graph != nullptr, REPORT_INNER_ERR_MSG("E18888", "param graph is nullptr, check invalid."); - return nullptr, "[Check][Param] Original graph is null"); - ComputeGraphPtr new_graph = ComGraphMakeShared(graph->GetName()); - GE_CHK_BOOL_EXEC(new_graph != nullptr, - REPORT_INNER_ERR_MSG("E18888", "create computegraph %s failed.", graph->GetName().c_str()); - return nullptr, "[Create][ComputeGraph] %s failed", graph->GetName().c_str()); - - std::unordered_map all_new_nodes; - for (const auto &n : graph->GetDirectNode()) { - const OpDescPtr op_desc = GraphUtils::CopyOpDesc(n->GetOpDesc()); - GE_CHK_BOOL_EXEC(op_desc != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Create node:%s failed.", n->GetOpDesc()->GetName().c_str()); - return nullptr, "[Create][Node] %s failed", n->GetOpDesc()->GetName().c_str()); - - if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) { - return nullptr; - } - - const bool is_const_op = (n->GetType() == CONSTANT) || (n->GetType() == CONSTANTOP); - if (is_const_op) { - GeTensorPtr weight = nullptr; - if (!AttrUtils::MutableTensor(n->GetOpDesc(), ATTR_NAME_WEIGHTS, weight)) { - GELOGI("Can not find attr ATTR_NAME_WEIGHTS for node:%s.", n->GetName().c_str()); - continue; - } - const GeTensor copy_weight = weight->Clone(); - if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, copy_weight)) { - REPORT_INNER_ERR_MSG("E18888", "Clone ATTR_NAME_WEIGHTS for node:%s failed.", op_desc->GetName().c_str()); - GELOGE(INTERNAL_ERROR, "[Set][Tensor] Clone ATTR_NAME_WEIGHTS for node:%s failed.", op_desc->GetName().c_str()); - return nullptr; - } - GELOGD("Clone ATTR_NAME_WEIGHTS for node:%s success.", op_desc->GetName().c_str()); - } - - op_desc->SetName(n->GetName() + suffix); - NodePtr node = new_graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(node != nullptr, - REPORT_INNER_ERR_MSG("E18888", "add node %s to graph:%s failed", op_desc->GetName().c_str(), - new_graph->GetName().c_str()); - return nullptr, "[Add][Node] [%s] to graph:%s failed", - op_desc->GetName().c_str(), new_graph->GetName().c_str()); - all_new_nodes[node->GetName()] = node; - - if (OpTypeUtils::IsDataNode(node->GetType())) { - input_nodes.emplace_back(node); - } else if (node->GetType() == NETOUTPUT) { - output_nodes.emplace_back(node); - } else { - // do nothing - } - } - - for (const auto &n : graph->GetDirectNode()) { - if (RelinkGraphEdges(n, suffix, all_new_nodes) != GRAPH_SUCCESS) { - return nullptr; - } - } - - // inherit all attr from old graph to new graph - InheritOriginalAttr(graph, new_graph); - - // copy info of output nodes from old graph to new graph. - const std::vector> out_nodes_info = graph->GetGraphOutNodesInfo(); - std::vector> new_out_nodes_info; - for (const auto &info : out_nodes_info) { - const auto it = all_new_nodes.find(info.first->GetName()); - if (it != all_new_nodes.end()) { - new_out_nodes_info.emplace_back(it->second, info.second); - } - } - new_graph->SetGraphOutNodesInfo(new_out_nodes_info); - return new_graph; -} - -/// Copy tensor attribute to new node. -/// @param [in] dst_node: cloned node. -/// @param [in] src_node: original node. -/// @return success: GRAPH_SUCESS -graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { - if (dst_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param dst_desc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Input param dst node not valid"); - return GRAPH_FAILED; - } - if ((src_node == nullptr) || (src_node->GetOpDesc() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param src_node is nullptr or it's opdesc is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Input param src node not valid"); - return GRAPH_FAILED; - } - - const auto &src_desc = src_node->GetOpDesc(); - dst_desc->CopyAttrsFrom(*src_desc); - - for (uint32_t i = 0U; i < src_node->GetAllInDataAnchorsSize(); ++i) { - const auto input_desc = dst_desc->MutableInputDesc(i); - if (input_desc == nullptr) { - continue; - } - input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i)); - } - - for (uint32_t i = 0U; i < src_node->GetAllOutDataAnchorsSize(); ++i) { - const auto output_desc = dst_desc->MutableOutputDesc(i); - if (output_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Param dst node:%s not valid, output_desc[%d] is nullptr", - dst_desc->GetName().c_str(), i); - GELOGE(GRAPH_FAILED, "[Check][Param] Param dst node:%s not valid", dst_desc->GetName().c_str()); - return GRAPH_FAILED; - } - output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i)); - } - - return GRAPH_SUCCESS; -} - -/// Relink all edges for cloned ComputeGraph. -/// @param [in] node: original node. -/// @param [in] suffix: node name suffix of new node. -/// @param [in] all_nodes: all nodes in new graph. -/// @return success: GRAPH_SUCESS -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const std::string &suffix, - const std::unordered_map &all_nodes) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr or it's opdesc is nullptr. check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Input node not valid"); - return GRAPH_FAILED; - } - - auto it = all_nodes.find(node->GetName() + suffix); - if (it == all_nodes.end()) { - REPORT_INNER_ERR_MSG("E18888", "all_nodes not contain node:%s.", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] node[%s] not found", node->GetName().c_str()); - return GRAPH_FAILED; - } - const auto &new_node = it->second; - - // traversing from the parent node can be completely restored in the original one-to-many order. - for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - GE_CHK_BOOL_EXEC(out_anchor != nullptr, - REPORT_INNER_ERR_MSG("E18888", "out data anchor is null, node:%s.", node->GetName().c_str()); - return GRAPH_FAILED, "[Check][Param] Out data anchor is null, node:%s", node->GetName().c_str()); - for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - GE_CHK_BOOL_EXEC(peer_in_anchor->GetOwnerNodeBarePtr() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Peer in node:%s is null", node->GetName().c_str()); - return GRAPH_FAILED, "Peer in node:%s is null", node->GetName().c_str()); - it = all_nodes.find(peer_in_anchor->GetOwnerNodeBarePtr()->GetName() + suffix); - if (it == all_nodes.end()) { - GELOGW("[Check][Param] node[%s] not found", peer_in_anchor->GetOwnerNode()->GetName().c_str()); - continue; - } - const auto &new_peer_in_node = it->second; - const auto ret = GraphUtils::AddEdge(new_node->GetOutAnchor(out_anchor->GetIdx()), - new_peer_in_node->GetInAnchor(peer_in_anchor->GetIdx())); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add data edge from %s to %s failed", new_node->GetName().c_str(), - new_peer_in_node->GetName().c_str()); - return GRAPH_FAILED, "[Invoke][AddEdge] link data edge failed[%s to %s]", - new_node->GetName().c_str(), new_peer_in_node->GetName().c_str()); - } - } - - if (node->GetOutControlAnchor() != nullptr) { - for (const auto peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerAnchorsPtr()) { - GE_CHECK_NOTNULL(peer_in_control_anchor); - GE_CHK_BOOL_EXEC(peer_in_control_anchor->GetOwnerNodeBarePtr() != nullptr, - REPORT_INNER_ERR_MSG("E18888", "Peer out node is null"); - return GRAPH_FAILED, "[Invoke][GetOwnerNode] Peer out node is null"); - it = all_nodes.find(peer_in_control_anchor->GetOwnerNodeBarePtr()->GetName() + suffix); - if (it == all_nodes.end()) { - GELOGW("[Check][Param] node[%s] not found", peer_in_control_anchor->GetOwnerNode()->GetName().c_str()); - continue; - } - const auto &new_peer_in_node = it->second; - const auto ret = GraphUtils::AddEdge(new_node->GetOutControlAnchor(), - new_peer_in_node->GetInAnchor(peer_in_control_anchor->GetIdx())); - GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "add control edge from %s to %s failed.", - new_node->GetName().c_str(), new_peer_in_node->GetName().c_str()); - return GRAPH_FAILED, "[Invoke][AddEdge] link control edge failed[%s to %s]", - new_node->GetName().c_str(), new_peer_in_node->GetName().c_str()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol) { - GE_CHECK_NOTNULL(graph); - for (const auto &node : graph->GetAllNodes()) { - // in_data_anchor - GE_ASSERT_GRAPH_SUCCESS(HandleInAnchorMapping(graph, node, symbol_to_anchors, anchor_to_symbol), - "Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); - // out_data_anchor - GE_ASSERT_GRAPH_SUCCESS(HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol), - "Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); - } - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::HandleInAnchorMapping(const ComputeGraphPtr &graph, const NodePtr &node, - SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - if (node->GetOwnerComputeGraph() != graph) { - // when curr graph is subgraph , to handle subgraph input/output ref mapping - if (NodeUtils::IsSubgraphOutput(node)) { - return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); - } - - if (NodeUtils::IsSubgraphInput(node)) { - return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); - } - } - - const std::string &type = node->GetType(); - if ((type == MERGE) || (type == STREAMMERGE)) { - return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); - } - - for (const auto in_data_anchor : node->GetAllInDataAnchorsPtr()) { - const NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); - const OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - const std::string &symbol = cur_node_info.ToString(); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors[symbol] = { cur_node_info }; - anchor_to_symbol[symbol] = symbol; - } else { - const NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - GE_ASSERT_GRAPH_SUCCESS(UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol)); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, - SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - const NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); - if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { - continue; - } - - NodePtr refed_node = nullptr; - bool is_ref_from_other = false; - GE_ASSERT_GRAPH_SUCCESS(CheckIsRefFromOther(out_data_anchor, refed_node, is_ref_from_other)); - NodeIndexIO exist_ref_data_info(refed_node, 0U, kOut); - if (is_ref_from_other && (anchor_to_symbol.find(exist_ref_data_info.ToString()) != anchor_to_symbol.end())) { - GELOGD("Node %s output:%d is ref form node: %s.", node->GetName().c_str(), out_data_anchor->GetIdx(), - exist_ref_data_info.ToString().c_str()); - GE_ASSERT_GRAPH_SUCCESS( - UpdateRefMapping(cur_node_info, exist_ref_data_info, symbol_to_anchors, anchor_to_symbol)); - } - - // 这里ref from input和ref from refdata不冲突 - int32_t reuse_in_index = -1; - const bool reuse_input_flag = IsRefFromInput(out_data_anchor, reuse_in_index); - if (reuse_input_flag && (node->GetInDataAnchor(reuse_in_index) != nullptr)) { - const NodeIndexIO exist_node_info(node, reuse_in_index, kIn); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("[Update][SymbolMapping] failed."); - return GRAPH_FAILED; - } - } else { - if (reuse_input_flag) { - GELOGW("[GetRefMapping][Check] Invalid reuse_input attr on output %d of node %s, please check attr reuse_input " - "and reuse_input_index", out_data_anchor->GetIdx(), node->GetName().c_str()); - } - const std::string &symbol = cur_node_info.ToString(); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - (void)symbol_to_anchors.emplace(std::make_pair(symbol, std::list{ cur_node_info })); - (void)anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, - SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - - // Data in subgraph - uint32_t index = 0U; - if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { - REPORT_INNER_ERR_MSG("E18888", "Get Attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); - GE_LOGE("[Get][Attr] ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); - return GRAPH_FAILED; - } - const NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - const InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(static_cast(index)); - GE_CHECK_NOTNULL(parent_in_anchor); - const OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor != nullptr) { - // Data has and only has one input - const NodeIndexIO cur_node_info(node, 0, kIn); - const NodeIndexIO exist_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { - GE_LOGE("[Update][SymbolMapping] failed."); - return GRAPH_FAILED; - } - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, - SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - std::vector exist_node_infos; - std::vector cur_node_infos; - for (const auto in_data_anchor : node->GetAllInDataAnchorsPtr()) { - auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - std::string next_name; - if ((AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name)) && (!next_name.empty())) { - ComputeGraphPtr graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - const ge::NodePtr next_node = FindNodeFromAllNodes(graph, next_name); - GE_CHECK_NOTNULL(next_node); - // NextIteration has and only has one output - peer_out_anchor = next_node->GetOutDataAnchor(0); - GE_CHECK_NOTNULL(peer_out_anchor); - cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); - cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); - } - } else { - cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); - exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); - } - } - - size_t anchor_nums = 0U; - NodeIndexIO max_node_index_io(static_cast(nullptr), 0, kOut); - for (const auto &temp_node_info : exist_node_infos) { - const auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); - if (iter1 != anchor_to_symbol.end()) { - const std::string &temp_symbol = iter1->second; - const auto iter2 = symbol_to_anchors.find(temp_symbol); - if (iter2 != symbol_to_anchors.end()) { - if (iter2->second.size() > anchor_nums) { - max_node_index_io = temp_node_info; - anchor_nums = iter2->second.size(); - } - } - } - } - - std::string symbol; - for (const auto &temp_node_info : exist_node_infos) { - if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != - GRAPH_SUCCESS) || - symbol.empty()) { - GE_LOGE("[Union][SymbolMap] anchor1:%s & anchor2:%s failed.", max_node_index_io.ToString().c_str(), - temp_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - } - - const auto iter = symbol_to_anchors.find(symbol); - if (iter != symbol_to_anchors.end()) { - for (const auto &temp_node_info : cur_node_infos) { - GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); - iter->second.emplace_back(temp_node_info); - (void)anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, - SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol) { - GE_CHECK_NOTNULL(node); - const ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - const NodePtr parent_node = owner_graph->GetParentNode(); - GE_CHECK_NOTNULL(parent_node); - - const OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - for (const auto &in_data_anchor : node->GetAllInDataAnchorsPtr()) { - const OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - - const auto &in_tensor = op_desc->GetInputDescPtr(static_cast(in_data_anchor->GetIdx())); - uint32_t index = 0U; - if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { - continue; - } - GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(static_cast(index))); - // Union symbol of peer_out_anchor & parent_out_anchor - const NodeIndexIO peer_node_info(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); - const NodeIndexIO parent_node_info(parent_node, index, kOut); - std::string symbol; - if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, - symbol) != GRAPH_SUCCESS) || symbol.empty()) { - GE_LOGE("[Union][SymbolMap] anchor1:%s, and anchor2:%s failed.", - peer_node_info.ToString().c_str(), parent_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - - NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - symbol_to_anchors[symbol].emplace_back(cur_node_info); - (void)anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, - SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol, std::string &symbol) { - const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; - const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; - if (symbol1 == symbol2) { - symbol = symbol1; - GELOGI("no need to union."); - return GRAPH_SUCCESS; - } - - const auto iter1 = symbol_to_anchors.find(symbol1); - const auto iter2 = symbol_to_anchors.find(symbol2); - if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { - REPORT_INNER_ERR_MSG("E18888", "symbol %s or %s does not exist.", symbol1.c_str(), symbol2.c_str()); - GE_LOGE("[Check][Param] symbol %s or %s does not exist.", symbol1.c_str(), symbol2.c_str()); - return GRAPH_FAILED; - } - - auto &max_iter = ((iter1->second.size() > iter2->second.size()) ? iter1 : iter2); - auto &min_iter = ((iter1->second.size() > iter2->second.size()) ? iter2 : iter1); - symbol = ((iter1->second.size() > iter2->second.size()) ? symbol1 : symbol2); - const std::string min_symbol = ((iter1->second.size() > iter2->second.size()) ? symbol2 : symbol1); - for (auto &node_index_io : min_iter->second) { - GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); - max_iter->second.emplace_back(node_index_io); - const auto iter = anchor_to_symbol.find(node_index_io.ToString()); - GE_ASSERT_TRUE(iter != anchor_to_symbol.end(), "anchor %s does not exist in anchor_to_symbol.", - node_index_io.ToString().c_str()); - if (iter->second != min_symbol) { - GELOGW("[GetRefMapping][Check] not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), - min_symbol.c_str(), iter->second.c_str()); - } - iter->second = symbol; - } - - GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); - (void)symbol_to_anchors.erase(min_iter); - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, - SymbolToAnchors &symbol_to_anchors, - AnchorToSymbol &anchor_to_symbol) { - const auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); - if (iter1 == anchor_to_symbol.end()) { - REPORT_INNER_ERR_MSG("E18888", "data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", - exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); - GE_LOGE("[Check][Param] data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", - exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); - return GRAPH_FAILED; - } - - const std::string &symbol = iter1->second; - const auto iter2 = symbol_to_anchors.find(symbol); - if (iter2 == symbol_to_anchors.end()) { - REPORT_INNER_ERR_MSG("E18888", "symbol %s does not exist in symbol_to_anchors.", symbol.c_str()); - GE_LOGE("[Check][Param] symbol %s not found.", symbol.c_str()); - return GRAPH_FAILED; - } - GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); - iter2->second.emplace_back(cur_node_info); - const auto ret = anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); - GE_ASSERT_TRUE(ret.first != anchor_to_symbol.end(), "failed to insert anchor to symbol mapping."); - GE_ASSERT_TRUE(ret.first->second == symbol, "update anchor's symbol failed. cur_node_info: %s, old_symbol: %s, " - "new_symbol: %s", cur_node_info.ToString().c_str(), ret.first->second.c_str(), symbol.c_str()); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name) { - const auto root_graph = FindRootGraph(graph); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param graph is nullptr,check invalid."); - GE_LOGE("[Check][Param] param graph is nullptr, check invalid"); - return nullptr; - } - - std::deque candidates; - - (void) candidates.insert(candidates.begin(), graph->impl_->nodes_.begin(), graph->impl_->nodes_.end()); - while (!candidates.empty()) { - NodePtr node = candidates.front(); - candidates.pop_front(); - if (node == nullptr) { - continue; - } - if (NodeUtils::IsNameEqual(node, name.c_str())) { - return node; - } - const auto op_desc = node->GetOpDescBarePtr(); - if (op_desc != nullptr) { - const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); - auto name_iter = subgraph_names.rbegin(); - while (name_iter != subgraph_names.rend()) { - const auto subgraph = root_graph->GetSubgraph(*name_iter); - if (subgraph != nullptr) { - (void) (candidates.insert(candidates.begin(), subgraph->impl_->nodes_.begin(), - subgraph->impl_->nodes_.end())); - } - ++name_iter; - } - } - } - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::vector GraphUtils::FindNodesByTypeFromAllNodes(ComputeGraphPtr &graph, const std::string &type) { - std::vector nodes; - const auto &root_graph = FindRootGraph(graph); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param graph is nullptr,check invalid."); - GE_LOGE("[Check][Param] param graph is nullptr, check invalid"); - return nodes; - } - - for (const auto &node : root_graph->GetAllNodes()) { - if (node == nullptr) { - continue; - } - if (NodeUtils::IsTypeEqual(node, type.c_str())) { - nodes.emplace_back(node); - } - } - - return nodes; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::vector GraphUtils::FindBareNodesByTypeFromAllNodes(ComputeGraphPtr &graph, const char_t *const type) { - const auto &root_graph = FindRootGraph(graph); - GE_ASSERT_NOTNULL(root_graph); - - std::vector nodes; - for (const auto node : root_graph->GetAllNodesPtr()) { - if (strcmp(node->GetTypePtr(), type) == 0) { - nodes.emplace_back(node); - } - } - return nodes; -} - -graphStatus GraphUtils::GetSubgraphsRecursively(const ComputeGraphPtr &graph, std::vector &subgraphs) { - const auto root_graph = GraphUtils::FindRootGraph(graph); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to find root graph"); - GELOGE(GRAPH_FAILED, "[Get][Graph] Failed to find root graph"); - return GRAPH_FAILED; - } - if (graph == root_graph) { - subgraphs = graph->GetAllSubgraphs(); - return GRAPH_SUCCESS; - } - for (const auto &node : graph->GetAllNodes()) { - // op_desc of node should not be null - for (const auto &graph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) { - const auto &subgraph = root_graph->GetSubgraph(graph_name); - if (subgraph == nullptr) { - GELOGW("[Get][Subgraph] subgraph %s of node %s is null", graph_name.c_str(), node->GetName().c_str()); - continue; - } - subgraphs.emplace_back(subgraph); - } - } - return GRAPH_SUCCESS; -} - -/// Check if out_data_anchor is reference of input -/// @param [in] out_data_anchor -/// @param [out] reuse_in_index -/// @return bool -bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { - if (out_data_anchor == nullptr) { - GELOGW("[Check][Param] out_data_anchor is null"); - return false; - } - const int32_t output_index = out_data_anchor->GetIdx(); - - // pass-through op - const auto node = out_data_anchor->GetOwnerNodeBarePtr(); - const std::string &type = node->GetType(); - static const std::unordered_set pass_through_types = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; - if ((pass_through_types.count(type) > 0U) || (NodeUtils::IsSubgraphInput(node))) { - reuse_in_index = output_index; - GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); - return true; - } - - // Merge op 0th output - const bool is_ge_local_op = ((type == MERGE) || (type == RESHAPE)) && (output_index == 0); - if (is_ge_local_op) { - reuse_in_index = 0; - GELOGI("%s name[%s] output_index[0] reuse input 0.", type.c_str(), node->GetName().c_str()); - return true; - } - - // ref op - // op_desc of node should not be null - const auto op_desc = node->GetOpDescBarePtr(); - bool is_ref = false; - (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); - if (is_ref) { - const std::string &output_name = op_desc->GetOutputNameByIndex(static_cast(output_index)); - for (const auto &input_name : op_desc->GetAllInputNames()) { - if ((!input_name.empty()) && (output_name == input_name)) { - reuse_in_index = op_desc->GetInputIndexByName(input_name); - GELOGD("Reference name[%s] output[%s][%d] ref to input[%s][%d].", op_desc->GetName().c_str(), - output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); - return true; - } - } - } - - // reuse input - const auto output_op_desc = op_desc->GetOutputDescPtr(static_cast(output_index)); - if (output_op_desc != nullptr) { - bool reuse_input = false; - if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { - uint32_t reuse_input_index = 0U; - if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { - reuse_in_index = static_cast(reuse_input_index); - GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, - reuse_in_index); - return true; - } - } - } - // nopadding reuse - return IsNoPaddingRefFromInput(out_data_anchor, reuse_in_index); -} - -graphStatus GraphUtils::CheckIsRefFromOther(const OutDataAnchorPtr &out_data_anchor, NodePtr &refed_node, - bool &is_ref_from_other) { - GE_ASSERT_NOTNULL(out_data_anchor); - const auto owner_node = out_data_anchor->GetOwnerNode(); - GE_ASSERT_NOTNULL(owner_node); - bool is_ref_from_refdata = false; - bool is_ref_from_innerdata = false; - GE_ASSERT_GRAPH_SUCCESS(CheckIsRefFromRefData(out_data_anchor, refed_node, is_ref_from_refdata)); - GE_ASSERT_GRAPH_SUCCESS(CheckIsRefFromInnerData(out_data_anchor, refed_node, is_ref_from_innerdata)); - is_ref_from_other = (is_ref_from_refdata || is_ref_from_innerdata); - return GRAPH_SUCCESS; -} - -bool GraphUtils::IsNoPaddingRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { - const auto node = out_data_anchor->GetOwnerNodeBarePtr(); - // nopadding means output[0] reuse input[0], but as history reason, - // other output index also return true for mem assign in block_mem_assigner - bool attr_reuse = false; - bool is_input_continuous = false; - bool is_out_continuous = false; - (void)ge::AttrUtils::GetBool(node->GetOpDescBarePtr(), ATTR_NAME_NOPADDING_CONTINUOUS_INPUT, is_input_continuous); - (void)ge::AttrUtils::GetBool(node->GetOpDescBarePtr(), ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT, is_out_continuous); - const bool get_reuse_flag = ge::AttrUtils::GetBool(node->GetOpDescBarePtr(), ATTR_NAME_OUTPUT_REUSE_INPUT, - attr_reuse); - const bool is_no_padding_reuse_input = (is_input_continuous || is_out_continuous) && get_reuse_flag && attr_reuse; - if (is_no_padding_reuse_input) { - reuse_in_index = 0; - GELOGI("Nopadding ReuseInput name[%s] output[%d] reuse input[%d].", node->GetName().c_str(), - out_data_anchor->GetIdx(), reuse_in_index); - return true; - } - return false; -} - -bool GraphUtils::IsNodeInGraphRecursively(const ComputeGraphPtr &graph, const Node &node) { - auto parent_graph = node.GetOwnerComputeGraph(); - while (parent_graph != nullptr) { - if (parent_graph == graph) { - return true; - } - parent_graph = parent_graph->GetParentGraph(); - } - return false; -} - -/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs -/// of the graph have UNKNOWN_SHAPE operators or not. -/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following -/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE -/// ROOT graph: A -----> B -----> C -/// K subgraph U -/// | -/// V -/// SUB graph: D --> E --> F -/// K K K -/// @param [in] graph -/// @return bool -bool GraphUtils::IsUnknownShapeGraph(const ComputeGraphPtr &graph) { - if (graph == nullptr) { - GELOGW("[Check][Param] Input graph is nullptr."); - return false; - } - for (const auto &node : graph->GetDirectNode()) { - bool is_unknown = false; - const auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); - if (ret != GRAPH_SUCCESS) { - GELOGW("[Check][UnknownGraph] Get unknown status failed, node name:%s, type:%s", node->GetName().c_str(), - node->GetType().c_str()); - continue; - } - if (is_unknown) { - GELOGD("Node %s, type %s is unknown shape in graph %s.", - node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str()); - return true; - } - } - GELOGD("Graph %s does not have unknown shape node.", graph->GetName().c_str()); - return false; -} - -ComputeGraphPtr GraphUtils::BuildSubgraphWithNodes(const ComputeGraphPtr &graph, const std::set &nodes, - const std::string &subgraph_name) { - if (graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Graph is null"); - GELOGE(FAILED, "[Check][Param] graph is null"); - return nullptr; - } - return BuildSubgraphWithNodes(*graph, nodes, subgraph_name); -} - -ComputeGraphPtr GraphUtils::BuildSubgraphWithNodes(ComputeGraph &graph, const std::set &nodes, - const std::string &subgraph_name) { - if (nodes.empty()) { - GELOGW("nodes is empty, no need to build subgraph"); - return nullptr; - } - - GraphInfo graph_info; - BuildGraphInfoFromNodes(nodes, graph_info); - - const NodePtr graph_node = BuildSubgraphNode(graph, subgraph_name, graph_info); - if (graph_node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Build SubgraphNode failed, subgraph_name:%s.", subgraph_name.c_str()); - GELOGE(FAILED, "[Build][SubgraphNode] failed, subgraph_name:%s.", subgraph_name.c_str()); - return nullptr; - } - - const ComputeGraphPtr subgraph = BuildSubgraph(graph_node, graph_info, subgraph_name); - if (subgraph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Build Subgraph %s failed", subgraph_name.c_str()); - GELOGE(FAILED, "[Build][Subgraph] %s failed", subgraph_name.c_str()); - return nullptr; - } - const auto &root_graph = GraphUtils::FindRootGraph(graph_node->GetOwnerComputeGraph()); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Find root graph failed, graph:%s", graph.GetName().c_str()); - GELOGE(FAILED, "[Find][RootGraph] failed, graph:%s", graph.GetName().c_str()); - return nullptr; - } - if (root_graph->AddSubgraph(subgraph) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Add subgraph %s failed, root graph:%s", subgraph->GetName().c_str(), - root_graph->GetName().c_str()); - GELOGE(FAILED, "[Add][SubGraph] %s failed, root graph:%s", subgraph->GetName().c_str(), - root_graph->GetName().c_str()); - return nullptr; - } - - if ((RelinkDataEdges(graph_node, graph_info) != GRAPH_SUCCESS) || - (RelinkCtrlEdges(graph_node, graph_info) != GRAPH_SUCCESS)) { - REPORT_INNER_ERR_MSG("E18888", "ReLink edges for graph %s failed, graph_node:%s", graph.GetName().c_str(), - graph_node->GetName().c_str()); - GELOGE(FAILED, "[ReLink][Edges] for graph %s failed, graph_node:%s", graph.GetName().c_str(), - graph_node->GetName().c_str()); - return nullptr; - } - - for (const auto &node : nodes) { - // op_desc of node should not be null - const auto subgraph_names_inner = node->GetOpDesc()->GetSubgraphInstanceNames(); - for (const auto &subgraph_name_inner : subgraph_names_inner) { - node->GetOpDesc()->RemoveSubgraphInstanceName(subgraph_name_inner); - } - if (RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node) != GRAPH_SUCCESS) { - GELOGW("Remove node %s failed.", node->GetName().c_str()); - } - } - - return subgraph; -} -template -void GraphUtils::BuildGraphInfoFromNodes(const Container &nodes, GraphInfo &graph_info) { - // 将 nodes 转换为有序容器(std::vector),并按节点的唯一标识符排序 - std::vector ordered_nodes(nodes.begin(), nodes.end()); - - // 按节点的唯一标识符排序 - std::sort(ordered_nodes.begin(), ordered_nodes.end(), [](const NodePtr &a, const NodePtr &b) { - return a->GetName() < b->GetName(); - }); - - std::map data_input_index_map; - for (const auto &node : ordered_nodes) { - // graph nodes - (void)graph_info.nodes_.emplace(node); - // in data - BuildInDataEdgesFromNode(node, nodes, data_input_index_map, graph_info); - // out data - std::list peer_data_anchors; - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - peer_data_anchors.clear(); - const auto &peer_in_anchors = out_data_anchor->GetPeerInDataAnchors(); - (void)std::copy_if(peer_in_anchors.begin(), peer_in_anchors.end(), std::back_inserter(peer_data_anchors), - [nodes](const InDataAnchorPtr &peer_in_anchor) { - return nodes.count(peer_in_anchor->GetOwnerNode()) == 0UL; - }); - if (!peer_data_anchors.empty()) { - const size_t output_index = graph_info.data_outputs_.size(); - graph_info.data_outputs_[output_index] = std::make_pair(out_data_anchor, peer_data_anchors); - } - } - // in ctrl - for (const auto &in_ctrl_node : node->GetInControlNodes()) { - if (nodes.count(in_ctrl_node) == 0UL) { - graph_info.ctrl_inputs_.emplace_back(in_ctrl_node->GetOutControlAnchor(), node->GetInControlAnchor()); - } else { - graph_info.inner_ctrl_edges_.emplace_back(std::make_pair(in_ctrl_node->GetOutControlAnchor(), - node->GetInControlAnchor())); - } - } - // out ctrl - for (const auto &out_ctrl_node : node->GetOutControlNodes()) { - if (nodes.count(out_ctrl_node) == 0UL) { - graph_info.ctrl_outputs_.emplace_back(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()); - } - } - } -} -template -void GraphUtils::BuildInDataEdgesFromNode(const NodePtr &node, const Container &nodes, - std::map &data_input_index_map, - GraphInfo &graph_info) { - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - continue; - } - if (nodes.count(peer_out_anchor->GetOwnerNode()) == 0UL) { - size_t input_index; - if (data_input_index_map.count(peer_out_anchor) == 0UL) { - input_index = graph_info.data_inputs_.size(); - data_input_index_map[peer_out_anchor] = input_index; - graph_info.data_inputs_[input_index].first = peer_out_anchor; - } else { - input_index = data_input_index_map[peer_out_anchor]; - } - graph_info.data_inputs_[input_index].second.emplace_back(in_data_anchor); - } else { - graph_info.inner_data_edges_.emplace_back(std::make_pair(peer_out_anchor, in_data_anchor)); - } - } -} - -NodePtr GraphUtils::BuildSubgraphNode(ComputeGraph &graph, const std::string &graph_name, - const GraphInfo &graph_info) { - OpDescBuilder op_desc_builder(graph_name + "_" + PARTITIONEDCALL, PARTITIONEDCALL); - int32_t i = 0; - for (const auto &item : graph_info.data_inputs_) { - for (const auto &in_data_anchor : item.second.second) { - const auto input_desc = in_data_anchor->GetOwnerNodeBarePtr()->GetOpDesc(); - if (input_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "op_desc is null, node:%s", in_data_anchor->GetOwnerNode()->GetName().c_str()); - GELOGE(PARAM_INVALID, "[Check][Param] op_desc is null, node:%s", - in_data_anchor->GetOwnerNode()->GetName().c_str()); - return nullptr; - } - (void)op_desc_builder.AddInput("args" + std::to_string(i), - input_desc->GetInputDesc(static_cast(in_data_anchor->GetIdx()))); - i++; - } - } - for (const auto &item : graph_info.data_outputs_) { - const auto output_desc = item.second.first->GetOwnerNodeBarePtr()->GetOpDesc(); - if (output_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "op_desc is null, node:%s", item.second.first->GetOwnerNode()->GetName().c_str()); - GELOGE(PARAM_INVALID, "[Check][Param] op_desc is null, node:%s", - item.second.first->GetOwnerNode()->GetName().c_str()); - return nullptr; - } - (void)op_desc_builder.AddOutput("output" + std::to_string(item.first), - output_desc->GetOutputDesc(static_cast(item.second.first->GetIdx()))); - } - - const OpDescPtr op_desc = op_desc_builder.Build(); - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Create op_desc for subgraph node failed, name:%s.", graph_name.c_str()); - GELOGE(FAILED, "[Create][OpDesc] for subgraph node failed, name:%s.", graph_name.c_str()); - return nullptr; - } - - (void)op_desc->AddSubgraphName("f"); - (void)op_desc->SetSubgraphInstanceName(0U, graph_name); - - return graph.AddNode(op_desc); -} - -ComputeGraphPtr GraphUtils::BuildGraph(const GraphInfo &graph_info, const std::string &name) { - return BuildGraphInternal(graph_info, name, nullptr); // 普通图,parent_node 为 nullptr -} - -ComputeGraphPtr GraphUtils::BuildSubgraph(const NodePtr &subgraph_node, const GraphInfo &graph_info, - const std::string &subgraph_name) { - return BuildGraphInternal(graph_info, subgraph_name, subgraph_node); // 子图,传入 parent_node -} - -graphStatus GraphUtils::RelinkDataEdges(const NodePtr &subgraph_node, const GraphInfo &graph_info) { - // in data nodes - int32_t i = 0; - for (const auto &item : graph_info.data_inputs_) { - for (const auto &in_data_anchor : item.second.second) { - GE_CHK_STATUS_RET(item.second.first->Unlink(in_data_anchor), "[Remove][DataEdge] %s:%d->%s:%d failed", - item.second.first->GetOwnerNode()->GetName().c_str(), item.second.first->GetIdx(), - in_data_anchor->GetOwnerNode()->GetName().c_str(), in_data_anchor->GetIdx()); - GE_CHK_STATUS_RET(item.second.first->LinkTo(subgraph_node->GetInDataAnchor(i)), - "[Add][DataEdge] %s:%d->%s:%u failed.", - item.second.first->GetOwnerNode()->GetName().c_str(), - item.second.first->GetIdx(), subgraph_node->GetName().c_str(), item.first); - i++; - } - } - // out data nodes - for (const auto &item : graph_info.data_outputs_) { - const auto &out_data_anchor = subgraph_node->GetOutDataAnchor(static_cast(item.first)); - GE_CHECK_NOTNULL(out_data_anchor); - for (const auto &peer_in_anchor : item.second.second) { - GE_CHK_STATUS_RET(item.second.first->Unlink(peer_in_anchor), "[Remove][DataEdge] %s:%d->%s:%d failed.", - item.second.first->GetOwnerNode()->GetName().c_str(), item.second.first->GetIdx(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx()); - GE_CHK_STATUS_RET(out_data_anchor->LinkTo(peer_in_anchor), "[Add][DataEdge] %s:%u->%s:%d failed.", - subgraph_node->GetName().c_str(), item.first, peer_in_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_anchor->GetIdx()); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::RelinkCtrlEdges(const NodePtr &subgraph_node, const GraphInfo &graph_info) { - // in ctrl nodes - for (const auto &ctrl_input : graph_info.ctrl_inputs_) { - GE_CHK_STATUS_RET(ctrl_input.first->Unlink(ctrl_input.second), "[Remove][CtrlEdge] %s->%s failed", - ctrl_input.first->GetOwnerNode()->GetName().c_str(), - ctrl_input.second->GetOwnerNode()->GetName().c_str()); - if (!ctrl_input.first->IsLinkedWith(subgraph_node->GetInControlAnchor())) { - GE_CHK_STATUS_RET(ctrl_input.first->LinkTo(subgraph_node->GetInControlAnchor()), "[Add][CtrlEdge] %s->%s failed.", - ctrl_input.first->GetOwnerNode()->GetName().c_str(), subgraph_node->GetName().c_str()); - } - } - // out ctrl nodes - for (const auto &ctrl_output : graph_info.ctrl_outputs_) { - GE_CHK_STATUS_RET(ctrl_output.first->Unlink(ctrl_output.second), "[Remove][CtrlEdge] %s->%s failed.", - ctrl_output.first->GetOwnerNode()->GetName().c_str(), - ctrl_output.second->GetOwnerNode()->GetName().c_str()); - if (!subgraph_node->GetOutControlAnchor()->IsLinkedWith(ctrl_output.second)) { - GE_CHK_STATUS_RET(subgraph_node->GetOutControlAnchor()->LinkTo(ctrl_output.second), - "[Add][CtrlEdge] %s->%s failed.", subgraph_node->GetName().c_str(), - ctrl_output.second->GetOwnerNode()->GetName().c_str()); - } - } - - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::UnfoldSubgraph(const ComputeGraphPtr &graph, - const std::function &filter) { - GE_CHECK_NOTNULL(graph); - const auto &parent_graph = graph->GetParentGraph(); - const auto &parent_node = graph->GetParentNode(); - if ((parent_graph == nullptr) && (parent_node == nullptr)) { - return GRAPH_SUCCESS; - } - - return UnfoldGraph(graph, parent_graph, parent_node, filter); -} - -graphStatus GraphUtils::UnfoldGraph(const ComputeGraphPtr &graph, const ComputeGraphPtr &target_graph, - const NodePtr &target_node, const function &filter, - int32_t depth) { - if (depth >= kCopyGraphMaxRecursionDepth) { - REPORT_INNER_ERR_MSG("E18888", "param depth:%d >= %d(allow max subgraphs)", depth, kCopyGraphMaxRecursionDepth); - GELOGE(GRAPH_FAILED, "[Check][Param]exist too much subgraphs:%d > %d(allow max subgraphs)", depth, - kCopyGraphMaxRecursionDepth); - return GRAPH_FAILED; - } - GE_CHECK_NOTNULL(graph); - GE_CHECK_NOTNULL(target_graph); - GE_CHECK_NOTNULL(target_node); - - GE_CHK_STATUS_RET(MergeInputNodes(graph, target_node), - "[Invoke][MergeInputNodes] Merge data nodes for graph %s failed", graph->GetName().c_str()); - GE_CHK_STATUS_RET(MergeNetOutputNode(graph, target_node), - "[Invoke][MergeNetOutputNode] Merge net output nodes for graph %s failed", - graph->GetName().c_str()); - GELOGD("[%s] Merging graph inputs and outputs successfully", graph->GetName().c_str()); - - for (auto &node : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - if ((node->GetType() == DATA) || (node->GetType() == NETOUTPUT)) { - continue; - } - - std::vector subgraphs; - GE_CHK_STATUS_RET(NodeUtils::GetDirectSubgraphs(node, subgraphs), "[Get][Subgraphs] failed, graph:%s", - node->GetName().c_str()); - bool skip_add_node_flag = true; - for (const auto &subgraph : subgraphs) { - if ((filter != nullptr) && filter(subgraph)) { - GE_CHK_STATUS_RET( - UnfoldGraph(subgraph, subgraph->GetParentGraph(), subgraph->GetParentNode(), filter, depth + 1), - "[Invoke][UnfoldSubgraph] Failed to merge graph %s", subgraph->GetName().c_str()); - skip_add_node_flag = false; - } else { - subgraph->SetParentGraph(target_graph); - } - } - - if (skip_add_node_flag) { - (void) target_graph->AddNode(node); - GELOGD("[%s::%s] added to target graph: [%s].", graph->GetName().c_str(), node->GetName().c_str(), - target_graph->GetName().c_str()); - (void) node->SetOwnerComputeGraph(target_graph); - } - } - - GELOGD("[%s] Done merging graph. remove it from root graph", graph->GetName().c_str()); - - const auto &subgraph_name = graph->GetName(); - const auto &root_graph = GraphUtils::FindRootGraph(target_graph); - GE_CHECK_NOTNULL(root_graph); - root_graph->RemoveSubgraph(graph->GetName()); - target_node->GetOpDesc()->RemoveSubgraphInstanceName(subgraph_name); - if (RemoveNodeWithoutRelink(target_graph, target_node) != GRAPH_SUCCESS) { - GELOGW("Remove node %s failed, graph:%s.", target_node->GetName().c_str(), target_graph->GetName().c_str()); - } - - return SUCCESS; -} - -graphStatus GraphUtils::MergeInputNodes(const ComputeGraphPtr &graph, const NodePtr& target_node) { - GE_CHECK_NOTNULL(target_node); - - std::set src_nodes; - for (const auto &node : graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - if (node->GetType() != DATA) { - if (node->GetInAllNodes().empty()) { - (void)src_nodes.emplace(node); - } - continue; - } - - uint32_t parent_index = 0U; - if ((!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) && - (!AttrUtils::GetInt(node->GetOpDesc(), "index", parent_index))) { - REPORT_INNER_ERR_MSG("E18888", "Get attr {%s} or {index} failed, node:%s", ATTR_NAME_PARENT_NODE_INDEX.c_str(), - node->GetName().c_str()); - GELOGE(FAILED, "[Get][Attr] {%s} or {index} failed, node:%s", ATTR_NAME_PARENT_NODE_INDEX.c_str(), - node->GetName().c_str()); - return GRAPH_FAILED; - } - - const auto parent_node_in_anchor = target_node->GetInDataAnchor(static_cast(parent_index)); - GE_CHECK_NOTNULL(parent_node_in_anchor); - const auto src_out_anchor = parent_node_in_anchor->GetPeerOutAnchor(); - if ((src_out_anchor == nullptr) || (src_out_anchor->GetOwnerNodeBarePtr() == nullptr)) { - continue; - } - parent_node_in_anchor->UnlinkAll(); - - // link src to outputs of DataNode - for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - auto dst_node = peer_in_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(dst_node); - const auto &in_nodes = dst_node->GetInDataNodes(); - if (std::all_of(in_nodes.begin(), in_nodes.end(), [](const NodePtr &n) { return n->GetType() == DATA; })) { - (void)src_nodes.emplace(dst_node); - } - GE_CHK_STATUS_RET(ReplaceEdgeSrc(out_data_anchor, peer_in_anchor, src_out_anchor), - "[Replace][DataEdge] failed"); - } - } - // when unfold partitonCall, if data have control edges, which will be left in final graph - // which cause topo sort failed. - auto out_control_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(out_control_anchor); - out_control_anchor->UnlinkAll(); - } - - // transfer in control edges to all root nodes - for (const auto &src_node : src_nodes) { - const auto &in_nodes = src_node->GetInAllNodes(); - const std::set in_node_set(in_nodes.begin(), in_nodes.end()); - for (const auto &in_control_node : target_node->GetInControlNodes()) { - GE_CHECK_NOTNULL(in_control_node); - if ((in_node_set.count(in_control_node) == 0UL) && (kMergeInputSkipTypes.count(src_node->GetType()) == 0UL)) { - GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), src_node->GetName().c_str()); - (void)AddEdge(in_control_node->GetOutControlAnchor(), src_node->GetInControlAnchor()); - } - } - } - - target_node->GetInControlAnchor()->UnlinkAll(); - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::MergeNetOutputNode(const ComputeGraphPtr &graph, const NodePtr& target_node) { - GE_CHECK_NOTNULL(target_node); - - const NodePtr &net_output = graph->FindFirstNodeMatchType(NETOUTPUT); - if (net_output == nullptr) { - GELOGD("Graph has no NetOutput node, no need to merge"); - return SUCCESS; - } - auto all_in_nodes = net_output->GetInAllNodes(); - auto all_out_nodes = target_node->GetOutAllNodes(); - net_output->GetInControlAnchor()->UnlinkAll(); - target_node->GetOutControlAnchor()->UnlinkAll(); - - for (const auto &in_data_anchor : net_output->GetAllInDataAnchors()) { - GE_CHECK_NOTNULL(in_data_anchor); - const auto index = in_data_anchor->GetIdx(); - uint32_t parent_index = index; - // op_desc of node should not be null - if (!AttrUtils::GetInt(net_output->GetOpDesc()->GetInputDesc(static_cast(index)), - ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGW("SubGraph: %s NetOutput input tensor %d, attr %s not found, use anchor index %u.", - graph->GetName().c_str(), index, ATTR_NAME_PARENT_NODE_INDEX.c_str(), parent_index); - } - - const auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - GE_CHECK_NOTNULL(src_out_anchor->GetOwnerNodeBarePtr()); - GE_CHK_STATUS_RET(RemoveEdge(src_out_anchor, in_data_anchor), "[Remove][DataEdge] %s:%d->%s:%d failed", - src_out_anchor->GetOwnerNode()->GetName().c_str(), src_out_anchor->GetIdx(), - net_output->GetName().c_str(), in_data_anchor->GetIdx()); - - const OutDataAnchorPtr &parent_out_anchor = target_node->GetOutDataAnchor(static_cast(parent_index)); - GE_CHECK_NOTNULL(parent_out_anchor); - for (InDataAnchorPtr &dst_in_anchor : parent_out_anchor->GetPeerInDataAnchors()) { - GE_CHK_STATUS_RET(ReplaceEdgeSrc(parent_out_anchor, dst_in_anchor, src_out_anchor), - "[Replace][DataEdge] failed"); - } - } - - // transfer out control edges - const OrderedNodeSet in_node_set(all_in_nodes.begin(), all_in_nodes.end()); - const OrderedNodeSet out_node_set(all_out_nodes.begin(), all_out_nodes.end()); - for (auto &src_node : in_node_set) { - GELOGD("[%s] process in node.", src_node->GetName().c_str()); - auto out_nodes = src_node->GetOutAllNodes(); - const std::set node_set(out_nodes.begin(), out_nodes.end()); - for (auto &dst_node : out_node_set) { - if (node_set.count(dst_node) == 0UL) { - GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str()); - (void)src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor()); - } - } - } - - return GRAPH_SUCCESS; -} - -void GraphUtils::InheritOriginalAttr(const ComputeGraphPtr &src_compute_graph, - ComputeGraphPtr &dst_compute_graph) { - const std::map &original_attrs = AttrUtils::GetAllAttrs(src_compute_graph); - for (auto const &attr_iter : original_attrs) { - if (dst_compute_graph->TrySetAttr(attr_iter.first, attr_iter.second) != GRAPH_SUCCESS) { - GELOGW("Set inherit original attr[%s] failed, Please Check.", attr_iter.first.c_str()); - } - } -} - -bool GraphUtils::IsSingleOpScene(const ComputeGraphPtr &graph) { - bool is_single_op = false; - if (AttrUtils::GetBool(graph, ATTR_SINGLE_OP_SCENE, is_single_op)) { - return is_single_op; - } - GELOGD("There is no _single_op_scene for graph:%s. Start search all node.", graph->GetName().c_str()); - for (const auto &node : graph->GetAllNodes()) { - GE_ASSERT_NOTNULL(node->GetOpDesc()); - if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_SINGLE_OP_SCENE, is_single_op)) { - return is_single_op; - } - } - return is_single_op; -} - -CycleDetectorPtr GraphUtils::CreateCycleDetector(const ComputeGraphPtr &graph) { - CycleDetectorPtr detector = ComGraphMakeUnique(); - if (detector == nullptr) { - GELOGW("Fail to create cycle detector. Return null."); - return nullptr; - } - const auto ret = detector->Init(graph); - if (ret != SUCCESS) { - GELOGW("Fail to init cycle detector. Return null."); - return nullptr; - } - return detector; -} - -CycleDetectorSharedPtr GraphUtils::CreateSharedCycleDetector(const ComputeGraphPtr &graph) { - CycleDetectorSharedPtr detector = nullptr; - GE_MAKE_SHARED(detector = std::make_shared(), return nullptr); - if (detector == nullptr) { - GELOGW("Fail to create cycle detector. Return null."); - return nullptr; - } - const auto ret = detector->Init(graph); - if (ret != SUCCESS) { - GELOGW("Fail to init cycle detector. Return null."); - return nullptr; - } - return detector; -} - -/// @brief Add node to graph -/// @param [in] op_desc -/// @return ComputeGraphBuilder -ComputeGraphBuilder& ComputeGraphBuilder::AddNode(const OpDescPtr &op_desc) { - nodes_.emplace_back(op_desc); - return *this; -} - -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return ComputeGraphBuilder -ComputeGraphBuilder& ComputeGraphBuilder::AddDataLink(const std::string &src_name, const uint32_t out_anchor_ind, - const std::string &dst_name, const uint32_t in_anchor_ind) { - data_links_.emplace_back(std::make_pair(std::make_pair(src_name, out_anchor_ind), - std::make_pair(dst_name, in_anchor_ind))); - return *this; -} - -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return ComputeGraphBuilder -ComputeGraphBuilder& ComputeGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - ctrl_links_.emplace_back(std::make_pair(src_name, dst_name)); - return *this; -} - -/// @brief Build nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void ComputeGraphBuilder::BuildNodes(graphStatus &error_code, std::string &error_msg) { - if (owner_graph_ == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; - return; - } - - std::string node_name; - for (auto &op_desc : nodes_) { - if (op_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "op_desc is NULL."; - return; - } - - node_name = op_desc->GetName(); - const NodePtr node = owner_graph_->AddNode(op_desc); - if (node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "Add node " + node_name + " failed."; - return; - } - - GELOGD("Add node name:%s, type:%s.", node_name.c_str(), op_desc->GetType().c_str()); - node_names_[node_name] = node; - } - - GELOGD("BuildNodes succ."); -} - -/// @brief Build data-links -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void ComputeGraphBuilder::BuildDataLinks(graphStatus &error_code, std::string &error_msg) { - for (auto &pair : data_links_) { - const std::string src_name = pair.first.first; - const auto out_ind = static_cast(pair.first.second); - const std::string dst_name = pair.second.first; - const auto in_ind = static_cast(pair.second.second); - std::string log_msg = "Add data-edge "; - (void)log_msg.append(src_name).append(":").append(std::to_string(out_ind)).append("->") - .append(dst_name).append(":").append(std::to_string(in_ind)); - - const auto src_iter = node_names_.find(src_name); - const auto dst_iter = node_names_.find(dst_name); - if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node does not exist in graph."; - return; - } - - const NodePtr src_node = node_names_[src_name]; - const NodePtr dst_node = node_names_[dst_name]; - if ((src_node == nullptr) || (dst_node == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node is NULL."; - return; - } - - if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_ind), dst_node->GetInDataAnchor(in_ind)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed."; - return; - } - - GELOGD("%s succ.", log_msg.c_str()); - } - - GELOGD("BuildDataLinks succ."); -} - -/// @brief Build ctrl-links -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void ComputeGraphBuilder::BuildCtrlLinks(graphStatus &error_code, std::string &error_msg) { - for (auto &pair : ctrl_links_) { - const std::string src_name = pair.first; - const std::string dst_name = pair.second; - std::string log_msg = "Add ctrl-edge "; - (void)log_msg.append(src_name).append("->").append(dst_name); - - const auto src_iter = node_names_.find(src_name); - const auto dst_iter = node_names_.find(dst_name); - if ((src_iter == node_names_.end()) || (dst_iter == node_names_.end())) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node does not exist in graph."; - return; - } - - const NodePtr src_node = node_names_[src_name]; - const NodePtr dst_node = node_names_[dst_name]; - if ((src_node == nullptr) || (dst_node == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: node is NULL."; - return; - } - - if (GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed."; - return; - } - - GELOGD("%s succ.", log_msg.c_str()); - } - - GELOGD("BuildCtrlLinks succ."); -} - -/// @brief Get node with name -/// @param [in] name -/// @return NodePtr -NodePtr ComputeGraphBuilder::GetNode(const std::string &name) { - const auto iter = node_names_.find(name); - if (iter == node_names_.end()) { - REPORT_INNER_ERR_MSG("E18888", "node %s does not exist.", name.c_str()); - GE_LOGE("[Check][Param] node %s does not exist.", name.c_str()); - return nullptr; - } - return iter->second; -} - -/// @brief Get all nodes -/// @return std::vector -std::vector ComputeGraphBuilder::GetAllNodes() { - std::vector nodes; - for (const auto &iter : node_names_) { - nodes.emplace_back(iter.second); - } - return nodes; -} - -/// @brief Add node to graph -/// @param [in] op_desc -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::AddNode(const OpDescPtr &op_desc) { - (void)ComputeGraphBuilder::AddNode(op_desc); - return *this; -} - -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::AddDataLink(const std::string &src_name, const uint32_t out_anchor_ind, - const std::string &dst_name, const uint32_t in_anchor_ind) { - (void)ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); - return *this; -} - -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - (void)ComputeGraphBuilder::AddControlLink(src_name, dst_name); - return *this; -} - -/// @brief Set index_th input anchor for graph -/// @param [in] index -/// @param [in] node_names -/// @param [in] anchor_inds -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::SetInput(const uint32_t index, const std::vector &node_names, - const std::vector &anchor_inds) { - graph_inputs_[index] = std::make_pair(node_names, anchor_inds); - return *this; -} - -/// @brief Set index_th input of graph as useless -/// @param [in] index -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::SetUselessInput(const uint32_t index) { - graph_inputs_[index] = std::make_pair(std::vector(), std::vector()); - return *this; -} - -/// @brief Add output anchor for graph -/// @param [in] owner_node_name -/// @param [in] anchor_ind -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::AddOutput(const std::string &owner_node_name, uint32_t anchor_ind) { - graph_outputs_.emplace_back(std::make_pair(owner_node_name, anchor_ind)); - return *this; -} - -/// @brief Add target for graph -/// @param [in] target_name -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::AddTarget(const std::string &target_name) { - graph_targets_.emplace_back(target_name); - return *this; -} - -/// @brief Set parent-node of graph -/// @param [in] parent_node -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::SetParentNode(const NodePtr &parent_node) { - parent_node_ = parent_node; - return *this; -} - -/// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node -/// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::SetInputMapping(const std::map &input_mapping) { - for (auto &item : input_mapping) { - input_mapping_[item.first] = item.second; - } - return *this; -} - -/// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind -/// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node -/// @return CompleteGraphBuilder -CompleteGraphBuilder& CompleteGraphBuilder::SetOutputMapping(const std::map &output_mapping) { - for (auto &item : output_mapping) { - output_mapping_[item.first] = item.second; - } - return *this; -} - -/// @brief Build graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return ComputeGraphPtr -ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { - owner_graph_ = ComGraphMakeShared(name_); - if (owner_graph_ == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; - return nullptr; - } - - BuildNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildDataLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildCtrlLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - AddDataNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - if (retval_flag_) { - AddRetValNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - BuildGraphTargets(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - } else { - AddNetOutputNode(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - } - - PostProcess(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - return owner_graph_; -} - -/// @brief Add data nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { - for (auto &input : graph_inputs_) { - const NodePtr data_node = AddDataNode(input.first, error_code, error_msg); - if (data_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + + " failed."; - return; - } - - if (owner_graph_->AddInputNode(data_node) == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + + " failed."; - return; - } - - // useless input - const std::vector input_names = input.second.first; - const std::vector anchor_indes = input.second.second; - if (input_names.size() != anchor_indes.size()) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; - return; - } - if (input_names.empty()) { - continue; - } - - const size_t input_num = input_names.size(); - for (size_t i = 0U; i < input_num; i++) { - const std::string input_name = input_names[i]; - const int32_t ind = static_cast(anchor_indes[i]); - const auto iter = node_names_.find(input_name); - if (iter == node_names_.end()) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: node " + input_name + " does not exist in graph."; - return; - } - - const NodePtr in_node = node_names_[input_name]; - if (in_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; - return; - } - - if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + - input_name + ":" + std::to_string(ind) + " failed."; - return; - } - } - - GELOGD("AddDataNodes : Add %u input succ.", input.first); - } - - GELOGD("AddDataNodes succ."); -} - -/// @brief Add data node -/// @param [in] index -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -NodePtr CompleteGraphBuilder::AddDataNode(const uint32_t index, graphStatus &error_code, std::string &error_msg) { - const std::string data_name = "Data_" + std::to_string(index); - OpDescBuilder op_desc_builder(data_name, "Data"); - const OpDescPtr op_desc = op_desc_builder.AddInput("x") - .AddOutput("y") - .Build(); - if (op_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; - return nullptr; - } - - const auto index_iter = input_mapping_.find(index); - if (index_iter != input_mapping_.end()) { - if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, static_cast(index_iter->second))) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; - return nullptr; - } - } - if (parent_node_ != nullptr) { - // op_desc should not be null - const auto &parent_desc = parent_node_->GetOpDesc()->GetInputDesc(index_iter->second); - if ((op_desc->UpdateInputDesc(0U, parent_desc) != GRAPH_SUCCESS) || - (op_desc->UpdateOutputDesc(0U, parent_desc) != GRAPH_SUCCESS)) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: update tensor_desc for " + data_name + " failed."; - return nullptr; - } - } - - const NodePtr data_node = owner_graph_->AddNode(op_desc); - if (data_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddDataNode failed: add node " + data_name + " failed."; - return nullptr; - } - node_names_[data_name] = data_node; - - return data_node; -} - -/// @brief Add RetVal nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { - const size_t output_num = graph_outputs_.size(); - for (size_t i = 0U; i < output_num; i++) { - const int32_t index = static_cast(graph_outputs_[i].second); - const auto out_iter = node_names_.find(graph_outputs_[i].first); - if (out_iter == node_names_.end()) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " does not exist in graph."; - return; - } - const NodePtr node = out_iter->second; - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode failed: node is NULL."; - return; - } - - const std::string name = node->GetName() + "_RetVal_"+ std::to_string(index); - const OpDescPtr ret_val_desc = ComGraphMakeShared(name, FRAMEWORKOP); - if (ret_val_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; - return; - } - const ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(static_cast(index)); - if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || - (ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; - return; - } - - if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && - ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, static_cast(i)))) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; - return; - } - const auto iter = output_mapping_.find(i); - if (iter != output_mapping_.end()) { - if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, static_cast(iter->second))) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; - return; - } - } - - const NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); - if (ret_val_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add node failed."; - return; - } - - if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddRetValNode " + name + " failed: add data-edge " + - node->GetName() + ":" + std::to_string(index) + "->" + ret_val_node->GetName() + ":0 failed."; - return; - } - } - - GELOGD("AddRetValNodes succ."); -} - -/// @brief Build target-nodes for graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::string &error_msg) { - std::vector target_nodes; - for (const std::string &target_name : graph_targets_) { - const auto target_iter = node_names_.find(target_name); - if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "BuildGraphTargets failed: target_node " + target_name + " does not exist in graph."; - return; - } - target_nodes.emplace_back(target_iter->second); - } - owner_graph_->SetGraphTargetNodesInfo(target_nodes); -} - -/// @brief Add NetOutput node -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) { - if (graph_outputs_.empty() && graph_targets_.empty()) { - return; - } - const std::string node_name = "Node_Output"; - const std::string log_msg = "AddNetOutputNode name:" + node_name + ", type:" + NETOUTPUT; - const OpDescPtr net_output_desc = ComGraphMakeShared(node_name, NETOUTPUT); - if (net_output_desc == nullptr) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: op_desc is NULL."; - return; - } - - const size_t output_num = graph_outputs_.size(); - std::vector peer_out_anchors(output_num); - for (size_t i = 0U; i < output_num; i++) { - const uint32_t index = graph_outputs_[i].second; - const auto out_iter = node_names_.find(graph_outputs_[i].first); - if (out_iter == node_names_.end()) { - error_code = GRAPH_FAILED; - error_msg = "AddNetOutputNode failed: node " + graph_outputs_[i].first + " does not exist in graph."; - return; - } - const NodePtr node = out_iter->second; - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "AddNetOutputNode failed: node is NULL."; - return; - } - - ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); - int64_t update_index = static_cast(i); - const auto iter = output_mapping_.find(i); - if (iter != output_mapping_.end()) { - update_index = static_cast(iter->second); - (void) ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, update_index); - } - - if (net_output_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddNetOutputNode failed: add input_desc ailed."; - return; - } - peer_out_anchors[i] = node->GetOutDataAnchor(static_cast(index)); - } - - BuildNetOutputNodeWithLink(net_output_desc, peer_out_anchors, error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return; - } - - GELOGD("%s succ.", log_msg.c_str()); -} - -/// @brief Build NetOutput nodes with data & ctrl edges -/// @param [in] net_output_desc -/// @param [in] peer_out_anchors -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void CompleteGraphBuilder::BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc, - const std::vector &peer_out_anchors, - graphStatus &error_code, std::string &error_msg) { - const std::string log_msg = "AddNetOutputNode name:" + std::string(NODE_NAME_NET_OUTPUT) + ", type:" + NETOUTPUT; - const NodePtr net_output = owner_graph_->AddNode(net_output_desc); - if (net_output == nullptr) { - error_code = GRAPH_FAILED; - error_msg = log_msg + " failed: add NetOutput node failed."; - return; - } - owner_graph_->SetNetOutputNode(net_output); - - const size_t output_num = graph_outputs_.size(); - for (size_t i = 0U; i < output_num; i++) { - if (GraphUtils::AddEdge(peer_out_anchors[i], - net_output->GetInDataAnchor(static_cast(i))) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddNetOutputNode failed: add data-edge " + - peer_out_anchors[i]->GetOwnerNode()->GetName() + ":" + std::to_string(peer_out_anchors[i]->GetIdx()) + - "->" + NODE_NAME_NET_OUTPUT + ":" + std::to_string(i) + " failed."; - return; - } - } - for (const std::string &target_name : graph_targets_) { - const auto target_iter = node_names_.find(target_name); - if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { - error_code = GRAPH_FAILED; - error_msg = "BuildGraphTargets failed: target_node " + target_name + " does not exist in graph."; - return; - } - const auto &target_node = target_iter->second; - if (GraphUtils::AddEdge(target_node->GetOutControlAnchor(), net_output->GetInControlAnchor()) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "AddNetOutputNode failed: add ctrl-edge " + - target_node->GetName() + "->" + NODE_NAME_NET_OUTPUT + " failed."; - return; - } - } -} - -/// @brief process after build -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void CompleteGraphBuilder::PostProcess(graphStatus &error_code, std::string &error_msg) { - if (parent_node_ != nullptr) { - owner_graph_->SetParentNode(parent_node_); - const auto &parent_graph = parent_node_->GetOwnerComputeGraph(); - if (parent_graph == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "Parent graph is null, parent_node=" + parent_node_->GetName(); - return; - } - owner_graph_->SetParentGraph(parent_graph); - // ATTR_NAME_SESSION_GRAPH_ID - std::string graph_id; - if ((!AttrUtils::GetStr(parent_graph, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) || - (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id))) { - error_code = GRAPH_FAILED; - error_msg = "Copy attr session_graph_id failed."; - return; - } - if (parent_graph->HasAttr(ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED)) { - bool is_dynamic_shape = false; - if ((!AttrUtils::GetBool(parent_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape)) || - (!AttrUtils::SetBool(owner_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape))) { - error_code = GRAPH_FAILED; - error_msg = "Copy attr _dynamic_shape_partitioned failed."; - return; - } - } - owner_graph_->SetGraphUnknownFlag(parent_graph->GetGraphUnknownFlag()); - - // refresh parent node/graph in subgraphs - for (const NodePtr &node : owner_graph_->GetDirectNode()) { - std::vector subgraphs; - if (NodeUtils::GetDirectSubgraphs(node, subgraphs) != GRAPH_SUCCESS) { - error_code = GRAPH_FAILED; - error_msg = "Get subgraphs for failed failed, node:" + node->GetName(); - return; - } - for (const auto &subgraph : subgraphs) { - subgraph->SetParentNode(node); - subgraph->SetParentGraph(subgraph); - } - } - } - - // refresh node name - for (const NodePtr &node : owner_graph_->GetDirectNode()) { - if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { - continue; - } - node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); - } -} - -/// @brief Add node to graph -/// @param [in] op_desc -/// @return PartialGraphBuilder -PartialGraphBuilder& PartialGraphBuilder::AddNode(const OpDescPtr &op_desc) { - (void)ComputeGraphBuilder::AddNode(op_desc); - return *this; -} - -/// @brief Add data-link among nodes in graph -/// @param [in] src_name -/// @param [in] out_anchor_ind -/// @param [in] dst_name -/// @param [in] in_anchor_ind -/// @return PartialGraphBuilder -PartialGraphBuilder& PartialGraphBuilder::AddDataLink(const std::string &src_name, const uint32_t out_anchor_ind, - const std::string &dst_name, const uint32_t in_anchor_ind) { - (void)ComputeGraphBuilder::AddDataLink(src_name, out_anchor_ind, dst_name, in_anchor_ind); - return *this; -} - -/// @brief Add ctrl-link among nodes in graph -/// @param [in] src_name -/// @param [in] dst_name -/// @return PartialGraphBuilder -PartialGraphBuilder& PartialGraphBuilder::AddControlLink(const std::string &src_name, const std::string &dst_name) { - (void)ComputeGraphBuilder::AddControlLink(src_name, dst_name); - return *this; -} - -/// @brief Set owner graph -/// @param [in] graph -/// @return PartialGraphBuilder -PartialGraphBuilder& PartialGraphBuilder::SetOwnerGraph(const ComputeGraphPtr &graph) { - owner_graph_ = graph; - return *this; -} - -/// @brief Add exist node -/// @param [in] node -/// @return PartialGraphBuilder -PartialGraphBuilder& PartialGraphBuilder::AddExistNode(const NodePtr &exist_node) { - exist_nodes_.emplace_back(exist_node); - return *this; -} - -/// @brief Build partial graph -/// @param [out] error_code -/// @param [out] error_msg -/// @return ComputeGraphPtr -ComputeGraphPtr PartialGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { - if (owner_graph_ == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "graph is NULL."; - return nullptr; - } - - BuildNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildExistNodes(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildDataLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - BuildCtrlLinks(error_code, error_msg); - if (error_code != GRAPH_SUCCESS) { - return nullptr; - } - - return owner_graph_; -} - -/// @brief Build exist nodes -/// @param [out] error_code -/// @param [out] error_msg -/// @return void -void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string &error_msg) { - std::string node_name; - for (auto &exist_node : exist_nodes_) { - if (exist_node == nullptr) { - error_code = GRAPH_FAILED; - error_msg = "Build exist nodes failed: node is NULL."; - return; - } - - node_name = exist_node->GetName(); - if (exist_node->GetOwnerComputeGraph() != owner_graph_) { - error_code = GRAPH_FAILED; - error_msg = "Build exist nodes failed: node " + node_name + " not belongs to this graph."; - return; - } - - GELOGD("Add exist_node name:%s.", node_name.c_str()); - node_names_[node_name] = exist_node; - } - - GELOGD("Build exist nodes succ."); -} -graphStatus GraphUtils::MoveNodeToGraph(const NodePtr &node, ComputeGraph &dst_graph) { - GE_ASSERT_SUCCESS(IsolateNode(node, {})); - GE_ASSERT_SUCCESS(RemoveNodesWithoutRelink(node->GetOwnerComputeGraph(), {node})); - GE_ASSERT_NOTNULL(dst_graph.AddNode(node)); - GE_ASSERT_SUCCESS(node->SetOwnerComputeGraph(dst_graph.shared_from_this())); - return GRAPH_SUCCESS; -} - -graphStatus GraphUtils::RemoveJustNodes(const ComputeGraphPtr &compute_graph, - const std::unordered_set &nodes) { - GE_CHECK_NOTNULL(compute_graph); - GE_CHECK_NOTNULL(compute_graph->impl_); - - size_t success_removed_nodes_size = 0U; - for (auto iter = compute_graph->impl_->nodes_.begin(); iter != compute_graph->impl_->nodes_.end();) { - if (nodes.count(*iter) > 0U) { - GELOGD("Remove %s from graph %s.", (*iter)->GetNamePtr(), compute_graph->GetName().c_str()); - iter = compute_graph->impl_->nodes_.erase(iter); - --(compute_graph->impl_->direct_nodes_size_); - ++success_removed_nodes_size; - } else { - ++iter; - } - } - const auto to_be_remove_nodes_size = nodes.size(); - if (success_removed_nodes_size != to_be_remove_nodes_size) { - GELOGW("Successfully remove %zu nodes but there are %zu nodes to be delete", success_removed_nodes_size, - to_be_remove_nodes_size); - } - return GRAPH_SUCCESS; -} - -namespace { -bool IsCurrentNodeHasMaxTopid(const NodePtr &node, const std::vector out_nodes) { - GE_ASSERT_NOTNULL(node->GetOpDesc()); - int64_t cur_topid = node->GetOpDesc()->GetId(); - for (const auto &out_node : out_nodes) { - GE_ASSERT_NOTNULL(out_node->GetOpDesc()); - if (out_node->GetOpDesc()->GetId() > cur_topid) { - GELOGD("Current node %s does not have max topid.", node->GetName().c_str()); - return false; - } - } - return true; -} - -bool HasSameStreamId(const NodePtr &node, const std::vector out_nodes) { - GE_ASSERT_NOTNULL(node->GetOpDesc()); - int64_t cur_stream_id = node->GetOpDesc()->GetStreamId(); - std::string node_name = node->GetName(); - for (const auto &out_node : out_nodes) { - GE_ASSERT_NOTNULL(out_node->GetOpDesc()); - auto out_node_stream_id = out_node->GetOpDesc()->GetStreamId(); - if (cur_stream_id == kInvalidStream) { - cur_stream_id = out_node_stream_id; - node_name = out_node->GetName(); - continue; - } - if (out_node_stream_id != cur_stream_id && out_node_stream_id != kInvalidStream) { - GELOGD("Node %s stream id[%lld] is not same with node %s stream id[%lld].", - node_name.c_str(), cur_stream_id, out_node->GetName().c_str(), out_node_stream_id); - return false; - } - } - return true; -} - -bool HasRefAttr(const std::vector out_nodes) { - for (const auto &out_node : out_nodes) { - GE_ASSERT_NOTNULL(out_node->GetOpDesc()); - bool is_ref = false; - (void)ge::AttrUtils::GetBool(out_node->GetOpDesc(), ATTR_NAME_REFERENCE, is_ref); - if (is_ref) { - GELOGD("Node %s has ref attr.", out_node->GetName().c_str()); - return true; - } - } - return false; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus GraphUtils::GetSupportInplaceOutput(const NodePtr &node, - std::map> &out_index_to_refable_in_indexes) { - // 找到带output_inplace_ability属性的节点 - GE_ASSERT_NOTNULL(node); - auto node_op_desc = node->GetOpDesc(); - GE_ASSERT_NOTNULL(node_op_desc); - // 获取output_inplace_ability属性,Ge拿到的就是真实的index,不是ir index - std::vector> output_inplace_index_list; - if (!ge::AttrUtils::GetListListInt(node_op_desc, ATTR_NAME_OUTPUT_INPLACE_ABILITY, output_inplace_index_list)) { - return GRAPH_SUCCESS; - } - // 类似这种键值对:{{0,0},{0,1},{0,3},{1,4}} - constexpr size_t kInplaceAbilitySize = 2U; - for (auto &inplace_index : output_inplace_index_list) { - if (inplace_index.size() != kInplaceAbilitySize) { - GELOGW("The size %u of inplace index is not invalid, must be equal to 2.", inplace_index.size()); - return GRAPH_FAILED; - } - GE_ASSERT_TRUE(ge::IntegerChecker::Compat(inplace_index[0])); - GE_ASSERT_TRUE(ge::IntegerChecker::Compat(inplace_index[1])); - size_t output_index = inplace_index[0]; - size_t input_index = inplace_index[1]; - // 判断输入节点对应的所有输出之中,当前节点的topid是否最大,streamid是否相同 - auto in_node = NodeUtils::GetInDataNodeByIndex(*node, static_cast(input_index)); - GE_ASSERT_NOTNULL(in_node); - GE_ASSERT_NOTNULL(node->GetOutDataAnchor(static_cast(output_index))); - auto out_nodes = in_node->GetOutDataNodesPtr(); - GELOGD("Check whether node %s's %zu output can be inplaced, input node is %s input_index[%zu].", - node->GetName().c_str(), output_index, in_node->GetName().c_str(), input_index); - if (IsCurrentNodeHasMaxTopid(node, out_nodes) && HasSameStreamId(in_node, out_nodes) && !HasRefAttr(out_nodes)) { - // 一个输出对应多个输入 - out_index_to_refable_in_indexes[output_index].push_back(input_index); - } - } - - for (auto &item : out_index_to_refable_in_indexes) { - for (auto &index : item.second) { - GELOGD("Node %s's output[%zu] can inplace input[%zu].", node->GetName().c_str(), item.first, index); - } - } - - return GRAPH_SUCCESS; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ComputeGraphPtr GraphUtils::BuildGraphFromNodes(const std::unordered_set &nodes, const std::string &name) { - if (nodes.empty()) { - GELOGW("nodes is empty, no need to build subgraph"); - return nullptr; - } - - GraphInfo graph_info; - BuildGraphInfoFromNodes(nodes, graph_info); - return BuildGraph(graph_info, name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ComputeGraphPtr GraphUtils::BuildGraphInternal(const GraphUtils::GraphInfo &graph_info, - const string &name, - const NodePtr &parent_node) { - // 构造 graph_builder - CompleteGraphBuilder graph_builder(name, false); - - // 如果是子图,设置父节点 - if (parent_node != nullptr) { - (void)graph_builder.SetParentNode(parent_node); - } - - // 添加节点 - for (const auto &node : graph_info.nodes_) { - (void)graph_builder.AddNode(GraphUtils::CopyOpDesc(node->GetOpDesc())); - } - - // 设置输入 - uint32_t index = 0U; - std::map input_mapping; - for (const auto &item : graph_info.data_inputs_) { - for (const auto &in_data_anchor : item.second.second) { - (void)graph_builder.SetInput(index, {in_data_anchor->GetOwnerNodeBarePtr()->GetName()}, - {static_cast(in_data_anchor->GetIdx())}); - if (parent_node != nullptr) { // 仅子图需要设置输入映射 - input_mapping[index] = index; - } - index++; - } - } - - // 设置输入映射(仅子图) - if (parent_node != nullptr) { - (void)graph_builder.SetInputMapping(input_mapping); - } - - // 添加输出 - index = 0U; - std::map output_mapping; - for (const auto &item : graph_info.data_outputs_) { - (void)graph_builder.AddOutput(item.second.first->GetOwnerNodeBarePtr()->GetName(), - static_cast(item.second.first->GetIdx())); - if (parent_node != nullptr) { // 仅子图需要设置输出映射 - output_mapping[index] = index; - } - index++; - } - - // 设置输出映射(仅子图) - if (parent_node != nullptr) { - (void)graph_builder.SetOutputMapping(output_mapping); - } - - // 添加目标节点 - for (const auto &item : graph_info.ctrl_outputs_) { - (void)graph_builder.AddTarget(item.first->GetOwnerNodeBarePtr()->GetName()); - } - - // 添加数据边 - for (const auto &data_edge : graph_info.inner_data_edges_) { - (void)graph_builder.AddDataLink(data_edge.first->GetOwnerNodeBarePtr()->GetName(), - static_cast(data_edge.first->GetIdx()), - data_edge.second->GetOwnerNodeBarePtr()->GetName(), - static_cast(data_edge.second->GetIdx())); - } - - // 添加控制边 - for (const auto &ctrl_edge : graph_info.inner_ctrl_edges_) { - (void)graph_builder.AddControlLink(ctrl_edge.first->GetOwnerNodeBarePtr()->GetName(), - ctrl_edge.second->GetOwnerNodeBarePtr()->GetName()); - } - - // 构建图 - graphStatus error_code = GRAPH_SUCCESS; - std::string error_msg; - auto graph = graph_builder.Build(error_code, error_msg); - if (graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Build graph %s failed:%s.", name.c_str(), error_msg.c_str()); - GELOGE(error_code, "[Build][Graph] %s failed:%s.", name.c_str(), error_msg.c_str()); - } - return graph; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::GenDumpOnnxFileName( - const ComputeGraphPtr &compute_graph, const std::string &suffix, std::string &real_path_name) { -#ifdef FMK_SUPPORT_DUMP - static std::atomic atomic_file_index(0); - const auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump ge onnx file: %" PRId64, file_index); - if (CheckDumpGraphNum(file_index) != GRAPH_SUCCESS) { - return FAILED; - } - - std::stringstream stream_file_name; - GetDumpGraphPrefix(stream_file_name); - if (mmAccess2(stream_file_name.str().c_str(), M_F_OK) != EN_OK) { - const int32_t ret = CreateDir(stream_file_name.str()); - if (ret != 0) { - GELOGW("[DumpGraph][CreateDir] Create dump graph dir failed, path:%s", stream_file_name.str().c_str()); - stream_file_name.str(""); - stream_file_name << "./"; - } - } - - std::string single_op = ""; - if (IsSingleOpScene(compute_graph)) { - single_op = "_aclop"; - } - std::stringstream ss; - ss << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index << single_op; - ss << "_graph_" << compute_graph->GetGraphID() << "_" << GetSanitizedName(suffix) << ".pbtxt"; - std::string dump_file_name = ss.str(); - if ((dump_file_name.length()) >= kNameMax) { - dump_file_name = dump_file_name.substr(0U, kNameMax - 7U) + ".pbtxt"; - GELOGW("[Check][Param] File name is too longer!, file:%s", dump_file_name.c_str()); - } - std::string proto_file = stream_file_name.str() + dump_file_name; - - char_t real_path[MMPA_MAX_PATH] = {}; - /// Returning nullptr means 3 case as follows: - /// a.path is MMPA_MAX_PATH chars or more - /// b.the file does not exist - /// c.the path has no permissions - /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() - auto const ret = mmRealPath(proto_file.c_str(), &(real_path[0]), MMPA_MAX_PATH); - if (ret != EN_OK) { - GELOGD("[Get][RealPath]file does not exist, it will be create. ret:%d", ret); - } - real_path_name = real_path; - return SUCCESS; -#else - (void)compute_graph; - (void)suffix; - (void)real_path_name; - GELOGW("[Gen][OnnxFileName] Need to define FMK_SUPPORT_DUMP for dump graph."); - return FAILED; -#endif -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::GenDumpTxtFileName(const ComputeGraphPtr &compute_graph, const std::string &suffix, - const std::string &user_graph_name, std::string &real_path_name) { -#ifdef FMK_SUPPORT_DUMP - static std::atomic atomic_file_index(0); - const auto file_index = atomic_file_index.fetch_add(1); - GELOGD("Start to dump om txt: %" PRId64, file_index); - if (CheckDumpGraphNum(file_index) != GRAPH_SUCCESS) { - return FAILED; - } - std::stringstream stream_file_name; - std::string single_op = ""; - if (IsSingleOpScene(compute_graph)) { - single_op = "aclop_"; - } - stream_file_name << single_op << "graph_" << compute_graph->GetGraphID() << "_" << suffix; // add graphId, like graph_x_xxxx - auto const ret = GetDumpRealPath(file_index, stream_file_name.str(), user_graph_name, real_path_name); - if (ret != GRAPH_SUCCESS) { - GELOGW("[Get][RealPath]realpath invalid."); - return FAILED; - } - return GRAPH_SUCCESS; -#else - (void)compute_graph; - (void)suffix; - (void)user_graph_name; - (void)real_path_name; - GELOGW("[Gen][TxtFileName] Need to define FMK_SUPPORT_DUMP for dump graph."); - return FAILED; -#endif -} -} // namespace ge diff --git a/graph/utils/graph_utils_ex.cc b/graph/utils/graph_utils_ex.cc deleted file mode 100644 index dde8b275524f81f6d8c212b1d012e8fcdcd8fc67..0000000000000000000000000000000000000000 --- a/graph/utils/graph_utils_ex.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/graph_utils_ex.h" - -#include "common/ge_common/util.h" -#include "common/util/trace_manager/trace_manager.h" -#include "graph/refiner/format_refiner.h" -#include "graph/normal_graph/operator_impl.h" -#include "graph/common_error_codes.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/transformer_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "common/util/mem_utils.h" -#include "graph/utils/op_type_utils.h" - -namespace ge { -graphStatus GraphUtilsEx::InferOriginFormat(const ComputeGraphPtr &graph) { - return FormatRefiner::InferOrigineFormat(graph); -} - -graphStatus GraphUtilsEx::InferShapeInNeed(const ComputeGraphPtr &graph) { - GE_LOGW_IF(graph->TopologicalSorting() != GRAPH_SUCCESS, "Verify failed."); - for (const auto &node_ptr : graph->GetAllNodes()) { - GE_CHECK_NOTNULL(node_ptr); - const auto op_desc = node_ptr->GetOpDesc(); - bool is_need_infer = false; - (void)AttrUtils::GetBool(op_desc, NEED_INFER, is_need_infer); - if (is_need_infer) { - if (NodeUtilsEx::Verify(node_ptr) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Verifying %s failed.", node_ptr->GetName().c_str()); - GELOGE(FAILED, "[Call][Verify] Verifying %s failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - const graphStatus status = NodeUtilsEx::InferShapeAndType(node_ptr); - if ((!OpTypeUtils::IsDataNode(node_ptr->GetType())) && (status == GRAPH_PARAM_INVALID)) { - GELOGI("Op %s does not have the IMPLEMT_INFERFUNC definition, " - "and subsequent operators no longer perform shape inference.", - node_ptr->GetName().c_str()); - break; - } - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Inferring %s failed.", node_ptr->GetName().c_str()); - GELOGE(FAILED, "[Call][InferShapeAndType] Inferring %s failed.", node_ptr->GetName().c_str()); - return GRAPH_FAILED; - } - - for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - GE_CHECK_NOTNULL(out_anchor->GetOwnerNodeBarePtr()->GetOpDesc()); - auto output_tensor = out_anchor->GetOwnerNodeBarePtr()->GetOpDesc()->MutableOutputDesc( - static_cast(out_anchor->GetIdx())); - GE_CHECK_NOTNULL(output_tensor); - TensorUtils::SetRealDimCnt(*(output_tensor.get()), - static_cast(output_tensor->GetShape().GetDims().size())); - - for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { - const auto peer_in_tensor_desc = peer_anchor->GetOwnerNodeBarePtr()->GetOpDesc()->MutableInputDesc( - static_cast(peer_anchor->GetIdx())); - GE_CHECK_NOTNULL(peer_in_tensor_desc); - OpDescUtilsEx::UpdateShapeAndDType(output_tensor, peer_in_tensor_desc); - } - } - } - } - return GRAPH_SUCCESS; -} - -std::vector GraphUtilsEx::GetUserInputDataNodes(const ComputeGraphPtr &compute_graph) { - std::vector user_input_nodes; - for (const auto &node : compute_graph->GetInputNodes()) { - if (!AttrUtils::HasAttr(node->GetOpDesc(), "_is_multi_batch_shape_data")) { - user_input_nodes.emplace_back(node); - } - } - return user_input_nodes; -} - -graphStatus GraphUtilsEx::CopyGraph(const Graph &src_graph, Graph &dst_graph) { - std::string graph_name; - AscendString ascend_name; - if (dst_graph.GetName(ascend_name) == GRAPH_SUCCESS) { - graph_name = std::string((ascend_name.GetString() != nullptr) ? ascend_name.GetString() : ""); - } - if (graph_name.empty() && (src_graph.GetName(ascend_name) == GRAPH_SUCCESS)) { - graph_name = std::string((ascend_name.GetString() != nullptr) ? ascend_name.GetString() : ""); - } - - ComputeGraphPtr new_compute_graph = MakeShared(graph_name); - GE_CHECK_NOTNULL(new_compute_graph); - const ComputeGraphPtr src_compute_graph = GraphUtilsEx::GetComputeGraph(src_graph); - GE_CHECK_NOTNULL(src_compute_graph); - if (src_compute_graph->GetParentGraph() != nullptr) { - GELOGE(GRAPH_FAILED, "[Check][RootGraph] Only support copy root graph, current graph name:%s, " - "parent graph name:%s.", src_compute_graph->GetName().c_str(), - src_compute_graph->GetParentGraph()->GetName().c_str()); - return GRAPH_FAILED; - } - const int32_t depth = 0; - std::map node_old_2_new; - std::map op_desc_old_2_new; - graphStatus ret = GraphUtils::CopyComputeGraph(src_compute_graph, new_compute_graph, - node_old_2_new, op_desc_old_2_new, depth); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][Graph] failed, ret:%d.", ret); - return GRAPH_FAILED; - } - Graph tmp_graph = GraphUtilsEx::CreateGraphFromComputeGraph(new_compute_graph); - ret = GraphUtilsEx::CopyGraphImpl(src_graph, tmp_graph, node_old_2_new, op_desc_old_2_new); - if (ret != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Copy][GraphImpl] failed, ret:%d.", ret); - return GRAPH_FAILED; - } - std::swap(dst_graph, tmp_graph); - return GRAPH_SUCCESS; -} -} // namespace ge - diff --git a/graph/utils/inference_rule.cc b/graph/utils/inference_rule.cc deleted file mode 100644 index 990b191e9fbe420639ae20a29ffc36eef6e45e97..0000000000000000000000000000000000000000 --- a/graph/utils/inference_rule.cc +++ /dev/null @@ -1,903 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd.|Hisilicon Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "inference_rule.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/checker.h" -#include "external/graph/ge_error_codes.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" - -using Json = nlohmann::json; -namespace ge { -namespace { -/** - * @brief 表达一个符号的来源 - * - * 用于描述某个符号源自输入的某个维度或某个值。并支持生成对应的C++定义代码片段。 - */ -class SymbolDef { - public: - explicit SymbolDef(const std::string &name) : name_(name), is_value_(name[0] == 'v') {} - - void RecordSource(size_t input_index, size_t offset) { - sources_.emplace_back(input_index, offset); - } - - [[nodiscard]] std::string Codegen() const { - std::stringstream ss; - if (!sources_.empty()) { - const size_t input = sources_.front().first; - const size_t offset = sources_.front().second; - if (is_value_) { - ss << " GET_SYMBOL_VALUE(" << name_ << ", " << input << ", " << offset << ");"; - } else { - ss << " GET_SYMBOL_DIM(" << name_ << ", " << input << ", " << offset << ");"; - } - } - return ss.str(); - } - - private: - std::string name_; - std::vector> sources_; - bool is_value_; -}; - -/** - * @brief 表达一个Shape维度由符号表达的输出Tensor - * - * 用于描述输出Shape每个维度的计算表达式,表达式是支持受限的表达式(+,-,*,Div,Floor,Ceil,Mod,Pow),也可以是常量表达式。 - */ -class SymbolTensor { - public: - explicit SymbolTensor(const size_t output_index) : output_index_(output_index) {} - - void AppendDim(const std::string &dim) { - dims_.push_back(dim); - } - - // 生成执行时的Shape设置代码片段 - [[nodiscard]] std::string Codegen() const { - std::stringstream ss; - ss << " SET_OUTPUT_RANK(" << output_index_ << ", " << dims_.size() << ");" << std::endl; - for (size_t i = 0; i < dims_.size(); i++) { - ss << " SET_OUTPUT_DIM(" << output_index_ << ", " << i << ", static_cast(" << dims_[i] << "));" - << std::endl; - } - return ss.str(); - } - - // 生成编译时的Shape设置代码片段 - [[nodiscard]] std::string CodegenCompileTime() const { - std::stringstream ss; - ss << " SET_OUTPUT_RANK(" << output_index_ << ", " << dims_.size() << ");" << std::endl; - for (size_t i = 0; i < dims_.size(); i++) { - const bool has_symbol = dims_[i].find('s') != std::string::npos || dims_[i].find('v') != std::string::npos; - ss << " SET_OUTPUT_DIM(" << output_index_ << ", " << i << ", " << (has_symbol ? "-1" : dims_[i]) << ");" - << std::endl; - } - return ss.str(); - } - - private: - size_t output_index_; - std::vector dims_; -}; - -/** - * @brief Shape推导规则的JSON解析器 - * - * 完成推导规则JSON的解析、合法性校验以及到InferShape代码的生成。 - */ -class RuleJsonParser { - public: - std::string ParseJson(const std::string &json_str) { - std::stringstream ss; - Json rule_json; - try { - rule_json = Json::parse(json_str); - } catch (const std::exception &e) { - ss << "Error parsing json: " << e.what(); - return ss.str(); - } - - if (!rule_json.contains("shape")) { - ss << "Missing 'shape' field in rule json."; - return ss.str(); - } - - auto shape_json = rule_json["shape"]; - std::vector> inputs; - std::vector> outputs; - - std::string error_msg = ParseJsonToVecVecString(shape_json["inputs"], inputs); - if (!error_msg.empty()) { - ss << "Invalid 'shape.inputs' field: " << shape_json["inputs"] << " " << error_msg; - return ss.str(); - } - error_msg = ParseJsonToVecVecString(shape_json["outputs"], outputs); - if (!error_msg.empty()) { - ss << "Invalid 'shape.outputs' field: " << shape_json["outputs"] << " " << error_msg; - return ss.str(); - } - std::map symbol_defs; - error_msg = GetInputSymbolDefs(inputs, symbol_defs); - if (!error_msg.empty()) { - ss << "Error parsing input symbols: " << error_msg; - return ss.str(); - } - error_msg = GetOutputSymbolTensors(outputs, symbol_defs, symbols_, symbol_tensors_); - if (!error_msg.empty()) { - ss << "Error parsing output tensors: " << error_msg; - return ss.str(); - } - return ss.str(); - } - - void CodegenInferShape(std::stringstream &code_ss) const { - code_ss << R"(extern "C" {)"; - code_ss << R"(bool infer_shape(Ctx *ctx) {)" << std::endl; - - for (const auto &symbol : symbols_) { - code_ss << symbol.Codegen() << std::endl; - } - - code_ss << std::endl; - - for (const auto &tensor : symbol_tensors_) { - code_ss << tensor.Codegen() << std::endl; - } - - code_ss << " return true;\n}" << std::endl; - - code_ss << R"(bool infer_shape_on_compile(Ctx *ctx) {)" << std::endl; - for (const auto &tensor : symbol_tensors_) { - code_ss << tensor.CodegenCompileTime() << std::endl; - } - code_ss << " return true;\n}"; - - code_ss << "}"; - } - - private: - std::vector symbols_; - std::vector symbol_tensors_; - - static std::string GetInputSymbolDefs(const std::vector> &inputs, - std::map &symbol_defs) { - for (size_t i = 0; i < inputs.size(); i++) { - const auto &dims = inputs[i]; - for (size_t j = 0; j < dims.size(); j++) { - const auto &dim = dims[j]; - if (dim.empty() || IsNumber(dim)) { - continue; - } - if (!IsSymbol(dim)) { - std::stringstream ss; - ss << "Invalid input[" << i << "].size(" << j << "): " << dim - << ", symbol dimension must start with 's' or 'v' and follow with a number"; - return ss.str(); - } - auto it = symbol_defs.find(dim); - if (it != symbol_defs.end()) { - // 已经存在,记录来源 - it->second.RecordSource(i, j); - } else { - // 新建符号定义 - SymbolDef symbol(dim); - symbol.RecordSource(i, j); - symbol_defs.emplace(dim, std::move(symbol)); - } - } - } - return ""; - } - - static std::string GetOutputSymbolTensors(const std::vector> &outputs, - const std::map &symbol_defs, - std::vector &used_symbol_defs, - std::vector &symbol_tensors) { - std::set used_symbols; - std::stringstream ss; - for (size_t i = 0; i < outputs.size(); i++) { - symbol_tensors.emplace_back(i); - const auto &dims = outputs[i]; - - for (size_t j = 0; j < dims.size(); j++) { - auto &dim = dims[j]; - if (dim.empty()) { - ss << "Invalid output[" << i << "].size(" << j << "): empty dimension"; - return ss.str(); - } - std::string error_msg = ValidateDimExpr(dim, used_symbols); - if (!error_msg.empty()) { - ss << "Invalid dim expr '" << dim << "': " << error_msg; - return ss.str(); - } - symbol_tensors.back().AppendDim(dim); - } - } - - for (const auto &symbol : used_symbols) { - auto it = symbol_defs.find(symbol); - if (it == symbol_defs.end()) { - ss << "Symbol '" << symbol << "' used in output but not defined in inputs"; - return ss.str(); - } - used_symbol_defs.emplace_back(it->second); - } - - return ""; - } - - static std::string ValidateDimExpr(std::string expr, std::set &used_symbols) { - expr.erase(remove_if(expr.begin(), expr.end(), isspace), expr.end()); - - // 2. 定义 token 正则 - // - 函数/变量名: [A-Za-z0-9_]* - // - 运算符: [+*()-,] - const std::regex token_regex(R"([A-Za-z0-9_]*|\+|\-|\*|\(|\)|,)"); - const auto begin = std::sregex_iterator(expr.begin(), expr.end(), token_regex); - const auto end = std::sregex_iterator(); - - std::vector tokens; // 存储匹配到的 token,应当为操作符、操作数、函数名、括号之一 - for (auto it = begin; it != end; ++it) { - if (!it->str().empty()) { - tokens.push_back(it->str()); - } - } - - // 检查是否所有字符都被匹配(防止非法字符) - size_t totalLen = 0U; - for (auto &t : tokens) totalLen += t.size(); - if (totalLen != expr.size()) { - return "Expression contains invalid characters"; - } - - // 3. 遍历 tokens 检查合法性 - std::stack func_stack; - for (size_t i = 0U; i < tokens.size(); i++) { - const std::string &token = tokens[i]; - - if (std::isalpha(token[0])) { - if (i + 1U < tokens.size() && tokens[i + 1U] == "(") { - if (!IsSupportedFunc(token)) { - return "Invalid function: " + token + ", supported [Div, Floor, Ceil, Pow, Mod]"; - } - } else { - used_symbols.insert(token); - } - } else if (token == "(") { - func_stack.emplace("("); - } else if (token == ")") { - if (func_stack.empty()) { - return "Unmatched ')'"; - } - func_stack.pop(); - } else if (IsSupportedOperator(token) || IsNumber(token)) { - // 运算符不做额外语法检查,由C++编译器处理 - } else { - return "Invalid identifier: '" + token + "', expected start with 's' or 'v' and follow with a number"; - } - } - - if (!func_stack.empty()) { - return "Unmatched '('"; - } - - return ""; - } - - static std::string ParseJsonToVecVecString(const Json &json, std::vector> &result) { - if (json.is_null()) { - return ""; - } - if (!json.is_array()) { - return "field must be an array or null."; - } - - for (const auto &dims : json) { - if (dims.is_null()) { - result.emplace_back(); - continue; - } - if (!dims.is_array()) { - return "element must be an array of dimension expressions."; - } - result.emplace_back(); - for (const auto &dim : dims) { - if (dim.is_null()) { - result.back().emplace_back(); - continue; - } - if (!dim.is_string() && !dim.is_number_integer()) { - return "dimension expression must be a string or integer."; - } - result.back().push_back(dim.is_string() ? dim.get() : std::to_string(dim.get())); - } - } - return ""; - } - - static bool IsSymbol(const std::string &token) { - // 符号必须以 's' 或 'v' 开头,后跟数字 - return token.size() > 1 && (token[0] == 's' || token[0] == 'v') && IsNumber(&token[1]); - } - - static bool IsSupportedFunc(const std::string &func) { - static const std::unordered_set kAllowedFuncs = {"Div", "Floor", "Ceil", "Pow", "Mod"}; - return kAllowedFuncs.find(func) != kAllowedFuncs.end(); - } - - static bool IsSupportedOperator(const std::string &op) { - // 支持的运算符 - return op == "+" || op == "-" || op == "*" || op == ","; - } - - static bool IsNumber(const std::string &s) { - try { - size_t idx; - std::stod(s, &idx); - return idx == s.size(); // 必须整个字符串都被解析 - } catch (...) { - return false; - } - } -}; - -/** - * @brief Cpp JIT编译器 - * - * 用于将生成的C++代码编译为内存中的.so,并加载以供调用。 - */ -class CppJitCompiler { - public: - std::string Error() const { - return err_.str(); - } - - std::vector Compile(const std::string &source_code) { - std::vector so_data; - - const int32_t cpp_fd = CreateMemFd("source.cpp"); - const int32_t so_fd = CreateMemFd("output.so"); - if (cpp_fd == -1 || so_fd == -1) { - err_ << "mem fd create failed: " << strerror(errno); - return {}; - } - - ClearCloexec(cpp_fd); - ClearCloexec(so_fd); - - if (!WriteToFd(cpp_fd, source_code)) { - err_ << "write source code to mem fd failed: " << strerror(errno); - return {}; - } - - lseek(cpp_fd, 0, SEEK_SET); - lseek(so_fd, 0, SEEK_SET); - - if (!CompileToSo(cpp_fd, so_fd)) { - return {}; - } - - lseek(so_fd, 0, SEEK_SET); - - char buf[4096]; - ssize_t n; - while ((n = read(so_fd, buf, sizeof(buf))) > 0) { - so_data.insert(so_data.end(), buf, buf + n); - } - - close(cpp_fd); - close(so_fd); - return so_data; - } - - void *Load(const uint8_t *so_binary, const size_t so_size) { - static std::atomic loaded{0}; - - char tmp_filename[256] = {}; - // make sure the filename is unique for disable cache for dlopen - const std::string filename = "/tmp/temp_so" + std::to_string(loaded++) + "XXXXXX"; - if (snprintf_s(tmp_filename, sizeof(tmp_filename), filename.size(), "%s", filename.c_str()) < 0) { - err_ << "snprintf file name failed: " << strerror(errno); - return nullptr; - } - - const int32_t fd = mkstemp(tmp_filename); - if (fd == -1) { - err_ << "mkstemp failed: " << strerror(errno); - return nullptr; - } - - const ssize_t written = write(fd, so_binary, so_size); - if (written != static_cast(so_size)) { - err_ << "write so binary to temp file failed: " << strerror(errno); - close(fd); - unlink(tmp_filename); - return nullptr; - } - - close(fd); - - void *handle = dlopen(tmp_filename, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - err_ << "dlopen failed: " << dlerror(); - unlink(tmp_filename); - return nullptr; - } - - unlink(tmp_filename); - return handle; - } - - private: - std::stringstream err_; - static std::string GetSystemCompiler() { - if (system("g++ --version > /dev/null 2>&1") == 0) { - return "g++"; - } - if (system("gcc --version > /dev/null 2>&1") == 0) { - return "gcc"; - } - return ""; - } - - static int32_t CreateMemFd(const std::string &name) { - return syscall(__NR_memfd_create, name.c_str(), MFD_CLOEXEC); - } - - static void ClearCloexec(const int32_t fd) { - const int32_t flags = fcntl(fd, F_GETFD); - if (flags != -1) { - fcntl(fd, F_SETFD, flags & ~FD_CLOEXEC); - } - } - - static bool WriteToFd(const int32_t fd, const std::string &data) { - size_t written = 0; - while (written < data.size()) { - const ssize_t n = write(fd, data.data() + written, data.size() - written); - if (n <= 0) { - return false; - } - written += n; - } - return true; - } - - bool CompileToSo(const int32_t input_fd, const int32_t output_fd) { - const std::string input_path = "/proc/self/fd/" + std::to_string(input_fd); - const std::string output_path = "/proc/self/fd/" + std::to_string(output_fd); - - const std::string compiler = GetSystemCompiler(); - if (compiler.empty()) { - err_ << "No C++ compiler found (g++ or gcc) for jit compiling symbol infer"; - return false; - } - - const std::vector args = { - compiler.c_str(), "-x", "c++", "-shared", "-fPIC", "-o", output_path.c_str(), - input_path.c_str(), "-lstdc++", nullptr}; - - const pid_t pid = fork(); - if (pid == 0) { - execvp(compiler.c_str(), const_cast(args.data())); - _exit(1); - } - - int32_t status = 0; - waitpid(pid, &status, 0); - const bool succeed = WIFEXITED(status) && WEXITSTATUS(status) == 0; - if (!succeed) { - err_ << "syntax error"; - } - return succeed; - } -}; - -const std::string kHeader = R"( -#include -#include - -inline double Pow(const double base, const double exp) { return std::pow(base, exp); } -inline double Floor(const double x) { return std::floor(x); } -inline double Div(const double x, const double y) { return x / y; } -inline double Ceil(const double x) { return std::ceil(x); } -inline double Mod(const double a, const double b) { - double r = std::fmod(a, b); - if ((r != 0) && ((b < 0 && r > 0) || (b > 0 && r < 0))) { - r += b; - } - return r; -} - -extern "C" { -int64_t version() { return 1; } -} - -class Ctx { - public: - virtual ~Ctx() = default; - virtual bool GetInputDim(int64_t input, int64_t dim_index, int64_t &dim) = 0; - virtual bool GetInputValue(int64_t input, int64_t offset, int64_t &value) = 0; - virtual bool SetOutputDimNum(int64_t output, int64_t dim_num) = 0; - virtual bool SetOutputDim(int64_t output, int64_t dim_index, int64_t dim) = 0; - virtual void SetError(const char *) = 0; -}; - -#define GET_SYMBOL_DIM(S, INPUT, DIM) \ -int64_t S##_int; \ -if (!ctx->GetInputDim(INPUT, DIM, S##_int)) { \ - ctx->SetError("Failed to get dim sym '" #S "' from input[" #INPUT "], dim: " #DIM); \ - return false; \ -} \ -const double S = static_cast(S##_int); - -#define GET_SYMBOL_VALUE(S, INPUT, DIM) \ -int64_t S##_int; \ -if (!ctx->GetInputValue(INPUT, DIM, S##_int)) { \ - ctx->SetError("Failed to get value sym '" #S "' from input[" #INPUT "], offset: " #DIM); \ - return false; \ -} \ -const double S = static_cast(S##_int); - -#define SET_OUTPUT_RANK(OUTPUT, RANK) \ -if (!ctx->SetOutputDimNum(OUTPUT, RANK)) { \ - ctx->SetError("Failed to set rank " #RANK " for output[" #OUTPUT "]"); \ - return false; \ -} - -#define SET_OUTPUT_DIM(OUTPUT, INDEX, DIM) \ -if (!ctx->SetOutputDim(OUTPUT, INDEX, DIM)) { \ - ctx->SetError("Failed to set dim " #DIM " for output[" #OUTPUT "], dim: " #INDEX); \ - return false; \ -} -)"; - -/** - * @brief 适用于GertCtx的包装器 - * - * Jit生成InferShape代码时,设计时保证不使用任何本地头文件参与编译,通过运行时的Ctx封装,隔离本地文件依赖。 - */ -class GertContextWrapper final : public ShapeInferenceRule::Ctx { - public: - explicit GertContextWrapper(gert::InferShapeContext *ctx) : ctx_(ctx) {} - - bool GetInputDim(int64_t input, int64_t dim_index, int64_t &dim) override { - const auto shape = ctx_->GetInputShape(input); - if (shape == nullptr) { - return false; - } - dim = shape->GetDim(dim_index); - return true; - } - - bool GetInputValue(int64_t input, int64_t offset, int64_t &value) override { - auto *tensor = ctx_->GetInputTensor(input); - if (tensor == nullptr || tensor->GetAddr() == nullptr) { - return false; - } - if (offset < 0 || offset >= tensor->GetShapeSize()) { - return false; - } - if (tensor->GetDataType() == ge::DT_INT64) { - value = tensor->GetData()[offset]; - } else if (tensor->GetDataType() == ge::DT_INT32) { - value = tensor->GetData()[offset]; - } else if (tensor->GetDataType() == ge::DT_UINT32) { - value = tensor->GetData()[offset]; - } else { - SetError("Only int32, uint32 and int64 are supported for input value tensors"); - return false; - } - return true; - } - - bool SetOutputDimNum(int64_t output, int64_t dim_num) override { - const auto shape = ctx_->GetOutputShape(output); - if (shape == nullptr) { - return false; - } - shape->SetDimNum(dim_num); - return true; - } - - bool SetOutputDim(int64_t output, int64_t dim_index, int64_t dim) override { - const auto shape = ctx_->GetOutputShape(output); - if (shape == nullptr) { - return false; - } - shape->SetDim(dim_index, dim); - return true; - } - - void SetError(const char *msg) override { - if (msg != nullptr) { - error_message_ << msg << std::endl; - } - } - - std::string Error() const { - return error_message_.str(); - } - - private: - gert::InferShapeContext *ctx_ = nullptr; - std::stringstream error_message_; -}; - -template -class Cache { - public: - std::shared_ptr Get(const std::string &key) { - std::lock_guard lock(mtx_); - auto it = cache_.find(key); - if (it != cache_.end()) { - return it->second; - } - return nullptr; - } - - std::shared_ptr GetWithDefault(const std::string &key, const std::shared_ptr &value) { - std::lock_guard lock(mtx_); - return cache_.emplace(key, value).first->second; - } - - private: - std::mutex mtx_; - std::map> cache_; -}; - -Cache g_shape_rule_cache; -Cache g_dtype_rule_cache; -} // namespace - -ShapeInferenceRule::~ShapeInferenceRule() { - if (handle_) { - dlclose(handle_); - handle_ = nullptr; - infer_shape_ = nullptr; - infer_shape_on_compile_ = nullptr; - } -} - -ge::graphStatus ShapeInferenceRule::InferOnRuntime(Ctx *ctx) const { - if (!infer_shape_) { - ctx->SetError("infer_shape function is not set"); - return ge::GRAPH_FAILED; - } - if (!infer_shape_(ctx)) { - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ShapeInferenceRule::InferOnCompile(Ctx *ctx) const { - if (!infer_shape_on_compile_) { - ctx->SetError("infer_shape_on_compile function is not set"); - return ge::GRAPH_FAILED; - } - if (!infer_shape_on_compile_(ctx)) { - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ShapeInferenceRule::InferOnRuntime(gert::InferShapeContext *infer_shape_ctx) const { - GE_ASSERT_NOTNULL(infer_shape_ctx); - GertContextWrapper ctx(infer_shape_ctx); - const ge::graphStatus result = InferOnRuntime(&ctx); - if (result != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Failed infer shape by rule for op %s(%s): %s", infer_shape_ctx->GetNodeName(), - infer_shape_ctx->GetNodeType(), ctx.Error().c_str()); - } - return result; -} - -ge::graphStatus ShapeInferenceRule::InferOnCompile(gert::InferShapeContext *infer_shape_ctx) const { - GE_ASSERT_NOTNULL(infer_shape_ctx); - GertContextWrapper ctx(infer_shape_ctx); - const ge::graphStatus result = InferOnCompile(&ctx); - if (result != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Failed infer shape on compile by rule for op %s(%s): %s", infer_shape_ctx->GetNodeName(), - infer_shape_ctx->GetNodeType(), ctx.Error().c_str()); - } - return result; -} - -std::string InferenceRule::GetInferenceRule(const ge::OpDescPtr &op) { - if (op == nullptr) { - return ""; - } - std::string rule_json; - (void) ge::AttrUtils::GetStr(op, ge::ATTR_NAME_INFER_RULE, rule_json); - return rule_json; -} - -std::shared_ptr ShapeInferenceRule::FromOpDesc(const ge::OpDescPtr &op) { - std::string rule_json; - if (!ge::AttrUtils::GetStr(op, ge::ATTR_NAME_INFER_RULE, rule_json)) { - // Skip log error if op does not with rule - return nullptr; - } - return FromJsonString(rule_json); -} - -std::shared_ptr ShapeInferenceRule::FromJsonString(const std::string &json_str) { - auto cached = g_shape_rule_cache.Get(json_str); - if (cached != nullptr) { - return cached; - } - - const auto rule = std::make_shared(); - RuleJsonParser parser; - const std::string error_msg = parser.ParseJson(json_str); - if (!error_msg.empty()) { - *rule << error_msg; - return g_shape_rule_cache.GetWithDefault(json_str, rule); - } - - std::stringstream gen_code_ss; - parser.CodegenInferShape(gen_code_ss); - - std::stringstream code_ss; - code_ss << kHeader << std::endl; - code_ss << gen_code_ss.str() << std::endl; - - CppJitCompiler compiler; - const auto binary = compiler.Compile(code_ss.str()); - if (binary.empty()) { - *rule << "Failed to compile C++ code to shared object:\n" << gen_code_ss.str() << "\nError: " << compiler.Error(); - return g_shape_rule_cache.GetWithDefault(json_str, rule); - } - return g_shape_rule_cache.GetWithDefault(json_str, std::make_shared(FromCompiledBinary(binary))); -} - -ShapeInferenceRule ShapeInferenceRule::FromCompiledBinary(const uint8_t *binary, const size_t size) { - ShapeInferenceRule infer_handle; - CppJitCompiler compiler; - void *handle = compiler.Load(binary, size); - if (!handle) { - infer_handle << "Failed to load compiled shared object from memory: " << compiler.Error(); - return infer_handle; - } - - infer_handle.handle_ = handle; - infer_handle.infer_shape_ = (InferShapeFunc) dlsym(handle, "infer_shape"); - if (!infer_handle.infer_shape_) { - infer_handle << "dlsym infer_shape failed: " << dlerror(); - return infer_handle; - } - infer_handle.infer_shape_on_compile_ = (InferShapeFunc) dlsym(handle, "infer_shape_on_compile"); - if (!infer_handle.infer_shape_on_compile_) { - infer_handle << "dlsym infer_shape_on_compile failed: " << dlerror(); - return infer_handle; - } - return infer_handle; -} - -ShapeInferenceRule ShapeInferenceRule::FromCompiledBinary(const std::vector &binary) { - return FromCompiledBinary(binary.data(), binary.size()); -} - -ge::graphStatus ShapeInferenceRule::CompileJsonString(const std::string &json_str, std::vector &binary) { - RuleJsonParser parser; - const std::string error_msg = parser.ParseJson(json_str); - if (!error_msg.empty()) { - GELOGE(ge::FAILED, "%s", error_msg.c_str()); - return ge::GRAPH_FAILED; - } - - std::stringstream code_ss; - code_ss << kHeader << std::endl; - parser.CodegenInferShape(code_ss); - - CppJitCompiler compiler; - binary = compiler.Compile(code_ss.str()); - if (binary.empty()) { - GELOGE(ge::FAILED, "Failed to compile C++ code to shared object:%s,\nError:%s", code_ss.str().c_str(), - compiler.Error().c_str()); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus DtypeInferenceRule::InferDtype(gert::InferDataTypeContext *infer_dtype_ctx) const { - GE_ASSERT_NOTNULL(infer_dtype_ctx); - if (!Error().empty()) { - GELOGE(ge::FAILED, "Failed infer dtype by rule for op %s(%s): %s", infer_dtype_ctx->GetNodeName(), - infer_dtype_ctx->GetNodeType(), Error().c_str()); - return ge::GRAPH_FAILED; - } - for (size_t i = 0U; i < dtypes_.size(); i++) { - GE_ASSERT_GRAPH_SUCCESS(infer_dtype_ctx->SetOutputDataType(i, dtypes_[i])); - } - return ge::GRAPH_SUCCESS; -} - -std::shared_ptr DtypeInferenceRule::FromOpDesc(const ge::OpDescPtr &op) { - std::string rule_json; - if (!ge::AttrUtils::GetStr(op, ge::ATTR_NAME_INFER_RULE, rule_json)) { - // Skip log error if op does not with rule - return nullptr; - } - return FromJsonString(rule_json); -} - -std::shared_ptr DtypeInferenceRule::FromJsonString(const std::string &json_str) { - auto cached = g_dtype_rule_cache.Get(json_str); - if (cached != nullptr) { - return cached; - } - - const auto rule = std::make_shared(); - Json rule_json; - try { - rule_json = Json::parse(json_str); - } catch (const std::exception &e) { - *rule << "Error parsing json: " << e.what(); - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - if (!rule_json.contains("dtype")) { - *rule << "Missing 'dtype' field in rule json."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - const auto dtype_json = rule_json["dtype"]; - if (dtype_json.is_null()) { - *rule << "Filed 'dtype' must not be null."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - if (!dtype_json.is_array()) { - *rule << "Field 'dtype' must be an array."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - - for (const auto &dtype : dtype_json) { - if (dtype.is_null()) { - *rule << "Element in 'dtype' field must not be null."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - if (!dtype.is_number_integer()) { - *rule << "Element in 'dtype' field must be an integer."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - const int32_t dtype_value = dtype.get(); - if (dtype_value >= ge::DataType::DT_MAX || dtype_value < 0 || dtype_value == ge::DataType::DT_UNDEFINED) { - *rule << "Element " << dtype_value << " in 'dtype' field is out of range [0," << ge::DataType::DT_MAX - << "(DT_MAX)) and cannot be " << ge::DataType::DT_UNDEFINED << "(DT_UNDEFINED)."; - return g_dtype_rule_cache.GetWithDefault(json_str, rule); - } - rule->dtypes_.emplace_back(static_cast(dtype_value)); - } - - return g_dtype_rule_cache.GetWithDefault(json_str, rule); -} -} // namespace ge diff --git a/graph/utils/inference_rule.h b/graph/utils/inference_rule.h deleted file mode 100644 index a836ca818ee67d37414699d8f5b726901062b6ab..0000000000000000000000000000000000000000 --- a/graph/utils/inference_rule.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd.|Hisilicon Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_INFERENCE_RULE_H -#define REGISTER_INFERENCE_RULE_H - -#include -#include - -#include "external/exe_graph/runtime/infer_shape_context.h" -#include "external/exe_graph/runtime/infer_datatype_context.h" -#include "graph/op_desc.h" - -namespace ge { -/** - * @brief 推导规则基类 - * - * 为了引导原始错误记录在对象上,不分散的打印日志,有助于向用户展示明确报错。 - */ -class InferenceRule { - public: - template - InferenceRule &operator<<(const T &msg) { - err_ << msg; - return *this; - } - - std::string Error() const { - return err_.str(); - } - - bool IsValid() const { - return err_.str().empty(); - } - static std::string GetInferenceRule(const ge::OpDescPtr &op); - - protected: - std::stringstream err_; -}; - -/** - * @brief Shape推导实现类 - * - * 负责从不同类型的输入编译并加载得到Shape推导可执行函数,并与GE数据结构配合工作。 - */ -class ShapeInferenceRule : public InferenceRule { - public: - // Ctx接口定义,供推导函数调用,不依赖任何头文件。实现与用户环境完全隔离。 - class Ctx { - public: - virtual ~Ctx() = default; - - virtual bool GetInputDim(int64_t input, int64_t dim_index, int64_t &dim) = 0; - - virtual bool GetInputValue(int64_t input, int64_t offset, int64_t &value) = 0; - - virtual bool SetOutputDimNum(int64_t output, int64_t dim_num) = 0; - - virtual bool SetOutputDim(int64_t output, int64_t dim_index, int64_t dim) = 0; - - virtual void SetError(const char *) = 0; - }; - - using InferShapeFunc = bool (*)(Ctx *); - - ShapeInferenceRule() : handle_(nullptr), infer_shape_(nullptr), infer_shape_on_compile_(nullptr) {} - ~ShapeInferenceRule(); - ShapeInferenceRule(const ShapeInferenceRule &) = delete; - ShapeInferenceRule &operator=(const ShapeInferenceRule &) = delete; - ShapeInferenceRule &operator=(ShapeInferenceRule &&other) = delete; - ShapeInferenceRule(ShapeInferenceRule &&other) noexcept { - handle_ = other.handle_; - infer_shape_ = other.infer_shape_; - infer_shape_on_compile_ = other.infer_shape_on_compile_; - err_ << other.err_.str(); - other.handle_ = nullptr; - other.infer_shape_ = nullptr; - other.infer_shape_on_compile_ = nullptr; - } - - static std::shared_ptr FromOpDesc(const ge::OpDescPtr &op); - static std::shared_ptr FromJsonString(const std::string &json_str); - - // 编译后的二进制以属性的方式保存在节点上,用于RT2执行时加载 - static ge::graphStatus CompileJsonString(const std::string &json_str, std::vector &binary); - static ShapeInferenceRule FromCompiledBinary(const std::vector &binary); - static ShapeInferenceRule FromCompiledBinary(const uint8_t *binary, size_t size); - - ge::graphStatus InferOnRuntime(gert::InferShapeContext *infer_shape_ctx) const; - ge::graphStatus InferOnCompile(gert::InferShapeContext *infer_shape_ctx) const; - - ge::graphStatus InferOnRuntime(Ctx *ctx) const; - ge::graphStatus InferOnCompile(Ctx *ctx) const; - - private: - void *handle_; - InferShapeFunc infer_shape_; - InferShapeFunc infer_shape_on_compile_; -}; - -/** - * @brief Dtype推导实现类 - * - * 负责从不同类型的解析得到Shape推导可执行函数,并与GE图结构配合工作。Dtype推导实现无需编译。 - */ -class DtypeInferenceRule : public InferenceRule { - public: - static std::shared_ptr FromOpDesc(const ge::OpDescPtr &op); - static std::shared_ptr FromJsonString(const std::string &json_str); - - ge::graphStatus InferDtype(gert::InferDataTypeContext *infer_dtype_ctx) const; - - private: - std::vector dtypes_; -}; -} // namespace ge -#endif diff --git a/graph/utils/multi_thread_graph_builder.cc b/graph/utils/multi_thread_graph_builder.cc deleted file mode 100644 index 522a548526e5c52ee600a3670633d0d8c955481b..0000000000000000000000000000000000000000 --- a/graph/utils/multi_thread_graph_builder.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/multi_thread_graph_builder.h" -#include "graph/normal_graph/operator_impl.h" -#include "graph/debug/ge_util.h" - -namespace ge { -MultiThreadGraphBuilder::MultiThreadGraphBuilder(int32_t thread_num) - : thread_num_(thread_num < 1 ? 1 : thread_num) {} - -graphStatus MultiThreadGraphBuilder::GetGraphRelatedOperators(const std::vector &inputs, - std::vector &related_ops) { - std::vector vec_inputs; - for (auto &it : inputs) { - GE_CHECK_NOTNULL(it.operator_impl_); - vec_inputs.push_back(it.operator_impl_); - } - GE_CHK_GRAPH_STATUS_RET(WalkForwardOperators(vec_inputs, related_ops), - "Fail to walk all forward operators."); - return GRAPH_SUCCESS; -} - -void MultiThreadGraphBuilder::GetOutputLinkOps(const OperatorImplPtr &op_impl, - std::vector &output_op_impls) { - for (const auto &out_link : op_impl->output_links_) { - for (const auto &op_forward : out_link.second) { - output_op_impls.push_back(op_forward.GetOwner()); - } - } - auto &out_control_links = op_impl->control_output_link_; - for (const auto &out_control_link : out_control_links) { - output_op_impls.push_back(out_control_link.lock()); - } -} - -graphStatus MultiThreadGraphBuilder::WalkForwardOperators(const std::vector &vec_ops, - std::vector &related_ops) { - std::set all_impls; - std::queue> que; - que.push(vec_ops); - while (!que.empty()) { - const auto vec_tem = que.front(); - que.pop(); - for (const auto &op_impl : vec_tem) { - GE_CHECK_NOTNULL(op_impl); - if (all_impls.find(op_impl) == all_impls.cend()) { - all_impls.emplace(op_impl); - std::vector vec_op_forward{}; - GetOutputLinkOps(op_impl, vec_op_forward); - que.push(vec_op_forward); - } - } - } - - for (auto impl : all_impls) { - related_ops.emplace_back(impl); - } - return GRAPH_SUCCESS; -} - -void MultiThreadGraphBuilder::ResetOpSubgraphBuilder(const OpDescPtr &op_desc, OperatorImplPtr &op_impl) { - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - for (const auto &name_idx : subgraph_names_to_index) { - const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first.c_str()); - if (builder == nullptr) { - continue; - } - std::shared_future future_graph = pool_->commit([builder]() -> Graph { - return builder(); - }); - auto future_graph_ptr = std::make_shared>(future_graph); - auto graph_builder = [future_graph_ptr, builder]() mutable { - ge::Graph graph; - if (future_graph_ptr->valid()) { - graph = future_graph_ptr->get(); - // reset shared_future to release graph ownner, can not be invoked twice - *future_graph_ptr = std::shared_future(); - } else { - // use default builder - graph = builder(); - } - return graph; - }; - op_impl->SetSubgraphBuilder(name_idx.first.c_str(), name_idx.second, graph_builder); - } -} - -Graph &MultiThreadGraphBuilder::SetInputs(const std::vector &inputs, ge::Graph &graph) { - { - const std::lock_guard lock(mutex_); - if (thread_num_ > 1 && pool_ == nullptr) { - pool_ = ComGraphMakeUnique(thread_num_); - } - } - - if (pool_ != nullptr) { - GELOGI("Build subgraph async, thread num = %d.", thread_num_); - std::vector all_related_ops; - (void)GetGraphRelatedOperators(inputs, all_related_ops); - for (auto &op_impl : all_related_ops) { - if (op_impl->op_desc_ != nullptr) { - ResetOpSubgraphBuilder(op_impl->op_desc_, op_impl); - } - } - } - return graph.SetInputs(inputs); -} -} // namespace ge diff --git a/graph/utils/node_utils.cc b/graph/utils/node_utils.cc deleted file mode 100644 index d3bdae68b8cde1b865aeb4f116c57ea2a75180c5..0000000000000000000000000000000000000000 --- a/graph/utils/node_utils.cc +++ /dev/null @@ -1,1416 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/node_utils.h" -#include -#include -#include "graph/utils/op_type_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/ge_context.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/constant_utils.h" -#include "common/checker.h" - -namespace ge { -const std::set kConstOpTypes{"Const", "Constant"}; - -const std::set kEnterOpTypes{"Enter", "RefEnter"}; -const std::set kMergeOpTypes{"Merge", "RefMerge"}; -const std::set kSwitchOpTypes{"Switch", "RefSwitch"}; -const std::set kNextIterationOpTypes{"NextIteration", "RefNextIteration"}; -const std::set kExitOpTypes{"Exit", "RefExit"}; - -const std::set kIfOpTypes{"If", "_If", "StatelessIf"}; -const std::set kWhileOpTypes{"While", "_While", "StatelessWhile"}; -const std::set kCaseOpTypes{"Case"}; -const std::set kForOpTypes{"For"}; - -const char_t *const kRefIndex = "_parent_node_index"; -const char_t *const kPartSrcGraph = "part_src_graph"; - -namespace { -constexpr int32_t kInvalidIndex = -1; -bool OpShapeIsUnknown(const OpDescPtr &desc) { - for (const auto &ptr : desc->GetAllInputsDescPtr()) { - const auto ge_shape = ptr->GetShape(); - auto dims = ge_shape.GetDims(); - if (std::any_of(dims.begin(), dims.end(), - [](const int64_t dim) { return ((dim == UNKNOWN_DIM) || (dim == (UNKNOWN_DIM_NUM))); })) { - return true; - } - } - for (const auto &ptr : desc->GetAllOutputsDescPtr()) { - const auto ge_shape = ptr->GetShape(); - auto dims = ge_shape.GetDims(); - if (std::any_of(dims.begin(), dims.end(), - [](const int64_t dim) { return ((dim == UNKNOWN_DIM) || (dim == (UNKNOWN_DIM_NUM))); })) { - return true; - } - } - return false; -} - -bool IsComputableOp(const NodePtr &node) { - if ((node->GetType() == DATA) || (node->GetType() == NETOUTPUT)) { - return false; - } - if (!node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { - return false; - } - return true; -} -} // namespace - -graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { - GE_CHK_BOOL_EXEC((node_ptr != nullptr) && (node_ptr->impl_ != nullptr) && (in_data_anchor != nullptr), - REPORT_INNER_ERR_MSG("E18888", "param node or in_data_anchor is nullptr, check invalid."); - return GRAPH_FAILED, "[Check][Param] node or in_data_anchor is nullptr"); - bool find_flag = false; - uint32_t index = 0U; - std::vector::iterator it = node_ptr->impl_->in_data_anchors_.end(); - for (const auto &tmp : node_ptr->impl_->in_data_anchors_) { - if (tmp == in_data_anchor) { - find_flag = true; - const auto iter = node_ptr->impl_->in_data_anchors_.begin() + static_cast(index); - if (iter != node_ptr->impl_->in_data_anchors_.end()) { - it = node_ptr->impl_->in_data_anchors_.erase(iter); - } - break; - } - index++; - } - while (it != node_ptr->impl_->in_data_anchors_.end()) { - (*it)->SetIdx(static_cast(index)); - index++; - ++it; - } - - if (!find_flag) { - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr, REPORT_INNER_ERR_MSG("E18888", "param node_ptr is nullptr, check invalid"); - return GRAPH_FAILED, "[Check][Param] node is nullptr"); - GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "SetAllAnchorStatus failed, node:%s", node_ptr->GetName().c_str()); - return GRAPH_FAILED, "[Set][AllAnchorStatus] failed, node:%s", node_ptr->GetName().c_str()); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::SetAllAnchorStatus(Node &node) { - if (node.impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Param node impl is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Node impl is nullptr."); - return GRAPH_FAILED; - } - node.impl_->anchor_status_updated_ = true; - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) { - GE_CHK_BOOL_EXEC(node_ptr != nullptr, REPORT_INNER_ERR_MSG("E18888", "param node_ptr is nullptr, check invalid"); - return false, "[Check][Param] node is nullptr"); - return IsAnchorStatusSet(*node_ptr); -} - -bool NodeUtils::IsAnchorStatusSet(const Node &node) { - if (node.impl_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Param node impl is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Node impl is nullptr."); - return false; - } - return node.impl_->anchor_status_updated_; -} - -graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) { - if ((origin_node == nullptr) || (new_node == nullptr)) { - return GRAPH_FAILED; - } - auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors(); - const auto origin_out_data_anchors_size = origin_out_data_anchors.size(); - auto new_out_data_anchors = new_node->GetAllOutDataAnchors(); - if (origin_out_data_anchors_size != new_out_data_anchors.size()) { - return GRAPH_FAILED; - } - - for (size_t i = 0UL; i < origin_out_data_anchors_size; ++i) { - for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC( - origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "unlink peer_dataanchor failed, node:%s", origin_node->GetName().c_str()); - continue, "[Unlink][PeerAnchor] failed, node:%s", origin_node->GetName().c_str()); - GE_CHK_BOOL_EXEC( - new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "LinkTo peer_dataanchor failed, node:%s", new_node->GetName().c_str()); - continue, "[LinkTo][PeerAnchor] failed, node:%s", new_node->GetName().c_str()); - } - - for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC( - origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "unlink peer_controlanchor failed, node:%s", origin_node->GetName().c_str()); - continue, "[Unlink][PeerAnchor] failed, node:%s", origin_node->GetName().c_str()); - GE_CHK_BOOL_EXEC( - new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "LinkTo peer_controlanchor failed, node:%s", new_node->GetName().c_str()); - continue, "[LinkTo][PeerAnchor] failed, node:%s", new_node->GetName().c_str()); - } - } - - const auto origin_out_control_anchor = origin_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(origin_out_control_anchor); - const auto new_out_control_anchor = new_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(new_out_control_anchor); - for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) { - GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "linkto peer_anchor from %s to %s failed.", - new_out_control_anchor->GetOwnerNode()->GetName().c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str()); - continue, "[LinkTo][PeerAnchor] from %s to %s failed", - new_out_control_anchor->GetOwnerNode()->GetName().c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str()); - } - for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) { - GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "linkto peer_anchor from %s to %s failed.", - new_out_control_anchor->GetOwnerNode()->GetName().c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str()); - continue, "[LinkTo][PeerAnchor] from %s to %s failed", - new_out_control_anchor->GetOwnerNode()->GetName().c_str(), - peer_anchor->GetOwnerNode()->GetName().c_str()); - } - origin_out_control_anchor->UnlinkAll(); - - return GRAPH_SUCCESS; -} - -bool NodeUtils::IsConst(const Node &node) { - const auto src_node_type = node.GetType(); - const bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP)); - return is_const; -} - -void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) { - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "param node_ptr is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] node is null"); - return; - } - UpdateIsInputConst(*node_ptr); -} - -/// update is_input_const -/// @param node -/// @return void -void NodeUtils::UpdateIsInputConst(Node &node) { - std::vector is_input_const; - const uint32_t anchor_num = node.GetAllInDataAnchorsSize(); - for (uint32_t i = 0UL; i < anchor_num; i++) { - const auto in_anchor = node.GetInDataAnchor(static_cast(i)); - if (in_anchor == nullptr) { - is_input_const.push_back(false); - continue; - } - const auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - is_input_const.push_back(false); - continue; - } - const auto src_node = peer_out_anchor->GetOwnerNodeBarePtr(); - if (src_node == nullptr) { - is_input_const.push_back(false); - continue; - } - if (IsConst(*(src_node))) { - is_input_const.push_back(true); - } else { - is_input_const.push_back(false); - } - } - if (node.GetOpDesc() == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node has no opdesc."); - GELOGE(GRAPH_FAILED, "[Check][Param] Node get opdesc is nullptr"); - return; - } - node.GetOpDesc()->SetIsInputConst(is_input_const); -} - -void NodeUtils::UnlinkAll(const Node &node) { - for (const auto &anchor : node.GetAllOutAnchors()) { - anchor->UnlinkAll(); - } - for (const auto &anchor : node.GetAllInAnchors()) { - anchor->UnlinkAll(); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, - const uint32_t num) { - if ((node == nullptr) || (node->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Input node is null"); - return GRAPH_FAILED; - } - - const GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const auto &op_desc = node->GetOpDesc(); - for (size_t i = op_desc->GetAllInputsSize(); i < num; ++i) { - if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "AddInputDesc failed, op:%s", op_desc->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][InputDesc] failed, op:%s", op_desc->GetName().c_str()); - return GRAPH_FAILED; - } - } - - for (size_t i = node->impl_->in_data_anchors_.size(); i < num; ++i) { - const auto anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Current in data anchor is null, make shared_ptr failed."); - GELOGE(OUT_OF_MEMORY, "[Create][InDataAnchor] Current in data anchor is null, make shared_ptr failed."); - return GRAPH_FAILED; - } - node->impl_->in_data_anchors_.push_back(anchor); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::ClearInputDesc(const OpDescPtr &op_desc, - const uint32_t index) { - GE_CHK_BOOL_EXEC((op_desc != nullptr) && (op_desc->impl_ != nullptr), - REPORT_INNER_ERR_MSG("E18888", "op_desc is nullptr, check invalid"); - return false, "[Check][Param] op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->impl_->inputs_desc_.size(), - REPORT_INNER_ERR_MSG("E18888", "index %u is invalid, out of range(0, %zu).", - index, op_desc->impl_->inputs_desc_.size()); - return false, - "[Check][Param] index %u is invalid, out of range(0, %zu).", - index, op_desc->impl_->inputs_desc_.size()); - - const auto iter = op_desc->impl_->inputs_desc_.begin() + static_cast(index); - if (iter < op_desc->impl_->inputs_desc_.end()) { - (void)op_desc->impl_->inputs_desc_.erase(iter); - } else { - GELOGW("[Clear][InputDesc] inputs_desc_ iterator out of range."); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::ClearOutputDesc(const OpDescPtr &op_desc, - const uint32_t index) { - GE_CHK_BOOL_EXEC((op_desc != nullptr) && (op_desc->impl_ != nullptr), - REPORT_INNER_ERR_MSG("E18888", "param op_desc is nullptr, check invalid"); - return false, "[Check][Param] op_desc is nullptr"); - GE_CHK_BOOL_EXEC(index < op_desc->impl_->outputs_desc_.size(), - REPORT_INNER_ERR_MSG("E18888", "index %u is invalid. out of range(0, %zu)", - index, op_desc->impl_->outputs_desc_.size()); - return false, - "[Check][Param] index %u is invalid. out of range(0, %zu)", - index, op_desc->impl_->outputs_desc_.size()); - const auto iter = op_desc->impl_->outputs_desc_.begin() + static_cast(index); - if (iter < op_desc->impl_->outputs_desc_.end()) { - (void)op_desc->impl_->outputs_desc_.erase(iter); - } else { - GELOGW("[Clear][OutputDesc] outputs_desc_ iterator out of range."); - } - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, - const uint32_t num) { - if ((node == nullptr) || (node->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "param node is null, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Input node is null"); - return GRAPH_FAILED; - } - - const auto &op_desc = node->GetOpDesc(); - while (op_desc->GetInputsSize() > num) { - if (!NodeUtils::ClearInputDesc(op_desc, num)) { - return GRAPH_FAILED; - } - } - - const auto input_names = op_desc->GetAllInputName(); - (void) op_desc->UpdateInputName(input_names); - auto is_input_const = op_desc->GetIsInputConst(); - is_input_const.resize(static_cast(num)); - op_desc->SetIsInputConst(is_input_const); - - while (node->impl_->in_data_anchors_.size() > num) { - node->impl_->in_data_anchors_.pop_back(); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node, - const uint32_t num) { - if ((node == nullptr) || (node->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Input node is null, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Input node is null"); - return GRAPH_FAILED; - } - - const GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - const OpDescPtr &op_desc = node->GetOpDesc(); - for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) { - if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Add output desc failed, op:%s", op_desc->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Add][OutputDesc] failed, op:%s", op_desc->GetName().c_str()); - return GRAPH_FAILED; - } - } - - for (size_t i = node->impl_->out_data_anchors_.size(); i < num; ++i) { - const auto anchor = ComGraphMakeShared(node, i); - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Current out data anchor is null, make shared_ptr failed."); - GELOGE(OUT_OF_MEMORY, "[Create][OutDataAnchor] Current out data anchor is null, make shared_ptr failed."); - return GRAPH_FAILED; - } - node->impl_->out_data_anchors_.push_back(anchor); - } - - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node, - const uint32_t num) { - if ((node == nullptr) || (node->impl_ == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "Input node is null, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Input node is null"); - return GRAPH_FAILED; - } - - const auto &op_desc = node->GetOpDesc(); - const auto output_names = op_desc->GetAllOutputName(); - while (op_desc->GetOutputsSize() > num) { - if (!NodeUtils::ClearOutputDesc(op_desc, num)) { - return GRAPH_FAILED; - } - } - (void) op_desc->UpdateOutputName(output_names); - - while (node->impl_->out_data_anchors_.size() > num) { - node->impl_->out_data_anchors_.pop_back(); - } - - return GRAPH_SUCCESS; -} - -GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, const uint32_t index) { - const auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return {}; - } - return desc->GetOutputDesc(index); -} - -graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { - const auto desc = node.GetOpDesc(); - GE_CHECK_NOTNULL(desc); - // check self - is_unknow = OpShapeIsUnknown(desc); - if (is_unknow) { - return GRAPH_SUCCESS; - } - const auto sub_graph_names = desc->GetSubgraphInstanceNames(); - if (sub_graph_names.empty()) { - return GRAPH_SUCCESS; - } else { - const auto owner_graph = node.GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - // During graph splitting, get parent graph cannot be obtained in some scenarios, - // but the root graph can be set use the attribute. - ge::ComputeGraphPtr src_graph = owner_graph->TryGetExtAttr(kPartSrcGraph, ge::ComputeGraphPtr()); - if (src_graph == nullptr) { - GELOGD("src graph is null, owner graph name is %s", owner_graph->GetName().c_str()); - src_graph = owner_graph; - } - GELOGD("src graph is %s, owner graph name is %s", src_graph->GetName().c_str(), owner_graph->GetName().c_str()); - const auto root_graph = GraphUtils::FindRootGraph(src_graph); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node:%s has no root graph.", node.GetName().c_str()); - GE_LOGE("[Get][Graph] Node %s gets null root graph", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - for (auto &sub_graph_name : sub_graph_names) { - const auto sub_graph = root_graph->GetSubgraph(sub_graph_name); - if (sub_graph == nullptr) { - GELOGD("sub graph %s is empty", sub_graph_name.c_str()); - continue; - } - for (const auto &node_ptr : sub_graph->GetDirectNode()) { - const auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); - if (status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "GetNodeUnknownShapeStatus failed, node:%s, status:%u", - node_ptr->GetName().c_str(), status); - GE_LOGE("[Get][NodeUnknownShapeStatus] failed! node:%s, status:%u", node_ptr->GetName().c_str(), status); - return status; - } - if (is_unknow) { - return GRAPH_SUCCESS; - } - } - } - } - return GRAPH_SUCCESS; -} - -std::string NodeUtils::GetNodeType(const Node &node) { - if (node.GetType() != FRAMEWORKOP) { - return node.GetType(); - } - - std::string type; - (void) AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); - return type; -} - -std::string NodeUtils::GetNodeType(const NodePtr &node) { - return (node == nullptr) ? "" : GetNodeType(*node); -} - -graphStatus NodeUtils::GetDirectSubgraphs(const NodePtr &node, std::vector &subgraphs) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - REPORT_INNER_ERR_MSG("E18888", "node or op_desc is null"); - GELOGE(GRAPH_FAILED, "[Check][Param] node or op_desc is null"); - return GRAPH_FAILED; - } - - const auto &root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to find root graph from node %s ", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] Failed to find root graph from node %s ", node->GetName().c_str()); - return GRAPH_FAILED; - } - - for (const auto &graph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) { - const auto &graph = root_graph->GetSubgraph(graph_name); - if (graph == nullptr) { - GELOGW("[Get][Subgraph] subgraph %s of node %s is null", graph_name.c_str(), node->GetName().c_str()); - continue; - } - subgraphs.emplace_back(graph); - } - - return GRAPH_SUCCESS; -} - -ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, const uint32_t index) { - const auto op_desc = node.GetOpDesc(); - if (op_desc == nullptr) { - return nullptr; - } - const auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - return nullptr; - } - return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); -} - -graphStatus NodeUtils::SetSubgraph(Node &node, const uint32_t index, const ComputeGraphPtr &subgraph) { - if (subgraph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), - index); - GE_LOGE("[Check][Param] Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); - return GRAPH_PARAM_INVALID; - } - const auto op_desc = node.GetOpDesc(); - if (op_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - const auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (root_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); - GE_LOGE("[Get][Graph] Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - const auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); - GE_LOGE("[Set][Name] Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); - return ret; - } - subgraph->SetParentNode(node.shared_from_this()); - subgraph->SetParentGraph(node.GetOwnerComputeGraph()); - return root_graph->AddSubgraph(subgraph); -} -graphStatus NodeUtils::AddSubgraph(Node &node, const std::string &subgraph_ir_name, const ComputeGraphPtr &subgraph) { - GE_ASSERT_NOTNULL(subgraph); - auto op_desc = node.GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - - // Will report inner warning if the Op is created using REG_OP format - // because during REG_OP it has already registered subgraph IR name - (void) op_desc->AddSubgraphName(subgraph_ir_name); - auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - const auto &iter = subgraph_names_to_index.find(subgraph_ir_name); - GE_ASSERT_TRUE(iter != subgraph_names_to_index.cend()); - - return SetSubgraph(node, iter->second, subgraph); -} -graphStatus NodeUtils::AddSubgraph(const NodePtr &node_ptr, const std::string &subgraph_ir_name, - const ComputeGraphPtr &subgraph) { - GE_ASSERT_NOTNULL(node_ptr); - auto op_desc = node_ptr->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - auto subgraph_ir_type = op_desc->GetSubgraphTypeByIrName(subgraph_ir_name); - if (subgraph_ir_type == kSubgraphTypeEnd) { - op_desc->RegisterSubgraphIrName(subgraph_ir_name, kStatic); - } else { - GE_ASSERT_EQ(kStatic, subgraph_ir_type); - } - auto &node = *node_ptr.get(); - return AddSubgraph(node, subgraph_ir_name, subgraph); -} -graphStatus NodeUtils::AddSubgraphs(const NodePtr &node_ptr, const std::string &subgraph_ir_name, - const std::vector &subgraphs) { - GE_ASSERT_NOTNULL(node_ptr); - auto op_desc = node_ptr->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - auto subgraph_ir_type = op_desc->GetSubgraphTypeByIrName(subgraph_ir_name); - if (subgraph_ir_type == kSubgraphTypeEnd) { - op_desc->RegisterSubgraphIrName(subgraph_ir_name, kDynamic); - } else { - GE_ASSERT_EQ(kDynamic, subgraph_ir_type); - } - auto &node = *node_ptr.get(); - for (int64_t i = 0U; i < static_cast(subgraphs.size()); ++i) { - const auto& subgraph = subgraphs[i]; - GE_ASSERT_SUCCESS(AddSubgraph(node, GenDynamicSubgraphName(subgraph_ir_name, i), subgraph)); - } - return GRAPH_SUCCESS; -} -std::string NodeUtils::GenDynamicSubgraphName(const std::string &subgraph_ir_name, int64_t index) { - return subgraph_ir_name + std::to_string(index); -} - -/// Check if node is input of subgraph -/// @param [in] node -/// @return bool -bool NodeUtils::IsSubgraphInput(const NodePtr &node) { - return IsSubgraphInput(node.get()); -} - -bool NodeUtils::IsSubgraphInput(const Node *const node) { - if ((node == nullptr) || (node->GetOpDescBarePtr() == nullptr) || - (node->GetOwnerComputeGraphBarePtr()->GetParentNodeBarePtr() == nullptr)) { - return false; - } - - const auto parent_op_desc = node->GetOwnerComputeGraphBarePtr()->GetParentNodeBarePtr()->GetOpDescBarePtr(); - if (parent_op_desc == nullptr) { - return false; - } - - // dynamic shape unknown graph false - // dynamic shape known graph with functional subgraph maybe true - bool is_forced_unknown = false; - if (AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_forced_unknown) && is_forced_unknown) { - if (node->GetOwnerComputeGraphBarePtr()->GetParentGraphBarePtr()->GetGraphUnknownFlag()) { - return false; - } else { - if (node->GetOwnerComputeGraphBarePtr()->GetParentNodeBarePtr()->GetOwnerComputeGraphBarePtr() - ->GetParentNodeBarePtr() == nullptr) { - return false; - } - } - } - - return node->GetOpDescBarePtr()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); -} - -/// Check if node is output of subgraph -/// @param [in] node -/// @return bool -bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr) || - (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { - return false; - } - - const auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); - if (parent_op_desc == nullptr) { - return false; - } - - bool is_forced_unknown = false; - if (AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_forced_unknown) && is_forced_unknown) { - if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) { - return false; - } else { - if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) { - return false; - } - } - } - - for (const auto &tensor : node->GetOpDesc()->GetAllInputsDescPtr()) { - if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { - return true; - } - } - - return false; -} - -/// @brief Get subgraph original input node. -/// @param [in] node -/// @return Node -NodePtr NodeUtils::GetParentInput(const Node &node) { - uint32_t parent_index = 0U; - if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - return nullptr; - } - - // Subgraph Data Node, check for constant input. - const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); - GE_CHECK_NOTNULL_EXEC(graph, return nullptr); - - const NodePtr &parent_node = graph->GetParentNode(); - if (parent_node == nullptr) { - GELOGW("Node {%s %s} has attr %s but has no parent node.", - node.GetNamePtr(), - node.GetTypePtr(), - ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return nullptr; - } - - const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(static_cast(parent_index)); - GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); - - const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); - - auto peer_node = peer_out_anchor->GetOwnerNode(); - if (peer_node->GetType() == DATA) { - if (peer_node->GetOpDesc() == nullptr) { - return nullptr; - } - if (peer_node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { - return GetParentInput(peer_node); - } - } - return peer_node; -} - -NodePtr NodeUtils::GetParentInput(const NodePtr &node) { - return (node == nullptr) ? node : GetParentInput(*node); -} -NodeToOutAnchor NodeUtils::GetParentInputAndAnchor(const NodePtr &node) { - uint32_t parent_index = 0U; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - return {nullptr, nullptr}; - } - - // Subgraph Data Node, check for constant input. - const ComputeGraphPtr &graph = node->GetOwnerComputeGraph(); - if (graph == nullptr) { - return {nullptr, nullptr}; - } - - const NodePtr &parent_node = graph->GetParentNode(); - if (parent_node == nullptr) { - return {nullptr, nullptr}; - } - - const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(static_cast(parent_index)); - if (in_anchor == nullptr) { - return {nullptr, nullptr}; - } - - const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - return {nullptr, nullptr}; - } - - return std::make_pair(peer_out_anchor->GetOwnerNode(), peer_out_anchor); -} - -NodeToOutAnchor NodeUtils::GetParentInputAndAnchorCrossSubgraph(const NodePtr &node) { - NodeToOutAnchor node_to_out_anchor = {nullptr, nullptr}; - std::stack s; - s.push(node); - while (!s.empty()) { - auto n = s.top(); - s.pop(); - node_to_out_anchor = GetParentInputAndAnchor(n); - auto peer_node = node_to_out_anchor.first; - if ((peer_node == nullptr) || (peer_node->GetType() != DATA)) { - continue; - } - - if ((peer_node->GetOpDesc() != nullptr) && peer_node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { - s.push(peer_node); - } - } - return node_to_out_anchor; -} - -/// @brief Get is dynamic shape graph from node. -/// @param [in] node -/// @return bool -bool NodeUtils::IsDynamicShape(const Node &node) { - const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); - if (graph == nullptr) { - return false; - } - - bool is_dynamic_shape = false; - (void) AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); - return is_dynamic_shape; -} - -bool NodeUtils::IsDynamicShape(const NodePtr &node) { - return (node == nullptr) ? false : IsDynamicShape(*node); -} - -/// @brief Check is varying_input for while node -/// @param [in] node: Data node for subgraph -/// @return bool -bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { - if (node == nullptr) { - return false; - } - if (node->GetType() != DATA) { - return false; // not input_node for subgraph - } - - const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); - if (parent_node == nullptr) { - return false; // root graph - } - - if (kWhileOpTypes.count(parent_node->GetType()) == 0U) { - return false; // not input_node for while subgraph - } - - uint32_t index_i = 0U; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { - GELOGW("[Check][Attr] Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); - return false; - } - bool varying_flag = true; - for (const auto &item : node->GetOutDataNodesAndAnchors()) { - if (item.first->GetType() != NETOUTPUT) { - continue; - } - const OpDescPtr op_desc = item.first->GetOpDesc(); - uint32_t index_o = 0U; - if ((op_desc == nullptr) || - (!AttrUtils::GetInt(op_desc->GetInputDesc(static_cast(item.second->GetIdx())), - ATTR_NAME_PARENT_NODE_INDEX, index_o))) { - continue; // input for while-cond subgraph - } - if (index_i != index_o) { - continue; // varying input for while-body subgraph - } - varying_flag = false; - break; - } - return varying_flag; -} - -/// @brief Get subgraph input is constant. -/// @param [in] node -/// @param [out] string -/// @return bool -bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) { - if (node == nullptr) { - return false; - } - - const auto node_type = node->GetType(); - if ((node_type == CONSTANT) || (node_type == CONSTANTOP) || (node_type == FILECONSTANT)) { - type = node->GetType(); - return true; - } - - if (node_type != DATA) { - return false; // not subgraph input node - } - - const auto &parent = GetParentInput(node); - return GetConstOpType(parent, type); -} - -/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. -/// @param [in] node -/// @return return GRAPH_SUCCESS if remove successfully, other for failed. -graphStatus NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { - GE_CHECK_NOTNULL(node); - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - const auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - return GRAPH_SUCCESS; - } else { - const auto owner_graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(owner_graph); - const auto root_graph = GraphUtils::FindRootGraph(owner_graph); - GE_CHECK_NOTNULL(root_graph); - - std::set subgraph_to_remove; - for (auto &subgraph_name : subgraph_names) { - std::deque queue; - queue.push_back(subgraph_name); - (void) subgraph_to_remove.insert(subgraph_name); - op_desc->RemoveSubgraphInstanceName(subgraph_name); - while (!queue.empty()) { - const auto graph_name = queue.front(); - queue.pop_front(); - - const auto subgraph = root_graph->GetSubgraph(graph_name); - GE_CHECK_NOTNULL(subgraph); - for (const auto &sub_node : subgraph->GetDirectNode()) { - const auto sub_op_desc = sub_node->GetOpDesc(); - GE_CHECK_NOTNULL(sub_op_desc); - const auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); - // Subgraph and all nodes in it will be removed later, - // no need to remove 'SubgraphInstanceName' in op desc here. - for (auto &name : sub_names) { - if (subgraph_to_remove.insert(name).second) { - queue.push_back(name); - } - } - } - } - } - // Remove subgraph from root_graph - for (const auto &name : subgraph_to_remove) { - GELOGI("Remove subgraph:%s.", name.c_str()); - root_graph->RemoveSubgraph(name); - } - } - - return GRAPH_SUCCESS; -} - -std::vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, const int32_t index) { - std::vector in_data_node_vec; - const auto op_desc = node.GetOpDescBarePtr(); - GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); - const auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - return in_data_node_vec; - } - const auto compute_graph = FindRootGraph(node); - for (const std::string &instance_name : subgraph_names) { - const auto subgraph = compute_graph->GetSubgraph(instance_name); - if (subgraph == nullptr) { - continue; - } - for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { - if (IsTypeEqual(node_in_subgraph, DATA)) { - int32_t parent_index = -1; - (void) AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); - if (parent_index == index) { - in_data_node_vec.emplace_back(node_in_subgraph); - break; - } - } - } - } - return in_data_node_vec; -} -/// @brief Get subgraph input data node by index. -/// @param [in] node -/// @return Node -std::vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { - std::vector out_data_node_vec; - const auto op_desc = node.GetOpDesc(); - GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); - const auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - if (subgraph_names.empty()) { - GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); - return out_data_node_vec; - } - const auto compute_graph = FindRootGraph(node); - for (const std::string &instance_name : subgraph_names) { - const auto subgraph = compute_graph->GetSubgraph(instance_name); - if (subgraph == nullptr) { - continue; - } - out_data_node_vec.emplace_back(subgraph->GetOrUpdateNetOutputNode()); - } - return out_data_node_vec; -} - -NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int32_t index) { - if (node.GetInDataAnchor(index) == nullptr) { - return nullptr; - } - if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { - return nullptr; - } - return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); -} - -std::vector> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, - const int32_t index) { - std::vector> out_data_nodes; - const auto out_data_anchor = node.GetOutDataAnchor(index); - if (out_data_anchor == nullptr) { - return out_data_nodes; - } - - for (const auto &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - if (peer_in_anchor == nullptr) { - continue; - } - if (peer_in_anchor->GetOwnerNodeBarePtr() == nullptr) { - continue; - } - out_data_nodes.emplace_back(peer_in_anchor, peer_in_anchor->GetOwnerNode()); - } - return out_data_nodes; -} - -std::string NodeUtils::GetInConstNodeTypeCrossSubgraph(const NodePtr &node) { - const NodePtr input_node = GetInNodeCrossSubgraph(node); - if (input_node == nullptr) { - return ""; - } - - return input_node->GetType(); -} - -NodePtr NodeUtils::GetInNodeCrossSubgraph(const NodePtr &node) { - NodePtr input_node = node; - while (input_node != nullptr) { - if (input_node->GetType() != DATA) { - return input_node; - } - - const auto owner_graph = input_node->GetOwnerComputeGraph(); - const auto parent_node = owner_graph->GetParentNode(); - if ((parent_node == nullptr) || (kWhileOpTypes.count(parent_node->GetType()) > 0UL)) { - return input_node; // not in subgraph or while subgraph. - } - - input_node = GetParentInput(input_node); - } - - return input_node; -} - -NodePtr NodeUtils::CreatNodeWithoutGraph(const OpDescPtr op_desc) { - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "The OpDesc ptr should not be null."); - GELOGE(GRAPH_FAILED, "[Check][Param] The OpDesc ptr should not be null."); - return nullptr; - } - auto node_ptr = shared_ptr(new (std::nothrow) Node(op_desc, nullptr)); - if (node_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "create node failed."); - GELOGE(GRAPH_FAILED, "[Create][Node] node_ptr is NULL!"); - return nullptr; - } - return node_ptr; -} - -graphStatus NodeUtils::GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node) { - int32_t peer_out_anchor_index = kInvalidIndex; - return GetInNodeCrossPartionedCallNode(node, index, peer_node, peer_out_anchor_index); -} - -graphStatus NodeUtils::GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node, - int32_t &peer_out_anchor_index) { - GE_CHECK_NOTNULL(node); - peer_out_anchor_index = kInvalidIndex; - if ((node->GetAllInDataAnchorsSize() <= index) && (node->GetType() != DATA)) { - return GRAPH_FAILED; - } - GELOGD("in node:%s index:%d", node->GetName().c_str(), index); - peer_node = (node->GetType() == DATA) ? node : GetInDataNodeByIndex(*node, static_cast(index)); - if (peer_node == nullptr) { - // A->B - // Asuming A and B belongs to different engine, during graph partition, A will be set to B's extra attr as - // parent node. when FE get parent node A from B, check A's in_anchor peer_out_anchor is null. - return GRAPH_SUCCESS; - } - - if (node->GetType() != DATA) { - const auto in_anchor = node->GetInDataAnchor(static_cast(index)); - GE_CHECK_NOTNULL(in_anchor); - const auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - peer_out_anchor_index = peer_out_anchor->GetIdx(); - } - while (!IsComputableOp(peer_node)) { - if (peer_node->GetType() == DATA) { - const auto parent_node_2_anchor = GetParentInputAndAnchor(peer_node); - if ((parent_node_2_anchor.first == nullptr) || (parent_node_2_anchor.second == nullptr)) { - GELOGW("Returned peer_out_node is nullptr because no attr[%s] on DATA[%s] node!", kRefIndex, - peer_node->GetName().c_str()); - peer_node = nullptr; - return GRAPH_SUCCESS; - } - peer_node = parent_node_2_anchor.first; - peer_out_anchor_index = parent_node_2_anchor.second->GetIdx(); - continue; - } - - if (peer_node->GetType() != PARTITIONEDCALL) { - if (peer_node->GetOpDesc()->GetSubgraphInstanceNames().empty()) { - GELOGI("Node [%s] type [%s], real peer in node [%s] type[%s].", node->GetName().c_str(), - node->GetType().c_str(), peer_node->GetName().c_str(), peer_node->GetType().c_str()); - return GRAPH_SUCCESS; - } - // other subgraph(if,while,case) currently not support, return node and warn - GELOGW("Node [%s] type [%s], real peer in node [%s] type[%s] has subgraph. Current not support.", - node->GetName().c_str(), node->GetType().c_str(), peer_node->GetName().c_str(), - peer_node->GetType().c_str()); - - return GRAPH_SUCCESS; - } - // if peer node is PartionedCall, return owner graph's correspond node - const auto sub_graph = GetSubgraph(*peer_node, 0U); - if (sub_graph == nullptr) { - GELOGW("SubGraph of node %s index 0 is null. Null is invalid.", peer_node->GetName().c_str()); - return ge::PARAM_INVALID; - } - const auto sub_graph_netoutput = sub_graph->GetOrUpdateNetOutputNode(); - GE_CHECK_NOTNULL(sub_graph_netoutput); - - for (const auto &in_data_anchor : sub_graph_netoutput->GetAllInDataAnchors()) { - const auto in_desc = - sub_graph_netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor->GetIdx())); - GE_CHECK_NOTNULL(in_desc); - int32_t ref_o = 0; - if (!AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { - return GRAPH_FAILED; - } - if (peer_out_anchor_index != ref_o) { - continue; - } - peer_node = NodeUtils::GetInDataNodeByIndex(*sub_graph_netoutput, in_data_anchor->GetIdx()); - GE_CHECK_NOTNULL(peer_node); - GE_CHECK_NOTNULL(in_data_anchor->GetPeerOutAnchor()); - peer_out_anchor_index = in_data_anchor->GetPeerOutAnchor()->GetIdx(); - GELOGD("in node[%s] peer_node[%s] type[%s] out anchor index[%d].", node->GetName().c_str(), - peer_node->GetName().c_str(), peer_node->GetType().c_str(), peer_out_anchor_index); - break; - } - } - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::SetNodeParallelGroup(Node &node, const char_t *const group_name) { - if (group_name == nullptr) { - GE_LOGE("[Check][Parameter]Get nullptr when set parallel group on node:%s", node.GetName().c_str()); - REPORT_INNER_ERR_MSG("E18888", "Get nullptr when set parallel group on node:%s", node.GetName().c_str()); - return GRAPH_FAILED; - } - std::string current_group; - const std::string new_group(group_name); - if (AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, current_group)) { - if (new_group != current_group) { - GE_LOGE("[Compare][Attr]Failed to set parallel group name %s on node %s, group conflict with existing %s", - new_group.c_str(), node.GetName().c_str(), group_name); - REPORT_INNER_ERR_MSG("E18888", "Failed to set parallel group name %s on node %s, group conflict with existing %s", - new_group.c_str(), node.GetName().c_str(), group_name); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; - } - if (!AttrUtils::SetStr(node.GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, new_group)) { - GE_LOGE("[Set][Attr] Failed to set parallel group name %s on node %s", group_name, node.GetName().c_str()); - REPORT_INNER_ERR_MSG("E18888", "Failed to set parallel group name %s on node %s", group_name, node.GetName().c_str()); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::UpdateInputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape) { - const auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - const auto input_desc = desc->MutableInputDesc(index); - if (input_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - input_desc->SetShape(shape); - input_desc->SetOriginShape(shape); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::UpdateOutputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape) { - const auto desc = node.GetOpDesc(); - if (desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - const auto output_desc = desc->MutableOutputDesc(index); - if (output_desc == nullptr) { - return GRAPH_PARAM_INVALID; - } - output_desc->SetShape(shape); - output_desc->SetOriginShape(shape); - return GRAPH_SUCCESS; -} -std::pair NodeUtils::GetInDataNodeAndAnchorByIndex(const Node &node, const int32_t index) { - const auto dst_anchor = node.GetInDataAnchor(index); - if (dst_anchor == nullptr) { - GE_LOGE("Failed to get in data anchor from index %d for node %s", index, node.GetName().c_str()); - return {nullptr, nullptr}; - } - auto src_anchor = dst_anchor->GetPeerOutAnchor(); - if (src_anchor == nullptr) { - GE_LOGE("Failed to get peer out data anchor from index %i for node %s", index, node.GetName().c_str()); - return {nullptr, nullptr}; - } - auto src_node = src_anchor->GetOwnerNode(); - if (src_node == nullptr) { - GE_LOGE("Failed to get in data node from index %d for node %s", index, node.GetName().c_str()); - return {nullptr, nullptr}; - } - return {src_node, src_anchor}; -} - -bool NodeUtils::IsDtResourceNode(const NodePtr &node) { - for (const auto &in_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { - if (in_desc->GetDataType() == DT_RESOURCE) { - return true; - } - } - for (const auto &out_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { - if (out_desc->GetDataType() == DT_RESOURCE) { - return true; - } - } - return false; -} - -bool NodeUtils::IsLikeAtomicClean(const NodePtr &node) { - const auto node_type = NodeUtils::GetNodeType(node); - return (node_type == ATOMICADDRCLEAN) || (node_type == MEMSET); -} - -bool NodeUtils::IsIdentityUsefulForRWControl(const NodePtr &node_ptr) { - GE_ASSERT_NOTNULL(node_ptr); - if (!(OpTypeUtils::IsIdentityLikeNode(node_ptr->GetType()))) { - return false; - } - Node &node = *(node_ptr.get()); - if (node.GetOutControlNodesSize() == 0U) { - return false; - } - if (node.GetInDataNodesSize() != 1U) { - return false; - } - if (node.GetOutDataNodesSize() == 0U) { - return false; - } - const auto out_data_node = node.GetOutDataNodes().at(0U); - GE_ASSERT_NOTNULL(node.GetInDataAnchor(0U)); - const auto &in_node_out_data_anchor = node.GetInDataAnchor(0U)->GetPeerOutAnchor(); - if (in_node_out_data_anchor == nullptr) { - return false; - } - const auto in_node_ptr = in_node_out_data_anchor->GetOwnerNodeBarePtr(); // in_node_ptr must not be null - for (const auto out_control_node_in_control_anchor : node.GetOutControlAnchor()->GetPeerInControlAnchorsPtr()) { - const auto out_control_node = - out_control_node_in_control_anchor->GetOwnerNodeBarePtr(); // out_control node must not be null - for (const auto out_control_node_in_data_anchor : out_control_node->GetAllInDataAnchorsPtr()) { - // out_control_node_in_data_anchor must not be null - // out_control_node_in_data_anchor->GetOwnerNodeBarePtr() must not be null - if (in_node_out_data_anchor->IsLinkedWith(out_control_node_in_data_anchor->shared_from_this())) { - if ((OpTypeUtils::IsVarLikeNode(in_node_ptr->GetType())) && - (OpTypeUtils::IsAssignLikeNode(out_control_node->GetType()))) { - GELOGD("Node[%s %s] is useful for control relation, keep this node to ensure out data node[%s %s] read " - "in_data_node [%s %s] firstly, then out control node [%s %s] write in_data_node", - node.GetName().c_str(), node.GetType().c_str(), out_data_node->GetName().c_str(), - out_data_node->GetType().c_str(), in_node_ptr->GetName().c_str(), in_node_ptr->GetType().c_str(), - out_control_node->GetName().c_str(), out_control_node->GetType().c_str()); - return true; - } - } - } - } - return false; -} - -ComputeGraphPtr NodeUtils::FindRootGraph(const Node &node) { - return GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); -} - -std::vector NodeUtils::GetOutControlNodes(const Node &node, const NodeFilter &node_filter) { - std::vector out_ctrl_nodes; - const auto &out_control = node.GetOutControlAnchor(); - if (out_control == nullptr) { - return out_ctrl_nodes; - } - out_ctrl_nodes.reserve(node.GetOutControlNodesSize()); - for (const auto &in_anchor : out_control->GetPeerAnchorsPtr()) { - const auto &peer_node = in_anchor->GetOwnerNode(); - if ((node_filter == nullptr) || node_filter(*peer_node)) { - out_ctrl_nodes.push_back(peer_node); - } - } - return out_ctrl_nodes; -} - -std::vector NodeUtils::GetOutDataNodes(const Node &node, const NodeFilter &node_filter) { - std::vector out_data_nodes; - for (const auto &out_anchor : node.impl_->out_data_anchors_) { - GE_ASSERT_NOTNULL(out_anchor); - for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_ASSERT_NOTNULL(in_anchor); - const auto out_data_node = in_anchor->GetOwnerNode(); - if ((node_filter == nullptr) || node_filter(*out_data_node)) { - out_data_nodes.push_back(out_data_node); - } - } - } - return out_data_nodes; -} - - -std::vector NodeUtils::GetInControlNodes(const Node &node, const NodeFilter &node_filter) { - std::vector in_ctrl_nodes; - const auto &in_control = node.GetInControlAnchor(); - if (in_control == nullptr) { - return in_ctrl_nodes; - } - in_ctrl_nodes.reserve(node.GetInControlNodesSize()); - for (const auto out_anchor : in_control->GetPeerAnchorsPtr()) { - const auto &peer_node = out_anchor->GetOwnerNode(); - if ((node_filter == nullptr) || node_filter(*peer_node)) { - in_ctrl_nodes.push_back(peer_node); - } - } - return in_ctrl_nodes; -} - -std::vector NodeUtils::GetInDataNodes(const Node &node, const NodeFilter &node_filter) { - std::vector in_data_nodes; - in_data_nodes.reserve(node.GetInDataNodesSize()); - for (const auto &in_anchor : node.impl_->in_data_anchors_) { - GE_ASSERT_NOTNULL(in_anchor); - const auto anchor_ptr = in_anchor->GetPeerOutAnchor(); - if (anchor_ptr == nullptr) { - continue; - } - const auto in_node = anchor_ptr->GetOwnerNode(); - if ((node_filter == nullptr) || node_filter(*in_node)) { - in_data_nodes.push_back(in_node); - } - } - return in_data_nodes; -} - -graphStatus NodeUtils::TryGetWeightByPlaceHolderNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor) { - if (ge_tensor != nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "ge_tensor already has value"); - return GRAPH_PARAM_INVALID; - } - if (node_ptr->GetType() != PLACEHOLDER) { - return GRAPH_SUCCESS; - } - const auto &op_desc = node_ptr->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - // In some case, Placeholder operator may has it's peer const node's weight - if (ConstantUtils::GetWeight(op_desc, 0U, ge_tensor)) { - GELOGI("op [%s %s] has direct weight attr", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - return GRAPH_SUCCESS; - } - NodePtr parent_node = nullptr; - parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); - if (parent_node == nullptr) { - GELOGI("op [%s %s] get not any ext node attr", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - return GRAPH_SUCCESS; - } - const auto &parent_op_desc = parent_node->GetOpDesc(); - GE_CHECK_NOTNULL(parent_op_desc); - if (ConstantUtils::IsConstant(parent_op_desc)) { - if (ConstantUtils::GetWeight(parent_op_desc, 0U, ge_tensor)) { - GELOGI("op [%s %s] has indirect weight attr from other op [%s %s]", op_desc->GetType().c_str(), - op_desc->GetName().c_str(), parent_op_desc->GetType().c_str(), parent_op_desc->GetName().c_str()); - return GRAPH_SUCCESS; - } - } - if (parent_op_desc->GetType() == DATA) { - return TryGetWeightByDataNode(parent_node, ge_tensor); - } - GELOGI("op [%s %s] get not any weight attr", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - return GRAPH_SUCCESS; -} - -graphStatus NodeUtils::TryGetWeightByDataNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor) { - if (ge_tensor != nullptr) { - GELOGE(GRAPH_PARAM_INVALID, "ge_tensor already has value"); - return GRAPH_PARAM_INVALID; - } - if (node_ptr->GetType() != DATA) { - return GRAPH_SUCCESS; - } - const auto &op_desc = node_ptr->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - // the input const data should not be obtained, - // as the input will change during multiple rounds of infershape for while - if ((node_ptr->GetOwnerComputeGraphBarePtr() != nullptr) && - (node_ptr->GetOwnerComputeGraphBarePtr()->GetParentNodeBarePtr() != nullptr) && - (kWhileOpTypes.count(node_ptr->GetOwnerComputeGraphBarePtr()->GetParentNodeBarePtr()->GetType()) > 0U)) { - GELOGI("The value of a const node should not be obtained, when the const node is outside a while node, " - "while node name: %s", - node_ptr->GetOwnerComputeGraphBarePtr()->GetParentNodeBarePtr()->GetName().c_str()); - return GRAPH_SUCCESS; - } - NodePtr real_parent_node = nullptr; - (void) NodeUtils::GetInNodeCrossPartionedCallNode(node_ptr, 0U, real_parent_node); - if ((real_parent_node != nullptr) && (ConstantUtils::IsConstant(real_parent_node->GetOpDesc()))) { - GELOGI("Get in really parent node:[%s %s] for node:[%s %s]", real_parent_node->GetName().c_str(), - real_parent_node->GetType().c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str()); - if (ConstantUtils::IsConstant(real_parent_node)) { - if (ConstantUtils::GetWeight(real_parent_node->GetOpDesc(), 0U, ge_tensor)) { - GELOGI("op [%s %s] has indirect weight attr from other op [%s %s]", op_desc->GetType().c_str(), - op_desc->GetName().c_str(), real_parent_node->GetType().c_str(), real_parent_node->GetName().c_str()); - return GRAPH_SUCCESS; - } - } - } - GELOGI("op [%s %s] get not any weight attr", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - return GRAPH_SUCCESS; -} -bool NodeUtils::IsNameEqual(const NodePtr &node, const ge::char_t *const name) { - return strcmp(node->GetNamePtr(), name) == 0; -} -bool NodeUtils::IsTypeEqual(const NodePtr &node, const ge::char_t *const type) { - return strcmp(node->GetTypePtr(), type) == 0; -} - -NodePtr NodeUtils::GetNodeWithMinimalId(const std::vector &nodes) { - NodePtr min_id_node = nullptr; - int64_t min_id = -1; - for (const auto &node : nodes) { - const auto op_desc = node->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - const auto id = op_desc->GetId(); - if ((min_id == -1) || (id < min_id)) { - min_id = id; - min_id_node = node; - } - } - return min_id_node; -} -} // namespace ge diff --git a/graph/utils/node_utils_ex.cc b/graph/utils/node_utils_ex.cc deleted file mode 100644 index a4183d87e69f56c90df328ed94ef73688730589a..0000000000000000000000000000000000000000 --- a/graph/utils/node_utils_ex.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/node_utils_ex.h" -#include "common/ge_common/util.h" -#include "common/util/trace_manager/trace_manager.h" -#include "graph/refiner/format_refiner.h" -#include "graph/shape_refiner.h" -#include "graph/normal_graph/operator_impl.h" -#include "graph/operator_factory_impl.h" -#include "graph/common_error_codes.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_op_types.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "common/util/mem_utils.h" -#include "graph/utils/op_type_utils.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "base/err_msg.h" - -namespace ge { -namespace { -bool NeedUpdateIOName(const OpDescPtr &op_desc) { - const auto &input_name_2_idx = op_desc->GetAllInputName(); - const bool is_input_names_empty = (op_desc->GetInputsSize() > 0U) && input_name_2_idx.empty(); - const bool is_default_input_name = !input_name_2_idx.empty() && - StringUtils::StartWith(input_name_2_idx.cbegin()->first, "__input"); - if (is_input_names_empty || is_default_input_name) { - return true; - } - - const auto &output_name_2_idx = op_desc->GetAllOutputName(); - const bool is_output_names_empty = (op_desc->GetOutputsSize() > 0U) && output_name_2_idx.empty(); - const bool is_default_output_name = !output_name_2_idx.empty() && - StringUtils::StartWith(output_name_2_idx.cbegin()->first, "__output"); - if (is_output_names_empty || is_default_output_name) { - return true; - } - return false; -} -std::string IoNameToString(const std::string &prefix, const std::map &io_names) { - std::stringstream ss; - ss << prefix << ":"; - if (io_names.empty()) { - ss << "empty"; - return ss.str(); - } - for (const auto &pair : io_names) { - ss << "[" << pair.second << "," << pair.first << "]"; - } - return ss.str(); -} -} // namespace -graphStatus NodeUtilsEx::InferShapeAndType(const NodePtr &node) { - GE_CHECK_NOTNULL(node, ", Node is null for Infer Shape."); - Operator op = OpDescUtils::CreateOperatorFromNode(node); - return ShapeRefiner::InferShapeAndType(node, op); -} - -graphStatus NodeUtilsEx::InferOriginFormat(const NodePtr &node) { - GE_CHECK_NOTNULL(node, ", Node is null for Infer Format."); - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Format."); - Operator op = OpDescUtils::CreateOperatorFromNode(node); - return OpDescUtilsEx::CallInferFormatFunc(op_desc, op); -} - -graphStatus NodeUtilsEx::IsInputsValid(const NodePtr &node) { - const auto &op_desc = node->GetOpDesc(); - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - if (in_anchor == nullptr) { - GELOGW("[Verify][CheckParam] In data anchor is null"); - continue; - } - const bool valid_anchor = OpTypeUtils::IsDataNode(node->GetType()) || - (node->GetType() == CONSTANT) || (node->GetType() == VARIABLE) || - (node->GetType() == CONSTANTOP) || - (op_desc->MutableInputDesc(static_cast(in_anchor->GetIdx())) == nullptr) || - (in_anchor->GetPeerAnchorsSize() > 0UL); - if (!valid_anchor) { - REPORT_PREDEFINED_ERR_MSG( - "E11019", std::vector({"opname", "index"}), - std::vector({node->GetName().c_str(), std::to_string(in_anchor->GetIdx()).c_str()})); - GELOGE(GRAPH_FAILED, "[Check][Param] operator %s's input %d is not linked.", - node->GetName().c_str(), in_anchor->GetIdx()); - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -graphStatus NodeUtilsEx::Verify(const NodePtr &node) { - GE_CHECK_NOTNULL(node, ", Node is null for Infer Verify."); - const bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); - if (is_unknown_graph) { - return GRAPH_SUCCESS; - } - - GE_CHK_STATUS_RET_NOLOG(IsInputsValid(node)); - -/* - 临时方案: - 如下代码使用原型库注册的creator构造临时op_desc,获取其input_names设置到当前op desc上有缺陷。 - 1.只能恢复靠前的必选输入 - 2.不能恢复dynamic input - 3.不能区分传入了哪几个可选输入,全部恢复 - - 且该行为归属parser, 不应该由infershape干预。但因为tf parser等前端没有正确设置input names。直接去掉会导致部分算子infershape失败。 - 因此判断若input names以'__input'打头才需要刷新,作为临时方案。 - - 正式方案: - tf、caffee、onnx parser要将op desc的必备字段设置完整 - */ - const auto op_desc = node->GetOpDesc(); - const bool need_update_name = (node->GetType() != FRAMEWORKOP) && NeedUpdateIOName(op_desc); - GELOGD("Before update %s(%s) io name, input size %zu, %s, output size %zu, %s", op_desc->GetNamePtr(), - op_desc->GetTypePtr(), op_desc->GetInputsSize(), - IoNameToString("Input names", op_desc->GetAllInputName()).c_str(), - op_desc->GetOutputsSize(), IoNameToString("Output names", op_desc->GetAllOutputName()).c_str()); - if (need_update_name) { - const auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", node->GetType()); - if (node_op.IsEmpty()) { - GELOGW("[Verify][CheckParam] Get op from OperatorFactory failed, type: %s", node->GetType().c_str()); - } else { - GELOGD("get op from OperatorFactory success. opType: %s", node->GetType().c_str()); - const auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); - if (temp_op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "GetOpDescFromOperator failed, as return nullptr, type:%s", - node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Get][OpDesc] temp op desc is null, type:%s", node->GetType().c_str()); - return GRAPH_FAILED; - } - if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { - GELOGW("[Verify][Update] Update input name failed"); - } - if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { - GELOGW("[Verify][Update] Update output name failed"); - } - GELOGD("After update %s(%s) io name, input size %zu, %s, output size %zu, %s", op_desc->GetNamePtr(), - op_desc->GetTypePtr(), op_desc->GetInputsSize(), - IoNameToString("Input names", op_desc->GetAllInputName()).c_str(), - op_desc->GetOutputsSize(), IoNameToString("Output names", op_desc->GetAllOutputName()).c_str()); - } - node_op.BreakConnect(); - } - - if (op_desc->CommonVerify() == GRAPH_SUCCESS) { - Operator op = OpDescUtils::CreateOperatorFromNode(node); - auto verify_func = op_desc->GetVerifyFunc(); - if (verify_func == nullptr) { - verify_func = OperatorFactoryImpl::GetVerifyFunc(node->GetType()); - } - if (verify_func != nullptr) { - return static_cast(verify_func(op)); - } - return GRAPH_SUCCESS; - } else { - REPORT_INNER_ERR_MSG("E18888", "%s(%s) Verify failed.", node->GetName().c_str(), node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Call][CommonVerify] %s(%s) failed.", node->GetName().c_str(), node->GetType().c_str()); - return GRAPH_FAILED; - } -} -} // namespace ge diff --git a/graph/utils/op_desc_utils.cc b/graph/utils/op_desc_utils.cc deleted file mode 100644 index 0cd451f8b70d642d34c5cf81455d303e9431b2ea..0000000000000000000000000000000000000000 --- a/graph/utils/op_desc_utils.cc +++ /dev/null @@ -1,1123 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/op_desc_utils.h" - -#include - -#include "common/util/mem_utils.h" -#include "common/checker.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_util.h" -#include "graph/anchor.h" -#include "graph/compute_graph.h" -#include "graph/ge_context.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/op_desc.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/constant_utils.h" -#include "graph/utils/recover_ir_utils.h" -#include "graph/normal_graph/operator_impl.h" -#include "graph/type/sym_dtype.h" -#include "graph/detail/model_serialize_imp.h" -#include "mmpa/mmpa_api.h" -#include "graph/operator_factory_impl.h" - -/*lint -e512 -e737 -e752*/ -namespace ge { -const char_t OP_DESC_QUANT_PARAMS[] = "quantize_factor"; - -namespace { -const uint32_t CONST_OP_NORMAL_WEIGHT_SIZE = 1U; -const char* const kMultiThreadCompile = "MULTI_THREAD_COMPILE"; -const char* const kDisEnableFlag = "0"; -void GetConstantOpName(std::string &op_name) { - thread_local int64_t const_count = 0; - std::string compile_thread; - if ((ge::GetContext().GetOption(kMultiThreadCompile, compile_thread) == GRAPH_SUCCESS) - && (compile_thread.compare(kDisEnableFlag) == 0)) { - op_name = "dynamic_const_" + std::to_string(const_count); - } else { - op_name = "dynamic_const_" + std::to_string(GeLog::GetTid()) + "_" + std::to_string(const_count); - } - ++const_count; -} - -bool FindSubsequentMatches(const std::map &valid_index_2_names, size_t start_index, - const std::string &ir_name) { - for (size_t i = start_index; i < valid_index_2_names.size(); ++i) { - const auto name = valid_index_2_names.at(i); - if (name == ir_name) { - GELOGI("ir_name:%s, node input index:%zu", ir_name.c_str(), i); - return true; - } - } - return false; -} - -std::string InputsNamesStr(const OpDescPtr &op_desc) { - std::stringstream ss; - ss << "node: " << op_desc->GetName() << "(" << op_desc->GetType() << ") ir inputs names: ["; - for (const auto &ir_input : op_desc->GetIrInputs()) { - ss << ir_input.first << ", "; - } - ss << "], actual inputs names: ["; - for (size_t i = 0U; i < op_desc->GetAllInputsSize(); i++) { - if (op_desc->MutableInputDesc(static_cast(i)) != nullptr) { - const auto valid_name = op_desc->GetInputNameByIndex(static_cast(i)); - ss << valid_name << ", "; - } - } - ss << "]"; - return ss.str(); -} -} - -bool OpDescUtils::ClearInputDesc(const NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, REPORT_INNER_ERR_MSG("E18888", "param node is nullptr, check invalid."); - return false, "[Check][Param] node is nullptr"); - GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, REPORT_INNER_ERR_MSG("E18888", "opdesc is nullptr."); - return false, "[Check][Param] opdesc is nullptr"); - std::vector index_list; - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - if (in_anchor->GetPeerOutAnchor() == nullptr) { - index_list.push_back(in_anchor->GetIdx()); - } - } - std::sort(index_list.begin(), index_list.end()); - // Node's in anchor index need shrink - if (node->GetOpDesc()->impl_ == nullptr) { - GELOGE(FAILED, "[Clear][InputDesc] Op desc impl is nullptr. "); - return false; - } - for (size_t i = 0UL; i < index_list.size(); ++i) { - const auto iter = node->GetOpDesc()->impl_->inputs_desc_.begin() + static_cast(index_list[i]); - if (iter < node->GetOpDesc()->impl_->inputs_desc_.end()) { - (void)node->GetOpDesc()->impl_->inputs_desc_.erase(iter); - } else { - GELOGW("[Clear][InputDesc] inputs_desc_ iterator out of range."); - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearInputDesc(const OpDescPtr op_desc, - const uint32_t index) { - return NodeUtils::ClearInputDesc(op_desc, index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::HasQuantizeFactorParams(const OpDescPtr &op_desc) { - if (op_desc == nullptr) { - GELOGI("op_desc is nullptr"); - return false; - } - return op_desc->HasAttr(OP_DESC_QUANT_PARAMS); -} - -bool OpDescUtils::ClearOutputDesc(const NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, REPORT_INNER_ERR_MSG("E18888", "node is nullptr, check invalid."); - return false, "[Check][Param] node is nullptr"); - GE_CHK_BOOL_EXEC(node->GetOpDesc() != nullptr, REPORT_INNER_ERR_MSG("E18888", "opdesc is nullptr."); - return false, "[Check][Param] opdesc is nullptr"); - std::vector index_list; - for (const auto &out_anchor : node->GetAllOutDataAnchors()) { - if (out_anchor->GetPeerInDataAnchors().empty()) { - index_list.push_back(out_anchor->GetIdx()); - } - } - std::sort(index_list.begin(), index_list.end()); - // Node's out anchor index need shrink - if (node->GetOpDesc()->impl_ == nullptr) { - GELOGE(FAILED, "[Clear][OutputDesc] Op desc impl is nullptr. "); - return false; - } - for (size_t i = 0UL; i < index_list.size(); ++i) { - const auto iter = node->GetOpDesc()->impl_->outputs_desc_.begin() + static_cast(index_list[i]); - if (iter < node->GetOpDesc()->impl_->outputs_desc_.end()) { - (void)node->GetOpDesc()->impl_->outputs_desc_.erase(iter); - } else { - GELOGW("[Clear][OutputDesc] outputs_desc_ iterator out of range."); - } - } - - return true; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::ClearOutputDesc(const OpDescPtr &op_desc, - const uint32_t index) { - return NodeUtils::ClearOutputDesc(op_desc, index); -} - -bool OpDescUtils::HasQuantizeFactorParams(const OpDesc &op_desc) { return op_desc.HasAttr(OP_DESC_QUANT_PARAMS); } - -GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { - GeTensorPtr weight = nullptr; - (void)AttrUtils::MutableTensor(&op_desc, ATTR_NAME_WEIGHTS, weight); - return weight; -} - -GE_FUNC_HOST_VISIBILITY GeTensorPtr OpDescUtils::MutableWeights(const OpDescPtr op_desc) { - if (op_desc == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "op_desc is null, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] op_desc is null"); - return nullptr; - } - return MutableWeights(*op_desc); -} - -graphStatus OpDescUtils::SetWeights(OpDesc &op_desc, const GeTensorPtr weight) { - if (weight == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "weight is null, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] weight is null"); - return GRAPH_FAILED; - } - return AttrUtils::SetTensor(&op_desc, ATTR_NAME_WEIGHTS, weight) ? GRAPH_SUCCESS : GRAPH_FAILED; -} - -graphStatus OpDescUtils::SetWeights(OpDescPtr op_desc, const GeTensorPtr weight) { - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(weight); - return SetWeights(*op_desc, weight); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::vector OpDescUtils::GetWeights(const ge::Node &node) { - auto weights = MutableWeights(node); - std::vector ret(weights.size()); - (void)std::copy(weights.begin(), weights.end(), ret.begin()); - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDescUtils::GetWeights( - const ge::ConstNodePtr &node) { - if (node == nullptr) { - return std::vector(); - } - return GetWeights(*node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDescUtils::GetConstInputNode( - const ge::Node &node) { - std::vector ret; - const auto in_anchors = node.GetAllInDataAnchors(); - for (const auto &in_anchor : in_anchors) { - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - // normally out_anchor could be null, this is ok - GELOGD("node %s' peer_out_anchor is null", node.GetName().c_str()); - continue; - } - auto in_node = out_anchor->GetOwnerNode(); - while (true) { - if (in_node == nullptr) { - break; - } - if (ConstantUtils::IsConstant(in_node)) { - ret.push_back(in_node); - break; - } else if (in_node->GetType() == DATA) { - if (NodeUtils::IsWhileVaryingInput(in_node)) { - break; - } - in_node = NodeUtils::GetParentInput(in_node); - } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { - bool is_constant = false; - (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); - if (!is_constant) { - break; - } - // Enter node has and only has one input - if (in_node->GetInDataNodesSize() != 1U) { - GELOGW("[Get][ConstInput] Check number of input_nodes for Enter node %s failed, input_node_num=%zu.", - in_node->GetName().c_str(), in_node->GetInDataNodes().size()); - break; - } - in_node = in_node->GetInDataNodes().at(0UL); - } else { - break; - } - } - } - return ret; -} - -std::vector OpDescUtils::GetConstInputNodeAndAnchor(const ge::Node &node) { - std::vector> ret; - const auto in_nodes_and_anchors = node.GetInDataNodesAndAnchors(); - for (const auto &in_node_2_anchor : in_nodes_and_anchors) { - auto in_node = in_node_2_anchor.first; - auto in_node_2_out_anchor = in_node_2_anchor; - while (true) { - if (in_node == nullptr) { - break; - } - if (ConstantUtils::IsConstant(in_node)) { - ret.push_back(in_node_2_out_anchor); - break; - } else if (in_node->GetType() == DATA) { - if (NodeUtils::IsWhileVaryingInput(in_node)) { - break; - } - in_node_2_out_anchor = NodeUtils::GetParentInputAndAnchor(in_node); - in_node = in_node_2_out_anchor.first; - } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { - bool is_constant = false; - (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); - if (!is_constant) { - break; - } - // Enter node has and only has one input - if (in_node->GetInDataNodesSize() != 1U) { - GELOGW("[Get][ConstInput] Check number of input_nodes for Enter node %s failed, input_node_num=%zu.", - in_node->GetName().c_str(), in_node->GetInDataNodes().size()); - break; - } - if (in_node->GetInDataAnchor(0) == nullptr) { - break; - } - auto peer_out_anchor = in_node->GetInDataAnchor(0)->GetPeerOutAnchor(); - if (peer_out_anchor == nullptr) { - break; - } - in_node = peer_out_anchor->GetOwnerNode(); - in_node_2_out_anchor = std::make_pair(in_node, peer_out_anchor); - } else { - break; - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDescUtils::GetInputData( - const std::vector &input_nodes) { - std::vector ret; - - for (const auto &input_node : input_nodes) { - const auto temp_weight = MutableWeights(input_node->GetOpDesc()); - if (temp_weight == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "const op's weight is null, name: %s", input_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][MutableWeights] const op's weight is null, name: %s", - input_node->GetName().c_str()); - return std::vector(); - } - ret.push_back(temp_weight); - } - - return ret; -} - -vector OpDescUtils::GetWeightsFromNodes( - const std::vector &input_nodes_2_out_anchors) { - std::vector ret; - for (const auto &input_node_2_anchor : input_nodes_2_out_anchors) { - const auto input_node = input_node_2_anchor.first; - GeTensorPtr temp_weight ; - (void)ConstantUtils::MutableWeight(input_node->GetOpDesc(), - static_cast(input_node_2_anchor.second->GetIdx()), - temp_weight); - if (temp_weight == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "const op's weight is null, name: %s", input_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][MutableWeights] const op's weight is null, name: %s", - input_node->GetName().c_str()); - return std::vector(); - } - ret.push_back(temp_weight); - } - - return ret; -} -size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { - if (NodeUtils::IsAnchorStatusSet(node)) { - size_t input_num = 0UL; - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - input_num++; - continue; - } - } - return input_num; // lint !e712 - } else { - GE_IF_BOOL_EXEC( - node.GetInDataNodesSize() < GetConstInputs(node).size(), - REPORT_INNER_ERR_MSG("E18888", "InDataNodes size:%zu is smaller than ConstInputs size:%zu", - node.GetInDataNodes().size(), GetConstInputs(node).size()); - GELOGE(GRAPH_FAILED, "[Check][Param] %zu is smaller than %zu", - node.GetInDataNodes().size(), GetConstInputs(node).size()); - return 0UL); - return node.GetInDataNodesSize() - GetConstInputs(node).size(); - } -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDescUtils::GetNonConstInputsSize(const ge::ConstNodePtr node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] Node is nullptr"); - return 0UL; - } - return GetNonConstInputsSize(*node); -} - -GeTensorDesc OpDescUtils::GetNonConstInputTensorDesc(const ge::Node &node, const size_t index_non_const) { - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, REPORT_INNER_ERR_MSG("E18888", "node.GetOpDesc() is nullptr!"); - return GeTensorDesc(), "[Check][Param] node.GetOpDesc() is nullptr!"); - size_t i = 0UL; - if (NodeUtils::IsAnchorStatusSet(node)) { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - if (index_non_const == i) { - return node.GetOpDesc()->GetInputDesc(static_cast(anchor->GetIdx())); - } - ++i; - } - } - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - const auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - continue; - } - const auto owner_node = peer_anchor->GetOwnerNodeBarePtr(); - if (owner_node == nullptr) { - continue; - } - if (owner_node->GetType() == CONSTANT) { - continue; - } - if (index_non_const == i) { - return node.GetOpDesc()->GetInputDesc(static_cast(anchor->GetIdx())); - } - ++i; - } - } - return GeTensorDesc(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc -OpDescUtils::GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, const size_t index_non_const) { - CHECK_FALSE_EXEC(node != nullptr, return GeTensorDesc()); - return GetNonConstInputTensorDesc(*node, index_non_const); -} - -bool OpDescUtils::GetNonConstInputIndex(const ge::Node &node, const size_t index_non_const, size_t &index) { - bool ret = false; - size_t i = 0UL; - if (NodeUtils::IsAnchorStatusSet(node)) { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(anchor) == ANCHOR_DATA) { - if (index_non_const == i) { - index = static_cast(anchor->GetIdx()); - ret = true; - } - ++i; - } - } - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - const auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - continue; - } - const auto owner_node = peer_anchor->GetOwnerNodeBarePtr(); - if (owner_node == nullptr) { - continue; - } - if (owner_node->GetType() == CONSTANT) { - continue; - } - if (index_non_const == i) { - index = static_cast(anchor->GetIdx()); - ret = true; - } - ++i; - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::GetNonConstInputIndex(const ge::ConstNodePtr &node, - const size_t index_non_const, - size_t &index) { - CHECK_FALSE_EXEC(node != nullptr, return false); - return GetNonConstInputIndex(*node, index_non_const, index); -} - -bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { - bool ret = false; - if (index < static_cast(node.GetAllInDataAnchorsSize())) { - if (NodeUtils::IsAnchorStatusSet(node)) { - ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast(index))) == - ANCHOR_DATA); // lint !e712 - } else { - for (const auto &anchor : node.GetAllInDataAnchors()) { - if (anchor->GetIdx() != static_cast(index)) { - continue; - } - const auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - break; - } - const auto owner_node = peer_anchor->GetOwnerNodeBarePtr(); - if (owner_node == nullptr) { - break; - } - ret = (owner_node->GetType() != CONSTANT); - } - } - } - - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDescUtils::IsNonConstInput(const ge::ConstNodePtr &node, - const size_t index) { - CHECK_FALSE_EXEC(node != nullptr, return false); - return IsNonConstInput(*node, index); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDescUtils::GetConstInputs( - const ge::ConstNodePtr &node) { - if (node == nullptr) { - return std::vector(); - } - return GetConstInputs(*node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector OpDescUtils::GetNonConstTensorDesc( - const ge::ConstNodePtr &node) { - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - return std::vector(); - } - std::vector ret; - if (NodeUtils::IsAnchorStatusSet(*node)) { - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { - ret.push_back(node->GetOpDesc()->GetInputDesc(static_cast(in_anchor->GetIdx()))); - } - } - } else { - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if ((out_anchor == nullptr) || (out_anchor->GetOwnerNodeBarePtr()->GetOpDesc() == nullptr)) { - continue; - } - if (out_anchor->GetOwnerNodeBarePtr()->GetOpDesc()->GetType() != CONSTANT) { - ret.push_back(node->GetOpDesc()->GetInputDesc(static_cast(in_anchor->GetIdx()))); - } - } - } - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::vector OpDescUtils::GetConstInputs(const ge::Node &node, const uint32_t depth) { - std::vector ret; - if (depth == 0U) { - return ret; - } - - const auto in_anchors = node.GetAllInDataAnchors(); - for (const auto &in_anchor : in_anchors) { - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - continue; - } - - const auto in_node = out_anchor->GetOwnerNode(); - if (in_node->GetType() == CONSTANT) { - ret.push_back(in_node); - } else if ((in_node->GetType() == SWITCH) && (node.GetType() == MATMUL)) { - // const --> switch --> matmul - auto switch_input = GetConstInputs(*in_node, depth - 1U); - if (switch_input.size() > 0U) { - (void)ret.insert(ret.end(), switch_input.begin(), switch_input.end()); - } - } else if (in_node->GetType() == DATA) { - const auto parent = NodeUtils::GetParentInput(in_node); - if ((parent != nullptr) && (parent->GetType() == CONSTANT)) { - ret.push_back(parent); - } - } else { - // do nothing - } - } - return ret; -} - - -graphStatus OpDescUtils::SetNoneConstNodeWeights(ge::Node &node, const std::vector &weights) { - const auto input_nodes = GetConstInputs(node); - if (weights.size() < input_nodes.size()) { - REPORT_INNER_ERR_MSG("E18888", "weights count:%zu can't be less than const input count:%zu, node:%s(%s)", - weights.size(), input_nodes.size(), node.GetName().c_str(), node.GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] weights count:%zu can't be less than const input count:%zu", - weights.size(), input_nodes.size()); - return GRAPH_PARAM_INVALID; - } - - ge::NamedAttrs named_attrs; - (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); - std::vector copy_weights; - (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); - - for (size_t i = 0UL; i < input_nodes.size(); ++i) { - if (input_nodes[i]->GetOpDesc() != nullptr) { - if (SetWeights(input_nodes[i]->GetOpDesc(), copy_weights[i]) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "set weights failed, node:%s(%s)", input_nodes[i]->GetName().c_str(), - input_nodes[i]->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Set][Weights] failed, node:%s(%s)", - input_nodes[i]->GetName().c_str(), input_nodes[i]->GetType().c_str()); - return GRAPH_FAILED; - } - } - } - - // If set more weights than constop, need to add constop - for (size_t i = input_nodes.size(); i < copy_weights.size(); ++i) { - // Use org weight before SetWeights Overwrite - const auto const_opdesc = CreateConstOp(copy_weights[i]); - GE_CHECK_NOTNULL(const_opdesc); - - const auto owner_graph = node.GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node's graph is empty, node name: %s", node.GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] node's graph is empty, name: %s", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - const auto const_node = owner_graph->AddNodeFront(const_opdesc); - GE_CHK_BOOL_EXEC(node.AddLinkFrom(const_node) == GRAPH_SUCCESS, - REPORT_INNER_ERR_MSG("E18888", "node:%s add link failed.", node.GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][AddLinkFrom] graph add link failed! node:%s", - node.GetName().c_str()); - return GRAPH_FAILED); - const std::vector original_nodes; - ge::GraphUtils::RecordOriginalNames(original_nodes, const_node); - } - return GRAPH_SUCCESS; -} - -graphStatus OpDescUtils::SetNoneConstNodeWeights(ge::Node &node, const std::map &weights_map) { - for (const auto &pair:weights_map) { - const auto idx = pair.first; - // idx = in data anchor size is valid, it meant to add a new const node - if ((idx < 0) || (static_cast(idx) > node.GetAllInDataAnchorsSize())) { - REPORT_INNER_ERR_MSG("E18888", "Invalid map key: %d of node[%s].", idx, node.GetName().c_str()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] Invalid map key: %d of node[%s].", idx, node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - const auto peer_node = NodeUtils::GetInDataNodeByIndex(node, idx); - if (peer_node != nullptr) { - // a. update const input node - if (peer_node->GetType() != CONSTANT) { - REPORT_INNER_ERR_MSG("E18888", "op %s [%d]'s input node should be const, but is %s type:%s ", - node.GetName().c_str(), pair.first, peer_node->GetName().c_str(), - peer_node->GetType().c_str()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] op %s [%d]'s input node should be const, but is %s type:%s ", - node.GetName().c_str(), pair.first, peer_node->GetName().c_str(), peer_node->GetType().c_str()); - } - if (SetWeights(peer_node->GetOpDesc(), pair.second) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "set weights failed, node:%s(%s)", peer_node->GetName().c_str(), - peer_node->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Set][Weights] failed, node:%s(%s)", - peer_node->GetName().c_str(), peer_node->GetType().c_str()); - return GRAPH_FAILED; - } - } else { - // b. create new const input node - const auto const_opdesc = CreateConstOp(pair.second); - GE_CHECK_NOTNULL(const_opdesc); - const auto owner_graph = node.GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node's graph is empty, node name: %s", node.GetName().c_str()); - GELOGE(GRAPH_PARAM_INVALID, "[Get][Graph] node's graph is empty, name: %s", node.GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - const auto const_node = owner_graph->AddNodeFront(const_opdesc); - if (node.AddLinkFrom(static_cast(pair.first), const_node) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "op %s add const to input index[%d] failed", node.GetName().c_str(), pair.first); - GELOGE(GRAPH_FAILED, "[Invoke][AddLinkFrom] op %s add const to input index[%d] failed", - node.GetName().c_str(), pair.first); - return GRAPH_FAILED; - } - } - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::vector OpDescUtils::MutableWeights(const ge::Node &node) { - std::vector ret; - auto op_desc = node.GetOpDesc(); - GE_CHK_BOOL_EXEC(op_desc != nullptr, REPORT_INNER_ERR_MSG("E18888", "param node's op_desc is nullptr."); - return ret, "[Check][Param] op_desc is nullptr!"); - // Const operator, take the weight directly - if ((op_desc->GetType() == CONSTANT) || (op_desc->GetType() == CONSTANTOP)) { - const auto weight = MutableWeights(op_desc); - if (weight == nullptr) { - GELOGD("op type %s has no weight, op name:%s", node.GetType().c_str(), node.GetName().c_str()); - return ret; - } - ret.push_back(weight); - return ret; - } - // Place holder operator, try to get the weight from parent node - // when parent node is const operator - if (node.GetType() == PLACEHOLDER) { - ConstGeTensorPtr ge_tensor = nullptr; - if (NodeUtils::TryGetWeightByPlaceHolderNode(std::const_pointer_cast(node.shared_from_this()), ge_tensor) == - GRAPH_SUCCESS && - ge_tensor != nullptr) { - ret.push_back(std::const_pointer_cast(ge_tensor)); - } - return ret; - } - - if (node.GetType() == DATA) { - ConstGeTensorPtr ge_tensor = nullptr; - if (NodeUtils::TryGetWeightByDataNode(std::const_pointer_cast(node.shared_from_this()), ge_tensor) == - GRAPH_SUCCESS && - ge_tensor != nullptr) { - ret.push_back(std::const_pointer_cast(ge_tensor)); - } - return ret; - } - - // Other operators, get weights from connected constop - const auto input_nodes = GetConstInputs(node); - for (const auto &input_node : input_nodes) { - const auto temp_weight = MutableWeights(input_node->GetOpDesc()); - if (temp_weight == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "const op's weight is null, name: %s", input_node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Invoke][MutableWeights] const op's weight is null, name: %s", - input_node->GetName().c_str()); - return std::vector(); - } - ret.push_back(temp_weight); - } - - return ret; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -std::vector OpDescUtils::MutableWeights(const ge::NodePtr node) { - if (node == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Node is nullptr"); - return std::vector(); - } - return MutableWeights(*node); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::Node &node, const std::vector &weights) { - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, REPORT_INNER_ERR_MSG("E18888", "opdesc of node is nullptr."); - return GRAPH_PARAM_INVALID, "[Check][Param] node.GetOpDesc is nullptr!"); - if (node.GetOpDesc()->GetType() == CONSTANT) { - if (weights.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { - return SetWeights(node.GetOpDesc(), weights[0UL]); - } - GELOGI("const op weight size %zu should be 1", weights.size()); - return GRAPH_PARAM_INVALID; - } - - return SetNoneConstNodeWeights(node, weights); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::Node &node, const std::map &weights_map) { - GE_CHECK_NOTNULL(node.GetOpDesc()); - // 1. node is const - if (node.GetOpDesc()->GetType() == CONSTANT) { - if (weights_map.size() == CONST_OP_NORMAL_WEIGHT_SIZE) { - return SetWeights(node.GetOpDesc(), weights_map.begin()->second); - } - REPORT_INNER_ERR_MSG("E18888", "const op %s weight size %zu should be 1", node.GetName().c_str(), - weights_map.size()); - GELOGE(GRAPH_PARAM_INVALID, "[Check][Param] const op %s weight size %zu should be 1", - node.GetName().c_str(), weights_map.size()); - return GRAPH_PARAM_INVALID; - } - // 2. node is not const - auto const ret = SetNoneConstNodeWeights(node, weights_map); - if (ret != GRAPH_SUCCESS) { - return ret; - } - NodeUtils::UpdateIsInputConst(node); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::CloneOpDesc(const ConstOpDescPtr &org_op_desc) { - return GraphUtils::CloneOpDesc(org_op_desc); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescUtils::CopyOpDesc(const ConstOpDescPtr &org_op_desc) { - return GraphUtils::CopyOpDesc(org_op_desc); -} - -OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { - return CreateConstOp(tensor_ptr, true); -} - -OpDescPtr OpDescUtils::CreateConstOpZeroCopy(const GeTensorPtr& tensor_ptr) { - return CreateConstOp(tensor_ptr, false); -} - -OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr, const bool copy) { - GE_ASSERT_NOTNULL(tensor_ptr); - const shared_ptr const_opdesc = ComGraphMakeShared(); - GE_ASSERT_NOTNULL(const_opdesc, "[Create][OpDesc] failed."); - if (copy) { - GE_ASSERT_GRAPH_SUCCESS(SetWeights(const_opdesc, tensor_ptr), "[Set][Weights] failed, op[%s]", - const_opdesc->GetNamePtr()); - } else { - GE_ASSERT_TRUE(AttrUtils::SetShareTensor(const_opdesc, ATTR_NAME_WEIGHTS, *tensor_ptr), - "[Set][ShardTensor] success for %s.", const_opdesc->GetNamePtr()); - } - const_opdesc->SetType(CONSTANT); - std::string op_name; - GetConstantOpName(op_name); - const_opdesc->SetName(op_name); - GELOGI("add const op: %s", const_opdesc->GetNamePtr()); - (void)const_opdesc->AddOutputDesc("y", tensor_ptr->GetTensorDesc()); - GELOGI("after add const op: %s", const_opdesc->GetName().c_str()); - return const_opdesc; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::AddConstOpToAnchor(const InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr) { - GE_CHECK_NOTNULL(in_anchor); - GE_CHECK_NOTNULL(tensor_ptr); - const auto const_opdesc = CreateConstOp(tensor_ptr); - GE_CHECK_NOTNULL(const_opdesc); - const auto in_node = in_anchor->GetOwnerNodeBarePtr(); - GE_CHECK_NOTNULL(in_node); - const auto owner_graph = in_node->GetOwnerComputeGraph(); - if (owner_graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "node's graph is empty, name: %s", in_node->GetName().c_str()); - GELOGE(GRAPH_PARAM_INVALID, "[Get][Graph] node's graph is empty, name: %s", in_node->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - const auto const_node = in_node->GetOwnerComputeGraph()->AddNodeFront(const_opdesc); - GE_CHECK_NOTNULL(const_node); - if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "AddEdge const %s to node %s failed", const_node->GetName().c_str(), - in_node->GetName().c_str()); - GELOGE(GRAPH_PARAM_INVALID, "[Add][Edge] const %s to node %s failed.", const_node->GetName().c_str(), - in_node->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -OpDescUtils::SetWeights(ge::NodePtr node, const std::vector &weights) { - GE_CHECK_NOTNULL(node); - return SetWeights(*node, weights); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWeights(const ge::NodePtr node) { - GE_CHECK_NOTNULL(node); - const auto const_ops = GetConstInputs(node); - const auto graph = node->GetOwnerComputeGraph(); - if (graph == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "GetOwnerComputeGraph failed, graph is nullptr, node:%s", node->GetName().c_str()); - GELOGE(GRAPH_FAILED, "[Get][Graph] Graph is nullptr"); - return GRAPH_PARAM_INVALID; - } - for (const auto &const_op : const_ops) { - GE_CHK_STATUS_RET(GraphUtils::IsolateNode(const_op, {}), "[Isolate][Node] %s, type:%s failed", - const_op->GetName().c_str(), const_op->GetType().c_str()); - GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph, const_op), - "[Remove][Node] %s, type: %s without relink failed", const_op->GetName().c_str(), - const_op->GetType().c_str()); - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -graphStatus OpDescUtils::SetSubgraphInstanceName(const std::string &subgraph_name, - const std::string &subgraph_instance_name, - OpDescPtr &op_desc) { - const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); - const auto iter = subgraph_names_to_index.find(subgraph_name); - if (iter == subgraph_names_to_index.end()) { - REPORT_INNER_ERR_MSG( - "E18888", "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exist", - subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), subgraph_name.c_str()); - GELOGE(GRAPH_PARAM_INVALID, - "[Check][Param] Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exist", - subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), subgraph_name.c_str()); - return GRAPH_PARAM_INVALID; - } - - return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -ConstGeTensorBarePtr OpDescUtils::GetInputConstData(const Operator &op, const uint32_t idx) { - if (op.operator_impl_ == nullptr) { - AscendString op_name; - (void)op.GetName(op_name); - GELOGW("[Check][Param] Op(%s) operator_impl_ is nullptr.", op_name.GetString()); - return nullptr; - } - - ConstGeTensorPtr ge_tensor = nullptr; - if (op.operator_impl_->GetInputConstData(idx, ge_tensor) == GRAPH_SUCCESS) { - return ge_tensor.get(); - } - AscendString name; - (void) op.GetName(name); - AscendString type; - (void) op.GetOpType(type); - GELOGI("[Get][ConstInput] Op(%s %s) is unable to get const data with input index[%u] ", - name.GetString(), type.GetString(), idx); - return nullptr; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -void OpDescUtils::SetRuntimeContextToOperator(const Operator &op, RuntimeInferenceContext *const context) { - op.operator_impl_->runtime_context_ = context; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -void OpDescUtils::SetCallbackGetConstInputFuncToOperator(const Operator &op, - GetConstInputOnRuntimeFun get_const_input_func) { - op.operator_impl_->get_const_input_runtime_ = get_const_input_func; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY -bool OpDescUtils::HasCallbackGetConstInputFunc(const Operator &op) { - return (op.operator_impl_->get_const_input_runtime_ != nullptr); -} - -ge::graphStatus IrInputRequiredCall(const OpDescPtr &op_desc, size_t ir_index, size_t start_index, size_t all_ins_num, - const std::string &ir_name, - const std::map &valid_index_2_names, size_t &instance_num) { - (void)all_ins_num; - const auto max_index = valid_index_2_names.rbegin()->first; - if (start_index > max_index) { - GELOGW("Failed to get instance num for node %s, current name %s current index %zu out of range %u", - op_desc->GetName().c_str(), ir_name.c_str(), start_index, max_index); - instance_num = 1U; - return ge::SUCCESS; - } - const auto name = valid_index_2_names.at(start_index); - if (name != ir_name) { - GELOGW("Failed to get instance num for node %s, can not find the input for ir name %s, current index %zu, " - "current name %s", - op_desc->GetName().c_str(), ir_name.c_str(), start_index, name.c_str()); - if (FindSubsequentMatches(valid_index_2_names, start_index + 1U, ir_name)) { - GELOGE(ge::FAILED, "Find another input name that match ir name. ir_index:%zu, ir_name:%s, inputs names:%s", - ir_index, ir_name.c_str(), InputsNamesStr(op_desc).c_str()); - return FAILED; - } - } - instance_num = 1U; - GELOGD("Get instance num %zu for node %s, current name %s current ir index %zu, start_index %zu", instance_num, - op_desc->GetName().c_str(), ir_name.c_str(), ir_index, start_index); - return ge::SUCCESS; -} - -ge::graphStatus IrInputOptionalCall(const OpDescPtr &op_desc, size_t ir_index, size_t start_index, size_t all_ins_num, - const std::string &ir_name, - const std::map &valid_index_2_names, size_t &instance_num) { - (void)all_ins_num; - const auto max_index = valid_index_2_names.rbegin()->first; - // ooooooxxx - // o : required input - // x : option input - if (start_index > max_index) { - instance_num = 0U; - return ge::SUCCESS; - } - const auto name = valid_index_2_names.at(start_index); - if (name == ir_name) { - instance_num = 1U; - } else { - instance_num = 0U; - } - GELOGD("Get instance num %zu for node %s, current name %s current ir index %zu, start_index %zu", instance_num, - op_desc->GetName().c_str(), ir_name.c_str(), ir_index, start_index); - return ge::SUCCESS; -} - -ge::graphStatus IrDynamicCall(const OpDescPtr &op_desc, size_t ir_index, size_t start_index, size_t all_ins_num, - const std::string &ir_name, - const std::map &valid_index_2_names, size_t &instance_num) { - size_t dyn_i = 0; - const auto max_index = valid_index_2_names.rbegin()->first; - for (size_t i = start_index; i < all_ins_num; ++i, ++dyn_i) { - if (i > max_index) { - break; - } - const auto name = valid_index_2_names.at(i); - if (name != ir_name + std::to_string(dyn_i)) { - break; - } - } - instance_num = dyn_i; - GELOGD("Get instance num %zu for node %s, current name %s current ir index %zu, start_index %zu", instance_num, - op_desc->GetName().c_str(), ir_name.c_str(), ir_index, start_index); - return ge::SUCCESS; -} - -ge::graphStatus GetOutputInstanceNum(const OpDescPtr &op_desc, size_t ir_index, size_t start_index, - const std::map &valid_index_2_names, size_t &instance_num) { - GE_CHECK_NOTNULL(op_desc); - if (valid_index_2_names.empty()) { - GELOGD("Node %s has not any outputs, just return", op_desc->GetName().c_str()); - return ge::SUCCESS; - } - const auto &ir_outputs = op_desc->GetIrOutputs(); - const auto ir_type = ir_outputs[ir_index].second; - const auto ir_name = ir_outputs[ir_index].first; - using GetInstanceCall = std::function &valid_index_2_names, size_t &instance_num)>; - static std::map get_instance_calls = {{kIrOutputRequired, &IrInputRequiredCall}, - {kIrOutputDynamic, &IrDynamicCall}}; - const auto it = get_instance_calls.find(ir_type); - if (it != get_instance_calls.end()) { - const size_t all_ins_num = op_desc->GetAllOutputsDescSize(); - return (it->second)(op_desc, ir_index, start_index, all_ins_num, ir_name, valid_index_2_names, instance_num); - } - GELOGE(ge::FAILED, "Failed to get instance num for node %s, unknown ir output type %d, ir name %s", - op_desc->GetName().c_str(), ir_type, ir_name.c_str()); - return ge::FAILED; -} - -ge::graphStatus OpDescUtils::GetInstanceNum(const OpDescPtr &op_desc, size_t ir_index, size_t start_index, - const std::map &valid_index_2_names, - size_t &instance_num) { - GE_CHECK_NOTNULL(op_desc); - if (valid_index_2_names.empty()) { - GELOGD("Node %s has not any inputs, just return", op_desc->GetName().c_str()); - return ge::SUCCESS; - } - const auto &ir_inputs = op_desc->GetIrInputs(); - const auto ir_type = ir_inputs[ir_index].second; - const auto ir_name = ir_inputs[ir_index].first; - using GetInstanceCall = std::function &valid_index_2_names, size_t &instance_num)>; - static std::map get_instance_calls = {{kIrInputRequired, &IrInputRequiredCall}, - {kIrInputOptional, &IrInputOptionalCall}, - {kIrInputDynamic, &IrDynamicCall}}; - const auto it = get_instance_calls.find(ir_type); - if (it != get_instance_calls.end()) { - const size_t all_ins_num = op_desc->GetAllInputsSize(); - return (it->second)(op_desc, ir_index, start_index, all_ins_num, ir_name, valid_index_2_names, instance_num); - } - GELOGE(ge::FAILED, "Failed to get instance num for node %s, unknown ir input type %d, ir name %s", - op_desc->GetName().c_str(), ir_type, ir_name.c_str()); - return ge::FAILED; -} - -std::map> OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap( - const OpDescPtr &op_desc) { - std::map> ir_index_to_instance_index_pair_map; - - if (GetIrInputInstanceDescRange(op_desc, ir_index_to_instance_index_pair_map) == GRAPH_SUCCESS) { - return ir_index_to_instance_index_pair_map; - } - - return {}; -} - -std::map> OpDescUtils::GetOutputIrIndexes2InstanceIndexesPairMap( - const OpDescPtr &op_desc) { - std::map> ir_index_to_instance_index_pair_map; - - if (GetIrOutputDescRange(op_desc, ir_index_to_instance_index_pair_map) == GRAPH_SUCCESS) { - return ir_index_to_instance_index_pair_map; - } - - return {}; -} - -ge::graphStatus OpDescUtils::GetInputIrIndexByInstanceIndex(const OpDescPtr &op_desc, - size_t instance_index, size_t &ir_index) { - GE_CHECK_NOTNULL(op_desc); - auto ir_index_to_instance_index_pair_map = GetInputIrIndexes2InstanceIndexesPairMap(op_desc); - if (ir_index_to_instance_index_pair_map.empty()) { - GELOGE(ge::GRAPH_FAILED, - "node [%s(%s)] get ir indexes to instance indexes list failed, instance_index[%zu], which is empty", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), instance_index); - return ge::GRAPH_FAILED; - } - ir_index = std::numeric_limits::max(); - for (size_t i = 0U; i < op_desc->GetIrInputs().size(); ++i) { - const auto &index_pair = ir_index_to_instance_index_pair_map[i]; - size_t ir_index_end = 0U; - GE_ASSERT_TRUE(!ge::AddOverflow(index_pair.first, index_pair.second, ir_index_end)); - if ((instance_index >= index_pair.first) && (instance_index < ir_index_end)) { - ir_index = i; - GELOGD("node [%s(%s)] get ir index [%zu] successfully!", op_desc->GetName().c_str(), op_desc->GetType().c_str(), - ir_index); - return ge::GRAPH_SUCCESS; - } - } - ir_index = std::numeric_limits::max(); - GELOGW("node [%s(%s)] failed to get ir index by instance index[%zu], set ir_index to %zu", op_desc->GetName().c_str(), - op_desc->GetType().c_str(), instance_index, ir_index); - return GRAPH_SUCCESS; -} - -graphStatus OpDescUtils::GetIrInputInstanceDescRange(const OpDescPtr &op, - std::map> &ir_input_2_range) { - return ge::GetIrInputInstanceDescRange(op, ir_input_2_range); -} - -graphStatus OpDescUtils::GetIrInputRawDescRange(const OpDescPtr &op, - std::map> &ir_input_2_range) { - return ge::GetIrInputRawDescRange(op, ir_input_2_range); -} - -graphStatus OpDescUtils::GetIrOutputDescRange(const OpDescPtr &op, - std::map> &ir_output_2_range) { - return ge::GetIrOutputDescRange(op, ir_output_2_range); -} - -graphStatus OpDescUtils::GetPromoteIrInputList(const OpDescPtr &op_desc, - std::vector> &promote_index_list) { - GE_ASSERT_NOTNULL(op_desc); - const ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("temp_operator", op_desc->GetType().c_str()); - const auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir); - GE_ASSERT_NOTNULL(opdesc_ir); - return opdesc_ir->GetPromoteIrInputList(promote_index_list); -} - -graphStatus OpDescUtils::GetPromoteInstanceInputList(const OpDescPtr &op_desc, - std::vector> &promote_index_list) { - GE_ASSERT_NOTNULL(op_desc); - GE_ASSERT_SUCCESS(ge::RecoverIrUtils::RecoverOpDescIrDefinition(op_desc, op_desc->GetTypePtr())); - auto ir_ranges = GetInputIrIndexes2InstanceIndexesPairMap(op_desc); - std::vector> ir_promote_index_list; - GE_ASSERT_SUCCESS(op_desc->GetPromoteIrInputList(ir_promote_index_list)); - for (const auto& ir_input_indexes : ir_promote_index_list) { - std::vector instance_input_indexes; - for (const auto& ir_input_index : ir_input_indexes) { - auto ir_range = ir_ranges.find(ir_input_index); - if (ir_range == ir_ranges.end()) { - continue; - } - for (size_t i = ir_range->second.first; i < ir_range->second.second + ir_range->second.first; i++) { - instance_input_indexes.push_back(i); - } - } - promote_index_list.push_back(instance_input_indexes); - } - return ge::GRAPH_SUCCESS; -} -} // namespace ge -/*lint +e512 +e737 +e752*/ diff --git a/graph/utils/op_desc_utils_ex.cc b/graph/utils/op_desc_utils_ex.cc deleted file mode 100644 index 9c5e97623cad435fca606bc306bfaae47df16d97..0000000000000000000000000000000000000000 --- a/graph/utils/op_desc_utils_ex.cc +++ /dev/null @@ -1,336 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/op_desc_utils_ex.h" - -#include "common/ge_common/util.h" -#include "common/util/trace_manager/trace_manager.h" -#include "graph/normal_graph/operator_impl.h" -#include "graph/operator_factory_impl.h" -#include "graph/common_error_codes.h" -#include "graph/ge_context.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/transformer_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/utils/recover_ir_utils.h" -#include "common/util/mem_utils.h" -#include "common/checker.h" -#include "debug/ge_op_types.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -namespace { -std::function TryGetV1InferFunc(const OpDescPtr &op_desc) { - auto infer_func = op_desc->GetInferFunc(); - if (infer_func != nullptr) { - return infer_func; - } - return OperatorFactoryImpl::GetInferShapeFunc(op_desc->GetType()); -} -bool EnableIgnoreInferError() { - const char_t *env_value = nullptr; - MM_SYS_GET_ENV(MM_ENV_IGNORE_INFER_ERROR, env_value); - if (env_value == nullptr) { - GELOGD("Can not get env [IGNORE_INFER_ERROR]. Disable ignore infer validation."); - return false; - } - - std::string env_str_value = std::string(env_value); - GELOGI("Got value of env[IGNORE_INFER_ERROR] is [%s].", env_str_value.c_str()); - return !env_str_value.empty(); -} -} - -graphStatus OpDescUtilsEx::CallInferFuncV2(const OpDescPtr &op_desc, Operator &op) { - const auto call_infer_data_type = OperatorFactoryImpl::GetInferDataTypeFunc(); - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - if ((call_infer_data_type == nullptr) || (call_infer_shape_v2 == nullptr) || (call_infer_shape_range == nullptr)) { - GELOGW("[Call][InferFuncV2] Node %s(%s) has no infer func v2 either v1. Please check op proto to make sure at " - "least has one.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return GRAPH_FAILED; - } - if (op_desc->GetIrInputs().empty() && op_desc->GetIrOutputs().empty() && op_desc->GetAllOutputsDescSize() != 0U) { - GE_CHK_STATUS_RET(RecoverIrUtils::RecoverOpDescIrDefinition(op_desc), "Failed recover ir def for %s %s", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - } - GE_WARN_ASSERT_GRAPH_SUCCESS(call_infer_data_type(op_desc), - "[Call][InferFuncV2]Failed to infer data_type of node %s[%s].", op_desc->GetNamePtr(), - op_desc->GetTypePtr()); - GE_WARN_ASSERT_GRAPH_SUCCESS(call_infer_shape_v2(op, op_desc), - "[Call][InferFuncV2]Failed to infer shape of node %s[%s].", op_desc->GetNamePtr(), - op_desc->GetTypePtr()); - GE_WARN_ASSERT_GRAPH_SUCCESS(call_infer_shape_range(op, op_desc), - "[Call][InferFuncV2]Failed to infer shape_range of node %s[%s].", op_desc->GetNamePtr(), - op_desc->GetTypePtr()); - return GRAPH_SUCCESS; -} - -graphStatus OpDescUtilsEx::CallInferFuncV1(const OpDescPtr &op_desc, Operator &op) { - NodeShapeTransUtils transformer(op_desc); - const auto is_init_success = transformer.Init(); - if (!is_init_success) { - GELOGE(GRAPH_FAILED, "[Call][Init] for transformer failed"); - return GRAPH_FAILED; - } - if (!transformer.CatchFormatAndShape()) { - GELOGE(GRAPH_FAILED, "[Call][CatchFormatAndShape] for transformer failed!"); - return GRAPH_FAILED; - } - graphStatus graph_status = GRAPH_SUCCESS; - { - const auto &node_ptr = NodeUtilsEx::GetNodeFromOperator(op); - const bool empty_name = (node_ptr == nullptr) || (node_ptr->GetOwnerComputeGraph() == nullptr); - const auto &graph_name = empty_name ? std::string("") - : node_ptr->GetOwnerComputeGraph()->GetName(); - TraceOwnerGuard guard("OP", op_desc->GetName() + ":infershape", graph_name); - auto infer_func = op_desc->GetInferFunc(); - graph_status = infer_func(op); - } - if ((graph_status != GRAPH_SUCCESS) && (graph_status != GRAPH_NODE_NEED_REPASS)) { - GELOGE(GRAPH_FAILED, "[Call][InferFuncV1] for %s(%s) failed. ret:%u", op_desc->GetNamePtr(), op_desc->GetTypePtr(), - graph_status); - return GRAPH_FAILED; - } - if (!transformer.UpdateFormatAndShape()) { - GELOGE(GRAPH_FAILED, "[Call][UpdateFormatAndShape] for transformer failed!"); - return GRAPH_FAILED; - } - return graph_status; -} - -graphStatus OpDescUtilsEx::CallInferFunc(const OpDescPtr &op_desc, Operator &op) { - GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Shape."); - graphStatus ret; - // NoOp算子有原型,无infershape, 无数据输出, 此类需要跳过infer - const bool has_io = (op_desc->GetInputsSize() != 0U || op_desc->GetOutputsSize() != 0U); - const bool need_infer = - (has_io && (OperatorFactory::IsExistOp(op_desc->GetTypePtr()) || (op_desc->GetInferFunc() != nullptr))); - if (!need_infer) { - // todo 这是一个特殊错误码,早期版本与接口调用方约定的错误码,调用方认为这不是个错误,并且会依据该错误码作额外工作 - // 如,映射到已有的IR上等行为,或无痛地跳过infershape(如netoutput/framework - // 这里暂时为了v1动态shape执行时保留该错误码,后续整改 - GELOGD("Node %s(%s) no io or no prototype so does not need infer.", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - ret = GRAPH_PARAM_INVALID; - } else { - // priority of use infer func v1 - // when v2 func is ready, remove v1 func, it will automatically follow the V2 process - auto infer_func = TryGetV1InferFunc(op_desc); - bool can_support_rt1 = (infer_func != nullptr); - GELOGD("Op %s[%s] Call InferShapeFuncV%s", op_desc->GetNamePtr(), op_desc->GetTypePtr(), - can_support_rt1 ? "1" : "2"); - if (can_support_rt1) { - op_desc->AddInferFunc(infer_func); - ret = CallInferFuncV1(op_desc, op); - } else { - ret = CallInferFuncV2(op_desc, op); - // 临时方案,避免客户自定义算子交付件不完备,为了快速恢复用例,提供临时环境变量 - // 后续引导客户正确补充交付件以后,删除该环境变量 - static bool enable_fast_ignore_infer_error = EnableIgnoreInferError(); - if (enable_fast_ignore_infer_error) { - ret = (ret == GRAPH_SUCCESS) ? GRAPH_SUCCESS : GRAPH_PARAM_INVALID; - } else if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG( - "EZ9999", - "Call InferShapeAndType for node:%s(%s) failed. You can ignore this validation by exporting " - "IGNORE_INFER_ERROR=1 if necessary, but it is highly recommended to fix this problem.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - } - } - } - - if (ret == GRAPH_SUCCESS) { - GE_ASSERT_SUCCESS(InferShapeByOutputShapesAttr(op_desc), "[Infer][ByShapeValue] failed, op = %s", - op_desc->GetNamePtr()); - } - return ret; -} - -graphStatus OpDescUtilsEx::InferShapeByOutputShapesAttr(const OpDescPtr &op_desc) { - std::vector> shape_values; - const bool got = ge::AttrUtils::GetListListInt(op_desc, ATTR_NAME_PRESET_OUTPUT_SHAPES, shape_values); - if (!got) { - GELOGD("Do not need infer op = %s by shape value, shape_values = %zu.", - op_desc->GetNamePtr(), shape_values.size()); - return GRAPH_SUCCESS; - } - GE_ASSERT_TRUE(op_desc->GetAllOutputsDescSize() == static_cast(shape_values.size()), - "op = %s has output size = %u, but shape values size = %zu.", op_desc->GetNamePtr(), - op_desc->GetAllOutputsDescSize(), shape_values.size()); - size_t output_idx = 0UL; - for (const auto &shape_value : shape_values) { - const auto &output_desc = op_desc->MutableOutputDesc(output_idx); - GE_ASSERT_NOTNULL(output_desc, "[Get][Output] failed, id = %zu, op = %s.", output_idx, op_desc->GetNamePtr()); - output_idx++; - const auto output_shape = GeShape(shape_value); - GE_ASSERT_TRUE(TensorUtils::IsShapeEqual(output_desc->GetShape(), output_shape), - "[Check][ShapeEqual] op = %s inferred shape is %s, but shape value set shape is %s, is not same.", - op_desc->GetNamePtr(), output_desc->GetShape().ToString().c_str(), output_shape.ToString().c_str()); - output_desc->SetShape(output_shape); - output_desc->SetOriginShape(output_shape); - GELOGD("Update op = %s output[%zu] shape = %s", op_desc->GetNamePtr(), output_idx, - ToString(output_shape.GetDims()).c_str()); - } - return GRAPH_SUCCESS; -} - -graphStatus OpDescUtilsEx::CallInferFormatFuncV1(const OpDescPtr &op_desc, Operator &op) { - GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Format."); - auto infer_format_func = op_desc->GetInferFormatFunc(); - if (infer_format_func != nullptr) { - return static_cast(infer_format_func(op)); - } - infer_format_func = OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()); - if (infer_format_func == nullptr) { - return op_desc->DefaultInferFormat(); - } - op_desc->AddInferFormatFunc(infer_format_func); - return infer_format_func(op); -} - -graphStatus OpDescUtilsEx::CallInferFormatFuncV2(const OpDescPtr &op_desc, Operator &op) { - const auto call_infer_format_v2 = OperatorFactoryImpl::GetInferFormatV2Func(); - GE_ASSERT_NOTNULL(call_infer_format_v2); - if (op_desc->GetIrInputs().empty() && op_desc->GetIrOutputs().empty() && op_desc->GetAllOutputsDescSize() != 0U) { - GE_CHK_STATUS_RET(RecoverIrUtils::RecoverOpDescIrDefinition(op_desc), "Failed recover ir def for %s %s", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - } - return call_infer_format_v2(op, op_desc); -} - -graphStatus OpDescUtilsEx::CallInferFormatFunc(const OpDescPtr &op_desc, Operator &op) { - const auto is_infer_format_v2_registered_func = OperatorFactoryImpl::GetIsInferFormatV2RegisteredFunc(); - if ((is_infer_format_v2_registered_func != nullptr) && is_infer_format_v2_registered_func(op_desc)) { - GELOGI("[Call][InferFormat] call V2 func for op [%s][%s]", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return CallInferFormatFuncV2(op_desc, op); - } - GELOGI("[Call][InferFormat] call V1 func for op [%s][%s]", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return CallInferFormatFuncV1(op_desc, op); -} - -graphStatus OpDescUtilsEx::CallInferValueRangeFunc(const OpDescPtr &op_desc, Operator &op) { - GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer ValueRange."); - auto infer_value_range_func = op_desc->GetInferValueRangeFunc(); - if (infer_value_range_func != nullptr) { - return static_cast(infer_value_range_func(op)); - } - - const InferValueRangePara infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(op_desc->GetType()); - if (!infer_value_range_param.is_initialized) { - REPORT_INNER_ERR_MSG("E18888", "Node %s does not register func to infer value range.", op_desc->GetName().c_str()); - GELOGE(GRAPH_PARAM_INVALID, "Node %s does not register func to infer value range.", op_desc->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - - infer_value_range_func = infer_value_range_param.infer_value_func; - if (infer_value_range_func == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Value range infer func of node %s has been registered, but infer func is nullptr.", - op_desc->GetName().c_str()); - GELOGE(GRAPH_PARAM_INVALID, "Value range infer func of node %s has been registered, but infer func is nullptr.", - op_desc->GetName().c_str()); - return GRAPH_PARAM_INVALID; - } - op_desc->AddInferValueRangeFunc(infer_value_range_func); - return infer_value_range_func(op); -} - -graphStatus OpDescUtilsEx::OpVerify(const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Verify."); - auto verify_func = op_desc->GetVerifyFunc(); - if (verify_func == nullptr) { - verify_func = OperatorFactoryImpl::GetVerifyFunc(op_desc->GetType()); - } - if (verify_func != nullptr) { - Operator op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - const graphStatus ret = static_cast(verify_func(op)); - op_desc->AddVerifierFunc(verify_func); - op.BreakConnect(); - return ret; - } - return GRAPH_SUCCESS; -} - -graphStatus OpDescUtilsEx::InferShapeAndType(const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Shape."); - auto infer_func = op_desc->GetInferFunc(); - if (infer_func == nullptr) { - infer_func = OperatorFactoryImpl::GetInferShapeFunc(op_desc->GetType()); - if (infer_func == nullptr) { - GELOGW("[InferShape][Check] %s does not have infer_func.", op_desc->GetName().c_str()); - /// The infer_func has not been added for each operator in the current operator information library. - /// No infer_func added operator skips the call - /// and directly uses the shape information passed down by the upper framework - return GRAPH_SUCCESS; - } - } - Operator op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - const graphStatus ret = static_cast(infer_func(op)); - op_desc->AddInferFunc(infer_func); - op.BreakConnect(); - return ret; -} - -graphStatus OpDescUtilsEx::InferDataSlice(const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(op_desc, ", Op is null for Infer Slice."); - auto infer_data_slice_func = op_desc->GetInferDataSliceFunc(); - if (infer_data_slice_func == nullptr) { - infer_data_slice_func = OperatorFactoryImpl::GetInferDataSliceFunc(op_desc->GetType()); - if (infer_data_slice_func == nullptr) { - GELOGW("[InferDataSlice][Check] %s does not have infer data slice func.", op_desc->GetName().c_str()); - return NO_DEPENDENCE_FUNC; - } - } - Operator op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - const graphStatus ret = static_cast(infer_data_slice_func(op)); - op_desc->AddInferDataSliceFunc(infer_data_slice_func); - op.BreakConnect(); - return ret; -} - -void OpDescUtilsEx::SetType(OpDescPtr &op_desc, const std::string &type) { - // If the type changes, IR related variables should be modified accordingly - auto op = OperatorFactory::CreateOperator("tmp", type.c_str()); - op.BreakConnect(); - - op_desc->SetType(type); - op_desc->SetIrRelated(OpDescUtils::GetOpDescFromOperator(op)); - TRACE_GEN_RECORD(TraceManager::GetTraceHeader(), "modify", TraceManager::GetOutGraphName(), - op_desc->GetName(), "type", "", "", type); -} - -void OpDescUtilsEx::ResetFuncHandle(OpDescPtr &op_desc) { - op_desc->AddInferFunc(nullptr); - op_desc->AddInferFormatFunc(nullptr); - op_desc->AddInferValueRangeFunc(nullptr); - op_desc->AddVerifierFunc(nullptr); - op_desc->AddInferDataSliceFunc(nullptr); -} - -void OpDescUtilsEx::SetTypeAndResetFuncHandle(OpDescPtr &op_desc, const std::string &type) { - SetType(op_desc, type); - ResetFuncHandle(op_desc); -} - -void OpDescUtilsEx::UpdateShapeAndDType(const GeTensorDescPtr &src, const GeTensorDescPtr &dst) { - dst->SetOriginShape(src->GetOriginShape()); - dst->SetShape(src->GetShape()); - dst->SetDataType(src->GetDataType()); - dst->SetOriginDataType(src->GetOriginDataType()); - std::vector> src_shape_range; - src->GetShapeRange(src_shape_range); - dst->SetShapeRange(src_shape_range); - dst->SetOriginShapeRange(src_shape_range); - ge::TensorUtils::SetRealDimCnt(*dst, static_cast(src->GetShape().GetDims().size())); -} -} // namespace ge diff --git a/graph/utils/op_type_utils.cc b/graph/utils/op_type_utils.cc deleted file mode 100644 index 579ed479d7a164230304fc5ce2acd63a6289a49a..0000000000000000000000000000000000000000 --- a/graph/utils/op_type_utils.cc +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/op_type_utils.h" -#include -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_util.h" - -namespace ge { -namespace { -const std::unordered_set kDataOpSet = {DATA, REFDATA, AIPPDATA, ANN_DATA}; -const std::unordered_set kVariableOpSet = {VARIABLE, VARIABLEV2}; -const std::unordered_set kAssignOpSet = { - ASSIGNADD, ASSIGN, ASSIGNSUB, ASSIGNADDVARIABLEOP, ASSIGNSUBVARIABLEOP, ASSIGNVARIABLEOP}; -const std::unordered_set kIdentityOpSet = {IDENTITY, READVARIABLEOP}; -const std::unordered_set kConstPlaceHolderOpSet = {CONSTPLACEHOLDER}; -const std::unordered_set kConstOpSet = {CONSTANT, CONSTANTOP, CONSTPLACEHOLDER}; -const std::unordered_set kGraphOutputOpSet = {NETOUTPUT}; -const std::unordered_set kAutofuseNodeSet = {ASC_BC, FUSE_ASC_BC, EMPTY_ASC_BC}; -} // namespace - -/** - * @brief 判断类型是否为Autofuse - * @param type - * @return true - * @return false - */ -bool OpTypeUtils::IsAutofuseNode(const std::string &type) { - return (kAutofuseNodeSet.count(type) > 0); -} - -/** - * @brief 判断类型是否为Autofuse - * - * @param op_desc - * @return true - * @return false - */ -bool OpTypeUtils::IsAutofuseNode(const ge::OpDescPtr &op_desc) { - return IsAutofuseNode(op_desc->GetType()); -} - -/** - * @brief 判断类型是否为空tensor的Autofuse算子 - * - * @param type - * @return true - * @return false - */ -bool OpTypeUtils::IsEmptyAutofuseNode(const std::string &type) { - return (type == EMPTY_ASC_BC); -} - -/** - * @brief 判断类型是否为DATA - * 其中不包含QueueData, 该算子原型与其他Data不同,只有输出没有输出 - * 且在编译过程中有自由逻辑,不宜一起判断。 - * - * @param type - * @return true - * @return false - */ -bool OpTypeUtils::IsDataNode(const std::string &type) { - return (kDataOpSet.count(type) > 0); -} -/** - * @brief 判断类型是否为RefDATA并且为输入节点 - * @param node - * @return true - * @return false - */ -bool OpTypeUtils::IsInputRefData(const ge::OpDescPtr &op_desc) { - if ((op_desc == nullptr) || (op_desc->GetType() != REFDATA)) { - return false; - } - return !AttrUtils::HasAttr(op_desc, REF_VAR_SRC_VAR_NAME); -} - -bool OpTypeUtils::IsVariableNode(const std::string &type) { - return (kVariableOpSet.count(type) > 0); -} - -bool OpTypeUtils::IsVarLikeNode(const std::string &type) { - return IsVariableNode(type) || (type == REFDATA); -} - -bool OpTypeUtils::IsAssignLikeNode(const std::string &type) { - return kAssignOpSet.count(type) > 0U; -} - -bool OpTypeUtils::IsIdentityLikeNode(const std::string &type) { - return kIdentityOpSet.count(type) > 0U; -} - -bool OpTypeUtils::IsConstPlaceHolderNode(const std::string &type) { - return kConstPlaceHolderOpSet.count(type) > 0U; -} - -// CONST/CONSTANT/CONSTPLACEHOLDER -bool OpTypeUtils::IsConstNode(const std::string &type) { - return kConstOpSet.count(type) > 0U; -} - -// IsDataNode/IsVariableNode/IsVarLikeNode/IsConstNode -bool OpTypeUtils::IsGraphInputNode(const std::string &type) { - return (IsDataNode(type)) || (IsVariableNode(type)) || IsVarLikeNode(type) || IsConstNode(type); -} - -// NETOUTPUT -bool OpTypeUtils::IsGraphOutputNode(const std::string &type) { - return (kGraphOutputOpSet.count(type) > 0); -} - -/** - * @brief get the Original Type of FrameworkOp - * 其中不包含QueueData, 该算子原型与其他Data不同,只有输出没有输出 - * 且在编译过程中有自由逻辑,不宜一起判断。 - * - * @param [in] node - * @param [out] type - * @return graphStatus - */ -graphStatus OpTypeUtils::GetOriginalType(const ge::OpDescPtr &op_desc, std::string &type) { - GE_CHECK_NOTNULL(op_desc); - type = op_desc->GetType(); - GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return GRAPH_SUCCESS); - const bool ret = ge::AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); - if (!ret) { - REPORT_INNER_ERR_MSG("E19999", "Get Attr:%s fail from op:%s(%s)", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "[Get][Attr] %s fail from op:%s(%s)", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return INTERNAL_ERROR; - } - GELOGD("Get FrameWorkOp original type [%s]", type.c_str()); - return GRAPH_SUCCESS; -} - -bool OpTypeUtils::IsSubgraphInnerData(const ge::OpDescPtr &op_desc) { - return ((op_desc->GetType() == DATA) && op_desc->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)); -} -} // namespace ge diff --git a/graph/utils/profiler.cc b/graph/utils/profiler.cc deleted file mode 100644 index bc3a4372b1d14764bb55330fc2378bf04b441565..0000000000000000000000000000000000000000 --- a/graph/utils/profiler.cc +++ /dev/null @@ -1,156 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/profiler.h" -#include -#include "mmpa/mmpa_api.h" -#include "securec.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/def_types.h" -#include "external/graph/types.h" - -namespace ge { -namespace profiling { -namespace { -constexpr char_t kVersion[] = "1.0"; -int64_t GetThread() { - thread_local static int64_t tid = static_cast(mmGetTid()); - return tid; -} -void DumpEventType(const EventType et, std::ostream &out_stream) { - switch (et) { - case EventType::kEventStart: - out_stream << "Start"; - break; - case EventType::kEventEnd: - out_stream << "End"; - break; - case EventType::kEventTimestamp: - break; - default: - out_stream << "UNKNOWN(" << static_cast(et) << ")"; - break; - } -} -} - -void Profiler::RecordCurrentThread(const int64_t element, const int64_t event, const EventType et) { - Record(element, GetThread(), event, et, std::chrono::system_clock::now()); -} - -void Profiler::RecordCurrentThread(const int64_t element, const int64_t event, const EventType et, - const std::chrono::time_point time_point) { - Record(element, GetThread(), event, et, time_point); -} - -void Profiler::UpdateHashByIndex(const int64_t index, const uint64_t hash) { - if (index >= kMaxStrIndex) { - return; - } - PtrAdd(GetStringHashes(), static_cast(kMaxStrIndex), static_cast(index))->hash = hash; -} - -void Profiler::RegisterString(const int64_t index, const std::string &str) { - if (index >= kMaxStrIndex) { - return; - } - - // can not use strcpy_s, which will copy nothing when the length of str beyond kMaxStrLen - const auto ret = strncpy_s(PtrAdd(GetStringHashes(), - static_cast(kMaxStrIndex), static_cast(index))->str, - kMaxStrLen, str.c_str(), kMaxStrLen - 1UL); - if (ret != EN_OK) { - GELOGW("Register string failed, index %ld, str %s", index, str.c_str()); - } -} - -void Profiler::RegisterStringHash(const int64_t index, const uint64_t hash, const std::string &str) { - if (index >= kMaxStrIndex) { - return; - } - - // can not use strcpy_s, which will copy nothing when the length of str beyond kMaxStrLen - const auto ret = strncpy_s(PtrAdd(GetStringHashes(), - static_cast(kMaxStrIndex), static_cast(index))->str, - kMaxStrLen, str.c_str(), kMaxStrLen - 1UL); - if (ret != EN_OK) { - GELOGW("Register string failed, index %ld, str %s", index, str.c_str()); - } - PtrAdd(GetStringHashes(), static_cast(kMaxStrIndex), static_cast(index))->hash = hash; -} - -void Profiler::Record(const int64_t element, const int64_t thread, const int64_t event, const EventType et, - const std::chrono::time_point time_point) { - auto current_index = record_size_++; - if (current_index >= kMaxRecordNum) { - return; - } - records_[current_index] = ProfilingRecord({element, thread, event, et, time_point}); -} -void Profiler::Dump(std::ostream &out_stream) const { - if (record_size_ == 0UL) { - return; - } - size_t print_size = record_size_; - out_stream << "Profiler version: " << &kVersion[0] - << ", dump start, records num: " << print_size << std::endl; - if (print_size > records_.size()) { - out_stream << "Too many records(" << print_size << "), the records after " - << records_.size() << " will be dropped" << std::endl; - print_size = records_.size(); - } - for (size_t i = 0UL; i < print_size; ++i) { - auto &rec = records_[i]; - // in format: - out_stream << std::chrono::duration_cast(rec.timestamp.time_since_epoch()).count() << ' '; - out_stream << rec.thread << ' '; - DumpByIndex(rec.element, out_stream); - out_stream << ' '; - DumpByIndex(rec.event, out_stream); - out_stream << ' '; - DumpEventType(rec.et, out_stream); - out_stream << std::endl; - } - out_stream << "Profiling dump end" << std::endl; -} -void Profiler::DumpByIndex(const int64_t index, std::ostream &out_stream) const { - if ((index < 0) || (index >= kMaxStrIndex) || - (strnlen(PtrAdd(GetStringHashes(), - static_cast(kMaxStrIndex), - static_cast(index))->str, kMaxStrLen) == 0UL)) { - out_stream << "UNKNOWN(" << index << ")"; - } else { - out_stream << '[' << PtrAdd(GetStringHashes(), - static_cast(kMaxStrIndex), static_cast(index))->str << "]"; - } -} -Profiler::Profiler() : record_size_(0UL), records_(), indexes_to_str_hashes_() {} -void Profiler::Reset() { - // 不完全reset,indexes_to_str_hashes_还是有值的 - record_size_ = 0UL; -} -std::unique_ptr Profiler::Create() { - return ComGraphMakeUnique(); -} -size_t Profiler::GetRecordNum() const noexcept { - return record_size_; -} -const ProfilingRecord *Profiler::GetRecords() const { - return &(records_[0UL]); -} -Profiler::ConstStringHashesPointer Profiler::GetStringHashes() const { - return indexes_to_str_hashes_; -} -Profiler::StringHashesPointer Profiler::GetStringHashes() { - return indexes_to_str_hashes_; -} -Profiler::~Profiler() = default; -} -} diff --git a/graph/utils/screen_printer.cc b/graph/utils/screen_printer.cc deleted file mode 100644 index 76b406d7c37edb4b63b88b6ea859819fb268252f..0000000000000000000000000000000000000000 --- a/graph/utils/screen_printer.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/screen_printer.h" - -#include -#include "graph/debug/ge_log.h" -#include "graph/ge_context.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -namespace { -constexpr size_t kMaxLogLen = 1024U; -constexpr size_t kMaxTimeLen = 128U; -constexpr int64_t kOneThousandMs = 1000L; -constexpr const char_t *kModeDisable = "disable"; - -std::string CurrentTimeFormatStr() { - mmSystemTime_t system_time; - if (mmGetSystemTime(&system_time) != EN_OK) { - return ""; - } - mmTimeval tv; - if (mmGetTimeOfDay(&tv, nullptr) != EN_OK) { - return ""; - } - char_t format_time[kMaxTimeLen] = {}; - if (snprintf_s(format_time, kMaxTimeLen, kMaxTimeLen - 1U, "[%04d-%02d-%02d-%02d:%02d:%02d.%03ld.%03ld]", - system_time.wYear, system_time.wMonth, system_time.wDay, system_time.wHour, system_time.wMinute, - system_time.wSecond, (tv.tv_usec / kOneThousandMs), (tv.tv_usec % kOneThousandMs)) == -1) { - return ""; - } - return format_time; -} -} - -ScreenPrinter &ScreenPrinter::GetInstance() { - static ScreenPrinter instance; - return instance; -} - -void ScreenPrinter::Log(const char *fmt, ...) { - if (fmt == nullptr) { - GELOGE(FAILED, "param is nullptr and will not print message."); - return; - } - if (print_mode_ == PrintMode::DISABLE) { - return; - } - va_list va_list; - va_start(va_list, fmt); - char_t str[kMaxLogLen + 1U] = {}; - if (vsnprintf_s(str, kMaxLogLen + 1U, kMaxLogLen, fmt, va_list) == -1) { - va_end(va_list); - GELOGE(FAILED, "sprintf log failed and will not print message."); - return; - } - va_end(va_list); - - const auto &format_time = CurrentTimeFormatStr(); - if (format_time.empty()) { - GELOGE(FAILED, "construct format time failed and will not print message."); - return; - } - - const std::lock_guard lk(mutex_); - std::cout << format_time << mmGetTid() << " " << str << std::endl; - return; -} - -void ScreenPrinter::Init(const std::string &print_mode) { - if ((!print_mode.empty()) && (print_mode == kModeDisable)) { - print_mode_ = PrintMode::DISABLE; - } else { - print_mode_ = PrintMode::ENABLE; - } - GELOGD("Screen print mode:%u", print_mode_); -} -} // namespace ge diff --git a/graph/utils/tensor_utils.cc b/graph/utils/tensor_utils.cc deleted file mode 100644 index d95bf8ae4f3d9a2ace8a9d390c397baacba3d96e..0000000000000000000000000000000000000000 --- a/graph/utils/tensor_utils.cc +++ /dev/null @@ -1,426 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/tensor_utils.h" - -#include - -#include "graph/debug/ge_log.h" -#include "graph/utils/type_utils.h" -#include "mmpa/mmpa_api.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/attr_utils.h" -#include "base/err_msg.h" - -namespace ge { -namespace { -// Unknown shape element num -const int64_t kElementCntUnknownShape = -1; - -// Unknown shape mem size -const int64_t kUnknownShapeMemSize = -1; - -// Nchw and nhwc dim size must be 4 -const uint32_t kDimSize4d = 4U; - -// C1HWNCoC0 dim size must be 6 -const uint32_t kDimSizeC1hwncoc0 = 6U; - -const int64_t kDataMemAlignSize = 32; -const int64_t kNum2 = 2; - -const char_t *const kShapeRangeInvalid = "format of shape range is invalid"; -const char_t *const kShapeRangeSample = "\"[1~20,3,3~6,-1]\""; -} // namespace - -static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) { - if (a > 0) { - if (b > 0) { - if (a > (std::numeric_limits::max() / b)) { - return true; - } - } else { - if (b < (std::numeric_limits::min() / a)) { - return true; - } - } - } else { - if (b > 0) { - if (a < (std::numeric_limits::min() / b)) { - return true; - } - } else { - if ((a != 0) && (b < (std::numeric_limits::max() / a))) { - return true; - } - } - } - return false; -} - -/// -/// Calculate element num by dims directly. -/// @param dims dim info -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntByDims(const std::vector &dims, int64_t &element_cnt) { - element_cnt = 1; - for (const int64_t dim : dims) { - if (CheckMultiplyOverflowInt64(element_cnt, dim)) { - REPORT_INNER_ERR_MSG("E18888", "result will overflow when multiplying %" PRId64 " and %" PRId64 ".", element_cnt, - dim); - GELOGE(GRAPH_FAILED, - "[Check][Overflow] CalcElementCntByDims failed, when multiplying %" PRId64 " and %" PRId64 ".", - element_cnt, dim); - return GRAPH_FAILED; - } - element_cnt *= dim; - } - return GRAPH_SUCCESS; -} - -/// -/// Calculate fixed dims element num. -/// @param dims dim info -/// @param fixed_dim_size fixed dim size -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcElementCntOfFixedDims(const std::vector &dims, const Format format, - const uint32_t fixed_dim_size, int64_t &element_cnt) { - if (dims.size() != fixed_dim_size) { - GELOGD("[Util][CalcElemCnt] Format %d(%s) need dim size=%u but %zu, calc as ND.", - format, TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size()); - } - return CalcElementCntByDims(dims, element_cnt); -} - -static graphStatus GetMaxShapeDimsFromNoTilingTensor(const GeTensorDesc &tensor_desc, - std::vector &output_dims) { - const auto &shape = tensor_desc.GetShape(); - const std::vector &dims = shape.GetDims(); - std::vector max_shape_list; - // use the max shape set by user - const bool has_attr = AttrUtils::GetListInt(tensor_desc, ATTR_NAME_TENSOR_MAX_SHAPE, max_shape_list); - if (has_attr) { - if (max_shape_list.size() == dims.size()) { - output_dims = std::move(max_shape_list); - return GRAPH_SUCCESS; - } - REPORT_INNER_ERR_MSG("E18888", "invalid input shape range."); - GELOGE(PARAM_INVALID, "[Check][Param]tensor invalid max_shape_list size[%zu], dim size[%zu].", - max_shape_list.size(), dims.size()); - return PARAM_INVALID; - } - // if max shape attr not set, use shape range - std::vector> range; - const graphStatus graph_status = tensor_desc.GetShapeRange(range); - if (graph_status != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Get shape range failed."); - GELOGE(PARAM_INVALID, "[Check][Param] GetShapeRange failed."); - return graph_status; - } - if (dims.size() != range.size()) { - REPORT_INNER_ERR_MSG("E18888", "Error shape range size."); - GELOGE(PARAM_INVALID, "[Check][Param] size not matched dims_size[%zu] range_size[%zu].", dims.size(), range.size()); - return PARAM_INVALID; - } - for (size_t i = 0U; i < dims.size(); ++i) { - const int64_t dim = (dims[i] < 0) ? range[i].second : dims[i]; - output_dims.push_back(dim); - } - return GRAPH_SUCCESS; -} - -/// Calculate tensor element num. -/// @param dims dim info -/// @param format tensor format -/// @param data_type data type -/// @param element_cnt element count -/// @return GRAPH_SUCCESS:success -/// other:failed -/// -static graphStatus CalcTensorElementCnt(const std::vector &dims, const Format format, const DataType data_type, - int64_t &element_cnt) { - const std::string format_str = TypeUtils::FormatToSerialString(format); - // Check dims - for (size_t i = 0U; i < dims.size(); ++i) { - const int64_t dim = dims[i]; - if (dim < 0) { - GELOGI("It's unknown shape, as dims[%zu]=%" PRId64 " negative, format=%d(%s).", i, dim, format, - format_str.c_str()); - element_cnt = kElementCntUnknownShape; - return GRAPH_SUCCESS; - } else if (dim == 0) { - GELOGI("No need calc element count, as dims[%zu]=%" PRId64 ", format=%d(%s).", i, dim, format, - format_str.c_str()); - element_cnt = 0; - return GRAPH_SUCCESS; - } else { - // else branch - } - } - - graphStatus graph_status; - switch (GetPrimaryFormat(format)) { - case FORMAT_ND: - case FORMAT_MD: - graph_status = CalcElementCntByDims(dims, element_cnt); - break; - case FORMAT_NCHW: - case FORMAT_HWCN: - case FORMAT_NHWC: - case FORMAT_CHWN: - case FORMAT_C1HWC0: - graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt); - break; - case FORMAT_C1HWNCoC0: - graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt); - break; - case FORMAT_NC1HWC0: - case FORMAT_FRACTAL_Z: - case FORMAT_FILTER_HWCK: - case FORMAT_FRACTAL_NZ: - case FORMAT_FRACTAL_NZ_C0_16: - case FORMAT_FRACTAL_NZ_C0_32: - case FORMAT_FRACTAL_ZZ: - case FORMAT_NDHWC: - case FORMAT_NCDHW: - case FORMAT_DHWCN: - case FORMAT_DHWNC: - case FORMAT_FRACTAL_Z_3D: - case FORMAT_FRACTAL_Z_3D_TRANSPOSE: - case FORMAT_NDC1HWC0: - case FORMAT_FRACTAL_Z_C04: - case FORMAT_FRACTAL_ZN_LSTM: - case FORMAT_NC1HWC0_C04: - case FORMAT_ND_RNN_BIAS: - case FORMAT_FRACTAL_ZN_RNN: - case FORMAT_NYUV: - case FORMAT_NYUV_A: - case FORMAT_NCL: - case FORMAT_FRACTAL_Z_WINO: - graph_status = CalcElementCntByDims(dims, element_cnt); - break; - default: - REPORT_INNER_ERR_MSG("E18888", "unsupported format, format=%d(%s).", format, format_str.c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] unsupported format, format=%d(%s).", format, format_str.c_str()); - graph_status = GRAPH_FAILED; - break; - } - - const std::string type_str = TypeUtils::DataTypeToSerialString(data_type); - if (graph_status == GRAPH_SUCCESS) { - GELOGD( - "CalcTensorElementCnt end, format=%d(%s), data_type=%d(%s), element_cnt=%" PRId64 ".", - format, format_str.c_str(), data_type, type_str.c_str(), element_cnt); - } else { - REPORT_INNER_ERR_MSG("E18888", "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, - format_str.c_str(), data_type, type_str.c_str()); - GELOGE(GRAPH_FAILED, "[Calc][TensorElementCnt] failed, format=%d(%s), data_type=%d(%s).", - format, format_str.c_str(), data_type, type_str.c_str()); - } - return graph_status; -} - -/// -/// Calculate tensor mem size. -/// @param shape tensor shape -/// @param format tensor format -/// @param data_type tensor data type -/// @param mem_size -1 means unknown shape,other means mem size -/// @return GRAPH_SUCCESS:success, other:failed -/// -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape, - const Format format, - const DataType data_type, - int64_t &mem_size) { - const std::string format_str = TypeUtils::FormatToSerialString(format); - const std::string type_str = TypeUtils::DataTypeToSerialString(data_type); - - const std::vector dims = shape.GetDims(); - int64_t element_cnt = 0; - const graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt); - if (status != GRAPH_SUCCESS) { - GELOGE(status, "[Calc][TensorElementCnt] failed, shape[%s], status=%u format=%d(%s) data_type=%d(%s).", - shape.ToString().c_str(), status, format, format_str.c_str(), data_type, type_str.c_str()); - return status; - } - // Support unknown shape - if (element_cnt < 0) { - mem_size = kUnknownShapeMemSize; - GELOGD("element_cnt is unknown. shape[%s], format=%d(%s), data_type=%d(%s), mem_size=%" PRId64, - shape.ToString().c_str(), format, format_str.c_str(), data_type, type_str.c_str(), mem_size); - return GRAPH_SUCCESS; - } - - if ((data_type == DT_STRING) || (data_type == DT_STRING_REF)) { - uint32_t type_size = 0U; - const bool result = TypeUtils::GetDataTypeLength(data_type, type_size); - if (!result) { - REPORT_INNER_ERR_MSG("E18888", "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str()); - GELOGE(GRAPH_FAILED, "[Get][DataTypeLength] failed, data_type=%d(%s).", data_type, type_str.c_str()); - return GRAPH_FAILED; - } - const auto type_size_int64 = static_cast(type_size); - if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) { - REPORT_PREDEFINED_ERR_MSG( - "E19013", std::vector({"function", "var1", "var2"}), - std::vector( - {"CheckMultiplyOverflowInt64", std::to_string(element_cnt).c_str(), std::to_string(type_size).c_str()})); - GELOGE(GRAPH_FAILED, - "[Check][Overflow] multiply %" PRId64 " and %u, shape[%s], format=%d(%s), data_type=%d(%s).", - element_cnt, type_size, shape.ToString().c_str(), format, format_str.c_str(), data_type, type_str.c_str()); - return GRAPH_FAILED; - } - mem_size = element_cnt * type_size_int64; - } else { - mem_size = ge::GetSizeInBytes(element_cnt, data_type); - } - - GELOGD("shape[%s], format=%d(%s), data_type=%d(%s), mem_size=%" PRId64, - shape.ToString().c_str(), format, format_str.c_str(), data_type, type_str.c_str(), mem_size); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { - const graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); - if (graph_status != GRAPH_SUCCESS) { - return GRAPH_FAILED; - } - - // 64-byte alignment, if size is 0, align to 32 bytes - if (size_temp > (std::numeric_limits::max() - (kNum2 * kDataMemAlignSize))) { - GELOGW("[Util][CalcBytesSize] Mem size %" PRId64 " after alignment is bigger than INT64_MAX", size_temp); - } else { - size_temp = ((size_temp + (kNum2 * kDataMemAlignSize) - 1) / kDataMemAlignSize) * kDataMemAlignSize; - } - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::CalcTensorMemSizeForNoTiling(const GeTensorDesc &tensor, const Format format, - const DataType data_type, int64_t &mem_size) { - if (tensor.GetShape().IsUnknownShape()) { - std::vector dims; - GE_CHK_STATUS_RET(GetMaxShapeDimsFromNoTilingTensor(tensor, dims), - "[Calc][GetMaxShapeDimsFromNoTilingTensor] failed."); - return CalcTensorMemSize(GeShape(dims), format, data_type, mem_size); - } - return CalcTensorMemSize(tensor.GetShape(), format, data_type, mem_size); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { - const Format format = desc_temp.GetFormat(); - const DataType data_type = desc_temp.GetDataType(); - int64_t output_mem_size = 0; - - bool is_no_tiling = false; - (void)AttrUtils::GetBool(desc_temp, ATTR_NAME_TENSOR_NO_TILING_MEM_TYPE, is_no_tiling); - graphStatus graph_status; - if (is_no_tiling) { - graph_status = CalcTensorMemSizeForNoTiling(desc_temp, format, data_type, output_mem_size); - } else { - graph_status = CalcTensorMemSize(desc_temp.GetShape(), format, data_type, output_mem_size); - } - if (graph_status != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Calc][TensorMemSize] failed! type:%s, is_no_tiling:%s", - TypeUtils::DataTypeToSerialString(data_type).c_str(), is_no_tiling ? "true" : "false"); - return GRAPH_FAILED; - } - - if (output_mem_size < 0) { - REPORT_INNER_ERR_MSG("E18888", - "After calc concat tensor memory size, output_mem_size = %" PRId64 "," - " out of data range [0, %" PRId64 "]", - output_mem_size, std::numeric_limits::max()); - GELOGW("[Check][Param] After calc concat tensor memory size, " - "output_mem_size = %" PRId64 ", out of data range [0, %" PRId64 "]", - output_mem_size, std::numeric_limits::max()); - return GRAPH_FAILED; - } - - size_temp = output_mem_size; - return GRAPH_SUCCESS; -} -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -TensorUtils::CheckShapeByShapeRange(const GeShape &shape, const std::vector> &shape_range) { - if ((shape.GetDimNum() == 0U) || shape_range.empty()) { - GELOGD(" Shape or shape range is empty, no need to check."); - return GRAPH_SUCCESS; - } - if (shape.GetDimNum() != shape_range.size()) { - REPORT_PREDEFINED_ERR_MSG("E10049", std::vector({"shape_range_size", "cur_dim_size"}), - std::vector({std::to_string(shape_range.size()).c_str(), - std::to_string(shape.GetDimNum()).c_str()})); - GELOGE(PARAM_INVALID, "[Check][Param] Given shape_range dim num [%zu] and current dim num [%zu] are not match. " - "Please check", shape_range.size(), shape.GetDimNum()); - return PARAM_INVALID; - } - - for (size_t idx = 0U; idx < shape.GetDimNum(); idx++) { - const auto cur_dim = shape.GetDim(idx); - if (cur_dim == UNKNOWN_DIM) { - GELOGD("[Check][InputShape]cur shape dim [%" PRId64 "] is dynamic, no need to check.", cur_dim); - continue; - } - const auto left_range = shape_range[idx].first; - const auto right_range = shape_range[idx].second; - if (left_range < 0) { - const std::string error_range = std::to_string(left_range) + " ~ " + std::to_string(right_range); - REPORT_PREDEFINED_ERR_MSG( - "E10048", std::vector({"shape_range", "reason", "sample"}), - std::vector({error_range.c_str(), kShapeRangeInvalid, kShapeRangeSample})); - GELOGE(PARAM_INVALID, "[Check][Param] Given shape range[%s] is invalid, reason: %s, correct sample is %s.", - error_range.c_str(), kShapeRangeInvalid, kShapeRangeSample); - return PARAM_INVALID; - } - - if (cur_dim < left_range) { - REPORT_PREDEFINED_ERR_MSG( - "E10050", std::vector({"cur_dim", "shape_range_left", "shape_range_right"}), - std::vector({std::to_string(cur_dim).c_str(), std::to_string(left_range).c_str(), - std::to_string(right_range).c_str()})); - GELOGE(PARAM_INVALID, "[Check][Param] Current dim shape [%" PRId64 "] is out of " - "shape range [%" PRId64 "~%" PRId64 "]. Please check.", - cur_dim, left_range, right_range); - return PARAM_INVALID; - } - - if (right_range < 0) { - if (right_range != UNKNOWN_DIM) { - const std::string error_range = std::to_string(left_range) + " ~ " + std::to_string(right_range); - REPORT_PREDEFINED_ERR_MSG( - "E10048", std::vector({"shape_range", "reason", "sample"}), - std::vector({error_range.c_str(), kShapeRangeInvalid, kShapeRangeSample})); - GELOGE(PARAM_INVALID, "[Check][Param] Given shape range[%s] is invalid, reason: %s, correct sample is %s.", - error_range.c_str(), kShapeRangeInvalid, kShapeRangeSample); - return PARAM_INVALID; - } - } else { - if (cur_dim > right_range) { - REPORT_PREDEFINED_ERR_MSG( - "E10050", std::vector({"cur_dim", "shape_range_left", "shape_range_right"}), - std::vector({std::to_string(cur_dim).c_str(), std::to_string(left_range).c_str(), - std::to_string(right_range).c_str()})); - GELOGE(PARAM_INVALID, "[Check][Param] Current dim shape [%" PRId64 "] is out of " - "shape range [%" PRId64 "~%" PRId64 "]. Please check.", - cur_dim, left_range, right_range); - return PARAM_INVALID; - } - } - } - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/utils/trace/trace_manager.cc b/graph/utils/trace/trace_manager.cc deleted file mode 100644 index 880fed73aba307f1a215d4bec185ce8f7851b2c1..0000000000000000000000000000000000000000 --- a/graph/utils/trace/trace_manager.cc +++ /dev/null @@ -1,274 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/util/trace_manager/trace_manager.h" - -#include -#include -#include -#include -#include - -#include "mmpa/mmpa_api.h" -#include "graph/debug/ge_util.h" -#include "graph/ge_context.h" -#include "graph/utils/file_utils.h" - -namespace ge { -namespace { -class TraceFileHolder { - public: - explicit TraceFileHolder(int32_t fd) : fd_(fd) {} - TraceFileHolder(TraceFileHolder const &) = delete; - TraceFileHolder &operator=(TraceFileHolder const &) = delete; - ~TraceFileHolder() { - if (fd_ >= 0) { - (void)mmClose(fd_); - fd_ = -1; - } - } - - void Write(const char_t *data, const char *separator = "\r\n") const { - if (fd_ >= 0) { - const mmSsize_t written_count = mmWrite(fd_, const_cast(data), strlen(data)); - if ((written_count == EN_INVALID_PARAM) || (written_count == EN_ERROR)) { - GELOGE(INTERNAL_ERROR, "[trace] Failed write trace info to file %s", data); - } - (void) mmWrite(fd_, const_cast(separator), strlen(separator)); - } - } - - bool Valid() const { - return fd_ >= 0; - } - - private: - int32_t fd_; -}; - -std::string CurrentTimeInSecondsStr() { - mmSystemTime_t sysTime; - if (mmGetSystemTime(&sysTime) != EN_OK) { - GELOGE(INTERNAL_ERROR, "Get current time failed"); - const static std::string kInvalidTimeStr; - return kInvalidTimeStr; - } - - std::stringstream ss; - ss << sysTime.wYear << sysTime.wMonth << sysTime.wDay << sysTime.wHour << sysTime.wMinute << sysTime.wSecond; - return ss.str(); -} - -constexpr uint64_t kTraceSaveArraySize = (kTraceSaveTriggerNum << 1U); -constexpr uint64_t kTraceSaveCountsPerFile = 2000000U; -} // namespace - -thread_local std::string TraceManager::trace_header_; -thread_local std::string TraceManager::graph_name_; - -TraceManager &TraceManager::GetInstance() { - static TraceManager instance; - return instance; -} - -// Set owner -void TraceManager::SetTraceOwner(const std::string &owner, const std::string &stage, const std::string &graph_name) { - if (!enabled_) { - return; - } - trace_header_ = owner + ":" + stage; - graph_name_ = graph_name; -} - -// Clear owner -void TraceManager::ClearTraceOwner() { - if (!enabled_) { - return; - } - trace_header_.clear(); - graph_name_.clear(); -} - -std::string TraceManager::NextFileName() { - static std::atomic uuid(0U); - - std::stringstream ss; - // need 3 widths to express uuid - ss << trace_save_file_path_ << "trace_" << CurrentTimeInSecondsStr() << "_" << std::setw(3) << std::setfill('0') - << uuid++ << ".txt"; - - return ss.str(); -} - -std::unique_ptr OpenOrCreateFile(const std::string &file_path) { - if (strnlen(file_path.c_str(), MMPA_MAX_PATH) >= MMPA_MAX_PATH) { - GELOGE(PATH_INVALID, "[trace] Trace file name %s exceed max length %u", file_path.c_str(), - static_cast(MMPA_MAX_PATH)); - return nullptr; - } - - char_t real_path[MMPA_MAX_PATH] = {}; - if (mmRealPath(file_path.c_str(), &real_path[0], MMPA_MAX_PATH) != EN_OK) { - GELOGI("[trace] Create new trace file %s", file_path.c_str()); - } - - const static auto kFlag = static_cast(static_cast(M_WRONLY) | static_cast(M_CREAT) | - static_cast(M_APPEND)); - const static auto kMode = static_cast(static_cast(M_IRUSR) | static_cast(M_IWUSR)); - - return ComGraphMakeUnique(mmOpen2(&real_path[0], kFlag, kMode)); -} - -void TraceManager::SaveTraceBufferToFile(const ReadyPart ready_part) { - if (ready_part == ReadyPart::None) { - return; - } - - ScopeGuard guard([this, ready_part]() { - // Saved count must update for un-block add tracing thread - if (ready_part == ReadyPart::A) { - part1_ready_nums_ = 0U; - } else { - part2_ready_nums_ = 0U; - } - // Must update save nums after clear part ready nums - total_saved_nums_ += kTraceSaveTriggerNum; - }); - - if (current_saving_file_name_.empty() || (current_file_saved_nums_ >= kTraceSaveCountsPerFile)) { - current_saving_file_name_ = NextFileName(); - current_file_saved_nums_ = 0U; - } - - auto fh = OpenOrCreateFile(current_saving_file_name_); - if (fh == nullptr || (!fh->Valid())) { - GELOGE(INTERNAL_ERROR, "[trace] Failed get file holder for %s", current_saving_file_name_.c_str()); - return; - } - - while (((ready_part == ReadyPart::A) && (part1_ready_nums_ < kTraceSaveTriggerNum)) || - ((ready_part == ReadyPart::B) && (part2_ready_nums_ < kTraceSaveTriggerNum))) { - } - const size_t start = (ready_part == ReadyPart::A) ? 0U : kTraceSaveTriggerNum; - for (size_t i = start; i < (start + kTraceSaveTriggerNum); i++) { - if (!trace_array_[i].empty()) { - current_file_saved_nums_++; - fh->Write(trace_array_[i].c_str()); - } - } -} - -void TraceManager::SaveBufferToFileThreadFunc() { - (void)pthread_setname_np(pthread_self(), "ge_trace_savbuf"); - while (true) { - std::unique_lock lock_file(mu_); - while ((ready_part_ == ReadyPart::None) && (!stopped_)) { - data_ready_var_.wait(lock_file); - } - if (stopped_ && (ready_part_ == ReadyPart::None)) { // Keep save remain trace even request stop - break; - } - const auto ready_part = ready_part_; - ready_part_ = ReadyPart::None; - lock_file.unlock(); - - SaveTraceBufferToFile(ready_part); - } -} - -Status TraceManager::Initialize(const char_t *file_save_path) { - // init data - std::stringstream ss; - ss << file_save_path << MMPA_PATH_SEPARATOR_STR << "extra-info" << MMPA_PATH_SEPARATOR_STR << "graph_trace" - << MMPA_PATH_SEPARATOR_STR << ge::GetContext().DeviceId() << MMPA_PATH_SEPARATOR_STR; - trace_save_file_path_ = ss.str(); - if (CreateDir(trace_save_file_path_) != 0) { - GELOGE(INTERNAL_ERROR, "[trace] Trace not enabled as failed create trace file save directory[%s]", - trace_save_file_path_.c_str()); - return FAILED; - } - trace_array_.resize(kTraceSaveTriggerNum << 1U); - try { - save_thread_ = std::thread(&TraceManager::SaveBufferToFileThreadFunc, this); - } catch (const std::system_error &) { - GELOGE(INTERNAL_ERROR, "[trace] Trace not enabled as failed start trace saving thread"); - return FAILED; - } - return SUCCESS; -} - -void TraceManager::Finalize() { - std::thread([this]() { - (void)pthread_setname_np(pthread_self(), "ge_trace_final"); - // Trigger save for left trace info, trace added when or after dtor may lose - for (size_t i = 1; i < kTraceSaveTriggerNum; i++) { - AddTrace(""); - } - }).join(); - // After join the thread above, remain trace must have trigger save part A or B - std::unique_lock lk(mu_); - stopped_ = true; // stopping record any new trace here - data_ready_var_.notify_all(); - lk.unlock(); - if (save_thread_.joinable()) { - save_thread_.join(); - } -} - -TraceManager::TraceManager() { - const char_t *trace_env_path = nullptr; - MM_SYS_GET_ENV(MM_ENV_NPU_COLLECT_PATH, trace_env_path); - enabled_ = (trace_env_path != nullptr) && (trace_env_path[0U] != '\0'); - if (!enabled_) { - GELOGI("[trace] Trace not enabled as env 'NPU_COLLECT_PATH' not set"); - return; - } - - if (Initialize(trace_env_path) != SUCCESS) { - enabled_ = false; - GELOGE(INTERNAL_ERROR, "[trace] Trace not enabled as initialize failed"); - } -} - -TraceManager::~TraceManager() { - if (!enabled_) { - return; - } - Finalize(); -} - -void TraceManager::AddTrace(std::string &&trace_info) { - if (!enabled_) { - return; - } - // Assume kTraceSaveArraySize = 2 * kTraceSaveTriggerNum - const auto current_trace_nums = trace_index_.fetch_add(1); - // blocking when almost full to prevent re-trigger save - const static uint64_t kLeftNumTriggerBlock = 1U; - while (((current_trace_nums - total_saved_nums_) >= (kTraceSaveArraySize - kLeftNumTriggerBlock)) && (!stopped_)) { - } - if (stopped_) { // Drop trace after request stopping - return; - } - const auto index = current_trace_nums % kTraceSaveArraySize; - trace_array_[index] = std::move(trace_info); - if (index < kTraceSaveTriggerNum) { - part1_ready_nums_++; - } else { - part2_ready_nums_++; - } - // assume kTraceSaveTriggerNum is an aliquot part of kTraceSaveArraySize - if ((index + 1U) % kTraceSaveTriggerNum == 0) { - std::unique_lock lk(mu_); - ready_part_ = (index < kTraceSaveTriggerNum) ? ReadyPart::A : ReadyPart::B; - lk.unlock(); - data_ready_var_.notify_all(); - } -} -} // namespace ge diff --git a/graph/utils/transformer_utils.cc b/graph/utils/transformer_utils.cc deleted file mode 100644 index 9a2f17024602ca5a8abe199e02670f6983aff161..0000000000000000000000000000000000000000 --- a/graph/utils/transformer_utils.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "transformer_utils.h" - -#include "common/ge_common/debug/ge_log.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/attr_utils.h" -#include "inc/graph/debug/ge_attr_define.h" -#include "expand_dimension.h" -#include "transfer_shape_according_to_format.h" - -namespace ge { -namespace { -bool OriginShapeInitialized(const GeTensorDescPtr &tensor_desc) { - // The caller guarantees that the pointer is not null - if (!tensor_desc->GetOriginShape().IsScalar()) { - return true; - } - return tensor_desc->IsOriginShapeInitialized(); -} -bool SameCurrentAndOrigin(const GeTensorDescPtr &tensor_desc) { - // The caller guarantees that the pointer is not null - if (tensor_desc->GetFormat() == tensor_desc->GetOriginFormat()) { - if (tensor_desc->GetShape() == tensor_desc->GetOriginShape()) { - return true; - } - return !OriginShapeInitialized(tensor_desc); - } - return false; -} -} -bool NodeShapeTransUtils::Init() { - if (op_desc_ == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "op_desc_ is nullptr, check invalid."); - GELOGE(GRAPH_FAILED, "[Check][Param] input op_desc_ is nullptr!"); - return false; - } - in_num_ = op_desc_->MutableAllInputName().size(); - out_num_ = op_desc_->MutableAllOutputName().size(); - map_format_in_.resize(in_num_, FORMAT_RESERVED); - map_ori_format_in_.resize(in_num_, FORMAT_RESERVED); - map_dtype_in_.resize(in_num_, DT_UNDEFINED); - map_format_out_.resize(out_num_, FORMAT_RESERVED); - map_ori_format_out_.resize(out_num_, FORMAT_RESERVED); - map_dtype_out_.resize(out_num_, DT_UNDEFINED); - return true; -} -bool NodeShapeTransUtils::CatchFormatAndShape() { - for (size_t i = 0UL; i < in_num_; i++) { - const auto tensor_desc_input = op_desc_->MutableInputDesc(static_cast(i)); - if (tensor_desc_input == nullptr) { - continue; - } - const auto format = tensor_desc_input->GetFormat(); - const auto ori_format = tensor_desc_input->GetOriginFormat(); - if ((format == ori_format) && - (tensor_desc_input->GetShape() == tensor_desc_input->GetOriginShape())) { - GELOGD("Node is %s, input tensor idx is %zu. ori format: %s, format: %s, ori shape:%s, shape:%s is same! " - "No need to catch format&shape!", op_desc_->GetName().c_str(), i, - TypeUtils::FormatToSerialString(ori_format).c_str(), - TypeUtils::FormatToSerialString(format).c_str(), - tensor_desc_input->GetOriginShape().ToString().c_str(), - tensor_desc_input->GetShape().ToString().c_str()); - continue; - } - map_format_in_[i] = format; - map_ori_format_in_[i] = ori_format; - map_dtype_in_[i] = tensor_desc_input->GetDataType(); - tensor_desc_input->SetFormat(ori_format); - tensor_desc_input->SetShape(tensor_desc_input->GetOriginShape()); - } - - for (size_t i = 0UL; i < out_num_; i++) { - const auto tensor_desc_output = op_desc_->MutableOutputDesc(static_cast(i)); - if (tensor_desc_output == nullptr) { - continue; - } - const auto format = tensor_desc_output->GetFormat(); - const auto ori_format = tensor_desc_output->GetOriginFormat(); - if (SameCurrentAndOrigin(tensor_desc_output)) { - GELOGD("Node is %s, output tensor idx is %zu. ori format: %s, format: %s, ori shape:%s, shape:%s is same!" - "or output original not initialized. No need to catch format&shape!", op_desc_->GetName().c_str(), i, - TypeUtils::FormatToSerialString(ori_format).c_str(), - TypeUtils::FormatToSerialString(format).c_str(), - tensor_desc_output->GetOriginShape().ToString().c_str(), - tensor_desc_output->GetShape().ToString().c_str()); - continue; - } - map_format_out_[i] = format; - map_ori_format_out_[i] = ori_format; - map_dtype_out_[i] = tensor_desc_output->GetDataType(); - - if (format == ori_format) { - continue; - } - tensor_desc_output->SetFormat(ori_format); - } - - return true; -} - -bool NodeShapeTransUtils::UpdateFormatAndShape() { - transformer::ShapeTransferAccordingToFormat shape_transfer; - for (size_t i = 0UL; i < in_num_; i++) { - const auto tensor_desc_input = op_desc_->MutableInputDesc(static_cast(i)); - if (tensor_desc_input == nullptr) { - continue; - } - // if can not find saved info, it says format and origin format is same when catched - if (map_format_in_[i] == FORMAT_RESERVED) { - GELOGD("Node is [%s], input tensor idx [%zu] is not been catched.Skip update action for it!", - op_desc_->GetName().c_str(), i); - tensor_desc_input->SetOriginFormat(tensor_desc_input->GetFormat()); - tensor_desc_input->SetOriginShape(tensor_desc_input->MutableShape()); - continue; - } - const auto ori_format = tensor_desc_input->GetFormat(); - auto &ori_shape = tensor_desc_input->MutableShape(); - const auto curr_format = map_format_in_[i]; - if (curr_format == FORMAT_ND) { - continue; - } - const ge::DataType dtype = map_dtype_in_[i]; - - // FE set and Ge get for PadDimention - std::string infer_reshape_type; - (void) AttrUtils::GetStr(*tensor_desc_input, ATTR_NAME_RESHAPE_INFER_TYPE, infer_reshape_type); - const bool is_success = transformer::ExpandDimension(op_desc_->GetType(), ori_format, curr_format, i, - infer_reshape_type, ori_shape); - if (!is_success) { - REPORT_INNER_ERR_MSG("E18888", "ExpandDimension failed, op type:%s", op_desc_->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Call][ExpandDimension] failed, op type:%s", op_desc_->GetType().c_str()); - return false; - } - transformer::ShapeAndFormat shape_and_format_info {ori_shape, ori_format, curr_format, dtype}; - (void)shape_transfer.GetShapeAccordingToFormat(op_desc_, shape_and_format_info); - tensor_desc_input->SetFormat(curr_format); - } - - for (size_t i = 0UL; i < out_num_; i++) { - const auto tensor_desc_output = op_desc_->MutableOutputDesc(static_cast(i)); - if (tensor_desc_output == nullptr) { - continue; - } - // if can not find saved info, it says format and origin format is same when catched - if (map_ori_format_out_[i] == FORMAT_RESERVED) { - GELOGD("Node is [%s], output tensor idx [%zu] is not been catched.Skip update action for it!", - op_desc_->GetName().c_str(), i); - tensor_desc_output->SetOriginFormat(tensor_desc_output->GetFormat()); - tensor_desc_output->SetOriginShape(tensor_desc_output->MutableShape()); - continue; - } - auto &ori_shape = tensor_desc_output->MutableShape(); - const auto curr_format = tensor_desc_output->GetFormat(); - if (curr_format != map_ori_format_out_[i]) { - REPORT_INNER_ERR_MSG("E18888", - "Node is %s, out tensor idx is %zu. format: %s, " - "recorded origin format: %s is not same", - op_desc_->GetName().c_str(), i, TypeUtils::FormatToSerialString(curr_format).c_str(), - TypeUtils::FormatToSerialString(map_ori_format_out_[i]).c_str()); - GELOGE(GRAPH_FAILED, "[Check][Param] Node is %s, out tensor idx is %zu. format: %s, " - "recorded origin format: %s is not same", op_desc_->GetName().c_str(), i, - TypeUtils::FormatToSerialString(curr_format).c_str(), - TypeUtils::FormatToSerialString(map_ori_format_out_[i]).c_str()); - return false; - } - tensor_desc_output->SetOriginShape(ori_shape); - const auto saved_format = map_format_out_[i]; - if (saved_format == FORMAT_ND) { - GELOGD("Node is %s, out tensor idx is %zu. ori format: %s, recorded format: %s is same! No need to transfer", - op_desc_->GetName().c_str(), i, TypeUtils::FormatToSerialString(curr_format).c_str(), - TypeUtils::FormatToSerialString(saved_format).c_str()); - continue; - } - tensor_desc_output->SetFormat(saved_format); - const ge::DataType dtype = tensor_desc_output->GetDataType(); - - // FE set and Ge get for PadDimention - std::string infer_reshape_type; - (void) AttrUtils::GetStr(*tensor_desc_output, ATTR_NAME_RESHAPE_INFER_TYPE, infer_reshape_type); - const bool is_success = transformer::ExpandDimension(op_desc_->GetType(), curr_format, saved_format, i, - infer_reshape_type, ori_shape); - if (!is_success) { - REPORT_INNER_ERR_MSG("E18888", "ExpandDimension failed, op type:%s.", op_desc_->GetType().c_str()); - GELOGE(GRAPH_FAILED, "[Call][ExpandDimension] failed, op type:%s.", op_desc_->GetType().c_str()); - return false; - } - transformer::ShapeAndFormat shape_and_format_info {ori_shape, curr_format, saved_format, dtype}; - (void)shape_transfer.GetShapeAccordingToFormat(op_desc_, shape_and_format_info); - GELOGD("Node is %s, out tensor idx is %zu. Update format and shape success, ori format: %s, format: %s", - op_desc_->GetName().c_str(), i, TypeUtils::FormatToSerialString(curr_format).c_str(), - TypeUtils::FormatToSerialString(saved_format).c_str()); - } - GELOGD("Node is %s. Update format and shape success", op_desc_->GetName().c_str()); - return true; -} -} // namespace ge diff --git a/graph/utils/transformer_utils.h b/graph/utils/transformer_utils.h deleted file mode 100644 index 5ed9a433488a6ffa1900cb7e6e7571c9dc31a43b..0000000000000000000000000000000000000000 --- a/graph/utils/transformer_utils.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ -#define COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ -#include -#include - -#include "external/graph/types.h" -#include "graph/op_desc.h" -#include "graph/ge_tensor.h" -#include "graph/small_vector.h" -#include "graph/ascend_limits.h" - -namespace ge { - -class NodeShapeTransUtils { - public: - bool Init(); - bool CatchFormatAndShape(); - bool UpdateFormatAndShape(); - - explicit NodeShapeTransUtils(const OpDescPtr op_desc) : op_desc_(op_desc), in_num_(0U), out_num_(0U) { - } - - ~NodeShapeTransUtils() { - } - - private: - SmallVector map_format_in_; - SmallVector map_ori_format_in_; - SmallVector map_dtype_in_; - SmallVector map_format_out_; - SmallVector map_ori_format_out_; - SmallVector map_dtype_out_; - - OpDescPtr op_desc_; - size_t in_num_; - size_t out_num_; -}; -} // namespace ge -#endif // COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_ diff --git a/graph/utils/tuning_utils.cc b/graph/utils/tuning_utils.cc deleted file mode 100644 index 0ca443522e93226a013ead3d911dfc94dabbff7f..0000000000000000000000000000000000000000 --- a/graph/utils/tuning_utils.cc +++ /dev/null @@ -1,1117 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/tuning_utils.h" - -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_op_types.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/file_utils.h" -#include "graph/utils/recover_ir_utils.h" -#include "inc/common/checker.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -namespace { -const int64_t kControlIndex = -1; -const char_t *const peer_node_name_attr = "_peerNodeName"; -const char_t *const parent_node_name_attr = "_parentNodeName"; -const char_t *const alias_name_attr = "_aliasName"; -const char_t *const alias_indexes_attr = "_aliasIndexes"; -const char_t *const parent_node_anchor_index_attr = "_parentNodeAnchorIndex"; -const char_t *const tuning_subgraph_prefix = "/aicore_subgraph_"; -const char_t *const non_tuning_subgraph_prefix = "/subgraph_"; -const char_t *const kTmpWeightDir = "tmp_weight_"; -const char_t *const kOriginName4Recover = "_origin_name_4_recover"; -const char_t *const kOriginType4Recover = "_origin_type_4_recover"; -const char_t *const kLocation4Recover = "_location_4_recover"; -const char_t *const kLength4Recover = "_length_4_recover"; -const std::set kPartitionOpTypes = {PLACEHOLDER, END}; -const std::set kExeTypes = {DATA, CONSTANT, FILECONSTANT, NETOUTPUT}; -const size_t kConstOpNormalWeightSize = 1U; -const size_t kMaxDataLen = 1048576U; // 1M -} -const std::set ir_builder_supported_options_for_lx_fusion = { - BUILD_MODE, - BUILD_STEP, - TUNING_PATH -}; - -const std::set build_mode_options = { - BUILD_MODE_NORMAL, - BUILD_MODE_TUNING, - BUILD_MODE_BASELINE, - BUILD_MODE_OPAT_RESULT -}; - -const std::set build_step_options = { - BUILD_STEP_BEFORE_UB_MATCH, - BUILD_STEP_AFTER_UB_MATCH, - BUILD_STEP_AFTER_BUILDER, - BUILD_STEP_AFTER_BUILDER_SUB, - BUILD_STEP_BEFORE_BUILD, - BUILD_STEP_AFTER_BUILD, - BUILD_STEP_AFTER_MERGE -}; - -NodeNametoNodeNameMap TuningUtils::data_2_end_; -NodetoNodeNameMap TuningUtils::data_node_2_end_node_ ; -NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; -NodeVec TuningUtils::netoutput_nodes_; -NodeVec TuningUtils::merged_graph_nodes_; -SubgraphCreateOutNode TuningUtils::create_output_; -std::mutex TuningUtils::mutex_; -std::set TuningUtils::reusable_weight_files_; -std::map TuningUtils::name_to_index_; -std::map> TuningUtils::hash_to_files_; - -std::string TuningUtils::PrintCheckLog() { - std::stringstream ss; - ss << "d2e:{"; - for (const auto &pair : data_2_end_) { - ss << "data:" << pair.first << "-" << "end:" << pair.second; - ss << " | "; - } - ss << "}"; - ss << "netoutputs:{"; - for (const auto &node : netoutput_nodes_) { - ss << "netoutput:" << node->GetName(); - ss << " | "; - } - ss << "}"; - return ss.str(); -} - -std::string TuningUtils::GetNodeNameByAnchor(const Anchor * const anchor) { - if (anchor == nullptr) { - REPORT_INNER_ERR_MSG("E18888", "Anchor is nullptr, check invalid"); - GELOGE(GRAPH_FAILED, "[Check][Param] Anchor is nullptr"); - return "Null"; - } - const auto node = anchor->GetOwnerNodeBarePtr(); - return (node == nullptr) ? "Null" : node->GetName(); -} - -// part 1 -graphStatus TuningUtils::ConvertGraphToFile(std::vector tuning_subgraphs, - std::vector non_tuning_subgraphs, - const bool exe_flag, const std::string &path, - const std::string &user_path) { - int64_t i = 0; - int64_t j = 0; - const std::lock_guard lock(mutex_); - reusable_weight_files_.clear(); - name_to_index_.clear(); - hash_to_files_.clear(); - GELOGI("Total tuning graph num: %zu, non tuning graph: %zu.", tuning_subgraphs.size(), non_tuning_subgraphs.size()); - for (auto &subgraph : tuning_subgraphs) { - (void)create_output_.emplace(subgraph, nullptr); - auto help_info = HelpInfo{i, exe_flag, true, path, user_path}; - help_info.need_preprocess_ = true; - if (MakeExeGraph(subgraph, help_info) != SUCCESS) { - GELOGE(GRAPH_FAILED, "[Invoke][MakeExeGraph] TUU:subgraph %zu generate exe graph failed", i); - return GRAPH_FAILED; - } - i++; - } - - for (auto &subgraph : non_tuning_subgraphs) { - (void)create_output_.emplace(subgraph, nullptr); - const auto help_info = HelpInfo{j, true, false, path, user_path}; - if (MakeExeGraph(subgraph, help_info) != SUCCESS) { - GELOGE(GRAPH_FAILED, "[Invoke][MakeExeGraph] TUU:non tuning_subgraph %zu generate exe graph failed", j); - return GRAPH_FAILED; - } - j++; - } - create_output_.clear(); - return SUCCESS; -} - -graphStatus TuningUtils::ConvertConstToWeightAttr(const ComputeGraphPtr &exe_graph) { - GELOGI("Start to convert const to weight attr of graph %s.", exe_graph->GetName().c_str()); - for (const auto &node : exe_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - if (node->GetType() != PLACEHOLDER) { - continue; - } - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::vector weight; - TryGetWeight(node, weight); - if (weight.empty()) { - continue; - } - if (!ge::AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight[0U])) { - REPORT_INNER_ERR_MSG("E18888", "Set tensor to node[%s] failed", op_desc->GetName().c_str()); - GELOGE(FAILED, "[Set][Tensor] to node[%s] failed", op_desc->GetName().c_str()); - return FAILED; - } - GELOGI("Set tensor to node[%s].", op_desc->GetName().c_str()); - } - return SUCCESS; -} - -// +---------------+ -// | pld pld | -// | \ / | -// | relu relu | -// | \ / | -// | add | -// | | | -// | end | -// +---------------+ -// | -// | -// V -// +---------------+ -// | data data | -// | \ / | -// | relu relu | -// | \ / | -// | add | -// | | | -// | netoutput | -// +---------------+ -graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, - const HelpInfo& help_info) { - GE_CHECK_NOTNULL(exe_graph); - graphStatus ret = exe_graph->TopologicalSortingGraph(true); - if (ret != SUCCESS) { - GraphUtils::DumpGEGraphToOnnx(*exe_graph, "black_box"); - REPORT_INNER_ERR_MSG("E18888", "TopologicalSortingGraph [%s] failed, saved to file black_box ret:%u.", - exe_graph->GetName().c_str(), ret); - GELOGE(ret, "[Sort][Graph] Graph[%s] topological sort failed, saved to file black_box ret:%u.", - exe_graph->GetName().c_str(), ret); - return ret; - } - // clear graph id - GE_ASSERT_TRUE(AttrUtils::SetStr(*exe_graph, ATTR_NAME_SESSION_GRAPH_ID, "")); - GELOGI("TUU:clear [%s] session_graph_id success", exe_graph->GetName().c_str()); - // if not make exe, just dump and return - if (!help_info.exe_flag_) { - if (ConvertConstToWeightAttr(exe_graph) != SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Convert const to weight attr of graph %s failed", exe_graph->GetName().c_str()); - GELOGE(FAILED, "[Convert][Const] to weight attr of graph %s failed", exe_graph->GetName().c_str()); - return FAILED; - } - DumpGraphToPath(exe_graph, help_info.index_, help_info.is_tuning_graph_, help_info.path_); - GELOGI("TUU:just return, dump original sub_graph[%s]index[%" PRId64 "]", exe_graph->GetName().c_str(), - help_info.index_); - return SUCCESS; - } - // modify sub graph - for (NodePtr &node : exe_graph->GetDirectNode()) { - // 1.handle pld - if (node->GetType() == PLACEHOLDER) { - GE_ASSERT_GRAPH_SUCCESS(HandlePld(node, help_info.path_)); - } - // 2.handle end - if (node->GetType() == END) { - GE_ASSERT_GRAPH_SUCCESS(HandleEnd(node)); - } - GE_ASSERT_GRAPH_SUCCESS(HandleConst(node, help_info.path_)); - if (help_info.need_preprocess_) { - GE_ASSERT_GRAPH_SUCCESS(PreProcessNode(node)); - } - } - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RemoveNodesByTypeWithoutRelink(exe_graph, std::string(PLACEHOLDER))); - GE_ASSERT_GRAPH_SUCCESS(GraphUtils::RemoveNodesByTypeWithoutRelink(exe_graph, std::string(END))); - GE_ASSERT_GRAPH_SUCCESS(exe_graph->TopologicalSortingGraph(true)); - // dump subgraphs which modified by us - if (help_info.user_path_.empty()) { - DumpGraphToPath(exe_graph, help_info.index_, help_info.is_tuning_graph_, help_info.path_); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path_); - } - return SUCCESS; -} - -void TuningUtils::DumpGraphToPath(const ComputeGraphPtr &exe_graph, const int64_t index, - const bool is_tuning_graph, std::string path) { - if (!path.empty()) { - if (is_tuning_graph) { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } - } else { - path = "./"; - if (is_tuning_graph) { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } else { - GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); - } - } -} - -void TuningUtils::TryGetWeight(const NodePtr &node, std::vector &weight) { - // The caller guarantees that the node is not null - ConstGeTensorPtr ge_tensor = nullptr; - (void) NodeUtils::TryGetWeightByPlaceHolderNode(node, ge_tensor); - if (ge_tensor != nullptr) { - weight.emplace_back(std::const_pointer_cast(ge_tensor)); - } -} - -graphStatus TuningUtils::HandleConst(NodePtr &node, const std::string &aoe_path) { - if (kConstOpTypes.count(node->GetType()) == 0U) { - return SUCCESS; - } - const auto &weights = OpDescUtils::MutableWeights(node); - GE_ASSERT_TRUE(weights.size() == kConstOpNormalWeightSize); - GE_CHECK_NOTNULL(weights[0]); - - const size_t data_length = weights[0]->GetData().GetSize(); - // empty tensor - if (data_length == 0U) { - return SUCCESS; - } - - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, kOriginName4Recover, node->GetName())); - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, kOriginType4Recover, node->GetType())); - op_desc->SetType(FILECONSTANT); - op_desc->SetName(op_desc->GetName() + "_" + FILECONSTANT); - - GE_ASSERT_SUCCESS(SetFileConstInfo(node, weights[0U], aoe_path, op_desc)); - weights[0U]->ClearData(); - return SUCCESS; -} - -std::string TuningUtils::GenerateFileConstPath(const std::string &aoe_path, const OpDescPtr &op_desc) { - std::string file_path; - if ((!AttrUtils::GetStr(op_desc, parent_node_name_attr, file_path)) || (file_path.empty())) { - file_path = op_desc->GetName(); - } - static std::atomic node_count{0}; - const auto iter = name_to_index_.find(file_path); - if (iter == name_to_index_.end()) { - name_to_index_[file_path] = node_count; - file_path = kTmpWeightDir + std::to_string(mmGetPid()) + "/" + std::to_string(node_count); - ++node_count; - } else { - file_path = kTmpWeightDir + std::to_string(mmGetPid()) + "/" + std::to_string(iter->second); - } - - if (aoe_path.empty()) { - return "./" + file_path; - } - return aoe_path + "/" + file_path; -} - -Status TuningUtils::CheckFilesSame(const std::string &file_name, const char_t *const data, const size_t data_length, - bool &is_content_same) { - const auto file_buff = ComGraphMakeUnique(data_length); - GE_CHECK_NOTNULL(file_buff); - const auto &real_path = RealPath(file_name.c_str()); - GE_ASSERT_TRUE(!real_path.empty()); - std::ifstream ifs(real_path, std::ifstream::binary); - GE_ASSERT_TRUE(ifs.is_open()); - (void)ifs.seekg(0, std::ifstream::end); - const size_t file_length = static_cast(ifs.tellg()); - if (data_length != file_length) { - ifs.close(); - return SUCCESS; - } - (void)ifs.seekg(0, std::ifstream::beg); - (void)ifs.read(static_cast(file_buff.get()), static_cast(file_length)); - GE_ASSERT_TRUE(ifs.good()); - ifs.close(); - if ((memcmp(data, file_buff.get(), data_length) == 0)) { - is_content_same = true; - GELOGD("Check files with same content success"); - } - return SUCCESS; -} - -Status TuningUtils::GetOrSaveReusableFileConst(const GeTensorPtr &tensor, std::string &file_path) { - if (reusable_weight_files_.count(file_path) != 0U) { - GELOGD("File: %s is reusable.", file_path.c_str()); - return SUCCESS; - } - - const char_t* data = PtrToPtr(tensor->GetData().GetData()); - const size_t data_length = tensor->GetData().GetSize(); - GE_ASSERT_TRUE(data_length > 0U); - const size_t file_buff_len = std::min(data_length, kMaxDataLen); - const std::string file_buff_str(data, data + file_buff_len); - const size_t hash_value = std::hash{}(file_buff_str); - GELOGD("Get hash of file[%s] success, value[%zu]", file_path.c_str(), hash_value); - if (hash_to_files_.find(hash_value) == hash_to_files_.end()) { - GE_ASSERT_SUCCESS(SaveBinToFile(data, data_length, file_path)); - reusable_weight_files_.emplace(file_path); - hash_to_files_[hash_value].emplace_back(file_path); - GELOGD("Save reusable weight file: %s, hash_value: %zu.", file_path.c_str(), hash_value); - return SUCCESS; - } - - for (const auto &file : hash_to_files_[hash_value]) { - bool has_same_content = false; - GE_ASSERT_SUCCESS(CheckFilesSame(file, data, data_length, has_same_content)); - if (has_same_content) { - GELOGD("External weight file[%s] can be reused, skip generate file:%s", file.c_str(), file_path.c_str()); - file_path = file; - return SUCCESS; - } - } - - GE_ASSERT_SUCCESS(SaveBinToFile(data, data_length, file_path)); - reusable_weight_files_.emplace(file_path); - hash_to_files_[hash_value].emplace_back(file_path); - GELOGD("Save reusable weight file: %s, hash_value: %zu.", file_path.c_str(), hash_value); - - return SUCCESS; -} - -graphStatus TuningUtils::SetFileConstInfo(const NodePtr &node, const GeTensorPtr &tensor, const std::string &aoe_path, - const OpDescPtr &op_desc) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - std::string file_path = GenerateFileConstPath(aoe_path, node->GetOpDesc()); - GELOGD("Generate tmp weight file path: %s of %s.", file_path.c_str(), node->GetName().c_str()); - GE_ASSERT_SUCCESS(GetOrSaveReusableFileConst(tensor, file_path)); - GE_ASSERT_TRUE(AttrUtils::SetStr(op_desc, kLocation4Recover, file_path)); - - const int64_t length = static_cast(tensor->GetData().GetSize()); - GE_ASSERT_TRUE(AttrUtils::SetInt(op_desc, kLength4Recover, length)); - const auto tensor_desc = tensor->GetTensorDesc(); - GE_ASSERT_TRUE(AttrUtils::SetDataType(op_desc, VAR_ATTR_DTYPE, tensor_desc.GetDataType())); - GE_ASSERT_TRUE(AttrUtils::SetListInt(op_desc, VAR_ATTR_SHAPE, tensor_desc.GetShape().GetDims())); - - GELOGD("Convert node: %s to file constant: %s success, file path: %s, length: %ld.", node->GetName().c_str(), - op_desc->GetName().c_str(), file_path.c_str(), length); - - return SUCCESS; -} - -graphStatus TuningUtils::CreateDataNode(NodePtr &node, const std::string &aoe_path, NodePtr &data_node) { - const auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - OpDescPtr data_op_desc; - std::vector weight; - TryGetWeight(node, weight); - GeTensorDesc output_desc; - if (!weight.empty()) { - GE_ASSERT_TRUE(weight.size() == kConstOpNormalWeightSize); - GE_CHECK_NOTNULL(weight[0U]); - const size_t data_length = weight[0U]->GetData().GetSize(); - // empty tensor - if (data_length == 0U) { - data_op_desc = ComGraphMakeShared(node->GetName(), CONSTANT); - } else { - const std::string file_const_name = node->GetName() + "_" + FILECONSTANT; - data_op_desc = ComGraphMakeShared(file_const_name, FILECONSTANT); - GE_CHECK_NOTNULL(data_op_desc); - GE_ASSERT_SUCCESS(SetFileConstInfo(node, weight[0U], aoe_path, data_op_desc)); - } - output_desc = weight[0U]->GetTensorDesc(); - std::string parent_node_name; - if (AttrUtils::GetStr(node->GetOpDesc(), parent_node_name_attr, parent_node_name) && (!parent_node_name.empty())) { - (void) AttrUtils::SetStr(data_op_desc, ATTR_NAME_SRC_CONST_NAME, parent_node_name); - } - GELOGD("Create const node for %s, output_desc shape is:%s", - node->GetName().c_str(), output_desc.GetShape().ToString().c_str()); - } else { - data_op_desc = ComGraphMakeShared(node->GetName(), DATA); - const auto pld_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(pld_op_desc); - output_desc = pld_op_desc->GetOutputDesc(0U); // only one output for pld and data - GELOGD("Create data node for %s, output_desc shape is:%s", - node->GetName().c_str(), output_desc.GetShape().ToString().c_str()); - } - GE_CHECK_NOTNULL(data_op_desc); - // data inputdesc & outputdesc set as same - GE_ASSERT_GRAPH_SUCCESS(data_op_desc->AddInputDesc(output_desc)); - GE_ASSERT_GRAPH_SUCCESS(data_op_desc->AddOutputDesc(output_desc)); - data_node = graph->AddNode(data_op_desc); - GE_CHECK_NOTNULL(data_node); - if (data_node->GetType() == CONSTANT) { - if (OpDescUtils::SetWeights(data_node, weight) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "TUU:const node %s add weight failed", data_op_desc->GetName().c_str()); - GELOGE(FAILED, "[Set][Weights] TUU:const node %s add weight failed", data_op_desc->GetName().c_str()); - return FAILED; - } - } - GE_ASSERT_GRAPH_SUCCESS(data_node->SetOwnerComputeGraph(graph)); - return SUCCESS; -} - -graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, const NodePtr &data_node) { - const auto op_desc = data_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - const auto pld_desc = pld->GetOpDesc(); - GE_CHECK_NOTNULL(pld_desc); - // inherit - // a. set `end's input node type` as attr - std::string parent_op_type; - if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); - GELOGE(FAILED, "[Invoke][GetStr] TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void) AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type); - // b. set `end's input node name` as attr - std::string parent_op_name; - if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); - GELOGE(FAILED, "[Invoke][GetStr] TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void) AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name); - // c. set `end's input node's out anchor index` as attr - int32_t parent_node_anchor_index; - if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); - GELOGE(FAILED, "[Invoke][GetStr] TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void) AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index); - GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", - pld->GetName().c_str(), pld->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); - // d. set `end node name` as attr - std::string peer_end_name; - if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); - GELOGE(FAILED, "[Invoke][GetStr] TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); - return FAILED; - } - (void) AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name); - GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", - pld->GetName().c_str(), pld->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::ChangePld2Data(const NodePtr &node, const NodePtr &data_node) { - const auto type_pld = node->GetType(); - const auto type_data = data_node->GetType(); - if ((type_pld != PLACEHOLDER) || (kExeTypes.count(type_data) == 0U)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), - type_pld.c_str(), type_data.c_str()); - GELOGE(FAILED, "[Check][Param] TUU:Failed to change node %s from type %s to type %s", - node->GetName().c_str(), type_pld.c_str(), type_data.c_str()); - return FAILED; - } - const auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - std::vector output_map(static_cast(node->GetAllOutDataAnchorsSize())); - for (size_t i = 0UL; i < node->GetAllOutDataAnchorsSize(); ++i) { - output_map[i] = static_cast(i); - } - - const auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map); - if (ret != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "TUU:Failed to replace node %s by node %s, ret:%u", node->GetName().c_str(), - data_node->GetName().c_str(), ret); - GELOGE(FAILED, "[Replace][Node] %s by node %s failed, ret:%u", - node->GetName().c_str(), data_node->GetName().c_str(), ret); - return FAILED; - } - - NodeUtils::UnlinkAll(*node); - - GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", - node->GetName().c_str(), node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); - return ret; -} - -graphStatus TuningUtils::HandlePld(NodePtr &node, const std::string &aoe_path) { - GE_CHECK_NOTNULL(node); - const auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - - NodePtr data_node = nullptr; - // 1. create data node - if (CreateDataNode(node, aoe_path, data_node) != SUCCESS) { - GELOGE(FAILED, "[Create][DataNode] TUU:Failed to handle node %s from graph %s", - node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 2. add necessary info to data_node for recovery whole graph - if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) { - GELOGE(FAILED, "[Add][Attr] TUU:Failed to handle node %s from graph %s", - node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 3. replace pld node by data node created before - if (ChangePld2Data(node, data_node) != SUCCESS) { - GELOGE(FAILED, "[Change][Pld2Data] TUU:Failed to handle node %s from graph %s", - node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:pld[%s] handle success", node->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::CreateNetOutput(const NodePtr &node, NodePtr &out_node) { - GE_CHECK_NOTNULL(node); - const auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - const auto search = create_output_.find(graph); - if (search == create_output_.end()) { - REPORT_INNER_ERR_MSG("E18888", "TUU:node %s's owner sub graph %s does not exist in create_output map", - node->GetName().c_str(), graph->GetName().c_str()); - GELOGE(FAILED, "[Check][Param] TUU:node %s's owner sub graph %s does not exist in create_output map", - node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - if (search->second != nullptr) { - out_node = search->second; - GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str()); - return SUCCESS; - } - const auto out_op_desc = ComGraphMakeShared(node->GetName(), NETOUTPUT); - GE_CHECK_NOTNULL(out_op_desc); - out_node = graph->AddNode(out_op_desc); - GE_CHECK_NOTNULL(out_node); - if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "TUU:SetOwnerComputeGraph failed, graph:%s", graph->GetName().c_str()); - GELOGE(FAILED, "[Set][Graph] TUU:SetOwnerComputeGraph failed, graph:%s", graph->GetName().c_str()); - return FAILED; - } - create_output_[graph] = out_node; - return SUCCESS; -} - -graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, const NodePtr &out_node, - const int64_t index) { - GE_CHECK_NOTNULL(end); - GE_CHECK_NOTNULL(out_node); - const auto op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::vector alias_names = {}; - (void) AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names); - alias_names.push_back(end->GetName()); - (void) AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names); - - std::vector indexes = {}; - (void) AttrUtils::GetListInt(op_desc, alias_indexes_attr, indexes); - indexes.push_back(index); - (void) AttrUtils::SetListInt(op_desc, alias_indexes_attr, indexes); - - return SUCCESS; -} - -graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { - GE_CHECK_NOTNULL(end_node); - GE_CHECK_NOTNULL(out_node); - GE_CHECK_NOTNULL(end_node->GetInDataAnchor(0)); - // get end in node is control node or normal node - const AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) - ? Anchor::DynamicAnchorCast(end_node->GetInControlAnchor()) - : Anchor::DynamicAnchorCast(end_node->GetInDataAnchor(0)); - GE_CHECK_NOTNULL(end_in_anchor); - const auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 - GE_CHECK_NOTNULL(src_anchor); - if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", - "TUU:remove end input edge from from %s(%d) to %s(%d) failed. " - "node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), - end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); - GELOGE(FAILED, "[Remove][Edge] TUU:remove end input edge from from %s(%d) to %s(%d) failed. " - "node_name:%s, graph_name:%s", GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), - end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - // add edge between `end in node` and `out_node` - if (src_anchor->IsTypeIdOf()) { - const std::shared_ptr - anchor = ComGraphMakeShared(out_node, out_node->GetAllInDataAnchors().size()); - GE_CHECK_NOTNULL(anchor); - GE_CHECK_NOTNULL(out_node->impl_); - out_node->impl_->in_data_anchors_.push_back(anchor); - if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - GELOGE(FAILED, "[Add][Edge] from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), - end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - const auto end_op_desc = end_node->GetOpDesc(); - GE_CHECK_NOTNULL(end_op_desc); - const auto out_node_op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(out_node_op_desc); - // end node always has one input - if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0U)) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str()); - GELOGE(FAILED, "[Add][InputDesc] failed, TUU:node %s .", out_node_op_desc->GetName().c_str()); - return FAILED; - } - // add necessary info to out_node for recovery whole graph - if (AddAttrToNetOutputForMergeGraph(end_node, out_node, static_cast(anchor->GetIdx())) != SUCCESS) { - GELOGE(FAILED, "[Add][Attr] TUU:Failed to handle node %s from graph %s", - end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - } else if (src_anchor->IsTypeIdOf()) { - OpDescPtr noop = nullptr; - noop = ComGraphMakeShared(end_node->GetName() + NOOP, NOOP); - GE_CHECK_NOTNULL(noop); - const auto noop_node = end_node->GetOwnerComputeGraph()->AddNode(noop); - GE_CHECK_NOTNULL(noop_node); - const auto out_in_anchor = out_node->GetInControlAnchor(); - if ((GraphUtils::AddEdge(src_anchor, noop_node->GetInControlAnchor()) != GRAPH_SUCCESS) || - (GraphUtils::AddEdge(noop_node->GetOutControlAnchor(), out_in_anchor) != GRAPH_SUCCESS)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(noop_node->GetInControlAnchor().get()).c_str(), - noop_node->GetInControlAnchor()->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - GELOGE(FAILED, "[Add][Edge] from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", - GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), - GetNodeNameByAnchor(noop_node->GetInControlAnchor().get()).c_str(), - noop_node->GetInControlAnchor()->GetIdx(), end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - // add necessary info to out_node for recovery whole graph - if (AddAttrToNetOutputForMergeGraph(end_node, out_node, kControlIndex) != SUCCESS) { - GELOGE(FAILED, "[Add][Attr] TUU:Failed to handle node %s from graph %s", end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - } else { - REPORT_INNER_ERR_MSG("E18888", "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(), - end_node->GetOwnerComputeGraph()->GetName().c_str()); - GELOGE(FAILED, "[Handle][Node] TUU: node_name:%s, graph_name:%s handled failed", - end_node->GetName().c_str(), end_node->GetOwnerComputeGraph()->GetName().c_str()); - return FAILED; - } - - return SUCCESS; -} - -graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { - GE_CHECK_NOTNULL(end_node); - GE_CHECK_NOTNULL(out_node); - const auto type_end = end_node->GetType(); - const auto type_out = out_node->GetType(); - if ((type_end != END) || (type_out != NETOUTPUT)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:Failed to change end_node %s from type %s to type %s", - end_node->GetName().c_str(), type_end.c_str(), type_out.c_str()); - GELOGE(FAILED, "[Check][Param] TUU:Failed to change end_node %s from type %s to type %s", - end_node->GetName().c_str(), type_end.c_str(), type_out.c_str()); - return FAILED; - } - // link all `end nodes's in node` to this out_node - if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) { - GELOGE(FAILED, "[Invoke][LinkEnd2NetOutput] failed, TUU:end_node [%s].", end_node->GetName().c_str()); - return FAILED; - } - // remove `end node` - NodeUtils::UnlinkAll(*end_node); - return SUCCESS; -} - -graphStatus TuningUtils::HandleEnd(NodePtr &node) { - GE_CHECK_NOTNULL(node); - const auto graph = node->GetOwnerComputeGraph(); - GE_CHECK_NOTNULL(graph); - NodePtr out_node = nullptr; - - // 1. create net_output node , add only one NetOutput node to one subgraph - if (CreateNetOutput(node, out_node) != SUCCESS) { - GELOGE(FAILED, "[Create][NetOutput] TUU:Failed to handle node %s from graph %s", - node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - // 2. replace all end nodes by one output node created before - if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) { - GELOGE(FAILED, "[Invoke][ChangeEnd2NetOutput] TUU:Failed to handle node %s from graph %s", - node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:end[%s] handle success", node->GetName().c_str()); - return SUCCESS; -} - -// part 2 -graphStatus TuningUtils::ConvertFileToGraph(const std::map &options, ge::Graph &graph) { - // 1. get all subgraph object - std::vector root_graphs; - std::map> name_to_subgraphs; - if (LoadGraphFromFile(options, root_graphs, name_to_subgraphs) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "Load graph from file according to options failed"); - return GRAPH_FAILED; - } - - // 2. merge root graph - ComputeGraphPtr merged_root_graph = ComGraphMakeShared("whole_graph_after_tune"); - GE_CHECK_NOTNULL(merged_root_graph); - if (MergeGraph(root_graphs, merged_root_graph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "merge root graph failed"); - return GRAPH_FAILED; - } - - // 3. merge subgraphs - std::map name_to_merged_subgraph; - for (const auto &pair : name_to_subgraphs) { - ComputeGraphPtr merged_subgraph = ComGraphMakeShared(pair.first); - GE_CHECK_NOTNULL(merged_subgraph); - if (MergeGraph(pair.second, merged_subgraph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "merge root graph failed"); - return GRAPH_FAILED; - } - name_to_merged_subgraph[pair.first] = merged_subgraph; - } - - // 4. construct relation of root graph and subgraphs - const auto ret_link_subgraph = LinkSubgraph(merged_root_graph, merged_root_graph, name_to_merged_subgraph); - if (ret_link_subgraph != GRAPH_SUCCESS) { - return ret_link_subgraph; - } - - // 5. construct relation of root graph and subgraph of subgrah - for (const auto &subgraph_iter: name_to_merged_subgraph) { - const auto ret = LinkSubgraph(merged_root_graph, subgraph_iter.second, name_to_merged_subgraph); - if (ret != GRAPH_SUCCESS) { - return ret; - } - } - - graph = GraphUtilsEx::CreateGraphFromComputeGraph(merged_root_graph); - return SUCCESS; -} - -graphStatus TuningUtils::LinkSubgraph(ComputeGraphPtr &root_graph, const ComputeGraphPtr &graph, - const std::map &name_to_merged_subgraph) { - for (const auto &node : graph->GetDirectNode()) { - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - for (const auto &subgraph_name : op_desc->GetSubgraphInstanceNames()) { - const auto iter = name_to_merged_subgraph.find(subgraph_name); - if (iter == name_to_merged_subgraph.end()) { - REPORT_INNER_ERR_MSG("E18888", "TUU:can not find subgraph with name:%s for op:%s.", subgraph_name.c_str(), - op_desc->GetName().c_str()); - GELOGE(GRAPH_FAILED, "can not find subgraph with name:%s for op:%s", - subgraph_name.c_str(), op_desc->GetName().c_str()); - return GRAPH_FAILED; - } - - iter->second->SetParentNode(node); - iter->second->SetParentGraph(graph); - (void)root_graph->AddSubGraph(iter->second); - GELOGI("add subgraph:%s for node:%s success", subgraph_name.c_str(), op_desc->GetName().c_str()); - } - } - return GRAPH_SUCCESS; -} - -graphStatus TuningUtils::MergeGraph(const std::vector &subgraphs, - ComputeGraphPtr &output_merged_compute_graph) { - GE_CHECK_NOTNULL(output_merged_compute_graph); - const std::function callback = [&]() { - data_2_end_.clear(); - data_node_2_end_node_.clear(); - data_node_2_netoutput_node_.clear(); - netoutput_nodes_.clear(); - merged_graph_nodes_.clear(); - }; - GE_MAKE_GUARD(release, callback); - - // merge graph - if (MergeAllSubGraph(subgraphs, output_merged_compute_graph) != GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "[Merge][Graph] failed"); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -graphStatus TuningUtils::LoadGraphFromFile(const std::map &options, - std::vector &root_graphs, - std::map> &name_to_subgraphs) { - // options format like {index:"subgraph_path"} - for (const auto &pair : options) { - auto compute_graph = ComGraphMakeShared(std::to_string(pair.first)); - if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), compute_graph)) { - REPORT_INNER_ERR_MSG("E18888", "LoadGEGraph from file:%s failed", pair.second.c_str()); - GELOGE(FAILED, "[Load][Graph] from file:%s failed", pair.second.c_str()); - } - bool is_root_graph = false; - if (ge::AttrUtils::GetBool(compute_graph, ATTR_NAME_IS_ROOT_GRAPH, is_root_graph) && - is_root_graph) { - root_graphs.emplace_back(compute_graph); - } else { - std::string parent_graph_name; - if (!ge::AttrUtils::GetStr(compute_graph, ATTR_NAME_PARENT_GRAPH_NAME, parent_graph_name)) { - REPORT_INNER_ERR_MSG("E18888", "TUU:get attr ATTR_NAME_PARENT_GRAPH_NAME failed for subgraph."); - GELOGE(GRAPH_FAILED, "get attr ATTR_NAME_PARENT_GRAPH_NAME failed for subgraph:%s", - compute_graph->GetName().c_str()); - return GRAPH_FAILED; - } - name_to_subgraphs[parent_graph_name].emplace_back(compute_graph); - } - } - - if (root_graphs.empty()) { - REPORT_INNER_ERR_MSG("E18888", "TUU:root graph has no subgraphs, can not merge."); - GELOGE(GRAPH_FAILED, "root graph has no subgraphs, can not merge"); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -// +----------------------------------+ -// | const const | -// | \ / | -// | netoutput(end,end) | -// +----------------------------------+ -// + -// +----------------------------------+ -// | data(pld) data(pld) | -// | \ / | -// | relu relu | -// | \ / | -// | \ / | -// | add | -// | | | -// | netoutput(end) | -// +----------------------------------+ -// + -// +----------------------------------+ -// | data(pld) | -// | / | -// | netoutput | -// +----------------------------------+ -// | -// | -// V -// +----------------------------------+ -// | const const | -// | \ / | -// | relu relu | -// | \ / | -// | \ / | -// | add | -// | | | -// | netoutput | -// +----------------------------------+ -graphStatus TuningUtils::MergeAllSubGraph(const std::vector &subgraphs, - ComputeGraphPtr &output_merged_compute_graph) { - GE_CHECK_NOTNULL(output_merged_compute_graph); - // 1. handle all subgraphs - for (auto &subgraph : subgraphs) { - const Status ret_status = MergeSubGraph(subgraph); - if (ret_status != SUCCESS) { - GELOGE(ret_status, "[Invoke][MergeSubGraph] TUU:subgraph %s merge failed", subgraph->GetName().c_str()); - return ret_status; - } - } - - for (const auto &node : merged_graph_nodes_) { - (void) output_merged_compute_graph->AddNode(node); - // set owner graph - GE_CHK_STATUS_RET(node->SetOwnerComputeGraph(output_merged_compute_graph), - "[Set][Graph] TUU:node %s set owner graph failed", node->GetName().c_str()); - GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str()); - } - - // 2. remove data and output node added by us - if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) { - GELOGE(FAILED, "[Remove][Edge] TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str()); - return FAILED; - } - const graphStatus ret = output_merged_compute_graph->TopologicalSorting(); - if (ret != SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "Graph[%s] topological sort failed, ret:%u.", - output_merged_compute_graph->GetName().c_str(), ret); - GELOGE(ret, "[Sort][Graph] Graph[%s] topological sort failed, ret:%u.", - output_merged_compute_graph->GetName().c_str(), ret); - return ret; - } - GELOGD("TUU:Print-%s", PrintCheckLog().c_str()); - GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str()); - return SUCCESS; -} - -graphStatus TuningUtils::MergeSubGraph(const ComputeGraphPtr &subgraph) { - for (auto &node : subgraph->GetDirectNode()) { - if (kPartitionOpTypes.count(node->GetType()) > 0UL) { - REPORT_INNER_ERR_MSG("E18888", "TUU:subgraph passed in should not contain nodes of end or pld type"); - GELOGE(FAILED, "[Check][Param] TUU:subgraph passed in should not contain nodes of end or pld type"); - return FAILED; - } - // handle data converted from pld node - if ((node->GetType() == DATA) || (node->GetType() == CONSTANT)) { - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::string peer_out_name; - const bool has_valid_str = - (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty()); - if (has_valid_str) { - const std::lock_guard lock(mutex_); - (void)data_2_end_.emplace(op_desc->GetName(), peer_out_name); - (void)data_node_2_end_node_.emplace(node, peer_out_name); - continue; - } - } - // handle netoutput converted from end node - if (node->GetType() == NETOUTPUT) { - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::vector out_alias_name; - const bool has_valid_str = - (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty()); - if (has_valid_str) { - const std::lock_guard lock(mutex_); - netoutput_nodes_.emplace_back(node); - } - } - { - const std::lock_guard lock(mutex_); - merged_graph_nodes_.emplace_back(node); - } - GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str()); - } - GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str()); - return SUCCESS; -} - -NodePtr TuningUtils::FindNode(const std::string &name, int64_t &in_index) { - for (const auto &node : netoutput_nodes_) { - if (node == nullptr) { - continue; - } - std::vector out_alias_name; - std::vector alias_indexes; - if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name) && - AttrUtils::GetListInt(node->GetOpDesc(), alias_indexes_attr, alias_indexes) && - (out_alias_name.size() == alias_indexes.size())) { - for (size_t i = 0UL; i < out_alias_name.size(); i++) { - if (out_alias_name[i] == name) { - in_index = alias_indexes[i]; - return node; - } - } - } - } - return nullptr; -} - -graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) { - GE_CHECK_NOTNULL(graph); - // 1. traverse - for (auto &pair : data_node_2_end_node_) { - auto data_node = pair.first; - GE_CHECK_NOTNULL(data_node); - const auto end_name = pair.second; - int64_t index = 0; - auto netoutput_node = FindNode(end_name, index); - GELOGD("TUU:start to find info[%s][%s][%" PRId64 "] ", data_node->GetName().c_str(), end_name.c_str(), index); - GE_CHECK_NOTNULL(netoutput_node); - (void)data_node_2_netoutput_node_.emplace(data_node, netoutput_node); - // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor` - GE_CHECK_NOTNULL(data_node->GetOutDataAnchor(0)); - const AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr) - ? Anchor::DynamicAnchorCast(data_node->GetOutControlAnchor()) - : Anchor::DynamicAnchorCast(data_node->GetOutDataAnchor(0)); - AnchorPtr net_output_in_anchor = nullptr; - AnchorPtr src_out_anchor = nullptr; - if (index != kControlIndex) { - net_output_in_anchor = netoutput_node->GetInDataAnchor(static_cast(index)); - GE_CHECK_NOTNULL(net_output_in_anchor); - src_out_anchor = net_output_in_anchor->GetFirstPeerAnchor(); - } else { - net_output_in_anchor = netoutput_node->GetInControlAnchor(); - for (const auto &out_ctrl : net_output_in_anchor->GetPeerAnchorsPtr()) { - const auto noop_node = out_ctrl->GetOwnerNode(); - GE_CHECK_NOTNULL(noop_node); - if ((noop_node->GetType() == NOOP) && (noop_node->GetName() == (end_name + NOOP))) { - src_out_anchor = noop_node->GetInControlAnchor()->GetFirstPeerAnchor(); - // remove noop node - NodeUtils::UnlinkAll(*noop_node); - if (GraphUtils::RemoveJustNode(graph, noop_node) != SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "TUU:noop node [%s] RemoveNodeWithoutRelink failed.", - noop_node->GetName().c_str()); - GELOGE(FAILED, "[Remove][Node]TUU:noop node [%s] RemoveNodeWithoutRelink failed.", - noop_node->GetName().c_str()); - return FAILED; - } - break; - } - } - } - GE_CHECK_NOTNULL(src_out_anchor); - GELOGD("TUU:get out node:%s 's in anchor(%d) peer_src_node:%s 's out anchor(%d) match info[%s][%s][%" PRId64 "]", - netoutput_node->GetName().c_str(), net_output_in_anchor->GetIdx(), - src_out_anchor->GetOwnerNode()->GetName().c_str(), src_out_anchor->GetIdx(), data_node->GetName().c_str(), - end_name.c_str(), index); - - // 3. relink - // unlink netoutput_node with it's input in stage 4 - GE_CHECK_NOTNULL(data_out_anchor); - for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) { - if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", - "[Remove][Edge] from %s(%d) to %s(%d) failed. " - "node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - GELOGE(FAILED, "[Remove][Edge] from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", - "TUU:add edge from %s(%d) to %s(%d) failed. " - "node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - GELOGE(FAILED, "[Add][Edge] from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", - GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), - GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), - data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); - return FAILED; - } - } - } - // 4. remove out nodes added by us - for (auto &node: netoutput_nodes_) { - NodeUtils::UnlinkAll(*node); - if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E18888", "TUU:Failed to remove node %s from graph", node->GetName().c_str()); - GELOGE(FAILED, "[Remove][Node] %s from graph failed.", node->GetName().c_str()); - return FAILED; - } - GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str()); - } - return SUCCESS; -} - -graphStatus TuningUtils::PreProcessNode(const NodePtr &node) { - const auto &op_desc = node->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - if (op_desc->GetType() == PLACEHOLDER || op_desc->GetType() == END) { - return GRAPH_SUCCESS; - } - // strep 0: recovery ir - if (op_desc->GetIrInputs().empty() && op_desc->GetIrOutputs().empty() && (op_desc->GetAllOutputsDescSize() != 0U)) { - GE_ASSERT_GRAPH_SUCCESS(RecoverIrUtils::RecoverOpDescIrDefinition(op_desc), - "Failed recover ir def for %s %s", - op_desc->GetNamePtr(), - op_desc->GetTypePtr()); - GELOGI("Node %s %s recover ir def successfully", node->GetNamePtr(), node->GetTypePtr()); - } - GELOGI("Node %s %s pre-process successfully", node->GetNamePtr(), node->GetTypePtr()); - return GRAPH_SUCCESS; -} -} // namespace ge diff --git a/graph/utils/type_utils.cc b/graph/utils/type_utils.cc deleted file mode 100644 index e2c17cb1b52f7d23ea13dd9162f62fbec1c9008b..0000000000000000000000000000000000000000 --- a/graph/utils/type_utils.cc +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/type_utils.h" -#include "base/utils/type_utils_impl.h" -#include "graph/utils/type_utils_inner.h" -#include -#include "graph/buffer.h" -#include "graph/debug/ge_util.h" -#include "external/graph/types.h" - -namespace ge { -namespace { -const std::map kDomiFormatToGeFormat = { - {domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, - {domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, - {domi::DOMI_TENSOR_ND, FORMAT_ND}, - {domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, - {domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, - {domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, - {domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, - {domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, - {domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, - {domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, - {domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, - {domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, - {domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, - {domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, - {domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, - {domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, - {domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED} -}; - -const std::set kInternalFormat = { - "NC1HWC0", - "FRACTAL_Z", - "NC1C0HWPAD", - "NHWC1C0", - "FRACTAL_DECONV", - "C1HWNC0", - "FRACTAL_DECONV_TRANSPOSE", - "FRACTAL_DECONV_SP_STRIDE_TRANS", - "NC1HWC0_C04", - "FRACTAL_Z_C04", - "FRACTAL_DECONV_SP_STRIDE8_TRANS", - "NC1KHKWHWC0", - "C1HWNCoC0", - "FRACTAL_ZZ", - "FRACTAL_NZ", - "NDC1HWC0", - "FRACTAL_Z_3D", - "FRACTAL_Z_3D_TRANSPOSE", - "FRACTAL_ZN_LSTM", - "FRACTAL_Z_G", - "ND_RNN_BIAS", - "FRACTAL_ZN_RNN", - "NYUV", - "NYUV_A" -}; - -const std::map kFmkTypeToString = { - {domi::CAFFE, "caffe"}, - {domi::MINDSPORE, "mindspore"}, - {domi::TENSORFLOW, "tensorflow"}, - {domi::ANDROID_NN, "android_nn"}, - {domi::ONNX, "onnx"}, - {domi::FRAMEWORK_RESERVED, "framework_reserved"}, -}; - -const std::map kImplyTypeToString = { - {domi::ImplyType::BUILDIN, "buildin"}, - {domi::ImplyType::TVM, "tvm"}, - {domi::ImplyType::CUSTOM, "custom"}, - {domi::ImplyType::AI_CPU, "ai_cpu"}, - {domi::ImplyType::CCE, "cce"}, - {domi::ImplyType::GELOCAL, "gelocal"}, - {domi::ImplyType::HCCL, "hccl"}, - {domi::ImplyType::INVALID, "invalid"} -}; -} - - -std::string TypeUtils::DataTypeToSerialString(const DataType data_type) { - return TypeUtilsImpl::DataTypeToAscendString(data_type).GetString(); -} - -DataType TypeUtils::SerialStringToDataType(const std::string &str) { - return TypeUtilsImpl::AscendStringToDataType(str.c_str()); -} - -std::string TypeUtils::FormatToSerialString(const Format format) { - return TypeUtilsImpl::FormatToAscendString(format).GetString(); -} - -Format TypeUtils::SerialStringToFormat(const std::string &str) { - return TypeUtilsImpl::AscendStringToFormat(str.c_str()); -} - -Format TypeUtils::DataFormatToFormat(const std::string &str) { - return TypeUtilsImpl::DataFormatToFormat(str.c_str()); -} - -bool TypeUtils::GetDataTypeLength(const ge::DataType data_type, uint32_t &length) { - return TypeUtilsImpl::GetDataTypeLength(data_type, length); -} - -std::string TypeUtilsInner::ImplyTypeToSerialString(const domi::ImplyType imply_type) { - const auto it = kImplyTypeToString.find(imply_type); - if (it != kImplyTypeToString.end()) { - return it->second; - } else { - REPORT_INNER_ERR_MSG("E18888", "ImplyTypeToSerialString: imply_type not support %u", - static_cast(imply_type)); - GELOGE(GRAPH_FAILED, "[Check][Param] ImplyTypeToSerialString: imply_type not support %u", - static_cast(imply_type)); - return "UNDEFINED"; - } -} - -bool TypeUtilsInner::IsInternalFormat(const Format format) { - const std::string serial_format = TypeUtils::FormatToSerialString(static_cast(GetPrimaryFormat(format))); - const auto iter = kInternalFormat.find(serial_format); - const bool result = (iter == kInternalFormat.cend()) ? false : true; - return result; -} - -Format TypeUtilsInner::DomiFormatToFormat(const domi::domiTensorFormat_t domi_format) { - const auto it = kDomiFormatToGeFormat.find(domi_format); - if (it != kDomiFormatToGeFormat.end()) { - return it->second; - } - GELOGW("[Check][Param] do not find domi Format %d from map", domi_format); - return FORMAT_RESERVED; -} - -std::string TypeUtilsInner::FmkTypeToSerialString(const domi::FrameworkType fmk_type) { - const auto it = kFmkTypeToString.find(fmk_type); - if (it != kFmkTypeToString.end()) { - return it->second; - } else { - GELOGW("[Util][Serialize] Framework type %d not support.", fmk_type); - return ""; - } -} -} // namespace ge diff --git a/graph/utils/type_utils_ex.cc b/graph/utils/type_utils_ex.cc deleted file mode 100644 index a531dd17d5a77a0cb07b22e4e3ed17938b73a1d6..0000000000000000000000000000000000000000 --- a/graph/utils/type_utils_ex.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "graph/utils/type_utils.h" -#include "base/utils/type_utils_impl.h" - -namespace ge { - -ge::AscendString TypeUtils::DataTypeToAscendString(const DataType &data_type) { - return TypeUtilsImpl::DataTypeToAscendString(data_type); -} - -DataType TypeUtils::AscendStringToDataType(const ge::AscendString &str) { - return TypeUtilsImpl::AscendStringToDataType(str); -} - -AscendString TypeUtils::FormatToAscendString(const Format &format) { - return TypeUtilsImpl::FormatToAscendString(format); -} - -Format TypeUtils::AscendStringToFormat(const AscendString &str) { - return TypeUtilsImpl::AscendStringToFormat(str); -} - -Format TypeUtils::DataFormatToFormat(const AscendString &str) { - return TypeUtilsImpl::DataFormatToFormat(str); -} - -} \ No newline at end of file diff --git a/inc/CMakeLists.txt b/inc/CMakeLists.txt index 96a273c95eaa1d244ce661e8b3b246f55cf72cdd..2ab2880c384cb46d1f4226e9d0bf641608f1fa8c 100644 --- a/inc/CMakeLists.txt +++ b/inc/CMakeLists.txt @@ -25,13 +25,6 @@ target_include_directories(metadef_headers INTERFACE # 下列头文件包含路径是非法的,需要在后续整改中删掉 # --------------------start------------------------ - $ - $ - $ - $ - $ - $ - $ $ # ---------------------end----------------------- diff --git a/inc/common/blocking_queue.h b/inc/common/blocking_queue.h deleted file mode 100644 index 6e4bf199b56753e97682fbe1bf7c49ddd5648e9f..0000000000000000000000000000000000000000 --- a/inc/common/blocking_queue.h +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_BLOCKING_QUEUE_H_ -#define INC_COMMON_BLOCKING_QUEUE_H_ - -#include -#include -#include - -namespace ge { -constexpr uint32_t kDefaultMaxQueueSize = 2048U; -constexpr int32_t kDefaultWaitTimeoutInSec = 600; - -template -class BlockingQueue { - public: - explicit BlockingQueue(const uint32_t max_size = kDefaultMaxQueueSize) : max_size_(max_size) {} - - ~BlockingQueue() = default; - - bool Pop(T &item, const int32_t time_out = INT32_MAX) { - std::unique_lock lock(mutex_); - - while (!empty_cond_.wait_for(lock, std::chrono::seconds(time_out), - [this]() -> bool { return (!queue_.empty()) || (is_stoped_); })) { - is_stuck_ = true; - return false; - } - - if (is_stoped_) { - return false; - } - - item = std::move(queue_.front()); - queue_.pop_front(); - - full_cond_.notify_one(); - - return true; - } - - bool Pop(T &item, bool &is_stuck) { - const auto ret = Pop(item, kDefaultWaitTimeoutInSec); - is_stuck = is_stuck_; - return ret; - } - - bool Push(const T &item, const bool is_wait = true) { - std::unique_lock lock(mutex_); - - while ((queue_.size() >= max_size_) && (!is_stoped_)) { - if (!is_wait) { - return false; - } - full_cond_.wait(lock); - } - - if (is_stoped_) { - return false; - } - - queue_.push_back(item); - - empty_cond_.notify_one(); - - return true; - } - - bool Push(T &&item, const bool is_wait = true) { - std::unique_lock lock(mutex_); - - while ((queue_.size() >= max_size_) && (!is_stoped_)) { - if (!is_wait) { - return false; - } - full_cond_.wait(lock); - } - - if (is_stoped_) { - return false; - } - - queue_.emplace_back(std::move(item)); - - empty_cond_.notify_one(); - - return true; - } - - void Stop() { - { - const std::unique_lock lock(mutex_); - is_stoped_ = true; - } - - full_cond_.notify_all(); - empty_cond_.notify_all(); - } - - void Restart() { - const std::unique_lock lock(mutex_); - is_stoped_ = false; - } - - // if the queue is stoped ,need call this function to release the unprocessed items - std::list GetRemainItems() { - const std::unique_lock lock(mutex_); - - if (!is_stoped_) { - return std::list(); - } - - return queue_; - } - - bool IsFull() { - const std::unique_lock lock(mutex_); - return queue_.size() >= max_size_; - } - - void Clear() { - const std::unique_lock lock(mutex_); - queue_.clear(); - } - - void SetMaxSize(const uint32_t size) { - const std::unique_lock lock(mutex_); - if (size == 0U) { - max_size_ = kDefaultMaxQueueSize; - return; - } - max_size_ = size; - } - - uint32_t Size() { - const std::unique_lock lock(mutex_); - return static_cast(queue_.size()); - } - - private: - std::list queue_; - std::mutex mutex_; - std::condition_variable empty_cond_; - std::condition_variable full_cond_; - uint32_t max_size_; - - bool is_stoped_{false}; - bool is_stuck_{false}; -}; -} // namespace ge - -#endif // INC_COMMON_BLOCKING_QUEUE_H_ diff --git a/inc/common/dynamic_aipp.h b/inc/common/dynamic_aipp.h deleted file mode 100644 index 47b22c6fd8f6aeef37344a9babf8fdf078df23a2..0000000000000000000000000000000000000000 --- a/inc/common/dynamic_aipp.h +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_DYNAMIC_AIPP_H_ -#define INC_COMMON_DYNAMIC_AIPP_H_ - -#include - -/** -* @ingroup dnn -* @brief struct define of dynamic aipp batch parameter. -*/ -struct tagAippDynamicBatchPara { - int8_t cropSwitch; // crop switch - int8_t scfSwitch; // resize switch - int8_t paddingSwitch; // 0: unable padding - // 1: padding config value,sfr_filling_hblank_ch0 ~ sfr_filling_hblank_ch2 - // 2: padding source picture data, single row/collumn copy - // 3: padding source picture data, block copy - // 4: padding source picture data, mirror copy - int8_t rotateSwitch; // rotate switch,0: non-ratate, - // 1: ratate 90° clockwise,2: ratate 180° clockwise,3: ratate 270° clockwise - int8_t reserve[4]; - int32_t cropStartPosW; // the start horizontal position of cropping - int32_t cropStartPosH; // the start vertical position of cropping - int32_t cropSizeW; // crop width - int32_t cropSizeH; // crop height - - int32_t scfInputSizeW; // input width of scf - int32_t scfInputSizeH; // input height of scf - int32_t scfOutputSizeW; // output width of scf - int32_t scfOutputSizeH; // output height of scf - - int32_t paddingSizeTop; // top padding size - int32_t paddingSizeBottom; // bottom padding size - int32_t paddingSizeLeft; // left padding size - int32_t paddingSizeRight; // right padding size - - int16_t dtcPixelMeanChn0; // mean value of channel 0 - int16_t dtcPixelMeanChn1; // mean value of channel 1 - int16_t dtcPixelMeanChn2; // mean value of channel 2 - int16_t dtcPixelMeanChn3; // mean value of channel 3 - - uint16_t dtcPixelMinChn0; // min value of channel 0 - uint16_t dtcPixelMinChn1; // min value of channel 1 - uint16_t dtcPixelMinChn2; // min value of channel 2 - uint16_t dtcPixelMinChn3; // min value of channel 3 - uint16_t dtcPixelVarReciChn0; // sfr_dtc_pixel_variance_reci_ch0 - uint16_t dtcPixelVarReciChn1; // sfr_dtc_pixel_variance_reci_ch1 - uint16_t dtcPixelVarReciChn2; // sfr_dtc_pixel_variance_reci_ch2 - uint16_t dtcPixelVarReciChn3; // sfr_dtc_pixel_variance_reci_ch3 - - int8_t reserve1[16]; // 32B assign, for ub copy -}; -using kAippDynamicBatchPara = tagAippDynamicBatchPara; - -/** -* @ingroup dnn -* @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte -*/ -struct tagAippDynamicPara { - uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 - int8_t cscSwitch; // csc switch - int8_t rbuvSwapSwitch; // rb/ub swap switch - int8_t axSwapSwitch; // RGBA->ARGB, YUVA->AYUV swap switch - int8_t batchNum; // batch parameter number - int8_t reserve1[3]; - int32_t srcImageSizeW; // source image width - int32_t srcImageSizeH; // source image height - int16_t cscMatrixR0C0; // csc_matrix_r0_c0 - int16_t cscMatrixR0C1; // csc_matrix_r0_c1 - int16_t cscMatrixR0C2; // csc_matrix_r0_c2 - int16_t cscMatrixR1C0; // csc_matrix_r1_c0 - int16_t cscMatrixR1C1; // csc_matrix_r1_c1 - int16_t cscMatrixR1C2; // csc_matrix_r1_c2 - int16_t cscMatrixR2C0; // csc_matrix_r2_c0 - int16_t cscMatrixR2C1; // csc_matrix_r2_c1 - int16_t cscMatrixR2C2; // csc_matrix_r2_c2 - int16_t reserve2[3]; - uint8_t cscOutputBiasR0; // output Bias for RGB to YUV, element of row 0, unsigned number - uint8_t cscOutputBiasR1; // output Bias for RGB to YUV, element of row 1, unsigned number - uint8_t cscOutputBiasR2; // output Bias for RGB to YUV, element of row 2, unsigned number - uint8_t cscInputBiasR0; // input Bias for YUV to RGB, element of row 0, unsigned number - uint8_t cscInputBiasR1; // input Bias for YUV to RGB, element of row 1, unsigned number - uint8_t cscInputBiasR2; // input Bias for YUV to RGB, element of row 2, unsigned number - uint8_t reserve3[2]; - int8_t reserve4[16]; // 32B assign, for ub copy - - kAippDynamicBatchPara aippBatchPara; // allow transfer several batch para. -}; -using kAippDynamicPara = tagAippDynamicPara; - -#endif // INC_COMMON_DYNAMIC_AIPP_H_ diff --git a/inc/common/fe_executor/ffts_plus_qos_update.h b/inc/common/fe_executor/ffts_plus_qos_update.h deleted file mode 100644 index 94a22aea257666813ce0cc1d5a75828ddcfec575..0000000000000000000000000000000000000000 --- a/inc/common/fe_executor/ffts_plus_qos_update.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef FFTS_PLUS_QOS_UPDATE_H_ -#define FFTS_PLUS_QOS_UPDATE_H_ - -#include "runtime/rt_ffts_plus_define.h" -#include "graph/utils/node_utils.h" -namespace ffts { - -bool UpdateAicAivCtxQos(rtFftsPlusAicAivCtx_t *ctx, int label, int device_id); -bool UpdateMixAicAivCtxQos(rtFftsPlusMixAicAivCtx_t *ctx, int label, int device_id); -bool UpdateDataCtxQos(rtFftsPlusDataCtx_t *ctx, int device_id); - -} -#endif diff --git a/inc/common/ge_common/fmk_error_codes.h b/inc/common/ge_common/fmk_error_codes.h deleted file mode 100644 index 7d6c44a0ec2c05b086cdc24ea8d889254c6a564a..0000000000000000000000000000000000000000 --- a/inc/common/ge_common/fmk_error_codes.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_FMK_ERROR_CODES_H_ -#define INC_COMMON_FMK_ERROR_CODES_H_ - -#if defined(_MSC_VER) -#ifdef FUNC_VISIBILITY -#define GE_OBJECT_VISIBILITY -#else -#define GE_OBJECT_VISIBILITY -#endif -#else -#ifdef FUNC_VISIBILITY -#define GE_OBJECT_VISIBILITY -#else -#define GE_OBJECT_VISIBILITY __attribute__((visibility("hidden"))) -#endif -#endif - -#include -#include - -#include "common/ge_common/fmk_types.h" -#include "register/register_error_codes.h" -#include "external/ge_common/ge_error_codes.h" - -// Each module uses the following four macros to define error codes: -#define DECLARE_ERRORNO_OMG(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OMG, name, value) -#define DECLARE_ERRORNO_OME(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OME, name, value) -#define DECLARE_ERRORNO_CALIBRATION(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_CALIBRATION, name, value) - -#define DEF_ERRORNO(name, desc) \ - const bool g_##name##_errorno = StatusFactory::Instance()->RegisterErrorNo(name, desc) - -// Interface for Obtaining Error Code Description -#define GET_ERRORNO_STR(value) domi::StatusFactory::Instance()->GetErrDesc(value) - -namespace domi { -constexpr int32_t MODID_OMG = 1; // OMG module ID -constexpr int32_t MODID_OME = 2; // OME module ID -constexpr int32_t MODID_CALIBRATION = 3; // Calibration module ID - -class GE_FUNC_VISIBILITY StatusFactory { - public: - static StatusFactory *Instance(); - - bool RegisterErrorNo(const uint32_t err, const std::string &desc); - - std::string GetErrDesc(const uint32_t err); - - protected: - StatusFactory() = default; - virtual ~StatusFactory() = default; - - private: - std::map err_desc_; -}; - -// Common errocode -DECLARE_ERRORNO_COMMON(MEMALLOC_FAILED, 0); // 50331648 -DECLARE_ERRORNO_COMMON(CCE_FAILED, 2); // 50331650 -DECLARE_ERRORNO_COMMON(RT_FAILED, 3); // 50331651 -DECLARE_ERRORNO_COMMON(INTERNAL_ERROR, 4); // 50331652 -DECLARE_ERRORNO_COMMON(CSEC_ERROR, 5); // 50331653 -DECLARE_ERRORNO_COMMON(TEE_ERROR, 6); // 50331653 -DECLARE_ERRORNO_COMMON(UNSUPPORTED, 100); -DECLARE_ERRORNO_COMMON(OUT_OF_MEMORY, 101); - -// Omg errorcode -DECLARE_ERRORNO_OMG(PARSE_MODEL_FAILED, 0); -DECLARE_ERRORNO_OMG(PARSE_WEIGHTS_FAILED, 1); -DECLARE_ERRORNO_OMG(NOT_INITIALIZED, 2); -DECLARE_ERRORNO_OMG(TIMEOUT, 3); - -// Ome errorcode -DECLARE_ERRORNO_OME(MODEL_NOT_READY, 0); -DECLARE_ERRORNO_OME(PUSH_DATA_FAILED, 1); -DECLARE_ERRORNO_OME(DATA_QUEUE_ISFULL, 2); -} // namespace domi - -#endif // INC_COMMON_FMK_ERROR_CODES_H_ diff --git a/inc/common/ge_common/fmk_types.h b/inc/common/ge_common/fmk_types.h deleted file mode 100644 index 5447d982af0a8cfe9a885a4493464f0e212c695c..0000000000000000000000000000000000000000 --- a/inc/common/ge_common/fmk_types.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_FMK_TYPES_H_ -#define INC_COMMON_FMK_TYPES_H_ - -#include "graph/types.h" -#include "register/register_types.h" - -#endif // INC_COMMON_FMK_TYPES_H_ diff --git a/inc/common/ge_common/ge_types.h b/inc/common/ge_common/ge_types.h deleted file mode 100644 index 32eb7002fcbbdad059dac232b78c1cefb9f85367..0000000000000000000000000000000000000000 --- a/inc/common/ge_common/ge_types.h +++ /dev/null @@ -1,532 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_GE_TYPES_H_ -#define INC_COMMON_GE_TYPES_H_ - -#include -#include -#include - -#include "common/ge_common/fmk_error_codes.h" -#include "external/ge_common/ge_api_error_codes.h" -#include "external/graph/types.h" -#include "external/ge_common/ge_api_types.h" - -namespace ge { -enum RuntimeType { HOST = 0, DEVICE = 1 }; - -enum class PerfLevel : int32_t { - GEN_TASK_WITH_FUSION = -1, - GEN_TASK_WITHOUT_L2FUSION = 3, - GEN_TASK_WITHOUT_FUSION = 4 -}; - -enum FrameworkType { - CAFFE = 0, - MINDSPORE = 1, - TENSORFLOW = 3, - ANDROID_NN = 4, - ONNX = 5, -}; - -enum class GraphStage : int64_t { - GRAPH_STAGE_FUZZ = 0, - GRAPH_STAGE_RESERVED -}; - -const char_t *const kGraphDumpStage = "DumpStage"; - -const std::map kFwkTypeToStr = {{"0", "Caffe"}, - {"1", "MindSpore"}, - {"3", "TensorFlow"}, - {"4", "Android_NN"}, - {"5", "Onnx"}}; - -enum OpEngineType { - ENGINE_SYS = 0, // default engine - ENGINE_AICORE = 1, - ENGINE_VECTOR = 2, - ENGINE_AICUBE = 3, // not support - ENGINE_AIVECTOR = 4 // not support -}; - -enum InputAippType { DATA_WITHOUT_AIPP = 0, DATA_WITH_STATIC_AIPP, DATA_WITH_DYNAMIC_AIPP, DYNAMIC_AIPP_NODE }; - -enum OfflineModelFormat { OM_FORMAT_DEFAULT, OM_FORMAT_LITE, OM_FORMAT_NANO }; - -const char_t *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; -const char_t *const GE_OPTION_EXEC_PLACEMENT = "ge.exec.placement"; - -// profiling data - -const std::string kTaskTypeAicore = "AI_CORE"; -const std::string kTaskTypeAiv = "AIV"; -const std::string kTaskTypeMixAic = "MIX_AIC"; -const std::string kTaskTypeMixAiv = "MIX_AIV"; -const std::string kTaskTypeAicpu = "AI_CPU"; -const std::string kTaskTypeDsa = "DSA"; -const std::string kTaskTypeWriteBackData = "WRITE_BACK"; -const std::string kTaskTypeInvalidData = "INVALID"; -const std::string kTaskTypeInvalid = "TASK_TYPE_INVALID"; -const std::string kTaskTypeFftsPlus = "FFTS_PLUS"; -const std::string kEngineNameVectorCore = "VectorEngine"; - -const std::string kEngineNameHccl = "ops_kernel_info_hccl"; -const std::string kEngineNameRts = "DNN_VM_RTS_OP_STORE"; -const std::string kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; -const std::string kEngineNameGeLocal = "DNN_VM_GE_LOCAL_OP_STORE"; -const std::string kEngineNameAiCpu = "aicpu_ascend_kernel"; -const std::string kEngineNameAiCpuTf = "aicpu_tf_kernel"; -const std::string kEngineNameAiCore = "AIcoreEngine"; -const std::string kEngineNameDvpp = "dvpp_ops_kernel"; -const std::string kEngineNameDsa = "DSAEngine"; -const std::string kAtomicOpType = "DynamicAtomicAddrClean"; -const char_t *const kAICpuKernelLibName = "aicpu_kernel_lib_name"; -const char_t *const kPartiallySupported = "partially_supported"; - -// runtime2.0 lowering func -const std::string kAttrLowingFunc = "_ge_attr_lowering_func"; -const std::string kFFTSAiCoreLowerFunc = "ffts_ai_core_lower_func"; -const std::string kFFTSGraphLowerFunc = "ffts_graph_lower_func"; -const std::string kFFTSStaticGraphLowerFunc = "ffts_static_graph_lower_func"; -const std::string kFFTSMixL2LowerFunc = "ffts_mix_l2_lower_func"; -// runtime2.0 calculate func -const std::string kAttrCalcArgsSizeFunc = "_ge_attr_calculate_func"; -const std::string kFFTSMixL2CalcFunc = "ffts_mix_l2_calc_func"; - -const std::string kInputTensorIndexs = "input_tensor_indexs"; -const std::string kOutputTensorIndexs = "output_tensor_indexs"; -const std::string kShapeTypeStatic = "static"; -const std::string kShapeTypeDynamic = "dynamic"; -const std::string kAtomicPrefix = "_atomic"; - -constexpr uint64_t kInferSessionId = 0U; -constexpr uint64_t kReleaseFlag = 1U; -constexpr uint32_t kInvalidModelId = 0xFFFFFFFFU; -constexpr size_t kNumTaskWithAtomicAddrCleanTask = 2U; -constexpr uint32_t INVALID_MODEL_ID = 0xFFFFFFFFU; - -// dynamic execute mode -const char_t *const kLazyRecompile = "lazy_recompile"; -const char_t *const kIsCopyOuputAddr = "1"; - -constexpr size_t kMaxHostMemInputLen = 128U; // 64 aligned - -// memory policy -const std::string kBalanceMode = "BalanceMode"; -const std::string kMemoryPriority = "MemoryPriority"; -const std::set kValidValues = {"", kBalanceMode, kMemoryPriority}; - -const uint32_t kManualThread = 0U; -const uint32_t kAutoThread = 1U; - -// model deploy mode -const std::string kModelDeployModeSpmd = "SPMD"; - -// dsa -constexpr size_t kDSASetInputAddr = 0U; -constexpr size_t kDSAOutputAddrSize = 1U; -constexpr size_t kDSAWorkspaceAddrSize = 2U; -constexpr size_t kDSAInputAddrSize = 3U; -constexpr size_t kDSAArgsInputAddrSize = 4U; -constexpr size_t kDSAStateInputAddrSize = 5U; -constexpr size_t k32Bits = 32U; - -// mix -constexpr size_t kMixMultiKernelPcAddrCnt = 1U; -constexpr size_t kMixSingleKernelPcAddrCnt = 2U; -constexpr size_t kMixSingleOnlyKernelPcAddrCnt = 1U; -constexpr size_t kMixSingleKernelAicPcIndex = 0U; -constexpr size_t kMixSingleKernelAivPcIndex = 1U; - -// Data cache, including data address and length -struct DataBuffer { - void *data; // Data address - uint64_t length; // Data length - bool isDataSupportMemShare; - uint32_t placement; - - DataBuffer(void *const data_in, const uint64_t data_len, const bool is_support_mem_share = false, - const uint32_t data_placement = 0U) : data(data_in), length(data_len), - isDataSupportMemShare(is_support_mem_share), - placement(data_placement) {} - DataBuffer() : data(nullptr), length(0UL), isDataSupportMemShare(false), - placement(0U) {} -}; - -/// -/// @ingroup domi_ome -/// @brief External input data -/// -struct InputData { - uint32_t index; // Index of input data - uint32_t timestamp; // Data creation time - uint32_t timeout; // Processing timeout - uint32_t model_id; // Model ID required for data processing - uint64_t request_id = 0UL; // Request ID - std::vector blobs; // Actual input data, currently only supports one input - bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false - std::string batch_label; // Gear used for current inference in dynamic batch scene - std::vector> shapes; // Input shapes -}; - -/// Output result structure definition -struct OutputData { - uint32_t index; // Index of input data - uint32_t model_id; // The model ID corresponding to the processing result - /// Output data cache, arranged in sequence of output operators. - /// If the operator has multiple outputs, - /// the data buffer order of the operator is the same as that defined in the - /// offline model - std::vector blobs; -}; - -// The definition of command data structure -struct Command { - std::string cmd_type; // Command type - std::vector cmd_params; // Command params - uint64_t module_index; // prof module - uint32_t cache_flag; // clear prof cache flag -}; - -// The definition of I/O shape description -struct ShapeDescription { - int64_t num = 0L; - int64_t channel = 0L; - int64_t height = 0L; - int64_t width = 0L; - std::vector dims; - std::vector> shape_ranges; -}; - -// Definition of input and output description information -struct InputOutputDescInfo { - std::string name; - uint64_t size; - uint32_t data_type; - ShapeDescription shape_info; -}; - -// Definition of model io dims -struct InputOutputDims { - std::string name; - size_t dim_num; - uint32_t size; - std::vector dims; -}; - -// Definition of model io dims -struct OriginInputInfo { - Format format; - DataType data_type; - uint32_t dim_num; -}; - -// The structure of AIPP info -struct AippConfigInfo { - int8_t aipp_mode; - int8_t input_format; - int32_t src_image_size_w; - int32_t src_image_size_h; - int8_t crop; - int32_t load_start_pos_w; - int32_t load_start_pos_h; - int32_t crop_size_w; - int32_t crop_size_h; - int8_t resize; - int32_t resize_output_w; - int32_t resize_output_h; - int8_t padding; - int32_t left_padding_size; - int32_t right_padding_size; - int32_t top_padding_size; - int32_t bottom_padding_size; - int8_t csc_switch; - int8_t rbuv_swap_switch; - int8_t ax_swap_switch; - int8_t single_line_mode; - int32_t matrix_r0c0; - int32_t matrix_r0c1; - int32_t matrix_r0c2; - int32_t matrix_r1c0; - int32_t matrix_r1c1; - int32_t matrix_r1c2; - int32_t matrix_r2c0; - int32_t matrix_r2c1; - int32_t matrix_r2c2; - int32_t output_bias_0; - int32_t output_bias_1; - int32_t output_bias_2; - int32_t input_bias_0; - int32_t input_bias_1; - int32_t input_bias_2; - int32_t mean_chn_0; - int32_t mean_chn_1; - int32_t mean_chn_2; - int32_t mean_chn_3; - float32_t min_chn_0; - float32_t min_chn_1; - float32_t min_chn_2; - float32_t min_chn_3; - float32_t var_reci_chn_0; - float32_t var_reci_chn_1; - float32_t var_reci_chn_2; - float32_t var_reci_chn_3; - int8_t support_rotation; - uint32_t related_input_rank; - uint32_t max_src_image_size; -}; - -// The structure of offline Modeldata -struct ModelData { - void *model_data = nullptr; // Model binary data start addr - uint64_t model_len = 0UL; // Model binary data length - int32_t priority = 0; // Model priority - std::string key; // Key path for encrypt model, Empty for unencrypt - std::string om_name; // om file name, used for data dump - std::string om_path; // om file path, used for concatenating file constant path - std::string weight_path; // weight path, used for load weight -}; - -struct ModelParam { - ModelParam() : priority(0), mem_base(0U), mem_size(0U), weight_base(0U), weight_size(0U), fixed_mem_base(0U), - fixed_mem_size(0U), p2p_fixed_mem_base(0U), p2p_fixed_mem_size(0U) {} - ModelParam(const int32_t pri, const uintptr_t m_base, const size_t m_len, const uintptr_t w_base, const size_t w_len) - : priority(pri), mem_base(m_base), mem_size(m_len), weight_base(w_base), weight_size(w_len), fixed_mem_base(0U), - fixed_mem_size(0U), p2p_fixed_mem_base(0U), p2p_fixed_mem_size(0U) {} - virtual ~ModelParam() = default; - - int32_t priority; - uintptr_t mem_base; - size_t mem_size; - uintptr_t weight_base; - size_t weight_size; - uintptr_t fixed_mem_base; - size_t fixed_mem_size; - uintptr_t p2p_fixed_mem_base; - size_t p2p_fixed_mem_size; -}; - -// The definition of Model information -struct ModelInfo { - uint32_t version = 0U; - std::string name; - bool is_encrypt = false; // 0:unencrypt, 1:encrypt - std::vector input_desc; - std::vector output_desc; - uint8_t reserved[3] = {0U}; // 3-byte reserved field -}; - -// Asynchronous callback interface, implemented by the caller -class GE_FUNC_VISIBILITY ModelListener { - public: - virtual ~ModelListener() = default; - ModelListener() = default; - ModelListener(const ModelListener &) = delete; - ModelListener& operator=(const ModelListener &) & = delete; - /// - /// @brief Asynchronous callback interface - /// @param [in] model_id Model ID of the callback - /// @param [in] data_index Index of the input_data - /// @param [in] resultCode Execution results - /// - virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code, - std::vector &outputs) = 0; - - virtual void SetCallback(const RunAsyncCallback &callback) { - (void)callback; - } - - virtual uint32_t GetResultCode() { return 0U; }; - - virtual Status ResetResult() { return SUCCESS; }; -}; - -// Profiling info of task -struct TaskDescInfo { - uint64_t prof_time; - std::string model_name; - std::string op_name; - std::string op_type; - uint32_t block_dim; - uint32_t task_id; - uint32_t stream_id; - std::string shape_type; - int64_t cur_iter_num; - std::string task_type; - std::vector input_format; - std::vector> input_shape; - std::vector input_data_type; - std::vector output_format; - std::vector> output_shape; - std::vector output_data_type; - uint32_t context_id = 0xFFFFFFFFU; -}; - -struct OpDescInfoId { - uint32_t task_id; - uint32_t stream_id; - uint32_t context_id; - uint32_t thread_id; - int32_t device_id; - - OpDescInfoId(): task_id(0U), stream_id(0U), context_id(UINT32_MAX), thread_id(UINT32_MAX), device_id(0) {} - OpDescInfoId(const uint32_t task, const uint32_t stream) - : OpDescInfoId() { - task_id = task; - stream_id = stream; - } - OpDescInfoId(const uint32_t task, const uint32_t stream, const uint32_t context, const uint32_t thread) - : OpDescInfoId(task, stream) { - context_id = context; - thread_id = thread; - } - // for exception dump(one process may has multi thread for mutli device ids) - OpDescInfoId(const uint32_t task, const uint32_t stream, const int32_t dev_id) - : task_id(task), stream_id(stream), context_id(UINT32_MAX), thread_id(UINT32_MAX), device_id(dev_id) {} - OpDescInfoId(const uint32_t task, const uint32_t stream, const uint32_t context, const uint32_t thread, - const int32_t dev_id) - : task_id(task), stream_id(stream), context_id(context), thread_id(thread), device_id(dev_id) {} -}; - -struct OpDescInfo { - std::string op_name; - std::string op_type; - OpDescInfoId id; - uint32_t imply_type = 0U; - uint32_t block_dim = 0U; - std::string op_file_path; - std::string dev_func; - std::string tvm_magic; - uint32_t tiling_key = 0U; - uintptr_t args = 0U; - size_t args_size = 0UL; - std::map cust_to_relevant_offset_; - std::string tiling_data; - bool is_mem_log; - std::vector space_addrs; - std::string node_info; - std::vector workspace_bytes; - std::vector input_format; - std::vector> input_shape; - std::vector input_data_type; - std::vector input_addrs; - std::vector input_size; - std::vector output_format; - std::vector> output_shape; - std::vector output_data_type; - std::vector output_addrs; - std::vector output_size; - bool is_host_args{false}; - std::string all_attrs; - std::string args_before_execute; -}; - -struct DumpBlacklist { - std::string name; - std::vector pos; -}; - -struct ModelDumpConfig { - std::string model_name; - std::vector layers; - std::vector watcher_nodes; - std::vector optype_blacklist; - std::vector opname_blacklist; - std::vector> dump_op_ranges; -}; - -struct DumpConfig { - std::string dump_path; - std::string dump_mode; - std::string dump_status; - std::string dump_op_switch; - std::string dump_debug; - std::string dump_step; - std::string dump_exception; - std::vector dump_list; - std::string dump_data; - std::string dump_level; - std::vector dump_stats; -}; - -struct QueueAttrs { - uint32_t queue_id; - int32_t device_type; // CPU NPU - int32_t device_id; - uint32_t logic_id {0U}; -}; - -struct InputAlignAttrs { - uint32_t align_max_cache_num; // 0 means align not enable - int32_t align_timeout; // -1 means never timeout - bool drop_when_not_align; - uint8_t res[3]; -}; -static_assert(std::is_pod::value, "The class InputAlignAttrs must be a POD"); - -struct ModelQueueParam { - uint32_t group_total_count{1}; - uint32_t group_index{0U}; - uint32_t group_policy{0U}; - std::vector input_queues; - std::vector output_queues; - std::vector input_fusion_offsets; - std::vector input_events; - std::vector output_events; - std::vector input_queues_attrs; - std::vector output_queues_attrs; - QueueAttrs status_output_queue; - uint32_t model_uuid {0U}; - bool is_dynamic_sched {false}; - bool need_report_status {false}; - InputAlignAttrs input_align_attrs{}; - bool is_head {true}; - bool no_need_check_inputs {false}; // 废弃并删除,使用need_check_inputs代替 - bool need_check_inputs {false}; - bool need_model_config {false}; - bool mark_dump_step {false}; - bool io_with_tensor_desc {false}; - bool copy_inputs_for_non_zero_copy {false}; -}; - -// internal options -// 1: Graph resource evaluation does not limit model memory size. -const char_t *const EVALUATE_GRAPH_RESOURCE_MODE = "ge.evaluateGraphResourceMode"; - -// 3: Config all resource and device mesh -const char_t *const RESOURCE_CONFIG_PATH = "ge.resourceConfigPath"; - -// 5: auto recompute attribute -const char_t *const RECOMPUTE = "ge.recompute"; -const char_t *const GRAPH_SLICE_MODE = "ge.graphSliceMode"; - -// 6: Topological Sorting Mode -const char_t *const OPTION_TOPOSORTING_MODE = "ge.topoSortingMode"; - -const char_t *const OPTION_EXEC_RANK_TABLE = "ge.exec.rankTable"; -const char_t *const OPTION_EXEC_HCOM_GROUPLIST = "ge.exec.hcomGrouplist"; -const char_t *const OPTION_EXEC_HCOM_RANK_MAPPING = "ge.exec.hcomRankMapping"; - -const char_t *const OPTION_NUMA_CONFIG = "ge.numaConfig"; - -// 7: config format mode(expirimental option) -const char_t *const OPTION_EXEC_FORMAT_MODEL = "ge.exec.formatMode"; - -// 8: config build graph mode(online or offline) -const char_t *const OPTION_BUILD_GRAPH_MODE = "ge.buildGraphMode"; - -const std::set ir_builder_suppported_options_inner = {EVALUATE_GRAPH_RESOURCE_MODE, - RESOURCE_CONFIG_PATH, - RECOMPUTE, - OPTION_TOPOSORTING_MODE, - GRAPH_SLICE_MODE}; -} // namespace ge -#endif // INC_COMMON_GE_TYPES_H_ diff --git a/inc/common/large_bm.h b/inc/common/large_bm.h deleted file mode 100644 index 907c63d4e9a18dc4bde4dfa86f7228e52e2cb255..0000000000000000000000000000000000000000 --- a/inc/common/large_bm.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_LARGE_BM_H_ -#define INC_COMMON_LARGE_BM_H_ - -#include -#include - -/* LargeBitmap create a way to generate bitmaps larger than 64bit. */ -namespace ge { -class LargeBitmap { -public: - explicit LargeBitmap(const size_t &size); - - ~LargeBitmap() = default; - - bool operator==(const LargeBitmap &another_bm) const; - - bool operator!=(const LargeBitmap &another_bm) const; - - // set all vector to specific value - void SetValues(const uint64_t &value); - - // Get the value on position index - bool GetBit(const size_t &index) const; - - // Set the value on position index to 1 - void SetBit(const size_t &index); - - // Combine two bitmap with the following rule. - // If one bit of either one of the two bitmaps is 1, - // the result of final bitmap is 1. - void Or(const LargeBitmap &another_bm); - - // Combine two bitmap with the following rule. - // If one bit of either one of the two bitmaps is 0, - // the result of final bitmap is 0. - void And(const LargeBitmap &another_bm); - - void ClearBit(size_t bit_idx); - - void ResizeBits(size_t new_size); -private: - // Number of element in vector bits - size_t size_; - - std::vector bits_; -}; -} -#endif // INC_COMMON_LARGE_BM_H_ diff --git a/inc/common/npu_error_define.h b/inc/common/npu_error_define.h deleted file mode 100644 index f88c3ae9dc98ccc8057920932c483f5c8b6c4d61..0000000000000000000000000000000000000000 --- a/inc/common/npu_error_define.h +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_NPU_ERROR_DEFINE_H_ -#define INC_COMMON_NPU_ERROR_DEFINE_H_ - -typedef enum tagHiAiNpuLocal { - HIAI_HOST = 1, - HIAI_DEVICE = 2, -} HiAiNpuLocal; - -typedef enum tagHiAiNpuCodeType { - ERROR_CODE = 1, - EXCEPTION_CODE = 2, -} HiAiNpuCodeType; - -typedef enum tagHiAiNpuErrLevel { - NONE_LEVEL = 0, - SUGGESTION_LEVEL = 1, - NORMAL_LEVEL = 2, - SERIOUS_LEVEL = 3, - CRITICAL_ERROR = 4, -} HiAiNpuErrLevel; - -typedef enum tagHiAiNpuModuleId { - HIAI_DRIVER = 1, - HIAI_CTRLCPU = 2, - HIAI_TS = 3, - HIAI_RUNTIME = 4, - HIAI_AICPU = 5, - HIAI_CCE = 6, - HIAI_TVM = 7, - HIAI_FRAMEWORK = 8, - HiAI_ENGINE = 9, - HIAI_DVPP = 10, - HIAI_AIPP = 11, - HIAI_LOWPOWER = 12, - HIAI_MDC = 13, - HIAI_COMPILE = 14, - HIAI_TOOLCHIAN = 15, - HIAI_ALG = 16, - HIAI_PROFILING = 17, - HIAI_HCCL = 18, - HIAI_SIMULATION = 19, - HIAI_BIOS = 20, - HIAI_SEC = 21, - HIAI_TINY = 22, - HIAI_DP = 23, -} HiAiNpuModuleId; - -/* bit 31-bit30 to be hiai local */ -#define HIAI_NPULOCAL_MASK 0xC0000000 -#define SHIFT_LOCAL_MASK 30 -#define HIAI_NPULOCAL_VAL_MASK 0x3 -/* bit 29 -bit28 to be hiai aicpu code type */ -#define HIAI_CODE_TYPE_MASK 0x30000000 -#define SHIFT_CODE_MASK 28 -#define HIAI_CODE_TYPE_VAL_MASK 0x3 -/* bit 27 -bit25 to be hiai error level */ -#define HIAI_ERROR_LEVEL_MASK 0x0E000000 -#define SHIFT_ERROR_LVL_MASK 25 -#define HIAI_ERROR_LEVEL_VAL_MASK 0x7 -/* bit 24 -bit17 to be hiai mod */ -#define HIAI_MODE_ID_MASK 0x01FE0000 -#define SHIFT_MODE_MASK 17 -#define HIAI_MODE_ID_VAL_MASK 0xFF - -#define HIAI_NPU_LOC_BIT(a) \ - (HIAI_NPULOCAL_MASK & ((unsigned int)((HiAiNpuLocal)(a)) & HIAI_NPULOCAL_VAL_MASK) << SHIFT_LOCAL_MASK) -#define HIAI_NPU_CODE_TYPE_BIT(a) \ - (HIAI_CODE_TYPE_MASK & ((unsigned int)((HiAiNpuCodeType)(a)) & HIAI_CODE_TYPE_VAL_MASK) << SHIFT_CODE_MASK) -#define HIAI_NPU_ERR_LEV_BIT(a) \ - (HIAI_ERROR_LEVEL_MASK & ((unsigned int)((HiAiNpuErrLevel)(a)) & HIAI_ERROR_LEVEL_VAL_MASK) << SHIFT_ERROR_LVL_MASK) -#define HIAI_NPU_MOD_ID_BIT(a) \ - (HIAI_MODE_ID_MASK & ((unsigned int)((HiAiNpuModuleId)(a)) & HIAI_MODE_ID_VAL_MASK) << SHIFT_MODE_MASK) - -#define HIAI_NPU_ERR_CODE_HEAD(npuLocal, codeType, errLevel, moduleId) \ - (HIAI_NPU_LOC_BIT(npuLocal) + HIAI_NPU_CODE_TYPE_BIT(codeType) + HIAI_NPU_ERR_LEV_BIT(errLevel) + \ - HIAI_NPU_MOD_ID_BIT(moduleId)) - -#endif // INC_COMMON_NPU_ERROR_DEFINE_H_ diff --git a/inc/common/opskernel/ge_task_info.h b/inc/common/opskernel/ge_task_info.h deleted file mode 100644 index eec374718defa7921619ae81b84057d9536c77ea..0000000000000000000000000000000000000000 --- a/inc/common/opskernel/ge_task_info.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ -#define INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ - -#include -#include -#include -#include "runtime/rt.h" -#include "graph/op_desc.h" - -namespace ge { -struct HcclDumpInfo { - uint32_t task_id; - uint32_t stream_id; - uint32_t sub_task_type; - void *input_addr; - uint64_t input_size; - void *output_addr; - uint64_t output_size; -}; - -struct DvppInfo { - OpDescPtr op_desc; - std::vector io_addrs; - uint32_t sqe[16]; -}; - -// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD -struct GETaskKernelHcclInfo { - std::string input_name; - std::string hccl_type; - void *inputDataAddr; - void *outputDataAddr; - void *workSpaceAddr; - int64_t count; - int32_t dataType; - int32_t opType; - int64_t rootId; - uint64_t workSpaceMemSize; - std::vector dims; - std::vector hcclStreamList; - std::vector hccl_dump_info; - std::vector global_workspace_addr; - uint32_t hcclQosCfg; - std::vector inputDataAddrs; - std::vector outputDataAddrs; - std::vector workSpaceAddrs; - std::vector workSpaceMemSizes; - std::vector inputZeroCopyFlags; - std::vector outputZeroCopyFlags; -}; - -struct GETaskInfo { - uint32_t id; - uint16_t type; - uint32_t streamID; - void *stream; // rtKernelLaunch input argument - void *event; - void *privateDef; - uint32_t privateDefLen; - void *opsKernelStorePtr; - std::vector kernelHcclInfo; - DvppInfo dvpp_info; - bool needRefresh{false}; - std::vector rt_attached_streams; -}; - -struct HcomRemoteAccessAddrInfo -{ - uint32_t remotetRankID; - uint64_t remoteAddr; // host embedding table address - uint64_t localAddr; // device HBM address - uint64_t length; // memory Length in Bytes -}; - - -} // namespace ge -#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/common/opskernel/ops_kernel_builder.h b/inc/common/opskernel/ops_kernel_builder.h deleted file mode 100644 index 4213f276ffbf4a6eb9de2d5c39d995ad717924d9..0000000000000000000000000000000000000000 --- a/inc/common/opskernel/ops_kernel_builder.h +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_BUILDER_H_ -#define INC_COMMON_OPSKERNEL_OPS_KERNEL_BUILDER_H_ - -#include "external/ge_common/ge_api_error_codes.h" -#include "cce/aicpu_engine_struct.h" -#include "common/opskernel/ops_kernel_info_types.h" -#include "graph/node.h" -#include "external/ge_common/ge_api_types.h" -#include "proto/task.pb.h" - -namespace ge { -class OpsKernelBuilder { - public: - enum class Mode : uint32_t { - kNormal, - kFfts, - kFftsPlus - }; - OpsKernelBuilder() = default; - virtual ~OpsKernelBuilder() = default; - OpsKernelBuilder(const OpsKernelBuilder &) = delete; - OpsKernelBuilder(OpsKernelBuilder &&) = delete; - OpsKernelBuilder &operator=(const OpsKernelBuilder &)& = delete; - OpsKernelBuilder &operator=(OpsKernelBuilder &&)& = delete; - - // initialize OpsKernelBuilder - virtual Status Initialize(const std::map &options) = 0; - - // finalize OpsKernelBuilder - virtual Status Finalize() = 0; - - // memory allocation requirement - virtual Status CalcOpRunningParam(Node &node) = 0; - - // generate task for op - virtual Status GenerateTask(const Node &node, RunContext &context, - std::vector &tasks) = 0; - - // generate task for op with different mode - virtual Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks, - OpsKernelBuilder::Mode) { - (void)node; - (void)context; - (void)tasks; - return SUCCESS; - } - - // update task which need stream event info, after SplitStream. Only change field in task, forbid change tasks size - virtual Status UpdateTask(const Node &node, std::vector &tasks) { - (void)node; - (void)tasks; - return SUCCESS; - } - - // only call aicpu interface to generate task struct - virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, std::string &task_info) { - (void)node; - (void)task; - (void)task_info; - return FAILED; - } - - // only call aicpu interface to generate task struct - virtual Status GenMemCopyTask(const uint64_t count, STR_FWK_OP_KERNEL &task, std::string &task_info) { - (void)count; - (void)task; - (void)task_info; - return FAILED; - } -}; -} // namespace ge -#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_BUILDER_H_ diff --git a/inc/common/opskernel/ops_kernel_info_store.h b/inc/common/opskernel/ops_kernel_info_store.h deleted file mode 100644 index f9171baee8f4f7d397e416afbb37a9490b183148..0000000000000000000000000000000000000000 --- a/inc/common/opskernel/ops_kernel_info_store.h +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ -#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ - -#include -#include -#include -#include -#include - -#include "common/opskernel/ge_task_info.h" -#include "common/opskernel/ops_kernel_info_types.h" -#include "common/ge_common/ge_inner_error_codes.h" -#include "graph/node.h" -#include "external/graph/operator.h" - -namespace ge { -class OpsKernelInfoStore { - public: - OpsKernelInfoStore() = default; - - virtual ~OpsKernelInfoStore() = default; - OpsKernelInfoStore(const OpsKernelInfoStore &) = delete; - OpsKernelInfoStore(OpsKernelInfoStore &&) = delete; - OpsKernelInfoStore &operator=(const OpsKernelInfoStore &)& = delete; - OpsKernelInfoStore &operator=(OpsKernelInfoStore &&)& = delete; - - // initialize opsKernelInfoStore - virtual Status Initialize(const std::map &options) = 0; - - // close opsKernelInfoStore - virtual Status Finalize() = 0; /*lint -e148*/ - - virtual Status CreateSession(const std::map &session_options) { - (void)session_options; - return SUCCESS; - } - - virtual Status DestroySession(const std::map &session_options) { - (void)session_options; - return SUCCESS; - } - - // get all opsKernelInfo - virtual void GetAllOpsKernelInfo(std::map &infos) const = 0; - - // whether the opsKernelInfoStore is supported based on the operator attribute - virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; - - virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, - const bool realQuery = false) const { - (void)realQuery; - return CheckSupported(opDescPtr, un_supported_reason); - } - // opsFlag opsFlag[0] indicates constant folding is supported or not - virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag) { - (void)node; - (void)opsFlag; - }; - - // only call fe engine interface to compile single op - virtual Status CompileOp(std::vector &node_vec) { - (void) node_vec; - return SUCCESS; - } - virtual Status CompileOpRun(std::vector &node_vec) { - (void)node_vec; - return SUCCESS; - } - - // prepare task for op - virtual Status PrepareTaskAsync(GETaskInfo &task) { - (void)task; - return SUCCESS; - } - - // load task for op - virtual Status LoadTask(GETaskInfo &task) { - (void)task; - return SUCCESS; - } - - virtual bool CheckSupported(const ge::NodePtr &node, std::string &un_supported_reason) const { - if (node == nullptr) { - return false; - } - return CheckSupported(node->GetOpDesc(), un_supported_reason); - } - - virtual bool CheckAccuracySupported(const ge::NodePtr &node, std::string &un_supported_reason, - const bool realQuery = false) const { - (void)realQuery; - if (node == nullptr) { - return false; - } - return CheckAccuracySupported(node->GetOpDesc(), un_supported_reason, realQuery); - } - // Set cut support info - virtual Status SetCutSupportedInfo(const ge::NodePtr &node) { - (void)node; - return SUCCESS; - } - // unload task for op - virtual Status UnloadTask(GETaskInfo &task) { - (void)task; - return SUCCESS; - } - - // fuzz compile interface - virtual Status FuzzCompileOp(std::vector &node_vec) { - (void) node_vec; - return SUCCESS; - } - - // Query information such as foramt/dtype/impl supported by operators (extensible) - virtual bool GetNodeSupportInfo(const OperatorPtr &op, std::string &support_info) { - (void)op; - (void)support_info; - return false; - } - - virtual bool CheckSupported(const ge::NodePtr &node, std::string &un_supported_reason, - CheckSupportFlag &flag) const { - (void)flag; - return CheckSupported(node, un_supported_reason); - } -}; -} // namespace ge -#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ diff --git a/inc/common/opskernel/ops_kernel_info_types.h b/inc/common/opskernel/ops_kernel_info_types.h deleted file mode 100644 index d3f7b4d07685408e415beb47111d0c90c051b591..0000000000000000000000000000000000000000 --- a/inc/common/opskernel/ops_kernel_info_types.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ -#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ - -#include -#include -#include -#include "graph/buffer.h" -#include "runtime/rt_model.h" - -namespace ge { -/*lint -e148*/ -struct RunContext { - uint64_t sessionId; - uint64_t dataMemSize; - uint8_t *dataMemBase; - std::map mem_type_data_mem_size; - std::map mem_type_data_mem_base; - uint64_t weightMemSize; - uint8_t *weightMemBase; - ge::Buffer weightsBuffer; -}; - -/*lint +e148*/ -struct Task { - uint32_t id; - uint16_t type; - void *stream; - void *event; -}; - -struct OpInfo { - std::string engine; // which engin - /*lint -e148*/ - std::string opKernelLib; // which opsKernelStore - int32_t computeCost; // compute cost - bool flagPartial; // whether to support is related to shape - bool flagAsync; // Whether to support asynchronous - bool isAtomic; // whether to support atomic addr clean - std::string opFileName; // op file name - std::string opFuncName; // op function name -}; - -enum class CheckSupportFlag : uint32_t { - kDefault = 0, - kNotSupportDynamicShape -}; -} // namespace ge - -#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ diff --git a/inc/common/optimizer/graph_optimizer.h b/inc/common/optimizer/graph_optimizer.h deleted file mode 100644 index 8ea9ebea3b40c2930cfd41910ba53822f3f0f9dd..0000000000000000000000000000000000000000 --- a/inc/common/optimizer/graph_optimizer.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ -#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ - -#include -#include -#include "graph_optimizer_types.h" -#include "optimize_utility.h" -#include "common/ge_common/ge_inner_error_codes.h" -#include "common/opskernel/ops_kernel_info_types.h" -#include "graph/compute_graph.h" -#include "graph/op_kernel_bin.h" - -/*lint -e148*/ -namespace ge { -class GraphOptimizer { - public: - virtual ~GraphOptimizer() {} - - // initialize graphOptimizer - virtual Status Initialize(const std::map &options, - OptimizeUtility *const optimize_utility) = 0; - - // close graphOptimizer - virtual Status Finalize() = 0; - - // init process for optimize graph every time because options may different in different build process - // 当前引擎获取编译option是在OptimizeGraphPrepare接口中获取,该接口默认会过滤vector engine。 - // 当前出现问题场景是子图优化阶段因为算子融合直接选择了vector engine的场景,出现了vector engine获取不到编译option导致问题。 - // 当前决策新增OptimizeGraphInit接口,该接口不会过滤引擎,全部调用.这样获取到build option操作就从OptimizeGraphPrepare剥离。 - virtual Status OptimizeGraphInit(ComputeGraph& graph) { - (void)graph; - return SUCCESS; - } - - // optimize original graph for FE quant optimize - virtual Status OptimizeGraphPrepare(ComputeGraph& graph) { - (void)graph; - return SUCCESS; - } - - // optimize graph after normalization, include multi dims and pre/post process - virtual Status OptimizeAfterGraphNormalization(const ComputeGraphPtr& graph) { - (void)graph; - return SUCCESS; - } - - // optimize graph before build for RTS - virtual Status OptimizeGraphBeforeBuild(ComputeGraph& graph) { - (void)graph; - return SUCCESS; - } - - // optimize original graph, using in graph preparation stage - virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; - - // optimize original graph, using for conversion operator insert in graph preparation stage - virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { - (void)graph; - return SUCCESS; - } - - // optimize fused graph - virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; - - // optimize whole graph, using after graph merged stage - virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; - - // get attribute of graph optimizer - virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; - - // optimize streamed Graph - virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { - (void)graph; - (void)context; - return SUCCESS; - } - - // optimize streamed whole Graph - virtual Status OptimizeStreamedWholeGraph(ComputeGraph &graph) { - (void)graph; - return SUCCESS; - } - - // op compile - virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { - (void)graph; - return SUCCESS; - } - - // optimize whole graph, using after stage1 - virtual Status OptimizeAfterStage1(ComputeGraph &graph) { - (void)graph; - return SUCCESS; - } - - // recover compile result of precompiled op - using KernelLookup = std::function; - virtual Status OptimizeSubgraphOfPrecompiledOp(ComputeGraph &graph, const KernelLookup &lookup) { - static_cast(graph); - static_cast(lookup); - return SUCCESS; - } - - // 为避免子图优化中多线程操作导致的数据读写冲突,提供子图优化前后的单线程接口,由引擎实现以实现改图功能 - virtual Status OptimizeSubgraphPreProc(ComputeGraph &graph) { - (void)graph; - return SUCCESS; - } - virtual Status OptimizeSubgraphPostProc(ComputeGraph &graph) { - (void)graph; - return SUCCESS; - } - // 格式选择接口,从OptimizeOriginalGraphJudgeInsert中独立出来格式选择能力,OptimizeOriginalGraphJudgeInsert接口实现精度选择能力 - virtual Status OptimizeOriginalGraphJudgeFormatInsert(ComputeGraph &graph) { - (void)graph; - return SUCCESS; - } -}; -} // namespace ge -/*lint +e148*/ -#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ diff --git a/inc/common/optimizer/graph_optimizer_types.h b/inc/common/optimizer/graph_optimizer_types.h deleted file mode 100644 index 5b51a3aff02fc610409dfdec97702ca1cb9ce97c..0000000000000000000000000000000000000000 --- a/inc/common/optimizer/graph_optimizer_types.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ -#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ - -#include - -namespace ge { -enum OPTIMIZER_SCOPE { - UNIT = 0, - ENGINE, -}; - -struct GraphOptimizerAttribute { - std::string engineName; - OPTIMIZER_SCOPE scope; -}; -} // namespace ge - -#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ diff --git a/inc/common/optimizer/optimize_utility.h b/inc/common/optimizer/optimize_utility.h deleted file mode 100644 index 646b116e37241963f8df9f53065b7a43e554ca43..0000000000000000000000000000000000000000 --- a/inc/common/optimizer/optimize_utility.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_OPTIMIZER_OPTIMIZE_UTILITY_H_ -#define INC_COMMON_OPTIMIZER_OPTIMIZE_UTILITY_H_ - -#include "common/ge_common/ge_inner_error_codes.h" -#include "graph/compute_graph.h" - -namespace ge { -class OptimizeUtility { - public: - virtual ~OptimizeUtility() = default; - - // Deprecated: will delete later. Graph infershape util - virtual Status InferShape(ComputeGraph &compute_graph) { - (void)compute_graph; - return SUCCESS; - } - - // Graph infershape util - virtual Status InferShape(const ComputeGraphPtr &compute_graph) = 0; - - // Mlti Dims and pre/post process - virtual Status MultiDimsProcess(const ComputeGraphPtr &compute_graph) { - (void)compute_graph; - return SUCCESS; - } - - // Constant folding - virtual Status ConstantFolding(NodePtr &node) { - (void)node; - return SUCCESS; - } -}; -} // namespace ge -#endif // INC_COMMON_OPTIMIZER_OPTIMIZE_UTILITY_H_ diff --git a/inc/common/screen_printer.h b/inc/common/screen_printer.h deleted file mode 100644 index 1e80f9c3c25d62ae8abcdaa363b0b13cc1d391c8..0000000000000000000000000000000000000000 --- a/inc/common/screen_printer.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_COMMON_SCREEN_PRINTER_H_ -#define METADEF_INC_COMMON_SCREEN_PRINTER_H_ - -#include -#include -#include -#include "external/graph/types.h" - -namespace ge { -class ScreenPrinter { - public: - static ScreenPrinter &GetInstance(); - void Log(const char *fmt, ...); - void Init(const std::string &print_mode); - - private: - ScreenPrinter() = default; - ~ScreenPrinter() = default; - - ScreenPrinter(const ScreenPrinter &) = delete; - ScreenPrinter(const ScreenPrinter &&) = delete; - ScreenPrinter &operator=(const ScreenPrinter &)& = delete; - ScreenPrinter &operator=(const ScreenPrinter &&)& = delete; - - enum class PrintMode : uint32_t { - ENABLE = 0U, - DISABLE = 1U - }; - PrintMode print_mode_ = PrintMode::ENABLE; - std::mutex mutex_; -}; - -#define SCREEN_LOG(fmt, ...) \ - do { \ - ScreenPrinter::GetInstance().Log(fmt, ##__VA_ARGS__); \ - } while (false) -} // namespace ge -#endif // METADEF_INC_COMMON_SCREEN_PRINTER_H_ diff --git a/inc/common/sgt_slice_type.h b/inc/common/sgt_slice_type.h deleted file mode 100644 index 6b62f7a475f4bd35e5aeae792ef0742e81c07b6f..0000000000000000000000000000000000000000 --- a/inc/common/sgt_slice_type.h +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_SGT_SLICE_TYPES_H_ -#define INC_COMMON_SGT_SLICE_TYPES_H_ - -#include -#include -#include -#include - -namespace ffts { -const std::string kAttrSgtJsonInfo = "_sgt_json_info"; -const std::string kAttrSgtStructInfo = "_sgt_struct_info"; -const std::string kAttrSgtStructInfoDy = "_sgt_struct_info_dy"; -const size_t kSgtTillingNum = 2U; - -struct OpCut { - int16_t split_cut_idx = -1; - int16_t reduce_cut_idx = -1; - int64_t cut_id = -1; -}; - -struct DimRange { - int64_t lower; - int64_t higher; - bool operator==(const DimRange& dim_range) const { - return (this->higher == dim_range.higher) && (this->lower == dim_range.lower); - } -}; - -enum class AtomicType { - None = 0, - ADD = 1, - SUB, - MUL, - DIV -}; - -struct ThreadSliceMap { - uint32_t thread_scope_id; - bool is_first_node_in_topo_order; - uint32_t thread_mode; - uint32_t node_num_in_thread_scope; - bool is_input_node_of_thread_scope; - bool is_output_node_of_thread_scope; - std::vector>> ori_input_tensor_shape; - std::vector>> ori_output_tensor_shape; - std::string original_node; - uint32_t slice_instance_num; - uint32_t parallel_window_size; - uint32_t thread_id; - std::vector>> dependencies; - std::vector core_num; - std::vector cut_type; - std::vector atomic_types; - std::vector same_atomic_clean_nodes; - std::vector input_axis; - std::vector output_axis; - std::vector input_tensor_indexes; - std::vector output_tensor_indexes; - std::vector>> input_tensor_slice; - std::vector>> output_tensor_slice; - std::vector>> ori_input_tensor_slice; - std::vector>> ori_output_tensor_slice; - std::vector> input_cut_list; - std::vector> output_cut_list; - ThreadSliceMap() : thread_scope_id(1U), is_first_node_in_topo_order(false), thread_mode(0U), - node_num_in_thread_scope(1U), is_input_node_of_thread_scope(false), is_output_node_of_thread_scope(false), - slice_instance_num(1U), parallel_window_size(1U), thread_id(0U) {} - bool GetThreadMode() const { - return (thread_mode == 0U) ? false : true; - } -}; - -struct ThreadSliceMapDy { - uint32_t slice_instance_num; - uint32_t parallel_window_size; - std::vector input_tensor_indexes; - std::vector output_tensor_indexes; - std::vector>> input_tensor_slice; - std::vector>> output_tensor_slice; - ThreadSliceMapDy() : slice_instance_num(1U), parallel_window_size(1U) {} -}; - -using ThreadSliceMapPtr = std::shared_ptr; -using ThreadSliceMapDyPtr = std::shared_ptr; -} // namespace ffts -#endif // INC_COMMON_SGT_SLICE_TYPES_H_ diff --git a/inc/common/util/sanitizer_options.h b/inc/common/util/sanitizer_options.h deleted file mode 100644 index 8c84f8b50087b7ecdeb6437a11595a2886240769..0000000000000000000000000000000000000000 --- a/inc/common/util/sanitizer_options.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_UTILS_SANITIZERS_SANITIZER_OPTIONS_H_ -#define COMMON_UTILS_SANITIZERS_SANITIZER_OPTIONS_H_ - -// active if -fsanitize=address -#if defined(ONLY_COMPILE_OPEN_SRC) && defined(__SANITIZE_ADDRESS__) -#include "sanitizer/lsan_interface.h" -/* 如果业务代码中存在已知的内存泄漏的代码块, 并且允许这部分内存泄漏存在, 可在代码块首尾添加开关, - * 控制地址消毒器关闭与开启. - * 开关仅在蓝区CI场景可用,仅在当前thread有效. - * DT_DETECT_LEAKS_OFF(); - * // code block with memory leak - * DT_DETECT_LEAKS_ON(); - */ -#define DT_DETECT_LEAKS_OFF() \ - do { \ - __lsan_disable(); \ - } while (0) -#define DT_DETECT_LEAKS_ON() \ - do { \ - __lsan_enable(); \ - } while (0) -#define DT_DO_DETECT_LEAKS() \ - do { \ - __lsan_do_leak_check(); \ - } while (0) -#else -#define DT_DETECT_LEAKS_OFF() \ - do { \ - } while (0) -#define DT_DETECT_LEAKS_ON() \ - do { \ - } while (0) -#define DT_DO_DETECT_LEAKS() \ - do { \ - } while (0) -#endif - -#define DT_ALLOW_LEAKS_GUARD(name) ::ge::LeaksGuarder leaks_guard_for_##name - -namespace ge { -class LeaksGuarder { - public: - LeaksGuarder(const LeaksGuarder &) = delete; - LeaksGuarder &operator=(const LeaksGuarder &) = delete; - - LeaksGuarder() { - DT_DETECT_LEAKS_OFF(); - } - - ~LeaksGuarder() { - DT_DETECT_LEAKS_ON(); - } -}; - -} // namespace ge - -#endif // COMMON_UTILS_SANITIZERS_SANITIZER_OPTIONS_H_ diff --git a/inc/common/util/trace_manager/trace_manager.h b/inc/common/util/trace_manager/trace_manager.h deleted file mode 100644 index e2ad0ae7871ec3853a07d08f771223259382bdf7..0000000000000000000000000000000000000000 --- a/inc/common/util/trace_manager/trace_manager.h +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_UTIL_TRACE_MANAGER_TRACE_MANAGER_H_ -#define COMMON_UTIL_TRACE_MANAGER_TRACE_MANAGER_H_ - -#include -#include -#include -#include -#include -#include "common/ge_common/util.h" - -namespace ge { -#define TRACE_GEN_RECORD(owner, action, graph_name, node_name, node_data, tensor_index, tensor_data, content) \ - do { \ - if (TraceManager::GetInstance().IsTraceEnabled()) { \ - if (TraceManager::GetTraceHeader().size() == 0) { \ - GELOGD("[Check][Param] owner and stage have not been set"); \ - } else { \ - std::stringstream ss; \ - ss << owner << "," << action << "," << graph_name << "," << node_name << "," << node_data << "," \ - << tensor_index << "," << tensor_data << "," << content; \ - TraceManager::GetInstance().AddTrace(ss.str()); \ - } \ - } \ - } while (false) - -using char_t = char; - -constexpr uint64_t kTraceSaveTriggerNum = 5000U; - -enum class ReadyPart { A, B, None }; - -class TraceManager { - public: - static TraceManager &GetInstance(); - - void AddTrace(std::string &&trace_info); - - bool IsTraceEnabled() const { - return enabled_; - } - void SetTraceOwner(const std::string &owner, const std::string &stage, const std::string &graph_name); - void ClearTraceOwner(); - static inline const std::string &GetTraceHeader() { - return trace_header_; - } - static inline const std::string &GetOutGraphName() { - return graph_name_; - } - - private: - TraceManager(); - ~TraceManager(); - TraceManager(const TraceManager &) = delete; - TraceManager(TraceManager &&) = delete; - TraceManager &operator=(const TraceManager &) = delete; - TraceManager &operator=(TraceManager &&) = delete; - Status Initialize(const char_t *file_save_path); - void Finalize(); - - std::string NextFileName(); - void SaveTraceBufferToFile(const ReadyPart ready_part); - void SaveBufferToFileThreadFunc(); - - static thread_local std::string trace_header_; - static thread_local std::string graph_name_; - - std::atomic enabled_{false}; - std::vector trace_array_; - std::atomic trace_index_{0}; - std::atomic total_saved_nums_{0}; - std::atomic part1_ready_nums_{0}; - std::atomic part2_ready_nums_{0}; - std::string trace_save_file_path_; - std::string current_saving_file_name_; - uint64_t current_file_saved_nums_ = 0; - ReadyPart ready_part_ = ReadyPart::None; - - std::mutex mu_; - std::thread save_thread_; - std::atomic stopped_{false}; - std::condition_variable data_ready_var_; -}; - -class TraceOwnerGuard { - public: - TraceOwnerGuard(const std::string &owner, const std::string &stage, const std::string &graph_name) { - TraceManager::GetInstance().SetTraceOwner(owner, stage, graph_name); - } - ~TraceOwnerGuard() { - TraceManager::GetInstance().ClearTraceOwner(); - } - TraceOwnerGuard(const TraceOwnerGuard &) = delete; - TraceOwnerGuard(TraceOwnerGuard &&) = delete; - TraceOwnerGuard &operator=(const TraceOwnerGuard &) = delete; - TraceOwnerGuard &operator=(TraceOwnerGuard &&) = delete; -}; - -#define TRACE TraceManager::GetInstance() -#define TRACE_HEADER TraceManager::GetTraceHeader() -} // namespace ge -#endif // COMMON_UTIL_TRACE_MANAGER_TRACE_MANAGER_H_ diff --git a/inc/exe_graph/lowering/bg_ir_attrs.h b/inc/exe_graph/lowering/bg_ir_attrs.h deleted file mode 100644 index 1b8b8145a2d3185a45ad5c13777e75299227d2e0..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/bg_ir_attrs.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_IR_ATTRS_H_ -#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_IR_ATTRS_H_ -#include "graph/node.h" -#include "value_holder.h" -namespace gert { -namespace bg { -std::unique_ptr CreateAttrBuffer(const ge::NodePtr &node, size_t &size); -std::unique_ptr CreateAttrBuffer(const ge::NodePtr &node, - const std::vector &runtime_attrs_list, - size_t &size); -std::unique_ptr CreateAttrBufferWithoutIr(const ge::NodePtr &node, - const std::vector &runtime_attrs_list, - size_t &size); -} -} -#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_IR_ATTRS_H_ diff --git a/inc/exe_graph/lowering/bg_kernel_context_extend.h b/inc/exe_graph/lowering/bg_kernel_context_extend.h deleted file mode 100644 index 24abb1105d6b6447a9d3adba7576aafcd00c7cad..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/bg_kernel_context_extend.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_KERNEL_CONTEXT_EXTEND_H_ -#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_KERNEL_CONTEXT_EXTEND_H_ -#include "graph/node.h" -#include "buffer_pool.h" -#include "register/op_impl_registry.h" -namespace gert { -namespace bg { -std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool); -std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool, size_t &total_size); -std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool, - const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, size_t &total_size); -std::unique_ptr CreateComputeNodeInfoWithoutIrAttr(const ge::NodePtr &node, BufferPool &buffer_pool, - const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, size_t &total_size); -} -} -#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_KERNEL_CONTEXT_EXTEND_H_ diff --git a/inc/exe_graph/lowering/buffer_pool.h b/inc/exe_graph/lowering/buffer_pool.h deleted file mode 100644 index 474134131ecac509ca82987d5df92a1f1447912a..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/buffer_pool.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_LOWERING_BUFFER_POOL_H_ -#define AIR_CXX_RUNTIME_V2_LOWERING_BUFFER_POOL_H_ -#include -#include -#include -#include - -namespace gert { -namespace bg { -class BufferPool { - public: - using BufId = size_t; - BufId AddStr(const char *data); - BufId AddBuf(const uint8_t *data, const size_t len); - std::unique_ptr Serialize(size_t &total_size) const; - std::unique_ptr Serialize() const; - size_t GetSize() const; - - // very slow, only use in UT - const char *GetBufById(const BufId id) const; - - private: - BufId AddBuf(std::string &&str); - BufId AddLargeBuf(std::string &&str); - - private: - std::unordered_map bufs_to_id_; - std::vector> large_bufs_to_id_; // large buf size, not do hash - uint64_t id_generator_{0U}; -}; -} // namespace bg -} // namespace gert -#endif // AIR_CXX_RUNTIME_V2_LOWERING_BUFFER_POOL_H_ diff --git a/inc/exe_graph/lowering/builtin_node_types.h b/inc/exe_graph/lowering/builtin_node_types.h deleted file mode 100644 index 2eb20238b4f9b8dc82327c9c604c78461033571b..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/builtin_node_types.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_LOWERING_BUILTIN_NODE_TYPES_H_ -#define METADEF_CXX_INC_EXE_GRAPH_LOWERING_BUILTIN_NODE_TYPES_H_ -#include -#include "graph/types.h" - -namespace gert { -// 子图的输出,InnerNetOutput有多个输入,没有输出,每个InnerNetOutput的输入,对应相同Index的parent node的输出 -constexpr ge::char_t const *kInnerNetOutput = "InnerNetOutput"; - -// 子图的输入,InnerData没有输入,有一个个输出,代表其对应的parent node的输入。 -// InnerData具有一个key为index的属性,该属性类型是int32,代表该InnerData对应的parent node输入的index -constexpr ge::char_t const *kInnerData = "InnerData"; - -// 图的输出,NetOutput出现在Main图上,表达执行完成后,图的输出。NetOutput有多个输入,没有输出,每个输入对应相同Index的图输出 -constexpr ge::char_t const *kNetOutput = "NetOutput"; - -// 图的输入,Data没有输入,有一个个输出,代表图的输入。 -// Data具有一个key为index的属性,该属性类型是int32,代表图的输入的Index -constexpr ge::char_t const *kData = "Data"; - -// 图的输出,未来NetOutput的会被OutputData所代替 -// OutputData只在Main图上出现,执行完成后,图的输出会被写入到OutputData -// OutputData没有输入,有多个输出,每个输出对应相同Index的图输出 -constexpr ge::char_t const *kOutputData = "OutputData"; - -// 常量节点,该节点没有输入,有一个输出,代表常量的值 -// 常量节点有一个属性"value"代表该常量节点的值,value是一段二进制,常量节点本身不关注其内容的格式 -constexpr ge::char_t const *kConst = "Const"; - -// 常量输入节点,该节点没有输入,有一个输出, -// 其值在lowering阶段不可获得,加载时由外部传入,且在执行过程中不会被改变 -// ConstData具有一个key为type的属性,该属性类型是int32,由此代表ConstData的类型,也代表顺序 -// 详见air仓ConstDataType枚举 -constexpr ge::char_t const *kConstData = "ConstData"; - -inline bool IsTypeData(const ge::char_t *const node_type) { - return strcmp(kData, node_type) == 0; -} -inline bool IsTypeInnerData(const ge::char_t *const node_type) { - return strcmp(kInnerData, node_type) == 0; -} -inline bool IsTypeInnerNetOutput(const ge::char_t *const node_type) { - return strcmp(kInnerNetOutput, node_type) == 0; -} -inline bool IsTypeNetOutput(const ge::char_t *const node_type) { - return strcmp(kNetOutput, node_type) == 0; -} -inline bool IsTypeConst(const ge::char_t *const node_type) { - return strcmp(kConst, node_type) == 0; -} -inline bool IsTypeConstData(const ge::char_t *const node_type) { - return strcmp(kConstData, node_type) == 0; -} -inline bool IsTypeOutputData(const ge::char_t *const node_type) { - return strcmp(kOutputData, node_type) == 0; -} -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_LOWERING_BUILTIN_NODE_TYPES_H_ diff --git a/inc/exe_graph/lowering/dev_mem_value_holder.h b/inc/exe_graph/lowering/dev_mem_value_holder.h deleted file mode 100644 index e916f057316bfa8cc29c21b8b68b4e48e43db9e8..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/dev_mem_value_holder.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_DEV_MEM_VALUE_HOLDER_H_ -#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_DEV_MEM_VALUE_HOLDER_H_ - -#include "value_holder.h" - -namespace gert { -namespace bg { -constexpr int64_t kMainStream = 0; -class DevMemValueHolder; -using DevMemValueHolderPtr = std::shared_ptr; -/** - * Value holder with stream - * - */ -class DevMemValueHolder : public ValueHolder { - public: - explicit DevMemValueHolder(const int64_t logic_stream_id) : logic_stream_id_(logic_stream_id){}; - - DevMemValueHolder() = delete; - DevMemValueHolder(const DevMemValueHolder &other) = delete; - DevMemValueHolder &operator=(const DevMemValueHolder &other) = delete; - ~DevMemValueHolder() override = default; - - ValueHolderPtr CreateMateFromNode(ge::FastNode *node, int32_t index, ValueHolderType type) override; - - static DevMemValueHolderPtr CreateSingleDataOutput(const ge::char_t *node_type, - const std::vector &inputs, - int64_t logic_stream_id); - - static std::vector CreateDataOutput(const ge::char_t *node_type, - const std::vector &inputs, - size_t out_count, int64_t logic_stream_id); - - static DevMemValueHolderPtr CreateConst(const void *data, size_t size, int64_t logic_stream_id, - bool is_string = false); - - static DevMemValueHolderPtr CreateError(int64_t logic_stream_id, const char *fmt, va_list arg); - static DevMemValueHolderPtr CreateError(int64_t logic_stream_id, const char *fmt, ...); - - int64_t GetLogicStream() const; - - private: - int64_t logic_stream_id_{kMainStream}; -}; -} // namespace bg -} // namespace gert - -#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_DEV_MEM_VALUE_HOLDER_H_ diff --git a/inc/exe_graph/lowering/device_tiling_context_builder.h b/inc/exe_graph/lowering/device_tiling_context_builder.h deleted file mode 100644 index 1703ca18ef9cbff4da2ddfd8f21adcc500994611..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/device_tiling_context_builder.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GE_COMMMON_RUNTIME_DEVICE_TILING_KERNEL_CONTEXT_BUILDER_H_ -#define GE_COMMMON_RUNTIME_DEVICE_TILING_KERNEL_CONTEXT_BUILDER_H_ - -#include "graph/node.h" -#include "exe_graph/runtime/compute_node_info.h" -#include "exe_graph/runtime/kernel_context.h" -#include "exe_graph/lowering/buffer_pool.h" -#include "exe_graph/runtime/tiling_context.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "register/op_impl_space_registry.h" - -namespace gert { -struct AddrRefreshedTensor { - gert::Tensor *host_addr; - uint64_t device_addr; -}; - -struct TiledKernelContextHolder { - uint64_t dev_op_type_addr_{0UL}; - uint64_t dev_op_name_addr_{0UL}; - KernelContext *host_context_{nullptr}; - uint64_t dev_context_addr_{0UL}; - std::vector output_addrs_; - uint8_t *host_compute_node_info_{nullptr}; - size_t compute_node_info_size_{0UL}; -}; - -class DeviceTilingContextBuilder { - public: - static size_t CalcTotalTiledSize(const ge::OpDescPtr &op_desc); - DeviceTilingContextBuilder &CompileInfo(void *compile_info); - DeviceTilingContextBuilder &Deterministic(int32_t deterministic); - DeviceTilingContextBuilder &PlatformInfo(void *platform_info); - DeviceTilingContextBuilder &TilingData(void *tiling_data); - DeviceTilingContextBuilder &AddrRefreshedInputTensor(const std::map &index_to_tensor); - DeviceTilingContextBuilder &TiledHolder(uint8_t *host_addr, uint64_t dev_addr, size_t max_mem_size); - DeviceTilingContextBuilder &Workspace(void *workspace); - ge::graphStatus Build(const ge::NodePtr &node, TiledKernelContextHolder &holder); - - private: - ge::graphStatus BuildRtTensor(const ge::GeTensorDesc &tensor_desc, ConstTensorAddressPtr address); - ge::graphStatus BuildPlacementRtTensor(const ge::GeTensorDesc &tensor_desc, Tensor *rt_tensor) const; - ge::graphStatus BuildIOTensors(const ge::OpDesc *const op_desc); - - ge::graphStatus TiledBuild(const ge::NodePtr &node, TiledKernelContextHolder &holder); - - void *compile_info_{nullptr}; - void *platform_info_{nullptr}; - int32_t deterministic_{0}; - uint64_t dev_begin_{0UL}; - uint8_t *host_begin_{nullptr}; - size_t max_mem_size_{0UL}; - std::map index_to_tensor_; - std::vector inputs_; - std::vector outputs_{TilingContext::kOutputNum}; -}; -} // namespace gert -#endif // GE_COMMMON_RUNTIME_DEVICE_TILING_KERNEL_CONTEXT_BUILDER_H_ diff --git a/inc/exe_graph/lowering/exe_graph_attrs.h b/inc/exe_graph/lowering/exe_graph_attrs.h deleted file mode 100644 index 266318ac49caa13bec1c3f23d69540a73dd339f7..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/exe_graph_attrs.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXE_GRAPH_ATTRS_H_ -#define AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXE_GRAPH_ATTRS_H_ -#include "graph/types.h" - -namespace gert { -// 打在const节点上,表达const的值, -// todo 后面这个应该慢慢要废弃掉,通过buffer id代替 -constexpr const ge::char_t *kConstValue = "value"; - -constexpr const ge::char_t *kGraph = "graph"; - -// 打在exe node上,代表本node执行的stage(例如DeInit) -constexpr const ge::char_t *kStage = "stage"; - -// 打在feed节点上(Data、InnerData),代表这是第几个输入 -constexpr const ge::char_t *kFeedIndex = "index"; - -// 打在输出desc上,标识本输出不申请独立的内存,从某个node的某个输出index上引用过来使用 -constexpr const ge::char_t *kRefFromNode = "RefFromNode"; -constexpr const ge::char_t *kRefFromIndex = "RefFromIndex"; - -// 打在exe graph上,保存了本graph涉及的所有的ComputeNodeInfo -constexpr const ge::char_t *kComputeNodeInfo = "ComputeNodeInfo"; - -// 打在exe node上,用来标识本node所对应的计算图上的node的index -constexpr const ge::char_t *kComputeNodeIndex = "ComputeNodeIndex"; - -// 打在exe graph上,保存了本graph涉及的所有的KernelExtendInfo -constexpr const ge::char_t *kKernelExtendInfo = "KernelExtendInfo"; - -// 打在exe node上,用来标识本node所对应的kernel信息的index -constexpr const ge::char_t *kKernelExtendIndex = "KernelExtendInfoIndex"; - -// 打在exe node上,用来标识本node属于某一子图,且该子图内节点源自于子图外部,且对应子图外部节点拥有Guarder -constexpr const ge::char_t *kNodeWithGuarderOutside = "NodeWithGuarderOutside"; - -// 打在exe graph上,保存了本graph涉及的所有的二进制buffer(字符串、const值等) -constexpr const ge::char_t *kBuffer = "buffer"; - -// 打在exe graph上,保存了本graph涉及的ModelDesc信息 -constexpr const ge::char_t *kModelDesc = "ModelDesc"; - -// 打在exe node上,类型是int,代表两层含义:1. 本node释放一个资源;2. 本node释放的资源位于本node的第n的输入index;n为属性的值 -constexpr ge::char_t kReleaseResourceIndex[] = "ReleaseResourceIndex"; - -// 作为扩展属性打在exe graph上,类型是ge::ComputeGraphPtr,保存的是原来的计算图,未来会删除,因为无法做序列化,执行图序列化反序列化后会丢失该属性 -constexpr const ge::char_t *kComputeGraph = "_compute_graph"; - -// 作为扩展属性打在exe node上,类型是PassChangedKernels,记录执行图经过pass后的新旧exe nodes输出的对应关系 -constexpr const ge::char_t *kPassChangedInfo = "_pass_changed_info"; - -// 打在exe graph上,保存space_registry的智能指针 -constexpr const ge::char_t *kSpaceRegistry = "SpaceRegistry"; - -// 打在exe graph上,保存外置权重文件目录的string -constexpr const ge::char_t *kExternalFileConstantDir = "ExternalFileConstantDir"; - -// 打在exe node上,类型是string,保存它的guarder节点的node type -constexpr const ge::char_t *kGuarderNodeType = "GuarderNodeType"; -} // namespace gert -#endif // AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXE_GRAPH_ATTRS_H_ diff --git a/inc/exe_graph/lowering/exe_res_generation_ctx_builder.h b/inc/exe_graph/lowering/exe_res_generation_ctx_builder.h deleted file mode 100644 index deb4364557a8a9ee4116e76d594dc3b8d3c6fee9..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/exe_res_generation_ctx_builder.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXE_GRAPH_LOWERING_EXE_RES_GENERATION_CTX_BUILDER_H -#define INC_EXE_GRAPH_LOWERING_EXE_RES_GENERATION_CTX_BUILDER_H - -#include "external/exe_graph/runtime/exe_res_generation_context.h" -#include "graph/node.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" - -namespace gert { -using ExeResGenerationCtxHolderPtr = std::shared_ptr; -class ExeResGenerationCtxBuilder { - public: - ExeResGenerationCtxHolderPtr CreateOpExeContext(ge::Node &node); - ExeResGenerationCtxHolderPtr CreateOpCheckContext(ge::Node &node); - private: - void CreateShapesInputs(const ge::Node &node, std::vector &inputs); - private: - ExeResGenerationCtxHolderPtr ctx_holder_ptr_; - std::vector input_shapes_; - std::vector output_shapes_; -}; -} // namespace fe - -#endif diff --git a/inc/exe_graph/lowering/frame_selector.h b/inc/exe_graph/lowering/frame_selector.h deleted file mode 100644 index 14b85806954abaebeac86cebb93178e548c9ac29..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/frame_selector.h +++ /dev/null @@ -1,92 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_LOWERING_FRAME_SELECTOR_H_ -#define METADEF_CXX_INC_EXE_GRAPH_LOWERING_FRAME_SELECTOR_H_ -#include -#include "value_holder.h" -#include "common/checker.h" -namespace gert { -namespace bg { -/** - * frame选择器,通过frame选择器,可以将执行图逻辑生成到执行的frame上 - * - * 当前frame选择器还无法提供OnInitFrame/OnDeInitFrame功能,因为ValueHolder跨图连接的能力仅支持从父图向子图连接, - * 若出现了Init向Main图中节点连边的场景,当前无法处理。对于Init/DeInit需求,按需开发即可 - */ -class FrameSelector { - public: - /** - * 选择Init图,将builder中的逻辑生成到Init图上,并返回**Init节点**的输出ValueHolder。 - * 注意:如果builder中创建的输出ValueHolder具有guarder,那么此guarder节点会被自动移动到DeInit图中 - * @param builder 执行图构建函数 - * @return 成功时,将builder返回的ValueHolderPtr作为本函数的返回值;失败时,本函数返回空vector - */ - static std::vector OnInitRoot(const std::function()> &builder); - - static ge::graphStatus OnInitRoot(const std::function()> &builder, - std::vector &init_graph_outputs, - std::vector &init_node_outputs); - /** - * 选择DeInit图,将builder中的逻辑生成到DeInit图上,并返回builder返回的ValueHolder。 - * 注意:builder中创建无输出的ValueHolder,即CreateVoid,并将其返回。 - * @param builder 执行图构建函数 - * @return 成功时,将DeInit图中builder返回的ValueHolder节点作为本函数的返回值;失败时,本函数返回空vector - */ - static std::vector OnDeInitRoot(const std::function()> &builder); - - /** - * 选择Main图,将builder中的逻辑生成到Main图上。 - * - * 需要注意的是,本函数仅保证当前将builder中的逻辑生成到Main图上,但不保证其始终在Main图上。 - * 在lowering构图完成后,在图优化阶段,CEM等优化可能将Main图上的Node移动到Init图中。 - * - * @param builder 执行图构建函数 - * @return 成功时,将builder返回的ValueHolderPtr作为本函数的返回值;失败时,本函数返回空vector - */ - static std::vector OnMainRoot(const std::function()> &builder); - - static ge::graphStatus OnMainRoot(const std::function()> &builder, - std::vector &outputs); - /** - * 选择Main图,将builder中的逻辑生成到Main图上, 并且保证builder生成的节点在main图最开始执行 - * 当前已有阶段,请参考bg::OnMainRootFirstExecStage的枚举值 - * - * @param builder 执行图构建函数 - * @return 成功时,将builder返回的ValueHolderPtrs作为本函数的返回值;失败时,本函数返回空vector - */ - static std::vector OnMainRootFirst(const std::function()> &builder); - - static ValueHolderPtr OnMainRootLast(const std::function &builder); - - /** - * 选择Main图,将builder中的逻辑生成到Main图上, builder生成的节点在LastEventSync阶段执行. - * 当前已有阶段,请参考bg::OnMainRootLastExecStage的枚举值 - * - * @param builder 执行图构建函数 - * @return 成功时,将builder返回的ValueHolderPtrs作为本函数的返回值;失败时,本函数返回空vector - */ - static std::vector OnMainRootLastEventSync( - const std::function()> &builder); - - /** - * 选择Main图,将builder中的逻辑生成到Main图上, builder生成的节点在LastResourceClean阶段执行. - * 当前已有阶段,请参考bg::OnMainRootLastExecStage的枚举值 - * - * @param builder 执行图构建函数 - * @return 成功时,将builder返回的ValueHolderPtrs作为本函数的返回值;失败时,本函数返回空vector - */ - static std::vector OnMainRootLastResourceClean( - const std::function()> &builder); -}; - -ValueHolderPtr HolderOnInit(const ValueHolderPtr &holder); -} // namespace bg -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_LOWERING_FRAME_SELECTOR_H_ diff --git a/inc/exe_graph/lowering/generate_exe_graph.h b/inc/exe_graph/lowering/generate_exe_graph.h deleted file mode 100644 index 8f8b3a410757f121e1998e10b5870e89cb2cda83..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/generate_exe_graph.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_LOWERING_GENERATE_EXE_GRAPH_H_ -#define METADEF_CXX_INC_EXE_GRAPH_LOWERING_GENERATE_EXE_GRAPH_H_ -#include - -#include "dev_mem_value_holder.h" -#include "graph/compute_graph.h" -#include "lowering_global_data.h" -namespace gert { -namespace bg { -class GenerateExeGraph { - public: - struct ExeGraphGenerator { - using InferShapeFunc = std::vector (*)(const ge::NodePtr &node, - const std::vector &shapes, - LoweringGlobalData &global_data); - using AllocOutputMemoryFunc = std::vector (*)(TensorPlacement placement, - const ge::NodePtr &node, - const std::vector &output_sizes, - LoweringGlobalData &global_data); - using CalcTensorSizeFunc = std::vector (*)(const ge::NodePtr &node, - const std::vector &output_shapes); - - InferShapeFunc infer_shape; - AllocOutputMemoryFunc alloc_output_memory; - CalcTensorSizeFunc calc_tensor_size; - }; - - public: - static std::vector InferShape(const ge::NodePtr &node, const std::vector &shapes, - LoweringGlobalData &global_data) { - if (generator_.infer_shape == nullptr) { - return {}; - } - return generator_.infer_shape(node, shapes, global_data); - } - static std::vector AllocOutputMemory(TensorPlacement placement, const ge::NodePtr &node, - const std::vector &output_sizes, - LoweringGlobalData &global_data) { - if (generator_.alloc_output_memory == nullptr) { - return {}; - } - return generator_.alloc_output_memory(placement, node, output_sizes, global_data); - } - static std::vector CalcTensorSize(const ge::NodePtr &node, - const std::vector &output_shapes) { - if (generator_.calc_tensor_size == nullptr) { - return {}; - } - return generator_.calc_tensor_size(node, output_shapes); - } - - static void AddBuilderImplement(ExeGraphGenerator generator) { - generator_ = generator; - } - - static ValueHolderPtr MakeSureTensorAtHost(const ge::Node *node, LoweringGlobalData &global_data, - const ValueHolderPtr &addr, const ValueHolderPtr &size); - - static ValueHolderPtr CalcTensorSizeFromShape(ge::DataType dt, const ValueHolderPtr &shape); - - static ValueHolderPtr FreeMemoryGuarder(const ValueHolderPtr &resource); - - private: - static ExeGraphGenerator generator_; -}; -} // namespace bg -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_LOWERING_GENERATE_EXE_GRAPH_H_ diff --git a/inc/exe_graph/lowering/getcdim.h b/inc/exe_graph/lowering/getcdim.h deleted file mode 100644 index a8f37aac4e62a8fc2d6d976ce94b519ad6f25f9a..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/getcdim.h +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_GETCDIM_H_ -#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_GETCDIM_H_ -#include "exe_graph/runtime/infer_shape_context.h" -#include "exe_graph/runtime/tiling_context.h" -namespace gert { - int64_t GetInputCDim(gert::TilingContext *kernel_context, const size_t index); - int64_t GetOutputCDim(gert::TilingContext *kernel_context, const size_t index); -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_TILING_PARSE_CONTEXT_H_ diff --git a/inc/exe_graph/lowering/graph_frame.h b/inc/exe_graph/lowering/graph_frame.h deleted file mode 100644 index 029385714f077af32545be34d2998f5e15cc410e..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/graph_frame.h +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_GRAPH_FRAME_H_ -#define AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_GRAPH_FRAME_H_ -#include -#include -#include -#include -#include "graph/node.h" -#include "buffer_pool.h" -#include "bg_kernel_context_extend.h" -#include "graph/fast_graph/execute_graph.h" - -namespace gert { -namespace bg { -class ValueHolder; -using ValueHolderPtr = std::shared_ptr; -constexpr const ge::char_t *kStageIdsToLastPartitionedCall = "StageIdsToLastPartitionedCall"; -constexpr const ge::char_t *kStageIdsToFirstPartitionedCall = "StageIdsToFirstPartitionedCall"; -/* - * 执行的阶段, 越小越靠前执行,越大越靠后执行 - */ -enum class OnMainRootLastExecStage { - kFirstStage = 0, - kLastEventSyncStage, - kLastResourceClean, - // add level before this - kStageSize -}; - -/* - * 执行的阶段, 越小越靠前执行,越大越靠后执行 - */ -enum class OnMainRootFirstExecStage { - kFirstEventSyncStage = 0, - // add level before this - kStageSize -}; - -class GraphFrame { - public: - GraphFrame(const GraphFrame &) = delete; - GraphFrame(GraphFrame &&) = delete; - GraphFrame operator=(const GraphFrame &) = delete; - GraphFrame operator=(GraphFrame &&) = delete; - - GraphFrame(ge::ExecuteGraphPtr exe_graph, const GraphFrame &parent_frame) noexcept - : execute_graph_(std::move(exe_graph)), - current_compute_node_and_index_(), root_frame_(parent_frame.root_frame_), - nodes_to_index_(root_frame_.nodes_to_index_), indexes_to_node_(root_frame_.indexes_to_node_), - relevant_input_node_(root_frame_.relevant_input_node_) {} - - explicit GraphFrame(ge::ExecuteGraphPtr exe_graph) noexcept - : execute_graph_(std::move(exe_graph)), - current_compute_node_and_index_(), root_frame_(*this), - nodes_to_index_holder_(), nodes_to_index_(nodes_to_index_holder_), indexes_to_node_holder_(), - indexes_to_node_(indexes_to_node_holder_), relevant_input_node_holder_(), - relevant_input_node_(relevant_input_node_holder_) {} - - const ge::NodePtr &GetCurrentComputeNode() const { - return current_compute_node_and_index_.first; - } - void SetCurrentComputeNode(const ge::NodePtr ¤t_node) { - if (current_node == nullptr) { - current_compute_node_and_index_ = {nullptr, 0}; - return; - } - const auto result = nodes_to_index_.emplace(current_node, nodes_to_index_.size()); - current_compute_node_and_index_ = {current_node, result.first->second}; - if (result.second) { - indexes_to_node_.emplace_back(current_node); - } - } - void AddRelevantInputNode(const ge::NodePtr ¤t_node) { - relevant_input_node_.emplace_back(current_node); - } - bool GetCurrentNodeIndex(size_t &index) const { - if (current_compute_node_and_index_.first == nullptr) { - return false; - } - index = current_compute_node_and_index_.second; - return true; - } - - bool IsRootFrame() const { - return &root_frame_ == this; - } - - const ge::ExecuteGraphPtr &GetExecuteGraph() const { - return execute_graph_; - } - - const vector &GetIndexesToNode() const { - return indexes_to_node_; - } - - const std::unordered_map &GetNodesToIndex() const { - return nodes_to_index_; - } - - const std::vector &GetLastExecNodes() const { - return last_exec_nodes_; - } - - /* - * set last exec node, its priority is first level - */ - void SetLastExecNode(const ValueHolderPtr last_exec_node) { - if (last_exec_node != nullptr) { - last_exec_nodes_.emplace_back(last_exec_node); // todo to be deprecated - } - } - - private: - ge::ExecuteGraphPtr execute_graph_; - std::pair current_compute_node_and_index_; - GraphFrame &root_frame_; - std::unordered_map nodes_to_index_holder_; - std::unordered_map &nodes_to_index_; - std::vector indexes_to_node_holder_; - std::vector &indexes_to_node_; - std::vector last_exec_nodes_; // todo to be deprecated - std::vector relevant_input_node_holder_; - std::vector &relevant_input_node_; -}; -} // namespace bg -} // namespace gert - -#endif // AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_GRAPH_FRAME_H_ diff --git a/inc/exe_graph/lowering/lowering_definitions.h b/inc/exe_graph/lowering/lowering_definitions.h deleted file mode 100644 index 54d187b7546677a28f287a93dfd77510ca1dacc0..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/lowering_definitions.h +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_DEFINITIONS_H_ -#define AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_DEFINITIONS_H_ -#include "graph/types.h" - -namespace gert { -constexpr const ge::char_t *kLoweringInputInfo = "_lowering_input_info"; -constexpr const ge::char_t *kLoweringResult = "_lowering_result"; -constexpr const ge::char_t *kLoweringTensorResult = "_lowering_tensor_result"; -constexpr const ge::char_t *kLoweringHostTensorResult = "_lowering_host_tensor_result"; -} // namespace gert -#endif // AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_DEFINITIONS_H_ diff --git a/inc/exe_graph/lowering/lowering_global_data.h b/inc/exe_graph/lowering/lowering_global_data.h deleted file mode 100644 index ef02d63f245394dda5adb923a8215c682f709e92..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/lowering_global_data.h +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_GLOBAL_DATA_H_ -#define AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_GLOBAL_DATA_H_ -#include -#include "proto/task.pb.h" -#include "value_holder.h" -#include "exe_graph/runtime/tensor.h" -#include "exe_graph/runtime/allocator.h" -#include "exe_graph/runtime/execute_graph_types.h" -#include "base/registry/op_impl_space_registry_v2.h" -#include "register/op_impl_space_registry.h" -#include "exe_graph/lowering/lowering_opt.h" - -namespace gert { -constexpr int64_t kRtMemoryTypeHbm = 0x2; -constexpr int64_t kDefaultMainStreamId = 0; -// todo change to get stream num from model_desc const data -constexpr const ge::char_t *kGlobalDataModelStreamNum = "ModelStreamNum"; -class LoweringGlobalData { - public: - struct NodeCompileResult { - const std::vector &GetTaskDefs() const { - return task_defs; - } - std::vector task_defs; - }; - - std::vector LoweringAndSplitRtStreams(int64_t stream_num); - bg::ValueHolderPtr GetStreamById(int64_t logic_stream_id) const; - inline bg::ValueHolderPtr GetStream() const { - int64_t current_stream_id = kDefaultMainStreamId; - if ((bg::ValueHolder::GetCurrentFrame() != nullptr) && - (bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode() != nullptr)) { - current_stream_id = bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode()->GetOpDesc()->GetStreamId(); - } - return GetStreamById(current_stream_id); - } - - void SetRtNotifies(const std::vector ¬ify_holders); - bg::ValueHolderPtr GetNotifyById(int64_t logic_notify_id) const; - - const NodeCompileResult *FindCompiledResult(const ge::NodePtr &node) const; - LoweringGlobalData &AddCompiledResult(const ge::NodePtr &node, NodeCompileResult compile_result); - - void *GetGraphStaticCompiledModel(const std::string &graph_name) const; - LoweringGlobalData &AddStaticCompiledGraphModel(const std::string &graph_name, void *const model); - - bg::ValueHolderPtr GetL1Allocator(const AllocatorDesc &desc) const; - LoweringGlobalData &SetExternalAllocator(bg::ValueHolderPtr &&allocator); - LoweringGlobalData &SetExternalAllocator(bg::ValueHolderPtr &&allocator, const ExecuteGraphType graph_type); - - bg::ValueHolderPtr GetOrCreateL1Allocator(const AllocatorDesc desc); - bg::ValueHolderPtr GetOrCreateL2Allocator(int64_t logic_stream_id, const AllocatorDesc desc); - bg::ValueHolderPtr GetInitL2Allocator(const AllocatorDesc desc) const; - bg::ValueHolderPtr GetMainL2Allocator(int64_t logic_stream_id, const AllocatorDesc desc) const; - inline bg::ValueHolderPtr GetOrCreateAllocator(const AllocatorDesc desc) { - int64_t current_stream_id = kDefaultMainStreamId; - if ((bg::ValueHolder::GetCurrentFrame() != nullptr) && - (bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode() != nullptr)) { - current_stream_id = bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode()->GetOpDesc()->GetStreamId(); - } - return GetOrCreateL2Allocator(current_stream_id, desc); - } - bg::ValueHolderPtr GetOrCreateAllL2Allocators(); - - bg::ValueHolderPtr GetOrCreateUniqueValueHolder(const std::string &name, - const std::function &builder); - std::vector GetOrCreateUniqueValueHolder(const std::string &name, - const std::function()> &builder); - bg::ValueHolderPtr GetUniqueValueHolder(const std::string &name) const; - void SetUniqueValueHolder(const std::string &name, const bg::ValueHolderPtr &holder); - void SetValueHolders(const string &name, const bg::ValueHolderPtr &holder); - size_t GetValueHoldersSize(const string &name); - - void SetModelWeightSize(const size_t require_weight_size); - size_t GetModelWeightSize() const; - const OpImplSpaceRegistryArray &GetSpaceRegistries() { - return space_registries_; - } - const OpImplSpaceRegistryPtr GetSpaceRegistry(ge::OppImplVersion opp_impl_version = ge::OppImplVersion::kOpp) const { - if (opp_impl_version >= ge::OppImplVersion::kVersionEnd) { - return nullptr; - } - return space_registries_[static_cast(opp_impl_version)]; - }; - void SetSpaceRegistries(const gert::OpImplSpaceRegistryArray &space_registries) { - space_registries_ = space_registries; - } - // 兼容air, 随后air合入后删除 - void SetSpaceRegistry(gert::OpImplSpaceRegistryPtr space_registry) { - space_registries_[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - } - - // TODOO: 先合入metadef,air适配结束,删除metadef中未调用的实现 - const OpImplSpaceRegistryV2Ptr GetSpaceRegistryV2( - OppImplVersionTag opp_impl_version = OppImplVersionTag::kOpp) const { - if (opp_impl_version >= OppImplVersionTag::kVersionEnd) { - return nullptr; - } - return space_registries_v2_[static_cast(opp_impl_version)]; - }; - const OpImplSpaceRegistryV2Array &GetSpaceRegistriesV2() const { - return space_registries_v2_; - }; - void SetSpaceRegistriesV2(const OpImplSpaceRegistryV2Array &space_registries) { - space_registries_v2_ = space_registries; - } - - const LoweringOption &GetLoweringOption() const; - void SetLoweringOption(const LoweringOption &lowering_option); - - void SetStaicModelWsSize(const int64_t require_ws_size) { - static_model_ws_size = require_ws_size; - } - - int64_t GetStaticModelWsSize() const { - return static_model_ws_size; - } - - void SetFixedFeatureMemoryBase(const void * const memory, const size_t size) { - fixed_feature_mem_[kRtMemoryTypeHbm] = std::make_pair(memory, size); - } - - const std::pair &GetFixedFeatureMemoryBase() const { - const auto iter = fixed_feature_mem_.find(kRtMemoryTypeHbm); - if (iter != fixed_feature_mem_.end()) { - return iter->second; - } - static std::pair dummy_result; - return dummy_result; - } - - void SetFixedFeatureMemoryBase(const int64_t type, const void * const memory, const size_t size) { - fixed_feature_mem_[type] = std::make_pair(memory, size); - } - - /* - * 获取图所需fixed feature memory地址和长度 - * 1 地址为nullptr,长度为0:用户设置的结果,比较特殊,表示不需要GE默认申请fixed内存 - * 2 地址为nullptr,长度不为0,表示需要fixed内存,但是用户没有设置。GE要默认申请fixed内存 - * 3 地址不为nullptr,长度不为0,用户设置的结果。 - */ - const std::map> &GetAllTypeFixedFeatureMemoryBase() const { - return fixed_feature_mem_; - } - - bool IsSingleStreamScene() const { - return is_single_stream_scene_; - } - - void SetHostResourceCenter(void *host_resource_center_ptr) { - host_resource_center_ = host_resource_center_ptr; - } - void *GetHostResourceCenter() { - return host_resource_center_; - } - - private: - struct HolderByGraphs { - bg::ValueHolderPtr holders[static_cast(ExecuteGraphType::kNum)]; - }; - struct HoldersByGraphs { - std::vector holders[static_cast(ExecuteGraphType::kNum)]; - }; - - bg::ValueHolderPtr GetOrCreateInitL2Allocator(const AllocatorDesc desc); - bg::ValueHolderPtr GetExternalAllocator(const bool from_init, const string &key, const AllocatorDesc &desc); - bool CanUseExternalAllocator(const ExecuteGraphType &graph_type, const TensorPlacement placement) const; - private: - std::unordered_map node_name_to_compile_result_holders_; - std::map graph_to_static_models_; - std::unordered_map> unique_name_to_value_holders_; - HoldersByGraphs streams_; - HoldersByGraphs notifies_; - HolderByGraphs external_allocators_; - // todo need delete and change to const_data after const_data is ready - int64_t model_weight_size_; - int64_t static_model_ws_size; - OpImplSpaceRegistryArray space_registries_; - OpImplSpaceRegistryV2Array space_registries_v2_; - LoweringOption lowering_option_; - // addr为nullptr,但size不为0,表示用户没有设置fixed内存,需要GE默认申请fixed内存 - std::map> fixed_feature_mem_; - bool is_single_stream_scene_{true}; - void *host_resource_center_{nullptr}; -}; -} // namespace gert -#endif // AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_GLOBAL_DATA_H_ diff --git a/inc/exe_graph/lowering/shape_utils.h b/inc/exe_graph/lowering/shape_utils.h deleted file mode 100644 index 712b2c7d17edb027cf532a66cf364c68f738ab98..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/shape_utils.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_UTILS_H_ -#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_UTILS_H_ -#include -#include -#include "exe_graph/runtime/shape.h" -#include "graph/ge_error_codes.h" -#include "graph/types.h" - -namespace gert { -extern const Shape g_vec_1_shape; -/** - * 确保返回的shape是非scalar的。 - * 当一个shape的dim num为0时,此shape被认为表达了一个scalar。 - * 本函数在接受一个非scalar的shape时,会返回原有shape;在接收到scalar shape时,会返回返回一个{1}的vector shape - * @param in_shape 输入shape - * @return 保证非scalar的shape - */ -inline const Shape &EnsureNotScalar(const Shape &in_shape) { - if (in_shape.IsScalar()) { - return g_vec_1_shape; - } - return in_shape; -} -/** - * 返回shape的字符串,本函数性能较低,不可以在执行时的正常流程中使用 - * @param shape 需要转为字符串的shape实例 - * @param join_char 每个Dim的间隔,默认为`,` - * @return 转好的字符串 - */ -inline std::string ShapeToString(const Shape &shape, const char *join_char = ",") { - if (join_char == nullptr) { - join_char = ","; - } - std::stringstream ss; - for (size_t i = 0U; i < shape.GetDimNum(); ++i) { - if (i > 0U) { - ss << join_char; - } - ss << shape[i]; - } - return ss.str(); -} - -ge::graphStatus CalcAlignedSizeByShape(const Shape &shape, ge::DataType data_type, uint64_t &ret_tensor_size); -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_UTILS_H_ diff --git a/inc/exe_graph/lowering/tiling_context_builder.h b/inc/exe_graph/lowering/tiling_context_builder.h deleted file mode 100644 index bc76b495e208b023f2953c71400d9cd6c3e74247..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/tiling_context_builder.h +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GE_COMMMON_RUNTIME_TILING_KERNEL_CONTEXT_BUILDER_H_ -#define GE_COMMMON_RUNTIME_TILING_KERNEL_CONTEXT_BUILDER_H_ - -#include "graph/node.h" -#include "exe_graph/runtime/compute_node_info.h" -#include "exe_graph/runtime/kernel_context.h" -#include "exe_graph/lowering/buffer_pool.h" -#include "exe_graph/runtime/tiling_context.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "base/registry/op_impl_space_registry_v2.h" -#include "register/op_impl_space_registry.h" - -namespace gert { -class TilingContextBuilder { - public: - TilingContextBuilder &CompileInfo(void *compile_info); - TilingContextBuilder &Deterministic(int32_t deterministic); - TilingContextBuilder &PlatformInfo(void *platform_info); - TilingContextBuilder &TilingData(void *tiling_data); - TilingContextBuilder &Workspace(ContinuousVector *workspace); - // 兼容air, 随后air合入后删除 - TilingContextBuilder &SpaceRegistry(const gert::OpImplSpaceRegistryPtr &space_registry); - TilingContextBuilder &SpaceRegistries(const gert::OpImplSpaceRegistryArray &space_registries); - TilingContextBuilder &SetSpaceRegistryV2(const OpImplSpaceRegistryV2Ptr &space_registry, - OppImplVersionTag version_tag); - KernelContextHolder Build(const ge::Operator &op); // deprecated later - KernelContextHolder Build(const ge::Operator &op, ge::graphStatus &ret); - - private: - ge::graphStatus GetDependInputTensorAddr(const ge::Operator &op, const size_t input_idx, TensorAddress &address); - ge::graphStatus BuildRtTensor(const ge::GeTensorDesc &tensor_desc, ConstTensorAddressPtr address, - std::unique_ptr &rt_tensor_holder) const; - ge::graphStatus BuildRTInputTensors(const ge::Operator &op); - ge::graphStatus BuildRTOutputShapes(const ge::Operator &op); - - void *compile_info_{nullptr}; - void *platform_info_{nullptr}; - int32_t deterministic_; - std::vector> depend_ge_tensor_holders_; - std::vector> rt_tensor_holders_; - std::vector outputs_ {TilingContext::kOutputNum}; - KernelRunContextBuilder base_builder_; - OpImplSpaceRegistryArray space_registries_; - OpImplSpaceRegistryV2Array space_registries_v2_; - bool use_registry_v2_{false}; -}; - -class AtomicTilingContextBuilder { - public: - AtomicTilingContextBuilder &CompileInfo(void *compile_info); - AtomicTilingContextBuilder &CleanWorkspaceSizes(ContinuousVector *workspace_sizes); - AtomicTilingContextBuilder &CleanOutputSizes(const std::vector &output_sizes); - AtomicTilingContextBuilder &TilingData(void *tiling_data); - AtomicTilingContextBuilder &Workspace(ContinuousVector *workspace); - KernelContextHolder Build(const ge::Operator &op); // deprecated later - KernelContextHolder Build(const ge::Operator &op, ge::graphStatus &ret); - - private: - void *compile_info_{nullptr}; - void *worksapce_sizes_{nullptr}; - std::vector clean_output_sizes_; - std::vector outputs_ {TilingContext::kOutputNum}; - KernelRunContextBuilder base_builder_; -}; -} // namespace gert -#endif // GE_COMMMON_RUNTIME_TILING_KERNEL_CONTEXT_BUILDER_H_ diff --git a/inc/exe_graph/lowering/tiling_parse_context_builder.h b/inc/exe_graph/lowering/tiling_parse_context_builder.h deleted file mode 100644 index 188727540202d3de7b3efefe08dc26ba67cb3066..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/tiling_parse_context_builder.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GE_RUNTIME_TILING_PARSE_CONTEXT_BUILDER_H_ -#define GE_RUNTIME_TILING_PARSE_CONTEXT_BUILDER_H_ - -#include "exe_graph/runtime/kernel_context.h" -#include "graph/operator.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "register/op_impl_registry.h" - -namespace gert { -class TilingParseContextBuilder { - public: - TilingParseContextBuilder &CompileJson(const ge::char_t *compile_json); - TilingParseContextBuilder &PlatformInfo(void *platform_info); - TilingParseContextBuilder &CompileInfoCreatorFunc(OpImplRegisterV2::CompileInfoCreatorFunc create_func); - TilingParseContextBuilder &CompileInfoDeleterFunc(OpImplRegisterV2::CompileInfoDeleterFunc delete_func); - KernelContextHolder Build(const ge::Operator &op); - - private: - void *compile_json_{ nullptr }; - void *platform_info_{ nullptr }; - OpImplRegisterV2::CompileInfoCreatorFunc create_func_{ nullptr }; - OpImplRegisterV2::CompileInfoDeleterFunc delete_func_{ nullptr }; -}; -} // namespace gert -#endif // GE_RUNTIME_TILING_PARSE_CONTEXT_BUILDER_H_ diff --git a/inc/exe_graph/lowering/value_holder.h b/inc/exe_graph/lowering/value_holder.h deleted file mode 100644 index a14a5e8e3d14f487571be01ffaa9a82f60ac32b5..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/value_holder.h +++ /dev/null @@ -1,225 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_H_ -#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_H_ -#include -#include -#include -#include - -#include "graph/buffer.h" -#include "graph/any_value.h" -#include "graph/compute_graph.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/fast_node_utils.h" -#include "graph/node.h" -#include "common/hyper_status.h" -#include "graph_frame.h" -#include "exe_graph/runtime/tensor.h" -#include "common/checker.h" -#include "common/util/mem_utils.h" -#include "graph/fast_graph/execute_graph.h" - -namespace gert { -namespace bg { -class ValueHolder; -using ValueHolderPtr = std::shared_ptr; -class ValueHolder { - public: - enum class ValueHolderType { - kConst, // 常量,执行时不变 - kFeed, // 执行时外部指定 - kOutput, // 由node产生,包含数据输出与控制输出 - kConstData, // 常量Const,执行时由外部指定,执行时不变 - // Add new type definitions here - kValueHolderTypeEnd - }; - - class CurrentComputeNodeGuarder { - public: - explicit CurrentComputeNodeGuarder(ge::NodePtr old_node) : old_node_(std::move(old_node)) {} - ~CurrentComputeNodeGuarder() { - ValueHolder::SetCurrentComputeNode(old_node_); - } - - private: - ge::NodePtr old_node_; - }; - - ValueHolder(const ValueHolder &other) = delete; - ValueHolder &operator=(const ValueHolder &other) = delete; - virtual ~ValueHolder(); - - bool IsOk() const noexcept; - - HyperStatus AddInnerDataToKVMap(int32_t index) const noexcept; - - int64_t GetId() const noexcept; - ValueHolderType GetType() const noexcept; - - ge::FastNode *GetFastNode() const noexcept; - ge::ExecuteGraph *GetExecuteGraph() const noexcept; - - ValueHolderPtr GetGuarder() const noexcept; - void SetGuarder(const bg::ValueHolderPtr &guarder) noexcept; - - int32_t GetOutIndex() const noexcept; - - // ref-from other的含义是,本value指向了other(本value没有独立的内存) - ge::graphStatus RefFrom(const ValueHolderPtr &other); - - // 在other产生后,本holder的生命周期才结束 - void ReleaseAfter(const ValueHolderPtr &other); - - const int32_t &GetPlacement() const; - void SetPlacement(const int32_t &placement); - - template - std::vector> AppendOutputs(size_t append_count, Args... args) { - auto start_index = fast_node_->GetDataOutNum(); - auto ret = ge::FastNodeUtils::AppendOutputEdgeInfo(fast_node_, start_index + append_count); - if (ret != ge::GRAPH_SUCCESS) { - return {}; - } - return CreateFromNode(fast_node_, start_index, append_count, args...); - } - // src nodes may come from different graph from current node and can add data edges to current node - // currently only support to pass through parent nodes with only one subgraph - ge::graphStatus AppendInputs(const std::vector &src); - - static ValueHolderPtr CreateError(const ge::char_t *fmt, ...); - static ValueHolderPtr CreateError(const ge::char_t *fmt, va_list arg); - - static ValueHolderPtr CreateConst(const void *data, size_t size, bool is_string = false); - - static ValueHolderPtr CreateFeed(int64_t index); - - static ValueHolderPtr CreateConstData(int64_t index); - - static ValueHolderPtr CreateSingleDataOutput(const ge::char_t *node_type, const std::vector &inputs); - - static std::vector CreateDataOutput(const ge::char_t *node_type, - const std::vector &inputs, size_t out_count); - - template - static std::shared_ptr CreateVoid(const ge::char_t *node_type, const std::vector &inputs, - Args... args) { - auto node = CreateNode(node_type, inputs, 0); - GE_ASSERT_NOTNULL(node); - return CreateFromNode(node, -1, ValueHolderType::kOutput, args...); - } - - static ValueHolderPtr CreateVoidGuarder(const ge::char_t *node_type, const ValueHolderPtr &resource, - const std::vector &args); - - static HyperStatus AddDependency(const ValueHolderPtr &src, const ValueHolderPtr &dst); - - /** - * 压栈一个Root GraphFrame,只有栈底的GraphFrame才被称为ROOT GraphFrame,因此调用此借口前,需要保证栈内不存在GraphFrame,否则会失败 - * @return 成功后,返回创建好的GraphFrame指针,失败时返回空指针 - */ - static GraphFrame *PushGraphFrame(); - /** - * 压栈一个非root的GraphFrame - * @param belongs 新加入的GraphFrame所归属的ValueHolder,新压栈的GraphFrame会被挂在该ValueHolder所归属的Node上 - * @param graph_name 挂接GraphFrame到Node时,使用的name - * @return 创建且挂接成功后,返回创建好的GraphFrame指针,失败时返回空指针 - */ - static GraphFrame *PushGraphFrame(const ValueHolderPtr &belongs, const ge::char_t *graph_name); - /** - * 压栈一个GraphFrame, 若该graph frame非root frame,需要保证栈顶frame为其父frame - * @return 成功后,返回该GraphFrame指针,失败时返回空指针 - */ - static GraphFrame *PushGraphFrame(GraphFrame *graph_frame); - - static std::unique_ptr PopGraphFrame(); - static std::unique_ptr PopGraphFrame(const std::vector &outputs, - const std::vector &targets); - - static std::unique_ptr PopGraphFrame(const std::vector &outputs, - const std::vector &targets, - const ge::char_t *out_node_type); - - static GraphFrame *GetCurrentFrame(); - - static void ClearGraphFrameResource(); - - static ge::ExecuteGraph *GetCurrentExecuteGraph(); - - static void SetCurrentComputeNode(const ge::NodePtr &node); - static void AddRelevantInputNode(const ge::NodePtr &node); - static std::unique_ptr SetScopedCurrentComputeNode(const ge::NodePtr &node); - - static ge::FastNode *AddNode(const ge::char_t *node_type, size_t input_count, size_t output_count, - const GraphFrame &frame); - - template - static std::vector> CreateFromNode(ge::FastNode *node, size_t start_index, - size_t create_count, Args... args) { - if (node == nullptr) { - return {create_count, nullptr}; - } - std::vector> holders; - for (size_t i = 0; i < create_count; ++i) { - holders.emplace_back( - CreateFromNode(node, static_cast(i + start_index), ValueHolderType::kOutput, args...)); - } - - return holders; - } - - template - static std::shared_ptr CreateFromNode(ge::FastNode *node, int32_t index, ValueHolderType type, Args... args) { - auto holder = std::shared_ptr(new (std::nothrow) T(args...)); - GE_ASSERT_NOTNULL(holder); - - holder->type_ = type; - holder->fast_node_ = node; - holder->index_ = index; - holder->op_desc_ = holder->fast_node_->GetOpDescPtr(); - return holder; - } - - virtual ValueHolderPtr CreateMateFromNode(ge::FastNode *node, int32_t index, ValueHolderType type); - - static std::string GenerateNodeName(const ge::char_t *node_type, const GraphFrame &frame); - - static std::vector GetLastExecNodes(); - - protected: - ValueHolder(); - - static ge::FastNode *CreateNode(const ge::char_t *node_type, const std::vector &inputs, - size_t out_count); - - template - static std::vector> CreateFromNodeStart(ge::FastNode *node, size_t out_count, - Args... args) { - return CreateFromNode(node, 0U, out_count, args...); - } - - void SetErrorMsg(const char *fmt, va_list arg); - - private: - static std::atomic id_generator_; - int64_t id_; - ValueHolder::ValueHolderType type_; - ge::FastNode *fast_node_; // 通过ValueHolder创建的fast_node节点如果后续在图中被删除,此处会是无效指针,不能直接使用 - ge::OpDescPtr op_desc_; - int32_t index_; - int32_t placement_; - std::unique_ptr error_msg_; - ValueHolderPtr guarder_; - friend class ValueHolderUtils; -}; -} // namespace bg -} // namespace gert - -#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_H_ diff --git a/inc/exe_graph/lowering/value_holder_utils.h b/inc/exe_graph/lowering/value_holder_utils.h deleted file mode 100644 index 5dbad024663a737e5aab01df5d96d7c5dd871236..0000000000000000000000000000000000000000 --- a/inc/exe_graph/lowering/value_holder_utils.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_UTILS_H_ -#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_UTILS_H_ - -#include "value_holder.h" -#include "graph/op_desc.h" - -namespace gert { -namespace bg { -class ValueHolderUtils { -public: - static bool IsNodeValid(const ValueHolderPtr &holder); - - static bool IsNodeEqual(const ValueHolderPtr &src, const ValueHolderPtr &dst); - - static std::string GetNodeName(const ValueHolderPtr &holder); - static const char *GetNodeNameBarePtr(const ValueHolderPtr &holder); - - static std::string GetNodeType(const ValueHolderPtr &holder); - static const char *GetNodeTypeBarePtr(const ValueHolderPtr &holder); - - static ge::OpDescPtr GetNodeOpDesc(const ValueHolderPtr &holder); - static ge::OpDesc *GetNodeOpDescBarePtr(const ValueHolderPtr &holder); - - static bool IsDirectlyControlled(const bg::ValueHolderPtr &src, const bg::ValueHolderPtr &dst); -}; -} // bg -} // gert - -#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_UTILS_H_ diff --git a/inc/exe_graph/runtime/allocator.h b/inc/exe_graph/runtime/allocator.h deleted file mode 100644 index 9d7c44e49633879ba1edbebc14ffe78241f70ba7..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/allocator.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_EXE_GRAPH_RUNTIME_ALLOCATOR_H_ -#define METADEF_INC_EXE_GRAPH_RUNTIME_ALLOCATOR_H_ -#include -#include "exe_graph/runtime/tensor.h" - -namespace gert { -enum class AllocatorUsage { - kAllocNodeOutput, - kAllocNodeWorkspace, - kAllocNodeShapeBuffer, - kEnd -}; -struct AllocatorDesc { - TensorPlacement placement; - AllocatorUsage usage; - bool operator<(const AllocatorDesc &other) const { - return std::tie(placement, usage) < std::tie(other.placement, other.usage); - } - std::string GetKey() const { - return "Allocator-" + std::to_string(placement); - } -}; -} -#endif // METADEF_INC_EXE_GRAPH_RUNTIME_ALLOCATOR_H_ diff --git a/inc/exe_graph/runtime/atomic_clean_tiling_context.h b/inc/exe_graph/runtime/atomic_clean_tiling_context.h deleted file mode 100644 index 1e7f4db0853837a4869cb6897a498e4d06f83402..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/atomic_clean_tiling_context.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATOMICCLEANTILINGCONTEXT_H_ -#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATOMICCLEANTILINGCONTEXT_H_ -#include "exe_graph/runtime/tiling_context.h" -#include "exe_graph/runtime/continuous_vector.h" -namespace gert { -class AtomicCleanTilingContext : public TilingContext { - public: - /** - * 获取workspace size的列表 - * @return workspace size列表 - */ - const ContinuousVector *GetCleanWorkspaceSizes() const { - return GetInputPointer(0); - } - - /** - * 通过节点的输出index,获取需要清理的输出内存的大小 - * @param index 节点输出index - * @return 需要清理的输出内存的大小 - */ - uint64_t GetCleanOutputSize(size_t index) const { - return GetInputValue(index + 1U); - } -}; -static_assert(std::is_standard_layout::value, - "The class AtomicCleanTilingContext must be a POD"); -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATOMICCLEANTILINGCONTEXT_H_ diff --git a/inc/exe_graph/runtime/continuous_buffer.h b/inc/exe_graph/runtime/continuous_buffer.h deleted file mode 100644 index 69df5e9a6807715a38174e9eafa815c48698485c..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/continuous_buffer.h +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_BUFFER_H_ -#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_BUFFER_H_ -#include -#include -#include "graph/def_types.h" -namespace gert { -namespace bg { -class BufferPool; -} -class ContinuousBuffer { - public: - /** - * 获取buffer的数量 - * @return buffer的数量 - */ - size_t GetNum() const { - return num_; - } - /** - * 获取本实例的总长度 - * @return 本实例的总长度,单位为字节 - */ - size_t GetTotalLength() const { - return offsets_[num_]; - } - /** - * 获取一个buffer - * @tparam T buffer的类型 - * @param index buffer的index - * @return 指向该buffer的指针,若index非法,则返回空指针 - */ - template - const T *Get(size_t index) const { - if (index >= num_) { - return nullptr; - } - return ge::PtrToPtr(ge::PtrToPtr(this) + offsets_[index]); - } - /** - * 获取一个buffer,及其对应的长度 - * @tparam T buffer的类型 - * @param index buffer的index - * @param len buffer的长度 - * @return 指向该buffer的指针,若index非法,则返回空指针 - */ - template - const T *Get(size_t index, size_t &len) const { - if (index >= num_) { - return nullptr; - } - len = offsets_[index + 1] - offsets_[index]; - return ge::PtrToPtr(ge::PtrToPtr(this) + offsets_[index]); - } - /** - * 获取一个buffer - * @tparam T buffer的类型 - * @param index buffer的index - * @return 指向该buffer的指针,若index非法,则返回空指针 - */ - template - T *Get(size_t index) { - if (index >= num_) { - return nullptr; - } - return ge::PtrToPtr(ge::PtrToPtr(this) + offsets_[index]); - } - - private: - friend ::gert::bg::BufferPool; - size_t num_; - int64_t reserved_; // Reserved field, 8-byte aligned - size_t offsets_[1]; -}; -static_assert(std::is_standard_layout::value, "The class ContinuousText must be POD"); -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_BUFFER_H_ diff --git a/inc/exe_graph/runtime/dfx_info_filler.h b/inc/exe_graph/runtime/dfx_info_filler.h deleted file mode 100644 index d620c407a7d0962b25959e7fcbeeb36fa1054702..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/dfx_info_filler.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_EXE_GRAPH_RUNTIME_DFX_INFO_FILLER_H -#define METADEF_INC_EXE_GRAPH_RUNTIME_DFX_INFO_FILLER_H - -#include -#include -#include - -namespace gert { -enum class NodeProfInfoType : uint32_t { - kOriginalNode, - kCmoPreFetch, - kCmoInvalidate, - kCmoWriteBack, - kMixVectorCore, - kNodeTypeMax -}; - -class ProfilingInfoWrapper { - public: - virtual ~ProfilingInfoWrapper() = default; - - virtual void SetBlockDim(uint32_t block_dim) { - (void)block_dim; - } - - virtual void SetBlockDim(const uint32_t block_dim, const NodeProfInfoType prof_info_type) { - (void) block_dim; - (void) prof_info_type; - } - - virtual void SetMixLaunchEnable(const bool mix_launch_enable) { - (void) mix_launch_enable; - } - - virtual void SetLaunchTimeStamp(const uint64_t begin_time, const uint64_t end_time, - const NodeProfInfoType prof_info_type) { - (void) begin_time; - (void) end_time; - (void) prof_info_type; - } - - virtual void SetBlockDimForAtomic(uint32_t block_dim) { - (void)block_dim; - } - - virtual ge::graphStatus FillShapeInfo(const std::vector> &input_shapes, - const std::vector> &output_shapes) { - (void)input_shapes; - (void)output_shapes; - return ge::GRAPH_SUCCESS; - } -}; - -class DataDumpInfoWrapper { - public: - virtual ~DataDumpInfoWrapper() = default; - virtual ge::graphStatus CreateFftsCtxInfo(uint32_t thread_id, uint32_t context_id) = 0; - virtual ge::graphStatus AddFftsCtxAddr(uint32_t thread_id, bool is_input, uint64_t address, uint64_t size) = 0; - virtual void AddWorkspace(uintptr_t addr, int64_t bytes) = 0; - virtual bool SetStrAttr(const std::string &name, const std::string &value) = 0; -}; - -class ExceptionDumpInfoWrapper { - public: - virtual ~ExceptionDumpInfoWrapper() = default; - virtual void SetTilingData(uintptr_t addr, size_t size) = 0; - virtual void SetTilingKey(uint32_t key) = 0; - virtual void SetHostArgs(uintptr_t addr, size_t size) = 0; - virtual void SetDeviceArgs(uintptr_t addr, size_t size) = 0; - virtual void AddWorkspace(uintptr_t addr, int64_t bytes) = 0; -}; -} - -#endif - diff --git a/inc/exe_graph/runtime/dvpp_context.h b/inc/exe_graph/runtime/dvpp_context.h deleted file mode 100644 index e4f6667fbacad5ad53a022565dcf404b77357f6b..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/dvpp_context.h +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_DVPP_CONTEXT_H_ -#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_DVPP_CONTEXT_H_ -#include -#include "exe_graph/runtime/storage_shape.h" -#include "exe_graph/runtime/tensor.h" -#include "exe_graph/runtime/extended_kernel_context.h" - -namespace gert { -/** - * Dvpp kernel的context - */ -class DvppContext : public ExtendedKernelContext { -public: - /** - * 获取输入shape,输入shape中包含了原始shape与运行时shape - * @param index 输入index - * @return 输入shape指针,index非法时返回空指针 - */ - const StorageShape *GetInputShape(size_t index) const { - auto compute_node_info = GetComputeNodeInfo(); - if (compute_node_info == nullptr) { - return nullptr; - } - - if (index >= compute_node_info->GetInputsNum()) { - return nullptr; - } - - return GetInputPointer(index); - } - - /** - * 获取输入tensor - * - * **注意:只有在`IMPL_OP`实现算子时, 将对应输入设置为数据依赖后, - * 才可以调用此接口获取tensor,否则行为是未定义的。** - * @param index 输入index - * @return 输入tensor指针,index非法时返回空指针 - */ - const Tensor *GetInputTensor(size_t index) const { - return GetInputPointer(index); - } - - /** - * 根据输出index,获取输出shape指针,shape中包含了原始shape与运行时shape - * @param index 输出index - * @return 输出shape指针,index非法时,返回空指针 - */ - const StorageShape *GetOutputShape(size_t index) const { - auto compute_node_info = GetComputeNodeInfo(); - if (compute_node_info == nullptr) { - return nullptr; - } - - if (index >= compute_node_info->GetOutputsNum()) { - return nullptr; - } - - size_t offset = compute_node_info->GetInputsNum(); - return GetInputPointer(offset + index); - } -}; -static_assert(std::is_standard_layout::value, - "The class DvppContext must be a POD"); -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_DVPP_CONTEXT_H_ diff --git a/inc/exe_graph/runtime/execute_graph_types.h b/inc/exe_graph/runtime/execute_graph_types.h deleted file mode 100644 index 6ee4d6c35fd4f0c9edea39e16a23311cba368e34..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/execute_graph_types.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXECUTE_GRAPH_TYPES_H_ -#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXECUTE_GRAPH_TYPES_H_ -#include -namespace gert { -/** - * 执行图类型,在一个Model中,包含多张执行图,本枚举定义了所有执行图的类型 - */ -enum class ExecuteGraphType { - kInit, //!< 初始化图,本张图在图加载阶段执行 - kMain, //!< 主图,每次执行图时,均执行本张图 - kDeInit, //!< 去初始化图,在图卸载时,执行本张图 - kNum -}; - -/** - * 获取执行图的字符串描述 - * @param type 执行图类型枚举 - * @return - */ -inline const char *GetExecuteGraphTypeStr(const ExecuteGraphType type) { - if (type >= ExecuteGraphType::kNum) { - return nullptr; - } - constexpr const char *kStrs[static_cast(ExecuteGraphType::kNum)] = {"Init", "Main", "DeInit"}; - return kStrs[static_cast(type)]; -} -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXECUTE_GRAPH_TYPES_H_ diff --git a/inc/exe_graph/runtime/infer_symbol_shape_context.h b/inc/exe_graph/runtime/infer_symbol_shape_context.h deleted file mode 100644 index a6d95947bbda8d3d745fd37fe895d820fb35127f..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/infer_symbol_shape_context.h +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SYMBOL_SHAPE_CONTEXT_H_ -#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SYMBOL_SHAPE_CONTEXT_H_ - -#include -#include "exe_graph/runtime/runtime_attrs.h" -#include "symbolic_tensor.h" -#include "exe_graph/runtime/extended_kernel_context.h" - -namespace gert { -class InferSymbolShapeContext : public ExtendedKernelContext { - public: - /** - * 根据输入index,获取输入symbol shape指针,该接口仅在编译态使用; - * @param index 输入index; - * @return 输入symbol shape指针,index非法时,返回空指针。 - */ - const SymbolShape *GetInputSymbolShape(const size_t index) const { - if (GetInputSymbolTensor(index) == nullptr) { - return nullptr; - } - return &(GetInputSymbolTensor(index)->GetOriginSymbolShape()); - } - - /** - * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入symbol shape指针,该接口仅在编译态使用; - * @param ir_index IR原型定义中的index; - * @return symbol shape指针,index非法,或该INPUT没有实例化时,返回空指针。 - */ - const SymbolShape *GetOptionalInputSymbolShape(const size_t ir_index) const { - if (GetDynamicInputSymbolTensor(ir_index, 0) == nullptr) { - return nullptr; - } - return &(GetDynamicInputSymbolTensor(ir_index, 0)->GetOriginSymbolShape()); - } - - /** - * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入symbol shape指针,该接口仅在编译态使用; - * @param ir_index IR原型定义中的index; - * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2]; - * @return symbol shape指针,index或relative_index非法时,返回空指针。 - */ - const SymbolShape *GetDynamicInputSymbolShape(const size_t ir_index, const size_t relative_index) const { - if (GetDynamicInputSymbolTensor(ir_index, relative_index) == nullptr) { - return nullptr; - } - return &(GetDynamicInputSymbolTensor(ir_index, relative_index)->GetOriginSymbolShape()); - } - - /** - * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入symbol shape指针,该接口仅在编译态使用; - * @param ir_index IR原型定义中的index - * @return symbol shape指针,index非法,或该INPUT没有实例化时,返回空指针 - */ - const SymbolShape *GetRequiredInputSymbolShape(const size_t ir_index) const { - if (GetDynamicInputSymbolTensor(ir_index, 0) == nullptr) { - return nullptr; - } - return &(GetDynamicInputSymbolTensor(ir_index, 0)->GetOriginSymbolShape()); - } - - /** - * 根据输入index,获取输入SymbolTensor指针,该接口仅在编译态使用; - * 若算子被配置为'data'数据依赖,则返回的SymbolTensor对象中保存了的符号值;反之,内存地址为nullptr。 - * @param index 输入index - * @return 输入SymbolTensor指针,index非法时,返回空指针 - */ - const SymbolTensor *GetInputSymbolTensor(const size_t index) const { - return GetInputPointer(index); - } - - /** - * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入SymbolTensor指针,该接口仅在编译态使用; - * 若算子被配置为'data'数据依赖,则返回的SymbolTensor对象中保存了的符号值;反之,内存地址为nullptr。 - * @param ir_index IR原型定义中的index - * @return SymbolTensor指针,index非法,或该INPUT没有实例化时,返回空指针 - */ - const SymbolTensor *GetOptionalInputSymbolTensor(const size_t ir_index) const { - return GetDynamicInputPointer(ir_index, 0); - } - - /** - * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入tensor指针,该接口仅在编译态使用; - * 若算子被配置为'data'数据依赖,则返回的SymbolTensor对象中保存了的符号值;反之,内存地址为nullptr。 - * @param ir_index IR原型定义中的index - * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] - * @return SymbolTensor指针,index或relative_index非法时,返回空指针 - */ - const SymbolTensor *GetDynamicInputSymbolTensor(const size_t ir_index, const size_t relative_index) const { - return GetDynamicInputPointer(ir_index, relative_index); - } - - /** - * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入SymbolTensor指针,该接口仅在编译态使用; - * 若算子被配置为'data'数据依赖,则返回的SymbolTensor对象中保存了的符号值;反之,内存地址为nullptr。 - * @param ir_index IR原型定义中的index - * @return SymbolTensor指针,index非法时,返回空指针 - */ - const SymbolTensor *GetRequiredInputSymbolTensor(const size_t ir_index) const { - return GetDynamicInputPointer(ir_index, 0); - } - - /** - * 根据输出index,获取输出符号化Symbolshape指针,该接口仅在编译态使用; - * @param index 输出index; - * @return 输出符号化Symbolshape指针,index非法时,返回空指针。 - */ - SymbolShape *GetOutputSymbolShape(const size_t index) { - return GetOutputPointer(index); - } -}; - -static_assert(std::is_standard_layout::value, - "The class InferSymbolShapeContext must be a POD"); -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SYMBOL_SHAPE_CONTEXT_H_ diff --git a/inc/exe_graph/runtime/symbolic_shape.h b/inc/exe_graph/runtime/symbolic_shape.h deleted file mode 100644 index 8df4e3b56e3f6a1a0c13360c69b1df4b8ae2a006..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/symbolic_shape.h +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_SYMBOL_SHAPE_H_ -#define METADEF_CXX_INC_EXE_GRAPH_SYMBOL_SHAPE_H_ - -#include -#include -#include -#include -#include -#include -#include "utils/extern_math_util.h" -#include "graph/symbolizer/symbolic.h" -#include "graph/debug/ge_util.h" - -namespace gert { -/* - * 注意:此类是一个符号化的shape,它的成员变量是一个Expression数组,只在编译态使用,暂不考虑POD形式组织该类 - * */ -class SymbolShape { - public: - /** - * 默认构造一个符号化symbol shape,默认构造的shape实例中,dims长度为空 - */ - SymbolShape() = default; - - /** - * 通过dims值构造符号化shape,例如:SymbolShape({&s0, &s1, &s2, &s3})创建一个Shape实例,有4个维度, - * 每个维度的值分别是s0, s1, s2, s3 - * @param dims shape的所有dim值 - */ - SymbolShape(const std::initializer_list &args) : dims_(args) {} - - /** - * 拷贝构造函数,移动构造函数 - * @param other - */ - SymbolShape(const SymbolShape &other) = default; - SymbolShape &operator=(const SymbolShape &other) = default; - SymbolShape(SymbolShape &&other) = default; - SymbolShape &operator=(SymbolShape &&other) = default; - - /** - * 判断与另外一个shape对象是否相等,如果两个shape的dim num并且dim num内每个dim中的symbol值都相等,那么认为两个symbol shape相等 - * @param rht 另一个Shape对象 - * @return true/false - */ - bool operator==(const SymbolShape &rht) const { - if (this->dims_.size() != rht.dims_.size()) { - return false; - } - for (size_t i = 0; i < this->dims_.size(); i++) { - if ((this->dims_[i].IsValid()) && (rht.dims_[i].IsValid()) && (this->dims_[i] != rht.dims_[i])) { - return false; - } - } - return true; - } - - /** - * 判断与另一个Shape对象是否不等 - * @param rht 另一个SymbolShape对象 - * @return true/false - */ - bool operator!=(const SymbolShape &rht) const { - return !(*this == rht); - } - - /** - * 获取shape size表达式,如果是scalar场景,返回Symbol(1),如果symbol_shape中某个表达式非法,那么返回Symbol(0) - * @return shape-size,是一个Expression表达式 - */ - const ge::Expression &GetSymbolShapeSize() const { - if (symsize_cache_is_valid_) { // 性能优化,避免重复计算 - return symbol_shape_size_; - } - symbol_shape_size_ = ge::Symbol(1); - for (const auto &dim : dims_) { - if (dim.IsValid()) { - symbol_shape_size_ = ge::sym::Mul(symbol_shape_size_, dim); - } else { - static auto kZero = ge::Symbol(0); - return kZero; - } - } - symsize_cache_is_valid_ = true; - return symbol_shape_size_; - } - - /** - * 判断本Symbol shape是否为标量,所谓标量,是指dims的长度为0,即shape为标量 - * @return true/false - */ - bool IsScalar() const { - return dims_.empty(); - } - - /** - * 设置shape为标量 - * @param none - */ - void SetScalar() { - MutableDims().clear(); - } - - /** - * 清空symbol shape的所有维度 - * @return none - */ - void Clear() { - MutableDims().clear(); - } - - /** - * 获取dim num - * @return - */ - size_t GetDimNum() const { - return dims_.size(); - } - - /** - * 向后扩展一个dim值,如果扩展的dim数量超出Shape的最大限制,那么本函数不做任何事情 - * @param 扩展的dim值 - * @return this引用 - */ - SymbolShape &AppendDim(const ge::Expression &dim_value) { - MutableDims().emplace_back(dim_value); - return *this; - } - - /** - * 获取只读的symbol shape的所有维度的常量引用 - * @return 返回一个常量list,返回所有维度的符号化表达,例如[s0, s1, s2],返回[s0, s1, s2] - */ - const std::vector &GetDims() const { - return dims_; - } - - /** - * 获取只读的第idx位置的dim值 - * @param idx dim的index,调用者需要保证index合法 - * @return dim值,Expression指针类型,在idx超出MaxDimNum时,会触发vector访问异常 - */ - const ge::Expression &GetDim(const size_t idx) const { - return dims_[idx]; - } - - /** - * 获取可修改的symbol shape的所有维度的引用 - * @return 返回一个常量list,返回所有维度的符号化表达,例如[s0, s1, s2],返回[s0, s1, s2] - */ - std::vector &MutableDims() { - symsize_cache_is_valid_ = false; - return dims_; - } - - /** - * 获取只读的第idx位置的dim值 - * @param idx dim的index,调用者需要保证index合法 - * @return dim值,Expression指针类型,在idx超出MaxDimNum时,会触发vector访问异常 - */ - const ge::Expression &GetDim(const size_t idx) { - return dims_[idx]; - } - - /** - * 获取可修改的第idx位置的dim值 - * @param idx dim的index,调用者需要保证index合法 - * @return dim值,Expression指针类型,在idx超出MaxDimNum时,会触发vector访问异常 - */ - ge::Expression &MutableDim(const size_t idx) { - symsize_cache_is_valid_ = false; - return dims_[idx]; - } - - private: - std::vector dims_; - mutable bool symsize_cache_is_valid_{false}; // 性能优化,避免重复计算Symbol shape size - mutable ge::Expression symbol_shape_size_; -}; -} // namespace gert - -#endif // METADEF_CXX_INC_EXE_GRAPH_SYMBOL_SHAPE_H_ diff --git a/inc/exe_graph/runtime/symbolic_tensor.h b/inc/exe_graph/runtime/symbolic_tensor.h deleted file mode 100644 index 56527c9afa32e154f7f7d6c468ae7f0111e555a0..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/symbolic_tensor.h +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_SYMBOL_TENSOR_H_ -#define METADEF_CXX_INC_EXE_GRAPH_SYMBOL_TENSOR_H_ - -#include -#include -#include -#include -#include -#include -#include "utils/extern_math_util.h" -#include "graph/debug/ge_util.h" -#include "symbolic_shape.h" - -namespace gert { -class SymbolTensor { - public: - SymbolTensor() = default; - SymbolTensor(const std::initializer_list &origin_symbol_shape) - : origin_symbol_shape_(origin_symbol_shape) {} - - SymbolTensor(const std::initializer_list &origin_symbol_shape, - const std::initializer_list &symbolic_values) - : origin_symbol_shape_(origin_symbol_shape), - symbolic_values_(ge::ComGraphMakeUnique>(symbolic_values)) {} - - // 拷贝构造函数 - SymbolTensor(const SymbolTensor &other) - : origin_symbol_shape_(other.origin_symbol_shape_), - symbolic_values_(other.symbolic_values_ - ? ge::ComGraphMakeUnique>(*other.symbolic_values_) - : nullptr) {} - - // 拷贝赋值运算符 - SymbolTensor &operator=(const SymbolTensor &other) { - if (this != &other) { - origin_symbol_shape_ = other.origin_symbol_shape_; - if (other.symbolic_values_) { - symbolic_values_ = ge::ComGraphMakeUnique>(*other.symbolic_values_); - } else { - symbolic_values_.reset(); - } - } - return *this; - } - - // 移动构造函数 - SymbolTensor(SymbolTensor &&other) noexcept = default; - - // 移动赋值运算符 - SymbolTensor &operator=(SymbolTensor &&other) noexcept = default; - - /** - * 获取只读的原始格式符号化shape - * @return 原始格式符号化shape引用 - */ - const SymbolShape &GetOriginSymbolShape() const { - return origin_symbol_shape_; - } - /** - * 获取可修改的原始格式符号化shape - * @return 原始格式符号化shape引用 - */ - SymbolShape &MutableOriginSymbolShape() { - return origin_symbol_shape_; - } - /** - * 设置原始格式符号化shape - * @return void - */ - void SetOriginSymbolShape(const SymbolShape &ori_symbol_shape) { - origin_symbol_shape_ = ori_symbol_shape; - } - - /** - * 设置原始格式、存储格式符号化shape - * @return void - */ - void SetSymbolShape(const SymbolShape &symbol_shape) { - SetOriginSymbolShape(symbol_shape); - } - /** - * 获取symbol tensor中存储的符号值,一般在data dependent算子场景使用 - * 1. 返回值为nullptr,表明无符号值 - * 2. 返回值为{},表明符号值为空 - * 3. 返回值不为空,表明存在符号值 - * - * @return 只读的symbol tensor符号值; - */ - const std::vector *GetSymbolicValue() const { - return symbolic_values_.get(); - } - /** - * 设置symbol tensor中存储的符号值 - * @param symbolic_values - */ - void SetSymbolicValue(std::unique_ptr> symbolic_values) { - symbolic_values_ = std::move(symbolic_values); - } - /** - * 获取tensor data - * @return 可写的symbol tensor符号值 - */ - std::vector *MutableSymbolicValue() { - if (!symbolic_values_) { - symbolic_values_ = ge::ComGraphMakeUnique>(); - } - - if (!symbolic_values_) { - return nullptr; - } - - return symbolic_values_.get(); - } - private: - SymbolShape origin_symbol_shape_; - std::unique_ptr> symbolic_values_; -}; -} // namespace gert - -#endif // METADEF_CXX_INC_EXE_GRAPH_SYMBOL_SHAPE_H_ diff --git a/inc/exe_graph/runtime/tensor_data_utils.h b/inc/exe_graph/runtime/tensor_data_utils.h deleted file mode 100644 index fd0cf55ca013d9e149ee45b5c584c98fcd5b7e2d..0000000000000000000000000000000000000000 --- a/inc/exe_graph/runtime/tensor_data_utils.h +++ /dev/null @@ -1,110 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXE_GRAPH_TENSOR_DATA_UTILS_H_ -#define METADEF_CXX_INC_EXE_GRAPH_TENSOR_DATA_UTILS_H_ - -#include "exe_graph/runtime/tensor_data.h" -#include "graph/types.h" - -namespace gert { -namespace { -struct PlacementBase { - virtual ~PlacementBase() = default; -}; -struct PlacementDeviceHbm : public PlacementBase { - ~PlacementDeviceHbm() override = default; -}; -struct PlacementDeviceP2p : public PlacementDeviceHbm { - ~PlacementDeviceP2p() override = default; -}; -struct PlacementHost : public PlacementBase { - ~PlacementHost() override = default; -}; - -class PlacementClassFactory { - public: - const PlacementBase *Get(const TensorPlacement placement) const { - switch (placement) { - case kOnDeviceHbm: - return &hbm_; - case kOnDeviceP2p: - return &p2p_; - case kOnHost: - case kFollowing: - return &host_; - case kTensorPlacementEnd: - return &base_; - default: - return &base_; - } - } - bool CanSrcDynamicCastToDst(const TensorPlacement src, const TensorPlacement dst) const { - const auto *src_ptr = Get(src); - bool cast_success; - switch (dst) { - case kOnDeviceHbm: - cast_success = (dynamic_cast(src_ptr) != nullptr); - break; - case kOnDeviceP2p: - cast_success = (dynamic_cast(src_ptr) != nullptr); - break; - case kOnHost: - case kFollowing: - cast_success = (dynamic_cast(src_ptr) != nullptr); - break; - case kTensorPlacementEnd: - cast_success = (dynamic_cast(src_ptr) != nullptr); - break; - default: - cast_success = (dynamic_cast(src_ptr) != nullptr); - break; - } - return cast_success; - } - - private: - PlacementDeviceHbm hbm_; - PlacementDeviceP2p p2p_; - PlacementHost host_; - PlacementBase base_; -}; -} - -inline const ge::char_t *GetPlacementStr(const TensorPlacement placement) { - static const ge::char_t *placement_str[static_cast(kTensorPlacementEnd) + 1] = {"DeviceHbm", "HostDDR", - "HostDDR", "DeviceP2p", - "Unknown"}; - if ((placement >= kTensorPlacementEnd) || (placement < kOnDeviceHbm)) { - return placement_str[kTensorPlacementEnd]; - } - return placement_str[placement]; -} - -/** - * 判断源placement到目的placement是否需要拷贝 - * @param src_placement 源placement - * @param dst_placement 目的placement - */ -inline bool IsPlacementSrcToDstNeedCopy(const TensorPlacement src_placement, const TensorPlacement dst_placement) { - if ((src_placement >= kTensorPlacementEnd) || (dst_placement >= kTensorPlacementEnd)) { - return true; - } - - static PlacementClassFactory factory; - const auto *dst_class_ptr = factory.Get(dst_placement); - const auto *src_class_ptr = factory.Get(src_placement); - if (dst_class_ptr == src_class_ptr) { - return false; - } - - return !factory.CanSrcDynamicCastToDst(src_placement, dst_placement); -} -} // namespace gert -#endif // METADEF_CXX_INC_EXE_GRAPH_TENSOR_DATA_UTILS_H_ diff --git a/inc/external/exe_graph/lowering/lowering_opt.h b/inc/external/exe_graph/lowering/lowering_opt.h deleted file mode 100644 index f6de626de0d7e71b053fa293a2058c33816480e1..0000000000000000000000000000000000000000 --- a/inc/external/exe_graph/lowering/lowering_opt.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_LOWERING_OPT_H_ -#define INC_EXTERNAL_LOWERING_OPT_H_ - -#include - -namespace gert { -struct LoweringOption { - /** - * 是否相信用户传入的输出tensor上的shape,如果开启了本选项,可以省去计算图上输出节点的InferShape,提升一点Host调度性能。 - * 与此同时,也会损失掉对外部传入的输出Tensor Shape、TensorData长度的校验能力。 - * - * 约束: - * 1. 如果一个节点有多个输出,并且部分输出并不是网络的输出, - * 那么这个节点的InferShape不会被省掉,体现为在这个节点上,本选项会被忽略。 - * 2. 如果一个节点没有InferShape函数,例如第三类、第四类算子, - * 需要从Device拷贝回Shape,那么在这个节点上,本选项会被忽略。 - * 3. 本选项是个加载时选项,一旦选定后,意味着后续本model的每次调用都需要用户传入输出shape,否则可能会导致执行失败 - */ - bool trust_shape_on_out_tensor = false; - - /** - * 总是零拷贝开关,默认关闭。如果本开关打开,含义是外部调用者总是保证会正确地申请输出内存,包含: - * 1. 申请的输出内存大于等于输出shape所以计算出的Tensor大小 - * 2. 输出内存的placement正确 - * - * 打开本开关后,可以提升一点Host调度性能。与此同时,对于零拷贝失效的回退处理将不再进行, - * 在外部申请的输出内存错误、或未申请输出内存时,执行报错。 - */ - bool always_zero_copy = false; - - /** - * 总是使用外部allocator开关,默认关闭。如果本开关打开,含义是外部调用者总是保证会传入所有allocator,包含: - * 1. 创建所有在加载/执行阶段所需的allocator并传入 - * 2. 由于总是信任外部allocator,一旦开启后,如果在加载/执行阶段获取外置allocator失败,则报错。 - * - * 打开本开关后,在执行器内部不需要再创建allocator,减少资源浪费 - */ - bool always_external_allocator = false; - - /** - * 使能单流,默认关闭。如果本开关打开,执行时动态根图任务下发在一条流上。 - * 该开关由rt2的使用者(acl/hybrid model)根据设备上流是否充裕,来决定是否只使用一条流资源。 - * - */ - bool enable_single_stream = false; - - /** - * 二进制兼容保留字段,增加option时,对应缩减删除reserved长度 - */ - uint8_t reserved[4U + 8U] = {0U}; -}; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/inc/external/ge/framework/common/taskdown_common.h b/inc/external/ge/framework/common/taskdown_common.h deleted file mode 100644 index d31ee6af5e5b02c3b2a11d5250ec67b0bc51989b..0000000000000000000000000000000000000000 --- a/inc/external/ge/framework/common/taskdown_common.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_GE_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ -#define INC_EXTERNAL_GE_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ - -namespace ge { -enum class ccKernelType : uint32_t { - CCE_AI_CORE = 0, /* cce aicore */ - CCE_AI_CPU = 1, /* cce aicpu */ - TE = 2, /* te operator */ - CUSTOMIZED = 3, /* customized operator */ - TE_AI_CORE = 4, /* te aicore operator */ - TE_AI_CPU = 5, /* te aicpu operator */ - AI_CPU = 6, /* aicpu */ - CUST_AI_CPU = 7, /* custom aicpu */ - HOST_CPU = 8, /* host cpu */ - DVPP = 9, /* dvpp */ - AI_CPU_KFC = 10, /* aicpu kfc */ - MIX_AICORE = 11, - MIX_VECTOR_CORE = 12, /* vector core only */ - INVALID = 10000 /* unknown kernel type */ -}; -} // namespace ge - -#endif // INC_EXTERNAL_GE_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ diff --git a/inc/external/graph/arg_desc_info.h b/inc/external/graph/arg_desc_info.h deleted file mode 100644 index d41a19c1d430738125a09e8c33ff9f347b267a51..0000000000000000000000000000000000000000 --- a/inc/external/graph/arg_desc_info.h +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef METADEF_INC_EXTERNAL_GRAPH_ARG_DESC_INFO_H -#define METADEF_INC_EXTERNAL_GRAPH_ARG_DESC_INFO_H - -#include -#include -#include -#include "graph/ge_error_codes.h" -#include "ascend_string.h" - -namespace ge { -enum class ArgDescType { - kIrInput = 0, // 输入 - kIrOutput, // 输出 - kWorkspace, // workspace地址 - kTiling, // tiling地址 - kHiddenInput, // ir上不表达的额外输入 - kCustomValue, // 自定义内容 - kIrInputDesc, // 具有描述信息的输入地址 - kIrOutputDesc, // 具有描述信息的输入地址 - kInputInstance, // 实例化的输入 - kOutputInstance, // 实例化的输出 - kEnd -}; - -enum class HiddenInputSubType { - kHcom, // 用于通信的hiddenInput,mc2算子使用 - kEnd -}; - -class ArgDescInfoImpl; -using ArgDescInfoImplPtr = std::unique_ptr; -class ArgDescInfo { - public: - /** - * 构造ArgDescInfo对象,ArgDescInfo对象主要用于描述args中某一个地址所表达的含义 - * @param arg_type 当前args地址的类型 - * @param ir_index 当前args地址对应算子的ir索引 - * @param is_folded 当前args地址是否为二级指针 (即是否将多个地址折叠到一个二级指针中设置给args,设置为true) - */ - explicit ArgDescInfo(ArgDescType arg_type, - int32_t ir_index = -1, bool is_folded = false); - ~ArgDescInfo(); - ArgDescInfo(const ArgDescInfo &other); - ArgDescInfo(ArgDescInfo &&other) noexcept; - ArgDescInfo &operator=(const ArgDescInfo &other); - ArgDescInfo &operator=(ArgDescInfo &&other) noexcept; - /** - * 构造一个CustomValue类型的ArgDescInfo对象 - * @param custom_value 自定义内容 - * @return ArgDescInfo对象 - */ - static ArgDescInfo CreateCustomValue(uint64_t custom_value); - /** - * 构造一个HiddenInput类型的ArgDescInfo对象 - * @param hidden_type hidden输入的类型 - * @return ArgDescInfo对象 - */ - static ArgDescInfo CreateHiddenInput(HiddenInputSubType hidden_type); - /** - * 获取当前ArgDescInfo的类型 - * @return 当ArgDescType非法时,返回kEnd,合法时,返回此arg地址的类型(未设置时的默认值为kEnd) - */ - ArgDescType GetType() const; - /** - * 获取自定义内容的值,只有当type为kCustomValue时,才能获取到内容 - * @return 当ArgDescType非法时,返回uint64_max, 合法时,返回自定义内容(未设置时的默认值为0) - */ - uint64_t GetCustomValue() const; - /** - * 设置自定义内容,只有当type为kCustomValue时,才能设置此字段 - * @param custom_value 自定义内容 - * @return SUCCESS: 设置成功 其他:ArgDescInfo非法或者type为非kCustomValue - */ - graphStatus SetCustomValue(uint64_t custom_value); - /** - * 获取hidden输入的type,只有当type为kHiddenInput时,才能获取到内容 - * @return 当ArgDescType非法时,返回kEnd, 合法时,返回hidden输入的type(未设置时的默认值为kEnd) - */ - HiddenInputSubType GetHiddenInputSubType() const; - /** - * 设置hidden输入的type,只有当type为kHiddenInput时,才能设置此字段 - * @param hidden_type hidden输入的type - * @return SUCCESS: 设置成功 其他:ArgDescInfo非法或者type为非kHiddenInput - */ - graphStatus SetHiddenInputSubType(HiddenInputSubType hidden_type); - /** - * 获取当前arg地址对应的ir索引 - * @return 返回ir索引(未设置时的默认值为-1) - */ - int32_t GetIrIndex() const; - /** - * 设置当前arg地址对应的ir索引 - * @param ir_index ir索引 - */ - void SetIrIndex(int32_t ir_index); - /** - * 判断当前arg地址是否为二级指针 - * @return true: 是二级指针; false:不是二级指针(未设置时的默认值为false) - */ - bool IsFolded() const; - /** - * 设置当前arg地址是否为二级指针 - * @param is_folded:是否为二级指针 - */ - void SetFolded(bool is_folded); - private: - friend class ArgsFormatSerializer; - ArgDescInfo() = delete; - explicit ArgDescInfo(ArgDescInfoImplPtr &&impl); - std::unique_ptr impl_; -}; - -class ArgsFormatSerializer { - public: - /** - * 序列化argsFormat,argsFormat是由若干个ArgDescInfo组成,每一个ArgDescInfo表达当前args地址的信息 - * @param args_format args_format信息 - * @return 成功:返回序列化后的argsFormat; 失败:空字符串 - */ - static AscendString Serialize(const std::vector &args_format); - /** - * 将一个argsFormat的序列化字符串反序列化 - * @param args_str args_format的序列化字符串 - * @return 成功:返回反序列化后的argsFormat; 失败:空vector - */ - static std::vector Deserialize(const AscendString &args_str); -}; -} -#endif // METADEF_INC_EXTERNAL_GRAPH_ARG_DESC_INFO_H \ No newline at end of file diff --git a/inc/external/graph/ct_infer_shape_context.h b/inc/external/graph/ct_infer_shape_context.h deleted file mode 100644 index 04a490a87409875d2bc2c050b149abbb79f627c9..0000000000000000000000000000000000000000 --- a/inc/external/graph/ct_infer_shape_context.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_GRAPH_CT_INFER_SHAPE_CONTEXT_H_ -#define METADEF_CXX_INC_GRAPH_CT_INFER_SHAPE_CONTEXT_H_ -#include -#include "exe_graph/runtime/infer_shape_context.h" -#include "graph/inference_context.h" - -namespace gert { -/** - * 在节点输入后的扩展输入的索引,若需要扩展,请新增枚举类型 - */ -enum class CtInferShapeInputExternLayout : uint32_t { - kInferShapeFunc = 0, - kInferenceContext = 1, -}; - -class CtInferShapeContext : public InferShapeContext { - public: - /** - * 获取InferenceContext指针 - * @param NA - * @return 输出InferenceContext指针 - */ - ge::InferenceContext *GetInferenceContext() const { - const auto compute_node_info = GetComputeNodeInfo(); - if (compute_node_info == nullptr) { - return nullptr; - } - const auto offset = - compute_node_info->GetInputsNum() + static_cast(CtInferShapeInputExternLayout::kInferenceContext); - return MutableInputPointer(offset); - } -}; -static_assert(std::is_standard_layout::value, "The class CtInferShapeContext must be a POD"); -} // namespace gert -#endif // METADEF_CXX_INC_GRAPH_CT_INFER_SHAPE_CONTEXT_H_ diff --git a/inc/external/graph/ct_infer_shape_range_context.h b/inc/external/graph/ct_infer_shape_range_context.h deleted file mode 100644 index 8dfc159984b4ef1349e158bc59158a1f3436bf11..0000000000000000000000000000000000000000 --- a/inc/external/graph/ct_infer_shape_range_context.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_GRAPH_CT_INFER_SHAPE_RANGE_CONTEXT_H_ -#define METADEF_CXX_INC_GRAPH_CT_INFER_SHAPE_RANGE_CONTEXT_H_ -#include -#include "exe_graph/runtime/infer_shape_range_context.h" -#include "graph/inference_context.h" - -namespace gert { -/** - * 在节点输入后的扩展输入的索引,若需要扩展,请新增枚举类型 - */ -enum class CtInferShapeRangeInputExternLayout : uint32_t { - kInferShapeFunc = 0, - kInferenceContext = 1, -}; - -class CtInferShapeRangeContext : public InferShapeRangeContext { - public: - /** - * 获取InferenceContext指针 - * @param NA - * @return 输出InferenceContext指针 - */ - ge::InferenceContext *GetInferenceContext() const { - const auto compute_node_info = GetComputeNodeInfo(); - if (compute_node_info == nullptr) { - return nullptr; - } - const auto offset = - compute_node_info->GetInputsNum() + static_cast(CtInferShapeRangeInputExternLayout::kInferenceContext); - return MutableInputPointer(offset); - } -}; -static_assert(std::is_standard_layout::value, - "The class CtInferShapeRangeContext must be a POD"); -} // namespace gert -#endif // METADEF_CXX_INC_GRAPH_CT_INFER_SHAPE_RANGE_CONTEXT_H_ diff --git a/inc/external/graph/gnode.h b/inc/external/graph/gnode.h deleted file mode 100644 index 46a9381203c45204b08c2e6fc41fbc3e2421f593..0000000000000000000000000000000000000000 --- a/inc/external/graph/gnode.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_GRAPH_GNODE_H_ -#define INC_EXTERNAL_GRAPH_GNODE_H_ - -#include -#include - -#include "./ge_error_codes.h" -#include "./types.h" -#include "./tensor.h" -#include "./ascend_string.h" - -namespace ge { -class AttrValue; -class GNode; -class OpDesc; -class Graph; -class ComputeGraph; -using GNodePtr = std::shared_ptr; -using GraphPtr = std::shared_ptr; -using OpBytes = std::vector; -using OpDescPtr = std::shared_ptr; -using ComputeGraphPtr = std::shared_ptr; - -class NodeImpl; -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GNode { - public: - GNode(); - - ~GNode() = default; - - graphStatus GetType(AscendString &type) const; - - graphStatus GetName(AscendString &name) const; - - std::pair GetInDataNodesAndPortIndexs(const int32_t index) const; - - std::vector GetInControlNodes() const; - - std::vector> GetOutDataNodesAndPortIndexs(const int32_t index) const; - - std::vector GetOutControlNodes() const; - - graphStatus GetInputConstData(const int32_t index, Tensor &data) const; - - graphStatus GetInputIndexByName(const AscendString &name, int32_t &index); - - graphStatus GetOutputIndexByName(const AscendString &name, int32_t &index); - - graphStatus GetDynamicInputIndexesByName(const AscendString &name, std::vector &indexes); - - graphStatus GetDynamicOutputIndexesByName(const AscendString &name, std::vector &indexes); - - size_t GetInputsSize() const; - - size_t GetOutputsSize() const; - - graphStatus GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const; - - graphStatus UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc); - - graphStatus GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const; - - graphStatus UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc); - - graphStatus GetAttr(const AscendString &name, int64_t &attr_value) const; - graphStatus GetAttr(const AscendString &name, int32_t &attr_value) const; - graphStatus GetAttr(const AscendString &name, uint32_t &attr_value) const; - graphStatus GetAttr(const AscendString &name, float32_t &attr_value) const; - graphStatus GetAttr(const AscendString &name, AscendString &attr_value) const; - graphStatus GetAttr(const AscendString &name, bool &attr_value) const; - graphStatus GetAttr(const AscendString &name, Tensor &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_values) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus GetAttr(const AscendString &name, OpBytes &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector> &attr_value) const; - graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus GetAttr(const AscendString &name, ge::DataType &attr_value) const; - graphStatus GetAttr(const AscendString &name, AttrValue &attr_value) const; - - graphStatus SetAttr(const AscendString &name, int64_t &attr_value) const; - graphStatus SetAttr(const AscendString &name, int32_t &attr_value) const; - graphStatus SetAttr(const AscendString &name, uint32_t &attr_value) const; - graphStatus SetAttr(const AscendString &name, float32_t &attr_value) const; - graphStatus SetAttr(const AscendString &name, AscendString &attr_value) const; - graphStatus SetAttr(const AscendString &name, bool &attr_value) const; - graphStatus SetAttr(const AscendString &name, Tensor &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_values) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus SetAttr(const AscendString &name, OpBytes &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector> &attr_value) const; - graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; - graphStatus SetAttr(const AscendString &name, ge::DataType &attr_value) const; - graphStatus SetAttr(const AscendString &name, AttrValue &attr_value) const; - - // 添加AttrValue类型的输入输出属性支持 - graphStatus GetOutputAttr(const AscendString &name, uint32_t output_index, AttrValue &attr_value) const; - graphStatus SetOutputAttr(const AscendString &name, uint32_t output_index, const AttrValue &attr_value); - graphStatus GetInputAttr(const AscendString &name, uint32_t input_index, AttrValue &attr_value) const; - graphStatus SetInputAttr(const AscendString &name, uint32_t input_index, const AttrValue &attr_value); - - bool HasAttr(const AscendString &name); - - graphStatus GetSubgraph(uint32_t index, GraphPtr &graph) const; - - graphStatus GetALLSubgraphs(std::vector &graph_list) const; - - /** - * @brief Add the subgraph to the node - * @param subgraph_ir_name IR subgraph name - * @param subgraph the subgraph to be added - * @return GRAPH_SUCCESS: success, others: failed - */ - graphStatus SetSubgraph(const AscendString &subgraph_ir_name, const Graph &subgraph); - - /** - * @brief Add subgraphs to the node - * @param subgraph_ir_name Dynamic IR subgraphs name - * @param subgraphs subgraphs to be added - * @return GRAPH_SUCCESS: success, others: failed - */ - graphStatus SetSubgraphs(const AscendString &subgraph_ir_name, const std::vector &subgraphs); - private: - std::shared_ptr impl_; - friend class NodeAdapter; -}; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_GNODE_H_ diff --git a/inc/external/graph/graph.h b/inc/external/graph/graph.h deleted file mode 100644 index bf6d0c394edd401a5e7dbf3027eb81e85a4bff72..0000000000000000000000000000000000000000 --- a/inc/external/graph/graph.h +++ /dev/null @@ -1,185 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_GRAPH_GRAPH_H_ -#define INC_EXTERNAL_GRAPH_GRAPH_H_ - -#include -#include -#include -#include - -#include "./operator.h" -#include "./gnode.h" - -namespace ge { -class Graph; -class GraphImpl; -class GraphBuffer; - -using GraphImplPtr = std::shared_ptr; -using GraphPtr = std::shared_ptr; - -using ConstGraphPtr = std::shared_ptr; - -/*lint -e148*/ -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { - friend class GraphUtils; - friend class GraphUtilsEx; - - public: - ATTRIBUTED_DEPRECATED(Graph(const char_t *)) - explicit Graph(const std::string &name); - - explicit Graph(const char_t *name); - - Graph() = default; - - ~Graph() = default; - /** - * 触发内部图的构建, 用于基于Operator的IR构图场景 - * @param inputs 图的输入节点 - * @return - */ - Graph &SetInputs(const std::vector &inputs); - - Graph &SetOutputs(const std::vector &outputs); - - Graph &SetOutputs(const std::vector>> &output_indexs); - - ATTRIBUTED_DEPRECATED(Graph &SetOutputs(const std::vector> &outputs); - - Graph &SetOutputs(const std::vector> &outputs); - - Graph &SetTargets(const std::vector &targets); - - bool IsValid() const; - graphStatus SetValid(); - - graphStatus AddOp(const ge::Operator &op); - - ATTRIBUTED_DEPRECATED(graphStatus FindOpByName(const char_t *, ge::Operator &)) - graphStatus FindOpByName(const std::string &name, ge::Operator &op) const; - - graphStatus FindOpByName(const char_t *name, ge::Operator &op) const; - - ATTRIBUTED_DEPRECATED(graphStatus FindOpByType(const char_t *, std::vector &)) - graphStatus FindOpByType(const std::string &type, std::vector &ops) const; - - graphStatus FindOpByType(const char_t *type, std::vector &ops) const; - - ATTRIBUTED_DEPRECATED(graphStatus GetAllOpName(std::vector &) const) - graphStatus GetAllOpName(std::vector &op_name) const; - - graphStatus GetAllOpName(std::vector &names) const; - - ATTRIBUTED_DEPRECATED(graphStatus SaveToFile(const char_t *file_name) const) - graphStatus SaveToFile(const std::string &file_name) const; - - graphStatus SaveToFile(const char_t *file_name) const; - - ATTRIBUTED_DEPRECATED(graphStatus LoadFromFile(const char_t *)) - graphStatus LoadFromFile(const std::string &file_name); - - graphStatus LoadFromFile(const char_t *file_name); - - graphStatus LoadFromSerializedModelArray(const void *serialized_model, size_t size); - - graphStatus SaveToMem(GraphBuffer &graph_buffer) const; - - graphStatus LoadFromMem(const GraphBuffer &graph_buffer); - - graphStatus LoadFromMem(const uint8_t *data, const size_t len); - - ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &) const) - const std::string &GetName() const; - - graphStatus GetName(AscendString &name) const; - - /// - /// Set is need train iteration. - /// If set true, it means this graph need to be run iteration some - /// times(according variant "npu_runconfig/iterations_per_loop"). - /// @param need_iteration need_iteration:whether to set iteration or not - /// - void SetNeedIteration(bool need_iteration); - - std::vector GetAllNodes() const; - - std::vector GetDirectNode () const; - - graphStatus RemoveNode(GNode &node); - - graphStatus RemoveNode(GNode &node, bool contain_subgraph); - - graphStatus RemoveEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, const int32_t dst_port_index); - - GNode AddNodeByOp(const Operator &op); - - graphStatus AddDataEdge(GNode &src_node, const int32_t src_port_index, - GNode &dst_node, const int32_t dst_port_index); - - graphStatus AddControlEdge(GNode &src_node, GNode &dst_node); - - graphStatus CopyFrom(const Graph &src_graph); - - /** - * @brief Find the GNode with the target node_name in the graph - * @param node_name GNode name - * @return GNodePtr GNode pointer in the graph, return nullptr if failed - */ - GNodePtr FindNodeByName(const AscendString &node_name) const; - - /** - * @brief Get the parent graph of current sub graph - * @return ConstGraphPtr The parent graph shared pointer of current graph, return nullptr if failed - */ - ConstGraphPtr GetParentGraph() const; - - /** - * @brief Get the parent node of current sub graph - * @return GNodePtr The parent node shared pointer of current graph, return nullptr if failed - */ - GNodePtr GetParentNode() const; - - static GraphPtr ConstructFromInputs(const std::vector &inputs, const AscendString &name); - - // 添加AttrValue类型的属性支持 - graphStatus SetAttr(const AscendString &name, const AttrValue &attr_value); - graphStatus GetAttr(const AscendString &name, AttrValue &attr_value) const; - - enum class DumpFormat : uint32_t { - kOnnx, - kTxt - }; - /** - * 将graph序列化到ostream中 - * 不包含权重等数据,只包含图结构及相关属性 - * @param format - * @param o_stream - * @return - */ - graphStatus Dump(DumpFormat format, std::ostream &o_stream) const; - - /** - * 将graph序列化到执行路径下的文件中 - * 不包含权重等数据,只包含图结构及相关属性 - * @param suffix - * @return - */ - graphStatus DumpToFile(DumpFormat format, const AscendString &suffix) const; - - private: - - GraphImplPtr impl_{nullptr}; -}; -} // namespace ge - -#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/inc/external/graph/graph_buffer.h b/inc/external/graph/graph_buffer.h deleted file mode 100644 index a5e3f56be789c28a9364ef6d2b7f98b2bec57021..0000000000000000000000000000000000000000 --- a/inc/external/graph/graph_buffer.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_GRAPH_BUFFER_H_ -#define INC_EXTERNAL_GRAPH_BUFFER_H_ - -#include -#include -#include -#include "./types.h" - -namespace ge { -class Graph; -class Buffer; -using BufferPtr = std::shared_ptr; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GraphBuffer { - public: - GraphBuffer(); - GraphBuffer(const GraphBuffer &) = delete; - GraphBuffer &operator=(const GraphBuffer &) = delete; - ~GraphBuffer(); - - const std::uint8_t *GetData() const; - std::size_t GetSize() const; - - private: - BufferPtr buffer_{nullptr}; - friend class Graph; -}; -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_BUFFER_H_ diff --git a/inc/external/graph/kernel_launch_info.h b/inc/external/graph/kernel_launch_info.h deleted file mode 100644 index 065d1b8b995c6bf7c37b034b2a8374c3368e74f3..0000000000000000000000000000000000000000 --- a/inc/external/graph/kernel_launch_info.h +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_EXTERNAL_GRAPH_KERNEL_LAUNCH_INFO_H -#define METADEF_INC_EXTERNAL_GRAPH_KERNEL_LAUNCH_INFO_H - -#include -#include -#include -#include "exe_graph/runtime/exe_res_generation_context.h" - -namespace ge { -class KernelLaunchInfoImpl; -using KernelLaunchInfoImplPtr = std::unique_ptr; -class KernelLaunchInfo { - public: - ~KernelLaunchInfo(); - KernelLaunchInfo(const KernelLaunchInfo &other); - KernelLaunchInfo(KernelLaunchInfo &&other) noexcept; - KernelLaunchInfo &operator=(const KernelLaunchInfo &other); - KernelLaunchInfo &operator=(KernelLaunchInfo &&other) noexcept; - - /** - * 从字符串中加载算子的Launch信息 - * @param context gentask callback函数的入参,保存了算子的基础信息 - * @param data 算子launch信息的序列化数据流 - * @return KernelLaunchInfo对象,保存了算子的Launch信息 - */ - static KernelLaunchInfo LoadFromData(const gert::ExeResGenerationContext *context, - const std::vector &data); - /** - * 创建一个Aicpu通信算子Task - * @param context gentask callback函数的入参,保存了算子的基础信息 - * @param so_name aicpu算子的so名字 - * @param kernel_name aicpu算子的入口函数名字 - * @return KernelLaunchInfo对象,保存了算子的Launch信息 - */ - static KernelLaunchInfo CreateAicpuKfcTask(const gert::ExeResGenerationContext *context, - const char *so_name, const char *kernel_name); - /** - * 创建一个Record Task,用于唤醒相同group_name的Wait Task - * @param context gentask callback函数的入参,保存了算子的基础信息 - * @param group_name Record task的分组名字,默认为group - * @return KernelLaunchInfo对象,保存了算子的Launch信息 - */ - static KernelLaunchInfo CreateHcomRecordTask(const gert::ExeResGenerationContext *context, - const char *group_name = "group"); - /** - * 创建一个Wait Task,用于阻塞当前流,当有相同group_name的Record Task被执行时,解除阻塞 - * @param context gentask callback函数的入参,保存了算子的基础信息 - * @param group_name Wait task的分组名字,默认为group - * @return KernelLaunchInfo对象,保存了算子的Launch信息 - */ - static KernelLaunchInfo CreateHcomWaitTask(const gert::ExeResGenerationContext *context, - const char *group_name = "group"); - /** - * 将KernelLaunchInfo序列化成数据流 - * @return 被序列化后的数据流 - */ - std::vector Serialize(); - /** - * 获取当前task所在流的id - * @return 当KernelLaucnhInfo合法时,返回当前task所在流的id(默认值为0),非法时返回int32_max - */ - uint32_t GetStreamId() const; - /** - * 设置task的流id - * @param stream_id 流id - */ - void SetStreamId(uint32_t stream_id); - /** - * 获取算子blockdim - * @return 当KernelLaucnhInfo合法时,返回当前算子的blockdim(默认值为0),非法时返回int32_max - */ - uint32_t GetBlockDim() const; - /** - * 设置blockdim - * @param block_dim 算子blockdim - * @return SUCCESS: 设置成功,其他:设置失败报错 - */ - graphStatus SetBlockDim(uint32_t block_dim); - /** - * 获取当前task的args_format, args_format信息是args内存的语义化表达,用户通过拼接一个argsFormat内容,告诉框架如何排布args内存, - * 只有aicpu和aicore算子有argsformat信息 - * @return 算子的args_format被设置时,返回args_format的序列化字符串,未设置时返回nullptr - */ - const char *GetArgsFormat() const; - /** - * 设置当前task的args_format, args_format信息是args内存的语义化表达,用户通过拼接一个argsFormat内容,告诉框架如何排布args内存, - * 只有aicpu和aicore算子有argsformat信息 - * @param args_format 算子的args_format信息 - * @return SUCCESS: 设置成功,其他:设置失败报错 - */ - graphStatus SetArgsFormat(const char *args_format); - /** - * 获取当前task的so_name, 只有aicpu算子可以获取到 - * @return 算子的so_name被设置时,返回so_name的字符串,未设置时返回nullptr - */ - const char *GetSoName() const; - /** - * 获取当前task的kernel_name, 只有aicpu算子可以获取到 - * @return 算子的kernel_name被设置时,返回kernel_name的字符串,未设置时返回nullptr - */ - const char *GetKernelName() const; - private: - KernelLaunchInfo() = delete; - explicit KernelLaunchInfo(KernelLaunchInfoImplPtr &&impl); - std::unique_ptr impl_; -}; -} -#endif // METADEF_INC_EXTERNAL_GRAPH_KERNEL_LAUNCH_INFO_H \ No newline at end of file diff --git a/inc/external/graph/operator_reg.h b/inc/external/graph/operator_reg.h deleted file mode 100644 index 6514c6568aa7feb2ba4b9f989aec4bf038d2d2ee..0000000000000000000000000000000000000000 --- a/inc/external/graph/operator_reg.h +++ /dev/null @@ -1,736 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ -#define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ - -#include -#include -#include -#include - -#include "graph/operator.h" -#include "graph/operator_factory.h" -#include "graph/tensor.h" -#include "graph/types.h" -#include "graph/graph.h" - -#if defined(__GNUC__) || defined(__clang__) -#define FORCE_INLINE __attribute__((always_inline)) -#elif defined(_MSC_VER) || defined(__INTEL_COMPILER) -#define FORCE_INLINE __forceinline -#elif defined(__IBMCPP__) -#define FORCE_INLINE __inline(always) -#else -#define FORCE_INLINE inline -#endif - -template -ge::AscendString ConvertToAscendString(T str); - -template<> -inline ge::AscendString ConvertToAscendString(const char *str) { - return ge::AscendString(str); -} - -template<> -inline ge::AscendString ConvertToAscendString(std::string str) { - return ge::AscendString(str.c_str()); -} - -template<> -inline ge::AscendString ConvertToAscendString(ge::AscendString str) { - return str; -} - -template -std::vector ConvertToListAscendString(T strs); - -template<> -inline std::vector ConvertToListAscendString(std::vector strs) { - std::vector ascend_strs(strs.size()); - for (size_t i = 0; i < strs.size(); ++i) { - ascend_strs[i] = ge::AscendString(strs[i].c_str()); - } - return ascend_strs; -} - -template<> -inline std::vector ConvertToListAscendString(std::vector strs) { - return strs; -} -namespace ge { -using std::function; -using std::string; -using std::vector; - -#define ATTR_String(x, ...) \ - graphStatus get_attr_##x(AscendString &ret) const { \ - std::string ret_str = __VA_ARGS__; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - ret = AscendString(ret_str.c_str()); \ - } \ - return GRAPH_SUCCESS; \ - } \ - _THIS_TYPE &set_attr_##x(const char *v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { \ - (void) v; \ - return *this; \ - } - -#define ATTR_ListString(x, ...) \ - graphStatus get_attr_##x(std::vector &ret) const { \ - std::vector ret_strs = __VA_ARGS__; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - for (auto &ret_str : ret_strs) { \ - ret.emplace_back(ret_str.c_str()); \ - } \ - } \ - return GRAPH_SUCCESS; \ - } \ - _THIS_TYPE &set_attr_##x(const std::vector &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function()> &v) { \ - (void) v; \ - return *this; \ - } - -#define ATTR_AscendString(x, ...) \ - graphStatus get_attr_##x(AscendString &ret) const { \ - AscendString ret_str = __VA_ARGS__; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - ret = AscendString(ret_str); \ - } \ - return GRAPH_SUCCESS; \ - } - -#define ATTR_ListAscendString(x, ...) \ - graphStatus get_attr_##x(std::vector &ret) const { \ - std::vector ret_strs = __VA_ARGS__; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - for (auto &ret_str : ret_strs) { \ - if (ret_str.GetString() != nullptr) { \ - ret.emplace_back(ret_str.GetString()); \ - } \ - } \ - } \ - return GRAPH_SUCCESS; \ - } - -#define ATTR_Int(x, ...) -#define ATTR_Float(x, ...) -#define ATTR_Bool(x, ...) -#define ATTR_Tensor(x, ...) -#define ATTR_Type(x, ...) -#define ATTR_NamedAttrs(x, ...) -#define ATTR_ListInt(x, ...) -#define ATTR_ListFloat(x, ...) -#define ATTR_ListBool(x, ...) -#define ATTR_ListTensor(x, ...) -#define ATTR_Bytes(x, ...) -#define ATTR_ListListInt(x, ...) -#define ATTR_ListType(x, ...) -#define ATTR_ListNamedAttrs(x, ...) - -#define SET_VALUE_String(x) auto value = ConvertToAscendString(x) -#define SET_VALUE_AscendString(x) auto value = ConvertToAscendString(x) - -#define SET_VALUE_ListString(x) \ - auto input = (x); \ - std::vector value = ConvertToListAscendString(input) - -#define SET_VALUE_ListAcendString(x) \ - auto input = (x); \ - std::vector value = ConvertToListAscendString(input) - -#define SET_VALUE_ListAscendString(x) \ - auto input = (x); \ - std::vector value = ConvertToListAscendString(input) - -#define SET_VALUE_Int(x) auto value = (x) -#define SET_VALUE_Float(x) auto value = (x) -#define SET_VALUE_Bool(x) auto value = (x) -#define SET_VALUE_Tensor(x) auto value = (x) -#define SET_VALUE_Type(x) auto value = (x) -#define SET_VALUE_NamedAttrs(x) auto value = (x) -#define SET_VALUE_ListInt(x) auto value = (x) -#define SET_VALUE_ListFloat(x) auto value = (x) -#define SET_VALUE_ListBool(x) auto value = (x) -#define SET_VALUE_ListTensor(x) auto value = (x) -#define SET_VALUE_Bytes(x) auto value = (x) -#define SET_VALUE_ListListInt(x) auto value = (x) -#define SET_VALUE_ListType(x) auto value = (x) -#define SET_VALUE_ListNamedAttrs(x) auto value = (x) - -#define REQUIRED_ATTR_String(x) \ - graphStatus get_attr_##x(AscendString &ret) const { \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } \ - _THIS_TYPE &set_attr_##x(const char *v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { \ - (void) v; \ - return *this; \ - } - -#define REQUIRED_ATTR_ListString(x) \ - graphStatus get_attr_##x(std::vector &ret) const { \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } \ - _THIS_TYPE &set_attr_##x(const std::vector &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function()> &v) { \ - (void) v; \ - return *this; \ - } - -#define REQUIRED_ATTR_AscendString(x) \ - graphStatus get_attr_##x(AscendString &ret) const { \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } - -#define REQUIRED_ATTR_ListAscendString(x) \ - graphStatus get_attr_##x(std::vector &ret) const { \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return GRAPH_FAILED; \ - } \ - return GRAPH_SUCCESS; \ - } - -#define REQUIRED_ATTR_Int(x) -#define REQUIRED_ATTR_Float(x) -#define REQUIRED_ATTR_Bool(x) -#define REQUIRED_ATTR_Tensor(x) -#define REQUIRED_ATTR_Type(x) -#define REQUIRED_ATTR_NamedAttrs(x) -#define REQUIRED_ATTR_ListInt(x) -#define REQUIRED_ATTR_ListFloat(x) -#define REQUIRED_ATTR_ListBool(x) -#define REQUIRED_ATTR_ListTensor(x) -#define REQUIRED_ATTR_Bytes(x) -#define REQUIRED_ATTR_ListListInt(x) -#define REQUIRED_ATTR_ListType(x) -#define REQUIRED_ATTR_ListNamedAttrs(x) - -class OpReg { - public: - OpReg &N() { - return *this; - } - - OpReg &ATTR() { - return *this; - } - - OpReg &REQUIRED_ATTR() { - return *this; - } - - OpReg &INPUT() { - return *this; - } - - OpReg &OPTIONAL_INPUT() { - return *this; - } - - OpReg &OUTPUT() { - return *this; - } - - OpReg &GRAPH() { - return *this; - } - - OpReg &DYNAMIC_GRAPH() { - return *this; - } - - OpReg &INFER_SHAPE_AND_TYPE() { - return *this; - } -}; - -#define REG_OP(x) \ - namespace op { \ - class x : public Operator { \ - typedef x _THIS_TYPE; \ - \ - public: \ - ATTRIBUTED_DEPRECATED(x(const char *)) \ - explicit FORCE_INLINE x(const std::string &name) : Operator(name.c_str(), #x) { \ - __##x(); \ - } \ - explicit FORCE_INLINE x(const char *name) : Operator(name, #x) { \ - __##x(); \ - } \ - explicit FORCE_INLINE x(const AscendString &name) : Operator(name, #x) { \ - __##x(); \ - } \ - FORCE_INLINE x() : Operator(#x) { \ - __##x(); \ - } \ - \ - private: \ - void FORCE_INLINE __##x() { \ - OpReg() - -#define ATTR(x, Type, ...) \ - N(); \ - __attr_##x(); \ - } \ - \ - public: \ - ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \ - static const std::string name_attr_##x() { \ - return #x; \ - } \ - static void name_attr_##x(AscendString &attr) { \ - attr = AscendString(#x); \ - } \ - ATTR_##Type(x, __VA_ARGS__) Op##Type get_attr_##x() const { \ - Op##Type ret = __VA_ARGS__; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return ret; \ - } \ - return ret; \ - } \ - _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { \ - (void) v; \ - return *this; \ - } \ - \ - private: \ - void FORCE_INLINE __attr_##x() { \ - SET_VALUE_##Type(Op##Type(__VA_ARGS__)); \ - Operator::AttrRegister(#x, value); \ - std::string attr_name(#x); \ - (void) OpReg() - -#define REQUIRED_ATTR(x, Type) \ - N(); \ - __required_attr_##x(); \ - } \ - \ - public: \ - ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \ - static const std::string name_attr_##x() { \ - return #x; \ - } \ - static void name_attr_##x(AscendString &attr_name) { \ - attr_name = AscendString(#x); \ - } \ - REQUIRED_ATTR_##Type(x) Op##Type get_attr_##x() const { \ - Op##Type ret; \ - if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ - return ret; \ - } \ - return ret; \ - } \ - _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ - Operator::SetAttr(#x, v); \ - return *this; \ - } \ - _THIS_TYPE &set_attr_##x(const function &v) { \ - (void) v; \ - return *this; \ - } \ - \ - private: \ - void FORCE_INLINE __required_attr_##x() { \ - Operator::RequiredAttrWithTypeRegister(#x, #Type); \ - std::string attr_name(#x); \ - (void) OpReg() - -#define DATATYPE(x, t) \ - N(); \ - __datatype_##x(); \ - } \ - \ - private: \ - void FORCE_INLINE __datatype_##x() { \ - auto type_range = t; \ - Operator::DataTypeRegister(#x, type_range); \ - (void) OpReg() - -#define INPUT(x, t) \ - N(); \ - __input_##x(); \ - } \ - \ - public: \ - ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \ - static const std::string name_in_##x() { \ - return #x; \ - } \ - static void name_in_##x(AscendString &name) { \ - name = AscendString(#x); \ - } \ - ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \ - _THIS_TYPE &set_input_##x(Operator &v, const std::string &srcName) { \ - Operator::SetInput(#x, v, srcName.c_str()); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \ - Operator::SetInput(#x, v, srcName); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ - Operator::SetInput(#x, v, index); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v) { \ - Operator::SetInput(#x, v); \ - return *this; \ - } \ - TensorDesc get_input_desc_##x() const { \ - return Operator::GetInputDescByName(#x); \ - } \ - graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateInputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void FORCE_INLINE __input_##x() { \ - Operator::InputRegister(#x, #t); \ - (void) OpReg() - -#define OPTIONAL_INPUT(x, t) \ - N(); \ - __optional_input_##x(); \ - } \ - \ - public: \ - ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \ - static const std::string name_in_##x() { \ - return #x; \ - } \ - static void name_in_##x(AscendString &name) { \ - name = AscendString(#x); \ - } \ - _THIS_TYPE &set_input_##x(Operator &v) { \ - Operator::SetInput(#x, v); \ - return *this; \ - } \ - ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \ - _THIS_TYPE &set_input_##x(Operator &v, const std::string &srcName) { \ - Operator::SetInput(#x, v, srcName.c_str()); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \ - Operator::SetInput(#x, v, srcName); \ - return *this; \ - } \ - _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ - Operator::SetInput(#x, v, index); \ - return *this; \ - } \ - TensorDesc get_input_desc_##x() const { \ - return Operator::GetInputDescByName(#x); \ - } \ - graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateInputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void FORCE_INLINE __optional_input_##x() { \ - Operator::OptionalInputRegister(#x, #t); \ - (void) OpReg() - -#define OUTPUT(x, t) \ - N(); \ - __out_##x(); \ - } \ - \ - public: \ - ATTRIBUTED_DEPRECATED(static const void name_out_##x(AscendString &)) \ - static const std::string name_out_##x() { \ - return #x; \ - } \ - static void name_out_##x(AscendString &name) { \ - name = AscendString(#x); \ - } \ - TensorDesc get_output_desc_##x() const { \ - return Operator::GetOutputDescByName(#x); \ - } \ - graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \ - return Operator::UpdateOutputDesc(#x, tensorDesc); \ - } \ - \ - private: \ - void FORCE_INLINE __out_##x() { \ - Operator::OutputRegister(#x, #t); \ - (void) OpReg() - -#define DYNAMIC_INPUT(x, t) \ - N(); \ - __dy_input_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ - Operator::DynamicInputRegister(#x, num, #t, isPushBack); \ - return *this; \ - } \ - _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ - Operator::DynamicInputRegisterByIndex(#x, num, index); \ - return *this; \ - } \ - TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { \ - return Operator::GetDynamicInputDesc(#x, index); \ - } \ - graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ - } \ - _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ - Operator::SetInput(#x, dstIndex, v); \ - return *this; \ - } \ - ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_dynamic_input_##x(uint32_t, Operator &, const char *)) \ - _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const std::string &srcName) { \ - Operator::SetInput(#x, dstIndex, v, srcName.c_str()); \ - return *this; \ - } \ - _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const char *srcName) { \ - Operator::SetInput(#x, dstIndex, v, srcName); \ - return *this; \ - } \ - \ - private: \ - void FORCE_INLINE __dy_input_##x() { \ - Operator::DynamicInputRegister(#x, 0, #t, true); \ - (void) OpReg() - -#define DYNAMIC_OUTPUT(x, t) \ - N(); \ - __dy_output_##x(); \ - } \ - \ - public: \ - _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ - Operator::DynamicOutputRegister(#x, num, #t, isPushBack); \ - return *this; \ - } \ - TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { \ - return Operator::GetDynamicOutputDesc(#x, index); \ - } \ - graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ - return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ - } \ - \ - private: \ - void FORCE_INLINE __dy_output_##x() { \ - Operator::DynamicOutputRegister(#x, 0, #t, true); \ - (void) OpReg() - -#define GRAPH(x) \ - N(); \ - __graph_##x(); \ - } \ - \ - public: \ - ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \ - static const std::string name_graph_##x() { \ - return #x; \ - } \ - static void name_graph_##x(AscendString &name) { \ - name = AscendString(#x); \ - } \ - SubgraphBuilder get_subgraph_builder_##x() const { \ - return Operator::GetSubgraphBuilder(#x); \ - } \ - _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ - Operator::SetSubgraphBuilder(#x, 0, v); \ - return *this; \ - } \ - Graph get_subgraph_##x() const { \ - return Operator::GetSubgraph(#x); \ - } \ - \ - private: \ - void FORCE_INLINE __graph_##x() { \ - Operator::SubgraphRegister(#x, false); \ - Operator::SubgraphCountRegister(#x, 1); \ - (void) OpReg() - -#define DYNAMIC_GRAPH(x) \ - N(); \ - __graph_##x(); \ - } \ - \ - public: \ - ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \ - static const std::string name_graph_##x() { \ - return #x; \ - } \ - static void name_graph_##x(AscendString &name) { \ - name = AscendString(#x); \ - } \ - _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ - Operator::SubgraphCountRegister(#x, num); \ - return *this; \ - } \ - SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ - return Operator::GetDynamicSubgraphBuilder(#x, index); \ - } \ - Graph get_dynamic_subgraph_##x(uint32_t index) const { \ - return Operator::GetDynamicSubgraph(#x, index); \ - } \ - _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ - Operator::SetSubgraphBuilder(#x, index, v); \ - return *this; \ - } \ - \ - private: \ - void FORCE_INLINE __graph_##x() { \ - Operator::SubgraphRegister(#x, true); \ - (void) OpReg() - -#define PASTE(g_register, y) g_register##y - -#define __OP_END_IMPL_WITHOUT_REGISTER__(x) \ - N(); \ - } \ - static_assert( \ - std::is_same::value, \ - "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ - } \ - ; \ - } - -#ifdef DISABLE_COMPILE_V1 -#define __OP_END_IMPL__(x, y) \ - __OP_END_IMPL_WITHOUT_REGISTER__(x) -#else -#define __OP_END_IMPL__(x, y) \ - N(); \ - } \ - static_assert( \ - std::is_same::value, \ - "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ - } \ - ; \ - static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const AscendString &name) { return x(name); }); \ - } -#endif -#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) - -// Specialized shape inferencer macro - -#define IMPLEMT_INFERFUNC(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) - -#define IMPLEMT_COMMON_INFERFUNC(func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op) - -#define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) - -// Specialized verifier macro - -#define IMPLEMT_VERIFIER(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op) - -#define INFER_VERIFY_FUNC(op_name, x) [](Operator &v) { return x((op::op_name &) v); } - -#define COMMON_INFER_VERIFY_FUNC(x) [](Operator &v) { return x(v); } - -#define INFER_FORMAT_FUNC(op_name, x) [](Operator &v) { return x((op::op_name &) v); } - -#define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x) -#define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x) - -// Infer format func register -#define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \ - static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x) - -// Shape inferencer & verifier register macro - -#define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) - -#define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC((x)), __COUNTER__) - -#define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) - -// Value Range Infer -#define INFER_VALUE_RANGE_FUNC(op_name, x) [](Operator &v) { return x((op::op_name &) v); } - -#define INFER_VALUE_RANGE_DEFAULT_REG(op_name) __INFER_VALUE_RANGE_DEFAULT_REG_IMPL__(op_name, __COUNTER__) -#define __INFER_VALUE_RANGE_DEFAULT_REG_IMPL__(op_name, n) \ - static const InferValueRangeFuncRegister PASTE(iv_reg_default, n)(#op_name) - -#define INFER_VALUE_RANGE_CUSTOM_FUNC_REG(op_name, when_call, x) \ - __INFER_VALUE_RANGE_CUSTOM_FUNC_REG_IMPL__(op_name, when_call, INFER_VALUE_RANGE_FUNC(op_name, x), __COUNTER__) -#define __INFER_VALUE_RANGE_CUSTOM_FUNC_REG_IMPL__(op_name, when_call, x, n) \ - static const InferValueRangeFuncRegister PASTE(iv_reg_custom, n)(#op_name, when_call, x) - -#define IMPL_INFER_VALUE_RANGE_FUNC(op_name, func_name) \ - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) - -// Infer format func reg -#define INFER_FORMAT_FUNC_REG(op_name, x) \ - __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__) - -// Common shape inferencer - -#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ - [](Operator op) -> graphStatus { \ - auto x_input_desc = op.GetInputDescByName(in_name); \ - auto x_shape = x_input_desc.GetShape().GetDims(); \ - auto x_type = x_input_desc.GetDataType(); \ - std::vector> x_shape_range; \ - (void) x_input_desc.GetShapeRange(x_shape_range); \ - TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \ - op_output_desc.SetShape(ge::Shape(x_shape)); \ - op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ - op_output_desc.SetDataType(x_type); \ - if (!x_shape_range.empty()) { \ - op_output_desc.SetShapeRange(x_shape_range); \ - } \ - return op.UpdateOutputDesc(out_name, op_output_desc); \ - } - -graphStatus BroadCastInfer(const function()> &get_in1_shape, - const function()> &get_in2_shape, - const function &y_shape)> &set_out_shape); - -#define BROADCAST_INFER(in1_name, in2_name, out_name) \ - [](Operator op) -> graphStatus { \ - return BroadCastInfer([&]() { return op.GetInputDescByName(in1_name).GetShape().GetDims(); }, \ - [&]() { return op.GetInputDescByName(in2_name).GetShape().GetDims(); }, \ - [&](const std::vector &y_shape) { \ - TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \ - op_output_desc.SetShape(ge::Shape(y_shape)); \ - (void) op.UpdateOutputDesc(out_name, op_output_desc); \ - }); \ - } -} // namespace ge -#endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ diff --git a/inc/external/graph/utils/args_format_desc_utils.h b/inc/external/graph/utils/args_format_desc_utils.h deleted file mode 100644 index a4ca245fd3add888199bb53476a8435050c7e94a..0000000000000000000000000000000000000000 --- a/inc/external/graph/utils/args_format_desc_utils.h +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef METADEF_CXX_ARGS_FORMAT_DESC_UTILS_H -#define METADEF_CXX_ARGS_FORMAT_DESC_UTILS_H - -#include -#include - -#include "graph/ge_error_codes.h" -#include "graph/op_desc.h" -#include "register/hidden_inputs_func_registry.h" - -namespace ge { -enum class AddrType { - INPUT = 0, - OUTPUT, - WORKSPACE, - TILING, - INPUT_DESC, - OUTPUT_DESC, - FFTS_ADDR, - OVERFLOW_ADDR, - TILING_FFTS, - HIDDEN_INPUT, - TILING_CONTEXT, - OP_TYPE, - PLACEHOLDER, - CUSTOM_VALUE, - INPUT_INSTANCE, - OUTPUT_INSTANCE, - SUPER_KERNEL_SUB_NODE, - EVENT_ADDR, - MAX // the end, add new value before MAX. -}; - -enum class TilingContextSubType { - TILING_CONTEXT = -1, - TILING_DATA, - TILING_KEY, - BLOCK_DIM, - MAX // the end, add new value before MAX. -}; - -// i* -> ir_idx = -1,folded=false -// 对于输入输出,idx表示ir定义的idx,-1表示所有输入、所有输出,此时非动态输入、输出默认展开,动态输出要i1*这样才表示展开 -// 对于workspace -1表示个数未知,folded暂时无意义 -// 对ffts尾块非尾块地址,idx=0表示非尾块,idx=1表示尾块 -// 对于hidden input,支持多个,idx表示索引,从0开始,reserved字段表示类型(uint32) -// 对于custom value,reserved字段表示需要透传的值(uint64),其他字段无意义 -// 对于其他类型, idx和fold暂时没有意义 -struct ArgDesc { - AddrType addr_type; - int32_t ir_idx; - bool folded; - uint8_t reserved[8]; -}; -static_assert(std::is_standard_layout::value, "The class ArgDesc must be a POD"); - -class ArgsFormatDescUtils { - public: - static void Append(std::vector &arg_descs, AddrType type, int32_t ir_idx = -1, bool folded = false); - - static void AppendTilingContext(std::vector &arg_descs, - TilingContextSubType sub_type = TilingContextSubType::TILING_CONTEXT); - - // insert_pos为插入位置,-1表示添加到最后,0表示添加到最前面,以此类推,注意不能超过arg_descs的个数,否则会报错 - // input_cnt为插入hidden input的个数 - static graphStatus InsertHiddenInputs(std::vector &arg_descs, int32_t insert_pos, - HiddenInputsType hidden_type, size_t input_cnt = 1U); - - // insert_pos为插入位置,-1表示添加到最后,0表示添加到最前面,以此类推,注意不能超过arg_descs的个数,否则会报错 - static graphStatus InsertCustomValue(std::vector &arg_descs, int32_t insert_pos, uint64_t custom_value); - - static std::string ToString(const std::vector &arg_descs); - - // 字符串用i*这样的通配符时,返回的argDesc不会按照实际个数展开 - static graphStatus Parse(const std::string &str, std::vector &arg_descs); - - static std::string Serialize(const std::vector &arg_descs); -}; -} - -#endif // METADEF_CXX_ARGS_FORMAT_DESC_UTILS_H diff --git a/inc/external/hcom/hcom_topo_info.h b/inc/external/hcom/hcom_topo_info.h deleted file mode 100644 index 1e1c44a6d58d6fbff9e856fdfe8cfb931ca311a5..0000000000000000000000000000000000000000 --- a/inc/external/hcom/hcom_topo_info.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ -#define METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ - -#include -#include -#include "ge_common/ge_api_types.h" - -namespace ge { -static constexpr uint32_t COMM_MESH = 0b1U; -static constexpr uint32_t COMM_SWITCH = (COMM_MESH << 1U); -static constexpr uint32_t COMM_RING = (COMM_MESH << 2U); -static constexpr uint32_t COMM_PAIRWISE = (COMM_MESH << 3U); -class HcomTopoInfo { - public: - enum class TopoLevel { - L0 = 0, - L1, - MAX, - }; - struct TopoLevelDesc { - uint32_t comm_sets; - uint32_t rank_size; - }; - using TopoDescs = TopoLevelDesc[static_cast(TopoLevel::MAX)]; - struct TopoInfo { - int64_t rank_size; - void *notify_handle; - TopoDescs topo_level_descs; - }; - static HcomTopoInfo &Instance(); - bool TopoInfoHasBeenSet(const char_t *group); - bool TryGetGroupTopoInfo(const char_t *group, TopoInfo &info); - Status SetGroupTopoInfo(const char_t *group, const TopoInfo &info); - Status GetGroupRankSize(const char_t *group, int64_t &rank_size); - TopoDescs *GetGroupTopoDesc(const char_t *group); - Status GetGroupNotifyHandle(const char_t *group, void *¬ify_handle); - void UnsetGroupTopoInfo(const char_t *group) { - const std::lock_guard lock(mutex_); - (void) rank_info_.erase(group); - } - - Status SetGroupOrderedStream(const int32_t device_id, const char_t *group, void *stream); - Status GetGroupOrderedStream(const int32_t device_id, const char_t *group, void *&stream); - void UnsetGroupOrderedStream(const int32_t device_id, const char_t *group); - private: - HcomTopoInfo() = default; - ~HcomTopoInfo() = default; - std::unordered_map rank_info_; - std::mutex mutex_; - std::unordered_map> device_id_to_group_to_ordered_stream_; // 通信域保序流 -}; -} - -#endif // METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ diff --git a/inc/external/op_common/common_infershape_fns.h b/inc/external/op_common/common_infershape_fns.h deleted file mode 100644 index 35790d7921e6adb90c475ab621f2bc6c1a95fc28..0000000000000000000000000000000000000000 --- a/inc/external/op_common/common_infershape_fns.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -/*! - * \file common_infershape_fns.h - * \brief - */ - -#ifndef EXTERNAL_OP_COMMON_INFERSHAPE_FNS_H_ -#define EXTERNAL_OP_COMMON_INFERSHAPE_FNS_H_ - -#include "external/exe_graph/runtime/shape.h" -#include "external/exe_graph/runtime/infer_shape_context.h" - -namespace opcommon { -ge::graphStatus InferShape4BroadcastOp(gert::InferShapeContext* context); -ge::graphStatus InferShape4ReduceOp(gert::InferShapeContext* context); -ge::graphStatus InferShape4ElewiseOp(gert::InferShapeContext* context); -} // namespace opcommon - -#endif // EXTERNAL_OP_COMMON_INFERSHAPE_FNS_H_ diff --git a/inc/external/op_common/data_type_utils.h b/inc/external/op_common/data_type_utils.h deleted file mode 100644 index 9551e561a34ca5effe88e30dea80a012903726b6..0000000000000000000000000000000000000000 --- a/inc/external/op_common/data_type_utils.h +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -/*! - * \file data_type_utils.h - * \brief - */ - -#ifndef EXTERNAL_OP_COMMON_DATA_TYPE_UTILS_H_ -#define EXTERNAL_OP_COMMON_DATA_TYPE_UTILS_H_ - -#include "graph/types.h" - -namespace opcommon { -inline bool IsComplexType(const ge::DataType type) { - return (type == ge::DataType::DT_COMPLEX32 || type == ge::DataType::DT_COMPLEX64 || - type == ge::DataType::DT_COMPLEX128); -} - -inline bool IsFloatingType(const ge::DataType type) -{ - return (type == ge::DataType::DT_DOUBLE || type == ge::DataType::DT_FLOAT || type == ge::DataType::DT_BF16 - || type == ge::DataType::DT_FLOAT16); -} - -inline bool IsIntegralType(const ge::DataType type) -{ - return (type == ge::DataType::DT_INT8 || type == ge::DataType::DT_INT16 || type == ge::DataType::DT_INT32 - || type == ge::DataType::DT_INT64 || type == ge::DataType::DT_UINT8 || type == ge::DataType::DT_UINT16 - || type == ge::DataType::DT_UINT32 || type == ge::DataType::DT_UINT64); -} - -inline bool IsIntegralType(const ge::DataType type, const bool include_bool) -{ - bool is_integral = IsIntegralType(type); - return include_bool ? (is_integral || (type == ge::DataType::DT_BOOL)) : is_integral; -} - -inline bool CanCast(const ge::DataType from, const ge::DataType to) -{ - if (IsComplexType(from) && !IsComplexType(to)) { - return false; - } - - if (IsFloatingType(from) && IsIntegralType(to, false)) { - return false; - } - - if (from != ge::DataType::DT_BOOL && to == ge::DataType::DT_BOOL) { - return false; - } - - return true; -} - -inline ge::DataType PromoteType(ge::DataType type_a, ge::DataType type_b) -{ - if (type_a < 0 || type_b < 0 || type_a >= ge::DataType::DT_MAX || type_b >= ge::DataType::DT_MAX) { - return ge::DataType::DT_UNDEFINED; - } - - if (type_a == type_b) { - return type_a; - } - - constexpr auto u1 = ge::DataType::DT_UINT8; - constexpr auto i1 = ge::DataType::DT_INT8; - constexpr auto i2 = ge::DataType::DT_INT16; - constexpr auto i4 = ge::DataType::DT_INT32; - constexpr auto i8 = ge::DataType::DT_INT64; - constexpr auto f2 = ge::DataType::DT_FLOAT16; - constexpr auto f4 = ge::DataType::DT_FLOAT; - constexpr auto f8 = ge::DataType::DT_DOUBLE; - constexpr auto c2 = ge::DataType::DT_COMPLEX32; - constexpr auto c4 = ge::DataType::DT_COMPLEX64; - constexpr auto c8 = ge::DataType::DT_COMPLEX128; - constexpr auto b1 = ge::DataType::DT_BOOL; - constexpr auto bf = ge::DataType::DT_BF16; - constexpr auto ud = ge::DataType::DT_UNDEFINED; - // @formatter:off - static constexpr ge::DataType kPromoteTypesLookup[static_cast( - ge::DataType::DT_MAX)][static_cast(ge::DataType::DT_MAX)] = { - /* f4 f2 i1 i4 u1 xx i2 u2 u4 i8 u8 f8 b1 sv d1 D1 c4 c8 q1 q2 q4 Q1 Q2 rs sr du va bf, ud t4 T1 t2 T2 c2*/ - /* f4 0 */ {f4, f4, f4, f4, f4, ud, f4, ud, ud, f4, ud, f8, f4, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, f4, ud, ud, ud, ud, ud, c4}, - /* f2 1 */ {f4, f2, f2, f2, f2, ud, f2, ud, ud, f2, ud, f8, f2, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, f4, ud, ud, ud, ud, ud, c2}, - /* i1 2 */ {f4, f2, i1, i4, i2, ud, i2, ud, ud, i8, ud, f8, i1, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, - /* i4 3 */ {f4, f2, i4, i4, i4, ud, i4, ud, ud, i8, ud, f8, i4, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, - /* u1 4 */ {f4, f2, i2, i4, u1, ud, i2, ud, ud, i8, ud, f8, u1, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, - /* xx 5 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* i2 6 */ {f4, f2, i2, i4, i2, ud, i2, ud, ud, i8, ud, f8, i2, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, - /* u2 7 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* u4 8 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* i8 9 */ {f4, f2, i8, i8, i8, ud, i8, ud, ud, i8, ud, f8, i8, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, - /* u8 10*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, c8, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* f8 11*/ {f8, f8, f8, f8, f8, ud, f8, ud, ud, f8, ud, f8, f8, ud, ud, ud, c8, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, f8, ud, ud, ud, ud, ud, c8}, - /* b1 12*/ {f4, f2, i1, i4, u1, ud, i2, ud, ud, i8, ud, f8, b1, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, - /* sv 13*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* d1 14*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* D1 15*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* c4 16*/ {c4, c4, c4, c4, c4, ud, c4, ud, ud, c4, ud, c8, c4, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, ud, ud, ud, ud, ud, c4}, - /* c8 17*/ {c8, c8, c8, c8, c8, ud, c8, ud, ud, c8, ud, c8, c8, ud, ud, ud, c8, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, c8, ud, ud, ud, ud, ud, c8}, - /* q1 18*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* q2 19*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* q4 20*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* Q1 21*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* Q2 22*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* rs 23*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* sr 24*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* du 25*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* va 26*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* bf 27*/ {f4, f4, bf, bf, bf, ud, bf, ud, ud, bf, ud, f8, bf, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c4}, - /* ud 28*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* t4 29*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* T1 30*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* t2 31*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* T2 32*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, - /* c2 33*/ {c4, c2, c2, c2, c2, ud, c2, ud, ud, c2, ud, c8, c2, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, ud, ud, ud, ud, ud, c2}, - }; - // @formatter:on - return kPromoteTypesLookup[static_cast(type_a)][static_cast(type_b)]; -} -} // namespace opcommon - -#endif // EXTERNAL_OP_COMMON_DATA_TYPE_UTILS_H_ diff --git a/inc/external/op_common/op_dev.h b/inc/external/op_common/op_dev.h deleted file mode 100644 index 780bd8e2325a9e735445b21357588b2027be77a2..0000000000000000000000000000000000000000 --- a/inc/external/op_common/op_dev.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -/*! - * \file op_dev.h - * \brief - */ - -#ifndef EXTERNAL_OP_COMMON_OP_DEV_H_ -#define EXTERNAL_OP_COMMON_OP_DEV_H_ - -#include "common_infershape_fns.h" -#include "op_error_code.h" - -#endif // EXTERNAL_OP_COMMON_OP_DEV_H_ diff --git a/inc/external/op_common/op_error_code.h b/inc/external/op_common/op_error_code.h deleted file mode 100644 index b329a1ee2fca451f76bf0b3feeb58e264dd7a2ec..0000000000000000000000000000000000000000 --- a/inc/external/op_common/op_error_code.h +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -/*! - * \file error_code.h - * \brief - */ -#ifndef EXTERNAL_OP_COMMON_ERROR_CODE_H_ -#define EXTERNAL_OP_COMMON_ERROR_CODE_H_ - -namespace opcommon { -enum ViewErrorCode { - VECTOR_INNER_ERROR = 89999 -}; -} // namespace opcommon - -#endif // EXTERNAL_OP_COMMON_ERROR_CODE_H_ diff --git a/inc/external/op_common/tiling_aicpu_task.h b/inc/external/op_common/tiling_aicpu_task.h deleted file mode 100644 index 6d91d1f017f341f50ea17d89d80ad07adbaa1e49..0000000000000000000000000000000000000000 --- a/inc/external/op_common/tiling_aicpu_task.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef EXTERNAL_OP_COMMON_TILING_AICPU_TASK_H_ -#define EXTERNAL_OP_COMMON_TILING_AICPU_TASK_H_ -#include "exe_graph/runtime/tiling_context.h" - -namespace optiling { -struct TilingAicpuTask { - gert::TilingContext *tilingContext; - const char *opType; - char reserve[64]; -}; -} // namespace optiling - -#endif \ No newline at end of file diff --git a/inc/external/register/device_op_impl_registry.h b/inc/external/register/device_op_impl_registry.h deleted file mode 100644 index c13c2a234b52e13768304f390aca01f9a6b4c5b0..0000000000000000000000000000000000000000 --- a/inc/external/register/device_op_impl_registry.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -/*! - * \file device_op_impl_registry.h - * \brief - */ - -#ifndef REGISTER_DEVICE_OP_IMPL_REGISTRY_H -#define REGISTER_DEVICE_OP_IMPL_REGISTRY_H -#include -#include "graph/compiler_def.h" -#include "exe_graph/runtime/tiling_context.h" - -namespace optiling { -using SinkTilingFunc = std::function; - -class DeviceOpImplRegisterImpl; -class DeviceOpImplRegister { -public: - DeviceOpImplRegister(const char *opType); - ~DeviceOpImplRegister(); - DeviceOpImplRegister(DeviceOpImplRegister &&other) noexcept; - DeviceOpImplRegister(const DeviceOpImplRegister &other); - DeviceOpImplRegister &operator=(const DeviceOpImplRegister &) = delete; - DeviceOpImplRegister &operator=(DeviceOpImplRegister &&) = delete; - DeviceOpImplRegister &Tiling(SinkTilingFunc func); - -private: - std::unique_ptr impl_; -}; -} // namespace optiling - -#define DEVICE_IMPL_OP_OPTILING(optype) \ - static optiling::DeviceOpImplRegister VAR_UNUSED g_deviceOpImplRegister##optype = \ - optiling::DeviceOpImplRegister(#optype) -#endif \ No newline at end of file diff --git a/inc/external/register/ffts_node_calculater_registry.h b/inc/external/register/ffts_node_calculater_registry.h deleted file mode 100644 index 4e8b2ec277ad297037e7ba72ac8ce352a53e00dd..0000000000000000000000000000000000000000 --- a/inc/external/register/ffts_node_calculater_registry.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CALCULATER_REGISTRY_H_ -#define AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CALCULATER_REGISTRY_H_ -#include -#include -#include -#include "graph/node.h" -#include "exe_graph/lowering/value_holder.h" -#include "exe_graph/lowering/lowering_global_data.h" - -namespace gert { -struct NodeMemPara { - size_t size {0}; - void *dev_addr {nullptr}; - void *host_addr {nullptr}; -}; - -class FFTSNodeCalculaterRegistry { - public: - /* - * Output param: - * 1.size_t total_size -- node need memory size - * 2.size_t pre_data_size -- node pre proc data size - * 3.std::unique_ptr pre_data_ptr -- node pre proc data memory with size(pre_data_size), framework will - * copy pre_data to new alloc memory - */ - using NodeCalculater = ge::graphStatus (*)(const ge::NodePtr &node, const LoweringGlobalData *global_data, - size_t &total_size, size_t &pre_data_size, std::unique_ptr &pre_data_ptr); - static FFTSNodeCalculaterRegistry &GetInstance(); - NodeCalculater FindNodeCalculater(const std::string &func_name); - void Register(const std::string &func_name, const NodeCalculater func); - - private: - std::unordered_map names_to_calculater_; -}; - -class FFTSNodeCalculaterRegister { - public: - FFTSNodeCalculaterRegister(const string &func_name, FFTSNodeCalculaterRegistry::NodeCalculater func) noexcept; -}; -} // namespace gert - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif - -#define GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER2(type, func, counter) \ - static const gert::FFTSNodeCalculaterRegister g_register_node_calculater_##counter ATTRIBUTE_USED = \ - gert::FFTSNodeCalculaterRegister(type, func) -#define GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER(type, func, counter) \ - GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER2(type, func, counter) -#define FFTS_REGISTER_NODE_CALCULATER(type, func) \ - GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER(type, func, __COUNTER__) - -#endif // AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CALCULATER_REGISTRY_H_ diff --git a/inc/external/register/hidden_input_func_registry.h b/inc/external/register/hidden_input_func_registry.h deleted file mode 100644 index 5452f9b0ec92be56c002e683b26d2ad53666978d..0000000000000000000000000000000000000000 --- a/inc/external/register/hidden_input_func_registry.h +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_HIDDEN_INPUT_FUNC_REGISTRY_H_ -#define INC_EXTERNAL_REGISTER_HIDDEN_INPUT_FUNC_REGISTRY_H_ - -#include -#include -#include "graph/op_desc.h" -namespace ge { -enum class HiddenInputType : uint32_t { HCOM }; - -using GetHiddenAddr = ge::graphStatus (*)(const ge::OpDescPtr &op_desc, void *&addr); -class HiddenInputFuncRegistry { - public: - static HiddenInputFuncRegistry &GetInstance(); - GetHiddenAddr FindHiddenInputFunc(const HiddenInputType input_type); - void Register(const HiddenInputType input_type, const GetHiddenAddr func); - - private: - std::map type_to_funcs_; -}; - -class HiddenInputFuncRegister { - public: - HiddenInputFuncRegister(const HiddenInputType input_type, const GetHiddenAddr func); -}; -} // namespace ge - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif -#define REG_HIDDEN_INPUT_FUNC(type, func) REG_HIDDEN_INPUT_FUNC_UNIQ_HELPER(type, func, __COUNTER__) -#define REG_HIDDEN_INPUT_FUNC_UNIQ_HELPER(type, func, counter) REG_HIDDEN_INPUT_FUNC_UNIQ(type, func, counter) -#define REG_HIDDEN_INPUT_FUNC_UNIQ(type, func, counter) \ - static ::ge::HiddenInputFuncRegister register_hidden_func_##counter ATTRIBUTE_USED = \ - ge::HiddenInputFuncRegister(type, func) - -#endif // INC_EXTERNAL_REGISTER_HIDDEN_INPUT_FUNC_REGISTRY_H_ diff --git a/inc/external/register/hidden_inputs_func_registry.h b/inc/external/register/hidden_inputs_func_registry.h deleted file mode 100644 index 9f9ba0e79eb13ca842d7e963ce1b2bb444a5d43f..0000000000000000000000000000000000000000 --- a/inc/external/register/hidden_inputs_func_registry.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_HIDDEN_INPUTS_FUNC_REGISTRY_H_ -#define INC_EXTERNAL_REGISTER_HIDDEN_INPUTS_FUNC_REGISTRY_H_ - -#include -#include -#include "graph/op_desc.h" -namespace ge { -// 待废弃枚举,1230废弃,不推荐使用,推荐使用HiddenInputSubType (arg_desc_info.h) -enum class HiddenInputsType : uint32_t { HCOM, TILEFWK, HCCLSUPERKERNEL, MAX }; - -using GetHiddenAddrs = ge::graphStatus (*)(const ge::OpDescPtr &op_desc, std::vector &addr); -class HiddenInputsFuncRegistry { - public: - static HiddenInputsFuncRegistry &GetInstance(); - GetHiddenAddrs FindHiddenInputsFunc(const HiddenInputsType input_type); - void Register(const HiddenInputsType input_type, const GetHiddenAddrs func); - - private: - std::map type_to_funcs_; -}; - -class HiddenInputsFuncRegister { - public: - HiddenInputsFuncRegister(const HiddenInputsType input_type, const GetHiddenAddrs func); -}; -} // namespace ge - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif -#define REG_HIDDEN_INPUTS_FUNC(type, func) REG_HIDDEN_INPUTS_FUNC_UNIQ_HELPER(type, func, __COUNTER__) -#define REG_HIDDEN_INPUTS_FUNC_UNIQ_HELPER(type, func, counter) REG_HIDDEN_INPUTS_FUNC_UNIQ(type, func, counter) -#define REG_HIDDEN_INPUTS_FUNC_UNIQ(type, func, counter) \ - static ::ge::HiddenInputsFuncRegister register_hidden_func_##counter ATTRIBUTE_USED = \ - ge::HiddenInputsFuncRegister(type, func) - -#endif // INC_EXTERNAL_REGISTER_HIDDEN_INPUTS_FUNC_REGISTRY_H_ diff --git a/inc/external/register/op_bin_info.h b/inc/external/register/op_bin_info.h deleted file mode 100644 index acec7520e8c0f2a7cf0eb864c7cb2ce8f3c485d5..0000000000000000000000000000000000000000 --- a/inc/external/register/op_bin_info.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_OP_BIN_INFO_H -#define REGISTER_OP_BIN_INFO_H - -#include -#include -#include -#include "graph/ascend_string.h" - -namespace ops { -using OpInfo = std::vector>; - -class OpBinInfo { -public: - OpBinInfo(const std::string& opType, const OpInfo& opInfo); - ~OpBinInfo(); - uint32_t Generate(ge::AscendString* opLibPath, const std::string& targetPath); - static bool Check(const std::string& path); - -private: - std::string opType_; - std::string basePath_; - const OpInfo& opInfo_; -}; - -} -#endif \ No newline at end of file diff --git a/inc/external/register/op_binary_resource_manager.h b/inc/external/register/op_binary_resource_manager.h deleted file mode 100644 index d01ed9eb2ff0e091856fd57a12b43a010a516d01..0000000000000000000000000000000000000000 --- a/inc/external/register/op_binary_resource_manager.h +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "nlohmann/json.hpp" -#include "graph/ascend_string.h" -#include "graph/ge_error_codes.h" - -#ifndef INC_EXTERNAL_REGISTER_OP_BINARY_RESOURCE_MANAGER_H_ -#define INC_EXTERNAL_REGISTER_OP_BINARY_RESOURCE_MANAGER_H_ - -namespace nnopbase { -// 二进制内容,用于内存中的二进制.o文件描述 -struct Binary { - const uint8_t *content; // 二进制内容的起始指针 - uint32_t len; // 二进制的长度 -}; - -class OpBinaryResourceManager { -public: - static OpBinaryResourceManager &GetInstance() - { - static OpBinaryResourceManager manager; - return manager; - } - - ~OpBinaryResourceManager() = default; - - void AddOpFuncHandle(const ge::AscendString &opType, const std::vector &opResourceHandle); - ge::graphStatus AddBinary(const ge::AscendString &opType, - const std::vector> &opBinary); - ge::graphStatus AddRuntimeKB(const ge::AscendString &opType, - const std::vector> &opRuntimeKb); - -// 获取资源 -public: - // 获取所有算子的描述信息 - const std::map &GetAllOpBinaryDesc() const; - - // 获取某个算子的描述信息 - ge::graphStatus GetOpBinaryDesc(const ge::AscendString &opType, nlohmann::json &binDesc) const; - - // 根据json文件路径(算子json中存在)查找.json/.o的信息 - ge::graphStatus GetOpBinaryDescByPath(const ge::AscendString &jsonFilePath, - std::tuple &binInfo) const; - - // 根据simplifiedKey(算子json中存在)查找.json/.o的信息 - ge::graphStatus GetOpBinaryDescByKey(const ge::AscendString &simplifiedKey, - std::tuple &binInfo) const; - - // 二进制知识库 - ge::graphStatus GetOpRuntimeKB(const ge::AscendString &opType, std::vector &kbList) const; - -private: - OpBinaryResourceManager() = default; // 单例,禁止外部创建对象 - OpBinaryResourceManager &operator=(const OpBinaryResourceManager &) = delete; // 禁止拷贝 - OpBinaryResourceManager &operator=(OpBinaryResourceManager &&) = delete; - OpBinaryResourceManager(const OpBinaryResourceManager &) = delete; - OpBinaryResourceManager(OpBinaryResourceManager &&) = delete; - - mutable std::recursive_mutex mutex_; - - // 二进制描述信息 opType -> xxx.json - std::map opBinaryDesc_; - - // 二进制jsonPath simplifiedKey -> jsonPath, 可能存在多个simplifiedKey对应同一个jsonPath - std::map keyToPath_; - - // 二进制信息 jsonPath -> xxx.json, xxx.o - std::map> pathToBinary_; - - // 二进制知识库 opType -> xxx1.json, xxx2.json - std::map> runtimeKb_; - - // infershape/op tiling/runtime kb parser等全局变量指针,注册类资源,仅需持有 - std::map> resourceHandle_; -}; -} // nnopbase - -#endif // INC_EXTERNAL_REGISTER_OP_BINARY_RESOURCE_MANAGER_H_ diff --git a/inc/external/register/op_check.h b/inc/external/register/op_check.h deleted file mode 100644 index efc1d26f6d8af302419d03fcc76704bd6681792d..0000000000000000000000000000000000000000 --- a/inc/external/register/op_check.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_CHECK_H_ -#define INC_EXTERNAL_REGISTER_OP_CHECK_H_ - -#include "op_check_register.h" -namespace optiling { - -#define REG_CHECK_SUPPORT(op_type, func) \ - static OpCheckFuncHelper op_check_registry_##op_type##_check_supported(FUNC_CHECK_SUPPORTED, #op_type, func) -#define REG_OP_SELECT_FORMAT(op_type, func) \ - static OpCheckFuncHelper op_check_registry_##op_type##_op_select_format(FUNC_OP_SELECT_FORMAT, #op_type, func) -#define REG_OP_SUPPORT_INFO(op_type, func) \ - static OpCheckFuncHelper op_check_registry_##op_type##_get_op_support_info(FUNC_GET_OP_SUPPORT_INFO, #op_type, func) -#define REG_OP_SPEC_INFO(op_type, func) \ - static OpCheckFuncHelper op_check_registry_##op_type##_get_specific_info(FUNC_GET_SPECIFIC_INFO, #op_type, func) - -#define REG_OP_PARAM_GENERALIZE(op_type, generalize_func) \ - static OpCheckFuncHelper op_check_generalize_registry_##op_type(#op_type, generalize_func) - -#define REG_REPLAY_FUNC(op_type, soc_version, func) \ - static ReplayFuncHelper op_replay_registry_##op_type_##soc_version(#op_type, #soc_version, func) -} // end of namespace optiling -#endif // INC_EXTERNAL_REGISTER_OP_CHECK_H_ \ No newline at end of file diff --git a/inc/external/register/op_compile_info_base.h b/inc/external/register/op_compile_info_base.h deleted file mode 100644 index 3becaeceacb55956563ad0b090ec28f083df129a..0000000000000000000000000000000000000000 --- a/inc/external/register/op_compile_info_base.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_COMPILE_INFO_BASE_H_ -#define INC_EXTERNAL_REGISTER_OP_COMPILE_INFO_BASE_H_ - -#include - -namespace optiling { -class CompileInfoBase; -using CompileInfoPtr = std::shared_ptr; - -class CompileInfoBase { -public: - CompileInfoBase() {} - virtual ~CompileInfoBase() {} -}; -} // namespace optiling -#endif // INC_REGISTER_OP_TILING_REGISTRY_H_ diff --git a/inc/external/register/op_config_registry.h b/inc/external/register/op_config_registry.h deleted file mode 100644 index cb1119c9d926d17705106c9b065fc0a1a573da0d..0000000000000000000000000000000000000000 --- a/inc/external/register/op_config_registry.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_CONFIG_REGISTRY_H_ -#define INC_EXTERNAL_REGISTER_OP_CONFIG_REGISTRY_H_ - -#include "register/op_def.h" - -namespace ops { -using OpAICoreConfigFunc = OpAICoreConfig (*)(); - -class OpConfigRegistry { -public: - OpConfigRegistry(); - void RegisterOpAICoreConfig(const char* name, const char* socVersion, OpAICoreConfigFunc func); -}; - -std::map GetOpAllAICoreConfig(const char* name); -} - -#define REGISTER_OP_AICORE_CONFIG(opType, socVersion, opFunc) REGISTER_OP_AICORE_CONFIG_UNIQ_HELPER(opType, socVersion, (opFunc), __COUNTER__) - -#define REGISTER_OP_AICORE_CONFIG_UNIQ_HELPER(opType, socVersion, opFunc, counter) REGISTER_OP_AICORE_CONFIG_UNIQ(opType, socVersion, (opFunc), counter) - -#define REGISTER_OP_AICORE_CONFIG_UNIQ(opType, socVersion, opFunc, counter) \ - static uint32_t g_##opType##Op##socVersion##ConfigRegistryInterfV1##counter = [](void) { \ - ops::OpConfigRegistry configRegistry; \ - configRegistry.RegisterOpAICoreConfig(#opType, #socVersion, opFunc); \ - return 0; \ - }() - -#endif // INC_EXTERNAL_REGISTER_OP_CONFIG_REGISTRY_H_ \ No newline at end of file diff --git a/inc/external/register/op_def.h b/inc/external/register/op_def.h deleted file mode 100644 index 1d6c812916aeb92cb18c9b4ba4ba3a0a48296b46..0000000000000000000000000000000000000000 --- a/inc/external/register/op_def.h +++ /dev/null @@ -1,508 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef OP_DEF_H -#define OP_DEF_H - -#include -#include -#include -#include "register/op_impl_registry.h" -#include "graph/operator_reg.h" - -namespace optiling { -#define FUNC_CHECK_SUPPORTED "check_supported" -#define FUNC_OP_SELECT_FORMAT "op_select_format" -#define FUNC_GET_OP_SUPPORT_INFO "get_op_support_info" -#define FUNC_GET_SPECIFIC_INFO "get_op_specific_info" - -using OP_CHECK_FUNC = ge::graphStatus (*)(const ge::Operator &op, ge::AscendString &result); - -using PARAM_GENERALIZE_FUNC = ge::graphStatus (*)(const ge::Operator &op, const ge::AscendString &generalize_config, - ge::AscendString &generalized_op_params); - -class OpCheckFuncHelper { -public: - OpCheckFuncHelper(const ge::AscendString &check_type, const ge::AscendString &op_type, OP_CHECK_FUNC func); - - OpCheckFuncHelper(const ge::AscendString &op_type, PARAM_GENERALIZE_FUNC func); -}; -} - -namespace ops { -class AclnnOpGenerator; -class Generator; -class OpProtoGenerator; -class GeneratorFactory; -class CfgGenerator; -class OpParamTrunk; - -enum Option { IGNORE = 0, OPTIONAL = 1, REQUIRED = 2, DYNAMIC = 3, VIRTUAL = 4 }; - -enum class FormatCheckOption : uint32_t { - DEFAULT = 0, - STRICT = 1, - MAX -}; - -enum class DependScope : uint32_t { - ALL = 0, - TILING = 1, - INVALID_SCOPE -}; - -enum class FollowType : uint32_t { - ALL = 0, - DTYPE = 1, - FORMAT = 2, - SHAPE = 3, - INVALID_TYPE -}; - -enum class AttrDataType { - ATTR_DT_BOOL = 0, - ATTR_DT_FLOAT = 1, - ATTR_DT_INT = 2, - ATTR_DT_STR = 3, - ATTR_DT_LIST_BOOL = 4, - ATTR_DT_LIST_FLOAT = 5, - ATTR_DT_LIST_INT = 6, - ATTR_DT_LIST_LIST_INT = 7, - ATTR_DT_MAX -}; - -enum class InitValueType : uint32_t { - INIT_VALUE_UINT64_T = 0, - INIT_VALUE_DEFAULT = static_cast(-1), -}; - -enum class CommentSection : uint32_t { - CATEGORY = 0, - BRIEF = 1, - CONSTRAINTS = 2, - RESTRICTIONS = 3, - SEE = 4, - THIRDPARTYFWKCOMPAT = 5, - SECTION_MAX -}; - -enum class ScalarType : uint32_t { - UINT64 = 0, - INT64 = 1, - UINT32 = 2, - INT32 = 3, - UINT16 = 4, - INT16 = 5, - UINT8 = 6, - INT8 = 7, - FLOAT32 = 8, - FLOAT16 = 9, - INVALID_DTYPE = static_cast(-1), -}; - -enum class HcclServerType : uint32_t { - AICPU = 0, - AICORE = 1, - MAX -}; - -union ScalarNum { - uint64_t value_u64; - int64_t value_i64; - float value_f32; - ScalarNum() : value_u64(0) {} - explicit ScalarNum(uint64_t value) : value_u64(value) {} - explicit ScalarNum(int64_t value) : value_i64(value) {} - explicit ScalarNum(float value) : value_f32(value) {} -}; - -using InitValueNum = ScalarNum; - -struct ScalarVar { - ScalarType scalar_type; - ScalarNum scalar_num; - ScalarVar() : scalar_type(ScalarType::INVALID_DTYPE) {} - ScalarVar(ScalarType type, uint64_t num) : scalar_type(type), scalar_num(num) { - if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { - scalar_num = ScalarNum(static_cast(num)); - } - } - ScalarVar(ScalarType type, int64_t num) : scalar_type(type), scalar_num(num) { - if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { - scalar_num = ScalarNum(static_cast(num)); - } - } - ScalarVar(ScalarType type, int num) : scalar_type(type), scalar_num(static_cast(num)) { - if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { - scalar_num = ScalarNum(static_cast(num)); - } - } - ScalarVar(ScalarType type, unsigned int num) : scalar_type(type), scalar_num(static_cast(num)) { - if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { - scalar_num = ScalarNum(static_cast(num)); - } - } - ScalarVar(ScalarType type, float num) : scalar_type(type), scalar_num(num) { - if (type != ScalarType::FLOAT32 && type != ScalarType::FLOAT16) { - if (type == ScalarType::UINT64) { - scalar_num = ScalarNum(static_cast(num)); - } - scalar_num = ScalarNum(static_cast(num)); - } - } - ScalarVar(ScalarType type, double num) : scalar_type(type), scalar_num(static_cast(num)) { - if (type != ScalarType::FLOAT32 && type != ScalarType::FLOAT16) { - if (type == ScalarType::UINT64) { - scalar_num = ScalarNum(static_cast(num)); - } - scalar_num = ScalarNum(static_cast(num)); - } - } - bool operator==(const ScalarVar& other) const { - if (scalar_type == other.scalar_type && scalar_num.value_u64 == other.scalar_num.value_u64) { - return true; - } - return false; - } -}; - -enum class ItemFindStatus { ITEM_FIND = 0, ITEM_NOEXIST = 1 }; - -class OpParamDefImpl; -class OpParamDef { -public: - explicit OpParamDef(const char *name); - OpParamDef(const OpParamDef &def); - ~OpParamDef(); - OpParamDef &operator=(const OpParamDef &def); - OpParamDef &ParamType(Option param_type); - OpParamDef &DataType(std::vector types); - OpParamDef &DataTypeList(std::vector types); - OpParamDef &Format(std::vector formats); - OpParamDef &FormatList(std::vector formats); - OpParamDef &DataTypeForBinQuery(std::vector types); - OpParamDef &FormatForBinQuery(std::vector formats); - OpParamDef &UnknownShapeFormat(std::vector formats); - OpParamDef &ValueDepend(Option value_depend); - OpParamDef &ValueDepend(Option value_depend, DependScope scope); - OpParamDef &IgnoreContiguous(void); - OpParamDef &AutoContiguous(); - OpParamDef &Scalar(); - OpParamDef &ScalarList(); - OpParamDef &To(const ge::DataType type); - OpParamDef &To(const char *name); - OpParamDef &Version(uint32_t version); - OpParamDef &InitValue(uint64_t value); - OpParamDef &InitValue(const ScalarVar &value); - OpParamDef &InitValue(const std::vector &value); - OpParamDef &OutputShapeDependOnCompute(); - OpParamDef &Follow(const char *paramName); - OpParamDef &Follow(const char *paramName, FollowType ftype); - OpParamDef &Comment(const char *comment); - -private: - friend class AclnnFallBackGenerator; - friend class AclnnOpGenerator; - friend class Generator; - friend class OpProtoGenerator; - friend class GeneratorFactory; - friend class CfgGenerator; - friend class OpParamTrunk; - friend class OpDef; - - bool operator==(const OpParamDef &def) const; - void MergeParam(const OpParamDef &def); - ge::AscendString &GetParamName(void) const; - Option GetParamType(void); - std::vector &GetDataTypes(void); - std::vector &GetOriginDataTypes(void); - std::vector &GetDataTypesList(void); - std::vector &GetDataTypesForBin(void) const; - std::vector &GetFormats(void); - std::vector &GetFormatsList(void); - std::vector &GetFormatsForBin(void) const; - std::vector &GetUnknownShapeFormats(void); - ge::AscendString &GetValueDepend(void) const; - DependScope &GetDependScope(void) const; - ge::AscendString &GetFollowName(void) const; - FollowType &GetFollowType(void) const; - ge::AscendString &GetComment(void) const; - bool GetIgnoreContiguous(void); - bool GetAutoContiguous(void); - bool IsScalar(void) const; - bool IsScalarList(void) const; - bool IsScalarOrScalarList(void) const; - bool IsScalarTypeSet(void) const; - bool IsScalarNameSet(void) const; - bool IsValueDepend(void) const; - bool IsDtype(void) const; - bool IsDtypeList(void) const; - bool IsFormat(void) const; - bool IsFormatList(void) const; - bool IsOutputShapeDependOnCompute(void) const; - bool IsSetDtypeForBin(void) const; - bool IsSetFormatForBin(void) const; - ge::AscendString &GetScalarName(void) const; - ge::DataType GetScalarType(void) const; - uint32_t GetVersion(void); - InitValueType &GetInitValueType(void); - InitValueNum &GetInitValue(void); - std::vector &GetInitValueList(void); - std::unique_ptr impl_; -}; - -class OpAttrDefImpl; -class OpAttrDef { -public: - explicit OpAttrDef(const char *name); - OpAttrDef(const OpAttrDef &attr_def); - ~OpAttrDef(); - OpAttrDef &operator=(const OpAttrDef &attr_def); - OpAttrDef &AttrType(Option attr_type); - OpAttrDef &Bool(void); - OpAttrDef &Bool(bool value); - OpAttrDef &Float(void); - OpAttrDef &Float(float value); - OpAttrDef &Int(void); - OpAttrDef &Int(int64_t value); - OpAttrDef &String(void); - OpAttrDef &String(const char *value); - OpAttrDef &ListBool(void); - OpAttrDef &ListBool(std::vector value); - OpAttrDef &ListFloat(void); - OpAttrDef &ListFloat(std::vector value); - OpAttrDef &ListInt(void); - OpAttrDef &ListInt(std::vector value); - OpAttrDef &ListListInt(void); - OpAttrDef &ListListInt(std::vector> value); - OpAttrDef &Version(uint32_t version); - OpAttrDef &Comment(const char *comment); - ge::AscendString &GetName(void) const; - bool IsRequired(void); - -private: - friend class AclnnFallBackGenerator; - friend class AclnnOpGenerator; - friend class Generator; - friend class OpProtoGenerator; - friend class GeneratorFactory; - friend class CfgGenerator; - friend class OpParamTrunk; - friend class OpDef; - - bool operator==(const OpAttrDef &attr_def) const; - ge::AscendString &GetCfgDataType(void) const; - ge::AscendString &GetProtoDataType(void) const; - ge::AscendString &GetAttrDefaultVal(const char *brac); - uint32_t GetVersion(void); - ge::AscendString &GetComment(void) const; - - std::unique_ptr impl_; -}; - -class OpAICoreConfigImpl; -class OpAICoreConfig { -public: - OpAICoreConfig(); - OpAICoreConfig(const char *soc); - OpAICoreConfig(const OpAICoreConfig &aicore_config); - ~OpAICoreConfig(); - OpAICoreConfig &operator=(const OpAICoreConfig &aicore_config); - OpParamDef &Input(const char *name); - OpParamDef &Output(const char *name); - OpAICoreConfig &DynamicCompileStaticFlag(bool flag); - OpAICoreConfig &DynamicFormatFlag(bool flag); - OpAICoreConfig &DynamicRankSupportFlag(bool flag); - OpAICoreConfig &DynamicShapeSupportFlag(bool flag); - OpAICoreConfig &NeedCheckSupportFlag(bool flag); - OpAICoreConfig &PrecisionReduceFlag(bool flag); - OpAICoreConfig &ExtendCfgInfo(const char *key, const char *value); - -private: - friend class AclnnFallBackGenerator; - friend class AclnnOpGenerator; - friend class Generator; - friend class OpProtoGenerator; - friend class GeneratorFactory; - friend class CfgGenerator; - friend class OpParamTrunk; - friend class OpDef; - - std::vector &GetInputs(void) const; - std::vector &GetOutputs(void) const; - std::vector &GetCfgKeys(void); - std::map &GetCfgInfo(void); - ge::AscendString &GetConfigValue(const char *key); - void AddCfgItem(const char *key, const char *value); - - std::unique_ptr impl_; -}; - -class OpAICoreDefImpl; -class OpAICoreDef { -public: - OpAICoreDef(); - OpAICoreDef(const OpAICoreDef &aicore_def); - ~OpAICoreDef(); - OpAICoreDef &operator=(const OpAICoreDef &aicore_def); - OpAICoreDef &SetTiling(gert::OpImplRegisterV2::TilingKernelFunc func); - OpAICoreDef &SetCheckSupport(optiling::OP_CHECK_FUNC func); - OpAICoreDef &SetOpSelectFormat(optiling::OP_CHECK_FUNC func); - OpAICoreDef &SetOpSupportInfo(optiling::OP_CHECK_FUNC func); - OpAICoreDef &SetOpSpecInfo(optiling::OP_CHECK_FUNC func); - OpAICoreDef &SetParamGeneralize(optiling::PARAM_GENERALIZE_FUNC func); - gert::OpImplRegisterV2::TilingKernelFunc &GetTiling(void); - optiling::OP_CHECK_FUNC &GetCheckSupport(void); - optiling::OP_CHECK_FUNC &GetOpSelectFormat(void); - optiling::OP_CHECK_FUNC &GetOpSupportInfo(void); - optiling::OP_CHECK_FUNC &GetOpSpecInfo(void); - optiling::PARAM_GENERALIZE_FUNC &GetParamGeneralize(void); - OpAICoreDef &AddConfig(const char *soc); - OpAICoreDef &AddConfig(const char *soc, OpAICoreConfig &aicore_config); - -private: - friend class AclnnFallBackGenerator; - friend class AclnnOpGenerator; - friend class Generator; - friend class OpProtoGenerator; - friend class GeneratorFactory; - friend class CfgGenerator; - friend class OpParamTrunk; - friend class OpDef; - - std::map &GetAICoreConfigs(void); - void Log(const char *op_type, const char *info) const; - - std::unique_ptr impl_; -}; - -class OpMC2DefImpl; -class OpMC2Def { -public: - OpMC2Def(); - OpMC2Def(const OpMC2Def &mc2_def); - ~OpMC2Def(); - OpMC2Def &operator=(const OpMC2Def &mc2_def); - OpMC2Def &HcclGroup(const char *value); - OpMC2Def &HcclGroup(std::vector value); - void HcclServerType(enum HcclServerType type, const char *soc = nullptr); - -private: - friend class AclnnFallBackGenerator; - friend class AclnnOpGenerator; - friend class Generator; - friend class OpProtoGenerator; - friend class GeneratorFactory; - friend class CfgGenerator; - friend class OpParamTrunk; - - std::vector &GetHcclGroups(void) const; - ops::HcclServerType GetHcclServerType(const ge::AscendString &soc_version = "") const; - std::unique_ptr impl_; -}; - -class OpDefImpl; -class OpDef { -public: - explicit OpDef(const char *type); - OpDef(const OpDef &op_def); - ~OpDef(); - OpDef &operator=(const OpDef &op_def); - OpParamDef &Input(const char *name); - OpParamDef &Output(const char *name); - OpAttrDef &Attr(const char *name); - OpDef &Comment(CommentSection section, const char *comment); - OpDef &SetInferShape(gert::OpImplRegisterV2::InferShapeKernelFunc func); - OpDef &SetInferShapeRange(gert::OpImplRegisterV2::InferShapeRangeKernelFunc func); - OpDef &SetInferDataType(gert::OpImplRegisterV2::InferDataTypeKernelFunc func); - gert::OpImplRegisterV2::InferShapeKernelFunc &GetInferShape(void); - gert::OpImplRegisterV2::InferShapeRangeKernelFunc &GetInferShapeRange(void); - gert::OpImplRegisterV2::InferDataTypeKernelFunc &GetInferDataType(void); - OpAICoreDef &AICore(void); - OpMC2Def &MC2(void); - OpDef &FormatMatchMode(FormatCheckOption option); - OpDef &EnableFallBack(void); - -private: - friend class AclnnFallBackGenerator; - friend class AclnnOpGenerator; - friend class Generator; - friend class OpProtoGenerator; - friend class GeneratorFactory; - friend class CfgGenerator; - friend class OpParamTrunk; - using ArrParam = std::pair; - struct DfsParam { - std::vector> full_types; - std::vector> full_formats; - std::vector types; - std::vector formats; - }; - enum class PortStat : uint32_t { - IN = 0, - OUT = 1, - INOUT = 2, - INVALID_STAT - }; - struct PortFollowInfo { - PortStat port_stat = PortStat::IN; - uint32_t index_in = 0; - uint32_t index_out = 0; - ge::AscendString follow_port_name = ""; - FollowType follow_type = FollowType::ALL; - }; - ge::AscendString &GetOpType(void); - ge::AscendString &GetCateGory(void) const; - std::vector &GetBrief(void) const; - std::vector &GetConstraints(void) const; - std::vector &GetRestrictions(void) const; - std::vector &GetSee(void) const; - std::vector &GetThirdPartyFwkCopat(void) const; - std::vector &GetInputs(void); - std::vector &GetOutputs(void); - std::vector &GetAttrs(void); - std::vector GetMergeInputs(OpAICoreConfig &aicore_config); - std::vector GetMergeOutputs(OpAICoreConfig &aicore_config); - void CheckIncompatible(const std::vector& all) const; - void FullPermutation(std::vector &input_param, std::vector &output_param); - void DfsFullPermutation(DfsParam &dfs_param, const std::vector &all_param, - uint32_t list_idx, uint32_t non_list_idx) const; - void DfsDataType(DfsParam &dfs_param, const std::vector &all_param, - uint32_t list_idx, uint32_t non_list_idx) const; - void DfsFormat(DfsParam &dfs_param, const std::vector &all_param, - uint32_t list_idx, uint32_t non_list_idx) const; - uint32_t GetNonListLen(std::vector &input_param, std::vector &output_param) const; - bool IsNonListTypes(const OpParamDef &def) const; - bool IsNonListFormats(const OpParamDef &def) const; - void SetDefaultND(std::vector &defs) const; - std::vector> GetMergeInputsOutputs(const OpAICoreConfig &aicore_config); - void SetPermutedParam(const DfsParam &dfs_param, std::vector &input, - std::vector &output); - void MergeParam(std::vector &merge, std::vector &aicore_params) const; - ItemFindStatus FindAttr(const char *name, OpAttrDef **attr); - OpAttrDef &AddAttr(OpAttrDef &attr); - OpAttrDef &GetOrCreateAttr(const char *name); - void FollowImpl(void); - void FollowListImpl(const DfsParam &dfs_param, std::vector& input, std::vector& output); - std::map GetFollowMap(void); - std::map>> GetFollowShapeMap(void); - std::map>> GetFollowTypeMap(void); - OpParamDef GetParamDef(const ge::AscendString& name, OpDef::PortStat stat); - FormatCheckOption GetFormatMatchMode(void); - bool IsEnableFallBack(void); - void UpdateInput(const DfsParam &dfs_param, std::vector &input); - void UpdateOutput(const DfsParam &dfs_param, std::vector &output); - void UpdateDtypeImpl(const DfsParam &dfs_param, OpParamDef ¶m, const uint32_t ¶m_idx); - void UpdateFormatImpl(const DfsParam &dfs_param, OpParamDef ¶m, const uint32_t ¶m_idx); - - std::unique_ptr impl_; -}; -} // namespace ops - -#endif diff --git a/inc/external/register/op_def_factory.h b/inc/external/register/op_def_factory.h deleted file mode 100644 index 9f5a99c4f258b49277f61e12f04423bc64f6e8bf..0000000000000000000000000000000000000000 --- a/inc/external/register/op_def_factory.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef OP_DEF_FACTORY_H -#define OP_DEF_FACTORY_H - -#include "register/op_def.h" - -namespace optiling { -class DeviceOpImplRegister; -} -namespace ops { -using OpDefCreator = std::function; -class OpDefFactory { -public: - static int OpDefRegister(const char *name, OpDefCreator creator); - -private: - friend class AclnnFallBackGenerator; - friend class AclnnOpGenerator; - friend class Generator; - friend class OpProtoGenerator; - friend class GeneratorFactory; - friend class CfgGenerator; - friend class optiling::DeviceOpImplRegister; - - static OpDef OpDefCreate(const char *name); - static std::vector &GetAllOp(void); - static void OpTilingSinkRegister(const char *opType); - static bool OpIsTilingSink(const char *opType); -}; -} // namespace ops - -#endif diff --git a/inc/external/register/op_def_registry.h b/inc/external/register/op_def_registry.h deleted file mode 100644 index 0c652a7a9e6ec34bdb1abd48e0e861592bcac77a..0000000000000000000000000000000000000000 --- a/inc/external/register/op_def_registry.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef OP_DEF_REGISTRY_H -#define OP_DEF_REGISTRY_H - -#include "register/op_def.h" -#include "register/op_def_factory.h" - -#if defined(OP_PROTO_LIB) - -#define OP_ADD(opType, ...) \ - static int g_##opType##_added = [](const char *name) { \ - opType op(#opType); \ - gert::OpImplRegisterV2 impl(#opType); \ - impl.InferShape(op.GetInferShape()) \ - .InferShapeRange(op.GetInferShapeRange()) \ - .InferDataType(op.GetInferDataType()); \ - gert::OpImplRegisterV2 implReg(impl); \ - return 0; \ - }(#opType) - -#elif defined(OP_TILING_LIB) - -#define OP_ADD(opType, ...) \ - struct OpAddCompilerInfoPlaceholder##opType {}; \ - static ge::graphStatus TilingPrepare##opType(gert::TilingParseContext *context) { return ge::GRAPH_SUCCESS; } \ - static int g_##opType##_added = [](const char *name) { \ - opType op(#opType); \ - gert::OpImplRegisterV2 impl(#opType); \ - impl.Tiling(op.AICore().GetTiling()); \ - impl.TilingParse(TilingPrepare##opType); \ - optiling::OpCheckFuncHelper(FUNC_CHECK_SUPPORTED, #opType, op.AICore().GetCheckSupport()); \ - optiling::OpCheckFuncHelper(FUNC_OP_SELECT_FORMAT, #opType, op.AICore().GetOpSelectFormat()); \ - optiling::OpCheckFuncHelper(FUNC_GET_OP_SUPPORT_INFO, #opType, op.AICore().GetOpSupportInfo()); \ - optiling::OpCheckFuncHelper(FUNC_GET_SPECIFIC_INFO, #opType, op.AICore().GetOpSpecInfo()); \ - optiling::OpCheckFuncHelper(#opType, op.AICore().GetParamGeneralize()); \ - gert::OpImplRegisterV2 implReg(impl); \ - return 0; \ - }(#opType) - -#else - -#define OP_ADD(opType, ...) \ - static int g_##opType##_added = \ - ops::OpDefFactory::OpDefRegister(#opType, [](const char *name) { return opType(#opType); }) - -#endif -#endif diff --git a/inc/external/register/op_info_record_registry.h b/inc/external/register/op_info_record_registry.h deleted file mode 100644 index 1c25f3070a26d3893bac4c03dee0e6f014f48534..0000000000000000000000000000000000000000 --- a/inc/external/register/op_info_record_registry.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_INFO_RECORD_REGISTRY_H_ -#define INC_EXTERNAL_REGISTER_OP_INFO_RECORD_REGISTRY_H_ -#include - -#include "external/exe_graph/runtime/tiling_context.h" - -namespace OpInfoRecord { -struct OpCompilerOption { - explicit OpCompilerOption(const std::string &impl_mode_v, bool deterministic_v = true) : - impl_mode(impl_mode_v), deterministic(deterministic_v) {} - explicit OpCompilerOption(const char *impl_mode_v, bool deterministic_v = true) : - impl_mode(impl_mode_v), deterministic(deterministic_v) {} - std::string impl_mode; - bool deterministic; -}; - -struct OpKernelInfo { - explicit OpKernelInfo(const std::string &bin_info_v, int8_t bin_type_v) : - bin_info(bin_info_v), bin_type(bin_type_v) {} - explicit OpKernelInfo(const char *bin_info_v, int8_t bin_type_v) : - bin_info(bin_info_v), bin_type(bin_type_v) {} - std::string bin_info; - int8_t bin_type; -}; - -class __attribute__((visibility("default"))) OpInfoRecordRegister { -public: - using NotifyFn = void(*)(bool); - static OpInfoRecordRegister *Instance(); - /* - * @ingroup OpInfoRecord - * @brief Register the notification function - * @param notify_fn [IN] Callback notification function. - */ - void RegNotify(const NotifyFn notifyFn) const; - - /* - * @ingroup OpInfoRecord - * @brief Obtains the current switch status - * @retval true: The switch is enabled. - * @retval false: The switch is disablesd. - */ - bool GetSwitchState() const; - - /* - * @ingroup OpInfoRecord - * @brief Output the current operator information - * - * @param ctx [IN] Operator context information - * @param opt [IN] Operator compile option - */ - void ExeOptInfoStat( - const gert::TilingContext *ctx, - const OpCompilerOption &opt, - const OpKernelInfo *kernelInfo) const; - -private: - OpInfoRecordRegister() = default; - ~OpInfoRecordRegister() = default; - OpInfoRecordRegister(const OpInfoRecordRegister &) = delete; - OpInfoRecordRegister &operator=(const OpInfoRecordRegister &) = delete; -}; // class OpInfoRecordRegister -} // namespace OpInfoRecord -#endif // INC_EXTERNAL_REGISTER_OP_INFO_RECORD_REGISTRY_H_ diff --git a/inc/external/register/op_lib_register.h b/inc/external/register/op_lib_register.h deleted file mode 100644 index ea62be2b8a6e8aa28c6a56dd4988f4c74e33304f..0000000000000000000000000000000000000000 --- a/inc/external/register/op_lib_register.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_LIB_REGISTER_H -#define INC_EXTERNAL_REGISTER_OP_LIB_REGISTER_H -#include "graph/compiler_def.h" -#include "graph/types.h" -#include "graph/ascend_string.h" - -namespace ge { -class OpLibRegisterImpl; -class OpLibRegister { - public: - explicit OpLibRegister(const char_t *vendor_name); - OpLibRegister(OpLibRegister &&other) noexcept; - OpLibRegister(const OpLibRegister &other); - OpLibRegister &operator=(const OpLibRegister &) = delete; - OpLibRegister &operator=(OpLibRegister &&) = delete; - ~OpLibRegister(); - - using OpLibInitFunc = uint32_t (*)(ge::AscendString&); - OpLibRegister &RegOpLibInit(OpLibInitFunc func); - - private: - std::unique_ptr impl_; -}; -} // namespace ge - -#define REGISTER_OP_LIB(vendor_name) REGISTER_OP_LIB_UNIQ_HELPER(vendor_name, __COUNTER__) - -#define REGISTER_OP_LIB_UNIQ_HELPER(vendor_name, counter) REGISTER_OP_LIB_UNIQ(vendor_name, counter) - -#define REGISTER_OP_LIB_UNIQ(vendor_name, counter) \ - static ge::OpLibRegister VAR_UNUSED g_##vendor_name##counter = ge::OpLibRegister(#vendor_name) - -#endif // INC_EXTERNAL_REGISTER_OP_LIB_REGISTER_H diff --git a/inc/external/register/op_tiling_attr_utils.h b/inc/external/register/op_tiling_attr_utils.h deleted file mode 100644 index a168d349ccb973d19f53ca47bef95cf9b352db79..0000000000000000000000000000000000000000 --- a/inc/external/register/op_tiling_attr_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_TILING_ATTR_UTILS_H_ -#define INC_EXTERNAL_REGISTER_OP_TILING_ATTR_UTILS_H_ - -#include -#include "external/graph/operator.h" - -namespace optiling { -class AttrData; -using AttrDataPtr = std::shared_ptr; - -class AttrData { -public: - AttrData() {} - virtual ~AttrData() {} - virtual size_t GetSize() const = 0; - virtual const std::uint8_t *GetData() = 0; -}; - -ge::graphStatus GetOperatorAttrValue(const ge::Operator &op, const char *attr_name, const char *attr_dtype, - AttrDataPtr &attr_data_ptr, const char *target_dtype = nullptr); - -} // namespace optiling -#endif // INC_EXTERNAL_REGISTER_OP_TILING_ATTR_UTILS_H_ diff --git a/inc/external/register/op_tiling_info.h b/inc/external/register/op_tiling_info.h deleted file mode 100644 index 6c853970f8413b6b30f5cc3b3b84a9aa8dae7962..0000000000000000000000000000000000000000 --- a/inc/external/register/op_tiling_info.h +++ /dev/null @@ -1,172 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_TILING_INFO_H_ -#define INC_EXTERNAL_REGISTER_OP_TILING_INFO_H_ - -#include -#include -#include -#include "external/graph/ge_error_codes.h" -#include "external/graph/ascend_string.h" -#include "external/graph/tensor.h" - -namespace optiling { -using ByteBuffer = std::stringstream; - -enum class TensorArgType { - TA_NONE, - TA_SINGLE, - TA_LIST, -}; - -class TeOpVarAttrArgsImpl; -class TeOpVarAttrArgs { - friend class VarAttrHelper; - -public: - TeOpVarAttrArgs() = default; - ~TeOpVarAttrArgs() = default; - const uint8_t *GetData(const std::string &name, const std::string &dtype, size_t &size) const; - -private: - std::shared_ptr impl_; -}; - -struct TeOpTensor { - std::vector shape; - std::vector ori_shape; - std::string format; - std::string ori_format; - std::string dtype; - std::string name; - std::map attrs; -}; - -struct TeOpTensorArg { - TensorArgType arg_type; - std::vector tensor; -}; - -struct OpRunInfo { - uint32_t block_dim; - std::vector workspaces; - ByteBuffer tiling_data; - bool clear_atomic; - uint64_t tiling_key; - int32_t tiling_cond; -}; - -using TeOpAttrArgs = std::vector; -using TeConstTensorData = std::tuple; - -struct TeOpParas { - std::vector inputs; - std::vector outputs; - std::map const_inputs; - TeOpAttrArgs attrs; - std::string op_type; - TeOpVarAttrArgs var_attrs; -}; - -struct OpCompileInfo { - std::string str; - std::string key; -}; - -namespace utils { -class OpRunInfoImpl; -class OpRunInfo { -public: - OpRunInfo(); - ~OpRunInfo() = default; - - OpRunInfo(const uint32_t &block_dim, const bool &clear_atomic, const uint64_t &tiling_key); - // Copy - OpRunInfo(const OpRunInfo &runinfo); - // Move - OpRunInfo(OpRunInfo &&runinfo); - // Copy - OpRunInfo &operator=(const OpRunInfo &runinfo); - // Move - OpRunInfo &operator=(OpRunInfo &&runinfo); - - void SetBlockDim(const uint32_t &block_dim); - uint32_t GetBlockDim() const; - void SetAicpuBlockDim(uint32_t block_dim); - uint32_t GetAicpuBlockDim() const; - void SetScheduleMode(const uint32_t schedule_mode); - uint32_t GetScheduleMode() const; - void AddWorkspace(const int64_t &workspace); - size_t GetWorkspaceNum() const; - ge::graphStatus GetWorkspace(const size_t &idx, int64_t &workspace) const; - void GetAllWorkspaces(std::vector &workspaces) const; - const std::vector &GetAllWorkspaces() const; - void SetWorkspaces(const std::vector &workspaces); - - template - void AddTilingData(const T &value) { - AddTilingData(reinterpret_cast(&value), sizeof(value)); - } - template - void operator << (const T &value) { - AddTilingData(reinterpret_cast(&value), sizeof(T)); - } - void AddTilingData(const char *value, const size_t size); - void* GetAddrBase(uint64_t& max_size) const; - void SetAddrBaseOffset(const uint64_t size); - ByteBuffer &GetAllTilingData(); - const ByteBuffer &GetAllTilingData() const; - void InternelSetTiling(const ByteBuffer &value); - void SetClearAtomic(const bool clear_atomic); - bool GetClearAtomic() const; - - void SetTilingKey(const uint64_t &new_tiling_key); - uint64_t GetTilingKey() const; - uint64_t GetTilingDataSize() const; - void ResetWorkspace(); - void ResetAddrBase(void *const addr_base, const uint64_t max_size); - void AlignOffsetWith64(); - bool SetMemCheckBaseOffset(const uint64_t &offset); - void SetTilingCond(const int32_t tiling_cond); - int32_t GetTilingCond() const; - void SetLocalMemorySize(const uint32_t local_memory_size); - uint32_t GetLocalMemorySize() const; -private: - std::shared_ptr impl_; -}; - -class OpCompileInfoImpl; -class OpCompileInfo { -public: - OpCompileInfo(); - ~OpCompileInfo() = default; - OpCompileInfo(const ge::AscendString &key, const ge::AscendString &value); - OpCompileInfo(const std::string &key, const std::string &value); - // Copy - OpCompileInfo(const OpCompileInfo &compileinfo); - // Move - OpCompileInfo(OpCompileInfo &&compileinfo); - // Copy - OpCompileInfo &operator=(const OpCompileInfo &compileinfo); - // Move - OpCompileInfo &operator=(OpCompileInfo &&compileinfo); - - void SetKey(const ge::AscendString &key); - const ge::AscendString &GetKey() const; - - void SetValue(const ge::AscendString &value); - const ge::AscendString &GetValue() const; - -private: - std::shared_ptr impl_; -}; -} -} // namespace optiling -#endif // INC_REGISTER_OP_TILING_REGISTRY_H_ diff --git a/inc/external/register/op_tiling_registry.h b/inc/external/register/op_tiling_registry.h deleted file mode 100644 index e54639f2ae5bb2e106797f99dc67fbd3bf4ff185..0000000000000000000000000000000000000000 --- a/inc/external/register/op_tiling_registry.h +++ /dev/null @@ -1,152 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_OP_TILING_REGISTRY_H_ -#define INC_EXTERNAL_REGISTER_OP_TILING_REGISTRY_H_ - -#include -#include -#include -#include -#include -#include "external/graph/operator.h" -#include "external/register/register_error_codes.h" -#include "external/register/register_types.h" -#include "external/register/op_compile_info_base.h" -#include "external/register/op_tiling_info.h" - -#define REGISTER_OP_TILING(optype, opfunc) REGISTER_OP_TILING_UNIQ_HELPER(optype, (opfunc), __COUNTER__) - -#define REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, counter) REGISTER_OP_TILING_UNIQ(optype, (opfunc), counter) - -#define REGISTER_OP_TILING_V2(optype, opfunc) REGISTER_OP_TILING_UNIQ_HELPER_V2(optype, (opfunc), __COUNTER__) - -#define REGISTER_OP_TILING_UNIQ_HELPER_V2(optype, opfunc, counter) REGISTER_OP_TILING_UNIQ_V2(optype, (opfunc), counter) - -#define REGISTER_OP_TILING_V3(optype, tilingfunc, parsefunc) \ - REGISTER_OP_TILING_UNIQ_HELPER_V3(optype, (tilingfunc), (parsefunc), __COUNTER__) - -#define REGISTER_OP_TILING_UNIQ_HELPER_V3(optype, tilingfunc, parsefunc, counter) \ - REGISTER_OP_TILING_UNIQ_V3(optype, (tilingfunc), (parsefunc), counter) - -#define REGISTER_OP_TILING_V4(optype, tilingfunc, parsefunc) \ - REGISTER_OP_TILING_UNIQ_HELPER_V4(optype, (tilingfunc), (parsefunc), __COUNTER__) - -#define REGISTER_OP_TILING_UNIQ_HELPER_V4(optype, tilingfunc, parsefunc, counter) \ - REGISTER_OP_TILING_UNIQ_V4(optype, (tilingfunc), (parsefunc), counter) - -#ifdef DISABLE_COMPILE_V1 -#define REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) -#define REGISTER_OP_TILING_UNIQ_V2(optype, opfunc, counter) -#define REGISTER_OP_TILING_UNIQ_V3(optype, tilingfunc, parsefunc, counter) -#define REGISTER_OP_TILING_UNIQ_V4(optype, tilingfunc, parsefunc, counter) -#else -#define REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) \ - static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV1##counter(#optype, (opfunc)) - -#define REGISTER_OP_TILING_UNIQ_V2(optype, opfunc, counter) \ - static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV2##counter(#optype, (opfunc)) - -#define REGISTER_OP_TILING_UNIQ_V3(optype, tilingfunc, parsefunc, counter) \ - static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV3##counter(#optype, (tilingfunc), (parsefunc)) - -#define REGISTER_OP_TILING_UNIQ_V4(optype, tilingfunc, parsefunc, counter) \ - static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV4##counter(#optype, (tilingfunc), (parsefunc)) -#endif - - -using Status = domi::Status; -namespace optiling { -template -ByteBuffer &ByteBufferPut(ByteBuffer &buf, const T &buffer_value) { - (void) buf.write(reinterpret_cast(&buffer_value), static_cast(sizeof(buffer_value))); - (void) buf.flush(); - return buf; -} - -template -ByteBuffer &ByteBufferGet(ByteBuffer &buf, T &buffer_value) { - (void) buf.read(reinterpret_cast(&buffer_value), static_cast(sizeof(buffer_value))); - return buf; -} - -size_t ByteBufferGetAll(ByteBuffer &buf, ge::char_t *dest, size_t dest_len); -ByteBuffer &ByteBufferPut(ByteBuffer &buf, const uint8_t *data, size_t data_len); - -using OpTilingFunc = std::function; -using OpTilingFuncPtr = std::shared_ptr; -class FMK_FUNC_HOST_VISIBILITY OpTilingRegistryInterf { - public: - OpTilingRegistryInterf(std::string op_type, OpTilingFunc func); - ~OpTilingRegistryInterf() = default; - static std::unordered_map &RegisteredOpInterf(); -}; - -using OpRunInfoV2 = utils::OpRunInfo; -using OpCompileInfoV2 = utils::OpCompileInfo; -using OpTilingFuncV2 = std::function; -using OpTilingFuncV2Ptr = std::shared_ptr; -class FMK_FUNC_HOST_VISIBILITY OpTilingRegistryInterf_V2 { -public: - OpTilingRegistryInterf_V2(const std::string &op_type, OpTilingFuncV2 func); - ~OpTilingRegistryInterf_V2() = default; - static std::unordered_map &RegisteredOpInterf(); -}; - -using OpTilingFuncV3 = std::function; -using OpParseFuncV3 = std::function; -using OpTilingFuncV4 = std::function; -using OpParseFuncV4 = std::function; - -class OpTilingFuncInfo { -public: - explicit OpTilingFuncInfo(const std::string &op_type); - OpTilingFuncInfo() = default; - ~OpTilingFuncInfo() = default; - - bool IsFunctionV4(); - bool IsFunctionV3(); - bool IsFunctionV2(); - bool IsFunctionV1(); - void SetOpTilingFunc(OpTilingFunc &tiling_func); - void SetOpTilingFuncV2(OpTilingFuncV2 &tiling_func); - void SetOpTilingFuncV3(OpTilingFuncV3 &tiling_func, OpParseFuncV3 &parse_func); - void SetOpTilingFuncV4(OpTilingFuncV4 &tiling_func, OpParseFuncV4 &parse_func); - const OpTilingFunc& GetOpTilingFunc(); - const OpTilingFuncV2& GetOpTilingFuncV2(); - const OpTilingFuncV3& GetOpTilingFuncV3(); - const OpParseFuncV3& GetOpParseFuncV3(); - const OpTilingFuncV4& GetOpTilingFuncV4(); - const OpParseFuncV4& GetOpParseFuncV4(); - const std::string& GetOpType() const { - return op_type_; - } - -private: - std::string op_type_; - OpTilingFunc tiling_func_; - OpTilingFuncV2 tiling_func_v2_; - OpTilingFuncV3 tiling_func_v3_; - OpParseFuncV3 parse_func_v3_; - OpTilingFuncV4 tiling_func_v4_; - OpParseFuncV4 parse_func_v4_; -}; - -class FMK_FUNC_HOST_VISIBILITY OpTilingFuncRegistry { -public: - OpTilingFuncRegistry(const std::string &op_type, OpTilingFunc tiling_func); - OpTilingFuncRegistry(const std::string &op_type, OpTilingFuncV2 tiling_func); - OpTilingFuncRegistry(const std::string &op_type, OpTilingFuncV3 tiling_func, OpParseFuncV3 parse_func); - OpTilingFuncRegistry(const std::string &op_type, OpTilingFuncV4 tiling_func, OpParseFuncV4 parse_func); - ~OpTilingFuncRegistry() = default; - static std::unordered_map &RegisteredOpFuncInfo(); -}; - -} // namespace optiling -#endif // INC_EXTERNAL_REGISTER_OP_TILING_REGISTRY_H_ diff --git a/inc/external/register/register_base.h b/inc/external/register/register_base.h deleted file mode 100644 index fb19f663658006dcb85fdff10361758afd62382f..0000000000000000000000000000000000000000 --- a/inc/external/register/register_base.h +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_BASE_H -#define INC_EXTERNAL_REGISTER_BASE_H - -#ifdef __cplusplus -extern "C" { -#endif - -const char *aclGetCustomOpLibPath(); - -#ifdef __cplusplus -} -#endif - -#endif // INC_EXTERNAL_REGISTER_BASE_H diff --git a/inc/external/register/register_custom_pass.h b/inc/external/register/register_custom_pass.h deleted file mode 100644 index d89c64c77ec7ac5b330420eabc6dfca6545523db..0000000000000000000000000000000000000000 --- a/inc/external/register/register_custom_pass.h +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_REGISTER_PASS_H_ -#define INC_EXTERNAL_REGISTER_REGISTER_PASS_H_ - -#include -#include -#include - -#include "graph/graph.h" -#include "external/ge_common/ge_api_error_codes.h" -#include "register/register_types.h" - -namespace ge { -class PassRegistrationDataImpl; -class CustomPassContext; -class CustomPassContextImpl; -class StreamPassContext; -class StreamPassContextImpl; -using ConstGraphPtr = std::shared_ptr; -using CustomPassFunc = std::function; -using CustomAllocateStreamPassFunc = std::function; -constexpr int64_t INVALID_STREAM_ID = -1; - -/** - * 自定义pass执行阶段,若需扩展,请在kInvalid之前添加 - */ -enum class CustomPassStage : uint32_t { - kBeforeInferShape = 0, - kAfterInferShape = 1, - kAfterAssignLogicStream = 2, // only support CustomAllocateStreamPassFunc in this stage - kAfterBuiltinFusionPass = 3, - kInvalid -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PassRegistrationData { - public: - PassRegistrationData() = default; - ~PassRegistrationData() = default; - - PassRegistrationData(std::string pass_name); - - PassRegistrationData &CustomPassFn(const CustomPassFunc &custom_pass_fn); - - std::string GetPassName() const; - - CustomPassFunc GetCustomPassFn() const; - - PassRegistrationData &Stage(const CustomPassStage stage); - - CustomPassStage GetStage() const; - - PassRegistrationData &CustomAllocateStreamPassFn(const CustomAllocateStreamPassFunc &allocate_stream_pass_fn); - - CustomAllocateStreamPassFunc GetCustomAllocateStreamPass() const; - - private: - std::shared_ptr impl_; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PassReceiver { - public: - PassReceiver(PassRegistrationData ®_data); - ~PassReceiver() = default; -}; - -class CustomPassContext { - public: - CustomPassContext(); - virtual ~CustomPassContext() = default; - - void SetErrorMessage(const AscendString &error_message); - - AscendString GetErrorMessage() const; - - private: - std::unique_ptr impl_; -}; - -class StreamPassContext : public CustomPassContext { -public: - explicit StreamPassContext(int64_t current_max_stream_id); - - ~StreamPassContext() override = default; - - graphStatus SetStreamId(const GNode &node, int64_t stream_id); - - int64_t GetStreamId(const GNode &node) const; - - int64_t AllocateNextStreamId(); - - int64_t GetCurrMaxStreamId() const; - -private: - std::unique_ptr impl_; -}; -} // namespace ge - -#define REGISTER_CUSTOM_PASS(name) REGISTER_CUSTOM_PASS_UNIQ_HELPER(__COUNTER__, (name)) -#define REGISTER_CUSTOM_PASS_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_PASS_UNIQ(ctr, (name)) -#define REGISTER_CUSTOM_PASS_UNIQ(ctr, name) \ - static ::ge::PassReceiver register_pass##ctr \ - __attribute__((unused)) = \ - ::ge::PassRegistrationData((name)) - -#endif // INC_EXTERNAL_REGISTER_REGISTER_PASS_H_ diff --git a/inc/external/register/scope/scope_fusion_pass_register.h b/inc/external/register/scope/scope_fusion_pass_register.h deleted file mode 100644 index 417ca38bf30fb1fc6a5f7897dccfc9345269fc6d..0000000000000000000000000000000000000000 --- a/inc/external/register/scope/scope_fusion_pass_register.h +++ /dev/null @@ -1,401 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ -#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ - -#include -#include -#include -#include -#include -#include "external/ge_common/ge_api_error_codes.h" -#include "register/register_error_codes.h" -#include "register/register_types.h" -#include "graph/operator.h" - -#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ - do { \ - if (!(cond)) { \ - if ((fusion_rlt) != nullptr) { \ - (fusion_rlt)->SetType(ge::kScopeInvalidType); \ - } \ - return; \ - } \ - } while (0) - -namespace domi { -class TensorFlowModelParser; -} // namespace domi -namespace ge { -const int32_t kFusionDisableIndex = 99999; -const char_t *const kScopeToMultiNodes = "ScopeToMultiNodes"; -const char_t *const kScopeInvalidType = "ScopeInvalidType"; -const char_t *const kInputFromFusionScope = "InputFromFusionScope"; -const char_t *const kOutputToFusionScope = "OutputToFusionScope"; -class ScopePattern; -using ScopeFusionPatterns = std::vector>; - -class ScopePassManager; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { - public: - Scope(); - ATTRIBUTED_DEPRECATED(Status Init(const char_t *, const char_t *, Scope *)) - Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); - Status Init(const char_t *name, const char_t *sub_type, Scope *father_scope = nullptr); - ~Scope(); - ATTRIBUTED_DEPRECATED(Status Name(AscendString &) const) - const std::string &Name() const; - Status Name(AscendString &name) const; - ATTRIBUTED_DEPRECATED(Status SubType(AscendString &) const) - const std::string &SubType() const; - Status SubType(AscendString &sub_type) const; - ATTRIBUTED_DEPRECATED(Status AllNodesMap(std::unordered_map &) const) - const std::unordered_map &AllNodesMap() const; - Status AllNodesMap(std::unordered_map &node_map) const; - ATTRIBUTED_DEPRECATED(Scope *GetSubScope(const char_t *scope_name) const) - Scope *GetSubScope(const std::string &scope_name) const; - Scope *GetSubScope(const char_t *scope_name) const; - ATTRIBUTED_DEPRECATED(Status LastName(AscendString &) const) - const std::string LastName() const; - Status LastName(AscendString &name) const; - const std::vector &GetAllSubScopes() const; - const Scope *GetFatherScope() const; - - private: - class ScopeImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; - friend class ScopeTree; - friend class NodeOpTypeFeature; - friend class NodeAttrFeature; - friend class ScopeFeature; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { - public: - FusionScopesResult(); - Status Init(); - ~FusionScopesResult(); - ATTRIBUTED_DEPRECATED(void SetName(const char_t *)) - void SetName(const std::string &name); - void SetName(const char_t *name); - ATTRIBUTED_DEPRECATED(void SetType(const char_t *)) - void SetType(const std::string &type); - void SetType(const char_t *type); - ATTRIBUTED_DEPRECATED(void SetDescription(const char_t *)) - void SetDescription(const std::string &description); - void SetDescription(const char_t *description); - ATTRIBUTED_DEPRECATED(const Status Name(AscendString &) const) - const std::string &Name() const; - Status Name(AscendString &name) const; - const std::vector &Nodes() const; - ATTRIBUTED_DEPRECATED(void InsertInputs(const char_t *, const std::vector &)) - void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); - void InsertInputs(const char_t *inner_op_name, const std::vector &index_map); - ATTRIBUTED_DEPRECATED(void InsertOutputs(const char_t *, const std::vector &)) - void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); - void InsertOutputs(const char_t *inner_op_name, const std::vector &index_map); - - class InnerNodeInfo { - public: - ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char_t *)) - explicit InnerNodeInfo(const std::string &fusion_node_name); - explicit InnerNodeInfo(const char_t *fusion_node_name); - ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char_t *, const char_t *, const char_t *)) - InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); - InnerNodeInfo(const char_t *fusion_node_name, const char_t *name, const char_t *type); - InnerNodeInfo(InnerNodeInfo &&other) noexcept; - InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; - InnerNodeInfo(const InnerNodeInfo &) = delete; - InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; - ~InnerNodeInfo(); - ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetName(const char_t *)) - InnerNodeInfo &SetName(const std::string &name); - InnerNodeInfo &SetName(const char_t *name); - ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetType(const char_t *)) - InnerNodeInfo &SetType(const std::string &type); - InnerNodeInfo &SetType(const char_t *type); - ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertInput(const char_t *, int32_t)) - InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); - InnerNodeInfo &InsertInput(const char_t *input_node, int32_t peer_out_idx); - ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertOutput(const char_t *, int32_t)) - InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); - InnerNodeInfo &InsertOutput(const char_t *output_node, int32_t peer_in_idx); - ge::graphStatus BuildInnerNode(); - ATTRIBUTED_DEPRECATED(ge::graphStatus SetInputFormat(const char_t *, const char_t *)) - ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); - ge::graphStatus SetInputFormat(const char_t *input_name, const char_t *format); - ATTRIBUTED_DEPRECATED(ge::graphStatus SetOutputFormat(const char_t *, const char_t *)) - ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); - ge::graphStatus SetOutputFormat(const char_t *output_name, const char_t *format); - ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicInputFormat(const char_t *, uint32_t index, const char_t *)) - ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); - ge::graphStatus SetDynamicInputFormat(const char_t *input_name, uint32_t index, const char_t *format); - ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicOutputFormat(const char_t *, uint32_t, const char_t *)) - ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); - ge::graphStatus SetDynamicOutputFormat(const char_t *output_name, uint32_t index, const char_t *format); - ge::Operator *MutableOperator(); - ATTRIBUTED_DEPRECATED(ge::graphStatus GetName(AscendString &) const) - std::string GetName() const; - ge::graphStatus GetName(AscendString &name) const; - ATTRIBUTED_DEPRECATED(ge::graphStatus GetType(AscendString &) const) - std::string GetType() const; - ge::graphStatus GetType(AscendString &type) const; - ATTRIBUTED_DEPRECATED(ge::graphStatus GetInputs(std::vector> &) const) - std::vector> GetInputs() const; - ge::graphStatus GetInputs(std::vector> &inputs) const; - ATTRIBUTED_DEPRECATED(ge::graphStatus GetOutputs(std::vector> &) const) - std::vector> GetOutputs() const; - ge::graphStatus GetOutputs(std::vector> &outputs) const; - private: - class InnerNodeInfoImpl; - std::unique_ptr impl_; - }; - ATTRIBUTED_DEPRECATED(InnerNodeInfo *AddInnerNode(const char_t *, const char_t *)) - InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); - InnerNodeInfo *AddInnerNode(const char_t *name, const char_t *type); - InnerNodeInfo *MutableRecentInnerNode(); - InnerNodeInfo *MutableInnerNode(uint32_t index); - ge::graphStatus CheckInnerNodesInfo(); - - private: - class FusionScopesResultImpl; - std::unique_ptr impl_; - friend class ScopeGraph; - friend class ScopeBasePass; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { - public: - ScopeTree(); - Status Init(); - ScopeTree(const ScopeTree &scopetree) = delete; - ScopeTree &operator=(const ScopeTree &scopetree) = delete; - ~ScopeTree(); - - const std::vector &GetAllScopes() const; - - private: - class ScopeTreeImpl; - std::unique_ptr impl_; - friend class ScopeGraph; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { - public: - ScopeGraph(); - Status Init(); - ScopeGraph(const ScopeGraph &scope_graph) = delete; - ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; - ~ScopeGraph(); - - const ScopeTree *GetScopeTree() const; - ATTRIBUTED_DEPRECATED(Status GetNodesMap(std::unordered_map &) const) - const std::unordered_map &GetNodesMap() const; - Status GetNodesMap(std::unordered_map &nodes_map) const; - - private: - class ScopeGraphImpl; - std::unique_ptr impl_; - friend class ScopePassManager; - friend class ScopeBasePass; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { - public: - ScopeAttrValue(); - ScopeAttrValue(ScopeAttrValue const &attr_value); - ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); - ~ScopeAttrValue(); - - void SetIntValue(int64_t value); - void SetFloatValue(float32_t value); - ATTRIBUTED_DEPRECATED(void SetStringValue(const char_t *)) - void SetStringValue(std::string value); - void SetStringValue(const char_t *value); - void SetBoolValue(bool value); - - private: - class ScopeAttrValueImpl; - std::unique_ptr impl_; - friend class NodeAttrFeature; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { - public: - virtual bool Match(const Scope *scope) = 0; - ScopeBaseFeature() = default; - ScopeBaseFeature(const ScopeBaseFeature &) = delete; - ScopeBaseFeature &operator=(const ScopeBaseFeature &) = delete; - ScopeBaseFeature(ScopeBaseFeature &&) = delete; - ScopeBaseFeature &operator=(ScopeBaseFeature &&) = delete; - virtual ~ScopeBaseFeature()= default; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { - public: - ATTRIBUTED_DEPRECATED(NodeOpTypeFeature(const char_t *, int, int)) - NodeOpTypeFeature(std::string nodeType, int32_t num, int32_t step = 0); - NodeOpTypeFeature(const char_t *node_type, int32_t num, int32_t step = 0); - NodeOpTypeFeature(NodeOpTypeFeature const &feature); - NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); - ~NodeOpTypeFeature() override; - bool Match(const Scope *scope) override; - - private: - class NodeOpTypeFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { - public: - ATTRIBUTED_DEPRECATED(NodeAttrFeature(const char_t *, const char_t *, ge::DataType, ScopeAttrValue &)) - NodeAttrFeature(std::string nodeType, std::string attr_name, - ge::DataType datatype, ScopeAttrValue &attr_value); - NodeAttrFeature(const char_t *node_type, const char_t *attr_name, - ge::DataType data_type, ScopeAttrValue &attr_value); - NodeAttrFeature(NodeAttrFeature const &feature); - NodeAttrFeature &operator=(NodeAttrFeature const &feature); - ~NodeAttrFeature() override; - bool Match(const Scope *scope) override; - - private: - class NodeAttrFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { - public: - ATTRIBUTED_DEPRECATED(ScopeFeature(const char_t *, int32_t, const char_t *, const char_t *, int)) - ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", - std::string sub_scope_mask = "", int32_t step = 0); - ScopeFeature(const char_t *sub_type, int32_t num, const char_t *suffix, - const char_t *sub_scope_mask, int32_t step = 0); - ScopeFeature(ScopeFeature const &feature); - ScopeFeature &operator=(ScopeFeature const &feature); - ~ScopeFeature() override; - bool Match(const Scope *scope) override; - - private: - class ScopeFeatureImpl; - std::unique_ptr impl_; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { - public: - ScopePattern(); - ~ScopePattern(); - ATTRIBUTED_DEPRECATED(ScopePattern &SetSubType(const char_t *)) - ScopePattern &SetSubType(const std::string &sub_type); - ScopePattern &SetSubType(const char_t *sub_type); - ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); - ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); - ScopePattern &AddScopeFeature(ScopeFeature feature); - - private: - class ScopePatternImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { - public: - ScopesResult(); - ScopesResult(ScopesResult const &result); - ScopesResult &operator=(ScopesResult const &result); - ~ScopesResult(); - - void SetScopes(std::vector &scopes); - void SetNodes(std::vector &nodes); - - private: - class ScopesResultImpl; - std::unique_ptr impl_; - friend class ScopeBasePass; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { - public: - ScopeBasePass(); - ScopeBasePass(const ScopeBasePass &) = delete; - ScopeBasePass &operator=(const ScopeBasePass &) = delete; - ScopeBasePass(ScopeBasePass &&) = delete; - ScopeBasePass &operator=(ScopeBasePass &&) = delete; - virtual ~ScopeBasePass(); - - protected: - // Subclasses implement respective fusion strategies and build the Patterns - virtual std::vector DefinePatterns() = 0; - // Define the name of the scope pass - virtual std::string PassName() = 0; - // Subclasses implement respective multi-scope or operator fusion methods across scopes - virtual Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) = 0; - // Subclasses implement their own results and set the input and output of the final fusion operator - virtual void GenerateFusionResult(const std::vector &scopes, FusionScopesResult *fusion_rlt) = 0; - - private: - class ScopeBasePassImpl; - std::unique_ptr impl_; - friend class ge::ScopePassManager; - friend class ScopeBasePassImpl; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { - public: - using CreateFn = ScopeBasePass *(*)(); - ~ScopeFusionPassRegistry(); - - static ScopeFusionPassRegistry& GetInstance(); - - ATTRIBUTED_DEPRECATED(void RegisterScopeFusionPass(const char_t *, CreateFn, bool)) - void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); - - void RegisterScopeFusionPass(const char_t *pass_name, CreateFn create_fn, bool is_general); - - private: - ScopeFusionPassRegistry(); - class ScopeFusionPassRegistryImpl; - /*lint -e148*/ - std::unique_ptr impl_; - friend class TensorFlowModelParser; -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { - public: - ATTRIBUTED_DEPRECATED(static AscendString StringReplaceAll(const char_t *, const char_t *, const char_t *)) - static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); - static AscendString StringReplaceAll(const char_t *str, const char_t *old_value, const char_t *new_value); - static void FreeScopePatterns(ScopeFusionPatterns &patterns); - static void FreeOneBatchPattern(std::vector &one_batch_pattern); -}; - -class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { - public: - ScopeFusionPassRegistrar(const char_t *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); - ~ScopeFusionPassRegistrar() = default; -}; -} // namespace ge -#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ - REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, (pass_name), scope_pass, (is_general)) - -#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ - REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, (pass_name), scope_pass, (is_general)) - -#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ - static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ - ::ge::ScopeFusionPassRegistrar( \ - (pass_name), []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, (is_general)) -#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ - diff --git a/inc/external/register/tilingdata_base.h b/inc/external/register/tilingdata_base.h deleted file mode 100644 index 8618184816139c193bc55f24587a816698d0b715..0000000000000000000000000000000000000000 --- a/inc/external/register/tilingdata_base.h +++ /dev/null @@ -1,266 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_REGISTER_ASCENDC_TILINGDATA_BASE_HEADER__ -#define __INC_REGISTER_ASCENDC_TILINGDATA_BASE_HEADER__ - -#include -#include -#include -#include -#include -#include "graph/ascend_string.h" - -namespace optiling { -struct CharPtrCmp { - bool operator()(const char *strLeft, const char *strRight) const - { - return strcmp(strLeft, strRight) < 0; - } -}; - -class StructSizeInfoBase { -public: - static StructSizeInfoBase &GetInstance() - { - static StructSizeInfoBase instance; - return instance; - } - void SetStructSize(const char *structType, const size_t structSize) - { - if (structSizeInfo.find(structType) != structSizeInfo.end()) { - return; - } - structSizeInfo[structType] = structSize; - } - size_t GetStructSize(const char *structType) - { - return structSizeInfo.at(structType); - } -private: - StructSizeInfoBase() { }; - ~StructSizeInfoBase() { }; - StructSizeInfoBase(const StructSizeInfoBase &); - StructSizeInfoBase &operator=(const StructSizeInfoBase &); - std::map structSizeInfo; -}; - -class FieldInfo { -public: - FieldInfo(const char *dtype, const char *name) - : dtype_(dtype), name_(name), classType_("0") {} - FieldInfo(const char *dtype, const char *name, size_t arrSize) - : dtype_(dtype), name_(name), arrSize_(arrSize), classType_("1") {} - FieldInfo(const char *dtype, const char *name, const char *structType, - size_t structSize) - : dtype_(dtype), name_(name), structType_(structType), structSize_(structSize), classType_("2") {} - -public: - const char *dtype_; - const char *name_; - size_t arrSize_; - const char *structType_; - size_t structSize_; - const char *classType_; -}; - -class TilingDef { -public: - ~TilingDef() - { - if (!inited_data_ptr && data_ptr_ != nullptr) { - delete[] data_ptr_; - } - data_ptr_ = nullptr; - class_name_ = nullptr; - } - void SaveToBuffer(void *pdata, size_t capacity); - std::vector GetFieldInfo() const; - const char *GetTilingClassName() const; - size_t GetDataSize() const; - void SetDataPtr(void *dataPtr); - void CheckAlignAndGenPlaceHolder(const char *name, size_t typeSize); -protected: - void InitData(); - void GeLogError(const std::string& str) const; - // dtype, name - std::vector field_info_; - uint8_t *data_ptr_ = nullptr; - size_t data_size_ = 0; - const char *class_name_; - std::vector> saveBufferPtr; - size_t struct_size_ = 0; - bool inited_data_ptr = false; - uint32_t feature_bit_flag = 0; - uint8_t reserved_buf[128] = {0}; -}; - -using TilingDataConstructor = std::shared_ptr (*)(); - -class CTilingDataClassFactory { -public: - static CTilingDataClassFactory &GetInstance(); - void RegisterTilingData(const char *op_type, const TilingDataConstructor constructor); - std::shared_ptr CreateTilingDataInstance(const char *op_type); - -private: - CTilingDataClassFactory() { }; - ~CTilingDataClassFactory() { }; - CTilingDataClassFactory(const CTilingDataClassFactory &); - CTilingDataClassFactory &operator=(const CTilingDataClassFactory &); - std::map instance_; -}; - -class TilingDataStructBase { -public: - static TilingDataStructBase &GetInstance() - { - static TilingDataStructBase instance; - return instance; - } - uint32_t __attribute__((weak)) RecordTilingStruct(const char* name, const char* file, uint32_t line); -private: - TilingDataStructBase() { }; - ~TilingDataStructBase() { }; - TilingDataStructBase(const TilingDataStructBase &); - TilingDataStructBase &operator=(const TilingDataStructBase &); - std::map, CharPtrCmp> records; -}; -} // end of namespace optiling - -/* -example: -// supported data_type: int8_t/uint8_t/int16_t/uint16_t/int32_t/uint32_t/int64_t/uint64_t -BEGIN_TILING_DATA_DEF(MaxPoolTilingData) - // format: TILING_DATA_FIELD_DEF(data_type, field_name); - TILING_DATA_FIELD_DEF(int32_t, dim_0); - TILING_DATA_FIELD_DEF(uint8_t, var_1); - TILING_DATA_FIELD_DEF(int64_t, factor_1); -END_TILING_DATA_DEF -REGISTER_TILING_DATA_CLASS(MaxPool, MaxPoolTilingData) -*/ -#define BEGIN_TILING_DATA_DEF(class_name) \ -namespace { \ - static uint32_t class_name##tiling = [](void) { \ - if (&TilingDataStructBase::RecordTilingStruct != nullptr) { \ - return TilingDataStructBase::GetInstance().RecordTilingStruct(#class_name, __FILE__, __LINE__); \ - } else { \ - return static_cast(0); \ - } \ - }(); \ -} \ - class class_name : public TilingDef { \ - public: \ - size_t FieldHandler(const char *dtype, const char *name, size_t typeSize, const char* namePh) { \ - CheckAlignAndGenPlaceHolder(namePh, typeSize); \ - field_info_.emplace_back(FieldInfo(dtype, name)); \ - size_t ret_val = data_size_; \ - data_size_ += typeSize; \ - return ret_val; \ - } \ - size_t FieldHandler(const char *dtype, const char *name, size_t typeSize, \ - size_t arrSize, const char* namePh) { \ - CheckAlignAndGenPlaceHolder(namePh, typeSize); \ - field_info_.emplace_back(FieldInfo(dtype, name, arrSize)); \ - size_t ret_val = data_size_; \ - data_size_ += typeSize * arrSize; \ - return ret_val; \ - } \ - size_t FieldHandler(const char *dtype, const char *name, \ - const char *structType, size_t structSize, void *ptr, const char* namePh) { \ - CheckAlignAndGenPlaceHolder(namePh, 8); \ - field_info_.emplace_back(FieldInfo(dtype, name, structType, structSize)); \ - size_t ret_val = data_size_; \ - data_size_ += structSize; \ - saveBufferPtr.emplace_back(std::make_pair(ptr, ret_val)); \ - struct_size_ += structSize; \ - return ret_val; \ - } \ - \ - public: \ - class_name() { \ - class_name_ = #class_name; \ - CheckAlignAndGenPlaceHolder(#class_name"PH", 8); \ - StructSizeInfoBase::GetInstance().SetStructSize(#class_name, data_size_); \ - InitData(); \ - } \ - explicit class_name(void *ptr) { \ - class_name_ = #class_name; \ - CheckAlignAndGenPlaceHolder(#class_name"PH", 8); \ - StructSizeInfoBase::GetInstance().SetStructSize(#class_name, data_size_); \ - if (ptr == nullptr) { \ - return; \ - } \ - SetDataPtr(ptr); \ - } - -#define TILING_DATA_FIELD_DEF(data_type, field_name) \ - public: \ - void set_##field_name(data_type field_name) { \ - field_name##_ = field_name; \ - *((data_type *) (data_ptr_ + field_name##_offset_)) = field_name; \ - } \ - data_type get_##field_name() { return field_name##_; } \ - \ - private: \ - data_type field_name##_ = 0; \ - size_t field_name##_offset_ = FieldHandler(#data_type, #field_name, sizeof(data_type), #field_name"PH"); \ - uint8_t field_name##_reserve_buf_[16] = {0}; - -#define TILING_DATA_FIELD_DEF_ARR(arr_type, arr_size, field_name) \ - public: \ - void set_##field_name(arr_type *field_name) { \ - field_name##_ = field_name; \ - auto offset = field_name##_offset_; \ - if (data_ptr_ + offset == (uint8_t *)field_name) { \ - return; \ - } \ - const auto err_t = memcpy_s(data_ptr_ + offset, data_size_ - offset, field_name, (arr_size) * sizeof(arr_type)); \ - if (err_t != EOK) { \ - GeLogError("tilingdata_base.h TILING_DATA_FIELD_DEF_ARR memcpy is failed !"); \ - } \ - } \ - arr_type *get_##field_name() { \ - return (arr_type *)(data_ptr_ + field_name##_offset_); \ - } \ - \ - private: \ - arr_type *field_name##_ = nullptr; \ - size_t field_name##_offset_ = FieldHandler(#arr_type, #field_name, sizeof(arr_type), arr_size, #field_name"PH"); \ - uint8_t field_name##_reserve_buf_[16] = {0}; - -#define TILING_DATA_FIELD_DEF_STRUCT(struct_type, field_name) \ - public: \ - struct_type field_name{nullptr}; \ - \ - private: \ - size_t field_name##_offset_ = \ - FieldHandler("struct", #field_name, #struct_type, \ - StructSizeInfoBase::GetInstance().GetStructSize(#struct_type), \ - (void *) &field_name, #field_name"PH"); \ - uint8_t field_name##_reserve_buf_[16] = {0}; - -#define END_TILING_DATA_DEF \ - } \ - ; - -#define REGISTER_TILING_DATA_CLASS(op_type, class_name) \ -namespace { \ - class op_type##class_name##Helper { \ - public: \ - op_type##class_name##Helper() { \ - CTilingDataClassFactory::GetInstance().RegisterTilingData(#op_type, \ - op_type##class_name##Helper::CreateTilingDataInstance); \ - } \ - static std::shared_ptr CreateTilingDataInstance() { return std::make_shared(); } \ - }; \ - static class_name g_##op_type##class_name##init; \ - static op_type##class_name##Helper g_tilingdata_##op_type##class_name##helper; \ -} -#endif // __INC_REGISTER_ASCENDC_TILINGDATA_BASE_HEADER__ diff --git a/inc/external/register/tuning_bank_key_registry.h b/inc/external/register/tuning_bank_key_registry.h deleted file mode 100644 index 94e79d896c0608c4c184e3c6ab389b2e971447bf..0000000000000000000000000000000000000000 --- a/inc/external/register/tuning_bank_key_registry.h +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_REGISTER_TUNING_BANK_KEY_REGISTRY_HEADER__ -#define __INC_REGISTER_TUNING_BANK_KEY_REGISTRY_HEADER__ -#include -#include -#include -#include -#include "graph/ascend_string.h" -#include "register/register_types.h" -#include "exe_graph/runtime/tiling_context.h" - -// v1 stub -#define REGISTER_OP_BANK_KEY_CONVERT_FUN(op, opfunc) \ - REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER(op, (opfunc)) - -#define REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER(optype, opfunc) \ - REGISTER_OP_BANK_KEY_UNIQ(optype, (opfunc)) - -#define REGISTER_OP_BANK_KEY_UNIQ(optype, opfunc) \ - static tuningtiling::OpBankKeyFuncRegistry g_##optype##BankKeyRegistryInterf(#optype, (opfunc)) - -#define REGISTER_OP_BANK_KEY_PARSE_FUN(op, parse_func, load_func) \ - REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER(op, (parse_func), (load_func)) - -#define REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER(optype, parse_func, load_func) \ - REGISTER_OP_BANK_KEY_PARSE_UNIQ(optype, (parse_func), (load_func)) - -#define REGISTER_OP_BANK_KEY_PARSE_UNIQ(optype, parse_func, load_func) \ - static tuningtiling::OpBankKeyFuncRegistry g_##optype##BankParseInterf(#optype, (parse_func), (load_func)) - -// v2 -#define REGISTER_OP_BANK_KEY_CONVERT_FUN_V2(op, opfunc) \ - REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER_V2(op, (opfunc)) - -#define REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER_V2(optype, opfunc) \ - REGISTER_OP_BANK_KEY_UNIQ_V2(optype, (opfunc)) - -#define REGISTER_OP_BANK_KEY_UNIQ_V2(optype, opfunc) \ - static tuningtiling::OpBankKeyFuncRegistryV2 g_##optype##BankKeyRegistryInterf(#optype, (opfunc)) - -#define REGISTER_OP_BANK_KEY_PARSE_FUN_V2(op, parse_func, load_func) \ - REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER_V2(op, (parse_func), (load_func)) - -#define REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER_V2(optype, parse_func, load_func) \ - REGISTER_OP_BANK_KEY_PARSE_UNIQ_V2(optype, (parse_func), (load_func)) - -#define REGISTER_OP_BANK_KEY_PARSE_UNIQ_V2(optype, parse_func, load_func) \ - static tuningtiling::OpBankKeyFuncRegistryV2 g_##optype##BankParseInterf(#optype, (parse_func), (load_func)) - -#define TUNING_TILING_MAKE_SHARED(exec_expr0, exec_expr1) \ - do { \ - try { \ - exec_expr0; \ - } catch (...) { \ - exec_expr1; \ - } \ - } while (0) - -// v1 stub -#define DECLARE_STRUCT_RELATE_WITH_OP(op, bank_key, ...) \ - NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(bank_key, __VA_ARGS__); \ - static bool ParseFunc##op##bank_key(const std::shared_ptr &in_args, size_t len, ge::AscendString &bank_key_str) { \ - if (sizeof(bank_key_str) != len || in_args == nullptr) { \ - return false; \ - } \ - return false; \ - } \ - static bool LoadFunc##op##bank_key(std::shared_ptr &in_args, size_t &len, const ge::AscendString &bank_key_str) { \ - len = sizeof(bank_key_str); \ - TUNING_TILING_MAKE_SHARED(in_args = std::make_shared(), return false); \ - auto op_ky = std::static_pointer_cast(in_args); \ - return false; \ - } \ - REGISTER_OP_BANK_KEY_PARSE_FUN(op, ParseFunc##op##bank_key, LoadFunc##op##bank_key); - -// v2 -#define DECLARE_STRUCT_RELATE_WITH_OP_V2(op, bank_key, ...) \ - NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(bank_key, __VA_ARGS__); \ - static bool ParseFuncV2##op##bank_key(const std::shared_ptr &in_args, size_t len, \ - ge::AscendString &bank_key_json_str) { \ - if (sizeof(bank_key) != len || in_args == nullptr) { \ - return false; \ - } \ - nlohmann::json bank_key_json; \ - bank_key_json = *(std::static_pointer_cast(in_args)); \ - try { \ - std::string json_dump_str = bank_key_json.dump(); \ - bank_key_json_str = ge::AscendString(json_dump_str.c_str()); \ - } catch (std::exception& e) { \ - return false; \ - } \ - return true; \ - } \ - static bool LoadFuncV2##op##bank_key(std::shared_ptr &in_args, size_t &len, \ - const ge::AscendString &bank_key_json_str) { \ - len = sizeof(bank_key); \ - TUNING_TILING_MAKE_SHARED(in_args = std::make_shared(), return false); \ - nlohmann::json bank_key_json; \ - try { \ - bank_key_json = nlohmann::json::parse(bank_key_json_str.GetString()); \ - auto op_ky = std::static_pointer_cast(in_args); \ - *op_ky = bank_key_json.get(); \ - } catch (std::exception& e) { \ - return false; \ - } \ - return true; \ - } \ - REGISTER_OP_BANK_KEY_PARSE_FUN_V2(op, ParseFuncV2##op##bank_key, LoadFuncV2##op##bank_key); - - -namespace tuningtiling { -// v1兼容老版本om -using OpBankKeyConvertFun = std::function &, size_t &)>; -using OpBankParseFun = std::function &, size_t, ge::AscendString &)>; -using OpBankLoadFun = std::function &, size_t &, const ge::AscendString &)>; - -// v2 -using OpBankKeyConvertFunV2 = std::function &, size_t &)>; -using OpBankParseFunV2 = std::function &, size_t, ge::AscendString &)>; -using OpBankLoadFunV2 = std::function &, size_t &, const ge::AscendString &)>; -// v1兼容老版本om -class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncInfo { -public: - explicit OpBankKeyFuncInfo(const ge::AscendString &optype); - OpBankKeyFuncInfo() = default; - ~OpBankKeyFuncInfo() = default; - void SetOpConvertFunc(const OpBankKeyConvertFun &convert_func); - void SetOpParseFunc(const OpBankParseFun &parse_func); - void SetOpLoadFunc(const OpBankLoadFun &load_func); - const OpBankKeyConvertFun& GetBankKeyConvertFunc() const; - const OpBankParseFun& GetBankKeyParseFunc() const; - const OpBankLoadFun& GetBankKeyLoadFunc() const; - const ge::AscendString& GetOpType() const { - return optype_; - } - -private: - ge::AscendString optype_; - OpBankKeyConvertFun convert_func_; - OpBankParseFun parse_func_; - OpBankLoadFun load_func_; - -}; - -// v2 -class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncInfoV2 { -public: - explicit OpBankKeyFuncInfoV2(const ge::AscendString &optypeV2); - OpBankKeyFuncInfoV2() = default; - ~OpBankKeyFuncInfoV2() = default; - void SetOpConvertFuncV2(const OpBankKeyConvertFunV2 &convert_funcV2); - void SetOpParseFuncV2(const OpBankParseFunV2 &parse_funcV2); - void SetOpLoadFuncV2(const OpBankLoadFunV2 &load_funcV2); - const OpBankKeyConvertFunV2& GetBankKeyConvertFuncV2() const; - const OpBankParseFunV2& GetBankKeyParseFuncV2() const; - const OpBankLoadFunV2& GetBankKeyLoadFuncV2() const; - const ge::AscendString& GetOpTypeV2() const { - return optypeV2_; - } - -private: - ge::AscendString optypeV2_; - OpBankKeyConvertFunV2 convert_funcV2_; - OpBankParseFunV2 parse_funcV2_; - OpBankLoadFunV2 load_funcV2_; -}; - -// v1兼容老版本om -class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncRegistry { -public: - OpBankKeyFuncRegistry(const ge::AscendString &optype, const OpBankKeyConvertFun &convert_func); - OpBankKeyFuncRegistry(const ge::AscendString &optype, const OpBankParseFun &parse_func, - const OpBankLoadFun &load_func); - ~OpBankKeyFuncRegistry() = default; - static std::unordered_map &RegisteredOpFuncInfo(); -}; - -// v2 -class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncRegistryV2 { -public: - OpBankKeyFuncRegistryV2(const ge::AscendString &optype, const OpBankKeyConvertFunV2 &convert_funcV2); - OpBankKeyFuncRegistryV2(const ge::AscendString &optype, const OpBankParseFunV2 &parse_funcV2, - const OpBankLoadFunV2 &load_funcV2); - ~OpBankKeyFuncRegistryV2() = default; - static std::unordered_map &RegisteredOpFuncInfoV2(); -}; -} // namespace tuningtiling -#endif diff --git a/inc/external/register/tuning_tiling_reflection_utils.h b/inc/external/register/tuning_tiling_reflection_utils.h deleted file mode 100644 index bf44278486876bf45fb3b1f3c99bab032576ef2e..0000000000000000000000000000000000000000 --- a/inc/external/register/tuning_tiling_reflection_utils.h +++ /dev/null @@ -1,169 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_REGISTER_TUNING_TILING_REFLECTION_UTILS_HEADER__ -#define __INC_REGISTER_TUNING_TILING_REFLECTION_UTILS_HEADER__ -#include -#include -#include -#include - -namespace tuningtiling { -// implement for std c++11 -template -using decay_t = typename std::decay::type; - -template -using enable_if_t = typename std::enable_if::type; - -template -struct integer_sequence { - using value_type = T; - static constexpr std::size_t size() { - return sizeof...(Ints); - } -}; - -template -using index_sequence = integer_sequence; - -template -struct make_integer_sequence : make_integer_sequence {}; - -template -struct make_integer_sequence : integer_sequence {}; - -template -using make_index_sequence = make_integer_sequence; - -template -struct StructInfo { - static std::tuple<> Info() { - return std::make_tuple(); - } -}; - -#define DECLARE_SCHEMA(Struct, ...) \ - template<> \ - struct StructInfo { \ - static decltype(std::make_tuple(__VA_ARGS__)) Info() { \ - return std::make_tuple(__VA_ARGS__); \ - } \ - }; - -#define FIELD(class, FieldName) std::make_tuple(#FieldName, &class ::FieldName) - -template -void ForEachTuple(Tuple &&tuple, Field &&fields, Fn &&fn, index_sequence) { - (void) std::initializer_list { - (fn(std::get<0>(std::get(fields)), tuple.*std::get<1>(std::get(fields))), Is)...}; -} - -template -void ForEachTuple(Tuple &&tuple, Fn &&fn) { - const auto fields = StructInfo>::Info(); - ForEachTuple(std::forward(tuple), fields, std::forward(fn), - make_index_sequence::value> {}); -} - -template -struct is_optional : std::false_type {}; - -template -struct is_optional> : std::true_type {}; - -template -bool is_optional_v() { - return is_optional>::value; -} - -template -decltype(std::begin(T()), std::true_type {}) containable(size_t); - -template -std::false_type containable(...); - -template -using is_containable = decltype(containable(0U)); - -template -constexpr bool IsSerializeType() { - return ((!std::is_class>::value) || is_containable>()); -} - -template -void ForEachField(T &&value, Fn &&fn) { - ForEachTuple(std::forward(value), std::forward(fn)); -} - -template -struct DumpFunctor; - -template()>* = nullptr> -void DumpObj(T &&obj, const std::string &field_name, Js &j) { - if (field_name.empty()) { - ForEachField(std::forward(obj), DumpFunctor(j)); - return; - } - ForEachField(std::forward(obj), DumpFunctor(j[field_name])); -} - -template()>* = nullptr> -void DumpObj(T &&obj, const std::string &field_name, Js &j) { - if (field_name.empty()) { - return; - } - j[field_name] = std::forward(obj); -} - -template -struct DumpFunctor { - explicit DumpFunctor(T &j) : js(j) {} - template - void operator()(Name &&name, Field &&field) const { - DumpObj(std::forward(field), std::forward(name), js); - } - T &js; -}; - -template -struct FromJsonFunctor; - -template()>* = nullptr> -void FromJsonImpl(T &&obj, const std::string &field_name, const Js &j) { - if (field_name.empty()) { - ForEachField(std::forward(obj), FromJsonFunctor(j)); - return; - } - if (j.find(field_name) == j.cend()) { - return; - } - ForEachField(std::forward(obj), FromJsonFunctor(j[field_name])); -} - -template()>* = nullptr> -void FromJsonImpl(T &&obj, const std::string &field_name, const Js &j) { - // ignore missing field of optional - if ((tuningtiling::is_optional_v()) || (j.find(field_name) == j.cend())) { - return; - } - j.at(field_name).get_to(std::forward(obj)); -} - -template -struct FromJsonFunctor { - explicit FromJsonFunctor(const Js &j) : js(j) {} - template - void operator()(Name &&name, Field &&field) const { - FromJsonImpl(std::forward(field), std::forward(name), js); - } - const Js &js; -}; -} // namespace tuningtiling -#endif diff --git a/inc/external/register/tuning_tiling_registry.h b/inc/external/register/tuning_tiling_registry.h deleted file mode 100644 index a2fe31c6cc4d28ea205fe86725d95f0d26997f51..0000000000000000000000000000000000000000 --- a/inc/external/register/tuning_tiling_registry.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_REGISTER_TUNING_TILING_REGISTRY_HEADER__ -#define __INC_REGISTER_TUNING_TILING_REGISTRY_HEADER__ -#include -#include -#include -#include -#include "graph/ascend_string.h" -#include "register/tuning_tiling_reflection_utils.h" -namespace tuningtiling { -struct TilingItem { - ge::AscendString dtype_; - ge::AscendString name_; -}; - -class TuningTilingDef { -public: - virtual void FromJson(const nlohmann::json &j) = 0; - virtual void ToJson(nlohmann::json &j) = 0; - ge::AscendString GetClassName() const; - virtual std::vector GetItemInfo() const = 0; - -protected: - TuningTilingDef() = default; - virtual ~TuningTilingDef() = default; - // dtype , name - std::vector field_info_; - ge::AscendString class_name_; -}; - -#define BEGIN_TUNING_TILING_DEF(class_name) \ - class class_name : public TuningTilingDef { \ - public: \ - virtual void FromJson(const nlohmann::json &j) { \ - FromJsonImpl(*this, "", j); \ - } \ - \ - virtual void ToJson(nlohmann::json &j) { \ - DumpObj(*this, "", j); \ - } \ - \ - std::vector GetItemInfo() const { \ - return field_info_; \ - } \ - \ - class FieldHandler { \ - public: \ - FieldHandler(class_name *pinstance, const ge::AscendString &dtype, const ge::AscendString &name) { \ - pinstance->field_info_.push_back( {dtype, name}); \ - } \ - }; \ - friend class FieldHandler; \ - \ - public: \ - class_name() { \ - class_name_ = #class_name; \ - }; - -#define TUNING_TILING_DATA_FIELD_DEF(data_type, field_name) \ - public: \ - data_type field_name; \ - FieldHandler field_name##_handler_ = FieldHandler(this, #data_type, #field_name); - -#define END_TUNING_TILING_DEF \ - } \ - ; - -using TuningTilingDefConstructor = std::shared_ptr (*)(); -class TuningTilingClassFactory { -public: - static std::map &RegisterInfo(); - static void RegisterTilingData(const ge::AscendString &optype, TuningTilingDefConstructor const constructor); - static std::shared_ptr CreateTilingDataInstance(const ge::AscendString &optype); -}; - -#define REGISTER_TUNING_TILING_CLASS(optype, class_name) \ - class optype##Helper { \ - public: \ - optype##Helper() { \ - TuningTilingClassFactory::RegisterTilingData(#optype, optype##Helper::CreateTilingDataInstance); \ - } \ - static std::shared_ptr CreateTilingDataInstance() { \ - return std::make_shared(); \ - } \ - }; \ - optype##Helper g_tuning_tiling_##optype##Helper; -using TuningTilingDefPtr = std::shared_ptr; -} // namespace tuningtiling - -#endif diff --git a/inc/graph/args_format_desc.h b/inc/graph/args_format_desc.h deleted file mode 100644 index 74506b888de66e38a559cc2d0bd12be87734ac55..0000000000000000000000000000000000000000 --- a/inc/graph/args_format_desc.h +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_ARGS_FORMAT_H -#define METADEF_CXX_ARGS_FORMAT_H - -#include -#include - -#include "common/ge_common/debug/ge_log.h" -#include "graph/ge_error_codes.h" -#include "graph/op_desc.h" -#include "graph/node.h" -#include "register/hidden_inputs_func_registry.h" -#include "graph/utils/args_format_desc_utils.h" - -namespace ge { - -// Meaningful only for PLACEHOLDER and CUSTOM_VALUE. -enum class ArgsFormatWidth : int32_t { - BIT64 = -1, - BIT32 = -2, -}; - -struct SkArgDesc { - AddrType addr_type; - int32_t ir_idx; - bool folded; - AddrType sub_addr_type; - int32_t sub_idx; -}; -static_assert(std::is_standard_layout::value, "The class SkArgDesc must be a POD"); - -struct SkArgDescV2 { - AddrType addr_type; - int32_t ir_idx; - uint32_t reserved; - AddrType sub_addr_type; - int32_t sub_idx; -}; -static_assert(std::is_standard_layout::value, "The class SkArgDescV2 must be a POD"); - -class ArgsFormatDesc { - public: - // i* -> ir_idx = -1, folded=false - // 对于输入输出,idx表示ir定义的idx,-1表示所有输入、所有输出,此时非动态输入、输出默认展开,动态输出要i1*这样才表示展开 - // 对于workspace -1表示个数未知,folded暂时无意义 - // 对ffts尾块非尾块地址,idx=0表示非尾块,idx=1表示尾块 - // 对于其他类型, idx和fold 暂时没有意义 - void Append(AddrType type, int32_t ir_idx = -1, bool folded = false); - - void Clear(); - - void AppendTilingContext(TilingContextSubType sub_type = TilingContextSubType::TILING_CONTEXT); - void AppendCustomValue(uint64_t value, ArgsFormatWidth width = ArgsFormatWidth::BIT64); - void AppendPlaceholder(ArgsFormatWidth width = ArgsFormatWidth::BIT64); - - std::string ToString() const; - - graphStatus GetArgsSize(const OpDescPtr &op_desc, size_t &args_size) const; - - static graphStatus GetArgSize(const OpDescPtr &op_desc, const ArgDesc arg_desc, size_t &arg_size); - - static graphStatus Parse(const OpDescPtr &op_desc, const std::string &str, std::vector &arg_descs); - - // 为了方便使用,字符串用i*这样的通配符时,返回的argDesc会按照实际个数展开 - // easy mode 不需要进行展开,只根据字面值做反序列化,允许不传op_desc - static graphStatus Parse(const OpDescPtr &op_desc, const std::string &str, std::vector &arg_descs, - const bool easy_mode); - - // 抽取公共序列化/反序列化函数 - static std::string Serialize(const std::vector &arg_descs); - - using const_iterator = std::vector::const_iterator; - const_iterator begin() const { return arg_descs_.begin(); } - const_iterator end() const { return arg_descs_.end(); } - - static graphStatus FromString(ArgsFormatDesc &format, - const OpDescPtr &op_desc, const std::string &str, const bool easy_mode = false) { - return Parse(op_desc, str, format.arg_descs_, easy_mode); - } - - static graphStatus ConvertArgDescSkToNormal(const ArgDesc &sk_arg_desc, ArgDesc &arg_desc, int32_t &sub_op_id); - - static graphStatus ConvertToSuperKernelArgFormat(const NodePtr &sk_node, - const NodePtr &sub_node, const std::string &sub_node_arg_format, - std::string &sk_arg_format); - - private: - std::vector arg_descs_; -}; -} // namespace ge - -#endif // METADEF_CXX_ARGS_FORMAT_H diff --git a/inc/graph/ascendc_ir/ascend_reg_ops.h b/inc/graph/ascendc_ir/ascend_reg_ops.h deleted file mode 100644 index e4080b394c07df628d4f93d8d9cc295ce55a375d..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/ascend_reg_ops.h +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#ifndef GRAPH_ASCEND_REG_OPS_H -#define GRAPH_ASCEND_REG_OPS_H - -#include -#include -#include "graph/operator_reg.h" -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" - -#define OP_END_FACTORY_REG_WITHOUT_REGISTER(x) __OP_END_IMPL_WITHOUT_REGISTER__(x) - -#endif // GRAPH_ASCEND_REG_OPS_H diff --git a/inc/graph/ascendc_ir/ascendc_ir_check.h b/inc/graph/ascendc_ir/ascendc_ir_check.h deleted file mode 100644 index 104850ede715b9b4599bcd55d77008a4c6d1e42e..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/ascendc_ir_check.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#ifndef METADEF_CXX_INC_GRAPH_ASCENDC_IR_ASCENDC_IR_CHECK_H_ -#define METADEF_CXX_INC_GRAPH_ASCENDC_IR_ASCENDC_IR_CHECK_H_ -#include "common/ge_common/debug/ge_log.h" -#include "common/checker.h" - -namespace ge { -class AscIRException : public std::exception { - public: - struct Info { - graphStatus error_code; - std::string error_msg; - }; - explicit AscIRException(const Info &info); - const Info &GetInfo() const; - const char *what() const noexcept override { - return info_.error_msg.c_str(); - } - private: - Info info_; -}; -} - -inline bool IsVarNameValidAllowEmpty(const std::string &str) { - if (str.empty()) { - return true; - } - // 首字符必须是字母或下划线 - char first = str[0]; - if (!std::isalpha(static_cast(first)) && first != '_') { - return false; - } - - // 后续字符只能是字母、数字或下划线 - for (size_t i = 1U; i < str.size(); ++i) { - char c = str[i]; - if (!std::isalnum(static_cast(c)) && c != '_') { - return false; - } - } - - return true; -} - -#define CHECK_NOTNULL_WITH_THROW_EXCEPTION(val, ...) \ - ASCIR_ASSERT_NOTNULL(val, __VA_ARGS__) - -#define CHECK_BOOL_WITH_THROW_EXCEPTION(val, ...) \ - ASCIR_ASSERT((val), __VA_ARGS__) - -#define ASCIR_ASSERT(exp, ...) \ - do { \ - if (!(exp)) { \ - auto msg = CreateErrorMsg(__VA_ARGS__); \ - if (msg.empty()) { \ - REPORT_INNER_ERR_MSG("E19999", "Assert %s failed", #exp); \ - GELOGE(ge::FAILED, "Assert %s failed", #exp); \ - throw ge::AscIRException({ge::FAILED, #exp}); \ - } else { \ - REPORT_INNER_ERR_MSG("E19999", "%s", msg.data()); \ - GELOGE(ge::FAILED, "%s", msg.data()); \ - throw ge::AscIRException({ge::FAILED, msg.data()}); \ - } \ - } \ - } while (false) -#define ASCIR_ASSERT_NOTNULL(v, ...) ASCIR_ASSERT(((v) != nullptr), __VA_ARGS__) -#endif // METADEF_CXX_INC_GRAPH_ASCENDC_IR_ASCENDC_IR_CHECK_H_ diff --git a/inc/graph/ascendc_ir/ascendc_ir_core/OWNERS b/inc/graph/ascendc_ir/ascendc_ir_core/OWNERS deleted file mode 100644 index 8910042248d7bd64f9a34c5dc16f6654c1ba5281..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/ascendc_ir_core/OWNERS +++ /dev/null @@ -1,13 +0,0 @@ -approvers: -- wqtshg -- wangxiaotian22 -- zhangfan_hq - -reviewers: -- t86l -- sheng-nan -- xchu42 -- guang-jun-zhang2 - -options: - no_parent_owners: true diff --git a/inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h b/inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h deleted file mode 100644 index 97c7250e75106a28344ba826b1a814f8c2e05578..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h +++ /dev/null @@ -1,614 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_ASCENDC_IR_H -#define GRAPH_ASCENDC_IR_H - -#include -#include -#include "attr_store.h" -#include "graph/compute_graph.h" -#include "graph/symbolizer/symbolic.h" -#include "graph/node.h" -#include "graph/anchor.h" -#include "debug/ge_util.h" -#include "graph/utils/op_desc_utils.h" -#include "external/graph/operator.h" -#include "graph/utils/type_utils.h" -#include "graph/ascendc_ir/ascendc_ir_check.h" -#include "graph/expression/const_values.h" -#include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h" - -namespace ge { -struct DiffAxesInfo { - std::vector add_axes; - std::vector del_axes; -}; -struct View { - std::vector axis_ids; - std::vector repeats; - std::vector strides; -}; -using TransInfoRoadOfGraph = std::vector; - -// 默认实现 -template -std::string ViewMemberToString(const std::vector &vec) { - std::ostringstream oss; - oss << "["; - for (size_t i = 0; i < vec.size(); ++i) { - oss << vec[i]; - if (i < vec.size() - 1) { - oss << ", "; - } - } - oss << "]"; - return oss.str(); -} - -// 特化实现,针对 ge::Expression 类型 -template<> -inline std::string ViewMemberToString(const std::vector &vec) { - std::ostringstream oss; - oss << "["; - for (size_t i = 0; i < vec.size(); ++i) { - const ge::Expression &expr = vec[i]; - oss << (expr.Str() != nullptr ? expr.Str().get() : std::string("null")); - if (i < vec.size() - 1) { - oss << ", "; - } - } - oss << "]"; - return oss.str(); -} - -inline std::string ViewToString(const View &view) { - std::string result = "{ axis: " + ViewMemberToString(view.axis_ids) + - ", repeats: " + ViewMemberToString(view.repeats) + - ", strides: " + ViewMemberToString(view.strides) + - " }"; - return result; -} - -class AscOutputAttrDataType { - public: - AscOutputAttrDataType(ge::Operator *op, uint32_t output_index) : op_(op), output_index_(output_index) {} - - AscOutputAttrDataType &operator=(const ge::DataType &value) { - if (op_ == nullptr) { - GELOGE(PARAM_INVALID, "op_ is null"); - return *this; - } - const auto desc = ge::OpDescUtils::GetOpDescFromOperator(*op_); - if (desc == nullptr) { - GELOGE(PARAM_INVALID, "desc is null"); - return *this; - } - if (desc->MutableOutputDesc(output_index_) == nullptr) { - GELOGE(PARAM_INVALID, - "output_index_ %u is invalid for %s %s", - output_index_, - desc->GetNamePtr(), - desc->GetTypePtr()); - return *this; - } - desc->MutableOutputDesc(output_index_)->SetDataType(value); - return *this; - } - - operator ge::DataType() const { - if (op_ == nullptr) { - GELOGE(PARAM_INVALID, "op_ is null"); - return ge::DT_UNDEFINED; - } - const auto desc = ge::OpDescUtils::GetOpDescFromOperator(*op_); - if (desc == nullptr) { - GELOGE(PARAM_INVALID, "desc is null"); - return ge::DT_UNDEFINED; - } - if (desc->MutableOutputDesc(output_index_) == nullptr) { - GELOGE(PARAM_INVALID, - "output_index_ %u is invalid for %s %s", - output_index_, - desc->GetNamePtr(), - desc->GetTypePtr()); - return ge::DT_UNDEFINED; - } - return desc->MutableOutputDesc(output_index_)->GetDataType(); - }; - - private: - ge::Operator *op_{nullptr}; - uint32_t output_index_{UINT32_MAX}; -}; - -class AscOutputAttrFormat { - public: - AscOutputAttrFormat(ge::Operator *op, uint32_t output_index) : op_(op), output_index_(output_index) {} - - AscOutputAttrFormat &operator=(const ge::Format &value) { - if (op_ == nullptr) { - GELOGE(PARAM_INVALID, "op_ is null"); - return *this; - } - const auto desc = ge::OpDescUtils::GetOpDescFromOperator(*op_); - if (desc == nullptr) { - GELOGE(PARAM_INVALID, "desc is null"); - return *this; - } - if (desc->MutableOutputDesc(output_index_) == nullptr) { - GELOGE(PARAM_INVALID, - "output_index_ %u is invalid for %s %s", - output_index_, - desc->GetNamePtr(), - desc->GetTypePtr()); - return *this; - } - desc->MutableOutputDesc(output_index_)->SetFormat(value); - return *this; - } - - operator ge::Format() const { - if (op_ == nullptr) { - GELOGE(PARAM_INVALID, "op_ is null"); - return ge::FORMAT_RESERVED; - } - const auto desc = ge::OpDescUtils::GetOpDescFromOperator(*op_); - if (desc == nullptr) { - GELOGE(PARAM_INVALID, "desc is null"); - return ge::FORMAT_RESERVED; - } - if (desc->MutableOutputDesc(output_index_) == nullptr) { - GELOGE(PARAM_INVALID, - "output_index_ %u is invalid for %s %s", - output_index_, - desc->GetNamePtr(), - desc->GetTypePtr()); - return ge::FORMAT_RESERVED; - } - return desc->MutableOutputDesc(output_index_)->GetFormat(); - }; - - private: - ge::Operator *op_{nullptr}; - uint32_t output_index_{UINT32_MAX}; -}; - -class AscOpOutput { - public: - class AscOpOutputOffsetHelper { - friend class AscOpOutput; - public: - ~AscOpOutputOffsetHelper() = default; - void operator=(const AscOpOutput &asc_op_output) { - // asc_op_output.op_保证非空 - output_.op_ = asc_op_output.op_; - output_.output_index = asc_op_output.output_index; - output_.dtype = asc_op_output.dtype; - output_.format = asc_op_output.format; - output_.axis = asc_op_output.axis; - output_.repeats = asc_op_output.repeats; - output_.strides = asc_op_output.strides; - *asc_op_output.vectorized_axis = output_.load_vectorized_axes_; - output_.vectorized_axis = asc_op_output.vectorized_axis; - output_.vectorized_strides = asc_op_output.vectorized_strides; - output_.mem = asc_op_output.mem; - output_.que = asc_op_output.que; - output_.buf = asc_op_output.buf; - output_.opt = asc_op_output.opt; - } - private: - explicit AscOpOutputOffsetHelper(AscOpOutput &output) : output_(output) {} - AscOpOutput &output_; - }; - template - friend class AscOpInput; - template - friend class AscOpDynamicInput; - friend class VectorizedOutTensor; - AscOpOutput() - : op_(nullptr), - load_vectorized_axes_({}), - output_index(UINT32_MAX), - dtype(nullptr, output_index), - format(nullptr, output_index), - axis(nullptr), - repeats(nullptr), - strides(nullptr), - vectorized_axis(nullptr), - vectorized_strides(nullptr), - mem(nullptr), - que(nullptr), - buf(nullptr), - opt(nullptr) {} - AscOpOutput(ge::Operator *op, uint32_t output_idx) - : op_(op), output_index(output_idx), dtype(op, output_index), format(op, output_index) { - TryInitTensorAttr(); - } - explicit AscOpOutput(std::vector axis_ids) - : op_(nullptr), - load_vectorized_axes_(std::move(axis_ids)), - output_index(0), - dtype(nullptr, output_index), - format(nullptr, output_index), - axis(nullptr), - repeats(nullptr), - strides(nullptr), - vectorized_axis(nullptr), - vectorized_strides(nullptr), - mem(nullptr), - que(nullptr), - buf(nullptr), - opt(nullptr) {} - AscOpOutput(const AscOpOutput &output) : AscOpOutput(output.op_, output.output_index) {} - AscOpOutput(AscOpOutput &&output) noexcept: AscOpOutput(output.op_, output.output_index) {} - AscOpOutput &operator=(AscOpOutput &&) = delete; - AscOpOutput &operator=(const AscOpOutput &) = delete; - - bool SetContiguousView(const std::vector &axes) { - GE_ASSERT_NOTNULL(axis, "output tensor should bind to API by API function or by AutoOffset"); - GE_ASSERT_NOTNULL(repeats, "output tensor should bind to API by API function or by AutoOffset"); - GE_ASSERT_NOTNULL(strides, "output tensor should bind to API by API function or by AutoOffset"); - std::vector axes_ids; - std::vector tmp_repeats; - std::vector tmp_strides; - axes_ids.reserve(axes.size()); - tmp_repeats.reserve(axes.size()); - tmp_strides.reserve(axes.size()); - - std::for_each(axes.rbegin(), axes.rend(), - [&axes_ids, &tmp_repeats, &tmp_strides](const Axis &tmp_axis) { - if (tmp_strides.empty()) { - tmp_strides.emplace_back(sym::kSymbolOne); - } else { - tmp_strides.emplace_back(*tmp_repeats.rbegin() * *tmp_strides.rbegin()); - } - tmp_repeats.emplace_back(tmp_axis.size); - axes_ids.emplace_back(tmp_axis.id); - }); - std::reverse(axes_ids.begin(), axes_ids.end()); - std::reverse(tmp_repeats.begin(), tmp_repeats.end()); - std::reverse(tmp_strides.begin(), tmp_strides.end()); - - *axis = axes_ids; - *repeats = tmp_repeats; - *strides = tmp_strides; - return true; - } - - const ge::Operator &GetOwnerOp() const { - return *op_; - } - - ge::Operator &MutableOwnerOp() const{ - return *op_; - } - - void TryInitTensorAttr() { - auto tensor_attr_ptr = AscTensorAttr::GetTensorAttrPtr(op_, output_index); - if (tensor_attr_ptr == nullptr) { - return; - } - auto &tensor_attr = *tensor_attr_ptr; - axis = &tensor_attr.axis; - repeats = &tensor_attr.repeats; - strides = &tensor_attr.strides; - vectorized_axis = &tensor_attr.vectorized_axis; - vectorized_strides = &tensor_attr.vectorized_strides; - mem = &tensor_attr.mem; - que = &tensor_attr.que; - buf = &tensor_attr.buf; - opt = &tensor_attr.opt; - } - AscOpOutput &Use(const AscOpOutput &used_out) { - if (op_ == nullptr) { - GELOGE(PARAM_INVALID, "output tensor should bind to API by API function or by AutoOffset"); - return *this; - } - if (HasBindToContainer()) { - GELOGE(PARAM_INVALID, " this tensor has been bound to a que or buf, can not be repeated bound."); - return *this; - } - if (!used_out.HasBindToContainer()) { - GELOGE(PARAM_INVALID, " tensor to be used has not been bound to any que or buf."); - return *this; - } - if (used_out.que->id != kIdNone) { - (void) UseTQue(used_out.mem->position, used_out.que->depth, used_out.que->buf_num, used_out.que->id); - } - if (used_out.buf->id != kIdNone) { - (void) UseTBuf(used_out.mem->position, used_out.buf->id); - } - mem->reuse_id = used_out.mem->reuse_id; - return *this; - } - - AscOpOutput &TQue(const Position pos, const int64_t depth, const int64_t buf_num) { - if (op_ == nullptr) { - GELOGE(PARAM_INVALID, "output tensor should bind to API by API function or by AutoOffset"); - return *this; - } - mem->reuse_id = GenNextReuseId(); - (void) UseTQue(pos, depth, buf_num); - return *this; - } - - AscOpOutput &TBuf(const Position pos) { - if (op_ == nullptr) { - GELOGE(PARAM_INVALID, "output tensor should bind to API by API function or by AutoOffset"); - return *this; - } - mem->reuse_id = GenNextReuseId(); - UseTBuf(pos); - return *this; - } - AscOpOutputOffsetHelper AutoOffset() { - return AscOpOutputOffsetHelper(*this); - } - - private: - int64_t GenContainerId(); - int64_t GenNextReuseId(); - bool UseTQue(const Position pos, const int64_t depth, const int64_t buf_num, const int64_t id = kIdNone); - bool UseTBuf(const Position pos, const int64_t id = kIdNone); - bool HasBindToContainer() const; - ge::Operator *op_; - std::vector load_vectorized_axes_; - public: - uint32_t output_index{UINT32_MAX}; - AscOutputAttrDataType dtype; - AscOutputAttrFormat format; - std::vector *axis{}; - std::vector *repeats{}; - std::vector *strides{}; - std::vector *vectorized_axis{}; - std::vector *vectorized_strides{}; - MemAttr *mem{}; - MemQueAttr *que{}; - MemBufAttr *buf{}; - MemOptAttr *opt{}; -}; - -class VectorizedOutTensor { - public: - explicit VectorizedOutTensor(std::vector vectorized_axis) : vectorized_axis_(std::move(vectorized_axis)) { - } - VectorizedOutTensor &operator=(const VectorizedOutTensor &) = delete; - VectorizedOutTensor(const VectorizedOutTensor &) = delete; - explicit operator AscOpOutput() const { - GE_ASSERT_NOTNULL(op_); - return AscOpOutput(op_, output_index_); - } - void operator=(AscOpOutput &&asc_op_output) { - // 不支持更改归属 - if (op_ != nullptr) { - AscendString name; - (void) op_->GetName(name); - GELOGE(FAILED, "Tensor has been bind to %s", name.GetString()); - return; - } - op_ = asc_op_output.op_; - output_index_ = asc_op_output.output_index; - // 修改归属op的向量化轴信息 - AscTensorAttr::GetTensorAttr(op_, output_index_).vectorized_axis = vectorized_axis_; - } - private: - std::vector vectorized_axis_; - ge::Operator *op_{nullptr}; - uint32_t output_index_{UINT32_MAX}; -}; - -graphStatus AddEdgeForNode(const ge::Operator &src_op, int32_t src_index, ge::Operator &dst_op, int32_t dst_index); -graphStatus LinkByIrIndex(const ge::Operator &src_op, - uint32_t src_ir_index, - ge::Operator &dst_op, - uint32_t dst_ir_index, - uint32_t dynamic_index = 0U); -graphStatus SetDynamicInputNumByIrIndex(ge::Operator &op, uint32_t ir_index, uint32_t dynamic_num); - -template -class AscOpInput { - public: - explicit AscOpInput(ge::Operator *op) : op_(op) {} - - AscOpInput &operator=(const AscOpOutput &output) { - if (op_ == nullptr || output.op_ == nullptr) { - GELOGE(FAILED, "Invalid op, make sure construct func is called."); - return *this; - } - LinkByIrIndex(*output.op_, output.output_index, *this->op_, INPUT_INDEX); - return *this; - } - - private: - ge::Operator *op_; -}; - -struct AscTensor { - explicit AscTensor(const ge::OutDataAnchor &an) : attr(AscTensorAttr::GetTensorAttr(an)), anchor(an) {} - ~AscTensor() = default; - AscTensorAttr &attr; // not owner - const ge::OutDataAnchor &anchor; // not owner -}; - -struct AscNodeOutputs { - friend class AscNode; - AscTensor &operator[](uint32_t index); - std::vector operator()(); - private: - explicit AscNodeOutputs(ge::Node *node) : node_(node) { - Init(); - } - void Init(); - std::vector tensors_; - Node *node_; -}; - -struct AscNodeInputs { - friend class AscNode; - AscTensor &operator[](uint32_t index); - std::vector operator()(); - uint32_t Size(); - private: - explicit AscNodeInputs(ge::Node *node) : node_(node) { - Init(); - } - void Init(); - std::vector tensors_; - Node *node_; -}; - -class AscNode : public Node { - public: - AscNode(const OpDescPtr &op_desc, const ComputeGraphPtr &compute_graph); - AscNodeInputs inputs; - AscNodeOutputs outputs; - AscNodeAttr &attr; -}; -using AscNodePtr = std::shared_ptr; - -class AscNodeIter { - public: - explicit AscNodeIter(ge::ComputeGraph::Vistor::Iterator &&iter); - AscNodeIter &operator++(); - AscNodePtr operator*(); - bool operator!=(const AscNodeIter &other) const; - private: - ge::ComputeGraph::Vistor::Iterator impl_; -}; - -class AscNodeVisitor { - public: - using Iterator = ge::ComputeGraph::Vistor::Iterator; - AscNodeIter begin(); - AscNodeIter end(); - explicit AscNodeVisitor(ge::ComputeGraph::Vistor &&visitor); - private: - ge::ComputeGraph::Vistor impl_; -}; - - -template -class AscOpDynamicInput { - public: - explicit AscOpDynamicInput(ge::Operator *op) : op_(op) {} - AscOpDynamicInput &operator=(const std::initializer_list &outputs) { - return AssignImpl(outputs); - } - AscOpDynamicInput &operator=(const std::vector &outputs) { - return AssignImpl(outputs); - } - - private: - template - AscOpDynamicInput &AssignImpl(const Container &outputs) { - if (op_ == nullptr) { - GELOGE(FAILED, "op_ in null"); - return *this; - } - if (inited_) { - AscendString op_name; - (void) op_->GetName(op_name); - GELOGE(FAILED, "It is not allowed to set the dynamic input repeatedly, node:[%s].", op_name.GetString()); - return *this; - } - const size_t input_nums = outputs.size(); - SetDynamicInputNumByIrIndex(*this->op_, INPUT_INDEX, input_nums); - size_t idx = 0UL; - for (const auto &output : outputs) { - if (output.op_ == nullptr) { - GELOGE(FAILED, "Src tensor is null"); - return *this; - } - LinkByIrIndex(*output.op_, output.output_index, *this->op_, INPUT_INDEX, idx++); - } - inited_ = true; - return *this; - } - ge::Operator *op_{nullptr}; - bool inited_{false}; -}; - -class AscGraphImpl; -namespace ascir { -namespace cg { -class CodeGenUtils; -} -} -class AscGraph { - friend class ascir::cg::CodeGenUtils; - friend class AscGraphUtils; - public: - explicit AscGraph(const char *name); - ~AscGraph(); - void SetTilingKey(const uint32_t tiling_key); - int64_t GetTilingKey() const; - void SetGraphType(const AscGraphType type); - AscGraphType GetGraphType() const; - graphStatus CreateSizeVar(const Expression &expression); - Expression CreateSizeVar(const int64_t value); - Expression CreateSizeVar(const std::string &name); - Axis &CreateAxis(const std::string &name, const Expression &size); - Axis &CreateAxis(const std::string &name, Axis::Type type, const Expression &size, const std::vector &from, - AxisId split_peer); - Axis *FindAxis(const int64_t axis_id); - AscNodePtr AddNode(ge::Operator &op); - AscNodePtr FindNode(const char *name) const; - std::pair BlockSplit(const int64_t axis_id, const std::string &outer_axis_name = "", - const std::string &inner_axis_name = ""); - std::pair TileSplit(const int64_t axis_id, const std::string &outer_axis_name = "", - const std::string &inner_axis_name = ""); - AxisPtr MergeAxis(const std::vector &axis_ids, const std::string &merge_axis_name = ""); - bool BindBlock(const int64_t outter_id, const int64_t inner_id); - bool ApplySplit(const AscNodePtr &node, const int64_t outter_id, const int64_t inner_id); - bool ApplyMerge(const AscNodePtr &node, const int64_t merged_axis_id); - bool ApplyReorder(const AscNodePtr &node, const std::vector &reordered_axis); - bool ApplySchedAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id, const std::vector &original); - bool ApplySchedAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id); - bool ApplyTensorAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id, const std::vector &original); - bool ApplyTensorAxisMerge(const AscNodePtr &node, const int64_t merged_axis_id); - bool ApplySchedAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis); - bool ApplyTensorAxisReorder(const AscNodePtr &node, const std::vector &reordered_axis); - bool TryApplyAxisReplace(const AscNodePtr &node, const Axis &src, const Axis &dst); - AscNodeVisitor GetAllNodes() const; - AscNodeVisitor GetInputNodes() const; - std::vector GetAllSizeVar() const; - std::vector GetAllAxis() const; - TransInfoRoadOfGraph GetAllAxisTransInfo() const; - std::string GetName() const; - bool CheckValid() const; - - AscOpOutput CreateContiguousData(const char *name, - const ge::DataType &dt, - const std::vector &axes, - const ge::Format &format = ge::FORMAT_ND); - - AscOpOutput CreateContiguousOut(const char *name, - const ge::DataType &dt, - const std::vector &axes, - const ge::Format &format = ge::FORMAT_ND); - void SortByExecOrder(); - bool CopyFrom(const ge::AscGraph &graph); - bool CopyAttrFrom(const AscGraph &src_graph); - static bool CopyAscNodeTensorAttr(const AscNodePtr &src_node, AscNodePtr &dst_node); - Status AddSubGraph(const ge::AscGraph &graph) const; - Status FindSubGraph(const std::string &name, ge::AscGraph &graph) const; - Status GetAllSubGraphs(std::vector &graphs) const; - - private: - bool CheckExprValid() const; - bool CheckAxisValid() const; - bool CheckExecOrderValid() const; - bool CheckTensorValid() const; - bool CheckNodeConnectionValid() const; - std::shared_ptr impl_; -}; -} // namespace ge - -#endif // GRAPH_ASCENDC_IR_H diff --git a/inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h b/inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h deleted file mode 100644 index ae90085f1b91515b95419e245e9b082735ae94af..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h +++ /dev/null @@ -1,452 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_ASCENDC_IR_DEF_H -#define METADEF_CXX_ASCENDC_IR_DEF_H - -#include -#include -#include "attr_store.h" -#include "graph/compute_graph.h" -#include "graph/symbolizer/symbolic.h" -#include "graph/node.h" -#include "graph/anchor.h" -#include "debug/ge_util.h" -#include "graph/utils/op_desc_utils.h" -#include "external/graph/operator.h" -#include "graph/utils/type_utils.h" -#include "graph/ascendc_ir/ascendc_ir_check.h" -#include "inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h" -#include "proto/ascendc_ir.pb.h" -#include "proto/ge_ir.pb.h" -#include "serialization/attr_serializer_registry.h" -// proto调整后,兼容老代码 -namespace ascendc_ir { -namespace proto { -using AscIrAttrDef = ge::proto::AscIrAttrDef; -using AscNodeAttrGroupsDef = ge::proto::AscNodeAttrGroupsDef; -using AscGraphAttrGroupsDef = ge::proto::AscGraphAttrGroupsDef; -using AscTensorAttrGroupsDef = ge::proto::AscTensorAttrGroupsDef; -} -} -namespace ge { -namespace { -constexpr int64_t kIdNone = -1; -const std::string kDataIndex = "index"; -} -struct SizeVar { - using Type = enum : int32_t { - kSizeTypeVar = 0, - kSizeTypeConst = 1, - }; - - // [HI] 符号`id`,从0开始,TODO:待删除 - int64_t id{}; - - // [HI] 符号名,图内唯一,符号名被用于全图的表达式,TODO:待删除 - std::string name; - - // [HI] 如果符号是常量,`const_value`表示常量的值,TODO:待删除,使用expr中的内容 - int64_t const_value{}; - - // [HI] 符号的类型,TODO:待删除,使用expr中的内容 - Type type; - - // [HI] TODO 这里只能使用Symbol创建,不允许使用Expression - explicit SizeVar(ge::Expression expr_other) : expr(std::move(expr_other)) {} - - // [HI] 符号,expr中的符号名图内唯一,符号名被用于全图的表达式 - ge::Expression expr; -}; -using SizeVarPtr = std::shared_ptr; - -struct Axis { - using Type = enum : int32_t { - kAxisTypeOriginal, - kAxisTypeBlockOuter, // outer axis after split by multicore - kAxisTypeBlockInner, // inner axis after split by multicore - kAxisTypeTileOuter, // outer axis after split by one core - kAxisTypeTileInner, // inner axis after split by one core - kAxisTypeMerged, - kAxisTypeInvalid - }; - - int64_t id{kIdNone}; // axis id - - // [HI] 轴的名字,图内唯一 - std::string name; // axis name - - // [HI] 轴的类型 - Type type{kAxisTypeInvalid}; - - // [I] 是否为`block`轴 - bool bind_block{false}; - - // [HI] 轴的大小 - ge::Expression size; - - // [I] TODO 轴的对齐要求,详细说明不同的值分别是什么含义 - int32_t align{-1}; - - // [I] 当轴为被切分轴时, - std::vector from; - - // [I] 如果轴是被切分出来的,`split_pair`表示切分出来的另一个轴的`id` - int64_t split_pair_other_id{kIdNone}; - // 自动融合场景的默认值,手写场景可以做配置,供ATT使用 - bool allow_oversize_axis{false}; - bool allow_unaligned_tail{true}; -}; -using AxisPtr = std::shared_ptr; -using AxisId = int64_t; -enum class TransType : int64_t { - kSplit = 0, - kMerge, - kValid -}; -struct OneTransInfo { - TransType trans_type; - std::vector src_axis; - std::vector dst_axis; -}; -using TransInfoRoadOfGraph = std::vector; - -enum class ComputeType : int32_t { - kComputeLoad, - kComputeStore, - kComputeReduceStore, - kComputeElewise, - kComputeBroadcast, - kComputeReduce, - kComputeTranspose, - kComputeConcat, - kComputeGather, - kComputeCube, - kComputeSplit, - kComputeInvalid, -}; - -enum class ComputeUnit : int32_t { - kUnitNone, - kUnitMTE1, - kUnitMTE2, - kUnitMTE3, - kUnitScalar, - kUnitVector, - kUnitCube, - kUnitInvalid, -}; - -enum class ApiType : int32_t { - kAPITypeBuffer, // Workspace/Data/Constant/IndexExpr/Output - kAPITypeCompute, // Load/Store/ReduceStore/Elewise/BroadCast/Reduce/Transpose - kAPITypeInvalid, -}; - -enum class ExecuteCondition: int32_t { - kNoCache = 0, // 不缓存 - kCacheBlockSplitFusedBroadcastAxis, // 缓存,条件是合轴后的广播轴拆分到T和t中 - kCacheBlockSplitOriginBroadcastAxis, // 缓存,条件是合轴后分到T中的原始轴都是广播轴 - kConditionInvalid, -}; - -struct ApiInfo { - // [I] `api`的类型 - ApiType type = ApiType::kAPITypeInvalid; - - // [I] `api`的计算类型 - ComputeType compute_type = ComputeType::kComputeInvalid; - - // [I] `api`的计算单元 - ComputeUnit unit = ComputeUnit::kUnitInvalid; -}; - -struct SchedInfo { - // [HI] 执行序,按值从小到大执行 - int64_t exec_order{kIdNone}; - - // [HI] 节点所处的多层嵌套循环的轴`id`,按循环表示从外层到内层的轴`id` - std::vector axis; - - // [I] 节点进行`api`计算的最内层循环,这个轴以内的部分将被映射为`api`的参数长度,这个轴以外的循环将会展开 - int64_t loop_axis{kIdNone}; - - // [I] 节点的执行时的附加条件,目前主要用于是否对`api`结果进行缓存 - ExecuteCondition exec_condition{ExecuteCondition::kNoCache}; -}; - -class AscIrAttrDefBase { - public: - AscIrAttrDefBase() = default; - virtual ~AscIrAttrDefBase() = default; - graphStatus Serialize(ascendc_ir::proto::AscIrAttrDef &asc_ir_attr_def); - graphStatus Deserialize(const ascendc_ir::proto::AscIrAttrDef &asc_ir_attr_def); - std::unique_ptr Clone();; - template - graphStatus GetAttrValue(const std::string &attr_name, T &attr_value) { - auto *const v = attr_store_.GetAnyValue(attr_name); - if (v == nullptr) { - GELOGW("Attr %s has not been set.", attr_name.c_str()); - return GRAPH_FAILED; - } - if (v->Get() == nullptr) { - GELOGW("Attr %s is set, however maybe type is not fit.", attr_name.c_str()); - return GRAPH_FAILED; - } - attr_value = *(v->Get()); - return GRAPH_SUCCESS; - } - template - T *DownCastTo() { - // 子类没有成员,所以可以这样搞 - static_assert(std::is_base_of::value, "Template parameter must be derived from IrAttrDefBase"); - return reinterpret_cast(this); - } - protected: - AttrStore::CustomDefinedAttrStore attr_store_; -}; - -enum class AllocType : int32_t { - kAllocTypeGlobal, - kAllocTypeL1, - kAllocTypeL2, - kAllocTypeBuffer, - kAllocTypeQueue, - kAllocTypeInvalid, -}; - -enum class MemHardware : int32_t { - kMemHardwareGM, - kMemHardwareUB, - kMemHardwareInvalid, -}; - -enum class Position : int32_t { - kPositionGM, - kPositionVecIn, - kPositionVecOut, - kPositionVecCalc, - kPositionInvalid, -}; - -struct MemAttr { - int64_t tensor_id = kIdNone; - AllocType alloc_type = AllocType::kAllocTypeGlobal; - Position position = Position::kPositionGM; - MemHardware hardware = MemHardware::kMemHardwareGM; - // TODO 待删除 - std::vector buf_ids; - // TODO 待删除 - std::string name; - // reuse_id配合que_id表达que的共用和复用 - // que_id相同,一个reuse_id对应一组tensor, 该组中的多个tensor共用该que_id, tensor使用该que的offset由使用者自己计算和维护 - // que_id相同,多个reuse_id对应多组tensor,每组tensor间复用该que_id - int64_t reuse_id = kIdNone; -}; - -struct MemQueAttr { - int64_t id = kIdNone; - int64_t depth{-1}; - int64_t buf_num{-1}; - // TODO 待删除 - std::string name{""}; -}; - -struct MemBufAttr { - int64_t id = kIdNone; - // TODO 待删除 - std::string name{""}; -}; - -struct MemOptAttr { - int64_t reuse_id = kIdNone; // TODO 待删除, 正式方案放在MemAttr - int64_t ref_tensor = kIdNone; - int64_t merge_scope = kIdNone; -}; - -struct TmpBufDesc { - Expression size; - int64_t life_time_axis_id = -1; // -1: 生命周期为API级别, >= 0: loop级别 -}; - -struct TmpBuffer { - TmpBufDesc buf_desc; - MemAttr mem{}; -}; - - -class AscNodeAttr : public ge::AttrGroupsBase { - public: - // [HI] 节点名,图内唯一 - std::string name; - // [HI] 节点类型 - std::string type; - // 调度信息 - SchedInfo sched{}; - ApiInfo api{}; - // Ir定义的属性,跟具体Ir有关 - std::unique_ptr ir_attr{nullptr}; - std::vector tmp_buffers; - AscNodeAttr() = default; - ~AscNodeAttr() override = default; - graphStatus SerializeAttr(ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_group) const; - graphStatus DeserializeAttr(const ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_group); - graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override; - graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override; - AscNodeAttr &operator=(const AscNodeAttr &other); - AscNodeAttr(const AscNodeAttr &other) - : name(other.name), - type(other.type), - sched(other.sched), - api(other.api), - ir_attr(other.ir_attr ? other.ir_attr->Clone() : nullptr), - tmp_buffers(other.tmp_buffers) {} -// 没有注册ir属性时,调用这个接口 - static AscNodeAttr *Create(ge::Operator &op); - -// 注册了ir属性时,调用这个接口 - template - static AscNodeAttr *Create(ge::Operator &op) { - static_assert( - std::is_base_of::value && !std::is_same::value, - "Template parameter must be derived from IrAttrDefBase"); - return CreateImplWithIrAttrInit(op); - } - std::unique_ptr Clone() override; - private: - static AscNodeAttr *CreateImpl(ge::Operator &op); - template - static AscNodeAttr *CreateImplWithIrAttrInit(ge::Operator &op) { - auto attr_group = CreateImpl(op); - GE_ASSERT_NOTNULL(attr_group); - attr_group->ir_attr = std::move(ComGraphMakeUnique()); - GE_ASSERT_NOTNULL(attr_group->ir_attr); - return attr_group; - } -}; - -class AscDataIrAttrDef : public AscIrAttrDefBase { - // 子类不应该有自己的成员,只需要有对应的set,get函数 - public: - ~AscDataIrAttrDef() override = default; - graphStatus GetIndex(int64_t &index) const; - graphStatus SetIndex(int64_t index); -}; - -enum class AscGraphType : int64_t { - kHintGraph = 0, - kImplGraph, -}; - -class AscGraphAttr : public ge::AttrGroupsBase { - public: - // TODO 待确认正式方案 - int64_t tiling_key = -1; - - // [HI] 图上的轴 - std::vector axis; - - // TODO 待正式方案后删除 - TransInfoRoadOfGraph trans_info_road; - - // [HI] 图上的符号,TODO:未来不需要这个数据结构了,改成Expression即可 - std::vector size_vars; - AscGraphType type{AscGraphType::kHintGraph}; - graphStatus SerializeAttr(ascendc_ir::proto::AscGraphAttrGroupsDef &asc_graph_group); - graphStatus DeserializeAttr(const ascendc_ir::proto::AscGraphAttrGroupsDef &asc_graph_group); - std::unique_ptr Clone() override; - graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override; - graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override; -}; - -class AscTensorDataType { - public: - operator ge::DataType() const { - if (tensor_desc_ == nullptr) { - GELOGE(FAILED, "tensor_desc_ is null"); - return ge::DT_UNDEFINED; - } - return tensor_desc_->GetDataType(); - }; - void operator=(const ge::DataType &other) { - if (tensor_desc_ == nullptr) { - GELOGE(FAILED, "tensor_desc_ is null"); - return; - } - tensor_desc_->SetDataType(other); - }; - AscTensorDataType &operator=(const AscTensorDataType &other) { - if (this == &other) { - return *this; - } - if ((tensor_desc_ != nullptr) && (other.tensor_desc_ != nullptr)) { - tensor_desc_->SetDataType(static_cast(other)); - } - if ((tensor_desc_ == nullptr) && (other.tensor_desc_ != nullptr)) { - // 浅拷贝,兼容已存在的用法,调用者需要保证声明周期有效 - tensor_desc_ = other.tensor_desc_; - } - return *this; - } - AscTensorDataType(const AscTensorDataType &other) { - if ((tensor_desc_ != nullptr) && (other.tensor_desc_ != nullptr)) { - tensor_desc_->SetDataType(static_cast(other)); - } - if ((tensor_desc_ == nullptr) && (other.tensor_desc_ != nullptr)) { - // 浅拷贝,兼容已存在的用法,调用者需要保证声明周期有效 - tensor_desc_ = other.tensor_desc_; - } - } - AscTensorDataType() = default; - private: - friend struct AscNodeOutputs; - friend class AscTensorAttr; - friend class AscGraphUtils; - GeTensorDesc *tensor_desc_{nullptr}; -}; - -class AscTensorAttr : public ge::AttrGroupsBase { - friend class AscGraphUtils; - public: - - // [HI] 该`Tensor`的数据类型 - AscTensorDataType dtype; - - // [HI] 该`Tensor`中包含的轴的`id` - std::vector axis; - - // [HI] `repeat[i]`表示该`Tensor`包含的第`i`个轴的大小的符号表达式 - std::vector repeats; - - // [HI] `stride[i]`表示该`Tensor`包含的第`i`个轴,在索引时的步长 - std::vector strides; - - // [I] `buffer`中存储哪些轴的内容 - std::vector vectorized_axis; - - // [I] `buffer`中存储的内容,按轴索引时的步长 - std::vector vectorized_strides; - MemAttr mem{}; - MemQueAttr que{}; - MemBufAttr buf{}; - MemOptAttr opt{}; - static AscTensorAttr &GetTensorAttr(ge::Operator *op, const uint32_t index); - static AscTensorAttr &GetTensorAttr(const OutDataAnchor &output); - static AscTensorAttr *GetTensorAttrPtr(ge::Operator *op, const uint32_t index); - static AscTensorAttr *GetTensorAttrPtr(const OutDataAnchor &output); - graphStatus SerializeAttr(ascendc_ir::proto::AscTensorAttrGroupsDef &asc_tensor_group); - graphStatus DeserializeAttr(const ascendc_ir::proto::AscTensorAttrGroupsDef &asc_tensor_group, - GeTensorDesc *tensor_desc); - std::unique_ptr Clone() override; - graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override; - graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override; -}; -} // namespace ge - -#endif // METADEF_CXX_ASCENDC_IR_DEF_H diff --git a/inc/graph/ascendc_ir/ascir_register.h b/inc/graph/ascendc_ir/ascir_register.h deleted file mode 100644 index 8dd52a50405ae4b463d949c8201e69c096bce36e..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/ascir_register.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AUTOFUSE_ASCIR_REGISTER_H -#define AUTOFUSE_ASCIR_REGISTER_H -#include -#include -#include "graph/ascendc_ir/ascir_registry.h" -#include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h" -#include "graph/ascend_string.h" -#include "graph/symbolizer/symbolic.h" - -namespace ge { -namespace ascir { -class AscirRegister { - public: - AscirRegister() = default; - AscirRegister(const char *type, const char *def_file_path, int64_t line); - AscirRegister &Inputs(std::vector &&input_names); - AscirRegister &Input(const char_t *input_name, const char_t *datatype_symbol); - AscirRegister &DataType(const char_t *datatype_symbol, const TensorType &type_range); - AscirRegister &DataType(const char_t *datatype_symbol, const OrderedTensorTypeList &type_range); - AscirRegister &DynamicInput(const std::string &input_name); - AscirRegister &DynamicInput(const char_t *input_name, const char_t *datatype_symbol); - AscirRegister &OptionalInput(const std::string &input_name); - AscirRegister &Outputs(std::vector &&output_names); - AscirRegister &Output(const char_t *output_name, const char_t *datatype_symbol); - AscirRegister &DynamicOutput(const std::string &output_name); - AscirRegister &DynamicOutput(const char_t *output_name, const char_t *datatype_symbol); - AscirRegister &Comment(const std::string &comment); - - template - AscirRegister &Attr(ge::AscendString &&name); - - AscirRegister &InferDataType(AscIrDef::CodeGenerator infer_data_type_generator); - AscirRegister &UseFirstInputDataType() { - const auto &output_defs = ir_def_.GetOutputDefs(); - return DataTypes(std::vector(output_defs.size(), DtypePolicy(0U))); - } - AscirRegister &UseSecondInputDataType() { - const auto &output_defs = ir_def_.GetOutputDefs(); - return DataTypes(std::vector(output_defs.size(), DtypePolicy(1U))); - } - - AscirRegister &InferView(AscIrDef::CodeGenerator infer_view_generator); - AscirRegister &UseFirstInputView() { - const auto &output_defs = ir_def_.GetOutputDefs(); - return Views(std::vector(output_defs.size(), ViewPolicy(0))); - } - - AscirRegister &StartNode(); - AscirRegister &Views(const std::vector &views_policy); - AscirRegister &DataTypes(const std::vector &data_types_policy); - AscirRegister(const AscirRegister &other); - AscirRegister &operator=(const AscirRegister &) = delete; - - AscirRegister(AscirRegister &&) noexcept = delete; - AscirRegister &operator=(AscirRegister &&) noexcept = delete; - - AscirRegister &CalcTmpBufSize(const std::string &calc_tmp_buf_size_func); - AscirRegister &SameTmpBufSizeFromFirstInput(); - - AscirRegister &ApiTilingDataType(const std::string &tiling_data_name); - - AscirRegister &Impl(const std::vector &soc_version, const AscIrImpl &impl); - - AscirRegister &Impl(const std::vector &soc_version, const AscIrImplV2 &impl); - - size_t GetSocImplSize() const; - - private: - AscirRegister &Attr(std::string name, std::string asc_type, std::string ge_type); - - private: - AscIrDef ir_def_; -}; - -#define REG_ASC_IR(type) static auto g_register_##type = ge::ascir::AscirRegister(#type, __FILE__, __LINE__) -#define REG_ASC_IR_START_NODE(type) REG_ASC_IR(type).Inputs({}).Outputs({"y"}).StartNode() -#define REG_ASC_IR_START_NODE_WITH_ATTR(type) \ - REG_ASC_IR(type).Inputs({}).Outputs({"y"}).Attr("index").StartNode() -#define REG_ASC_IR_1IO(type) REG_ASC_IR(type).Input("x", "T").Output("y", "T").DataType("T", TensorType::ALL()) -#define REG_ASC_IR_2I1O(type) \ - REG_ASC_IR(type).Input("x1", "T").Input("x2", "T").Output("y", "T").DataType("T", TensorType::ALL()) - -#define EXPAND_CHAIN_CALL(...) #__VA_ARGS__ -#define REG_ASC_IR_WITH_COMMENT(type, ...) \ - constexpr const char *comment_##type = \ - R"COMMENT(REG_ASC_IR()COMMENT" #type ")\n" EXPAND_CHAIN_CALL(__VA_ARGS__) ";"; \ - static auto g_register_##type = AscirRegister(#type, __FILE__, __LINE__) __VA_ARGS__.Comment(comment_##type) -#define EXPORT_GENERATOR() -} // namespace ascir -} // namespace ge - -#endif // AUTOFUSE_ASCIR_REGISTER_H diff --git a/inc/graph/ascendc_ir/ascir_registry.h b/inc/graph/ascendc_ir/ascir_registry.h deleted file mode 100644 index ef86d126f40b3ccdbfe457e9b9846aa73cc99be2..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/ascir_registry.h +++ /dev/null @@ -1,467 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AUTOFUSE_ASCIR_REGISTRY_H -#define AUTOFUSE_ASCIR_REGISTRY_H -#include -#include -#include -#include -#include -#include -#include -#include -#include "ascendc_ir/ascendc_ir_check.h" -#include "external/graph/types.h" -#include "op_desc.h" -#include "ir/ir_data_type_symbol_store.h" -#include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir_def.h" -#include "graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h" - -namespace ge { -namespace ascir { -using ApplyOutputView = std::function; -struct ViewPolicy { - public: - enum ViewType : int64_t { - kElementWise = 0, - kReduce, - kBroadCast, - kInvalid, - }; - ViewPolicy(uint32_t element_wise_input_index) : use_input_index(element_wise_input_index) { - view_type = kElementWise; - } - ViewPolicy(uint32_t reduce_input_index, std::string reduce_axis_name) - : use_input_index(reduce_input_index), reduce_axis_attr_name(std::move(reduce_axis_name)) { - view_type = kReduce; - } - - explicit ViewPolicy(std::vector broad_cast_in_indexs) - : broad_cast_input_indexs(std::move(broad_cast_in_indexs)) { - view_type = kBroadCast; - } - - ViewType view_type{kInvalid}; - uint32_t use_input_index{UINT32_MAX}; - std::string reduce_axis_attr_name; - std::vector broad_cast_input_indexs; -}; - -inline ViewPolicy ReduceView(uint32_t index, const std::string &attr_name) { - return ViewPolicy(index, attr_name); -} -inline ViewPolicy BroadCastView(const std::vector &broad_cast_input_indexs) { - return ViewPolicy(broad_cast_input_indexs); -} - -struct DtypePolicy { - enum PolicyType : int64_t { - kUseInput = 0, - kPromptInput, - kUseDtype, - kInvalid, - }; - - public: - DtypePolicy(uint32_t use_in_index) : use_input_index(use_in_index) { - policy_type = kUseInput; - }; - DtypePolicy(ge::DataType dtype) : data_type(dtype) { - policy_type = kUseDtype; - }; - PolicyType policy_type{kInvalid}; - uint32_t use_input_index{UINT32_MAX}; - ge::DataType data_type{ge::DataType::DT_UNDEFINED}; -}; - -inline DtypePolicy PromptDtype(uint32_t index) { - auto policy = DtypePolicy(index); - policy.policy_type = DtypePolicy::kPromptInput; - return policy; -} -// TODO: c++的类ABI兼容性不好,后面考虑换成C接口实现 -struct AscIrAttrDef { - std::string name; - std::string asc_ir_type; - std::string ge_ir_type; -}; -enum CalcTmpBufSizeFuncType : int64_t { - CommonType = 0, - CustomizeType, -}; -struct CalcTmpBufSizeFunc { - std::string func_name; - CalcTmpBufSizeFuncType func_type = CalcTmpBufSizeFuncType::CommonType; - CalcTmpBufSizeFunc() = default; - CalcTmpBufSizeFunc(std::string name, const CalcTmpBufSizeFuncType type) - : func_name(std::move(name)), func_type(type) {} -}; - -class AscIrCodegen { - public: - virtual std::vector> CalcTmpBufSize(const ge::AscNode &node) { - (void) node; - return std::vector>(); - } - virtual std::string GetApiTilingTypeName() const { - return ""; - } - - virtual uint32_t GetInstNum() const { - return 0U; - } - - // 返回api call类的名称 - virtual std::string GetApiCallName() const { - return ""; - } - - // 返回api的名称 - virtual std::string GetApiName() const { - return ""; - } - - // 返回api call类的名称 - virtual std::string GetMicroApiCallName() const { - return ""; - } - - // 返回api的名称 - virtual std::string GetMicroApiName() const { - return ""; - } - - // 微指令api包含的微指令的条数, 如果支持vector function, 该接口返回值才有意义 - virtual uint32_t GetMicroInstNum() const { - return 1U; - } - - // 返回需要加载的头文件 - virtual std::vector LoadApiHeaderFiles() const { - return std::vector(); - } - - virtual bool IsVectorFunctionSupported(const ge::AscNode &node) const { - (void)node; - return false; - } - virtual bool IsScalarInputSupported(const std::vector &is_scalar_list) const { - (void)is_scalar_list; - return false; - } - virtual bool IsScalarInputSupportedIfExchangeInputs(const std::vector &is_scalar_list) const { - (void)is_scalar_list; - return false; - } - - virtual bool IsInplaceSupported(const ge::AscNode &node) const { - (void)node; - return false; - } - - virtual bool IsBrcInlineSupported(const ge::AscNode &node) const { - (void)node; - return false; - } -}; - -class AscIrAtt { - public: - // 最内轴建议对齐值(默认32B对齐) - virtual uint32_t GetInnerDimPromptAlignSize() const { - return 32U; - } - // 最外轴建议对齐值(默认为1,表示对外轴无对齐要求) - virtual uint32_t GetOuterDimPromptAlignSize() const { - return 1U; - } - // 返回ASCIR API接口性能公式函数(不同硬件的ASCIR实现存在差异,性能公式形式存在差异) - virtual void *GetApiPerf() const = 0; - // 返回MicroApi性能函数公式(不同硬件的ASCIR的vf function实现存在差异) - virtual void *GetMicroApiPerf() const = 0; - // 返回AscendCApi的性能公式(不同硬件的性能公式参数存在差异) - virtual void *GetAscendCApiPerfTable() const = 0; -}; - -template -std::function()> AscIrImplCreator() { - return []() { return std::unique_ptr(new T()); }; -} - -using AscIrAttCreator = std::function()>; -using AscIrCodegenCreator = std::function()>; - -struct AscIrImpl { - AscIrAttCreator att; - AscIrCodegenCreator codegen; - std::vector> support_dtypes; -}; - -struct AscIrImplV2 { - AscIrAttCreator att; - AscIrCodegenCreator codegen; - std::vector> support_dtypes; -}; - -struct AscIrDefImpl; -class AscIrDef { - public: - AscIrDef(); - using CodeGenerator = void (*)(const AscIrDef &def, std::stringstream &ss); - bool IsAttrExisted(const std::string &attr_name) const; - - void Init(const char *type, const char *def_file_path, int64_t line) const; - - const std::vector> &GetInputDefs() const; - const std::vector> &GetOutputDefs() const; - bool HasDynamicOutput() const; - void AppendInput(const string &name, ge::IrInputType type) const; - void AppendOutput(const string &name, ge::IrOutputType type) const; - void StoreInputIrSymName(const std::string &ir_name, const std::string &sym_name) const; - void StoreOutputIrSymName(const std::string &ir_name, const std::string &sym_name) const; - const std::string &GetType() const; - void StartNode() const; - bool IsStartNode() const; - void SetAttr(const std::string &name, const std::string &asc_type, const std::string &ge_type) const; - void SetDtypePolicy(const std::vector &output_dtypes_policy) const; - const std::vector &GetOutputDtypePolicy() const; - void SetViewPolicy(const std::vector &view_policy) const; - const std::vector &GetViewPolicy() const; - void SetApiTilingDataName(const std::string &tiling_data_name) const; - const string &GetApiTilingDataName() const; - void SetCalcTmpBufSizeFunc(const std::string &calc_tmp_buf_size_func, CalcTmpBufSizeFuncType type) const; - const CalcTmpBufSizeFunc &GetCalcTmpBufSizeFunc() const; - const std::vector &GetAttrDefs() const; - std::vector &MutableAttrDefs() const; - void SetComment(const string &comment) const; - const string &GetComment() const; - const std::string &GetFilePath() const; - int64_t GetLine() const; - IRDataTypeSymbolStore &MutableDataTypeSymbolStore() const; - const IRDataTypeSymbolStore &GetDataTypeSymbolStore() const; - const std::map &GetSocToDataTypeSymbolStore() const; - void AddSocImpl(const std::vector &soc_versions, const AscIrImpl &impl) const; - void AddSocImplV2(const std::vector &soc_versions, const AscIrImplV2 &impl) const; - void AppendSocImpl(const AscIrDef &ir_def) const; - size_t GetSocImplSize() const; - CodeGenerator infer_data_type_generator{nullptr}; - CodeGenerator infer_view_generator{nullptr}; - std::unique_ptr GetAscIrAttImpl(const std::string &soc_version); - std::unique_ptr GetAscIrCodegenImpl(const std::string &soc_version); - - private: - friend class AscirRegister; - std::shared_ptr impl_; -}; - -inline std::string UpdateViewIfCrossLoop(const std::string &trans_infos, const std::string &input_api_sched_axis, - const std::string &op_attr_sched_axis, const std::string &tie_expression) { - return "AxisUtils::UpdateViewIfCrossLoop(" + trans_infos + ", " + input_api_sched_axis + ", " + op_attr_sched_axis + - ", " + "std::move(" + tie_expression + "))"; -} - -inline void GenChosenInputView(const AscIrDef &def, const uint32_t chosen_input_index, std::stringstream &ss) { - const auto &input_defs = def.GetInputDefs(); - ss << input_defs[chosen_input_index].first + "_tmp = " << "{*" << input_defs[chosen_input_index].first - << "_in.axis, *" << input_defs[chosen_input_index].first << "_in.repeats, *" - << input_defs[chosen_input_index].first << "_in.strides};" << std::endl; -} -template -void GenErrorIfPolicyInvalid(Policy policy, size_t range, std::stringstream &ss) { - if (policy.use_input_index < range) { - return; - } - ss << "Policy is invalid as use_input_index :" << policy.use_input_index - << " should be less than input size:" << range << std::endl; -} - -inline void DefineChosenInputView(const AscIrDef &def, const ViewPolicy &policy, uint32_t &chosen_input_index, - std::unordered_set &chosen_input_index_set, std::stringstream &ss) { - const auto &input_defs = def.GetInputDefs(); - ss << " // set tmp view to store input view and apply view transform" << std::endl; - const std::string view_type("View "); - GenErrorIfPolicyInvalid(policy, input_defs.size(), ss); - chosen_input_index = policy.use_input_index; - ss << " "; - if (chosen_input_index_set.insert(chosen_input_index).second) { - ss << view_type; - } - GenChosenInputView(def, chosen_input_index, ss); -} - -inline void SameDataTypeFromInput(const AscIrDef &def, std::stringstream &ss, const char *input_name) { - const auto &output_defs = def.GetOutputDefs(); - for (const auto &output_def : output_defs) { - ss << " op." << output_def.first << ".dtype = static_cast(" << input_name << "_in.dtype);" - << std::endl; - } -} - -inline void GenerateViewUpdateCode(const AscIrDef &def, const std::pair out_to_chosen_input, - const ApplyOutputView &apply_output_view, std::stringstream &ss, - bool &gen_trans_infos_instance) { - const auto &input_defs = def.GetInputDefs(); - const auto &output_defs = def.GetOutputDefs(); - const size_t output_index = out_to_chosen_input.first; - const size_t chosen_input_index = out_to_chosen_input.second; - if (!gen_trans_infos_instance) { - ss << " auto trans_infos = CodeGenUtils::GetOwnerGraphAscAttr(op." << output_defs[output_index].first - << ".GetOwnerOp())" << "->trans_info_road;" << std::endl; - gen_trans_infos_instance = true; - } - - const std::string which_input_api_sched_axis = output_defs[output_index].first + "_in_api_sched_axis"; - ss << " auto " << which_input_api_sched_axis << " = CodeGenUtils::GetOwnerOpAscAttr(" - << input_defs[chosen_input_index].first << "_in.GetOwnerOp())" - << "->sched.axis;" << std::endl; - std::string view = input_defs[chosen_input_index].first + "_tmp"; - ss << " {" << std::endl << " const auto &[axes, repeats, strides] = "; - std::string val = - UpdateViewIfCrossLoop("trans_infos", which_input_api_sched_axis, "op.attr.sched.axis", view).append(".second"); - // 应用输出的语义变换 - if (!(apply_output_view == nullptr)) { - ss << apply_output_view(val) << ";" << std::endl; - } else { - ss << val << ";" << std::endl; - } - ss << " std::tie(*op." << output_defs[output_index].first << ".axis, *op." << output_defs[output_index].first - << ".repeats, *op." << output_defs[output_index].first << ".strides) = std::make_tuple(axes, repeats, strides);" - << std::endl - << " }" << std::endl; -} - -inline ApplyOutputView GenApplyOutputViewFunc(const AscIrDef &def, const size_t output_index, - uint32_t &chosen_input_index, std::stringstream &ss) { - (void) chosen_input_index; - const auto &output_views_policy = def.GetViewPolicy(); - const auto &policy = output_views_policy[output_index]; - ApplyOutputView apply_output_view; - switch (policy.view_type) { - case ViewPolicy::kElementWise: - break; - case ViewPolicy::kReduce: - if (!def.IsAttrExisted(policy.reduce_axis_attr_name)) { - return apply_output_view; - } - apply_output_view = [&output_views_policy, output_index](const std::string &var) -> std::string { - return "AxisUtils::ReduceView(" + var + ", " + output_views_policy[output_index].reduce_axis_attr_name + ")"; - }; - break; - case ViewPolicy::kBroadCast: // TODO 广播代码后续支持 - case ViewPolicy::kInvalid: - default: - ss << "unsupported policy type: " << policy.view_type << std::endl; - break; - } - return apply_output_view; -} - -inline void InferViewByPolicy(const AscIrDef &def, std::stringstream &ss) { - const auto &output_defs = def.GetOutputDefs(); - const auto &output_views_policy = def.GetViewPolicy(); - if (output_defs.size() != output_views_policy.size()) { - std::string error_info = std::string("view_policy's size ") - .append(std::to_string(output_views_policy.size())) - .append(" should be equal with output_defs's size ") - .append(std::to_string(output_defs.size())); - ss << error_info; - return; - } - bool gen_trans_infos_instance = false; - std::unordered_set chosen_input_index_set; - for (size_t output_index = 0U; output_index < output_views_policy.size(); ++output_index) { - uint32_t chosen_input_index = 0U; - DefineChosenInputView(def, output_views_policy[output_index], chosen_input_index, chosen_input_index_set, ss); - GenerateViewUpdateCode(def, std::make_pair(output_index, chosen_input_index), - GenApplyOutputViewFunc(def, output_index, chosen_input_index, ss), ss, - gen_trans_infos_instance); - } -} - -inline void InferDtypeByPolicy(const AscIrDef &def, std::stringstream &ss) { - const auto &output_defs = def.GetOutputDefs(); - const auto &output_dtypes_policy = def.GetOutputDtypePolicy(); - if (output_defs.size() != output_dtypes_policy.size()) { - std::string error_info = std::string("dtype_policy's size ") - .append(std::to_string(output_dtypes_policy.size())) - .append("should be equal with output_defs's size ") - .append(std::to_string(output_defs.size())); - ss << error_info; - return; - } - const auto &input_defs = def.GetInputDefs(); - for (size_t output_index = 0U; output_index < output_dtypes_policy.size(); ++output_index) { - const auto &policy = output_dtypes_policy[output_index]; - switch (policy.policy_type) { - case DtypePolicy::kUseInput: - GenErrorIfPolicyInvalid(policy, input_defs.size(), ss); - ss << " op." << output_defs[output_index].first << ".dtype = static_cast(" - << input_defs[policy.use_input_index].first << "_in.dtype);" << std::endl; - break; - case DtypePolicy::kPromptInput: - GenErrorIfPolicyInvalid(policy, input_defs.size(), ss); - ss << " op." << output_defs[output_index].first - << ".dtype = DtypeTransformUtils::Prompt(static_cast(" - << input_defs[policy.use_input_index].first << "_in.dtype));" << std::endl; - break; - case DtypePolicy::kUseDtype: - ss << " op." << output_defs[output_index].first << ".dtype = static_cast(" << policy.data_type - << ");" << std::endl; - break; - case DtypePolicy::kInvalid: - default: - ss << "unsupported policy type: " << policy.policy_type << std::endl; - } - } -} - -inline void SameDataTypeFromFirstInput(const AscIrDef &def, std::stringstream &ss) { - const auto &input_defs = def.GetInputDefs(); - if (!input_defs.empty()) { - SameDataTypeFromInput(def, ss, input_defs[0].first.c_str()); - } -} -inline void SameDataTypeFromSecondInput(const AscIrDef &def, std::stringstream &ss) { - const auto &input_defs = def.GetInputDefs(); - if (input_defs.size() > 1U) { - SameDataTypeFromInput(def, ss, input_defs[1].first.c_str()); - } -} -inline void SameViewFromInput(const AscIrDef &def, std::stringstream &ss, const char *input_name) { - const auto &output_defs = def.GetOutputDefs(); - for (const auto &output_def : output_defs) { - ss << " op." << output_def.first << ".axis = " << input_name << "_in.axis;" << std::endl; - ss << " op." << output_def.first << ".repeats = " << input_name << "_in.repeats;" << std::endl; - ss << " op." << output_def.first << ".strides = " << input_name << "_in.strides;" << std::endl; - } -} -inline void SameViewFromFirstInput(const AscIrDef &def, std::stringstream &ss) { - const auto &input_defs = def.GetInputDefs(); - if (!input_defs.empty()) { - SameViewFromInput(def, ss, input_defs[0].first.c_str()); - } -} - -class AscirRegistry { - public: - static AscirRegistry &GetInstance(); - void RegisterAscIr(const std::string &type, const AscIrDef &def); - - const std::unordered_map &GetAll() const; - std::unique_ptr GetIrAttImpl(const std::string &soc_version, const std::string &type); - std::unique_ptr GetIrCodegenImpl(const std::string &soc_version, const std::string &type); - void ClearAll(); - - private: - std::unordered_map types_to_ascir_; -}; -} // namespace ascir -} // namespace ge -#endif // AUTOFUSE_ASCIR_REGISTRY_H diff --git a/inc/graph/ascendc_ir/utils/asc_graph_utils.h b/inc/graph/ascendc_ir/utils/asc_graph_utils.h deleted file mode 100644 index 99330502dcbef12d66bbe9ff0abfbf847010fea7..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/utils/asc_graph_utils.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_ASC_GRAPH_UTILS_H -#define METADEF_CXX_ASC_GRAPH_UTILS_H - -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "proto/ascendc_ir.pb.h" - -namespace ge { -class AscGraphUtils { - public: - static ComputeGraphPtr GetComputeGraph(const AscGraph &asc_graph); - static Status FromComputeGraph(const ge::ComputeGraphPtr &compute_graph, ge::AscGraph &graph); - /** - * @param compute_graph的node对象是Node类型时候,接口内部转换为AscNode - * @return - */ - static graphStatus ConvertComputeGraphToAscGraph(const ComputeGraphPtr &compute_graph, AscGraph &asc_graph); - static graphStatus SerializeToBinary(const AscGraph &asc_graph, std::string &output); - static graphStatus SerializeToReadable(const AscGraph &asc_graph, std::string &output); - static graphStatus SerializeToProto(const AscGraph &asc_graph, ascendc_ir::proto::AscGraphDef &asc_graph_def); - static graphStatus DeserializeFromBinary(const std::string &to_be_deserialized, AscGraph &out_asc_graph); - static graphStatus DeserializeFromReadable(const std::string &to_be_deserialized, AscGraph &out_asc_graph); - static graphStatus DeserializeFromProto(const ascendc_ir::proto::AscGraphDef &asc_graph_def, AscGraph &asc_graph); -}; -class AscNodeSerializeUtils { - public: - static graphStatus SerializeIrDef(const AscNode &node, ascendc_ir::proto::IrDef &ir_def); - static graphStatus SerializeAttrGroupsDef(const AscNode &node, - ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_attr_groups_def); -}; - -class AscNodeDeserializeUtils { - public: - static graphStatus DeserializeIrDef(const ascendc_ir::proto::IrDef &ir_def, AscNode &node); - static graphStatus DeserializeAttrGroupsDef(const ascendc_ir::proto::AscNodeAttrGroupsDef &asc_node_attr_groups_def, - AscNode &node); -}; -class ExpressionSerializer : public GeIrAttrSerializer { - public: - ExpressionSerializer() = default; - graphStatus Serialize(const AnyValue &av, proto::AttrDef &def) override; - graphStatus Deserialize(const proto::AttrDef &def, AnyValue &av) override; -}; -} // namespace ge - -#endif // METADEF_CXX_ASC_GRAPH_UTILS_H diff --git a/inc/graph/ascendc_ir/utils/asc_tensor_utils.h b/inc/graph/ascendc_ir/utils/asc_tensor_utils.h deleted file mode 100644 index b5c1a30230077ab64c33a60ecca12bb8a35b9f6a..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/utils/asc_tensor_utils.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_ASC_TENSOR_UTILS_H -#define METADEF_CXX_ASC_TENSOR_UTILS_H - -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" - -namespace ge { -namespace ascir { -class AscTensorUtils { - public: - static bool IsConstTensor(const AscTensor &t); - static Node *GetOwner(const AscTensor &t); - static int32_t Index(const AscTensor &t); -}; -} -} // namespace ge - -#endif // METADEF_CXX_ASC_TENSOR_UTILS_H diff --git a/inc/graph/ascendc_ir/utils/ascendc_ir_dump_utils.h b/inc/graph/ascendc_ir/utils/ascendc_ir_dump_utils.h deleted file mode 100644 index 41506468f43894585ecb0017564bc814eff17f6e..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/utils/ascendc_ir_dump_utils.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "inc/external/graph/utils/type_utils.h" -namespace ge { -class DumpAscirGraph { - public: - static std::string DumpGraph(AscGraph &graph); - static void WriteOutToFile(const std::string &filename, AscGraph &graph); - private: - static std::stringstream &TilingKeyStr(std::stringstream &ss, AscGraph &graph); - static std::stringstream &NameStr(std::stringstream &ss, AscGraph &graph); - static std::stringstream &AllAxisStr(std::stringstream &ss, AscGraph &graph); - static std::stringstream &AscNodeAttrStr(std::stringstream &ss, AscNodeAttr &attr); - static std::stringstream &AscTensorAttrStr(std::stringstream &ss, AscTensorAttr *attr); - static std::stringstream &MemAttrStr(std::stringstream &ss, AscTensorAttr *attr); - static std::stringstream &MemQueueAttrStr(std::stringstream &ss, AscTensorAttr *attr); - static std::stringstream &MemBufAttrStr(std::stringstream &ss, AscTensorAttr *attr); - static std::stringstream &MemOptAttrStr(std::stringstream &ss, AscTensorAttr *attr); - static std::stringstream &NodesStr(std::stringstream &ss, ge::AscNodeVisitor &nodes); - static std::string ApiTypeToString(ge::ApiType type); - static std::string ComputUnitToString(ge::ComputeUnit unit); - static std::string ComputeTypeToString(ge::ComputeType type); - static std::string AllocTypeToString(ge::AllocType type); - static std::string PositionToString(ge::Position position); - static std::string HardwareToString(ge::MemHardware hardware); -}; -} // namespace ge \ No newline at end of file diff --git a/inc/graph/ascendc_ir/utils/cg_calc_tmp_buff_common_funcs.h b/inc/graph/ascendc_ir/utils/cg_calc_tmp_buff_common_funcs.h deleted file mode 100644 index 5c53dbd75920b80269a4d26cb0df428f4292460b..0000000000000000000000000000000000000000 --- a/inc/graph/ascendc_ir/utils/cg_calc_tmp_buff_common_funcs.h +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -* This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_CG_CALC_TMP_BUFF_COMMON_FUNCS_H -#define METADEF_CXX_CG_CALC_TMP_BUFF_COMMON_FUNCS_H - -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "graph/symbolizer/symbolic.h" - -inline std::vector> SameTmpBufSizeWithFirstInput(const ge::AscNode &node) { - std::vector> tmp_buf_descs; - ge::AscNodeInputs node_inputs = node.inputs; - if (node_inputs.Size() <= 0) { - return tmp_buf_descs; - } - auto expr = ge::Expression(ge::Symbol(ge::GetSizeByDataType(node_inputs[0].attr.dtype))); - for (const auto &repeat : node_inputs[0].attr.repeats) { - expr = ge::sym::Mul(expr, repeat); - } - tmp_buf_descs.emplace_back(std::make_unique(ge::TmpBufDesc{expr})); - return tmp_buf_descs; -} - - -#endif // METADEF_CXX_CG_CALC_TMP_BUFF_COMMON_FUNCS_H diff --git a/inc/graph/attr_value_serializable.h b/inc/graph/attr_value_serializable.h deleted file mode 100644 index ee065807ead3d70b141dc57ad65a2c8e32b2622a..0000000000000000000000000000000000000000 --- a/inc/graph/attr_value_serializable.h +++ /dev/null @@ -1,17 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ -#define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ - -#include -#include "graph/ge_attr_value.h" -#include "graph/compiler_options.h" - -#endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ diff --git a/inc/graph/attribute_group/attr_group_serialize.h b/inc/graph/attribute_group/attr_group_serialize.h deleted file mode 100644 index 13cd10df2af60815bc5fb4664372f2242ca4d51d..0000000000000000000000000000000000000000 --- a/inc/graph/attribute_group/attr_group_serialize.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef INC_GRAPH_ATTR_GROUP_SERIALIZE_H -#define INC_GRAPH_ATTR_GROUP_SERIALIZE_H - -#include "graph/ge_error_codes.h" -#include "graph/attr_store.h" -#include "graph/detail/attributes_holder.h" -#include "proto/ge_ir.pb.h" - -namespace ge { -namespace proto { -class AttrGroups; -} -class AttrGroupSerialize { - public: - static graphStatus SerializeAllAttr(proto::AttrGroups &attr_groups, const AttrStore &attr_store); - static graphStatus DeserializeAllAttr(const proto::AttrGroups &attr_group, AttrHolder *attr_holder); - - private: - static graphStatus OtherGroupSerialize(proto::AttrGroups &attr_groups, const AttrStore &attr_store); - static graphStatus OtherGroupDeserialize(const proto::AttrGroups &attr_groups, AttrStore &attr_store) ; -}; -} - -#endif // INC_GRAPH_ATTR_GROUP_SERIALIZE_H diff --git a/inc/graph/attribute_group/attr_group_serializer_registry.h b/inc/graph/attribute_group/attr_group_serializer_registry.h deleted file mode 100644 index 09355814b49a6b4624e7669c8a2e3dc60664e28d..0000000000000000000000000000000000000000 --- a/inc/graph/attribute_group/attr_group_serializer_registry.h +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2025. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef METADEF_CXX_INC_GRAPH_ATTRIBUTE_GROUP_ATTR_GROUP_SERIALIZER_REGISTRY_H_ -#define METADEF_CXX_INC_GRAPH_ATTRIBUTE_GROUP_ATTR_GROUP_SERIALIZER_REGISTRY_H_ -#include -#include -#include -#include - -#include "graph/attribute_group/attr_group_base.h" -#include "proto/ge_ir.pb.h" - -#define REG_ATTR_GROUP_SERIALIZER(serializer_name, cls, obj_type, bin_type) \ - REG_ATTR_GROUP_SERIALIZER_BUILDER_UNIQ_HELPER(serializer_name, __COUNTER__, cls, obj_type, bin_type) - -#define REG_ATTR_GROUP_SERIALIZER_BUILDER_UNIQ_HELPER(name, ctr, cls, obj_type, bin_type) \ - REG_ATTR_GROUP_SERIALIZER_BUILDER_UNIQ(name, ctr, cls, obj_type, bin_type) - -#define REG_ATTR_GROUP_SERIALIZER_BUILDER_UNIQ(name, ctr, cls, obj_type, bin_type) \ - static ::ge::AttrGroupSerializerRegister register_serialize_##name##ctr \ - __attribute__((unused)) = \ - ::ge::AttrGroupSerializerRegister([]()->std::unique_ptr{ \ - return std::unique_ptr(new(std::nothrow)cls()); \ - }, obj_type, bin_type) - -namespace ge { -template -struct HashedPointer { - explicit HashedPointer(const T *ptr) : hash_value(std::hash{}(ptr)) {} - size_t hash_value; - std::string ToString() const { - return "Hashed_" + std::to_string(hash_value); - } -}; -using AttrGroupSerializeBuilder = std::function()>; -struct AttrGroupDeserializer { - AttrGroupDeserializer(std::unique_ptr impl_obj, TypeId id_obj) - : impl(std::move(impl_obj)), id(id_obj) {} - std::unique_ptr impl{nullptr}; - TypeId id{nullptr}; -}; -class AttrGroupSerializerRegistry { - public: - AttrGroupSerializerRegistry(const AttrGroupSerializerRegistry &) = delete; - AttrGroupSerializerRegistry(AttrGroupSerializerRegistry &&) = delete; - AttrGroupSerializerRegistry &operator=(const AttrGroupSerializerRegistry &) = delete; - AttrGroupSerializerRegistry &operator=(AttrGroupSerializerRegistry &&) = delete; - - ~AttrGroupSerializerRegistry() = default; - - static AttrGroupSerializerRegistry &GetInstance(); - /** - * 注册一个Attr Group的序列化、反序列化handler - * @param builder 调用该builder时,返回一个handler的实例 - * @param obj_type 内存中的数据类型,可以通过`GetTypeId`函数获得 - * @param proto_type protobuf数据类型枚举值 - */ - void RegisterAttrGroupSerialize(const AttrGroupSerializeBuilder &builder, - const TypeId obj_type, - const proto::AttrGroupDef::AttrGroupCase proto_type); - - std::unique_ptr GetSerializer(const TypeId obj_type); - AttrGroupDeserializer GetDeserializer(const proto::AttrGroupDef::AttrGroupCase proto_type); - - private: - AttrGroupSerializerRegistry() = default; - - std::mutex mutex_; - std::map serializer_builder_map_; - std::map> deserializer_builder_map_; -}; - -class AttrGroupSerializerRegister { - public: - AttrGroupSerializerRegister(const AttrGroupSerializeBuilder builder, - TypeId const obj_type, - const proto::AttrGroupDef::AttrGroupCase proto_type) noexcept; - ~AttrGroupSerializerRegister() = default; -}; -} // namespace ge -#endif // METADEF_CXX_INC_GRAPH_ATTRIBUTE_GROUP_ATTR_GROUP_SERIALIZER_REGISTRY_H_ diff --git a/inc/graph/attribute_group/attr_group_shape_env.h b/inc/graph/attribute_group/attr_group_shape_env.h deleted file mode 100644 index 9ab0145fc8d0e5bafb94a0aef76209fc1fee3554..0000000000000000000000000000000000000000 --- a/inc/graph/attribute_group/attr_group_shape_env.h +++ /dev/null @@ -1,227 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef INC_GRAPH_ATTR_GROUP_ATTR_GROUP_SHAPE_ENV_H -#define INC_GRAPH_ATTR_GROUP_ATTR_GROUP_SHAPE_ENV_H - -#include -#include -#include -#include "attr_group_base.h" -#include "graph/symbolizer/symbolic.h" -#include "graph/symbolizer/symbolic_utils.h" -#include "common/checker.h" - -namespace ge { -namespace proto { -class AttrGroupDef; -class ShapeEnvAttrGroupsDef; -} - -ShapeEnvAttr *GetCurShapeEnvContext(); -void SetCurShapeEnvContext(ShapeEnvAttr *shape_env); - -class Source { - public: - virtual ~Source() = default; - // 目的是用于codegen GetAllSym时生成符号来源 - virtual std::string GetSourceStr() const = 0; - // 目的是用于codegen infershape\tiling时生成符号索引 - virtual std::string GetGlobalIndexStr() const; - - virtual size_t GetGlobalIndex() const { - return global_index_; - } - void SetGlobalIndex(size_t global_index) { - global_index_ = global_index; - }; - // todoo: 兼容上库,待删除 - virtual int32_t GetInputDataIdx() const { - return std::numeric_limits::max(); - }; - virtual size_t GetDimIdx() const { - return std::numeric_limits::max(); - } -private: - size_t global_index_{std::numeric_limits::max()}; -}; -using SourcePtr = std::shared_ptr; - -// todoo: 当前兼容ascgen仓使用,待删除 -struct HashSymbol { - size_t operator()(const Expression &e) const { - int64_t value_int = 0L; - double value_float = 0.0f; - bool value_bool = false; - switch (e.GetExprType()) { - case ExprType::kExprConstantBoolean: - GE_ASSERT_TRUE(e.GetConstValue(value_bool)); - return std::hash()(value_bool); - case ExprType::kExprConstantInteger: - GE_ASSERT_TRUE(e.GetConstValue(value_int)); - return std::hash()(value_int); - case ExprType::kExprConstantRealDouble: - case ExprType::kExprConstantRation: - GE_ASSERT_TRUE(e.GetConstValue(value_float)); - return std::hash()(value_float); - default: - return std::hash()(std::string(e.Serialize().get())); - } - } -}; -struct SymbolCheckInfo { - ge::Expression expr; - std::string file; - int64_t line{}; - std::string dfx_info; - explicit SymbolCheckInfo(const ge::Expression &in_expr, - const std::string &in_file = "", const int64_t in_line = -1, const std::string &dfx = "") - : expr(in_expr), file(in_file), line(in_line), dfx_info(dfx) {} - SymbolCheckInfo() = default; - bool operator==(const SymbolCheckInfo &other) const { - return this->expr == other.expr; - } -}; - -struct SymbolCheckInfoKeyLess { - bool operator()(const SymbolCheckInfo &a, const SymbolCheckInfo &b) const { - // 只比较expr, file与line暂不比较 - return a.expr.Compare(b.expr) < 0; - } -}; - -// 配置符号的生成方式 -// dynamic:不管hint值是否相等,均生成新符号 -// duck:当hint值相同时,则不生成新符号,使用之前生成过的符号 -// static:根据hint值生成符号,同时添加一个Assert(sym == hint)的guard -enum class DynamicMode { - kDynamic = 0, - kDuck = 1, - kStatic = 2, - kEnd = 3 -}; - -struct ShapeEnvSetting { - bool specialize_zero_one{false}; - DynamicMode dynamic_mode{DynamicMode::kDynamic}; - ShapeEnvSetting() = default; - ShapeEnvSetting(const bool in_specialize_zero_one, const DynamicMode &in_dynamic_mode) - : specialize_zero_one(in_specialize_zero_one), dynamic_mode(in_dynamic_mode) {} -}; - -struct Replacement { - ge::Expression replace_expr; - int32_t rank; - bool has_replace; - Replacement(const ge::Expression &a, const int32_t in_rank, bool in_has_replace = false) - : replace_expr(a), rank(in_rank), has_replace(in_has_replace) {} - Replacement() : rank(0), has_replace(false) {} - bool operator<=(const Replacement &other); -}; - -class ShapeEnvAttr : public AttrGroupsBase { -public: - ShapeEnvAttr() = default; - ~ShapeEnvAttr() override = default; - explicit ShapeEnvAttr(const ShapeEnvSetting &shape_env_setting) : shape_env_setting_(shape_env_setting) {} - - ShapeEnvAttr(const ShapeEnvAttr& other); - ShapeEnvAttr &operator=(const ShapeEnvAttr& other); - graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override; - graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override; - // 只支持int32,uint32, int64, uint64 - template - typename std::enable_if::value, Symbol>::type CreateSymbol(T hint, const SourcePtr &source) { - const std::lock_guard lk(mutex_); - auto hint_int64 = static_cast(hint); - GE_ASSERT_TRUE((shape_env_setting_.dynamic_mode >= DynamicMode::kDynamic) && - (shape_env_setting_.dynamic_mode < DynamicMode::kEnd), - "Invalid dynamic mode: %d, create symbol failed", shape_env_setting_.dynamic_mode); - if (shape_env_setting_.specialize_zero_one && ((hint_int64 == 0) || (hint_int64 == 1))) { - GELOGI("Create symbol %d for in specialize_zero_one mode, source: %s", hint_int64, - source->GetSourceStr().c_str()); - return Symbol(hint_int64); - } - if (shape_env_setting_.dynamic_mode != DynamicMode::kDynamic) { - // 非动态模式,hint值相同使用同一个符号 - const auto &iter = value_to_symbol_.find(hint_int64); - if (iter != value_to_symbol_.end()) { - GE_ASSERT_TRUE(!iter->second.empty()); - return Symbol(iter->second.front().Serialize().get()); - } - } - auto global_index = unique_sym_id_++; - GE_ASSERT_TRUE(global_index < std::numeric_limits::max(), - "unique_sym_id_ is " PRIu64 ". will reach the maximum value of uint64.", unique_sym_id_); - source->SetGlobalIndex(global_index); - const std::string sym_name = "s" + std::to_string(global_index); - auto sym = Symbol(sym_name.c_str()); - symbol_to_source_.emplace(sym, source); - symbol_to_value_.emplace(sym, hint_int64); - const auto iter = value_to_symbol_.find(hint_int64); - if (iter != value_to_symbol_.end()) { - iter->second.emplace_back(sym); - } else { - std::vector syms = {sym}; - value_to_symbol_.emplace(hint_int64, syms); - } - // 静态场景需要增加一个s == hint的Assert信息 - if (shape_env_setting_.dynamic_mode == DynamicMode::kStatic) { - ASSERT_SYMBOL_EQ(sym, Symbol(hint_int64)); - } - return sym; - } - std::vector> GetAllSym2Src(); - - ge::Expression Simplify(const ge::Expression &expr); - void SimplifySymbolCheckInfo(); - ge::Expression EvaluateExpr(const ge::Expression &expr); - graphStatus AppendReplacement(const ge::Expression &expr1, const ge::Expression &expr2); - graphStatus AppendSymbolAssertInfo(const ge::Expression &expr, - const std::string &file = "", const int64_t line = 0L); - graphStatus AppendSymbolCheckInfo(const ge::Expression &expr, - const std::string &file = "", const int64_t line = 0L); - const std::vector GetAllSymbolCheckInfos() const; - const std::vector GetAllSymbolAssertInfos() const; - bool HasSymbolCheckInfo(const ge::Expression &expr) const; - bool HasSymbolAssertInfo(const ge::Expression &expr) const; - TriBool HasSymbolInfo(const ge::Expression &expr) const; - std::unique_ptr Clone() override; - void SetGuardDfxContextInfo(const std::string &guard_dfx_info); - void ClearGuardDfxContextInfo(); - private: - void SimplifySymbolCheckInfo(std::set &symbol_check_infos); - void AppendInitReplacement(const ge::Expression &expr); - ge::Expression FindReplacements(const ge::Expression &expr); - graphStatus MergeReplacement(const ge::Expression &expr1, const ge::Expression &expr2); - graphStatus FindRootExpr(const ge::Expression &expr, ge::Expression &root_expr); - graphStatus SerializeSymbolCheckInfos(proto::ShapeEnvAttrGroupsDef *shape_env_group); - graphStatus MergePath(); - graphStatus SerializeSymbolInfo(proto::ShapeEnvAttrGroupsDef *shape_env_group); - graphStatus DeserializeSymbolInfo(const proto::ShapeEnvAttrGroupsDef &shape_env_group); - graphStatus DeserializeSymbolCheckInfos(const proto::ShapeEnvAttrGroupsDef &shape_env_group); - std::string GetGuardDfxContextInfo() const; - using UMapExprReplacement = std::unordered_map; - using UMapExprInt = std::unordered_map; - using UMapExprSource= std::unordered_map; - UMapExprReplacement replacements_; - UMapExprInt symbol_to_value_; - UMapExprSource symbol_to_source_; - std::map> value_to_symbol_; - std::set symbol_check_infos_; - std::set symbol_assert_infos_; - ShapeEnvSetting shape_env_setting_; - size_t unique_sym_id_{0U}; - std::mutex mutex_; - thread_local static std::string guard_dfx_info_; -}; - -} - -#endif // INC_GRAPH_ATTR_GROUP_ATTR_GROUP_SHAPE_ENV_H diff --git a/inc/graph/attribute_group/attr_group_symbolic_desc.h b/inc/graph/attribute_group/attr_group_symbolic_desc.h deleted file mode 100644 index cada951f09947bb91035208bac1de1513c137c93..0000000000000000000000000000000000000000 --- a/inc/graph/attribute_group/attr_group_symbolic_desc.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef INC_GRAPH_ATTR_GROUP_SYMBOLIC_DESC_H -#define INC_GRAPH_ATTR_GROUP_SYMBOLIC_DESC_H - -#include "graph/ge_error_codes.h" -#include "type_utils.h" -#include "attr_group_base.h" -#include "graph/tensor.h" -#include "exe_graph/runtime/symbolic_tensor.h" - -namespace ge { -namespace proto { -class AttrGroupDef; -} - -class SymbolicDescAttr : public AttrGroupsBase { - public: - SymbolicDescAttr() = default; - - ~SymbolicDescAttr() override = default; - graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override; - graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override; - std::unique_ptr Clone() override; - - gert::SymbolTensor symbolic_tensor; -}; -} -#endif // INC_GRAPH_ATTR_GROUP_SYMBOLIC_DESC_H diff --git a/inc/graph/cache_policy/aging_policy.h b/inc/graph/cache_policy/aging_policy.h deleted file mode 100644 index 397b532fe03ce2ed04c0ecbd54453e93f6f35335..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/aging_policy.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_H_ -#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_H_ -#include "graph/cache_policy/cache_state.h" - -namespace ge { -constexpr const size_t kDefaultCacheQueueDepth = 1U; -class AgingPolicy { - public: - AgingPolicy() = default; - virtual ~AgingPolicy() = default; - virtual void SetCachedAgingDepth(size_t depth) = 0; - virtual std::vector DoAging(const CacheState &cache_state) const = 0; - virtual bool IsReadyToAddCache(const CacheHashKey hash_key, const CacheDescPtr &cache_desc) = 0; - private: - AgingPolicy &operator=(const AgingPolicy &anging_polocy) = delete; - AgingPolicy(const AgingPolicy &anging_polocy) = delete; -}; -} -#endif diff --git a/inc/graph/cache_policy/aging_policy_lru.h b/inc/graph/cache_policy/aging_policy_lru.h deleted file mode 100644 index 5eb7f804a92c6ae00b5e3a2e3c4656fcbc937d8b..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/aging_policy_lru.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_LRU_H_ -#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_LRU_H_ -#include "graph/cache_policy/aging_policy.h" -#include "graph/cache_policy/policy_register.h" - -namespace ge { -class AgingPolicyLru : public AgingPolicy { -public: - virtual ~AgingPolicyLru() override = default; - void SetDeleteInterval(const uint64_t &interval) { - delete_interval_ = interval; - } - void SetCachedAgingDepth(size_t depth) override { - (void)depth; - } - bool IsReadyToAddCache(const CacheHashKey hash_key, const CacheDescPtr &cache_desc) override { - (void) hash_key; - (void) cache_desc; - return true; - } - std::vector DoAging(const CacheState &cache_state) const override; - -private: - uint64_t delete_interval_ = 0U; -}; - -REGISTER_AGING_POLICY_CREATOR(AgingPolicyType::AGING_POLICY_LRU, - []() { - return make_shared(); - }); -} // namespace ge -#endif diff --git a/inc/graph/cache_policy/aging_policy_lru_k.h b/inc/graph/cache_policy/aging_policy_lru_k.h deleted file mode 100644 index bc5439b9d70520c4d50e5bbb8aa1dd3d039ce9f2..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/aging_policy_lru_k.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_GRAPH_CACHE_POLICY_AGING_POLICY_LRU_K_H -#define METADEF_CXX_GRAPH_CACHE_POLICY_AGING_POLICY_LRU_K_H -#include "graph/cache_policy/aging_policy.h" -#include "graph/cache_policy/policy_register.h" - -namespace ge { -class AgingPolicyLruK : public AgingPolicy { - public: - AgingPolicyLruK() : depth_(kDefaultCacheQueueDepth) {} - explicit AgingPolicyLruK(size_t depth) : depth_(depth) {} - AgingPolicyLruK(size_t k_times, size_t depth) : k_times_(k_times), depth_(depth) {} - ~AgingPolicyLruK() override = default; - - void SetCachedAgingDepth(size_t depth) override { - depth_ = depth; - } - bool IsReadyToAddCache(const CacheHashKey hash_key, const CacheDescPtr &cache_desc) override { - return IsCacheDescAppearKTimes(hash_key, cache_desc); - } - std::vector DoAging(const CacheState &cache_state) const override; - - private: - bool IsCacheDescAppearKTimes(const CacheHashKey hash_key, const CacheDescPtr &cache_desc); - private: - size_t k_times_ = 2U; - size_t depth_; - // todo 历史缓存队列的老化 - std::mutex hash_2_cache_descs_and_count_mu_; - std::unordered_map>> hash_2_cache_descs_and_count_; -}; -REGISTER_AGING_POLICY_CREATOR(AgingPolicyType::AGING_POLICY_LRU_K, - []() { return std::make_shared(); }); -} -#endif // METADEF_CXX_GRAPH_CACHE_POLICY_AGING_POLICY_LRU_K_H diff --git a/inc/graph/cache_policy/cache_desc.h b/inc/graph/cache_policy/cache_desc.h deleted file mode 100644 index 012f2005161e3439b8b15058da8af2432c95e252..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/cache_desc.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_CACHE_DESC_H -#define GRAPH_CACHE_POLICY_CACHE_DESC_H - -#include -#include "graph/utils/hash_utils.h" -namespace ge { -class CacheDesc; -using CacheDescPtr = std::shared_ptr; -class CacheDesc { - public: - CacheDesc() = default; - virtual ~CacheDesc() = default; - virtual bool IsEqual(const CacheDescPtr &other) const = 0; - virtual bool IsMatch(const CacheDescPtr &other) const = 0; - virtual CacheHashKey GetCacheDescHash() const = 0; -}; -} // namespace ge -#endif diff --git a/inc/graph/cache_policy/cache_policy.h b/inc/graph/cache_policy/cache_policy.h deleted file mode 100644 index 57708694b38a12d36cdf7730ce902b13da17f044..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/cache_policy.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_CACHE_POLICY_H_ -#define GRAPH_CACHE_POLICY_CACHE_POLICY_H_ - -#include -#include -#include "cache_state.h" -#include "policy_register.h" -#include "graph/ge_error_codes.h" - -namespace ge { -class CachePolicy { - public: - ~CachePolicy() = default; - - CachePolicy(const CachePolicy &) = delete; - CachePolicy(CachePolicy &&) = delete; - CachePolicy &operator=(const CachePolicy &) = delete; - CachePolicy &operator=(CachePolicy &&) = delete; - - static std::unique_ptr Create(const MatchPolicyPtr &mp, const AgingPolicyPtr &ap); - static std::unique_ptr Create(const MatchPolicyType mp_type, const AgingPolicyType ap_type, - size_t cached_aging_depth = kDefaultCacheQueueDepth); - - graphStatus SetMatchPolicy(const MatchPolicyPtr mp); - - graphStatus SetAgingPolicy(const AgingPolicyPtr ap); - - CacheItemId AddCache(const CacheDescPtr &cache_desc); - - CacheItemId FindCache(const CacheDescPtr &cache_desc) const; - - std::vector DeleteCache(const DelCacheFunc &func); - - std::vector DeleteCache(const std::vector &delete_item); - - std::vector DoAging(); - - CachePolicy() = default; - - private: - CacheState compile_cache_state_; - MatchPolicyPtr mp_ = nullptr; - AgingPolicyPtr ap_ = nullptr; -}; -} // namespace ge -#endif diff --git a/inc/graph/cache_policy/cache_state.h b/inc/graph/cache_policy/cache_state.h deleted file mode 100644 index 881ba9429b0ac621da878d535f7971b02c8a5fdf..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/cache_state.h +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_CACHE_STATE_H -#define GRAPH_CACHE_POLICY_CACHE_STATE_H - -#include -#include -#include -#include -#include -#include - -#include "compile_cache_desc.h" - -namespace ge { -class CacheInfo; -using CacheItemId = uint64_t; -constexpr CacheItemId KInvalidCacheItemId = std::numeric_limits::max(); - -using DelCacheFunc = std::function; -using CCStatType = std::unordered_map>; - -class CacheInfo { -friend class CacheState; -public: - CacheInfo(const uint64_t timer_count, const CacheItemId item_id, const CacheDescPtr &desc) - : item_id_(item_id), desc_(desc), timer_count_(timer_count) {} - CacheInfo(const CacheInfo &other) - : item_id_(other.item_id_), desc_(other.desc_), timer_count_(other.timer_count_) {} - CacheInfo &operator=(const CacheInfo &other) { - timer_count_ = other.timer_count_; - item_id_ = other.item_id_; - desc_ = other.desc_; - return *this; - } - CacheInfo() = delete; - ~CacheInfo() = default; - - void RefreshTimerCount(uint64_t time_count) { - timer_count_ = time_count; - } - - uint64_t GetTimerCount() const noexcept { - return timer_count_; - } - - CacheItemId GetItemId() const noexcept { - return item_id_; - } - - const CacheDescPtr &GetCacheDesc() const noexcept { - return desc_; - } - -private: - CacheItemId item_id_; - CacheDescPtr desc_; - uint64_t timer_count_; -}; - -struct CacheInfoQueue { - void Insert(const CacheHashKey main_hash_key, std::vector &cache_info); - void EmplaceBack(const CacheHashKey main_hash_key, CacheInfo &cache_info); - void Erase(std::vector &delete_ids, const DelCacheFunc &is_need_delete_func); - - CCStatType cc_state_; - uint64_t cache_info_num_ = 0U; -}; - -class CacheState { -public: - CacheState() = default; - ~CacheState() = default; - - CacheItemId AddCache(const CacheHashKey main_hash_key, const CacheDescPtr &cache_desc); - - std::vector DelCache(const DelCacheFunc &func); - - std::vector DelCache(const std::vector &delete_item); - - const CCStatType &GetState() const { - return cache_info_queue.cc_state_; - } - - uint64_t GetCacheInfoNum() const { - return cache_info_queue.cache_info_num_; - } - - uint64_t GetCurTimerCount() const { - return cache_timer_count_; - } -private: - CacheItemId GetNextCacheItemId(); - void RecoveryCacheItemId(const std::vector &cache_items); - uint64_t GetNextTimerCount() { - const std::lock_guard lock(cache_timer_count_mu_); - return cache_timer_count_++; - } - - std::mutex cache_info_queue_mu_; - std::mutex cache_item_mu_; - - int64_t cache_item_counter_ = 0L; - std::queue cache_item_queue_; - CacheInfoQueue cache_info_queue; - - uint64_t cache_timer_count_ = 0U; - std::mutex cache_timer_count_mu_; -}; -} // namespace ge -#endif diff --git a/inc/graph/cache_policy/compile_cache_desc.h b/inc/graph/cache_policy/compile_cache_desc.h deleted file mode 100644 index 1a646ba7cdd4a37a0b88ebedd90a178a87ac5576..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/compile_cache_desc.h +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_COMPILE_CACHE_DESC_H -#define GRAPH_CACHE_POLICY_COMPILE_CACHE_DESC_H - -#include -#include -#include "cache_desc.h" -#include "graph/small_vector.h" -#include "graph/ascend_limits.h" -#include "graph/types.h" -#include "graph/def_types.h" -#include "graph/utils/hash_utils.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/ge_common/debug/log.h" - -namespace ge { -class CompileCacheDesc; -using CompileCacheDescPtr = std::shared_ptr; -class BinaryHolder { - public: - BinaryHolder() = default; - - ~BinaryHolder() = default; - - BinaryHolder(const BinaryHolder &other); - BinaryHolder(BinaryHolder &&other); - BinaryHolder &operator=(const BinaryHolder &other); - BinaryHolder &operator=(BinaryHolder &&other); - - BinaryHolder(const uint8_t *const data, const size_t data_len); - - static std::unique_ptr createFrom(std::unique_ptr &&ptr, size_t length); - - const uint8_t *GetDataPtr() const noexcept; - - const size_t &GetDataLen() const noexcept; - - bool operator!=(const BinaryHolder &second) const; - - private: - std::unique_ptr holder_ = nullptr; - size_t data_len_ = 0UL; -}; - -class TensorInfoArgs { - public: - TensorInfoArgs(const Format format, const Format origin_format, const DataType data_type) - : format_(format), - origin_format_(origin_format), - data_type_(data_type) {} - - ~TensorInfoArgs() = default; - - bool IsUnknownShape() const; - bool IsShapeInRange(const TensorInfoArgs &other) const; - bool IsTensorInfoMatch(const TensorInfoArgs &other) const; - Format GetFormat() const; - Format GetOriginFormat() const; - DataType GetDataType() const; - void SetShape(const std::vector &shape); - void SetShape(const SmallVector &shape); - void SetOriginShape(const std::vector &origin_shape); - void SetOriginShape(const SmallVector &origin_shape); - void SetShapeRange(const std::vector> &ranges); - bool operator!=(const TensorInfoArgs &second) const; - - private: - Format format_; - Format origin_format_; - DataType data_type_; - SmallVector shape_; - SmallVector origin_shape_; - SmallVector, kDefaultMaxInputNum> shape_range_; -}; - -class CompileCacheDesc : public CacheDesc { - friend class CacheHasher; - public: - CompileCacheDesc() = default; - ~CompileCacheDesc() override = default; - bool IsEqual(const CacheDescPtr &other) const override; - bool IsMatch(const CacheDescPtr &other) const override; - CacheHashKey GetCacheDescHash() const override; - void SetOpType(const std::string &op_type); - void AddBinary(const BinaryHolder &holder); - void AddBinary(BinaryHolder &&holder); - void AddTensorInfo(const TensorInfoArgs &tensor_info); - void SetScopeId(const std::initializer_list scope_id); - size_t GetTensorInfoSize(); - TensorInfoArgs *MutableTensorInfo(size_t index); - - private: - bool CheckWithoutTensorInfo(const CompileCacheDesc *first, const CompileCacheDesc *second) const; - std::string op_type_; // op type - SmallVector scope_id_; // graph_id and session_id - SmallVector tensor_info_args_vec_; // input tensordescs - SmallVector other_desc_; // attrs float float size -}; -} // namespace ge - -namespace std { -template<> -struct hash { - size_t operator()(const ge::BinaryHolder &value) const { - GE_CHECK_NOTNULL(value.GetDataPtr()); - size_t seed = ge::HashUtils::MultiHash(); - const uint64_t u8_data = ge::PtrToValue(ge::PtrToPtr(value.GetDataPtr())); - for (size_t idx = 0UL; idx < value.GetDataLen(); idx++) { - seed = ge::HashUtils::HashCombine(seed, *(ge::PtrToPtr(ge::ValueToPtr(u8_data + idx)))); - } - return seed; - } -}; -} // namespace std -#endif diff --git a/inc/graph/cache_policy/match_policy.h b/inc/graph/cache_policy/match_policy.h deleted file mode 100644 index 7a2ce33248ce90d7cdab3d0f016c0322dad0380a..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/match_policy.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_H_ -#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_H_ -#include "graph/cache_policy/cache_state.h" - -namespace ge { -class MatchPolicy { - public: - MatchPolicy() = default; - virtual ~MatchPolicy() = default; - virtual CacheItemId GetCacheItemId(const CCStatType &cc_state, const CacheDescPtr &desc) const = 0; - private: - MatchPolicy &operator=(const MatchPolicy &match_polocy) = delete; - MatchPolicy(const MatchPolicy &match_polocy) = delete; -}; -} // namespace ge -#endif diff --git a/inc/graph/cache_policy/match_policy_exact_only.h b/inc/graph/cache_policy/match_policy_exact_only.h deleted file mode 100644 index 66ba40820bbfbd5f3f559545f1a6c0bc5fe224a4..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/match_policy_exact_only.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_EXACT_ONLY_H_ -#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_EXACT_ONLY_H_ -#include "graph/cache_policy/match_policy.h" -#include "graph/cache_policy/policy_register.h" - -namespace ge { -class MatchPolicyExactOnly : public MatchPolicy { -public: - CacheItemId GetCacheItemId(const CCStatType &cc_state, const CacheDescPtr &desc) const override; - ~MatchPolicyExactOnly() override = default; -}; - -REGISTER_MATCH_POLICY_CREATOR(MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - []() { - return make_shared(); - }); -} // namespace ge -#endif - diff --git a/inc/graph/cache_policy/match_policy_for_exactly_the_same.h b/inc/graph/cache_policy/match_policy_for_exactly_the_same.h deleted file mode 100644 index 170c6edb51e50fdc88b1deea690188d12957ebbf..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/match_policy_for_exactly_the_same.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_GRAPH_CACHE_POLICY_MATCH_POLICY_FOR_EXACTLY_THE_SAME_H -#define METADEF_CXX_GRAPH_CACHE_POLICY_MATCH_POLICY_FOR_EXACTLY_THE_SAME_H -#include "graph/cache_policy/match_policy.h" -#include "graph/cache_policy/policy_register.h" - -namespace ge { -class MatchPolicyForExactlyTheSame : public MatchPolicy { - public: - MatchPolicyForExactlyTheSame() = default; - ~MatchPolicyForExactlyTheSame() override = default; - - CacheItemId GetCacheItemId(const CCStatType &cc_state, const CacheDescPtr &cache_desc) const override; -}; -REGISTER_MATCH_POLICY_CREATOR(MatchPolicyType::MATCH_POLICY_FOR_EXACTLY_THE_SAME, - []() { return std::make_shared(); }); -} // namespace ge -#endif // METADEF_CXX_GRAPH_CACHE_POLICY_MATCH_POLICY_FOR_EXACTLY_THE_SAME_H diff --git a/inc/graph/cache_policy/policy_register.h b/inc/graph/cache_policy/policy_register.h deleted file mode 100644 index c676c936e604a14be771778bae40d3ba24120821..0000000000000000000000000000000000000000 --- a/inc/graph/cache_policy/policy_register.h +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_POLICY_REGISTER_H_ -#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_POLICY_REGISTER_H_ -#include -#include -#include "common/checker.h" -#include "match_policy.h" -#include "aging_policy.h" - -namespace ge { -using MatchPolicyPtr = std::shared_ptr; -using AgingPolicyPtr = std::shared_ptr; -using MatchPolicyCreator = std::function; -using AgingPolicyCreator = std::function; -enum class MatchPolicyType { - MATCH_POLICY_EXACT_ONLY = 0, - MATCH_POLICY_FOR_EXACTLY_THE_SAME = 1 -}; -enum class AgingPolicyType { - AGING_POLICY_LRU = 0, - AGING_POLICY_LRU_K = 1 -}; - -class PolicyRegister { - public: - ~PolicyRegister() = default; - PolicyRegister(const PolicyRegister&) = delete; - PolicyRegister &operator=(const PolicyRegister &other) = delete; - static PolicyRegister &GetInstance(); - void RegisterMatchPolicy(const MatchPolicyType match_policy_type, const MatchPolicyCreator &creator) { - const std::lock_guard lock(mu_); - (void)match_policy_registry_.emplace(match_policy_type, creator); - return; - } - - void RegisterAgingPolicy(const AgingPolicyType aging_policy_type, const AgingPolicyCreator &creator) { - const std::lock_guard lock(mu_); - (void)aging_policy_registry_.emplace(aging_policy_type, creator); - } - - MatchPolicyPtr GetMatchPolicy(const MatchPolicyType match_policy_type) { - const auto iter = match_policy_registry_.find(match_policy_type); - if (iter != match_policy_registry_.end()) { - GE_ASSERT_NOTNULL(iter->second, "[GetMatchPolicy] failed. Match policy type : %d was incorrectly registered", - static_cast(match_policy_type)); - return iter->second(); - } - GELOGE(ge::GRAPH_FAILED, "[GetMatchPolicy] failed. Match policy type : %d has not been registered", - static_cast(match_policy_type)); - return nullptr; - } - AgingPolicyPtr GetAgingPolicy(const AgingPolicyType aging_policy_type) { - const auto iter = aging_policy_registry_.find(aging_policy_type); - if (iter != aging_policy_registry_.end()) { - GE_ASSERT_NOTNULL(iter->second, "[GetAgingPolicy] failed. Aging policy type : %d was incorrectly registered", - static_cast(aging_policy_type)); - return iter->second(); - } - GELOGE(ge::GRAPH_FAILED, "[GetAgingPolicy] failed. Aging policy type : %d has not been registered", - static_cast(aging_policy_type)); - return nullptr; - } - private: - PolicyRegister() = default; - std::mutex mu_; - std::map match_policy_registry_; - std::map aging_policy_registry_; -}; - -class MatchPolicyRegister { - public: - MatchPolicyRegister(const MatchPolicyType match_policy_type, const MatchPolicyCreator &creator); - ~MatchPolicyRegister() = default; -}; - -class AgingPolicyRegister { - public: - AgingPolicyRegister(const AgingPolicyType aging_policy_type, const AgingPolicyCreator &creator); - ~AgingPolicyRegister() = default; -}; - -#define REGISTER_MATCH_POLICY_CREATOR_COUNTER(policy_type, func, counter) \ - static MatchPolicyRegister match_policy_register##counter(policy_type, func) -#define REGISTER_MATCH_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, counter) \ - REGISTER_MATCH_POLICY_CREATOR_COUNTER(policy_type, func, counter) -#define REGISTER_MATCH_POLICY_CREATOR(policy_type, func) \ - REGISTER_MATCH_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, __COUNTER__) - -#define REGISTER_AGING_POLICY_CREATOR_COUNTER(policy_type, func, counter) \ - static AgingPolicyRegister aging_policy_register##counter(policy_type, func) -#define REGISTER_AGING_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, counter) \ - REGISTER_AGING_POLICY_CREATOR_COUNTER(policy_type, func, counter) -#define REGISTER_AGING_POLICY_CREATOR(policy_type, func) \ - REGISTER_AGING_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, __COUNTER__) -} // namespace ge -#endif diff --git a/inc/graph/common_error_codes.h b/inc/graph/common_error_codes.h deleted file mode 100644 index a7267477dce5385b932dfbf3f26251be28585710..0000000000000000000000000000000000000000 --- a/inc/graph/common_error_codes.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_COMMON_ERROR_CODES_H_ -#define INC_GRAPH_COMMON_ERROR_CODES_H_ - -#include "external/graph/ge_error_codes.h" - -namespace ge { -constexpr graphStatus NO_DEPENDENCE_FUNC = 50331647U; -constexpr graphStatus NO_OVERLAP_DIM = 50331646U; -constexpr graphStatus NOT_SUPPORT_SLICE = 50331645U; -} // namespace ge - -#endif // INC_GRAPH_COMMON_ERROR_CODES_H_ diff --git a/inc/graph/detail/model_serialize_imp.h b/inc/graph/detail/model_serialize_imp.h deleted file mode 100644 index b39209c2f5a9eba5c58e54adc93b26904db91847..0000000000000000000000000000000000000000 --- a/inc/graph/detail/model_serialize_imp.h +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ -#define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ - -#include -#include -#include -#include -#include -#include "graph/buffer.h" -#include "graph/model.h" -#include "graph/anchor.h" -#include "graph/detail/attributes_holder.h" -#include "graph/ge_tensor.h" -#include "graph/graph.h" -#include "graph/node.h" -#include "external/ge_common/ge_api_types.h" - -namespace ge { -using ComputeGraphPtr = std::shared_ptr; -using AnchorWithIndex = std::pair; -struct MyCmp { - bool operator()(const AnchorWithIndex &anchor1, const AnchorWithIndex &anchor2) const { - return anchor1.second < anchor2.second; - } -}; -using DstAnchors = std::set; - -struct NodeNameGraphReq { - public: - NodeNameGraphReq(const std::string &name, const int32_t index, const ComputeGraphPtr &graph) - : node_name(name), index(index), graph(graph) {} - friend class ModelSerializeImp; - - private: - std::string node_name; - int32_t index; - ComputeGraphPtr graph; -}; - -struct NodeNameNodeReq { - public: - NodeNameNodeReq(const std::string &src_name, const int32_t src_index, const int32_t src_out_peer_index, - const NodePtr dst_node, const int32_t dst_index, const std::string &dst_name) - : src_node_name(src_name), - src_out_index(src_index), - src_out_peer_index(src_out_peer_index), - dst_node(dst_node), - dst_in_index(dst_index), - dst_node_name(dst_name) {} - - friend class ModelSerializeImp; - private: - std::string src_node_name; - int32_t src_out_index; - int32_t src_out_peer_index; - NodePtr dst_node; - int32_t dst_in_index; - std::string dst_node_name; -}; - -class ModelSerializeImp { - public: - bool SerializeModel(const Model &model, proto::ModelDef *const model_proto, const bool not_dump_all = false) const; - // if is_dump_graph is true, ensure peer anchors of node in the same order during serialization and deserialization - // if is_dump_graph is false, cannot guarantee peer anchors in the same order during serialization and deserialization - bool SerializeModel(const Model &model, const bool is_dump_graph, proto::ModelDef *const model_proto, - const bool not_dump_all = false) const; - - bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *const graph_proto, - const bool not_dump_all = false) const; - bool SerializeGraph(const ConstComputeGraphPtr &graph, const bool is_dump_graph, proto::GraphDef *const graph_proto, - const bool not_dump_all = false) const; - - bool SerializeEdge(const NodePtr &node, proto::OpDef *const op_def_proto, const bool is_dump_graph = false) const; - - bool SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *const op_def_proto, - const bool not_dump_all = false) const; - - bool SerializeNode(const NodePtr &node, proto::OpDef *const op_def_proto, const bool not_dump_all = false) const; - bool SerializeNode(const NodePtr &node, const bool is_dump_graph, proto::OpDef *const op_def_proto, - const bool not_dump_all = false) const; - - bool SeparateModelDef(Buffer &buffer, const std::string &path, proto::ModelDef &model_def) const; - - bool SerializeToBuffer(const proto::ModelDef &model_def, Buffer &buffer) const; - - bool UnserializeModel(Model &model, proto::ModelDef &model_proto, - const bool is_enable_multi_thread = false); - bool SetWeightForModel(proto::OpDef &op_def) const; - - bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto); - - bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graph_proto); - - bool HandleNodeNameRef(); - - void AttrDefToOpDescIrDef(OpDescPtr &op_desc, proto::OpDef &op_def_proto) const; - bool UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) const; - void AttrDefToOpDescIn(OpDescPtr &op_desc, std::vector &key_in, std::vector &value_in) const; - void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_out, std::vector &value_out, - const std::vector &opt_input) const; - void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *const op_def_proto, - const bool not_dump_all = false) const; - void OpDescIrDefToAttrDef(const ConstOpDescPtr &op_desc, - google::protobuf::Map *op_desc_attr) const; - bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto); - - bool ParseNodeIndex(const std::string &node_index, std::string &node_name, int32_t &index) const; - - void SetProtobufOwner(const ProtoMsgOwner &buffer_proto_buf_onwer) { protobuf_owner_ = buffer_proto_buf_onwer; } - - bool LoadWeightFromFile(const std::string &file_path, const int64_t &length, std::string &weight) const; - - void SetAirModelPath(const std::string &path) { air_path_ = path; } - static bool SerializeAllAttrsFromAnyMap(const std::map &attr_map, - google::protobuf::Map *const mutable_attr); - static bool DeserializeAllAttrsToAttrHolder( - const google::protobuf::Map &proto_attr_map, AttrHolder *const attr_holder); - - private: - bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map &subgraphs) const; - Status ParallelUnserializeGraph( - std::map &graphs, - ::google::protobuf::RepeatedPtrField &graphs_proto); - Status UnserializeGraph( - std::map &graphs, - ::google::protobuf::RepeatedPtrField &graphs_proto); - void FixOpDefSubgraphInstanceName(const ConstOpDescPtr &op_desc) const; - - void ExtractMetaDataAttrIn(proto::OpDef &op_def_proto, std::vector &opt_input, - std::vector &key_in, std::vector &value_in) const; - void ExtractMetaDataAttr(proto::OpDef &op_def_proto, std::vector &key_out, - std::vector &value_out) const; - - int64_t GenDataInputInfo(const OutDataAnchorPtr &src_anchor, const InDataAnchorPtr &dst_anchor) const; - int64_t GenCtrlInputInfo(const OutControlAnchorPtr &src_anchor, const InControlAnchorPtr &dst_anchor) const; - void SaveEdgeInfo(const AnchorPtr &src_anchor, const AnchorPtr &dst_anchor, const int64_t src_out_peer_index, - const int64_t cur_index, std::unordered_map &edges) const; - bool LinkEdges(const std::unordered_map &edges) const; - - std::vector graph_input_node_names_; - std::vector graph_output_node_names_; - std::vector node_input_node_names_; - std::map node_map_; - ProtoMsgOwner protobuf_owner_; - std::string air_path_; // path store air model path -}; -} // namespace ge - -#endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ diff --git a/inc/graph/fast_graph/edge.h b/inc/graph/fast_graph/edge.h deleted file mode 100644 index 148ec76d4583e1138640e4fd5af166ab3bf54c7b..0000000000000000000000000000000000000000 --- a/inc/graph/fast_graph/edge.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_EDGE_H -#define INC_GRAPH_EDGE_H - -#include "graph/anchor.h" - -namespace ge { -constexpr int32_t kInvalidEdgeIndex = -2; -constexpr int32_t kControlEdgeIndex = -1; - -enum class DirectionType { - kDirectionInType, - kDirectionOutType, -}; - -template -struct Edge { - T *src = nullptr; // src node or output node - T *dst = nullptr; // dst node or input node - int32_t src_output = -1; // the output index of output node - int32_t dst_input = -1; // the input index of input node - int32_t in_edge_index = -1; // the record index of input node, it used to quickly find the edge in node. - int32_t out_edge_index = -1; // the record index of output node, it used to quickly find the edge in node. - Anchor *src_anchor_ptr = nullptr; // the reserved information. - Anchor *dst_anchor_ptr = nullptr; // the reserved information. -}; - -class FastNode; - -struct EdgeEndpoint { - FastNode *node; - int32_t index; - DirectionType type; -}; -struct EdgeEndpointWithDirection { - EdgeEndpointWithDirection() : node(nullptr), index(kInvalidEdgeIndex) {} - EdgeEndpointWithDirection(FastNode *n, int32_t i) : node(n), index(i) {} - bool operator<(const EdgeEndpointWithDirection &rhs) const { - if (node < rhs.node) { - return true; - } - if (node > rhs.node) { - return false; - } - return index < rhs.index; - } - bool operator==(const EdgeEndpointWithDirection &rhs) const { - return (node == rhs.node) && (index == rhs.index); - } - FastNode *node; - int32_t index; -}; -using EdgeDstEndpoint = EdgeEndpointWithDirection; -using EdgeSrcEndpoint = EdgeEndpointWithDirection; -} // namespace ge -#endif // INC_GRAPH_EDGE_H diff --git a/inc/graph/fast_graph/execute_graph.h b/inc/graph/fast_graph/execute_graph.h deleted file mode 100644 index 056db5e4ea69d56ecc4730a80174252604512fba..0000000000000000000000000000000000000000 --- a/inc/graph/fast_graph/execute_graph.h +++ /dev/null @@ -1,254 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_EXECUTE_GRAPH_H -#define INC_GRAPH_EXECUTE_GRAPH_H - -#include -#include -#include -#include "graph/fast_graph/fast_node.h" -#include "graph/fast_graph/list_element.h" - -namespace ge { -template -class FastGraphImpl; - -using FastNodeFilter = std::function; - -class ExecuteGraph : public std::enable_shared_from_this, public AttrHolder { - public: - struct SubGraphInfo { - std::shared_ptr sub_graph; - ListElement *quick_graph; - }; - - explicit ExecuteGraph(const std::string &name); - ~ExecuteGraph() override{}; - - /** - * The function is shallow copy for ExecuteGraph - */ - ExecuteGraph &operator=(ge::ExecuteGraph &exec_graph); - - /** - * The function is deep copy for ExecuteGraph - */ - ExecuteGraph &CompleteCopy(ge::ExecuteGraph &exec_graph); - - /** - * The function is used to add node of graph. - * The node push back to container in graph. - */ - FastNode *AddNode(const OpDescPtr &op); - FastNode *AddNode(FastNode *const fast_node); - FastNode *AddNode(const OpDescPtr &op, int64_t id); - - /** - * The function is used to add node of graph. - * The node push front to container in graph. - */ - FastNode *AddNodeFront(const OpDescPtr &op); - FastNode *AddNodeFront(FastNode *const fast_node); - - /** - * The function is used to remove node of graph. - * The node don`t release and it will push to free container which is used to store free obj. - */ - graphStatus RemoveJustNode(const FastNode *const fast_node); - - /** - * The function is used to add input node of graph. - */ - FastNode *AddInputNode(FastNode *const fast_node); - - /** - * The function is used to remove input node of graph. - */ - graphStatus RemoveInputNode(FastNode *const fast_node); - - /** - * The function is used to add output node of graph. - */ - FastNode *AddOutputNodeByIndex(FastNode *const fast_node, int32_t index); - - /** - * The function is used to remove output node of graph. - */ - graphStatus RemoveOutputNode(const FastNode *const fast_node); - - /** - * The function is used to add edge of graph. - */ - FastEdge *AddEdge(FastNode *const src, int32_t src_index, FastNode *const dst, int32_t dst_index); - - /** - * The function is used to remove edge of graph. - * The edge don`t release and it will push to free container which is used to store free obj. - */ - graphStatus RemoveEdge(const FastEdge *const edge); - - const FastNode *GetParentNodeBarePtr() const; - FastNode *GetParentNodeBarePtr(); - void SetParentNode(FastNode *const node); - - /** - * The function is used to directly add subgraph of graph without any check. - * The shared pointer of subgraph will record in graph. - */ - ExecuteGraph *AddSubGraph(const std::shared_ptr &sub_graph); - - /** - * The function will add subgraph After strict checking the valid of subgraph. - * The shared pointer of subgraph will record in graph. - */ - ExecuteGraph *AddSubGraph(const std::shared_ptr &sub_graph_ptr, const std::string &name); - - /** - * The function is used to remove subgraph of graph. - * The shared pointer of subgraph will clear in graph. - */ - graphStatus RemoveSubGraph(const ExecuteGraph *const sub_graph); - graphStatus RemoveSubGraph(const std::string &name); - - /** - * The function is used to get subgraph with name. - */ - ExecuteGraph *GetSubGraph(const std::string &name) const; - - /** - * remove all subgraph from parent graph. - */ - void ClearAllSubGraph(); - - /** - * get the number of direct nodes form graph. - */ - size_t GetDirectNodesSize() const; - - /** - * get direct nodes from graph (it is convert to vector which is long time). - * external modifications don`t affect internal nodes. - */ - std::vector GetDirectNode() const; - - /** - * get all edges from graph (it is convert to vector which is long time). - * external modifications don`t affect internal edges. - */ - std::vector GetAllEdges() const; - - /** - * get all sub graph from graph (it is convert to vector which is long time). - * external modifications don`t affect internal edges. - */ - std::vector GetAllSubgraphs() const; - - /** - * find the node with node token in the graph. - */ - const FastNode *FindNode(size_t token) const; - - /** - * is is topo sort (include dfs, bfs, DFS_POSTORDER). - */ - graphStatus TopologicalSortingGraph(const ExecuteGraph *const execute_graph, const bool dfs_reverse); - - /** - * get name of graph. - */ - std::string GetName() const; - - /** - * set name of graph. - */ - void SetName(const std::string &name); - - void SetParentGraph(ExecuteGraph *const parent_graph); - - const ExecuteGraph *GetParentGraphBarePtr(void) const; - ExecuteGraph *GetParentGraphBarePtr(void); - - /** - * topo sort in the graph (include sub graph). - */ - graphStatus TopologicalSorting(); - - /** - * push edge to free edge. - */ - graphStatus RecycleQuickEdge(const FastEdge *const fast_edge); - - /** - * push node to free edge. - */ - graphStatus RecycleQuickNode(const FastNode *const fast_node); - - /** - * get all of nodes in graph (include subgraph). - */ - std::vector GetAllNodes() const; - std::vector GetAllNodes(const FastNodeFilter &fast_node_filter) const; - - /** - * It is used to set input order which is used in topo sorting - */ - void SetInputsOrder(const std::vector &inputs_order); - - void ReorderByNodeId(); - - void SetGraphId(size_t graph_id); - - size_t GetGraphId() const; - - bool CheckNodeIsInGraph(const FastNode *const node) const; - - bool CheckEdgeIsInGraph(const FastEdge *const edge) const; - - /** - * The edge belong to graph. - * somtime, we need to change the owner of edge to correct graph. - */ - graphStatus MoveEdgeToGraph(const FastEdge *const edge); - - protected: - ProtoAttrMap &MutableAttrMap() override; - ConstProtoAttrMap &GetAttrMap() const override; - - private: - std::vector AllGraphNodes(std::vector> &subgraphs, - const FastNodeFilter &fast_node_filter) const; - void GetAllNodesFromOpdesc(std::vector> &subgraphs, const OpDesc &op_desc, - std::deque &candidates) const; - void RemoveNodeFromNodesFree(const FastNode *const fast_node) const; - graphStatus SortNodes(std::vector &stack, std::map &map_in_edge_num) const; - void GetOutNodesFromEdgesToMap(std::map &map_in_edge_num, FastNode *node, - std::map &breadth_node_map) const; - graphStatus CollectBreadthOutNode(const FastNode *const node, std::map &map_in_edge_num, - std::map &breadth_node_map) const; - graphStatus BFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ExecuteGraph *const compute_graph) const; - graphStatus DFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ExecuteGraph *const compute_graph) const; - graphStatus RDFSTopologicalSorting(std::vector &node_vec, const bool reverse, - const ExecuteGraph *const compute_graph) const; - void GetInNodes(const FastNode *const current, std::vector &input_nodes) const; - - private: - std::shared_ptr> graph_shared_; - std::unordered_map names_to_subgraph_; - std::vector inputs_order_; - AttrStore attrs_; - - friend class ExecuteGraphAdapter; - friend class ExecuteGraphUtils; -}; -using ExecuteGraphPtr = std::shared_ptr; -} // namespace ge -#endif // INC_GRAPH_EXECUTE_GRAPH_H diff --git a/inc/graph/fast_graph/fast_node.h b/inc/graph/fast_graph/fast_node.h deleted file mode 100644 index 9fc0fb7da21d8b58aad9154dee99bbbb94c1aeff..0000000000000000000000000000000000000000 --- a/inc/graph/fast_graph/fast_node.h +++ /dev/null @@ -1,433 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef D_INC_GRAPH_FAST_NODE_H -#define D_INC_GRAPH_FAST_NODE_H - -#include -#include -#include -#include -#include "graph/op_desc.h" -#include "graph/fast_graph/edge.h" - -namespace ge { -struct OutDataEdgeStatisticsInfo { - size_t total_num = 0UL; // The total number edge of all inputs or all outputs - std::vector per_edges_num; // The number of one input or one output. -}; -constexpr uint64_t kInvalidSymbol = UINT64_MAX; - -class ExecuteGraph; -class FastNode; -using FastEdge = Edge; - -class ExtendInfo { - public: - virtual ~ExtendInfo() {} - /** - * set the input index of graph. - */ - void SetInputIndex(int32_t idx); - - /** - * get the input index of graph. - * Don`t use this function unless it is explicitly required. - */ - int32_t GetInputIndex() const; - - /** - * add the output index of graph. - * Don`t use this function unless it is explicitly required. - */ - void AddOneOutputIndex(int32_t idx); - - /** - * get the output index of graph. - * Don`t use this function unless it is explicitly required. - */ - std::vector &GetOutputIndex(); - - /** - * get the owner graph of node. - */ - ExecuteGraph *GetOwnerGraphBarePtr() const; - - /** - * set the owner graph of node. - */ - graphStatus SetOwnerGraph(ExecuteGraph *const graph, const FastNode *const fast_node); - - /** - * check the extend information is same with r_info. - */ - bool operator==(const ExtendInfo &r_info) const; - - /** - * get the flag of host node. - */ - bool GetHostNode() const; - - /** - * set the flag of host node. - */ - void SetHostNode(const bool is_host); - - /** - * clear members. - */ - void Clear(); - - /** - * update size of input_symbols. - */ - void UpdateInputSymbols(size_t data_in_num); - - /** - * update size of output_symbols. - */ - void UpdateOutputSymbols(size_t data_out_num); - - /** - * set symbol of the idx data input - */ - graphStatus SetInputSymbol(size_t idx, uint64_t symbol); - - /** - * set symbol of the idx data output - */ - graphStatus SetOutputSymbol(size_t idx, uint64_t symbol); - - /** - * get symbol of the idx data input - */ - uint64_t GetInputSymbol(size_t idx); - - /** - * get symbol of the idx data output - */ - uint64_t GetOutputSymbol(size_t idx); - - private: - bool IsDataIndexValid(size_t idx, const std::vector &symbols) const; - ExecuteGraph *execute_graph_ = nullptr; - std::vector output_index_; - int32_t input_index_ = kControlEdgeIndex; - bool host_node_ = false; - std::vector input_symbols_{}; - std::vector output_symbols_{}; -}; - -class FastNode { - public: - /** - * construct a fastnode. - * please call Init after construction - */ - FastNode(); - - ~FastNode(); - /** - * The function is used to init node with opdesc. - */ - graphStatus Init(const OpDescPtr &op); - - /** - * get the bare pointer of op desc. - */ - OpDesc *GetOpDescBarePtr() const; - - /** - * get the shared pointer of op desc. - */ - OpDescPtr GetOpDescPtr() const; - - /** - * get the type of node. - */ - std::string GetType() const; - - /** - * get the type of node. - */ - const char *GetTypePtr() const; - - /** - * get the name of node. - */ - std::string GetName() const; - - /** - * get the name of node. - */ - const char *GetNamePtr() const; - - /** - * record the edge info to node. - * The funcion is not recommended, please used the AddEdge funcion istead. - */ - graphStatus RecordEdge(Edge *const edge, DirectionType type); - - /** - * clear the edge info to node. - * The funcion is not recommended, please used the RemoveEdge funcion istead. - */ - graphStatus EraseEdge(const Edge *const edge, DirectionType type); - - /** - * adjust the position of the edge in the node record. - */ - graphStatus MoveEdge(DirectionType type, int32_t io_idx, int32_t cur_array_index, int32_t replace_array_index); - - /** - * get a unique identifier of node. - * please notes the unique identifier is in the graph. - */ - size_t GetNodeToken() const; - - /** - * get the number of input data in the node. - */ - size_t GetDataInNum() const; - - /** - * get the number of output data in the node. - */ - size_t GetDataOutNum() const; - - /** - * update the number of data input in the node. - */ - void UpdateDataInNum(size_t new_num); - - /** - * update the number of data ouput in the node. - */ - void UpdateDataOutNum(size_t new_num); - - /** - * get the total number of output edges from the node. - */ - size_t GetAllOutEdgesSize() const; - size_t GetAllOutDataEdgesSize() const; - size_t GetAllOutControlEdgesSize() const; - - /** - * get the total number of input edges from the node. - */ - size_t GetAllInDataEdgesSize() const; - size_t GetAllInControlEdgesSize() const; - - /** - * check the node is same with r_node. - */ - bool operator==(const FastNode &r_node) const; - - /** - * get the total number of in edge from the node. - * the number include data edge and control edge. - */ - size_t GetAllInEdgeSize() const; - - /** - * collecting all input edge. - * please check the item, the item from vector may be nullptr. - * if the item is nullptr, it just continue to get next, no error handing is required. - */ - const std::vector *> &GetAllInDataEdgesRef() const; - - /** - * collecting all output edge. - * please check the item, the item from vector may be nullptr. - * if the item is nullptr, it just continue to get next, no error handing is required. - */ - const std::vector *> &GetAllOutControlEdgesRef() const; - const std::vector *>> &GetAllOutDataEdgesRef() const; - - /** - * collecting all output or input edge. - * it already filter the nullptr item. - */ - std::vector *> GetAllOutDataEdges() const; - std::vector *> GetAllOutControlEdges() const; - std::vector *> GetAllInDataEdges() const; - std::vector *> &MutableAllInDataEdges(); - - /** - * collecting input control edges with input index. - * please check the item, the item from vector may be nullptr. - * if the item is nullptr, it just continue to get next, no error handing is required. - */ - std::vector *> GetAllInControlEdges() const; - const std::vector *> &GetAllInControlEdgesRef() const; - - /** - * Check the number of out edge is zero. - */ - bool OutNodesIsEmpty() const; - - /** - * Set the relative node information. - * Don`t use this function unless it is explicitly required. - */ - void SetNodePtr(const std::shared_ptr &node); - - /** - * clear the relative node information. - * Don`t use this function unless it is explicitly required. - */ - void ClearNodePtr(); - void ClearNodeBarePtr(); - - /** - * get the relative node information. - * Don`t use this function unless it is explicitly required. - */ - std::shared_ptr GetNodePtr() const; - Node *GetNodeBarePtr() const; - - /** - * get the total number of edge with input index. - */ - size_t GetInEdgesSizeByIndex(int32_t idx) const; - - /** - * get the total number of edge with output index. - */ - size_t GetOutEdgesSizeByIndex(int32_t idx) const; - - /** - * collecting input data edge with input index. - * please check the item, the item from vector may be nullptr. - * if the item is nullptr, it just continue to get next, no error handing is required. - */ - Edge *GetInDataEdgeByIndex(int32_t idx) const; - - bool IsDirectlyControlledByNode(FastNode const *node) const; - - /** - * collecting all output edge with output index. - * please check the item, the item from vector may be nullptr. - * if the item is nullptr, it just continue to get next, no error handing is required. - */ - std::vector *> GetOutEdgesByIndex(int32_t idx) const; - const std::vector *> &GetOutEdgesRefByIndex(int32_t idx) const; - - /** - * remove all of edge in the node. - * please define remove_edge_func to delete edge. - * The general process is as follow: - * 1. clear the edge information in src node and dst node (use EraseEdge); - * 2. remove edge in container. - * 3. add the free edge to free container. - * example: - * node[1]->RemoveAllEdge([&compute_graph](FastEdge *e) { - * auto src_node = e->src; - * auto dst_node = e->dst; - * Utils::GetNode(src_node).EraseEdge(e, DirectionType::kDirectionOutType); - * Utils::GetNode(dst_node).EraseEdge(e, DirectionType::kDirectionInType); - * if (Utils::GetListElementAddr(e)->owner != nullptr) { - * Utils::GetListElementAddr(e)->owner->erase(Utils::GetListElementAddr(e)); - * } - * auto ret = compute_graph->RecycleQuickEdge(e); - * if ((ret != GRAPH_SUCCESS) && (e != nullptr)) { - * delete e; - * } - * }); - */ - void RemoveAllEdge(std::function *)> const &remove_edge_func); - - /** - * get the extend info of node. - */ - ExtendInfo *GetExtendInfo() const; - - /** - * get the numbers input edges which peer node is not NEXTITERATION or REFNEXTITERATION. - */ - size_t GetInEdgeSize() const; - - void UpdateOpDesc(const OpDescPtr &new_opdesc); - - /** - * get peer nodes from all input data edges. - */ - std::vector GetInDataNodes() const; - - /** - * get peer nodes from out data edges with index. - */ - std::vector GetOutDataNodesByIndex(int32_t index) const; - - /** - * get peer nodes from all out data edges. - */ - std::vector GetOutDataNodes() const; - - /** - * get peer nodes from all out control edges. - */ - std::vector GetOutControlNodes() const; - - /** - * get peer nodes from all in control edges. - */ - std::vector GetInControlNodes() const; - - /** - * get peer nodes from all out edges. - */ - std::vector GetAllOutNodes() const; - - /** - * get peer nodes from all in edges. - */ - std::vector GetAllInNodes() const; - - private: - graphStatus CheckAllInputParamter(DirectionType type, int32_t io_idx, int32_t cur_array_index, - int32_t replace_array_index) const; - inline bool CheckDataIndexIsValid(int32_t index, DirectionType type) const; - graphStatus Reset(); - void UpdateDataForIoNumChange(); - graphStatus RecordInControlEdge(FastEdge *const edge); - graphStatus RecordOutControlEdge(FastEdge *const edge); - graphStatus RecordInDataEdge(FastEdge *const edge, int32_t index); - graphStatus RecordOutDataEdge(FastEdge *const edge, int32_t index); - graphStatus EraseInControlEdge(const FastEdge *const edge); - graphStatus EraseOutControlEdge(const FastEdge *const edge); - graphStatus EraseInDataEdge(const FastEdge *const edge); - graphStatus EraseOutDataEdge(const FastEdge *const edge, int32_t index); - graphStatus ModifySizeByNodeType(const FastEdge *const fast_edge, size_t &in_edge_size) const; - - private: - std::string name_; - size_t node_token_ = 0UL; - OpDescPtr opdesc_ = nullptr; - std::shared_ptr self_ptr_ = nullptr; - Node *node_bare_ptr_ = nullptr; - - size_t data_in_num_ = 0UL; - size_t data_out_num_ = 0UL; - - mutable std::vector *> in_data_edges_; - std::vector *> in_control_edges_; - std::vector *> out_control_edges_; - std::vector *>> out_data_edges_; - - size_t in_data_edges_count_ = 0UL; - size_t in_control_edge_count_ = 0UL; - size_t out_control_edges_count_ = 0UL; - OutDataEdgeStatisticsInfo out_data_edges_info_; - - std::unique_ptr extend_info_ = nullptr; -}; - -} // namespace ge -#endif // D_INC_GRAPH_FAST_NODE_H diff --git a/inc/graph/fast_graph/list_element.h b/inc/graph/fast_graph/list_element.h deleted file mode 100644 index 8e8a8843fd3039aa1a6e707a01955bfdb71d011c..0000000000000000000000000000000000000000 --- a/inc/graph/fast_graph/list_element.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef D_INC_GRAPH_LIST_NODE_H -#define D_INC_GRAPH_LIST_NODE_H - -namespace ge { -enum class ListMode { kWorkMode = 0, kFreeMode }; - -template -class QuickList; - -template -struct ListElement { - ListElement *next; - ListElement *prev; - QuickList *owner; - ListMode mode; - T data; - explicit ListElement(const T &x) : data(x), next(nullptr), prev(nullptr), owner(nullptr), mode(ListMode::kFreeMode) {} - bool operator==(const ListElement &r_ListElement) const { - return data == r_ListElement.data; - } - ListElement() : next(nullptr), prev(nullptr), owner(nullptr), mode(ListMode::kFreeMode) {} - void SetOwner(QuickList *new_owner) { - owner = new_owner; - } -}; -} // namespace ge -#endif diff --git a/inc/graph/fast_graph/repeated_iterator.h b/inc/graph/fast_graph/repeated_iterator.h deleted file mode 100644 index da475f63ddef6ab7600d8fcebf99e732d62652ab..0000000000000000000000000000000000000000 --- a/inc/graph/fast_graph/repeated_iterator.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_REPEATED_ITERATOR_H -#define METADEF_CXX_REPEATED_ITERATOR_H -#include -#include - -namespace ge { -template -class RepeatedIterator { - public: - using iterator_category = std::forward_iterator_tag; - using difference_type = std::ptrdiff_t; - using value_type = T; - using pointer = T *; - using reference = T &; - using size_type = size_t; - - RepeatedIterator(size_type index, reference value) : index_(index), value_(value) {} - - reference operator*() const { - return value_; - } - - pointer operator->() const { - return &value_; - } - - RepeatedIterator &operator++() { - ++index_; - return *this; - } - RepeatedIterator operator++(int) { - RepeatedIterator ret = *this; - ++*this; - return ret; - } - - friend bool operator==(const RepeatedIterator &lhs, const RepeatedIterator &rhs) { - return (lhs.index_ == rhs.index_) && (&lhs.value_ == &rhs.value_); - } - friend bool operator!=(const RepeatedIterator &lhs, const RepeatedIterator &rhs) { - return !(lhs == rhs); - }; - - private: - size_type index_; - reference value_; -}; -} // namespace ge -#endif // METADEF_CXX_REPEATED_ITERATOR_H diff --git a/inc/graph/ge_context.h b/inc/graph/ge_context.h deleted file mode 100644 index 47d1fc406d7fa21fc04ba5f9f23ef5ae1c4d8d8d..0000000000000000000000000000000000000000 --- a/inc/graph/ge_context.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_GE_CONTEXT_H_ -#define INC_GRAPH_GE_CONTEXT_H_ - -#include -#include -#include -#include "graph/ge_error_codes.h" -#include "graph/option/optimization_option.h" - -namespace ge { -class GEContext { - public: - graphStatus GetOption(const std::string &key, std::string &option); - const std::string &GetReadableName(const std::string &key); - bool GetHostExecFlag() const; - bool GetTrainGraphFlag() const; - bool IsOverflowDetectionOpen() const; - bool IsGraphLevelSat() const; - uint64_t GetInputFusionSize() const; - uint64_t SessionId() const; - uint32_t DeviceId() const; - int32_t StreamSyncTimeout() const; - int32_t EventSyncTimeout() const; - void Init(); - void SetSessionId(const uint64_t session_id); - void SetContextId(const uint64_t context_id); - void SetCtxDeviceId(const uint32_t device_id); - void SetStreamSyncTimeout(const int32_t timeout); - void SetEventSyncTimeout(const int32_t timeout); - graphStatus SetOptionNameMap(const std::string &option_name_map_json); - void SetMultiBatchShapeIndex(uint32_t graph_id, - const std::map> &data_index_and_shape_map); - const std::map> GetMultiBatchShapeIndex(uint32_t graph_id); - OptimizationOption &GetOo() const; - - private: - thread_local static uint64_t session_id_; - thread_local static uint64_t context_id_; - uint32_t device_id_ = 0U; - // GEContext不允许拓展新的成员变量 -}; // class GEContext - -/// Get context -/// @return -GEContext &GetContext(); -static_assert(sizeof(GEContext) == 4U, "Do not add member to a thread-safe global variable"); -} // namespace ge -#endif // INC_GRAPH_GE_CONTEXT_H_ diff --git a/inc/graph/ge_global_options.h b/inc/graph/ge_global_options.h deleted file mode 100644 index 520ffb4182d48559a2e0a6f6f42c4a21024a69cf..0000000000000000000000000000000000000000 --- a/inc/graph/ge_global_options.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_ -#define INC_GRAPH_GE_GLOBAL_OPTIONS_H_ - -#include -#include -#include - -namespace ge { -std::mutex &GetGlobalOptionsMutex(); -std::map &GetMutableGlobalOptions(); -} -#endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_ diff --git a/inc/graph/ge_local_context.h b/inc/graph/ge_local_context.h deleted file mode 100644 index a51a9f8de24d8337d238288aef1b28917386065a..0000000000000000000000000000000000000000 --- a/inc/graph/ge_local_context.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_ -#define INC_GRAPH_GE_LOCAL_CONTEXT_H_ - -#include -#include -#include "graph/ge_error_codes.h" -#include "graph/option/optimization_option.h" - -namespace ge { -class GEThreadLocalContext { - public: - graphStatus GetOption(const std::string &key, std::string &option); - void SetGraphOption(std::map options_map); - void SetSessionOption(std::map options_map); - void SetGlobalOption(std::map options_map); - graphStatus SetOptionNameMap(const std::string &option_name_map_json); - const std::string &GetReadableName(const std::string &key); - - void SetStreamSyncTimeout(const int32_t timeout); - void SetEventSyncTimeout(const int32_t timeout); - int32_t StreamSyncTimeout() const; - int32_t EventSyncTimeout() const; - OptimizationOption &GetOo(); - - std::map GetAllGraphOptions() const; - std::map GetAllSessionOptions() const; - std::map GetAllGlobalOptions() const; - std::map GetAllOptions() const; - - private: - std::map graph_options_; - std::map session_options_; - std::map global_options_; - std::map option_name_map_; - int32_t stream_sync_timeout_ = -1; - int32_t event_sync_timeout_ = -1; - OptimizationOption optimization_option_; -}; // class GEThreadLocalContext - -GEThreadLocalContext &GetThreadLocalContext(); -} // namespace ge -#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ diff --git a/inc/graph/graph_util.h b/inc/graph/graph_util.h deleted file mode 100644 index 42e5de4bf4c96a67e4f0906edbc6603071a2c23d..0000000000000000000000000000000000000000 --- a/inc/graph/graph_util.h +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_GRAPH_UTIL_H_ -#define INC_GRAPH_GRAPH_UTIL_H_ - -#include - -namespace ge { -using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; -bool HasOpAttr(const OpDef *opdef, std::string attr_name); -bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); - -static const char OP_TYPE_DATA[] = "Data"; -static const char OP_TYPE_INPUT[] = "Input"; -static const char ATTR_KEY_INPUT_FORMAT[] = "input_format"; -static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format"; -static const char OP_TYPE_ANN_DATA[] = "AnnData"; -} // namespace ge - -#if !defined(__ANDROID__) && !defined(ANDROID) -#include "toolchain/slog.h" -const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"}; -#else -#include -#include -const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}; -#endif - -#ifdef _MSC_VER -#define FUNC_NAME __FUNCTION__ -#else -#define FUNC_NAME __PRETTY_FUNCTION__ -#endif - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \ - dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \ - dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \ - dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) -#else -#define D_GRAPH_LOG(level, format, ...) \ - do { \ - { \ - fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ - __FILE__, __LINE__, ##__VA_ARGS__); \ - syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ - ##__VA_ARGS__); \ - } \ - } while (0) -#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#endif - -#if !defined(__ANDROID__) && !defined(ANDROID) -#define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__) -#define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__) -#define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__) -#else - -#define GRAPH_LOG(level, format, ...) \ - do { \ - { \ - fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ - __FILE__, __LINE__, ##__VA_ARGS__); \ - syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ - ##__VA_ARGS__); \ - } \ - } while (0) -#define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) -#endif - -#define GRAPH_CHK_STATUS_RET_NOLOG(expr) \ - do { \ - const domi::graphStatus _status = (expr); \ - if (_status != domi::GRAPH_SUCCESS) { \ - return _status; \ - } \ - } while (0) - -#define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GRAPH_LOGE(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -// Do not add do...while(0), otherwise it wll introduce security issues -#define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ - { \ - bool b = (expr); \ - if (!b) { \ - exec_expr; \ - } \ - } - -// Do not add do...while(0), otherwise it wll introduce security issues -#define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \ - { \ - if (expr) { \ - exec_expr; \ - } \ - } - -#define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ - do { \ - const ::domi::graphStatus _status = (expr); \ - if (_status) { \ - GRAPH_LOGE(__VA_ARGS__); \ - return _status; \ - } \ - } while (0) - -#endif // INC_GRAPH_GRAPH_UTIL_H_ diff --git a/inc/graph/host_resource/host_resource.h b/inc/graph/host_resource/host_resource.h deleted file mode 100644 index 69b69380780d557cd44e784dff05a91c79570d68..0000000000000000000000000000000000000000 --- a/inc/graph/host_resource/host_resource.h +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_CXX_HOST_RESOURCE_H -#define INC_GRAPH_CXX_HOST_RESOURCE_H -namespace ge { -class HostResource { - public: - virtual ~HostResource() = default; -}; -} // namespace ge -#endif // INC_GRAPH_CXX_HOST_RESOURCE_H diff --git a/inc/graph/ir_definitions_recover.h b/inc/graph/ir_definitions_recover.h deleted file mode 100644 index df562d3352ba1c6dcd8f50c95fa37d2b7b54b4c4..0000000000000000000000000000000000000000 --- a/inc/graph/ir_definitions_recover.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_IR_DEFINITIONS_RECOVER_H_ -#define GRAPH_IR_DEFINITIONS_RECOVER_H_ - -#include -#include "graph/compute_graph.h" - -namespace ge { -ge::graphStatus RecoverIrDefinitions(const ge::ComputeGraphPtr &graph, const vector &attr_names = {}); -ge::graphStatus RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const std::string &op_type = ""); -bool CheckIrSpec(const ge::OpDescPtr &desc); -} // namespace ge -#endif // GRAPH_IR_DEFINITIONS_RECOVER_H_ diff --git a/inc/graph/model.h b/inc/graph/model.h deleted file mode 100644 index c72a0636b12ac958cfc6b6067ef893beb816bc7c..0000000000000000000000000000000000000000 --- a/inc/graph/model.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_MODEL_H_ -#define INC_GRAPH_MODEL_H_ - -#include -#include -#include "detail/attributes_holder.h" -#include "graph/ge_attr_value.h" -#include "graph/compute_graph.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { - public: - Model(); - - ~Model() override = default; - - Model(const std::string &name, const std::string &custom_version); - - Model(const char_t *name, const char_t *custom_version); - - std::string GetName() const; - - void SetName(const std::string &name); - - uint32_t GetVersion() const; - - void SetVersion(const uint32_t version) { version_ = version; } - - std::string GetPlatformVersion() const; - - void SetPlatformVersion(const std::string version) { platform_version_ = version; } - - const ComputeGraphPtr GetGraph() const; - - void SetGraph(const ComputeGraphPtr &graph); - - void SetAttr(const ProtoAttrMap &attrs); - - using AttrHolder::GetAllAttrNames; - using AttrHolder::GetAllAttrs; - using AttrHolder::GetAttr; - using AttrHolder::HasAttr; - using AttrHolder::SetAttr; - - graphStatus Save(Buffer &buffer, const bool is_dump = false) const; - graphStatus Save(proto::ModelDef &model_def, const bool is_dump = false) const; - graphStatus SaveWithoutSeparate(Buffer &buffer, const bool is_dump = false) const; - graphStatus SaveToFile(const std::string &file_name, const bool force_separate = false) const; - // Model will be rewrite - static graphStatus Load(const uint8_t *data, size_t len, Model &model); - /** - * 多线程加载模型接口,将data中的内容反序列化到model对象中 - * 当模型图具有多个子图时,此接口可以多线程并行加载子图加速,线程上线为16 - * @param data 模型序列化后的内容指针 - * @param len 模型序列化后的内容长度 - * @param model 模型加载后的承载对象 - * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED - */ - static graphStatus LoadWithMultiThread(const uint8_t *data, size_t len, Model &model); - graphStatus Load(ge::proto::ModelDef &model_def); - graphStatus LoadFromFile(const std::string &file_name); - - bool IsValid() const; - - protected: - ConstProtoAttrMap &GetAttrMap() const override; - ProtoAttrMap &MutableAttrMap() override; - - private: - void Init(); - graphStatus Load(ge::proto::ModelDef &model_def, const std::string &path); - graphStatus Save(Buffer &buffer, const std::string &path, const bool is_dump = false) const; - graphStatus SaveSeparateModel(Buffer &buffer, const std::string &path, const bool is_dump = false) const; - AttrStore attrs_; - friend class ModelSerializeImp; - friend class GraphDebugImp; - friend class OnnxUtils; - friend class ModelHelper; - friend class ModelBuilder; - std::string name_; - uint32_t version_; - std::string platform_version_{""}; - ComputeGraphPtr graph_; -}; -using ModelPtr = std::shared_ptr; -} // namespace ge - -#endif // INC_GRAPH_MODEL_H_ diff --git a/inc/graph/model_serialize.h b/inc/graph/model_serialize.h deleted file mode 100644 index 3e4644b3c309d3cfdcb9d14fb64820a5ef6d7636..0000000000000000000000000000000000000000 --- a/inc/graph/model_serialize.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_MODEL_SERIALIZE_H_ -#define INC_GRAPH_MODEL_SERIALIZE_H_ - -#include -#include -#include "graph/buffer.h" -#include "graph/compute_graph.h" -#include "graph/model.h" -#include "external/ge_common/ge_api_types.h" - -namespace ge { -class ModelSerialize { - public: - Buffer SerializeModel(const Model &model, const bool not_dump_all = false) const; - Buffer SerializeSeparateModel(const Model &model, const std::string &path, const bool not_dump_all = false) const; - Buffer SerializeModel(const Model &model, const std::string &path, - const bool is_need_separate, const bool not_dump_all = false) const; - Status SerializeModel(const Model &model, const bool not_dump_all, proto::ModelDef &model_def) const; - - bool UnserializeModel(const uint8_t *const data, const size_t len, - Model &model, const bool is_enable_multi_thread = false) const; - bool UnserializeModel(ge::proto::ModelDef &model_def, Model &model, const std::string &path) const; - bool UnserializeModel(ge::proto::ModelDef &model_def, Model &model) const; - private: - friend class ModelSerializeImp; - friend class GraphDebugImp; -}; -} // namespace ge -#endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/inc/graph/op_kernel_bin.h b/inc/graph/op_kernel_bin.h deleted file mode 100644 index a8a7834c1dc597ff6cae9c9dc6ec7965a37a5962..0000000000000000000000000000000000000000 --- a/inc/graph/op_kernel_bin.h +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_OP_KERNEL_BIN_H_ -#define INC_GRAPH_OP_KERNEL_BIN_H_ - -#include -#include -#include -#include "graph/types.h" -#include "graph/def_types.h" -#include "graph/host_resource/host_resource.h" - -namespace ge { -class OpKernelBin : public HostResource { - public: - OpKernelBin(const std::string &name, std::vector &&data) : name_(name), data_(std::move(data)) {} - - ~OpKernelBin() override = default; - - const std::string &GetName() const { return name_; } - const uint8_t *GetBinData() const { return ge::PtrToPtr(data_.data()); } - size_t GetBinDataSize() const { return data_.size(); } - OpKernelBin(const OpKernelBin &) = delete; - const OpKernelBin &operator=(const OpKernelBin &) = delete; - - private: - std::string name_; - std::vector data_; -}; - -using OpKernelBinPtr = std::shared_ptr; -constexpr char_t OP_EXTATTR_NAME_TBE_KERNEL[] = "tbeKernel"; -constexpr char_t OP_EXTATTR_NAME_THREAD_TBE_KERNEL[] = "thread_tbeKernel"; -constexpr char_t OP_EXTATTR_CUSTAICPU_KERNEL[] = "cust_aicpu_kernel"; -} // namespace ge - -#endif // INC_GRAPH_OP_KERNEL_BIN_H_ diff --git a/inc/graph/op_types.h b/inc/graph/op_types.h deleted file mode 100644 index c1d7129190046da9f6a1fbf883601a9e2319e507..0000000000000000000000000000000000000000 --- a/inc/graph/op_types.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_OP_TYPES_H_ -#define INC_GRAPH_OP_TYPES_H_ - -#include -#include - -#include "graph/types.h" - -namespace ge { -class GE_FUNC_VISIBILITY OpTypeContainer { - public: - static OpTypeContainer &Instance() { - static OpTypeContainer instance; - return instance; - } - ~OpTypeContainer() = default; - - void Register(const std::string &op_type) { static_cast(op_type_list_.insert(op_type)); } - - bool IsExisting(const std::string &op_type) { - return op_type_list_.find(op_type) != op_type_list_.end(); - } - - protected: - OpTypeContainer() {} - - private: - std::set op_type_list_; -}; - -class GE_FUNC_VISIBILITY OpTypeRegistrar { - public: - explicit OpTypeRegistrar(const std::string &op_type) noexcept { OpTypeContainer::Instance().Register(op_type); } - ~OpTypeRegistrar() {} -}; - -#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \ - FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char_t *var_name - -#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \ - const char_t *var_name = str_name; \ - const ge::OpTypeRegistrar g_##var_name##_reg(str_name) - -#define IS_OPTYPE_EXISTING(str_name) (ge::OpTypeContainer::Instance().IsExisting(str_name)) -} // namespace ge - -#endif // INC_GRAPH_OP_TYPES_H_ diff --git a/inc/graph/opsproto_manager.h b/inc/graph/opsproto_manager.h deleted file mode 100644 index 4fb59c7639fa30d7b048b5507a42a5c6b65e76fc..0000000000000000000000000000000000000000 --- a/inc/graph/opsproto_manager.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_OPSPROTO_MANAGER_H_ -#define INC_GRAPH_OPSPROTO_MANAGER_H_ - -#include -#include -#include -#include - -namespace ge { -class OpsProtoManager { - public: - OpsProtoManager() = default; - ~OpsProtoManager(); - static OpsProtoManager *Instance(); - - bool Initialize(const std::map &options); - void Finalize(); - - private: - void LoadOpsProtoPluginSo(const std::string &path); - - std::string pluginPath_; - std::vector handles_; - bool is_init_ = false; - std::mutex mutex_; -}; -} // namespace ge - -#endif // INC_GRAPH_OPSPROTO_MANAGER_H_ diff --git a/inc/graph/option/optimization_option.h b/inc/graph/option/optimization_option.h deleted file mode 100644 index bfb612f28b1e3e0bf80f24a4a38fc12096f01f2c..0000000000000000000000000000000000000000 --- a/inc/graph/option/optimization_option.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_OPTION_OPTIMIZATION_OPTION_H_ -#define INC_GRAPH_OPTION_OPTIMIZATION_OPTION_H_ - -#include -#include -#include -#include "graph/ge_error_codes.h" -#include "optimization_option_info.h" - -namespace ge { -class OptimizationOption { - public: - OptimizationOption() = default; - ~OptimizationOption() = default; - - graphStatus Initialize(const std::map &ge_options, - const std::unordered_map ®istered_options); - graphStatus Initialize(const std::map &ge_options, - const std::unordered_map ®istered_options, - const std::unordered_set &forbidden_option_set); - graphStatus GetValue(const std::string &opt_name, std::string &opt_value) const; - static graphStatus IsOoLevelValid(const std::string &oo_level); - static graphStatus IsOptionValueValid(const std::string &opt_name, const std::string &opt_value, - OoInfo::ValueChecker checker); - graphStatus RefreshPassSwitch(const std::string &fusion_config_str); - - private: - graphStatus InitWorkingOolevel(const std::map &ge_options); - void PrintAllWorkingOo(); - graphStatus SetPassSwitch(const std::string &pass_switch_str, const std::unordered_set &forbidden_option_set, bool force_update); - graphStatus UpdatePassSwitchByOption(const std::map &ge_options, const std::unordered_set &forbidden_option_set); - bool IsPassConfiguredWithOptimizationSwitch(const std::string &pass_name) const; - - private: - OoLevel working_oo_level_{OoLevel::kEnd}; - std::unordered_map working_opt_names_to_value_; -}; -} // namespace ge -#endif // INC_GRAPH_OPTION_OPTIMIZATION_OPTION_H_ diff --git a/inc/graph/option/optimization_option_info.h b/inc/graph/option/optimization_option_info.h deleted file mode 100644 index edb29f8c7d6993c83759883515103376d43c3c9e..0000000000000000000000000000000000000000 --- a/inc/graph/option/optimization_option_info.h +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_OPTION_OPTIMIZATION_OPTION_INFO_H -#define INC_GRAPH_OPTION_OPTIMIZATION_OPTION_INFO_H -#include -#include -#include -#include -#include -#include - -namespace ge { -// pre-defined set of optimizaion templates -enum class OoLevel : uint32_t { - kO0 = 0, - kO1 = 1, - kO2 = 2, - kO3 = 3, - kEnd, -}; -static_assert(static_cast(OoLevel::kEnd) < 64, "The number of OoLevels exceeds 64!"); - -// the hierarchy of optimizaion options -enum class OoHierarchy : uint32_t { - kH1 = 0, // Primary options - kH2 = 1, // Secondary options - kEnd, -}; - -// the entry points to GE -enum class OoEntryPoint : uint32_t { - kSession = 0, - kIrBuild = 1, - kAtc = 2, - kEnd, -}; - -enum class OoCategory : uint32_t { - kGeneral = 0, - kInput, - kOutput, - kTarget, - kFeature, - kModelTuning, - kOperatorTuning, - kDebug, - kEnd, -}; - -struct OoShowInfo { - OoCategory catagory; - std::string show_name; -}; - -struct OoInfo { - using ValueChecker = bool (*)(const std::string& opt_value); - // identifies the visibility of the option at different entrances of the program - uint64_t visibility; - uint64_t levels; - OoHierarchy hierarchy; - ValueChecker checker; - std::string name; - std::string help_text; - // Maps each entry point to its corresponding display option information - std::map show_infos; - std::map default_values; - - explicit OoInfo(std::string opt_name, OoHierarchy opt_hierarchy = OoHierarchy::kEnd, uint64_t opt_level = 0UL, - uint64_t opt_vis = 0UL, std::map opt_values = {}, - ValueChecker opt_checker = nullptr, std::map opt_entry_infos = {}, - std::string opt_help = "") - : visibility(opt_vis), levels(opt_level), hierarchy(opt_hierarchy), checker(opt_checker), - name(std::move(opt_name)), help_text(std::move(opt_help)), show_infos(std::move(opt_entry_infos)), - default_values(std::move(opt_values)) {} -}; - -class OoInfoUtils { - public: - static bool IsBitSet(const uint64_t bits, const uint32_t pos); - static uint64_t GenOptVisibilityBits(const std::vector &entries); - static uint64_t GenOptLevelBits(const std::vector &levels); - static std::string GenOoLevelStr(const uint64_t opt_level); - static std::string GetDefaultValue(const OoInfo &info, OoLevel target_level); - static bool IsSwitchOptValueValid(const std::string &opt_value); -}; -} // namespace ge -#endif // INC_GRAPH_OPTION_OPTIMIZATION_OPTION_INFO_H diff --git a/inc/graph/parallelism/graph_parallel_option.h b/inc/graph/parallelism/graph_parallel_option.h deleted file mode 100644 index 74196f19c8783248fe5580c0a726a20685d05414..0000000000000000000000000000000000000000 --- a/inc/graph/parallelism/graph_parallel_option.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ -#define METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ - -#include -#include - -namespace ge { -struct PipelineParallelOption { - bool is_enabled = false; - bool is_auto = false; - std::string pipeline_strategy; - int32_t pipe_stage_num = -1; - int32_t schedule_opt_virtual_stage_num = -1; -}; - -struct TensorParallelOption { - bool is_enabled = false; - bool is_auto = false; - int32_t tensor_parallel_size = -1; - int32_t inter_batch_flow_num = 1; -}; - -struct DataParallelOption { - bool is_enabled = false; - bool is_auto = false; - // to be deleted below - bool optimizer_state_sharding = false; - bool gradient_sharding = false; - bool model_weight_sharding = false; - bool model_weight_prefetch = true; - int32_t data_parallel_size = -1; - // model weight prefetch buffer size(MB) - uint32_t model_weight_prefetch_buffer_size = 0U; -}; - -struct TensorShardingOption { - bool is_enabled = false; - bool optimizer_state_sharding = false; - bool gradient_sharding = false; - bool model_weight_sharding = false; - bool model_weight_prefetch = true; - // model weight prefetch buffer size(MB) - uint32_t model_weight_prefetch_buffer_size = 0U; -}; - -struct OptimizerOffloadGraphOption { - bool is_enabled = false; - std::string offload; // cpu or NVME, NVME is reserved - std::string offload_path; // NVME path, reserved -}; - -struct EngineParallelOption { - bool is_enabled = false; - bool is_auto = false; - std::string config_path; // used if is_auto == true -}; - -struct GraphParallelOption { - bool auto_deploy = false; - std::string mode; // AOE mode, search_strategy/search_and_shard_graph/load_strategy/load_and_eval_strategy - std::string work_dir; // AOE dump/load path for strategies - std::string opt_level; - int32_t global_batch_size = -1; - DataParallelOption data_parallel_option; - TensorParallelOption tensor_parallel_option; - TensorShardingOption tensor_sharding_option; - PipelineParallelOption pipeline_parallel_option; - OptimizerOffloadGraphOption optimizer_offload_option; - EngineParallelOption engine_parallel_option; -}; -} // namespace ge - -#endif // METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ diff --git a/inc/graph/parallelism/tensor_parallel_attrs.h b/inc/graph/parallelism/tensor_parallel_attrs.h deleted file mode 100644 index ea56fe1544cb10afe8b157902d9a6c39b9e8d0a9..0000000000000000000000000000000000000000 --- a/inc/graph/parallelism/tensor_parallel_attrs.h +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ -#define METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ - -#include -#include -#include -#include -#include -#include "external/ge_common/ge_api_types.h" - -namespace ge { -namespace tp { -constexpr const char_t *kCommTaskTypeConcat = "Concat"; -constexpr const char_t *kCommTaskTypeUniqueConcat = "UniqueConcat"; -constexpr const char_t *kCommTaskTypeModifyValue = "ModifyValue"; -constexpr const char_t *kCommTaskTypeSlice = "Slice"; -constexpr const char_t *kCommTaskTypeSliceByAxis = "SliceByAxis"; -constexpr const char_t *kCommTaskTypeSplit = "Split"; -constexpr const char_t *kCommTaskTypeTranspose = "Transpose"; -constexpr const char_t *kCommTaskTypeReshape = "Reshape"; -constexpr const char_t *kCommTaskTypeCast = "Cast"; -constexpr const char_t *kCommTaskTypeHcomAllGather = "HcomAllGather"; -constexpr const char_t *kCommTaskTypeHcomAllReduce = "HcomAllReduce"; -constexpr const char_t *kCommTaskTypeHcomAllReduceMean = "HcomAllReduceMean"; -constexpr const char_t *kCommTaskTypeHcomReduceScatter = "HcomReduceScatter"; -constexpr const char_t *kCommTaskTypeHcomBroadcast = "HcomBroadcast"; -constexpr const char_t *kCommTaskTypeHcomAllToAll = "HcomAllToAll"; -constexpr const char_t *kCommTaskTypeSendReceive = "SendReceive"; -constexpr const char_t *kCommTaskTypeLocalReduce = "LocalReduce"; -constexpr const char_t *kGraphSlicingSuffix = "_by_graph_slice_"; -constexpr const char_t *kFlowAttrEnqueuePolicyFifo = "FIFO"; -constexpr const char_t *kFlowAttrEnqueuePolicyOverwrite = "OVERWRITE"; -constexpr const char_t *kSendRecvCommTypeQueue = "Queue"; -constexpr const char_t *kSendRecvCommTypeP2p = "P2pComm"; - -// tensor deployment attrs -struct DimSlice { - int64_t begin; - int64_t end; -}; - -struct DeviceIndex { - std::string engine_type; - std::vector indices; - std::string DebugString() const; -}; - -struct ModelIndex { - // use this construct when need use stage id - ModelIndex() = default; - ModelIndex(const DeviceIndex &device_index, const int64_t stage_id, const int64_t virtual_stage_id) - : device_index(device_index), virtual_stage_id(virtual_stage_id), stage_id(stage_id) {} - // use this construct when do not need use stage id - ModelIndex(const DeviceIndex &device_index, const int64_t virtual_stage_id) - : device_index(device_index), virtual_stage_id(virtual_stage_id), stage_id(0L) {} - ~ModelIndex() = default; - DeviceIndex device_index; - int64_t virtual_stage_id = 0L; - int64_t stage_id = 0L; - std::string DebugString() const; -}; - -struct PipelineConfig { - int64_t micro_batch = 1L; - int64_t stage_id = 0L; - std::vector virtual_stage_id {0L}; -}; - -bool operator==(const DeviceIndex &lhs, const DeviceIndex &rhs); -bool operator!=(const DeviceIndex &lhs, const DeviceIndex &rhs); -bool operator<(const DeviceIndex &lhs, const DeviceIndex &rhs); - -bool operator==(const ModelIndex &lhs, const ModelIndex &rhs); -bool operator!=(const ModelIndex &lhs, const ModelIndex &rhs); -bool operator<(const ModelIndex &lhs, const ModelIndex &rhs); - -struct TensorSliceDeployment { - std::vector> axis_slices; - std::vector> device_indices_each_slice; - std::string reduce_type; -}; - -struct TensorDeployment { - TensorSliceDeployment shard_deployment; - std::string verbose; -}; - -struct NodeDeployment { - std::vector devices; - PipelineConfig pipeline_config; -}; - -struct NodeDeployments { - std::map deployments; -}; - -struct TensorDeployments { - std::map deployments; -}; - -// P2P communications -struct CommPair { - DeviceIndex src_device_index; - int64_t src_virtual_stage_id = 0L; - DeviceIndex dst_device_index; - int64_t dst_virtual_stage_id = 0L; -}; - -struct FlowAttr { - int32_t depth = 1; - std::string enqueue_policy = kFlowAttrEnqueuePolicyFifo; -}; - -struct SendRecvReshardTask { - std::vector comm_pairs; - std::string parallel_group; - std::string comm_type = kSendRecvCommTypeQueue; - FlowAttr flow_attr; // used when comm_type is Queue -}; - -struct CastReshardTask { - DataType dst_type = DT_MAX; -}; - -// group communications -struct CommGroup { - std::vector device_indices; -}; - -struct AllToAllReshardTask { - std::vector comm_groups; - std::string parallel_group; -}; - -struct AllGatherReshardTask { - std::vector comm_groups; - int32_t axis; // axis to concat - std::string parallel_group; - std::string output_allocator; -}; - -struct AllReduceReshardTask { - std::string reduction; - std::vector comm_groups; - std::string parallel_group; -}; - -struct AllReduceMeanReshardTask { - std::vector comm_groups; - int32_t axis; - int32_t value; - std::string parallel_group; -}; - -struct ReduceScatterReshardTask { - std::string reduction; - std::vector comm_groups; - std::string parallel_group; -}; - -struct BroadcastReshardTask { - std::vector root_device_indices; // size == num_groups - std::vector comm_groups; - std::string parallel_group; -}; - -// local reshardings -struct SliceReshardTask { - std::vector axes; - std::vector offsets; - std::vector sizes; - DeviceIndex device_index; -}; - -struct SliceByAxisReshardTask { - // key: axis to split - // value: index: slice index - // element: devices to deploy - std::map>> axis_to_slice_deployments; -}; - -struct SplitReshardTask { - int32_t split_dim = 0; - int32_t num_split = 0; -}; - -struct ConcatReshardTask { - int32_t concat_dim = 0; -}; - -struct UniqueConcatReshardTask { - std::string unique_id; - int32_t concat_dim = 0; - std::vector src_device_indices; - DeviceIndex dst_device_index; -}; - -struct TransposeReshardTask { - std::vector perm; -}; - -struct ReshapeReshardTask { - std::vector shape; -}; - -struct ModifyValueReshardTask { - std::string op_type; // mul, div - std::vector value; -}; - -struct LocalReduceReshardTask { - std::string op_type; -}; - -struct CommTask { - std::string task_type; - std::shared_ptr send_recv_reshard_task; - std::shared_ptr all_gather_reshard_task; - std::shared_ptr all_to_all_reshard_task; - std::shared_ptr all_reduce_reshard_task; - std::shared_ptr all_reduce_mean_reshard_task; - std::shared_ptr reduce_scatter_reshard_task; - std::shared_ptr broadcast_reshard_task; - std::shared_ptr split_reshard_task; - std::shared_ptr concat_reshard_task; - std::shared_ptr unique_concat_reshard_task; - std::shared_ptr slice_reshard_task; - std::shared_ptr slice_by_axis_reshard_task; - std::shared_ptr transpose_reshard_task; - std::shared_ptr modify_value_reshard_task; - std::shared_ptr local_reduce_reshard_task; - std::shared_ptr reshape_reshard_task; - std::shared_ptr cast_reshard_task; -}; - -struct CommStepInput { - int32_t step_id = -1; - int32_t output_index = -1; -}; - -bool operator==(const CommStepInput &lhs, const CommStepInput &rhs); -bool operator<(const CommStepInput &lhs, const CommStepInput &rhs); - -struct CommStep { - int32_t id; - std::vector inputs; - CommTask comm_task; -}; - -struct PeerInput { - int32_t step_id = -1; - std::string node_name; - uint32_t input_index; - int64_t stage_id = 0L; - int64_t virtual_stage_id = 0L; -}; - -// reshard ops for one output tensor -struct OutputReshardRes { - std::vector comm_steps; - std::vector peer_inputs; - std::vector device_indices; - int64_t stage_id = 0L; - int64_t virtual_stage_id = 0L; -}; - -struct ReshardAttr { - std::vector> reshard_infos; // indexed by output index -}; - -struct SrcNodeInfo { - int32_t inserted_node_id = -1; - int32_t output_index = -1; -}; -bool operator==(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs); -bool operator<(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs); - -struct OrigNodeInfo { - std::string node_name; - int32_t sliced_id = -1; - - std::string Name() const { - return (sliced_id == -1) ? node_name : (node_name + kGraphSlicingSuffix + std::to_string(sliced_id)); - } -}; - -bool operator==(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs); -bool operator<(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs); - -struct DstNodeInfo { - OrigNodeInfo orig_node_info; - std::vector input_indexes; - - std::string InputIndexesToString() const { - std::string res; - for (const uint32_t input_index : input_indexes) { - res += std::to_string(input_index) + " "; - } - return res; - } -}; - -bool operator==(const DstNodeInfo &lhs, const DstNodeInfo &rhs); -bool operator<(const DstNodeInfo &lhs, const DstNodeInfo &rhs); - -struct InsertedNodeInput { - SrcNodeInfo input_info; - OrigNodeInfo orig_node_info; -}; - -bool operator==(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs); -bool operator<(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs); - -struct PeerOutNodeInfo { - SrcNodeInfo input_info; - DstNodeInfo node_info; -}; - -bool operator==(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs); -bool operator<(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs); - -struct InsertedNodeInfo { - uint32_t id; - CommTask task; - std::vector inputs; -}; - -struct OutputSlicedRes { - std::vector inserted_nodes_info; - std::vector peer_out_nodes; -}; - -struct SlicedEdgeInfo { - std::vector steps_sliced; -}; - -struct TensorShapeSlicedInfo { - std::vector> axis_slices; -}; - -struct NodeSliceStrategy { - std::map input_shape_sliced_info; - std::map output_shape_sliced_info; - - std::map outputs_sliced_edge_info; - std::vector>> dependencies; - size_t size = 1U; -}; - -struct ShardGraphExtAttrs { - // ExtAttr _device_index_to_logic_device_id, key is DeviceIndex, value is logic device id - std::map> dev_index_to_logic_dev_id; - // ExtAttr _model_events, key1 is graph name, key2 is endpoint name, value is serialized endpoints - std::map>> graph_name_to_endpoints; - // ExtAttr _hcomgroups, key is group name, value is device ids - std::map> group_name_to_dev_ids; -}; - -class TensorParallelAttrs { - public: - static Status FromJson(const std::string &json_str, DeviceIndex &device_index); - static Status FromJson(const std::string &json_str, ModelIndex &model_index); - static Status FromJson(const std::string &json_str, PipelineConfig &pipeline_config); - static Status FromJson(const std::string &json_str, NodeDeployment &node_deployment); - static Status FromJson(const std::string &json_str, TensorDeployment &tensor_deployment); - static Status FromJson(const std::string &json_str, TensorDeployments &tensor_deployments); - static Status FromJson(const std::string &json_str, NodeDeployments &node_deployments); - static Status FromJson(const std::string &json_str, CommTask &comm_task); - static Status FromJson(const std::string &json_str, CommStep &comm_step); - static Status FromJson(const std::string &json_str, OutputReshardRes &output_reshard_res); - static Status FromJson(const std::string &json_str, ReshardAttr &reshard_attr); - static Status FromJson(const std::string &json_str, ShardGraphExtAttrs &shard_graph_ext_attrs); - - static std::string ToJson(const NodeDeployment &node_deployment); - static std::string ToJson(const DeviceIndex &device_index); - static std::string ToJson(const ModelIndex &model_index); - static std::string ToJson(const PipelineConfig &pipeline_config); - static std::string ToJson(const TensorDeployment &tensor_deployment); - static std::string ToJson(const NodeDeployments &node_deployments); - static std::string ToJson(const ReshardAttr &reshard_attr); - static std::string ToJson(const TensorDeployments &tensor_deployments); - static std::string ToJson(const ShardGraphExtAttrs &shard_graph_ext_attrs); -}; -} // namespace tp -} // namespace ge - -#endif // METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ diff --git a/inc/graph/ref_relation.h b/inc/graph/ref_relation.h deleted file mode 100644 index 8973e18c61a6705778262a229d9e2e864853677f..0000000000000000000000000000000000000000 --- a/inc/graph/ref_relation.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_GRAPH_REF_RELATION_H_ -#define COMMON_GRAPH_REF_RELATION_H_ - -#include -#include -#include -#include -#include - -#include "graph/compute_graph.h" -#include "graph/ge_error_codes.h" -#include "node.h" - -namespace ge { -enum InOutFlag { - NODE_IN = 0, // input flag - NODE_OUT = 1, // output flag -}; - -// RefCell的对象一经创建,便不允许去修改其数据成员。 -struct RefCell { - const std::string node_name; - const ge::NodePtr node; - const InOutFlag in_out; - const int32_t in_out_idx; - const std::string hash_key; - - explicit RefCell(const std::string &name, const ge::NodePtr &node_ptr, const InOutFlag in_out_flag, const int32_t idx) - : node_name(name), node(node_ptr), in_out(in_out_flag), in_out_idx(idx), - hash_key(std::string("") - .append(node_name) - .append(std::to_string(in_out)) - .append(std::to_string(in_out_idx)) - .append(std::to_string(PtrToValue(node.get())))) {} - RefCell(const RefCell &ref_cell) - : node_name(ref_cell.node_name), node(ref_cell.node), in_out(ref_cell.in_out), in_out_idx(ref_cell.in_out_idx), - hash_key(ref_cell.hash_key) {} - ge::RefCell &operator=(const ge::RefCell &ref_cell) = delete; - bool operator == (const RefCell &c) const { - return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; - } - ~RefCell() = default; -}; - -struct RefCellHash{ - size_t operator()(const RefCell &c) const { - return std::hash()(c.hash_key); - } -}; - -class RefRelations { - public: - graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result); - graphStatus BuildRefRelations(ge::ComputeGraph &graph); - graphStatus Clear(); - - RefRelations(); - ~RefRelations() = default; - private: - class Impl; - std::shared_ptr impl_ = nullptr; -}; - -} // namespace ge -#endif // COMMON_GRAPH_REF_RELATION_H_ diff --git a/inc/graph/resource_context_mgr.h b/inc/graph/resource_context_mgr.h deleted file mode 100644 index af0c1fd3896ef136b38a0ee28a1800ee8bddc895..0000000000000000000000000000000000000000 --- a/inc/graph/resource_context_mgr.h +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_RESOURCE_CONTEXT_MGR_H_ -#define INC_GRAPH_RESOURCE_CONTEXT_MGR_H_ - -#include -#include -#include -#include "external/graph/resource_context.h" -#include "graph/ge_error_codes.h" -#include "graph/node.h" -#include "graph/utils/node_utils.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ResourceContextMgr { - public: - ResourceContextMgr() = default; - ~ResourceContextMgr() = default; - /** - * Given resource_key , return corresponding resource pointer - * @param resource_key - * @return orresponding resource pointer - */ - ResourceContext *GetResourceContext(const std::string &resource_key); - /** - * Given resource_key , corresponding resource pointer, set resouce_context with new resource - * @param resource_key - * @param context - * @return status - */ - graphStatus SetResourceContext(const std::string &resource_key, ResourceContext *const context); - /** - * Given resource_key , node reiled on this resource, mgr will keep the relation - * @param resource_key - * @param node - * @return status - */ - graphStatus RegisterNodeReliedOnResource(const std::string &resource_key, NodePtr &node); - /** - * Given resource_key , mgr find node reiled on this reousrce. - * @param resource_key - * @param read_nodes - * @return status - */ - OrderedNodeSet &MutableNodesReliedOnResource(const std::string &resource_key); - /** - * Resource context need to be cleared when session finalize - * @return status - */ - graphStatus ClearContext(); - - private: - std::mutex ctx_mu_; - std::map> resource_keys_to_contexts_; - std::map resource_keys_to_read_nodes_; -}; -} // namespace ge -#endif // INC_GRAPH_RESOURCE_CONTEXT_MGR_H_ diff --git a/inc/graph/runtime_inference_context.h b/inc/graph/runtime_inference_context.h deleted file mode 100644 index f053e0a72f7b5fd2c18ef2654388d2b829a0a4b7..0000000000000000000000000000000000000000 --- a/inc/graph/runtime_inference_context.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ -#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ - -#include -#include -#include -#include -#include "external/graph/ge_error_codes.h" -#include "external/graph/tensor.h" -#include "ge_attr_value.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { - public: - graphStatus SetTensor(int64_t node_id, int32_t output_id, GeTensorPtr tensor); - graphStatus GetTensor(const int64_t node_id, int32_t output_id, GeTensorPtr &tensor) const; - void Release(); - - private: - std::map> ge_tensors_; - mutable std::mutex mu_; -}; -} // namespace ge - -#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ diff --git a/inc/graph/shape_refiner.h b/inc/graph/shape_refiner.h deleted file mode 100644 index ede6922aaca550fe50e289205705eeea4772da2b..0000000000000000000000000000000000000000 --- a/inc/graph/shape_refiner.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_SHAPE_REFINER_H_ -#define INC_GRAPH_SHAPE_REFINER_H_ - -#include -#include "external/graph/inference_context.h" - -#include "external/graph/ge_error_codes.h" -#include "graph/node.h" -#include "graph/resource_context_mgr.h" - -namespace ge { -// ShapeRefiner performs shape inference for compute graphs -class ShapeRefiner { - public: - static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, const bool before_subgraph); - static graphStatus InferShapeAndType(const NodePtr &node, const bool before_subgraph); - static graphStatus InferShapeAndType(const NodePtr &node); - static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); - static graphStatus DoInferShapeAndTypeForRunning(const ConstNodePtr &node, Operator &op, const bool before_subgraph); - static graphStatus InferShapeAndTypeForRunning(const NodePtr &node, Operator &op, const bool before_subgraph); - static void ClearContextMap(); - static graphStatus CreateInferenceContext(const NodePtr &node, - InferenceContextPtr &inference_context); - static graphStatus CreateInferenceContext(const NodePtr &node, - ResourceContextMgr *const resource_context_mgr, - InferenceContextPtr &inference_context); - static void PushToContextMap(const NodePtr &node, const InferenceContextPtr &inference_context); - - private: - static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); - static graphStatus GetRealInNodesAndIndex(NodePtr &input_node, int32_t &output_idx, - std::map &nodes_idx); - static graphStatus PostProcessAfterInfershape(const NodePtr &node, const Operator &op, const bool is_unknown_graph); - static graphStatus UpdateInputOutputDesc(const NodePtr &node); -}; -} // namespace ge -#endif // INC_GRAPH_SHAPE_REFINER_H_ diff --git a/inc/graph/symbolizer/OWNERS b/inc/graph/symbolizer/OWNERS deleted file mode 100644 index d7e0ca0f2cefd8ffbcfa13c252ae675d55915bd3..0000000000000000000000000000000000000000 --- a/inc/graph/symbolizer/OWNERS +++ /dev/null @@ -1,15 +0,0 @@ -approvers: -- wqtshg -- wangxiaotian22 -- zhangfan_hq -- lipeiyang3699 -- zhangdepeng2 -- yskhhh -- ji_chen - -reviewers: -- sheng-nan -- zhan-jun - -options: - no_parent_owners: true diff --git a/inc/graph/symbolizer/guard_dfx_context.h b/inc/graph/symbolizer/guard_dfx_context.h deleted file mode 100644 index 05c9e17f09b04b13203b36f09d05a7de0bd87805..0000000000000000000000000000000000000000 --- a/inc/graph/symbolizer/guard_dfx_context.h +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GUARD_DFX_CONTEXT_H_ -#define GUARD_DFX_CONTEXT_H_ - -#include - -namespace ge { -class GuardDfxContext { - public: - explicit GuardDfxContext(const std::string &guard_dfx_info); - ~GuardDfxContext(); - GuardDfxContext(const GuardDfxContext &) = delete; - GuardDfxContext(const GuardDfxContext &&) = delete; - GuardDfxContext &operator=(const GuardDfxContext &) = delete; - GuardDfxContext &&operator=(const GuardDfxContext &&) = delete; -}; -} -#endif // GUARD_DFX_CONTEXT_H_ diff --git a/inc/graph/symbolizer/symbol_checker.h b/inc/graph/symbolizer/symbol_checker.h deleted file mode 100644 index 8b2d27c3f589af550a1d9e7b78f990ce1cee5dd0..0000000000000000000000000000000000000000 --- a/inc/graph/symbolizer/symbol_checker.h +++ /dev/null @@ -1,160 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_SYMBOLIZER_SYMBOL_CHECKER_H_ -#define GRAPH_SYMBOLIZER_SYMBOL_CHECKER_H_ - -/* - * 校验表达式e0是否与e1相等 - * 如果e0的hint值与e1的hint值相等,则宏返回true,并生成e0 == e1的guard - * 反之,则返回false,并生成 e0 != e1的guard - */ -#define EXPECT_SYMBOL_EQ(e0, e1) \ - ge::sym::ExpectSymbolEq(e0, e1, __FILE__, __LINE__) - -/* - * 校验表达式e0是否与e1不相等 - * 如果e0的hint值与e1的hint值不相等,则宏返回true,并生成e0 != e1的guard - * 反之,则返回false,并生成 e0 == e1的guard - */ -#define EXPECT_SYMBOL_NE(e0, e1) \ - ge::sym::ExpectSymbolNe(e0, e1, __FILE__, __LINE__) - -/* - * 校验表达式e0是否小于e1 - * 如果e0的hint值小于e1的hint值,则宏返回true,并生成e0 < e1的guard - * 反之,则返回false,并生成 e1 <= e0的guard - */ -#define EXPECT_SYMBOL_LT(e0, e1) \ - EXPECT_SYMBOL_CHECK(ge::sym::Lt(e0, e1), __FILE__, __LINE__) - -/* - * 校验表达式e0是否小于等于e1 - * 如果e0的hint值小于等于e1的hint值,则宏返回true,并生成e0 <= e1的guard - * 反之,则返回false,并生成 e1 < e0的guard - */ -#define EXPECT_SYMBOL_LE(e0, e1) \ - EXPECT_SYMBOL_CHECK(ge::sym::Le(e0, e1), __FILE__, __LINE__) - -/* - * 校验表达式e0是否大于e1 - * 如果e0的hint值大于e1的hint值,则宏返回true,并生成e1 < e0的guard - * 反之,则返回false,并生成 e0 <= e1的guard - */ -#define EXPECT_SYMBOL_GT(e0, e1) \ - EXPECT_SYMBOL_CHECK(ge::sym::Gt(e0, e1), __FILE__, __LINE__) - -/* - * 校验表达式e0是否大于等于e1 - * 如果e0的hint值大于等于e1的hint值,则宏返回true,并生成e1 <= e0的guard - * 反之,则返回false,并生成 e0 < e1的guard - */ -#define EXPECT_SYMBOL_GE(e0, e1) \ - EXPECT_SYMBOL_CHECK(ge::sym::Ge(e0, e1), __FILE__, __LINE__) - -/* - * 检查表达式列表是否都为true - * 如果表达式都为true,则宏返回true, - * 如果其中有一个是false,则返回false,否则并生成LogicAnd()的guard - * 例如校验表达式:EXPECT_SYMBOL_AND(Ge(s0, s1), Le(s2, s3), Eq(s4, s5)) - * hint值为true时添加guard:LogicAnd(ExpectEq(s4, s5), ExpectLe(s1, s0), ExpectLe(s2, s3)) - * hint值为false时添加guard:LogicOr(ExpectLt(s0, s1), ExpectLt(s3, s2), ExpectNe(s4, s5)) - */ -#define EXPECT_SYMBOL_AND(...) \ - EXPECT_SYMBOL_CHECK(ge::sym::LogicalAnd(std::vector{__VA_ARGS__}), __FILE__, __LINE__) - -/* - * 检查表达式列表是否有一个为true - * 如果表达式全部为false,则宏返回false, - * 如果其中有一个是true,则返回true,否则并生成LogicOr()的guard - * 例如校验表达式:EXPECT_SYMBOL_OR(Ge(s0, s1), Le(s2, s3), Eq(s4, s5)) - * hint值为true时添加guard:LogicOr(ExpectEq(s4, s5), ExpectLe(s1, s0), ExpectLe(s2, s3)) - * hint值为false时添加guard:LogicAnd(ExpectLt(s0, s1), ExpectLt(s3, s2), ExpectNe(s4, s5)) - */ -#define EXPECT_SYMBOL_OR(...) \ - EXPECT_SYMBOL_CHECK(ge::sym::LogicalOr(std::vector{__VA_ARGS__}), __FILE__, __LINE__) - -/* - * 强校验表达式e0是否等于e1 - * 如果e0的hint值等于e1的hint值,则生成e0 == e1的guard - * 反之,则报错 - */ -#define ASSERT_SYMBOL_EQ(e0, e1) \ - do { \ - if (!ge::sym::AssertSymbolEq(e0, e1, __FILE__, __LINE__)) { \ - return ::ErrorResult(); \ - } \ - } while (false) - -/* - * 强校验表达式e0是否不等于e1 - * 如果e0的hint值不等于e1的hint值,则生成e0 != e1的guard - * 反之,则报错 - */ -#define ASSERT_SYMBOL_NE(e0, e1) \ - ASSERT_SYMBOL_CHECK(ge::sym::Ne(e0, e1), __FILE__, __LINE__) - -/* - * 强校验表达式e0是否小于e1 - * 如果e0的hint值小于e1的hint值,则生成e0 < e1的guard - * 反之,则报错 - */ -#define ASSERT_SYMBOL_LT(e0, e1) \ - ASSERT_SYMBOL_CHECK(ge::sym::Lt(e0, e1), __FILE__, __LINE__) - -/* - * 强校验表达式e0是否小于等于e1 - * 如果e0的hint值小于等于e1的hint值,则生成e0 <= e1的guard - * 反之,则报错 - */ -#define ASSERT_SYMBOL_LE(e0, e1) \ - ASSERT_SYMBOL_CHECK(ge::sym::Le(e0, e1), __FILE__, __LINE__) - -/* - * 强校验表达式e0是否大于e1 - * 如果e0的hint值大于e1的hint值,则生成e1 < e0的guard - * 反之,则报错 - */ -#define ASSERT_SYMBOL_GT(e0, e1) \ - ASSERT_SYMBOL_CHECK(ge::sym::Gt(e0, e1), __FILE__, __LINE__) - -/* - * 强校验表达式e0是否大于等于e1 - * 如果e0的hint值大于等于e1的hint值,则生成e1 <= e0的guard - * 反之,则报错 - */ -#define ASSERT_SYMBOL_GE(e0, e1) \ - ASSERT_SYMBOL_CHECK(ge::sym::Ge(e0, e1), __FILE__, __LINE__) - -#define EXPECT_SYMBOL_CHECK(expr, file, line) \ - ge::sym::ExpectSymbolBool(expr, file, line) - -#define ASSERT_SYMBOL_CHECK(expr, file, line) \ - do { \ - if (!ge::sym::AssertSymbolBool(expr, file, line)) { \ - return ::ErrorResult(); \ - } \ - } while (false) - -namespace ge { -class Expression; -namespace sym { -bool ExpectSymbolEq(const Expression &e0, const Expression &e1, - const char_t *file, const int64_t line); -bool AssertSymbolEq(const Expression &e0, const Expression &e1, - const char_t *file, const int64_t line); -bool ExpectSymbolNe(const Expression &e0, const Expression &e1, - const char_t *file, const int64_t line); -bool ExpectSymbolBool(const Expression &expr, - const char_t *file, const int64_t line); -bool AssertSymbolBool(const Expression &expr, - const char_t *file, const int64_t line); -} // namespace sym -} -#endif // GRAPH_SYMBOLIZER_SYMBOL_CHECKER_H_ \ No newline at end of file diff --git a/inc/graph/symbolizer/symbol_operator.h b/inc/graph/symbolizer/symbol_operator.h deleted file mode 100644 index 697e9e5b8390a3dd33ebaba5123cf598653c5db8..0000000000000000000000000000000000000000 --- a/inc/graph/symbolizer/symbol_operator.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_SYMBOLIZER_SYMBOL_OPERATOR_H_ -#define GRAPH_SYMBOLIZER_SYMBOL_OPERATOR_H_ - -#include -namespace ge { -class Expression; -namespace sym { -Expression Add(const Expression &a, const Expression &b); -Expression Sub(const Expression &a, const Expression &b); -Expression Mul(const Expression &a, const Expression &b); -Expression Div(const Expression &a, const Expression &b); -Expression Max(const Expression &a, const Expression &b); -Expression Min(const Expression &a, const Expression &b); -Expression Pow(const Expression &base, const Expression &exp); -Expression Mod(const Expression &base, const Expression &exp); -Expression Abs(const Expression &a); -Expression Log(const Expression &a); // 默认以E为底 -Expression Log(const Expression &arg, const Expression &base); -Expression Coeff(const Expression &b, const Expression &x, const Expression &n); -Expression Rational(int32_t num, int32_t den); // 分数 -Expression Ceiling(const Expression &a); -Expression Align(const Expression &arg, uint32_t alignment); -Expression AlignWithPositiveInteger(const Expression &arg, uint32_t alignment); -Expression Floor(const Expression &arg); -Expression Eq(const Expression &a, const Expression &b); // == -Expression Ne(const Expression &a, const Expression &b); // != -Expression Ge(const Expression &a, const Expression &b); // >= -Expression Gt(const Expression &a, const Expression &b); // > -Expression Le(const Expression &a, const Expression &b); // <= -Expression Lt(const Expression &a, const Expression &b); // < -Expression Not(const Expression &a); // ! -Expression Neg(const Expression &a); // 负号 -Expression LogicalAnd(const std::vector &a); -Expression LogicalOr(const std::vector &a); -} // namespace sym -} -#endif // GRAPH_SYMBOLIZER_SYMBOL_OPERATOR_H_ \ No newline at end of file diff --git a/inc/graph/symbolizer/symbolic.h b/inc/graph/symbolizer/symbolic.h deleted file mode 100644 index 494adb585fefb90487f4f523dad2d005503febc4..0000000000000000000000000000000000000000 --- a/inc/graph/symbolizer/symbolic.h +++ /dev/null @@ -1,313 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -// 新的头文件,等air仓头文件切换后,搬移内容到这里 -#ifndef GRAPH_SYMBOLIZER_SYMBOLIC_H_ -#define GRAPH_SYMBOLIZER_SYMBOLIC_H_ - -#include -#include -#include "graph/ge_error_codes.h" -#include "graph/types.h" -#include "graph/type_utils.h" -#include "graph/symbolizer/symbol_operator.h" -#include "graph/symbolizer/symbol_checker.h" -// 不允许添加symbolic_utils.h,后续symbolic.h需要开源 - -namespace ge { -class Expression; -class ExpressionImpl; -class ShapeEnvAttr; -using ExpressionImplPtr = std::unique_ptr; -std::ostream &operator<<(std::ostream &os, const Expression &e); - -enum class ExprType : uint32_t { - kExprConstantInteger = 0, - kExprConstantRealDouble = 1, - kExprConstantRation = 2, - kExprConstantBoolean = 3, - // add const defination here - kExprVariable = 100, - // add variable defination here - kExprOperation = 200, - kExprOperationBoolean, - // add operation defination here - kExprNone = std::numeric_limits::max() -}; - -enum class StrType : size_t { - kStrCpp = 0, - kStrExpr = 1, // kStrExpr和kStrCpp只有在处理除法的时候有区别,例如Div(a,b): kStrExpr返回a/b,kStrCpp返回Rational(a,b) - kStrEnd = 2, -}; - -class Expression { - public: - Expression(); - ~Expression(); - Expression(const Expression &other); - Expression(Expression &&other) noexcept; - Expression &operator=(const Expression &other); - Expression &operator=(Expression &&other) noexcept; - /** - * @brief 获取表达式转换成字符串 - */ - std::unique_ptr Str(const StrType type = StrType::kStrCpp) const; - /** - * @brief 将字符串转换成表达式,与Str接口匹配 - */ - static Expression Parse(const char_t *str); - /** - * @brief 序列化,将表达式转换成字符串 - */ - std::unique_ptr Serialize() const; - - /** - * @brief 反序列化,与Serialize接口匹配,将字符串转换成表达式,同时会校验字符串格式是否为序列化接口序列化出的字符串,如果不是则会报错 - */ - static Expression Deserialize(const char_t *str); - - /** - * @brief 获取表达式的参数,例如表达式 2*s0 + Pow(s1, 2),函数返回[2*s0, Pow(s1, 2)] - */ - std::vector GetArgs(); - /** - * @brief 获取表达式的类型 - */ - ExprType GetExprType() const; - /** - * @brief 是否是ConstExpr类型 - */ - bool IsConstExpr() const; - - /** - * @brief 是否是Symbol类型 - */ - bool IsVariableExpr() const; - - /** - * @brief 是否是Bool类型 - */ - bool IsBooleanExpr() const; - - /** - * @brief 对当前表达式中的表达式进行替换,例如 y= x+2. y.replace({x, 2*x}) -> y = 2*x + 2 - * 注意当前symengine对div sub的表达式替换能力有缺失,需要用户自己保证,例如x/y*z->Replace({{x/y, m}})会替换失败 - * @param pair first为被替换的表达式,second为替换的表达式 - */ - Expression Replace(const std::vector> &replace_vars) const; - - /** - * @brief 对当前表达式中的符号进行替换,如对于表达式expr = x + y,expr.subs({x:2, y:z+1}) -> y + z + 1。于repalce比较功能较单一, - * 只能替换单一符号,无法处理复杂的表达式 - * @param subs_vars 待替换的符号列表,pair中first为被替换的表达式,second为替换的表达式 - * @return 替换后表达式 - */ - Expression Subs(const std::vector> &subs_vars) const; - /** - * @brief 对当前表达式进行化简。例如2+x+4 -> 6+x - */ - Expression Simplify() const; - /** - * @brief 判断当前表达式字符串中是否含有表达式e的子字符串,例如max((x+2), (4*y)) 含有 x和y - */ - bool ContainVar(const Expression &e) const; - - /** - * @brief 判断两个Expr是否相等 - */ - bool operator==(const Expression &e) const; - /** - * @brief 判断一个expr与常量是否相等 - */ - template - typename std::enable_if::value || std::is_floating_point::value, bool>::type - operator==(const T &e) const; - - /** - * @brief 判断两个Expr是否不相等 - */ - bool operator!=(const Expression &e) const; - - /** - * @brief 判断一个expr与常量是否不相等 - */ - template - typename std::enable_if::value || std::is_floating_point::value, bool>::type - operator!=(const T &e) const; - - /** - * @brief 获取表达式最基础的元素。例如x - (y * z),返回{x, y, z}, 注意该接口没有依据字符去重 - */ - std::vector FreeSymbols() const; - - /** - * @brief 获取表达式的值 - */ - graphStatus GetResult(const std::vector> &vars_value, double &result) const; - - /** - * @brief 判断表达式是否合法,成员变量impl_为null则不合法 - */ - bool IsValid() const; - - /** - * @brief 返回一个Expression类对象的hash值,主要目的是用于将Expression对象作为map的key时使用 - */ - uint64_t Hash() const; - - /** - * @brief 分别返回 -1, 0, 1 当 `this < e, this == e, this > e`. - */ - int64_t Compare(const Expression &e) const; - - /** - * @brief 获取常量的值,只有GetExprType为EXPR_CONSTANT时有效 - * @param value 常量的值 - * @return 成功返回true,失败返回false,失败时value的值无效 - */ - template - typename std::enable_if::value || std::is_floating_point::value, bool>::type - GetConstValue(T &value) const; - - /** - * @brief 获取表达式hint值 - * @param hint 获取表达式的hint值 - * @return 成功返回true,失败返回false,失败时value的值无效 - */ - template - typename std::enable_if::value || std::is_floating_point::value, bool>::type - GetHint(T &hint) const { - return ComputeHint(hint); - } - - Expression operator+(const Expression &other) const; - Expression operator-(const Expression &other) const; - Expression operator*(const Expression &other) const; - Expression operator/(const Expression &other) const; - - friend Expression sym::Add(const Expression &a, const Expression &b); - friend Expression sym::Sub(const Expression &a, const Expression &b); - friend Expression sym::Mul(const Expression &a, const Expression &b); - friend Expression sym::Div(const Expression &a, const Expression &b); - friend Expression sym::Max(const Expression &a, const Expression &b); - friend Expression sym::Min(const Expression &a, const Expression &b); - friend Expression sym::Pow(const Expression &base, const Expression &exp); - friend Expression sym::Mod(const Expression &base, const Expression &exp); - friend Expression sym::Abs(const Expression &a); - friend Expression sym::Log(const Expression &a); // 默认以E为底 - friend Expression sym::Log(const Expression &arg, const Expression &base); - friend Expression sym::Coeff(const Expression &b, const Expression &x, const Expression &n); - friend Expression sym::Rational(int32_t num, int32_t den); // 分数 - friend Expression sym::Ceiling(const Expression &a); - friend Expression sym::Floor(const Expression &arg); - friend Expression sym::Align(const Expression &arg, uint32_t alignment); - friend Expression sym::AlignWithPositiveInteger(const Expression &arg, uint32_t alignment); - friend std::ostream &operator<<(std::ostream &os, const Expression &e); - friend Expression sym::Eq(const Expression &a, const Expression &b); // == - friend Expression sym::Ne(const Expression &a, const Expression &b); // != - friend Expression sym::Ge(const Expression &a, const Expression &b); // >= - friend Expression sym::Gt(const Expression &a, const Expression &b); // > - friend Expression sym::Le(const Expression &a, const Expression &b); // <= - friend Expression sym::Lt(const Expression &a, const Expression &b); // < - friend Expression sym::Not(const Expression &a); // ! - friend Expression sym::Neg(const Expression &a); // 负号 - friend Expression sym::LogicalAnd(const std::vector &a); - friend Expression sym::LogicalOr(const std::vector &a); - friend class ShapeEnvAttr; - protected: - explicit Expression(ExpressionImplPtr &&e); - template - typename std::enable_if::value || std::is_floating_point::value, bool>::type - ComputeHint(T &hint) const; - ExpressionImplPtr impl_; - - private: - Expression CanonicalizeBoolExpr() const; -}; - -class Symbol : public Expression { - public: - // 拷贝构造、赋值、移动构造、移动赋值默认使用基类,需要保证Symbol类大小与Expression类大小一致 - /** - * @brief 创建常量 - * @param value 常量的值 - * @param name 常量的名称,默认为空,内部不持有该指针 - */ - explicit Symbol(int32_t value, const char_t *name = ""); - explicit Symbol(int64_t value, const char_t *name = ""); - explicit Symbol(uint32_t value, const char_t *name = ""); - explicit Symbol(uint64_t value, const char_t *name = ""); - explicit Symbol(double value, const char_t *name = ""); - - /** - * @brief 创建变量 - * @param name 变量的名称 - */ - explicit Symbol(const char_t *name = ""); - - /** - * @brief 获取symbol的name,返回值是一个unique_ptr,需要用户自己释放 - */ - std::unique_ptr GetName() const; - friend class ShapeEnvAttr; - private: - explicit Symbol(ExpressionImplPtr &&e); -}; - -template -typename std::enable_if::value || std::is_floating_point::value, bool>::type -Expression::operator==(const T &e) const { - Symbol symbol(e); - return (*this == symbol); -} - -template -typename std::enable_if::value || std::is_floating_point::value, bool>::type -Expression::operator!=(const T &e) const { - Symbol symbol(e); - return !(*this == symbol); -} - -// 为了保证ABI兼容性,禁用虚函数,Symbol的大小必须和Expression一样 -static_assert(sizeof(Symbol) == sizeof(Expression), - "The size of the subclass Symbol must be equal to the size of the base class Expression."); -template <> -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY inline TypeId GetTypeId() { - return reinterpret_cast(1024); -} - -// 目的是构建以Expression作为key值的map、set、与unordered_map -// todoo:后续考虑挪到symbolic_utils.h或symbolic_dict.h中 -struct ExpressionHash { - //! Returns the hashed value. - uint64_t operator()(const Expression &k) const { - return k.Hash(); - } -}; -struct ExpressionKeyEq { - //! Comparison Operator `==` - bool operator()(const Expression &x, const Expression &y) const { - return x == y; - } -}; -struct ExpressionKeyLess { - //! true if `x < y`, false otherwise - bool operator()(const Expression &x, const Expression &y) const { - int64_t xh = x.Hash(); - int64_t yh = y.Hash(); - if (xh != yh) - return xh < yh; - if (x == y) - return false; - return x.Compare(y) == -1; - } -}; -} // namespace ge -#endif \ No newline at end of file diff --git a/inc/graph/symbolizer/symbolic_utils.h b/inc/graph/symbolizer/symbolic_utils.h deleted file mode 100644 index 619768cb515727a6af2cb38b777ec3b911441477..0000000000000000000000000000000000000000 --- a/inc/graph/symbolizer/symbolic_utils.h +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_SYMBOLIZER_SYMBOLIC_UTILS_H_ -#define GRAPH_SYMBOLIZER_SYMBOLIC_UTILS_H_ - -#include -namespace ge { -class Expression; -enum TriBool { kFalse = 0, kTrue = 1, kUnknown = -1 }; - -class SymbolicUtils { - public: - /** - * @param e 符号 - * @brief 序列化,将表达式转换成字符串 - */ - static std::string ToString(const Expression &e); - - /** - * @brief 基于之前生成的guard信息判断e1与e2是否相等,仅基于已有guard做校验,不生产新的guard,主要用于内存优化等编译态优化时判断使用 - * @param e1 表达式1 - * @param e2 表达式2 - * @return TriBool - 三态布尔返回值: - * kTrue: 可确定 e1 == e2 - * kFalse: 可确定 e1 != e2 - * kUnknown: 根据现有guard无法确定大小关系 - */ - static TriBool StaticCheckEq(const Expression &e1, const Expression &e2); - - /** - * @brief 基于之前生成的guard信息判断e1与e2是否不相等,仅基于已有guard做校验,不生产新的guard,主要用于内存优化等编译态优化时判断使用 - * @param e1 表达式1 - * @param e2 表达式2 - * @return TriBool - 三态布尔返回值: - * kTrue: 可确定 e1 != e2 - * kFalse: 可确定 e1 == e2 - * kUnknown: 根据现有guard无法确定大小关系 - */ - static TriBool StaticCheckNe(const Expression &e1, const Expression &e2); - - /** - * @brief 基于之前生成的guard信息判断e1是否小于e2,仅基于已有guard做校验,不生产新的guard,主要用于内存优化等编译态优化时判断使用 - * @param e1 表达式1 - * @param e2 表达式2 - * @return TriBool - 三态布尔返回值: - * kTrue: 可确定 e1 < e2 - * kFalse: 可确定 e1 >= e2 - * kUnknown: 根据现有guard无法确定大小关系 - */ - static TriBool StaticCheckLt(const Expression &e1, const Expression &e2); - - /** - * @brief 基于之前生成的guard信息判断e1是否小于等于e2,仅基于已有guard做校验,不生产新的guard,主要用于内存优化等编译态优化时判断使用 - * @param e1 表达式1 - * @param e2 表达式2 - * @return TriBool - 三态布尔返回值: - * kTrue: 可确定 e1 <= e2 - * kFalse: 可确定 e1 > e2 - * kUnknown: 根据现有guard无法确定大小关系 - */ - static TriBool StaticCheckLe(const Expression &e1, const Expression &e2); - - /** - * @brief 基于之前生成的guard信息判断e1是否大于e2,仅基于已有guard做校验,不生产新的guard,主要用于内存优化等编译态优化时判断使用 - * @param e1 表达式1 - * @param e2 表达式2 - * @return TriBool - 三态布尔返回值: - * kTrue: 可确定 e1 > e2 - * kFalse: 可确定 e1 <= e2 - * kUnknown: 根据现有guard无法确定大小关系 - */ - static TriBool StaticCheckGt(const Expression &e1, const Expression &e2); - - /** - * @brief 基于之前生成的guard信息判断e1是否大于等于e2,仅基于已有guard做校验,不生产新的guard,主要用于内存优化等编译态优化时判断使用 - * @param e1 表达式1 - * @param e2 表达式2 - * @return TriBool - 三态布尔返回值: - * kTrue: 可确定 e1 >= e2 - * kFalse: 可确定 e1 < e2 - * kUnknown: 根据现有guard无法确定大小关系 - */ - static TriBool StaticCheckGe(const Expression &e1, const Expression &e2); - - private: - static TriBool StaticCheckBool(const Expression &expr); -}; -} -#endif // GRAPH_SYMBOLIZER_SYMBOLIC_UTILS_H_ \ No newline at end of file diff --git a/inc/graph/tuning_utils.h b/inc/graph/tuning_utils.h deleted file mode 100644 index d61a1ea57d2c7fae34347f4bcd23867a5255f00c..0000000000000000000000000000000000000000 --- a/inc/graph/tuning_utils.h +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef MAIN_TUNING_UTILS_H -#define MAIN_TUNING_UTILS_H - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/ge_common/debug/ge_log.h" -#include "utils/attr_utils.h" -#include "utils/node_utils.h" -#include "external/ge_common/ge_api_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -namespace ge { -// Configure build mode, default value is "normal" -constexpr char_t BUILD_MODE[] = "ge.buildMode"; -constexpr char_t BUILD_STEP[] = "ge.buildStep"; -// Configure tuning path -constexpr char_t TUNING_PATH[] = "ge.tuningPath"; -// for interface: aclgrphBuildModel -extern const std::set ir_builder_supported_options_for_lx_fusion; - -// Build model -constexpr char_t BUILD_MODE_NORMAL[] = "normal"; -constexpr char_t BUILD_MODE_TUNING[] = "tuning"; -constexpr char_t BUILD_MODE_BASELINE[] = "baseline"; -constexpr char_t BUILD_MODE_OPAT_RESULT[] = "opat_result"; -extern const std::set build_mode_options; - -// Build step -constexpr char_t BUILD_STEP_BEFORE_UB_MATCH[] = "before_ub_match"; -constexpr char_t BUILD_STEP_AFTER_UB_MATCH[] = "after_ub_match"; -constexpr char_t BUILD_STEP_AFTER_BUILDER[] = "after_builder"; -constexpr char_t BUILD_STEP_AFTER_BUILDER_SUB[] = "after_builder_sub"; -constexpr char_t BUILD_STEP_AFTER_MERGE[] = "after_merge"; -constexpr char_t BUILD_STEP_BEFORE_BUILD[] = "before_build"; -constexpr char_t BUILD_STEP_AFTER_BUILD[] = "after_build"; -extern const std::set build_step_options; - -using SubgraphCreateOutNode = std::unordered_map; -using NodetoNodeMap = std::unordered_map; -using NodeVec = std::vector; -using NodeNametoNodeNameMap = std::map; -using NodetoNodeNameMap = std::unordered_map; -class TuningUtils { - public: - TuningUtils() = default; - ~TuningUtils() = default; - // Dump all the subgraphs and modify - // the subgraphs in them to be executable subgraphs if exe_flag is true - // `tuning_path` means path to save the graphs - static graphStatus ConvertGraphToFile(std::vector tuning_subgraphs, - std::vector non_tuning_subgraphs = {}, - const bool exe_flag = false, - const std::string &path = "", - const std::string &user_path = ""); - // Recovery `graph` from graph dump files configured in options - static graphStatus ConvertFileToGraph(const std::map &options, ge::Graph &graph); - - static graphStatus LinkSubgraph(ComputeGraphPtr &root_graph, const ComputeGraphPtr &graph, - const std::map &name_to_merged_subgraph); - -private: - // part 1 - class HelpInfo { - HelpInfo(const int64_t index, const bool exe_flag, const bool is_tuning_graph, const std::string &path, - const std::string &user_path) : index_(index), - exe_flag_(exe_flag), - is_tuning_graph_(is_tuning_graph), - path_(path), - user_path_(user_path) {} - ~HelpInfo() = default; - private: - int64_t index_; - bool exe_flag_; - bool is_tuning_graph_; - const std::string &path_; - const std::string &user_path_; - bool need_preprocess_ = false; - friend class TuningUtils; - }; - static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, - const HelpInfo& help_info); - static graphStatus ConvertConstToWeightAttr(const ComputeGraphPtr &exe_graph); - static graphStatus SetFileConstInfo(const NodePtr &node, const GeTensorPtr &tensor, const std::string &aoe_path, - const OpDescPtr &op_desc); - static graphStatus HandlePld(NodePtr &node, const std::string &aoe_path); - static graphStatus HandleConst(NodePtr &node, const std::string &aoe_path); - static graphStatus PreProcessNode(const NodePtr &node); - static graphStatus HandleEnd(NodePtr &node); - static graphStatus ChangePld2Data(const NodePtr &node, const NodePtr &data_node); - static graphStatus ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node); - static graphStatus LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node); - static graphStatus CreateDataNode(NodePtr &node, const std::string &aoe_path, NodePtr &data_node); - static graphStatus CreateNetOutput(const NodePtr &node, NodePtr &out_node); - static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, const NodePtr &data_node); - static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, const NodePtr &out_node, const int64_t index); - static void DumpGraphToPath(const ComputeGraphPtr &exe_graph, const int64_t index, - const bool is_tuning_graph, std::string path); - static void TryGetWeight(const NodePtr &node, std::vector &weight); - - static SubgraphCreateOutNode create_output_; - // part 2 - static graphStatus MergeGraph(const std::vector &subgraphs, - ComputeGraphPtr &output_merged_compute_graph); - static graphStatus MergeAllSubGraph(const std::vector &subgraphs, - ComputeGraphPtr &output_merged_compute_graph); - static graphStatus MergeSubGraph(const ComputeGraphPtr &subgraph); - // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 - static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); - static NodePtr FindNode(const std::string &name, int64_t &in_index); - static graphStatus LoadGraphFromFile(const std::map &options, - std::vector &root_graphs, - std::map> &name_to_subgraphs); - static NodeNametoNodeNameMap data_2_end_; - static NodetoNodeNameMap data_node_2_end_node_; - static NodetoNodeMap data_node_2_netoutput_node_; - static NodeVec netoutput_nodes_; - static NodeVec merged_graph_nodes_; - static std::mutex mutex_; - static std::set reusable_weight_files_; - static std::map name_to_index_; - static std::map> hash_to_files_; - // for debug - static std::string PrintCheckLog(); - static std::string GetNodeNameByAnchor(const Anchor * const anchor); - static std::string GenerateFileConstPath(const std::string &aoe_path, const OpDescPtr &op_desc); - static Status GetOrSaveReusableFileConst(const GeTensorPtr &tensor, std::string &file_path); - static Status CheckFilesSame(const std::string &file_name, const char_t *const data, const size_t data_length, - bool &is_content_same); -}; -} -#endif // MAIN_TUNING_UTILS_H diff --git a/inc/graph/utils/anchor_utils.h b/inc/graph/utils/anchor_utils.h deleted file mode 100644 index 2c63a98e8dece1db80314762346e642cdfaa8edc..0000000000000000000000000000000000000000 --- a/inc/graph/utils/anchor_utils.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_ -#define INC_GRAPH_UTILS_ANCHOR_UTILS_H_ - -#include "graph/anchor.h" -#include "graph/node.h" - -namespace ge { -class AnchorUtils { - public: - // Get anchor status - static AnchorStatus GetStatus(const DataAnchorPtr &data_anchor); - - // Set anchor status - static graphStatus SetStatus(const DataAnchorPtr &data_anchor, const AnchorStatus anchor_status); - - static int32_t GetIdx(const AnchorPtr &anchor); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_ diff --git a/inc/graph/utils/axis_utils.h b/inc/graph/utils/axis_utils.h deleted file mode 100644 index 510b05efed4d9e0ba2ac291e2651814b65877ec9..0000000000000000000000000000000000000000 --- a/inc/graph/utils/axis_utils.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#ifndef METADEF_CXX_INC_GRAPH_UTILS_AXIS_UTILS_H_ -#define METADEF_CXX_INC_GRAPH_UTILS_AXIS_UTILS_H_ -#include -#include "graph/symbolizer/symbolic.h" -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" - -namespace ge { -class AxisUtils { - public: - static View ReduceView(const View &src_view, int64_t reduce_axis); - static View ReorderView(const View &src_view, const std::vector &my_api_sched_axes); - - static View SplitView(const View &src_view, const ge::Expression &split_size, const int64_t outter_id, - const int64_t inner_id, const int64_t original_id); - static View MergeView(const View &src_view, const int64_t merged_axis_id, - const std::vector &original); - static std::pair UpdateViewIfCrossLoop(const TransInfoRoadOfGraph &trans_info_road_of_graph, - const std::vector &input_api_sched_axes, - const std::vector &my_api_sched_axes, - View &&tensor_view_to_update); - static std::vector GetDefaultVectorizedAxis(const std::vector &tensor_axis, int64_t loop_axis); -}; -} // namespace ge -#endif // METADEF_CXX_INC_GRAPH_UTILS_AXIS_UTILS_H_ diff --git a/inc/graph/utils/cg_utils.h b/inc/graph/utils/cg_utils.h deleted file mode 100644 index 76441b84dc1887dcfff5062462093d59f04a8d4a..0000000000000000000000000000000000000000 --- a/inc/graph/utils/cg_utils.h +++ /dev/null @@ -1,204 +0,0 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2024 All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef AUTOFUSE_CG_UTILS_H -#define AUTOFUSE_CG_UTILS_H -#include -#include "ascendc_ir/ascend_reg_ops.h" -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "utils/axis_utils.h" -#include "utils/node_utils_ex.h" -#include "utils/dtype_transform_utils.h" - -namespace ge { -namespace ascir { -using ge::Axis; -using ge::AxisId; -namespace cg { -constexpr char RELATED_OP[] = "RelatedOp"; -#define THROW(condition) if (!(condition)) throw std::runtime_error("Check Failed: " #condition) -struct LoopOption { - bool pad_tensor_axes_to_loop; -}; -class CgContext { - public: - static CgContext *GetThreadLocalContext(); - static std::shared_ptr GetSharedThreadLocalContext(); - static void SetThreadLocalContext(const std::shared_ptr &context); - - void SetOption(const LoopOption &option); - const LoopOption &GetOption() const; - - const std::vector &GetLoopAxes() const; - const std::vector &GetLoopAxisIds() const; - void SetLoopAxes(std::vector axes); - void PushLoopAxis(const Axis &axis); - void PopBackLoopAxis(const Axis &axis); - - void SetBlockLoopEnd(AxisId id); - AxisId GetBlockLoopEnd() const; - - void SetVectorizedLoopEnd(AxisId id); - AxisId GetVectorizedLoopEnd() const; - - void SetLoopEnd(AxisId id); - AxisId GetLoopEnd() const; - - private: - LoopOption option_; - std::vector loop_axes_; - std::vector loop_axis_ids_cache_; // 与 loop_axes_ 同源,避免反复创建 - - AxisId block_loop_end_{0}; - AxisId vectorized_loop_end_{0}; - AxisId loop_end_{0}; -}; - -class LoopGuard { - public: - explicit LoopGuard(const Axis &axis); - ~LoopGuard(); - - static std::unique_ptr Create(const Axis &axis) { - return Create(axis, {}); - } - - static std::unique_ptr Create(const Axis &axis, const LoopOption &option); - - private: - Axis axis_; - std::shared_ptr context_; -}; -using Axes = std::vector; - -class BlockLoopGuard { - explicit BlockLoopGuard(std::vector axes); - ~BlockLoopGuard(); -}; - -class VectorizedLoopGuard { - explicit VectorizedLoopGuard(std::vector axes); - ~VectorizedLoopGuard(); -}; - -#define INNER_LOOP_COUNTER_1(counter, axis) \ - for (auto guarder_##counter = ascir::cg::LoopGuard::Create(axis); guarder_##counter != nullptr; \ - guarder_##counter = nullptr) -#define INNER_LOOP_COUNTER(counter, axis) INNER_LOOP_COUNTER_1(counter, axis) -#define LOOP(axis) INNER_LOOP_COUNTER(__COUNTER__, axis) - -#define OPTION_LOOP_COUNTER_1(counter, axis, option) \ - for (auto guarder_##counter = ascir::cg::LoopGuard::Create(axis, option); guarder_##counter != nullptr; \ - guarder_##counter = nullptr) -#define OPTION_LOOP_COUNTER(counter, axis, option) OPTION_LOOP_COUNTER_1(counter, axis, option) -#define OPTION_LOOP(axis, option) OPTION_LOOP_COUNTER(__COUNTER__, axis, option) - -#define SET_SCHED_AXIS_IF_IN_CONTEXT(op) \ - do { \ - auto context = ascir::cg::CgContext::GetThreadLocalContext(); \ - if (context != nullptr) { \ - (op).attr.sched.axis = (context)->GetLoopAxisIds(); \ - if (!((op).attr.sched.axis.empty())) { \ - (op).attr.sched.loop_axis = ((op).attr.sched.axis.back()); \ - } \ - } \ - } while (0) - -class CodeGenUtils { - public: - static int64_t GenNextExecId(const ge::Operator &op); - static int64_t GenNextExecId(const ge::AscGraph &graph); - static int64_t GenNextTensorId(const ge::Operator &op); - static int64_t GenNextContainerId(const ge::Operator &op); - static int64_t GenNextReuseId(const ge::Operator &op); - static AscGraphAttr *GetOwnerGraphAscAttr(const Operator &op) { - const auto &node = NodeUtilsEx::GetNodeFromOperator(op); - GE_ASSERT_NOTNULL(node, "Node is null."); - const auto &compute_graph = node->GetOwnerComputeGraph(); - GE_ASSERT_NOTNULL(compute_graph, "Compute graph is null."); - - auto attr = compute_graph->GetOrCreateAttrsGroup(); - GE_ASSERT_NOTNULL(attr, "AscGraphAttr is null."); - return attr; - } - - static AscNodeAttr *GetOwnerOpAscAttr(const Operator &op) { - const auto &op_desc = OpDescUtils::GetOpDescFromOperator(op); - GE_ASSERT_NOTNULL(op_desc, "op_desc is null."); - return op_desc->GetOrCreateAttrsGroup(); - } - private: - static int64_t GenNextExecId(const ge::ComputeGraphPtr &graph); -}; - -inline bool PadOutputViewToSched(ge::AscOpOutput &output) { - auto context = ascir::cg::CgContext::GetThreadLocalContext(); - if (context == nullptr || !context->GetOption().pad_tensor_axes_to_loop) { - return true; - } - - // check if need pad - auto &sched_ids = context->GetLoopAxisIds(); - const auto &origin_axis_ids = output.axis; - if (origin_axis_ids->size() == sched_ids.size()) { - return *origin_axis_ids == sched_ids; - } - - // calc pad indexes, if op_i not iter to the end, means the axis order in tensor is different from sched - // max: pad, positive: index of origin_axis_ids - std::vector indexes; - size_t op_i = 0U; - for (auto sched_axis_id : sched_ids) { - if (op_i < origin_axis_ids->size() && sched_axis_id == (*origin_axis_ids).at(op_i)) { - indexes.push_back(op_i++); - } else { - indexes.push_back(std::numeric_limits::max()); - } - } - if (op_i != origin_axis_ids->size()) { - return false; - } - - // do pad - const auto &origin_repeats = output.repeats; - const auto &origin_strides = output.strides; - std::vector padded_axis_ids; - std::vector padded_repeats; - std::vector padded_strides; - for (size_t i = 0U; i < indexes.size(); ++i) { - op_i = indexes[i]; - if (op_i == std::numeric_limits::max()) { - padded_axis_ids.push_back(sched_ids.at(i)); - padded_repeats.push_back(sym::kSymbolOne); - padded_strides.push_back(sym::kSymbolZero); - } else { - padded_axis_ids.push_back((*origin_axis_ids).at(op_i)); - padded_repeats.push_back((*origin_repeats).at(op_i)); - padded_strides.push_back((*origin_strides).at(op_i)); - } - } - - *output.axis = padded_axis_ids; - *output.repeats = padded_repeats; - *output.strides = padded_strides; - - return true; -} -} // namespace cg -} // namespace ascir -} - -#endif // AUTOFUSE_CG_UTILS_H diff --git a/inc/graph/utils/connection_matrix.h b/inc/graph/utils/connection_matrix.h deleted file mode 100644 index bbcfd2979c0bf2acabd189e3fd2a3a80847f812e..0000000000000000000000000000000000000000 --- a/inc/graph/utils/connection_matrix.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CONNECTION_MATRIX_H_ -#define GRAPH_CONNECTION_MATRIX_H_ - -#include "graph/node.h" -#include "graph/graph.h" -#include "graph/compute_graph.h" - -namespace ge { -class ConnectionMatrixImpl; -using ConnectionMatrixImplPtr = std::shared_ptr; - -class ConnectionMatrix { -public: - explicit ConnectionMatrix(const ComputeGraphPtr &graph); - ~ConnectionMatrix() = default; - - bool IsConnected(const NodePtr &a, const NodePtr &b) const; - - // inputs are all input nodes of parameter node. - // if there is a path between A->B, then B will own A's - // connectivity. The reason is --- - // If some node can reach A, than it can also reach B. - void SetConnectivity(const Node::Vistor &inputs, const NodePtr &node); - - /* Computes the connectivity between two nodes in the - * computation. The returned ConnectivityMatrix is constructed such that - * ConnectivityMatrix::IsConnected(a, b) returns true iff there exists a - * directed path (from producer to consumer) from 'a' to 'b'. Both data - * connection and control connection are considered for connectivity. - * A node is connected to itself. */ - graphStatus Generate(const ComputeGraphPtr &graph); - - // update reachablity map for fused nodes. - void Update(const ComputeGraphPtr &graph, const std::vector &fusion_nodes); - - void ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name); -private: - ConnectionMatrixImplPtr impl_{nullptr}; -}; -} -#endif // GRAPH_CONNECTION_MATRIX_H_ diff --git a/inc/graph/utils/constant_utils.h b/inc/graph/utils/constant_utils.h deleted file mode 100644 index fdf423af7d4b66dc6f5781d30e69837404ccce2e..0000000000000000000000000000000000000000 --- a/inc/graph/utils/constant_utils.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_GRAPH_UTILS_CONSTANT_UTILS_H_ -#define COMMON_GRAPH_UTILS_CONSTANT_UTILS_H_ -#include "graph/node.h" -#include "graph/op_desc.h" - -namespace ge { -class ConstantUtils { - public: - // check is constant - static bool IsConstant(const NodePtr &node); - static bool IsConstant(const OpDescPtr &op_desc); - static bool IsPotentialConst(const OpDescPtr &op_desc); - static bool IsRealConst(const OpDescPtr &op_desc); - // get/set weight - static bool GetWeight(const OpDescPtr &op_desc, const uint32_t index, ConstGeTensorPtr &weight); - static bool MutableWeight(const OpDescPtr &op_desc, const uint32_t index, GeTensorPtr &weight); - static bool SetWeight(const OpDescPtr &op_desc, const uint32_t index, const GeTensorPtr weight); - static bool MarkPotentialConst(const OpDescPtr &op_desc, const std::vector indices, - const std::vector weights); - static bool UnMarkPotentialConst(const OpDescPtr &op_desc); - // for fileconstant - static bool GetWeightFromFile(const OpDescPtr &op_desc, ConstGeTensorPtr &weight); - private: - static bool GetPotentialWeight(const OpDescPtr &op_desc, std::vector &weight_indices, - std::vector &weights); - static bool MutablePotentialWeight(const OpDescPtr &op_desc, std::vector &weight_indices, - std::vector &weights); -}; -} - -#endif // COMMON_GRAPH_UTILS_CONSTANT_UTILS_H_ diff --git a/inc/graph/utils/cycle_detector.h b/inc/graph/utils/cycle_detector.h deleted file mode 100644 index 5bc86e68aef16ae275bbc183f24b3562ebf37d56..0000000000000000000000000000000000000000 --- a/inc/graph/utils/cycle_detector.h +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_CYCLE_DETECTOR_H_ -#define GRAPH_CYCLE_DETECTOR_H_ - -#include "graph/node.h" -#include "graph/compute_graph.h" -#include "connection_matrix.h" - -namespace ge { -class CycleDetector { - friend class GraphUtils; -public: - CycleDetector() = default; - ~CycleDetector() = default; - /* Detect whether there are cycles in graph - * after fusing all nodes in param fusion_nodes. - * Before call this func, you should call GenerateConnectionMatrix frist - * to generate connection_matrix based on current graph. - * - * Compared with Cycle Detection - * @param fusion_nodes: each vector in fusion_nodes - * will be fused into an entity(which could contains - * more than one node). The caller should put all original - * nodes which are expected to be fused into one larger node - * into each sub-vector of fusion_nodes. - * - * This function can tell whether there are a cycle after - * fusing all nodes in fusion_nodes. Each vector in 2-d - * vector fusion_nodes will be fused into an entity. - * - * - * This interface cannot detect whether there are cycles - * inside the fused nodes. - * - * e.g. {a, b, c, d} -> {e, f} - * Because the edge information is not given for e and f - * so this function we cannot tell if e and f are in a - * cycle. - * */ - bool HasDetectedCycle(const std::vector> &fusion_nodes); - - /** - * Update connection matrix based on graph. - * Connection matrix is served for cycle detection. - * - * The first param graph, it should be the same one graph when contribue cycle_detector - */ - void Update(const ComputeGraphPtr &graph, const std::vector &fusion_nodes); - - /** - * Expand dim and update connection matrix based on graph. - */ - void ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name); -private: - graphStatus Init(const ComputeGraphPtr &graph); - std::unique_ptr connectivity_{nullptr}; -}; - -using CycleDetectorPtr = std::unique_ptr; -using CycleDetectorSharedPtr = std::shared_ptr; -} -#endif // GRAPH_CYCLE_DETECTOR_H_ diff --git a/inc/graph/utils/dtype_transform_utils.h b/inc/graph/utils/dtype_transform_utils.h deleted file mode 100644 index a499754de0db301548fe552656a6aec45a0265c1..0000000000000000000000000000000000000000 --- a/inc/graph/utils/dtype_transform_utils.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#ifndef METADEF_CXX_INC_GRAPH_UTILS_DTYPE_TRANSFORM_UTILS_H_ -#define METADEF_CXX_INC_GRAPH_UTILS_DTYPE_TRANSFORM_UTILS_H_ -#include "graph/types.h" -class DtypeTransformUtils { - public: - static ge::DataType Prompt(ge::DataType src_type); -}; -#endif // METADEF_CXX_INC_GRAPH_UTILS_DTYPE_TRANSFORM_UTILS_H_ diff --git a/inc/graph/utils/enum_attr_utils.h b/inc/graph/utils/enum_attr_utils.h deleted file mode 100644 index 449e10e6284ac4f092023bd09f5471a2847566c0..0000000000000000000000000000000000000000 --- a/inc/graph/utils/enum_attr_utils.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_METADEF_ENUM_ATTR_UTILS_H -#define __INC_METADEF_ENUM_ATTR_UTILS_H - -#include -#include "common/ge_common/util.h" -#include "graph/ge_error_codes.h" -#include "graph/ge_tensor.h" - -namespace ge { -using namespace std; -constexpr uint16_t kMaxValueOfEachDigit = 127U; -constexpr size_t kAppendNum = 1U; -constexpr char_t prefix = '\0'; - -class EnumAttrUtils { - public: - static void GetEnumAttrName(vector &enum_attr_names, const string &attr_name, string &enum_attr_name, - bool &is_new_attr); - static void GetEnumAttrValue(vector &enum_attr_values, const string &attr_value, int64_t &enum_attr_value); - static void GetEnumAttrValues(vector &enum_attr_values, const vector &attr_values, - vector &enum_values); - - static graphStatus GetAttrName(const vector &enum_attr_names, const vector name_use_string_values, - const string &enum_attr_name, string &attr_name, bool &is_value_string); - static graphStatus GetAttrValue(const vector &enum_attr_values, const int64_t enum_attr_value, - string &attr_value); - static graphStatus GetAttrValues(const vector &enum_attr_values, const vector &enum_values, - vector &attr_values); - private: - static void Encode(const uint32_t src, string &dst); - static void Decode(const string &src, size_t &dst); -}; -} // namespace ge -#endif // __INC_METADEF_ENUM_ATTR_UTILS_H diff --git a/inc/graph/utils/execute_graph_adapter.h b/inc/graph/utils/execute_graph_adapter.h deleted file mode 100644 index 486e6af646116bba3f1d290c2c5ec4ab6356434b..0000000000000000000000000000000000000000 --- a/inc/graph/utils/execute_graph_adapter.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_EXECUTE_GRAPH_ADAPTER_H -#define INC_GRAPH_UTILS_EXECUTE_GRAPH_ADAPTER_H - -#include "graph/fast_graph/execute_graph.h" -#include "external/graph/ge_error_codes.h" - -namespace ge { -class ExecuteGraphAdapter { - public: - ~ExecuteGraphAdapter() = default; - ExecuteGraphAdapter(const ExecuteGraphAdapter &adapter) = delete; - ExecuteGraphAdapter &operator=(const ExecuteGraphAdapter &adapter) = delete; - - // 返回的ComputeGraph复用了原图src_graph的OpDesc对象,返回后src_graph不能释放 - // src_graph和返回的ComputeGraph的生命周期需要保证一致 - static ComputeGraphPtr ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph); - - private: - ExecuteGraphAdapter() = default; - static graphStatus ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, - const int32_t depth); - static graphStatus CopyOpAndSubgraph(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, - std::unordered_map &all_new_nodes, const int32_t depth); - static graphStatus RelinkGraphEdges(FastNode *old_node, - const std::unordered_map &all_new_nodes); - static graphStatus CopyMembers(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, - const std::unordered_map &all_new_nodes); - static void InheritOriginalAttr(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_EXECUTE_GRAPH_ADAPTER_H diff --git a/inc/graph/utils/execute_graph_utils.h b/inc/graph/utils/execute_graph_utils.h deleted file mode 100644 index 3cb48848562889db3c81145e45532a6a6e29cbe6..0000000000000000000000000000000000000000 --- a/inc/graph/utils/execute_graph_utils.h +++ /dev/null @@ -1,195 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_EXECUTE_GRAPH_UTILS_H -#define INC_GRAPH_UTILS_EXECUTE_GRAPH_UTILS_H -#include "graph/fast_graph/fast_node.h" -#include "graph/fast_graph/execute_graph.h" - -namespace ge { -class ExecuteGraphUtils { - public: - /** - * 查找`exe_graph`的根图,如果当前图就是根图或者当前图没有父图,则返回当前图 - * @param exe_graph - * @return - */ - static ExecuteGraph *FindRootGraph(ExecuteGraph *exe_graph); - - /** - * 查找`exe_graph`中节点名为`name`的节点,包含子图 - * @param exe_graph - * @return - */ - static FastNode *FindNodeFromAllNodes(ExecuteGraph *exe_graph, const char_t *const name); - - /** - * 查找`exe_graph`中节点类型为`type`的节点,包含子图 - * @param exe_graph - * @return - */ - static std::vector FindNodesByTypeFromAllNodes(ExecuteGraph *exe_graph, const char_t *const type); - - /** - * 查找`exe_graph`中首个节点类型为`type`的节点,不包含子图 - * @param exe_graph - * @param type - * @return - */ - static FastNode *FindFirstNodeMatchType(ExecuteGraph *exe_graph, const char_t *const type); - - /** - * 接口行为是在'src'的源节点输出端和'dst'目的节点输入端们之间插入一个`insert_node`节点, - * 默认是`insert_node`的`0`号数据输入端和`0`号输出端参与连边,`insert_node`插入之后, `src_node`和`insert_node` - * 作为一个整体与原来的`src_node`具备等价的控制和数据关系 - * @param src 源数据输出端 - * @param dsts 源数据输出端连接的目的数据输入端,使用vector的原因是存在一个源节点输出端给到多个目的节点输入端的情况 - * @param insert_node 表示要插入的节点 - * @param input_index 表示插入节点的哪个数据输入端要跟src相连,如果不传递,默认取0 - * @param output_index 表示插入节点的哪个数据输出端要跟dsts依次相连,如果不传递,默认取0 - * @return 如果插入成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED - */ - static graphStatus InsertNodeAfter(const EdgeSrcEndpoint &src, const std::vector &dsts, - FastNode *insert_node, const uint32_t input_index = 0U, - const uint32_t output_index = 0U); - - /** - * 接口行为是在数据`dst`目的节点输入端和其对端源节点输出端之间插入一个`insert_node`节点, - * 默认是`insert_node`的`0`号数据输入端和`0`号数据输出数据端参与连边,`insert_node`插入之后, - * `dst_node`和`insert_node`作为一个整体与原来的`dst_node`具备等价的控制和数据关系 - * @param dst 目的数据输入端 - * @param insert_node 表示要插入的节点 - * @param input_index 表示插入节点的哪个数据输入端要跟dst的对端src输出端相连,如果不传递,默认取0 - * @param output_index 表示插入节点的哪个数据输出端要跟dst相连,如果不传递,默认取0 - * @return 如果插入成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED - */ - static graphStatus InsertNodeBefore(const EdgeSrcEndpoint &dst, FastNode *insert_node, - const uint32_t input_index = 0U, const uint32_t output_index = 0U); - - /** - * 移动`src_node`的控制输入边到`dst_node`上 - * @param src_node - * @param dst_node - * @return - */ - static graphStatus MoveInCtrlEdges(const FastNode *src_node, FastNode *dst_node); - - /** - * 移动`src_node`的控制输出边到`dst_node`上 - * @param src_node - * @param dst_node - * @return - */ - static graphStatus MoveOutCtrlEdges(const FastNode *src_node, FastNode *dst_node); - - /** - * 将node所有的输入、输出边断开,并移动到dst_graph - * @param dst_graph 目的Graph, - * @param node 需要移动的Node - * @return 成功时,返回ge::GRAPH_SUCCESS - */ - static graphStatus MoveNodeToGraph(FastNode *node, ExecuteGraph *dst_graph); - - /** - * 接口行为是根据`inputs_map`和`outputs_map`把`old_node`上的数据关系`移动`到`new_node`上;具体操作是 - * 把`old_node`的第`inputs_map[i]`/`outputs_map[i]`个数据输入、输出端点的数据关系替换到`new_node`的第`i`个 - * 输入、输出端点上, `i`的取值范围是[0, `inputs_map`/`outputs_map`的元素个数); 如果`inputs_map[i]`/`outputs_map[i]` - * 的值小于0或者不在`old_node`的输入、输出端点范围之内,那么`new_node`的第`i`个数据输入、输出端点的数据关系保持原样 - * @param new_node - * @param old_node - * @param inputs_map 用于指导输入数据端点的替换,注意元素个数不应该超过`new_node`的输入端点总个数 - * @param outputs_map 用于指导输出端点的替换,注意元素个数不应该超过`new_node`的输出端点总个数 - * @param graph 表示`new_node`和`old_node`所在的graph,如果不传会通过`new_node`获取 - * @return - */ - static graphStatus ReplaceNodeDataEdges(FastNode *new_node, FastNode *old_node, - const std::initializer_list inputs_map, - const std::initializer_list outputs_map, - ExecuteGraph *graph = nullptr); - static graphStatus ReplaceNodeDataEdges(FastNode *new_node, FastNode *old_node, - const std::vector &inputs_map, - const std::vector &outputs_map, ExecuteGraph *graph = nullptr); - - /** - * 此接口对数据关系的处理与`ReplaceNodeDataEdges`的处理行为一致, 在此基础上, - * 复制了`old_node`的所有控制关系到`new_node`上,这也是要注意的一点: - * `数据`关系是`移动`操作,`控制`关系是`复制`操作 - * @param new_node - * @param old_node - * @param inputs_map 用于指导输入数据锚点的替换,注意元素个数不应该超过`new_node`的输入短点总个数 - * @param outputs_map 用于指导输出锚点的替换,注意元素个数不应该超过`new_node`的输出短点总个数 - * @return 如果替换成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED - */ - static graphStatus ReplaceNodeEdges(FastNode *new_node, FastNode *old_node, - const std::initializer_list inputs_map, - const std::initializer_list outputs_map); - static graphStatus ReplaceNodeEdges(FastNode *new_node, FastNode *old_node, const std::vector &inputs_map, - const std::vector &outputs_map); - - /** - * 孤立`node`, 根据`io_map`完成node的输入输出数据边的重连;同时会添加必要的控制边保证`node`的所有输入节点 - * 均在`node`的输出节点之前执行 - * @param node - * @param io_map 把第`io_map[i]`个输入的对端输出,连接到第`i`个输出的对端输入。因此`io_map`的元素个数应该与 - * `node`的输出端点的个数相等,如果`io_map[i]`小于0,则仅断开第`i`个输出端点到对端的所有连边 - * @return - */ - static graphStatus IsolateNode(FastNode *node, const std::initializer_list &io_map); - static graphStatus IsolateNode(FastNode *node, const std::vector &io_map); - - /** - * 替换`old_edge`的``src`为`new_src`指示的`node`和`index` - * @param old_edge - * @param new_src - * @return 替换成功返回GRAPH_SUCCESS, 替换失败返回GRAPH_FAILED - */ - static graphStatus ReplaceEdgeSrc(FastEdge *old_edge, const EdgeSrcEndpoint &new_src); - - /** - * 从`execute_graph`上删除`直接`或者`间接`父节点为remove_node的所有子图对象 - * @param execute_graph - * @param remove_node - * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED - */ - static graphStatus RemoveSubgraphRecursively(ExecuteGraph *execute_graph, FastNode *remove_node); - - /** - * 从`execute_graph`中删除`node`对象的所有关系,包括子图关系,从属关系,作为`execute_graph`的输入,输出的关系; - * 仅删除,不进行断边连边,不保证删除后节点前后的控制关系传递 - * @param execute_graph - * @param node - * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED - */ - static graphStatus RemoveNodeWithoutRelink(ExecuteGraph *execute_graph, FastNode *node); - - /** - * 拷贝`src_node`的输入控制边到`dst_node`上 - * @param src_node - * @param dst_node - * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED - */ - static graphStatus CopyInCtrlEdges(const FastNode *src_node, FastNode *dst_node); - - /** - * 拷贝`src_node`的输出控制边到`dst_node`上 - * @param src_node - * @param dst_node - * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED - */ - static graphStatus CopyOutCtrlEdges(const FastNode *src_node, FastNode *dst_node); - - /** - * 构建并返回根图中所有节点名到节点的映射,包含子图中的节点 - * @param exe_graph - * @return 节点名到节点的映射 - */ - static std::unordered_map GetNodeMapFromAllNodes(ExecuteGraph *exe_graph); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_EXECUTE_GRAPH_UTILS_H diff --git a/inc/graph/utils/fast_node_utils.h b/inc/graph/utils/fast_node_utils.h deleted file mode 100644 index d42b8f58cdccb77b8ead5d6a150048c73ea35dbb..0000000000000000000000000000000000000000 --- a/inc/graph/utils/fast_node_utils.h +++ /dev/null @@ -1,168 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_FAST_NODE_UTILS_H -#define INC_GRAPH_UTILS_FAST_NODE_UTILS_H - -#include "graph/fast_graph/fast_node.h" -#include "graph/fast_graph/edge.h" -#include "graph/fast_graph/execute_graph.h" -#include "external/graph/ge_error_codes.h" - -namespace ge { -class FastNodeUtils { - public: - // node utils - /** - * @brief 获取给定节点的父节点的输入。 - * - * @param node 指向待查询节点的指针。 - * @return 返回指向父节点的输入的指针,如果父节点不存在,则返回 nullptr。 - */ - static FastNode *GetParentInput(const FastNode *const node); - - /** - * @brief 获取指定索引处的输入数据节点。 - * - * @param node 指向要访问其输入数据节点的指针。 - * @param index 要检索的输入数据节点的索引。 - * @return 如果找到指定索引处的输入数据节点,则返回该节点的指针;否则返回 nullptr。 - */ - static FastNode *GetInDataNodeByIndex(const FastNode * const node, const int32_t index); - - /** - * @brief 获取给定节点是否为常量 Op 节点。 - * - * @param node 指向待查询节点的指针。 - * @return 如果给定节点是常量 Op 节点,则返回 true;否则返回 false。 - */ - static bool GetConstOpType(const FastNode *const node); - - // subgraph utils - /** - * @brief 向给定节点追加子图。 - * - * @param node 指向待添加子图的节点的指针。 - * @param subgraph_name 子图的名称。 - * @param subgraph 指向待添加的子图的智能指针,子图生命周期由 root_graph 管理。 - * @return 返回添加子图的状态。 - */ - static graphStatus AppendSubgraphToNode(FastNode *const node, const std::string &subgraph_name, - const ExecuteGraphPtr &subgraph); - - /** - * @brief 获取给定节点的指定索引处的子图。 - * - * @param node 指向待查询子图的节点的指针。 - * @param index 子图的索引。 - * @return 返回指向子图的指针,如果索引超出范围或者子图不存在,则返回 nullptr。 - */ - static ExecuteGraph *GetSubgraphFromNode(const FastNode *const node, const uint32_t index); - - /** - * @brief 在给定节点的指定索引 index 处挂载子图。若原先存在子图,该接口会替换 node 下索引为 index - * 的子图,但是不会从root_graph中移除原子图。 若需移除原子图,建议调用ExecuteGraph::RemoveSubGraph。 - * - * @param node 指向待设置子图的节点的指针。 - * @param index 子图的索引。 - * @param subgraph 指向待设置的子图的智能指针。 - * @return 返回设置子图的状态。 - */ - static graphStatus MountSubgraphToNode(FastNode *const node, const uint32_t index, const ExecuteGraphPtr &subgraph); - - // edge utils - /** - * @brief 向给定节点追加输入边信息,直至输入边信息数量达到 num。 - * 注意:该接口不会实际建立输入边,若需要为节点连边,建议调用ExecuteGraph::AddEdge。 - * - * @param node 待追加输入边信息的节点的指针。 - * @param num 追加操作后,node 所拥有的输入边信息数量。 - * @return 返回追加输入边信息的状态。 - */ - static graphStatus AppendInputEdgeInfo(FastNode *const node, const uint32_t num); - - /** - * @brief 向给定节点追加输出边信息,直至输出边信息数量达到 num。 - * 注意:该接口不会实际建立输出边,若需要为节点连边,建议调用ExecuteGraph::AddEdge。 - * - * @param node 待追加输出边信息的节点的指针。 - * @param num 追加操作后,node 所拥有的输出边信息数量。 - * @return 返回追加输出边信息的状态。 - */ - static graphStatus AppendOutputEdgeInfo(FastNode *const node, const uint32_t num); - - /** - * @brief 清除给定 OpDesc 的指定索引处的 InputDesc。 - * - * @param op_desc OpDesc 的指针。 - * @param index 要清除的 InputDesc 的索引。 - * @return 如果成功清除 InputDesc,则返回 true;否则返回 false。 - */ - static bool ClearInputDesc(const OpDesc *const op_desc, const uint32_t index); - - /** - * @brief 移除给定节点的输入边信息,直至输入边信息数量减少到 num。 - * 注意:该接口不会从执行图中移除输入边,若需要移除输入边,建议调用ExecuteGraph::RemoveEdge。 - * - * @param node 待移除输入边信息的节点的指针。 - * @param num 移除操作后,node 所拥有的输入边信息数量。 - * @return 返回追加输出边信息的状态。 - */ - static graphStatus RemoveInputEdgeInfo(FastNode *const node, const uint32_t num); - - /** - * @brief 断开给定节点与其所有相连节点之间的输入边和输出边。 - * - * @param node 目标节点。 - */ - static void UnlinkAll(FastNode *const node); - - /** - * @brief 获取给定边的输入端点,包含 dst 节点指针和输入 index。 - * (SrcNode:[OutEndpoint])->Edge->([InEndpoint]:DstNode) - * - * @param edge 指向待查询输入端点的边的指针。 - * @return 返回输入端点。 - */ - static EdgeDstEndpoint GetDstEndpoint(const FastEdge *const edge); - - /** - * @brief 获取给定边的输出端点,包含 src 节点指针和输出 index。 - * (SrcNode:[OutEndpoint])->Edge->([InEndpoint]:DstNode) - * - * @param edge 指向待查询输出端点的边的指针。 - * @return 返回输出端点。 - */ - static EdgeSrcEndpoint GetSrcEndpoint(const FastEdge *const edge); -}; - -struct FastNodeCompareKey { - bool operator()(const FastNode *const n0, const FastNode *const n1) const { - if ((n0 == nullptr) || (n1 == nullptr)) { - return false; - } - if (n0->GetName() == n1->GetName()) { - const ExtendInfo *const extend_info0 = n0->GetExtendInfo(); - const ExtendInfo *const extend_info1 = n1->GetExtendInfo(); - if ((extend_info0 == nullptr) || (extend_info1 == nullptr)) { - return false; - } - const ExecuteGraph *const g0 = extend_info0->GetOwnerGraphBarePtr(); - const ExecuteGraph *const g1 = extend_info1->GetOwnerGraphBarePtr(); - if ((g0 == nullptr) || (g1 == nullptr)) { - return false; - } - return (g0->GetName() < g1->GetName()); - } - return (n0->GetName() < n1->GetName()); - } -}; -} // namespace ge - -#endif // INC_GRAPH_UTILS_FAST_NODE_UTILS_H diff --git a/inc/graph/utils/ffts_graph_utils.h b/inc/graph/utils/ffts_graph_utils.h deleted file mode 100644 index ba19ae24c0ba6f1b90000736721e9cfbb8490470..0000000000000000000000000000000000000000 --- a/inc/graph/utils/ffts_graph_utils.h +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_FFTS_GRAPH_UTILS_H_ -#define INC_GRAPH_UTILS_FFTS_GRAPH_UTILS_H_ - -#include "graph/anchor.h" -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/node.h" - -namespace ge { -class FftsGraphUtils { - public: - using CalcFunc = std::function(const NodePtr &)>; - static graphStatus GraphPartition(ComputeGraph &graph, const std::set &unsupported_nodes); - - static graphStatus GraphPartition(ComputeGraph &graph, - const CalcFunc &calc_func, - const std::vector &upper_limit); - private: - static graphStatus CollectClipNodesAndGraphs(const ComputeGraphPtr &graph, - const std::set &unsupported_nodes, - std::unordered_set &nodes_need_clip, - std::unordered_set &graphs_need_split); - - static bool IsGraphNeedSplit(const ComputeGraphPtr &graph, const std::unordered_set &nodes_need_clip); - - static graphStatus SplitNodesWithCheck(const ComputeGraphPtr &graph, - const std::unordered_set &nodes_need_clip, - std::vector>> &split_nodes); - - static void SplitNodes(const std::set &calc_nodes, const std::function &is_cur_stage, - std::set &visited_nodes, std::set &cur_nodes, std::set &next_nodes); - - static graphStatus SplitSubgraph(const ComputeGraphPtr &subgraph, - const std::vector>> &split_nodes); - - static graphStatus BuildFftsPlusSubgraphWithAllNodes(const ComputeGraphPtr &subgraph); - - static void CollectCalcNodeInSubgraph(const ComputeGraphPtr &subgraph, std::set &calc_nodes); - - static void CollectEndNodeInSubgraph(const ComputeGraphPtr &subgraph, const std::set &ctrl_goto_types, - std::set &edge_nodes); - - static ComputeGraphPtr GetFftsPlusGraph(ComputeGraph &graph); - - static graphStatus SetAttrForFftsPlusSubgraph(const ComputeGraphPtr &subgraph); - - static graphStatus Calculate(const ComputeGraphPtr &graph, - const CalcFunc &calc_func, - std::map> &node_value, - std::map> &graph_value, - const uint32_t recursive_depth = 1U); - - static std::vector Calculate(const NodePtr &node, const CalcFunc &calc_func, - std::map> &node_value, - std::map> &graph_value, - const uint32_t recursive_depth); - - static bool IsValueValid(const ComputeGraphPtr &graph, const std::vector &upper_limit, - const std::map> &node_value, - const std::map> &graph_value); - - static graphStatus PartitionGraphWithLimit(const ComputeGraphPtr &graph, - std::map> &node_value, - std::map> &graph_value, - const std::vector &upper_limit, - const uint32_t recursive_depth = 1U); - - static graphStatus SplitFuncNode(const std::vector exceed_single_node, - std::map> &node_value, - std::map> &graph_value, - const std::vector &upper_limit, - const uint32_t recursive_depth); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/inc/graph/utils/graph_dump_utils.h b/inc/graph/utils/graph_dump_utils.h deleted file mode 100644 index 0251537d77f65add19bbd077bc32a4e4cdfb433d..0000000000000000000000000000000000000000 --- a/inc/graph/utils/graph_dump_utils.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_GRAPH_DUMP_UTILS_H -#define INC_GRAPH_UTILS_GRAPH_DUMP_UTILS_H - -#include "graph/compute_graph.h" -#include "graph/fast_graph/execute_graph.h" -#include "graph/utils/execute_graph_adapter.h" -#include "graph/utils/graph_utils.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -/** - * 将ComputeGraph落盘成文件 - * @param compute_graph 要落盘的对象 - * @param name 落盘的文件名,会拼接上默认的前后缀 - * @return - */ -inline void DumpGraph(const ComputeGraphPtr &compute_graph, const char_t *const name) { - ge::GraphUtils::DumpGEGraph(compute_graph, name); - ge::GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); - uint64_t i = 0U; - for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { - const auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); - ge::GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); - ge::GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); - } -} - -/** - * 将ExecuteGraph落盘成文件 - * @param execute_graph 要落盘的对象 - * @param name 落盘的文件名,会拼接上默认的前后缀 - * @return - */ -inline void DumpGraph(ExecuteGraph *execute_graph, const char_t *const name) { - const char_t *dump_ge_graph = nullptr; - MM_SYS_GET_ENV(MM_ENV_DUMP_GE_GRAPH, dump_ge_graph); - if (dump_ge_graph == nullptr) { - return; - } - const auto compute_graph = ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(execute_graph); - if (compute_graph != nullptr) { - DumpGraph(compute_graph, name); - } -} -} // namespace ge -#endif // INC_GRAPH_UTILS_GRAPH_DUMP_UTILS_H diff --git a/inc/graph/utils/graph_thread_pool.h b/inc/graph/utils/graph_thread_pool.h deleted file mode 100644 index 1fc065cb11276e08b3197141bc2bb68ea4d6581a..0000000000000000000000000000000000000000 --- a/inc/graph/utils/graph_thread_pool.h +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_GRAPH_THREAD_POOL_H_ -#define INC_GRAPH_UTILS_GRAPH_THREAD_POOL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common/ge_common/debug/ge_log.h" -#include "common/ge_common/ge_inner_error_codes.h" -#include "external/ge_common/ge_api_error_codes.h" -#include "common/util/mem_utils.h" - -namespace ge { -using ThreadTask = std::function; - -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GraphThreadPool { - public: - explicit GraphThreadPool(const uint32_t size = 4U); - ~GraphThreadPool(); - - template - auto commit(Func &&func, Args &&... args) -> std::future { - GELOGD("commit run task enter."); - using retType = decltype(func(args...)); - std::future fail_future; - if (is_stoped_.load()) { - GELOGE(ge::FAILED, "thread pool has been stopped."); - return fail_future; - } - - const auto bindFunc = std::bind(std::forward(func), std::forward(args)...); - const auto task = MakeShared>(bindFunc); - if (task == nullptr) { - GELOGE(ge::FAILED, "Make shared failed."); - return fail_future; - } - std::future future = task->get_future(); - { - const std::lock_guard lock{m_lock_}; - tasks_.emplace([task]() { (*task)(); }); - } - cond_var_.notify_one(); - GELOGD("commit run task end"); - return future; - } - - static void ThreadFunc(GraphThreadPool *const thread_pool); - - private: - std::vector pool_; - std::queue tasks_; - std::mutex m_lock_; - std::condition_variable cond_var_; - std::atomic is_stoped_; - std::atomic idle_thrd_num_; -}; -} // namespace ge - -#endif // INC_GRAPH_UTILS_GRAPH_THREAD_POOL_H_ diff --git a/inc/graph/utils/graph_utils_ex.h b/inc/graph/utils/graph_utils_ex.h deleted file mode 100644 index 40df622c9c2d8256115f551b49be6549fa2751a1..0000000000000000000000000000000000000000 --- a/inc/graph/utils/graph_utils_ex.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_METADEF_GRAPH_UTILS_EX_H -#define __INC_METADEF_GRAPH_UTILS_EX_H - -#include "graph/node.h" -#include "graph/compute_graph.h" -#include "external/graph/graph.h" - -namespace ge { -class GraphUtilsEx { - public: - // Detach from ComputeGraph - static graphStatus InferOriginFormat(const ComputeGraphPtr &graph); - static graphStatus InferShapeInNeed(const ComputeGraphPtr &graph); - - // Detach from GraphUtils - static ComputeGraphPtr GetComputeGraph(const Graph &graph); - static ComputeGraphPtr CreateGraphFromOperator(const std::string &name, const std::vector &inputs); - - /** - * 使用ops中的算子为graph对象构造图,且构造出来的图中的算子需要按照ops中的顺序排序 - * @param graph 需要构造图的graph对象 - * @param ops 用于生成计算图的算子 - * @return 计算图指针,成功时,返回生成的ComputeGraph指针 失败返回nullptr - */ - static graphStatus CreateGraphFromOperatorWithStableTopo(Graph &graph, - const std::vector &ops); - /** - * 使用ops中的算子构造计算图,且构造出来的图中的算子需要按照ops中的顺序排序 - * @param graph 需要构造图的graph对象 - * @param ops 用于生成计算图的算子 - * @return 计算图指针,成功时,返回生成的ComputeGraph指针 失败返回nullptr - */ - static ComputeGraphPtr CreateComputeGraphFromOperatorWithStableTopo(const std::string &name, - const std::vector &ops); - - static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); - static GraphPtr CreateGraphPtrFromComputeGraph(const ComputeGraphPtr compute_graph); - static std::unique_ptr CreateGraphUniquePtrFromComputeGraph(const ComputeGraphPtr &compute_graph); - static void BreakConnect(const std::map &all_nodes_infos); - static graphStatus RecoverGraphOperators(const Graph &graph); - static graphStatus CopyGraph(const Graph &src_graph, Graph &dst_graph); - - /** - * 获取所有需要用户传入输入Tensor的Data节点,当前会排除掉分档场景新插入的Data节点 - * @param graph 图对象 - * @return 用户输入节点集合,失败时返回空集合 - */ - static std::vector GetUserInputDataNodes(const ComputeGraphPtr &compute_graph); - private: - static graphStatus CopyGraphImpl(const Graph &src_graph, Graph &dst_graph, - const std::map &node_old_2_new, - const std::map &op_desc_old_2_new); -}; -} // namespace ge -#endif // __INC_METADEF_GRAPH_UTILS_EX_H diff --git a/inc/graph/utils/hash_utils.h b/inc/graph/utils/hash_utils.h deleted file mode 100644 index 9430d6ef3bd3d51d6f41ffd00824ee452f96fbc6..0000000000000000000000000000000000000000 --- a/inc/graph/utils/hash_utils.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef GRAPH_COMPILE_CACHE_POLICY_HASH_UTILS_H_ -#define GRAPH_COMPILE_CACHE_POLICY_HASH_UTILS_H_ - -#include -#include -#include -#include "graph/small_vector.h" -#include "graph/types.h" - -namespace ge { -using CacheHashKey = uint64_t; -class HashUtils { -public: - static constexpr CacheHashKey HASH_SEED = 0x7863a7deUL; - static constexpr CacheHashKey COMBINE_KEY = 0x9e3779b9UL; - - template - static inline CacheHashKey HashCombine(CacheHashKey seed, const T &value) { - const std::hash hasher; - seed ^= hasher(value) + COMBINE_KEY + (seed << 6U) + (seed >> 2U); - return seed; - } - - static inline CacheHashKey HashCombine(CacheHashKey seed, const Format &value) { - const std::hash hasher; - seed ^= hasher(static_cast(value)) + COMBINE_KEY + (seed << 6U) + (seed >> 2U); - return seed; - } - - static inline CacheHashKey HashCombine(CacheHashKey seed, const DataType &value) { - const std::hash hasher; - seed ^= hasher(static_cast(value)) + COMBINE_KEY + (seed << 6U) + (seed >> 2U); - return seed; - } - - template - static inline CacheHashKey HashCombine(CacheHashKey seed, const SmallVector &values) { - for (const auto &val : values) { - seed = HashCombine(seed, val); - } - return seed; - } - - template - static inline CacheHashKey HashCombine(CacheHashKey seed, const std::vector &values) { - for (const auto &val : values) { - seed = HashCombine(seed, val); - } - return seed; - } - - static inline CacheHashKey MultiHash() { - return HASH_SEED; - } - - template - static inline CacheHashKey MultiHash(const T &value, const M... args) { - return HashCombine(MultiHash(args...), value); - } -}; -} // namespace ge -#endif diff --git a/inc/graph/utils/mem_utils.h b/inc/graph/utils/mem_utils.h deleted file mode 100644 index 1f27f14179249d3137c45c561d550fdf4fff7ee3..0000000000000000000000000000000000000000 --- a/inc/graph/utils/mem_utils.h +++ /dev/null @@ -1,101 +0,0 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2024 All rights reserved. - * - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_MEM_UTILS_H_ -#define INC_GRAPH_UTILS_MEM_UTILS_H_ -#include -#include -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -namespace ge { -template -void CheckAscTensorAttr(T &) { - static_assert(std::is_same::value, "Expected AscTensorAttr type"); -} -/// Usage: -/// Create TQueConfig: MemUtils::CreateTQueConfig(position, depth, buffer_num) -/// TQueConfig BindTensors: config.BindTensors(ascend_tensor1, ascend_tensor2, ...) -class TQueConfig { - friend class MemUtils; - public: - template - TQueConfig &BindTensors(Args &&...tensors) { - int dummy[] = { (CheckAscTensorAttr(std::forward(tensors)), 0)... }; - (void)dummy; - - int dummy1[] = { (tensors.que = queue_attr_, 0)... }; - (void)dummy1; - int dummy2[] = { (tensors.buf.id = kIdNone, 0)... }; - (void)dummy2; - int dummy3[] = { (tensors.mem.position = pos_, 0)... }; - (void)dummy3; - int dummy4[] = { (tensors.mem.alloc_type = AllocType::kAllocTypeQueue, 0)... }; - (void)dummy4; - - return *this; - } - TQueConfig() = default; - private: - TQueConfig(const int64_t id, const ge::Position pos, const int64_t depth, const int64_t buf_num); - MemQueAttr queue_attr_{}; - ge::Position pos_{ge::Position::kPositionInvalid}; -}; - -/// Usage: -/// Create TBufConfig: MemUtils::CreateTBufConfig(position) -/// TBufConfig BindTensors: config.BindTensors(ascend_tensor1, ascend_tensor2, ...) -class TBufConfig { - friend class MemUtils; - public: - template - TBufConfig &BindTensors(Args &&...tensors) { - int dummy[] = { (CheckAscTensorAttr(std::forward(tensors)), 0)... }; - (void)dummy; - - int dummy1[] = { (tensors.buf = buf_attr_, 0)... }; - (void)dummy1; - int dummy2[] = { (tensors.que.id = kIdNone, 0)... }; - (void)dummy2; - int dummy3[] = { (tensors.mem.position = pos_, 0)... }; - (void)dummy3; - int dummy4[] = { (tensors.mem.alloc_type = AllocType::kAllocTypeBuffer, 0)... }; - (void)dummy4; - - return *this; - } - TBufConfig() = default; - private: - TBufConfig(const int64_t id, const ge::Position pos); - MemBufAttr buf_attr_; - ge::Position pos_; -}; - -// Only applicable to the three-stage(Tque/Tbuf alloc) ascend ir graph construction -class MemUtils { - public: - static TQueConfig CreateTQueConfig(const ge::Position pos, const int64_t depth, const int64_t buf_num); - static TBufConfig CreateTBufConfig(const ge::Position pos); - - template - static void MergeScope(Args &&...tensors) { - // 修改合并作用域的展开方式 - int dummy[] = { (CheckAscTensorAttr(std::forward(tensors)), 0)... }; - (void)dummy; - int dummy1[] = { (tensors.opt.merge_scope = scope_id_, 0)... }; - (void)dummy1; - scope_id_++; - } - - private: - static std::atomic gen_container_id_; - static std::atomic scope_id_; -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_MEM_UTILS_H_ diff --git a/inc/graph/utils/multi_thread_graph_builder.h b/inc/graph/utils/multi_thread_graph_builder.h deleted file mode 100644 index 2e9f0512fd498d61a0920cec57288fbb2bbea285..0000000000000000000000000000000000000000 --- a/inc/graph/utils/multi_thread_graph_builder.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_METADEF_MULTI_THREAD_GRAPH_BUILDER_H -#define __INC_METADEF_MULTI_THREAD_GRAPH_BUILDER_H - -#include -#include -#include -#include "external/graph/graph.h" -#include "graph/utils/graph_thread_pool.h" - -namespace ge { -class MultiThreadGraphBuilder { - public: - explicit MultiThreadGraphBuilder(int32_t thread_num); - ~MultiThreadGraphBuilder() = default; - - Graph &SetInputs(const std::vector &inputs, ge::Graph &graph); - - private: - static graphStatus GetGraphRelatedOperators(const std::vector &inputs, - std::vector &related_ops); - static void GetOutputLinkOps(const OperatorImplPtr &op_impl, - std::vector &output_op_impls); - static graphStatus WalkForwardOperators(const std::vector &vec_ops, - std::vector &related_ops); - void ResetOpSubgraphBuilder(const OpDescPtr &op_desc, OperatorImplPtr &op_impl); - - int32_t thread_num_; - std::mutex mutex_; - std::unique_ptr pool_; -}; -} // namespace ge -#endif // __INC_METADEF_MULTI_THREAD_GRAPH_BUILDER_H diff --git a/inc/graph/utils/node_adapter.h b/inc/graph/utils/node_adapter.h deleted file mode 100644 index 9b8a60c4e9792be319391e3871cffbc9f2fe94c7..0000000000000000000000000000000000000000 --- a/inc/graph/utils/node_adapter.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_NODE_ADAPTER_H_ -#define INC_GRAPH_UTILS_NODE_ADAPTER_H_ - -#include "graph/gnode.h" -#include "graph/node.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodeAdapter { - public: - static GNode Node2GNode(const NodePtr &node); - static NodePtr GNode2Node(const GNode &node); - static GNodePtr Node2GNodePtr(const NodePtr &node); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_NODE_ADAPTER_H_ diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h deleted file mode 100644 index 81f1d8b25c3b0fd2f6592ab36620236088f1fae7..0000000000000000000000000000000000000000 --- a/inc/graph/utils/node_utils.h +++ /dev/null @@ -1,315 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ -#define INC_GRAPH_UTILS_NODE_UTILS_H_ - -#include -#include -#include -#include -#include "external/graph/operator.h" -#include "external/graph/types.h" -#include "graph/anchor.h" -#include "graph/node.h" -#include "graph/compute_graph.h" - -/*lint -e148*/ -namespace ge { -// Op types of Const like Opps. -extern const std::set kConstOpTypes; - -// Op types of Enter like Opps. -extern const std::set kEnterOpTypes; -// Op types of Merge like Opps. -extern const std::set kMergeOpTypes; -// Op types of Switch like Opps. -extern const std::set kSwitchOpTypes; -// Op types of NextIteration like Opps. -extern const std::set kNextIterationOpTypes; -// Op types of Exit like Opps. -extern const std::set kExitOpTypes; - -// Op types of If like Opps. -extern const std::set kIfOpTypes; -// Op types of While like Opps. -extern const std::set kWhileOpTypes; -// Op types of Case like Opps. -extern const std::set kCaseOpTypes; -// Op types of For like Opps. -extern const std::set kForOpTypes; - -class NodeUtils { - public: - static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor); - static graphStatus SetAllAnchorStatus(const NodePtr &node_ptr); - static graphStatus SetAllAnchorStatus(Node &node); - static bool IsAnchorStatusSet(const NodePtr &node_ptr); - static bool IsAnchorStatusSet(const Node &node); - - static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node); - - static void UpdateIsInputConst(const NodePtr &node_ptr); - static void UpdateIsInputConst(Node &node); - static bool IsConst(const Node &node); - static void UnlinkAll(const Node &node); - - static bool ClearInputDesc(const OpDescPtr &op_desc, const uint32_t index); - static bool ClearOutputDesc(const OpDescPtr &op_desc, const uint32_t index); - - static graphStatus AppendInputAnchor(const NodePtr &node, const uint32_t num); - static graphStatus RemoveInputAnchor(const NodePtr &node, const uint32_t num); - - static graphStatus AppendOutputAnchor(const NodePtr &node, const uint32_t num); - static graphStatus RemoveOutputAnchor(const NodePtr &node, const uint32_t num); - - static GeTensorDesc GetOutputDesc(const Node &node, const uint32_t index); - // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; - // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, - // the out param "is_unknow" will be true too - static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); - - static std::string GetNodeType(const Node &node); - static std::string GetNodeType(const NodePtr &node); - - static graphStatus GetDirectSubgraphs(const NodePtr &node, std::vector &subgraphs); - static ComputeGraphPtr GetSubgraph(const Node &node, const uint32_t index); - static graphStatus SetSubgraph(Node &node, const uint32_t index, const ComputeGraphPtr &subgraph); - /** - * @brief Add the subgraph to the node with the ir name, will not register the ir name with type - * @param node the node the subgraph will be added - * @param subgraph_ir_name the subgraph ir name - * @param subgraph the subgraph - * @return GRAPH_SUCCESS: success, others: failed - */ - static graphStatus AddSubgraph(Node &node, const std::string &subgraph_ir_name, const ComputeGraphPtr &subgraph); - /** - * @brief Add the static subgraph to the node with the given ir name, will register the ir name with as kStatic - * @param node_ptr the node the subgraph will be added - * @param subgraph_ir_name the subgraph ir name - * @param subgraph the subgraph - * @return GRAPH_SUCCESS: success, others: failed - */ - static graphStatus AddSubgraph(const NodePtr &node_ptr, const std::string &subgraph_ir_name, - const ComputeGraphPtr &subgraph); - /** - * @brief Add the dynamic subgraph to the node with the given ir name, will register the ir name with as kDynamic - * @param node_ptr the node the subgraph will be added - * @param subgraph_ir_name the subgraph ir name - * @param subgraphs vector of dynamic subgraphs - * @return GRAPH_SUCCESS: success, others: failed - */ - static graphStatus AddSubgraphs(const NodePtr &node_ptr, const std::string &subgraph_ir_name, - const std::vector &subgraphs); - static std::string GenDynamicSubgraphName(const std::string &subgraph_ir_name, int64_t index); - - static NodePtr CreatNodeWithoutGraph(const OpDescPtr op_desc); - /// Check if node is input of subgraph - /// @param [in] node - /// @return bool - static bool IsSubgraphInput(const NodePtr &node); - static bool IsSubgraphInput(const Node *const node); - - /// Check if node is output of subgraph - /// @param [in] node - /// @return bool - static bool IsSubgraphOutput(const NodePtr &node); - - /// @brief Get subgraph original input node. - /// @param [in] node - /// @return Node - static NodePtr GetParentInput(const Node &node); - static NodePtr GetParentInput(const NodePtr &node); - /// @brief Get subgraph original input node and corresponding out_anchor. - /// @param [in] node - /// @return NodeToOutAnchor node and out_anchor which linked to in_param node - static NodeToOutAnchor GetParentInputAndAnchor(const NodePtr &node); - /// @brief Get subgraph original input node and corresponding out_anchor corss subgraph. - /// @param [in] node - /// @return NodeToOutAnchor node and out_anchor which linked to in_param node - static NodeToOutAnchor GetParentInputAndAnchorCrossSubgraph(const NodePtr &node); - - /// @brief Get is dynamic shape graph from node. - /// @param [in] node - /// @return bool - static bool IsDynamicShape(const Node &node); - static bool IsDynamicShape(const NodePtr &node); - - /// @brief Check is varying_input for while node - /// @param [in] node: Data node for subgraph - /// @return bool - static bool IsWhileVaryingInput(const ge::NodePtr &node); - - /// @brief Get subgraph input is constant. - /// @param [in] node - /// @param [out] string - /// @return bool - static bool GetConstOpType(const NodePtr &node, std::string &type); - - /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. - /// @param [in] node - /// @return return GRAPH_SUCCESS if remove successfully, other for failed. - static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); - - /** - * 获取`node`挂载的所有子图中的索引为`index`的Data节点集合; - * 每个子图最多能找到一个跟`index`匹配的Data节点 - * @param node - * @param index - * @return - */ - static std::vector GetSubgraphDataNodesByIndex(const Node &node, const int32_t index); - - /** - * 获取`node`挂载的所有子图中的NetOutput节点集合; - * 每个子图有且只有一个NetOutput节点 - * @param node - * @return - */ - static std::vector GetSubgraphOutputNodes(const Node &node); - - /** - * 获取`node`所在的图对应的根图 - * @param node - * @return - */ - static ComputeGraphPtr FindRootGraph(const Node &node); - - /** - * 根据`node_filter`获取被node控制的输出节点 - * @param node - * @param node_filter 控制边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输出节点的获取 - * @return - */ - static std::vector GetOutControlNodes(const Node &node, const NodeFilter &node_filter); - /** - * 根据`node_filter`获取node的输出数据消费节点 - * @param node - * @param node_filter 数据边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输出节点的获取 - * @return - */ - static std::vector GetOutDataNodes(const Node &node, const NodeFilter &node_filter); - - /** - * 根据`node_filter`获取控制node的输入节点 - * @param node - * @param node_filter 控制边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输入节点的获取 - * @return - */ - static std::vector GetInControlNodes(const Node &node, const NodeFilter &node_filter); - /** - * 根据`node_filter`获取node的数据输入节点 - * @param node - * @param node_filter 数据边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输入节点的获取 - * @return - */ - static std::vector GetInDataNodes(const Node &node, const NodeFilter &node_filter); - - static NodePtr GetInDataNodeByIndex(const Node &node, const int32_t index); - static std::pair GetInDataNodeAndAnchorByIndex(const Node &node, const int32_t index); - - static std::vector> GetOutDataNodesWithAnchorByIndex(const Node &node, - const int32_t index); - - /** - * 适用于`node`节点作为子图中的Data占位节点时,获取根图中父节点对应的实际输入节点的类型 - * 其他情况返回`node`本身的节点类型 - * @param node - * @return - */ - static std::string GetInConstNodeTypeCrossSubgraph(const ge::NodePtr &node); - - /** -* 适用于`node`节点作为子图中的Data占位节点时,获取根图中父节点对应的实际输入节点对象 -* 其他情况返回`node`本身 -* @param node -* @return -*/ - static NodePtr GetInNodeCrossSubgraph(const ge::NodePtr &node); - - /// @brief Get peer input node, supported get cross PartitionedCall . - /// @param [in] node, current node - /// @param [in] index, current node the index'th input, if it is PartionedCall's subgraph Data, please assign 0 - /// @param [out] peer_node, - /// A(PartionedCall_0)->B(PartionedCall_1) - /// PartionedCall_0's subgraph: Data->A->Netoutput - /// PartionedCall_1's subgraph: Data1->B->Netoutput - /// If it is called like GetInNodeCrossPartionCallNode(B,0,peer_node)or(Data1,0,peer_node), peer_node is A - /// @param [out] peer_out_anchor_index, peer_node's corresponding out anchor's index - /// @return [graphStatus] running result of this function - static graphStatus GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node); - static graphStatus GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node, - int32_t &peer_out_anchor_index); - - static graphStatus SetNodeParallelGroup(Node &node, const char_t *const group_name); - - static graphStatus UpdateInputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape); - static graphStatus UpdateOutputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape); - static bool IsDtResourceNode(const NodePtr &node); - static bool IsLikeAtomicClean(const NodePtr &node); - /** - * 用于判断identity节点是否被用于控制先读后写顺序的,如果是的话, - * 则图优化的时候不能无脑删除identity节点来提升性能 - * @param node_ptr - * @return - */ - static bool IsIdentityUsefulForRWControl(const NodePtr &node_ptr); - /** - * 尝试通过pld占位节点对应的实际const节点来获取权重 - * @param node_ptr placeholder的占位节点,常见于图拆分中间状态的图的输入节点类型 - * @param ge_tensor 权重的承载对象,成功获取时ge_tensor被设置为非空 - * @return 失败时代表内部流程错误,成功时不代表一定获取到了权重 - */ - static graphStatus TryGetWeightByPlaceHolderNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor); - /** - * 尝试通过Data占位节点对应的实际const节点来获取权重 - * @param node_ptr Data占位节点,常见于子图的输入节点类型 - * @param ge_tensor 权重的承载对象,成功获取时ge_tensor被设置为非空 - * @return 失败时代表内部流程错误,成功时不代表一定获取到了权重 - */ - static graphStatus TryGetWeightByDataNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor); - /** - * 判断`node`的名称是否是`name` - * @param node - * @param name - * @return 如果是的话,返回true,否则 false - */ - static bool IsNameEqual(const NodePtr &node, const ge::char_t *const name); - /** - * 判断`node`的类型是否是`type` - * @param node - * @param type - * @return - */ - static bool IsTypeEqual(const NodePtr &node, const ge::char_t *const type); - - static NodePtr GetNodeWithMinimalId(const std::vector &nodes); -}; - -struct NodeCompareKey { - bool operator()(const NodePtr &n0, const NodePtr &n1) const { - if ((n0 == nullptr) || (n1 == nullptr)) { - return false; - } - int32_t comp_res = strcmp(n0->GetNamePtr(), n1->GetNamePtr()); - if (comp_res == 0) { - const auto graph0 = n0->GetOwnerComputeGraph(); - const auto graph1 = n1->GetOwnerComputeGraph(); - if ((graph0 == nullptr) || (graph1 == nullptr)) { - return false; - } - return (graph0->GetName() < graph1->GetName()); - } - return (comp_res < 0); - } -}; -using OrderedNodeSet = std::set; -} // namespace ge -/*lint +e148*/ -#endif // INC_GRAPH_UTILS_NODE_UTILS_H_ diff --git a/inc/graph/utils/node_utils_ex.h b/inc/graph/utils/node_utils_ex.h deleted file mode 100644 index 7e702139fee03be5430bdcc143089f67e40e066e..0000000000000000000000000000000000000000 --- a/inc/graph/utils/node_utils_ex.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_METADEF_NODE_UTILS_EX_H -#define __INC_METADEF_NODE_UTILS_EX_H - -#include "graph/node.h" -#include "graph/op_desc.h" - -namespace ge { -class NodeUtilsEx { - public: - // Detach from Node - static graphStatus Verify(const NodePtr &node); - static graphStatus InferShapeAndType(const NodePtr &node); - static graphStatus InferOriginFormat(const NodePtr &node); - // Detach from NodeUtils - static ConstNodePtr GetNodeFromOperator(const Operator &op); - static graphStatus SetNodeToOperator(Operator &op, const ConstNodePtr &node); - private: - static graphStatus IsInputsValid(const NodePtr &node); -}; -} // namespace ge -#endif // __INC_METADEF_NODE_UTILS_EX_H diff --git a/inc/graph/utils/object_pool.h b/inc/graph/utils/object_pool.h deleted file mode 100644 index 02cec46e799734b2f398a5041598cea3cb8c85f2..0000000000000000000000000000000000000000 --- a/inc/graph/utils/object_pool.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef EXECUTE_GRAPH_OBJECT_POOL_H -#define EXECUTE_GRAPH_OBJECT_POOL_H - -#include -#include -namespace ge { -constexpr size_t kDefaultPoolSize = 100UL; - -template -class ObjectPool { - public: - ObjectPool() = default; - ~ObjectPool() = default; - ObjectPool(ObjectPool &) = delete; - ObjectPool(ObjectPool &&) = delete; - ObjectPool &operator=(const ObjectPool &) = delete; - ObjectPool &operator=(ObjectPool &&) = delete; - - template - std::unique_ptr Acquire(Args &&...args) { - if (!handlers_.empty()) { - std::unique_ptr tmp(std::move(handlers_.front())); - handlers_.pop(); - return tmp; - } - return std::unique_ptr(new (std::nothrow) T(args...)); - } - - void Release(std::unique_ptr ptr) { - if ((handlers_.size() < N) && (ptr != nullptr)) { - handlers_.push(std::move(ptr)); - } - } - - bool IsEmpty() const { - return handlers_.empty(); - } - - bool IsFull() const { - return handlers_.size() >= N; - } - - private: - std::queue> handlers_; -}; -} // namespace ge -#endif // EXECUTE_GRAPH_OBJECT_POOL_H diff --git a/inc/graph/utils/op_desc_utils.h b/inc/graph/utils/op_desc_utils.h deleted file mode 100644 index 6f0931dff0c8999ffb13d755276b27a10b5a9544..0000000000000000000000000000000000000000 --- a/inc/graph/utils/op_desc_utils.h +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_ -#define INC_GRAPH_UTILS_OP_DESC_UTILS_H_ - -#include -#include -#include "graph/def_types.h" -#include "graph/node.h" -#include "graph/operator.h" -#include "graph/runtime_inference_context.h" - -/*lint -e148*/ -namespace ge { -using ConstGeTensorBarePtr = const GeTensor *; -class OpDescUtils { - public: - template - using Vistor = RangeVistor>; - using GetConstInputOnRuntimeFun = - std::function; - - OpDescUtils() = default; - ~OpDescUtils() = default; - static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); - static bool HasQuantizeFactorParams(const OpDesc& op_desc); - static std::vector GetConstInputNode(const ge::Node& node); - static std::vector GetConstInputNodeAndAnchor(const ge::Node &node); - static std::vector GetInputData(const std::vector& input_nodes); - static std::vector GetWeightsFromNodes( - const std::vector& input_nodes_2_out_anchors); - - static std::vector GetWeights(const ge::Node& node); - static std::vector GetWeights(const ge::ConstNodePtr& node); - static std::vector MutableWeights(const ge::Node& node); - static std::vector MutableWeights(const ge::NodePtr node); - static graphStatus SetWeights(ge::Node& node, const std::vector& weights); - static graphStatus SetWeights(ge::NodePtr node, const std::vector &weights); - static graphStatus SetWeights(ge::Node &node, const std::map &weights_map); - static graphStatus ClearWeights(const ge::NodePtr node); - static graphStatus SetNoneConstNodeWeights(ge::Node &node, const std::map &weights_map); - static graphStatus SetNoneConstNodeWeights(ge::Node &node, const std::vector &weights); - static bool ClearInputDesc(const ge::OpDescPtr op_desc, const uint32_t index); - static bool ClearInputDesc(const ge::NodePtr& node); - static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, const uint32_t index); - static bool ClearOutputDesc(const ge::NodePtr& node); - static std::vector GetConstInputs(const ge::Node& node, const uint32_t depth = 64U); - static std::vector GetConstInputs(const ge::ConstNodePtr& node); - static size_t GetNonConstInputsSize(const ge::Node& node); - static size_t GetNonConstInputsSize(const ge::ConstNodePtr node); - // Index: Indicates the index of all non const inputs - static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, const size_t index_non_const = 0UL); - static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, const size_t index_non_const = 0UL); - static bool GetNonConstInputIndex(const ge::Node& node, const size_t index_non_const, size_t& index); - static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, const size_t index_non_const, size_t& index); - // Index: Indicates the index of all inputs - static bool IsNonConstInput(const ge::Node& node, const size_t index = 0UL); - static bool IsNonConstInput(const ge::ConstNodePtr& node, const size_t index = 0UL); - - static std::vector GetNonConstTensorDesc(const ge::ConstNodePtr& node); - static graphStatus AddConstOpToAnchor(const InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); - - static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); - static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); - static OpDescPtr GetOpDescFromOperator(const Operator& oprt); - static graphStatus CopyOperatorLinks(const std::map &src_op_list, - std::map &dst_op_list); - static graphStatus CopyOperators(const ComputeGraphPtr &dst_compute_graph, - const std::map &node_old_2_new, - const std::map &op_desc_old_2_new, - const std::map &src_op_list, - std::map &dst_op_list); - static OpDescPtr CloneOpDesc(const ConstOpDescPtr &org_op_desc); - static OpDescPtr CopyOpDesc(const ConstOpDescPtr &org_op_desc); - static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); - static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr, const bool copy); - static OpDescPtr CreateConstOpZeroCopy(const GeTensorPtr& tensor_ptr); - - static graphStatus SetSubgraphInstanceName(const std::string &subgraph_name, - const std::string &subgraph_instance_name, OpDescPtr &op_desc); - static ConstGeTensorBarePtr GetInputConstData(const Operator &op, const uint32_t idx); - // deprecated - static void SetRuntimeContextToOperator(const Operator &op, RuntimeInferenceContext *const context); - static void SetCallbackGetConstInputFuncToOperator(const Operator &op, - GetConstInputOnRuntimeFun get_const_input_func); - static bool HasCallbackGetConstInputFunc(const Operator &op); - static std::map> GetInputIrIndexes2InstanceIndexesPairMap(const OpDescPtr &op_desc); - static std::map> GetOutputIrIndexes2InstanceIndexesPairMap( - const OpDescPtr &op_desc); - - static graphStatus GetIrInputInstanceDescRange(const OpDescPtr &op, - std::map> &ir_input_2_range); - - static graphStatus GetIrInputRawDescRange(const OpDescPtr &op, - std::map> &ir_input_2_range); - - static graphStatus GetIrOutputDescRange(const OpDescPtr &op, - std::map> &ir_output_2_range); - - static ge::graphStatus GetInputIrIndexByInstanceIndex(const OpDescPtr &op_desc, - size_t instance_index, size_t &ir_index); - static ge::graphStatus GetInstanceNum(const OpDescPtr &op_desc, size_t ir_index, size_t start_index, - const std::map &valid_index_2_names, - size_t &instance_num); - static ge::graphStatus GetPromoteIrInputList(const OpDescPtr &op_desc, - std::vector> &promote_index_list); - static ge::graphStatus GetPromoteInstanceInputList(const OpDescPtr &op_desc, - std::vector> &promote_index_list); - private: - static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); - static GeTensorPtr MutableWeights(const ge::OpDescPtr op_desc); - static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); - static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); -}; -} // namespace ge -/*lint +e148*/ -#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ diff --git a/inc/graph/utils/op_desc_utils_ex.h b/inc/graph/utils/op_desc_utils_ex.h deleted file mode 100644 index 410e2bdbc3151b99546390c7638ef85dbc0cb723..0000000000000000000000000000000000000000 --- a/inc/graph/utils/op_desc_utils_ex.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_METADEF_OP_DESC_UTILS_EX_H -#define __INC_METADEF_OP_DESC_UTILS_EX_H - -#include "graph/op_desc.h" - -namespace ge { -class OpDescUtilsEx { - public: - // Detach from OpDesc - static graphStatus CallInferFunc(const OpDescPtr &op_desc, Operator &op); - static graphStatus CallInferFormatFunc(const OpDescPtr &op_desc, Operator &op); - static graphStatus CallInferValueRangeFunc(const OpDescPtr &op_desc, Operator &op); - static graphStatus OpVerify(const OpDescPtr &op_desc); - static graphStatus InferShapeAndType(const OpDescPtr &op_desc); - static graphStatus InferDataSlice(const OpDescPtr &op_desc); - static void SetType(OpDescPtr &op_desc, const std::string &type); - static void ResetFuncHandle(OpDescPtr &op_desc); - static void SetTypeAndResetFuncHandle(OpDescPtr &op_desc, const std::string &type); - static void UpdateShapeAndDType(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); - - private: - static graphStatus CallInferFuncV1(const OpDescPtr &op_desc, Operator &op); - static graphStatus CallInferFuncV2(const OpDescPtr &op_desc, Operator &op); - static graphStatus InferShapeByOutputShapesAttr(const OpDescPtr &op_desc); - static graphStatus CallInferFormatFuncV1(const OpDescPtr &op_desc, Operator &op); - static graphStatus CallInferFormatFuncV2(const OpDescPtr &op_desc, Operator &op); -}; -} // namespace ge -#endif // __INC_METADEF_OP_DESC_UTILS_EX_H diff --git a/inc/graph/utils/op_type_utils.h b/inc/graph/utils/op_type_utils.h deleted file mode 100644 index 8142e97cd61c1b5a633d20bca167d02571fe84e5..0000000000000000000000000000000000000000 --- a/inc/graph/utils/op_type_utils.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef __INC_METADEF_OP_TYPE_UTILS_H -#define __INC_METADEF_OP_TYPE_UTILS_H -#include -#include "graph/node.h" - -namespace ge { -class OpTypeUtils { - public: - static bool IsDataNode(const std::string &type); - static bool IsInputRefData(const ge::OpDescPtr &op_desc); - static bool IsAutofuseNode(const std::string &type); - static bool IsAutofuseNode(const ge::OpDescPtr &op_desc); - static bool IsEmptyAutofuseNode(const std::string &type); - static bool IsVariableNode(const std::string &type); - static bool IsVarLikeNode(const std::string &type); - static bool IsAssignLikeNode(const std::string &type); - static bool IsIdentityLikeNode(const std::string &type); - static bool IsConstPlaceHolderNode(const std::string &type); - static graphStatus GetOriginalType(const ge::OpDescPtr &op_desc, std::string &type); - static bool IsSubgraphInnerData(const ge::OpDescPtr &op_desc); - // CONST/CONSTANT/CONSTPLACEHOLDER - static bool IsConstNode(const std::string &type); - // IsDataNode/IsInputRefData/IsVariableNode/IsVarLikeNode/IsConstNode - static bool IsGraphInputNode(const std::string &type); - // NETOUTPUT - static bool IsGraphOutputNode(const std::string &type); -}; -} // namespace ge -#endif // __INC_METADEF_OP_TYPE_UTILS_H diff --git a/inc/graph/utils/profiler.h b/inc/graph/utils/profiler.h deleted file mode 100644 index 1b98e58199d89142ee321a99de94083291806c40..0000000000000000000000000000000000000000 --- a/inc/graph/utils/profiler.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_PROFILER_H -#define METADEF_CXX_PROFILER_H -#include -#include -#include -#include -#include -#include "external/graph/types.h" - -namespace ge { -namespace profiling { -constexpr size_t kMaxStrLen = 256UL; -constexpr int64_t kMaxStrIndex = 1024 * 1024; -constexpr size_t kMaxRecordNum = 10UL * 1024UL * 1024UL; -enum class EventType { - kEventStart, - kEventEnd, - kEventTimestamp, - kEventTypeEnd -}; -struct ProfilingRecord { - int64_t element; - int64_t thread; - int64_t event; - EventType et; - std::chrono::time_point timestamp; -}; - -struct StrHash { - char_t str[kMaxStrLen]; - uint64_t hash; -}; - -class Profiler { - public: - static std::unique_ptr Create(); - void UpdateHashByIndex(const int64_t index, const uint64_t hash); - void RegisterString(const int64_t index, const std::string &str); - void RegisterStringHash(const int64_t index, const uint64_t hash, const std::string &str); - void Record(const int64_t element, const int64_t thread, const int64_t event, const EventType et, - const std::chrono::time_point time_point); - void RecordCurrentThread(const int64_t element, const int64_t event, const EventType et); - void RecordCurrentThread(const int64_t element, const int64_t event, const EventType et, - const std::chrono::time_point time_point); - - void Reset(); - void Dump(std::ostream &out_stream) const; - - size_t GetRecordNum() const noexcept; - const ProfilingRecord *GetRecords() const; - - using ConstStringHashesPointer = StrHash const(*); - using StringHashesPointer = StrHash (*); - ConstStringHashesPointer GetStringHashes() const; - StringHashesPointer GetStringHashes() ; - - ~Profiler(); - Profiler(); - - private: - void DumpByIndex(const int64_t index, std::ostream &out_stream) const; - - private: - std::atomic record_size_; - std::array records_; - StrHash indexes_to_str_hashes_[kMaxStrIndex]; -}; -} -} -#endif // METADEF_CXX_PROFILER_H diff --git a/inc/graph/utils/recover_ir_utils.h b/inc/graph/utils/recover_ir_utils.h deleted file mode 100644 index 1cea3993ab9357a2717b377cc932c5682d0f6442..0000000000000000000000000000000000000000 --- a/inc/graph/utils/recover_ir_utils.h +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#ifndef METADEF_CXX_INC_GRAPH_UTILS_RECOVER_IR_UTILS_H_ -#define METADEF_CXX_INC_GRAPH_UTILS_RECOVER_IR_UTILS_H_ -namespace ge { -class RecoverIrUtils { - public: - using InputIrDefs = std::vector>; - using OutputIrDefs = std::vector>; - template - using IrDefAppender = - std::function; - - struct IrDefinition { - bool inited{false}; - bool has_ir_definition{false}; - std::vector attr_names; - std::map attr_value; - InputIrDefs inputs; - OutputIrDefs outputs; - ge::OpDescPtr op_desc{nullptr}; - }; - static ge::graphStatus RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, - const std::string &op_type, - IrDefinition &ir_def); - static void InitIrDefinitionsIfNeed(const std::string &op_type, IrDefinition &ir_def); - static graphStatus RecoverIrAttrNames(const ge::OpDescPtr &desc, IrDefinition &ir_def); - static graphStatus RecoverIrInputAndOutput(const ge::OpDescPtr &desc, IrDefinition &ir_def); - static graphStatus RecoverIrDefinitions(const ge::ComputeGraphPtr &graph, const vector &attr_names = {}); - static graphStatus RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const std::string &op_type = ""); -}; -} - -#endif // METADEF_CXX_INC_GRAPH_UTILS_RECOVER_IR_UTILS_H_ diff --git a/inc/graph/utils/tensor_adapter.h b/inc/graph/utils/tensor_adapter.h deleted file mode 100644 index 89e8836f54966632c7cec115ddf91f31e81ccbc4..0000000000000000000000000000000000000000 --- a/inc/graph/utils/tensor_adapter.h +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ -#define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ - -#include -#include "graph/ge_tensor.h" -#include "graph/tensor.h" -#include "graph/ge_attr_value.h" - -namespace ge { -class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter { - public: - static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc); - static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc); - static Tensor GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor); - - static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value - static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value - static const GeTensor AsGeTensor(const Tensor &tensor); // Share value - static GeTensor AsGeTensor(Tensor &tensor); // Share value - static const Tensor AsTensor(const GeTensor &ge_tensor); // Share value - static Tensor AsTensor(GeTensor &ge_tensor); // Share value - static GeTensor AsGeTensorShared(const Tensor &tensor); - static GeTensor NormalizeGeTensor(const GeTensor &tensor); - static void NormalizeGeTensorDesc(GeTensorDesc &tensor_desc); - static const GeTensor* AsBareGeTensorPtr(const Tensor &tensor); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ diff --git a/inc/graph/utils/tensor_utils.h b/inc/graph/utils/tensor_utils.h deleted file mode 100644 index f11e692fde36e3444e9716bd19fb7443956b70d4..0000000000000000000000000000000000000000 --- a/inc/graph/utils/tensor_utils.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_ -#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ - -#include -#include "graph/attr_value_serializable.h" -#include "graph/def_types.h" -#include "graph/ge_error_codes.h" -#include "graph/ge_tensor.h" - -namespace ge { -class TensorUtils { - public: - static GeTensor CreateShareTensor(const GeTensor &other); - static GeTensor CreateShareTensor(const GeTensorDesc &tensor_desc, - std::shared_ptr aligned_ptr, - const size_t size); - static void ShareTensor(const GeTensor &from, GeTensor &to); - static TensorData CreateShareTensorData(const TensorData &other); - static void ShareTensorData(const TensorData &from, TensorData &to); - static void ShareAlignedPtr(std::shared_ptr ptr, const size_t size, TensorData &to); - static void ShareAlignedPtr(std::shared_ptr ptr, const size_t size, GeTensor &to); - static void CopyTensor(const GeTensor &from, GeTensor &to); - static ge::graphStatus GetSize(const GeTensorDesc &tensor_desc, int64_t &size); - static void SetSize(GeTensorDesc &tensor_desc, const int64_t size); - static int64_t GetWeightSize(const ConstGeTensorPtr &tensor_ptr); - static int64_t GetWeightSize(const GeTensor &tensor); - static int64_t GetWeightSize(const GeTensorDesc &tensor_desc); - static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, const uint8_t *const base); - static uint8_t *GetWeightAddr(const GeTensor &tensor, const uint8_t *const base); - static void SetWeightSize(GeTensorDesc &tensor_desc, const int64_t size); - static ge::graphStatus GetReuseInput(const GeTensorDesc &tensor_desc, bool &flag); - static void SetReuseInput(GeTensorDesc &tensor_desc, const bool flag); - static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensor_desc, bool &flag); - static void SetOutputTensor(GeTensorDesc &tensor_desc, const bool flag); - static graphStatus GetDeviceType(const GeTensorDesc &tensor_desc, DeviceType &type); - static void SetDeviceType(GeTensorDesc &tensor_desc, const DeviceType type); - static ge::graphStatus GetInputTensor(const GeTensorDesc &tensor_desc, bool &flag); - static void SetInputTensor(GeTensorDesc &tensor_desc, const bool flag); - static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensor_desc, uint32_t &cnt); - static void SetRealDimCnt(GeTensorDesc &tensor_desc, const uint32_t cnt); - static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx); - static void SetReuseInputIndex(GeTensorDesc &tensor_desc, const uint32_t idx); - static ge::graphStatus GetDataOffset(const GeTensorDesc &tensor_desc, int64_t &offset); - static void SetDataOffset(GeTensorDesc &tensor_desc, const int64_t offset); - static ge::graphStatus GetRC(const GeTensorDesc &tensor_desc, uint32_t &rc); - static void SetRC(GeTensorDesc &tensor_desc, const uint32_t rc); - static bool IsOriginShapeInited(const GeTensorDesc &tensor_desc); - - static ge::graphStatus CalcTensorMemSize(const GeShape &shape, const Format format, - const DataType data_type, int64_t &mem_size); - static ge::graphStatus CalcTensorMemSizeForNoTiling(const GeTensorDesc &tensor, - const Format format, - const DataType data_type, - int64_t &mem_size); - static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); - static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); - static ge::graphStatus CheckShapeByShapeRange(const GeShape &shape, - const std::vector> &shape_range); - static bool IsShapeEqual(const GeShape &src, const GeShape &dst); -}; -} // namespace ge -#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ diff --git a/inc/graph/yuv_subformat.h b/inc/graph/yuv_subformat.h deleted file mode 100644 index ad177bf6d92e505ba8fee5ae04b6a8776d3ebc7c..0000000000000000000000000000000000000000 --- a/inc/graph/yuv_subformat.h +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_YUV_SUBFORMAT_H_ -#define INC_YUV_SUBFORMAT_H_ -namespace ge { -enum YUVSubFormat { - YUV420_SP = 1, - YVU420_SP, - YUV422_SP, - YVU422_SP, - YUV440_SP, - YVU440_SP, - YUV444_SP, - YVU444_SP, - YUYV422_PACKED, - YVYU422_PACKED, - YUV444_PACKED, - YVU444_PACKED, - YUV400 -}; -} // namespace ge -#endif // INC_YUV_SUBFORMAT_H_ diff --git a/inc/register/custom_pass_helper.h b/inc/register/custom_pass_helper.h deleted file mode 100644 index dd57549d9c4b10697cb3033a76f395c29261f2f1..0000000000000000000000000000000000000000 --- a/inc/register/custom_pass_helper.h +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_CUSTOM_PASS_HELPER_H_ -#define INC_REGISTER_CUSTOM_PASS_HELPER_H_ - -#include -#include "external/ge_common/ge_api_error_codes.h" -#include "external/register/register_custom_pass.h" -#include "external/register/register_types.h" - -namespace ge { -class CustomPassHelper { - public: - static CustomPassHelper &Instance(); - - void Insert(const PassRegistrationData ®_data); - - Status Load(); - - Status Unload(); - - Status Run(GraphPtr &graph, CustomPassContext &custom_pass_context) const; - - Status Run(GraphPtr &graph, CustomPassContext &custom_pass_context, const CustomPassStage stage) const; - - ~CustomPassHelper() = default; - - private: - CustomPassHelper() = default; - std::vector registration_datas_; - std::vector handles_; -}; -} // namespace ge - -#endif // INC_REGISTER_CUSTOM_PASS_HELPER_H_ diff --git a/inc/register/ffts_node_converter_registry.h b/inc/register/ffts_node_converter_registry.h deleted file mode 100644 index 9b034ef545287d18d3bbea4ab7f934899b57b47e..0000000000000000000000000000000000000000 --- a/inc/register/ffts_node_converter_registry.h +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CONVERTER_REGISTRY_H_ -#define AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CONVERTER_REGISTRY_H_ -#include -#include -#include -#include "common/checker.h" -#include "node_converter_registry.h" -#include "graph/node.h" -#include "exe_graph/lowering/dev_mem_value_holder.h" -#include "exe_graph/lowering/lowering_global_data.h" - -namespace gert { -using FFTSPreThreadFunc = std::function &input_shapes, std::vector &output)>; -using FFTSPreThreadFuncNew = std::function &input_shapes, const std::vector &input_addrs, - std::vector &output)>; -using FFTSThreadFunc = std::function &input_shapes, - const std::vector &output_shapes, const bg::ValueHolderPtr thread_dim, - std::vector &output)>; - -struct SkipCtxRecord { - bool Init() { - ctx_id_v = std::unique_ptr>(new(std::nothrow) std::vector); - GE_ASSERT_NOTNULL(ctx_id_v); - ctx_type_v = std::unique_ptr>(new(std::nothrow) std::vector); - GE_ASSERT_NOTNULL(ctx_type_v); - return true; - } - size_t GetCtxNum() { - if (ctx_id_v == nullptr) { - return 0; - } - return ctx_id_v->size(); - } - bool SetSkipCtx(uint32_t ctx_id, uint32_t ctx_type) { - if (ctx_id_v == nullptr || ctx_type_v == nullptr) { - return false; - } - ctx_id_v->emplace_back(ctx_id); - ctx_type_v->emplace_back(ctx_type); - return true; - } - bool GetSkipCtx(size_t idx, uint32_t &ctx_id, uint32_t &ctx_type) { - if (ctx_id_v == nullptr || ctx_type_v == nullptr) { - return false; - } - if (idx >= ctx_id_v->size() || idx >= ctx_type_v->size()) { - return false; - } - ctx_id = ctx_id_v->at(idx); - ctx_type = ctx_type_v->at(idx); - return true; - } - void ClearRecord() { - if (ctx_id_v != nullptr) { - ctx_id_v->clear(); - } - if (ctx_type_v != nullptr) { - ctx_type_v->clear(); - } - return; - } - private: - std::unique_ptr> ctx_id_v{nullptr}; - std::unique_ptr> ctx_type_v{nullptr}; -}; - -struct FFTSLowerInput { - std::vector input_shapes; - std::vector input_addrs; - std::vector mem_pool_types; - LoweringGlobalData *global_data; - bg::ValueHolderPtr task_info; - bg::ValueHolderPtr thread_dim; - bg::ValueHolderPtr window_size; - bg::ValueHolderPtr args_para; - bg::ValueHolderPtr ffts_mem_allocator; - FFTSThreadFunc ffts_thread_fun; - bg::ValueHolderPtr skip_ctx_holder; -}; -class FFTSNodeConverterRegistry { - public: - using NodeConverter = LowerResult (*)(const ge::NodePtr &node, const FFTSLowerInput &lower_input); - struct ConverterRegisterData { - NodeConverter converter; - int32_t require_placement; - }; - static FFTSNodeConverterRegistry &GetInstance(); - NodeConverter FindNodeConverter(const std::string &func_name); - const ConverterRegisterData *FindRegisterData(const std::string &func_name) const; - void RegisterNodeConverter(const std::string &func_name, NodeConverter func); - void Register(const std::string &func_name, const ConverterRegisterData &data); - - private: - std::unordered_map names_to_register_data_; -}; - -class FFTSNodeConverterRegister { - public: - FFTSNodeConverterRegister(const char *lower_func_name, FFTSNodeConverterRegistry::NodeConverter func) noexcept; - FFTSNodeConverterRegister(const char *lower_func_name, int32_t require_placement, - FFTSNodeConverterRegistry::NodeConverter func) noexcept; -}; -} // namespace gert - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif - -#define GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER2(type, placement, func, counter) \ - static const gert::FFTSNodeConverterRegister g_register_node_converter_##counter ATTRIBUTE_USED = \ - gert::FFTSNodeConverterRegister(type, placement, func) -#define GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER(type, placement, func, counter) \ - GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER2(type, placement, func, counter) -#define FFTS_REGISTER_NODE_CONVERTER_PLACEMENT(type, placement, func) \ - GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER(type, placement, func, __COUNTER__) -#define FFTS_REGISTER_NODE_CONVERTER(type, func) FFTS_REGISTER_NODE_CONVERTER_PLACEMENT(type, -1, func) - -#endif // AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CONVERTER_REGISTRY_H_ diff --git a/inc/register/ffts_plus_engine_update.h b/inc/register/ffts_plus_engine_update.h deleted file mode 100644 index aaa83e56f376aceab7b0813cb2f137c72561fabd..0000000000000000000000000000000000000000 --- a/inc/register/ffts_plus_engine_update.h +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_FFTS_PLUS_ENGINE_UPDATE_H_ -#define INC_REGISTER_FFTS_PLUS_ENGINE_UPDATE_H_ -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" -#include "runtime/rt_ffts_plus.h" -#include "common/sgt_slice_type.h" -namespace ffts { -class FFTSPlusEngineUpdate { -public: - FFTSPlusEngineUpdate(); - ~FFTSPlusEngineUpdate(); - static bool UpdateCommonCtx(ge::ComputeGraphPtr &sgt_graph, rtFftsPlusTaskInfo_t &task_info); - static ThreadSliceMapDyPtr slice_info_ptr_; -}; -}; -#endif // INC_REGISTER_FFTS_PLUS_ENGINE_UPDATE_H_ diff --git a/inc/register/ffts_plus_task_update.h b/inc/register/ffts_plus_task_update.h deleted file mode 100644 index 65c6228d5904433be77176434c447e6ba8e65cab..0000000000000000000000000000000000000000 --- a/inc/register/ffts_plus_task_update.h +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_FFTS_PLUS_TASK_UPDATE_H_ -#define INC_REGISTER_FFTS_PLUS_TASK_UPDATE_H_ - -#include - -#include "graph/node.h" -#include "register/op_tiling_registry.h" -#include "runtime/rt_ffts_plus.h" -#include "external/ge_common/ge_api_error_codes.h" - -namespace ge { -struct AutoThreadSubTaskFlush { - int32_t device_id{0}; - void *args_base{nullptr}; - std::vector op_run_info; - - uintptr_t aic_non_tail_task_start_pc{0U}; - uintptr_t aic_tail_task_start_pc{0U}; - uint32_t aic_icache_prefetch_cnt{0U}; - - uintptr_t aiv_non_tail_task_start_pc{0U}; - uintptr_t aiv_tail_task_start_pc{0U}; - uint32_t aiv_icache_prefetch_cnt{0U}; - - // Task I/O Addrs. - std::vector input_addr_base; - std::vector output_addr_base; -}; - -struct AutoThreadParam { - uint16_t thread_dim{0U}; // thread dim after Pre-Thread - uint32_t input_output_num{0U}; // input + output - std::vector task_addr_offset; // input + output + workspace - - // Task Thread Dims. - std::vector>> *task_input_shape{nullptr}; // thread - std::vector>> *task_output_shape{nullptr}; // thread -}; - -class FFTSPlusTaskUpdate { - public: - FFTSPlusTaskUpdate() = default; - virtual ~FFTSPlusTaskUpdate() = default; - - virtual Status GetAutoThreadParam(const NodePtr &node, const std::vector &op_run_info, - AutoThreadParam &auto_thread_param) { - (void)node; - (void)op_run_info; - (void)auto_thread_param; - return SUCCESS; - } - - virtual Status UpdateSubTaskAndCache(const NodePtr &node, const AutoThreadSubTaskFlush &sub_task_flush, - rtFftsPlusTaskInfo_t &ffts_plus_task_info) { - (void)node; - (void)sub_task_flush; - (void)ffts_plus_task_info; - return SUCCESS; - } - - virtual Status UpdateCommonCtx(const ComputeGraphPtr &sgt_graph, rtFftsPlusTaskInfo_t &task_info) { - (void)sgt_graph; - (void)task_info; - return SUCCESS; - } - - virtual Status UpdateStaticDataCtx(size_t ctx_num, std::vector &io_addrs, size_t align_offset, - size_t host_io_base, std::map> &ctx_ids_map) { - (void)ctx_num; - (void)io_addrs; - (void)align_offset; - (void)host_io_base; - (void)ctx_ids_map; - return SUCCESS; - } -}; -} // namespace ge -#endif // INC_REGISTER_FFTS_PLUS_TASK_UPDATE_H_ diff --git a/inc/register/ffts_plus_update_manager.h b/inc/register/ffts_plus_update_manager.h deleted file mode 100644 index 798598224f453845eb5d3133211b49ae4282b397..0000000000000000000000000000000000000000 --- a/inc/register/ffts_plus_update_manager.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_FFTS_PLUS_UPDATE_MANAGER_H_ -#define INC_REGISTER_FFTS_PLUS_UPDATE_MANAGER_H_ - -#include -#include -#include -#include -#include - -#include "register/ffts_plus_task_update.h" - -namespace ge { -using FftsCtxUpdatePtr = std::shared_ptr; -using FftsCtxUpdateCreatorFun = std::function; -class PluginManager; - -class FftsPlusUpdateManager { - public: - static FftsPlusUpdateManager &Instance(); - /** - * For load so to register FFTSPlusTaskUpdate subclass constructor. - */ - Status Initialize(); - - /** - * Get FFTS Plus context by core type. - * @param core_type: core type of Node - * @return FFTS Plus Update instance. - */ - FftsCtxUpdatePtr GetUpdater(const std::string &core_type) const; - - class FftsPlusUpdateRegistrar { - public: - FftsPlusUpdateRegistrar(const std::string &core_type, const FftsCtxUpdateCreatorFun &creator) { - FftsPlusUpdateManager::Instance().RegisterCreator(core_type, creator); - } - ~FftsPlusUpdateRegistrar() = default; - }; - - private: - FftsPlusUpdateManager() = default; - ~FftsPlusUpdateManager(); - - /** - * Register FFTS Plus context update executor. - * @param core_type: core type of Node - * @param creator: FFTS Plus Update instance Creator. - */ - void RegisterCreator(const std::string &core_type, const FftsCtxUpdateCreatorFun &creator); - - std::map creators_; - std::unique_ptr plugin_manager_; -}; -} // namespace ge - -#define REGISTER_FFTS_PLUS_CTX_UPDATER(core_type, task_clazz) \ - REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ_HELPER(__COUNTER__, core_type, task_clazz) - -#define REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ_HELPER(ctr, type, clazz) \ - REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ(ctr, type, clazz) - -#define REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ(ctr, type, clazz) \ - ge::FftsPlusUpdateManager::FftsPlusUpdateRegistrar g_##type##_creator##ctr((type), []() { \ - return std::shared_ptr(new(std::nothrow) (clazz)()); \ - }) - -#endif // INC_REGISTER_FFTS_PLUS_UPDATE_MANAGER_H_ diff --git a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h b/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h deleted file mode 100644 index f2e418c95bc66f851d7da5db4f6aa08dcb327f4a..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ -#include -#include - -namespace fe { -const std::string MODIFY_GRAPH_IN_UB_FUSION_PASS = "ModifyGraphInUBFusionPass"; -const std::string UB_FUSION_OP_TYPE = "_ub_fusion_op_type"; -// add the op pattern -const std::string TBE_PATTERN_INPUT_NODE = "InputData"; -const std::string TBE_PATTERN_OP_TYPE_ANY = "OpTypeAny"; -const std::string TBE_PATTERN_OUTPUT_NODE = "OutputData"; -const std::string OP_PATTERN_ELEMWISE = "ElemWise"; -const std::string OP_PATTERN_COMMONREDUCE = "CommReduce"; -const std::string OP_PATTERN_BROAD_CAST = "Broadcast"; -const std::string OP_PATTERN_BROAD_CAST_NZ = "Broadcast_Nz"; -const std::string OP_PATTERN_SEGMENT = "Segment"; -const std::string OP_PATTERN_MAXPOOL = "MaxPool"; -const std::string OP_PATTERN_CONV = "Convolution"; -const std::string OP_PATTERN_MATMUL = "Matmul"; -const std::string OP_PATTERN_BNUPDATE = "bn_update"; -const std::string OP_PATTERN_BNREDUCE = "bn_reduce"; -const std::string OP_PATTERN_BNTUPLEREDUCE = "TupleReduce"; -const std::string OP_PATTERN_CONV_BACKPROP_INPUT = "Conv2d_backprop_input"; -const std::string OP_PATTERN_DEPTHWISE_CONV = "DepthwiseConvolution"; -const std::string OP_PATTERN_QUANT = "quant"; -const std::string OP_PATTERN_DEQUANT = "dequant"; -const std::string OP_PATTERN_REQUANT = "requant"; -const std::string OP_PATTERN_POOL2D = "Pool2d"; -const std::string OP_PATTERN_ANTIQUANT = "anti_quant"; -const std::string OP_PATTERN_STRIDED_WRITE = "strided_write"; -const std::string OP_PATTERN_STRIDED_READ = "strided_read"; -const std::string OP_PATTERN_AIPP = "aipp"; -const std::string OP_PATTERN_CONFUSION_TRANSPOSE = "confusiontranspose"; -const std::string OP_PATTERN_DEQUANTS16 = "dequant_s16"; -const std::string OP_PATTERN_REQUANTS16 = "requant_s16"; -const std::string OP_PATTERN_READ_SELECT = "read_select"; -const std::string OP_PATTERN_WRITE_SELECT = "write_select"; -const std::string OP_PATTERN_BATCH_MATMUL = "BatchMatmul"; -const std::string OP_PATTERN_CONV3D = "Conv3d"; -const std::string OP_PATTERN_DROPOUTDOMASKV3D = "DropOutDoMaskV3D"; -const std::string OP_PATTERN_CONV3D_BACKPROP_INPUT = "Conv3d_backprop_input"; -const std::string OP_PATTERN_CONV_BACKPROP_FILTER = "Conv2d_backprop_filter"; -const std::string OP_PATTERN_GEMM = "GEMM"; -const std::string OP_PATTERN_FIXPIPE = "fixpipe"; -const std::string OP_PATTERN_AVGPOOLUPDATE = "AvgPoolUpdate"; -const std::vector OP_PATTERN_VEC{OP_PATTERN_ELEMWISE, - OP_PATTERN_COMMONREDUCE, - OP_PATTERN_BROAD_CAST, - OP_PATTERN_BROAD_CAST_NZ, - OP_PATTERN_SEGMENT, - OP_PATTERN_MAXPOOL, - OP_PATTERN_CONV, - OP_PATTERN_MATMUL, - OP_PATTERN_BNUPDATE, - OP_PATTERN_BNREDUCE, - OP_PATTERN_BNTUPLEREDUCE, - OP_PATTERN_CONV_BACKPROP_INPUT, - OP_PATTERN_DEPTHWISE_CONV, - OP_PATTERN_QUANT, - OP_PATTERN_DEQUANT, - OP_PATTERN_REQUANT, - OP_PATTERN_POOL2D, - OP_PATTERN_ANTIQUANT, - OP_PATTERN_STRIDED_WRITE, - OP_PATTERN_STRIDED_READ, - OP_PATTERN_AIPP, - OP_PATTERN_CONFUSION_TRANSPOSE, - OP_PATTERN_DEQUANTS16, - OP_PATTERN_REQUANTS16, - OP_PATTERN_READ_SELECT, - OP_PATTERN_WRITE_SELECT, - OP_PATTERN_BATCH_MATMUL, - OP_PATTERN_CONV3D, - OP_PATTERN_DROPOUTDOMASKV3D, - OP_PATTERN_CONV3D_BACKPROP_INPUT, - OP_PATTERN_CONV_BACKPROP_FILTER, - OP_PATTERN_GEMM, - OP_PATTERN_FIXPIPE, - OP_PATTERN_AVGPOOLUPDATE -}; -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ diff --git a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h b/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h deleted file mode 100644 index 6e076cb1e63739003132a56d45e9f55a37b291eb..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ - -#include -#include -#include -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" -#include "register/graph_optimizer/fusion_common/op_slice_info.h" - -namespace fe { -enum BufferFusionPassType { - BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, - BUILT_IN_VECTOR_CORE_BUFFER_FUSION_PASS, - CUSTOM_AI_CORE_BUFFER_FUSION_PASS, - CUSTOM_VECTOR_CORE_BUFFER_FUSION_PASS, - BUFFER_FUSION_PASS_TYPE_RESERVED -}; - -class BufferFusionPassBase { - public: - explicit BufferFusionPassBase(); - virtual ~BufferFusionPassBase(); - virtual std::vector DefinePatterns() = 0; - virtual Status GetFusionNodes(const BufferFusionMapping &mapping, std::vector &fusion_nodes); - virtual Status GetMixl2FusionNodes(const BufferFusionMapping &mapping, std::vector &fusion_nodes); - virtual Status PostFusion(const ge::NodePtr &fused_node); - virtual Status CalcFusionOpSliceInfo(std::vector &fusion_nodes, OpCalcInfo &op_slice_info); - virtual Status CheckNodeCanFusion(const BufferFusionNodeDescMap &fusion_nodes, const ge::NodePtr &next_node); - static std::vector GetMatchedNodes(const BufferFusionMapping &mapping); - static std::vector GetMatchedNodesByDescName(const std::string &desc_name, - const BufferFusionMapping &mapping); - static ge::NodePtr GetMatchedHeadNode(const std::vector &matched_nodes); - static bool CheckNodeIsDynamicImpl(const ge::NodePtr &node); - static bool CheckTwoNodesImplConsistent(const ge::NodePtr &src_node, const ge::NodePtr &dst_node); - static bool CheckNodesImplConsistent(const BufferFusionMapping &mapping); - static bool CheckNodesImplConsistent(const std::vector &fusion_nodes); - static bool CheckNodeIsDynamicShape(const ge::NodePtr& node); - static bool CheckNodesIncDynamicShape(const BufferFusionMapping &mapping); - static bool CheckNodesIncDynamicShape(const std::vector &fusion_nodes); - void SetName(const std::string &name) { name_ = name; } - - std::string GetName() { return name_; } - - private: - std::string name_; -}; - -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ diff --git a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h b/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h deleted file mode 100644 index c193926288b02771562e52048fb92de24a33a90d..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_BUFFER_FUSION_PASS_REGISTRY_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_BUFFER_FUSION_PASS_REGISTRY_H_ -#include -#include -#include -#include -#include "register/graph_optimizer/fusion_common/fusion_pass_desc.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h" - -namespace fe { -class BufferFusionPassRegistry { - public: - using CreateFn = BufferFusionPassBase *(*)(); - struct PassDesc { - PassAttr attr; - CreateFn create_fn; - }; - ~BufferFusionPassRegistry(); - - static BufferFusionPassRegistry &GetInstance(); - - void RegisterPass(const BufferFusionPassType pass_type, const std::string &pass_name, - CreateFn create_fn, PassAttr attr); - - std::map GetPassDesc(const BufferFusionPassType &pass_type); - - std::map GetCreateFnByType(const BufferFusionPassType &pass_type); - - private: - BufferFusionPassRegistry(); - class BufferFusionPassRegistryImpl; - std::unique_ptr impl_; -}; - -class BufferFusionPassRegistrar { - public: - BufferFusionPassRegistrar(const BufferFusionPassType &pass_type, const std::string &pass_name, - BufferFusionPassBase *(*create_fun)(), PassAttr attr); - - ~BufferFusionPassRegistrar() {} -}; - -#define REGISTER_BUFFER_FUSION_PASS(pass_name, pass_type, pass_class) \ - REG_BUFFER_FUSION_PASS(pass_name, pass_type, pass_class, 0) - -#define REG_BUFFER_FUSION_PASS(pass_name, pass_type, pass_class, attr) \ - REG_BUFFER_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass_type, pass_class, attr) - -#define REG_BUFFER_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, pass_type, pass_class, attr) \ - REG_BUFFER_FUSION_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) - -#define REG_BUFFER_FUSION_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) \ - static ::fe::BufferFusionPassRegistrar register_buffer_fusion_##ctr __attribute__((unused)) = \ - ::fe::BufferFusionPassRegistrar( \ - (pass_type), (pass_name), \ - []() -> ::fe::BufferFusionPassBase * { return new (std::nothrow) pass_class();}, (attr)) -} // namespace fe -#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_BUFFER_FUSION_PASS_REGISTRY_H_ diff --git a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h b/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h deleted file mode 100644 index 5b2d48414effdb2995580ef1676b4a46dbf31e6b..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h +++ /dev/null @@ -1,152 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ -#include -#include -#include -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" - -namespace fe { -extern const int64_t TBE_FUSION_OP_NUM_MAX; -extern const int64_t TBE_PATTERN_NUM_MAX; -extern const int64_t TBE_PATTERN_NUM_NONE; -extern const int64_t TBE_PATTERN_NUM_DEFAULT; -extern const int64_t TBE_OUTPUT_BRANCH_DEFAULT; -extern const int64_t TBE_OUTPUT_BRANCH_SINGLE; -extern const int64_t TBE_OUTPUT_BRANCH_MULTI; -extern const int64_t TBE_PATTERN_GROUPID_INVALID; -extern const int32_t TBE_OUTPUT_MAX_NUM_LIMIT; - -enum SkipStatus { DISABLED = 0, AVAILABLE = 1, SKIPPED = 2 }; - -enum ShapeTypeRule { IGNORE_SHAPE_TYPE = 0, ONLY_SUPPORT_STATIC, ONLY_SUPPORT_DYNAMIC }; - -enum class PatternRelation { RELATIVE_POSITION_CONSISTENT = 0 }; - -extern const std::map kShapeTypeRuleToStr; - -struct BufferFusionOpDesc { - std::string desc_name; // description name - std::vector types; // description type - std::vector inputs; // all input op - std::vector outputs; // all output op - int64_t out_branch_type; // out desc type, 1:single, 2: multi - int64_t repeate_min; // opdesc min repeat num - int64_t repeate_max; // opdesc max repeat num - int64_t repeate_curr; // opdesc current repeat num - bool match_status; - bool not_pattern; - int64_t group_id; // record desc groupid, need one desc matched at least in - // the same group - std::vector shape_type_rules; - bool ignore_input_num; - bool ignore_output_num; - bool is_allow_series; // whether the nodes with the same pattern can be series in match graph - int32_t output_max_limit; - // used for two connected op, first opdesc has optional multiple nodes and - // ignore_output_num is true, second opdesc is same pattern type and - // out_branch_type is TBE_OUTPUT_BRANCH_MULTI - std::map multi_output_skip_status; - std::vector> relations; -}; - -struct MappingCmpKey { - bool operator() (const BufferFusionOpDesc *key1, const BufferFusionOpDesc *key2) const { - return (key1->desc_name) < (key2->desc_name); - } -}; -using BufferFusionMapping = std::map, MappingCmpKey>; -using BufferFusionNodeDescMap = std::unordered_map; - -class BufferFusionPattern { - public: - explicit BufferFusionPattern(std::string name = "", int64_t op_max_count = TBE_FUSION_OP_NUM_MAX); - - virtual ~BufferFusionPattern(); - - /* - * types vector use one ShapeTypeRule - */ - BufferFusionPattern &AddOpDesc(const std::string &desc_name, const std::vector &types, - const int64_t repeat_min = TBE_PATTERN_NUM_DEFAULT, - const int64_t repeat_max = TBE_PATTERN_NUM_DEFAULT, - const int64_t group_id = TBE_PATTERN_GROUPID_INVALID, - const ShapeTypeRule shape_type_rule = ONLY_SUPPORT_STATIC, - const bool not_pattern = false, const bool is_allow_series = true); - -/** - * add node desc - * @param desc_name - * @param types - * @param repeat_min - * @param repeat_max - * @param is_allow_series - * @return ref - */ - BufferFusionPattern &AddOpDesc(const std::string &desc_name, const std::vector &types, - const int64_t repeat_min, const int64_t repeat_max, const bool is_allow_series); - - /* - * types vector use ShapeTypeRule vector, and size should be same or ShapeTypeRule size equal 1 - */ - BufferFusionPattern &AddOpDescTypeRules(const std::string &desc_name, const std::vector &types, - const int64_t repeat_min, const int64_t repeat_max, const int64_t group_id, - const std::vector &shape_type_rules, - const bool not_pattern = false, const bool is_allow_series = true); - - BufferFusionPattern &SetOutputs(const std::string &desc_name, const std::vector &output_ids, - int64_t relation = TBE_OUTPUT_BRANCH_SINGLE, bool ignore_input_num = false, - bool ignore_output_num = false, int32_t output_max_limit = TBE_OUTPUT_MAX_NUM_LIMIT); - - BufferFusionPattern &SetHead(const std::vector &head_ids); - - /** - * add node desc - * @param desc_name - * @param types - * @param repeat_min - * @param repeat_max - * @param is_allow_series - * @return ref - */ - BufferFusionPattern &SetRelation(const std::string &src_desc_name, const std::string &dst_desc_name, - const PatternRelation pattern_relation); - - const std::string& GetName() const; - int64_t GetOpMaxCount() const; - const std::vector& GetOpDescs() const; - const std::vector& GetHead() const; - int64_t GetErrorCnt() const; - void SetGraphModType(int64_t graph_mod_type); - int64_t GetGraphModType() const; - bool GetOutputs(BufferFusionOpDesc *op_desc, std::vector &outputs, bool ignore_repeat = false); - void IncreaseErrorCount(); - - private: - bool IsOpDescValid(const std::string &desc_name, int64_t repeat_min, int64_t repeat_max) const; - bool IsShapeRulesSizeValid(const size_t &types_size, const size_t &rules_size) const; - BufferFusionOpDesc *GetOpDesc(const std::string &desc_name) const; - void UpdateSkipStatus(const BufferFusionOpDesc *op_desc) const; - - std::string name_; - int64_t op_max_count_; - std::vector ops_; - std::map op_map_; - std::vector head_; - int64_t error_count_; - // 0: this pattern will not modify graph(default) - // 1: this pattern will modify graph - int64_t graph_mod_type_; -}; -} // namespace fe -#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ diff --git a/inc/register/graph_optimizer/fusion_common/fusion_config_info.h b/inc/register/graph_optimizer/fusion_common/fusion_config_info.h deleted file mode 100644 index 0c84ed59ea5ce1818501dcd40d0c937ec389369d..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/fusion_config_info.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_CONFIG_INFO_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_CONFIG_INFO_H_ - -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" - -namespace fe { -class FusionConfigInfo { -public: - FusionConfigInfo(const FusionConfigInfo &) = delete; - FusionConfigInfo &operator=(const FusionConfigInfo &) = delete; - static FusionConfigInfo& Instance(); - Status Initialize(); - Status Finalize(); - bool IsEnableNetworkAnalysis() const; -private: - FusionConfigInfo() = default; - ~FusionConfigInfo() = default; - void InitEnvParam(); - bool is_init_ = false; - bool is_enable_network_analysis_ = false; -}; -} -#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_CONFIG_INFO_H_ diff --git a/inc/register/graph_optimizer/fusion_common/fusion_pass_desc.h b/inc/register/graph_optimizer/fusion_common/fusion_pass_desc.h deleted file mode 100644 index 62067ab8561de308761a25da2342a05b7d4b82f4..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/fusion_pass_desc.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_PASS_DESC_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_PASS_DESC_H_ -#include -#include "graph/option/optimization_option_info.h" -namespace fe { -using PassAttr = uint64_t; -const PassAttr FORBIDDEN_CLOSE = 0x01UL; // forbidden close, can not be closed by fusion switch -const PassAttr NEED_SORT = 0x02UL; // need topological sorting before executing -const PassAttr SINGLE_SCENE_OPEN = 0x04UL; // open for single op scene, can be close by fusion switch -const PassAttr FE_PASS = 0x08UL; // graph passes and ub passes in air project -constexpr PassAttr ENABLE_AUTO_FUSION = 0x10UL; // whether using auto match fusion frame -constexpr PassAttr ALWAYS_GENERALIZE = 0x20UL; -constexpr PassAttr PRUNING = 0x40UL; // whether the pass need to pre run before graph fusion -constexpr PassAttr ENABLE_FUSION_CHECK = 0x80UL; // enable do fusion check for matched fusion nodes during ub fusion -/* - * Compile level reg, if reg multi level, the lowest level will take effect. - * */ -constexpr PassAttr COMPILE_LEVEL_O0 = 0x100UL; // pure dynamic -constexpr PassAttr COMPILE_LEVEL_O1 = 0x200UL; // static functional optimize -constexpr PassAttr COMPILE_LEVEL_O2 = 0x400UL; // no time and space balance optimize -constexpr PassAttr COMPILE_LEVEL_O3 = 0x800UL; // open all optimize -constexpr PassAttr PASS_BIT_MASK = 0x1UL; // check if the loweset bit of pass is 1 - -enum class PassAttrType : int32_t { - FRBDN_CLOSE = 0, // Mark those passes that cannot be turned off in graph mode - NEED_TOPO_SORT = 1, // Mark those graph fusion passes that need topological sorting before executing - SINGLE_OP_SCENE_MUST_ON = 2, // Mark those passes that must be turned on in single-op mode or jit_compile=false - FE_PASS_FLAG = 3, // Mark those passes that belong to FE - AUTO_FUSION_FLAG = 4, // Using auto match fusion frame - /* The OpDescs in the patterns of this kind fusion pass are able to be generalized in all scenarios. - * For example, they can ignore the value dependency restrict. */ - ALWAYS_GENERALIZE_FLAG = 5, - PRUNING_FLAG = 6, - FUSION_CHECK_FLAG = 7, // Do fusion check for matched fusion nodes during ub fusion - COMPILE_O0 = 8, - COMPILE_O1 = 9, - COMPILE_O2 = 10, - COMPILE_O3 = 11, -}; -bool IsPassAttrTypeOn(PassAttr pass_attr, PassAttrType attr_type); -void RegPassCompileLevel(const std::string &pass_name, PassAttr pass_attr); -} // namespace fe -#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_PASS_DESC_H_ diff --git a/inc/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h b/inc/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h deleted file mode 100644 index 76759668df468d87a600072b579e475ca2fcab35..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H -#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H - -#include -#include -#include -#include -#include -#include - -namespace fe { - -class FusionInfo { - public: - explicit FusionInfo(const uint64_t session_id = 0, const std::string graph_id = "", const std::string pass_name = "", - const int32_t match_times = 0, const int32_t effect_times = 0, const int32_t repo_hit_times = 0); - - virtual ~FusionInfo(); - - void AddMatchTimes(const int32_t match_times); - - void AddEffectTimes(const int32_t effect_times); - - int32_t GetMatchTimes() const; - - void SetMatchTimes(const int32_t match_times); - - int32_t GetEffectTimes() const; - - void SetEffectTimes(const int32_t effect_times); - - int32_t GetRepoHitTimes() const; - - void SetRepoHitTimes(const int32_t repo_hit_times); - - std::string GetGraphId() const; - - std::string GetPassName() const; - - uint64_t GetSessionId() const; - - private: - uint64_t session_id_; - std::string graph_id_; - std::string pass_name_; - int32_t match_times_; - int32_t effect_times_; - int32_t repo_hit_times_; -}; - -using FusionStatisticMap = std::map>; - -class FusionStatisticRecorder { - public: - FusionStatisticRecorder(const FusionStatisticRecorder &) = delete; - - FusionStatisticRecorder &operator=(const FusionStatisticRecorder &) = delete; - - static FusionStatisticRecorder &Instance(); - - void UpdateGraphFusionMatchTimes(const FusionInfo &fusion_info); - - void UpdateGraphFusionEffectTimes(const FusionInfo &fusion_info); - - void UpdateBufferFusionMatchTimes(const FusionInfo &fusion_info); - - void UpdateBufferFusionEffectTimes(const FusionInfo &fusion_info); - - void GetAndClearFusionInfo(const std::string &session_graph_id, - std::map &graph_fusion_info_map, - std::map &buffer_fusion_info_map); - - void GetFusionInfo(const std::string &session_graph_id, std::map &graph_fusion_info_map, - std::map &buffer_fusion_info_map); - - void GetAllSessionAndGraphIdList(std::vector &session_graph_id_vec); - - private: - FusionStatisticRecorder(); - virtual ~FusionStatisticRecorder(); - FusionStatisticMap graph_fusion_info_map_; - FusionStatisticMap buffer_fusion_info_map_; - void ClearFusionInfo(const std::string& session_graph_id); - - std::recursive_mutex mutex_; -}; -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H diff --git a/inc/register/graph_optimizer/fusion_common/fusion_turbo.h b/inc/register/graph_optimizer/fusion_common/fusion_turbo.h deleted file mode 100644 index 2b59aabc252e802f1107eff20ed3fde0d29267b5..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/fusion_turbo.h +++ /dev/null @@ -1,249 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_H -#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_H -#include -#include -#include "graph/anchor.h" -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/model.h" -#include "graph/node.h" -#include "graph/utils/anchor_utils.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" -#include "register/graph_optimizer/fusion_common/fusion_turbo_utils.h" - -namespace fe { -enum TensorUptType { - UPDATE_NONE = 0, - UPDATE_THIS = 1, - UPDATE_PEER, -}; - -struct WeightInfo { - ge::GeShape shape; - ge::GeShape ori_shape; - ge::DataType datatype; - ge::DataType ori_datatype; - ge::Format format; - ge::Format ori_format; - uint8_t *data; - int64_t shape_size; - size_t total_data_size; // data_size * sizeof(datatype). !!!Could be zero!!! - inline void CalcTotalDataSize() { - if (shape.GetDimNum() == 0) { - shape_size = 1; - } else { - shape_size = shape.GetShapeSize(); - } - - if ((shape_size > 0) && (datatype < data_type_size.size())) { - total_data_size = (static_cast(shape_size)) * data_type_size[datatype]; - } else { - total_data_size = 0; - } - } - - WeightInfo(const ge::GeTensorDesc &tensor_desc, - void *data_p); - - WeightInfo(const ge::NodePtr &node, const int32_t &index, - void *data_p); - - WeightInfo(const ge::GeShape &shape_p, const ge::GeShape &ori_shape_p, - const ge::DataType &datatype_p, const ge::DataType &ori_datatype_p, - const ge::Format &format_p, const ge::Format &ori_format_p, void *data_p); - - WeightInfo(ge::GeShape &&shape_p, ge::GeShape &&ori_shape_p, - const ge::DataType &datatype_p, const ge::DataType &ori_datatype_p, - const ge::Format &format_p, const ge::Format &ori_format_p, void *data_p); - - WeightInfo(const ge::GeShape &shape_p, const ge::DataType &datatype_p, - const ge::Format &format_p, void *data_p); - - WeightInfo(ge::GeShape &&shape_p, const ge::DataType &datatype_p, - const ge::Format &format_p, void *data_p); -}; - -class FusionTurbo { - public: - explicit FusionTurbo(const ge::ComputeGraphPtr &graph); - - explicit FusionTurbo(ge::ComputeGraph &graph); - - ~FusionTurbo(); - - static Status BreakInput(const ge::NodePtr &node, - const vector &input_index); - - static Status BreakOutput(const ge::NodePtr &node, - const vector &output_index); - - static Status BreakAllInput(const ge::NodePtr &node); - - static Status BreakAllOutput(const ge::NodePtr &node); - - Status RemoveNodeWithRelink(const ge::NodePtr &node, const std::initializer_list &io_map = {}); - - Status RemoveNodeWithRelink(const ge::NodePtr &node, const std::vector &io_map = {}); - - Status RemoveNodeOnly(const ge::NodePtr &node); - - /* If the node has no subsequent nodes, remove it. - * If the node has subsequent nodes, just return. - * Parameter include_control_nodes: - * If only_care_data_nodes = true, then we will ignore the control outputs. */ - Status RemoveDanglingNode(const ge::NodePtr &node, const bool &only_care_data_nodes = false); - - Status RemoveMultiNodesOnly(const std::vector &nodes); - - ge::NodePtr UpdateConst(const ge::NodePtr &node, const int32_t &index, const WeightInfo &w_info) const; - - /* 1. If index is larger than or equalt to the input size of node, add a weight - * tensor and node as the last input of node. - * 2. If index is less than the input size of node and: - * 2.1 If the peer node of this input index is nullptr, we add a const node - * as input and update tensor desc. ---> Call AddConstNode. - * 2.2 If the peer node of this input index is Const, we substitute the data - * of current Const and update tensor desc. ---> Call UpdateConst - * 2.3 If the peer node of this input is other type, we just skip it. */ - ge::NodePtr AddWeight(const ge::NodePtr &node, const int32_t &index, const WeightInfo &w_info) const; - - /* Add weight after one output of node. For example: - * NodeA----> NodeB - * \----> NodeC - * After calling AddWeightAfter(NodeA, 0, w_info), the graph will be like: - * NewWeight----> NodeB - * \----> NodeC - * NodeA(will be dangling) - * The rule is adding weight in front of every peer out node of NodeA. - */ - ge::NodePtr AddWeightAfter(const ge::NodePtr &node, const int32_t &index, const WeightInfo &w_info) const; - - ge::NodePtr AddWeight(const ge::NodePtr &node, const string& tensor_name, const WeightInfo &w_info) const; - - /* Add a weight tensor and node as the last input of node. */ - ge::NodePtr AddWeight(const ge::NodePtr &node, const WeightInfo &w_info) const; - - std::vector AddWeights(const ge::NodePtr &node, - const vector &w_infos) const; - - static ge::GeTensorPtr MutableWeight(const ge::NodePtr &node, int32_t index); - - ge::NodePtr AddNodeOnly(const string &op_name, const string &op_type) const; - - static ge::NodePtr AddNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type); - - ge::NodePtr AddNodeOnly(const string &op_name, const string &op_type, - size_t dynamic_num) const; - - static ge::NodePtr AddNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type, - size_t dynamic_num); - - ge::NodePtr InsertNodeOnly(const string &op_name, const string &op_type, - const ge::NodePtr &origin_node, - const size_t dynamic_num = 0UL) const; - - static ge::NodePtr InsertNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type, - const ge::NodePtr &origin_node, - const size_t dynamic_num = 0UL); - - static ge::OpDescPtr CreateOpDesc(const string &op_name, - const string &op_type, const size_t dynamic_num); - - static Status TransferOutCtrlEdges(const std::vector &nodes, - const ge::NodePtr &new_node); - - static Status TransferInCtrlEdges(const std::vector &nodes, - const ge::NodePtr &new_node); - - ge::NodePtr InsertNodeBefore(const string &op_name, const string &op_type, - const ge::NodePtr &base_node, const int32_t &base_input_index, - const int32_t &input_index = 0, - const int32_t &output_index = 0) const; - - ge::NodePtr InsertNodeAfter(const string &op_name, const string &op_type, - const ge::NodePtr &base_node, const int32_t &base_output_index, - const int32_t &input_index = 0, const int32_t &output_index = 0) const; - - static Status LinkInput(Relations &input_relations, - const ge::NodePtr &dst_node, - const TensorUptType &update_tensor = UPDATE_THIS); - - static Status LinkOutput(Relations &output_relations, - const ge::NodePtr &src_node, - const TensorUptType &update_tensor = UPDATE_THIS); - - static ge::NodePtr GetPeerOutNode(const ge::NodePtr &node, const int32_t &this_node_input_index); - - static std::vector GetPeerInNodes(const ge::NodePtr &node, const int32_t &this_node_output_index); - - /* Check whether there is a path from [node1's] output [index1] to [node2]. - * The default value is -1 and -1 means any output is ok. */ - static bool CheckConnected(const ge::NodePtr &node1, const ge::NodePtr &node2, - const int32_t &index1 = -1); - - /* Default update input 0 of node. */ - Status UpdateInputByPeer(const ge::NodePtr &node, const int32_t &index, - const ge::NodePtr &peer_node, const int32_t &peer_index) const; - - Status UpdateOutputByPeer(const ge::NodePtr &node, const int32_t &index, - const ge::NodePtr &peer_node, const int32_t &peer_index) const; - - static bool IsUnknownShape(const ge::NodePtr &node, const int32_t &index, const bool &is_input = true); - - static bool IsUnknownOriShape(const ge::NodePtr &node, const int32_t &index, const bool &is_input = true); - - ge::NodePtr MultiInOne(const string &node_name, const string &node_type, - Relations &input_relations, - Relations &output_relations, - const std::vector &old_nodes = {}, - const bool &remove_old = true); - - Status MultiInOne(const ge::NodePtr &new_node, - Relations &input_relations, - Relations &output_relations, - const std::vector &old_nodes = {}, - const bool &remove_old = true); - - static bool HasControl(const ge::NodePtr &node); - - static bool HasInControl(const ge::NodePtr &node); - - static bool HasOutControl(const ge::NodePtr &node); - - static bool HasOutData(const ge::NodePtr &node); - - static Status MoveDataOutputUp(const ge::NodePtr &node, int32_t index); - - /* move node to pre node if pre node has subgraph - * @param node current need move node - * @param index node move input index - **/ - Status GraphNodeUpMigration(const ge::NodePtr &node, const int32_t index); - - /* move node to next node if next node has subgraph - * @param node current need move node - * @param index node move output index - **/ - Status GraphNodeDownMigration(const ge::NodePtr &node, const int32_t index); - - static NodeIndex GetPeerInFirstPair(const ge::NodePtr &node, int32_t index); - - static NodeIndex GetPeerOutPair(const ge::NodePtr &node, int32_t index); - private: - /* AddWeight will do either AddConstNode or UpdateConst. */ - ge::NodePtr AddConstNode(const ge::NodePtr &node, const WeightInfo &w_info, - const int32_t index) const; - - ge::ComputeGraphPtr graph_; -}; -} -#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_H diff --git a/inc/register/graph_optimizer/fusion_common/fusion_turbo_utils.h b/inc/register/graph_optimizer/fusion_common/fusion_turbo_utils.h deleted file mode 100644 index 8c81b202023897c9437bb8e202714fdaf2b1dcc6..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/fusion_turbo_utils.h +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_UTILS_H -#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_UTILS_H -#include -#include "graph/utils/op_desc_utils.h" -#include "graph/debug/ge_log.h" - -#define FUSION_TURBO_NOTNULL(val, ret) \ - do { \ - if ((val) == nullptr) { \ - GELOGD("Parameter[%s] must not be null.", #val); \ - return ret; \ - } \ - } while (0) - -namespace fe { -enum Direction { - CURRENT = 0, /* 表示NodeIndex指示的是当前节点的对应输入输出。 */ - /* 当连接输入的场景,PEER模式下会获取的对端输出节点和对端index。 */ - /* 当连接输出的场景,PEER模式下会获取的所有对端输入节点和所有对端index。 */ - PEER = 1, - /* 当连接输入的场景,PEER_SINGLE模式下会获取的对端输出节点和对端index。和PEER一致。 */ - /* 当连接输出的场景,PEER_SINGLE模式下会获取的第一个对端输出节点和对端index。 */ - PEER_SINGLE = 2 -}; - -struct NodeIndex { - ge::NodePtr node; - int32_t index; - Direction direction = CURRENT; - NodeIndex() { - node = nullptr; - index = -1; - } - NodeIndex(const ge::NodePtr &node_param, int32_t index_param) { - node = node_param; - index = index_param; - } - - NodeIndex(const std::pair &node_index_pair) { - node = node_index_pair.first; - index = node_index_pair.second; - } - - NodeIndex(const ge::NodePtr &node_param, int32_t index_param, Direction direction_param) { - node = node_param; - index = index_param; - direction = direction_param; - } -}; -using NodeIndices = std::vector; - -using ThisIndex = int32_t; - -class Relations { - public: - Relations(); - - Relations(const std::initializer_list &peer_indices); - - explicit Relations(const std::map &relations_param); - - explicit Relations(std::map &&relations_param); - - Relations(const Relations &relations_param); - - Relations(Relations &&relations_param) noexcept; - - Relations(ThisIndex this_index, const NodeIndex &peer_index); - - Relations(ThisIndex this_index, const NodeIndices &peer_indices); - - Relations(ThisIndex this_index, NodeIndex &&peer_index); - - Relations(ThisIndex this_index, NodeIndices &&peer_indices); - - Relations(const std::initializer_list> &peer_indices); - - Relations(const std::initializer_list>> &peer_indices_vec); - - /****** Interface Add from here. ******/ - Relations& Add(ThisIndex this_index, const NodeIndex &peer_index); - - Relations& Add(ThisIndex this_index, const std::initializer_list &peer_indices); - - Relations& Add(ThisIndex this_index, const NodeIndices &peer_indices); - - Relations& Add(ThisIndex this_index, NodeIndex &&peer_index); - - Relations& Add(ThisIndex this_index, NodeIndices &&peer_indices); - - /* 由于NodeIndex当连接输入或输出是完全不一样的,我们需要根据原始relations计算作为 - * 输入和输出的真正的对端节点,所以要求必须通过接口来修改relations。 */ - Relations& UpdatePeerIndex(ThisIndex this_index, const NodeIndices &peer_indices); - - Relations& UpdatePeerIndex(ThisIndex this_index, NodeIndices &&peer_indices); - - Relations& UpdatePeerIndex(const std::map &peer_indices); - - Relations& UpdatePeerIndex(std::map &&peer_indices); - - const std::map& GetRelations(); - - const std::map& GetInRelations(); - - const std::map& GetOutRelations(); - - Relations& operator=(const Relations &relations_param); - - Relations& operator=(Relations &&relations_param) noexcept; - private: - NodeIndex GetPeerInFirstPair(ThisIndex relation_index, const ge::NodePtr &node, int32_t index); - - void AppendPeerInAllPairs(ThisIndex relation_index, const ge::NodePtr &node, int32_t index); - - void PreProcessOneNodeIndex(ThisIndex index, const NodeIndex &node_index); - void PreProcessNodeIndices(ThisIndex index, const NodeIndices &node_indices); - - void PreProcess(); - /* 我们在添加ori_relations的时候就把两个方向的节点都计算好。 */ - std::map in_relations; - - std::map out_relations; - - /* 如果key是输出的index,那么vector里存放的就是对端输入的index;在单输出多引用场景, - * 对端输入的index可能有多个。 - * 如果key是输入的index,那么vector里存放的就是对端输出的index。对端输出只会有一个。 */ - std::map ori_relations; -}; - -extern const std::array(ge::DT_MAX + 1)> data_type_size; -class FusionTurboUtils { - public: - static NodeIndex GetPeerInFirstPair(const ge::NodePtr &node, int32_t index); - static NodeIndex GetPeerOutPair(const ge::NodePtr &node, int32_t index); - static ge::NodePtr GetConstInput(const ge::NodePtr &node, int32_t index); -}; -} -#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_UTILS_H diff --git a/inc/register/graph_optimizer/fusion_common/graph_pass_util.h b/inc/register/graph_optimizer/fusion_common/graph_pass_util.h deleted file mode 100644 index adf27735da263084257047762a7f15fc945cd2ef..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/graph_pass_util.h +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ -#include "graph/compute_graph.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/type_utils.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" -#include "external/graph/types.h" - -#include -#include -#include -#include -#include - -namespace fe { -enum class BackWardInheritMode { - kInsertNode = 0, - kFusedNode = 1, - kInheritTrue = 2, - kDoNotInherit = 3 -}; - -using NodeTypeMap = std::unordered_map>; -using NodeTypeMapPtr = std::shared_ptr; -struct NodeMapInfo { - int64_t run_count; - NodeTypeMapPtr node_type_map; -}; -using NodeMapInfoPtr = std::shared_ptr; -/** @brief define graph pass, which provides two interface: 1. run pass; -* 2. record op names before fusion */ -class GraphPassUtil { - public: - using OriginOpAttrsVec = std::vector>; - using UnorderedMapping = std::unordered_map; - /** set outputdesc attr for data dump - * - * @param origin_index,usually is origin node output index - * - * @param fusion_index,usually is fusion node output index - * - * @param origin_node, usually is origin node - * - * @param fusion_node, usually is fusion node - */ - static void SetOutputDescAttr(const uint32_t &origin_index, const uint32_t &fusion_index, - const ge::NodePtr &origin_node, const ge::NodePtr &fusion_node); - - static void SetOutputDescAttr(ge::ConstGeTensorDescPtr &origin_tensor_desc, const int64_t origin_index, - const ge::OpDescPtr &origin_op_desc, const ge::GeTensorDescPtr &target_tensor_desc); - - /** get origin format for data dump - * - * @param tensor_desc,usually is output_desc - * - * @return format of this tensor_desc - */ - static ge::Format GetDataDumpOriginFormat(const ge::GeTensorDescPtr &tensor_desc); - - static ge::Format GetDataDumpOriginFormat(ge::ConstGeTensorDescPtr &tensor_desc); - - /** set origin format for data dump - * - * @param origin format - * - * @param tensor_desc,usually is output_desc - */ - static void SetDataDumpOriginFormat(const ge::Format &origin_format, const ge::GeTensorDescPtr &tensor_desc); - - /** set origin datatype for data dump - * - * @param origin datatype - * - * @param tensor_desc,usually is output_desc - */ - static void SetDataDumpOriginDataType(const ge::DataType origin_data_type, const ge::GeTensorDescPtr &tensor_desc); - - /** get origin datatype for data dump - * - * @param tensor_desc,usually is output_desc - * - * @return format of this tensor_desc - */ - static ge::DataType GetDataDumpOriginDataType(const ge::GeTensorDescPtr &tensor_desc); - - static ge::DataType GetDataDumpOriginDataType(ge::ConstGeTensorDescPtr &tensor_desc); - - static void AddNodeFromOpTypeMap(const NodeMapInfoPtr &node_map_info, const ge::NodePtr &node_ptr); - - static Status GetOpTypeMapToGraph(NodeMapInfoPtr &node_map_info, const ge::ComputeGraph &graph); - - static void RecordPassnameAndOriginalAttrs(const std::vector &original_nodes, - std::vector &fus_nodes, const string &pass_name, - const OriginOpAttrsVec &origin_op_attrs = OriginOpAttrsVec()); - - static Status StoreAndUpdataOriginFusionPassName(const ge::OpDescPtr &op_desc, - const std::vector &original_nodes, - const std::string &pass_name); - - static void GetBackWardAttr(const std::vector &original_nodes, - bool &backward, BackWardInheritMode inherit_mode); - - static void InheritGraphRelatedAttr(const std::vector &original_nodes, - const std::vector &fusion_nodes, - BackWardInheritMode inherit_mode); - - /* If one of the original node has attribute like keep_dtype, the fused node - * will inherit that attribute. - * param inherit_mode: if fusion_nodes are newly inserted after original_nodes, - * backward attr will only care about its farther nodes(pass farther nodes in - * param original_nodes). - * And if fusion_nodes are fused by a bunch of original_nodes, the backward attr - * will not only care about original_nodes but also the input nodes of original_nodes. */ - static void InheritAttrFromOriNodes(const std::vector &original_nodes, - const std::vector &fusion_nodes, - BackWardInheritMode inherit_mode = BackWardInheritMode::kFusedNode); - - static void RecordOriginalOpAttrs(const std::vector &original_nodes, - const ge::OpDescPtr &op_desc, const string &pass_name, - const OriginOpAttrsVec &origin_op_attrs = OriginOpAttrsVec()); - - static void RecordOriginalNames(const std::vector &original_nodes, const ge::NodePtr &node); - - static void AddNodeToNodeTypeMap(const NodeTypeMapPtr &node_type_map, const std::string &op_type, - const ge::NodePtr &node_ptr); - - static void RemoveNodeFromNodeTypeMap(NodeTypeMapPtr &node_type_map, const std::string &op_type, - const ge::NodePtr &node_ptr); - - static void GetNodesFromNodeTypeMap(NodeTypeMapPtr &node_type_map, const std::string &op_type, - std::vector &nodes); - - static void GetOpCustomImplModeFromOriNode(const std::vector &original_nodes, - std::set &op_impl_mode_priority_set, - std::map &origin_node_impl_mode_map); - - static void SetOpCustomImplModeToFusNode(const ge::OpDescPtr &fusion_op, - const std::map &origin_node_impl_mode_map, - const std::set &op_impl_mode_priority_set); - - static void GetOpCustomGroupIdFromOriginNodes(const std::vector &original_nodes, - uint32_t ¶llel_group_id); - - static void SetOpCustomGroupIdToFusNode(const ge::OpDescPtr &fusion_op, const uint32_t ¶llel_group_id); - - static ge::OutDataAnchorPtr GetPeerOutAnchorNotInDeleteList(const ge::NodePtr &node, size_t idx); - - static void SetPairTensorIntAttr(const ge::NodePtr &node, size_t idx, const std::map &attr_val); - - static void SetPairTensorAttr(const ge::NodePtr &node, size_t idx, const std::map &attr_val, - bool is_input = true); -}; - -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ diff --git a/inc/register/graph_optimizer/fusion_common/op_slice_info.h b/inc/register/graph_optimizer/fusion_common/op_slice_info.h deleted file mode 100644 index ce7cefbf85d0a92b1fbbf4d66f1a1c59ec691af3..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/op_slice_info.h +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_COMMON_UTILS_AI_CORE_OP_SLICE_INFO_H -#define INC_COMMON_UTILS_AI_CORE_OP_SLICE_INFO_H - -#include -#include -#include - -namespace fe { -enum OpReduceType { REDUCE_MEAN = 0, REDUCE_ADD, REDUCE_MAX, REDUCE_MIN }; -enum OpL1FusionType { L1FUSION_DISABLE = 0, L1FUSION_BASIC, L1FUSION_INPUT_CTR }; - -class InputSplitInfoImpl; -using InputSplitInfoImplPtr = std::shared_ptr; -class InputSplitInfo; -using InputSplitInfoPtr = std::shared_ptr; -class OutputSplitInfoImpl; -using OutputSplitInfoImplPtr = std::shared_ptr; -class OutputSplitInfo; -using OutputSplitInfoPtr = std::shared_ptr; -class InputReduceInfoImpl; -using InputReduceInfoImplPtr = std::shared_ptr; -class InputReduceInfo; -using InputReduceInfoPtr = std::shared_ptr; -class OutputReduceInfoImpl; -using OutputReduceInfoImplPtr = std::shared_ptr; -class OutputReduceInfo; -using OutputReduceInfoPtr = std::shared_ptr; -class AxisSplitMapImpl; -using AxisSplitMapImplPtr = std::shared_ptr; -class AxisSplitMap; -using AxisSplitMapPtr = std::shared_ptr; -class AxisReduceMapImpl; -using AxisReduceMapImplPtr = std::shared_ptr; -class AxisReduceMap; -using AxisReduceMapPtr = std::shared_ptr; -class OpCalcInfoImpl; -using OpCalcInfoImplPtr = std::shared_ptr; -class OpCalcInfo; -using OpCalcInfoPtr = std::shared_ptr; - -class InputSplitInfo { - public: - InputSplitInfo(); - InputSplitInfo(const InputSplitInfo &input_split_info); - InputSplitInfo &operator = (const InputSplitInfo &input_split_info); - ~InputSplitInfo(); - bool Initialize(); - size_t GetIndex() const; - std::vector GetAxis() const; - std::vector GetHeadOverLap() const; - std::vector GetTailOverLap() const; - void SetIndex(const size_t& idx); - void SetAxis(std::vector& axis); - void SetHeadOverLap(std::vector& head_over_lap); - void SetTailOverLap(std::vector& tail_over_lap); - bool IsPtrNull() const; - private: - InputSplitInfoImplPtr split_impl_{nullptr}; -}; - -class OutputSplitInfo { - public: - OutputSplitInfo(); - OutputSplitInfo(const OutputSplitInfo &output_split_info); - OutputSplitInfo &operator = (const OutputSplitInfo &output_split_info); - ~OutputSplitInfo(); - bool Initialize(); - size_t GetIndex() const; - std::vector GetAxis() const; - void SetIndex(const size_t& idx); - void SetAxis(std::vector& axis); - bool IsPtrNull() const; - private: - OutputSplitInfoImplPtr split_impl_{nullptr}; -}; - -class InputReduceInfo { - public: - InputReduceInfo(); - InputReduceInfo(const InputReduceInfo &input_reduce_info); - InputReduceInfo &operator = (const InputReduceInfo &input_reduce_info); - ~InputReduceInfo(); - bool Initialize(); - size_t GetIndex() const; - std::vector GetAxis() const; - void SetIndex(const size_t& idx); - void SetAxis(std::vector& axis); - bool IsPtrNull() const; - private: - InputReduceInfoImplPtr reduce_impl_{nullptr}; -}; - -class OutputReduceInfo { - public: - OutputReduceInfo(); - OutputReduceInfo(const OutputReduceInfo &output_reduce_info); - OutputReduceInfo &operator = (const OutputReduceInfo &output_reduce_info); - ~OutputReduceInfo(); - bool Initialize(); - size_t GetIndex() const; - OpReduceType GetReduceType() const; - bool GetIsAtomic() const; - void SetIndex(const size_t& idx); - void SetReduceType(const OpReduceType& reduce_type); - void SetIsAtomic(const bool& is_atomic); - bool IsPtrNull() const; - private: - OutputReduceInfoImplPtr reduce_impl_{nullptr}; -}; - -class AxisSplitMap { - public: - friend class AxisSplitMapImpl; - AxisSplitMap(); - AxisSplitMap(const AxisSplitMap &axis_split_map); - AxisSplitMap &operator = (const AxisSplitMap &axis_split_map); - ~AxisSplitMap(); - bool Initialize(); - std::vector GetInputSplitInfos() const; - std::vector GetOutputSplitInfos() const; - std::vector GetInputSplitInfoVec() const; - std::vector GetOutputSplitInfoVec() const; - void AddInputSplitInfo(InputSplitInfo& input_split_info); - void SetInputSplitInfos(std::vector& input_split_vec); - void SetInputSplitInfos(std::vector& input_split_vec); - void AddOutputSplitInfo(OutputSplitInfo& output_split_info); - void SetOutputSplitInfos(std::vector& output_split_vec); - void SetOutputSplitInfos(std::vector& output_split_vec); - bool IsPtrNull() const; - private: - AxisSplitMapImplPtr aixs_split_impl_{nullptr}; -}; - -class AxisReduceMap { - public: - AxisReduceMap(); - AxisReduceMap(const AxisReduceMap &axis_reduce_map); - AxisReduceMap &operator = (const AxisReduceMap &axis_reduce_map); - ~AxisReduceMap(); - bool Initialize(); - friend class AxisReduceMapImpl; - std::vector GetInputReduceInfos() const; - std::vector GetOutputReduceInfos() const; - std::vector GetInputReduceInfoVec() const; - std::vector GetOutputReduceInfoVec() const; - void AddInputReduceInfo(InputReduceInfo& input_reduce_info); - void SetInputReduceInfos(std::vector& input_reduce_vec); - void SetInputReduceInfos(std::vector& input_reduce_vec); - void AddOutputReduceInfo(OutputReduceInfo& output_reduce_info); - void SetOutputReduceInfos(std::vector& output_reduce_vec); - void SetOutputReduceInfos(std::vector& output_reduce_vec); - bool IsPtrNull() const; - private: - AxisReduceMapImplPtr aixs_reduce_impl_{nullptr}; -}; - -class OpCalcInfo { - public: - OpCalcInfo(); - ~OpCalcInfo(); - bool Initialize(); - std::vector GetAxisSplitMaps() const; - std::vector GetAxisReduceMaps() const; - std::vector GetAxisSplitMapVec() const; - std::vector GetAxisReduceMapVec() const; - OpL1FusionType GetL1FusionEnable() const; - int64_t GetMinTbeL1Space() const; - void AddAxisSplitMap(AxisSplitMap& axis_split_map); - void SetAxisSplitMaps(std::vector& axis_split_vec); - void SetAxisSplitMaps(std::vector& axis_split_vec); - void AddAxisReduceMap(AxisReduceMap& axis_reduce_map); - void SetAxisReduceMaps(std::vector& axis_reduce_vec); - void SetAxisReduceMaps(std::vector& axis_reduce_vec); - void SetL1FusionEnable(const OpL1FusionType& l1_fusion_enable); - void SetMinTbeL1Space(const int64_t& min_tbe_l1_space); - void DelAxisSplitMapBaseAxis(std::vector& axis); - bool IsPtrNull() const; - private: - OpCalcInfoImplPtr op_calc_info_impl_{nullptr}; -}; -} // namespace fe -#endif // INC_COMMON_UTILS_AI_CORE_OP_SLICE_INFO_H diff --git a/inc/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h b/inc/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h deleted file mode 100644 index af2432bf14e87463537facd4da0d0150d29701f6..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h +++ /dev/null @@ -1,197 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ - -#include -#include -#include -#include -#include -#include - -#include "common/opskernel/ops_kernel_info_store.h" -#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" -#include "register/graph_optimizer/graph_fusion/graph_pass.h" -#include "register/graph_optimizer/graph_fusion/connection_matrix.h" - -namespace fe { -using std::initializer_list; -using std::map; -using std::string; -using std::vector; -using namespace std; - -using OpsKernelInfoStorePtr = std::shared_ptr; -class PatternFusionBasePassImpl; -using PatternFusionBasePassImplPtr = std::shared_ptr; - -/** Pass based on pattern - * @ingroup FUSION_PASS_GROUP - * @note New virtual methods should be append at the end of this class - */ -class PatternFusionBasePass : public GraphPass { - public: - using OpDesc = FusionPattern::OpDesc; - using Mapping = std::map, std::vector, CmpKey>; - using Mappings = std::vector; - - PatternFusionBasePass(); - ~PatternFusionBasePass() override; - - - /** execute pass - * - * @param [in] graph, the graph waiting for pass level optimization - * @return SUCCESS, successfully optimized the graph by the pass - * @return NOT_CHANGED, the graph did not change - * @return FAILED, fail to modify graph - */ - virtual Status Run(ge::ComputeGraph &graph) override; - - /** execute pass - * - * @param [in] graph, the graph waiting for pass level optimization - * @param [ops_kernel_info_store_ptr, OP info kernel instance - * @return SUCCESS, successfully optimized the graph by the pass - * @return NOT_CHANGED, the graph did not change - * @return FAILED, fail to modify graph - */ - virtual Status Run(ge::ComputeGraph &graph, OpsKernelInfoStorePtr ops_kernel_info_store_ptr); - - /* Detect whether there are cycles in graph - * after fusing all nodes in param fusion_nodes. - * - * Compared with Cycle Detection - * @param fusion_nodes: each vector in fusion_nodes - * will be fused into an entity(which could contains - * more than one node). The caller should put all original - * nodes which are expected to be fused into one larger node - * into each sub-vector of fusion_nodes. - * - * This function can tell whether there are a cycle after - * fusing all nodes in fusion_nodes. Each vector in 2-d - * vector fusion_nodes will be fused into an entity. - * - * - * This interface cannot detect whether there are cycles - * inside the fused nodes. - * - * e.g. {a, b, c, d} -> {e, f} - * Because the edge information is not given for e and f - * so this function we cannot tell if e and f are in a - * cycle. - * */ - bool CycleDetection(const ge::ComputeGraph &graph, const std::vector> &fusion_nodes); - - bool CycleDetection(const ge::ComputeGraph &graph, const std::vector &fusion_nodes); - - void GetConnectionMatrix(std::unique_ptr &connection_matrix); - - void SetConnectionMatrix(std::unique_ptr &connection_matrix); - - const std::vector &GetPatterns(); - - const std::vector &GetInnerPatterns(); - - bool MatchFromOutput(const ge::NodePtr &output_node, const std::shared_ptr &output_op_desc, Mapping &mapping); - - protected: - virtual std::vector DefinePatterns() = 0; - - virtual Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, std::vector &new_nodes) = 0; - - virtual std::vector DefineInnerPatterns(); - - virtual void SetDataDumpAttr(const std::vector &fused_nodes, - const std::vector &fusion_nodes); - - virtual void SetOriginalOutputDumpAttr(const std::vector &fused_nodes, - const std::vector &fusion_nodes); - - virtual void SetOriginalOpDumpAttr(const std::vector &fused_nodes, - const std::vector &fusion_nodes); - - void SetActualFusedNodes(const std::vector &fused_nodes); - - std::vector GetNodesFromMapping(const Mapping &mapping) const; - ge::NodePtr GetNodeFromMapping(const std::string &id, const Mapping &mapping) const; - - void RecordOutputAnchorMap(ge::NodePtr output_node); - - void ClearOutputAnchorMap(); - - bool CheckOpSupported(const ge::OpDescPtr &op_desc_ptr) const; - - bool CheckOpSupported(const ge::NodePtr &node) const; - - bool CheckAccuracySupported(const ge::NodePtr &node) const; - - /** check whether the input graph is Cyclic - * - * @param graph need to be checked - * @return false or true - */ - bool CheckGraphCycle(ge::ComputeGraph &graph) const; - - void DumpMapping(const FusionPattern &pattern, const Mapping &mapping) const; - - private: - /** match all nodes in graph according to pattern - * - * @param pattern fusion pattern defined - * @param mappings match result - * @return SUCCESS, successfully add edge - * @return FAILED, fail - */ - bool MatchAll(const ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings); - - Status RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed); - - /* Check whether there are cycles after fusing scope_nodes as an - * entity. The algorithm is: - * If one of the output node of scope nodes has an edged linked to - * the scope nodes again, there will be a cycle. - * e.g. - * A - * / \ - * B \ - * / \ - * D------->C - * | | - * After fusion A/B/C, the graph looks like: - * <--- - * / \ - * ABC--->D - * There obviously a cycle in the fused graph. - * */ - bool DetectOneScope(const std::vector &scope_nodes) const; - - bool CheckEachPeerOut(const ge::NodePtr &node, - const std::unordered_set &scope_nodes_set, - const std::vector &scope_nodes) const; - - void StoreOriginOpNames(const Mapping &mapping, std::vector &origin_op_names) const; - - /** Internal implement class ptr */ - std::shared_ptr pattern_fusion_base_pass_impl_ptr_; - - std::unordered_map> origin_op_anchors_map_; - - /* For detecting cycles, we will only build connectivity once. - * One time generation of connectivity needs O(n+e) where n is - * total number of nodes and e is total number of edges, which is - * not tolerable. And this requires one pass only executed once. - * */ - std::unique_ptr connectivity_{nullptr}; -}; -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ diff --git a/inc/register/graph_optimizer/fusion_common/unknown_shape_utils.h b/inc/register/graph_optimizer/fusion_common/unknown_shape_utils.h deleted file mode 100644 index 238eb351ee33367b9f1201c0188f71ba2d9c46da..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/fusion_common/unknown_shape_utils.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_UNKNOWN_SHAPE_UTILS_H -#define INC_REGISTER_GRAPH_OPTIMIZER_UNKNOWN_SHAPE_UTILS_H -#include "graph/utils/graph_utils.h" -namespace fe { -class UnknownShapeUtils { -public: - /* - * @ingroup fe - * @brief check whether the node is unknown shape. - * @param [in] input or output tensor. - * @return true: unknown; false: known - */ - static bool IsUnknownShapeOp(const ge::OpDesc &op_desc); - - /* - * @ingroup fe - * @brief check whether the input or output shape contains -2. - * @param op_desc input or output desc. - * @return true: contains; false: not contains - */ - static bool IsContainUnknownDimNum(const ge::OpDesc &op_desc); - - /* - * @brief check whether the value is -1 or -2 - * @param input or ourput shape dim - * @return true: contains; false: not contains - */ - static bool IsUnknownShapeValue(const int64_t &value); -private: - static bool IsUnKnownShapeTensor(const ge::OpDesc &op_desc); -}; -} // namespace fe - -#endif diff --git a/inc/register/graph_optimizer/graph_fusion/connection_matrix.h b/inc/register/graph_optimizer/graph_fusion/connection_matrix.h deleted file mode 100644 index d50151a2fbe542f67787d374becf93238b9c6602..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_fusion/connection_matrix.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_CONNECTION_MATRIX_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_CONNECTION_MATRIX_H_ - -#include "graph/debug/ge_attr_define.h" -#include "graph/node.h" -#include "graph/graph.h" -#include "graph/compute_graph.h" -#include "common/large_bm.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" - -namespace fe { -class ConnectionMatrix { -public: - ConnectionMatrix(); - explicit ConnectionMatrix(bool enable_data_flow); - explicit ConnectionMatrix(const ge::ComputeGraph &graph); - - ~ConnectionMatrix(); - - bool IsConnected(const ge::NodePtr &a, const ge::NodePtr &b) const; - - // inputs are all input nodes of parameter node. - // if there is a path between A->B, then B will own A's - // connectivity. The reason is --- - // If some node can reach A, than it can also reach B. - void SetConnectivity(const ge::Node::Vistor &inputs, const ge::NodePtr &node); - - bool IsDataConnected(const ge::NodePtr &a, const ge::NodePtr &b) const; - - /* Computes the connectivity between two nodes in the - * computation. The returned ConnectivityMatrix is constructed such that - * ConnectivityMatrix::IsConnected(a, b) returns true iff there exists a - * directed path (from producer to consumer) from 'a' to 'b'. Both data - * connection and control connection are considered for connectivity. - * A node is connected to itself. */ - void Generate(const ge::ComputeGraph &graph); - - // update reachablity map for fused nodes. - void Update(const ge::ComputeGraph &graph, const std::vector &fusion_nodes); - - void BackupBitMap(); - - void RestoreBitMap(); - -private: - int64_t GetIndex(const ge::NodePtr &node) const; - - const ge::LargeBitmap &GetBitMap(const ge::NodePtr &node) const; - - ge::LargeBitmap &GetBitMap(const ge::NodePtr &node); - - ge::LargeBitmap &GetBitMap(uint64_t index); - - const ge::LargeBitmap &GetDataBitMap(const ge::NodePtr &node) const; - - ge::LargeBitmap &GetDataBitMap(const ge::NodePtr &node); - - ge::LargeBitmap &GetDataBitMap(uint64_t index); - - void SetDataConnectivity(const ge::Node::Vistor &inputs, const ge::NodePtr &node); - - bool enable_data_flow_; - size_t size_ = 0; - std::vector bit_maps; - std::vector bit_maps_back_up_; - std::vector data_bit_maps_; - std::vector data_bit_maps_back_up_; - std::map name_to_index_; -}; -} -#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_CONNECTION_MATRIX_H_ diff --git a/inc/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h b/inc/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h deleted file mode 100644 index 85bbff0155c7d9afbfb6685283b860fc1bc93d0e..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_FUSION_PASS_MANAGER_FUSION_PASS_REGISTRY_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_FUSION_PASS_MANAGER_FUSION_PASS_REGISTRY_H_ - -#include -#include -#include -#include -#include "register/graph_optimizer/fusion_common/fusion_pass_desc.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" - -namespace fe { -class FusionPassRegistry { - public: - using CreateFn = GraphPass *(*)(); - struct PassDesc { - PassAttr attr; - CreateFn create_fn; - }; - ~FusionPassRegistry(); - - static FusionPassRegistry &GetInstance(); - - void RegisterPass(const GraphFusionPassType &pass_type, const std::string &pass_name, CreateFn create_fn, - PassAttr attr) const; - - std::map GetPassDesc(const GraphFusionPassType &pass_type); - - std::map GetCreateFnByType(const GraphFusionPassType &pass_type); - - private: - FusionPassRegistry(); - class FusionPassRegistryImpl; - std::unique_ptr impl_; -}; - -class FusionPassRegistrar { - public: - FusionPassRegistrar(const GraphFusionPassType &pass_type, const std::string &pass_name, - GraphPass *(*create_fn)(), PassAttr attr); - - ~FusionPassRegistrar() {} -}; - -#define REGISTER_PASS(pass_name, pass_type, pass_class) \ - REG_PASS(pass_name, pass_type, pass_class, 0) - -#define REG_PASS(pass_name, pass_type, pass_class, attr) \ - REG_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass_type, pass_class, attr) - -#define REG_PASS_UNIQ_HELPER(ctr, pass_name, pass_type, pass_class, attr) \ - REG_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) - -#define REG_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) \ - static ::fe::FusionPassRegistrar register_fusion_pass##ctr __attribute__((unused)) = ::fe::FusionPassRegistrar( \ - pass_type, pass_name, []() -> ::fe::GraphPass * { return new (std::nothrow) pass_class(); }, attr) -} // namespace fe -#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_FUSION_PASS_MANAGER_FUSION_PASS_REGISTRY_H_ diff --git a/inc/register/graph_optimizer/graph_fusion/fusion_pattern.h b/inc/register/graph_optimizer/graph_fusion/fusion_pattern.h deleted file mode 100644 index 99a6e11e4d7a57a80eb43cc06b46ac6314f182d4..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_fusion/fusion_pattern.h +++ /dev/null @@ -1,213 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ -#include -#include -#include -#include -#include - -namespace fe { -extern const uint32_t kFuzzyOutIndex; -/** Fusion pattern - * @ingroup FUSION_PASS_GROUP - * Describe Pattern of Ops waiting for fusion(Op type, etc) - */ -class FusionPattern { - public: - struct OpDesc; - using OpDescPtr = std::shared_ptr; - using OutputMapVecStr = std::initializer_list>>; - using OutputMapStr = std::initializer_list>; - using OutputMapDesc = std::map>; - /** - * @ingroup fe - * @brief description of Ops - */ - struct OpDesc { - std::string id; // Identifier - std::vector types; // the Op types of Ops - std::vector inputs; // all input Ops - OutputMapDesc outputs; // all output Ops - bool repeatable; // flag to show if match multiple Ops or not - bool check_unique; // flag op desc can be matched by only one node - bool is_output; // flag to show if the op is output node - bool is_output_fullmatch; // flag to match all output - size_t output_size; // all output size - bool allow_dumpable; - }; - - public: - explicit FusionPattern(const std::string name = ""); - ~FusionPattern(); - - /** set pattern name - * - * @param name pattern name - * @return FusionPattern - */ - FusionPattern &SetName(const std::string &name); - - /** add Op description with unknown number of args - * - * @param id pattern id - * @param types op type list - * @return FusionPattern - */ - FusionPattern &AddOpDesc(const std::string &id, const std::initializer_list &types = {}, - const bool allow_dumpable = false, const bool check_unique = false); - - /** add Op description with vector - * - * @param id pattern id - * @param types op type list - * - * @return FusionPattern - */ - FusionPattern &AddOpDesc(const std::string &id, const std::vector &types, - const bool allow_dumpable = false, const bool check_unique = false); - - /** set input Ops with unknown number of args - * - * @param id pattern id - * - * @param input_ids inputs to id op - * - * @return FusionPattern - */ - FusionPattern &SetInputs(const std::string &id, const std::initializer_list &input_ids); - - /** set input Ops with unknown number of args - * - * @param id pattern id - * - * @param input_ids inputs to id op - * - * @return FusionPattern - */ - FusionPattern &SetInputs(const std::string &id, const std::vector &input_ids); - - /** set output Ops with unknown number of args - * - * @param id pattern id - * - * @param output_map output map - * - * @param is_fullmatched flag of output full matched - * - * @return FusionPattern - */ - FusionPattern &SetOutputs(const std::string &id, const OutputMapStr &output_map, bool is_fullmatched = true); - - /** set output Ops with unknown number of args - * - * @param id pattern id - * - * @param output_map output map - * - * @param is_fullmatched flag of output full matched - * - * @return FusionPattern - */ - FusionPattern &SetOutputs(const std::string &id, const OutputMapVecStr &output_map, bool is_fullmatched = true); - - /** set output Op - * - * @param id pattern id - * - * @return FusionPattern - */ - FusionPattern &SetOutput(const std::string &id); - - /** build pattern and check if error exists - * - * @return True or False - */ - bool Build(); - - /** get pattern name - * - * @param id pattern id - * - * @return fusion pattern name - */ - const std::string &GetName() const; - - /** get the OpDesc of input Ops (const) - * - * @param op_desc op_desc for getting inputs - * - * @return op_desc's iniput opdesc list - */ - static const std::vector> *GetInputs(const std::shared_ptr op_desc); - - /** get the OpDesc of output Ops (const) - * - * @param op_desc op_desc for getting outputs - * - * @return op_desc's output opdesc map - */ - static const OutputMapDesc &GetOutputs(const OpDescPtr op_desc); - - /** get the OpDesc of output size - * - * @param op_desc op_desc for getting output size - * - * @return op_desc's output size - */ - static size_t GetOutputSize(const OpDescPtr op_desc); - - /** get the OpDesc of output Op - * - * @return pattern's output opdesc list - */ - const std::shared_ptr GetOutput() const; - - /** print pattern - * - */ - void Dump() const; - - /** get OpDesc based on ID, return nullptr if failed - * - * @param id pattern id - * - * @return pattern's output opdesc list - */ - std::shared_ptr GetOpDesc(const std::string &id) const; - - const std::vector> &GetOpDescs() const; - private: - FusionPattern(const FusionPattern &) = default; - FusionPattern &operator=(const FusionPattern &) = default; - - void SetError(); - - private: - std::string name_; - - std::vector> ops_; - - std::map> op_map_; - - std::shared_ptr output_; - - bool has_error_ = false; -}; -struct CmpKey { - bool operator() (const std::shared_ptr &key1, - const std::shared_ptr &key2) const { - return (key1->id) < (key2->id); - } -}; -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ diff --git a/inc/register/graph_optimizer/graph_fusion/fusion_quant_util.h b/inc/register/graph_optimizer/graph_fusion/fusion_quant_util.h deleted file mode 100644 index 50a1455709a5c1c816effd764858cefff4091c42..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_fusion/fusion_quant_util.h +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_FUSION_QUANT_UTIL_H_ -#define INC_FUSION_QUANT_UTIL_H_ -#include "graph/node.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" -#include - -struct BiasOptimizeEdges { - ge::InDataAnchorPtr quant_scale; - ge::InDataAnchorPtr quant_offset; - ge::InDataAnchorPtr cube_weight; - ge::InDataAnchorPtr cube_bias; - ge::InDataAnchorPtr deq_scale; - bool isValid() { - return !(cube_weight == nullptr || cube_bias == nullptr); - } -}; - -namespace fe { -struct QuantParam { - float quant_scale; - float quant_offset; -}; - -enum class WeightMode { - WEIGHTWITH2D = 0, - WEIGHTWITH5D = 1, - RESERVED -}; - -class QuantUtil { - public: - static Status BiasOptimizeByEdge(BiasOptimizeEdges ¶m, std::vector &fusion_nodes); - static Status BiasOptimizeByEdge(ge::NodePtr &quant_node, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes); - static Status BiasOptimizeByEdge(QuantParam &quant_param, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes, - WeightMode cube_type = WeightMode::RESERVED); - static Status InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr deq_scale, std::vector &fusion_nodes); - static Status InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr &deq_scale, ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes); - static Status InsertQuantScaleConvert(ge::InDataAnchorPtr &quant_scale, ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes); - static Status InsertRequantScaleConvert(ge::InDataAnchorPtr &req_scale, ge::InDataAnchorPtr &quant_offset, - ge::InDataAnchorPtr &cuba_bias, std::vector &fusion_nodes); -}; -} // namespace fe -#endif diff --git a/inc/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h b/inc/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h deleted file mode 100644 index 3e7e772e40cf299784bd0fe3ea44e0afb80d7447..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h +++ /dev/null @@ -1,118 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ - -#include -#include -#include -#include -#include - -#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" -#include "register/graph_optimizer/graph_fusion/graph_pass.h" - -namespace fe { -using std::initializer_list; -using std::map; -using std::string; -using std::vector; -using namespace std; - -enum GraphFusionPassType { - BUILT_IN_GRAPH_PASS = 0, - BUILT_IN_VECTOR_CORE_GRAPH_PASS, - CUSTOM_AI_CORE_GRAPH_PASS, - CUSTOM_VECTOR_CORE_GRAPH_PASS, - SECOND_ROUND_BUILT_IN_GRAPH_PASS, - BUILT_IN_BEFORE_TRANSNODE_INSERTION_GRAPH_PASS, - BUILT_IN_PREPARE_GRAPH_PASS, - BUILT_IN_BEFORE_QUANT_OPTIMIZATION_GRAPH_PASS, - BUILT_IN_TF_TAG_NO_CONST_FODING_GRAPH_PASS, - BUILT_IN_TF_MERGE_SUB_GRAPH_PASS, - BUILT_IN_QUANT_OPTIMIZATION_GRAPH_PASS, - BUILT_IN_EN_ISA_ARCH_EXC_V300_AND_V220_GRAPH_PASS, - BUILT_IN_EN_ISA_ARCH_V100_GRAPH_PASS, - BUILT_IN_EN_ISA_ARCH_V200_GRAPH_PASS, - BUILT_IN_DELETE_NO_CONST_FOLDING_GRAPH_PASS, - BUILT_IN_AFTER_MULTI_DIMS_PASS, - BUILT_IN_AFTER_OPTIMIZE_STAGE1, - BUILT_IN_AFTER_OP_JUDGE, - BUILT_IN_AFTER_BUFFER_OPTIMIZE, - GRAPH_FUSION_PASS_TYPE_RESERVED -}; -class PatternFusionBasePassImpl; -using PatternFusionBasePassImplPtr = std::shared_ptr; - -/** Pass based on pattern - * @ingroup FUSION_PASS_GROUP - * @note New virtual methods should be append at the end of this class - */ -class GraphFusionPassBase : public GraphPass { - public: - using OpDesc = FusionPattern::OpDesc; - using Mapping = std::map, std::vector, CmpKey>; - using Mappings = std::vector; - - GraphFusionPassBase(); - virtual ~GraphFusionPassBase() override; - - /** execute pass - * - * @param [in] graph, the graph waiting for pass level optimization - * @return SUCCESS, successfully optimized the graph by the pass - * @return NOT_CHANGED, the graph did not change - * @return FAILED, fail to modify graph - */ - virtual Status Run(ge::ComputeGraph &graph) override; - - protected: - /** define pattern - * - * @return NA - */ - virtual std::vector DefinePatterns() = 0; - - /** do fusion according to nodes matched - * - * @param graph the graph waiting for pass level optimization - * @param new_nodes fusion result node(s) - * @return SUCCESS, successfully optimized the graph by the pass - * @return NOT_CHANGED, the graph did not change - * @return FAILED, fail to modify graph - */ - virtual Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, std::vector &new_nodes) = 0; - - /** get nodes from matched result - * - * @param mapping match result - * @return nodes result - */ - static ge::NodePtr GetNodeFromMapping(const std::string &id, const Mapping &mapping); - - private: - /** match all nodes in graph according to pattern - * - * @param pattern fusion pattern defined - * @param mappings match result - * @return SUCCESS, successfully add edge - * @return FAILED, fail - */ - bool MatchAll(const ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings) const; - - Status RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed); - - /** Internal implement class ptr */ - std::shared_ptr pattern_fusion_base_pass_impl_ptr_; -}; - -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ diff --git a/inc/register/graph_optimizer/graph_fusion/graph_pass.h b/inc/register/graph_optimizer/graph_fusion/graph_pass.h deleted file mode 100644 index 149faf0a14aaec93b35fc2009609ec4dd07367c1..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_fusion/graph_pass.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ - -#include "register/graph_optimizer/graph_fusion/pass.h" - -namespace fe { - -/** graph pass - * @ingroup GRAPH_PASS_GROUP - * graph level pass - */ -class GraphPass : public Pass { - public: - /** execute pass - * - * @param [in] graph, the graph waiting for pass level optimization - * @return SUCCESS, successfully optimized the graph by the pass - * @return NOT_CHANGED, the graph did not change - * @return FAILED, fail to modify graph - */ - virtual Status Run(ge::ComputeGraph &graph) override = 0; -}; - -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ diff --git a/inc/register/graph_optimizer/graph_fusion/pass.h b/inc/register/graph_optimizer/graph_fusion/pass.h deleted file mode 100644 index 643c2e93d18c18261055464986368eb8c248445d..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_fusion/pass.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -/** @defgroup FUSION_PASS_GROUP Fusion Pass Interface */ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ -#define INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ - -#include "graph/compute_graph.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" - -namespace fe { - -/** fusion pass - * @ingroup GRAPH_PASS_GROUP - * network level pass - */ -template -class Pass { - public: - virtual ~Pass() {} - - /** execute pass - * - * @param [in] graph, the graph waiting for pass level optimization - * @return SUCCESS, successfully optimized the graph by the pass - * @return NOT_CHANGED, the graph did not change - * @return FAILED, fail to modify graph - */ - virtual Status Run(ge::ComputeGraph &graph) = 0; - - void SetName(const std::string &name) { name_ = name; } - - std::string GetName() { return name_; } - - private: - std::string name_; -}; - -} // namespace fe - -#endif // INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ diff --git a/inc/register/graph_optimizer/graph_optimize_register_error_codes.h b/inc/register/graph_optimizer/graph_optimize_register_error_codes.h deleted file mode 100644 index 2b7054b646e4036d62b63ffacf549e9bc17f45e5..0000000000000000000000000000000000000000 --- a/inc/register/graph_optimizer/graph_optimize_register_error_codes.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ -#define INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ - -#include -#include -#include - -namespace fe { - -/** Assigned SYS ID */ -const uint8_t SYSID_FE = 3; - -/** Common module ID */ -const uint8_t FE_MODID_COMMON = 50; - -/** FE error code definiton Macro -* Build error code -*/ -#define FE_DEF_ERRORNO(sysid, modid, name, value, desc) \ - static constexpr fe::Status name = \ - ((((static_cast((0xFF) & (static_cast(sysid)))) << 24) | \ - ((static_cast((0xFF) & (static_cast(modid)))) << 16)) | \ - ((0xFFFF) & (static_cast(value)))); - -using Status = uint32_t; - -#define FE_DEF_ERRORNO_COMMON(name, value, desc) \ - FE_DEF_ERRORNO(SYSID_FE, FE_MODID_COMMON, name, value, desc) - -using Status = uint32_t; - -FE_DEF_ERRORNO(0, 0, SUCCESS, 0, "success"); -FE_DEF_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFF, "failed"); -FE_DEF_ERRORNO_COMMON(NOT_CHANGED, 201, "The nodes of the graph not changed."); -FE_DEF_ERRORNO_COMMON(PARAM_INVALID, 1, "Parameter's invalid!"); -FE_DEF_ERRORNO_COMMON(GRAPH_FUSION_CYCLE, 301, "Graph is cycle after fusion!"); - -} // namespace fe -#endif // INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ diff --git a/inc/register/host_cpu_context.h b/inc/register/host_cpu_context.h deleted file mode 100644 index e28dadc941f2c6966b0bfd300cc007b3935026d0..0000000000000000000000000000000000000000 --- a/inc/register/host_cpu_context.h +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_HOST_CPU_CONTEXT_H_ -#define INC_REGISTER_HOST_CPU_CONTEXT_H_ - -#include "external/ge_common/ge_api_error_codes.h" -#include "register/register_types.h" - -namespace ge { -class HostCpuContext { - public: - HostCpuContext() = default; - ~HostCpuContext() = default; - private: - class Impl; - Impl *impl_; -}; -} // namespace ge - -#endif // INC_REGISTER_HOST_CPU_CONTEXT_H_ diff --git a/inc/register/kernel_registry.h b/inc/register/kernel_registry.h deleted file mode 100644 index 4a29665eb903a49d1a012151a90848639633724a..0000000000000000000000000000000000000000 --- a/inc/register/kernel_registry.h +++ /dev/null @@ -1,91 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef B369E37D560547C2B8DC137404F9713E_H -#define B369E37D560547C2B8DC137404F9713E_H -#include -#include -#include -#include -#include "graph/ge_error_codes.h" -#include "graph/types.h" -#include "graph/fast_graph/fast_node.h" -#include "exe_graph/runtime/base_type.h" -#include "exe_graph/runtime/kernel_context.h" -#include "exe_graph/runtime/dfx_info_filler.h" - -namespace ge { -class Node; -} // namespace ge - -namespace gert { -class KernelRegistry { - public: - static KernelRegistry &GetInstance(); - static void ReplaceKernelRegistry(std::shared_ptr registry); - - using CreateOutputsFunc = UINT32 (*)(const ge::FastNode *, KernelContext *); - using KernelFunc = UINT32 (*)(KernelContext *context); - using TracePrinter = std::vector (*)(const KernelContext *); - using ProfilingInfoFiller = ge::graphStatus (*)(const KernelContext *, ProfilingInfoWrapper &); - using DataDumpInfoFiller = ge::graphStatus (*)(const KernelContext *, DataDumpInfoWrapper &); - using ExceptionDumpInfoFiller = ge::graphStatus (*)(const KernelContext *, ExceptionDumpInfoWrapper &); - - struct KernelFuncs { - KernelFunc run_func; - CreateOutputsFunc outputs_creator; - TracePrinter trace_printer; - ProfilingInfoFiller profiling_info_filler; - DataDumpInfoFiller data_dump_info_filler; - ExceptionDumpInfoFiller exception_dump_info_filler; - }; - - struct KernelInfo { - KernelFuncs func; - std::string critical_section; - }; - - virtual ~KernelRegistry() = default; - virtual const KernelFuncs *FindKernelFuncs(const std::string &kernel_type) const = 0; - virtual const KernelInfo *FindKernelInfo(const std::string &kernel_type) const = 0; - virtual void RegisterKernel(std::string kernel_type, KernelInfo kernel_infos) { - (void) kernel_type; - (void) kernel_infos; - }; -}; - -class KernelRegisterData; -class KernelRegisterV2 { - public: - explicit KernelRegisterV2(const ge::char_t *kernel_type); - KernelRegisterV2(const KernelRegisterV2 &other); - ~KernelRegisterV2(); - KernelRegisterV2 &operator=(const KernelRegisterV2 &other) = delete; - KernelRegisterV2 &operator=(KernelRegisterV2 &&other) = delete; - KernelRegisterV2(KernelRegisterV2 &&other) = delete; - - KernelRegisterV2 &RunFunc(KernelRegistry::KernelFunc func); - KernelRegisterV2 &ConcurrentCriticalSectionKey(const std::string &critical_section_key); - - KernelRegisterV2 &OutputsCreator(KernelRegistry::CreateOutputsFunc func); - KernelRegisterV2 &TracePrinter(KernelRegistry::TracePrinter func); - KernelRegisterV2 &ProfilingInfoFiller(KernelRegistry::ProfilingInfoFiller func); - KernelRegisterV2 &DataDumpInfoFiller(KernelRegistry::DataDumpInfoFiller func); - KernelRegisterV2 &ExceptionDumpInfoFiller(KernelRegistry::ExceptionDumpInfoFiller func); - - private: - std::unique_ptr register_data_; -}; -} // namespace gert - -#define REGISTER_KERNEL_COUNTER2(type, counter) static auto g_register_kernel_##counter = gert::KernelRegisterV2(#type) -#define REGISTER_KERNEL_COUNTER(type, counter) REGISTER_KERNEL_COUNTER2(type, counter) -#define REGISTER_KERNEL(type) REGISTER_KERNEL_COUNTER(type, __COUNTER__) - -#endif diff --git a/inc/register/kernel_registry_impl.h b/inc/register/kernel_registry_impl.h deleted file mode 100644 index 71c91dc8b3351fef09a5987c6e3c7f447f847ef7..0000000000000000000000000000000000000000 --- a/inc/register/kernel_registry_impl.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_EXTERNAL_REGISTER_KERNEL_REGISTER_IMPL_H_ -#define INC_EXTERNAL_REGISTER_KERNEL_REGISTER_IMPL_H_ -#include -#include - -#include "kernel_registry.h" - -namespace gert { -class KernelRegistryImpl : public KernelRegistry { - public: - static KernelRegistryImpl &GetInstance(); - void RegisterKernel(std::string kernel_type, KernelInfo kernel_infos) override; - const KernelFuncs *FindKernelFuncs(const std::string &kernel_type) const override; - const KernelInfo *FindKernelInfo(const std::string &kernel_type) const override; - - const std::unordered_map &GetAll() const; - - private: - std::unordered_map kernel_infos_; -}; -} - -#endif // INC_EXTERNAL_REGISTER_KERNEL_REGISTER_IMPL_H_ diff --git a/inc/register/node_converter_registry.h b/inc/register/node_converter_registry.h deleted file mode 100644 index 944d5fac7571f9b71cf9249383332bc9a6a3e5bd..0000000000000000000000000000000000000000 --- a/inc/register/node_converter_registry.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef AIR_CXX_RUNTIME_V2_LOWERING_NODE_CONVERTER_REGISTRY_H_ -#define AIR_CXX_RUNTIME_V2_LOWERING_NODE_CONVERTER_REGISTRY_H_ -#include -#include -#include - -#include "graph/node.h" -#include "external/graph/types.h" -#include "exe_graph/lowering/dev_mem_value_holder.h" -#include "exe_graph/lowering/lowering_global_data.h" - -namespace gert { -struct LowerInput { - std::vector input_shapes; - std::vector input_addrs; - LoweringGlobalData *global_data; -}; -struct LowerResult { - HyperStatus result; - std::vector order_holders; - std::vector out_shapes; - std::vector out_addrs; -}; - -struct LowerInputInfo { - std::vector input_shapes; - std::vector input_addrs; -}; - -class NodeConverterRegistry { - public: - using NodeConverter = LowerResult (*)(const ge::NodePtr &node, const LowerInput &lower_input); - struct ConverterRegisterData { - NodeConverter converter; - int32_t require_placement; - }; - static NodeConverterRegistry &GetInstance(); - NodeConverter FindNodeConverter(const std::string &func_name); - const ConverterRegisterData *FindRegisterData(const std::string &func_name) const; - void RegisterNodeConverter(const std::string &func_name, NodeConverter func); - void Register(const std::string &func_name, const ConverterRegisterData &data); - - private: - std::unordered_map names_to_register_data_; -}; - -class NodeConverterRegister { - public: - NodeConverterRegister(const ge::char_t *lower_func_name, NodeConverterRegistry::NodeConverter func) noexcept; - NodeConverterRegister(const ge::char_t *lower_func_name, int32_t require_placement, - NodeConverterRegistry::NodeConverter func) noexcept; -}; -} // namespace gert - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif - -#define GERT_REGISTER_NODE_CONVERTER_COUNTER2(type, placement, func, counter) \ - static const gert::NodeConverterRegister g_register_node_converter_##counter ATTRIBUTE_USED = \ - gert::NodeConverterRegister(type, placement, func) -#define GERT_REGISTER_NODE_CONVERTER_COUNTER(type, placement, func, counter) \ - GERT_REGISTER_NODE_CONVERTER_COUNTER2(type, placement, func, counter) -#define REGISTER_NODE_CONVERTER_PLACEMENT(type, placement, func) \ - GERT_REGISTER_NODE_CONVERTER_COUNTER(type, placement, func, __COUNTER__) -#define REGISTER_NODE_CONVERTER(type, func) REGISTER_NODE_CONVERTER_PLACEMENT(type, -1, func) - -#endif // AIR_CXX_RUNTIME_V2_LOWERING_NODE_CONVERTER_REGISTRY_H_ diff --git a/inc/register/op_check_register.h b/inc/register/op_check_register.h deleted file mode 100644 index 7a1712751e6514dbb9b54728d6805be90e513233..0000000000000000000000000000000000000000 --- a/inc/register/op_check_register.h +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OP_CHECK_REGISTER_H_ -#define INC_REGISTER_OP_CHECK_REGISTER_H_ - -#include - -#include "graph/ascend_string.h" -#include "graph/operator.h" -#include "graph/ge_error_codes.h" -#include "register/op_def.h" - -namespace optiling { -struct ReplayFuncParam { - int32_t block_dim = 0; - const char *tiling_data = nullptr; - const char *kernel_name = nullptr; - const char *entry_file = nullptr; - int32_t gentype = 0; - const char *output_kernel_file = nullptr; - char **objptr = nullptr; - int32_t task_ration = 0; - int32_t tiling_key = 0; -}; - -using REPLAY_FUNC = int32_t (*)(ReplayFuncParam ¶m, const int32_t core_type); -using GEN_SIMPLIFIEDKEY_FUNC = bool (*)(const ge::Operator &op, ge::AscendString &result); - -class OpCheckFuncRegistry { -public: - static void RegisterOpCapability(const ge::AscendString &check_type, const ge::AscendString &op_type, - OP_CHECK_FUNC func); - - static OP_CHECK_FUNC GetOpCapability(const ge::AscendString &check_type, const ge::AscendString &op_type); - - static void RegisterGenSimplifiedKeyFunc(const ge::AscendString &op_type, GEN_SIMPLIFIEDKEY_FUNC func); - - static GEN_SIMPLIFIEDKEY_FUNC GetGenSimplifiedKeyFun(const ge::AscendString &op_type); - - static PARAM_GENERALIZE_FUNC GetParamGeneralize(const ge::AscendString &op_type); - - static void RegisterParamGeneralize(const ge::AscendString &op_type, PARAM_GENERALIZE_FUNC func); - - static void RegisterReplay(const ge::AscendString &op_type, const ge::AscendString &soc_version, REPLAY_FUNC func); - static REPLAY_FUNC GetReplay(const ge::AscendString &op_type, const ge::AscendString &soc_version); - -private: - static std::map> check_op_capability_instance_; - static std::map gen_simplifiedkey_instance_; - static std::map param_generalize_instance_; - static std::map> replay_instance_; -}; - -class ReplayFuncHelper { -public: - ReplayFuncHelper(const ge::AscendString &op_type, const ge::AscendString &soc_version, REPLAY_FUNC func); -}; -} // end of namespace optiling -#endif // INC_REGISTER_OP_CHECK_REGISTER_H_ diff --git a/inc/register/op_ext_calc_param_registry.h b/inc/register/op_ext_calc_param_registry.h deleted file mode 100644 index 4ce08a222f9f624cdb63c3278feff9b383ee4d0e..0000000000000000000000000000000000000000 --- a/inc/register/op_ext_calc_param_registry.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OP_EXT_CALC_PARAM_REGISTRY_H_ -#define INC_REGISTER_OP_EXT_CALC_PARAM_REGISTRY_H_ - -#include -#include -#include -#include "graph/node.h" -#include "external/ge_common/ge_api_types.h" - -namespace fe { -using OpExtCalcParamFunc = ge::Status (*)(const ge::Node &node); -class OpExtCalcParamRegistry { - public: - OpExtCalcParamRegistry() {}; - ~OpExtCalcParamRegistry() {}; - static OpExtCalcParamRegistry &GetInstance(); - OpExtCalcParamFunc FindRegisterFunc(const std::string &op_type) const; - void Register(const std::string &op_type, OpExtCalcParamFunc const func); - - private: - std::unordered_map names_to_register_func_; - }; -class OpExtGenCalcParamRegister { - public: - OpExtGenCalcParamRegister(const char *op_type, OpExtCalcParamFunc func) noexcept; -}; -} // namespace fe - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif - -#define REGISTER_NODE_EXT_CALC_PARAM_COUNTER2(type, func, counter) \ - static const fe::OpExtGenCalcParamRegister g_reg_op_ext_gentask_##counter ATTRIBUTE_USED = \ - fe::OpExtGenCalcParamRegister(type, func) -#define REGISTER_NODE_EXT_CALC_PARAM_COUNTER(type, func, counter) \ - REGISTER_NODE_EXT_CALC_PARAM_COUNTER2(type, func, counter) -#define REGISTER_NODE_EXT_CALC_PARAM(type, func) \ - REGISTER_NODE_EXT_CALC_PARAM_COUNTER(type, func, __COUNTER__) -#endif // INC_REGISTER_OP_EXT_CALC_PARAM_REGISTRY_H_ diff --git a/inc/register/op_ext_gentask_registry.h b/inc/register/op_ext_gentask_registry.h deleted file mode 100644 index 069f6799053492063bc21477a9080b6af639d0b7..0000000000000000000000000000000000000000 --- a/inc/register/op_ext_gentask_registry.h +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OP_EXTRA_GENTASK_REGISTRY_H -#define INC_REGISTER_OP_EXTRA_GENTASK_REGISTRY_H -#include -#include -#include -#include "graph/node.h" -#include "proto/task.pb.h" -#include "external/ge_common/ge_api_types.h" -#include "common/opskernel/ops_kernel_info_types.h" -namespace fe { -using OpExtGenTaskFunc = ge::Status (*)(const ge::Node &node, - ge::RunContext &context, std::vector &tasks); -using SKExtGenTaskFunc = ge::Status (*)( - const ge::Node &node, std::vector> &subTasks, - const std::vector &sub_nodes, std::vector &tasks); - -enum class ExtTaskType { - kFftsPlusTask, - kAicoreTask -}; - -class OpExtGenTaskRegistry { - public: - OpExtGenTaskRegistry() {}; - ~OpExtGenTaskRegistry() {}; - static OpExtGenTaskRegistry &GetInstance(); - OpExtGenTaskFunc FindRegisterFunc(const std::string &op_type) const; - void Register(const std::string &op_type, OpExtGenTaskFunc const func); - SKExtGenTaskFunc FindSKRegisterFunc(const std::string &op_type) const; - void RegisterSKFunc(const std::string &op_type, SKExtGenTaskFunc const func); - ExtTaskType GetExtTaskType(const std::string &op_type) const; - void RegisterAicoreExtTask(const std::string &op_type); - - private: - std::unordered_map names_to_register_func_; - std::unordered_map types_to_sk_register_func_; - std::unordered_set aicore_ext_task_ops_; -}; - -class OpExtGenTaskRegister { -public: - OpExtGenTaskRegister(const char *op_type, OpExtGenTaskFunc func) noexcept; -}; - -class SKExtGenTaskRegister { - public: - SKExtGenTaskRegister(const char *op_type, SKExtGenTaskFunc func) noexcept; -}; - -class ExtTaskTypeRegister { - public: - ExtTaskTypeRegister(const char *op_type, ExtTaskType type) noexcept; -}; -} // namespace fe - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif - -#define REGISTER_NODE_EXT_GENTASK_COUNTER2(type, func, counter) \ - static const fe::OpExtGenTaskRegister g_reg_op_ext_gentask_##counter ATTRIBUTE_USED = \ - fe::OpExtGenTaskRegister(type, func) -#define REGISTER_NODE_EXT_GENTASK_COUNTER(type, func, counter) \ - REGISTER_NODE_EXT_GENTASK_COUNTER2(type, func, counter) -#define REGISTER_NODE_EXT_GENTASK(type, func) \ - REGISTER_NODE_EXT_GENTASK_COUNTER(type, func, __COUNTER__) - -#define REGISTER_SK_EXT_GENTASK_COUNTER2(type, func, counter) \ - static const fe::SKExtGenTaskRegister g_reg_op_ext_gentask_##counter ATTRIBUTE_USED = \ - fe::SKExtGenTaskRegister(type, func) -#define REGISTER_SK_EXT_GENTASK_COUNTER(type, func, counter) \ - REGISTER_SK_EXT_GENTASK_COUNTER2(type, func, counter) -#define REGISTER_SK_EXT_GENTASK(type, func) \ - REGISTER_SK_EXT_GENTASK_COUNTER(type, func, __COUNTER__) - -#define REGISTER_EXT_TASK_TYPE_COUNTER2(type, task_type, counter) \ - static const fe::ExtTaskTypeRegister g_reg_op_ext_gentask_##counter ATTRIBUTE_USED = \ - fe::ExtTaskTypeRegister(#type, task_type) -#define REGISTER_EXT_TASK_TYPE_COUNTER(type, task_type, counter) \ - REGISTER_EXT_TASK_TYPE_COUNTER2(type, task_type, counter) -#define REGISTER_EXT_TASK_TYPE(type, task_type) \ - REGISTER_EXT_TASK_TYPE_COUNTER(type, task_type, __COUNTER__) -#endif // INC_REGISTER_OP_EXTRA_GENTASK_REGISTRY_H diff --git a/inc/register/op_kernel_registry.h b/inc/register/op_kernel_registry.h deleted file mode 100644 index 3fdd765c9b3a0db4c9f05e7db7a8f10623840d69..0000000000000000000000000000000000000000 --- a/inc/register/op_kernel_registry.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OP_KERNEL_REGISTRY_H_ -#define INC_REGISTER_OP_KERNEL_REGISTRY_H_ -#include -#include -#include "register/register_types.h" -#include "register.h" - -namespace ge { -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpKernelRegistry { - public: - using CreateFn = HostCpuOp* (*)(); - ~OpKernelRegistry(); - - static OpKernelRegistry& GetInstance(); - - bool IsRegistered(const std::string &op_type) const; - - void RegisterHostCpuOp(const std::string &op_type, const CreateFn create_fn); - - std::unique_ptr CreateHostCpuOp(const std::string &op_type) const; - - private: - OpKernelRegistry(); - class OpKernelRegistryImpl; - /*lint -e148*/ - std::unique_ptr impl_; -}; -} // namespace ge - -#endif // INC_REGISTER_OP_KERNEL_REGISTRY_H_ diff --git a/inc/register/op_lib_register_impl.h b/inc/register/op_lib_register_impl.h deleted file mode 100644 index efcff666c12a09051688dea2f9fe1b57fc213812..0000000000000000000000000000000000000000 --- a/inc/register/op_lib_register_impl.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_REGISTER_OP_LIB_REGISTER_IMPL_H_ -#define METADEF_CXX_REGISTER_OP_LIB_REGISTER_IMPL_H_ - -#include -#include -#include -#include "register/op_lib_register.h" -#include "graph/ge_error_codes.h" - -namespace ge { -class OpLibRegisterImpl { - public: - void SetVendorName(const std::string & vendor_name) { vendor_name_ = vendor_name; } - void SetInitFunc(const OpLibRegister::OpLibInitFunc init_func) { init_func_ = init_func; } - std::string GetVendorName() const { return vendor_name_; } - OpLibRegister::OpLibInitFunc GetInitFunc() const { return init_func_; } - - private: - std::string vendor_name_; - OpLibRegister::OpLibInitFunc init_func_ = nullptr; -}; - -class OpLibRegistry { - public: - static OpLibRegistry &GetInstance(); - ~OpLibRegistry(); - void RegisterInitFunc(OpLibRegisterImpl ®ister_impl); - graphStatus PreProcessForCustomOp(); - const char_t* GetCustomOpLibPath() const; - - private: - void ClearHandles(); - graphStatus GetAllCustomOpApiSoPaths(const std::string &custom_opp_path, - std::vector &so_real_paths) const; - graphStatus CallInitFunc(const std::string &custom_opp_path, - const std::vector &so_real_paths); - - std::mutex mu_; - std::vector> vendor_funcs_; - std::set vendor_names_set_; - std::vector handles_; - bool is_processed_ = false; - std::string op_lib_paths_; -}; -} // namespace ge -#endif // METADEF_CXX_REGISTER_OP_LIB_REGISTER_IMPL_H_ diff --git a/inc/register/op_registry.h b/inc/register/op_registry.h deleted file mode 100644 index 509fe7156d1163fd5b95bb53b5f680f11c257da4..0000000000000000000000000000000000000000 --- a/inc/register/op_registry.h +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OP_REGISTRY_H_ -#define INC_REGISTER_OP_REGISTRY_H_ - -#include -#include -#include -#include -#include -#include - -#include "register/register.h" - -namespace domi { -enum class RemoveInputType : uint16_t { - OMG_MOVE_TYPE_DTYPE = 0, - OMG_MOVE_TYPE_VALUE = 1, - OMG_MOVE_TYPE_SHAPE = 2, - OMG_MOVE_TYPE_FORMAT = 3, - OMG_MOVE_TYPE_AXIS = 4, - OMG_MOVE_TYPE_SCALAR_VALUE = 5, - OMG_REMOVE_TYPE_WITH_COND = 1000, - OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE = 1001, - OMG_INPUT_REORDER = 1002, -}; - -struct RemoveInputConfigure { - int32_t inputIdx = INT_MAX; - std::string attrName; - RemoveInputType moveType; - bool attrValue = false; - std::string originalType; - std::vector input_order; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { - public: - static OpRegistry *Instance(); - - std::vector registrationDatas; - - bool Register(const OpRegistrationData ®_data); - - domi::ImplyType GetImplyType(const std::string &op_type); - - void GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType imply_type) const; - - domi::ParseParamFunc GetParseParamFunc(const std::string &op_type, const std::string &ori_type); - - domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &ori_type); - - domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type, const std::string &ori_type); - - domi::FusionParseParamByOpFunc GetFusionParseParamByOpFunc(const std::string &op_type, - const std::string &ori_type); - - domi::ParseSubgraphFunc GetParseSubgraphPostFunc(const std::string &op_type); - - Status GetParseSubgraphPostFunc(const std::string &op_type, domi::ParseSubgraphFuncV2 &parse_subgraph_func); - - domi::ImplyType GetImplyTypeByOriOpType(const std::string &ori_optype); - - const std::vector &GetRemoveInputConfigure(const std::string &ori_optype) const; - - bool GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type); - - ParseOpToGraphFunc GetParseOpToGraphFunc(const std::string &op_type, const std::string &ori_type); - - private: - std::unordered_map op_run_mode_map_; - std::unordered_map op_parse_params_fn_map_; - std::unordered_map parse_params_by_op_func_map_; - std::unordered_map fusion_op_parse_params_fn_map_; - std::unordered_map fusion_parse_params_by_op_fn_map_; - std::unordered_map op_types_to_parse_subgraph_post_func_; - std::unordered_map> remove_input_configure_map_; - std::map origin_type_to_om_type_; - std::unordered_map parse_op_to_graph_fn_map_; - std::unordered_map op_types_to_parse_subgraph_post_func_v2_; -}; -} // namespace domi -#endif // INC_REGISTER_OP_REGISTRY_H_ diff --git a/inc/register/op_tiling.h b/inc/register/op_tiling.h deleted file mode 100644 index 18640179f31235a617dfebbc370a00492f95bd3d..0000000000000000000000000000000000000000 --- a/inc/register/op_tiling.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OP_TILING_H -#define INC_REGISTER_OP_TILING_H - -#include "graph/debug/ge_attr_define.h" -#include "graph/node.h" -#include "register/op_tiling_registry.h" - -namespace optiling { -extern "C" ge::graphStatus OpParaCalculateV2(const ge::Operator &op, OpRunInfoV2 &run_info); -extern "C" ge::graphStatus OpAtomicCalculateV2(const ge::Node &node, OpRunInfoV2 &run_info); -extern "C" ge::graphStatus OpFftsPlusCalculate(const ge::Operator &op, std::vector &op_run_info); -} // namespace optiling -#endif // INC_REGISTER_OP_TILING_H diff --git a/inc/register/ops_kernel_builder_registry.h b/inc/register/ops_kernel_builder_registry.h deleted file mode 100644 index b282bf02e8cc8976d512cd25a78be8e33d319a28..0000000000000000000000000000000000000000 --- a/inc/register/ops_kernel_builder_registry.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H -#define INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H - -#include -#include "register/register_types.h" -#include "common/opskernel/ops_kernel_builder.h" - -namespace ge { -using OpsKernelBuilderPtr = std::shared_ptr; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpsKernelBuilderRegistry { - public: - ~OpsKernelBuilderRegistry() noexcept; - static OpsKernelBuilderRegistry &GetInstance(); - - void Register(const std::string &lib_name, const OpsKernelBuilderPtr &instance); - - void Unregister(const std::string &lib_name); - - void UnregisterAll(); - - const std::map &GetAll() const; - - private: - OpsKernelBuilderRegistry() = default; - std::map kernel_builders_; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpsKernelBuilderRegistrar { - public: - using CreateFn = OpsKernelBuilder *(*)(); - OpsKernelBuilderRegistrar(const std::string &kernel_lib_name, const CreateFn fn); - ~OpsKernelBuilderRegistrar() noexcept; - -private: - std::string kernel_lib_name_; -}; -} // namespace ge - -#define REGISTER_OPS_KERNEL_BUILDER(kernel_lib_name, builder) \ - REGISTER_OPS_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_lib_name, builder) - -#define REGISTER_OPS_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_lib_name, builder) \ - REGISTER_OPS_KERNEL_BUILDER_UNIQ(ctr, kernel_lib_name, builder) - -#define REGISTER_OPS_KERNEL_BUILDER_UNIQ(ctr, kernel_lib_name, builder) \ - static ::ge::OpsKernelBuilderRegistrar register_op_kernel_builder_##ctr \ - __attribute__((unused)) = \ - ::ge::OpsKernelBuilderRegistrar((kernel_lib_name), []()->::ge::OpsKernelBuilder* { \ - return new (std::nothrow) (builder)(); \ - }) - -#endif // INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H diff --git a/inc/register/optimization_option_registry.h b/inc/register/optimization_option_registry.h deleted file mode 100644 index 1cbaefc056eb2ceb94e551771fbf84edf208672a..0000000000000000000000000000000000000000 --- a/inc/register/optimization_option_registry.h +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OPTIMIZATION_OPTION_REGISTRY_H -#define INC_REGISTER_OPTIMIZATION_OPTION_REGISTRY_H - -#include -#include -#include -#include -#include "graph/option/optimization_option_info.h" -#include "graph/ge_error_codes.h" - -namespace ge { -class OptionRegistry { - public: - static OptionRegistry &GetInstance(); - void Register(const OoInfo &option); - const OoInfo *FindOptInfo(const std::string &opt_name) const; - std::unordered_map GetVisibleOptions(OoEntryPoint entry_point) const; - - const std::unordered_map &GetRegisteredOptTable() const { - return registered_opt_table_; - } - OptionRegistry(const OptionRegistry &) = delete; - OptionRegistry &operator=(const OptionRegistry &) = delete; - - private: - OptionRegistry() = default; - ~OptionRegistry() = default; - - private: - std::unordered_map registered_opt_table_; -}; - -class PassOptionRegistry { - public: - using OptNameTable = - std::unordered_map(OoHierarchy::kEnd)>>; - static PassOptionRegistry &GetInstance(); - void Register(const std::string &pass_name, const std::map &option_names); - graphStatus FindOptionNamesByPassName(const std::string &pass_name, std::vector &option_names) const; - - PassOptionRegistry(const PassOptionRegistry &) = delete; - PassOptionRegistry &operator=(const PassOptionRegistry &) = delete; - - private: - PassOptionRegistry() = default; - ~PassOptionRegistry() = default; - - private: - OptNameTable pass_names_to_options_; -}; - -class OptionRegister { - public: - explicit OptionRegister(std::string opt_name, OoHierarchy hierarchy = OoHierarchy::kH1) - : opt_reg_data_(new(std::nothrow) OoInfo(std::move(opt_name), hierarchy)) {} - OptionRegister(const OptionRegister &other); - OptionRegister &SetDefaultValues(const std::map& opt_values); - OptionRegister &SetOptLevel(const std::vector &levels); - OptionRegister &SetVisibility(const std::vector &entry_points); - OptionRegister &SetOptValueChecker(OoInfo::ValueChecker opt_checker); - OptionRegister &SetHelpText(std::string opt_help); - OptionRegister &SetShowName(OoEntryPoint entry_point, std::string show_name, OoCategory category = OoCategory::kEnd); - - OptionRegister &operator=(const OptionRegister &) = delete; - OptionRegister &operator=(OptionRegister &&) = delete; - OptionRegister(OptionRegister &&) = delete; - - private: - std::unique_ptr opt_reg_data_; -}; - -class PassOptionRegister { - public: - explicit PassOptionRegister(std::string pass_name) - : pass_reg_data_(new(std::nothrow) PassOptRegData({std::move(pass_name), {}, {}})) {} - PassOptionRegister(const PassOptionRegister &other); - PassOptionRegister &SetOptLevel(const std::vector &levels); - PassOptionRegister &BindSwitchOption(const std::string &opt_name, OoHierarchy hierarchy = OoHierarchy::kH1); - - PassOptionRegister &operator=(const PassOptionRegister &) = delete; - PassOptionRegister &operator=(PassOptionRegister &&) = delete; - PassOptionRegister(PassOptionRegister &&) = delete; - - private: - struct PassOptRegData { - std::string pass_name; - std::vector levels; // levels is used when options is empty - std::map options; - }; - std::unique_ptr pass_reg_data_; -}; -} // namespace ge - -#ifdef __GNUC__ -#define ATTR_USED __attribute__((used)) -#else -#define ATTR_USED -#endif - -#define REG_OPTION(opt_name, ...) REG_UNIQUE_OPTION(__COUNTER__, opt_name, ##__VA_ARGS__) -#define REG_UNIQUE_OPTION(counter, opt_name, ...) REG_UNIQUE_OPTION_WRAPPER(counter, opt_name, ##__VA_ARGS__) -#define REG_UNIQUE_OPTION_WRAPPER(counter, opt_name, ...) \ - static ge::OptionRegister opt_register_##counter ATTR_USED = ge::OptionRegister((opt_name), ##__VA_ARGS__) - -#define DEFAULT_VALUES(...) SetDefaultValues(std::map{__VA_ARGS__}) -#define LEVELS(level1, ...) SetOptLevel(std::vector({(level1), ##__VA_ARGS__})) -#define VISIBILITY(...) SetVisibility(std::vector({__VA_ARGS__})) -#define CHECKER(checker_func) SetOptValueChecker((checker_func)) -#define HELP(help_text) SetHelpText((help_text)) -#define SHOW_NAME(entry, show_name, ...) SetShowName((entry), (show_name), ##__VA_ARGS__) -#define SWITCH_OPT(opt_name, ...) BindSwitchOption((opt_name), ##__VA_ARGS__) - -#define REG_PASS_OPTION(pass_name) REG_UNIQUE_PASS_OPTION(__COUNTER__, pass_name) -#define REG_UNIQUE_PASS_OPTION(counter, pass_name) REG_UNIQUE_PASS_OPTION_WRAPPER(counter, pass_name) -#define REG_UNIQUE_PASS_OPTION_WRAPPER(counter, pass_name) \ - static ge::PassOptionRegister pass_opt_register_##counter ATTR_USED = ge::PassOptionRegister((pass_name)) - -#endif // INC_REGISTER_OPTIMIZATION_OPTION_REGISTRY_H diff --git a/inc/register/pass_option_utils.h b/inc/register/pass_option_utils.h deleted file mode 100644 index a78b9925a7b87fcfdfec7470d5855f45c06a92ad..0000000000000000000000000000000000000000 --- a/inc/register/pass_option_utils.h +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef INC_REGISTER_PASS_OPTION_UTILS_H -#define INC_REGISTER_PASS_OPTION_UTILS_H - -#include "optimization_option_registry.h" -#include "graph/ge_error_codes.h" -namespace ge { -class PassOptionUtils { - public: - static graphStatus CheckIsPassEnabled(const std::string &pass_name, bool &is_enabled); - - static graphStatus CheckIsPassEnabledByOption(const std::string &pass_name, bool &is_enabled); -}; -} // namespace ge - -#endif // INC_REGISTER_PASS_OPTION_UTILS_H diff --git a/inc/register/prototype_pass_registry.h b/inc/register/prototype_pass_registry.h deleted file mode 100644 index c90f3e9f8adda813287d88d7d45cd36c1048e1ce..0000000000000000000000000000000000000000 --- a/inc/register/prototype_pass_registry.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_PROTOTYPE_PASS_REGISTRY_H -#define METADEF_PROTOTYPE_PASS_REGISTRY_H - -#include - -#include -#include -#include - -#include "external/ge_common/ge_api_error_codes.h" -#include "external/graph/types.h" -#include "register/register_error_codes.h" -#include "register/register_fmk_types.h" - -namespace ge { -class ProtoTypeBasePass { - public: - ProtoTypeBasePass() = default; - virtual Status Run(google::protobuf::Message *message) = 0; - virtual ~ProtoTypeBasePass() = default; - - private: - ProtoTypeBasePass(const ProtoTypeBasePass &) = delete; - ProtoTypeBasePass &operator=(const ProtoTypeBasePass &) & = delete; -}; - -class ProtoTypePassRegistry { - public: - using CreateFn = std::function; - ~ProtoTypePassRegistry(); - - static ProtoTypePassRegistry &GetInstance(); - - void RegisterProtoTypePass(const char_t *const pass_name, const CreateFn &create_fn, - const domi::FrameworkType fmk_type); - - std::vector> GetCreateFnByType(const domi::FrameworkType fmk_type) const; - - private: - ProtoTypePassRegistry(); - class ProtoTypePassRegistryImpl; - std::unique_ptr impl_; -}; - -class ProtoTypePassRegistrar { - public: - ProtoTypePassRegistrar(const char_t *const pass_name, ProtoTypeBasePass *(*const create_fn)(), - const domi::FrameworkType fmk_type); - ~ProtoTypePassRegistrar() = default; -}; -} // namespace ge - -#define REGISTER_PROTOTYPE_PASS(pass_name, pass, fmk_type) \ - REGISTER_PROTOTYPE_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass, fmk_type) - -#define REGISTER_PROTOTYPE_PASS_UNIQ_HELPER(ctr, pass_name, pass, fmk_type) \ - REGISTER_PROTOTYPE_PASS_UNIQ(ctr, pass_name, pass, fmk_type) - -#define REGISTER_PROTOTYPE_PASS_UNIQ(ctr, pass_name, pass, fmk_type) \ - static ::ge::ProtoTypePassRegistrar register_prototype_pass##ctr __attribute__((unused)) = \ - ::ge::ProtoTypePassRegistrar( \ - (pass_name), []()->::ge::ProtoTypeBasePass * { return new (std::nothrow) pass(); }, (fmk_type)) - -#endif // METADEF_PROTOTYPE_PASS_REGISTRY_H diff --git a/inc/register/register.h b/inc/register/register.h deleted file mode 100644 index edc01187c91335a9deeb3226b897c8b24cbab2e0..0000000000000000000000000000000000000000 --- a/inc/register/register.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_REGISTRY_H_ -#define INC_REGISTER_REGISTRY_H_ - -#include "external/register/register.h" -#include "external/ge_common/ge_api_error_codes.h" -#include "graph/ge_error_codes.h" - -namespace ge { -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { - public: - HostCpuOp() = default; - HostCpuOp(HostCpuOp &&) = delete; - HostCpuOp &operator=(HostCpuOp &&) & = delete; - virtual ~HostCpuOp() = default; - virtual graphStatus Compute(Operator &op, - const std::map &inputs, - std::map &outputs) = 0; - - private: - HostCpuOp(const HostCpuOp &) = delete; - HostCpuOp &operator=(const HostCpuOp &) & = delete; -}; - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { - public: - HostCpuOpRegistrar(const char_t *const op_type, HostCpuOp *(*const create_fn)()); - ~HostCpuOpRegistrar() = default; -}; -} // namespace ge - -#define REGISTER_HOST_CPU_OP_BUILDER(name, op) \ - REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) - -#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) \ - REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) - -#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ - static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr \ - __attribute__((unused)) = \ - ::ge::HostCpuOpRegistrar((name), []()->::ge::HostCpuOp* { \ - return new (std::nothrow) (op)(); \ - }) - -#endif // INC_REGISTER_REGISTRY_H_ diff --git a/inc/register/register_utils.h b/inc/register/register_utils.h deleted file mode 100644 index e4865ef073c44678d9d72efdd8129d8c335edc12..0000000000000000000000000000000000000000 --- a/inc/register/register_utils.h +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_REGISTER_UTILS_H -#define REGISTER_REGISTER_UTILS_H - -#include -#include "external/register/register_types.h" -#include "external/register/register_error_codes.h" -#include "external/register/register.h" -#include "external/graph/operator.h" - -namespace domi { -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OperatorAutoMapping(const Message *op_src, ge::Operator &op); -} // namespace domi -#endif // REGISTER_REGISTER_UTILS_H diff --git a/inc/register/scope/scope_graph_impl.h b/inc/register/scope/scope_graph_impl.h deleted file mode 100644 index 90fa5d86478a0e929e0ac5eef8622f944d42c142..0000000000000000000000000000000000000000 --- a/inc/register/scope/scope_graph_impl.h +++ /dev/null @@ -1,194 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H -#define REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H - -#include "external/register/scope/scope_fusion_pass_register.h" -#include "external/graph/types.h" -#include "graph/operator_factory.h" -#include "proto/tensorflow/graph.pb.h" -#include "proto/tensorflow/node_def.pb.h" -#include "graph/utils/type_utils.h" - -namespace ge { -using FusionInnerNodesInfo = std::vector>, // inputs - std::vector>, // outputs - const ge::Operator *>>; // operator - -class Scope::ScopeImpl { - public: - ScopeImpl() = default; - Status Init(const std::string &name, const std::string &sub_type = "", Scope *const father_scope = nullptr); - ~ScopeImpl(); - - const std::string &Name() const { return name_; } - const std::string &SubType() const { return sub_type_; } - void SetSubType(const std::string &sub_type) { sub_type_ = sub_type; } - void ClearTypeAndSubType(); - void AddNode(ge::OperatorPtr &node_def); - const std::vector &Nodes() const { return nodes_; } - const std::unordered_map &AllNodesMap(); - void AddSubScope(Scope *const scope); - Scope *GetSubScope(const std::string &scope_name) const; - const std::unordered_map &GetSubScopes() const { return sub_scopes_; } - const std::vector &GetAllSubScopes(); - int32_t GetOpTypeNum(const std::string &op_type) const; - void OpsNumInc(const std::string &op_type); - const std::string LastName() const; - const Scope *GetFatherScope() const { return father_scope_; } - // trim scope_index - static std::string TrimScopeIndex(const std::string &scope_name); - - private: - std::string name_; - std::string sub_type_; - Scope *father_scope_ = nullptr; - std::map op_nums_; - std::unordered_map sub_scopes_; - std::vector nodes_; - std::unordered_map all_nodes_map_; - std::map all_nodes_map_new_; - std::vector all_sub_scopes_; -}; - -class FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl { - public: - explicit InnerNodeInfoImpl(const std::string &fusion_node_name) : fusion_node_name_(fusion_node_name) {} - InnerNodeInfoImpl(const std::string &fusion_node_name, const std::string &name, const std::string &type) - : fusion_node_name_(fusion_node_name), name_(name), type_(type) { - SetName(name); - } - ~InnerNodeInfoImpl() noexcept; - std::string GetFullNodeName(const std::string &relative_name); - void SetName(const std::string &name) { name_ = GetFullNodeName(name); } - void SetType(const std::string &type) { type_ = type; } - void InsertInput(const std::string &input_node, int32_t peer_out_idx); - void InsertOutput(const std::string &output_node, int32_t peer_in_idx); - ge::graphStatus BuildOperator(); - ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format) ; - ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); - ge::graphStatus SetDynamicInputFormat(const std::string &input_name, const uint32_t index, const std::string &format); - ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, - const uint32_t index, - const std::string &format); - std::string GetName() const { return name_; } - std::string GetType() const { return type_; } - std::vector> GetInputs() const { return inner_node_inputs_; } - std::vector> GetOutputs() const { return inner_node_outputs_; } - ge::Operator *MutableOperator() { return &operator_; } - - private: - ge::Operator operator_; - std::string fusion_node_name_; - std::string name_; - std::string type_; - std::vector> inner_node_inputs_; - std::vector> inner_node_outputs_; -}; - -class FusionScopesResult::FusionScopesResultImpl { - public: - FusionScopesResultImpl() {} - ~FusionScopesResultImpl() = default; - void SetName(const std::string &name) { name_ = name; } - void SetType(const std::string &type) { type_ = type; } - void SetDescription(const std::string &description) { description_ = description; } - const std::string &Name() const { return name_; } - const std::string &Type() const { return type_; } - const std::string &Description() const { return description_; } - void AddNodes(const std::vector &nodes); - const std::vector &Nodes() const { return nodes_; } - void AddScopes(const std::vector &scopes) { - (void)scopes_.insert(scopes_.cend(), scopes.cbegin(), scopes.cend()); - } - const std::vector &Scopes() const { return scopes_; } - const std::map> &GetInputs() const { return inputs_; } - const std::map> &GetOutputs() const { return outputs_; } - void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); - void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); - bool FindNodes(const std::string &node_name) const; - bool FindScopes(const std::string &scope_name) const; - - InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); - InnerNodeInfo *MutableRecentInnerNode(); - InnerNodeInfo *MutableInnerNode(uint32_t index); - FusionInnerNodesInfo GetInnerNodesInfo(); - ge::graphStatus CheckInnerNodesInfo(); - - private: - std::string name_; - std::string type_; - std::string description_; - std::vector scopes_; - std::vector nodes_; - std::map> inputs_; - std::map> outputs_; - std::vector inner_node_infos_; -}; - -class ScopeTree::ScopeTreeImpl { - public: - ScopeTreeImpl() = default; - ScopeTreeImpl(const ScopeTreeImpl &) = delete; - ScopeTreeImpl &operator=(const ScopeTreeImpl &) & = delete; - Status Init(); - ~ScopeTreeImpl(); - - void AddNodeToScope(ge::OperatorPtr &node_def); - const std::vector &GetAllScopes() const { return scopes_; } - const Scope *Root() const { return root_; } - - private: - std::vector SplitNodeName(const std::string &node_name, const char_t delim) const; - Scope *root_ = nullptr; - std::vector scopes_; -}; - -struct ScopeFusionOpInfo { - std::string node_name; - std::string fusion_node_name; - std::string fusion_op_type; - std::string description; - bool scope_pass = true; -}; - -class ScopeGraph::ScopeGraphImpl { - public: - ScopeGraphImpl() = default; - ScopeGraphImpl(const ScopeGraphImpl &) = delete; - ScopeGraphImpl &operator=(const ScopeGraphImpl &) & = delete; - Status Init(); - ~ScopeGraphImpl() noexcept; - - const ScopeTree *GetScopeTree() const { return scope_tree_; } - void BuildScopeGraph(domi::tensorflow::GraphDef *graph_def); - void AddFusionScopesResult(FusionScopesResult *result); - const std::unordered_map &FusionScopesResults() const { return fusion_results_; } - FusionScopesResult *GetFusionScopesResults(const domi::tensorflow::NodeDef *const node_def) const; - FusionScopesResult *GetFusionScopesResults(const std::string &node_name) const; - const std::unordered_map &GetNodesMap() const { return nodes_map_; } - const std::map &GetNodesMapNew() const { return nodes_map_new_; } - bool IsFusionOpChild(const std::string &node_name, std::vector &info_list); - bool FusionOpChildIgnore(const ScopeFusionOpInfo &info); - bool IsFusionOp(const domi::tensorflow::NodeDef *const node_def); - Status GetInputOrOutputIndex(const ScopeFusionOpInfo &info, const int32_t old_index, - const bool input, int32_t &new_index); - - private: - std::vector GetFusionResultInputOrOutput(const ScopeFusionOpInfo &info, - const bool input); // input:true,output:false - std::unordered_map fusion_results_; - std::unordered_map nodes_map_; - std::map nodes_map_new_; - ScopeTree *scope_tree_ = nullptr; -}; -} // namespace ge -#endif // REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H diff --git a/inc/register/scope/scope_pass_impl.h b/inc/register/scope/scope_pass_impl.h deleted file mode 100644 index 5b0654b3234ec4660c9299b7eed9758976e566f2..0000000000000000000000000000000000000000 --- a/inc/register/scope/scope_pass_impl.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_SCOPE_SCOPE_PASS_IMPL_H -#define REGISTER_SCOPE_SCOPE_PASS_IMPL_H - -#include "external/register/scope/scope_fusion_pass_register.h" - -namespace ge { -class ScopesResult::ScopesResultImpl { - public: - void SetScopes(const std::vector &scopes) { scopes_ = scopes; } - const std::vector &GetScopes() const { return scopes_; } - void SetNodes(const std::vector &nodes) { nodes_ = nodes; } - const std::vector &GetNodes() const { return nodes_; } - - private: - std::vector scopes_; // multiple scopes - std::vector nodes_; // op outside of scope -}; - -class ScopeBasePass::ScopeBasePassImpl { - public: - explicit ScopeBasePassImpl(ScopeBasePass *const parent) : parent_(parent) {} - virtual ~ScopeBasePassImpl(); - - Status Run(std::shared_ptr &scope_graph); - - private: - Status AddFusionScopesResultToScopeGraph(const std::shared_ptr &scope_graph, - std::vector &scope_results) const; - // Match rules one by one, support multiple sets of matching rules, and finally output a single scope - // Note: This function does not have to be rewritten. - // In order to match the fusion rules designed by you better, - // you can implement your specific versions separately. - bool MatchAllBatches(const ScopeTree *scope_tree, std::vector &results); - - bool MatchOneBatch(const ScopeTree *const scope_tree, const std::vector &patternlist, - std::vector &results) const; - bool MatchOneScope(const ScopePattern *pattern, Scope *scope, std::vector &results) const; - Status PrintFusionScopeInfo(std::shared_ptr &scope_graph) const; - - private: - std::vector patterns_; - ScopeBasePass *parent_; -}; -} // namespace ge -#endif // REGISTER_SCOPE_SCOPE_PASS_IMPL_H diff --git a/inc/register/scope/scope_pass_registry_impl.h b/inc/register/scope/scope_pass_registry_impl.h deleted file mode 100644 index 46636a6869adb5dbd883a605b2cd4fee183d7f02..0000000000000000000000000000000000000000 --- a/inc/register/scope/scope_pass_registry_impl.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_SCOPE_SCOPE_PASS_REGISTRY_IMPL_H -#define REGISTER_SCOPE_SCOPE_PASS_REGISTRY_IMPL_H - -#include -#include "external/register/scope/scope_fusion_pass_register.h" - -namespace ge { -struct CreatePassFnPack; -class ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl { - public: - void RegisterScopeFusionPass(const std::string &pass_name, ScopeFusionPassRegistry::CreateFn create_fn, - bool is_general); - ScopeFusionPassRegistry::CreateFn GetCreateFn(const std::string &pass_name); - std::unique_ptr CreateScopeFusionPass(const std::string &pass_name); - std::vector GetAllRegisteredPasses(); - bool SetPassEnableFlag(const std::string pass_name, const bool flag); - - private: - std::mutex mu_; - std::vector pass_names_; // In the order of user registration - std::map create_fn_packs_; -}; -} // namespace ge -#endif // REGISTER_SCOPE_SCOPE_PASS_REGISTRY_IMPL_H diff --git a/inc/register/scope/scope_pattern_impl.h b/inc/register/scope/scope_pattern_impl.h deleted file mode 100644 index 70ad7fbeddbd5d94678b17f497e194d562ff6559..0000000000000000000000000000000000000000 --- a/inc/register/scope/scope_pattern_impl.h +++ /dev/null @@ -1,126 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H -#define REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H - -#include -#include -#include -#include "external/register/scope/scope_fusion_pass_register.h" -#include "external/graph/types.h" -#include "graph/compute_graph.h" - -namespace ge { -constexpr float32_t kCompareRatio = 2.0F; -class ScopeAttrValue::ScopeAttrValueImpl { - public: - ScopeAttrValueImpl() = default; - ~ScopeAttrValueImpl() = default; - - void SetIntValue(const int64_t value) { int_value_ = value; } - void SetFloatValue(const float32_t value) { float_value_ = value; } - void SetStringValue(const std::string &value) { string_value_ = value; } - void SetBoolValue(const bool value) { bool_value_ = value; } - const int64_t &GetIntValue() const { return int_value_; } - const float32_t &GetFloatValue() const { return float_value_; } - const std::string &GetStrValue() const { return string_value_; } - const bool &GetBoolValue() const { return bool_value_; } - - private: - int64_t int_value_ = 0; - float32_t float_value_ = 0.0F; - std::string string_value_ = ""; - bool bool_value_ = false; -}; - - -class NodeOpTypeFeature::NodeOpTypeFeatureImpl : ScopeBaseFeature { - public: - NodeOpTypeFeatureImpl(const std::string &nodeType, const int64_t num, const int64_t step) - : ScopeBaseFeature(), node_type_(nodeType), num_(num), step_(step) {} - ~NodeOpTypeFeatureImpl() override = default; - bool Match(const Scope *const scope) override; - -private: - std::string node_type_; // Node type - int64_t num_; // Node number - int64_t step_; // step - friend class NodeOpTypeFeature; -}; - -class NodeAttrFeature::NodeAttrFeatureImpl : ScopeBaseFeature { - public: - NodeAttrFeatureImpl(const std::string &nodeType, const std::string &attr_name, const ge::DataType datatype, - const ScopeAttrValue &attr_value) - : ScopeBaseFeature(), node_type_(nodeType), attr_name_(attr_name), datatype_(datatype), - attr_value_(attr_value) {} - ~NodeAttrFeatureImpl() override = default; - bool Match(const Scope *scope) override; - Status CheckNodeAttrFeatureData(const bool init_value, const ge::OpDescPtr &op_desc, const Scope *const scope); - Status CheckNodeAttrFeatureData(const std::string &init_value, - const ge::OpDescPtr &op_desc, const Scope *const scope); - Status CheckNodeAttrFeatureData(const int64_t init_value, const ge::OpDescPtr &op_desc, const Scope *const scope); - Status CheckNodeAttrFeatureData(const float32_t init_value, const ge::OpDescPtr &op_desc, const Scope *const scope); - template - typename std::enable_if::is_integer, bool>::type FloatIsEqual(const T x, const T y) const - { - // It is used for floating point comparisons. - // It mainly uses relative precision to judge whether floating-point numbers are equal. - // the 2 is ULPs - return (std::fabs(x - y) <= (std::numeric_limits::epsilon() * std::fabs(x + y) * kCompareRatio)) || - (std::fabs(x - y) < std::numeric_limits::min()); - } - -private: - std::string node_type_; // Node type - std::string attr_name_; // attribute name - ge::DataType datatype_; // datatype - ScopeAttrValue attr_value_; // AttrValue - friend class NodeAttrFeature; -}; - -class ScopeFeature::ScopeFeatureImpl : ScopeBaseFeature { - public: - ScopeFeatureImpl(const std::string &sub_type, const int32_t num, const std::string &suffix, - const std::string &sub_scope_mask, const int64_t step) - : ScopeBaseFeature(), sub_type_(sub_type), num_(num), suffix_(suffix), sub_scope_mask_(sub_scope_mask), - step_(step) {} - ~ScopeFeatureImpl() override = default; - bool Match(const Scope *const scope) override; - bool SubScopesMatch(const std::vector &scopes); - - private: - std::string sub_type_; - int32_t num_; - std::string suffix_; - std::string sub_scope_mask_; - int64_t step_; - friend class ScopeFeature; -}; - -class ScopePattern::ScopePatternImpl { - public: - ScopePatternImpl() {} - ~ScopePatternImpl() = default; - bool Match(const Scope *scope) const; - void SetSubType(const std::string &sub_type); - const std::string &SubType() const { return sub_type_; } - void AddNodeOpTypeFeature(NodeOpTypeFeature &feature); - void AddNodeAttrFeature(NodeAttrFeature &feature); - void AddScopeFeature(ScopeFeature &feature); - - private: - std::string sub_type_; // get Scope sub type - std::vector node_optype_features_; - std::vector node_attr_features_; - std::vector scopes_features_; -}; -} // namespace ge -#endif // REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H diff --git a/inc/register/stream_manage_func_registry.h b/inc/register/stream_manage_func_registry.h deleted file mode 100644 index b053d633f8992c541ea83357040dd7ec4ac516f6..0000000000000000000000000000000000000000 --- a/inc/register/stream_manage_func_registry.h +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_STREAM_MANAGE_FUNC_REGISTRY_H -#define INC_REGISTER_STREAM_MANAGE_FUNC_REGISTRY_H - -#include -#include -#include "runtime/rt.h" -#include "common/ge_common/debug/ge_log.h" - -namespace ge { -// acl action types -enum class MngActionType : uint32_t { - DESTROY_STREAM, - DESTROY_CONTEXT, - RESET_DEVICE, -}; - -typedef union { - rtStream_t stream; - rtContext_t context; - int32_t device_id; -} MngResourceHandle; - -enum class StreamMngFuncType : uint32_t { - ACLNN_STREAM_CALLBACK, // aclnn callback function for destroying sub-stream -}; - -using StreamMngFunc = uint32_t (*)(MngActionType action_type, MngResourceHandle handle); - -class StreamMngFuncRegistry { - public: - static StreamMngFuncRegistry &GetInstance(); - Status TryCallStreamMngFunc(const StreamMngFuncType func_type, MngActionType action_type, MngResourceHandle handle); - void Register(const StreamMngFuncType func_type, StreamMngFunc const manage_func); - StreamMngFunc LookUpStreamMngFunc(const StreamMngFuncType func_type); - - StreamMngFuncRegistry(const StreamMngFuncRegistry &other) = delete; - StreamMngFuncRegistry &operator=(const StreamMngFuncRegistry &other) = delete; - - private: - StreamMngFuncRegistry() = default; - ~StreamMngFuncRegistry() = default; - - std::mutex mutex_; - std::map type_to_func_; -}; - -class StreamMngFuncRegister { - public: - StreamMngFuncRegister(const StreamMngFuncType func_type, StreamMngFunc const manage_func); -}; -} // namespace ge - -#ifdef __GNUC__ -#define ATTRIBUTE_USED __attribute__((used)) -#else -#define ATTRIBUTE_USED -#endif -#define REG_STREAM_MNG_FUNC(type, func) REG_STREAM_MNG_FUNC_UNIQ_HELPER(type, func, __COUNTER__) -#define REG_STREAM_MNG_FUNC_UNIQ_HELPER(type, func, counter) REG_STREAM_MNG_FUNC_UNIQ(type, func, counter) -#define REG_STREAM_MNG_FUNC_UNIQ(type, func, counter) \ - static ::ge::StreamMngFuncRegister register_stream_mng_func_##counter ATTRIBUTE_USED = \ - ge::StreamMngFuncRegister(type, func) - -#endif // INC_REGISTER_STREAM_MANAGE_FUNC_REGISTRY_H diff --git a/inc/register/tensor_assign.h b/inc/register/tensor_assign.h deleted file mode 100644 index d3e08360f6ef7496e473e107b72a7ae9c583acbe..0000000000000000000000000000000000000000 --- a/inc/register/tensor_assign.h +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef TENSOR_ASSIGN_H -#define TENSOR_ASSIGN_H - -#include -#include "graph/ge_tensor.h" -#include "graph/def_types.h" -#include "common/checker.h" -#include "common/ge_common/debug/ge_log.h" -#include "external/register/register_error_codes.h" -#include "external/utils/extern_math_util.h" -#include "proto/tensorflow/tensor.pb.h" - -namespace domi { -using GeTensorPtr = std::shared_ptr; -using Status = uint32_t; -constexpr int64_t kComplexWidth = 2; - -class TensorAssign { - public: - static Status SetGeTensor(const domi::tensorflow::TensorProto &tensor, GeTensorPtr &weight); - - static Status SetGeTensorDataType(const int64_t data_type, GeTensorPtr &weight); - - static ge::DataType ConvertTensorflowDataType(const uint32_t tf_data_type); - - private: - static bool CheckBoolVal(const tensorflow::DataType data_type); - - static bool CheckHalfVal(const tensorflow::DataType data_type); - - static bool CheckFloatVal(const tensorflow::DataType data_type); - - static bool CheckDoubleVal(const tensorflow::DataType data_type); - - static bool CheckComplex32Val(const tensorflow::DataType data_type); - - static bool CheckComplex64Val(const tensorflow::DataType data_type); - - static bool CheckComplex128Val(const tensorflow::DataType data_type); - - static bool CheckStringVal(const tensorflow::DataType data_type); - - static bool CheckByte(const tensorflow::DataType data_type); - - static bool CheckDoubleByte(const tensorflow::DataType data_type); - - static bool CheckSignedFourByte(const tensorflow::DataType data_type); - - static bool CheckUnsignedFourByte(const tensorflow::DataType data_type); - - static bool CheckSignedEightByte(const tensorflow::DataType data_type); - - static bool CheckUnsignedEightByte(const tensorflow::DataType data_type); - - static Status GetDoubleByteVal(const int64_t val_size, - const google::protobuf::RepeatedField &val_vector, - const int64_t count, GeTensorPtr &weight); - static Status GetByteVal(const int64_t val_size, - const google::protobuf::RepeatedField &val_vector, - const int64_t count, GeTensorPtr &weight); - - static Status GetStringVal(const int64_t val_size, const google::protobuf::RepeatedPtrField &val_vector, - const int64_t count, GeTensorPtr &weight); - - static void SetGeTensorWeightData(const domi::tensorflow::TensorProto &tensor, const int64_t val_size, - const int64_t count, GeTensorPtr &weight); - - static void SetWeightData(const tensorflow::DataType data_type, const int64_t count, - const std::string &tensor_content, GeTensorPtr &weight); - - template - static Status GetVal(const int64_t val_size, const google::protobuf::RepeatedField &val_vector, - const int64_t count, GeTensorPtr &weight, const bool is_complex = false) { - // val_size must be even, and complex value should be an integer multiple of 2 - if (is_complex && ((val_size % kComplexWidth) != 0)) { - GELOGE(FAILED, "complex value should be an integer multiple of 2."); - return FAILED; - } - const std::unique_ptr addr(new (std::nothrow) T[count]()); // Zero init default value - GE_CHECK_NOTNULL(addr); - if (val_size == 0) { - (void)weight->SetData(ge::PtrToPtr(addr.get()), static_cast(count) * sizeof(T)); - return SUCCESS; - } - // Complex numbers are made up of real and imaginary numbers - const bool zerosLike = ((count != val_size) && ((val_size == 1) || (is_complex && (val_size == 2)))); - if ((!zerosLike) && (val_size <= count)) { - for (size_t i = 0UL; i < static_cast(val_size); i++) { - addr[i] = val_vector.Get(static_cast(i)); - } - const int64_t value_r = val_size - 1; - GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_r), true); - if (is_complex) { - // val_vector format is real value, complex value..., here is getting the corresponding value. - // real value and complex value are stored spaced apart, so use 2 and 1 to store in the correct addr. - const int64_t value_l = val_size - kComplexWidth; - GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_l), true); - for (int64_t i = val_size; i < count; i += kComplexWidth) { - addr[static_cast(i)] = val_vector.Get(static_cast(value_l)); - addr[static_cast(i) + 1UL] = val_vector.Get(static_cast(value_r)); - } - } else { - for (int64_t i = val_size; i < count; i++) { - addr[static_cast(i)] = val_vector.Get(static_cast(value_r)); - } - } - } else { - if (is_complex) { - for (int64_t i = 0; i < count; i += kComplexWidth) { - addr[static_cast(i)] = val_vector.Get(0); - addr[static_cast(i) + 1UL] = val_vector.Get(1); - } - } else { - for (int64_t i = 0; i < count; i++) { - addr[static_cast(i)] = val_vector.Get(0); - } - } - } - (void)weight->SetData(ge::PtrToPtr(addr.get()), static_cast(count) * sizeof(T)); - return SUCCESS; - } -}; -} // namespace domi -#endif // TENSOR_ASSIGN_H diff --git a/ops/op_imp.cpp b/ops/op_imp.cpp deleted file mode 100644 index f2ef11313319bbee193846308acd502e547976a9..0000000000000000000000000000000000000000 --- a/ops/op_imp.cpp +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "graph/debug/ge_log.h" - -namespace ge { -namespace { -static graphStatus BroadCastRankAndDim( - const std::vector &x1_shape, const std::vector &x2_shape, const int64_t len_diff, - const std::function &out_shape)> &set_out_shape) { - std::vector y_shape; - y_shape.reserve(x1_shape.size()); - for (size_t i = 0UL; i < static_cast(len_diff); i++) { - y_shape.push_back(x1_shape[i]); - } - for (size_t i = 0UL; i < x2_shape.size(); i++) { - const size_t idx_diff = i + static_cast(len_diff); - if ((x1_shape[idx_diff] != x2_shape[i]) && (std::min(x1_shape[idx_diff], x2_shape[i]) != 1)) { - GE_LOGE("operands could not be broadcast together"); - return GRAPH_FAILED; - } - y_shape.push_back(std::max(x1_shape[idx_diff], x2_shape[i])); - } - set_out_shape(y_shape); - return GRAPH_SUCCESS; -} -} // namespace - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -BroadCastInfer(const std::function()> &get_in1_shape, - const std::function()> &get_in2_shape, - const std::function &out_shape)> &set_out_shape) { - const auto x1_shape = get_in1_shape(); - const auto x2_shape = get_in2_shape(); - - if ((x1_shape.size() >= static_cast(std::numeric_limits::max())) || - (x2_shape.size() >= static_cast(std::numeric_limits::max()))) { - return GRAPH_FAILED; - } - - if (x1_shape.empty()) { - set_out_shape(x2_shape); - return GRAPH_SUCCESS; - } - if (x2_shape.empty()) { - set_out_shape(x1_shape); - return GRAPH_SUCCESS; - } - - const int64_t len_diff = static_cast(x1_shape.size()) - static_cast(x2_shape.size()); - if (len_diff >= 0) { - return BroadCastRankAndDim(x1_shape, x2_shape, len_diff, set_out_shape); - } else { - return BroadCastRankAndDim(x2_shape, x1_shape, std::abs(len_diff), set_out_shape); - } -} -} // namespace ge diff --git a/proto/CMakeLists.txt b/proto/CMakeLists.txt deleted file mode 100644 index 82349b59a45fb4ccd24ccb7d628c156e62c6281e..0000000000000000000000000000000000000000 --- a/proto/CMakeLists.txt +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -set(METADEF_PROTO_LIST - "${METADEF_DIR}/proto/task.proto" - "${METADEF_DIR}/proto/om.proto" - "${METADEF_DIR}/proto/ge_ir.proto" - "${METADEF_DIR}/proto/insert_op.proto" - "${METADEF_DIR}/proto/dump_task.proto" - "${METADEF_DIR}/proto/fwk_adapter.proto" - "${METADEF_DIR}/proto/op_mapping.proto" - "${METADEF_DIR}/proto/onnx/ge_onnx.proto" - "${METADEF_DIR}/proto/tensorflow/attr_value.proto" - "${METADEF_DIR}/proto/tensorflow/function.proto" - "${METADEF_DIR}/proto/tensorflow/graph.proto" - "${METADEF_DIR}/proto/tensorflow/node_def.proto" - "${METADEF_DIR}/proto/tensorflow/op_def.proto" - "${METADEF_DIR}/proto/tensorflow/resource_handle.proto" - "${METADEF_DIR}/proto/tensorflow/tensor.proto" - "${METADEF_DIR}/proto/tensorflow/tensor_shape.proto" - "${METADEF_DIR}/proto/tensorflow/types.proto" - "${METADEF_DIR}/proto/tensorflow/versions.proto" - "${METADEF_DIR}/proto/var_manager.proto" - "${METADEF_DIR}/proto/flow_model.proto" - "${METADEF_DIR}/proto/aicpu/cpu_attr.proto" - "${METADEF_DIR}/proto/aicpu/cpu_node_def.proto" - "${METADEF_DIR}/proto/aicpu/cpu_tensor_shape.proto" - "${METADEF_DIR}/proto/aicpu/cpu_tensor.proto" - "${METADEF_DIR}/proto/ascendc_ir.proto" - ) - -set(METADEF_ATTR_GROUP_PROTO_LIST - "${METADEF_DIR}/proto/attr_group_base.proto" - ) - -protobuf_generate(metadef_protos METADEF_PROTO_SRCS METADEF_PROTO_HDRS ${METADEF_PROTO_LIST} ${METADEF_ATTR_GROUP_PROTO_LIST} TARGET) - -set(METADEF_GRAPH_PROTO_SRCS - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/om.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/ge_ir.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/ascendc_ir.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/insert_op.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/task.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/dump_task.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/fwk_adapter.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/op_mapping.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/onnx/ge_onnx.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/var_manager.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/flow_model.pb.cc" - ) - -set(METADEF_TENSORFLOW_PROTO_SRCS - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/graph.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/node_def.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/function.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/versions.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/attr_value.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/op_def.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/tensor.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/tensor_shape.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/types.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/tensorflow/resource_handle.pb.cc" - ) - -set(METADEF_AICPU_PROTO_SRCS - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/aicpu/cpu_attr.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/aicpu/cpu_node_def.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/aicpu/cpu_tensor_shape.pb.cc" - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/aicpu/cpu_tensor.pb.cc" - ) - -set(METADEF_ATTR_GROUP_PROTO_SRCS - "${CMAKE_BINARY_DIR}/proto/metadef_protos/proto/attr_group_base.pb.cc" - ) - -add_library(metadef_graph_protos_obj OBJECT ${METADEF_GRAPH_PROTO_SRCS}) -add_dependencies(metadef_graph_protos_obj metadef_protos) -target_compile_definitions(metadef_graph_protos_obj PRIVATE - google=ascend_private - ) -target_link_libraries(metadef_graph_protos_obj PRIVATE ascend_protobuf intf_pub) -target_compile_options(metadef_graph_protos_obj PRIVATE - $<$:-O2 -fPIC -Wextra -Wfloat-equal> - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$,$>:/MTd> - $<$,$>:/MT> - ) - -add_library(metadef_tensorflow_protos_obj OBJECT ${METADEF_TENSORFLOW_PROTO_SRCS}) -add_dependencies(metadef_tensorflow_protos_obj metadef_protos) -target_compile_definitions(metadef_tensorflow_protos_obj PRIVATE - google=ascend_private - ) -target_link_libraries(metadef_tensorflow_protos_obj PRIVATE ascend_protobuf intf_pub) -target_compile_options(metadef_tensorflow_protos_obj PRIVATE - $<$:-O2 -fPIC -Wextra -Wfloat-equal> - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$,$>:/MTd> - $<$,$>:/MT> - ) - -add_library(metadef_aicpu_protos_obj OBJECT ${METADEF_AICPU_PROTO_SRCS}) -add_dependencies(metadef_aicpu_protos_obj metadef_protos) -target_compile_definitions(metadef_aicpu_protos_obj PRIVATE - google=ascend_private - ) -target_link_libraries(metadef_aicpu_protos_obj PRIVATE ascend_protobuf intf_pub) -target_compile_options(metadef_aicpu_protos_obj PRIVATE - $<$:-O2 -fPIC -Wextra -Wfloat-equal> - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$,$>:/MTd> - $<$,$>:/MT> - ) - -add_library(metadef_attr_group_protos_obj OBJECT ${METADEF_ATTR_GROUP_PROTO_SRCS}) -add_dependencies(metadef_attr_group_protos_obj metadef_protos) -target_compile_definitions(metadef_attr_group_protos_obj PRIVATE - google=ascend_private - ) -target_link_libraries(metadef_attr_group_protos_obj PRIVATE ascend_protobuf intf_pub) -target_compile_options(metadef_attr_group_protos_obj PRIVATE - $<$:-O2 -fPIC -Wextra -Wfloat-equal> - $<$,$>:-fexceptions> - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$,$>:/MTd> - $<$,$>:/MT> - ) - -############ stub/libop_common.so ############ -set(STUB_SRC_LIST - ${CMAKE_CURRENT_BINARY_DIR}/stub_common_infershape_fns.cc -) - -add_custom_command( - OUTPUT ${STUB_SRC_LIST} - COMMAND echo "Generating stub files." - && ${HI_PYTHON} ${CMAKE_CURRENT_LIST_DIR}/stub/gen_stubapi.py ${METADEF_DIR}/inc/external ${CMAKE_CURRENT_BINARY_DIR} - && echo "Generating stub files end." -) - -add_custom_target(op_common_stub DEPENDS ${STUB_SRC_LIST}) - -add_library(stub_op_common SHARED ${STUB_SRC_LIST}) - -add_dependencies(stub_op_common op_common_stub) - -target_include_directories(stub_op_common PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} -) - -target_compile_options(stub_op_common PRIVATE - -Wfloat-equal - -fno-common - -Werror=return-type -) - -target_link_libraries(stub_op_common - PRIVATE - intf_pub - c_sec - PUBLIC - metadef_headers -) - -set_target_properties(stub_op_common PROPERTIES - OUTPUT_NAME op_common - LIBRARY_OUTPUT_DIRECTORY stub -) - -############ install ############ -install(TARGETS stub_op_common OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR}/stub -) diff --git a/proto/aicpu/cpu_attr.proto b/proto/aicpu/cpu_attr.proto deleted file mode 100644 index 9c15e0db2cf7525a8d804fe9700d3d0def268bb8..0000000000000000000000000000000000000000 --- a/proto/aicpu/cpu_attr.proto +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package aicpuops; -import "cpu_tensor.proto"; -import "cpu_tensor_shape.proto"; - -message AttrValue { - - message ArrayValue { - repeated bytes s = 2; //"array(string)" - repeated int64 i = 3 [ packed = true ]; //"array(int)" - repeated float f = 4 [ packed = true ]; //"array(float)" - repeated bool b = 5 [ packed = true ]; //"array(bool)" - repeated int32 type = 6 [ packed = true ]; //"array(type)" - repeated TensorShape shape = 7; //"array(shape)" - repeated Tensor tensor = 8; //"array(tensor)" - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - - oneof value { - ArrayValue array = 1; - bytes s = 2; //"string" - int64 i = 3; //"int" - float f = 4; //"float" - bool b = 5; //"bool" - int32 type = 6; //"type" - TensorShape shape = 7; //"shape" - Tensor tensor = 8; //"tensor" - ListListInt list_list_int = 9; // List List Int type - } -} diff --git a/proto/aicpu/cpu_node_def.proto b/proto/aicpu/cpu_node_def.proto deleted file mode 100644 index 9b1c5aa61d1d65b7a57d8781ccaf1f570ad4f348..0000000000000000000000000000000000000000 --- a/proto/aicpu/cpu_node_def.proto +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package aicpuops; -import "cpu_attr.proto"; -import "cpu_tensor.proto"; - -message DynamicIdx { - int32 idx = 1; - int32 num = 2; -} - -message NodeDef { - string op = 2; - map attrs = 3; - repeated Tensor inputs = 4; - repeated Tensor outputs = 5; - map dym_inputs = 6; - map dym_outputs = 7; -} diff --git a/proto/aicpu/cpu_tensor.proto b/proto/aicpu/cpu_tensor.proto deleted file mode 100644 index 72c3d6698b9e56ca389262c94e6882b7707697be..0000000000000000000000000000000000000000 --- a/proto/aicpu/cpu_tensor.proto +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -option cc_enable_arenas = true; -import "cpu_tensor_shape.proto"; -package aicpuops; - -message Tensor { - - // tensor shape info - TensorShape tensor_shape = 1; - - // tensor content data type - int32 tensor_type = 2; - - // tensor memory device - // data located memory device , "DDR" "HBM" OR "NONE" - string mem_device = 3; - string name = 4; - uint64 data_ptr = 5; - uint64 data_size = 6; -} diff --git a/proto/aicpu/cpu_tensor_shape.proto b/proto/aicpu/cpu_tensor_shape.proto deleted file mode 100644 index 8fe274a38e01600ea309243a22bd0b537e5a0e4e..0000000000000000000000000000000000000000 --- a/proto/aicpu/cpu_tensor_shape.proto +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package aicpuops; - -message TensorShape { - // One dimension of the tensor. - message Dim { - // size must >=0 - int64 size = 1; - }; - - // group dim info - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; - - // data format "NHWC" "NCHW" "NC1HWC0" OR "NONE" - int32 data_format = 4; -}; diff --git a/proto/ascendc_ir.proto b/proto/ascendc_ir.proto deleted file mode 100644 index 09d72f04d489893ed5a7aff814cd9d4742f3bff3..0000000000000000000000000000000000000000 --- a/proto/ascendc_ir.proto +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., 2025 Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package ascendc_ir.proto; - -import "ge_ir.proto"; -message AscTensorDef { - ge.proto.AscTensorAttrGroupsDef attr = 1; -} - -message AscInputSourceDef { - string src_node_name = 1; - int32 src_out_index = 2; -} - -message IrDef { - repeated string input_names = 1; // IR上的输入名 - repeated string output_names = 2; // IR上的输出名 - repeated int64 input_ir_type = 3; // IR上的输入类型 - repeated int64 output_ir_type = 4; // IR上的输出类型 - string type = 5; // IR的类型 - repeated int64 input_nums = 6; // IR的输入个数 -} - -message AscNodeDef { - repeated AscInputSourceDef input_src = 1; - repeated AscTensorDef outputs = 2; - ge.proto.AscNodeAttrGroupsDef attr = 3; - IrDef ir_def = 4; -} - -message AscGraphDef { - ge.proto.AscGraphAttrGroupsDef asc_graph_attr = 1; - repeated AscNodeDef asc_node = 2; - string graph_name = 3; -} \ No newline at end of file diff --git a/proto/attr_group_base.proto b/proto/attr_group_base.proto deleted file mode 100644 index ef2d43b07e820962b32e85f4bae7366ad5aa5c91..0000000000000000000000000000000000000000 --- a/proto/attr_group_base.proto +++ /dev/null @@ -1,13 +0,0 @@ -syntax = "proto3"; -import "ge_ir.proto"; - -message OtherGroupDef { - map attr = 1; -} - -message AttrGroupsDef { - OtherGroupDef other_group_def = 1; - oneof attr_group { - ge.proto.AscendCIROpAttrGroupsDef op_attr_group = 2; - } -} diff --git a/proto/caffe/caffe.proto b/proto/caffe/caffe.proto deleted file mode 100644 index 73e885d1211731d4c8a5233435c12fd5fc0b8c8c..0000000000000000000000000000000000000000 --- a/proto/caffe/caffe.proto +++ /dev/null @@ -1,1836 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto2"; - -package domi.caffe; - -// Specifies the shape (dimensions) of a Blob. -message BlobShape { - repeated int64 dim = 1 [packed = true]; -} - -message BlobProto { - optional BlobShape shape = 7; - repeated float data = 5 [packed = true]; - repeated float diff = 6 [packed = true]; - repeated double double_data = 8 [packed = true]; - repeated double double_diff = 9 [packed = true]; - optional bytes int8_data = 10; - repeated int32 int32_data = 11 [packed = true]; - repeated uint64 uint64_data = 12 [packed = true]; - // 4D dimensions -- deprecated. Use "shape" instead. - optional int32 num = 1 [default = 0]; - optional int32 channels = 2 [default = 0]; - optional int32 height = 3 [default = 0]; - optional int32 width = 4 [default = 0]; -} - -// The BlobProtoVector is simply a way to pass multiple blobproto instances -// around. -message BlobProtoVector { - repeated BlobProto blobs = 1; -} - -message Datum { - optional int32 channels = 1; - optional int32 height = 2; - optional int32 width = 3; - // the actual image data, in bytes - optional bytes data = 4; - optional int32 label = 5; - // Optionally, the datum could also hold float data. - repeated float float_data = 6; - // If true data contains an encoded image that need to be decoded - optional bool encoded = 7 [default = false]; -} - -message FillerParameter { - // The filler type. - optional string type = 1 [default = 'constant']; - optional float value = 2 [default = 0]; // the value in constant filler - optional float min = 3 [default = 0]; // the min value in uniform filler - optional float max = 4 [default = 1]; // the max value in uniform filler - optional float mean = 5 [default = 0]; // the mean value in Gaussian filler - optional float std = 6 [default = 1]; // the std value in Gaussian filler - // The expected number of non-zero output weights for a given input in - // Gaussian filler -- the default -1 means don't perform sparsification. - optional int32 sparse = 7 [default = -1]; - // Normalize the filler variance by fan_in, fan_out, or their average. - // Applies to 'xavier' and 'msra' fillers. - enum VarianceNorm { - FAN_IN = 0; - FAN_OUT = 1; - AVERAGE = 2; - } - optional VarianceNorm variance_norm = 8 [default = FAN_IN]; -} - -message NetParameter { - optional string name = 1; // consider giving the network a name - // DEPRECATED. See InputParameter. The input blobs to the network. - repeated string input = 3; - // DEPRECATED. See InputParameter. The shape of the input blobs. - repeated BlobShape input_shape = 8; - - // 4D input dimensions -- deprecated. Use "input_shape" instead. - // If specified, for each input blob there should be four - // values specifying the num, channels, height and width of the input blob. - // Thus, there should be a total of (4 * #input) numbers. - repeated int32 input_dim = 4; - - // Whether the network will force every layer to carry out backward operation. - // If set False, then whether to carry out backward is determined - // automatically according to the net structure and learning rates. - optional bool force_backward = 5 [default = false]; - // The current "state" of the network, including the phase, level, and stage. - // Some layers may be included/excluded depending on this state and the states - // specified in the layers' include and exclude fields. - optional NetState state = 6; - - // Print debugging information about results while running Net::Forward, - // Net::Backward, and Net::Update. - optional bool debug_info = 7 [default = false]; - - // The layers that make up the net. Each of their configurations, including - // connectivity and behavior, is specified as a LayerParameter. - repeated LayerParameter layer = 100; // ID 100 so layers are printed last. - - // DEPRECATED: use 'layer' instead. - repeated V1LayerParameter layers = 2; -} - -// NOTE -// Update the next available ID when you add a new SolverParameter field. -// -// SolverParameter next available ID: 42 (last added: layer_wise_reduce) -message SolverParameter { - ////////////////////////////////////////////////////////////////////////////// - // Specifying the train and test networks - // - // Exactly one train net must be specified using one of the following fields: - // train_net_param, train_net, net_param, net - // One or more test nets may be specified using any of the following fields: - // test_net_param, test_net, net_param, net - // If more than one test net field is specified (e.g., both net and - // test_net are specified), they will be evaluated in the field order given - // above: (1) test_net_param, (2) test_net, (3) net_param/net. - // A test_iter must be specified for each test_net. - // A test_level and/or a test_stage may also be specified for each test_net. - ////////////////////////////////////////////////////////////////////////////// - - // Proto filename for the train net, possibly combined with one or more - // test nets. - optional string net = 24; - // Inline train net param, possibly combined with one or more test nets. - optional NetParameter net_param = 25; - - optional string train_net = 1; // Proto filename for the train net. - repeated string test_net = 2; // Proto filenames for the test nets. - optional NetParameter train_net_param = 21; // Inline train net params. - repeated NetParameter test_net_param = 22; // Inline test net params. - - // The states for the train/test nets. Must be unspecified or - // specified once per net. - // - // By default, all states will have solver = true; - // train_state will have phase = TRAIN, - // and all test_state's will have phase = TEST. - // Other defaults are set according to the NetState defaults. - optional NetState train_state = 26; - repeated NetState test_state = 27; - - // The number of iterations for each test net. - repeated int32 test_iter = 3; - - // The number of iterations between two testing phases. - optional int32 test_interval = 4 [default = 0]; - optional bool test_compute_loss = 19 [default = false]; - // If true, run an initial test pass before the first iteration, - // ensuring memory availability and printing the starting value of the loss. - optional bool test_initialization = 32 [default = true]; - optional float base_lr = 5; // The base learning rate - // the number of iterations between displaying info. If display = 0, no info - // will be displayed. - optional int32 display = 6; - // Display the loss averaged over the last average_loss iterations - optional int32 average_loss = 33 [default = 1]; - optional int32 max_iter = 7; // the maximum number of iterations - // accumulate gradients over `iter_size` x `batch_size` instances - optional int32 iter_size = 36 [default = 1]; - - // The learning rate decay policy. The currently implemented learning rate - // policies are as follows: - // - fixed: always return base_lr. - // - step: return base_lr * gamma ^ (floor(iter / step)) - // - exp: return base_lr * gamma ^ iter - // - inv: return base_lr * (1 + gamma * iter) ^ (- power) - // - multistep: similar to step but it allows non uniform steps defined by - // stepvalue - // - poly: the effective learning rate follows a polynomial decay, to be - // zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) - // - sigmoid: the effective learning rate follows a sigmod decay - // return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) - // - // where base_lr, max_iter, gamma, step, stepvalue and power are defined - // in the solver parameter protocol buffer, and iter is the current iteration. - optional string lr_policy = 8; - optional float gamma = 9; // The parameter to compute the learning rate. - optional float power = 10; // The parameter to compute the learning rate. - optional float momentum = 11; // The momentum value. - optional float weight_decay = 12; // The weight decay. - // regularization types supported: L1 and L2 - // controlled by weight_decay - optional string regularization_type = 29 [default = "L2"]; - // the stepsize for learning rate policy "step" - optional int32 stepsize = 13; - // the stepsize for learning rate policy "multistep" - repeated int32 stepvalue = 34; - - // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, - // whenever their actual L2 norm is larger. - optional float clip_gradients = 35 [default = -1]; - - optional int32 snapshot = 14 [default = 0]; // The snapshot interval - optional string snapshot_prefix = 15; // The prefix for the snapshot. - // whether to snapshot diff in the results or not. Snapshotting diff will help - // debugging but the final protocol buffer size will be much larger. - optional bool snapshot_diff = 16 [default = false]; - enum SnapshotFormat { - HDF5 = 0; - BINARYPROTO = 1; - } - optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO]; - // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. - enum SolverMode { - CPU = 0; - GPU = 1; - } - optional SolverMode solver_mode = 17 [default = GPU]; - // the device_id will that be used in GPU mode. Use device_id = 0 in default. - optional int32 device_id = 18 [default = 0]; - // If non-negative, the seed with which the Solver will initialize the Caffe - // random number generator -- useful for reproducible results. Otherwise, - // (and by default) initialize using a seed derived from the system clock. - optional int64 random_seed = 20 [default = -1]; - - // type of the solver - optional string type = 40 [default = "SGD"]; - - // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam - optional float delta = 31 [default = 1e-8]; - // parameters for the Adam solver - optional float momentum2 = 39 [default = 0.999]; - - // RMSProp decay value - // MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t) - optional float rms_decay = 38 [default = 0.99]; - - // If true, print information about the state of the net that may help with - // debugging learning problems. - optional bool debug_info = 23 [default = false]; - - // If false, don't save a snapshot after training finishes. - optional bool snapshot_after_train = 28 [default = true]; - - // DEPRECATED: old solver enum types, use string instead - enum SolverType { - SGD = 0; - NESTEROV = 1; - ADAGRAD = 2; - RMSPROP = 3; - ADADELTA = 4; - ADAM = 5; - } - // DEPRECATED: use type instead of solver_type - optional SolverType solver_type = 30 [default = SGD]; - - // Overlap compute and communication for data parallel training - optional bool layer_wise_reduce = 41 [default = true]; -} - -// A message that stores the solver snapshots -message SolverState { - optional int32 iter = 1; // The current iteration - optional string learned_net = 2; // The file that stores the learned net. - repeated BlobProto history = 3; // The history for sgd solvers - optional int32 current_step = 4 [default = 0]; // The current step for learning rate -} - -enum Phase { - TRAIN = 0; - TEST = 1; -} - -message NetState { - optional Phase phase = 1 [default = TEST]; - optional int32 level = 2 [default = 0]; - repeated string stage = 3; -} - -message NetStateRule { - // Set phase to require the NetState have a particular phase (TRAIN or TEST) - // to meet this rule. - optional Phase phase = 1; - - // Set the minimum and/or maximum levels in which the layer should be used. - // Leave undefined to meet the rule regardless of level. - optional int32 min_level = 2; - optional int32 max_level = 3; - - // Customizable sets of stages to include or exclude. - // The net must have ALL of the specified stages and NONE of the specified - // "not_stage"s to meet the rule. - // (Use multiple NetStateRules to specify conjunctions of stages.) - repeated string stage = 4; - repeated string not_stage = 5; -} - -// Specifies training parameters (multipliers on global learning constants, -// and the name and other settings used for weight sharing). -message ParamSpec { - // The names of the parameter blobs -- useful for sharing parameters among - // layers, but never required otherwise. To share a parameter between two - // layers, give it a (non-empty) name. - optional string name = 1; - - // Whether to require shared weights to have the same shape, or just the same - // count -- defaults to STRICT if unspecified. - optional DimCheckMode share_mode = 2; - enum DimCheckMode { - // STRICT (default) requires that num, channels, height, width each match. - STRICT = 0; - // PERMISSIVE requires only the count (num*channels*height*width) to match. - PERMISSIVE = 1; - } - - // The multiplier on the global learning rate for this parameter. - optional float lr_mult = 3 [default = 1.0]; - - // The multiplier on the global weight decay for this parameter. - optional float decay_mult = 4 [default = 1.0]; -} - -// NOTE -// Update the next available ID when you add a new LayerParameter field. -// -// LayerParameter next available layer-specific ID: 151 (last added: smooth_l1_loss_param) -message LayerParameter { - optional string name = 1; // the layer name - optional string type = 2; // the layer type - repeated string bottom = 3; // the name of each bottom blob - repeated string top = 4; // the name of each top blob - - // The train / test phase for computation. - optional Phase phase = 10; - - // The amount of weight to assign each top blob in the objective. - // Each layer assigns a default value, usually of either 0 or 1, - // to each top blob. - repeated float loss_weight = 5; - - // Specifies training parameters (multipliers on global learning constants, - // and the name and other settings used for weight sharing). - repeated ParamSpec param = 6; - - // The blobs containing the numeric parameters of the layer. - repeated BlobProto blobs = 7; - - // Specifies whether to backpropagate to each bottom. If unspecified, - // Caffe will automatically infer whether each input needs backpropagation - // to compute parameter gradients. If set to true for some inputs, - // backpropagation to those inputs is forced; if set false for some inputs, - // backpropagation to those inputs is skipped. - // - // The size must be either 0 or equal to the number of bottoms. - repeated bool propagate_down = 11; - - // Rules controlling whether and when a layer is included in the network, - // based on the current NetState. You may specify a non-zero number of rules - // to include OR exclude, but not both. If no include or exclude rules are - // specified, the layer is always included. If the current NetState meets - // ANY (i.e., one or more) of the specified rules, the layer is - // included/excluded. - repeated NetStateRule include = 8; - repeated NetStateRule exclude = 9; - - // Parameters for data pre-processing. - optional TransformationParameter transform_param = 100; - - // Parameters shared by loss layers. - optional LossParameter loss_param = 101; - - // Layer type-specific parameters. - // - // Note: certain layers may have more than one computational engine - // for their implementation. These layers include an Engine type and - // engine parameter for selecting the implementation. - // The default for the engine is set by the ENGINE switch at compile-time. - optional AccuracyParameter accuracy_param = 102; - optional ArgMaxParameter argmax_param = 103; - optional BatchNormParameter batch_norm_param = 139; - optional BiasParameter bias_param = 141; - optional ConcatParameter concat_param = 104; - optional ContrastiveLossParameter contrastive_loss_param = 105; - optional ConvolutionParameter convolution_param = 106; - optional CropParameter crop_param = 144; - optional DataParameter data_param = 107; - optional DetectionOutputParameter detection_output_param = 150; - optional DropoutParameter dropout_param = 108; - optional DummyDataParameter dummy_data_param = 109; - optional EltwiseParameter eltwise_param = 110; - optional ELUParameter elu_param = 140; - optional EmbedParameter embed_param = 137; - optional ExpParameter exp_param = 111; - optional FlattenParameter flatten_param = 135; - optional HDF5DataParameter hdf5_data_param = 112; - optional HDF5OutputParameter hdf5_output_param = 113; - optional HingeLossParameter hinge_loss_param = 114; - optional ImageDataParameter image_data_param = 115; - optional InfogainLossParameter infogain_loss_param = 116; - optional InnerProductParameter inner_product_param = 117; - optional InputParameter input_param = 143; - optional LogParameter log_param = 134; - optional LRNParameter lrn_param = 118; - optional MemoryDataParameter memory_data_param = 119; - optional MVNParameter mvn_param = 120; - optional ParameterParameter parameter_param = 145; - optional PoolingParameter pooling_param = 121; - optional PowerParameter power_param = 122; - optional PReLUParameter prelu_param = 131; - optional PythonParameter python_param = 130; - optional RecurrentParameter recurrent_param = 146; - optional ReductionParameter reduction_param = 136; - optional ReLUParameter relu_param = 123; - optional ReshapeParameter reshape_param = 133; - optional ScaleParameter scale_param = 142; - optional SigmoidParameter sigmoid_param = 124; - optional SmoothL1LossParameter smooth_l1_loss_param = 148; - optional SoftmaxParameter softmax_param = 125; - optional SPPParameter spp_param = 132; - optional SliceParameter slice_param = 126; - optional TanHParameter tanh_param = 127; - optional ThresholdParameter threshold_param = 128; - optional TileParameter tile_param = 138; - optional WindowDataParameter window_data_param = 129; - optional PermuteParameter permute_param = 202; - optional PriorBoxParameter prior_box_param = 203; - optional NormalizeParameter norm_param = 206; - optional PSROIPoolingParameter psroi_pooling_param = 207; - optional FreespaceExtractParameter freespace_extract_param = 151; - optional PostprocessParameter postprocess_param = 152; - optional SpatialTransformParameter spatial_transform_param = 153; - optional ROIAlignParameter roi_align_param = 154; - optional ReorgParameter reorg_param = 155; - optional RegionParameter region_param = 156; - optional ReverseParameter reverse_param = 157; - optional InterpParameter interp_param = 158; - optional ShuffleChannelParameter shuffle_channel_param = 159; - optional UpsampleParameter upsample_param = 160; - optional ROIPoolingParameter roi_pooling_param = 161; - optional YoloParameter yolo_param = 199; - optional YoloV3DetectionOutputParameter yolov3_detection_output_param = 200; - optional ProposalParameter proposal_param = 201; - optional FSRDetectionOutputParameter fsrdetectionoutput_param = 222; - optional SSDDetectionOutputParameter ssddetectionoutput_param = 232; - optional YoloV2DetectionOutputParameter yolov2_detection_output_param = 204; - optional QuantParameter quant_param = 208; - optional CondTakeParameter condtake_param = 233; - optional MatrixInverseParameter matrix_inverse_param = 210; - optional WarpPerspectiveParameter warp_perspective_param = 234; - optional BatchMatMulParameter batch_matmul_param = 235; - optional SpatialTransformerParameter st_param = 5000; - optional YoloV3DetectionOutputV2Parameter yolov3_detection_output_v2_param = 5001; - optional ContinuationIndicatorParameter continuation_indicator_param = 5002; -} - -// Message that stores parameters used to apply transformation -// to the data layer's data -message TransformationParameter { - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 1 [default = 1]; - // Specify if we want to randomly mirror data. - optional bool mirror = 2 [default = false]; - // Specify if we would like to randomly crop an image. - optional uint32 crop_size = 3 [default = 0]; - // mean_file and mean_value cannot be specified at the same time - optional string mean_file = 4; - // if specified can be repeated once (would substract it from all the channels) - // or can be repeated the same number of times as channels - // (would subtract them from the corresponding channel) - repeated float mean_value = 5; - // Force the decoded image to have 3 color channels. - optional bool force_color = 6 [default = false]; - // Force the decoded image to have 1 color channels. - optional bool force_gray = 7 [default = false]; -} - -// Message that stores parameters shared by loss layers -message LossParameter { - // If specified, ignore instances with the given label. - optional int32 ignore_label = 1; - // How to normalize the loss for loss layers that aggregate across batches, - // spatial dimensions, or other dimensions. Currently only implemented in - // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers. - enum NormalizationMode { - // Divide by the number of examples in the batch times spatial dimensions. - // Outputs that receive the ignore label will NOT be ignored in computing - // the normalization factor. - FULL = 0; - // Divide by the total number of output locations that do not take the - // ignore_label. If ignore_label is not set, this behaves like FULL. - VALID = 1; - // Divide by the batch size. - BATCH_SIZE = 2; - // Do not normalize the loss. - NONE = 3; - } - // For historical reasons, the default normalization for - // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID. - optional NormalizationMode normalization = 3 [default = VALID]; - // Deprecated. Ignored if normalization is specified. If normalization - // is not specified, then setting this to false will be equivalent to - // normalization = BATCH_SIZE to be consistent with previous behavior. - optional bool normalize = 2; -} - -// Messages that store parameters used by individual layer types follow, in -// alphabetical order. - -message AccuracyParameter { - // When computing accuracy, count as correct by comparing the true label to - // the top k scoring classes. By default, only compare to the top scoring - // class (i.e. argmax). - optional uint32 top_k = 1 [default = 1]; - - // The "label" axis of the prediction blob, whose argmax corresponds to the - // predicted label -- may be negative to index from the end (e.g., -1 for the - // last axis). For example, if axis == 1 and the predictions are - // (N x C x H x W), the label blob is expected to contain N*H*W ground truth - // labels with integer values in {0, 1, ..., C-1}. - optional int32 axis = 2 [default = 1]; - - // If specified, ignore instances with the given label. - optional int32 ignore_label = 3; -} - -message ArgMaxParameter { - // If true produce pairs (argmax, maxval) - optional bool out_max_val = 1 [default = false]; - optional uint32 top_k = 2 [default = 1]; - // The axis along which to maximise -- may be negative to index from the - // end (e.g., -1 for the last axis). - // By default ArgMaxLayer maximizes over the flattened trailing dimensions - // for each index of the first / num dimension. - optional int32 axis = 3; -} - -message ConcatParameter { - // The axis along which to concatenate -- may be negative to index from the - // end (e.g., -1 for the last axis). Other axes must have the - // same dimension for all the bottom blobs. - // By default, ConcatLayer concatenates blobs along the "channels" axis (1). - optional int32 axis = 2 [default = 1]; - - // DEPRECATED: alias for "axis" -- does not support negative indexing. - optional uint32 concat_dim = 1 [default = 1]; -} - -message BatchNormParameter { - // If false, normalization is performed over the current mini-batch - // and global statistics are accumulated (but not yet used) by a moving - // average. - // If true, those accumulated mean and variance values are used for the - // normalization. - // By default, it is set to false when the network is in the training - // phase and true when the network is in the testing phase. - optional bool use_global_stats = 1; - // What fraction of the moving average remains each iteration? - // Smaller values make the moving average decay faster, giving more - // weight to the recent values. - // Each iteration updates the moving average @f$S_{t-1}@f$ with the - // current mean @f$ Y_t @f$ by - // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$ - // is the moving_average_fraction parameter. - optional float moving_average_fraction = 2 [default = .999]; - // Small value to add to the variance estimate so that we don't divide by - // zero. - optional float eps = 3 [default = 1e-5]; -} - -message BiasParameter { - // The first axis of bottom[0] (the first input Blob) along which to apply - // bottom[1] (the second input Blob). May be negative to index from the end - // (e.g., -1 for the last axis). - // - // For example, if bottom[0] is 4D with shape 100x3x40x60, the output - // top[0] will have the same shape, and bottom[1] may have any of the - // following shapes (for the given value of axis): - // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 - // (axis == 1 == -3) 3; 3x40; 3x40x60 - // (axis == 2 == -2) 40; 40x60 - // (axis == 3 == -1) 60 - // Furthermore, bottom[1] may have the empty shape (regardless of the value of - // "axis") -- a scalar bias. - optional int32 axis = 1 [default = 1]; - - // (num_axes is ignored unless just one bottom is given and the bias is - // a learned parameter of the layer. Otherwise, num_axes is determined by the - // number of axes by the second bottom.) - // The number of axes of the input (bottom[0]) covered by the bias - // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. - // Set num_axes := 0, to add a zero-axis Blob: a scalar. - optional int32 num_axes = 2 [default = 1]; - - // (filler is ignored unless just one bottom is given and the bias is - // a learned parameter of the layer.) - // The initialization for the learned bias parameter. - // Default is the zero (0) initialization, resulting in the BiasLayer - // initially performing the identity operation. - optional FillerParameter filler = 3; - optional bool bias_from_blob = 4 [default = true]; -} - -message ContrastiveLossParameter { - // margin for dissimilar pair - optional float margin = 1 [default = 1.0]; - // The first implementation of this cost did not exactly match the cost of - // Hadsell et al 2006 -- using (margin - d^2) instead of (margin - d)^2. - // legacy_version = false (the default) uses (margin - d)^2 as proposed in the - // Hadsell paper. New models should probably use this version. - // legacy_version = true uses (margin - d^2). This is kept to support / - // reproduce existing models and results - optional bool legacy_version = 2 [default = false]; -} - -message ConvolutionParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - optional bool bias_term = 2 [default = true]; // whether to have bias terms - - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in all spatial dimensions, or once per spatial dimension. - repeated uint32 pad = 3; // The padding size; defaults to 0 - repeated uint32 kernel_size = 4; // The kernel size - repeated uint32 stride = 6; // The stride; defaults to 1 - // Factor used to dilate the kernel, (implicitly) zero-filling the resulting - // holes. (Kernel dilation is sometimes referred to by its use in the - // algorithme à trous from Holschneider et al. 1987.) - repeated uint32 dilation = 18; // The dilation; defaults to 1 - - // For 2D convolution only, the *_h and *_w versions may also be used to - // specify both spatial dimensions. - optional uint32 pad_h = 9 [default = 0]; // The padding height (2D only) - optional uint32 pad_w = 10 [default = 0]; // The padding width (2D only) - optional uint32 kernel_h = 11; // The kernel height (2D only) - optional uint32 kernel_w = 12; // The kernel width (2D only) - optional uint32 stride_h = 13; // The stride height (2D only) - optional uint32 stride_w = 14; // The stride width (2D only) - - optional uint32 group = 5 [default = 1]; // The group size for group conv - - optional FillerParameter weight_filler = 7; // The filler for the weight - optional FillerParameter bias_filler = 8; // The filler for the bias - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 15 [default = DEFAULT]; - - // The axis to interpret as "channels" when performing convolution. - // Preceding dimensions are treated as independent inputs; - // succeeding dimensions are treated as "spatial". - // With (N, C, H, W) inputs, and axis == 1 (the default), we perform - // N independent 2D convolutions, sliding C-channel (or (C/g)-channels, for - // groups g>1) filters across the spatial axes (H, W) of the input. - // With (N, C, D, H, W) inputs, and axis == 1, we perform - // N independent 3D convolutions, sliding (C/g)-channels - // filters across the spatial axes (D, H, W) of the input. - optional int32 axis = 16 [default = 1]; - - // Whether to force use of the general ND convolution, even if a specific - // implementation for blobs of the appropriate number of spatial dimensions - // is available. (Currently, there is only a 2D-specific convolution - // implementation; for input blobs with num_axes != 2, this option is - // ignored and the ND implementation will be used.) - optional bool force_nd_im2col = 17 [default = false]; -} - -message CropParameter { - // To crop, elements of the first bottom are selected to fit the dimensions - // of the second, reference bottom. The crop is configured by - // - the crop `axis` to pick the dimensions for cropping - // - the crop `offset` to set the shift for all/each dimension - // to align the cropped bottom with the reference bottom. - // All dimensions up to but excluding `axis` are preserved, while - // the dimensions including and trailing `axis` are cropped. - // If only one `offset` is set, then all dimensions are offset by this amount. - // Otherwise, the number of offsets must equal the number of cropped axes to - // shift the crop in each dimension accordingly. - // Note: standard dimensions are N,C,H,W so the default is a spatial crop, - // and `axis` may be negative to index from the end (e.g., -1 for the last - // axis). - optional int32 axis = 1 [default = 2]; - repeated uint32 offset = 2; -} - -message DataParameter { - enum DB { - LEVELDB = 0; - LMDB = 1; - } - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 4; - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - // DEPRECATED. Each solver accesses a different subset of the database. - optional uint32 rand_skip = 7 [default = 0]; - optional DB backend = 8 [default = LEVELDB]; - // DEPRECATED. See TransformationParameter. For data pre-processing, we can do - // simple scaling and subtracting the data mean, if provided. Note that the - // mean subtraction is always carried out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // DEPRECATED. See TransformationParameter. Specify if we would like to randomly - // crop an image. - optional uint32 crop_size = 5 [default = 0]; - // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror - // data. - optional bool mirror = 6 [default = false]; - // Force the encoded image to have 3 color channels - optional bool force_encoded_color = 9 [default = false]; - // Prefetch queue (Increase if data feeding bandwidth varies, within the - // limit of device memory for GPU training) - optional uint32 prefetch = 10 [default = 4]; -} - -message DropoutParameter { - optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio - optional bool scale_train = 2 [default = true]; // scale train or test phase -} - -// DummyDataLayer fills any number of arbitrarily shaped blobs with random -// (or constant) data generated by "Fillers" (see "message FillerParameter"). -message DummyDataParameter { - // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N - // shape fields, and 0, 1 or N data_fillers. - // - // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. - // If 1 data_filler is specified, it is applied to all top blobs. If N are - // specified, the ith is applied to the ith top blob. - repeated FillerParameter data_filler = 1; - repeated BlobShape shape = 6; - - // 4D dimensions -- deprecated. Use "shape" instead. - repeated uint32 num = 2; - repeated uint32 channels = 3; - repeated uint32 height = 4; - repeated uint32 width = 5; -} - -message EltwiseParameter { - enum EltwiseOp { - PROD = 0; - SUM = 1; - MAX = 2; - } - optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation - repeated float coeff = 2; // blob-wise coefficient for SUM operation - - // Whether to use an asymptotically slower (for >2 inputs) but stabler method - // of computing the gradient for the PROD operation. (No effect for SUM op.) - optional bool stable_prod_grad = 3 [default = true]; -} - -// Message that stores parameters used by ELULayer -message ELUParameter { - // Described in: - // Clevert, D.-A., Unterthiner, T., & Hochreiter, S. (2015). Fast and Accurate - // Deep Network Learning by Exponential Linear Units (ELUs). arXiv - optional float alpha = 1 [default = 1]; -} - -// Message that stores parameters used by EmbedLayer -message EmbedParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - // The input is given as integers to be interpreted as one-hot - // vector indices with dimension num_input. Hence num_input should be - // 1 greater than the maximum possible input value. - optional uint32 input_dim = 2; - - optional bool bias_term = 3 [default = true]; // Whether to use a bias term - optional FillerParameter weight_filler = 4; // The filler for the weight - optional FillerParameter bias_filler = 5; // The filler for the bias - -} - -// Message that stores parameters used by ExpLayer -message ExpParameter { - // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. - // Or if base is set to the default (-1), base is set to e, - // so y = exp(shift + scale * x). - optional float base = 1 [default = -1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -/// Message that stores parameters used by FlattenLayer -message FlattenParameter { - // The first axis to flatten: all preceding axes are retained in the output. - // May be negative to index from the end (e.g., -1 for the last axis). - optional int32 axis = 1 [default = 1]; - - // The last axis to flatten: all following axes are retained in the output. - // May be negative to index from the end (e.g., the default -1 for the last - // axis). - optional int32 end_axis = 2 [default = -1]; -} - -// Message that stores parameters used by HDF5DataLayer -message HDF5DataParameter { - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 2; - - // Specify whether to shuffle the data. - // If shuffle == true, the ordering of the HDF5 files is shuffled, - // and the ordering of data within any given HDF5 file is shuffled, - // but data between different files are not interleaved; all of a file's - // data are output (in a random order) before moving onto another file. - optional bool shuffle = 3 [default = false]; -} - -message HDF5OutputParameter { - optional string file_name = 1; -} - -message HingeLossParameter { - enum Norm { - L1 = 1; - L2 = 2; - } - // Specify the Norm to use L1 or L2 - optional Norm norm = 1 [default = L1]; -} - -message ImageDataParameter { - // Specify the data source. - optional string source = 1; - // Specify the batch size. - optional uint32 batch_size = 4 [default = 1]; - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - optional uint32 rand_skip = 7 [default = 0]; - // Whether or not ImageLayer should shuffle the list of files at every epoch. - optional bool shuffle = 8 [default = false]; - // It will also resize images if new_height or new_width are not zero. - optional uint32 new_height = 9 [default = 0]; - optional uint32 new_width = 10 [default = 0]; - // Specify if the images are color or gray - optional bool is_color = 11 [default = true]; - // DEPRECATED. See TransformationParameter. For data pre-processing, we can do - // simple scaling and subtracting the data mean, if provided. Note that the - // mean subtraction is always carried out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // DEPRECATED. See TransformationParameter. Specify if we would like to randomly - // crop an image. - optional uint32 crop_size = 5 [default = 0]; - // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror - // data. - optional bool mirror = 6 [default = false]; - optional string root_folder = 12 [default = ""]; -} - -message InfogainLossParameter { - // Specify the infogain matrix source. - optional string source = 1; - optional int32 axis = 2 [default = 1]; // axis of prob -} - -message InnerProductParameter { - optional uint32 num_output = 1; // The number of outputs for the layer - optional bool bias_term = 2 [default = true]; // whether to have bias terms - optional FillerParameter weight_filler = 3; // The filler for the weight - optional FillerParameter bias_filler = 4; // The filler for the bias - - // The first axis to be lumped into a single inner product computation; - // all preceding axes are retained in the output. - // May be negative to index from the end (e.g., -1 for the last axis). - optional int32 axis = 5 [default = 1]; - // Specify whether to transpose the weight matrix or not. - // If transpose == true, any operations will be performed on the transpose - // of the weight matrix. The weight matrix itself is not going to be transposed - // but rather the transfer flag of operations will be toggled accordingly. - optional bool transpose = 6 [default = false]; -} - -message InputParameter { - // This layer produces N >= 1 top blob(s) to be assigned manually. - // Define N shapes to set a shape for each top. - // Define 1 shape to set the same shape for every top. - // Define no shape to defer to reshaping manually. - repeated BlobShape shape = 1; -} - -// Message that stores parameters used by LogLayer -message LogParameter { - // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0. - // Or if base is set to the default (-1), base is set to e, - // so y = ln(shift + scale * x) = log_e(shift + scale * x) - optional float base = 1 [default = -1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -// Message that stores parameters used by LRNLayer -message LRNParameter { - optional uint32 local_size = 1 [default = 5]; - optional float alpha = 2 [default = 1.]; - optional float beta = 3 [default = 0.75]; - enum NormRegion { - ACROSS_CHANNELS = 0; - WITHIN_CHANNEL = 1; - } - optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; - optional float k = 5 [default = 1.]; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 6 [default = DEFAULT]; -} - -message MemoryDataParameter { - optional uint32 batch_size = 1; - optional uint32 channels = 2; - optional uint32 height = 3; - optional uint32 width = 4; -} - -message MVNParameter { - // This parameter can be set to false to normalize mean only - optional bool normalize_variance = 1 [default = true]; - - // This parameter can be set to true to perform DNN-like MVN - optional bool across_channels = 2 [default = false]; - - // Epsilon for not dividing by zero while normalizing variance - optional float eps = 3 [default = 1e-9]; -} - -message ParameterParameter { - optional BlobShape shape = 1; -} - -message PoolingParameter { - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional PoolMethod pool = 1 [default = MAX]; // The pooling method - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in height and width or as Y, X pairs. - optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) - optional uint32 pad_h = 9 [default = 0]; // The padding height - optional uint32 pad_w = 10 [default = 0]; // The padding width - optional uint32 kernel_size = 2; // The kernel size (square) - optional uint32 kernel_h = 5; // The kernel height - optional uint32 kernel_w = 6; // The kernel width - optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) - optional uint32 stride_h = 7; // The stride height - optional uint32 stride_w = 8; // The stride width - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 11 [default = DEFAULT]; - // If global_pooling then it will pool over the size of the bottom by doing - // kernel_h = bottom->height and kernel_w = bottom->width - optional bool global_pooling = 12 [default = false]; - optional bool ceil_mode = 13 [default = true]; - // How to calculate the output size - using ceil (default) or floor rounding. - enum RoundMode { - CEIL = 0; - FLOOR = 1; - } - optional RoundMode round_mode = 14 [default = CEIL]; -} - -message PowerParameter { - // PowerLayer computes outputs y = (shift + scale * x) ^ power. - optional float power = 1 [default = 1.0]; - optional float scale = 2 [default = 1.0]; - optional float shift = 3 [default = 0.0]; -} - -message PythonParameter { - optional string module = 1; - optional string layer = 2; - // This value is set to the attribute `param_str` of the `PythonLayer` object - // in Python before calling the `setup()` method. This could be a number, - // string, dictionary in Python dict format, JSON, etc. You may parse this - // string in `setup` method and use it in `forward` and `backward`. - optional string param_str = 3 [default = '']; - // Whether this PythonLayer is shared among worker solvers during data parallelism. - // If true, each worker solver sequentially run forward from this layer. - // This value should be set true if you are using it as a data layer. - optional bool share_in_parallel = 4 [default = false]; -} - -// Message that stores parameters used by RecurrentLayer -message RecurrentParameter { - // The dimension of the output (and usually hidden state) representation -- - // must be explicitly set to non-zero. - optional uint32 num_output = 1 [default = 0]; - - optional FillerParameter weight_filler = 2; // The filler for the weight - optional FillerParameter bias_filler = 3; // The filler for the bias - - // Whether to enable displaying debug_info in the unrolled recurrent net. - optional bool debug_info = 4 [default = false]; - - // Whether to add as additional inputs (bottoms) the initial hidden state - // blobs, and add as additional outputs (tops) the final timestep hidden state - // blobs. The number of additional bottom/top blobs required depends on the - // recurrent architecture -- e.g., 1 for RNNs, 2 for LSTMs. - optional bool expose_hidden = 5 [default = false]; -} - -// Message that stores parameters used by ReductionLayer -message ReductionParameter { - enum ReductionOp { - SUM = 1; - ASUM = 2; - SUMSQ = 3; - MEAN = 4; - } - - optional ReductionOp operation = 1 [default = SUM]; // reduction operation - - // The first axis to reduce to a scalar -- may be negative to index from the - // end (e.g., -1 for the last axis). - // (Currently, only reduction along ALL "tail" axes is supported; reduction - // of axis M through N, where N < num_axes - 1, is unsupported.) - // Suppose we have an n-axis bottom Blob with shape: - // (d0, d1, d2, ..., d(m-1), dm, d(m+1), ..., d(n-1)). - // If axis == m, the output Blob will have shape - // (d0, d1, d2, ..., d(m-1)), - // and the ReductionOp operation is performed (d0 * d1 * d2 * ... * d(m-1)) - // times, each including (dm * d(m+1) * ... * d(n-1)) individual data. - // If axis == 0 (the default), the output Blob always has the empty shape - // (count 1), performing reduction across the entire input -- - // often useful for creating new loss functions. - optional int32 axis = 2 [default = 0]; - - optional float coeff = 3 [default = 1.0]; // coefficient for output -} - -// Message that stores parameters used by ReLULayer -message ReLUParameter { - // Allow non-zero slope for negative inputs to speed up optimization - // Described in: - // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities - // improve neural network acoustic models. In ICML Workshop on Deep Learning - // for Audio, Speech, and Language Processing. - optional float negative_slope = 1 [default = 0]; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 2 [default = DEFAULT]; -} - -message ReshapeParameter { - // Specify the output dimensions. If some of the dimensions are set to 0, - // the corresponding dimension from the bottom layer is used (unchanged). - // Exactly one dimension may be set to -1, in which case its value is - // inferred from the count of the bottom blob and the remaining dimensions. - // For example, suppose we want to reshape a 2D blob "input" with shape 2 x 8: - // - // layer { - // type: "Reshape" bottom: "input" top: "output" - // reshape_param { ... } - // } - // - // If "input" is 2D with shape 2 x 8, then the following reshape_param - // specifications are all equivalent, producing a 3D blob "output" with shape - // 2 x 2 x 4: - // - // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 0 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 0 dim: 2 dim: -1 } } - // reshape_param { shape { dim: 0 dim:-1 dim: 4 } } - // - optional BlobShape shape = 1; - - // axis and num_axes control the portion of the bottom blob's shape that are - // replaced by (included in) the reshape. By default (axis == 0 and - // num_axes == -1), the entire bottom blob shape is included in the reshape, - // and hence the shape field must specify the entire output shape. - // - // axis may be non-zero to retain some portion of the beginning of the input - // shape (and may be negative to index from the end; e.g., -1 to begin the - // reshape after the last axis, including nothing in the reshape, - // -2 to include only the last axis, etc.). - // - // For example, suppose "input" is a 2D blob with shape 2 x 8. - // Then the following ReshapeLayer specifications are all equivalent, - // producing a blob "output" with shape 2 x 2 x 4: - // - // reshape_param { shape { dim: 2 dim: 2 dim: 4 } } - // reshape_param { shape { dim: 2 dim: 4 } axis: 1 } - // reshape_param { shape { dim: 2 dim: 4 } axis: -3 } - // - // num_axes specifies the extent of the reshape. - // If num_axes >= 0 (and axis >= 0), the reshape will be performed only on - // input axes in the range [axis, axis+num_axes]. - // num_axes may also be -1, the default, to include all remaining axes - // (starting from axis). - // - // For example, suppose "input" is a 2D blob with shape 2 x 8. - // Then the following ReshapeLayer specifications are equivalent, - // producing a blob "output" with shape 1 x 2 x 8. - // - // reshape_param { shape { dim: 1 dim: 2 dim: 8 } } - // reshape_param { shape { dim: 1 dim: 2 } num_axes: 1 } - // reshape_param { shape { dim: 1 } num_axes: 0 } - // - // On the other hand, these would produce output blob shape 2 x 1 x 8: - // - // reshape_param { shape { dim: 2 dim: 1 dim: 8 } } - // reshape_param { shape { dim: 1 } axis: 1 num_axes: 0 } - // - optional int32 axis = 2 [default = 0]; - optional int32 num_axes = 3 [default = -1]; -} - - -message ScaleParameter { - // The first axis of bottom[0] (the first input Blob) along which to apply - // bottom[1] (the second input Blob). May be negative to index from the end - // (e.g., -1 for the last axis). - // - // For example, if bottom[0] is 4D with shape 100x3x40x60, the output - // top[0] will have the same shape, and bottom[1] may have any of the - // following shapes (for the given value of axis): - // (axis == 0 == -4) 100; 100x3; 100x3x40; 100x3x40x60 - // (axis == 1 == -3) 3; 3x40; 3x40x60 - // (axis == 2 == -2) 40; 40x60 - // (axis == 3 == -1) 60 - // Furthermore, bottom[1] may have the empty shape (regardless of the value of - // "axis") -- a scalar multiplier. - optional int32 axis = 1 [default = 1]; - - // (num_axes is ignored unless just one bottom is given and the scale is - // a learned parameter of the layer. Otherwise, num_axes is determined by the - // number of axes by the second bottom.) - // The number of axes of the input (bottom[0]) covered by the scale - // parameter, or -1 to cover all axes of bottom[0] starting from `axis`. - // Set num_axes := 0, to multiply with a zero-axis Blob: a scalar. - optional int32 num_axes = 2 [default = 1]; - - // (filler is ignored unless just one bottom is given and the scale is - // a learned parameter of the layer.) - // The initialization for the learned scale parameter. - // Default is the unit (1) initialization, resulting in the ScaleLayer - // initially performing the identity operation. - optional FillerParameter filler = 3; - - // Whether to also learn a bias (equivalent to a ScaleLayer+BiasLayer, but - // may be more efficient). Initialized with bias_filler (defaults to 0). - optional bool bias_term = 4 [default = false]; - optional FillerParameter bias_filler = 5; - optional bool scale_from_blob = 6 [default = true]; -} - -message SigmoidParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; -} - -message SliceParameter { - // The axis along which to slice -- may be negative to index from the end - // (e.g., -1 for the last axis). - // By default, SliceLayer concatenates blobs along the "channels" axis (1). - optional int32 axis = 3 [default = 1]; - repeated uint32 slice_point = 2; - - // DEPRECATED: alias for "axis" -- does not support negative indexing. - optional uint32 slice_dim = 1 [default = 1]; -} - -message SmoothL1LossParameter { - // SmoothL1Loss(x) = - // 0.5 * (sigma * x) ** 2 -- if x < 1.0 / sigma / sigma - // |x| - 0.5 / sigma / sigma -- otherwise - optional float sigma = 1 [default = 1]; -} - -// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer -message SoftmaxParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; - - // The axis along which to perform the softmax -- may be negative to index - // from the end (e.g., -1 for the last axis). - // Any other axes will be evaluated as independent softmaxes. - optional int32 axis = 2 [default = 1]; -} - -message TanHParameter { - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 1 [default = DEFAULT]; -} - -// Message that stores parameters used by TileLayer -message TileParameter { - // The index of the axis to tile. - optional int32 axis = 1 [default = 1]; - - // The number of copies (tiles) of the blob to output. - optional int32 tiles = 2; -} - -// Message that stores parameters used by ThresholdLayer -message ThresholdParameter { - optional float threshold = 1 [default = 0]; // Strictly positive values -} - -message WindowDataParameter { - // Specify the data source. - optional string source = 1; - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 2 [default = 1]; - optional string mean_file = 3; - // Specify the batch size. - optional uint32 batch_size = 4; - // Specify if we would like to randomly crop an image. - optional uint32 crop_size = 5 [default = 0]; - // Specify if we want to randomly mirror data. - optional bool mirror = 6 [default = false]; - // Foreground (object) overlap threshold - optional float fg_threshold = 7 [default = 0.5]; - // Background (non-object) overlap threshold - optional float bg_threshold = 8 [default = 0.5]; - // Fraction of batch that should be foreground objects - optional float fg_fraction = 9 [default = 0.25]; - // Amount of contextual padding to add around a window - // (used only by the window_data_layer) - optional uint32 context_pad = 10 [default = 0]; - // Mode for cropping out a detection window - // warp: cropped window is warped to a fixed size and aspect ratio - // square: the tightest square around the window is cropped - optional string crop_mode = 11 [default = "warp"]; - // cache_images: will load all images in memory for faster access - optional bool cache_images = 12 [default = false]; - // append root_folder to locate images - optional string root_folder = 13 [default = ""]; -} - -message SPPParameter { - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional uint32 pyramid_height = 1; - optional PoolMethod pool = 2 [default = MAX]; // The pooling method - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 6 [default = DEFAULT]; -} - -// DEPRECATED: use LayerParameter. -message V1LayerParameter { - repeated string bottom = 2; - repeated string top = 3; - optional string name = 4; - repeated NetStateRule include = 32; - repeated NetStateRule exclude = 33; - enum LayerType { - NONE = 0; - ABSVAL = 35; - ACCURACY = 1; - ARGMAX = 30; - BNLL = 2; - CONCAT = 3; - CONTRASTIVE_LOSS = 37; - CONVOLUTION = 4; - DATA = 5; - DECONVOLUTION = 39; - DROPOUT = 6; - DUMMY_DATA = 32; - EUCLIDEAN_LOSS = 7; - ELTWISE = 25; - EXP = 38; - FLATTEN = 8; - HDF5_DATA = 9; - HDF5_OUTPUT = 10; - HINGE_LOSS = 28; - IM2COL = 11; - IMAGE_DATA = 12; - INFOGAIN_LOSS = 13; - INNER_PRODUCT = 14; - LRN = 15; - MEMORY_DATA = 29; - MULTINOMIAL_LOGISTIC_LOSS = 16; - MVN = 34; - POOLING = 17; - POWER = 26; - RELU = 18; - SIGMOID = 19; - SIGMOID_CROSS_ENTROPY_LOSS = 27; - SILENCE = 36; - SOFTMAX = 20; - SOFTMAX_LOSS = 21; - SPLIT = 22; - SLICE = 33; - TANH = 23; - WINDOW_DATA = 24; - THRESHOLD = 31; - QUANT = 208; - DEQUANT = 209; - } - optional LayerType type = 5; - repeated BlobProto blobs = 6; - repeated string param = 1001; - repeated DimCheckMode blob_share_mode = 1002; - enum DimCheckMode { - STRICT = 0; - PERMISSIVE = 1; - } - repeated float blobs_lr = 7; - repeated float weight_decay = 8; - repeated float loss_weight = 35; - optional AccuracyParameter accuracy_param = 27; - optional ArgMaxParameter argmax_param = 23; - optional ConcatParameter concat_param = 9; - optional ContrastiveLossParameter contrastive_loss_param = 40; - optional ConvolutionParameter convolution_param = 10; - optional DataParameter data_param = 11; - optional DropoutParameter dropout_param = 12; - optional DummyDataParameter dummy_data_param = 26; - optional EltwiseParameter eltwise_param = 24; - optional ExpParameter exp_param = 41; - optional HDF5DataParameter hdf5_data_param = 13; - optional HDF5OutputParameter hdf5_output_param = 14; - optional HingeLossParameter hinge_loss_param = 29; - optional ImageDataParameter image_data_param = 15; - optional InfogainLossParameter infogain_loss_param = 16; - optional InnerProductParameter inner_product_param = 17; - optional LRNParameter lrn_param = 18; - optional MemoryDataParameter memory_data_param = 22; - optional MVNParameter mvn_param = 34; - optional PoolingParameter pooling_param = 19; - optional PowerParameter power_param = 21; - optional ReLUParameter relu_param = 30; - optional SigmoidParameter sigmoid_param = 38; - optional SoftmaxParameter softmax_param = 39; - optional SliceParameter slice_param = 31; - optional TanHParameter tanh_param = 37; - optional ThresholdParameter threshold_param = 25; - optional WindowDataParameter window_data_param = 20; - optional TransformationParameter transform_param = 36; - optional LossParameter loss_param = 42; - optional V0LayerParameter layer = 1; -} - -// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters -// in Caffe. We keep this message type around for legacy support. -message V0LayerParameter { - optional string name = 1; // the layer name - optional string type = 2; // the string to specify the layer type - - // Parameters to specify layers with inner products. - optional uint32 num_output = 3; // The number of outputs for the layer - optional bool biasterm = 4 [default = true]; // whether to have bias terms - optional FillerParameter weight_filler = 5; // The filler for the weight - optional FillerParameter bias_filler = 6; // The filler for the bias - - optional uint32 pad = 7 [default = 0]; // The padding size - optional uint32 kernelsize = 8; // The kernel size - optional uint32 group = 9 [default = 1]; // The group size for group conv - optional uint32 stride = 10 [default = 1]; // The stride - enum PoolMethod { - MAX = 0; - AVE = 1; - STOCHASTIC = 2; - } - optional PoolMethod pool = 11 [default = MAX]; // The pooling method - optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio - - optional uint32 local_size = 13 [default = 5]; // for local response norm - optional float alpha = 14 [default = 1.]; // for local response norm - optional float beta = 15 [default = 0.75]; // for local response norm - optional float k = 22 [default = 1.]; - - // For data layers, specify the data source - optional string source = 16; - // For data pre-processing, we can do simple scaling and subtracting the - // data mean, if provided. Note that the mean subtraction is always carried - // out before scaling. - optional float scale = 17 [default = 1]; - optional string meanfile = 18; - // For data layers, specify the batch size. - optional uint32 batchsize = 19; - // For data layers, specify if we would like to randomly crop an image. - optional uint32 cropsize = 20 [default = 0]; - // For data layers, specify if we want to randomly mirror data. - optional bool mirror = 21 [default = false]; - - // The blobs containing the numeric parameters of the layer - repeated BlobProto blobs = 50; - // The ratio that is multiplied on the global learning rate. If you want to - // set the learning ratio for one blob, you need to set it for all blobs. - repeated float blobs_lr = 51; - // The weight decay that is multiplied on the global weight decay. - repeated float weight_decay = 52; - - // The rand_skip variable is for the data layer to skip a few data points - // to avoid all asynchronous sgd clients to start at the same point. The skip - // point would be set as rand_skip * rand(0,1). Note that rand_skip should not - // be larger than the number of keys in the database. - optional uint32 rand_skip = 53 [default = 0]; - - // Fields related to detection (det_*) - // foreground (object) overlap threshold - optional float det_fg_threshold = 54 [default = 0.5]; - // background (non-object) overlap threshold - optional float det_bg_threshold = 55 [default = 0.5]; - // Fraction of batch that should be foreground objects - optional float det_fg_fraction = 56 [default = 0.25]; - - // optional bool OBSOLETE_can_clobber = 57 [default = true]; - - // Amount of contextual padding to add around a window - // (used only by the window_data_layer) - optional uint32 det_context_pad = 58 [default = 0]; - - // Mode for cropping out a detection window - // warp: cropped window is warped to a fixed size and aspect ratio - // square: the tightest square around the window is cropped - optional string det_crop_mode = 59 [default = "warp"]; - - // For ReshapeLayer, one needs to specify the new dimensions. - optional int32 new_num = 60 [default = 0]; - optional int32 new_channels = 61 [default = 0]; - optional int32 new_height = 62 [default = 0]; - optional int32 new_width = 63 [default = 0]; - - // Whether or not ImageLayer should shuffle the list of files at every epoch. - // It will also resize images if new_height or new_width are not zero. - optional bool shuffle_images = 64 [default = false]; - - // For ConcatLayer, one needs to specify the dimension for concatenation, and - // the other dimensions must be the same for all the bottom blobs. - // By default it will concatenate blobs along the channels dimension. - optional uint32 concat_dim = 65 [default = 1]; - - optional HDF5OutputParameter hdf5_output_param = 1001; -} - -message PReLUParameter { - // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: - // Surpassing Human-Level Performance on ImageNet Classification, 2015. - - // Initial value of a_i. Default is a_i=0.25 for all i. - optional FillerParameter filler = 1; - // Whether or not slope parameters are shared across channels. - optional bool channel_shared = 2 [default = false]; -} - -// Message that stores parameters used by DetectionOutputLayer -//message DetectionOutputParameter { -// optional int32 num_classes = 1 [default = 21]; -// optional float nms_threshold = 2 [default = 0.3]; -// optional int32 top_k = 3; -// optional float confidence_threshold = 4 [default = 0.8]; -//} - -// Message that store parameters used by PriorBoxLayer -message PriorBoxParameter { - // Encode/decode type. - enum CodeType { - CORNER = 1; - CENTER_SIZE = 2; - CORNER_SIZE = 3; - } - // Minimum box size (in pixels). Required! - repeated float min_size = 1; - // Maximum box size (in pixels). Required! - repeated float max_size = 2; - // Various of aspect ratios. Duplicate ratios will be ignored. - // If none is provided, we use default ratio 1. - repeated float aspect_ratio = 3; - // If true, will flip each aspect ratio. - // For example, if there is aspect ratio "r", - // we will generate aspect ratio "1.0/r" as well. - optional bool flip = 4 [default = true]; - // If true, will clip the prior so that it is within [0, 1] - optional bool clip = 5 [default = false]; - // Variance for adjusting the prior bboxes. - repeated float variance = 6; - // By default, we calculate img_height, img_width, step_x, step_y based on - // bottom[0] (feat) and bottom[1] (img). Unless these values are explicitely - // provided. - // Explicitly provide the img_size. - optional uint32 img_size = 7; - // Either img_size or img_h/img_w should be specified; not both. - optional uint32 img_h = 8; - optional uint32 img_w = 9; - - // Explicitly provide the step size. - optional float step = 10; - // Either step or step_h/step_w should be specified; not both. - optional float step_h = 11; - optional float step_w = 12; - - // Offset to the top left corner of each cell. - optional float offset = 13 [default = 0.5]; -} - -// Message that stores parameters used by PermutetLayer -message PermuteParameter { - // The new orders of the axes of data. Notice it should be with - // in the same range as the input data, and it starts from 0. - // Do not provide repeated order. - repeated uint32 order = 1; -} - -message NormalizeParameter { - optional bool across_spatial = 1 [default = true]; - // Initial value of scale. Default is 1.0 for all - optional FillerParameter scale_filler = 2; - // Whether or not scale parameters are shared across channels. - optional bool channel_shared = 3 [default = true]; - // Epsilon for not dividing by zero while normalizing variance - optional float eps = 4 [default = 1e-10]; -} - -// needed by ssd -message SaveOutputParameter { - // Output directory. If not empty, we will save the results. - optional string output_directory = 1; - // Output name prefix. - optional string output_name_prefix = 2; - // Output format. - // VOC - PASCAL VOC output format. - // COCO - MS COCO output format. - optional string output_format = 3; - // If you want to output results, must also provide the following two files. - // Otherwise, we will ignore saving results. - // label map file. - optional string label_map_file = 4; - // A file which contains a list of names and sizes with same order - // of the input DB. The file is in the following format: - // name height width - // ... - optional string name_size_file = 5; - // Number of test images. It can be less than the lines specified in - // name_size_file. For example, when we only want to evaluate on part - // of the test images. - optional uint32 num_test_image = 6; - // The resize parameter used in saving the data. - // optional ResizeParameter resize_param = 7; -} - -message NonMaximumSuppressionParameter { - // Threshold to be used in nms. - optional float nms_threshold = 1 [default = 0.3]; - // Maximum number of results to be kept. - optional int32 top_k = 2; - // Parameter for adaptive nms. - optional float eta = 3 [default = 1.0]; -} - -message GeneralNmsParameter { - optional int32 post_top_k = 1 ; - optional float nms_threshold = 2 [default = 0]; - optional float iou_threshold_decay = 3 [default = 1.0]; - optional float coor_scale_factor = 4 [default = 1.0]; -} - -// Message that store parameters used by DetectionOutputLayer, ssd/fasterRcnn -message DetectionOutputParameter { - optional int32 num_classes = 1; - optional bool share_location = 2 [default = true]; - optional int32 background_label_id = 3 [default = 0]; - optional NonMaximumSuppressionParameter nms_param = 4; - optional SaveOutputParameter save_output_param = 5; - optional PriorBoxParameter.CodeType code_type = 6 [default = CENTER_SIZE]; - optional bool variance_encoded_in_target = 8 [default = true]; - optional int32 keep_top_k = 7; - optional float confidence_threshold = 9; - optional float nms_threshold = 13; - optional int32 top_k = 14; - optional int32 boxes = 15 [default = 1]; - optional bool relative = 17 [default = true]; - optional float objectness_threshold = 18 [default = 0.5]; - optional float class_threshold = 19 [default = 0.5]; - repeated float biases = 20; - optional GeneralNmsParameter general_nms_param = 21; - optional float objectness_score = 22; -} -message PSROIPoolingParameter { - required float spatial_scale = 1; - required int32 output_dim = 2; // output channel number - required int32 group_size = 3; // number of groups to encode position-sensitive score maps -} -// Message that stores parameters used by FreespaceExtractLayer -message FreespaceExtractParameter { - optional float org_height = 1; -} - -// Message that stores parameters used by DetectpostprocessLayer -message PostprocessParameter { - optional float nms_thresh = 1 [default = 0.3]; - optional float conf_thresh = 2 [default = 0.5]; - optional uint32 post_nms_topn = 3 [default = 100]; - optional uint32 cls_num = 4 [default = 12]; - repeated float bbox_reg_weights = 5; -} - -// Message that stores parameters used by SpatialTransformLayer -message SpatialTransformParameter { - optional uint32 output_h = 1 [default = 0]; - optional uint32 output_w = 2 [default = 0]; - optional float border_value = 3 [default = 0]; - repeated float affine_transform = 4; - enum Engine { - DEFAULT = 0; - CAFFE = 1; - CUDNN = 2; - } - optional Engine engine = 15 [default = DEFAULT]; -} -message ROIAlignParameter { - // Pad, kernel size, and stride are all given as a single value for equal - // dimensions in height and width or as Y, X pairs. - optional uint32 pooled_h = 1 [default = 0]; // The pooled output height - optional uint32 pooled_w = 2 [default = 0]; // The pooled output width - // Multiplicative spatial scale factor to translate ROI coords from their - // input scale to the scale used when pooling - optional float spatial_scale = 3 [default = 1]; - optional int32 sampling_ratio = 4 [default = -1]; - optional int32 roi_end_mode = 5 [default = 0]; -} - -message RegionParameter { - optional uint32 classes = 1 [default = 20]; // Category of classification - optional uint32 coords = 2 [default = 4]; // Coordinates of box - optional uint32 boxes = 3 [default = 1]; // Number of boxes predicted per grid - optional uint32 softmax = 4 [default = 0]; - optional string softmax_tree = 5 [default = ""]; - optional uint32 background = 6 [default = 0]; -} -message ReorgParameter{ - optional uint32 stride = 2 [default = 2]; - optional bool reverse = 1 [default = false]; -} -message ReverseParameter{ - repeated int32 axis = 1; -} -message InterpParameter{ - optional int32 height = 1 [default = 0];//Height of output - optional int32 width = 2 [default = 0];//Width of output - optional int32 zoom_factor = 3 [default = 1];//zoom factor - optional int32 shrink_factor = 4 [default = 1];//shrink factor - optional int32 pad_beg = 5 [default = 0];//padding at begin of input - optional int32 pad_end = 6 [default = 0];//padding at end of input -} -message ShuffleChannelParameter{ - optional uint32 group = 1[default = 1]; // The number of group -} -message UpsampleParameter{ - optional float scale = 1[default = 1]; - optional int32 stride = 2[default = 2]; - optional int32 stride_h = 3[default = 2]; - optional int32 stride_w = 4[default=2]; -} -message ROIPoolingParameter { - required int32 pooled_h = 1; - required int32 pooled_w = 2; - optional float spatial_scale = 3 [default=0.0625]; - optional float spatial_scale_h = 4; - optional float spatial_scale_w = 5; -} - -message YoloParameter { - optional int32 boxes = 1 [default = 3]; - optional int32 coords = 2 [default = 4]; - optional int32 classes = 3 [default = 80]; - optional string yolo_version = 4 [default = "V3"]; - optional bool softmax = 5 [default = false]; - optional bool background = 6 [default = false]; - optional bool softmaxtree = 7 [default = false]; -} - -message YoloV3DetectionOutputParameter { - optional int32 boxes = 1 [default = 3]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases_high = 9; - repeated float biases_mid = 10; - repeated float biases_low = 11; - optional int32 coords = 12 [default = 4]; - repeated float biases = 13; - optional bool resize_origin_img_to_net = 14 [default = false]; -} - -message YoloV3DetectionOutputV2Parameter { - optional int32 boxes = 1 [default = 3]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases_high = 9; - repeated float biases_mid = 10; - repeated float biases_low = 11; - optional int32 coords = 12 [default = 4]; - repeated float biases = 13; - optional bool resize_origin_img_to_net = 14 [default = false]; - optional int32 out_box_dim = 15 [default = 3]; -} - -message ProposalParameter { - optional float feat_stride = 1 [default = 16]; - optional float base_size = 2 [default = 16]; - optional float min_size = 3 [default = 16]; - repeated float ratio = 4; - repeated float scale = 5; - optional int32 pre_nms_topn = 6 [default = 3000]; - optional int32 post_nms_topn = 7 [default = 304]; - optional float iou_threshold = 8 [default = 0.7]; - optional bool output_actual_rois_num = 9 [default = false]; -} - -message FSRDetectionOutputParameter { - required int32 num_classes = 1; - required float score_threshold = 2; - required float iou_threshold = 3; - optional int32 batch_rois = 4 [default = 1]; -} - -message SSDDetectionOutputParameter { - required int32 num_classes= 1 [default = 2]; - optional bool share_location = 2 [default = true]; - optional int32 background_label_id = 3 [default = 0]; - optional float iou_threshold = 4 [default = 0.3]; - optional int32 top_k = 5 [default = 200]; - optional float eta = 6 [default = 1.0]; - optional bool variance_encoded_in_target = 7 [default = false]; - optional int32 code_type = 8 [default = 1]; - optional int32 keep_top_k = 9 [default = -1]; - optional float confidence_threshold = 10 [default = 0.0]; -} -message YoloV2DetectionOutputParameter { - optional int32 boxes = 1 [default = 5]; - optional int32 classes = 2 [default = 80]; - optional bool relative = 3 [default = true]; - optional float obj_threshold = 4 [default = 0.5]; - optional float score_threshold = 5 [default = 0.5]; - optional float iou_threshold = 6 [default = 0.45]; - optional int32 pre_nms_topn = 7 [default = 512]; - optional int32 post_nms_topn = 8 [default = 1024]; - repeated float biases = 9; - optional int32 coords = 10 [default = 4]; - optional bool resize_origin_img_to_net = 11 [default = false]; -} - -message QuantParameter { - optional float scale = 2; - optional bytes offset = 3; -} - -message BatchMatMulParameter{ - optional bool adj_x1 = 1 [default = false]; - optional bool adj_x2 = 2 [default = false]; -} - -message CondTakeParameter { - required string mode = 1; - required float val = 2; - optional float eps = 3 [default = 1e-06]; -} - -message MatrixInverseParameter { - optional bool adjoint = 1 [default = false]; -} - -message WarpPerspectiveParameter { - required int32 out_height = 1; - required int32 out_width = 2; - optional float constant = 3; - optional string border_type = 4 [default = 'BORDER_CONSTANT']; -} - -message SpatialTransformerParameter { - // How to use the parameter passed by localisation network - optional string transform_type = 1 [default = "affine"]; - // What is the sampling technique - optional string sampler_type = 2 [default = "bilinear"]; - - // If not set,stay same with the input dimension H and W - optional int32 output_H = 3; - optional int32 output_W = 4; - // If false, only compute dTheta, DO NOT compute dU - optional bool to_compute_dU = 5 [default = true]; - - // The default value for some parameters - optional double theta_1_1 = 6; - optional double theta_1_2 = 7; - optional double theta_1_3 = 8; - optional double theta_2_1 = 9; - optional double theta_2_2 = 10; - optional double theta_2_3 = 11; -} - -message ContinuationIndicatorParameter { - optional uint32 time_step = 1 [default = 0]; - optional uint32 batch_size = 2 [default = 0]; -} diff --git a/proto/dump_task.proto b/proto/dump_task.proto deleted file mode 100644 index 36c3cdf48537756d07e562e714ea84e797c9d15d..0000000000000000000000000000000000000000 --- a/proto/dump_task.proto +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package toolkit.dump; - -enum OutputDataType { - DT_UNDEFINED = 0; - DT_FLOAT = 1; - DT_FLOAT16 = 2; - DT_INT8 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_UINT16 = 6; - DT_INT32 = 7; - DT_INT64 = 8; - DT_UINT32 = 9; - DT_UINT64 = 10; - DT_BOOL = 11; - DT_DOUBLE = 12; - DT_STRING = 13; - DT_DUAL_SUB_INT8 = 14; - DT_DUAL_SUB_UINT8 = 15; - DT_COMPLEX64 = 16; - DT_COMPLEX128 = 17; - DT_QINT8 = 18; - DT_QINT16 = 19; - DT_QINT32 = 20; - DT_QUINT8 = 21; - DT_QUINT16 = 22; - DT_RESOURCE = 23; - DT_STRING_REF = 24; - DT_DUAL = 25; - DT_VARIANT = 26; - DT_BF16 = 27; // bf16 type - DT_INT4 = 28; // int4 type - DT_UINT1 = 29; // uint1 type - DT_INT2 = 30; // int2 type - DT_UINT2 = 31; // uint2 type - DT_COMPLEX32 = 32; // complex32 type - DT_HIFLOAT8 = 33; - DT_FLOAT8_E5M2 = 34; - DT_FLOAT8_E4M3FN = 35; - DT_FLOAT8_E8M0 = 36; // float8_e8m0 type - DT_FLOAT6_E3M2 = 37; // float6_e3m2 type - DT_FLOAT6_E2M3 = 38; // float6_e2m3 type - DT_FLOAT4_E2M1 = 39; // float4_e2m1 type - DT_FLOAT4_E1M2 = 40; // float4_e1m2 type -} - -enum OutputFormat { - FORMAT_NCHW = 0; - FORMAT_NHWC = 1; - FORMAT_ND = 2; - FORMAT_NC1HWC0 = 3; - FORMAT_FRACTAL_Z = 4; - FORMAT_NC1C0HWPAD = 5; - FORMAT_NHWC1C0 = 6; - FORMAT_FSR_NCHW = 7; - FORMAT_FRACTAL_DECONV = 8; - FORMAT_C1HWNC0 = 9; - FORMAT_FRACTAL_DECONV_TRANSPOSE = 10; - FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS = 11; - FORMAT_NC1HWC0_C04 = 12; - FORMAT_FRACTAL_Z_C04 = 13; - FORMAT_CHWN = 14; - FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15; - FORMAT_HWCN = 16; - FORMAT_NC1KHKWHWC0 = 17; - FORMAT_BN_WEIGHT = 18; - FORMAT_FILTER_HWCK = 19; - FORMAT_HASHTABLE_LOOKUP_LOOKUPS=20; - FORMAT_HASHTABLE_LOOKUP_KEYS = 21; - FORMAT_HASHTABLE_LOOKUP_VALUE = 22; - FORMAT_HASHTABLE_LOOKUP_OUTPUT = 23; - FORMAT_HASHTABLE_LOOKUP_HITS=24; - FORMAT_C1HWNCoC0 = 25; - FORMAT_MD = 26; - FORMAT_NDHWC = 27; - FORMAT_FRACTAL_ZZ = 28; - FORMAT_FRACTAL_NZ = 29; - FORMAT_NCDHW = 30; - FORMAT_DHWCH = 31; - FORMAT_NDC1HWC0 = 32; - FORMAT_FRACTAL_Z_3D = 33; - FORMAT_CN = 34; - FORMAT_NC = 35; - FORMAT_DHWNC = 36; - FORMAT_FRACTAL_Z_3D_TRANSPOSE = 37; - FORMAT_FRACTAL_ZN_LSTM = 38; - FORMAT_FRACTAL_Z_G = 39; - FORMAT_RESERVED = 40; - FORMAT_ALL = 41; - FORMAT_NULL = 42; - FORMAT_ND_RNN_BIAS = 43; - FORMAT_FRACTAL_ZN_RNN = 44; - FORMAT_NYUV = 45; - FORMAT_NYUV_A = 46; - FORMAT_NCL = 47; - FORMAT_FRACTAL_Z_WINO = 48; - FORMAT_C1HWC0 = 49; - FORMAT_MAX = 0xff; -} - -message OriginalOp { - string name = 1; - uint32 output_index = 2; - OutputDataType data_type = 3; - OutputFormat format = 4; -} - -message Shape { - repeated uint64 dim = 1; -} - -message DimRange { - uint64 dim_start = 1; - uint64 dim_end = 2; -} - -message OpOutput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - OriginalOp original_op = 4; // the original op corresponding to the output - bytes data = 5; - uint64 size = 6; - Shape original_shape = 7; - int32 sub_format = 8; - uint64 address = 9; - repeated DimRange dim_range = 10; - uint32 arg_index = 11; -} - -message OpInput { - OutputDataType data_type = 1; - OutputFormat format = 2; - Shape shape = 3; - bytes data = 4; - uint64 size = 5; - Shape original_shape = 6; - int32 sub_format = 7; - uint64 address = 8; - uint64 offset = 9; - uint32 arg_index = 10; - uint32 input_type = 11; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - bytes data = 2; - uint64 size = 3; -} - -message OpAttr { - string name = 1; - string value = 2; -} - -message Workspace { - enum SpaceType { - LOG = 0; - } - SpaceType type = 1; - bytes data = 2; - uint64 size = 3; - uint32 arg_index = 4; -} - -message DumpData{ - string version = 1; - uint64 dump_time = 2; - repeated OpOutput output = 3; - repeated OpInput input = 4; - repeated OpBuffer buffer = 5; - string op_name = 6; - repeated OpAttr attr = 7; - repeated Workspace space = 8; - string dfx_message = 9; -} diff --git a/proto/flow_model.proto b/proto/flow_model.proto deleted file mode 100644 index 66400aad44956f93f779acd279e47e5d79521c17..0000000000000000000000000000000000000000 --- a/proto/flow_model.proto +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package ge.flow_model.proto; - -import "ge_ir.proto"; - -message ModelRelationDef { - message QueueDef { - string name = 1; - uint32 depth = 2; - string enqueue_policy = 3; - bool is_control = 4; - } - message AttrValue { - oneof value { - bytes s = 1; //"string" - int64 i = 2; //"int" - bool b = 3; //"bool" - } - } - message Endpoint { - string name = 1; - int32 endpoint_type = 2; - map attrs = 3; - } - message InvokedModelQueueInfo { - repeated string input_queue_name = 1; - repeated string output_queue_name = 2; - } - message ModelQueueInfo { - string model_name = 1; - repeated string input_queue_name = 2; - repeated string output_queue_name = 3; - repeated string external_input_queue_name = 4; - repeated string external_output_queue_name = 5; - repeated string invoke_model_key = 6; - } - message ModelEndpointInfo { - string model_name = 1; - repeated string input_endpoint_name = 2; - repeated string output_endpoint_name = 3; - repeated string external_input_queue_name = 4; - repeated string external_output_queue_name = 5; - repeated string event_input_name = 6; - repeated string event_output_name = 7; - repeated string invoke_model_key = 8; - } - - repeated QueueDef queue_def = 1; - map submodel_queue_info = 2; - map invoked_model_queue_info = 3; - ModelQueueInfo root_model_queue_info = 4; - - repeated Endpoint endpoint = 5; - map submodel_endpoint_info = 6; - ModelEndpointInfo root_model_endpoint_info = 7; -} - -message RunningResource { - string type = 1; - int64 value = 2; -} - -message ModelDeployResource { - string resource_type = 1; - repeated RunningResource running_resource = 2; - bool is_heavy_load = 3; -} - -message ModelDeployInfo { - string logic_device_id = 1; -} - -message ModelRedundantDeployInfo { - string redundant_logic_device_id = 1; -} - -message SubmodelDef { - string model_name = 1; - string model_type = 2; - bytes om_data = 3; - ge.proto.GraphDef graph = 4; - ModelDeployResource deploy_resource = 5; - ModelDeployInfo deploy_info = 6; - map ext_attrs = 7; - ModelRedundantDeployInfo redundant_deploy_info = 8; - string om_data_file_path = 9; - bool is_builtin_udf = 10; -} - -message CompileResource { - string host_resource_type = 1; - map logic_device_id_to_resource_type = 2; - message RunningResourceList { - repeated RunningResource running_resource = 1; - } - map dev_to_resource_list = 3; // key is logic device id -} - -message FlowModelDef { - message EschedPriority { - map esched_priority = 1; - } - message RankIds { - repeated uint32 rank_id = 1; - } - message HcomClusterDef { - string name = 1; - string rank_table = 2; - map group_name_to_rank_ids = 3; - map device_to_rank_ids = 4; - } - message ModelClusterRankId { - string model_name = 1; - string cluster_name = 2; - uint32 rank_id = 3; - } - string model_name = 1; - ModelRelationDef relation = 2; - repeated string submodel_name = 3; - map models_esched_priority = 4; - map model_name_to_rank_id = 6; - map group_name_to_rank_ids = 5; - map device_to_rank_ids = 7; - CompileResource compile_resource = 8; - repeated HcomClusterDef hcom_cluster_defs = 9; - repeated ModelClusterRankId model_cluster_rank_ids = 10; -} diff --git a/proto/fusion_model.proto b/proto/fusion_model.proto deleted file mode 100644 index efea75653c34ac60f03ebca3caf0d24bf7897e3c..0000000000000000000000000000000000000000 --- a/proto/fusion_model.proto +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -import "om.proto"; - -package domi; - -message FusionModelDef { - string version = 1; - repeated OpDef fusion_op = 2; -} diff --git a/proto/fwk_adapter.proto b/proto/fwk_adapter.proto deleted file mode 100644 index 1f01485cb4e9e4e4db70ed2463017e22b9c6996b..0000000000000000000000000000000000000000 --- a/proto/fwk_adapter.proto +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package aicpu.FWKAdapter; -option cc_enable_arenas = true; - - -// Defines an struct for input and output. -message TensorDataInfo { - - // value DataType - uint32 dtype = 1; - - // shape dim - repeated int64 dim = 2; - - // data point addr - int64 data_addr = 3; -} - -message KernelRunParam { - // input - repeated TensorDataInfo input = 1; - // output - repeated TensorDataInfo output = 2; -} - diff --git a/proto/ge_api.proto b/proto/ge_api.proto deleted file mode 100644 index 9055b3cfe0ff27a60cda59800d87bbc4b5465a4f..0000000000000000000000000000000000000000 --- a/proto/ge_api.proto +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package ge.api_pb; - -import "ge_ir.proto"; - -// GE initialize -message GEInitialize { - map options = 1; -}; - -// initialize response -message GEInitializeResponse { - uint32 status = 1; - uint32 clientId = 2; -}; - -// GE finalize -message GEFinalize { - bool final = 1; - uint32 clientId = 2; -}; - -message GEFinalizeResponse { - uint32 status = 1; -}; - -// GE Session -message CreateSession{ - map options = 1; -}; - -message CreateSessionResponse { - uint32 status = 1; - uint64 sessionId = 2; -}; - -//GE AddGraph -//model serialize :: serializegraph -message SessionAddGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - ge.proto.GraphDef graph = 3; -}; - -message SessionAddGraphResponse { - uint32 status = 1; -}; - -//GE SessionRemoveGraph -message SessionRemoveGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; -}; - -message SessionRemoveGraphResponse { - uint32 status = 1; -}; - -message SessionRunGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; -}; - -message SessionBuildGraph{ - uint32 graphId = 1; - uint64 sessionId = 2; - repeated ge.proto.TensorDef tensor = 3; - string savePath = 4; -}; - -message SessionRunGraphResponse { - uint32 status = 1; - repeated ge.proto.TensorDef tensor = 2; -}; - -message SessionBuildGraphResponse { - uint32 status = 1; -}; - -message DestroySession{ - bool final = 1; - uint64 sessionId = 2; -}; - -message DestroySessionResponse { - uint32 status = 1; -}; diff --git a/proto/ge_ir.proto b/proto/ge_ir.proto deleted file mode 100644 index cb96d134b806aa7a911be1c2e6fc9e38c9a8576f..0000000000000000000000000000000000000000 --- a/proto/ge_ir.proto +++ /dev/null @@ -1,403 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package ge.proto; -enum DataType -{ - DT_UNDEFINED = 0; // Used to indicate a DataType field has not been set. - DT_FLOAT = 1; // float type - DT_FLOAT16 = 2; // fp16 type - DT_INT8 = 3; // int8 type - DT_UINT8 = 4; // uint8 type - DT_INT16 = 5; // int16 type - DT_UINT16 = 6; // uint16 type - DT_INT32 = 7; // - DT_INT64 = 8; // int64 type - DT_UINT32 = 9; // unsigned int32 - DT_UINT64 = 10; // unsigned int64 - DT_BOOL = 11; // bool type - DT_DOUBLE = 12; // double type - DT_STRING = 13; // string type - DT_DUAL_SUB_INT8 = 14; /**< dual output int8 type */ - DT_DUAL_SUB_UINT8 = 15; /**< dual output uint8 type */ - DT_COMPLEX64 = 16; // complex64 type - DT_COMPLEX128 = 17; // complex128 type - DT_QINT8 = 18; // qint8 type - DT_QINT16 = 19; // qint16 type - DT_QINT32 = 20; // qint32 type - DT_QUINT8 = 21; // quint8 type - DT_QUINT16 = 22; // quint16 type - DT_RESOURCE = 23; // resource type - DT_STRING_REF = 24; // string_ref type - DT_DUAL = 25; /**< dual output type */ - DT_VARIANT = 26; // variant type - DT_BF16 = 27; // bf16 type - DT_INT4 = 28; // int4 type - DT_UINT1 = 29; // uint1 type - DT_INT2 = 30; // int2 type - DT_UINT2 = 31; // uint2 type - DT_COMPLEX32 = 32; // complex32 type - DT_HIFLOAT8 = 33; // hifloat8 type - DT_FLOAT8_E5M2 = 34; // float8_e5m2 type - DT_FLOAT8_E4M3FN = 35; // float8_e4m3fn type - DT_FLOAT8_E8M0 = 36; // float8_e8m0 type - DT_FLOAT6_E3M2 = 37; // float6_e3m2 type - DT_FLOAT6_E2M3 = 38; // float6_e2m3 type - DT_FLOAT4_E2M1 = 39; // float4_e2m1 type - DT_FLOAT4_E1M2 = 40; // float4_e1m2 type -} - -message AttrDef -{ - message ListValue - { - enum ListValueType{ - VT_LIST_NONE = 0; - VT_LIST_STRING = 1; - VT_LIST_INT = 2; - VT_LIST_FLOAT = 3; - VT_LIST_BOOL = 4; - VT_LIST_BYTES = 5; - VT_LIST_TENSOR_DESC = 6; - VT_LIST_TENSOR = 7; - VT_LIST_GRAPH = 8; - VT_LIST_NAMED_ATTRS = 9; - VT_LIST_DATA_TYPE = 10; - } - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3; // "list(int)" - repeated float f = 4; // "list(float)" - repeated bool b = 5; // "list(bool)" - repeated bytes bt = 7; - repeated TensorDescriptor td = 8; - repeated TensorDef t = 9; - repeated GraphDef g = 10; - repeated NamedAttrs na = 11; - repeated int64 dt = 12; // list ge::DataType - - ListValueType val_type = 20; - } - - message ListListInt{ - message ListInt{ - repeated int64 list_i = 1; // list int - } - repeated ListInt list_list_i = 1; // list list int - } - - message ListListFloat{ - message ListFloat{ - repeated float list_f = 1; // list float - } - repeated ListFloat list_list_f = 1; // list list float - } - - oneof value - { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; // Used to support attr nesting - TensorDescriptor td = 11; // GeTensorDesc type - TensorDef t = 12; // GeTensor type - GraphDef g = 13; // Graph type - ListListInt list_list_int = 14; // List List Int type - int64 dt = 15; // ge::DataType - ListListFloat list_list_float = 16; // List List Float type - bytes expression = 17; - } -} - -// replace AttrDef in ge_ir.proto in the future -// do not use this, only stub currently -message AttributeDef -{ - oneof value - { - bytes s = 1; // "string" - int64 i = 2; // "int" - float f = 3; // "float" - bool b = 4; // "bool" - bytes bt = 5; // "bytes"; - } -} - -message OtherGroupDef { - map attr = 1; -} - -message TensorDescAttrGroupsDef { - repeated string origin_symbol_shape = 1; // symbolic origin shape - repeated string symbolic_value = 2; // symbolic value -} - -message InputSourceDef { - int32 input_data_idx = 1; - int64 dim_idx = 2; -} - -message ReplacementDef { - string replace_expr = 1; - int32 rank = 2; -} - -message SymbolCheckInfoDef { - string expr = 1; - string file = 2; - int64 line = 3; - string dfx = 4; -} - -message ShapeEnvSettingDef { - bool specialize_zero_one = 1; - int32 dynamic_mode = 2; -} - -message SymbolInfoDef { - repeated string symbols = 1; -} - -message ShapeEnvAttrGroupsDef { - map symbol_to_value = 1; - map value_to_symbol = 2; - map symbol_to_source = 3; - map replacements = 4; - repeated SymbolCheckInfoDef symbol_check_infos = 5; - repeated SymbolCheckInfoDef symbol_assert_infos = 6; - ShapeEnvSettingDef shape_setting = 7; - uint64 unique_sym_id = 8; -} - -message SchedInfoDef { - int64 exec_order = 1; - repeated int64 axis = 2; - int64 loop_axis = 3; - int32 exec_condition = 4; -} -message ApiInfoDef { - int32 type = 1; - int32 compute_type = 2; - int32 unit = 3; -} - -message MemAttrDef { - int64 tensor_id = 1; - int32 alloc_type = 2; - int32 position = 3; - int32 hardware = 4; - repeated int64 buf_ids = 5; - string name = 6; - int64 reuse_id = 7; -} - -message MemQueueAttrDef { - int64 id = 1; - int64 depth = 2; - int64 buf_num = 3; - string name = 4; -} - -message MemBufAttrDef { - int64 id = 1; - string name = 2; -} - -message MemOptAttrDef { - int64 reuse_id = 1; - int64 ref_tensor = 2; - int64 merge_scope = 3; -} - -message AxisDef { - int64 id = 1; - string name = 2; - int32 axis_type = 3; - bool bind_block = 4; - string size = 5; // expression - int32 align = 6; - repeated int64 from = 7; - int64 split_pair_other_id = 8; - bool allow_oversize_axis = 9; - bool allow_unaligned_tail = 10; -} - -message AscendCIROpAttrGroupsDef { - string name = 1; - string type = 2; -} - -message AscTensorAttrGroupsDef { - int64 dtype = 1; - repeated int64 axis_ids = 2; - repeated string repeats = 3; // expression - repeated string strides = 4; // expression - repeated int64 vectorized_axis = 5; - repeated string vectorized_strides = 6; - MemAttrDef mem = 7; - MemQueueAttrDef que = 8; - MemBufAttrDef buf = 9; - MemOptAttrDef opt = 10; -} - -message AscGraphAttrGroupsDef { - int64 tiling_key = 1; - repeated AxisDef axis = 2; - int64 type = 3; - repeated string size_var = 4; -} - -message AscIrAttrDef { - map attr = 1; -} - - -message TmpBufDescDef { - string size = 1; // expression - int64 life_time_axis_id = 2; -} - -message TmpBufferGroupDef { - TmpBufDescDef buf_desc = 1; - MemAttrDef mem = 2; -} - -message AscNodeAttrGroupsDef { - string name = 1; - string type = 2; - SchedInfoDef sched = 3; - ApiInfoDef api = 4; - AscIrAttrDef ir_attr_def = 5; - repeated TmpBufferGroupDef tmp_buffers = 6; -} - -message AttrGroupDef { - oneof attr_group { - AscendCIROpAttrGroupsDef op_attr_group = 2; - TensorDescAttrGroupsDef tensor_attr_group = 3; - ShapeEnvAttrGroupsDef shape_env_attr_group = 4; - AscGraphAttrGroupsDef asc_graph_attr_group = 5; - AscNodeAttrGroupsDef asc_node_attr_group = 6; - AscTensorAttrGroupsDef asc_tensor_attr_group = 7; - } -} - -message AttrGroups { - OtherGroupDef other_group_def = 1; - repeated AttrGroupDef attr_group_def = 2; -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs -{ - string name = 1; - map attr = 2; -} - -// Shape / dimension description, using row-major order -message ShapeDef -{ - repeated int64 dim = 1; // Size of each dimension -} - -// Multidimensional data description -message TensorDescriptor -{ - string name = 1; // Optional parameter, tensor name - - DataType dtype = 2; // tensor datatype - ShapeDef shape = 3; // Shape / dimension - string layout = 4; // Tensor format, eg: "NCHW", "NHWC", "CHW", "ND" - - bool has_out_attr = 9; - int64 size = 10; - int64 weight_size = 11; - bool reuse_input = 12; - bool output_tensor = 13; - string device_type = 14; - bool input_tensor =15; - int64 real_dim_cnt = 16; - int64 reuse_input_index = 17; - int64 data_offset = 18; - int64 cmps_size = 19; - string cmps_tab = 20; - int64 cmps_tab_offset = 21; - - map attr = 5; // Set of extra parameter fields - AttrGroups attr_groups = 6; // Set of attr groups -} - -// GeTensor definition -message TensorDef -{ - TensorDescriptor desc = 1; // Tensor description - bytes data = 2; // Tensor data -} - - -// Operator description -message OpDef -{ - string name = 1; // name - string type = 2; // type - - repeated string input = 5; // input original op name + outgoing index. op_name:index - - map attr = 10; // Set of operator parameter fields - AttrGroups attr_groups = 11; // Set of attr groups - - bool has_out_attr = 20; - int64 id = 21; - int64 stream_id =22; - repeated string input_name = 23; - repeated string src_name = 24; - repeated int64 src_index = 25; - repeated string dst_name = 26; - repeated int64 dst_index = 27; - repeated int64 input_i = 28; - repeated int64 output_i = 29; - repeated int64 workspace = 30; - repeated int64 workspace_bytes = 31; - repeated bool is_input_const = 32; - repeated TensorDescriptor input_desc = 33; - repeated TensorDescriptor output_desc = 34; - repeated string subgraph_name = 35; -} - -// Graph definition -message GraphDef -{ - string name = 1; // name - - repeated string input = 4; // Graph input - repeated string output = 5; // Graph output - - repeated OpDef op = 6; // List of operators - - map attr = 11; // Extended field - AttrGroups attr_groups = 12; // Set of attr groups -} - -// model definition -message ModelDef -{ - string name = 1; // name - uint32 version = 2; // IR Proto verion - string custom_version = 3; // User model version number, passed in by user - - repeated GraphDef graph = 7; // Graph definition,graph[0] represents the main diagram in modeldef - - map attr = 11; // Extended field - AttrGroups attr_groups = 12; // Set of attr groups -} diff --git a/proto/insert_op.proto b/proto/insert_op.proto deleted file mode 100644 index 8fb696e9bb31b21ba70869b10704dffd00a0f903..0000000000000000000000000000000000000000 --- a/proto/insert_op.proto +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi; - -message InsertNewOps { - repeated AippOpParams aipp_op = 1; - repeated MultiShapeOpParams multi_shape_op = 2; -} - -message AippOpParams { - enum InputFormat { - UNDEFINED = 0; - YUV420SP_U8 = 1; - XRGB8888_U8 = 2; - RGB888_U8 = 3; - YUV400_U8 = 4; - NC1HWC0DI_FP16 = 5; - NC1HWC0DI_S8 = 6; - ARGB8888_U8 = 7; - YUYV_U8 = 8; - YUV422SP_U8 = 9; - AYUV444_U8 = 10; - RAW10 = 11; - RAW12 = 12; - RAW16 = 13; - RAW24 = 14; - RGB16 = 15; - RGB20 = 16; - RGB24 = 17; - RGB8_IR = 18; - RGB16_IR = 19; - RGB24_IR = 20; - } - - enum AippMode { - undefined = 0; - static = 1; - dynamic = 2; - } - - // AIPP模式,区分静态AIPP和动态AIPP - AippMode aipp_mode = 1; - - // related_input_rank参数为必填,类型为整型,配置范围>=0, <=输入Data算子的个数,默认值为0。 - // 标识对模型的第几个输入做AIPP处理,例如模型有两个输入,需要对第2个输入做AIPP,则配置related_input_rank为1。 - uint32 related_input_rank = 2; - - // related_input_name is optional and the top name of data node which inserts aipp - string related_input_name = 6; - - // input_edge_idx参数为可选,类型为整型,配置范围为>=0。 - // 配置该参数的作用,在于对Data算子不同的输出做不同的AIPP处理,如果该参数没有配置,默认对related_input_rank指定的模型输入的所有输出边做AIPP。 - // 配置值 <= Data算子输出边的个数。 - repeated uint32 input_edge_idx = 3; - - // [Begin] 动态AIPP参数,配置静态AIPP时无效 - uint32 max_src_image_size = 4; - - // 是否支持旋转。默认不支持,开启支持旋转时,会有额外的空间和性能损失 - bool support_rotation = 5; - - // [End] 动态AIPP参数 - - - // [Begin] 静态AIPP参数,配置动态AIPP时无效 - InputFormat input_format = 51; - bool csc_switch = 52; - float cpadding_value = 53; - bool rbuv_swap_switch = 54; - bool ax_swap_switch = 55; - bool single_line_mode = 56; - - int32 src_image_size_w = 57; - int32 src_image_size_h = 58; - - bool crop = 59; - int32 load_start_pos_w = 60; - int32 load_start_pos_h = 61; - int32 crop_size_w = 62; - int32 crop_size_h = 63; - - bool resize = 64; - int32 resize_output_w = 65; - int32 resize_output_h = 66; - - bool padding = 67; - int32 left_padding_size = 68; - int32 right_padding_size = 69; - int32 top_padding_size = 70; - int32 bottom_padding_size = 71; - float padding_value = 72; - - int32 mean_chn_0 = 10; - int32 mean_chn_1 = 11; - int32 mean_chn_2 = 12; - int32 mean_chn_3 = 19; - float min_chn_0 = 13; - float min_chn_1 = 14; - float min_chn_2 = 15; - float min_chn_3 = 20; - repeated float var_reci_chn_0 = 16; - repeated float var_reci_chn_1 = 17; - repeated float var_reci_chn_2 = 18; - repeated float var_reci_chn_3 = 21; - - repeated int32 matrix_r0c0 = 30; - repeated int32 matrix_r0c1 = 31; - repeated int32 matrix_r0c2 = 32; - repeated int32 matrix_r1c0 = 33; - repeated int32 matrix_r1c1 = 34; - repeated int32 matrix_r1c2 = 35; - repeated int32 matrix_r2c0 = 36; - repeated int32 matrix_r2c1 = 37; - repeated int32 matrix_r2c2 = 38; - repeated int32 output_bias_0 = 39; - repeated int32 output_bias_1 = 40; - repeated int32 output_bias_2 = 41; - repeated int32 input_bias_0 = 42; - repeated int32 input_bias_1 = 43; - repeated int32 input_bias_2 = 44; - - // [End] 静态AIPP参数 - - // The n number that is used for raw/rgbir data into f16 transformation. - // The transformation equation is x/(2^n). If set to 0, no transform is performed. - uint32 raw_rgbir_to_f16_n = 45; -} - -message MultiShapeOpParams { - enum MultiShapeMode { - batch = 0; //动态batch - resolution = 1; //动态分辨率,扩展用 - } - - MultiShapeMode mode = 1; //算子模式 - uint32 related_input_rank = 2; //新增算子插入到哪个输入 - - - repeated uint32 batch_list = 11; //batch_list值,batch_list的个数是2到8之间 -} diff --git a/proto/om.proto b/proto/om.proto deleted file mode 100644 index fbbbd16a2a628cd969fc6f4e17d649fa51a056dc..0000000000000000000000000000000000000000 --- a/proto/om.proto +++ /dev/null @@ -1,394 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi; - -enum TargetType -{ - MINI = 0; - TINY = 1; - LITE = 2; -} - -// offline model -message ModelDef { - string name = 1; - uint32 version = 2; - - uint64 memory_size = 10; - uint32 stream_num = 11; - uint32 event_num = 12; - uint64 weight_size = 13; - uint32 label_num = 15; - repeated OpDef op = 20; - TargetType target_type = 23; - - map attr = 30; -}; - -// operator define -message OpDef { - string name = 1; - string type = 2; - - uint32 id = 3; - uint32 stream_id = 4; - - repeated string input_name = 5; - - repeated string src_name = 8; - repeated int32 src_index = 9; - repeated int64 input = 10; - repeated int64 output = 11; - repeated TensorDescriptor input_desc = 12; - repeated TensorDescriptor output_desc = 13; - repeated WeightDef weights = 14; - repeated string dst_name = 15; - repeated int32 dst_index = 16; - - repeated int64 workspace = 20; - repeated uint32 workspace_bytes = 21; - - repeated string weight_name = 22; - repeated bool is_input_const = 23; - - map attr = 30; - - QuantizeFactorParams quantize_factor = 31; - - oneof op_params { - // start at 100 here - SendOpParams sender_param = 100; - RecvOpParams receiver_param = 200; - ConvolutionOpParams convolution_param = 300; - PoolingOpParams pooling_param = 400; - EltwiseOpParams eltwise_param = 500; - BatchNormOpParams batchnorm_param = 600; - ScaleOpParams scale_param = 700; - FullConnectionOpParams full_connection_param = 800; - SoftmaxOpParams softmax_param = 900; - ActivationOpParams activation_param = 1000; - ReshapeOpParams reshape_param = 1100; - } -}; - -message SendOpParams { - uint32 event_id = 1; -}; - -message RecvOpParams { - uint32 event_id = 1; -}; - -enum QuantizeScaleType -{ - VECTOR_SCALE = 0; - SCALAR_SCALE = 1; -} - -enum QuantizeScaleMode -{ - NORMAL_MODE = 0; - SQRT_MODE = 1; -} - -enum QuantizeAlgorithm -{ - NON_OFFSET_ALGO = 0; - HALF_OFFSET_ALGO = 1; - ALL_OFFSET_ALGO = 2; -} -message QuantizeFactor -{ - QuantizeScaleMode scale_mode = 1; - bytes scale_value = 2; - int64 scale_offset = 3; - bytes offset_data_value = 4; - int64 offset_data_offset = 5; - bytes offset_weight_value = 6; - int64 offset_weight_offset = 7; - bytes offset_pad_value = 8; - int64 offset_pad_offset = 9; -}; - -message QuantizeCalcFactor -{ - bytes offsetw = 1; - int64 offsetw_offset = 2; - bytes offsetd = 3; - int64 offsetd_offset = 4; - bytes scalereq = 5; - int64 scaledreq_offset = 6; - bytes offsetdnext = 7; - int64 offsetdnext_offset = 8; -} - -message QuantizeFactorParams -{ - QuantizeAlgorithm quantize_algo = 1; - QuantizeScaleType scale_type = 2; - QuantizeFactor quantize_param = 3; - QuantizeFactor dequantize_param = 4; - QuantizeFactor requantize_param = 5; - QuantizeCalcFactor quantizecalc_param = 6; -}; - -message ConvolutionOpParams { - int32 mode = 1; - int32 algo = 2; - int32 pad_mode = 3; - uint32 group = 4; - uint32 num_output = 5; - - repeated uint32 pad = 10; - repeated uint32 stride = 11; - repeated uint32 dilation = 12; - repeated uint32 kernel = 13; - - float alpha = 20; - float beta = 21; - - WeightDef filter = 40; - WeightDef bias = 41; - - bool relu_flag = 62; - repeated uint32 adj = 70; - repeated uint32 target_shape = 71; - repeated uint32 before_pad = 72; -}; - -message PoolingOpParams { - int32 mode = 1; - int32 nan_opt = 2; - int32 pad_mode = 3; - bool global_pooling = 4; - - repeated uint32 window = 10; - repeated uint32 pad = 11; - repeated uint32 stride = 12; - bool ceil_mode = 13; - int32 data_mode = 14; - - float alpha = 20; - float beta = 21; - repeated uint32 before_pad = 22; -}; - -message EltwiseOpParams { - int32 mode = 1; - repeated float coeff = 2; - float alpha = 3; - float beta = 4; - repeated WeightDef weight = 5; - bool relu_flag = 6; -}; - -message ActivationOpParams { - int32 mode = 1; - float coef = 2; - float alpha = 3; - float beta = 4; -}; - -message BatchNormOpParams { - int32 mode = 1; - - float alpha = 2; - float beta = 3; - double epsilon = 4;//optinal,[default = 1e-5] - bool use_global_stats = 5; //optinal,by default true,testing mode - float moving_average_fraction = 6; //optinal,[default = .999]; - - WeightDef estimated_mean = 7; - WeightDef estimated_variance = 8; - - WeightDef scale = 9; - WeightDef bias = 10; -}; - -message ScaleOpParams { - WeightDef scale = 1; - WeightDef bias = 2; -}; - -message ReshapeOpParams { - float alpha = 1; - float beta = 2; - ShapeDef shape = 3; - int32 axis = 4; - int32 num_axes = 5; - int32 format = 6; -}; - -message SoftmaxOpParams { - int32 algo = 1; - int32 mode = 2; - float alpha = 3; - float beta = 4; -}; - -message FullConnectionOpParams { - WeightDef filter = 1; - WeightDef bias = 2; - uint32 num_output = 3; - bool relu_flag = 12; -}; - -message FlattenOpParams { - float alpha = 1; - float beta = 2; - int32 start_axis = 3; - int32 end_axis = 4; -} - -message AddLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message MulLimitedOpParams { - float alpha = 1; - float beta = 2; - int32 axis = 3; - bool broadcast = 4; - - repeated WeightDef weight = 10; -}; - -message AddOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message MulOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message SubOpParams { - float alpha = 1; - float beta = 2; - - repeated WeightDef weight = 10; -}; - -message BiasAddOpParams { - float alpha = 1; - float beta = 2; - - WeightDef bias = 10; -}; - -message MatMulOpParams { - float alpha = 1; - float beta = 2; - bool transposeX = 3; - bool transposeW = 4; - - WeightDef filter = 10; - WeightDef bias = 12; -}; - -message RsqrtOpParams { - float alpha = 1; - float beta = 2; -}; - - -message WeightDef { - int32 format = 1; - int32 data_type = 2; - ShapeDef shape = 3; - bytes data = 4; - int64 data_offset = 5; - uint32 cmps_size = 6; - bytes cmps_tab = 7; - int64 cmps_tab_offset = 10; - CompressInfo cmps_info = 8; - AllOffsetQuantizeInfo alloffset_quantize_info = 11; -} - -message ShapeDef { - repeated int64 dim = 1; -} - -enum DeviceType { - NPU = 0; // In default, we will use NPU. - CPU = 1; // CPU -} - -message AllOffsetQuantizeInfo { - float scale = 1; - int32 offset = 2; -} - -message TensorDescriptor { - int32 format = 1; - int32 data_type = 2; - repeated int64 dim = 3; - uint32 size = 4; - bool reuse_input = 5; - bool output_tensor = 7; - DeviceType device_type = 8; - bool input_tensor = 9; - uint32 real_dim_cnt = 10; - uint32 reuse_input_index = 11; - AllOffsetQuantizeInfo alloffset_quantize_info = 12; -} - -message CompressInfo { - int32 blockRow = 1; // block row - int32 blockCol = 2; // block col - int32 fractalK = 3; // fractal K - int32 fractalN = 4; // fractal N - int32 lastFractalK = 5; // K of last fractal - int32 lastFractalN = 6; // N of last fractal - int32 cubeSize = 7; // cube's length - int32 loadDir = 8; // data load directtiono 0:col load 1:row load -} - -message AttrDef { - message ListValue { - repeated string s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated uint32 u = 6 [packed = true]; // "list(uint)" - repeated bytes bt = 7; - } - - oneof value { - string s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - uint32 u = 6; // "uint32" - bytes bt = 7; - ListValue list = 1; // any "list(...)" - NamedAttrs func = 10; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NamedAttrs { - string name = 1; - map attr = 2; -} - diff --git a/proto/onnx/ge_onnx.proto b/proto/onnx/ge_onnx.proto deleted file mode 100644 index bbe9da37ee9c76ef481a40263011ed089374080d..0000000000000000000000000000000000000000 --- a/proto/onnx/ge_onnx.proto +++ /dev/null @@ -1,578 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package ge.onnx; - -// Overview -// -// ONNX is an open specification that is comprised of the following components: -// -// 1) A definition of an extensible computation graph model. -// 2) Definitions of standard data types. -// 3) Definitions of built-in operators. -// -// This document describes the syntax of models and their computation graphs, -// as well as the standard data types. Together, they are referred to as the ONNX -// Intermediate Representation, or 'IR' for short. -// -// The normative semantic specification of the ONNX IR is found in docs/IR.md. -// Definitions of the built-in neural network operators may be found in docs/Operators.md. - -// Notes -// -// Release -// -// We are still in the very early stage of defining ONNX. The current -// version of ONNX is a starting point. While we are actively working -// towards a complete spec, we would like to get the community involved -// by sharing our working version of ONNX. -// -// Protobuf compatibility -// -// To simplify framework compatibility, ONNX is defined using the subset of protobuf -// that is compatible with both protobuf v2 and v3. This means that we do not use any -// protobuf features that are only available in one of the two versions. -// -// Here are the most notable contortions we have to carry out to work around -// these limitations: -// -// - No 'map' (added protobuf 3.0). We instead represent mappings as lists -// of key-value pairs, where order does not matter and duplicates -// are not allowed. - - -// Versioning -// -// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md -// -// To be compatible with both proto2 and proto3, we will use a version number -// that is not defined by the default value but an explicit enum number. -enum Version { - // proto3 requires the first enum value to be zero. - // We add this just to appease the compiler. - _START_VERSION = 0; - // The version field is always serialized and we will use it to store the - // version that the graph is generated from. This helps us set up version - // control. - // For the IR, we are using simple numbers starting with with 0x00000001, - // which was the version we published on Oct 10, 2017. - IR_VERSION_2017_10_10 = 0x0000000000000001; - - // IR_VERSION 2 published on Oct 30, 2017 - // - Added type discriminator to AttributeProto to support proto3 users - IR_VERSION_2017_10_30 = 0x0000000000000002; - - // IR VERSION 3 published on Nov 3, 2017 - // - For operator versioning: - // - Added new message OperatorSetIdProto - // - Added opset_import in ModelProto - // - For vendor extensions, added domain in NodeProto - IR_VERSION_2017_11_3 = 0x0000000000000003; - - // IR VERSION 4 published on Jan 22, 2019 - // - Relax constraint that initializers should be a subset of graph inputs - // - Add type BFLOAT16 - IR_VERSION_2019_1_22 = 0x0000000000000004; - - // IR VERSION 5 published on March 18, 2019 - // - Add message TensorAnnotation. - // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. - IR_VERSION_2019_3_18 = 0x0000000000000005; - - // IR VERSION 6 published on Sep 19, 2019 - // - Add support for sparse tensor constants stored in model. - // - Add message SparseTensorProto - // - Add sparse initializers - IR_VERSION = 0x0000000000000006; -} - -// Attributes -// -// A named attribute containing either singular float, integer, string, graph, -// and tensor values, or repeated float, integer, string, graph, and tensor values. -// An AttributeProto MUST contain the name field, and *only one* of the -// following content fields, effectively enforcing a C/C++ union equivalent. -message AttributeProto { - - // Note: this enum is structurally identical to the OpSchema::AttrType - // enum defined in schema.h. If you rev one, you likely need to rev the other. - enum AttributeType { - UNDEFINED = 0; - FLOAT = 1; - INT = 2; - STRING = 3; - TENSOR = 4; - GRAPH = 5; - SPARSE_TENSOR = 11; - - FLOATS = 6; - INTS = 7; - STRINGS = 8; - TENSORS = 9; - GRAPHS = 10; - SPARSE_TENSORS = 12; - } - - // The name field MUST be present for this version of the IR. - string name = 1; // namespace Attribute - - // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. - // In this case, this AttributeProto does not contain data, and it's a reference of attribute - // in parent scope. - // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. - string ref_attr_name = 21; - - // A human-readable documentation for this attribute. Markdown is allowed. - string doc_string = 13; - - // The type field MUST be present for this version of the IR. - // For 0.0.1 versions of the IR, this field was not defined, and - // implementations needed to use has_field hueristics to determine - // which value field was in use. For IR_VERSION 0.0.2 or later, this - // field MUST be set and match the f|i|s|t|... field in use. This - // change was made to accomodate proto3 implementations. - AttributeType type = 20; // discriminator that indicates which field below is in use - - // Exactly ONE of the following fields must be present for this version of the IR - float f = 2; // float - int64 i = 3; // int - bytes s = 4; // UTF-8 string - TensorProto t = 5; // tensor value - GraphProto g = 6; // graph - SparseTensorProto sparse_tensor = 22; // sparse tensor value - // Do not use field below, it's deprecated. - // optional ValueProto v = 12; // value - subsumes everything but graph - - repeated float floats = 7; // list of floats - repeated int64 ints = 8; // list of ints - repeated bytes strings = 9; // list of UTF-8 strings - repeated TensorProto tensors = 10; // list of tensors - repeated GraphProto graphs = 11; // list of graph - repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors -} - -// Defines information on value, including the name, the type, and -// the shape of the value. -message ValueInfoProto { - // This field MUST be present in this version of the IR. - string name = 1; // namespace Value - // This field MUST be present in this version of the IR for - // inputs and outputs of the top-level graph. - TypeProto type = 2; - // A human-readable documentation for this value. Markdown is allowed. - string doc_string = 3; -} - -// Nodes -// -// Computation graphs are made up of a DAG of nodes, which represent what is -// commonly called a "layer" or "pipeline stage" in machine learning frameworks. -// -// For example, it can be a node of type "Conv" that takes in an image, a filter -// tensor and a bias tensor, and produces the convolved output. -message NodeProto { - repeated string input = 1; // namespace Value - repeated string output = 2; // namespace Value - - // An optional identifier for this node in a graph. - // This field MAY be absent in ths version of the IR. - string name = 3; // namespace Node - - // The symbolic identifier of the Operator to execute. - string op_type = 4; // namespace Operator - // The domain of the OperatorSet that specifies the operator named by op_type. - string domain = 7; // namespace Domain - - // Additional named attributes. - repeated AttributeProto attribute = 5; - - // A human-readable documentation for this node. Markdown is allowed. - string doc_string = 6; -} - -// Models -// -// ModelProto is a top-level file/container format for bundling a ML model and -// associating its computation graph with metadata. -// -// The semantics of the model are described by the associated GraphProto. -message ModelProto { - // The version of the IR this model targets. See Version enum above. - // This field MUST be present. - int64 ir_version = 1; - - // The OperatorSets this model relies on. - // All ModelProtos MUST have at least one entry that - // specifies which version of the ONNX OperatorSet is - // being imported. - // - // All nodes in the ModelProto's graph will bind against the operator - // with the same-domain/same-op_type operator with the HIGHEST version - // in the referenced operator sets. - repeated OperatorSetIdProto opset_import = 8; - - // The name of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - string producer_name = 2; - - // The version of the framework or tool used to generate this model. - // This field SHOULD be present to indicate which implementation/tool/framework - // emitted the model. - string producer_version = 3; - - // Domain name of the model. - // We use reverse domain names as name space indicators. For example: - // `com.facebook.fair` or `com.microsoft.cognitiveservices` - // - // Together with `model_version` and GraphProto.name, this forms the unique identity of - // the graph. - string domain = 4; - - // The version of the graph encoded. See Version enum below. - int64 model_version = 5; - - // A human-readable documentation for this model. Markdown is allowed. - string doc_string = 6; - - // The parameterized graph that is evaluated to execute the model. - GraphProto graph = 7; - - // Named metadata values; keys should be distinct. - repeated StringStringEntryProto metadata_props = 14; -}; - -// StringStringEntryProto follows the pattern for cross-proto-version maps. -// See https://developers.google.com/protocol-buffers/docs/proto3#maps -message StringStringEntryProto { - string key = 1; - string value= 2; -}; - -message TensorAnnotation { - string tensor_name = 1; - // pairs to annotate tensor specified by above. - // The keys used in the mapping below must be pre-defined in ONNX spec. - // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as - // quantization parameter keys. - repeated StringStringEntryProto quant_parameter_tensor_names = 2; -} - - - -// Graphs -// -// A graph defines the computational logic of a model and is comprised of a parameterized -// list of nodes that form a directed acyclic graph based on their inputs and outputs. -// This is the equivalent of the "network" or "graph" in many deep learning -// frameworks. -message GraphProto { - // The nodes in the graph, sorted topologically. - repeated NodeProto node = 1; - - // The name of the graph. - string name = 2; // namespace Graph - - // A list of named tensor values, used to specify constant inputs of the graph. - // Each TensorProto entry must have a distinct name (within the list) that - // MAY also appear in the input list. - repeated TensorProto initializer = 5; - - // Initializers (see above) stored in sparse format. - repeated SparseTensorProto sparse_initializer = 15; - - // A human-readable documentation for this graph. Markdown is allowed. - string doc_string = 10; - - // The inputs and outputs of the graph. - repeated ValueInfoProto input = 11; - repeated ValueInfoProto output = 12; - - // Information for the values in the graph. The ValueInfoProto.name's - // must be distinct. It is optional for a value to appear in value_info list. - repeated ValueInfoProto value_info = 13; - - // This field carries information to indicate the mapping among a tensor and its - // quantization parameter tensors. For example: - // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, - // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. - repeated TensorAnnotation quantization_annotation = 14; - - // DO NOT USE the following fields, they were deprecated from earlier versions. - // repeated string input = 3; - // repeated string output = 4; - // optional int64 ir_version = 6; - // optional int64 producer_version = 7; - // optional string producer_tag = 8; - // optional string domain = 9; -} - -// Tensors -// -// A serialized tensor value. -message TensorProto { - enum DataType { - UNDEFINED = 0; - // Basic types. - FLOAT = 1; // float - UINT8 = 2; // uint8_t - INT8 = 3; // int8_t - UINT16 = 4; // uint16_t - INT16 = 5; // int16_t - INT32 = 6; // int32_t - INT64 = 7; // int64_t - STRING = 8; // string - BOOL = 9; // bool - - // IEEE754 half-precision floating-point format (16 bits wide). - // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. - FLOAT16 = 10; - - DOUBLE = 11; - UINT32 = 12; - UINT64 = 13; - COMPLEX64 = 14; // complex with float32 real and imaginary components - COMPLEX128 = 15; // complex with float64 real and imaginary components - - // Non-IEEE floating-point format based on IEEE754 single-precision - // floating-point number truncated to 16 bits. - // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. - BFLOAT16 = 16; - - // Non-IEEE floating-point format based on papers - // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, - // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. - // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. - // The computation usually happens inside a block quantize / dequantize - // fused by the runtime. - FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf - FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients - - // Future extensions go here. - } - - // The shape of the tensor. - repeated int64 dims = 1; - - // The data type of the tensor. - // This field MUST have a valid TensorProto.DataType value - int32 data_type = 2; - - // For very large tensors, we may want to store them in chunks, in which - // case the following fields will specify the segment that is stored in - // the current TensorProto. - message Segment { - int64 begin = 1; - int64 end = 2; - } - Segment segment = 3; - - // Tensor content must be organized in row-major order. - // - // Depending on the data_type field, exactly one of the fields below with - // name ending in _data is used to store the elements of the tensor. - - // For float and complex64 values - // Complex64 tensors are encoded as a single array of floats, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. - repeated float float_data = 4 [packed = true]; - - // For int32, uint8, int8, uint16, int16, bool, and float16 values - // float16 values must be bit-wise converted to an uint16_t prior - // to writing to the buffer. - // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, UINT8, BOOL, or FLOAT16 - repeated int32 int32_data = 5 [packed = true]; - - // For strings. - // Each element of string_data is a UTF-8 encoded Unicode - // string. No trailing null, no leading BOM. The protobuf "string" - // scalar type is not used to match ML community conventions. - // When this field is present, the data_type field MUST be STRING - repeated bytes string_data = 6; - - // For int64. - // When this field is present, the data_type field MUST be INT64 - repeated int64 int64_data = 7 [packed = true]; - - // Optionally, a name for the tensor. - string name = 8; // namespace Value - - // A human-readable documentation for this tensor. Markdown is allowed. - string doc_string = 12; - - // Serializations can either use one of the fields above, or use this - // raw bytes field. The only exception is the string case, where one is - // required to store the content in the repeated bytes string_data field. - // - // When this raw_data field is used to store tensor value, elements MUST - // be stored in as fixed-width, little-endian order. - // Floating-point data types MUST be stored in IEEE 754 format. - // Complex64 elements must be written as two consecutive FLOAT values, real component first. - // Complex128 elements must be written as two consecutive DOUBLE values, real component first. - // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). - // - // Note: the advantage of specific field rather than the raw_data field is - // that in some cases (e.g. int data), protobuf does a better packing via - // variable length storage, and may lead to smaller binary footprint. - // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED - bytes raw_data = 9; - - // Data can be stored inside the protobuf file using type-specific fields or raw_data. - // Alternatively, raw bytes data can be stored in an external file, using the external_data field. - // external_data stores key-value pairs describing data location. Recognized keys are: - // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX - // protobuf model was stored - // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. - // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. - // - "length" (optional) - number of bytes containing data. Integer stored as string. - // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. - repeated StringStringEntryProto external_data = 13; - - // Location of the data for this tensor. MUST be one of: - // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. - // - EXTERNAL - data stored in an external location as described by external_data field. - enum DataLocation { - DEFAULT = 0; - EXTERNAL = 1; - } - - // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. - DataLocation data_location = 14; - - // For double - // Complex128 tensors are encoded as a single array of doubles, - // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the - // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] - // is encoded as [1.0, 2.0 ,3.0 ,4.0] - // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 - repeated double double_data = 10 [packed = true]; - - // For uint64 and uint32 values - // When this field is present, the data_type field MUST be - // UINT32 or UINT64 - repeated uint64 uint64_data = 11 [packed = true]; -} - -// A serialized sparse-tensor value -message SparseTensorProto { - // The sequence of non-default values are encoded as a tensor of shape [NNZ]. - // The default-value is zero for numeric tensors, and empty-string for string tensors. - TensorProto values = 1; - - // The indices of the non-default values, which may be stored in one of two formats. - // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value - // corresponding to the j-th index of the i-th value (in the values tensor). - // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value - // must be the linearized-index of the i-th value (in the values tensor). - // The linearized-index can be converted into an index tuple (k_1,...,k_rank) - // using the shape provided below. - // The indices must appear in ascending order without duplication. - // In the first format, the ordering is lexicographic-ordering: - // e.g., index-value [1,4] must appear before [2,1] - TensorProto indices = 2; - - // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] - repeated int64 dims = 3; -} - -// Defines a tensor shape. A dimension can be either an integer value -// or a symbolic variable. A symbolic variable represents an unknown -// dimension. -message TensorShapeProto { - message Dimension { - oneof value { - int64 dim_value = 1; - string dim_param = 2; // namespace Shape - }; - // Standard denotation can optionally be used to denote tensor - // dimensions with standard semantic descriptions to ensure - // that operations are applied to the correct axis of a tensor. - // Refer to https://github.com/onnx/onnx/blob/master/docs/DimensionDenotation.md#denotation-definition - // for pre-defined dimension denotations. - string denotation = 3; - }; - repeated Dimension dim = 1; -} - -// Types -// -// The standard ONNX data types. -message TypeProto { - - message Tensor { - // This field MUST NOT have the value of UNDEFINED - // This field MUST have a valid TensorProto.DataType value - // This field MUST be present for this version of the IR. - int32 elem_type = 1; - TensorShapeProto shape = 2; - } - - // repeated T - message Sequence { - // The type and optional shape of each element of the sequence. - // This field MUST be present for this version of the IR. - TypeProto elem_type = 1; - }; - - // map - message Map { - // This field MUST have a valid TensorProto.DataType value - // This field MUST be present for this version of the IR. - // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING - int32 key_type = 1; - // This field MUST be present for this version of the IR. - TypeProto value_type = 2; - }; - - oneof value { - // The type of a tensor. - Tensor tensor_type = 1; - - // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values - // as input and output to graphs and nodes. These types are needed to naturally - // support classical ML operators. DNN operators SHOULD restrict their input - // and output types to tensors. - - // The type of a sequence. - Sequence sequence_type = 4; - - // The type of a map. - Map map_type = 5; - - } - - // An optional denotation can be used to denote the whole - // type with a standard semantic description as to what is - // stored inside. Refer to https://github.com/onnx/onnx/blob/master/docs/TypeDenotation.md#type-denotation-definition - // for pre-defined type denotations. - string denotation = 6; -} - -// Operator Sets -// -// OperatorSets are uniquely identified by a (domain, opset_version) pair. -message OperatorSetIdProto { - // The domain of the operator set being identified. - // The empty string ("") or absence of this field implies the operator - // set that is defined as part of the ONNX specification. - // This field MUST be present in this version of the IR when referring to any other operator set. - string domain = 1; - - // The version of the operator set being identified. - // This field MUST be present in this version of the IR. - int64 version = 2; -} diff --git a/proto/op_mapping.proto b/proto/op_mapping.proto deleted file mode 100644 index ade24b352542ffaee86bd0c28653d4fe53f56020..0000000000000000000000000000000000000000 --- a/proto/op_mapping.proto +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package toolkit.aicpu.dump; - -message Shape { - repeated uint64 dim = 1; -} - -enum AddressType { - TRADITIONAL_ADDR = 0; - NOTILING_ADDR = 1; - RAW_ADDR = 2; - NANO_IO_ADDR = 3; - NANO_WEIGHT_ADDR = 4; - NANO_WORK_ADDR = 5; -} - -message Output { - int32 data_type = 1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - string original_name = 5; - int32 original_output_index = 6; - int32 original_output_data_type = 7; - int32 original_output_format = 8; - uint64 size = 9; - Shape origin_shape = 10; - AddressType addr_type = 11; - uint64 offset = 12; -} - -message Input { - int32 data_type =1; - int32 format = 2; - Shape shape = 3; - uint64 address = 4; - uint64 size = 5; - Shape origin_shape = 6; - AddressType addr_type = 7; - uint64 offset = 8; -} - -enum BufferType { - L1 = 0; -} - -message OpBuffer { - BufferType buffer_type = 1; - uint64 address = 2; - uint64 size = 3; -} - -message Op { - string op_name = 1; - string op_type = 2; -} - -message OpAttr { - string name = 1; - string value = 2; -} - -message Workspace { - enum SpaceType { - LOG = 0; - } - SpaceType type = 1; - uint64 data_addr = 2; - uint64 size = 3; -} - -message RealAddressAndSize { - uint64 address = 1; - uint64 size = 2; -} - -message Context { - uint32 context_id = 1; - uint32 thread_id = 2; - repeated RealAddressAndSize input = 3; - repeated RealAddressAndSize output = 4; -} - -message Task { - enum TaskType { - AICORE = 0; - AICPU = 1; - DEBUG = 2; - SDMA = 3; - FFTSPLUS = 4; - DSA = 5; - } - uint32 task_id = 1; - uint32 stream_id = 2; - Op op = 3; - repeated Output output = 4; - bool end_graph = 5; - repeated Input input = 6; - repeated OpBuffer buffer = 7; - TaskType task_type = 8; - uint32 context_id = 9; - repeated OpAttr attr = 10; - repeated Workspace space = 11; - repeated Context context = 12; - uint32 thread_id = 13; -} - -enum DumpData { - TENSOR_DUMP_DATA = 0; - STATS_DUMP_DATA = 1; -} - -message OpMappingInfo { - string dump_path = 1; - oneof model_name_param { - string model_name = 2; - } - oneof model_id_param { - uint32 model_id = 3; - } - oneof step_id { - uint64 step_id_addr = 4; - } - oneof iterations_per_loop { - uint64 iterations_per_loop_addr = 5; - } - oneof loop_cond { - uint64 loop_cond_addr = 6; - } - uint32 flag = 7; // 0x01 load, 0x00 unload - repeated Task task = 8; - string dump_step = 9; - DumpData dump_data = 10; -} diff --git a/proto/optimizer_priority.proto b/proto/optimizer_priority.proto deleted file mode 100644 index d2593e7d107b12aeac4ce4af0d39342dee44f845..0000000000000000000000000000000000000000 --- a/proto/optimizer_priority.proto +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package ge.optimizers; - -// Default: GE>FE>AICPU -message Priority{ - repeated string optimizer = 1; -} diff --git a/proto/stub/Makefile b/proto/stub/Makefile deleted file mode 100644 index ff94976d7782bd88ffebb117e7c68b15ca434c8b..0000000000000000000000000000000000000000 --- a/proto/stub/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -inc_path := $(shell pwd)/metadef/inc/external/ -out_path := $(shell pwd)/out/opcommon/lib64/stub/ -stub_path := $(shell pwd)/metadef/opcommon/stub/ - -mkdir_stub := $(shell mkdir -p $(out_path)) -opcommon_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/proto/stub/gen_stubapi.py b/proto/stub/gen_stubapi.py deleted file mode 100644 index 7e468a134a98252ee7046d4f9fd3a645846ab8f7..0000000000000000000000000000000000000000 --- a/proto/stub/gen_stubapi.py +++ /dev/null @@ -1,613 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: UTF-8 -*- -#------------------------------------------------------------------- -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -import os -import re -import sys -import logging - -""" - generate stub func body by return type -""" -RETURN_STATEMENTS = { - 'ge::graphStatus': - ' std::cout << "[ERROR]: stub library libop_common cannot be used for execution, please check your "\n' - ' << "environment variables and compilation options to make sure you use the correct library."\n' - ' << std::endl;\n' - ' return ge::GRAPH_FAILED;' -} - -""" - white_list_for_debug, include_dir_key_words is to - determines which header files to generate cc files from - when DEBUG on -""" -white_list_for_debug = ["common_infershape_fns.h"] -include_dir_key_words = ["op_common"] - -""" - this attr is used for symbol table visible -""" -GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' - -""" - max code len per line in hua_wei software programming specifications -""" -MAX_CODE_LEN_PER_LINE = 100 - -DEBUG = True - -logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', - level=logging.INFO) - - -def need_generate_func(func_line): - """ - :param func_line: - :return: - """ - if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ - or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): - return False - return True - - -def file_endswith_white_list_suffix(file): - """ - :param file: - :return: - """ - if DEBUG: - for suffix in white_list_for_debug: - suffix = re.sub(r'^/*', '/', suffix) - if file.endswith(suffix): - return True - return False - else: - return True - - -""" - belows are patterns used for analyse .h file -""" -# pattern function -pattern_func = re.compile(r"""(^[\s]*)([a-zA-Z~_].*[)](?!.*{).*)(;.*)\n$""", re.VERBOSE | re.MULTILINE | re.DOTALL) -# pattern comment -pattern_comment = re.compile(r'^\s*//') -pattern_comment_2_start = re.compile(r'^\s*/[*]') -pattern_comment_2_end = re.compile(r'[*]/\s*$') -# pattern define -pattern_define = re.compile(r'^\s*#define') -pattern_define_return = re.compile(r'\\\s*$') -# pattern static_assert -pattern_static_assert = re.compile(r'^\s*static_assert') -pattern_static_assert_return = re.compile(r'\);\s*$') -# blank line -pattern_blank_line = re.compile(r'^\s*$') -# virtual,explicit,friend,static -pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') -# lead space -pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') -# functions will have patterns such as func ( or func( -# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist -# format like :"operator = ()" -pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') -# template -pattern_template = re.compile(r'^\s*template') -pattern_template_end = re.compile(r'>\s*$') -# namespace -pattern_namespace = re.compile(r'namespace.*{') -# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with -pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ - | - std::(?:vector|shared_ptr)> - | - std::(?:map|unordered_map|pair)<[:\w]+[, ]+[:\w]+> - ) - ) - ([ ]+) - ([&*]+)""", re.VERBOSE) -# pattern for parsing ret_type & func_name -pat_search_func = re.compile(r"""^(?:const[ ]+)? - (?P - (?: - [:\w]+ - | - std::(?:vector|shared_ptr)<[:\w ]+> - | - std::(?:vector|shared_ptr)> - | - std::(?:map|unordered_map|pair)<[:\w]+[, ]+[:\w]+> - ) - (?:[&*]+)? - ) - [ ]+ - (?P\w+) - :: - \n? - (?P\w+|operator=) - [ ]* - \(""", re.VERBOSE) - - -class H2CC(object): - def __init__(self, input_file, output_file, shared_includes_content): - """ - :param input_file: - :param output_file: - :param shared_includes_content: - """ - self.input_file = input_file - self.output_file = output_file - self.shared_includes_content = shared_includes_content - self.line_index = 0 - self.input_fd = open(self.input_file, 'r') - self.input_content = self.input_fd.readlines() - self.output_fd = open(self.output_file, 'w') - - # The state may be normal_now(in the middle of {}),class_now,namespace_now - self.stack = [] - self.stack_class = [] - self.stack_template = [] - # record funcs generated by h2cc func - self.func_list_exist = [] - - def __del__(self): - self.input_fd.close() - self.output_fd.close() - del self.stack - del self.stack_class - del self.stack_template - del self.func_list_exist - - def just_skip(self): - # skip blank line or comment - if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( - self.input_content[self.line_index]): # /n or comment using // - self.line_index += 1 - if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /* - while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */ - self.line_index += 1 - self.line_index += 1 - # skip define - if pattern_define.search(self.input_content[self.line_index]): - while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search( - self.input_content[self.line_index]): - self.line_index += 1 - self.line_index += 1 - # skip static_assert - if pattern_static_assert.search(self.input_content[self.line_index]): - while not pattern_static_assert_return.search(self.input_content[self.line_index]): - self.line_index += 1 - self.line_index += 1 - - def write_inc_content(self): - for shared_include_content in self.shared_includes_content: - self.output_fd.write(shared_include_content) - - def h2cc(self): - """ - :return: - """ - logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file) - # write inc content - self.write_inc_content() - # core processing cycle, process the input .h file by line - while self.line_index < len(self.input_content): - # handle comment and blank line - self.just_skip() - - # match namespace - self.handle_namespace() - - # match template - template_string = self.handle_template() - # match class - line = self.input_content[self.line_index] - match_class = pattern_class.search(line) - match_start = pattern_start.search(line) - handle_class_result = self.handle_class(template_string, line, match_start, match_class) - if handle_class_result == "continue": - continue - - # match "}" - handle_stack_result = self.handle_stack(match_start) - if handle_stack_result == "continue": - continue - # handle func - handle_func1_result, line, start_i = self.handle_func1(line) - if handle_func1_result == "continue": - continue - - # here means func is found - # delete key word - line = pattern_keyword.sub('', line) - logging.info("line[%s]", line) - - # Class member function - # if friend we will not add class name - friend_match = re.search('friend ', line) - if len(self.stack_class) > 0 and not friend_match: - line, func_name = self.handle_class_member_func(line, template_string) - # Normal functions - else: - line, func_name = self.handle_normal_func(line, template_string) - - need_generate = need_generate_func(line) - # func body - line += self.implement_function(line) - # comment - line = self.gen_comment(start_i) + line - # write to out file - self.write_func_content(line, func_name, need_generate) - # next loop - self.line_index += 1 - - logging.info('Added %s functions', len(self.func_list_exist)) - logging.info('Successfully converted,please see %s', self.output_file) - - def handle_func1(self, line): - """ - :param line: - :return: - """ - find1 = re.search('[(]', line) - if not find1: - self.line_index += 1 - return "continue", line, None - find2 = re.search('[)]', line) - start_i = self.line_index - space_match = pattern_leading_space.search(line) - # deal with - # int abc(int a, - # int b) - if find1 and (not find2): - self.line_index += 1 - line2 = self.input_content[self.line_index] - if space_match: - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): - self.line_index += 1 - line2 = self.input_content[self.line_index] - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - - match_start = pattern_start.search(self.input_content[self.line_index]) - match_end = pattern_end.search(self.input_content[self.line_index]) - if match_start: # like ) { or ) {} int the last line - if not match_end: - self.stack.append('normal_now') - ii = start_i - while ii <= self.line_index: - ii += 1 - self.line_index += 1 - return "continue", line, start_i - logging.info("line[%s]", line) - # ' int abc();'->'int abc()' - (line, match) = pattern_func.subn(r'\2\n', line) - logging.info("line[%s]", line) - # deal with case of 'return type' and 'func_name' are not in the same line, like: 'int \n abc(int a, int b)' - if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): - line = self.input_content[start_i - 1] + line - line = line.lstrip() - if not match: - self.line_index += 1 - return "continue", line, start_i - return "pass", line, start_i - - def handle_stack(self, match_start): - """ - :param match_start: - :return: - """ - line = self.input_content[self.line_index] - match_end = pattern_end.search(line) - if match_start: - self.stack.append('normal_now') - if match_end: - top_status = self.stack.pop() - if top_status == 'namespace_now': - self.output_fd.write(line + '\n') - elif top_status == 'class_now': - self.stack_class.pop() - self.stack_template.pop() - if match_start or match_end: - self.line_index += 1 - return "continue" - - if len(self.stack) > 0 and self.stack[-1] == 'normal_now': - self.line_index += 1 - return "continue" - return "pass" - - def handle_class(self, template_string, line, match_start, match_class): - """ - :param template_string: - :param line: - :param match_start: - :param match_class: - :return: - """ - if not match_class: # we face a class - return "pass" - self.stack_template.append(template_string) - self.stack.append('class_now') - class_name = match_class.group(3) - - # class template specializations: class A > - if '<' in class_name: - k = line.index('<') - fit = 1 - for ii in range(k + 1, len(line)): - if line[ii] == '<': - fit += 1 - if line[ii] == '>': - fit -= 1 - if fit == 0: - break - class_name += line[k + 1:ii + 1] - logging.info('class_name[%s]', class_name) - self.stack_class.append(class_name) - while not match_start: - self.line_index += 1 - line = self.input_content[self.line_index] - match_start = pattern_start.search(line) - self.line_index += 1 - return "continue" - - def handle_template(self): - line = self.input_content[self.line_index] - match_template = pattern_template.search(line) - template_string = '' - if match_template: - match_template_end = pattern_template_end.search(line) - template_string = line - while not match_template_end: - self.line_index += 1 - line = self.input_content[self.line_index] - template_string += line - match_template_end = pattern_template_end.search(line) - self.line_index += 1 - return template_string - - def handle_namespace(self): - line = self.input_content[self.line_index] - match_namespace = pattern_namespace.search(line) - if match_namespace: # we face namespace - self.output_fd.write(line + '\n') - self.stack.append('namespace_now') - self.line_index += 1 - - def handle_normal_func(self, line, template_string): - template_line = '' - self.stack_template.append(template_string) - if self.stack_template[-1] != '': - template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) - # change '< class T = a, class U = A(3)>' to '' - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = re.sub(r'\s*=.*,', ',', line) - line = re.sub(r'\s*=.*\)', ')', line) - line = template_line + line - self.stack_template.pop() - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def handle_class_member_func(self, line, template_string): - template_line = '' - x = '' - if template_string != '': - template_string = re.sub(r'\s*template', 'template', template_string) - template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) - template_string = re.sub(r'\s*=.*,', ',', template_string) - template_string = re.sub(r'\s*=.*', '', template_string) - if self.stack_template[-1] != '': - if not (re.search(r'<\s*>', stack_template[-1])): - template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) - if not (re.search(r'<.*>', self.stack_class[-1])): - # for x we get like template -> - x = re.sub(r'template\s*<', '<', template_line) # remove template -> - x = re.sub(r'\n', '', x) - x = re.sub(r'\s*=.*,', ',', x) - x = re.sub(r'\s*=.*\>', '>', x) - x = x.rstrip() # remove \n - x = re.sub(r'(class|typename)\s+|(|\s*class)', '', - x) # remove class,typename -> - x = re.sub(r'<\s+', '<', x) - x = re.sub(r'\s+>', '>', x) - x = re.sub(r'\s+,', ',', x) - x = re.sub(r',\s+', ', ', x) - line = re.sub(r'\s*=\s+0', '', line) - line = re.sub(r'\s*=\s+.*,', ',', line) - line = re.sub(r'\s*=\s+.*\)', ')', line) - logging.info("x[%s]\nline[%s]", x, line) - # if the function is long, void ABC::foo() - # breaks into two lines void ABC::\n foo() - rep_fmt = '%s%s::{}%s' % (self.stack_class[-1], x, r'\1(') - temp_line = pattern_func_name.sub(rep_fmt.format(''), line, count=1) - if len(temp_line) > MAX_CODE_LEN_PER_LINE: - line = pattern_func_name.sub(rep_fmt.format('\n'), line, count=1) - else: - line = temp_line - logging.info("line[%s]", line) - # add template as the above if there is one - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = template_line + template_string + line - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def write_func_content(self, content, func_name, need_generate): - if not (func_name in self.func_list_exist) and need_generate: - self.output_fd.write(content) - self.func_list_exist.append(func_name) - logging.info('add func:[%s]', func_name) - - def gen_comment(self, start_i): - comment_line = '' - # Function comments are on top of function declarations, copy them over - k = start_i - 1 # one line before this func start - if pattern_template.search(self.input_content[k]): - k -= 1 - if pattern_comment_2_end.search(self.input_content[k]): - comment_line = self.input_content[k].lstrip() - while not pattern_comment_2_start.search(self.input_content[k]): - k -= 1 - comment_line = self.input_content[k].lstrip() + comment_line - else: - for j in range(k, 0, -1): - c_line = self.input_content[j] - if pattern_comment.search(c_line): - c_line = re.sub(r'\s*//', '//', c_line) - comment_line = c_line + comment_line - else: - break - return comment_line - - @staticmethod - def get_return_statements(func): - func = pat_format_func.sub(r'\1\3\2', func) - m = pat_search_func.search(func) - if not m: - return None - logging.info('ret_type: %s, class_name: %s, func_name: %s', *m.group('ret_type', 'class_name', 'func_name')) - type_cls_func_name = '%s %s::%s' % m.group('ret_type', 'class_name', 'func_name') - if type_cls_func_name in RETURN_STATEMENTS: - logging.info('type_cls_func_name:[%s] matched!', type_cls_func_name) - return RETURN_STATEMENTS[type_cls_func_name] - type_cls_name = '%s %s::' % m.group('ret_type', 'class_name') - if type_cls_name in RETURN_STATEMENTS: - logging.info('type_cls_name:[%s] matched!', type_cls_name) - return RETURN_STATEMENTS[type_cls_name] - type_only = m.group('ret_type') - if type_only in RETURN_STATEMENTS: - logging.info('type_only:[%s] matched!', type_only) - return RETURN_STATEMENTS[type_only] - return None - - @staticmethod - def implement_function(func): - function_def = '' - function_def += '{\n' - - return_statements = H2CC.get_return_statements(func) - if return_statements is not None: - function_def += return_statements - else: - all_items = func.split() - start = 0 - return_type = all_items[start] - if return_type == "const": - start += 1 - return_type = all_items[start] - if return_type.startswith(('std::map', 'std::set', 'std::vector')): - return_type = "std::map" - if return_type.endswith('*') or ( - len(all_items) > start + 1 and all_items[start + 1].startswith('*')) or return_type.startswith( - 'std::unique_ptr'): - return_type = "Ptr" - if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): - return_type += "&" - if RETURN_STATEMENTS.__contains__(return_type): - function_def += RETURN_STATEMENTS[return_type] - else: - logging.info("Unhandled func[%s]", func) - logging.warning("Unhandled return type[%s]", return_type) - - function_def += '\n' - function_def += '}\n' - function_def += '\n' - return function_def - - -def collect_header_files(path): - """ - :param path: - :return: - """ - header_files = [] - shared_includes_content = [] - for root, dirs, files in os.walk(path): - files.sort() - dirs.sort() - for file in files: - if file.find("git") >= 0: - continue - if not file.endswith('.h'): - continue - file_path = os.path.join(root, file) - file_path = file_path.replace('\\', '/') - header_files.append(file_path) - include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) - shared_includes_content.append(include_str) - # for acl error code - shared_includes_content.append('#include \n') - return header_files, shared_includes_content - - -def generate_stub_file(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - target_header_files, shared_includes_content = collect_header_files(inc_dir) - for header_file in target_header_files: - if not file_endswith_white_list_suffix(header_file): - continue - cc_file = re.sub(r'([^/]+)\.h$', r'stub_\1.cc', header_file) - h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) - h_2_cc.h2cc() - - -def gen_code(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - if not inc_dir.endswith('/'): - inc_dir += '/' - if not out_cc_dir.endswith('/'): - out_cc_dir += '/' - for include_dir_key_word in include_dir_key_words: - generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) - - -def main(): - if len(sys.argv) != 3: - logging.error("script %s must have 2 input parameters!", sys.argv[0]) - return - inc_dir = sys.argv[1] - out_cc_dir = sys.argv[2] - gen_code(inc_dir, out_cc_dir) - - -if __name__ == '__main__': - main() diff --git a/proto/task.proto b/proto/task.proto deleted file mode 100644 index 06ab90fd5fe1ec36428434167611abd8b18d6785..0000000000000000000000000000000000000000 --- a/proto/task.proto +++ /dev/null @@ -1,1061 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi; - -message ModelTaskDef { - string version = 1; - - map attr = 9; // Extended field - repeated TaskDef task = 10; - - uint64 memory_size = 11; - uint32 stream_num = 12; - uint32 event_num = 13; - uint64 weight_size = 14; - - repeated bytes op = 15; // input/output opdef in bytes - - uint64 base_addr = 16; // base addr - uint64 weight_addr = 17; // weight addr - uint32 batch_num = 18; -} - - -message TaskDef { - uint32 id = 1; - uint32 type = 2; - - uint32 stream_id = 10; - uint32 event_id = 11; - uint32 notify_id = 12; - uint32 sqe_num = 13; - - KernelDef kernel = 20; - KernelExDef kernel_ex = 21; - KernelHcclDef kernel_hccl = 25; - EventExDef event_ex = 26; - LogTimeStampDef log_timestamp = 28; - - uint32 label_id = 30; - - MemcpyAsyncDef memcpy_async = 31; - StreamSwitchDef stream_switch = 32; - StreamActiveDef stream_active = 33; - bytes private_def = 34; - uint64 ops_kernel_store_ptr = 35; // adjustments to other fields in the future - StreamSwitchNDef stream_switch_n = 36; - - LabelSetDef label_set = 37; - LabelGotoExDef label_goto_ex = 38; - LabelSwitchByIndexDef label_switch_by_index = 39; - KernelDefWithHandle kernel_with_handle = 40; - FftsTaskDef ffts_task = 41; - FftsPlusTaskDef ffts_plus_task = 42; - DSATaskDef dsa_task = 43; - CmoTaskDef cmo_task = 44; - CmoBarrierTaskDef cmo_barrier_task = 45; - NpuGetFloatStatusDef npu_get_float_status = 46; - NpuClearFloatStatusDef npu_clear_float_status = 47; - DvppTaskDef dvpp_task = 48; - NpuGetFloatDebugStatusDef npu_get_float_debug_status = 49; - NpuClearFloatDebugStatusDef npu_clear_float_debug_status = 50; - CmoAddrTaskDef cmo_addr_task = 51; - UpdatePcTaskDef update_pc_task = 52; - FusionTaskDef fusion_task = 53; -} - -message KernelDef { - KernelContext context = 1; - - string stub_func = 10; - uint32 block_dim = 11; - uint32 args_size = 12; - bytes args = 13; - bytes sm_desc = 14; - bytes flowtable = 15; - string so_name = 16; - string kernel_name = 17; - bytes kernel_ext_info = 18; - uint32 kernel_ext_info_size = 19; - repeated ArgsInfo args_info = 20; - uint32 schedule_mode = 21; - uint32 block_dim_offset = 22; // vector core offset -} - -message KernelDefWithHandle { - KernelContext context = 1; - - uint64 handle = 10; - string dev_func = 11; - uint32 block_dim = 12; - uint32 args_size = 13; - bytes args = 14; - bytes sm_desc = 15; - string original_kernel_key = 16; - string node_info = 17; - repeated ArgsInfo args_info = 18; - uint32 schedule_mode = 19; - uint32 block_dim_offset = 20; // vector core offset -} - -message KernelContext { - uint32 kernel_type = 1; - uint32 op_id = 2; // OP type in CCE - uint32 kernel_func_id = 3; - uint32 op_index = 4; // TE/Custom operator - bool is_flowtable = 5; // Identify whether args is a flowtable structure - bytes args_offset = 6; // args offset information - uint32 args_count = 7; // args count - string args_format = 10; - repeated uint32 origin_op_index = 8; -} - -message ArgsInfo { - enum ArgsType { - INPUT = 0; - OUTPUT = 1; - } - enum ArgsFormat { - DIRECT_ADDR = 0; - SECONDARY_ADDR = 1; - } - ArgsType arg_type = 1; // 表示args内存中的输入、输出等类型 - ArgsFormat arg_format = 2; // 一级指针 还是 二级指针 - int32 start_index = 3; // -1代表没有实际连边 - uint32 size = 4; // 实际数据个数 -} - -message KernelExDef { - uint32 flags = 1; - - uint32 op_index = 4; - uint32 args_size = 12; - bytes args = 13; - bytes task_info = 14; // serialized nodeDef, funcDef, inputoutput - uint32 task_info_size = 15; - bytes kernel_ext_info = 16; - uint32 kernel_ext_info_size = 17; -} - - -message KernelHcclDef { - uint32 op_index = 8; - string hccl_type = 9; - repeated int32 input_zero_copy_flag = 10; - repeated int32 output_zero_copy_flag = 11; -} - - -message EventExDef { - uint32 op_index = 1; - uint32 event_type = 2; -} - -message LogTimeStampDef { - uint64 logid = 1; - bool notify = 2; - uint32 flat = 3; -} - -message MemcpyAsyncDef { - uint64 dst = 1; - uint64 dst_max = 2; - uint64 src = 3; - uint64 count = 4; - uint32 kind = 5; - uint32 op_index = 6; - string args_format = 7; -} - -message StreamSwitchDef { - uint32 op_index = 1; - uint32 true_stream_id = 2; - int64 value = 3; - uint64 value_ptr = 4; - uint32 data_type = 5; -} - -message StreamActiveDef { - uint32 op_index = 1; - uint32 active_stream_id = 2; -} - -message StreamSwitchNDef { - uint32 op_index = 1; - uint32 size = 2; - repeated int64 target_value = 3; - repeated uint32 true_stream_id = 4; - uint32 element_size = 5; - uint32 data_type = 6; -} - -message LabelSetDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelGotoExDef { - uint32 op_index = 1; - uint32 label_id = 2; - uint32 model_id = 3; -} - -message LabelSwitchByIndexDef { - uint32 op_index = 1; - uint32 label_max = 2; -} - -message FftsTaskDef { - uint32 op_index = 1; - - uint32 ffts_type = 2; // 2: Auto Threading / 3: Manual Threading - uint32 addr_size = 3; // Total task addr num - FftsDescInfoDef ffts_desc = 4; - repeated FftsSubTaskDef sub_task = 5; - repeated TicketCacheDef ticket_cache = 6; -} - -message FftsDescInfoDef { - uint32 tm = 1; // thread subtask kickstart mode, 0:order, 1:disorder - uint32 di = 2; // discard invalidate - uint32 dw = 3; // discard write back - uint32 df = 4; // discard flush - uint32 data_split_unit = 5; // split source or ticket cache by 2~dataSplitUnit MB - uint32 prefetch_ost_num = 6; - uint32 cache_maintain_ost_num = 7; - uint32 aic_prefetch_upper = 8; - uint32 aic_prefetch_lower = 9; - uint32 aiv_prefetch_upper = 10; - uint32 aiv_prefetch_lower = 11; -} - -message FftsSubTaskDef { - uint32 sub_task_type = 1; // 0: AIC / 1: AIV / 2: DMU / 3: NOP / 4: Prefetch - - uint32 thread_dim = 2; // thread count - uint32 dst_tick_cache_vld_bitmap = 3; - uint32 src_tick_cache_vld_bitmap = 4; - uint32 src_data_out_of_subgraph_bitmap = 5; - - repeated uint32 dst_tick_cache_id = 6; // Max 8 output - repeated uint32 src_tick_cache_id = 7; // Max 8 input - - AutoThreadAicAivDef auto_thread_aic_aiv = 8; - - ManualThreadAicAivDef manual_thread_aic_aiv = 9; - ManualThreadNopDef manual_thread_nop = 10; -} - -message TicketCacheDef { - uint32 cache_option = 1; // 1: invalidate / 2: flush - uint32 ticket_cache_window = 2; - - AutoThreadCacheDef auto_thread_cache = 3; - - ManualThreadCacheDef manual_thread_cache = 4; -} - -message AutoThreadAicAivDef { - repeated uint64 task_addr = 1; // input / output address (runtime device memory) - repeated uint64 task_addr_offset = 2; - uint32 task_param_offset = 3; - - uint32 sat_mode = 4; - uint32 schedule_mode = 5; // 0:normal mode, 1:batch mode, 2:sync mode, 3:reserved - uint32 cache_prefetch_cnt = 6; // units is 2K - uint32 prefetch_enable_bitmap = 7; // 8 bit bitmap - uint32 prefetch_once_bitmap = 8; // 8 bit bitmap - - uint32 tail_blk_dim = 9; - uint32 non_tail_blk_dim = 10; - - string non_tail_task_func_stub = 11; - string tail_task_func_stub = 12; - - repeated AutoThreadPrefetchDef src_prefetch = 13; - uint32 input_output_count = 14; -} - -message AutoThreadCacheDef { - uint64 data_addr = 1; // device mem - uint32 data_addr_offset = 2; - uint32 non_tail_data_len = 3; - uint32 tail_data_len = 4; - uint32 ticket_cache_ref_cnt = 5; -} - -message AutoThreadPrefetchDef { - uint64 data_addr = 1; // device mem - uint32 data_addr_offset = 2; - uint32 non_tail_data_len = 3; - uint32 tail_data_len = 4; -} - -message ManualThreadAicAivDef { - repeated uint64 task_addr = 1; // input/output address(runtime device memory) - repeated uint64 task_addr_offset = 2; - uint32 task_param_offset = 3; - - uint32 sat_mode = 4; - uint32 schedule_mode = 5; // 0:normal mode; 1:batch mode; 2:sync mode; 3:reserved - uint32 cache_prefetch_cnt = 6; // units is 2K - uint32 prefetch_enable_bitmap = 7; // 8 bit bitmap 1010 - uint32 prefetch_once_bitmap = 8; // 8 bit bitmap 1010 - - uint32 prefetch_once_dmu_num = 9; // prefetch_once_dmu_descriptor_index in ffts - // num: thread0_prefetch_dmu_descriptor_index - prefetch_once_dmu_descriptor_index - // offset: PrefetchOnceDmuDescIndex - - repeated uint32 thread_prefetch_dmu_idx = 10; // max valid is thread dim - repeated uint32 thread_blk_dim = 11; - - repeated string thread_task_func_stub = 12; - - repeated ManualThreadDmuDef prefetch_list = 13; // dmu desc 0-64k - repeated ManualThreadDependencyDef src_dep_tbl = 14; - uint32 input_output_count = 15; -} - -message ManualThreadNopDef { - repeated ManualThreadDependencyDef src_dep_tbl = 1; -} - -message ManualThreadCacheDef { - repeated ManualThreadDmuDef dmu_list = 1; - - repeated uint32 slice_dmu_idx = 2; - repeated uint32 ticket_cache_ref_cnt_tbl = 3; -} - -message ManualThreadDmuDef { - uint64 data_addr = 1; // device mem - uint32 num_outer = 2; - uint32 num_inner = 3; - uint32 stride_outer = 4; - uint32 len_inner = 5; - uint32 stride_inner = 6; -} - -message ManualThreadDependencyDef { - repeated uint32 dependency = 1; -} - -message DSATaskDef { - uint32 op_index = 1; - uint32 start = 2; // start, the value is 1 - uint32 sqe_type = 3; - uint32 distribution_type = 4; - uint32 data_type = 5; - uint32 alg_type = 6; - uint32 input_vld = 7; - uint32 input_value_addr_flag = 8; - uint32 input1_value_or_ptr = 9; - uint32 input2_value_or_ptr = 10; - uint32 seed_value_or_ptr = 11; - uint32 random_count_value_or_ptr = 12; - DSATaskArgsDef args = 13; -} - -message DSATaskArgsDef { - uint64 output_addr = 1; - uint64 workspace_philox_count_addr = 2; - uint64 workspace_input_addr = 3; - bytes seed_value_or_addr = 4; - bytes random_count_value_or_addr = 5; - bytes input1_value_or_addr = 6; - bytes input2_value_or_addr = 7; -} - -message FftsPlusTaskDef { - uint32 op_index = 1; - uint32 addr_size = 2; // Total task addr num - FftsPlusSqeDef ffts_plus_sqe = 3; - repeated FftsPlusCtxDef ffts_plus_ctx = 4; // include total context - repeated AdditionalDataDef additional_data = 5; -} - -message AdditionalDataDef { - uint32 data_type = 1; // ModeInArgsFirstField - repeated uint32 context_id = 2; -} - -message FftsPlusSqeDef { - StarsSqeHeaderDef sqe_header = 1; - - uint32 wrr_ratio = 2; - uint32 sqe_index = 3; - - uint32 total_context_num = 4; - uint32 ready_context_num = 5; - uint32 preload_context_num = 6; - - uint32 prefetch_ost_num = 7; - uint32 cmaint_ost_num = 8; - - uint32 aic_prefetch_lower = 9; - uint32 aic_prefetch_upper = 10; - uint32 aiv_prefetch_lower = 11; - uint32 aiv_prefetch_upper = 12; - - uint32 data_split_unit = 13; -} - -message StarsSqeHeaderDef { - uint32 l1_lock = 1; - uint32 l1_unlock = 2; - uint32 block_dim = 3; -} - -// ffts plus context -message FftsPlusCtxDef { - enum OpType { - NORMAL = 0; - ATOMIC = 1; - } - uint32 op_index = 1; - string uniq_ctx_name = 2; - uint32 context_type = 3; - uint32 context_id = 4; - OpType op_type = 5; - - FftsPlusAicAivCtxDef aic_aiv_ctx = 6; - FftsPlusMixAicAivCtxDef mix_aic_aiv_ctx = 7; - FftsPlusSdmaCtxDef sdma_ctx = 8; - FftsPlusNotifyCtxDef notify_ctx = 9; - FftsPlusWriteValueCtxDef write_value_ctx = 10; - FftsPlusAicpuCtxDef aicpu_ctx = 11; - FftsPlusDataCtxDef data_ctx = 12; - FftsPlusAtStartCtxDef at_start_ctx = 13; - FftsPlusAtEndCtxDef at_end_ctx = 14; - FftsPlusLabelCtxDef label_ctx = 15; - FftsPlusCaseSwitchCtxDef case_switch_ctx = 16; - FftsPlusCaseDefaultCtxDef case_default_ctx = 17; - FftsPlusCondSwitchCtxDef cond_switch_ctx = 18; - FftsPlusCachePersistCtxDef cache_persist_ctx = 19; - FftsPlusDsaCtxDef dsa_ctx = 20; -} - -// aic/aiv context -message FftsPlusAicAivCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 prefetch_config = 3; - uint32 pred_cnt_init = 4; - uint32 pred_cnt = 5; - repeated uint32 successor_list = 6; // 16 bits, len = 26 - - uint32 schem = 7; - uint32 atm = 8; - uint32 prefetch_enable_bitmap = 9; - uint32 prefetch_once_bitmap = 10; - - uint32 pmg = 11; - uint32 ns = 12; - uint32 part_id = 13; - uint32 qos = 14; - - uint32 thread_id = 15; - uint32 thread_dim = 16; - - uint32 non_tail_block_dim = 17; - uint32 tail_block_dim = 18; - - uint32 task_param_ptr_offset = 19; - uint32 save_task_addr = 20; - repeated uint64 task_addr = 21; - repeated uint64 task_addr_offset = 22; - uint32 input_output_count = 23; - - repeated string kernel_name = 24; - - repeated uint32 src_slot = 25; // len = 4, context ID for source data which is out of subgraph - uint32 policy_pri = 26; - uint32 thread_window_size = 27; -} - -// mix aic/aiv context -message FftsPlusMixAicAivCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 prefetch_config = 3; - uint32 pred_cnt_init = 4; - uint32 pred_cnt = 5; - repeated uint32 successor_list = 6; // len = 26 - - uint32 schem = 7; - uint32 atm = 8; - uint32 prefetch_enable_bitmap = 9; - uint32 prefetch_once_bitmap = 10; - - uint32 pmg = 11; - uint32 ns = 12; - uint32 part_id = 13; - uint32 qos = 14; - - uint32 non_tail_block_ratio_n = 15; - uint32 tail_block_ratio_n = 16; - - uint32 thread_id = 17; - uint32 thread_dim = 18; - - uint32 non_tail_block_dim = 19; - uint32 tail_block_dim = 20; - - uint32 aic_task_param_ptr_offset = 21; - uint32 aiv_task_param_ptr_offset = 22; - - repeated string kernel_name = 23; - - repeated uint64 task_addr = 24; - repeated uint64 task_addr_offset = 25; - uint32 input_output_count = 26; - uint32 save_task_addr = 27; - repeated uint32 src_slot = 28; // len = 4, context ID for source data which is out of subgraph - uint32 policy_pri = 29; - uint32 thread_window_size = 30; - string args_format = 31; -} - -// sdma context -message FftsPlusSdmaCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 pred_cnt_init = 3; - uint32 pred_cnt = 4; - repeated uint32 successor_list = 5; // len = 26 - - uint32 atm = 6; - uint32 pmg = 7; - uint32 ns = 8; - uint32 part_id = 9; - uint32 qos = 10; - - uint32 thread_id = 11; - uint32 thread_dim = 12; - - uint32 sdma_sqe_header = 13; - - uint32 src_stream_id = 14; - uint32 src_sub_stream_id = 15; - uint32 dst_stream_id = 16; - uint32 dst_sub_stream_id = 17; - - uint64 src_addr_base = 18; - uint32 src_addr_offset = 19; - uint64 dst_addr_base = 20; - uint32 dst_addr_offset = 21; - - uint32 non_tail_data_len = 22; - uint32 tail_data_len = 23; -} - -// ffts plus notify record/wait context -message FftsPlusNotifyCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 pred_cnt_init = 3; - uint32 pred_cnt = 4; - repeated uint32 successor_list = 5; // len = 26 - - uint32 satm = 6; - uint32 atm = 7; - - uint32 thread_id = 8; - uint32 thread_dim = 9; - - uint32 notify_id_base = 10; - uint32 auto_window = 11; - - repeated uint32 notify_id = 12; -} - -// write value context -message FftsPlusWriteValueCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 pred_cnt_init = 3; - uint32 pred_cnt = 4; - repeated uint32 successor_list = 5; // len = 26 - - uint32 atm = 6; - uint32 thread_id = 7; - uint32 thread_dim = 8; - - uint32 aw_size = 9; - uint32 aw_snoop = 10; - uint32 aw_cache = 11; - uint32 aw_prot = 12; - uint32 aw_va = 13; - - uint32 ar_size = 14; - uint32 ar_snoop = 15; - uint32 ar_cache = 16; - uint32 ar_prot = 17; - uint32 ar_va = 18; - - uint64 write_addr_base = 19; - uint32 write_addr_offset = 20; - - repeated uint32 write_value = 21; -} - -// aicpu context -message FftsPlusAicpuCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 pred_cnt_init = 3; - uint32 pred_cnt = 4; - repeated uint32 successor_list = 5; // len = 26 - - uint32 atm = 6; - uint32 sqe_index = 7; - uint32 kernel_type = 8; - uint32 bm = 9; - uint32 topic_type = 10; - uint32 qos = 11; - - uint32 thread_id = 12; - uint32 thread_dim = 13; - - uint32 non_tail_block_dim = 14; - uint32 tail_block_dim = 15; - - uint32 sub_topic_id = 16; - uint32 topic_id = 17; - uint32 group_id = 18; - - uint32 task_param_offset = 19; - - aicpuKernelDef kernel = 20; -} - -message aicpuKernelDef { - uint32 args_size = 1; - bytes args = 2; - string so_name = 3; - string kernel_name = 4; - bytes kernel_ext_info = 5; - uint32 kernel_ext_info_size = 6; -} - -// data context -message FftsPlusDataCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 cnt_init = 3; - uint32 cnt = 4; - repeated uint32 successor_list = 5; // len = 26 - - uint32 atm = 6; - uint32 pmg = 7; - uint32 ns = 8; - uint32 part_id = 9; - uint32 qos = 10; - - uint32 orig_consumer_counter = 11; - uint32 run_consumer_counter = 12; - - uint32 thread_id = 13; - uint32 thread_dim = 14; - - uint64 addr_base = 15; - uint32 addr_offset = 16; - - uint32 non_tail_num_outter = 17; - uint32 non_tail_num_inner = 18; - uint32 non_tail_len_inner = 19; - uint32 non_tail_stride_outter = 20; - uint32 non_tail_stride_inner = 21; - - uint32 tail_num_outter = 22; - uint32 tail_num_inner = 23; - uint32 tail_len_inner = 24; - uint32 tail_stride_outter = 25; - uint32 tail_stride_inner = 26; -} - -// at start context -message FftsPlusAtStartCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 pred_cnt_init = 3; - uint32 pred_cnt = 4; - repeated uint32 successor_list = 5; // len = 26 - - uint32 thread_id = 6; - uint32 thread_dim = 7; - - uint32 thread_id_init = 8; - uint32 thread_window_size = 9; -} - -// at end context -message FftsPlusAtEndCtxDef { - uint32 at_start_slot_num = 1; - uint32 out_label_slot_num = 2; - uint32 aten = 3; - - uint32 pred_cnt_init = 4; - uint32 pred_cnt = 5; - - repeated uint32 succ_at_start_slot = 6; // len = 12 - repeated uint32 succ_out_label_slot = 7; // len = 12 - - uint32 thread_id = 8; -} - -// label context -message FftsPlusLabelCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 pred_cnt_init = 3; - uint32 pred_cnt = 4; - repeated uint32 successor_list = 5; // len = 26 - uint32 thread_id = 6; - uint32 thread_dim = 7; -} - -// switch context -message FftsPlusCaseSwitchCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 start_label_id = 3; - uint32 label_list_len = 4; - uint32 pred_cnt_init = 5; - uint32 pred_cnt = 6; - repeated uint32 successor_list = 7; // len = 26 - - uint32 atm = 8; - - uint32 thread_id = 9; - uint32 thread_dim = 10; - - uint32 ar_size = 11; - uint32 snoop = 12; - uint32 ar_cache = 13; - uint32 ar_prot = 14; - uint32 va = 15; - - uint64 load_addr0_base = 16; - uint32 ld0_en = 17; - uint32 load_addr0_offset = 18; - - uint64 load_addr1_base = 19; - uint32 ld1_en = 20; - uint32 load_addr1_offset = 21; -} - -// case default context -message FftsPlusCaseDefaultCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 start_label_id = 3; - uint32 label_list_len = 4; - uint32 pred_cnt_init = 5; - uint32 pred_cnt = 6; - repeated uint32 successor_list = 7; // len = 26 -} - -// cond context -message FftsPlusCondSwitchCtxDef { - uint32 true_successor_num = 1; - uint32 false_successor_num = 2; - uint32 aten = 3; - - uint32 condition = 4; - uint32 pred_cnt_init = 5; - uint32 pred_cnt = 6; - - repeated uint32 true_successor_list = 7; // len = 12 - repeated uint32 false_successor_list = 8; // len = 14 - - uint32 atm = 9; - - uint32 thread_id = 10; - uint32 thread_dim = 11; - - uint32 ar_size = 12; - uint32 snoop = 13; - uint32 ar_cache = 14; - uint32 ar_prot = 15; - uint32 va = 16; - - uint64 load_addr0_base = 17; - uint32 ld0_en = 18; - uint32 load_addr0_offset = 19; - - uint64 load_addr1_base = 20; - uint32 ld1_en = 21; - uint32 load_addr1_offset = 22; - - uint32 cmp_value_1 = 23; - uint32 cmp_value_2 = 24; -} - -message FftsPlusCachePersistCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 prefetch_config = 3; - uint32 pred_cnt_init = 4; - uint32 pred_cnt = 5; - repeated uint32 successor_list = 6; // 16 bits, len = 26 - - uint32 persistent_size = 7; - uint32 persistent_en = 8; - uint32 persistent_id = 9; -} - -//ffts dsa context -message FftsPlusDsaCtxDef { - uint32 successor_num = 1; - uint32 aten = 2; - uint32 pred_cnt_init = 3; - uint32 pred_cnt = 4; - repeated uint32 successor_list = 5; // 16 bits, len = 26 - uint32 atm = 6; - uint32 address_offset = 7; - uint32 thread_id = 8; - uint32 thread_dim = 9; - - uint32 start = 10; // start - uint32 distribution_type = 11; - uint32 data_type = 12; - uint32 alg_type = 13; - uint32 input_vld = 14; - uint32 input_value_addr_flag = 15; - uint32 input1_value_or_ptr = 16; - uint32 input2_value_or_ptr = 17; - uint32 seed_value_or_ptr = 18; - uint32 random_count_value_or_ptr = 19; - DSATaskArgsDef args = 20; -} - -message CmoTaskDef { - uint32 cmo_type = 1; - uint32 logic_id = 2; - uint32 op_code = 3; - uint32 qos = 4; - uint32 part_id = 5; - uint32 pmg = 6; - uint32 num_inner = 7; - uint32 num_outer = 8; - uint32 length_inner = 9; - uint64 source_addr = 10; - uint32 strider_outer = 11; - uint32 strider_inner = 12; -} - -message CmoAddrTaskDef { - uint32 cmo_op_code = 1; - uint32 op_index = 2; - uint32 num_inner = 3; - uint32 num_outer = 4; - uint64 src = 5; - uint32 length_inner = 6; - uint32 stride_outer = 7; - uint32 stride_inner = 8; - uint32 resv0 = 9; - uint32 resv1 = 10; -} - -message CmoBarrierTaskDef { - uint32 logic_id_num = 1; - repeated CmoBarrierInfoDef barrier_info = 2; -} - -message CmoBarrierInfoDef { - uint32 cmo_type = 1; - uint32 logic_id = 2; -} - -message NpuGetFloatStatusDef { - uint64 output_addr = 1; - uint32 output_size = 2; - uint32 mode = 3; - uint32 op_index = 4; -} - -message NpuClearFloatStatusDef { - uint32 mode = 1; - uint32 op_index = 2; -} - -message NpuGetFloatDebugStatusDef { - uint64 output_addr = 1; - uint32 output_size = 2; - uint32 mode = 3; - uint32 op_index = 4; -} - -message NpuClearFloatDebugStatusDef { - uint32 mode = 1; - uint32 op_index = 2; -} - -message DvppTaskDef { - uint32 op_index = 1; -} - -message UpdatePcTaskDef { - uint32 op_index = 1; - uint32 stream_id = 2; - string args_format = 3; -} - -message QueueAttrs { - uint32 queue_id = 1; - int32 device_type = 2; // CPU NPU - int32 device_id = 3; - uint32 logic_id = 4; // 动态请求全局id -} - -message QueueInfo { - QueueAttrs queue_attrs = 1; - uint32 logic_group_id = 2; - uint32 model_uuid = 3; - int32 trans_id_old = 4; - int32 route_label_old = 5; - uint32 choose_logic_id = 6; - uint32 root_model_id = 7; - bool need_cache = 8; - uint64 trans_id = 9; - uint32 route_label = 10; -} - -message FlowgwRequest { - int32 node_id = 1; - int32 input_index = 2; - repeated QueueInfo queue_infos = 3; -} - -message FlowgwResponse { - repeated QueueInfo queue_infos = 1; -} - -message QueueStatus { - QueueAttrs queue_attrs = 1; - uint32 queue_depth = 2; - uint32 input_consume_num = 3; -} - -message DataFlowException { - int32 exception_code = 1; - uint64 trans_id = 2; - string scope = 3; - uint64 user_context_id = 4; - bytes exception_context = 5; // user_data + xxx + HeadMsg -} - -message SubmodelStatus { - uint32 model_uuid = 1; - repeated QueueStatus queue_statuses = 2; - uint32 msg_type = 3; // 0:status 1:exception - DataFlowException exception = 4; -} - -message LaunchAttributeValue { - message Group { - uint32 group_dim = 1; - uint32 group_block_dim = 2; - } - - oneof attribute_value { - uint32 block_dim = 1; - uint32 dynamic_share_mem_size = 2; - Group group = 3; - uint32 qos = 4; - uint32 part_id = 5; - uint32 schem_model = 6; - uint32 block_dim_offset = 7; - uint32 dump_flag = 8; - } -} - -message LaunchAttribute { - enum LaunchAttributeId { - BLOCKDIM = 0; - DYNAMIC_SHARE_MEM_SIZE = 1; - GROUP = 2; - QOS = 3; - PARTID = 4; - SCHEMMODE = 5; - BLOCKDIM_OFFSET = 6; - DUMPFLAG = 7; - MAX = 8; - } - LaunchAttributeId id = 1; - LaunchAttributeValue value = 2; -} - -message LaunchConfig { - repeated LaunchAttribute launch_attribute = 1; -} - -message AicoreFusionTaskInfo { - KernelContext context = 1; - bool is_all_kernel = 2; - LaunchConfig config = 3; -} - -message AicpuFusionTaskInfo { - KernelContext context = 1; - uint32 flags = 2; - uint32 block_dim = 3; - string so_name = 4; - string kernel_name = 5; - bytes kernel_ext_info = 6; -} - -message CcuTaskInfo { - uint32 die_id = 1; - uint32 mission_id = 2; - uint32 timeout = 3; - uint32 inst_start_id = 4; - uint32 inst_cnt = 5; - uint32 key = 6; - uint32 arg_size = 7; - repeated uint64 args = 9; -} - -message CcuTaskGroup { - repeated CcuTaskInfo ccu_task_info = 1; -} - -message FusionSubTaskDef { - oneof sub_task_info { - AicoreFusionTaskInfo aicore_fusion_task_info = 1; - AicpuFusionTaskInfo aicpu_fusion_task_info = 2; - CcuTaskGroup ccu_task_group = 3; - } -} - -message FusionSubTaskInfo { - enum FusionType { - HCOM_CPU = 0; - AICPU = 1; - AICORE = 2; - CCU = 3; - END = 4; - } - - FusionType type = 1; - FusionSubTaskDef task = 2; -} - - -message FusionTaskDef { - uint32 op_index = 1; - string args_format = 2; - uint32 kfc_args_format_offset = 3; - repeated FusionSubTaskInfo fusion_sub_task_info = 4; -} diff --git a/proto/tensorflow/attr_value.proto b/proto/tensorflow/attr_value.proto deleted file mode 100644 index cfe0f31a12305b6ff037ce290c01d5562aa38638..0000000000000000000000000000000000000000 --- a/proto/tensorflow/attr_value.proto +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "AttrValueProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "tensor.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing the value for an attr used to configure an Op. -// Comment indicates the corresponding attr type. Only the field matching the -// attr type may be filled. -message AttrValue { - // LINT.IfChange - message ListValue { - repeated bytes s = 2; // "list(string)" - repeated int64 i = 3 [packed = true]; // "list(int)" - repeated float f = 4 [packed = true]; // "list(float)" - repeated bool b = 5 [packed = true]; // "list(bool)" - repeated DataType type = 6 [packed = true]; // "list(type)" - repeated TensorShapeProto shape = 7; // "list(shape)" - repeated TensorProto tensor = 8; // "list(tensor)" - repeated NameAttrList func = 9; // "list(attr)" - } - // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.cc) - - oneof value { - bytes s = 2; // "string" - int64 i = 3; // "int" - float f = 4; // "float" - bool b = 5; // "bool" - DataType type = 6; // "type" - TensorShapeProto shape = 7; // "shape" - TensorProto tensor = 8; // "tensor" - ListValue list = 1; // any "list(...)" - - // "func" represents a function. func.name is a function's name or - // a primitive op's name. func.attr.first is the name of an attr - // defined for that function. func.attr.second is the value for - // that attr in the instantiation. - NameAttrList func = 10; - - // This is a placeholder only used in nodes defined inside a - // function. It indicates the attr value will be supplied when - // the function is instantiated. For example, let us suppose a - // node "N" in function "FN". "N" has an attr "A" with value - // placeholder = "foo". When FN is instantiated with attr "foo" - // set to "bar", the instantiated node N's attr A will have been - // given the value "bar". - string placeholder = 9; - } -} - -// A list of attr names and their values. The whole list is attached -// with a string name. E.g., MatMul[T=float]. -message NameAttrList { - string name = 1; - map attr = 2; -} diff --git a/proto/tensorflow/function.proto b/proto/tensorflow/function.proto deleted file mode 100644 index 5112e43ed3c5de5bf618892bdc8cf42d29f86faf..0000000000000000000000000000000000000000 --- a/proto/tensorflow/function.proto +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "FunctionProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "node_def.proto"; -import "op_def.proto"; - -// A library is a set of named functions. -message FunctionDefLibrary { - repeated FunctionDef function = 1; - repeated GradientDef gradient = 2; -} - -// A function can be instantiated when the runtime can bind every attr -// with a value. When a GraphDef has a call to a function, it must -// have binding for every attr defined in the signature. -// * device spec, etc. -message FunctionDef { - // The definition of the function's name, arguments, return values, - // attrs etc. - OpDef signature = 1; - - // Attributes specific to this function definition. - map attr = 5; - - // NOTE: field id 2 deleted on Jan 11, 2017, GraphDef version 21. - reserved 2; - - // In both of the following fields, there is the need to specify an - // output that is used as either the input to another node (in - // `node_def`) or as a return value of the function (in `ret`). - // Unlike the NodeDefs in GraphDef, we need to be able to specify a - // list in some cases (instead of just single outputs). Also, we - // need to be able to deal with lists of unknown length (so the - // output index may not be known at function definition time). So - // we use the following format instead: - // * "fun_in" where "fun_in" is the name of a function input arg in - // the `signature` field above. This represents that input, whether - // it is a single tensor or a list. - // * "fun_in:0" gives the first element of a function input arg (a - // non-list input is considered a list of length 1 for these - // purposes). - // * "node:out" where "node" is the name of a node in `node_def` and - // "out" is the name one of its op's output arguments (the name - // comes from the OpDef of the node's op). This represents that - // node's output, whether it is a single tensor or a list. - // Note: We enforce that an op's output arguments are never - // renamed in the backwards-compatibility test. - // * "node:out:0" gives the first element of a node output arg (a - // non-list output is considered a list of length 1 for these - // purposes). - // - // NOT CURRENTLY SUPPORTED (but may be in the future): - // * "node:out:-1" gives last element in a node output list - // * "node:out:1:" gives a list with all but the first element in a - // node output list - // * "node:out::-1" gives a list with all but the last element in a - // node output list - - // The body of the function. Unlike the NodeDefs in a GraphDef, attrs - // may have values of type `placeholder` and the `input` field uses - // the "output" format above. - - // By convention, "op" in node_def is resolved by consulting with a - // user-defined library first. If not resolved, "func" is assumed to - // be a builtin op. - repeated NodeDef node_def = 3; - - // A mapping from the output arg names from `signature` to the - // outputs from `node_def` that should be returned by the function. - map ret = 4; -} - -// GradientDef defines the gradient function of a function defined in -// a function library. -// -// A gradient function g (specified by gradient_func) for a function f -// (specified by function_name) must follow the following: -// -// The function 'f' must be a numerical function which takes N inputs -// and produces M outputs. Its gradient function 'g', which is a -// function taking N + M inputs and produces N outputs. -// -// I.e. if we have -// (y1, y2, ..., y_M) = f(x1, x2, ..., x_N), -// then, g is -// (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N, -// dL/dy1, dL/dy2, ..., dL/dy_M), -// where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the -// loss function). dL/dx_i is the partial derivative of L with respect -// to x_i. -message GradientDef { - string function_name = 1; // The function name. - string gradient_func = 2; // The gradient function's name. -} diff --git a/proto/tensorflow/graph.proto b/proto/tensorflow/graph.proto deleted file mode 100644 index 6462a8183135398b1fd91cb42ac4b4219908e6d4..0000000000000000000000000000000000000000 --- a/proto/tensorflow/graph.proto +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "GraphProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "node_def.proto"; -import "function.proto"; -import "versions.proto"; - -// Represents the graph of operations -message GraphDef { - repeated NodeDef node = 1; - - // Compatibility versions of the graph. See core/public/version.h for version - // history. The GraphDef version is distinct from the TensorFlow version, and - // each release of TensorFlow will support a range of GraphDef versions. - VersionDef versions = 4; - - // Deprecated single version field; use versions above instead. Since all - // GraphDef changes before "versions" was introduced were forward - // compatible, this field is entirely ignored. - int32 version = 3; - - // EXPERIMENTAL. DO NOT USE OR DEPEND ON THIS YET. - // - // "library" provides user-defined functions. - // - // Naming: - // * library.function.name are in a flat namespace. - // NOTE: We may need to change it to be hierarchical to support - // different orgs. E.g., - // { "/google/nn", { ... }}, - // { "/google/vision", { ... }} - // { "/org_foo/module_bar", { ... }} - // map named_lib; - // * If node[i].op is the name of one function in "library", - // node[i] is deemed as a function call. Otherwise, node[i].op - // must be a primitive operation supported by the runtime. - // - // - // Function call semantics: - // - // * The callee may start execution as soon as some of its inputs - // are ready. The caller may want to use Tuple() mechanism to - // ensure all inputs are ready in the same time. - // - // * The consumer of return values may start executing as soon as - // the return values the consumer depends on are ready. The - // consumer may want to use Tuple() mechanism to ensure the - // consumer does not start until all return values of the callee - // function are ready. - FunctionDefLibrary library = 2; -}; diff --git a/proto/tensorflow/graph_library.proto b/proto/tensorflow/graph_library.proto deleted file mode 100644 index 8bbe136e3067d9b4df2575b66c0623dcf9637313..0000000000000000000000000000000000000000 --- a/proto/tensorflow/graph_library.proto +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; - -import "graph.proto"; - -message GeGraphDef { - string name = 1; - GraphDef graph = 2; -} - -message GraphDefLibrary { - repeated GeGraphDef graph_def = 1; -}; diff --git a/proto/tensorflow/node_def.proto b/proto/tensorflow/node_def.proto deleted file mode 100644 index 785aa9fdef23d89e13e7216d78cafe7f8f4c27b5..0000000000000000000000000000000000000000 --- a/proto/tensorflow/node_def.proto +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "NodeProto"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; - -message NodeDef { - // The name given to this operator. Used for naming inputs, - // logging, visualization, etc. Unique within a single GraphDef. - // Must match the regexp "[A-Za-z0-9.][A-Za-z0-9_./]*". - string name = 1; - - // The operation name. There may be custom parameters in attrs. - // Op names starting with an underscore are reserved for internal use. - string op = 2; - - // Each input is "node:src_output" with "node" being a string name and - // "src_output" indicating which output tensor to use from "node". If - // "src_output" is 0 the ":0" suffix can be omitted. Regular inputs - // may optionally be followed by control inputs that have the format - // "^node". - repeated string input = 3; - - // A (possibly partial) specification for the device on which this - // node should be placed. - // The expected syntax for this string is as follows: - // - // DEVICE_SPEC ::= PARTIAL_SPEC - // - // PARTIAL_SPEC ::= ("/" CONSTRAINT) * - // CONSTRAINT ::= ("job:" JOB_NAME) - // | ("replica:" [1-9][0-9]*) - // | ("task:" [1-9][0-9]*) - // | ("device:" [A-Za-z]* ":" ([1-9][0-9]* | "*") ) - // - // Valid values for this string include: - // * "/job:worker/replica:0/task:1/device:GPU:3" (full specification) - // * "/job:worker/device:GPU:3" (partial specification) - // * "" (no specification) - // - // If the constraints do not resolve to a single device (or if this - // field is empty or not present), the runtime will attempt to - // choose a device automatically. - string device = 4; - - // Operation-specific graph-construction-time configuration. - // Note that this should include all attrs defined in the - // corresponding OpDef, including those with a value matching - // the default -- this allows the default to change and makes - // NodeDefs easier to interpret on their own. However, if - // an attr with a default is not specified in this list, the - // default will be used. - // The "names" (keys) must match the regexp "[a-z][a-z0-9_]+" (and - // one of the names from the corresponding OpDef's attr field). - // The values must have a type matching the corresponding OpDef - // attr's type field. - // Add some examples here showing best practices. - map attr = 5; -}; diff --git a/proto/tensorflow/op_def.proto b/proto/tensorflow/op_def.proto deleted file mode 100644 index 3bd583dbfe7b966f2577c707c1e35751e074330b..0000000000000000000000000000000000000000 --- a/proto/tensorflow/op_def.proto +++ /dev/null @@ -1,173 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "OpDefProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "attr_value.proto"; -import "types.proto"; - -// Defines an operation. A NodeDef in a GraphDef specifies an Op by -// using the "op" field which should match the name of a OpDef. -// LINT.IfChange -message OpDef { - // Op names starting with an underscore are reserved for internal use. - // Names should be CamelCase and match the regexp "[A-Z][a-zA-Z0-9_]*". - string name = 1; - - // For describing inputs and outputs. - message ArgDef { - // Name for the input/output. Should match the regexp "[a-z][a-z0-9_]*". - string name = 1; - - // Human readable description. - string description = 2; - - // Describes the type of one or more tensors that are accepted/produced - // by this input/output arg. The only legal combinations are: - // * For a single tensor: either the "type" field is set or the - // "type_attr" field is set to the name of an attr with type "type". - // * For a sequence of tensors with the same type: the "number_attr" - // field will be set to the name of an attr with type "int", and - // either the "type" or "type_attr" field will be set as for - // single tensors. - // * For a sequence of tensors, the "type_list_attr" field will be set - // to the name of an attr with type "list(type)". - DataType type = 3; - string type_attr = 4; // if specified, attr must have type "type" - string number_attr = 5; // if specified, attr must have type "int" - // If specified, attr must have type "list(type)", and none of - // type, type_attr, and number_attr may be specified. - string type_list_attr = 6; - - // For inputs: if true, the inputs are required to be refs. - // By default, inputs can be either refs or non-refs. - // For outputs: if true, outputs are refs, otherwise they are not. - bool is_ref = 16; - }; - - // Description of the input(s). - repeated ArgDef input_arg = 2; - - // Description of the output(s). - repeated ArgDef output_arg = 3; - - // Description of the graph-construction-time configuration of this - // Op. That is to say, this describes the attr fields that will - // be specified in the NodeDef. - message AttrDef { - // A descriptive name for the argument. May be used, e.g. by the - // Python client, as a keyword argument name, and so should match - // the regexp "[a-z][a-z0-9_]+". - string name = 1; - - // One of the type names from attr_value.proto ("string", "list(string)", - // "int", etc.). - string type = 2; - - // A reasonable default for this attribute if the user does not supply - // a value. If not specified, the user must supply a value. - AttrValue default_value = 3; - - // Human-readable description. - string description = 4; - - - // --- Constraints --- - // These constraints are only in effect if specified. Default is no - // constraints. - - // For type == "int", this is a minimum value. For "list(___)" - // types, this is the minimum length. - bool has_minimum = 5; - int64 minimum = 6; - - // The set of allowed values. Has type that is the "list" version - // of the "type" field above (uses the "list" field of AttrValue). - // If type == "type" or "list(type)" above, then the "type" field - // of "allowed_values.list" has the set of allowed DataTypes. - // If type == "string" or "list(string)", then the "s" field of - // "allowed_values.list" has the set of allowed strings. - AttrValue allowed_values = 7; - } - repeated AttrDef attr = 4; - - // Optional deprecation based on GraphDef versions. - OpDeprecation deprecation = 8; - - // One-line human-readable description of what the Op does. - string summary = 5; - - // Additional, longer human-readable description of what the Op does. - string description = 6; - - // ------------------------------------------------------------------------- - // Which optimizations this operation can participate in. - - // True if the operation is commutative ("op(a,b) == op(b,a)" for all inputs) - bool is_commutative = 18; - - // If is_aggregate is true, then this operation accepts N >= 2 - // inputs and produces 1 output all of the same type. Should be - // associative and commutative, and produce output with the same - // shape as the input. The optimizer may replace an aggregate op - // taking input from multiple devices with a tree of aggregate ops - // that aggregate locally within each device (and possibly within - // groups of nearby devices) before communicating. - bool is_aggregate = 16; // for things like add - - // Other optimizations go here, like - // can_alias_input, rewrite_when_output_unused, partitioning_strategy, etc. - - // ------------------------------------------------------------------------- - // Optimization constraints. - - // Ops are marked as stateful if their behavior depends on some state beyond - // their input tensors (e.g. variable reading op) or if they have - // a side-effect (e.g. printing or asserting ops). Equivalently, stateless ops - // must always produce the same output for the same input and have - // no side-effects. - // - // By default Ops may be moved between devices. Stateful ops should - // either not be moved, or should only be moved if that state can also - // be moved (e.g. via some sort of save / restore). - // Stateful ops are guaranteed to never be optimized away by Common - // Subexpression Elimination (CSE). - bool is_stateful = 17; // for things like variables, queue - - // ------------------------------------------------------------------------- - // Non-standard options. - - // By default, all inputs to an Op must be initialized Tensors. Ops - // that may initialize tensors for the first time should set this - // field to true, to allow the Op to take an uninitialized Tensor as - // input. - bool allows_uninitialized_input = 19; // for Assign, etc. -}; -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/core/framework/op_def_util.cc) - -// Information about version-dependent deprecation of an op -message OpDeprecation { - // First GraphDef version at which the op is disallowed. - int32 version = 1; - - // Explanation of why it was deprecated and what to use instead. - string explanation = 2; -}; - -// A collection of OpDefs -message OpList { - repeated OpDef op = 1; -}; diff --git a/proto/tensorflow/resource_handle.proto b/proto/tensorflow/resource_handle.proto deleted file mode 100644 index 54a8696d1da51f29c87ddd70fdd7a5eee366e4bf..0000000000000000000000000000000000000000 --- a/proto/tensorflow/resource_handle.proto +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "ResourceHandle"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Protocol buffer representing a handle to a tensorflow resource. Handles are -// not valid across executions, but can be serialized back and forth from within -// a single run. -message ResourceHandleProto { - // Unique name for the device containing the resource. - string device = 1; - - // Container in which this resource is placed. - string container = 2; - - // Unique name of this resource. - string name = 3; - - // Hash code for the type of the resource. Is only valid in the same device - // and in the same execution. - uint64 hash_code = 4; - - // For debug-only, the name of the type pointed to by this handle, if - // available. - string maybe_type_name = 5; -}; diff --git a/proto/tensorflow/tensor.proto b/proto/tensorflow/tensor.proto deleted file mode 100644 index c4c6715993a711d2642c00e6f40eb0b8f962b2bf..0000000000000000000000000000000000000000 --- a/proto/tensorflow/tensor.proto +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TensorProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -import "resource_handle.proto"; -import "tensor_shape.proto"; -import "types.proto"; - -// Protocol buffer representing a tensor. -message TensorProto { - DataType dtype = 1; - - // Shape of the tensor. - TensorShapeProto tensor_shape = 2; - - // Only one of the representations below is set, one of "tensor_contents" and - // the "xxx_val" attributes. We are not using oneof because as oneofs cannot - // contain repeated fields it would require another extra set of messages. - - // Version number. - // - // In version 0, if the "repeated xxx" representations contain only one - // element, that element is repeated to fill the shape. This makes it easy - // to represent a constant Tensor with a single value. - int32 version_number = 3; - - // Serialized raw tensor content from either Tensor::AsProtoTensorContent or - // memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation - // can be used for all tensor types. The purpose of this representation is to - // reduce serialization overhead during RPC call by avoiding serialization of - // many repeated small items. - bytes tensor_content = 4; - - // Type specific representations that make it easy to create tensor protos in - // all languages. Only the representation corresponding to "dtype" can - // be set. The values hold the flattened representation of the tensor in - // row major order. - - // DT_HALF, DT_BFLOAT16. Note that since protobuf has no int16 type, we'll - // have some pointless zero padding for each value here. - repeated int32 half_val = 13 [packed = true]; - - // DT_FLOAT. - repeated float float_val = 5 [packed = true]; - - // DT_DOUBLE. - repeated double double_val = 6 [packed = true]; - - // DT_INT32, DT_INT16, DT_INT8, DT_UINT8. - repeated int32 int_val = 7 [packed = true]; - - // DT_STRING - repeated bytes string_val = 8; - - // DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - repeated float scomplex_val = 9 [packed = true]; - - // DT_INT64 - repeated int64 int64_val = 10 [packed = true]; - - // DT_BOOL - repeated bool bool_val = 11 [packed = true]; - - // DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real - // and imaginary parts of i-th double precision complex. - repeated double dcomplex_val = 12 [packed = true]; - - // DT_RESOURCE - repeated ResourceHandleProto resource_handle_val = 14; - - // DT_VARIANT - repeated VariantTensorDataProto variant_val = 15; - - // DT_UINT32 - repeated uint32 uint32_val = 16 [packed = true]; - - // DT_UINT64 - repeated uint64 uint64_val = 17 [packed = true]; - - // DT_COMPLEX32. icomplex_val(2*i) and icomplex_val(2*i+1) are real - // and imaginary parts of i-th single precision complex. - // Note that since protobuf has no int16 type, we'll have some - // pointless zero padding for each value here. - repeated int32 icomplex_val = 18 [packed = true]; -}; - -// Protocol buffer representing the serialization format of DT_VARIANT tensors. -message VariantTensorDataProto { - // Name of the type of objects being serialized. - string type_name = 1; - // Portions of the object that are not Tensors. - bytes metadata = 2; - // Tensors contained within objects being serialized. - repeated TensorProto tensors = 3; -} diff --git a/proto/tensorflow/tensor_shape.proto b/proto/tensorflow/tensor_shape.proto deleted file mode 100644 index 78c7814b0e40cc7fa35e475bc402923027947a7c..0000000000000000000000000000000000000000 --- a/proto/tensorflow/tensor_shape.proto +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -option cc_enable_arenas = true; -option java_outer_classname = "TensorShapeProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -package domi.tensorflow; - -// Dimensions of a tensor. -message TensorShapeProto { - // One dimension of the tensor. - message Dim { - // Size of the tensor in that dimension. - // This value must be >= -1, but values of -1 are reserved for "unknown" - // shapes (values of -1 mean "unknown" dimension). Certain wrappers - // that work with TensorShapeProto may fail at runtime when deserializing - // a TensorShapeProto containing a dim value of -1. - int64 size = 1; - - // Optional name of the tensor dimension. - string name = 2; - }; - - // Dimensions of the tensor, such as {"input", 30}, {"output", 40} - // for a 30 x 40 2D tensor. If an entry has size -1, this - // corresponds to a dimension of unknown size. The names are - // optional. - // - // The order of entries in "dim" matters: It indicates the layout of the - // values in the tensor in-memory representation. - // - // The first entry in "dim" is the outermost dimension used to layout the - // values, the last entry is the innermost dimension. This matches the - // in-memory layout of RowMajor Eigen tensors. - // - // If "dim.size()" > 0, "unknown_rank" must be false. - repeated Dim dim = 2; - - // If true, the number of dimensions in the shape is unknown. - // - // If true, "dim.size()" must be 0. - bool unknown_rank = 3; -}; diff --git a/proto/tensorflow/types.proto b/proto/tensorflow/types.proto deleted file mode 100644 index 712a8c459c83563b81c388cdbedcb773559a9273..0000000000000000000000000000000000000000 --- a/proto/tensorflow/types.proto +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "TypesProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// LINT.IfChange -enum DataType { - // Not a legal value for DataType. Used to indicate a DataType field - // has not been set. - DT_INVALID = 0; - - // Data types that all computation devices are expected to be - // capable to support. - DT_FLOAT = 1; - DT_DOUBLE = 2; - DT_INT32 = 3; - DT_UINT8 = 4; - DT_INT16 = 5; - DT_INT8 = 6; - DT_STRING = 7; - DT_COMPLEX64 = 8; // Single-precision complex - DT_INT64 = 9; - DT_BOOL = 10; - DT_QINT8 = 11; // Quantized int8 - DT_QUINT8 = 12; // Quantized uint8 - DT_QINT32 = 13; // Quantized int32 - DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops. - DT_QINT16 = 15; // Quantized int16 - DT_QUINT16 = 16; // Quantized uint16 - DT_UINT16 = 17; - DT_COMPLEX128 = 18; // Double-precision complex - DT_HALF = 19; - DT_RESOURCE = 20; - DT_VARIANT = 21; // Arbitrary C++ data types - DT_UINT32 = 22; - DT_UINT64 = 23; - DT_COMPLEX32 = 24; - - // Do not use! These are only for parameters. Every enum above - // should have a corresponding value below (verified by types_test). - DT_FLOAT_REF = 101; - DT_DOUBLE_REF = 102; - DT_INT32_REF = 103; - DT_UINT8_REF = 104; - DT_INT16_REF = 105; - DT_INT8_REF = 106; - DT_STRING_REF = 107; - DT_COMPLEX64_REF = 108; - DT_INT64_REF = 109; - DT_BOOL_REF = 110; - DT_QINT8_REF = 111; - DT_QUINT8_REF = 112; - DT_QINT32_REF = 113; - DT_BFLOAT16_REF = 114; - DT_QINT16_REF = 115; - DT_QUINT16_REF = 116; - DT_UINT16_REF = 117; - DT_COMPLEX128_REF = 118; - DT_HALF_REF = 119; - DT_RESOURCE_REF = 120; - DT_VARIANT_REF = 121; - DT_UINT32_REF = 122; - DT_UINT64_REF = 123; - DT_COMPLEX32_REF = 124; -} -// LINT.ThenChange( -// https://www.tensorflow.org/code/tensorflow/c/c_api.h, -// https://www.tensorflow.org/code/tensorflow/go/tensor.go, -// https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.h, -// https://www.tensorflow.org/code/tensorflow/core/framework/types.cc, -// https://www.tensorflow.org/code/tensorflow/python/framework/dtypes.py, -// https://www.tensorflow.org/code/tensorflow/python/framework/function.py) diff --git a/proto/tensorflow/versions.proto b/proto/tensorflow/versions.proto deleted file mode 100644 index 707416eef6b0023e2722b8da736e240441ac8ee0..0000000000000000000000000000000000000000 --- a/proto/tensorflow/versions.proto +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; - -package domi.tensorflow; -option cc_enable_arenas = true; -option java_outer_classname = "VersionsProtos"; -option java_multiple_files = true; -option java_package = "org.tensorflow.framework"; - -// Version information for a piece of serialized data -// -// There are different types of versions for each type of data -// (GraphDef, etc.), but they all have the same common shape -// described here. -// -// Each consumer has "consumer" and "min_producer" versions (specified -// elsewhere). A consumer is allowed to consume this data if -// -// producer >= min_producer -// consumer >= min_consumer -// consumer not in bad_consumers -// -message VersionDef { - // The version of the code that produced this data. - int32 producer = 1; - - // Any consumer below this version is not allowed to consume this data. - int32 min_consumer = 2; - - // Specific consumer versions which are disallowed (e.g. due to bugs). - repeated int32 bad_consumers = 3; -}; diff --git a/proto/var_manager.proto b/proto/var_manager.proto deleted file mode 100644 index b658af53606d85a364e69bfd8572bac0c49f37ee..0000000000000000000000000000000000000000 --- a/proto/var_manager.proto +++ /dev/null @@ -1,106 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -syntax = "proto3"; -package deployer; -import "ge_ir.proto"; - -message VarAddrMgrInfo { - ge.proto.TensorDescriptor desc = 1; - uint64 address = 2; - uint64 offset = 3; - uint64 memory_type = 4; -} - -message VarDevAddrMgr { - ge.proto.TensorDescriptor desc = 1; - uint64 address = 2; - uint64 dev_addr = 3; -} - -message SingleTransNodeInfo { - string node_type = 1; - ge.proto.TensorDescriptor input = 2; - ge.proto.TensorDescriptor output = 3; -} - -message TransNodeMultiInfo { - repeated SingleTransNodeInfo node_info = 1; -} - -message BroadcastInfo { - string var_name = 1; - string broadcast_name = 2; - int32 idx =3; - int64 input_offset = 4; - uint64 input_size = 5; - int64 output_offset = 6; - uint64 output_size = 7; -} - -message BroadcastMultiInfo { - map broadcast_info = 1; -} - -message VarDescInfo { - map cur_var_tensor_desc_map = 1; - map var_to_trans_road = 2; - repeated string changed_var_names = 3; - map staged_var_tensor_desc_map = 4; -} - -message VarMatchInfo { - VarDescInfo desc_info_before_compile = 1; - VarDescInfo desc_info_after_compile = 2; -} - -message VarResourceInfo { - map var_offset_map = 1; - map var_addr_mgr_map = 2; - map cur_var_tensor_desc_map = 3; - map var_to_trans_road = 4; - map var_names_to_changed_graph_id = 5; - map var_names_to_allocated_graph_id = 6; - map var_broad_cast_info = 7; - map var_dev_addr_mgr_map = 8; -} - -message MemResourceInfo { - uint64 total_size = 1; - uint64 var_mem_size = 2; -} - -message VarManagerInfo { - uint32 version = 1; - uint64 session_id = 2; - uint32 device_id = 3; - uint64 job_id = 4; - uint64 graph_mem_max_size = 5; - uint64 var_mem_max_size = 6; - uint64 var_mem_logic_base = 7; - uint64 use_max_mem_size = 8; - VarResourceInfo var_resource = 9; - map mem_resource_map = 10; - bool var_mem_auto_malloc = 11; -} - -message MultiVarManagerInfo { - repeated VarManagerInfo var_manager_info = 1; -} - -message SharedContentDescription { - uint64 session_id = 1; - string node_name = 2; - uint64 head_offset = 3; - uint64 total_length = 4; - uint64 current_offset = 5; - uint32 mem_type = 6; - ge.proto.TensorDescriptor tensor_desc = 7; - bytes om_content = 8; -} diff --git a/register/CMakeLists.txt b/register/CMakeLists.txt deleted file mode 100644 index 541a94b2ad1c7df76007db3fbf5c87d6454e37c2..0000000000000000000000000000000000000000 --- a/register/CMakeLists.txt +++ /dev/null @@ -1,421 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -include(${METADEF_DIR}/cmake/build_type.cmake) -set(SRC_LIST - "register.cpp" - "register_custom_pass.cpp" - "prototype_pass_registry.cc" - "ops_kernel_builder_registry.cc" - "graph_optimizer/fusion_common/op_slice_info.cc" - "graph_optimizer/fusion_common/fusion_pass_desc.cc" - "graph_optimizer/fusion_common/fusion_turbo_utils.cc" - "graph_optimizer/fusion_common/fusion_turbo.cc" - "graph_optimizer/fusion_common/unknown_shape_utils.cc" - "graph_optimizer/fusion_common/fusion_config_info.cc" - "graph_optimizer/graph_fusion/graph_fusion_pass_base.cc" - "graph_optimizer/graph_fusion/connection_matrix.cc" - "graph_optimizer/graph_fusion/fusion_pass_registry.cc" - "graph_optimizer/graph_fusion/fusion_pattern.cc" - "graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc" - "graph_optimizer/graph_fusion/graph_pass_util.cc" - "graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc" - "graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" - "graph_optimizer/graph_fusion/fusion_quant_util.cc" - "graph_optimizer/graph_fusion/fusion_quant_util_impl.cc" - "graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc" - "graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc" - "graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc" - "graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc" - "op_kernel_registry.cpp" - "auto_mapping_util.cpp" - "ffts_plus_update_manager.cc" - "ffts_node_converter_registry.cc" - "ffts_node_calculater_registry.cc" - "op_ext_calc_param_registry.cc" - "op_ext_gentask_registry.cc" - "host_cpu_context.cc" - "hidden_input_func_registry.cc" - "hidden_inputs_func_registry.cc" - "tensor_assign.cpp" - "infer_data_slice_registry.cc" - "infer_axis_slice_registry.cc" - "kernel_register_data.cc" - "kernel_registry_impl.cc" - "node_converter_registry.cc" - "op_binary_resource_manager.cc" - "exe_res_generation_context.cc" - "op_lib_register.cc" - "register_base.cc" - "scope/scope_graph.cc" - "scope/scope_pass.cc" - "scope/scope_pattern.cc" - "scope/scope_util.cc" - "scope/scope_pass_registry.cc" - "shape_inference.cc" - "ascendc/ascendc_py.cc" - "ascendc/op_check.cc" - "ascendc/tilingdata_base.cc" - "opdef/op_def.cc" - "opdef/op_def_attr.cc" - "opdef/op_def_param.cc" - "opdef/op_def_aicore.cc" - "opdef/op_def_factory.cc" - "opdef/op_def_mc2.cc" - "opdef/op_config_registry.cc" - "device_op_impl_registry.cc" - "tuning_tiling_registry.cc" - "tuning_bank_key_registry.cc" - "stream_manage_func_registry.cc" - "optimization_option_registry.cc" - "pass_option_utils.cc" - "kernel_launch_info.cc" - "kernel_launch_info_impl.cc" - "${METADEF_DIR}/exe_graph/lowering/kernel_run_context_builder.cc" - "${METADEF_DIR}/exe_graph/lowering/bg_kernel_context_extend.cc" - "${METADEF_DIR}/exe_graph/lowering/buffer_pool.cc" - "${METADEF_DIR}/exe_graph/lowering/bg_ir_attrs.cc" -) - -############ libregister.so ############ -add_library(register SHARED - ${SRC_LIST} - $ - $ -) - -target_compile_options(register PRIVATE - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$:/utf-8> - $<$,$>:/MTd> - $<$,$>:/MT> -) - -target_compile_definitions(register PRIVATE - google=ascend_private - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> - $<$:ONLY_COMPILE_OPEN_SRC> -) - -target_include_directories(register PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${METADEF_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_link_options(register PRIVATE - -Wl,-Bsymbolic -) - -target_link_libraries(register - PRIVATE - intf_pub - msprof_headers - cce_headers - runtime_headers - -Wl,--whole-archive - op_tiling_o2 - -Wl,--no-whole-archive - -Wl,--no-as-needed - ascend_protobuf - c_sec - slog - platform - graph - error_manager - opp_registry - static_mmpa - -Wl,--as-needed - json - PUBLIC - metadef_headers -) -file(GLOB_RECURSE OPP_REGISTRY_SRC_LIST - ${METADEF_DIR}/base/registry/*.cc -) - -############ libregister.a ############ -add_library(register_static STATIC - ${SRC_LIST} - $ - $ - "op_tiling/op_tiling.cc" - "op_tiling/op_tiling_info.cc" - "op_tiling/op_tiling_utils.cc" - "op_tiling/op_tiling_attr_utils.cc" - "op_tiling/op_compile_info_manager.cc" - "op_tiling/op_tiling_registry.cc" - "op_tiling/op_tiling_py.cc" -) - -target_compile_options(register_static PRIVATE - $<$,$>: -fno-common -Wextra -Wfloat-equal> - $<$:/utf-8> - $<$,$>:/MTd> - $<$,$>:/MT> -) - -target_compile_definitions(register_static PRIVATE - google=ascend_private - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> - $<$:ONLY_COMPILE_OPEN_SRC> - LOG_CPP -) - -target_include_directories(register_static PUBLIC - ${METADEF_DIR}/graph - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/base - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/inc/external/base - ${METADEF_DIR}/inc/register - ${TOP_DIR}/ace/npuruntime/runtime/platform/inc -) - -target_include_directories(register_static PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${METADEF_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_link_libraries(register_static PRIVATE - c_sec - json - intf_pub - slog_headers - msprof_headers - mmpa_headers - cce_headers - runtime_headers - metadef_headers -) - -target_link_libraries(register_static PRIVATE - ascend_protobuf_static - ) - -set_target_properties(register_static PROPERTIES - WINDOWS_EXPORT_ALL_SYMBOLS TRUE - OUTPUT_NAME $,libregister,register> -) - -############ libop_tiling_o2.a ############ -add_library(op_tiling_o2 STATIC - "op_tiling/op_tiling.cc" - "op_tiling/op_tiling_info.cc" - "op_tiling/op_tiling_utils.cc" - "op_tiling/op_tiling_attr_utils.cc" - "op_tiling/op_compile_info_manager.cc" - "op_tiling/op_tiling_registry.cc" - "op_tiling/op_tiling_py.cc") - -add_dependencies(op_tiling_o2 - metadef_protos -) - -target_include_directories(op_tiling_o2 PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${METADEF_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_compile_options(op_tiling_o2 PRIVATE - -O2 - $<$,$>:/MTd> - $<$,$>:/MT> - $<$,$>: -fno-common -Wextra -Wfloat-equal> -) - -target_compile_definitions(op_tiling_o2 PRIVATE - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:ONLY_COMPILE_OPEN_SRC> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> - LOG_CPP -) - -target_link_libraries(op_tiling_o2 PRIVATE - intf_pub - slog_headers - mmpa_headers - metadef_headers - ascend_protobuf - json - c_sec -) - -############ librt2_registry.a ############ -add_library(rt2_registry_objects OBJECT - "${METADEF_DIR}/base/registry/op_impl_registry.cc" - "${METADEF_DIR}/base/registry/op_ct_impl_registry.cc" - "${METADEF_DIR}/base/registry/op_impl_functions.cc" - "${METADEF_DIR}/register/op_bin_info.cc" -) - -target_compile_options(rt2_registry_objects PRIVATE - $<$,$>: -fvisibility=hidden -fno-common -fPIC -O2 -Werror -Wextra -Wfloat-equal> - $<$:/utf-8> - $<$,$>:/MTd> - $<$,$>:/MT> -) - -target_compile_definitions(rt2_registry_objects PRIVATE - $,OS_TYPE=WIN,OS_TYPE=0> - $<$:SECUREC_USING_STD_SECURE_LIB=0 NOMINMAX> - $<$:ONLY_COMPILE_OPEN_SRC> -) - -target_include_directories(rt2_registry_objects PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/third_party/graphengine/inc - ${METADEF_DIR}/third_party/graphengine/inc/external - ${METADEF_DIR}/third_party/fwkacllib/inc - ${TOP_DIR}/ace/npuruntime/runtime/platform/inc -) - -set_target_properties(rt2_registry_objects PROPERTIES - WINDOWS_EXPORT_ALL_SYMBOLS TRUE - OUTPUT_NAME $,librt2_registry,rt2_registry> -) - -target_link_libraries(rt2_registry_objects - PRIVATE - $ - c_sec - slog - PUBLIC - metadef_headers -) - -############ librt2_registry.a ############ -add_library(rt2_registry_static STATIC - $ -) - -set_target_properties(rt2_registry_static PROPERTIES - WINDOWS_EXPORT_ALL_SYMBOLS TRUE - OUTPUT_NAME $,librt2_registry,rt2_registry> -) - -target_link_libraries(rt2_registry_static - PUBLIC - metadef_headers -) - -############################################################## -set(STUB_HEADER_LIST - ${METADEF_DIR}/inc/external/register/op_def.h - ${METADEF_DIR}/inc/external/register/op_def_factory.h - ${METADEF_DIR}/inc/external/register/op_impl_registry.h - ${METADEF_DIR}/inc/external/register/op_lib_register.h - ${METADEF_DIR}/inc/external/register/register_base.h - ${METADEF_DIR}/inc/external/register/op_tiling_info.h - ${METADEF_DIR}/inc/external/register/op_tiling_registry.h - ${METADEF_DIR}/inc/external/register/register.h - ${METADEF_DIR}/inc/external/register/tilingdata_base.h - ${METADEF_DIR}/inc/external/register/tuning_bank_key_registry.h - ${METADEF_DIR}/inc/external/register/tuning_tiling_registry.h - ${METADEF_DIR}/inc/external/register/hidden_input_func_registry.h - ${METADEF_DIR}/inc/external/register/hidden_inputs_func_registry.h - ${METADEF_DIR}/inc/register/op_tiling.h - ${METADEF_DIR}/inc/register/op_impl_registry_base.h - ${METADEF_DIR}/inc/register/op_impl_space_registry.h - ${METADEF_DIR}/inc/register/stream_manage_func_registry.h - ${METADEF_DIR}/inc/register/op_impl_registry_holder_manager.h - ${METADEF_DIR}/inc/register/optimization_option_registry.h - ${METADEF_DIR}/inc/register/opp_so_manager.h - ${METADEF_DIR}/inc/register/op_lib_register_impl.h -) - -list(TRANSFORM STUB_HEADER_LIST - REPLACE "^.*/([^/]+)\\.h$" "${CMAKE_CURRENT_BINARY_DIR}/stub_\\1.cc" - OUTPUT_VARIABLE STUB_SRC_LIST -) - -add_custom_command( - OUTPUT ${STUB_SRC_LIST} - COMMAND echo "Generating stub files." - && ${HI_PYTHON} ${METADEF_DIR}/tests/stub/gen_stubapi.py ${CMAKE_CURRENT_BINARY_DIR} ${STUB_HEADER_LIST} - && echo "Generating stub files end." -) - -add_custom_target(register_stub DEPENDS ${STUB_SRC_LIST}) - -############ stub/libregister.so ############ -add_library(stub_register SHARED ${STUB_SRC_LIST}) - -add_dependencies(stub_register metadef_protos register_stub) - -target_include_directories(stub_register PRIVATE - ${CMAKE_CURRENT_LIST_DIR} - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos -) - -target_compile_options(stub_register PRIVATE - -Wfloat-equal - -fno-common - -Os - -Werror=return-type -) - -target_link_libraries(stub_register - PRIVATE - intf_pub - slog_headers - msprof_headers - runtime_headers - c_sec_headers - ascend_protobuf - c_sec - slog - json - PUBLIC - metadef_headers -) - -set_target_properties(stub_register PROPERTIES - OUTPUT_NAME register - LIBRARY_OUTPUT_DIRECTORY stub -) - - -############ stub/libregister.a ############ -if (NOT ENABLE_OPEN_SRC) - target_clone(stub_register stub_register_static STATIC) - - add_dependencies(stub_register_static metadef_protos register_stub) - - target_compile_options(stub_register_static PRIVATE -ffunction-sections -fdata-sections) - - set_target_properties(stub_register_static PROPERTIES - OUTPUT_NAME register - ARCHIVE_OUTPUT_DIRECTORY stub - ) -endif () - - -############ install ############ -install(TARGETS register_static - ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR} OPTIONAL -) - -install(TARGETS rt2_registry_static - ARCHIVE DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR} OPTIONAL -) - -install(TARGETS stub_register OPTIONAL - LIBRARY DESTINATION ${INSTALL_LIBRARY_DIR}/${CMAKE_SYSTEM_PROCESSOR}/stub -) diff --git a/register/ascendc/ascendc_py.cc b/register/ascendc/ascendc_py.cc deleted file mode 100644 index 7fd9e2f49e98e16661c5db32e9f7915e9980108c..0000000000000000000000000000000000000000 --- a/register/ascendc/ascendc_py.cc +++ /dev/null @@ -1,522 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/debug/ge_log.h" -#include "register/op_tiling_info.h" -#include "register/op_tiling_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "common/util/tiling_utils.h" -#include "register/op_impl_registry.h" -#include "exe_graph/runtime/storage_shape.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "exe_graph/runtime/tiling_context.h" -#include "common/checker.h" -#include "common/util/mem_utils.h" -#include "register/op_check_register.h" -#include "register/tilingdata_base.h" - -namespace { -bool DumpResultInfo(const std::string &result_string, char *result_info_char, const size_t result_info_len) { - if (result_info_char == nullptr) { - GE_LOGE("run_info buffer is null"); - return false; - } - if (result_string.size() >= result_info_len) { - GE_LOGE("result_info too large. %zu/%zu", result_string.size(), result_info_len); - return false; - } - return memcpy_s(result_info_char, result_string.size() + 1, result_string.c_str(), result_string.size() + 1) == EOK; -} - -void CopyConstDataWithFloat16(const nlohmann::json &json_array, std::vector &value) { - std::vector const_value = json_array.get>(); - std::vector const_data_vec; - for (const auto i : const_value) { - uint16_t const_data_uint16 = optiling::Float32ToFloat16(i); - const_data_vec.emplace_back(const_data_uint16); - } - uint8_t *pv_begin = reinterpret_cast(const_data_vec.data()); - uint8_t *pv_end = pv_begin + (const_data_vec.size() * sizeof(uint16_t)); - value = std::vector(pv_begin, pv_end); -} - -template -void GetConstDataPointer(const nlohmann::json &json_array, std::vector &const_value) { - std::vector value = json_array.get>(); - uint8_t *pv_begin = reinterpret_cast(value.data()); - uint8_t *pv_end = pv_begin + (value.size() * sizeof(T)); - const_value = std::vector(pv_begin, pv_end); -} - -bool CopyConstData(const std::string &dtype, const nlohmann::json &json_array, std::vector &value) { - if (dtype == "int8") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint8") { - GetConstDataPointer(json_array, value); - } else if (dtype == "int16") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint16") { - GetConstDataPointer(json_array, value); - } else if (dtype == "int32") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint32") { - GetConstDataPointer(json_array, value); - } else if (dtype == "int64") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint64") { - GetConstDataPointer(json_array, value); - } else if (dtype == "float32") { - GetConstDataPointer(json_array, value); - } else if (dtype == "double") { - GetConstDataPointer(json_array, value); - } else if (dtype == "float16") { - CopyConstDataWithFloat16(json_array, value); - } else { - GE_LOGE("Unknown dtype: %s", dtype.c_str()); - return false; - } - return true; -} - -void ParseConstShapeDescV2(const nlohmann::json &shape_json, ge::Operator &op_para, - std::map> &const_values) { - std::vector shape; - std::string format_str; - std::string dtype_str; - - if (!shape_json.contains("const_value")) { - GELOGI("Not const tenosr"); - return; - } - if (!shape_json.contains("name")) { - REPORT_INNER_ERR_MSG("E19999", "const tensor has no name"); - return; - } - std::string name = shape_json["name"]; - - if (shape_json.contains("shape")) { - shape = shape_json["shape"].get>(); - } - if (shape_json.contains("format")) { - format_str = shape_json["format"].get(); - } - if (shape_json.contains("dtype")) { - dtype_str = shape_json["dtype"].get(); - } - - std::vector value; - const bool bres = CopyConstData(dtype_str, shape_json["const_value"], value); - if (!bres) { - REPORT_INNER_ERR_MSG("E19999", "CopyConstData faild. buffer is null"); - return; - } - auto res = const_values.emplace(name, std::move(value)); - if (res.first == const_values.end()) { - return; // CodeDEX complains 'CHECK_CONTAINER_EMPTY' - } - - const ge::GeShape ge_shape(shape); - ge::DataType ge_dtype = ge::DT_UNDEFINED; - if (!dtype_str.empty()) { - std::transform(dtype_str.begin(), dtype_str.end(), dtype_str.begin(), ::toupper); - dtype_str = "DT_" + dtype_str; - ge_dtype = ge::TypeUtils::SerialStringToDataType(dtype_str); - } - ge::Format ge_format = ge::FORMAT_RESERVED; - if (!format_str.empty()) { - std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - } - ge::GeTensorDesc ge_tensor(ge_shape, ge_format, ge_dtype); - ge_tensor.SetName(name); - ge::GeTensor const_tensor(ge_tensor, res.first->second); - ge::GeTensorPtr const_tensor_ptr = ge::MakeShared(const_tensor); - ge::OpDescPtr const_op_desc = ge::OpDescUtils::CreateConstOp(const_tensor_ptr); - ge::Operator const_op = ge::OpDescUtils::CreateOperatorFromOpDesc(const_op_desc); - (void) op_para.SetInput(name.c_str(), const_op); -} - -void ParseShapeDescV2(const nlohmann::json &shape, ge::OpDescPtr &op_desc, const bool is_input) { - ge::GeTensorDesc tensor; - if (shape.contains("shape")) { - tensor.SetShape(ge::GeShape(shape["shape"].get>())); - } - if (shape.contains("ori_shape")) { - tensor.SetOriginShape(ge::GeShape(shape["ori_shape"].get>())); - } - if (shape.contains("format")) { - std::string format_str = shape["format"].get(); - std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - tensor.SetFormat(ge_format); - } - if (shape.contains("ori_format")) { - std::string format_str = shape["ori_format"].get(); - std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - tensor.SetOriginFormat(ge_format); - } - if (shape.contains("dtype")) { - std::string dtype_str = shape["dtype"].get(); - std::transform(dtype_str.begin(), dtype_str.end(), dtype_str.begin(), ::toupper); - dtype_str = "DT_" + dtype_str; - ge::DataType ge_dtype = ge::TypeUtils::SerialStringToDataType(dtype_str); - tensor.SetDataType(ge_dtype); - } - if (shape.contains("name")) { - std::string name = shape["name"]; - tensor.SetName(name); - is_input ? op_desc->AddInputDesc(name, tensor) : op_desc->AddOutputDesc(name, tensor); - } else { - is_input ? op_desc->AddInputDesc(tensor) : op_desc->AddOutputDesc(tensor); - } -} - -void ParseShapeDescListV2(const nlohmann::json &shape_list, ge::OpDescPtr &op_desc, const bool is_input) { - for (const auto &elem : shape_list) { - if (elem.is_array()) { - for (const auto &shape : elem) { - ParseShapeDescV2(shape, op_desc, is_input); - } - } else { - ParseShapeDescV2(elem, op_desc, is_input); - } - } -} - -void ParseConstTensorListV2(const nlohmann::json &shape_list, ge::Operator &operator_para, - std::map> &const_values) { - for (const auto &elem : shape_list) { - if (elem.is_array()) { - for (const auto &shape : elem) { - ParseConstShapeDescV2(shape, operator_para, const_values); - } - } else { - ParseConstShapeDescV2(elem, operator_para, const_values); - } - } -} - -template -void ParseAndSetAttrValue(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - T attr_value = attr["value"].get(); - (void) op.SetAttr(attr_name.c_str(), attr_value); -} -template -void ParseAndSetAttrListValue(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - std::vector attr_value = attr["value"].get>(); - (void) op.SetAttr(attr_name.c_str(), attr_value); -} - -void ParseAndSetAttrListListValue(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - std::vector> attr_value_int32 = attr["value"].get>>(); - std::vector> attr_value_int64; - std::vector temp_int64_vec; - for (const auto &vec_int32 : attr_value_int32) { - for (const auto &item : vec_int32) { - int64_t tmp = static_cast(item); - temp_int64_vec.emplace_back(tmp); - } - attr_value_int64.emplace_back(temp_int64_vec); - temp_int64_vec.clear(); - } - - (void) op.SetAttr(attr_name.c_str(), attr_value_int64); -} - -void ParseAndSetAttrListListInt64Value(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - const std::vector> attr_value_int64 = attr["value"].get>>(); - (void) op.SetAttr(attr_name.c_str(), attr_value_int64); -} - -using ParseAndSetAttrValueFunc = std::function; -using ParseAndSetAttrValuePtr = std::shared_ptr; - -const std::map parse_attr_dtype_map = { - {"bool", ge::MakeShared(&ParseAndSetAttrValue)}, - {"float", ge::MakeShared(&ParseAndSetAttrValue)}, - {"float32", ge::MakeShared(&ParseAndSetAttrValue)}, - {"int", ge::MakeShared(&ParseAndSetAttrValue)}, - {"int32", ge::MakeShared(&ParseAndSetAttrValue)}, - {"int64", ge::MakeShared(&ParseAndSetAttrValue)}, - {"str", ge::MakeShared(&ParseAndSetAttrValue)}, - {"list_bool", ge::MakeShared(&ParseAndSetAttrListValue)}, - {"list_float", ge::MakeShared(&ParseAndSetAttrListValue)}, - {"list_float32", ge::MakeShared(&ParseAndSetAttrListValue)}, - {"list_int", ge::MakeShared(&ParseAndSetAttrListValue)}, - {"list_int32", ge::MakeShared(&ParseAndSetAttrListValue)}, - {"list_int64", ge::MakeShared(&ParseAndSetAttrListValue)}, - {"list_str", ge::MakeShared(&ParseAndSetAttrListValue)}, - {"list_list_int", ge::MakeShared(&ParseAndSetAttrListListValue)}, - {"list_list_int32", ge::MakeShared(&ParseAndSetAttrListListValue)}, - {"list_list_int64", ge::MakeShared(&ParseAndSetAttrListListInt64Value)}}; - -void ParseAndSetAttr(const nlohmann::json &attr, ge::Operator &op) { - if ((!attr.contains("name")) || (!attr.contains("dtype")) || (!attr.contains("value"))) { - REPORT_INNER_ERR_MSG("E19999", "cur attr does not contain name or dtype or value."); - return; - } - std::string attr_name; - std::string dtype; - attr_name = attr["name"].get(); - dtype = attr["dtype"].get(); - auto iter = parse_attr_dtype_map.find(dtype); - if (iter == parse_attr_dtype_map.end()) { - REPORT_INNER_ERR_MSG("E19999", "Unknown dtype[%s], which is unsupported.", dtype.c_str()); - return; - } - ParseAndSetAttrValuePtr func_ptr = iter->second; - if (func_ptr == nullptr) { - GE_LOGE("ParseAndSetAttrValueFunc ptr cannot be null!"); - return; - } - (*func_ptr)(op, attr, attr_name); - GELOGD("Finish to set attr[name: %s] to Operator.", attr_name.c_str()); -} - -void ParseAndSetAttrsList(const nlohmann::json &attrs_list, ge::Operator &op) { - for (const auto &attr : attrs_list) { - ParseAndSetAttr(attr, op); - } -} - -void CheckAndSetAttr(const char *attrs, ge::Operator &operator_param) { - if (attrs != nullptr) { - GELOGD("Attrs set from pyAPI is: %s", attrs); - const nlohmann::json attrs_json = nlohmann::json::parse(attrs); - ParseAndSetAttrsList(attrs_json, operator_param); - } else { - GELOGD("Attrs has not been set."); - } - return; -} - -void ParseInputsAndOutputs(const char *inputs, const char *outputs, ge::OpDescPtr &op_desc, - ge::Operator &operator_param, std::map> &const_values) { - const nlohmann::json inputs_json = nlohmann::json::parse(inputs); - const nlohmann::json outputs_json = nlohmann::json::parse(outputs); - ParseShapeDescListV2(inputs_json, op_desc, true); - ParseShapeDescListV2(outputs_json, op_desc, false); - operator_param = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); - ParseConstTensorListV2(inputs_json, operator_param, const_values); -} -} // Anonymous Namespace - -using namespace optiling; - -extern "C" int32_t AscendCPyInterfaceCheckOp(const char *check_type, const char *optype, const char *inputs, - const char *outputs, const char *attrs, char *result_info, - const size_t result_info_len) { - if ((check_type == nullptr) || (optype == nullptr) || (inputs == nullptr) || (outputs == nullptr) || - (attrs == nullptr) || (result_info == nullptr)) { - GELOGE(ge::GRAPH_FAILED, "check_type/optype/inputs/outputs/attrs/result_info is null, %s, %s, %s, %s, %s, %s", - check_type, optype, inputs, outputs, attrs, result_info); - return 0; - } - ge::AscendString check_type_str = check_type; - ge::AscendString op_type_str = optype; - auto check_func = OpCheckFuncRegistry::GetOpCapability(check_type_str, op_type_str); - if (check_func == nullptr) { - GELOGW("Failed to GetOpCapability. check_type = %s, optype = %s", check_type, optype); - return 0; - } - - ge::OpDescPtr op_desc_ptr = ge::MakeShared("", op_type_str.GetString()); - std::map> const_values; - ge::Operator operator_param; - try { - ParseInputsAndOutputs(inputs, outputs, op_desc_ptr, operator_param, const_values); - CheckAndSetAttr(attrs, operator_param); - } catch (...) { - REPORT_INNER_ERR_MSG("E19999", - "Failed to parse json in AscendCPyInterfaceCheckOp. inputs = %s, outputs = %s, attrs = %s", - inputs, outputs, attrs); - return 0; - } - - ge::AscendString result; - try { - const ge::graphStatus rc = (check_func)(operator_param, result); - GELOGI("check cap return rc = %u, check_type = %s, optype = %s.", rc, check_type, optype); - if (rc != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "check cap return failed. check_type = %s, optype = %s", check_type, optype); - return 0; - } - } catch (...) { - GELOGE(ge::GRAPH_FAILED, "check cap abnormal failed. check_type = %s, optype = %s", check_type, optype); - return 0; - } - std::string std_str = result.GetString(); - bool dump_res = DumpResultInfo(std_str, result_info, result_info_len); - if (!dump_res) { - REPORT_INNER_ERR_MSG("E19999", "DumpResultInfo failed. result = %s", std_str.c_str()); - return 0; - } - return 1; -} - -extern "C" int32_t AscendCPyInterfaceGeneralized(const char *optype, const char *inputs, const char *outputs, - const char *attrs, const char *generalize_config, char *result_info, - const size_t result_info_len) { - if ((optype == nullptr) || (inputs == nullptr) || (outputs == nullptr) || (attrs == nullptr) || - (generalize_config == nullptr) || (result_info == nullptr)) { - GELOGE(ge::GRAPH_FAILED, - "optype/inputs/outputs/attrs/generalize_config/result_info is null, %s, %s, %s, %s, %s, %s", optype, inputs, - outputs, attrs, generalize_config, result_info); - return 0; - } - ge::AscendString op_type_str = optype; - auto generalize_func = OpCheckFuncRegistry::GetParamGeneralize(op_type_str); - if (generalize_func == nullptr) { - GELOGW("Failed to GetParamGeneralize. optype = %s", optype); - return 0; - } - - ge::OpDescPtr op_desc_ptr = ge::MakeShared("", op_type_str.GetString()); - std::map> const_values; - ge::Operator operator_params; - try { - ParseInputsAndOutputs(inputs, outputs, op_desc_ptr, operator_params, const_values); - CheckAndSetAttr(attrs, operator_params); - } catch (...) { - GELOGE(ge::GRAPH_FAILED, "Failed to parse json in AscendCPyInterfaceGeneralized. %s, %s, %s", - inputs, outputs, attrs); - return 0; - } - ge::AscendString generalize_config_str(generalize_config); - ge::AscendString result; - try { - const ge::graphStatus rc = (generalize_func)(operator_params, generalize_config_str, result); - GELOGI("generalize_func return rc = %d, optype = %s, generalize_config = %s", rc, optype, generalize_config); - if (rc != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "call generalize_func failed. optype = %s, generalize_config = %s", optype, - generalize_config); - return 0; - } - } catch (...) { - GELOGE(ge::GRAPH_FAILED, "call generalize_func failed. optype = %s, generalize_config = %s", optype, - generalize_config); - return 0; - } - std::string result_str = result.GetString(); - bool dump_res = DumpResultInfo(result_str, result_info, result_info_len); - if (!dump_res) { - REPORT_INNER_ERR_MSG("E19999", "DumpResultInfo failed. result = %s", result_str.c_str()); - return 0; - } - return 1; -} - -extern "C" int32_t AscendCPyInterfaceGetTilingDefInfo(const char *optype, char *result_info, size_t result_info_len) { - if ((optype == nullptr) || (result_info == nullptr)) { - GELOGE(ge::GRAPH_FAILED, "optype/result_info is null, %s, %s", optype, result_info); - return 0; - } - ge::AscendString op_type_str = optype; - auto tiling_def = CTilingDataClassFactory::GetInstance().CreateTilingDataInstance(optype); - if (tiling_def == nullptr) { - GELOGW("Failed to CreateTilingDataInstance. optype = %s", optype); - return 0; - } - - nlohmann::json json_obj; - json_obj["class_name"] = tiling_def->GetTilingClassName(); - json_obj["data_size"] = tiling_def->GetDataSize(); - const auto &field_list = tiling_def->GetFieldInfo(); - nlohmann::json json_field_list; - for (const auto &field : field_list) { - nlohmann::json json_field; - json_field["classType"] = field.classType_; - json_field["name"] = field.name_; - json_field["dtype"] = field.dtype_; - if (json_field["classType"] == "1") { - json_field["arrSize"] = field.arrSize_; - } else if (json_field["classType"] == "2") { - json_field["structType"] = field.structType_; - json_field["structSize"] = field.structSize_; - } - json_field_list.emplace_back(json_field); - } - json_obj["fields"] = json_field_list; - const std::string json_str = json_obj.dump(); - bool dump_res = DumpResultInfo(json_str, result_info, result_info_len); - if (!dump_res) { - GELOGE(ge::GRAPH_FAILED, "AscendCPyInterfaceGetTilingDefInfo DumpResultInfo failed. result = %s", json_str.c_str()); - return 0; - } - return 1; -} - -extern "C" int32_t AscendCPyInterfaceOpReplay(const char *optype, const char *soc_version, const int32_t block_dim, - const char *tiling_data, const char *kernel_name, const char *entry_file, - const char *output_kernel_file, const int32_t core_type, - const int32_t task_ration, const int32_t tiling_key) { - if ((optype == nullptr) || (soc_version == nullptr) || (tiling_data == nullptr) || (kernel_name == nullptr) || - (entry_file == nullptr) || (output_kernel_file == nullptr)) { - GELOGE(ge::GRAPH_FAILED, - "optype/soc_version/tiling_data/kernel_name/entry_file/output_kernel_file is null, " - "%s, %s, %s, %s, %s, %s", - optype, soc_version, tiling_data, kernel_name, entry_file, output_kernel_file); - return 0; - } - constexpr int32_t CORE_TYPE_BOTH = 0; - constexpr int32_t CORE_TYPE_CUBE = 1; - constexpr int32_t CORE_TYPE_VEC = 2; - if ((core_type != CORE_TYPE_BOTH) && (core_type != CORE_TYPE_CUBE) && (core_type != CORE_TYPE_VEC)) { - GELOGE(ge::GRAPH_FAILED, - "core_type is valid, should be one of 0/1/2, but args is " - "%d", - core_type); - return 0; - } - constexpr int32_t TASK_RATION_ONE = 1; - constexpr int32_t TASK_RATION_TWO = 2; - if ((task_ration != TASK_RATION_ONE) && (task_ration != TASK_RATION_TWO)) { - GELOGE(ge::GRAPH_FAILED, - "task_ration is valid, should be one of 1/2, but args is " - "%d", - task_ration); - return 0; - } - ge::AscendString op_type_str = optype; - ge::AscendString soc_version_str = soc_version; - auto replay_func = OpCheckFuncRegistry::GetReplay(op_type_str, soc_version_str); - if (replay_func == nullptr) { - GELOGE(ge::GRAPH_FAILED, "Failed to GetReplay. optype = %s, soc_version = %s", optype, soc_version); - return 0; - } - - try { - ReplayFuncParam replayParam; - replayParam.block_dim = block_dim; - replayParam.tiling_data = tiling_data; - replayParam.kernel_name = kernel_name; - replayParam.entry_file = entry_file; - replayParam.gentype = 0; - replayParam.output_kernel_file = output_kernel_file; - replayParam.task_ration = task_ration; - replayParam.tiling_key = tiling_key; - const int32_t rc = (replay_func)(replayParam, core_type); - if (rc <= 0) { - GELOGE(ge::GRAPH_FAILED, "call replay_func return %d. optype = %s, soc_version = %s", rc, optype, soc_version); - return 0; - } - GELOGI("replay_func return rc = %d, optype = %s, soc_version = %s.", rc, optype, soc_version); - } catch (...) { - GELOGE(ge::GRAPH_FAILED, "call replay_func segment fault. optype = %s, soc_version = %s", optype, soc_version); - return 0; - } - return 1; -} diff --git a/register/ascendc/op_check.cc b/register/ascendc/op_check.cc deleted file mode 100644 index 645ca9e89811ca39f6bc348a50e7fe23d3820ee3..0000000000000000000000000000000000000000 --- a/register/ascendc/op_check.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_check_register.h" -#include "common/ge_common/debug/ge_log.h" - -namespace optiling { -std::map> - OpCheckFuncRegistry::check_op_capability_instance_; -std::map OpCheckFuncRegistry::gen_simplifiedkey_instance_; -std::map OpCheckFuncRegistry::param_generalize_instance_; -std::map> OpCheckFuncRegistry::replay_instance_; - -void OpCheckFuncRegistry::RegisterOpCapability(const ge::AscendString &check_type, const ge::AscendString &op_type, - OP_CHECK_FUNC func) { - check_op_capability_instance_[check_type][op_type] = func; - GELOGI("RegisterOpCapability: check_type:%s, op_type:%s, funcPointer:%p, registered count:%zu", - check_type.GetString(), op_type.GetString(), func, check_op_capability_instance_[check_type].size()); -} - -OP_CHECK_FUNC OpCheckFuncRegistry::GetOpCapability(const ge::AscendString &check_type, - const ge::AscendString &op_type) { - const auto &check_map_it = check_op_capability_instance_.find(check_type); - if (check_map_it == check_op_capability_instance_.end()) { - GELOGW("GetOpCapability: check_type:%s, op_type:%s, cannot find check_type.", check_type.GetString(), - op_type.GetString()); - return nullptr; - } - const auto &func_it = check_map_it->second.find(op_type); - if (func_it == check_map_it->second.end()) { - GELOGW("GetOpCapability: check_type:%s, op_type:%s, cannot find op_type.", check_type.GetString(), - op_type.GetString()); - return nullptr; - } - return func_it->second; -} - -void OpCheckFuncRegistry::RegisterGenSimplifiedKeyFunc(const ge::AscendString &op_type, GEN_SIMPLIFIEDKEY_FUNC func) { - gen_simplifiedkey_instance_[op_type] = func; - GELOGI("RegisterGenSimplifiedKeyFunc: op_type:%s, registered count:%zu", op_type.GetString(), - gen_simplifiedkey_instance_.size()); -} - -GEN_SIMPLIFIEDKEY_FUNC OpCheckFuncRegistry::GetGenSimplifiedKeyFun(const ge::AscendString &op_type) { - const auto &func_it = gen_simplifiedkey_instance_.find(op_type); - if (func_it == gen_simplifiedkey_instance_.end()) { - GELOGW("GetGenSimplifiedKeyFun: op_type:%s, cannot find func by op_type.", op_type.GetString()); - return nullptr; - } - return func_it->second; -} - -PARAM_GENERALIZE_FUNC OpCheckFuncRegistry::GetParamGeneralize(const ge::AscendString &op_type) { - const auto &func_it = param_generalize_instance_.find(op_type); - if (func_it == param_generalize_instance_.end()) { - GELOGW("GetParamGeneralize: op_type:%s, cannot find op_type.", op_type.GetString()); - return nullptr; - } - return func_it->second; -} - -void OpCheckFuncRegistry::RegisterParamGeneralize(const ge::AscendString &op_type, PARAM_GENERALIZE_FUNC func) { - param_generalize_instance_[op_type] = func; - GELOGI("RegisterParamGeneralize: op_type:%s, funcPointer:%p, registered count:%zu", op_type.GetString(), func, - param_generalize_instance_.size()); -} - -void OpCheckFuncRegistry::RegisterReplay(const ge::AscendString &op_type, const ge::AscendString &soc_version, - REPLAY_FUNC func) { - replay_instance_[op_type][soc_version] = func; - GELOGI("RegisterReplay: op_type:%s, soc_version:%s funcPointer:%p, registered count:%zu", op_type.GetString(), - soc_version.GetString(), func, replay_instance_[op_type].size()); -} - -REPLAY_FUNC OpCheckFuncRegistry::GetReplay(const ge::AscendString &op_type, const ge::AscendString &soc_version) { - const auto &soc_map_it = replay_instance_.find(op_type); - if (soc_map_it == replay_instance_.end()) { - GELOGW("GetReplay: op_type:%s, soc_version:%s, cannot find op_type.", op_type.GetString(), soc_version.GetString()); - return nullptr; - } - const auto &func_it = soc_map_it->second.find(soc_version); - if (func_it == soc_map_it->second.end()) { - GELOGW("GetReplay: op_type:%s, soc_version:%s, cannot find soc_version.", op_type.GetString(), - soc_version.GetString()); - return nullptr; - } - return func_it->second; -} - -OpCheckFuncHelper::OpCheckFuncHelper(const ge::AscendString &check_type, const ge::AscendString &op_type, - OP_CHECK_FUNC func) { - OpCheckFuncRegistry::RegisterOpCapability(check_type, op_type, func); -} - -OpCheckFuncHelper::OpCheckFuncHelper(const ge::AscendString &op_type, PARAM_GENERALIZE_FUNC func) { - OpCheckFuncRegistry::RegisterParamGeneralize(op_type, func); -} - -ReplayFuncHelper::ReplayFuncHelper(const ge::AscendString &op_type, const ge::AscendString &soc_version, - REPLAY_FUNC func) { - OpCheckFuncRegistry::RegisterReplay(op_type, soc_version, func); -} -} // end of namespace optiling diff --git a/register/ascendc/tilingdata_base.cc b/register/ascendc/tilingdata_base.cc deleted file mode 100644 index 0cdcbcf7c71d54e29a2723aa28db3c12f4eed2c0..0000000000000000000000000000000000000000 --- a/register/ascendc/tilingdata_base.cc +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/tilingdata_base.h" -#include -#include -#ifndef ASCENDC_DEVICE_REG_STATIC -#include "common/ge_common/debug/ge_log.h" -#endif -#include "graph/ascend_string.h" - -namespace optiling { -std::vector TilingDef::GetFieldInfo() const { - return field_info_; -} - -const char *TilingDef::GetTilingClassName() const { - return class_name_; -} - -size_t TilingDef::GetDataSize() const { - return data_size_; -} - -bool CheckPathIsHeader(std::string file) { - const std::string suffix = ".h"; - if (suffix.size() > file.size()) { - return false; - } - return file.substr(file.size() - suffix.size()) == suffix; -} - -const char* GetFileName(const char* path) { - const char* file_name = strrchr(path, '/'); - if (!file_name) { - return path; - } else { - file_name++; - return file_name; - } -} - -uint32_t __attribute__((weak)) TilingDataStructBase::RecordTilingStruct(const char* name, const char* file, \ - uint32_t line) { - const char* file_name = GetFileName(file); - bool is_header = CheckPathIsHeader(std::string(file_name)); - auto it = records.find(name); - if (it != records.end()) { - std::pair item = it->second; - if (!is_header) { - if (item.second != line) { - printf("[Warning]: tiling struct [%s] is conflict with one in file %s, line %d\n", \ - name, item.first, item.second); - } - return 0; - } - if (!CheckPathIsHeader(std::string(item.first))) { - return 0; - } - if ((strcmp(item.first, file_name) == 0) && item.second == line) { - return 0; - } - printf("[Warning]: tiling struct [%s] is conflict with one in file %s, line %d\n", \ - name, item.first, item.second); - } else { - records.emplace(name, std::make_pair(file_name, line)); - } - return 0; -} - -void TilingDef::GeLogError(const std::string &str) const { -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGE(ge::GRAPH_FAILED, "%s", str.c_str()); -#endif -} - -void TilingDef::SetDataPtr(void *dataPtr) { - if (!inited_data_ptr && data_ptr_ != nullptr) { - delete[] data_ptr_; - } - inited_data_ptr = true; - data_ptr_ = (uint8_t*)dataPtr; - for (auto &ptr : saveBufferPtr) { - TilingDef* sub_ptr = (TilingDef *)ptr.first; - size_t offset = ptr.second; - uint8_t* struct_ptr = data_ptr_ + offset; - sub_ptr->SetDataPtr(struct_ptr); - } -} - -void TilingDef::SaveToBuffer(void *pdata, size_t capacity) { - if (inited_data_ptr) { -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGD("TilingDef::SaveToBuffer, op %s, data had been saved.", class_name_); -#endif - return; - } - // copy tilingdata to buffer without struct tiling data. - auto mem_ret = memcpy_s(pdata, capacity, data_ptr_, data_size_); - if (mem_ret != EOK) { -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGE(ge::GRAPH_FAILED, - "TilingDef::SaveToBuffer failed: memcpy_s return op [%s] [%d], capacity = [%zu], data_size_ = [%zu].", - class_name_, mem_ret, capacity, data_size_); -#endif - } -} - -void TilingDef::CheckAlignAndGenPlaceHolder(const char *name, size_t typeSize) { - if (data_size_ % typeSize == 0) { - return; - } - size_t alignSize = typeSize - (data_size_ % typeSize); - field_info_.emplace_back(FieldInfo("uint8_t", name, alignSize)); - data_size_ += alignSize; - return; -} - -void TilingDef::InitData() { -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGD("TilingDef::InitData, op %s, data size %d.", class_name_, data_size_); -#endif - data_ptr_ = new (std::nothrow)uint8_t[data_size_](); - if (data_ptr_ == nullptr) { -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGE(ge::GRAPH_FAILED, "TilingDef::InitData failed: op %s, init data size %d.", class_name_, data_size_); -#endif - return; - } - for (auto &ptr : saveBufferPtr) { - TilingDef* sub_ptr = (TilingDef *)ptr.first; - size_t offset = ptr.second; - uint8_t* struct_ptr = data_ptr_ + offset; - sub_ptr->SetDataPtr(struct_ptr); - } -} - -CTilingDataClassFactory &CTilingDataClassFactory::GetInstance() -{ - static CTilingDataClassFactory instance; - return instance; -} - -void CTilingDataClassFactory::RegisterTilingData(const char *op_type, - const TilingDataConstructor constructor) { - instance_.emplace(op_type, constructor); -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGD("op_type: %s, registered count: %zu.", op_type, instance_.size()); -#endif -} - -std::shared_ptr CTilingDataClassFactory::CreateTilingDataInstance(const char *op_type) { - const auto it = instance_.find(op_type); - if (it == instance_.end()) { -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGW("cannot find op_type:%s.", op_type); -#endif - return nullptr; - } - - const TilingDataConstructor constructor = it->second; - if (constructor == nullptr) { -#ifndef ASCENDC_DEVICE_REG_STATIC - GELOGW("CreateTilingDataInstance: constructor is nullptr."); -#endif - return nullptr; - } - - return (*constructor)(); -} -} // end of namespace optiling diff --git a/register/auto_mapping_util.cpp b/register/auto_mapping_util.cpp deleted file mode 100644 index 5523bda261a26d800fcb6542fdbc451b20a68ef7..0000000000000000000000000000000000000000 --- a/register/auto_mapping_util.cpp +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/auto_mapping_util.h" -#include "graph/debug/ge_util.h" - -namespace { -constexpr int32_t kMaxFuncRecursiveDepth = 30; -} // namespace -namespace ge { - -// Convert tensorflow property to ge property -bool AutoMappingUtil::FindAttrValue(const domi::tensorflow::NodeDef *const nodeDef, const string &attr_name, - domi::tensorflow::AttrValue &attr_value) { - if (nodeDef == nullptr) { - GE_LOGE("nodeDef is nullptr."); - return false; - } - const google::protobuf::Map &attr = nodeDef->attr(); - const google::protobuf::Map::const_iterator it = attr.find(attr_name); - if (it != attr.end()) { - attr_value = it->second; - return true; - } - return false; -} - -// Get the attribute shape of tensorflow -void AutoMappingUtil::ConvertShape(const domi::tensorflow::TensorShapeProto &shape, - vector& shape_dims) { - shape_dims.clear(); - if (!shape.unknown_rank()) { - for (auto &dim : shape.dim()) { - shape_dims.push_back(dim.size()); - } - } else { - shape_dims = ge::UNKNOWN_SHAPE; - } -} - -graphStatus AutoMappingUtil::ConvertTensor(const domi::tensorflow::TensorProto &tensor, ge::GeTensorPtr &weight) { - weight = ComGraphMakeShared(); - if (weight == nullptr) { - GE_LOGE("Weight is nullptr."); - return GRAPH_FAILED; - } - const domi::tensorflow::DataType tf_data_type = tensor.dtype(); - const ge::DataType ge_data_type = domi::TensorAssign::ConvertTensorflowDataType(tf_data_type); - if (domi::TensorAssign::SetGeTensorDataType(ge_data_type, weight) != domi::SUCCESS) { - GE_LOGE("Set Ge tensor data type failed."); - return GRAPH_FAILED; - } - if (domi::TensorAssign::SetGeTensor(tensor, weight) != domi::SUCCESS) { - GE_LOGE("Set Ge tensor failed."); - return GRAPH_FAILED; - } - return GRAPH_SUCCESS; -} - -void AutoMappingUtil::ConvertTensorList(const domi::tensorflow::AttrValue_ListValue &list, - std::vector &vec) { - vec.clear(); - for (auto &tensor : list.tensor()) { - ge::GeTensorPtr ge_tensor = nullptr; - if (ConvertTensor(tensor, ge_tensor) != GRAPH_SUCCESS) { - GE_LOGE("Convert tensor failed."); - return; - } - vec.push_back(ge_tensor); - } -} - -void AutoMappingUtil::ConvertFunc(const domi::tensorflow::NameAttrList& tf_func, - ge::NamedAttrs& ge_func, const int32_t recursive_depth) { - if (recursive_depth >= kMaxFuncRecursiveDepth) { - GELOGW("The call stack has exceeded the maximum recursive depth"); - return; - } - ge_func.SetName(tf_func.name()); - auto& attrs = tf_func.attr(); - for (auto &item : attrs) { - ConvertValue(item.first, item.second, ge_func, recursive_depth + 1); - } -} - -void AutoMappingUtil::ConvertDataTypeList(const domi::tensorflow::AttrValue_ListValue &list, - std::vector &vec) { - vec.clear(); - for (auto &e : list.type()) { - vec.push_back(domi::TensorAssign::ConvertTensorflowDataType(static_cast(e))); - } -} - -void AutoMappingUtil::ConvertShapeList(const domi::tensorflow::AttrValue_ListValue &list, - std::vector> &vec) { - vec.clear(); - for (const auto &e : list.shape()) { - vector shape_dims; - ConvertShape(e, shape_dims); - vec.push_back(shape_dims); - } -} - -void AutoMappingUtil::ConvertFuncList(const domi::tensorflow::AttrValue_ListValue &list, - std::vector &vec, const int32_t recursive_depth) { - if (recursive_depth >= kMaxFuncRecursiveDepth) { - GELOGW("The call stack has exceeded the maximum recursive depth"); - return; - } - vec.clear(); - for (const auto &e : list.func()) { - ge::NamedAttrs func; - ConvertFunc(e, func, recursive_depth + 1); - vec.push_back(func); - } -} - -} // namespace domi diff --git a/register/auto_mapping_util.h b/register/auto_mapping_util.h deleted file mode 100644 index 4f07578184b83db1dd3dd8dafc558d0d42207328..0000000000000000000000000000000000000000 --- a/register/auto_mapping_util.h +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_AUTO_MAPPING_UTIL_H_ -#define COMMON_AUTO_MAPPING_UTIL_H_ - -#include -#include "external/graph/types.h" -#include "common/ge_common/debug/ge_log.h" -#include "proto/tensorflow/attr_value.pb.h" -#include "proto/tensorflow/node_def.pb.h" -#include "graph/ge_tensor.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/debug/ge_log.h" -#include "register/tensor_assign.h" - -namespace ge { - -class AutoMappingUtil { - public: - static bool FindAttrValue(const domi::tensorflow::NodeDef *const nodeDef, const string &attr_name, - domi::tensorflow::AttrValue &attr_value); - static void ConvertShape(const domi::tensorflow::TensorShapeProto &shape, std::vector &shape_dims); - static graphStatus ConvertTensor(const domi::tensorflow::TensorProto &tensor, ge::GeTensorPtr &weight); - static void ConvertFunc(const domi::tensorflow::NameAttrList &tf_func, ge::NamedAttrs &ge_func, - const int32_t recursive_depth = 0); - - static void ConvertDataTypeList(const domi::tensorflow::AttrValue_ListValue &list, std::vector &vec); - static void ConvertShapeList(const domi::tensorflow::AttrValue_ListValue &list, std::vector> &vec); - static void ConvertTensorList(const domi::tensorflow::AttrValue_ListValue &list, std::vector &vec); - static void ConvertFuncList(const domi::tensorflow::AttrValue_ListValue &list, std::vector &vec, - const int32_t recursive_depth = 0); - - // Get the attribute list list of tensorflow and save it to obj according to the key - template - static void ConvertList(const std::string &key, const domi::tensorflow::AttrValue &value, T &obj, - const int32_t recursive_depth = 0) { - const domi::tensorflow::AttrValue_ListValue &list = value.list(); - if (list.s_size() > 0) { - std::vector vec; - for (const auto &e : list.s()) { - vec.push_back(e); - } - (void) ge::AttrUtils::SetListStr(obj, key, vec); - } else if (list.i_size() > 0) { - std::vector vec; - for (const int64_t e : list.i()) { - vec.push_back(e); - } - (void) ge::AttrUtils::SetListInt(obj, key, vec); - } else if (list.f_size() > 0) { - std::vector vec; - for (const float32_t e : list.f()) { - vec.push_back(e); - } - (void) ge::AttrUtils::SetListFloat(obj, key, vec); - } else if (list.b_size() > 0) { - std::vector vec; - for (const bool e : list.b()) { - vec.push_back(e); - } - (void) ge::AttrUtils::SetListBool(obj, key, vec); - } else if (list.type_size() > 0) { - std::vector vec; - ConvertDataTypeList(list, vec); - (void) ge::AttrUtils::SetListDataType(obj, key, vec); - } else if (list.shape_size() > 0) { - std::vector> shape_dims_vec; - ConvertShapeList(list, shape_dims_vec); - (void) ge::AttrUtils::SetListListInt(obj, key, shape_dims_vec); - } else if (list.tensor_size() > 0) { - std::vector vec; - ConvertTensorList(list, vec); - (void) ge::AttrUtils::SetListTensor(obj, key, vec); - } else if (list.func_size() > 0) { - std::vector vec; - ConvertFuncList(list, vec, recursive_depth + 1); - (void) ge::AttrUtils::SetListNamedAttrs(obj, key, vec); - } else { - GELOGD("The list has no value, key is %s.", key.c_str()); - } - } - - // According to the property type of tensorflow, set it to the corresponding property of obj - template - static void ConvertValue(const std::string &key, const domi::tensorflow::AttrValue &value, T &obj, - const int32_t recursive_depth = 0) { - switch (value.value_case()) { - case domi::tensorflow::AttrValue::kS: - (void) ge::AttrUtils::SetStr(obj, key, value.s()); - break; - case domi::tensorflow::AttrValue::kI: - (void) ge::AttrUtils::SetInt(obj, key, static_cast(value.i())); - break; - case domi::tensorflow::AttrValue::kF: - (void) ge::AttrUtils::SetFloat(obj, key, static_cast(value.f())); - break; - case domi::tensorflow::AttrValue::kB: - (void) ge::AttrUtils::SetBool(obj, key, static_cast(value.b())); - break; - case domi::tensorflow::AttrValue::kType: { - const ge::DataType ge_data_type = - domi::TensorAssign::ConvertTensorflowDataType(static_cast(value.type())); - (void) ge::AttrUtils::SetDataType(obj, key, ge_data_type); - break; - } - case domi::tensorflow::AttrValue::kList: - ConvertList(key, value, obj, recursive_depth + 1); - break; - case domi::tensorflow::AttrValue::kShape: { - std::vector shape_dims; - ConvertShape(value.shape(), shape_dims); - (void) ge::AttrUtils::SetListInt(obj, key, shape_dims); - break; - } - case domi::tensorflow::AttrValue::kTensor: { - ge::GeTensorPtr ge_tensor = nullptr; - if (ConvertTensor(value.tensor(), ge_tensor) != GRAPH_SUCCESS) { - GE_LOGE("Convert ge tensor failed, key is %s.", key.c_str()); - return; - } - (void) ge::AttrUtils::SetTensor(obj, key, ge_tensor); - break; - } - case domi::tensorflow::AttrValue::kFunc: { - ge::NamedAttrs func; - ConvertFunc(value.func(), func, recursive_depth + 1); - (void) ge::AttrUtils::SetNamedAttrs(obj, key, func); - break; - } - case domi::tensorflow::AttrValue::kPlaceholder: - (void) ge::AttrUtils::SetStr(obj, key, value.placeholder()); - break; - case domi::tensorflow::AttrValue::VALUE_NOT_SET: - GELOGD("the attr value of %s is not set.", key.c_str()); - break; - default: - GE_LOGE("the attr value type(%d) is invalid.", static_cast(value.value_case())); - break; - } - } -}; -} // namespace ge -#endif // COMMON_AUTO_MAPPING_UTIL_H_ diff --git a/register/custom_pass_context_impl.h b/register/custom_pass_context_impl.h deleted file mode 100644 index 757b9d83d7110b9192f19978cb171541640a230f..0000000000000000000000000000000000000000 --- a/register/custom_pass_context_impl.h +++ /dev/null @@ -1,80 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_CUSTOM_PASS_CONTEXT_IMPL_H_ -#define METADEF_CXX_CUSTOM_PASS_CONTEXT_IMPL_H_ -#include "common/checker.h" -#include "graph/utils/node_adapter.h" -#include "graph/ascend_string.h" -namespace ge { -class StreamPassContextImpl { - public: - explicit StreamPassContextImpl(int64_t current_max_stream_id) : current_stream_id_(current_max_stream_id) { - } - - ~StreamPassContextImpl() = default; - - int64_t GetCurrentMaxStreamId() const { - return current_stream_id_; - } - - int64_t AllocateNextStreamId() { - return ++current_stream_id_; - } - - graphStatus SetStreamId(const GNode &node, int64_t stream_id) const { - if (stream_id < 0) { - GELOGE(PARAM_INVALID, "Failed to set unassigned stream id %ld, stream id should be positive integer.", stream_id); - return FAILED; - } - if (stream_id > current_stream_id_) { - GELOGE(PARAM_INVALID, "Failed to set unassigned stream id %ld, current_stream_id is %ld.", stream_id, - current_stream_id_); - return FAILED; - } - const auto compute_node = NodeAdapter::GNode2Node(node); - GE_ASSERT_NOTNULL(compute_node); - const auto *op_desc = compute_node->GetOpDescBarePtr(); - GE_ASSERT_NOTNULL(op_desc); - GELOGI("Set node %s stream id from %ld to %ld by custom pass", op_desc->GetNamePtr(), op_desc->GetStreamId(), - stream_id); - compute_node->GetOpDesc()->SetStreamId(stream_id); - return GRAPH_SUCCESS; - } - - static graphStatus GetStreamId(const GNode &node, int64_t &stream_id) { - const auto compute_node = NodeAdapter::GNode2Node(node); - GE_ASSERT_NOTNULL(compute_node); - GE_ASSERT_NOTNULL(compute_node->GetOpDesc()); - stream_id = compute_node->GetOpDesc()->GetStreamId(); - return GRAPH_SUCCESS; - } - private: - int64_t current_stream_id_ = 0L; -}; - -class CustomPassContextImpl { - public: - CustomPassContextImpl() = default; - ~CustomPassContextImpl() = default; - - void SetErrorMessage(const AscendString &error_message) { - error_message_ = error_message; - } - - AscendString GetErrorMessage() const { - return error_message_; - } - - private: - AscendString error_message_{""}; -}; -} // namespace ge - -#endif // METADEF_CXX_CUSTOM_PASS_CONTEXT_IMPL_H_ diff --git a/register/device_op_impl_registry.cc b/register/device_op_impl_registry.cc deleted file mode 100644 index 8a1649b776a79fa4ba57f7ae475905f4eacc41b6..0000000000000000000000000000000000000000 --- a/register/device_op_impl_registry.cc +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/device_op_impl_registry.h" -#include "register/op_def_factory.h" -namespace optiling { -class DeviceOpImplRegisterImpl {}; -DeviceOpImplRegister::DeviceOpImplRegister(const char *opType) { - ops::OpDefFactory::OpTilingSinkRegister(opType); -} -DeviceOpImplRegister::~DeviceOpImplRegister() {} -DeviceOpImplRegister::DeviceOpImplRegister(DeviceOpImplRegister &&other) noexcept { - impl_ = std::move(other.impl_); -} -DeviceOpImplRegister::DeviceOpImplRegister(const DeviceOpImplRegister &other) { - (void)other; -} -DeviceOpImplRegister &DeviceOpImplRegister::Tiling(SinkTilingFunc func) { - (void)func; - return *this; -} -} \ No newline at end of file diff --git a/register/exe_res_generation_context.cc b/register/exe_res_generation_context.cc deleted file mode 100644 index 2ca71b634cfdd044443c6e76679837bae24743a0..0000000000000000000000000000000000000000 --- a/register/exe_res_generation_context.cc +++ /dev/null @@ -1,439 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include "inc/external/exe_graph/runtime/exe_res_generation_context.h" -#include "common/checker.h" -#include "graph/any_value.h" -#include "graph/node.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "exe_graph/lowering/exe_res_generation_ctx_builder.h" -#include "exe_graph/lowering/bg_kernel_context_extend.h" -namespace gert { -namespace { -enum class InputType : int32_t { - kNode, - kInputNum, - kOutputNum, - kShapeStart -}; -void GeShapeToGertShape(const ge::GeShape &ge_shape, gert::Shape &gert_shape) { - gert_shape.SetDimNum(ge_shape.GetDimNum()); - for (size_t i = 0; i < ge_shape.GetDimNum(); ++i) { - GELOGD("Dim[%zu] val[%ld].", i, ge_shape.GetDim(i)); - gert_shape.SetDim(i, ge_shape.GetDim(i)); - } -} -} - -void ExeResGenerationCtxBuilder::CreateShapesInputs(const ge::Node &node, std::vector &inputs) { - auto op_desc = node.GetOpDesc(); - input_shapes_.reserve(op_desc->GetInputsSize()); - output_shapes_.reserve(op_desc->GetOutputsSize()); - for (const auto &in_data_anchor : node.GetAllInDataAnchors()) { - if ((in_data_anchor == nullptr) || (in_data_anchor->GetPeerOutAnchor() == nullptr)) { - GELOGD("In anchor is unused, get next."); - continue; - } - const auto input_desc = op_desc->GetInputDescPtr(static_cast(in_data_anchor->GetIdx())); - if (input_desc == nullptr) { - continue; - } - GELOGD("In anchor[%ld] push in shape.", in_data_anchor->GetIdx()); - StorageShape shape; - GeShapeToGertShape(input_desc->GetShape(), shape.MutableStorageShape()); - GeShapeToGertShape(input_desc->GetOriginShape(), shape.MutableOriginShape()); - input_shapes_.emplace_back(std::move(shape)); - } - for (const auto &out_data_anchor : node.GetAllOutDataAnchors()) { - if (out_data_anchor == nullptr || out_data_anchor->GetPeerInDataNodesSize() == 0) { - GELOGD("Node[%s] out anchor is null or peer in size is zero.", op_desc->GetNamePtr()); - continue; - } - const auto &output_desc = op_desc->GetOutputDescPtr(static_cast(out_data_anchor->GetIdx())); - if (output_desc == nullptr) { - continue; - } - StorageShape shape; - GeShapeToGertShape(output_desc->GetShape(), shape.MutableStorageShape()); - GeShapeToGertShape(output_desc->GetOriginShape(), shape.MutableOriginShape()); - output_shapes_.emplace_back(std::move(shape)); - } - GELOGD("Node[%s] input size[%zu], output size[%zu].", op_desc->GetNamePtr(), input_shapes_.size(), - output_shapes_.size()); - inputs.emplace_back(reinterpret_cast(input_shapes_.size())); - inputs.emplace_back(reinterpret_cast(output_shapes_.size())); - for (auto &in_shape : input_shapes_) { - inputs.emplace_back(&in_shape); - } - for (auto &out_shape : output_shapes_) { - inputs.emplace_back(&out_shape); - } - return; -} - -ExeResGenerationCtxHolderPtr ExeResGenerationCtxBuilder::CreateOpExeContext(ge::Node &node) { - std::vector inputs; - inputs.emplace_back(&node); - CreateShapesInputs(node, inputs); - auto exe_res_ctx_holder = gert::KernelRunContextBuilder().Inputs(inputs).Build(node.GetOpDesc()); - ctx_holder_ptr_ = ge::ComGraphMakeShared(std::move(exe_res_ctx_holder)); - if (ctx_holder_ptr_ == nullptr || ctx_holder_ptr_->context_ == nullptr) { - GE_LOGE("Op[%s][%s] create context holder failed.", node.GetNamePtr(), node.GetTypePtr()); - return nullptr; - } - auto op_exe_res_ctx = reinterpret_cast(ctx_holder_ptr_->context_); - if (!op_exe_res_ctx->CheckContextValid()) { - GE_LOGE("Op[%s][%s] create context is invalid.", node.GetNamePtr(), node.GetTypePtr()); - return nullptr; - } - const auto input_num = op_exe_res_ctx->GetInputValue(static_cast(InputType::kInputNum)); - GELOGD("Input num:%zu.", input_num); - for (size_t i = 0; i < input_num; ++i) { - auto shape = op_exe_res_ctx->GetInputPointer(static_cast(InputType::kShapeStart) + i); - GE_ASSERT_NOTNULL(shape); - GELOGD("shape[%lx] Dim num[%zu].", shape, shape->GetStorageShape().GetDimNum()); - } - return ctx_holder_ptr_; -} - -ExeResGenerationCtxHolderPtr ExeResGenerationCtxBuilder::CreateOpCheckContext(ge::Node &node) { - std::vector inputs; - inputs.emplace_back(&node); - CreateShapesInputs(node, inputs); - auto kernel_ctx_holder = gert::KernelRunContextBuilder().Inputs(inputs).Build(node.GetOpDesc()); - ctx_holder_ptr_ = ge::ComGraphMakeShared(std::move(kernel_ctx_holder)); - if (ctx_holder_ptr_ == nullptr || ctx_holder_ptr_->context_ == nullptr) { - GE_LOGE("Op[%s][%s] create context holder failed.", node.GetNamePtr(), node.GetTypePtr()); - return nullptr; - } - auto check_ctx_ptr_ = reinterpret_cast(ctx_holder_ptr_->context_); - if (!check_ctx_ptr_->CheckContextValid()) { - GE_LOGE("Op[%s][%s] create context is invalid.", node.GetNamePtr(), node.GetTypePtr()); - return nullptr; - } - return ctx_holder_ptr_; -} - -bool ExeResGenerationContext::CheckContextValid() const { - auto node_ptr = MutableInputPointer(0); - if (node_ptr == nullptr) { - GE_LOGE("Op exe res context node is null."); - REPORT_INNER_ERR_MSG("E29999", "Node create exe context failed with null node."); - return false; - } - return true; -} - -ExecuteMode ExeResGenerationContext::GetExecuteMode() const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - const auto own_graph = node_ptr->GetOwnerComputeGraph(); - GE_ASSERT_NOTNULL(own_graph); - const auto ret = own_graph->GetGraphUnknownFlag() ? ExecuteMode::kDynamicExecute : ExecuteMode::kStaticOffloadExecute; - GELOGD("Node[%s] exe mode is %d.", node_ptr->GetNamePtr(), ret); - return ret; -} - -// GetInputConstData is inaccurate(do not judge subgraph), need change new interface form GE later. -bool ExeResGenerationContext::IsConstInput(const ge::AscendString &name) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - auto op_desc = node_ptr->GetOpDesc(); - GE_ASSERT_NOTNULL(op_desc); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node_ptr->shared_from_this()); - const auto index = op_desc->GetInputIndexByName(name.GetString()); - if (index < 0) { - GE_LOGE("Op[%s][%s] get invalid index[%d] by ir name[%s].", node_ptr->GetNamePtr(), node_ptr->GetTypePtr(), index, - name.GetString()); - REPORT_INNER_ERR_MSG("E29999", "Node[%s][%s] get invalid index[%d] by ir name[%s].", node_ptr->GetNamePtr(), - node_ptr->GetTypePtr(), index, name.GetString()); - return false; - } - const bool ret = (op_desc->GetInputDesc(static_cast(index)).IsValid() == ge::GRAPH_SUCCESS) && - (ge::OpDescUtils::GetInputConstData(op, static_cast(index)) != nullptr); - GELOGD("Node[%s] input[%d] is const flag:%d.", op_desc->GetNamePtr(), index, ret); - return ret; -} - -const gert::StorageShape* ExeResGenerationContext::GetInputShape(int64_t index) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - const auto input_num = GetInputValue(static_cast(InputType::kInputNum)); - GELOGD("Node[%s] input index[%ld] with input num:%zu.", node_ptr->GetNamePtr(), index, input_num); - if (index < 0 || static_cast(index) >= input_num) { - GE_LOGE("Op[%s] input index %ld is invalid, input num is %zu.", node_ptr->GetNamePtr(), index, input_num); - REPORT_INNER_ERR_MSG("E29999", "Node[%s][%s] input index %ld is invalid, input num is %zu.", node_ptr->GetNamePtr(), - node_ptr->GetTypePtr(), index, input_num); - return nullptr; - } - return GetInputPointer(static_cast(InputType::kShapeStart) + static_cast(index)); -} - -const gert::StorageShape* ExeResGenerationContext::GetOutputShape(int64_t index) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - const auto input_num = GetInputValue(static_cast(InputType::kInputNum)); - const auto output_num = GetInputValue(static_cast(InputType::kOutputNum)); - GELOGD("Node[%s] output index[%ld] with input num:%zu, out num:%zu.", node_ptr->GetNamePtr(), index, input_num, - output_num); - if (index < 0 || static_cast(index) >= output_num) { - GE_LOGE("Op[%s] output index %ld is invalid, output num is %zu.", node_ptr->GetNamePtr(), index, output_num); - REPORT_INNER_ERR_MSG("E29999", "Node[%s][%s] output index %ld is invalid, output num is %zu.", - node_ptr->GetNamePtr(), node_ptr->GetTypePtr(), index, output_num); - return nullptr; - } - return GetInputPointer(static_cast(InputType::kShapeStart) + input_num + static_cast(index)); -} - -ge::graphStatus ExeResGenerationContext::SetAttachedStreamInfos(std::vector &stream_info_vec) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - if (stream_info_vec.empty()) { - GELOGW("Node[%s] set empty stream info vector.", node_ptr->GetNamePtr()); - return ge::GRAPH_FAILED; - } - GELOGD("Node[%s] set stream info vector size:%zu.", node_ptr->GetNamePtr(), stream_info_vec.size()); - std::vector attached_stream_info; - for (const auto &stream_info : stream_info_vec) { - if (stream_info.name.GetLength() == 0) { - REPORT_INNER_ERR_MSG("E29999", "Node[%s][%s] stream info using name is empty.", node_ptr->GetNamePtr(), - node_ptr->GetTypePtr()); - return ge::GRAPH_FAILED; - } - ge::GeAttrValue::NAMED_ATTRS attached_stream; - (void)ge::AttrUtils::SetStr(attached_stream, ge::ATTR_NAME_ATTACHED_RESOURCE_NAME, - stream_info.name.GetString()); - (void)ge::AttrUtils::SetStr(attached_stream, ge::ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY, - stream_info.reuse_key.GetString()); - // ge::ATTR_NAME_ATTACHED_STREAM_DEPEND_VALUE_LIST - (void)ge::AttrUtils::SetListInt(attached_stream, ge::ATTR_NAME_ATTACHED_RESOURCE_DEPEND_VALUE_LIST_INT, - stream_info.depend_value_input_indices); - (void)ge::AttrUtils::SetBool(attached_stream, ge::ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG, stream_info.required); - GELOGD("Stream info: name[%s], reuse_key[%s], required[%d].", stream_info.name.GetString(), - stream_info.reuse_key.GetString(), stream_info.required); - attached_stream_info.emplace_back(attached_stream); - } - (void)ge::AttrUtils::SetListNamedAttrs(node_ptr->GetOpDesc(), ge::ATTR_NAME_ATTACHED_STREAM_INFO_LIST, - attached_stream_info); - return ge::GRAPH_SUCCESS; -} - -std::vector ExeResGenerationContext::GetAttachedStreamInfos() const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - GELOGD("Node[%s] get stream info vector.", node_ptr->GetNamePtr()); - std::vector stream_info_attrs; - (void)ge::AttrUtils::GetListNamedAttrs(node_ptr->GetOpDesc(), ge::ATTR_NAME_ATTACHED_STREAM_INFO_LIST, - stream_info_attrs); - GELOGD("Node[%s] get stream info vector size:%zu.", node_ptr->GetNamePtr(), stream_info_attrs.size()); - if (stream_info_attrs.empty()) { - GELOGD("Node[%s] get empty stream info vector.", node_ptr->GetNamePtr()); - return {}; - } - std::vector stream_info_vec; - stream_info_vec.reserve(stream_info_attrs.size()); - std::string tmp_str; - for (auto &stream_info_attr : stream_info_attrs) { - StreamInfo stream_info; - (void)ge::AttrUtils::GetStr(stream_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_NAME, tmp_str); - stream_info.name = tmp_str.c_str(); - (void)ge::AttrUtils::GetStr(stream_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY, tmp_str); - stream_info.reuse_key = tmp_str.c_str(); - (void)ge::AttrUtils::GetListInt(stream_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_DEPEND_VALUE_LIST_INT, - stream_info.depend_value_input_indices); - (void)ge::AttrUtils::GetBool(stream_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG, - stream_info.required); - (void)ge::AttrUtils::GetBool(stream_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_IS_VALID, stream_info.is_valid); - (void)ge::AttrUtils::GetInt(stream_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_ID, stream_info.stream_id); - GELOGD("Get stream info:name[%s], reuse_key[%s], stream_id[%ld], required[%d], is_valid[%d].", - stream_info.name.GetString(), stream_info.reuse_key.GetString(), stream_info.stream_id, - stream_info.required, stream_info.is_valid); - stream_info_vec.emplace_back(stream_info); - } - return stream_info_vec; -} - -ge::graphStatus ExeResGenerationContext::SetListStr(const std::string &attr_name, - const std::vector &list) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - GELOGD("Node[%s] set list str:%s, size:%zu.", node_ptr->GetNamePtr(), attr_name.c_str(), list.size()); - (void)ge::AttrUtils::SetListStr(node_ptr->GetOpDesc(), attr_name, list); - return ge::GRAPH_SUCCESS; -} - -bool ExeResGenerationContext::GetStrAttrVal(const char *attr_name, - ge::AscendString &val) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - std::string val_str; - auto res = ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), attr_name, val_str); - ge::AscendString val_tmp(val_str.c_str()); - val = val_tmp; - return res; -} - -bool ExeResGenerationContext::GetIntAttrVal(const char *attr_name, int64_t &val) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - return ge::AttrUtils::GetInt(node_ptr->GetOpDesc(), attr_name, val); -} - -bool ExeResGenerationContext::SetStrAttrVal(const char *attr_name, const char *val) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - return ge::AttrUtils::SetStr(node_ptr->GetOpDesc(), attr_name, val); -} - -bool ExeResGenerationContext::SetIntAttrVal(const char *attr_name, const int64_t val) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - return ge::AttrUtils::SetInt(node_ptr->GetOpDesc(), attr_name, val); -} - -ge::graphStatus ExeResGenerationContext::SetSyncResInfos(std::vector &sync_info_vec) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - GELOGD("Node[%s] set sync info vector size:%zu.", node_ptr->GetNamePtr(), sync_info_vec.size()); - if (sync_info_vec.empty()) { - GELOGW("Node[%s] set empty sync info vector.", node_ptr->GetNamePtr()); - return ge::GRAPH_FAILED; - } - std::vector sync_info_attrs; - for (const auto &sync_info : sync_info_vec) { - if (sync_info.name.GetLength() == 0) { - REPORT_INNER_ERR_MSG("E29999", "Node[%s][%s] sync info using name is empty.", node_ptr->GetNamePtr(), - node_ptr->GetTypePtr()); - return ge::GRAPH_FAILED; - } - ge::GeAttrValue::NAMED_ATTRS sync_info_attr; - (void)ge::AttrUtils::SetInt(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_TYPE, - static_cast(sync_info.type)); - (void)ge::AttrUtils::SetStr(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_NAME, sync_info.name.GetString()); - (void)ge::AttrUtils::SetStr(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY, - sync_info.reuse_key.GetString()); - (void)ge::AttrUtils::SetBool(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG, sync_info.required); - GELOGD("Sync info:name[%s], reuse_key[%s], type[%d], required[%d].", - sync_info.name.GetString(), sync_info.reuse_key.GetString(), sync_info.type, - sync_info.required); - sync_info_attrs.emplace_back(sync_info_attr); - } - // ge::ATTR_NAME_ATTACHED_SYNC_RES_INFO - (void)ge::AttrUtils::SetListNamedAttrs(node_ptr->GetOpDesc(), ge::ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST, - sync_info_attrs); - return ge::GRAPH_SUCCESS; -} - -std::vector ExeResGenerationContext::GetSyncResInfos() const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - GELOGD("Node[%s] get sync info vector.", node_ptr->GetNamePtr()); - std::vector sync_info_attrs; - (void)ge::AttrUtils::GetListNamedAttrs(node_ptr->GetOpDesc(), ge::ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST, - sync_info_attrs); - GELOGD("Node[%s] Get sync info vector size:%zu.", node_ptr->GetNamePtr(), sync_info_attrs.size()); - if (sync_info_attrs.empty()) { - GELOGD("Node[%s] get empty sync info vector.", node_ptr->GetNamePtr()); - return {}; - } - std::vector sync_info_vec; - sync_info_vec.reserve(sync_info_attrs.size()); - std::string tmp_str; - for (auto &sync_info_attr : sync_info_attrs) { - SyncResInfo sync_info; - int64_t type = 2; - (void)ge::AttrUtils::GetInt(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_TYPE, type); - sync_info.type = static_cast(type); - (void)ge::AttrUtils::GetStr(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_NAME, tmp_str); - sync_info.name = tmp_str.c_str(); - (void)ge::AttrUtils::GetStr(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY, tmp_str); - sync_info.reuse_key = tmp_str.c_str(); - (void)ge::AttrUtils::GetBool(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG, sync_info.required); - (void)ge::AttrUtils::GetBool(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_IS_VALID, sync_info.is_valid); - (void)ge::AttrUtils::GetInt(sync_info_attr, ge::ATTR_NAME_ATTACHED_RESOURCE_ID, sync_info.sync_res_id); - sync_info_vec.emplace_back(sync_info); - GELOGD("Sync info:name[%s], reuse_key[%s], type[%d], required[%d], is_valid[%d], sync_id[%ld].", - sync_info.name.GetString(), sync_info.reuse_key.GetString(), sync_info.type, - sync_info.required, sync_info.is_valid, sync_info.sync_res_id); - } - return sync_info_vec; -} - -std::vector ExeResGenerationContext::GetWorkspaceBytes() const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - GE_ASSERT_NOTNULL(node_ptr->GetOpDesc()); - return node_ptr->GetOpDesc()->GetWorkspaceBytes(); -} - -void ExeResGenerationContext::SetWorkspaceBytes(const std::vector &workspace_bytes) const { - auto node_ptr = MutableInputPointer(0); - if (node_ptr != nullptr && node_ptr->GetOpDesc() != nullptr) { - return node_ptr->GetOpDesc()->SetWorkspaceBytes(workspace_bytes); - } -} - -int64_t ExeResGenerationContext::GetStreamId() const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - GE_ASSERT_NOTNULL(node_ptr->GetOpDesc()); - return node_ptr->GetOpDesc()->GetStreamId(); -} - -int64_t ExeResGenerationContext::GetOpId() const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - GE_ASSERT_NOTNULL(node_ptr->GetOpDesc()); - return node_ptr->GetOpDesc()->GetId(); -} - -const StorageShape* OpCheckContext::GetInputShape(int64_t index) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - const auto input_num = GetInputValue(static_cast(InputType::kInputNum)); - GELOGD("Node[%s] input index[%ld] with input num:%zu.", node_ptr->GetNamePtr(), index, input_num); - if (index < 0 || static_cast(index) >= input_num) { - GE_LOGE("Op[%s] input index %ld is invalid, input num is %zu.", node_ptr->GetNamePtr(), index, input_num); - REPORT_INNER_ERR_MSG("E29999", "Node[%s][%s] input index %ld is invalid, input num is %zu.", node_ptr->GetNamePtr(), - node_ptr->GetTypePtr(), index, input_num); - return nullptr; - } - return GetInputPointer(static_cast(InputType::kShapeStart) + static_cast(index)); -} - -const StorageShape* OpCheckContext::GetOutputShape(int64_t index) const { - auto node_ptr = MutableInputPointer(0); - GE_ASSERT_NOTNULL(node_ptr); - const auto input_num = GetInputValue(static_cast(InputType::kInputNum)); - const auto output_num = GetInputValue(static_cast(InputType::kOutputNum)); - GELOGD("Node[%s] output index[%ld] with input num:%zu, out num:%zu.", node_ptr->GetNamePtr(), index, input_num, - output_num); - if (index < 0 || static_cast(index) >= output_num) { - GE_LOGE("Op[%s] output index %ld is invalid, output num is %zu.", node_ptr->GetNamePtr(), index, output_num); - REPORT_INNER_ERR_MSG("E29999", "Node[%s][%s] output index %ld is invalid, output num is %zu.", - node_ptr->GetNamePtr(), node_ptr->GetTypePtr(), index, output_num); - return nullptr; - } - return GetInputPointer(static_cast(InputType::kShapeStart) + input_num + static_cast(index)); -} - -bool OpCheckContext::CheckContextValid() const { - auto node_ptr = MutableInputPointer(0); - if (node_ptr == nullptr) { - GE_LOGE("Op exe res context node is null."); - REPORT_INNER_ERR_MSG("E29999", "Node create exe context failed with null node."); - return false; - } - return true; -} - -} // namespace gert diff --git a/register/ffts_node_calculater_registry.cc b/register/ffts_node_calculater_registry.cc deleted file mode 100644 index c541b34039186e7c95d55196af07fffd22bba598..0000000000000000000000000000000000000000 --- a/register/ffts_node_calculater_registry.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/ffts_node_calculater_registry.h" -#include "common/hyper_status.h" - -namespace gert { -FFTSNodeCalculaterRegistry &FFTSNodeCalculaterRegistry::GetInstance() { - static FFTSNodeCalculaterRegistry registry; - return registry; -} - -FFTSNodeCalculaterRegistry::NodeCalculater FFTSNodeCalculaterRegistry::FindNodeCalculater(const string &func_name) { - auto iter = names_to_calculater_.find(func_name); - if (iter == names_to_calculater_.end()) { - return nullptr; - } - return iter->second; -} - -void FFTSNodeCalculaterRegistry::Register(const string &func_name, - const FFTSNodeCalculaterRegistry::NodeCalculater func) { - names_to_calculater_[func_name] = func; -} - -FFTSNodeCalculaterRegister::FFTSNodeCalculaterRegister(const string &func_name, - FFTSNodeCalculaterRegistry::NodeCalculater func) noexcept { - FFTSNodeCalculaterRegistry::GetInstance().Register(func_name, func); -} -} // namespace gert diff --git a/register/ffts_node_converter_registry.cc b/register/ffts_node_converter_registry.cc deleted file mode 100644 index 5f8893921b1b731bf70f24c7ea0d02706f3eddc1..0000000000000000000000000000000000000000 --- a/register/ffts_node_converter_registry.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/ffts_node_converter_registry.h" -#include "common/hyper_status.h" - -namespace gert { -FFTSNodeConverterRegistry &FFTSNodeConverterRegistry::GetInstance() { - static FFTSNodeConverterRegistry registry; - return registry; -} - -FFTSNodeConverterRegistry::NodeConverter FFTSNodeConverterRegistry::FindNodeConverter(const string &func_name) { - auto data = FindRegisterData(func_name); - if (data == nullptr) { - return nullptr; - } - return data->converter; -} -void FFTSNodeConverterRegistry::RegisterNodeConverter(const std::string &func_name, NodeConverter func) { - names_to_register_data_[func_name] = {func, -1}; -} -const FFTSNodeConverterRegistry::ConverterRegisterData *FFTSNodeConverterRegistry::FindRegisterData( - const string &func_name) const { - auto iter = names_to_register_data_.find(func_name); - if (iter == names_to_register_data_.end()) { - return nullptr; - } - return &iter->second; -} -void FFTSNodeConverterRegistry::Register(const string &func_name, - const FFTSNodeConverterRegistry::ConverterRegisterData &data) { - names_to_register_data_[func_name] = data; -} -FFTSNodeConverterRegister::FFTSNodeConverterRegister(const char *lower_func_name, - FFTSNodeConverterRegistry::NodeConverter func) noexcept { - FFTSNodeConverterRegistry::GetInstance().Register(lower_func_name, {func, -1}); -} -FFTSNodeConverterRegister::FFTSNodeConverterRegister(const char *lower_func_name, int32_t require_placement, - FFTSNodeConverterRegistry::NodeConverter func) noexcept { - FFTSNodeConverterRegistry::GetInstance().Register(lower_func_name, {func, require_placement}); -} -} // namespace gert diff --git a/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc b/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc deleted file mode 100644 index fbcc3f54deeb434d46f2e5e02c4301ed3cf7f7d2..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.cc +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h" -#include -#include -#include -#include "register/graph_optimizer/fusion_common/fusion_turbo.h" - -namespace fe { -namespace { - const std::string kAttrNameIsOpDynamicImpl = "_is_op_dynamic_impl"; - constexpr uint32_t kNoNeedCompareSize = 2; -} -BufferFusionPassBase::BufferFusionPassBase() {} - -BufferFusionPassBase::~BufferFusionPassBase() {} - -Status BufferFusionPassBase::GetFusionNodes(const BufferFusionMapping &mapping, - std::vector &fusion_nodes) { - fusion_nodes = GetMatchedNodes(mapping); - return SUCCESS; -} - -Status BufferFusionPassBase::GetMixl2FusionNodes(const BufferFusionMapping &mapping, - std::vector &fusion_nodes) { - return NOT_CHANGED; -} - -Status BufferFusionPassBase::PostFusion(const ge::NodePtr &fused_node) { - return SUCCESS; -} - -Status BufferFusionPassBase::CalcFusionOpSliceInfo(vector &fusion_nodes, OpCalcInfo &op_slice_info) { - return SUCCESS; -} - -Status BufferFusionPassBase::CheckNodeCanFusion(const BufferFusionNodeDescMap &fusion_nodes, - const ge::NodePtr &next_node) { - return SUCCESS; -} - -std::vector BufferFusionPassBase::GetMatchedNodes(const BufferFusionMapping &mapping) { - std::vector nodes; - for (const auto &item : mapping) { - for (const auto &node : item.second) { - nodes.push_back(node); - } - } - return nodes; -} - -bool BufferFusionPassBase::CheckNodeIsDynamicImpl(const ge::NodePtr &node) { - if (node == nullptr) { - return false; - } - bool is_dynamic_impl = false; - (void)ge::AttrUtils::GetBool(node->GetOpDesc(), kAttrNameIsOpDynamicImpl, is_dynamic_impl); - return is_dynamic_impl; -} - -bool BufferFusionPassBase::CheckTwoNodesImplConsistent(const ge::NodePtr &src_node, const ge::NodePtr &dst_node) { - if (src_node == nullptr || dst_node == nullptr) { - return false; - } - bool src_dynamic_impl = false; - bool dst_dynamic_impl = false; - (void)ge::AttrUtils::GetBool(src_node->GetOpDesc(), kAttrNameIsOpDynamicImpl, src_dynamic_impl); - (void)ge::AttrUtils::GetBool(dst_node->GetOpDesc(), kAttrNameIsOpDynamicImpl, dst_dynamic_impl); - return src_dynamic_impl == dst_dynamic_impl; -} - -bool BufferFusionPassBase::CheckNodesImplConsistent(const BufferFusionMapping &mapping) { - const std::vector fusion_nodes = GetMatchedNodes(mapping); - return CheckNodesImplConsistent(fusion_nodes); -} - -bool BufferFusionPassBase::CheckNodesImplConsistent(const std::vector &fusion_nodes) { - if (fusion_nodes.size() < kNoNeedCompareSize) { - return true; - } - const ge::NodePtr first_node = fusion_nodes[0]; - for (size_t index = 1; index < fusion_nodes.size(); ++index) { - if (!CheckTwoNodesImplConsistent(first_node, fusion_nodes[index])) { - return false; - } - } - return true; -} - -bool BufferFusionPassBase::CheckNodeIsDynamicShape(const ge::NodePtr& node) { - const ge::OpDescPtr op_desc = node->GetOpDesc(); - for (size_t index = 0; index < op_desc->GetAllInputsSize(); ++index) { - if (FusionTurbo::IsUnknownShape(node, static_cast(index), true)) { - return true; - } - } - - for (size_t index = 0; index < op_desc->GetAllOutputsDescSize(); ++index) { - if (FusionTurbo::IsUnknownShape(node, static_cast(index), false)) { - return true; - } - } - return false; -} - -bool BufferFusionPassBase::CheckNodesIncDynamicShape(const BufferFusionMapping &mapping) { - const std::vector fusion_nodes = GetMatchedNodes(mapping); - return CheckNodesIncDynamicShape(fusion_nodes); -} - -bool BufferFusionPassBase::CheckNodesIncDynamicShape(const std::vector &fusion_nodes) { - for (const auto &node : fusion_nodes) { - if (CheckNodeIsDynamicShape(node)) { - return true; - } - } - return false; -} - -std::vector BufferFusionPassBase::GetMatchedNodesByDescName(const std::string &desc_name, - const BufferFusionMapping &mapping) { - std::vector nodes; - for (const auto &item : mapping) { - const BufferFusionOpDesc *const op_desc = item.first; - if ((op_desc != nullptr) && (op_desc->desc_name == desc_name)) { - for (const auto &node : item.second) { - nodes.push_back(node); - } - } - } - return nodes; -} - -ge::NodePtr BufferFusionPassBase::GetMatchedHeadNode(const std::vector &matched_nodes) { - for (const auto &node : matched_nodes) { - const auto input_nodes = node->GetInDataNodes(); - bool find_flag = false; - for (const auto &in_node : input_nodes) { - // find the node from fuison sub graph - if (std::find(matched_nodes.begin(), matched_nodes.end(), in_node) != matched_nodes.end()) { - find_flag = true; - break; - } - } - if (!find_flag) { - return node; - } - } - return nullptr; -} - -} // namespace fe diff --git a/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc b/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc deleted file mode 100644 index c34dc101f20848ec51139e132c11c3951a277bed..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.cc +++ /dev/null @@ -1,123 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h" -#include -#include -#include -#include -#include -#include "graph/debug/ge_log.h" - -namespace fe { -class BufferFusionPassRegistry::BufferFusionPassRegistryImpl { - public: - void RegisterPass(const BufferFusionPassType pass_type, const std::string &pass_name, - BufferFusionPassRegistry::CreateFn const create_fn, PassAttr attr) { - RegPassCompileLevel(pass_name, attr); - const std::string pass_module = IsPassAttrTypeOn(attr, PassAttrType::FE_PASS_FLAG) ? "FE" : "TBE"; - const std::lock_guard lock(mu_); - std::map>::const_iterator iter = pass_descs_.find(pass_type); - if (iter != pass_descs_.cend()) { - pass_descs_[pass_type][pass_name].attr = attr; - pass_descs_[pass_type][pass_name].create_fn = create_fn; - GELOGI("type=%d, name=%s, attr=%lu, module=%s.", pass_type, pass_name.c_str(), attr, pass_module.c_str()); - return; - } - - std::map pass_desc; - pass_desc[pass_name] = {attr, create_fn}; - pass_descs_[pass_type] = pass_desc; - GELOGI("type=%d, name=%s, attr=%lu, module=%s.", pass_type, pass_name.c_str(), attr, pass_module.c_str()); - } - - std::map GetCreateFn(const BufferFusionPassType &pass_type) { - const std::lock_guard lock(mu_); - std::map ret; - std::map>::const_iterator iter = pass_descs_.find(pass_type); - if (iter == pass_descs_.cend()) { - return ret; - } - for (const auto &ele : iter->second) { - std::ignore = ret.emplace(std::make_pair(ele.first, ele.second.create_fn)); - } - return ret; - } - - std::map GetPassDesc(const BufferFusionPassType &pass_type) { - const std::lock_guard lock(mu_); - std::map>::const_iterator iter = pass_descs_.find(pass_type); - if (iter == pass_descs_.cend()) { - std::map result; - return result; - } - return iter->second; - } - private: - std::mutex mu_; - std::map> pass_descs_; -}; - -BufferFusionPassRegistry::BufferFusionPassRegistry() { - impl_ = std::unique_ptr(new (std::nothrow) BufferFusionPassRegistryImpl); -} - -BufferFusionPassRegistry::~BufferFusionPassRegistry() {} - -BufferFusionPassRegistry &BufferFusionPassRegistry::GetInstance() { - static BufferFusionPassRegistry instance; - return instance; -} - -void BufferFusionPassRegistry::RegisterPass(const BufferFusionPassType pass_type, const std::string &pass_name, - CreateFn create_fn, PassAttr attr) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "[Check][Param]UbFusionPass[type=%d,name=%s]: failed to register the ub fusion pass", - pass_type, pass_name.c_str()); - return; - } - impl_->RegisterPass(pass_type, pass_name, create_fn, attr); -} - -std::map BufferFusionPassRegistry::GetPassDesc( - const BufferFusionPassType &pass_type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "[Check][Param]UbFusionPass[type=%d]: failed to get pass desc", pass_type); - std::map ret; - return ret; - } - return impl_->GetPassDesc(pass_type); -} - -std::map BufferFusionPassRegistry::GetCreateFnByType( - const BufferFusionPassType &pass_type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "[Check][Param]UbFusionPass[type=%d]: failed to create the ub fusion pass", pass_type); - return std::map{}; - } - return impl_->GetCreateFn(pass_type); -} - -BufferFusionPassRegistrar::BufferFusionPassRegistrar(const BufferFusionPassType &pass_type, - const std::string &pass_name, - BufferFusionPassBase *(*create_fun)(), - PassAttr attr) { - if ((pass_type < BUILT_IN_AI_CORE_BUFFER_FUSION_PASS) || (pass_type >= BUFFER_FUSION_PASS_TYPE_RESERVED)) { - GELOGE(ge::PARAM_INVALID, "[Check][Param:pass_type] value %d is not supported.", pass_type); - return; - } - - if (pass_name.empty()) { - GELOGE(ge::PARAM_INVALID, "[Check][Param:pass_name]Failed to register the ub fusion pass, the pass name is empty."); - return; - } - - BufferFusionPassRegistry::GetInstance().RegisterPass(pass_type, pass_name, create_fun, attr); -} -} // namespace fe diff --git a/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc b/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc deleted file mode 100644 index be86960d2f317466010243395f189cf250f52117..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.cc +++ /dev/null @@ -1,397 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h" -#include -#include -#include "graph/debug/ge_log.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" - -namespace fe { -using std::map; -using std::string; -using std::vector; - -const int64_t TBE_FUSION_OP_NUM_MAX = 5L; -const int64_t TBE_PATTERN_NUM_MAX = 5L; -const int64_t TBE_PATTERN_NUM_NONE = 0L; -const int64_t TBE_PATTERN_NUM_DEFAULT = 1L; -const int64_t TBE_OUTPUT_BRANCH_DEFAULT = 0L; -const int64_t TBE_OUTPUT_BRANCH_SINGLE = 1L; -const int64_t TBE_OUTPUT_BRANCH_MULTI = 2L; -const int64_t TBE_PATTERN_GROUPID_INVALID = -1L; -const int32_t TBE_OUTPUT_MAX_NUM_LIMIT = 10; - -const std::map kShapeTypeRuleToStr { - {IGNORE_SHAPE_TYPE, "IGNORE_SHAPE_TYPE"}, - {ONLY_SUPPORT_STATIC, "ONLY_SUPPORT_STATIC"}, - {ONLY_SUPPORT_DYNAMIC, "ONLY_SUPPORT_DYNAMIC"} -}; - -inline bool IsAddOverflow(const int64_t &a, const int64_t &b) { - return ((b > 0) && (a > (static_cast(INT64_MAX) - b))) || \ - ((b < 0) && (a < (static_cast(INT64_MIN) - b))); -} - -BufferFusionPattern::BufferFusionPattern(string name, int64_t op_max_count) - : name_(name), op_max_count_(op_max_count), error_count_(0), - graph_mod_type_(0) {} - -BufferFusionPattern::~BufferFusionPattern() { - for (auto op : ops_) { - if (op == nullptr) { - continue; - } - delete (op); - } -} - -bool BufferFusionPattern::IsOpDescValid(const std::string &desc_name, int64_t repeat_min, int64_t repeat_max) const { - if (desc_name.empty()) { - GELOGW("[IsOpDescValid][Check] The desc_name cannot be empty."); - return false; - } - - if (repeat_min > repeat_max) { - GELOGW("[IsOpDescValid][Check] Check desc %s failed as repeat_min > repeat_max; repeat_min=%ld, repeat_max=%ld", - desc_name.c_str(), repeat_min, repeat_max); - return false; - } - - if (GetOpDesc(desc_name) != nullptr) { - GELOGW("[IsOpDescValid][Check] Desc_name repeated. (desc_name:%s)", desc_name.c_str()); - return false; - } - return true; -} - -bool BufferFusionPattern::IsShapeRulesSizeValid(const size_t &types_size, const size_t &rules_size) const { - if (rules_size == 1 || types_size == rules_size) { - return true; - } - GELOGW("[IsShapeRulesSizeValid][Check] rule size invalid, rules_size:%zu, types_size:%zu", rules_size, types_size); - return false; -} - -/* - * @brief: add op desc info - * @param [in] desc_name: node desc name - * @param [in] types: node desc type - * @param [in] repeate_min: the min count for fusion match, - * patter match failed if real count lower than the - * value - * @param [in] repeate_max: the max count for fusion match, - * the op will be ignored if current match count equal - * with the value - * @return BufferFusionPattern: pattern object - */ -BufferFusionPattern &BufferFusionPattern::AddOpDesc(const std::string &desc_name, const std::vector &types, - const int64_t repeat_min, const int64_t repeat_max, - const int64_t group_id, const ShapeTypeRule shape_type_rule, - const bool not_pattern, const bool is_allow_series) { - std::vector shape_type_rules = {shape_type_rule}; - return AddOpDescTypeRules(desc_name, types, repeat_min, repeat_max, group_id, shape_type_rules, - not_pattern, is_allow_series); -} - -BufferFusionPattern &BufferFusionPattern::AddOpDesc(const std::string &desc_name, const std::vector &types, - const int64_t repeat_min, const int64_t repeat_max, - bool is_allow_series) { - return AddOpDescTypeRules(desc_name, types, repeat_min, repeat_max, TBE_PATTERN_GROUPID_INVALID, - {ONLY_SUPPORT_STATIC}, false, is_allow_series); -} - -BufferFusionPattern &BufferFusionPattern::AddOpDescTypeRules(const std::string &desc_name, - const std::vector &types, - const int64_t repeat_min, const int64_t repeat_max, - const int64_t group_id, - const std::vector &shape_type_rules, - const bool not_pattern, const bool is_allow_series) { - if (!IsOpDescValid(desc_name, repeat_min, repeat_max)) { - IncreaseErrorCount(); - return *this; - } - if (!IsShapeRulesSizeValid(types.size(), shape_type_rules.size())) { - IncreaseErrorCount(); - return *this; - } - - BufferFusionOpDesc *op = new (std::nothrow) BufferFusionOpDesc(); - if (op == nullptr) { - GELOGW("[AddOpDesc][Check] Failed to create a new object."); - IncreaseErrorCount(); - return *this; - } - - op->desc_name = desc_name; - op->types = types; - op->repeate_min = repeat_min; - op->repeate_max = repeat_max; - op->repeate_curr = 0; - op->group_id = group_id; - op->shape_type_rules = shape_type_rules; - op->match_status = false; - op->out_branch_type = TBE_OUTPUT_BRANCH_DEFAULT; - op->output_max_limit = TBE_OUTPUT_MAX_NUM_LIMIT; - op->ignore_input_num = false; - op->ignore_output_num = false; - op->not_pattern = not_pattern; - op->is_allow_series = is_allow_series; - if (repeat_max > repeat_min) { - for (int64_t i = repeat_min; i < repeat_max; i++) { - (void)op->multi_output_skip_status.insert(std::pair(i, SkipStatus::DISABLED)); - } - } - ops_.push_back(op); - op_map_[desc_name] = op; - - op->outputs.clear(); - return *this; -} - -/* - * @brief: set output desc info - * @param [in] desc_name: node desc name - * @param [in] output_ids: output desc - * @param [in] relation: output desc relation (1: serial, 2:parallel) - * @return BufferFusionPattern: pattern object - */ -BufferFusionPattern &BufferFusionPattern::SetOutputs(const string &desc_name, const std::vector &output_ids, - int64_t relation, bool ignore_input_num, bool ignore_output_num, - int32_t output_max_limit) { - if (desc_name.empty()) { - GELOGW("[SetOutputs][Check] Desc_name must not be empty."); - IncreaseErrorCount(); - return *this; - } - - BufferFusionOpDesc *op_desc = GetOpDesc(desc_name); - if (op_desc == nullptr) { - GELOGW("[SetOutputs][Check] Desc_name %s does not exist", desc_name.c_str()); - IncreaseErrorCount(); - return *this; - } - op_desc->output_max_limit = output_max_limit; - op_desc->ignore_input_num = ignore_input_num; - op_desc->ignore_output_num = ignore_output_num; - if (op_desc->out_branch_type == TBE_OUTPUT_BRANCH_DEFAULT) { - op_desc->out_branch_type = relation; - } - - UpdateSkipStatus(op_desc); - - // support one multi output for one op_type - for (const string &output_id : output_ids) { - BufferFusionOpDesc *output_op_desc = GetOpDesc(output_id); - if (output_op_desc == nullptr) { - GELOGW("[SetOutputs][Check] Desc_name does not exist. (desc_name:%s)", desc_name.c_str()); - if (IsAddOverflow(error_count_, 1) != SUCCESS) { - GELOGW("[SetOutputs][Check] errorCount_++ overflow. (desc_name:%s)", desc_name.c_str()); - return *this; - } - IncreaseErrorCount(); - return *this; - } - if (op_desc == output_op_desc) { - continue; - } - - op_desc->outputs.push_back(output_op_desc); - output_op_desc->inputs.push_back(op_desc); - - if (op_desc->out_branch_type != relation) { - GELOGW("[SetOutputs][Check] Setting outputs relation failed. Current value: %ld, New value: %ld.", op_desc->out_branch_type, - relation); - return *this; - } - } - return *this; -} - -/* - * @brief: get output desc info - * @param [in] op_desc: current desc - * @param [out] outputs: candidate output desc set - * @return bool: get output desc ok or not - */ -bool BufferFusionPattern::GetOutputs(BufferFusionOpDesc *op_desc, std::vector &outputs, - bool ignore_repeat) { - if (op_desc == nullptr) { - GELOGW("[GetOutputs][Check] op_desc is null."); - return false; - } - - // add curr desc can be reused while repeat_curr < repeate_max - if ((!ignore_repeat) && (op_desc->repeate_curr < op_desc->repeate_max)) { - outputs.push_back(op_desc); - } - - // check candidate desc - for (auto desc : op_desc->outputs) { - if (desc == nullptr) { - continue; - } - // add out desc - outputs.push_back(desc); - - // add sub out_descs while repeate_min == 0 - if (desc->repeate_min == 0) { - std::vector sub_output; - if (GetOutputs(desc, sub_output, true)) { - for (const auto &sub_desc : sub_output) { - outputs.push_back(sub_desc); - } - } - } - } - - return true; -} - -void BufferFusionPattern::IncreaseErrorCount() { - if (error_count_ < std::numeric_limits::max()) { - error_count_++; - return; - } - GELOGW("[IncreaseErrorCount][Check] error_count_ has overflowed."); -} -/* - * @brief: set fusion pattern head - * @param [in] head_ids: node list - * @return bool: set head desc ok or not - */ -BufferFusionPattern &BufferFusionPattern::SetHead(const std::vector &head_ids) { - if (head_ids.empty()) { - GELOGW("[SetHead][Check] The input head_ids is empty."); - IncreaseErrorCount(); - return *this; - } - for (const string &head_id : head_ids) { - BufferFusionOpDesc *head_op_desc = GetOpDesc(head_id); - if (head_op_desc == nullptr) { - GELOGW("[SetHead][Check] descName does not exist. (desc_name:%s)", head_id.c_str()); - if (IsAddOverflow(error_count_, 1) != SUCCESS) { - GELOGW("[SetHead][Check] errorCount_++ overflow. (desc_name:%s)", head_id.c_str()); - return *this; - } - IncreaseErrorCount(); - return *this; - } - // Head desc repeat number can not exceed 1 - // if must be exceed 1, it can be realized by several descs - if (head_op_desc->repeate_max > 1) { - GELOGW("[SetHead][Check] Head description named %s repeats more than once, current max repeat count is %ld", head_id.c_str(), - head_op_desc->repeate_max); - if (IsAddOverflow(error_count_, 1) != SUCCESS) { - GELOGW("[SetHead][Check] errorCount_++ overflow. (desc_name:%s)", head_id.c_str()); - return *this; - } - IncreaseErrorCount(); - return *this; - } - head_.push_back(head_op_desc); - } - - // check head desc repeat min total value, it can not excceed 1 - int64_t desc_total_min = 0; - for (const auto &desc : head_) { - if (IsAddOverflow(desc_total_min, desc->repeate_min) != SUCCESS) { - GELOGW("[SetHead][Check] desc_total_min[%ld] + repeate_min[%ld] overflow", desc_total_min, desc->repeate_min); - return *this; - } - desc_total_min += desc->repeate_min; - } - - if (desc_total_min > 1) { - GELOGW("[SetHead][Check] Head desc repeat min total value cannot exceed 1, current value is %ld", - desc_total_min); - IncreaseErrorCount(); - return *this; - } - return *this; -} - -void BufferFusionPattern::UpdateSkipStatus(const BufferFusionOpDesc *op_desc) const { - if (op_desc->out_branch_type == TBE_OUTPUT_BRANCH_MULTI) { - for (auto &input_desc : op_desc->inputs) { - if (input_desc->types.size() != op_desc->types.size()) { - continue; - } - bool is_same_type = true; - for (size_t i = 0; i < input_desc->types.size(); i++) { - if (input_desc->types[i] != op_desc->types[i]) { - is_same_type = false; - break; - } - } - if (is_same_type && (input_desc->ignore_output_num)) { - for (int64_t i = input_desc->repeate_min; i < input_desc->repeate_max; i++) { - input_desc->multi_output_skip_status[i] = SkipStatus::AVAILABLE; - } - } - } - } -} - -BufferFusionPattern &BufferFusionPattern::SetRelation(const std::string &src_desc_name, - const std::string &dst_desc_name, - const PatternRelation pattern_relation) { - if (src_desc_name.empty() || dst_desc_name.empty()) { - GELOGW("[SetRelation][Check] Source description name or destination description name is empty."); - IncreaseErrorCount(); - return *this; - } - BufferFusionOpDesc *src_op_desc = GetOpDesc(src_desc_name); - if (src_op_desc == nullptr) { - GELOGW("[SetRelation][Check] Op desc of [%s] is null.", src_desc_name.c_str()); - IncreaseErrorCount(); - return *this; - } - BufferFusionOpDesc *dst_op_desc = GetOpDesc(dst_desc_name); - if (dst_op_desc == nullptr) { - GELOGW("[SetRelation][Check] Op desc of [%s] is null.", dst_desc_name.c_str()); - IncreaseErrorCount(); - return *this; - } - src_op_desc->relations.push_back(std::make_pair(dst_op_desc, pattern_relation)); - if (pattern_relation == PatternRelation::RELATIVE_POSITION_CONSISTENT) { - dst_op_desc->relations.push_back(std::make_pair(src_op_desc, pattern_relation)); - } - return *this; -} - -/* - * @brief: get description ptr by name - * @param [in] desc_name: fusion pattern desc name - * @return BufferFusionOpDesc*: description ptr - */ -BufferFusionOpDesc *BufferFusionPattern::GetOpDesc(const string &desc_name) const { - const auto it = op_map_.find(desc_name); - if (it != op_map_.end()) { - return it->second; - } - return nullptr; -} - -const std::vector& BufferFusionPattern::GetHead() const { return head_; } - -const std::string& BufferFusionPattern::GetName() const { return name_; } - -int64_t BufferFusionPattern::GetOpMaxCount() const { return op_max_count_; } - -int64_t BufferFusionPattern::GetErrorCnt() const { return error_count_; } - -void BufferFusionPattern::SetGraphModType(int64_t graph_mod_type) { - graph_mod_type_ = graph_mod_type; -} - -int64_t BufferFusionPattern::GetGraphModType() const { return graph_mod_type_; } - -const std::vector& BufferFusionPattern::GetOpDescs() const { return ops_; } -} // namespace fe diff --git a/register/graph_optimizer/fusion_common/fusion_config_info.cc b/register/graph_optimizer/fusion_common/fusion_config_info.cc deleted file mode 100644 index cc4166cbe1eb8b325a8460c6512ebf12b3a5ac3f..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/fusion_common/fusion_config_info.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/fusion_config_info.h" -#include "mmpa/mmpa_api.h" -#include "graph/debug/ge_log.h" - -namespace fe { -FusionConfigInfo& FusionConfigInfo::Instance() { - static FusionConfigInfo fusion_config_info; - return fusion_config_info; -} - -Status FusionConfigInfo::Initialize() { - if (is_init_) { - return SUCCESS; - } - - InitEnvParam(); - is_init_ = true; - return SUCCESS; -} - -void FusionConfigInfo::InitEnvParam() { - const char *env_value = nullptr; - MM_SYS_GET_ENV(MM_ENV_ENABLE_NETWORK_ANALYSIS_DEBUG, env_value); - if (env_value != nullptr) { - std::string env_str_value = std::string(env_value); - GELOGD("The value of env[ENABLE_NETWORK_ANALYSIS_DEBUG] is [%s].", env_str_value.c_str()); - is_enable_network_analysis_ = static_cast (std::stoi(env_str_value.c_str())); - } - GELOGD("Enable network analysis is set to [%d].", is_enable_network_analysis_); -} - -Status FusionConfigInfo::Finalize() { - is_init_ = false; - is_enable_network_analysis_ = false; - return SUCCESS; -} - -bool FusionConfigInfo::IsEnableNetworkAnalysis() const { - return is_enable_network_analysis_; -} -} diff --git a/register/graph_optimizer/fusion_common/fusion_pass_desc.cc b/register/graph_optimizer/fusion_common/fusion_pass_desc.cc deleted file mode 100644 index f52fc4c574e130e99a9962907d2c3942636d1378..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/fusion_common/fusion_pass_desc.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/fusion_pass_desc.h" -#include "ge_common/debug/ge_log.h" -#include "register/optimization_option_registry.h" -namespace fe { -bool IsPassAttrTypeOn(PassAttr pass_attr, PassAttrType attr_type) { - return ((pass_attr >> static_cast(attr_type)) & PASS_BIT_MASK) == 1; -} -void RegPassCompileLevel(const std::string &pass_name, PassAttr pass_attr) { - std::vector levels; - if (IsPassAttrTypeOn(pass_attr, PassAttrType::COMPILE_O0)) { - levels.emplace_back(ge::OoLevel::kO0); - } - if (IsPassAttrTypeOn(pass_attr, PassAttrType::COMPILE_O1)) { - levels.emplace_back(ge::OoLevel::kO1); - } - if (IsPassAttrTypeOn(pass_attr, PassAttrType::COMPILE_O2)) { - levels.emplace_back(ge::OoLevel::kO2); - } - if (IsPassAttrTypeOn(pass_attr, PassAttrType::COMPILE_O3)) { - levels.emplace_back(ge::OoLevel::kO3); - } - if (levels.empty()) { - return; - } - const uint64_t level_bits = ge::OoInfoUtils::GenOptLevelBits(levels); - GELOGD("Fusion name [%s] registered with compile level %lu.", pass_name.c_str(), level_bits); - ge::OoInfo opt{pass_name, ge::OoHierarchy::kH1, level_bits}; - ge::OptionRegistry::GetInstance().Register(opt); - ge::PassOptionRegistry::GetInstance().Register(pass_name, {{ge::OoHierarchy::kH1, opt.name}}); - return; -} -} diff --git a/register/graph_optimizer/fusion_common/fusion_turbo.cc b/register/graph_optimizer/fusion_common/fusion_turbo.cc deleted file mode 100644 index 7de5e7e4ac09815e5b8a28840cf52b08202a0015..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/fusion_common/fusion_turbo.cc +++ /dev/null @@ -1,1294 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/fusion_turbo.h" -#include "graph/operator_factory.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" - -#include - -namespace fe { -const std::string kNetOutput = "NetOutput"; -WeightInfo::WeightInfo(const ge::GeTensorDesc &tensor_desc, void *data_p) - : data(reinterpret_cast(data_p)) { - shape = tensor_desc.GetShape(); - ori_shape = tensor_desc.GetOriginShape(); - datatype = tensor_desc.GetDataType(); - ori_datatype = tensor_desc.GetOriginDataType(); - format = tensor_desc.GetFormat(); - ori_format = tensor_desc.GetOriginFormat(); - CalcTotalDataSize(); -} - -WeightInfo::WeightInfo(const ge::NodePtr &node, const int32_t &index, void *data_p) - : data(reinterpret_cast(data_p)) { - if (node == nullptr) { - return; - } - const auto tensor = node->GetOpDesc()->MutableInputDesc(static_cast(index)); - if (tensor == nullptr) { - return; - } - shape = tensor->GetShape(); - ori_shape = tensor->GetOriginShape(); - datatype = tensor->GetDataType(); - ori_datatype = tensor->GetOriginDataType(); - format = tensor->GetFormat(); - ori_format = tensor->GetOriginFormat(); - CalcTotalDataSize(); -} - -WeightInfo::WeightInfo(const ge::GeShape &shape_p, const ge::GeShape &ori_shape_p, - const ge::DataType &datatype_p, const ge::DataType &ori_datatype_p, - const ge::Format &format_p, const ge::Format &ori_format_p, void *data_p) - : shape(shape_p), - ori_shape(ori_shape_p), - datatype(datatype_p), - ori_datatype(ori_datatype_p), - format(format_p), - ori_format(ori_format_p), - data(reinterpret_cast(data_p)) { - CalcTotalDataSize(); -} - -WeightInfo::WeightInfo(ge::GeShape &&shape_p, ge::GeShape &&ori_shape_p, - const ge::DataType &datatype_p, const ge::DataType &ori_datatype_p, - const ge::Format &format_p, const ge::Format &ori_format_p, void *data_p) - : shape(std::move(shape_p)), - ori_shape(std::move(ori_shape_p)), - datatype(datatype_p), - ori_datatype(ori_datatype_p), - format(format_p), - ori_format(ori_format_p), - data(reinterpret_cast(data_p)) { - CalcTotalDataSize(); -} - -WeightInfo::WeightInfo(const ge::GeShape &shape_p, const ge::DataType &datatype_p, - const ge::Format &format_p, void *data_p) - : shape(shape_p), - ori_shape(shape_p), - datatype(datatype_p), - ori_datatype(datatype_p), - format(format_p), - ori_format(format_p), - data(reinterpret_cast(data_p)) { - CalcTotalDataSize(); -} - -WeightInfo::WeightInfo(ge::GeShape &&shape_p, const ge::DataType &datatype_p, - const ge::Format &format_p, void *data_p) - :shape(std::move(shape_p)), - ori_shape(shape_p), - datatype(datatype_p), - ori_datatype(datatype_p), - format(format_p), - ori_format(format_p), - data(reinterpret_cast(data_p)) { - CalcTotalDataSize(); -} - -FusionTurbo::FusionTurbo(const ge::ComputeGraphPtr &graph) : graph_(graph) {} - -FusionTurbo::FusionTurbo(ge::ComputeGraph &graph) : graph_(graph.shared_from_this()) {} - -FusionTurbo::~FusionTurbo() {} - -Status FusionTurbo::BreakInput(const ge::NodePtr &node, - const vector &input_index) { - for (const auto &index : input_index) { - const auto in_anchor = node->GetInDataAnchor(index); - if (in_anchor == nullptr) { - continue; - } - - in_anchor->UnlinkAll(); - } - return SUCCESS; -} - -Status FusionTurbo::BreakOutput(const ge::NodePtr &node, - const vector &output_index) { - for (const auto &index : output_index) { - const auto out_anchor = node->GetOutDataAnchor(index); - if (out_anchor == nullptr) { - continue; - } - - out_anchor->UnlinkAll(); - } - return SUCCESS; -} - -Status FusionTurbo::BreakAllInput(const ge::NodePtr &node) { - const auto input_anchors = node->GetAllInDataAnchors(); - for (const auto &in_anchor : input_anchors) { - if (in_anchor == nullptr) { - continue; - } - - in_anchor->UnlinkAll(); - } - return SUCCESS; -} - -Status FusionTurbo::BreakAllOutput(const ge::NodePtr &node) { - const auto output_anchors = node->GetAllOutDataAnchors(); - for (const auto &out_anchor : output_anchors) { - if (out_anchor == nullptr) { - continue; - } - - out_anchor->UnlinkAll(); - } - return SUCCESS; -} - -Status FusionTurbo::RemoveNodeWithRelink(const ge::NodePtr &node, const std::initializer_list &io_map) { - return RemoveNodeWithRelink(node, std::vector(io_map)); -} - -Status FusionTurbo::RemoveNodeWithRelink(const ge::NodePtr &node, const std::vector &io_map) { - FUSION_TURBO_NOTNULL(node, PARAM_INVALID); - if (ge::GraphUtils::IsolateNode(node, io_map) != ge::GRAPH_SUCCESS) { - return FAILED; - } - - if (ge::GraphUtils::RemoveNodeWithoutRelink(graph_, node) != ge::GRAPH_SUCCESS) { - return FAILED; - } - - return SUCCESS; -} - -/* Just remove the node and all its relative data and control anchors. */ -Status FusionTurbo::RemoveNodeOnly(const ge::NodePtr &node) { - FUSION_TURBO_NOTNULL(node, PARAM_INVALID); - ge::NodeUtils::UnlinkAll(*node); - - if (ge::GraphUtils::RemoveNodeWithoutRelink(graph_, node) != ge::GRAPH_SUCCESS) { - return FAILED; - } - return SUCCESS; -} - -Status FusionTurbo::RemoveDanglingNode(const ge::NodePtr &node, const bool &only_care_data_nodes) { - FUSION_TURBO_NOTNULL(node, PARAM_INVALID); - bool able_to_remove = false; - if (only_care_data_nodes) { - if (!HasOutData(node)) { - able_to_remove = true; - } - } else { - if (!HasOutData(node) && !HasOutControl(node)) { - able_to_remove = true; - } - } - if (able_to_remove) { - return RemoveNodeOnly(node); - } - return FAILED; -} - -Status FusionTurbo::RemoveMultiNodesOnly(const std::vector &nodes) { - for (const auto &ele : nodes) { - if (RemoveNodeOnly(ele) != SUCCESS) { - return FAILED; - } - } - return SUCCESS; -} - -static ge::GeTensorPtr GenerateWeightTensor(const WeightInfo &w_info) { - ge::GeTensorDesc new_weight_tensor; - new_weight_tensor.SetShape(w_info.shape); - new_weight_tensor.SetDataType(w_info.datatype); - new_weight_tensor.SetFormat(w_info.format); - new_weight_tensor.SetOriginShape(w_info.ori_shape); - new_weight_tensor.SetOriginDataType(w_info.ori_datatype); - new_weight_tensor.SetOriginFormat(w_info.ori_format); - if (w_info.total_data_size == 0) { - return nullptr; - } - ge::GeTensorPtr w = nullptr; - GE_MAKE_SHARED(w = std::make_shared( - new_weight_tensor, - reinterpret_cast(w_info.data), w_info.total_data_size), - return nullptr); - return w; -} - -static inline ge::NodePtr GetPeerOutNode(const ge::NodePtr &node, - const int32_t index) { - const auto in_anchor = node->GetInDataAnchor(index); - FUSION_TURBO_NOTNULL(in_anchor, nullptr); - const auto peer_anchor = in_anchor->GetPeerOutAnchor(); - FUSION_TURBO_NOTNULL(peer_anchor, nullptr); - auto peer_node = peer_anchor->GetOwnerNode(); - return peer_node; -} - -static void UpdateTensor(const ge::GeTensorDescPtr &tensor, const WeightInfo &w_info) { - if (tensor == nullptr) { - return; - } - tensor->SetDataType(w_info.datatype); - tensor->SetOriginDataType(w_info.ori_datatype); - tensor->SetFormat(w_info.format); - tensor->SetOriginFormat(w_info.ori_format); - tensor->SetShape(w_info.shape); - tensor->SetOriginShape(w_info.ori_shape); -} - -ge::NodePtr FusionTurbo::AddConstNode(const ge::NodePtr &node, const WeightInfo &w_info, - const int32_t index) const { - const auto node_in_tensor = std::const_pointer_cast( - node->GetOpDesc()->GetInputDescPtrDfault(static_cast(index))); - FUSION_TURBO_NOTNULL(node_in_tensor, nullptr); - UpdateTensor(node_in_tensor, w_info); - - ge::GeTensorPtr const_out_tenosr = nullptr; - GE_MAKE_SHARED(const_out_tenosr = std::make_shared(*node_in_tensor), return nullptr); - - const Status ret = const_out_tenosr->SetData(reinterpret_cast(w_info.data), - w_info.total_data_size); - if (ret != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][AddWeight][AddConstNode] Failed to set data."); - return nullptr; - } - ge::OpDescPtr const_op_desc = ge::OpDescUtils::CreateConstOp(const_out_tenosr); - - auto const_node = graph_->AddNode(const_op_desc); - if (const_node == nullptr) { - GELOGE(FAILED, "[FusionTurbo][AddWeight][AddConstNode] Failed to add const node."); - return nullptr; - } - - GELOGD("Successfully created const input [%s] for node [%s].", const_op_desc->GetName().c_str(), - node->GetName().c_str()); - if (ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), node->GetInDataAnchor(index)) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][AddWeight][AddConstNode] Failed to add edge between const %s and index %d of %s.", - const_node->GetName().c_str(), index, node->GetName().c_str()); - } - - return const_node; -} - -ge::NodePtr FusionTurbo::UpdateConst(const ge::NodePtr &node, const int32_t &index, - const WeightInfo &w_info) const { - auto const_node = FusionTurboUtils::GetConstInput(node, index); - if (const_node == nullptr) { - return nullptr; - } - const auto const_op = const_node->GetOpDesc(); - const auto const_out_tensor_desc = const_op->MutableOutputDesc(0); - UpdateTensor(const_out_tensor_desc, w_info); - - const auto node_in_tensor = node->GetOpDesc()->MutableInputDesc(static_cast(index)); - FUSION_TURBO_NOTNULL(node_in_tensor, nullptr); - UpdateTensor(node_in_tensor, w_info); - - std::vector weights = ge::OpDescUtils::MutableWeights(const_node); - /* Substitute the const value with a new one. */ - Status ret; - if (weights.empty()) { - GELOGD("The weight for %s is missing; creating a new one.", const_node->GetName().c_str()); - ge::GeTensorPtr const_out = nullptr; - FUSION_TURBO_NOTNULL(const_out_tensor_desc, nullptr); - GE_MAKE_SHARED(const_out = std::make_shared(*const_out_tensor_desc), return nullptr); - if (w_info.data == nullptr) { - GELOGE(FAILED, "[FusionTurbo][AddWeight][UpdateConst] Data is null."); - return nullptr; - } - ret = const_out->SetData(reinterpret_cast(w_info.data), w_info.total_data_size); - } else { - GELOGD("The weight for %s is not null, updating data.", const_node->GetName().c_str()); - ge::GeTensorPtr &const_out = weights.at(0); - FUSION_TURBO_NOTNULL(const_out_tensor_desc, nullptr); - const_out->SetTensorDesc(*const_out_tensor_desc); - ret = const_out->SetData(reinterpret_cast(w_info.data), w_info.total_data_size); - } - if (ret != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][AddWeight][UpdateConst] Failed to set data."); - return nullptr; - } - - return const_node; -} - -ge::NodePtr FusionTurbo::AddWeightAfter(const ge::NodePtr &node, const int32_t &index, - const WeightInfo &w_info) const { - FUSION_TURBO_NOTNULL(node, nullptr); - const auto output_anchor = node->GetOutDataAnchor(index); - FUSION_TURBO_NOTNULL(output_anchor, nullptr); - const auto peer_in_anchors = output_anchor->GetPeerInDataAnchors(); - if (peer_in_anchors.empty()) { - GELOGD("Node %s does not have peer in anchors.", node->GetName().c_str()); - return nullptr; - } - - const auto& first_peer_in_anchor = peer_in_anchors.at(0); - const auto first_peer_in_node = first_peer_in_anchor->GetOwnerNode(); - FUSION_TURBO_NOTNULL(first_peer_in_node, nullptr); - Relations output_relation(0, {node, index, PEER}); - - output_anchor->UnlinkAll(); - /* Add weight in front of first peer input of node. */ - auto const_node = AddWeight(first_peer_in_node, first_peer_in_anchor->GetIdx(), w_info); - FUSION_TURBO_NOTNULL(const_node, nullptr); - - if (LinkOutput(output_relation, const_node) != SUCCESS) { - return nullptr; - } - return const_node; -} - -ge::NodePtr FusionTurbo::AddWeight(const ge::NodePtr &node, const int32_t &index, const WeightInfo &w_info) const { - FUSION_TURBO_NOTNULL(node, nullptr); - const size_t input_size = node->GetAllInDataAnchorsSize(); - if (static_cast(index) >= input_size) { - GELOGD("Index %d is larger than input size %zu of %s.", index, input_size, node->GetName().c_str()); - return AddWeight(node, w_info); - } else { - const auto in_anchor = node->GetInDataAnchor(index); - /* 1. If the peer node of this input index is nullptr, we add a const node - * as input and update tensor desc. */ - FUSION_TURBO_NOTNULL(in_anchor, nullptr); - if (in_anchor->GetPeerOutAnchor() == nullptr) { - auto const_node = AddConstNode(node, w_info, index); - return const_node; - } - - /* 2. If the peer node of this input index is Const, we substitute the data - * of current Const and update tensor desc. */ - return UpdateConst(node, index, w_info); - } -} - -ge::NodePtr FusionTurbo::AddWeight(const ge::NodePtr &node, const string& tensor_name, const WeightInfo &w_info) const { - FUSION_TURBO_NOTNULL(node, nullptr); - const auto index = node->GetOpDesc()->GetInputIndexByName(tensor_name); - if (index == -1) { - return nullptr; - } - return AddWeight(node, index, w_info); -} - -ge::NodePtr FusionTurbo::AddWeight(const ge::NodePtr &node, - const WeightInfo &w_info) const { - FUSION_TURBO_NOTNULL(node, nullptr); - /* 1. Collect all existing weights. */ - vector weights = ge::OpDescUtils::MutableWeights(node); - - /* 2. Create new weight and link edges. */ - ge::GeTensorPtr w = GenerateWeightTensor(w_info); - if (w == nullptr) { - GELOGE(FAILED, "[FusionTurbo][AddWeight]Failed to generate weight for node %s.", node->GetName().c_str()); - return nullptr; - } - weights.emplace_back(w); - if (ge::OpDescUtils::SetWeights(node, weights) != ge::GRAPH_SUCCESS) { - return nullptr; - } - - /* 3. Return new weight node. */ - const auto in_size = static_cast(node->GetAllInDataAnchorsSize()); - const auto i = in_size - 1; - return GetPeerOutNode(node, i); -} - -std::vector FusionTurbo::AddWeights(const ge::NodePtr &node, - const vector &w_infos) const { - std::vector ret; - FUSION_TURBO_NOTNULL(node, ret); - /* 1. Colloect all existing weights. */ - vector weights = ge::OpDescUtils::MutableWeights(node); - - /* 2. Create new weights and link edges. */ - for (auto &w_info : w_infos) { - ge::GeTensorPtr w = GenerateWeightTensor(w_info); - if (w == nullptr) { - GELOGE(FAILED, "[FusionTurbo][AddWeights]Failed to generate weight for node %s.", node->GetName().c_str()); - return ret; - } - weights.emplace_back(w); - } - if (ge::OpDescUtils::SetWeights(node, weights) != ge::GRAPH_SUCCESS) { - return ret; - } - - /* 3. Return new weight nodes. */ - const auto in_size = static_cast(node->GetAllInDataAnchorsSize()); - GELOGD("in_size %zu, w_info size %zu", in_size, w_infos.size()); - for (size_t i = in_size - w_infos.size(); i < in_size; i++) { - auto peer_node = GetPeerOutNode(node, static_cast(i)); - ret.emplace_back(peer_node); - } - return ret; -} - -ge::GeTensorPtr FusionTurbo::MutableWeight(const ge::NodePtr &node, int32_t index) { - FUSION_TURBO_NOTNULL(node, nullptr); - const auto const_node = FusionTurboUtils::GetConstInput(node, index); - if (const_node == nullptr) { - return nullptr; - } - - std::vector weights = ge::OpDescUtils::MutableWeights(const_node); - if (weights.empty()) { - return nullptr; - } - - return weights.at(0); -} - -ge::NodePtr FusionTurbo::AddNodeOnly(const string &op_name, const string &op_type) const { - return AddNodeOnly(*graph_, op_name, op_type, 0); -} - -ge::NodePtr FusionTurbo::AddNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type) { - return AddNodeOnly(graph, op_name, op_type, 0); -} - -ge::NodePtr FusionTurbo::AddNodeOnly(const string &op_name, const string &op_type, size_t dynamic_num) const { - return AddNodeOnly(*graph_, op_name, op_type, dynamic_num); -} - -ge::OpDescPtr FusionTurbo::CreateOpDesc(const string &op_name, - const string &op_type, const size_t dynamic_num) { - const auto op = ge::OperatorFactory::CreateOperator(op_name.c_str(), op_type.c_str()); - if (op.IsEmpty()) { - GELOGW("Failed to create operator %s %s.", op_name.c_str(), op_type.c_str()); - return nullptr; - } - - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - if (dynamic_num != 0) { - size_t index = 0; - const auto &ir_inputs = op_desc->GetIrInputs(); - for (auto &ir_input : ir_inputs) { - if (ir_input.second == ge::kIrInputDynamic) { - (void)op_desc->AddInputDescMiddle(ir_input.first, static_cast(dynamic_num), index); - index += dynamic_num; - } else { - ++index; - } - } - - index = 0; - const auto &ir_outputs = op_desc->GetIrOutputs(); - for (auto &ir_output : ir_outputs) { - if (ir_output.second == ge::kIrOutputDynamic) { - (void)op_desc->AddOutputDescMiddle(ir_output.first, static_cast(dynamic_num), index); - index += dynamic_num; - } else { - ++index; - } - } - } - return op_desc; -} - -ge::NodePtr FusionTurbo::AddNodeOnly(ge::ComputeGraph &graph, const string &op_name, - const string &op_type, size_t dynamic_num) { - auto op_desc = CreateOpDesc(op_name, op_type, dynamic_num); - auto ret_node = graph.AddNode(op_desc); - return ret_node; -} - -ge::NodePtr FusionTurbo::InsertNodeOnly(const string &op_name, const string &op_type, - const ge::NodePtr &origin_node, - const size_t dynamic_num) const { - return InsertNodeOnly(*graph_, op_name, op_type, origin_node, dynamic_num); -} - -ge::NodePtr FusionTurbo::InsertNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type, - const ge::NodePtr &origin_node, - const size_t dynamic_num) { - auto op_desc = CreateOpDesc(op_name, op_type, dynamic_num); - auto ret_node = graph.InsertNode(origin_node, op_desc); - return ret_node; -} - -ge::NodePtr FusionTurbo::InsertNodeBefore(const string &op_name, const string &op_type, - const ge::NodePtr &base_node, const int32_t &base_input_index, - const int32_t &input_index, const int32_t &output_index) const { - FUSION_TURBO_NOTNULL(base_node, nullptr); - const auto base_desc = base_node->GetOpDesc(); - const auto base_input = base_desc->MutableInputDesc(static_cast(base_input_index)); - FUSION_TURBO_NOTNULL(base_input, nullptr); - - /* 1. Create new operator, OpDesc and Node. */ - const auto op = ge::OperatorFactory::CreateOperator(op_name.c_str(), op_type.c_str()); - if (op.IsEmpty()) { - GELOGE(FAILED, "[FusionTurbo][InstNodeBefore]Cannot find this op %s in op factory.", op_type.c_str()); - return nullptr; - } - - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - auto ret_node = graph_->AddNode(op_desc); - - const auto base_in_anchor = base_node->GetInDataAnchor(base_input_index); - FUSION_TURBO_NOTNULL(base_in_anchor, nullptr); - const auto peer_out_anchor = base_in_anchor->GetPeerOutAnchor(); - /* 2. Update Output desc using base node's successor node. */ - if (op_desc->UpdateOutputDesc(static_cast(output_index), *base_input) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][InstNodeBefore]Failed to update output %d of node %s", output_index, op_name.c_str()); - goto failed_process; - } - - if (peer_out_anchor != nullptr) { - const auto peer_out_index = peer_out_anchor->GetIdx(); - const auto peer_output = peer_out_anchor->GetOwnerNode()->GetOpDesc()->MutableOutputDesc( - static_cast(peer_out_index)); - FUSION_TURBO_NOTNULL(peer_output, nullptr); - /* 3. Update input desc using base node's father node. */ - if (op_desc->UpdateInputDesc(static_cast(input_index), *peer_output) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][InstNodeBefore]Failed to update input %d of node %s", input_index, op_name.c_str()); - goto failed_process; - } - - /* 4.1. Insert new op into graph and between peer-out and base-in anchors. */ - if (ge::GraphUtils::InsertNodeBefore(base_in_anchor, ret_node, static_cast(input_index), - static_cast(output_index)) != ge::GRAPH_SUCCESS) { - goto failed_process; - } - } else { - GELOGD("Input %d of base node %s does not have peer out node.", base_input_index, base_node->GetName().c_str()); - /* 4.2. Just insert new op before base-in anchor. */ - FUSION_TURBO_NOTNULL(ret_node, nullptr); - const auto out_anchor = ret_node->GetOutDataAnchor(output_index); - FUSION_TURBO_NOTNULL(out_anchor, nullptr); - if (ge::GraphUtils::AddEdge(out_anchor, base_in_anchor) != ge::GRAPH_SUCCESS) { - goto failed_process; - } - } - GELOGD("Succeed inserting %s before %s.", op_name.c_str(), base_node->GetName().c_str()); - return ret_node; - -failed_process: - graph_->RemoveNode(ret_node); - return nullptr; -} - -ge::NodePtr FusionTurbo::InsertNodeAfter(const string &op_name, const string &op_type, const ge::NodePtr &base_node, - const int32_t &base_output_index, const int32_t &input_index, - const int32_t &output_index) const { - FUSION_TURBO_NOTNULL(base_node, nullptr); - const auto base_desc = base_node->GetOpDesc(); - const auto base_output = base_desc->MutableOutputDesc(static_cast(base_output_index)); - FUSION_TURBO_NOTNULL(base_output, nullptr); - - const auto base_out_anchor = base_node->GetOutDataAnchor(base_output_index); - FUSION_TURBO_NOTNULL(base_out_anchor, nullptr); - auto peer_in_anchors = base_out_anchor->GetPeerInDataAnchors(); - - /* 1. Create new operator, OpDesc and Node. */ - const auto op = ge::OperatorFactory::CreateOperator(op_name.c_str(), op_type.c_str()); - if (op.IsEmpty()) { - GELOGE(FAILED, "[FusionTurbo][InstNodeAfter]Cannot find this op %s in op factory.", op_type.c_str()); - return nullptr; - } - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - auto ret_node = graph_->AddNode(op_desc); - - /* 2. Update input desc using base_node. */ - if (op_desc->UpdateInputDesc(static_cast(input_index), *base_output) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][InstNodeAfter]Failed to update input %d of node %s", input_index, op_name.c_str()); - goto failed_process; - } - - if (!peer_in_anchors.empty()) { - /* 3. Update output desc by peer input. */ - const auto peer_in_anchor = peer_in_anchors.at(0); - const auto peer_in_index = peer_in_anchor->GetIdx(); - const auto peer_node = peer_in_anchor->GetOwnerNode(); - const auto peer_input = peer_node->GetOpDesc()->MutableInputDesc(static_cast(peer_in_index)); - if (op_desc->UpdateOutputDesc(static_cast(output_index), *peer_input) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][InstNodeAfter]Failed to update output %d of node %s", - output_index, op_name.c_str()); - goto failed_process; - } - - /* 4.1. Insert new op between base-out anchor and every peer-in anchor. */ - const auto peer_in_anchors_vec = std::vector(peer_in_anchors.begin(), peer_in_anchors.end()); - if (ge::GraphUtils::InsertNodeAfter(base_out_anchor, peer_in_anchors_vec, ret_node, - static_cast(input_index), - static_cast(output_index)) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][InstNodeAfter]Failed to insert node after output %d of node %s", - base_output_index, base_node->GetName().c_str()); - goto failed_process; - } - } else { - GELOGD("Output %d of base node %s does not have a peer in the nodes.", base_output_index, base_node->GetName().c_str()); - FUSION_TURBO_NOTNULL(ret_node, nullptr); - const auto in_anchor = ret_node->GetInDataAnchor(input_index); - /* 4.2. Just insert new op after base-out anchor. */ - if (ge::GraphUtils::AddEdge(base_out_anchor, in_anchor) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][InstNodeAfter]Failed to add edge between %d of %s and %d of %s", - base_output_index, base_node->GetName().c_str(), input_index, op_name.c_str()); - goto failed_process; - } - } - GELOGD("Succeed inserting %s after %s.", op_name.c_str(), base_node->GetName().c_str()); - return ret_node; -failed_process: - graph_->RemoveNode(ret_node); - return nullptr; -} - -/* parent_node -> child_node */ -static Status HandleTensorUpdate(const ge::NodePtr &parent_node, const ge::NodePtr &child_node, - const uint32_t parent_index, const uint32_t child_index, - const bool update_child) { - if (update_child) { - const auto parent_out_tensor_desc = parent_node->GetOpDesc()->MutableOutputDesc(parent_index); - FUSION_TURBO_NOTNULL(parent_out_tensor_desc, FAILED); - if (child_node->GetOpDesc()->UpdateInputDesc(child_index, *parent_out_tensor_desc) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][LinkInput]Failed to update input %u of node %s", - child_index, child_node->GetName().c_str()); - return FAILED; - } - } else { - const auto child_in_tensor_desc = child_node->GetOpDesc()->MutableInputDesc(child_index); - FUSION_TURBO_NOTNULL(child_in_tensor_desc, FAILED); - if (parent_node->GetOpDesc()->UpdateOutputDesc(parent_index, *child_in_tensor_desc) != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][LinkOutput]Failed to update output %d of node %s", - parent_index, parent_node->GetName().c_str()); - return FAILED; - } - } - return SUCCESS; -} - -Status FusionTurbo::LinkInput(Relations &input_relations, - const ge::NodePtr &dst_node, - const TensorUptType &update_tensor) { - FUSION_TURBO_NOTNULL(dst_node, PARAM_INVALID); - const auto dst_op_desc = dst_node->GetOpDesc(); - const auto &in_relations = input_relations.GetInRelations(); - if (in_relations.empty()) { - GELOGD("dst_node %s's input relations is empty.", dst_node->GetName().c_str()); - return PARAM_INVALID; - } - - const auto dst_input_size = dst_node->GetAllInDataAnchorsSize(); - for (const auto &relation : in_relations) { - const auto dst_in_index = static_cast(relation.first); - if (dst_in_index >= dst_input_size) { - GELOGW("Dst input index %u is larger than dst node %s's input size %u.", - dst_in_index, dst_node->GetName().c_str(), dst_input_size); - continue; - } - - if (relation.second.empty()) { - continue; - } - - const auto src_node = relation.second.at(0).node; - const auto src_out_index = relation.second.at(0).index; - FUSION_TURBO_NOTNULL(src_node, PARAM_INVALID); - const auto out_anchor = src_node->GetOutDataAnchor(src_out_index); - FUSION_TURBO_NOTNULL(out_anchor, PARAM_INVALID); - /* 1. Update tensor descs. We assume the input desc of src node is correct. */ - if (update_tensor == UPDATE_THIS) { - (void)HandleTensorUpdate(src_node, dst_node, static_cast(src_out_index), dst_in_index, true); - } else if (update_tensor == UPDATE_PEER) { - (void)HandleTensorUpdate(src_node, dst_node, static_cast(src_out_index), dst_in_index, false); - } else { - // do nothing - } - /* 2. Link anchors. */ - const auto dst_in_anchor = dst_node->GetInDataAnchor(static_cast(dst_in_index)); - if (ge::GraphUtils::AddEdge(out_anchor, dst_in_anchor) != ge::GRAPH_SUCCESS) { - return FAILED; - } - GELOGD("SuccessFully link input %s %d ---> %s %d.", src_node->GetName().c_str(), src_out_index, - dst_node->GetName().c_str(), dst_in_index); - } - return SUCCESS; -} - -Status FusionTurbo::LinkOutput(Relations &output_relations, const ge::NodePtr &src_node, - const TensorUptType &update_tensor) { - FUSION_TURBO_NOTNULL(src_node, PARAM_INVALID); - const auto dst_op_desc = src_node->GetOpDesc(); - const auto &out_relations = output_relations.GetOutRelations(); - if (out_relations.empty()) { - GELOGD("src_node %s's output relations is empty.", src_node->GetName().c_str()); - return PARAM_INVALID; - } - - const auto src_op_desc = src_node->GetOpDesc(); - const auto src_output_size = src_node->GetAllOutDataAnchorsSize(); - - for (auto &relation : out_relations) { - const auto src_out_index = static_cast(relation.first); - if (src_out_index >= src_output_size) { - GELOGW("Source output index %u is larger than src node %s's output size %u.", - src_out_index, src_node->GetName().c_str(), src_output_size); - continue; - } - - if (relation.second.empty()) { - continue; - } - - for (const auto &ele: relation.second) { - const auto dst_node = ele.node; - const auto dst_index = ele.index; - if (dst_node == nullptr) { - continue; - } - - /* 1. Update tensor descs. */ - if (update_tensor == UPDATE_THIS) { - (void)HandleTensorUpdate(src_node, dst_node, src_out_index, static_cast(dst_index), false); - } else if (update_tensor == UPDATE_PEER) { - (void)HandleTensorUpdate(src_node, dst_node, src_out_index, static_cast(dst_index), true); - } else { - // do nothing - } - /* 2. Link all peer in anchors. */ - const auto in_anchor = dst_node->GetInDataAnchor(dst_index); - FUSION_TURBO_NOTNULL(in_anchor, PARAM_INVALID); - const auto peer_out = in_anchor->GetPeerOutAnchor(); - if (peer_out != nullptr) { - GELOGD("Dst node %s's input %d already has a peer output [%s].", dst_node->GetName().c_str(), dst_index, - peer_out->GetOwnerNode()->GetName().c_str()); - in_anchor->UnlinkAll(); - } - - const auto src_out_anchor = src_node->GetOutDataAnchor(static_cast(src_out_index)); - if (ge::GraphUtils::AddEdge(src_out_anchor, in_anchor) != ge::GRAPH_SUCCESS) { - return FAILED; - } - GELOGD("SuccessFully link output %s %d ---> %s %d.", src_node->GetName().c_str(), src_out_index, - dst_node->GetName().c_str(), dst_index); - } - } - return SUCCESS; -} - - -ge::NodePtr FusionTurbo::GetPeerOutNode(const ge::NodePtr &node, const int32_t &this_node_input_index) { - FUSION_TURBO_NOTNULL(node, nullptr); - const auto input_anchor = node->GetInDataAnchor(this_node_input_index); - FUSION_TURBO_NOTNULL(input_anchor, nullptr); - const auto peer_out_anchor = input_anchor->GetPeerOutAnchor(); - FUSION_TURBO_NOTNULL(peer_out_anchor, nullptr); - return peer_out_anchor->GetOwnerNode(); -} - -std::vector FusionTurbo::GetPeerInNodes(const ge::NodePtr &node, const int32_t &this_node_output_index) { - std::vector ret; - FUSION_TURBO_NOTNULL(node, ret); - const auto output_anchor = node->GetOutDataAnchor(this_node_output_index); - FUSION_TURBO_NOTNULL(output_anchor, ret); - const auto peer_in_anchors = output_anchor->GetPeerInDataAnchors(); - for (const auto& ele : peer_in_anchors) { - ret.emplace_back(ele->GetOwnerNode()); - } - - return ret; -} - -bool FusionTurbo::CheckConnected(const ge::NodePtr &node1, const ge::NodePtr &node2, const int32_t &index1) { - FUSION_TURBO_NOTNULL(node1, false); - FUSION_TURBO_NOTNULL(node2, false); - if (index1 == -1) { - const auto all_output_of_node1 = node1->GetOutDataNodes(); - for (const auto &out_node : all_output_of_node1) { - return out_node == node2; - } - } else { - auto peer_in_nodes = GetPeerInNodes(node1, index1); - return (std::find(peer_in_nodes.begin(), peer_in_nodes.end(), node2) != peer_in_nodes.end()); - } - return false; -} - -Status FusionTurbo::UpdateInputByPeer(const ge::NodePtr &node, const int32_t &index, - const ge::NodePtr &peer_node, const int32_t &peer_index) const { - FUSION_TURBO_NOTNULL(node, PARAM_INVALID); - FUSION_TURBO_NOTNULL(peer_node, PARAM_INVALID); - - const auto peer_output_desc = peer_node->GetOpDesc()->MutableOutputDesc(static_cast(peer_index)); - FUSION_TURBO_NOTNULL(peer_output_desc, PARAM_INVALID); - - const auto input_desc = node->GetOpDesc()->MutableInputDesc(static_cast(index)); - FUSION_TURBO_NOTNULL(input_desc, PARAM_INVALID); - - *input_desc = *peer_output_desc; - - return SUCCESS; -} - -Status FusionTurbo::UpdateOutputByPeer(const ge::NodePtr &node, const int32_t &index, - const ge::NodePtr &peer_node, const int32_t &peer_index) const { - FUSION_TURBO_NOTNULL(node, PARAM_INVALID); - FUSION_TURBO_NOTNULL(peer_node, PARAM_INVALID); - - const auto peer_input_desc = peer_node->GetOpDesc()->MutableInputDesc(static_cast(peer_index)); - FUSION_TURBO_NOTNULL(peer_input_desc, PARAM_INVALID); - - const auto output_desc = node->GetOpDesc()->MutableOutputDesc(static_cast(index)); - FUSION_TURBO_NOTNULL(output_desc, PARAM_INVALID); - - *output_desc = *peer_input_desc; - return SUCCESS; -} - -bool FusionTurbo::IsUnknownShape(const ge::NodePtr &node, const int32_t &index, const bool &is_input) { - ge::GeTensorDescPtr tensor; - if (is_input) { - tensor = node->GetOpDesc()->MutableInputDesc(static_cast(index)); - } else { - tensor = node->GetOpDesc()->MutableOutputDesc(static_cast(index)); - } - FUSION_TURBO_NOTNULL(tensor, false); - const auto &shape = tensor->MutableShape(); - return shape.IsUnknownShape(); -} - -bool FusionTurbo::IsUnknownOriShape(const ge::NodePtr &node, const int32_t &index, const bool &is_input) { - ge::GeTensorDescPtr tensor; - if (is_input) { - tensor = node->GetOpDesc()->MutableInputDesc(static_cast(index)); - } else { - tensor = node->GetOpDesc()->MutableOutputDesc(static_cast(index)); - } - FUSION_TURBO_NOTNULL(tensor, false); - const auto &shape = tensor->GetOriginShape(); - return shape.IsUnknownShape(); -} - -Status FusionTurbo::TransferOutCtrlEdges(const std::vector &nodes, - const ge::NodePtr &new_node) { - FUSION_TURBO_NOTNULL(new_node, FAILED); - for (const auto &node : nodes) { - if (node == nullptr) { - continue; - } - const auto peer_in_ctrl_nodes = node->GetOutControlNodes(); - if (peer_in_ctrl_nodes.empty()) { - continue; - } - - for (const auto &in_node : peer_in_ctrl_nodes) { - if (new_node == in_node) { - GELOGD("Out Ctrl: Avoid same source and destination %s.", new_node->GetName().c_str()); - continue; - } - (void)ge::GraphUtils::AddEdge(new_node->GetOutControlAnchor(), in_node->GetInControlAnchor()); - } - } - return SUCCESS; -} - -Status FusionTurbo::TransferInCtrlEdges(const std::vector &nodes, - const ge::NodePtr &new_node) { - FUSION_TURBO_NOTNULL(new_node, FAILED); - for (const auto &node : nodes) { - if (node == nullptr) { - continue; - } - const auto peer_out_ctrl_nodes = node->GetInControlNodes(); - if (peer_out_ctrl_nodes.empty()) { - continue; - } - - for (const auto &out_node : peer_out_ctrl_nodes) { - if (out_node == new_node) { - GELOGD("In Ctrl: avoid same source and destination %s.", new_node->GetName().c_str()); - continue; - } - if (ge::GraphUtils::AddEdge(out_node->GetOutControlAnchor(), new_node->GetInControlAnchor()) != - ge::GRAPH_SUCCESS) { - return FAILED; - } - } - } - return SUCCESS; -} - -ge::NodePtr FusionTurbo::MultiInOne(const string &node_name, const string &node_type, - Relations &input_relations, - Relations &output_relations, - const std::vector &old_nodes, - const bool &remove_old) { - auto node = AddNodeOnly(node_name, node_type); - if (MultiInOne(node, input_relations, output_relations, old_nodes, remove_old) != SUCCESS) { - (void)graph_->RemoveNode(node); - return nullptr; - } - return node; -} - -Status FusionTurbo::MultiInOne(const ge::NodePtr &new_node, - Relations &input_relations, - Relations &output_relations, - const std::vector &old_nodes, - const bool &remove_old) { - FUSION_TURBO_NOTNULL(new_node, FAILED); - GELOGD("Merged multiple nodes into %s.", new_node->GetName().c_str()); - /* Check params. */ - const auto &in_ori_relaitons = input_relations.GetRelations(); - if (in_ori_relaitons.size() > new_node->GetAllInDataAnchorsSize()) { - GELOGE(FAILED, "[FusionTurbo][MultiInOne][ChkInput]Input relation size %zu is larger than %s's input size %u.", - in_ori_relaitons.size(), new_node->GetName().c_str(), new_node->GetAllInDataAnchorsSize()); - return FAILED; - } - const auto &out_ori_relaitons = output_relations.GetRelations(); - if (out_ori_relaitons.size() > new_node->GetAllOutDataAnchorsSize()) { - GELOGE(FAILED, "[FusionTurbo][MultiInOne][ChkOutput]Output relation size %zu is larger than %s's output size %u.", - out_ori_relaitons.size(), new_node->GetName().c_str(), new_node->GetAllOutDataAnchorsSize()); - return FAILED; - } - - /* Link data edges. */ - if (LinkInput(input_relations, new_node, UPDATE_THIS) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][MultiInOne][LnkIn]Failed to link input for node %s.", new_node->GetName().c_str()); - return FAILED; - } - - if (output_relations.GetOutRelations().empty()) { - GELOGD("[FusionTurbo][MultiInOne][LnkOut] Output relations is empty, skip for node %s.", - new_node->GetName().c_str()); - } else if (LinkOutput(output_relations, new_node, UPDATE_THIS) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][MultiInOne][LnkOut]Failed to link output for node %s.", new_node->GetName().c_str()); - return FAILED; - } else { - // No return value expected. - } - - /* Link control edges. */ - if (TransferInCtrlEdges(old_nodes, new_node) != SUCCESS) { - return FAILED; - } - - if (TransferOutCtrlEdges(old_nodes, new_node) != SUCCESS) { - return FAILED; - } - - if (remove_old) { - for (auto &old_node : old_nodes) { - (void)RemoveNodeOnly(old_node); - } - } - return SUCCESS; -} - -bool FusionTurbo::HasInControl(const ge::NodePtr &node) { - FUSION_TURBO_NOTNULL(node, false); - const auto in_control_anchor = node->GetInControlAnchor(); - for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { - if (peer_out_control_anchor->GetOwnerNode() != nullptr) { - return true; - } - } - return false; -} - -bool FusionTurbo::HasOutControl(const ge::NodePtr &node) { - FUSION_TURBO_NOTNULL(node, false); - const auto out_control_anchor = node->GetOutControlAnchor(); - for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { - if (peer_in_control_anchor->GetOwnerNode() != nullptr) { - return true; - } - } - return false; -} - -bool FusionTurbo::HasOutData(const ge::NodePtr &node) { - FUSION_TURBO_NOTNULL(node, false); - const auto out_data_anchors = node->GetAllOutDataAnchors(); - for (const auto &out_anchor : out_data_anchors) { - for (const auto &peer_in_data_anchor : out_anchor->GetPeerInDataAnchors()) { - if (peer_in_data_anchor->GetOwnerNode() != nullptr) { - return true; - } - } - } - return false; -} - -bool FusionTurbo::HasControl(const ge::NodePtr &node) { - return HasInControl(node) || HasOutControl(node); -} - -Status FusionTurbo::MoveDataOutputUp(const ge::NodePtr &node, int32_t index) { - const NodeIndex subgraph_node = FusionTurboUtils::GetPeerOutPair(node, index); - FUSION_TURBO_NOTNULL(subgraph_node.node, FAILED); - - uint32_t subgraph_output_size = subgraph_node.node->GetAllOutDataAnchorsSize(); - const auto node_pair_peer_out_anchor = subgraph_node.node->GetOutDataAnchor(subgraph_node.index); - ge::OutDataAnchorPtr out_link_anchor = node_pair_peer_out_anchor; - - // for multi outputs, first output move to current node output index, others need add new output anchor - for (size_t node_outanchor_index = 0; node_outanchor_index < node->GetAllOutDataAnchorsSize(); - ++node_outanchor_index) { - if (node_outanchor_index != 0) { - (void)subgraph_node.node->GetOpDesc()->AddOutputDesc(node->GetOpDesc()->GetOutputDesc( - static_cast(node_outanchor_index))); - out_link_anchor = subgraph_node.node->GetOutDataAnchor(static_cast(subgraph_output_size)); - ++subgraph_output_size; - } else { - FUSION_TURBO_NOTNULL(out_link_anchor, FAILED); - (void)subgraph_node.node->GetOpDesc()->UpdateOutputDesc(static_cast(out_link_anchor->GetIdx()), - node->GetOpDesc()->GetOutputDesc(static_cast(node_outanchor_index))); - } - const auto node_outanchor = node->GetOutDataAnchor(static_cast(node_outanchor_index)); - FUSION_TURBO_NOTNULL(node_outanchor, FAILED); - for (auto &peer_in_anchor : node_outanchor->GetPeerInDataAnchors()) { - if (peer_in_anchor->Unlink(node_outanchor) != ge::GRAPH_SUCCESS) { - return FAILED; - } - if (peer_in_anchor->LinkFrom(out_link_anchor) != ge::GRAPH_SUCCESS) { - return FAILED; - } - } - } - return SUCCESS; -} - -static ge::NodePtr AddSubGraphDataWithIndex(const ge::ComputeGraphPtr &graph, const ge::GeTensorDesc &tensor_desc, - const int32_t node_input_size) { - const std::string data_name = "Data_" + std::to_string(node_input_size); - auto data_node = FusionTurbo::AddNodeOnly(*graph, data_name, "Data"); - FUSION_TURBO_NOTNULL(data_node, nullptr); - auto op_desc = data_node->GetOpDesc(); - (void)ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, node_input_size); - (void)op_desc->UpdateInputDesc(0U, tensor_desc); - (void)op_desc->UpdateOutputDesc(0U, tensor_desc); - return data_node; -} - -static ge::NodePtr FindSubgraphData(const ge::ComputeGraphPtr &graph, const int32_t index) { - ge::NodePtr pair_data_node = nullptr; - for (auto &tmp_node : graph->GetDirectNode()) { - int64_t ref_i; - if ((tmp_node->GetType() == "Data") && - (ge::AttrUtils::GetInt(tmp_node->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, ref_i)) && - (ref_i == static_cast(index))) { - pair_data_node = tmp_node; - break; - } - } - return pair_data_node; -} - -static int32_t GetNetOutputTensorIndex(const ge::NodePtr &node, const int32_t index) { - int32_t tensor_index = 0; - for (uint32_t input_index = 0; input_index < node->GetOpDesc()->GetAllInputsSize(); ++input_index) { - int64_t parent_index = -1; - const auto input_desc = node->GetOpDesc()->MutableInputDesc(input_index); - (void)ge::AttrUtils::GetInt(input_desc, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index); - if (parent_index == index) { - tensor_index = static_cast(input_index); - break; - } - } - - return tensor_index; -} - -static Status MoveDataInputUpToSubgraph(const ge::NodePtr &node, const int32_t index, Relations &input_relations) { - const NodeIndex subgraph_node = FusionTurboUtils::GetPeerOutPair(node, index); - FUSION_TURBO_NOTNULL(subgraph_node.node, FAILED); - const auto subgraph = ge::NodeUtils::GetSubgraph(*subgraph_node.node, 0); - FUSION_TURBO_NOTNULL(subgraph, FAILED); - const auto netout_node = subgraph->FindFirstNodeMatchType(kNetOutput); - FUSION_TURBO_NOTNULL(netout_node, FAILED); - - const auto netout_tensor_index = GetNetOutputTensorIndex(netout_node, subgraph_node.index); - uint32_t subgraph_node_input_size = subgraph_node.node->GetAllInDataAnchorsSize(); - - /* for current node multi inputs, subgraph at move direction record the netout peer anchor - ** other inputs need add new data node and record input - */ - for (uint32_t node_inanchor_index = 0; node_inanchor_index < node->GetAllInDataAnchorsSize(); ++node_inanchor_index) { - const auto node_inanchor = node->GetInDataAnchor(static_cast(node_inanchor_index)); - FUSION_TURBO_NOTNULL(node_inanchor, FAILED); - const auto out_data_anchor = node_inanchor->GetPeerOutAnchor(); - if (out_data_anchor == nullptr) { - continue; - } - (void)out_data_anchor->Unlink(node_inanchor); - if (node_inanchor_index == static_cast(index)) { - (void)input_relations.Add(static_cast(node_inanchor_index), {netout_node, netout_tensor_index, PEER}); - continue; - } - if (ge::NodeUtils::AppendInputAnchor(subgraph_node.node, subgraph_node_input_size + 1) != ge::GRAPH_SUCCESS) { - return FAILED; - } - FUSION_TURBO_NOTNULL(subgraph_node.node->GetInDataAnchor(static_cast(subgraph_node_input_size)), FAILED); - if (subgraph_node.node->GetInDataAnchor(static_cast(subgraph_node_input_size))->LinkFrom(out_data_anchor) - != ge::GRAPH_SUCCESS) { - return FAILED; - } - - const ge::NodePtr data_node = AddSubGraphDataWithIndex(subgraph, node->GetOpDesc()->GetInputDesc(node_inanchor_index), - static_cast(subgraph_node_input_size)); - (void)input_relations.Add(static_cast(node_inanchor_index), {data_node, 0}); - ++subgraph_node_input_size; - } - return SUCCESS; -} - -Status FusionTurbo::GraphNodeUpMigration(const ge::NodePtr &node, - const int32_t index) { - if (HasControl(node)) { - GELOGD("[FusionTurbo][GraphNodeUpMigration] node:%s has control anchors, cannot move", node->GetName().c_str()); - return NOT_CHANGED; - } - const NodeIndex pre_node_index = FusionTurboUtils::GetPeerOutPair(node, index); - FUSION_TURBO_NOTNULL(pre_node_index.node, FAILED); - const auto subgraph = ge::NodeUtils::GetSubgraph(*pre_node_index.node, 0); - FUSION_TURBO_NOTNULL(subgraph, FAILED); - const auto netout_node = subgraph->FindFirstNodeMatchType(kNetOutput); - FUSION_TURBO_NOTNULL(netout_node, FAILED); - const auto netout_tensor_index = GetNetOutputTensorIndex(netout_node, pre_node_index.index); - - /* move all data output up to parent subgraph node, clear current op all output */ - if (MoveDataOutputUp(node, index) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][GraphNodeUpMigration][MoveDataOutputUp] Failed to relink output for node:%s", - node->GetName().c_str()); - return FAILED; - } - - /* move all data input up to parent subgraph node, and record the input relationship for node to added in subgraph */ - Relations input_relations; - if (MoveDataInputUpToSubgraph(node, index, input_relations) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][GraphNodeUpMigration][MoveDataInputUpToSubgraph] Failed to relink output for node:%s", - node->GetName().c_str()); - return FAILED; - } - - (void)RemoveNodeOnly(node); - - Relations output_relations(0, {netout_node, netout_tensor_index}); - (void)BreakInput(netout_node, {netout_tensor_index}); - - const auto node_in_subgraph = subgraph->AddNode(node->GetOpDesc()); - FUSION_TURBO_NOTNULL(node_in_subgraph, FAILED); - if (LinkInput(input_relations, node_in_subgraph, UPDATE_NONE) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][GraphNodeUpMigration][LnkIn] Failed to link input for node:%s", - node_in_subgraph->GetName().c_str()); - return FAILED; - } - - if (LinkOutput(output_relations, node_in_subgraph, UPDATE_NONE) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][GraphNodeUpMigration][LnkOut] Failed to link input for node:%s", - node_in_subgraph->GetName().c_str()); - return FAILED; - } - return SUCCESS; -} - -static Status MoveDataInputDownToSubgraph(const ge::NodePtr &node, const int32_t index, Relations &input_relations) { - const NodeIndex out_node_index = FusionTurboUtils::GetPeerInFirstPair(node, index); - FUSION_TURBO_NOTNULL(out_node_index.node, FAILED); - const auto subgraph = ge::NodeUtils::GetSubgraph(*out_node_index.node, 0U); - FUSION_TURBO_NOTNULL(subgraph, FAILED); - const ge::NodePtr pair_data_node = FindSubgraphData(subgraph, out_node_index.index); - FUSION_TURBO_NOTNULL(pair_data_node, FAILED); - - uint32_t subgraph_node_input_size = out_node_index.node->GetAllInDataAnchorsSize(); - ge::InDataAnchorPtr linkin_anchor = out_node_index.node->GetInDataAnchor(out_node_index.index); - FUSION_TURBO_NOTNULL(linkin_anchor, FAILED); - linkin_anchor->UnlinkAll(); - - // for multi inputs, first input connect to current node data in subgraph, others need create new data node - for (uint32_t node_inanchor_index = 0; node_inanchor_index < node->GetAllInDataAnchorsSize(); ++node_inanchor_index) { - const auto input_tensor_desc = node->GetOpDesc()->GetInputDesc(node_inanchor_index); - if (node_inanchor_index != 0) { - if (ge::NodeUtils::AppendInputAnchor(out_node_index.node, subgraph_node_input_size + 1) != ge::GRAPH_SUCCESS) { - return FAILED; - } - linkin_anchor = out_node_index.node->GetInDataAnchor(static_cast(subgraph_node_input_size)); - subgraph_node_input_size++; - } else { - (void)out_node_index.node->GetOpDesc()->UpdateInputDesc(static_cast(linkin_anchor->GetIdx()), - input_tensor_desc); - } - const auto node_inanchor = node->GetInDataAnchor(static_cast(node_inanchor_index)); - FUSION_TURBO_NOTNULL(node_inanchor, FAILED); - const auto peer_out_anchor = node_inanchor->GetPeerOutAnchor(); - FUSION_TURBO_NOTNULL(peer_out_anchor, FAILED); - if (peer_out_anchor->Unlink(node_inanchor) != ge::GRAPH_SUCCESS) { - return FAILED; - } - if (peer_out_anchor->LinkTo(linkin_anchor) != ge::GRAPH_SUCCESS) { - return FAILED; - } - ge::NodePtr data_node(pair_data_node); - if (node_inanchor_index != 0) { - data_node = AddSubGraphDataWithIndex(subgraph, input_tensor_desc, - static_cast(subgraph_node_input_size) - 1); - } - (void)input_relations.Add(static_cast(node_inanchor_index), {data_node, 0}); - } - return SUCCESS; -} - -Status FusionTurbo::GraphNodeDownMigration(const ge::NodePtr &node, - const int32_t index) { - if ((node->GetOutDataNodesSize() != 1) || (HasControl(node))) { - GELOGD("[FusionTurbo][GraphNodeDownMigration] Node: %s has multiple outputs or contains control nodes, cannot migrate", - node->GetName().c_str()); - return NOT_CHANGED; - } - const NodeIndex out_node_index = FusionTurboUtils::GetPeerInFirstPair(node, index); - FUSION_TURBO_NOTNULL(out_node_index.node, FAILED); - const auto subgraph = ge::NodeUtils::GetSubgraph(*out_node_index.node, 0U); - FUSION_TURBO_NOTNULL(subgraph, FAILED); - const ge::NodePtr pair_data_node = FindSubgraphData(subgraph, out_node_index.index); - FUSION_TURBO_NOTNULL(pair_data_node, FAILED); - - /* move data input down to subgraph node, and record input relationship for node to be added in subgraph */ - Relations input_relations; - if (MoveDataInputDownToSubgraph(node, index, input_relations) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][GraphNodeDownMigration][MoveDataInputDownSubgraph] Failed to link input for node:%s", - node->GetName().c_str()); - return FAILED; - } - (void)BreakOutput(node, {index}); - - Relations output_relations(0, {pair_data_node, 0, PEER}); - (void)BreakOutput(pair_data_node, {0}); - (void)RemoveNodeOnly(node); - - const auto node_in_subgraph = subgraph->AddNode(node->GetOpDesc()); - FUSION_TURBO_NOTNULL(node_in_subgraph, FAILED); - if (LinkInput(input_relations, node_in_subgraph, UPDATE_NONE) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][GraphNodeDownMigration][LnkIn] Failed to link input for node:%s", - node_in_subgraph->GetName().c_str()); - return FAILED; - } - if (LinkOutput(output_relations, node_in_subgraph, UPDATE_NONE) != SUCCESS) { - GELOGE(FAILED, "[FusionTurbo][GraphNodeDownMigration][LnkOut] Failed to link input for node:%s", - node_in_subgraph->GetName().c_str()); - return FAILED; - } - return SUCCESS; -} - -NodeIndex FusionTurbo::GetPeerInFirstPair(const ge::NodePtr &node, int32_t index) { - return FusionTurboUtils::GetPeerInFirstPair(node, index); -} - -NodeIndex FusionTurbo::GetPeerOutPair(const ge::NodePtr &node, int32_t index) { - return FusionTurboUtils::GetPeerOutPair(node, index); -} -} diff --git a/register/graph_optimizer/fusion_common/fusion_turbo_utils.cc b/register/graph_optimizer/fusion_common/fusion_turbo_utils.cc deleted file mode 100644 index 0cc5297ccfa9138f7fd1cd3aa31006c14b002de7..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/fusion_common/fusion_turbo_utils.cc +++ /dev/null @@ -1,364 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/fusion_turbo_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/anchor.h" -#include "graph/utils/node_utils.h" - -namespace fe { -const std::array(ge::DT_MAX + 1)> data_type_size = { - 4, // DT_FLOAT = 0, - 2, // DT_FLOAT16 = 1, - 1, // DT_INT8 = 2, - 4, // DT_INT32 = 3, - 1, // DT_UINT8 = 4, - 1, // DT_xxxx = 5, - 2, // DT_INT16 = 6, - 2, // DT_UINT16 = 7, - 4, // DT_UINT32 = 8, - 8, // DT_INT64 = 9, - 8, // DT_UINT64 = 10, - 8, // DT_DOUBLE = 11, - 1, // DT_BOOL = 12, - 8, // DT_STRING = 13, - 1, // DT_DUAL_SUB_INT8 = 14, - 1, // DT_DUAL_SUB_UINT8 = 15, - 8, // DT_COMPLEX64 = 16, - 16, // DT_COMPLEX128 = 17, - 1, // DT_QINT8 = 18, - 2, // DT_QINT16 = 19, - 4, // DT_QINT32 = 20, - 1, // DT_QUINT8 = 21, - 2, // DT_QUINT16 = 22, - 1, // DT_RESOURCE = 23, - 1, // DT_STRING_REF = 24, - 1, // DT_DUAL = 25, - 1, // DT_VARIANT = 26, - 2, // DT_BF16 = 27, - 1, // DT_UNDEFINED = 28, - 1, // DT_INT4 = 29, - 1, // DT_UINT1 = 30, - 1, // DT_INT2 = 31, - 4, // DT_UINT2 = 32, - 4, // DT_COMPLEX32 = 33, - 0, // DT_MAX = 34 -}; - -ge::NodePtr FusionTurboUtils::GetConstInput(const ge::NodePtr &node, int32_t index) { - ge::NodePtr ret = nullptr; - - auto in_anchor = node->GetInDataAnchor(index); - - FUSION_TURBO_NOTNULL(in_anchor, nullptr); - const auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - return nullptr; - } - - const auto in_node = out_anchor->GetOwnerNode(); - if (in_node->GetType() == ge::CONSTANT) { - ret = in_node; - } else if (in_node->GetType() == ge::DATA) { - const auto parent = ge::NodeUtils::GetParentInput(in_node); - if ((parent != nullptr) && (parent->GetType() == ge::CONSTANT)) { - ret = parent; - } - } else { - // do nothing - } - - return ret; -} - -NodeIndex FusionTurboUtils::GetPeerOutPair(const ge::NodePtr &node, int32_t index) { - NodeIndex ret; - FUSION_TURBO_NOTNULL(node, ret); - if (static_cast(index) >= node->GetAllInDataAnchorsSize()) { - return ret; - } - auto input_anchor = node->GetInDataAnchor(index); - if (input_anchor == nullptr) { - return ret; - } - auto peer_anchor = input_anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - return ret; - } - - auto peer_anchor_index = peer_anchor->GetIdx(); - auto actual_node = peer_anchor->GetOwnerNode(); - ret.node = actual_node; - ret.index = peer_anchor_index; - return ret; -} - -void Relations::AppendPeerInAllPairs(ThisIndex relation_index, const ge::NodePtr &node, int32_t index) { - if (static_cast(index) >= node->GetAllOutDataAnchorsSize()) { - return; - } - auto output_anchor = node->GetOutDataAnchor(index); - if (output_anchor == nullptr) { - return; - } - - auto peer_anchors = output_anchor->GetPeerInDataAnchors(); - if (peer_anchors.empty()) { - return; - } - - for (const auto &ele : peer_anchors) { - NodeIndex temp(ele->GetOwnerNode(), ele->GetIdx()); - out_relations[relation_index].emplace_back(temp); - } -} - -NodeIndex FusionTurboUtils::GetPeerInFirstPair(const ge::NodePtr &node, int32_t index) { - NodeIndex ret; - if (static_cast(index) >= node->GetAllOutDataAnchorsSize()) { - return ret; - } - - auto output_anchor = node->GetOutDataAnchor(index); - if (output_anchor == nullptr) { - return ret; - } - auto peer_anchors = output_anchor->GetPeerInDataAnchors(); - if (peer_anchors.empty()) { - return ret; - } - - ret.index = peer_anchors.at(0)->GetIdx(); - ret.node = peer_anchors.at(0)->GetOwnerNode(); - return ret; -} - -void Relations::PreProcessNodeIndices(ThisIndex index, const NodeIndices &node_indices) { - for (const auto &node_index : node_indices) { - PreProcessOneNodeIndex(index, node_index); - } -} - -void Relations::PreProcessOneNodeIndex(ThisIndex index, const NodeIndex &node_index) { - if (node_index.node == nullptr) { - return; - } - - if (node_index.direction != PEER && node_index.direction != PEER_SINGLE) { - in_relations[index].emplace_back(node_index); - out_relations[index].emplace_back(node_index); - } else { - /* Update input's peer nodes */ - auto peer_out = FusionTurboUtils::GetPeerOutPair(node_index.node, node_index.index); - if (peer_out.node != nullptr) { - in_relations[index].emplace_back(peer_out); - } else { - GELOGD("Peer input for %s %u is nullptr", node_index.node->GetName().c_str(), node_index.index); - } - - /* Update output's peer nodes */ - if (node_index.direction == PEER) { - AppendPeerInAllPairs(index, node_index.node, node_index.index); - } else if (node_index.direction == PEER_SINGLE) { - auto peer_in = FusionTurboUtils::GetPeerInFirstPair(node_index.node, node_index.index); - if (peer_in.node == nullptr) { - GELOGD("Peer output for %s %u is nullptr", node_index.node->GetName().c_str(), node_index.index); - } else { - out_relations[index].emplace_back(peer_in); - } - } - } -} - -void Relations::PreProcess() { - in_relations.clear(); - out_relations.clear(); - for (auto &relation : ori_relations) { - for (auto &pair : relation.second) { - PreProcessOneNodeIndex(relation.first, pair); - } - } -} - -Relations::Relations() {} - -Relations::Relations(const std::initializer_list &peer_indices) { - int32_t index = 0; - for (const auto &node_index : peer_indices) { - NodeIndices temp = {node_index}; - std::ignore = ori_relations.emplace(std::make_pair(index, temp)); - } - PreProcess(); -} - -Relations::Relations(const std::map &relations_param) { - ori_relations = relations_param; - PreProcess(); -} - -Relations::Relations(std::map &&relations_param) { - ori_relations = std::move(relations_param); - PreProcess(); -} - -Relations::Relations(const Relations &relations_param) { - ori_relations = relations_param.ori_relations; - in_relations = relations_param.in_relations; - out_relations = relations_param.out_relations; -} - -Relations::Relations(Relations &&relations_param) noexcept { - ori_relations = std::move(relations_param.ori_relations); - in_relations = std::move(relations_param.in_relations); - out_relations = std::move(relations_param.out_relations); -} - -Relations& Relations::operator=(const Relations &relations_param) { - ori_relations = relations_param.ori_relations; - in_relations = relations_param.in_relations; - out_relations = relations_param.out_relations; - return *this; -} - -Relations& Relations::operator=(Relations &&relations_param) noexcept { - ori_relations = std::move(relations_param.ori_relations); - in_relations = std::move(relations_param.in_relations); - out_relations = std::move(relations_param.out_relations); - return *this; -} - -Relations::Relations(ThisIndex this_index, const NodeIndex &peer_index) { - std::ignore = Add(this_index, peer_index); -} - -Relations::Relations(ThisIndex this_index, const NodeIndices &peer_indices) { - std::ignore = Add(this_index, peer_indices); -} - -Relations::Relations(ThisIndex this_index, NodeIndex &&peer_index) { - std::ignore = Add(this_index, std::move(peer_index)); -} - -Relations::Relations(ThisIndex this_index, NodeIndices &&peer_indices) { - std::ignore = Add(this_index, std::move(peer_indices)); -} - -Relations::Relations( - const std::initializer_list> &peer_indices) { - for (const auto &index_pair: peer_indices) { - std::ignore = Add(index_pair.first, index_pair.second); - } -} - -Relations::Relations( - const std::initializer_list>> &peer_indices_vec) { - for (const auto &peer_indices: peer_indices_vec) { - std::ignore = Add(peer_indices.first, peer_indices.second); - } -} - -Relations& Relations::Add(ThisIndex this_index, const NodeIndex &peer_index) { - const auto iter = ori_relations.find(this_index); - if (iter == ori_relations.end()) { - NodeIndices temp = {peer_index}; - std::ignore = ori_relations.emplace(std::make_pair(this_index, temp)); - } else { - iter->second.emplace_back(peer_index); - } - PreProcessOneNodeIndex(this_index, peer_index); - return *this; -} - -Relations& Relations::Add(ThisIndex this_index, NodeIndex &&peer_index) { - PreProcessOneNodeIndex(this_index, peer_index); - const auto iter = ori_relations.find(this_index); - if (iter == ori_relations.end()) { - NodeIndices temp = {std::move(peer_index)}; - std::ignore = ori_relations.emplace(std::make_pair(this_index, std::move(temp))); - } else { - iter->second.emplace_back(std::move(peer_index)); - } - return *this; -} - -Relations& Relations::Add(ThisIndex this_index, const std::initializer_list &peer_indices) { - const auto iter = ori_relations.find(this_index); - if (iter == ori_relations.end()) { - std::ignore = ori_relations.emplace(std::make_pair(this_index, peer_indices)); - } else { - for (const auto &peer_index : peer_indices) { - iter->second.emplace_back(peer_index); - } - } - PreProcessNodeIndices(this_index, peer_indices); - return *this; -} - -Relations& Relations::Add(ThisIndex this_index, const NodeIndices &peer_indices) { - const auto iter = ori_relations.find(this_index); - if (iter == ori_relations.end()) { - std::ignore = ori_relations.emplace(std::make_pair(this_index, peer_indices)); - } else { - for (const auto &peer_index : peer_indices) { - iter->second.emplace_back(peer_index); - } - } - PreProcessNodeIndices(this_index, peer_indices); - return *this; -} - -Relations& Relations::Add(ThisIndex this_index, NodeIndices &&peer_indices) { - PreProcessNodeIndices(this_index, peer_indices); - const auto iter = ori_relations.find(this_index); - if (iter == ori_relations.end()) { - std::ignore = ori_relations.emplace(std::make_pair(this_index, std::move(peer_indices))); - } else { - for (auto &&peer_index : peer_indices) { - iter->second.emplace_back(std::move(peer_index)); - } - } - return *this; -} - -Relations& Relations::UpdatePeerIndex(ThisIndex this_index, NodeIndices &&peer_indices) { - ori_relations[this_index] = std::move(peer_indices); - PreProcess(); - return *this; -} - -Relations& Relations::UpdatePeerIndex(ThisIndex this_index, const NodeIndices &peer_indices) { - ori_relations[this_index] = peer_indices; - PreProcess(); - return *this; -} - -Relations& Relations::UpdatePeerIndex(const std::map &peer_indices) { - ori_relations = peer_indices; - PreProcess(); - return *this; -} - -Relations& Relations::UpdatePeerIndex(std::map &&peer_indices) { - ori_relations = std::move(peer_indices); - PreProcess(); - return *this; -} - -const std::map& Relations::GetRelations() { - return ori_relations; -} - -const std::map& Relations::GetInRelations() { - return in_relations; -} - -const std::map& Relations::GetOutRelations() { - return out_relations; -} -} diff --git a/register/graph_optimizer/fusion_common/op_slice_info.cc b/register/graph_optimizer/fusion_common/op_slice_info.cc deleted file mode 100644 index ca8c173ba141d8c4e7a0467a151df57299b0901c..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/fusion_common/op_slice_info.cc +++ /dev/null @@ -1,715 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/op_slice_info.h" -#include -#include "graph/debug/ge_log.h" - -namespace fe { -#define FE_MAKE_SHARED(exec_expr0, exec_expr1) \ - do { \ - try { \ - exec_expr0; \ - } catch (...) { \ - GELOGW("Make shared failed"); \ - exec_expr1; \ - } \ - } while (0) - -class InputSplitInfoImpl { -public: - size_t GetIndex() const { return idx_; } - std::vector GetAxis() { return axis_; } - std::vector GetHeadOverLap() { return head_over_lap_; } - std::vector GetTailOverLap() { return tail_over_lap_; } - void SetIndex(const size_t& idx) { idx_ = idx; } - void SetAxis(std::vector& axis) { axis_ = axis; } - void SetHeadOverLap(std::vector& head_over_lap) { head_over_lap_ = head_over_lap; } - void SetTailOverLap(std::vector& tail_over_lap) { tail_over_lap_ = tail_over_lap; } - -private: - size_t idx_ = 0; - std::vector axis_; - std::vector head_over_lap_; - std::vector tail_over_lap_; -}; - -InputSplitInfo::InputSplitInfo() {} - -InputSplitInfo::~InputSplitInfo() {} - -InputSplitInfo::InputSplitInfo(const InputSplitInfo &input_split_info) { - this->split_impl_ = input_split_info.split_impl_; -} - -InputSplitInfo &InputSplitInfo::operator = (const InputSplitInfo &input_split_info) { - this->split_impl_ = input_split_info.split_impl_; - return *this; -} - -bool InputSplitInfo::IsPtrNull() const { - if (split_impl_ == nullptr) { - return true; - } - - return false; -} - -bool InputSplitInfo::Initialize() { - FE_MAKE_SHARED(split_impl_ = std::make_shared(), return false); - if (split_impl_== nullptr) { - return false; - } - return true; -} - -size_t InputSplitInfo::GetIndex() const { - return split_impl_->GetIndex(); -} - -std::vector InputSplitInfo::GetAxis() const { - return split_impl_->GetAxis(); -} - -std::vector InputSplitInfo::GetHeadOverLap() const { - return split_impl_->GetHeadOverLap(); -} - -std::vector InputSplitInfo::GetTailOverLap() const { - return split_impl_->GetTailOverLap(); -} - -void InputSplitInfo::SetIndex(const size_t& idx) { - split_impl_->SetIndex(idx); -} - -void InputSplitInfo::SetAxis(std::vector& axis) { - split_impl_->SetAxis(axis); -} - -void InputSplitInfo::SetHeadOverLap(std::vector& head_over_lap) { - split_impl_->SetHeadOverLap(head_over_lap); -} - -void InputSplitInfo::SetTailOverLap(std::vector& tail_over_lap) { - split_impl_->SetTailOverLap(tail_over_lap); -} - -class OutputSplitInfoImpl { -public: - size_t GetIndex() const { return idx_; } - std::vector GetAxis() { return axis_; } - void SetIndex(const size_t& idx) { idx_ = idx; } - void SetAxis(std::vector& axis) { axis_ = axis; } - -private: - size_t idx_ = 0; - std::vector axis_; -}; - -OutputSplitInfo::OutputSplitInfo() {} - -OutputSplitInfo::~OutputSplitInfo() {} - -OutputSplitInfo::OutputSplitInfo(const OutputSplitInfo &output_split_info) { - this->split_impl_ = output_split_info.split_impl_; -} - -OutputSplitInfo &OutputSplitInfo::operator = (const OutputSplitInfo &output_split_info) { - this->split_impl_ = output_split_info.split_impl_; - return *this; -} - -bool OutputSplitInfo::IsPtrNull() const { - if (split_impl_ == nullptr) { - return true; - } - return false; -} - -bool OutputSplitInfo::Initialize() { - FE_MAKE_SHARED(split_impl_ = std::make_shared(), return false); - if (split_impl_== nullptr) { - return false; - } - return true; -} - -size_t OutputSplitInfo::GetIndex() const { - return split_impl_->GetIndex(); -} - -std::vector OutputSplitInfo::GetAxis() const { - return split_impl_->GetAxis(); -} - -void OutputSplitInfo::SetIndex(const size_t& idx) { - split_impl_->SetIndex(idx); -} - -void OutputSplitInfo::SetAxis(std::vector& axis) { - split_impl_->SetAxis(axis); -} - -class InputReduceInfoImpl { -public: - size_t GetIndex() const { return idx_; } - std::vector GetAxis() { return axis_; } - void SetIndex(const size_t& idx) { idx_ = idx; } - void SetAxis(std::vector& axis) { axis_ = axis; } - -private: - size_t idx_ = 0; - std::vector axis_; -}; - -InputReduceInfo::InputReduceInfo() {} - -InputReduceInfo::~InputReduceInfo() {} - -bool InputReduceInfo::IsPtrNull() const { - if (reduce_impl_ == nullptr) { - return true; - } - return false; -} - -bool InputReduceInfo::Initialize() { - FE_MAKE_SHARED(reduce_impl_ = std::make_shared(), return false); - if (reduce_impl_ == nullptr) { - return false; - } - return true; -} - -InputReduceInfo::InputReduceInfo(const InputReduceInfo &input_reduce_info) { - this->reduce_impl_ = input_reduce_info.reduce_impl_; -} - -InputReduceInfo &InputReduceInfo::operator = (const InputReduceInfo &input_reduce_info) { - this->reduce_impl_ = input_reduce_info.reduce_impl_; - return *this; -} - -size_t InputReduceInfo::GetIndex() const { - return reduce_impl_->GetIndex(); -} - -std::vector InputReduceInfo::GetAxis() const { - return reduce_impl_->GetAxis(); -} - -void InputReduceInfo::SetIndex(const size_t& idx) { - reduce_impl_->SetIndex(idx); -} - -void InputReduceInfo::SetAxis(std::vector& axis) { - reduce_impl_->SetAxis(axis); -} - -class OutputReduceInfoImpl { -public: - size_t GetIndex() const { return idx_; } - OpReduceType GetReduceType() const { return reduce_type_; } - bool GetIsAtomic() const { return is_atomic_; } - void SetIndex(const size_t& idx) { idx_ = idx; } - void SetReduceType(const OpReduceType& reduce_type) { reduce_type_ = reduce_type; } - void SetIsAtomic(const bool& is_atomic) { is_atomic_ = is_atomic; } -private: - size_t idx_ = 0; - OpReduceType reduce_type_; - bool is_atomic_{false}; -}; - -OutputReduceInfo::OutputReduceInfo() {} - -OutputReduceInfo::~OutputReduceInfo() {} - -OutputReduceInfo::OutputReduceInfo(const OutputReduceInfo &output_reduce_info) { - this->reduce_impl_ = output_reduce_info.reduce_impl_; -} - -OutputReduceInfo &OutputReduceInfo::operator = (const OutputReduceInfo &output_reduce_info) { - this->reduce_impl_ = output_reduce_info.reduce_impl_; - return *this; -} - -bool OutputReduceInfo::IsPtrNull() const { - if (reduce_impl_ == nullptr) { - return true; - } - return false; -} - -bool OutputReduceInfo::Initialize() { - FE_MAKE_SHARED(reduce_impl_ = std::make_shared(), return false); - if (reduce_impl_ == nullptr) { - return false; - } - return true; -} - -size_t OutputReduceInfo::GetIndex() const { - return reduce_impl_->GetIndex(); -} - -OpReduceType OutputReduceInfo::GetReduceType() const { - return reduce_impl_->GetReduceType(); -} - -bool OutputReduceInfo::GetIsAtomic() const { - return reduce_impl_->GetIsAtomic(); -} - -void OutputReduceInfo::SetIndex(const size_t& idx) { - reduce_impl_->SetIndex(idx); -} - -void OutputReduceInfo::SetReduceType(const OpReduceType& reduce_type) { - reduce_impl_->SetReduceType(reduce_type); -} - -void OutputReduceInfo::SetIsAtomic(const bool& is_atomic) { - reduce_impl_->SetIsAtomic(is_atomic); -} - -class AxisSplitMapImpl { -public: - std::vector GetInputSplitInfos() { return input_split_vec_; } - std::vector GetOutputSplitInfos() { return output_split_vec_; } - void AddInputSplitInfo(const InputSplitInfo& input_split_info) { - InputSplitInfoPtr input_split_info_ptr = nullptr; - FE_MAKE_SHARED(input_split_info_ptr = std::make_shared(input_split_info), return); - if (input_split_info_ptr == nullptr) { - return; - } - if (input_split_info_ptr->IsPtrNull()) { - if (!input_split_info_ptr->Initialize()) { - return; - } - } - - input_split_vec_.push_back(input_split_info_ptr); - } - - void SetInputSplitInfos(std::vector& input_split_vec) { - input_split_vec_.clear(); - for (InputSplitInfo &input_split_info : input_split_vec) { - AddInputSplitInfo(input_split_info); - } - } - - void SetInputSplitInfos(std::vector& input_split_vec) { - input_split_vec_ = input_split_vec; - } - - void AddOutputSplitInfo(const OutputSplitInfo& output_split_info) { - OutputSplitInfoPtr output_split_info_ptr = nullptr; - FE_MAKE_SHARED(output_split_info_ptr = std::make_shared(output_split_info), return); - if (output_split_info_ptr == nullptr) { - return; - } - - if (output_split_info_ptr->IsPtrNull()) { - if (!output_split_info_ptr->Initialize()) { - return; - } - } - - output_split_vec_.push_back(output_split_info_ptr); - } - - void SetOutputSplitInfos(std::vector& output_split_vec) { - output_split_vec_.clear(); - for (OutputSplitInfo &output_split_info : output_split_vec) { - AddOutputSplitInfo(output_split_info); - } - } - - void SetOutputSplitInfos(std::vector& output_split_vec) { - output_split_vec_ = output_split_vec; - } - -private: - std::vector input_split_vec_; - std::vector output_split_vec_; -}; - -AxisSplitMap::AxisSplitMap() {} - -AxisSplitMap::~AxisSplitMap() {} - -AxisSplitMap::AxisSplitMap(const AxisSplitMap &axis_split_map) { - this->aixs_split_impl_ = axis_split_map.aixs_split_impl_; -} - -AxisSplitMap &AxisSplitMap::operator = (const AxisSplitMap &axis_split_map) { - this->aixs_split_impl_ = axis_split_map.aixs_split_impl_; - return *this; -} - -bool AxisSplitMap::IsPtrNull() const { - if (aixs_split_impl_ == nullptr) { - return true; - } - return false; -} - -bool AxisSplitMap::Initialize() { - FE_MAKE_SHARED(aixs_split_impl_ = std::make_shared(), return false); - if (aixs_split_impl_ == nullptr) { - return false; - } - return true; -} - -std::vector AxisSplitMap::GetInputSplitInfos() const { - return aixs_split_impl_->GetInputSplitInfos(); -} - -std::vector AxisSplitMap::GetInputSplitInfoVec() const { - std::vector ret; - for (InputSplitInfoPtr info_ptr : aixs_split_impl_->GetInputSplitInfos()) { - ret.push_back(*info_ptr); - } - return ret; -} - -std::vector AxisSplitMap::GetOutputSplitInfos() const { - return aixs_split_impl_->GetOutputSplitInfos(); -} - -std::vector AxisSplitMap::GetOutputSplitInfoVec() const { - std::vector ret; - for (OutputSplitInfoPtr info_ptr : aixs_split_impl_->GetOutputSplitInfos()) { - ret.push_back(*info_ptr); - } - return ret; -} - -void AxisSplitMap::AddInputSplitInfo(InputSplitInfo& input_split_info) { - aixs_split_impl_->AddInputSplitInfo(input_split_info); -} - -void AxisSplitMap::SetInputSplitInfos(std::vector& input_split_vec) { - aixs_split_impl_->SetInputSplitInfos(input_split_vec); -} - -void AxisSplitMap::SetInputSplitInfos(std::vector& input_split_vec) { - aixs_split_impl_->SetInputSplitInfos(input_split_vec); -} - -void AxisSplitMap::AddOutputSplitInfo(OutputSplitInfo& output_split_info) { - aixs_split_impl_->AddOutputSplitInfo(output_split_info); -} - -void AxisSplitMap::SetOutputSplitInfos(std::vector& output_split_vec) { - aixs_split_impl_->SetOutputSplitInfos(output_split_vec); -} - -void AxisSplitMap::SetOutputSplitInfos(std::vector& output_split_vec) { - aixs_split_impl_->SetOutputSplitInfos(output_split_vec); -} - -class AxisReduceMapImpl { -public: - std::vector GetInputReduceInfos() { return input_reduce_vec_; } - std::vector GetOutputReduceInfos() { return output_reduce_vec_; } - - void AddInputReduceInfo(const InputReduceInfo& input_reduce_info) { - InputReduceInfoPtr input_reduce_info_ptr = nullptr; - FE_MAKE_SHARED(input_reduce_info_ptr = std::make_shared(input_reduce_info), return); - if (input_reduce_info_ptr == nullptr) { - return; - } - - if(input_reduce_info_ptr->IsPtrNull()) { - if (!input_reduce_info_ptr->Initialize()) { - return; - } - } - - input_reduce_vec_.push_back(input_reduce_info_ptr); - } - - void SetInputReduceInfos(std::vector& input_reduce_vec) { - input_reduce_vec_.clear(); - for (InputReduceInfo &input_reduce_info : input_reduce_vec) { - AddInputReduceInfo(input_reduce_info); - } - } - - void SetInputReduceInfos(std::vector& input_reduce_vec) { - input_reduce_vec_ = input_reduce_vec; - } - - void AddOutputReduceInfo(const OutputReduceInfo& output_reduce_info) { - OutputReduceInfoPtr output_reduce_info_ptr = nullptr; - FE_MAKE_SHARED(output_reduce_info_ptr = std::make_shared(output_reduce_info), return); - if (output_reduce_info_ptr == nullptr) { - return; - } - - if (output_reduce_info_ptr->IsPtrNull()) { - if (!output_reduce_info_ptr->Initialize()) { - return; - } - } - output_reduce_vec_.push_back(output_reduce_info_ptr); - } - - void SetOutputReduceInfos(std::vector& output_reduce_vec) { - output_reduce_vec_.clear(); - for (OutputReduceInfo &output_reduce_info : output_reduce_vec) { - AddOutputReduceInfo(output_reduce_info); - } - } - - void SetOutputReduceInfos(std::vector& output_reduce_vec) { - output_reduce_vec_ = output_reduce_vec; - } - -private: - std::vector input_reduce_vec_; - std::vector output_reduce_vec_; -}; - -AxisReduceMap::AxisReduceMap() {} - -AxisReduceMap::~AxisReduceMap() {} - -AxisReduceMap::AxisReduceMap(const AxisReduceMap &axis_reduce_map) { - this->aixs_reduce_impl_ = axis_reduce_map.aixs_reduce_impl_; -} - -AxisReduceMap &AxisReduceMap::operator = (const AxisReduceMap &axis_reduce_map) { - this->aixs_reduce_impl_ = axis_reduce_map.aixs_reduce_impl_; - return *this; -} - -bool AxisReduceMap::IsPtrNull() const { - if (aixs_reduce_impl_ == nullptr) { - return true; - } - return false; -} - -bool AxisReduceMap::Initialize() { - FE_MAKE_SHARED(aixs_reduce_impl_ = std::make_shared(), return false); - if (aixs_reduce_impl_ == nullptr) { - return false; - } - return true; -} - -std::vector AxisReduceMap::GetInputReduceInfos() const { - return aixs_reduce_impl_->GetInputReduceInfos(); -} - -std::vector AxisReduceMap::GetInputReduceInfoVec() const { - std::vector ret; - for (InputReduceInfoPtr info_ptr : aixs_reduce_impl_->GetInputReduceInfos()) { - ret.push_back(*info_ptr); - } - return ret; -} - -std::vector AxisReduceMap::GetOutputReduceInfos() const { - return aixs_reduce_impl_->GetOutputReduceInfos(); -} - -std::vector AxisReduceMap::GetOutputReduceInfoVec() const { - std::vector ret; - for (OutputReduceInfoPtr info_ptr : aixs_reduce_impl_->GetOutputReduceInfos()) { - ret.push_back(*info_ptr); - } - return ret; -} - -void AxisReduceMap::AddInputReduceInfo(InputReduceInfo& input_reduce_info) { - aixs_reduce_impl_->AddInputReduceInfo(input_reduce_info); -} - -void AxisReduceMap::SetInputReduceInfos(std::vector& input_reduce_vec) { - aixs_reduce_impl_->SetInputReduceInfos(input_reduce_vec); -} - -void AxisReduceMap::SetInputReduceInfos(std::vector& input_reduce_vec) { - aixs_reduce_impl_->SetInputReduceInfos(input_reduce_vec); -} - -void AxisReduceMap::AddOutputReduceInfo(OutputReduceInfo& output_reduce_info) { - aixs_reduce_impl_->AddOutputReduceInfo(output_reduce_info); -} - -void AxisReduceMap::SetOutputReduceInfos(std::vector& output_reduce_vec) { - aixs_reduce_impl_->SetOutputReduceInfos(output_reduce_vec); -} - -void AxisReduceMap::SetOutputReduceInfos(std::vector& output_reduce_vec) { - aixs_reduce_impl_->SetOutputReduceInfos(output_reduce_vec); -} - -class OpCalcInfoImpl { -public: - std::vector GetAxisSplitMaps() { return axis_split_vec_; } - std::vector GetAxisReduceMaps() { return axis_reduce_vec_; } - OpL1FusionType GetL1FusionEnable() const { return l1_fusion_enable_; } - int64_t GetMinTbeL1Space() const { return min_tbe_l1_space_; } - - void AddAxisSplitMap(const AxisSplitMap& axis_split_map) { - AxisSplitMapPtr axis_split_map_ptr = nullptr; - FE_MAKE_SHARED(axis_split_map_ptr = std::make_shared(axis_split_map), return); - if (axis_split_map_ptr == nullptr) { - return; - } - axis_split_vec_.push_back(axis_split_map_ptr); - } - - void SetAxisSplitMaps(std::vector& axis_split_vec) { - axis_split_vec_.clear(); - for (AxisSplitMap &axis_split_map : axis_split_vec) { - AddAxisSplitMap(axis_split_map); - } - } - - void SetAxisSplitMaps(std::vector& axis_split_vec) { axis_split_vec_ = axis_split_vec; } - - void AddAxisReduceMap(const AxisReduceMap& axis_reduce_map) { - AxisReduceMapPtr axis_reduce_map_ptr = nullptr; - FE_MAKE_SHARED(axis_reduce_map_ptr = std::make_shared(axis_reduce_map), return); - if (axis_reduce_map_ptr == nullptr) { - return; - } - axis_reduce_vec_.push_back(axis_reduce_map_ptr); - } - - void SetAxisReduceMaps(std::vector& axis_reduce_vec) { - axis_reduce_vec_.clear(); - for (AxisReduceMap &axis_reduce_map : axis_reduce_vec) { - AddAxisReduceMap(axis_reduce_map); - } - } - - void SetAxisReduceMaps(std::vector& axis_reduce_vec) { axis_reduce_vec_ = axis_reduce_vec; } - void SetL1FusionEnable(const OpL1FusionType& l1_fusion_enable) { l1_fusion_enable_ = l1_fusion_enable; } - void SetMinTbeL1Space(const int64_t& min_tbe_l1_space) { min_tbe_l1_space_ = min_tbe_l1_space; } - - void DelAxisSplitMapBaseAxis(std::vector& axis) { - AxisSplitMapPtr temp_axis_split_map; - for (AxisSplitMapPtr axis_split : axis_split_vec_) { - for (InputSplitInfoPtr input_split : axis_split->GetInputSplitInfos()) { - if (input_split->GetAxis() == axis) { - temp_axis_split_map = axis_split; - } - } - } - - std::vector::iterator iter = - std::find(axis_split_vec_.begin(), axis_split_vec_.end(), temp_axis_split_map); - if (iter != axis_split_vec_.end()) { - std::ignore = axis_split_vec_.erase(iter); - } - } - -private: - std::vector axis_split_vec_; - std::vector axis_reduce_vec_; - OpL1FusionType l1_fusion_enable_ = L1FUSION_DISABLE; - int64_t min_tbe_l1_space_ = 0; -}; - -OpCalcInfo::OpCalcInfo() {} - -OpCalcInfo::~OpCalcInfo() {} - -bool OpCalcInfo::IsPtrNull() const { - if (op_calc_info_impl_ == nullptr) { - return true; - } - return false; -} - -bool OpCalcInfo::Initialize() { - FE_MAKE_SHARED(op_calc_info_impl_ = std::make_shared(), return false); - if (op_calc_info_impl_ == nullptr) { - return false; - } - return true; -} - -std::vector OpCalcInfo::GetAxisSplitMaps() const { - return op_calc_info_impl_->GetAxisSplitMaps(); -} - -std::vector OpCalcInfo::GetAxisReduceMaps() const { - return op_calc_info_impl_->GetAxisReduceMaps(); -} - -std::vector OpCalcInfo::GetAxisSplitMapVec() const { - std::vector ret; - for (AxisSplitMapPtr map_ptr : op_calc_info_impl_->GetAxisSplitMaps()) { - ret.push_back(*map_ptr); - } - return ret; -} - -std::vector OpCalcInfo::GetAxisReduceMapVec() const { - std::vector ret; - for (AxisReduceMapPtr map_ptr : op_calc_info_impl_->GetAxisReduceMaps()) { - ret.push_back(*map_ptr); - } - return ret; -} - -OpL1FusionType OpCalcInfo::GetL1FusionEnable() const { - return op_calc_info_impl_->GetL1FusionEnable(); -} - -int64_t OpCalcInfo::GetMinTbeL1Space() const { - return op_calc_info_impl_->GetMinTbeL1Space(); -} - -void OpCalcInfo::AddAxisSplitMap(AxisSplitMap& axis_split_map) { - op_calc_info_impl_->AddAxisSplitMap(axis_split_map); -} - -void OpCalcInfo::SetAxisSplitMaps(std::vector& axis_split_vec) { - op_calc_info_impl_->SetAxisSplitMaps(axis_split_vec); -} - -void OpCalcInfo::SetAxisSplitMaps(std::vector& axis_split_vec) { - op_calc_info_impl_->SetAxisSplitMaps(axis_split_vec); -} - -void OpCalcInfo::AddAxisReduceMap(AxisReduceMap& axis_reduce_map) { - op_calc_info_impl_->AddAxisReduceMap(axis_reduce_map); -} - -void OpCalcInfo::SetAxisReduceMaps(std::vector& axis_reduce_vec) { - op_calc_info_impl_->SetAxisReduceMaps(axis_reduce_vec); -} - -void OpCalcInfo::SetAxisReduceMaps(std::vector& axis_reduce_vec) { - op_calc_info_impl_->SetAxisReduceMaps(axis_reduce_vec); -} - -void OpCalcInfo::SetL1FusionEnable(const OpL1FusionType& l1_fusion_enable) { - op_calc_info_impl_->SetL1FusionEnable(l1_fusion_enable); -} - -void OpCalcInfo::SetMinTbeL1Space(const int64_t& min_tbe_l1_space) { - op_calc_info_impl_->SetMinTbeL1Space(min_tbe_l1_space); -} - -void OpCalcInfo::DelAxisSplitMapBaseAxis(std::vector& axis) { - op_calc_info_impl_->DelAxisSplitMapBaseAxis(axis); -} - -} // namespace fe diff --git a/register/graph_optimizer/fusion_common/unknown_shape_utils.cc b/register/graph_optimizer/fusion_common/unknown_shape_utils.cc deleted file mode 100644 index 5c58d9a9ce087bd188886af7a2405e95569663f8..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/fusion_common/unknown_shape_utils.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/unknown_shape_utils.h" -#include "graph/debug/ge_log.h" -namespace fe { -const std::string ATTR_NAME_UNKNOWN_SHAPE_OP = "_unknown_shape"; -bool UnknownShapeUtils::IsUnKnownShapeTensor(const ge::OpDesc &op_desc) { - for (auto &tenosr_desc_ptr : op_desc.GetAllInputsDescPtr()) { - if (tenosr_desc_ptr == nullptr) { - continue; - } - if (tenosr_desc_ptr->GetShape().IsUnknownShape()) { - return true; - } - } - - for (auto &tenosr_desc_ptr : op_desc.GetAllOutputsDescPtr()) { - if (tenosr_desc_ptr == nullptr) { - continue; - } - if (tenosr_desc_ptr->GetShape().IsUnknownShape()) { - return true; - } - } - - return false; -} - -bool UnknownShapeUtils::IsUnknownShapeOp(const ge::OpDesc &op_desc) { - bool unknown_shape_status = false; - if (ge::AttrUtils::GetBool(op_desc, ATTR_NAME_UNKNOWN_SHAPE_OP, unknown_shape_status)) { - return unknown_shape_status; - } - if (op_desc.GetAllInputsSize() != 0 || op_desc.GetOutputsSize() != 0) { - unknown_shape_status = IsUnKnownShapeTensor(op_desc); - } - ge::OpDesc *no_const_op_desc = const_cast(&op_desc); - (void)ge::AttrUtils::SetBool(*no_const_op_desc, ATTR_NAME_UNKNOWN_SHAPE_OP, unknown_shape_status); - GELOGD("Op[%s, %s] Set attr unknown_shape [%d].", op_desc.GetName().c_str(), op_desc.GetType().c_str(), - unknown_shape_status); - return unknown_shape_status; -} - -bool UnknownShapeUtils::IsContainUnknownDimNum(const ge::OpDesc &op_desc) { - for (auto &ptr : op_desc.GetAllInputsDescPtr()) { - if (ptr->GetShape().IsUnknownDimNum()) { - GELOGD("Op[name:%s,type:%s] has an input tensor whose shape contains -2.", op_desc.GetName().c_str(), - op_desc.GetType().c_str()); - return true; - } - } - - for (auto &ptr : op_desc.GetAllOutputsDescPtr()) { - if (ptr->GetShape().IsUnknownDimNum()) { - GELOGD("Op[name=%s, type=%s] has an output tensor whose shape contains -2.", op_desc.GetName().c_str(), - op_desc.GetType().c_str()); - return true; - } - } - - return false; -} - -bool UnknownShapeUtils::IsUnknownShapeValue(const int64_t &value) { - if (value == ge::UNKNOWN_DIM || value == ge::UNKNOWN_DIM_NUM) { - return true; - } - return false; -} -} // namespace fe diff --git a/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc b/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc deleted file mode 100644 index 6530905fbe2e5953468fbc8bf9da9691b629adb4..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/fusion_statistic/fusion_statistic_recorder.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/fusion_statistic_recorder.h" -#include -#include "graph/debug/ge_log.h" - -namespace fe { - -FusionStatisticRecorder::FusionStatisticRecorder(){}; - -FusionStatisticRecorder::~FusionStatisticRecorder(){}; - -FusionStatisticRecorder &FusionStatisticRecorder::Instance() { - static FusionStatisticRecorder fusion_statistic_recoder; - return fusion_statistic_recoder; -} - -void FusionStatisticRecorder::UpdateGraphFusionMatchTimes(const FusionInfo &fusion_info) { - const std::lock_guard my_lock(mutex_); - if (fusion_info.GetMatchTimes() != 0) { - const std::string session_and_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + \ - fusion_info.GetGraphId(); - graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].AddMatchTimes(fusion_info.GetMatchTimes()); - GELOGD("session %lu, graph %s, pass %s, match_times value: %d", fusion_info.GetSessionId(), - fusion_info.GetGraphId().c_str(), fusion_info.GetPassName().c_str(), - graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].GetMatchTimes()); - } -} - -void FusionStatisticRecorder::UpdateGraphFusionEffectTimes(const FusionInfo &fusion_info) { - const std::lock_guard my_lock(mutex_); - if (fusion_info.GetEffectTimes() != 0) { - const std::string session_and_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + \ - fusion_info.GetGraphId(); - graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].AddEffectTimes( - fusion_info.GetEffectTimes()); - GELOGD("session %lu, graph %s, pass %s, effect_times value: %d", fusion_info.GetSessionId(), - fusion_info.GetGraphId().c_str(), fusion_info.GetPassName().c_str(), - graph_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].GetEffectTimes()); - } -} - -void FusionStatisticRecorder::UpdateBufferFusionMatchTimes(const FusionInfo &fusion_info) { - const std::lock_guard my_lock(mutex_); - const std::string session_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + fusion_info.GetGraphId(); - if (fusion_info.GetMatchTimes() != 0) { - buffer_fusion_info_map_[session_graph_id][fusion_info.GetPassName()].AddMatchTimes(fusion_info.GetMatchTimes()); - } - - if (fusion_info.GetRepoHitTimes() != 0) { - buffer_fusion_info_map_[session_graph_id][fusion_info.GetPassName()].SetRepoHitTimes(fusion_info.GetRepoHitTimes()); - } - GELOGD("Updated match time of pass [%s] for graph [%s] and session [%lu].", - fusion_info.GetPassName().c_str(), fusion_info.GetGraphId().c_str(), fusion_info.GetSessionId()); - GELOGD("Match times is [%d] and repo match times is [%d].", - buffer_fusion_info_map_[session_graph_id][fusion_info.GetPassName()].GetMatchTimes(), - buffer_fusion_info_map_[session_graph_id][fusion_info.GetPassName()].GetRepoHitTimes()); -} - -void FusionStatisticRecorder::UpdateBufferFusionEffectTimes(const FusionInfo &fusion_info) { - const std::lock_guard my_lock(mutex_); - if (fusion_info.GetEffectTimes() != 0) { - const std::string session_and_graph_id = std::to_string(fusion_info.GetSessionId()) + "_" + \ - fusion_info.GetGraphId(); - buffer_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].AddEffectTimes( - fusion_info.GetEffectTimes()); - GELOGD("ub session %lu, graph %s, pass %s, effect_times value: %d", fusion_info.GetSessionId(), - fusion_info.GetGraphId().c_str(), fusion_info.GetPassName().c_str(), - buffer_fusion_info_map_[session_and_graph_id][fusion_info.GetPassName()].GetEffectTimes()); - } -} - -void FusionStatisticRecorder::GetAndClearFusionInfo(const std::string &session_graph_id, - std::map &graph_fusion_info_map, - std::map &buffer_fusion_info_map) { - const std::lock_guard my_lock(mutex_); - GELOGD("start to get graph map size %zu", graph_fusion_info_map_.size()); - GELOGD("start to get ub graph map size: %zu", buffer_fusion_info_map_.size()); - GetFusionInfo(session_graph_id, graph_fusion_info_map, buffer_fusion_info_map); - ClearFusionInfo(session_graph_id); -} - -void FusionStatisticRecorder::GetFusionInfo(const std::string &session_graph_id, - std::map &graph_fusion_info_map, - std::map &buffer_fusion_info_map) { - if (graph_fusion_info_map_.find(session_graph_id) != graph_fusion_info_map_.end()) { - graph_fusion_info_map = graph_fusion_info_map_[session_graph_id]; - } - if (buffer_fusion_info_map_.find(session_graph_id) != buffer_fusion_info_map_.end()) { - buffer_fusion_info_map = buffer_fusion_info_map_[session_graph_id]; - } -} - -void FusionStatisticRecorder::ClearFusionInfo(const std::string& session_graph_id) { - if (graph_fusion_info_map_.find(session_graph_id) != graph_fusion_info_map_.end()) { - (void)graph_fusion_info_map_.erase(session_graph_id); - } - if (buffer_fusion_info_map_.find(session_graph_id) != buffer_fusion_info_map_.end()) { - (void)buffer_fusion_info_map_.erase(session_graph_id); - } -} - -void FusionStatisticRecorder::GetAllSessionAndGraphIdList(std::vector &session_graph_id_vec) { - if (!graph_fusion_info_map_.empty()) { - for (auto iter = graph_fusion_info_map_.cbegin(); iter != graph_fusion_info_map_.cend(); iter++) { - session_graph_id_vec.push_back(iter->first); - } - } - if (!buffer_fusion_info_map_.empty()) { - for (auto iter = buffer_fusion_info_map_.cbegin(); iter != buffer_fusion_info_map_.cend(); iter++) { - if (std::find(session_graph_id_vec.begin(), session_graph_id_vec.end(), iter->first) - == session_graph_id_vec.end()) { - session_graph_id_vec.push_back(iter->first); - } - } - } -} - -FusionInfo::FusionInfo(const uint64_t session_id, const std::string graph_id, const std::string pass_name, - const int32_t match_times, const int32_t effect_times, const int32_t repo_hit_times) - : session_id_(session_id), graph_id_(graph_id), pass_name_(pass_name), - match_times_(match_times), effect_times_(effect_times), repo_hit_times_(repo_hit_times) {} - -FusionInfo::~FusionInfo() {} - -void FusionInfo::AddMatchTimes(const int32_t match_times) { - if (this->match_times_ > std::numeric_limits::max() - match_times) { - return; - } - this->match_times_ += match_times; -} - -void FusionInfo::AddEffectTimes(const int32_t effect_times) { - if (this->effect_times_ > std::numeric_limits::max() - effect_times) { - return; - } - this->effect_times_ += effect_times; -} - -int32_t FusionInfo::GetMatchTimes() const { return match_times_; } - -void FusionInfo::SetMatchTimes(const int32_t match_times) { this->match_times_ = match_times; } - -int32_t FusionInfo::GetEffectTimes() const { return effect_times_; } - -void FusionInfo::SetEffectTimes(const int32_t effect_times) { this->effect_times_ = effect_times; } - -int32_t FusionInfo::GetRepoHitTimes() const { return repo_hit_times_; } - -void FusionInfo::SetRepoHitTimes(const int32_t repo_hit_times) { this->repo_hit_times_ = repo_hit_times; } - -std::string FusionInfo::GetGraphId() const { return graph_id_; } - -std::string FusionInfo::GetPassName() const { return pass_name_; } - -uint64_t FusionInfo::GetSessionId() const { return session_id_; } -} diff --git a/register/graph_optimizer/graph_fusion/connection_matrix.cc b/register/graph_optimizer/graph_fusion/connection_matrix.cc deleted file mode 100644 index 40911f2cc607c2a85d53be4c7b97a09ac6ee00d9..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/connection_matrix.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/graph_fusion/connection_matrix.h" -#include "graph/debug/ge_log.h" - -namespace fe { -ConnectionMatrix::ConnectionMatrix() : enable_data_flow_(false) {} -ConnectionMatrix::ConnectionMatrix(bool enable_data_flow) : enable_data_flow_(enable_data_flow) {} -ConnectionMatrix::ConnectionMatrix(const ge::ComputeGraph &graph) : enable_data_flow_(false) { - (void)graph; -} - -ConnectionMatrix::~ConnectionMatrix() { - bit_maps.clear(); - name_to_index_.clear(); - bit_maps_back_up_.clear(); - data_bit_maps_.clear(); - data_bit_maps_back_up_.clear(); -} - -void ConnectionMatrix::Generate(const ge::ComputeGraph &graph) { - bit_maps.clear(); - data_bit_maps_.clear(); - name_to_index_.clear(); - const ge::ComputeGraph::Vistor direct_nodes = graph.GetDirectNode(); - size_ = direct_nodes.size(); - bit_maps.reserve(size_); - int64_t index_loop = 0; - for (const ge::NodePtr &node : direct_nodes) { - name_to_index_[node->GetName()] = index_loop; - bit_maps.emplace_back(size_); - index_loop++; - } - - if (enable_data_flow_) { - data_bit_maps_ = bit_maps; - } - - for (const ge::NodePtr &node : direct_nodes) { - const ge::Node::Vistor inputs = node->GetInAllNodes(); - SetConnectivity(inputs, node); - if (enable_data_flow_) { - const ge::Node::Vistor data_inputs = node->GetInDataNodes(); - SetDataConnectivity(data_inputs, node); - } - } -} - -void ConnectionMatrix::Update(const ge::ComputeGraph &graph, const std::vector &fusion_nodes) { - ge::LargeBitmap new_bit_vector(graph.GetDirectNodesSize()); - new_bit_vector.SetValues(0U); - std::vector fusion_indexs(fusion_nodes.size(), 0); - for (size_t i = 0U; i < fusion_nodes.size(); ++i) { - const uint64_t index = static_cast(GetIndex(fusion_nodes[i])); - new_bit_vector.Or(GetBitMap(index)); - fusion_indexs[i] = index; - } - - for (ge::LargeBitmap &node_map: bit_maps) { - for (size_t i = 0U; i < fusion_nodes.size(); ++i) { - if (node_map.GetBit(fusion_indexs[i])) { - node_map.Or(new_bit_vector); - break; - } - } - } - - if (enable_data_flow_) { - new_bit_vector.SetValues(0U); - for (size_t i = 0U; i < fusion_nodes.size(); ++i) { - const uint64_t index = static_cast(GetIndex(fusion_nodes[i])); - new_bit_vector.Or(GetDataBitMap(index)); - } - for (ge::LargeBitmap &node_map: data_bit_maps_) { - for (size_t i = 0U; i < fusion_nodes.size(); ++i) { - if (node_map.GetBit(fusion_indexs[i])) { - node_map.Or(new_bit_vector); - break; - } - } - } - } -} - -void ConnectionMatrix::BackupBitMap() { - bit_maps_back_up_ = bit_maps; - data_bit_maps_back_up_ = data_bit_maps_; -} - -void ConnectionMatrix::RestoreBitMap() { - bit_maps = bit_maps_back_up_; - data_bit_maps_ = data_bit_maps_back_up_; -} - -void ConnectionMatrix::SetConnectivity(const ge::Node::Vistor &inputs, const ge::NodePtr &node) { - ge::LargeBitmap &bitmap = GetBitMap(node); - if (std::find(inputs.begin(), inputs.end(), node) == inputs.end()) { - bitmap.SetValues(0); - } - - bitmap.SetBit(static_cast(GetIndex(node))); - for (const ge::NodePtr &input : inputs) { - if (input != node) { - bitmap.Or(GetBitMap(input)); - } - } -} - -void ConnectionMatrix::SetDataConnectivity(const ge::Node::Vistor &inputs, const ge::NodePtr &node) { - ge::LargeBitmap &bitmap = GetDataBitMap(node); - if (std::find(inputs.begin(), inputs.end(), node) == inputs.end()) { - bitmap.SetValues(0); - } - - bitmap.SetBit(static_cast(GetIndex(node))); - for (const ge::NodePtr &input : inputs) { - if (input != node) { - bitmap.Or(GetDataBitMap(input)); - } - } -} - -int64_t ConnectionMatrix::GetIndex(const ge::NodePtr &node) const { - const auto iter = name_to_index_.find(node->GetName()); - if (iter != name_to_index_.end()) { - return iter->second; - } else { - GELOGW("Node %s was not found in name_to_index_.", node->GetName().c_str()); - return 0; - } -} - -bool ConnectionMatrix::IsConnected(const ge::NodePtr &a, const ge::NodePtr &b) const { - return GetBitMap(b).GetBit(static_cast(GetIndex(a))); -} - -bool ConnectionMatrix::IsDataConnected(const ge::NodePtr &a, const ge::NodePtr &b) const { - return GetDataBitMap(b).GetBit(static_cast(GetIndex(a))); -} - -const ge::LargeBitmap &ConnectionMatrix::GetBitMap(const ge::NodePtr &node) const { - return bit_maps[static_cast(GetIndex(node))]; -} - -ge::LargeBitmap &ConnectionMatrix::GetBitMap(const ge::NodePtr &node) { - return bit_maps[static_cast(GetIndex(node))]; -} - -ge::LargeBitmap &ConnectionMatrix::GetBitMap(uint64_t index) { - return bit_maps[index]; -} - -const ge::LargeBitmap &ConnectionMatrix::GetDataBitMap(const ge::NodePtr &node) const { - return data_bit_maps_[static_cast(GetIndex(node))]; -} - -ge::LargeBitmap &ConnectionMatrix::GetDataBitMap(const ge::NodePtr &node) { - return data_bit_maps_[static_cast(GetIndex(node))]; -} - -ge::LargeBitmap &ConnectionMatrix::GetDataBitMap(uint64_t index) { - return data_bit_maps_[index]; -} -} diff --git a/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc b/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc deleted file mode 100644 index f942d62c9973a0b52df693629b7d96e972eaee08..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/fusion_pass_registry.cc +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include -#include -#include -#include -#include -#include "graph/debug/ge_log.h" - -namespace fe { -class FusionPassRegistry::FusionPassRegistryImpl { - public: - void RegisterPass(const GraphFusionPassType pass_type, const std::string &pass_name, - FusionPassRegistry::CreateFn const create_fn, PassAttr attr) { - RegPassCompileLevel(pass_name, attr); - const std::string pass_module = IsPassAttrTypeOn(attr, PassAttrType::FE_PASS_FLAG) ? "FE" : "TBE"; - const std::lock_guard my_lock(mu_); - if (pass_descs_.find(pass_type) != pass_descs_.end()) { - pass_descs_[pass_type][pass_name].attr = attr; - pass_descs_[pass_type][pass_name].create_fn = create_fn; - GELOGD("GraphFusionPass[type=%d, name=%s, attr=%lu, module=%s]: the pass type has already existed.", - pass_type, pass_name.c_str(), attr, pass_module.c_str()); - return; - } - - std::map pass_desc; - pass_desc[pass_name] = {attr, create_fn}; - pass_descs_[pass_type] = pass_desc; - GELOGD("GraphFusionPass[type=%d, name=%s, attr=%lu, module=%s]: the pass type does not exist.", - pass_type, pass_name.c_str(), attr, pass_module.c_str()); - } - - std::map GetPassDesc(const GraphFusionPassType &pass_type) { - const std::lock_guard my_lock(mu_); - std::map>::const_iterator iter = pass_descs_.find(pass_type); - if (iter == pass_descs_.end()) { - std::map ret; - return ret; - } - - return iter->second; - } - - std::map GetCreateFn(const GraphFusionPassType &pass_type) { - const std::lock_guard my_lock(mu_); - const auto iter = pass_descs_.find(pass_type); - std::map ret; - if (iter == pass_descs_.end()) { - return ret; - } - - for (const auto &ele : iter->second) { - std::ignore = ret.emplace(std::make_pair(ele.first, ele.second.create_fn)); - } - return ret; - } -private: - std::mutex mu_; - std::map> pass_descs_; -}; - -FusionPassRegistry::FusionPassRegistry() { - impl_ = std::unique_ptr(new (std::nothrow) FusionPassRegistryImpl); -} - -FusionPassRegistry::~FusionPassRegistry() {} - -FusionPassRegistry &FusionPassRegistry::GetInstance() { - static FusionPassRegistry instance; - return instance; -} - -void FusionPassRegistry::RegisterPass(const GraphFusionPassType &pass_type, const std::string &pass_name, - CreateFn create_fn, PassAttr attr) const { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "[Check][Param]param impl is nullptr, GraphFusionPass[type=%d,name=%s]: " - "failed to register the graph fusion pass", - pass_type, pass_name.c_str()); - return; - } - impl_->RegisterPass(pass_type, pass_name, create_fn, attr); -} - -std::map FusionPassRegistry::GetPassDesc( - const GraphFusionPassType &pass_type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "[Check][Param]param impl is nullptr, GraphFusionPass[type=%d]: " - "failed to get pass desc.", pass_type); - std::map ret; - return ret; - } - return impl_->GetPassDesc(pass_type); -} - -std::map FusionPassRegistry::GetCreateFnByType( - const GraphFusionPassType &pass_type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "[Check][Param]param impl is nullptr, GraphFusionPass[type=%d]: " - "failed to create the graph fusion pass.", pass_type); - return std::map{}; - } - return impl_->GetCreateFn(pass_type); -} - -FusionPassRegistrar::FusionPassRegistrar(const GraphFusionPassType &pass_type, const std::string &pass_name, - GraphPass *(*create_fn)(), PassAttr attr) { - if ((pass_type < BUILT_IN_GRAPH_PASS) || (pass_type >= GRAPH_FUSION_PASS_TYPE_RESERVED)) { - GELOGE(ge::PARAM_INVALID, "[Check][Param:pass_type] value:%d is not supported.", pass_type); - return; - } - - if (pass_name.empty()) { - GELOGE(ge::PARAM_INVALID, "[Check][Param:pass_name]Failed to register the graph fusion pass, " - "param pass_name is empty."); - return; - } - FusionPassRegistry::GetInstance().RegisterPass(pass_type, pass_name, create_fn, attr); -} -} // namespace fe diff --git a/register/graph_optimizer/graph_fusion/fusion_pattern.cc b/register/graph_optimizer/graph_fusion/fusion_pattern.cc deleted file mode 100644 index f31649c2a86c127e5dd56b54d5085925eb732a00..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/fusion_pattern.cc +++ /dev/null @@ -1,312 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" -#include -#include -#include -#include - -#include "graph/debug/ge_log.h" - -namespace fe { -const uint32_t kFuzzyOutIndex = 0xFFFFFFFF; -constexpr size_t MAX_LOG_LENGTH = 900; -#define FE_PATTERN_ERROR_RETURN_IF(condition, ...) \ - do { \ - if (condition) { \ - SetError(); \ - GELOGW(__VA_ARGS__); \ - return *this; \ - } \ - } while (0) - -#define FE_MAKE_SHARED(exec_expr0, exec_expr1) \ - do { \ - try { \ - exec_expr0; \ - } catch (...) { \ - GELOGW("Make shared failed"); \ - exec_expr1; \ - } \ - } while (0) - -FusionPattern::FusionPattern(const std::string name) : name_(name), output_(nullptr) {} - -FusionPattern::~FusionPattern() { - for (const auto &ops: ops_) { - ops->inputs.clear(); - ops->outputs.clear(); - } - ops_.clear(); - op_map_.clear(); -} - -/** - * @ingroup fe - * @brief set pattern name - */ -FusionPattern &FusionPattern::SetName(const std::string &name) { - name_ = name; - return *this; -} - -/** - * @ingroup fe - * @brief add Op description with unknown number of args - */ -FusionPattern &FusionPattern::AddOpDesc(const std::string &id, const std::initializer_list &types, - const bool allow_dumpable, const bool check_unique) { - return AddOpDesc(id, std::vector(types), allow_dumpable, check_unique); -} - -/** - * @ingroup fe - * @brief add Op description with vector - */ -FusionPattern &FusionPattern::AddOpDesc(const std::string &id, const std::vector &types, - const bool allow_dumpable, const bool check_unique) { - FE_PATTERN_ERROR_RETURN_IF(id.empty(), "ID cannot be empty."); - - FE_PATTERN_ERROR_RETURN_IF(GetOpDesc(id) != nullptr, "ID already exists. (id:%s)", id.c_str()); - - std::shared_ptr op; - FE_MAKE_SHARED(op = std::make_shared(), return *this); - FE_PATTERN_ERROR_RETURN_IF(op == nullptr, "new an object failed."); - - op->id = id; - op->types = types; - op->repeatable = false; - op->is_output = false; - op->is_output_fullmatch = true; - op->output_size = 0UL; - op->allow_dumpable = allow_dumpable; - op->check_unique = check_unique; - - ops_.push_back(op); - op_map_[id] = op; - - return *this; -} - -/** - * @ingroup fe - * @brief set input Ops with unknown number of args - */ -FusionPattern &FusionPattern::SetInputs(const std::string &id, const std::initializer_list &input_ids) { - return SetInputs(id, std::vector(input_ids)); -} - -/** - * @ingroup fe - * @brief set input Ops with vector - */ -FusionPattern &FusionPattern::SetInputs(const std::string &id, const std::vector &input_ids) { - FE_PATTERN_ERROR_RETURN_IF(id.empty(), "Id cannot be empty."); - const std::shared_ptr op_desc = GetOpDesc(id); - FE_PATTERN_ERROR_RETURN_IF(op_desc == nullptr, "Id does not exist. (id:%s)", id.c_str()); - - op_desc->inputs.clear(); - - for (const std::string &input_id : input_ids) { - const std::shared_ptr input_op_desc = GetOpDesc(input_id); - FE_PATTERN_ERROR_RETURN_IF(input_op_desc == nullptr, "Id does not exist. (id:%s)", input_id.c_str()); - op_desc->inputs.push_back(input_op_desc); - } - - return *this; -} - -/** - * @ingroup fe - * @brief set output Ops with vector - */ -FusionPattern &FusionPattern::SetOutputs(const std::string &id, const FusionPattern::OutputMapVecStr &output_map, - bool is_fullmatched) { - if (id.empty()) { - GELOGW("Id cannot be empty."); - return *this; - } - const std::shared_ptr op_desc = GetOpDesc(id); - FE_PATTERN_ERROR_RETURN_IF(op_desc == nullptr, "Id does not exist. (id:%s)", id.c_str()); - op_desc->outputs.clear(); - for (auto &iter : output_map) { - for (const std::string &output_id : iter.second) { - const std::shared_ptr output_op_desc = GetOpDesc(output_id); - FE_PATTERN_ERROR_RETURN_IF(output_op_desc == nullptr, "Id does not exist. (id:%s)", output_id.c_str()); - if (op_desc->outputs.find(iter.first) == op_desc->outputs.end()) { - op_desc->outputs[iter.first] = {}; - } - op_desc->outputs[iter.first].emplace_back(output_op_desc); - FE_PATTERN_ERROR_RETURN_IF(op_desc->output_size == std::numeric_limits::max(), - "op_desc->output_size has wrapped around."); - ++op_desc->output_size; - } - } - op_desc->is_output_fullmatch = is_fullmatched; - return *this; -} -/** - * @ingroup fe - * @brief set output Ops with vector - */ -FusionPattern &FusionPattern::SetOutputs(const std::string &id, const FusionPattern::OutputMapStr &output_map, - bool is_fullmatched) { - if (id.empty()) { - GELOGW("Id cannot be empty."); - return *this; - } - const std::shared_ptr op_desc = GetOpDesc(id); - FE_PATTERN_ERROR_RETURN_IF(op_desc == nullptr, "Id does not exist. (id:%s)", id.c_str()); - - op_desc->outputs.clear(); - for (auto &iter : output_map) { - const std::string output_id(iter.second); - const std::shared_ptr output_op_desc = GetOpDesc(output_id); - FE_PATTERN_ERROR_RETURN_IF(output_op_desc == nullptr, "Id does not exist. (id:%s)", output_id.c_str()); - op_desc->outputs[iter.first].emplace_back(output_op_desc); - FE_PATTERN_ERROR_RETURN_IF(op_desc->output_size == std::numeric_limits::max(), - "op_desc->output_size has wrapped around."); - ++op_desc->output_size; - } - op_desc->is_output_fullmatch = is_fullmatched; - return *this; -} - -/** - * @ingroup fe - * @brief set output Op - */ -FusionPattern &FusionPattern::SetOutput(const std::string &id) { - FE_PATTERN_ERROR_RETURN_IF(id.empty(), "Id cannot be empty."); - const std::shared_ptr op_desc = GetOpDesc(id); - FE_PATTERN_ERROR_RETURN_IF(op_desc == nullptr, "Id does not exist. (id:%s)", id.c_str()); - - op_desc->is_output = true; - - return *this; -} - -/** - * @ingroup fe - * @brief build pattern and check if error exists - */ -bool FusionPattern::Build() { - if (has_error_) { - return false; - } - - // check whether output node already exists - for (const std::shared_ptr op : ops_) { - if (op->is_output) { - if (output_ != nullptr) { - SetError(); - GELOGW("[FusionPattern][Build] Multiple outputs are not supported, (id:%s)", op->id.c_str()); - break; - } - output_ = op; - } - } - - if (output_ == nullptr) { - SetError(); - GELOGW("[FusionPattern][Build] Output must be set to a value."); - } - - return !has_error_; -} - -/** - * @ingroup fe - * @brief get pattern name - */ -const std::string &FusionPattern::GetName() const { return name_; } -/** - * @ingroup fe - * @brief get the OpDesc of input Ops (const) - */ - -const std::vector> *FusionPattern::GetInputs( - const std::shared_ptr op_desc) { - if (op_desc == nullptr) { - return nullptr; - } - return &(op_desc->inputs); -} - -const FusionPattern::OutputMapDesc &FusionPattern::GetOutputs(const OpDescPtr op_desc) { - return op_desc->outputs; -} - -size_t FusionPattern::GetOutputSize(const OpDescPtr op_desc) { - return op_desc->output_size; -} - -/** - * @ingroup fe - * @brief get the OpDesc of output Op - */ -const std::shared_ptr FusionPattern::GetOutput() const { return output_; } - -/** - * @ingroup fe - * @brief print pattern - */ -void FusionPattern::Dump() const { - std::ostringstream oss; - oss << std::endl << "Pattern (" << name_ << "):" << std::endl; - for (const auto &op : ops_) { - oss << " " << op->id << ": {"; - for (const std::string &type : op->types) { - oss << type << ", "; - } - oss << "} {"; - for (const auto &input : op->inputs) { - oss << input->id << ", "; - } - oss << "}"; - - if (op->is_output) { - oss << " [output]"; - } - - oss << std::endl; - } - size_t len = oss.str().length(); - size_t startIndex = 0; - size_t recursive_times = 0; - constexpr int32_t kMaxTurnCount = 10; - do { - recursive_times++; - const int32_t endIndex = static_cast(std::min(startIndex + MAX_LOG_LENGTH, len)); - string subStr = oss.str().substr(startIndex, static_cast(endIndex - startIndex)); - GELOGD("%s", subStr.c_str()); - startIndex = static_cast(endIndex); - } while (startIndex < len && static_cast(recursive_times) < kMaxTurnCount); -} - -/** - * @ingroup fe - * @brief get OpDesc based on ID, return nullptr if failed - */ -std::shared_ptr FusionPattern::GetOpDesc(const std::string &id) const { - const auto it = op_map_.find(id); - if (it != op_map_.end()) { - return it->second; - } - return nullptr; -} - -const std::vector> &FusionPattern::GetOpDescs() const { return ops_; } -/** - * @ingroup fe - * @brief record error - */ -void FusionPattern::SetError() { has_error_ = true; } -} diff --git a/register/graph_optimizer/graph_fusion/fusion_quant_util.cc b/register/graph_optimizer/graph_fusion/fusion_quant_util.cc deleted file mode 100644 index 0bee67100f811e084b24e772ebde8eacd323ac7b..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/fusion_quant_util.cc +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/graph_fusion/fusion_quant_util.h" -#include "register/graph_optimizer/graph_fusion/fusion_quant_util_impl.h" - -namespace fe { -Status QuantUtil::BiasOptimizeByEdge(ge::NodePtr &quant_node, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes) { - return QuantUtilImpl::BiasOptimizeByEdge(quant_node, param, fusion_nodes); -} - -Status QuantUtil::BiasOptimizeByEdge(BiasOptimizeEdges ¶m, std::vector &fusion_nodes) { - return QuantUtilImpl::BiasOptimizeByEdge(param, fusion_nodes); -} - -Status QuantUtil::BiasOptimizeByEdge(QuantParam &quant_param, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes, - WeightMode cube_type) { - return QuantUtilImpl::BiasOptimizeByEdge(quant_param, param, fusion_nodes, cube_type); -} - -Status QuantUtil::InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr deq_scale, - std::vector &fusion_nodes) { - return QuantUtilImpl::InsertFixpipeDequantScaleConvert(deq_scale, fusion_nodes); -} - -Status QuantUtil::InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr &deq_scale, - ge::InDataAnchorPtr &quant_offset, std::vector &fusion_nodes) { - return QuantUtilImpl::InsertFixpipeDequantScaleConvert(deq_scale, quant_offset, fusion_nodes); -} - -Status QuantUtil::InsertQuantScaleConvert(ge::InDataAnchorPtr &quant_scale, ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes) { - return QuantUtilImpl::InsertQuantScaleConvert(quant_scale, quant_offset, fusion_nodes); -} - -Status QuantUtil::InsertRequantScaleConvert(ge::InDataAnchorPtr &req_scale, ge::InDataAnchorPtr &quant_offset, - ge::InDataAnchorPtr &cube_bias, std::vector &fusion_nodes) { - return QuantUtilImpl::InsertRequantScaleConvert(req_scale, quant_offset, cube_bias, fusion_nodes); -} -} // namespace fe diff --git a/register/graph_optimizer/graph_fusion/fusion_quant_util_impl.cc b/register/graph_optimizer/graph_fusion/fusion_quant_util_impl.cc deleted file mode 100644 index 2b5661dfc5be6f07686c80d0da1ee5596dfcb5f7..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/fusion_quant_util_impl.cc +++ /dev/null @@ -1,1216 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/graph_fusion/fusion_quant_util_impl.h" -#include "common/ge_common/string_util.h" -#include "common/ge_common/util.h" -#include "graph/anchor.h" -#include "graph/compute_graph.h" -#include "graph/debug/ge_log.h" -#include "graph/ge_local_context.h" -#include "graph/node.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "platform/platform_info.h" -#include "register/tensor_assign.h" -#include "graph/debug/ge_attr_define.h" -#include -#include - -namespace fe { -namespace { -const size_t kAicVersionSize = 3; -const std::string X1INPUTNAME = "x1"; -const std::string ATTR_OFFSET = "offset"; -const std::string ATTR_SCALE = "scale"; -const std::string PARAM_QUANT_NODE = "param_quant_node"; -const std::string ATTR_OUTDTYPE = "dequantouttype"; -const std::string BIAS_OPTIMIZATION_BIAS = "cube_optimization_bias"; -const std::string BIAS_OPTIMIZATION_FILTER = "cube_optimization_filter"; -const std::string BIAS_OPTIMIZATION_DEQUANT_SCALE = "dequant_scale"; -const std::string BIAS_OPTIMIZATION_OUTPUT = "cube_quant_roll_back_output"; -constexpr size_t BIAS_OPT_OP_OFFSET_IDX = 3; -constexpr size_t BIAS_OPT_OP_SCALE_IDX = 4; -constexpr size_t IDX_2 = 2; -const std::string kAttrSingleOp = "_is_single_op"; -constexpr bool ATTRTRUE = true; -const std::string kSocVersionAscend035 = "Ascend035"; -const std::string kSocVersionAscend035A = "Ascend035A"; -const std::string kSocVersionAscend035B = "Ascend035B"; -const std::string SOC_VERSION = "ge.socVersion"; -const std::string kAICoreSpec = "AICoreSpec"; -const std::string kSupportFixpipe = "support_fixpipe"; -const std::string kConstOpType = "Const"; -const std::string kRequantHostCpuOpType = "RequantHostCpuOpV2Re"; -const std::string kRequantInputName = "requant_input"; -const std::string kRequantOutputName = "requant_input"; -const std::string kAttrQuantMode = "quant_mode"; -const std::string kQuantHighPrecision = "quant_high_precision"; -const std::string kQuantHighPerformance = "quant_high_performance"; -const std::string kAttrReluFlag = "relu_flag"; -const std::string kAttrBiasSize = "bias_size"; -const std::string kAttrBiasValue = "bias_value"; -const std::string kAttrQuantScale = "quant_scale"; -const std::string kAttrQuantScaleVec = "quant_scale_vec"; -const std::string kAttrQuantOffsetVec = "quant_offset_vec"; -const std::string kAscendQuant = "AscendQuant"; -const std::string kQuantCinCoutReverse = "quant_cin_cout_reverse"; -constexpr uint32_t kBitShift3ByteSize = 24; -constexpr uint32_t kBitShift37 = 37; -const std::unordered_set kNanoSocVersionSet = {kSocVersionAscend035, - kSocVersionAscend035A, kSocVersionAscend035B}; -// maps aic version to ISA arch VERSION -const std::map kAicIsaArchVersionMap{{"100", "v100"}, {"200", "v200"}, {"202", "v200"}, - {"210", "v200"}, {"220", "v220"}, {"300", "v300"}, - {"310", "v300"}, {"350", "v350"}}; - -const std::set kNeedAddBiasWithWeightNd = {"FFN"}; - -const std::map> AXIS_INDEX_OF_FORMAT = { - {ge::Format::FORMAT_NCHW, {{"N", NCHW_DIM_N}, {"C", NCHW_DIM_C}, {"H", NCHW_DIM_H}, {"W", NCHW_DIM_W}}}, - {ge::Format::FORMAT_HWCN, {{"N", HWCN_DIM_N}, {"C", HWCN_DIM_C}, {"H", HWCN_DIM_H}, {"W", HWCN_DIM_W}}}, - {ge::Format::FORMAT_NHWC, {{"N", NHWC_DIM_N}, {"C", NHWC_DIM_C}, {"H", NHWC_DIM_H}, {"W", NHWC_DIM_W}}}, - {ge::Format::FORMAT_CHWN, {{"N", CHWN_DIM_N}, {"C", CHWN_DIM_C}, {"H", CHWN_DIM_H}, {"W", CHWN_DIM_W}}}, - {ge::Format::FORMAT_NDHWC, - {{"N", NDHWC_DIM_N}, {"C", NDHWC_DIM_C}, {"H", NDHWC_DIM_H}, {"W", NDHWC_DIM_W}, {"D", NDHWC_DIM_D}}}, - {ge::Format::FORMAT_NCDHW, - {{"N", NCDHW_DIM_N}, {"C", NCDHW_DIM_C}, {"H", NCDHW_DIM_H}, {"W", NCDHW_DIM_W}, {"D", NCDHW_DIM_D}}}, - {ge::Format::FORMAT_DHWCN, - {{"N", DHWCN_DIM_N}, {"C", DHWCN_DIM_C}, {"H", DHWCN_DIM_H}, {"W", DHWCN_DIM_W}, {"D", DHWCN_DIM_D}}}, - {ge::Format::FORMAT_DHWNC, - {{"N", DHWNC_DIM_N}, {"C", DHWNC_DIM_C}, {"H", DHWNC_DIM_H}, {"W", DHWNC_DIM_W}, {"D", DHWNC_DIM_D}}}}; - -const std::set kRootGraphData = {"Data", "RefData"}; -} - -static uint64_t GetHostCpuAtomicId() { - static std::atomic global_trans_atomic_id(0); - return global_trans_atomic_id.fetch_add(1, std::memory_order_relaxed); -} - -static uint64_t GetBiasNodeAtomicId() { - static std::atomic global_bias_node_atomic_id(0); - return global_bias_node_atomic_id.fetch_add(1, std::memory_order_relaxed); -} - -static uint64_t GetQuantOpAtomicId() { - static std::atomic global_quant_op_atomic_id(0); - return global_quant_op_atomic_id.fetch_add(1, std::memory_order_relaxed); -} - -static void GetIsaArchVersionStr(std::string &isa_version) { - PlatFormInfos platform_infos; - OptionalInfos optional_infos; - if (PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platform_infos, optional_infos) != SUCCESS) { - GELOGE(ge::FAILED, "Get platform info failed."); - return; - } - - // short soc version - std::string short_soc_version; - if (!platform_infos.GetPlatformRes("version", "Short_SoC_version", short_soc_version) || short_soc_version.empty()) { - GELOGE(ge::FAILED, "Get short soc version failed."); - return; - } - GELOGD("Short soc version is [%s].", short_soc_version.c_str()); - - // aic version, ISAArchVersion - std::string aic_version_str; - if (!platform_infos.GetPlatformRes("version", "AIC_version", aic_version_str) || aic_version_str.empty()) { - GELOGE(ge::FAILED, "Aic version of [%s] is empty.", short_soc_version.c_str()); - return; - } - GELOGD("Aic version of [%s] is [%s].", short_soc_version.c_str(), aic_version_str.c_str()); - std::vector aic_version_vec = ge::StringUtils::Split(aic_version_str, '-'); - if (aic_version_vec.size() < kAicVersionSize) { - GELOGE(ge::FAILED, "The aic version[%s] is invalid.", aic_version_str.c_str()); - return; - } - auto iter = kAicIsaArchVersionMap.find(aic_version_vec[2]); - if (iter != kAicIsaArchVersionMap.end()) { - isa_version = iter->second; - } -} - -bool QuantUtilImpl::NeedBiasInput(const ge::InDataAnchorPtr &bias) { - auto peer_anchor = bias->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - return true; - } - auto bias_node = peer_anchor->GetOwnerNode(); - return bias_node == nullptr; -} - -Status QuantUtilImpl::PadShapeTo4Dim(const ge::Format &filter_format, const std::vector &filter_dims, - std::vector &filter_dims4_d) { - size_t size_of_filter = filter_dims.size(); - GELOGD("Size of filter is %zu bytes", size_of_filter); - for (size_t i = 0; i <= BIAS_OPT_OP_OFFSET_IDX; i++) { - if (i < size_of_filter) { - GELOGD("dim [%zu] is %ld", i, filter_dims.at(i)); - filter_dims4_d.emplace_back(filter_dims.at(i)); - } else { - if (filter_format == ge::Format::FORMAT_NCHW) { - filter_dims4_d.emplace_back(1); - } else if (filter_format == ge::Format::FORMAT_HWCN) { - (void)filter_dims4_d.insert(filter_dims4_d.cbegin(), 1); - } else if (filter_format == ge::Format::FORMAT_NHWC) { - (void)filter_dims4_d.insert(filter_dims4_d.cbegin() + 1, 1); - } else if (filter_format == ge::Format::FORMAT_ND) { - filter_dims4_d.emplace_back(0); - } else { - GELOGE(ge::FAILED, "[GraphOpt][Quant][PadShpTo4Dim] format %s can not pad shape.", - ge::TypeUtils::FormatToSerialString(filter_format).c_str()); - return FAILED; - } - } - } - - if (!filter_dims4_d.empty() && filter_dims4_d.size() >= BIAS_OPT_OP_OFFSET_IDX) { - GELOGD("Quant bias optimize, filter_format is %s, weight shape is [%ld %ld %ld %ld].", - ge::TypeUtils::FormatToSerialString(filter_format).c_str(), static_cast(filter_dims4_d[NCHW_DIM_N]), - static_cast(filter_dims4_d[NCHW_DIM_C]), static_cast(filter_dims4_d[NCHW_DIM_H]), - static_cast(filter_dims4_d[NCHW_DIM_W])); - } - return SUCCESS; -} - -int32_t QuantUtilImpl::GetAxisIndexByFormat(const ge::Format &format, const string &axis) { - auto iter = AXIS_INDEX_OF_FORMAT.find(format); - if (iter != AXIS_INDEX_OF_FORMAT.end()) { - auto iter2 = iter->second.find(axis); - if (iter2 != iter->second.end()) { - return iter2->second; - } else { - GELOGW("Unsupported axis: %s", axis.c_str()); - return -1; - } - } else { - GELOGW("Do not support this format %s", ge::TypeUtils::FormatToSerialString(format).c_str()); - return -1; - } -} - -inline Status CheckInt64MulOverflow(int64_t m, int64_t n) { - if (m > 0) { - if (n > 0) { - if (m > (static_cast(INT64_MAX) / n)) { - return FAILED; - } - } else { - if (n < (static_cast(INT64_MIN) / m)) { - return FAILED; - } - } - } else { - if (n > 0) { - if (m < (static_cast(INT64_MIN) / n)) { - return FAILED; - } - } else { - if ((m != 0) && (n < (static_cast(INT64_MAX) / m))) { - return FAILED; - } - } - } - return SUCCESS; -} - -Status QuantUtilImpl::GetCoValueByWeight(ge::NodePtr &cube_node, size_t idx, std::vector &bias_shape) { - FE_PARAM_CHECK_NOTNULL(cube_node->GetOpDesc()->MutableInputDesc(static_cast(idx))); - const ge::Format filter_format = - static_cast(ge::GetPrimaryFormat( - cube_node->GetOpDesc()->MutableInputDesc(static_cast(idx))->GetFormat())); - auto filter_shape = cube_node->GetOpDesc()->MutableInputDesc(static_cast(idx))->MutableShape(); - - if ((filter_format == ge::FORMAT_ND || filter_format == ge::FORMAT_NCHW) - && kNeedAddBiasWithWeightNd.count(cube_node->GetType()) != 0) { - auto filter_dims = filter_shape.GetDims(); - if (filter_dims.size() == 2) { // current only support 2D weight - bias_shape.emplace_back(filter_dims[1]); - } - - if (filter_dims.size() == kAicVersionSize) { - bias_shape.emplace_back(filter_dims[0]); - bias_shape.emplace_back(filter_dims[IDX_2]); - } - return SUCCESS; - } - if (filter_format != ge::FORMAT_ND) { - int64_t groups = 1; - std::vector filter_dims4_d; - (void) ge::AttrUtils::GetInt(cube_node->GetOpDesc(), "groups", groups); - (void) PadShapeTo4Dim(filter_format, filter_shape.GetDims(), filter_dims4_d); - if (filter_dims4_d.empty()) { - GELOGE(ge::FAILED, "[GraphOpt][AvgPolQntPcsFus][GetCoVal] Node[%s] filter_dims4_d is empty.", - cube_node->GetName().c_str()); - return FAILED; - } - int64_t index_co = GetAxisIndexByFormat(filter_format, "C"); - if (index_co < 0) { - GELOGE(ge::FAILED, "[GraphOpt][AvgPolQntPcsFus][GetCoVal] Node[%s] index_co is negative, Check filter_format.", - cube_node->GetName().c_str()); - return FAILED; - } - if (index_co >= static_cast(filter_dims4_d.size())) { - GELOGE(ge::FAILED, - "[GraphOpt][AvgPolQntPcsFus][GetCoVal] Node[%s] index_co is larger than the size of filter dims.", - cube_node->GetName().c_str()); - return FAILED; - } - if (CheckInt64MulOverflow(filter_dims4_d[static_cast(index_co)], groups) != SUCCESS) { - return FAILED; - } - bias_shape.emplace_back(filter_dims4_d[static_cast(index_co)] * groups); - } - return SUCCESS; -} - -TensorPtr QuantUtilImpl::CreateBiasTensor(const std::vector &shape) { - int64_t size = 1; - for (auto dim : shape) { - size *= dim; - } - std::unique_ptr bias_data_temp(new (std::nothrow) int32_t[size]()); - for (int64_t i = 0; i < size; i++) { - bias_data_temp[static_cast(i)] = 0; - } - - ge::GeTensorDesc tmp_desc; - ge::GeTensorPtr bias_ptr = nullptr; - GE_MAKE_SHARED(bias_ptr = std::make_shared(tmp_desc, reinterpret_cast(bias_data_temp.get()), - size * sizeof(int32_t)), - return nullptr); - - ge::GeShape bias_shape(shape); - bias_ptr->MutableTensorDesc().SetShape(bias_shape); - bias_ptr->MutableTensorDesc().SetDataType(ge::DT_INT32); - const Status ret = bias_ptr->SetData(reinterpret_cast(bias_data_temp.get()), static_cast(size) * static_cast(sizeof(int32_t))); - if (ret != SUCCESS) { - GELOGW("Failed to set bias data!"); - return nullptr; - } - return bias_ptr; -} - -ge::NodePtr QuantUtilImpl::CreateBiasNode(std::shared_ptr &graph, const ge::GeTensorPtr &bias_ptr, - const std::string &cube_node_name) { - ge::OpDescPtr const_opdesc = ge::OpDescUtils::CreateConstOp(bias_ptr); - if (const_opdesc == nullptr) { - GELOGE(ge::FAILED, "const_opdesc nullptr"); - return nullptr; - } - std::ostringstream oss; - oss << cube_node_name << "_quant_bias" << GetBiasNodeAtomicId(); - (void)oss.flush(); - const_opdesc->SetName(oss.str()); - ge::NodePtr const_node = graph->AddNode(const_opdesc); - return const_node; -} - -Status QuantUtilImpl::UpdateBiasOutputDesc(const ge::NodePtr &cube_node, const ge::GeShape &shape, - ge::Format format, const uint32_t index) { - FE_PARAM_CHECK_NOTNULL(cube_node->GetInDataAnchor(static_cast(index))); - FE_PARAM_CHECK_NOTNULL(cube_node->GetInDataAnchor(static_cast(index))->GetPeerOutAnchor()); - ge::NodePtr bias_node = cube_node->GetInDataAnchor(static_cast(index))->GetPeerOutAnchor()->GetOwnerNode(); - ge::OpDescPtr bias_op_desc = bias_node->GetOpDesc(); - // only has one output, index 0 - ge::GeTensorDesc bias_output_desc = bias_op_desc->GetOutputDesc(0); - bias_output_desc.SetShape(shape); - bias_output_desc.SetOriginFormat(format); - bias_output_desc.SetOriginShape(shape); - bias_output_desc.SetOriginDataType(ge::DT_INT32); - bias_output_desc.SetDataType(ge::DT_INT32); - if (bias_op_desc->UpdateOutputDesc(0, bias_output_desc) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "bias_op_desc faild"); - return FAILED; - } - return SUCCESS; -} - -Status QuantUtilImpl::UpdateCubeInputDesc(const ge::NodePtr &cube_node, const ge::GeShape &shape, - const ge::Format &format, const uint32_t index) { - ge::GeTensorDesc bias_desc = cube_node->GetOpDesc()->GetInputDesc(index); - bias_desc.SetShape(shape); - bias_desc.SetOriginFormat(format); - bias_desc.SetOriginShape(shape); - bias_desc.SetOriginDataType(ge::DT_INT32); - bias_desc.SetDataType(ge::DT_INT32); - if (cube_node->GetOpDesc()->UpdateInputDesc(index, bias_desc) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "cube_node UpdateInputDesc"); - return FAILED; - } - return SUCCESS; -} - -Status QuantUtilImpl::CreateBiasInput(std::shared_ptr &graph, ge::NodePtr &cube_node, - const std::vector &shape, const size_t &bias_idx) { - GELOGD("Node[name: %s, type: %s] has no bias, create bias and set data", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - - ge::GeTensorPtr bias_ptr = CreateBiasTensor(shape); - if (bias_ptr == nullptr) { - return FAILED; - } - ge::NodePtr const_node = CreateBiasNode(graph, bias_ptr, cube_node->GetName()); - if (const_node == nullptr) { - GELOGE(ge::FAILED, "[GraphOpt][BiasQuantPass][CreateBiasInput] Fail to add const node."); - return FAILED; - } - - if (cube_node->AddLinkFrom(static_cast(bias_idx), const_node) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "[GraphOpt][BiasQuantPass][CreateBiasInput] Fail to link const node with node[%s, %s].", - cube_node->GetName().c_str(), cube_node->GetType().c_str()); - return FAILED; - } - - const ge::GeShape bias_shape(shape); - auto bias_input_desc = cube_node->GetOpDesc()->GetInputDesc(static_cast(bias_idx)); - const ge::Format input_desc0_origin_format = bias_input_desc.GetOriginFormat(); - if (UpdateBiasOutputDesc(cube_node, bias_shape, input_desc0_origin_format, - static_cast(bias_idx)) != SUCCESS) { - return FAILED; - } - if (UpdateCubeInputDesc(cube_node, bias_shape, input_desc0_origin_format, - static_cast(bias_idx)) != SUCCESS) { - - return FAILED; - } - return SUCCESS; -} - -// 基于weight input anchor 获取weight的input node -Status QuantUtilImpl::GetWeightConstNode(const ge::InDataAnchorPtr &weight, ge::NodePtr &weight_const_node, - ge::NodePtr &ascend_weight_quant_node) { - auto peer_out_anchor_of_weight = weight->GetPeerOutAnchor(); - if (peer_out_anchor_of_weight == nullptr) { - GELOGE(ge::FAILED, "peer_out_anchor_of_weight is nullptr"); - return FAILED; - } - - auto weight_input_node = peer_out_anchor_of_weight->GetOwnerNode(); - if (weight_input_node == nullptr) { - GELOGE(ge::FAILED, "weight_input_node is nullptr"); - return FAILED; - } - - auto weight_input_node_first_input_anchor = weight_input_node->GetInDataAnchor(0); - // if dynamic batch or dynamic shape, cube_weight_input_node will be a Data node - if (weight_input_node_first_input_anchor == nullptr || - kRootGraphData.count(weight_input_node->GetOpDesc()->GetType()) != 0) { - ascend_weight_quant_node = nullptr; - weight_const_node = weight_input_node; - } else { - auto weight_const_out_anchor = weight_input_node_first_input_anchor->GetPeerOutAnchor(); - if (weight_const_out_anchor == nullptr) { - GELOGE(ge::FAILED, "weight_const_out_anchor is nullptr"); - return FAILED; - } - weight_const_node = weight_const_out_anchor->GetOwnerNode(); - if (weight_const_node == nullptr) { - GELOGE(ge::FAILED, "weight_const_node is nullptr"); - return FAILED; - } - ascend_weight_quant_node = weight_input_node; - } - return SUCCESS; -} - -Status QuantUtilImpl::GetInputDescByAnchor(const ge::InDataAnchorPtr &in_data_anchor, ge::GeTensorDesc &tensor_desc) { - auto owner_node = in_data_anchor->GetOwnerNode(); - const size_t anchor_idx = static_cast(in_data_anchor->GetIdx()); - if (owner_node == nullptr || anchor_idx >= owner_node->GetOpDesc()->GetAllInputsSize()) { - return FAILED; - } - - tensor_desc = owner_node->GetOpDesc()->GetInputDesc(static_cast(anchor_idx)); - return SUCCESS; -} - -void QuantUtilImpl::SetAttrsForBiasOptimizerOp(ge::OpDescPtr &op_desc, const ge::NodePtr &cube_node, - const ge::NodePtr &ascend_weight_quant_node, - const WeightMode cube_type) { - bool quant_cin_cout_reverse = false; - if (ge::AttrUtils::GetBool(cube_node->GetOpDesc(), kQuantCinCoutReverse, quant_cin_cout_reverse)) { - (void) ge::AttrUtils::SetBool(op_desc, kQuantCinCoutReverse, quant_cin_cout_reverse); - } else { - (void) ge::AttrUtils::SetBool(op_desc, kQuantCinCoutReverse, false); - } - int64_t groups = 1; - (void) ge::AttrUtils::GetInt(cube_node->GetOpDesc(), "groups", groups); - (void) ge::AttrUtils::SetInt(op_desc, "groups", groups); - (void) ge::AttrUtils::SetBool(op_desc, "_is_come_from_const_op", true); - if (cube_type == WeightMode::RESERVED) { - (void) ge::AttrUtils::SetStr(op_desc, "cube_op_type", cube_node->GetType()); - } else { - if (cube_type == WeightMode::WEIGHTWITH2D) { - (void) ge::AttrUtils::SetStr(op_desc, "cube_op_type", "MatMulV2"); - } else { - (void) ge::AttrUtils::SetStr(op_desc, "cube_op_type", "Conv3D"); - } - } - std::string soc_version = "v100"; - GetIsaArchVersionStr(soc_version); - (void) ge::AttrUtils::SetStr(op_desc, "soc_version", soc_version); - int dst_type = ge::DT_INT8; - if (ascend_weight_quant_node != nullptr) { - (void) ge::AttrUtils::GetInt(ascend_weight_quant_node->GetOpDesc(), "dst_type", dst_type); - } - (void) ge::AttrUtils::SetInt(op_desc, "dst_type", dst_type); -} - -Status QuantUtilImpl::SetQuantScaleAndOffset(const ge::NodePtr &quant_node, const BiasOptimizeEdges ¶m, - ge::OpDescPtr &host_op_desc) { - if (quant_node != nullptr) { - // get scale and offset from quant node attr - float_t scale_a = 0.0F; - (void) ge::AttrUtils::GetFloat(quant_node->GetOpDesc(), "scale", scale_a); - (void) ge::AttrUtils::SetFloat(host_op_desc, "scale", scale_a); - - float_t offset = 0.0F; - (void) ge::AttrUtils::GetFloat(quant_node->GetOpDesc(), "offset", offset); - (void) ge::AttrUtils::SetFloat(host_op_desc, "offset", offset); - return SUCCESS; - } - if (param.quant_offset == nullptr || param.quant_scale == nullptr) { - GELOGE(ge::FAILED, , "Invalid param! Quant_offset anchor and quant_scale anchor should not be nullptr, " - "please check in detail."); - return FAILED; - } - ge::GeTensorDesc quant_offset_tensor; - if (GetInputDescByAnchor(param.quant_offset, quant_offset_tensor) != SUCCESS) { - return FAILED; - } - (void)host_op_desc->AddInputDesc("offset", quant_offset_tensor); - - ge::GeTensorDesc quant_scale_tensor; - if (GetInputDescByAnchor(param.quant_scale, quant_scale_tensor) != SUCCESS) { - return FAILED; - } - (void)host_op_desc->AddInputDesc("scale", quant_scale_tensor); - - return SUCCESS; -} - -// 改图,区分有没有quant_offset 和 quant_scale输入的场景 -Status QuantUtilImpl::LinkBiasOptimizeHostOp(const ge::NodePtr &quant_node, const ge::NodePtr &weight_const_node, - const BiasOptimizeEdges ¶m, ge::NodePtr &host_op_node) { - // input index bias:dequant_scale:weight:quant_offset:quant_scale - // bias need delete ori link - auto bias_peer_out_anchor = param.cube_bias->GetPeerOutAnchor(); - if (bias_peer_out_anchor == nullptr) { - return FAILED; - } - if (ge::GraphUtils::RemoveEdge(bias_peer_out_anchor, param.cube_bias) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "[GraphOpt][CreateHostOp][LinkHostOpEdge] Remove Edge between bias output " - "and cube input anchor failed."); - return FAILED; - } - - if (ge::GraphUtils::AddEdge(bias_peer_out_anchor, host_op_node->GetInDataAnchor(0)) != SUCCESS) { - GELOGE(ge::FAILED, "add edge between bias peer out anchor and host op failed"); - return FAILED; - } - - // weight - auto weight_out_anchor = weight_const_node->GetOutDataAnchor(0); - if (ge::GraphUtils::AddEdge(weight_out_anchor, host_op_node->GetInDataAnchor(1)) != SUCCESS) { - GELOGE(ge::FAILED, "add edge between weight peer out anchor and host op failed"); - return FAILED; - } - - // dequant_scale - if (param.deq_scale == nullptr) { - GELOGD("No deq_scale node found."); - } else { - auto deq_scale_peer_out_anchor = param.deq_scale->GetPeerOutAnchor(); - if (deq_scale_peer_out_anchor == nullptr) { - return FAILED; - } - if (ge::GraphUtils::AddEdge(deq_scale_peer_out_anchor, - host_op_node->GetInDataAnchor(static_cast(IDX_2))) != SUCCESS) { - GELOGE(ge::FAILED, "add edge between dequant_scale peer out anchor and host op failed"); - return FAILED; - } - } - - // quant_offset - if (quant_node == nullptr && param.quant_offset != nullptr) { - auto quant_offset_peer_out_anchor = param.quant_offset->GetPeerOutAnchor(); - if (quant_offset_peer_out_anchor == nullptr || - host_op_node->GetAllInDataAnchorsSize() < BIAS_OPT_OP_SCALE_IDX + 1UL) { - GELOGE(ge::FAILED, "Quant_offset_peer_out_anchor is nullptr or invalid anchor size %u", - host_op_node->GetAllInDataAnchorsSize()); - return FAILED; - } - if (ge::GraphUtils::AddEdge(quant_offset_peer_out_anchor, - host_op_node->GetInDataAnchor(static_cast(BIAS_OPT_OP_OFFSET_IDX))) - != SUCCESS) { - GELOGE(ge::FAILED, "add edge between quant_offset peer out anchor and host op failed"); - return FAILED; - } - } - // quant_scale - if (quant_node == nullptr && param.quant_scale != nullptr) { - auto quant_scale_peer_out_anchor = param.quant_scale->GetPeerOutAnchor(); - if (quant_scale_peer_out_anchor == nullptr) { - return FAILED; - } - if (ge::GraphUtils::AddEdge(quant_scale_peer_out_anchor, - host_op_node->GetInDataAnchor(static_cast(BIAS_OPT_OP_SCALE_IDX))) - != SUCCESS) { - GELOGE(ge::FAILED, "add edge between quant_scale peer out anchor and host op failed"); - return FAILED; - } - } - - // output - auto quant_host_cpu_output_anchor = host_op_node->GetOutDataAnchor(0); - if (ge::GraphUtils::AddEdge(quant_host_cpu_output_anchor, param.cube_bias) != SUCCESS) { - GELOGE(ge::FAILED, "add edge between quant_scale peer out anchor and host op failed"); - return FAILED; - } - - return SUCCESS; -} - -Status QuantUtilImpl::CreateBiasOptimizeHostCpuOp(std::shared_ptr &graph, - const ge::NodePtr &quant_node, const BiasOptimizeEdges ¶m, - const ge::NodePtr &weight_const_node, - WeightMode cube_type, - std::vector &fusion_nodes) { - // create host cpu op desc - std::ostringstream oss; - oss << "QuantBiasOptimization" << GetHostCpuAtomicId(); - (void)oss.flush(); - - auto bias_optimizer_op_desc = std::make_shared(oss.str().c_str(), "QuantBiasOptimization"); - if (bias_optimizer_op_desc == nullptr || weight_const_node == nullptr) { - GELOGE(ge::FAILED, "bias_optimizer_op_desc or weight_const_node is nullptr"); - return FAILED; - } - // construct bias and deq_scale tensor - ge::GeTensorDesc bias_tensor_desc; - if (GetInputDescByAnchor(param.cube_bias, bias_tensor_desc) != SUCCESS) { - return FAILED; - } - (void)bias_optimizer_op_desc->AddInputDesc(BIAS_OPTIMIZATION_BIAS, bias_tensor_desc); - // get weight tensor desc - ge::GeTensorDesc weight_tensor_desc = weight_const_node->GetOpDesc()->GetOutputDesc(0); - (void)bias_optimizer_op_desc->AddInputDesc(BIAS_OPTIMIZATION_FILTER, weight_tensor_desc); - - if (param.deq_scale != nullptr) { - ge::GeTensorDesc deq_scale_tensor_desc; - if (GetInputDescByAnchor(param.deq_scale, deq_scale_tensor_desc) != SUCCESS) { - return FAILED; - } - (void)bias_optimizer_op_desc->AddInputDesc(BIAS_OPTIMIZATION_DEQUANT_SCALE, deq_scale_tensor_desc); - } - - // get offset and scale form quant attr or input - if (SetQuantScaleAndOffset(quant_node, param, bias_optimizer_op_desc) != SUCCESS) { - GELOGE(ge::FAILED, "error SetQuantScaleAndOffset"); - return FAILED; - } - // modify host cpu op input desc - for (uint32_t i = 0; i < bias_optimizer_op_desc->GetAllInputsSize(); ++i) { - auto tensor_desc_ptr = bias_optimizer_op_desc->MutableInputDesc(i); - if (tensor_desc_ptr == nullptr) { - GELOGI("The tensor_desc_ptr is null."); - continue; - } - /* Keep the original data type and format the same as the current ones */ - tensor_desc_ptr->SetOriginDataType(tensor_desc_ptr->GetDataType()); - tensor_desc_ptr->SetOriginFormat(static_cast(ge::GetPrimaryFormat(tensor_desc_ptr->GetFormat()))); - tensor_desc_ptr->SetOriginShape(tensor_desc_ptr->GetShape()); - } - // add output desc - (void)bias_optimizer_op_desc->AddOutputDesc(BIAS_OPTIMIZATION_OUTPUT, bias_tensor_desc); - FE_PARAM_CHECK_NOTNULL(bias_optimizer_op_desc->MutableOutputDesc(0)); - bias_optimizer_op_desc->MutableOutputDesc(0)->SetOriginFormat( - static_cast(ge::GetPrimaryFormat(bias_tensor_desc.GetFormat()))); - bias_optimizer_op_desc->MutableOutputDesc(0)->SetOriginDataType(bias_tensor_desc.GetDataType()); - bias_optimizer_op_desc->MutableOutputDesc(0)->SetOriginShape(bias_tensor_desc.GetShape()); - - SetAttrsForBiasOptimizerOp(bias_optimizer_op_desc, param.cube_weight->GetOwnerNode(), weight_const_node, cube_type); - // create host op node - auto bias_optimizer_node = graph->AddNode(bias_optimizer_op_desc); - if (bias_optimizer_node == nullptr) { - return FAILED; - } - fusion_nodes.emplace_back(bias_optimizer_node); - - // modify host op edge - if (SUCCESS != LinkBiasOptimizeHostOp(quant_node, weight_const_node, param, bias_optimizer_node)) { - return FAILED; - } - - return SUCCESS; -} -Status QuantUtilImpl::BiasOptimizeByEdgeCommon(const ge::NodePtr &quant_node, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes, WeightMode cube_type) { - if (!param.isValid()) { - GELOGE(ge::FAILED, "param check failed, input param is invalid"); - return FAILED; - } - FE_PARAM_CHECK_NOTNULL(param.cube_weight); - auto cube_node = param.cube_weight->GetOwnerNode(); - auto graph = cube_node->GetOwnerComputeGraph(); - if (NeedBiasInput(param.cube_bias)) { - GELOGD("start creating bias node for node %s", cube_node->GetNamePtr()); - std::vector bias_shape; - if (GetCoValueByWeight(cube_node, static_cast(param.cube_weight->GetIdx()), bias_shape) != SUCCESS) { - GELOGE(ge::FAILED, "[GraphOpt][AvgPolQntPcsFus][BiasOpti] Get node[%s] co value.", cube_node->GetName().c_str()); - return FAILED; - } - GELOGD("start create bias input for node %s", cube_node->GetNamePtr()); - if (CreateBiasInput(graph, cube_node, bias_shape, static_cast(param.cube_bias->GetIdx())) != SUCCESS) { - GELOGE(ge::FAILED, "[GraphOpt][CreateBiasInput][BiasOpti] Get node[%s] co value.", cube_node->GetName().c_str()); - return FAILED; - } - } - ge::NodePtr weight_const_node = nullptr; - ge::NodePtr ascend_weight_quant_node = nullptr; - if (GetWeightConstNode(param.cube_weight, weight_const_node, ascend_weight_quant_node) != SUCCESS) { - GELOGE(ge::FAILED, - "[OriginGraphOptimize][GraphFusion][QuantOpt] Get weight const from node[%s, %s] failed, please check graph", - cube_node->GetName().c_str(), cube_node->GetType().c_str()); - } - const Status ret = CreateBiasOptimizeHostCpuOp(graph, quant_node, param, weight_const_node, cube_type, fusion_nodes); - if (ret != SUCCESS) { - GELOGE(ge::FAILED, "[OriginGraphOptimize][GraphFusion][QuantOpt] Create host op failed."); - return FAILED; - } - return SUCCESS; -} - -Status QuantUtilImpl::BiasOptimizeByEdge(const QuantParam &quant_param, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes, - WeightMode cube_type) { - if (param.cube_weight == nullptr) { - GELOGE(ge::FAILED, "[OriginGraphOptimize][GraphFusion][QuantOpt] Cube weight null."); - return FAILED; - } - auto cube_node = param.cube_weight->GetOwnerNode(); - auto graph = cube_node->GetOwnerComputeGraph(); - ge::OpDescPtr opdesc = std::make_shared(PARAM_QUANT_NODE, kAscendQuant); - (void)ge::AttrUtils::SetFloat(opdesc, ATTR_OFFSET, quant_param.quant_offset); - (void)ge::AttrUtils::SetFloat(opdesc, ATTR_SCALE, quant_param.quant_scale); - auto quant_node = graph->AddNode(opdesc); - const Status ret = BiasOptimizeByEdgeCommon(quant_node, param, fusion_nodes, cube_type); - (void)ge::GraphUtils::IsolateNode(quant_node, {}); - (void)ge::GraphUtils::RemoveNodeWithoutRelink(graph, quant_node); - return ret; -} - -Status QuantUtilImpl::BiasOptimizeByEdge(ge::NodePtr &quant_node, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes) { - if (quant_node == nullptr) { - GELOGE(ge::FAILED, - "[OriginGraphOptimize][GraphFusion][QuantOpt] Invalid parameter, quant node should not be nullptr!"); - return FAILED; - } - return BiasOptimizeByEdgeCommon(quant_node, param, fusion_nodes, WeightMode::RESERVED); -} - -Status QuantUtilImpl::BiasOptimizeByEdge(BiasOptimizeEdges ¶m, std::vector &fusion_nodes) { - ge::NodePtr quant_node = nullptr; - return BiasOptimizeByEdgeCommon(quant_node, param, fusion_nodes, WeightMode::RESERVED); -} - -bool QuantUtilImpl::IsNanoSoc() { - std::string soc_version_str; - if (ge::GetThreadLocalContext().GetOption(SOC_VERSION, soc_version_str) != ge::GRAPH_SUCCESS) { - GELOGD("getting option %s did not succeed.", SOC_VERSION.c_str()); - return false; - } - GELOGD("Option %s is %s.", SOC_VERSION.c_str(), soc_version_str.c_str()); - return kNanoSocVersionSet.count(soc_version_str) != 0; -} - -ge::OpDescPtr QuantUtilImpl::CreateDeqScaleHostOp(const std::string &op_name, const std::string &op_type, - const ge::OpDescPtr &cube_node, size_t index) { - GELOGD("Begin to create SetQuantScale Host op[%s, %s].", op_name.c_str(), op_type.c_str()); - // create set quant scale host op - ge::OpDescPtr op_desc = nullptr; - GE_MAKE_SHARED(op_desc = std::make_shared(op_name, op_type), return nullptr); - - ge::ConstGeTensorDescPtr prenode_inputdesc = cube_node->GetInputDescPtr(static_cast(index)); - ge::ConstGeTensorDescPtr prenode_outputdesc = cube_node->GetOutputDescPtr(0); - if (prenode_inputdesc == nullptr || prenode_outputdesc == nullptr) { - return nullptr; - } - - ge::GeTensorDesc out_tensor_desc = prenode_inputdesc->Clone(); - out_tensor_desc.SetDataType(ge::DT_UINT64); - out_tensor_desc.SetOriginDataType(ge::DT_UINT64); - (void)op_desc->AddInputDesc(X1INPUTNAME, *(prenode_inputdesc)); - (void)op_desc->AddOutputDesc("y", out_tensor_desc); - - // set attr - float offset = 0.0F; - if (ge::AttrUtils::GetFloat(cube_node, ATTR_OFFSET, offset)) { - (void) ge::AttrUtils::SetFloat(op_desc, ATTR_OFFSET, offset); - GELOGD("Set offset value [%f] for op[%s]", offset, op_name.c_str()); - } - (void) ge::AttrUtils::SetInt(op_desc, ATTR_OUTDTYPE, static_cast(prenode_outputdesc->GetDataType())); - (void) ge::AttrUtils::SetBool(op_desc, kAttrSingleOp, ATTRTRUE); - GELOGD("Host op [%s, %s] has been created.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return op_desc; -} - -Status QuantUtilImpl::InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr deq_scale, - std::vector &fusion_nodes) { - if (deq_scale == nullptr) { - return FAILED; - } - // get post_fuze_node - ge::NodePtr post_fuze_node = deq_scale->GetOwnerNode(); - FE_PARAM_CHECK_NOTNULL(post_fuze_node); - // get pre_fuze_node - FE_PARAM_CHECK_NOTNULL(deq_scale->GetPeerOutAnchor()); - ge::NodePtr pre_fuze_node = deq_scale->GetPeerOutAnchor()->GetOwnerNode(); - FE_PARAM_CHECK_NOTNULL(pre_fuze_node); - // get graph - auto compute_graph = post_fuze_node->GetOwnerComputeGraph(); - // get deq_scale input index - auto deq_scale_index = deq_scale->GetIdx(); - // set op_desc - std::string new_op_type = QuantUtilImpl::IsNanoSoc() ? "SetQuantScale" : "SetM1Dequant"; - std::string new_op_name = post_fuze_node->GetName() + new_op_type + std::to_string(GetHostCpuAtomicId()); - ge::OpDescPtr new_op_desc = - QuantUtilImpl::CreateDeqScaleHostOp(new_op_name, new_op_type, post_fuze_node->GetOpDesc(), - static_cast(deq_scale_index)); - FE_PARAM_CHECK_NOTNULL(new_op_desc); - ge::NodePtr new_node = compute_graph->AddNode(new_op_desc); - FE_PARAM_CHECK_NOTNULL(new_node); - // edit edge - fusion_nodes.push_back(new_node); - (void)ge::GraphUtils::RemoveEdge(deq_scale, deq_scale->GetPeerOutAnchor()); - (void)ge::GraphUtils::AddEdge(pre_fuze_node->GetOutDataAnchor(0), new_node->GetInDataAnchor(0)); - (void)ge::GraphUtils::AddEdge(new_node->GetOutDataAnchor(0), post_fuze_node->GetInDataAnchor(deq_scale_index)); - return SUCCESS; -} - -Status QuantUtilImpl::InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr &deq_scale, - ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes) { - GELOGD("Begin to do InsertFixpipeDequantScaleConvert"); - FE_PARAM_CHECK_NOTNULL(deq_scale); - FE_PARAM_CHECK_NOTNULL(quant_offset); - - std::string new_op_type = QuantUtilImpl::IsNanoSoc() ? "SetQuantScale" : "SetM1Dequant"; - ge::NodePtr cube_node = deq_scale->GetOwnerNode(); - FE_PARAM_CHECK_NOTNULL(cube_node); - std::string new_op_name = cube_node->GetName() + "_" + new_op_type + "_" + std::to_string(GetHostCpuAtomicId()); - ge::OpDescPtr cube_op_desc = cube_node->GetOpDesc(); - FE_PARAM_CHECK_NOTNULL(cube_op_desc); - const uint8_t *quant_offset_data_tmp = GetDataByAnchor(quant_offset); - FE_PARAM_CHECK_NOTNULL(quant_offset_data_tmp); - const float *quant_offset_data = reinterpret_cast(quant_offset_data_tmp); - (void) ge::AttrUtils::SetFloat(cube_op_desc, ATTR_OFFSET, *quant_offset_data); - - ge::OpDescPtr new_op_desc = - QuantUtilImpl::CreateDeqScaleHostOp(new_op_name, new_op_type, cube_op_desc, static_cast(deq_scale->GetIdx())); - FE_PARAM_CHECK_NOTNULL(new_op_desc); - auto compute_graph = cube_node->GetOwnerComputeGraph(); - ge::NodePtr new_node = compute_graph->AddNode(new_op_desc); - FE_PARAM_CHECK_NOTNULL(new_node); - - ge::OutDataAnchorPtr deq_scale_peer_anchor = deq_scale->GetPeerOutAnchor(); - FE_PARAM_CHECK_NOTNULL(deq_scale_peer_anchor); - ge::NodePtr deq_scale_peer_node = deq_scale_peer_anchor->GetOwnerNode(); - FE_PARAM_CHECK_NOTNULL(deq_scale_peer_node); - if (ge::GraphUtils::RemoveEdge(deq_scale_peer_anchor, deq_scale) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to remove edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - if (ge::GraphUtils::AddEdge(new_node->GetOutDataAnchor(0), deq_scale) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - if (ge::GraphUtils::AddEdge(deq_scale_peer_node->GetOutDataAnchor(0), new_node->GetInDataAnchor(0)) != - ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - fusion_nodes.push_back(new_node); - return SUCCESS; -} - -bool QuantUtilImpl::IsSupportFixpipe() { - PlatFormInfos platform_infos; - OptionalInfos optional_infos; - if (PlatformInfoManager::Instance().GetPlatformInfoWithOutSocVersion(platform_infos, optional_infos) != SUCCESS) { - GELOGW("Failed to get platform infos."); - return false; - } - bool is_support_fixpipe = false; - std::string is_support_fixpipe_str; - if (platform_infos.GetPlatformRes(kAICoreSpec, kSupportFixpipe, is_support_fixpipe_str)) { - is_support_fixpipe = static_cast(std::atoi(is_support_fixpipe_str.c_str())); - } - return is_support_fixpipe; -} - -ge::GeTensorPtr QuantUtilImpl::GetTensorByAnchor(ge::InDataAnchorPtr &anchor) { - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - GELOGW("Peer_anchor is nullptr."); - return nullptr; - } - ge::NodePtr peer_node = peer_anchor->GetOwnerNode(); - std::vector weights = ge::OpDescUtils::MutableWeights(peer_node); - if (weights.empty()) { - GELOGW("Node [%s, %s]: failed to retrieve weights", peer_node->GetName().c_str(), peer_node->GetType().c_str()); - return nullptr; - } - return weights[0]; -} - -const uint8_t *QuantUtilImpl::GetDataByAnchor(ge::InDataAnchorPtr &anchor) { - ge::GeTensorPtr weight_tensor = GetTensorByAnchor(anchor); - if (weight_tensor == nullptr) { - GELOGW("Weight_tensor is nullptr."); - return nullptr; - } - return weight_tensor->GetData().GetData(); -} - -uint64_t QuantUtilImpl::TransM1Scale(const float &src_value) { - uint32_t value = 0; - if (memcpy_s(&value, sizeof(uint32_t), &src_value, sizeof(float)) != 0) { - GELOGW("Failed to execute memcpy_s."); - return 0; - } - uint64_t tmp_data = static_cast(value) & 0x000000000FFFFE000; - return tmp_data; -} - -uint64_t QuantUtilImpl::SetM1OfQuant(const float &scale, const float &offset, const ge::DataType &data_type) { - uint64_t uint64_offset = static_cast(static_cast(std::nearbyint(offset))); - uint64_t uint64_data = 0; - if (data_type == ge::DT_UINT16) { - uint64_data = TransM1Scale(scale) + ((uint64_offset >> kBitShift3ByteSize) & 0xFFUL) + - (((uint64_offset & 0x1FFUL) << kBitShift37) & 0x3FE000000000UL); - } else if (data_type == ge::DT_UINT8) { - uint64_data = TransM1Scale(scale) + (((uint64_offset & 0x1FFUL) << kBitShift37) & 0x3FE000000000UL); - } else if (data_type == ge::DT_INT4) { - uint64_data = TransM1Scale(scale) + (((uint64_offset & 0x1FUL) << kBitShift37) & 0x3E000000000UL) + - 0x400000000000UL; - } else if (data_type == ge::DT_INT16) { - uint64_data = TransM1Scale(scale) + ((uint64_offset >> kBitShift3ByteSize) & 0xFFUL) + - (((uint64_offset & 0x1FFUL) << kBitShift37) & 0x3FE000000000UL) + 0x400000000000UL; - } else if (data_type == ge::DT_INT8) { - uint64_data = TransM1Scale(scale) + (((uint64_offset & 0x1FFUL) << kBitShift37) & 0x3FE000000000UL) + - 0x400000000000UL; - } else { - // do nothing - } - return uint64_data; -} - -Status QuantUtilImpl::UpdateScalarInput(const float *quant_scale_data, const float *quant_offset_data, - const ge::DataType &data_type, ge::GeTensorDescPtr &scale_tensor_desc, - ge::GeTensorPtr &quant_op_tensor) { - scale_tensor_desc->SetDataType(ge::DT_UINT64); - scale_tensor_desc->SetOriginDataType(ge::DT_UINT64); - scale_tensor_desc->SetFormat(ge::FORMAT_ND); - scale_tensor_desc->SetOriginFormat(ge::FORMAT_ND); - int64_t dim_count = 1; - ge::GeShape scale_shape = scale_tensor_desc->GetShape(); - if (!scale_shape.IsScalar()) { - for (auto &dim : scale_shape.GetDims()) { - dim_count *= dim; - } - } - std::unique_ptr scale_data(new (std::nothrow) uint64_t[dim_count]()); - for (size_t i = 0; i < static_cast(dim_count); i++) { - if (quant_offset_data != nullptr) { - scale_data[i] = SetM1OfQuant(*quant_scale_data, *quant_offset_data, data_type); - quant_scale_data++; - quant_offset_data++; - } else { - scale_data[i] = SetM1OfQuant(*quant_scale_data, 0.0F, data_type); - quant_scale_data++; - } - } - size_t total_data_size = static_cast(dim_count) * sizeof(uint64_t); - if (quant_op_tensor->SetData(reinterpret_cast(scale_data.get()), total_data_size) != SUCCESS) { - GELOGE(ge::FAILED, "Fail to set data of quant_op_tensor."); - return FAILED; - } - return SUCCESS; -} - -Status QuantUtilImpl::CreateQuantOp(ge::NodePtr &cube_node, ge::InDataAnchorPtr &quant_scale, - ge::GeTensorDescPtr scale_tensor_desc, ge::GeTensorPtr quant_op_tensor, - std::vector &fusion_nodes) { - std::string quant_op_name = cube_node->GetName() + "_quant_op_" + std::to_string(GetQuantOpAtomicId()); - ge::OpDescPtr quant_op_desc = nullptr; - GE_MAKE_SHARED(quant_op_desc = std::make_shared(quant_op_name, kConstOpType), return FAILED); - if (quant_op_desc->AddOutputDesc(*scale_tensor_desc) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add output desc.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - (void)quant_op_tensor->SetTensorDesc(*scale_tensor_desc); - if (!ge::AttrUtils::SetTensor(quant_op_desc, ge::ATTR_NAME_WEIGHTS, quant_op_tensor)) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to set quant op tensor.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - auto compute_graph = cube_node->GetOwnerComputeGraph(); - FE_PARAM_CHECK_NOTNULL(compute_graph); - ge::NodePtr quant_op = compute_graph->AddNode(quant_op_desc); - FE_PARAM_CHECK_NOTNULL(quant_op); - - ge::OutDataAnchorPtr quant_scale_peer_anchor = quant_scale->GetPeerOutAnchor(); - FE_PARAM_CHECK_NOTNULL(quant_scale_peer_anchor); - if (ge::GraphUtils::RemoveEdge(quant_scale_peer_anchor, quant_scale) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to remove edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - if (ge::GraphUtils::AddEdge(quant_op->GetOutDataAnchor(0), quant_scale) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - ge::NodePtr quant_scale_peer_node = quant_scale_peer_anchor->GetOwnerNode(); - FE_PARAM_CHECK_NOTNULL(quant_scale_peer_node); - (void)compute_graph->RemoveNode(quant_scale_peer_node); - fusion_nodes.emplace_back(quant_op); - return SUCCESS; -} - -Status QuantUtilImpl::InsertFixpipeQuantScaleConvert(ge::InDataAnchorPtr &quant_scale, - ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes) { - GELOGD("Begin to do InsertFixpipeQuantScaleConvert"); - FE_PARAM_CHECK_NOTNULL(quant_scale); - const uint8_t *quant_scale_data_tmp = GetDataByAnchor(quant_scale); - FE_PARAM_CHECK_NOTNULL(quant_scale_data_tmp); - const float *quant_scale_data = reinterpret_cast(quant_scale_data_tmp); - - const float *quant_offset_data = nullptr; - if (quant_offset != nullptr) { - const uint8_t *quant_offset_data_tmp = GetDataByAnchor(quant_offset); - if (quant_offset_data_tmp != nullptr) { - quant_offset_data = reinterpret_cast(quant_offset_data_tmp); - } - } - ge::NodePtr cube_node = quant_scale->GetOwnerNode(); - FE_PARAM_CHECK_NOTNULL(cube_node); - auto cube_out_desc = cube_node->GetOpDesc()->MutableOutputDesc(0); - FE_PARAM_CHECK_NOTNULL(cube_out_desc); - ge::DataType cube_out_data_type = cube_out_desc->GetDataType(); - GELOGD("Node[%s, %s]: cube_out_data_type is %zu.", cube_node->GetName().c_str(), cube_node->GetType().c_str(), - static_cast(cube_out_data_type)); - ge::GeTensorDescPtr scale_tensor_desc = cube_node->GetOpDesc()->MutableInputDesc( - static_cast(quant_scale->GetIdx())); - bool has_desc = true; - if (scale_tensor_desc == nullptr) { - ge::GeShape shape{}; - ge::GeTensorDesc fake_desc(shape, ge::FORMAT_ND, ge::DT_UINT64); - GE_MAKE_SHARED(scale_tensor_desc = std::make_shared(fake_desc), return FAILED); - has_desc = false; - } - ge::GeTensorPtr quant_op_tensor = nullptr; - GE_MAKE_SHARED(quant_op_tensor = std::make_shared(), return FAILED); - if (UpdateScalarInput(quant_scale_data, quant_offset_data, cube_out_data_type, scale_tensor_desc, - quant_op_tensor) != SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to update scalar input.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - if (!has_desc) { - (void)cube_node->GetOpDesc()->UpdateInputDesc(static_cast(quant_scale->GetIdx()), *scale_tensor_desc); - } - if (CreateQuantOp(cube_node, quant_scale, scale_tensor_desc, quant_op_tensor, fusion_nodes) != SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to create quant op.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - return SUCCESS; -} - -Status QuantUtilImpl::InsertQuantScaleConvert(ge::InDataAnchorPtr &quant_scale, - ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes) { - if (IsSupportFixpipe()) { - return InsertFixpipeQuantScaleConvert(quant_scale, quant_offset, fusion_nodes); - } - return SUCCESS; -} - -Status QuantUtilImpl::SetAttrForRequantHostCpuOp(ge::OpDescPtr &req_host_op_desc, - const ge::GeTensorPtr &req_scale_tensor, - ge::InDataAnchorPtr &quant_offset, ge::InDataAnchorPtr &cube_bias, - int32_t &req_scale_size) { - ge::GeTensorPtr cube_bias_tensor = GetTensorByAnchor(cube_bias); - FE_PARAM_CHECK_NOTNULL(cube_bias_tensor); - const int32_t *bias_data = reinterpret_cast(cube_bias_tensor->GetData().GetData()); - int32_t bias_size = static_cast(cube_bias_tensor->GetData().size() / sizeof(int32_t)); - vector bias_value; - for (auto index = 0; index < bias_size; index++) { - bias_value.push_back(bias_data[index]); - } - const int32_t scale_size = static_cast(req_scale_tensor->GetData().size() / sizeof(uint64_t)); - req_scale_size = bias_size == 0 ? scale_size : bias_size; - GELOGD("Req_scale_size is %d.", req_scale_size); - - const uint8_t *req_scale_data_tmp = req_scale_tensor->GetData().GetData(); - FE_PARAM_CHECK_NOTNULL(req_scale_data_tmp); - const float *req_scale_data = reinterpret_cast(req_scale_data_tmp); - std::vector quant_scale_vec(1, *req_scale_data); - (void)quant_scale_vec.insert(quant_scale_vec.end(), static_cast(static_cast(req_scale_size) - 1), quant_scale_vec[0]); - const uint8_t *quant_offset_data_tmp = GetDataByAnchor(quant_offset); - FE_PARAM_CHECK_NOTNULL(quant_offset_data_tmp); - const int64_t *quant_offset_data = reinterpret_cast(quant_offset_data_tmp); - std::vector quant_offset_vec(1, *quant_offset_data); - (void)quant_offset_vec.insert(quant_offset_vec.end(), static_cast(static_cast(req_scale_size) - 1), quant_offset_vec[0]); - - ge::GeShape req_scale_shape = req_scale_tensor->GetTensorDesc().GetShape(); - if (req_scale_shape.GetDimNum() != 1) { - GELOGE(ge::FAILED, "Req_scale_shape %zu is invalid.", req_scale_shape.GetDimNum()); - return FAILED; - } - (void)ge::AttrUtils::SetStr(req_host_op_desc, kAttrQuantMode, kQuantHighPrecision); - int64_t req_co = req_scale_shape.GetDim(0); - const uint64_t *req_scale_data_int = reinterpret_cast(req_scale_data_tmp); - for (int64_t i = 0; i < req_co; i++) { - const int8_t req_n = static_cast(GET_REQUANT_N(req_scale_data_int[i])); - GELOGD("Qeq_scale N value[%ld] is %d", i, req_n); - if (req_n != 0) { - (void)ge::AttrUtils::SetStr(req_host_op_desc, kAttrQuantMode, kQuantHighPerformance); - break; - } - } - (void)ge::AttrUtils::SetBool(req_host_op_desc, kAttrReluFlag, false); - (void)ge::AttrUtils::SetInt(req_host_op_desc, kAttrBiasSize, static_cast(bias_size)); - (void)ge::AttrUtils::SetListInt(req_host_op_desc, kAttrBiasValue, bias_value); - (void)ge::AttrUtils::SetFloat(req_host_op_desc, kAttrQuantScale, *req_scale_data); - (void)ge::AttrUtils::SetListFloat(req_host_op_desc, kAttrQuantScaleVec, quant_scale_vec); - (void)ge::AttrUtils::SetListInt(req_host_op_desc, kAttrQuantOffsetVec, quant_offset_vec); - return SUCCESS; -} - -Status QuantUtilImpl::CreateRequantHostCpuOp(ge::InDataAnchorPtr &req_scale, ge::InDataAnchorPtr &cube_bias, - ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes) { - ge::NodePtr cube_node = req_scale->GetOwnerNode(); - FE_PARAM_CHECK_NOTNULL(cube_node); - ge::GeTensorDescPtr req_scale_tensor_desc = - cube_node->GetOpDesc()->MutableInputDesc(static_cast(req_scale->GetIdx())); - bool has_desc = true; - if (req_scale_tensor_desc == nullptr) { - ge::GeShape shape{}; - ge::GeTensorDesc fake_desc(shape, ge::FORMAT_ND, ge::DT_UINT64); - GE_MAKE_SHARED(req_scale_tensor_desc = std::make_shared(fake_desc), return FAILED); - has_desc = false; - } - - std::string req_host_op_name = cube_node->GetName() + "_" + kRequantHostCpuOpType + "_" + - std::to_string(GetHostCpuAtomicId()); - ge::OpDescPtr req_host_op_desc = nullptr; - GE_MAKE_SHARED(req_host_op_desc = - std::make_shared(req_host_op_name, kRequantHostCpuOpType), return FAILED); - if (req_host_op_desc->AddInputDesc(kRequantInputName, *req_scale_tensor_desc) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add input desc.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - req_scale_tensor_desc->SetDataType(ge::DT_UINT64); - req_scale_tensor_desc->SetOriginDataType(ge::DT_UINT64); - if (req_host_op_desc->AddOutputDesc(kRequantOutputName, *req_scale_tensor_desc) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add output desc.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - if (!has_desc) { - (void)cube_node->GetOpDesc()->UpdateInputDesc(static_cast(req_scale->GetIdx()), *req_scale_tensor_desc); - } - ge::GeTensorPtr req_host_op_tensor = nullptr; - GE_MAKE_SHARED(req_host_op_tensor = std::make_shared(), return FAILED); - (void)req_host_op_tensor->SetTensorDesc(*req_scale_tensor_desc); - if (!ge::AttrUtils::SetTensor(req_host_op_desc, ge::ATTR_NAME_WEIGHTS, req_host_op_tensor)) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to set requant host op tensor.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - int req_scale_size = 0; - ge::GeTensorPtr req_scale_tensor = GetTensorByAnchor(req_scale); - FE_PARAM_CHECK_NOTNULL(req_scale_tensor); - if (SetAttrForRequantHostCpuOp(req_host_op_desc, req_scale_tensor, quant_offset, cube_bias, req_scale_size) != - SUCCESS) { - return FAILED; - } - FE_PARAM_CHECK_NOTNULL(req_host_op_desc->MutableOutputDesc(0)); - req_host_op_desc->MutableOutputDesc(0)->SetShape(ge::GeShape({req_scale_size})); - req_host_op_desc->MutableOutputDesc(0)->SetOriginShape(ge::GeShape({req_scale_size})); - ge::OpDescPtr cube_op_desc = cube_node->GetOpDesc(); - FE_PARAM_CHECK_NOTNULL(cube_op_desc); - FE_PARAM_CHECK_NOTNULL(cube_op_desc->MutableInputDesc(static_cast(req_scale->GetIdx()))); - cube_op_desc->MutableInputDesc(static_cast(req_scale->GetIdx()))->SetShape(ge::GeShape({req_scale_size})); - cube_op_desc->MutableInputDesc( - static_cast(req_scale->GetIdx()))->SetOriginShape(ge::GeShape({req_scale_size})); - - auto compute_graph = cube_node->GetOwnerComputeGraph(); - FE_PARAM_CHECK_NOTNULL(compute_graph); - ge::NodePtr req_op = compute_graph->AddNode(req_host_op_desc); - FE_PARAM_CHECK_NOTNULL(req_op); - if (ge::GraphUtils::AddEdge(req_scale->GetPeerOutAnchor(), req_op->GetInDataAnchor(0)) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - if (ge::GraphUtils::RemoveEdge(req_scale->GetPeerOutAnchor(), req_scale) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to remove edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - if (ge::GraphUtils::AddEdge(req_op->GetOutDataAnchor(0), req_scale) != ge::GRAPH_SUCCESS) { - GELOGE(ge::FAILED, "Node[%s, %s]: fail to add edge.", cube_node->GetName().c_str(), - cube_node->GetType().c_str()); - return FAILED; - } - fusion_nodes.emplace_back(req_op); - return SUCCESS; -} - -Status QuantUtilImpl::InsertNotFixpipeRequantScaleConvert(ge::InDataAnchorPtr &req_scale, - ge::InDataAnchorPtr &quant_offset, - ge::InDataAnchorPtr &cube_bias, - std::vector &fusion_nodes) { - GELOGD("Begin to do InsertNotFixpipeRequantScaleConvert"); - FE_PARAM_CHECK_NOTNULL(req_scale); - FE_PARAM_CHECK_NOTNULL(quant_offset); - FE_PARAM_CHECK_NOTNULL(cube_bias); - if (CreateRequantHostCpuOp(req_scale, quant_offset, cube_bias, fusion_nodes) != SUCCESS) { - GELOGE(ge::FAILED, "Fail to create RequantHostCpuOp."); - return FAILED; - } - return SUCCESS; -} - -Status QuantUtilImpl::InsertRequantScaleConvert(ge::InDataAnchorPtr &req_scale, - ge::InDataAnchorPtr &quant_offset, - ge::InDataAnchorPtr &cube_bias, - std::vector &fusion_nodes) { - if (IsSupportFixpipe()) { - return InsertFixpipeDequantScaleConvert(req_scale, quant_offset, fusion_nodes); - } - return InsertNotFixpipeRequantScaleConvert(req_scale, quant_offset, cube_bias, fusion_nodes); -} -} // namespace fe diff --git a/register/graph_optimizer/graph_fusion/fusion_quant_util_impl.h b/register/graph_optimizer/graph_fusion/fusion_quant_util_impl.h deleted file mode 100644 index 2de96c0fe07ea708270c1f143ba778abb91cf4f9..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/fusion_quant_util_impl.h +++ /dev/null @@ -1,187 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_FUSION_QUANT_UTIL_IMPL_H_ -#define INC_FUSION_QUANT_UTIL_IMPL_H_ -#include "register/graph_optimizer/graph_fusion/fusion_quant_util.h" -#include "graph/node.h" -#include "common/ge_common/ge_inner_error_codes.h" -#include "register/graph_optimizer/graph_optimize_register_error_codes.h" -#include "graph/ge_tensor.h" -#include - -namespace fe { - -#define GET_REQUANT_N(req_scale_date) (((req_scale_date) & 0x000000FF00000000UL) >> 32) - -#define FE_PARAM_CHECK_NOTNULL(val) \ - do { \ - if ((val) == nullptr) { \ - GELOGE(ge::FAILED, "Parameter[%s] must not be null.", #val); \ - return fe::PARAM_INVALID; \ - } \ - } while (0) - -namespace { -const int32_t NCHW_DIM_N = 0; -const int32_t NCHW_DIM_C = 1; -const int32_t NCHW_DIM_H = 2; -const int32_t NCHW_DIM_W = 3; - -const int32_t NC1HWC0_DIM_N = 0; -const int32_t NC1HWC0_DIM_C1 = 1; -const int32_t NC1HWC0_DIM_C0 = 4; -const int32_t NC1HWC0_DIM_H = 2; -const int32_t NC1HWC0_DIM_W = 3; - -const int32_t NDC1HWC0_DIM_N = 0; -const int32_t NDC1HWC0_DIM_D = 1; -const int32_t NDC1HWC0_DIM_C1 = 2; -const int32_t NDC1HWC0_DIM_C0 = 5; -const int32_t NDC1HWC0_DIM_H = 3; -const int32_t NDC1HWC0_DIM_W = 4; - -const int32_t C1HWNCoC0_DIM_C1 = 0; -const int32_t C1HWNCoC0_DIM_H = 1; -const int32_t C1HWNCoC0_DIM_W = 2; -const int32_t C1HWNCoC0_DIM_N = 3; -const int32_t C1HWNCoC0_DIM_Co = 4; -const int32_t C1HWNCoC0_DIM_C0 = 5; - -const int32_t C1DHWNCoC0_DIM_C1 = 0; -const int32_t C1DHWNCoC0_DIM_D = 1; -const int32_t C1DHWNCoC0_DIM_H = 2; -const int32_t C1DHWNCoC0_DIM_W = 3; - -const int32_t NHWC_DIM_N = 0; -const int32_t NHWC_DIM_H = 1; -const int32_t NHWC_DIM_W = 2; -const int32_t NHWC_DIM_C = 3; - -const int32_t HWCN_DIM_H = 0; -const int32_t HWCN_DIM_W = 1; -const int32_t HWCN_DIM_C = 2; -const int32_t HWCN_DIM_N = 3; - -const int32_t CHWN_DIM_C = 0; -const int32_t CHWN_DIM_H = 1; -const int32_t CHWN_DIM_W = 2; -const int32_t CHWN_DIM_N = 3; - -const int32_t NDHWC_DIM_N = 0; -const int32_t NDHWC_DIM_D = 1; -const int32_t NDHWC_DIM_H = 2; -const int32_t NDHWC_DIM_W = 3; -const int32_t NDHWC_DIM_C = 4; -const uint32_t DIMENSION_NUM_FIVE = 5; - -const int32_t NCDHW_DIM_N = 0; -const int32_t NCDHW_DIM_C = 1; -const int32_t NCDHW_DIM_D = 2; -const int32_t NCDHW_DIM_H = 3; -const int32_t NCDHW_DIM_W = 4; - -const int32_t DHWCN_DIM_D = 0; -const int32_t DHWCN_DIM_H = 1; -const int32_t DHWCN_DIM_W = 2; -const int32_t DHWCN_DIM_C = 3; -const int32_t DHWCN_DIM_N = 4; - -const int32_t DHWNC_DIM_D = 0; -const int32_t DHWNC_DIM_H = 1; -const int32_t DHWNC_DIM_W = 2; -const int32_t DHWNC_DIM_N = 3; -const int32_t DHWNC_DIM_C = 4; -} - -using TensorPtr = std::shared_ptr; - -class QuantUtilImpl { - public: - static Status BiasOptimizeByEdge(BiasOptimizeEdges ¶m, std::vector &fusion_nodes); - static Status BiasOptimizeByEdge(ge::NodePtr &quant_node, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes); - static Status BiasOptimizeByEdge(const QuantParam &quant_param, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes, - WeightMode cube_type = WeightMode::RESERVED); - static Status InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr deq_scale, std::vector &fusion_nodes); - static Status InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr &deq_scale, ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes); - static Status InsertQuantScaleConvert(ge::InDataAnchorPtr &quant_scale, ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes); - static Status InsertRequantScaleConvert(ge::InDataAnchorPtr &req_scale, ge::InDataAnchorPtr &quant_offset, - ge::InDataAnchorPtr &cube_bias, std::vector &fusion_nodes); - - private: - static Status BiasOptimizeByEdgeCommon(const ge::NodePtr &quant_node, BiasOptimizeEdges ¶m, - std::vector &fusion_nodes, WeightMode cube_type); - static bool NeedBiasInput(const ge::InDataAnchorPtr &bias); - static Status GetCoValueByWeight(ge::NodePtr &cube_node, size_t idx, std::vector &bias_shape); - static Status PadShapeTo4Dim(const ge::Format &filter_format, const std::vector &filter_dims, - std::vector &filter_dims4_d); - static int32_t GetAxisIndexByFormat(const ge::Format &format, const string &axis); - static TensorPtr CreateBiasTensor(const std::vector &shape); - static ge::NodePtr CreateBiasNode(std::shared_ptr &graph, const ge::GeTensorPtr &bias_ptr, - const std::string &cube_node_name); - static Status UpdateBiasOutputDesc(const ge::NodePtr &cube_node, const ge::GeShape &shape, ge::Format format, - const uint32_t index); - static Status UpdateCubeInputDesc(const ge::NodePtr &cube_node, const ge::GeShape &shape, const ge::Format &format, - const uint32_t index); - static Status CreateBiasInput(std::shared_ptr &graph, ge::NodePtr &cube_node, - const std::vector &shape, const size_t &bias_idx); - static Status GetWeightConstNode(const ge::InDataAnchorPtr &weight, ge::NodePtr &weight_const_node, - ge::NodePtr &ascend_weight_quant_node); - static Status GetInputDescByAnchor(const ge::InDataAnchorPtr &in_data_anchor, ge::GeTensorDesc &tensor_desc); - static void SetAttrsForBiasOptimizerOp(ge::OpDescPtr &op_desc, const ge::NodePtr &cube_node, - const ge::NodePtr &ascend_weight_quant_node, const WeightMode cube_type); - static Status SetQuantScaleAndOffset(const ge::NodePtr &quant_node, const BiasOptimizeEdges ¶m, - ge::OpDescPtr &host_op_desc); - static Status LinkBiasOptimizeHostOp(const ge::NodePtr &quant_node, const ge::NodePtr &weight_const_node, - const BiasOptimizeEdges ¶m, ge::NodePtr &host_op_node); - static Status CreateBiasOptimizeHostCpuOp(std::shared_ptr &graph, const ge::NodePtr &quant_node, - const BiasOptimizeEdges ¶m, const ge::NodePtr &weight_const_node, - WeightMode cube_type, std::vector &fusion_nodes); - static ge::OpDescPtr CreateDeqScaleHostOp(const std::string &op_name, const std::string &op_type, - const ge::OpDescPtr &cube_node, size_t index); - static bool IsNanoSoc(); - - static bool IsSupportFixpipe(); - - static ge::GeTensorPtr GetTensorByAnchor(ge::InDataAnchorPtr &anchor); - - static const uint8_t *GetDataByAnchor(ge::InDataAnchorPtr &anchor); - - static uint64_t TransM1Scale(const float &src_value); - - static uint64_t SetM1OfQuant(const float &scale, const float &offset, const ge::DataType &data_type); - - static Status UpdateScalarInput(const float *quant_scale_data, const float *quant_offset_data, - const ge::DataType &data_type, ge::GeTensorDescPtr &scale_tensor_desc, - ge::GeTensorPtr &quant_op_tensor); - - static Status CreateQuantOp(ge::NodePtr &cube_node, ge::InDataAnchorPtr &quant_scale, - ge::GeTensorDescPtr scale_tensor_desc, ge::GeTensorPtr quant_op_tensor, - std::vector &fusion_nodes); - - static Status InsertFixpipeQuantScaleConvert(ge::InDataAnchorPtr &quant_scale, ge::InDataAnchorPtr &quant_offset, - std::vector &fusion_nodes); - - static Status SetAttrForRequantHostCpuOp(ge::OpDescPtr &req_host_op_desc, const ge::GeTensorPtr &req_scale_tensor, - ge::InDataAnchorPtr &quant_offset, ge::InDataAnchorPtr &cube_bias, - int32_t &req_scale_size); - - static Status CreateRequantHostCpuOp(ge::InDataAnchorPtr &req_scale, ge::InDataAnchorPtr &cube_bias, - ge::InDataAnchorPtr &quant_offset, std::vector &fusion_nodes); - - static Status InsertNotFixpipeRequantScaleConvert(ge::InDataAnchorPtr &req_scale, ge::InDataAnchorPtr &quant_offset, - ge::InDataAnchorPtr &cube_bias, - std::vector &fusion_nodes); -}; -} // namespace fe -#endif diff --git a/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc b/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc deleted file mode 100644 index 0e5551ce47ca36b33a9f5444b581b9275e9aa3df..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.cc +++ /dev/null @@ -1,203 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include -#include -#include -#include "graph/debug/ge_log.h" -#include "register/graph_optimizer/fusion_common/fusion_statistic_recorder.h" -#include "register/graph_optimizer/fusion_common/graph_pass_util.h" -#include "register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" - -namespace fe { -GraphFusionPassBase::GraphFusionPassBase() { - pattern_fusion_base_pass_impl_ptr_ = std::make_shared(); -} - -GraphFusionPassBase::~GraphFusionPassBase() {} - -/** - * @ingroup fe - * @brief execute pass - */ -Status GraphFusionPassBase::Run(ge::ComputeGraph &graph) { - bool is_patterns_ok = true; - // build Pattern - std::vector patterns; - std::string invalid_patterns; - pattern_fusion_base_pass_impl_ptr_->GetPatterns(patterns); - if (patterns.empty()) { - patterns = DefinePatterns(); - for (FusionPattern *pattern : patterns) { - if (pattern != nullptr) { - const bool ok = pattern->Build(); - if (!ok) { - GELOGW("[RunFusionPass][Check] Pattern: %s build failed", pattern->GetName().c_str()); - invalid_patterns += pattern->GetName() + ","; - } - pattern->Dump(); - is_patterns_ok = is_patterns_ok && ok; - } - } - - pattern_fusion_base_pass_impl_ptr_->SetPatterns(patterns); - } - if (!is_patterns_ok) { - GELOGE(FAILED, "[Check][Patterns]Pattern:%s invalid.", invalid_patterns.c_str()); - return FAILED; - } - - NodeMapInfoPtr node_map_info = nullptr; - if (GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph) == SUCCESS) { - if (node_map_info->run_count == std::numeric_limits::max()) { - GELOGE(ge::FAILED, "Run count is overflow."); - return FAILED; - } - node_map_info->run_count++; - } - // do matching and fusion for each pattern - bool final_changed = false; - for (const FusionPattern * const pattern : patterns) { - if (pattern != nullptr) { - bool changed = false; - const Status ret = RunOnePattern(graph, *pattern, changed); - if (ret != SUCCESS) { - GELOGW("[RunFusionPass][Check] Run pattern %s failed, the graph was not modified by it.", pattern->GetName().c_str()); - return ret; - } - final_changed = final_changed || changed; - } - } - return final_changed ? SUCCESS : NOT_CHANGED; -} - -/** - * @ingroup fe - * @brief do matching and fusion in graph based on the pattern - */ -Status GraphFusionPassBase::RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed) { - changed = false; - Mappings mappings; - int32_t effect_times = 0; - const uint32_t graph_id = graph.GetGraphID(); - FusionInfo fusion_info(graph.GetSessionID(), to_string(graph_id), GetName(), static_cast(mappings.size()), - effect_times); - // match all patterns in graph, and save them to mappings - if (!MatchAll(graph, pattern, mappings)) { - GELOGD("GraphFusionPass[%s]: pattern=%s, matched_times=%zu, effected_times=%d.", GetName().c_str(), - pattern.GetName().c_str(), mappings.size(), effect_times); - return SUCCESS; - } - - GELOGD("This graph has been matched with pattern[%s]. The mappings are as follows.", pattern.GetName().c_str()); - - // print the results of matching - pattern_fusion_base_pass_impl_ptr_->DumpMappings(pattern, mappings); - NodeMapInfoPtr node_map_info = nullptr; - // get nodes by type from node - (void)GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph); - // do fusion for each mapping - for (Mapping &mapping : mappings) { - std::vector fus_nodes; - ge::NodePtr first_node = nullptr; - for (auto &item : mapping) { - if (!item.second.empty()) { - first_node = item.second[0]; - break; - } - } - - const Status status = Fusion(graph, mapping, fus_nodes); - if ((status != SUCCESS) && (status != NOT_CHANGED)) { - GELOGE(status, "[Fuse][Graph]Fail with pattern[%s].", pattern.GetName().c_str()); - return status; - } - - if (status == SUCCESS) { - effect_times++; - if (!fus_nodes.empty()) { - // add fusednode to node map info - for (ge::NodePtr &node : fus_nodes) { - GraphPassUtil::AddNodeFromOpTypeMap(node_map_info, node); - } - } - } - changed = changed || (status == SUCCESS); - } - - // get match times and effect times - FusionStatisticRecorder &fusion_statistic_inst = FusionStatisticRecorder::Instance(); - fusion_info.SetMatchTimes(static_cast(mappings.size())); - fusion_info.SetEffectTimes(effect_times); - fusion_statistic_inst.UpdateGraphFusionMatchTimes(fusion_info); - fusion_statistic_inst.UpdateGraphFusionEffectTimes(fusion_info); - GELOGD("GraphId[%d], GraphFusionPass[%s]: pattern=%s, matched_times=%d, effected_times=%d.", graph_id, - GetName().c_str(), pattern.GetName().c_str(), static_cast(mappings.size()), effect_times); - return SUCCESS; -} - -/** - * @ingroup fe - * @brief match all nodes in graph according to pattern - * match nodes in graph according to pattern, the algorithm is shown as following: - * 1. get output node from pattern - * 2. Search for candidate nodes in Graph (network Graph generated after parsing) according to Op Type and - * (optional), and add the candidate node to the list of candidates - * 3. For each Node in the candidate list, check whether the type and the number - * of precursors are consistent with the description of corresponding Op - * in pattern. If they are consistent, add the precursor Node to the - * candidate list, and add "PatternOp-GraphNode" to the mapping; otherwise, return an empty mapping - * 4. repeat step 3 until all the Ops in pattern are matched - * 5. if all the Ops in pattern are matched successfully, return the mapping of PatternOp and GraphNode - */ -bool GraphFusionPassBase::MatchAll(const ge::ComputeGraph &graph, const FusionPattern &pattern, - Mappings &mappings) const { - std::vector matched_output_nodes; - - // find all the output nodes of pattern in the graph based on Op type - std::shared_ptr output_op_desc = pattern.GetOutput(); - if (output_op_desc == nullptr) { - return false; - } - - if (!pattern_fusion_base_pass_impl_ptr_->GetMatchOutputNodes(graph, pattern, matched_output_nodes)) { - return false; - } - - // begin matching from every output node - for (ge::NodePtr &output_node : matched_output_nodes) { - Mapping mapping; - if (pattern_fusion_base_pass_impl_ptr_->MatchFromOutput(output_node, output_op_desc, mapping)) { - mappings.push_back(mapping); - } - } - // if matching is successful, return true; otherwise false - return !mappings.empty(); -} - -/** - * @ingroup fe - * @brief get an op from mapping according to ID - */ -ge::NodePtr GraphFusionPassBase::GetNodeFromMapping(const std::string &id, const Mapping &mapping) { - for (auto &item : mapping) { - const std::shared_ptr op_desc = item.first; - if ((op_desc != nullptr) && (op_desc->id == id)) { - if (item.second.empty()) { - return nullptr; - } else { - return item.second[0]; - } - } - } - return nullptr; -} - -} // namespace fe diff --git a/register/graph_optimizer/graph_fusion/graph_pass_util.cc b/register/graph_optimizer/graph_fusion/graph_pass_util.cc deleted file mode 100644 index 7e578b382077ef2198c5d67467ebabf893291a00..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/graph_pass_util.cc +++ /dev/null @@ -1,730 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/graph_pass_util.h" -#include -#include "graph/debug/ge_log.h" -#include "register/graph_optimizer/fusion_common/fusion_turbo_utils.h" -#include "mmpa/mmpa_api.h" - -#define REGISTER_MAKE_SHARED(exec_expr0, exec_expr1) \ - do { \ - try { \ - exec_expr0; \ - } catch (...) { \ - GELOGW("Make shared failed"); \ - exec_expr1; \ - } \ - } while (0) - -namespace fe { -namespace { -const std::string kPassName = "pass_name"; -const std::string kBackWard = "_backward"; -const std::string kRecompute = "_recompute"; -const std::string kOptimizer = "_optimizer"; -const std::string kOpDebugCompile = "_op_debug_compile"; -const std::array kBoolAttrNeedInherit = {kRecompute, kOptimizer}; -// Indicates custom impl mode for specified op -const std::string kOpCustomImplModeEnum = "_op_custom_impl_mode_enum"; -// Indicates impl mode for specified op -const std::string kOpImplModeEnum = "_op_impl_mode_enum"; -// impl_mode priority from high to low -const std::map kOpImplIntToPriorityMap = { - {0x40, 1}, // enable_hi_float_32_execution - {0x20, 2}, // enable_float_32_execution - {0x4, 3}, // high_precision - {0x2, 4}, // high_performance - {0x10, 5}, // support_of_bound_index - {0x8, 6}, // super_performance -}; -const uint32_t DefaultGroupId = 0xFFFFFFFF; -const std::set HeavyOpList = {"QuantBatchMatmul", "BatchMatmulFixpipe", "WeightQuantBatchmatmul", "MatMul", - "BatchMatMul", "BatchMatMulV2", "MatMulV2", "MatMulV3", "MatMulV2Compress"}; -const std::unordered_set kGeLocalOpType = {"Expanddims", "Reshape", "ReFormat", "Squeeze", - "Unsqueeze", "SqueezeV2", "UnsqueezeV2", "SqueezeV3", - "UnsqueezeV3", "Size", "Shape", "ShapeN", "Rank"}; -} - -void GraphPassUtil::SetOutputDescAttr(const uint32_t &origin_index, const uint32_t &fusion_index, - const ge::NodePtr &origin_node, const ge::NodePtr &fusion_node) { - if (origin_node == nullptr || fusion_node == nullptr) { - return; - } - - if (fusion_node->GetOpDesc() == nullptr) { - return; - } - - const ge::OpDescPtr origin_op_desc = origin_node->GetOpDesc(); - if (origin_op_desc == nullptr) { - return; - } - - auto origin_node_output_desc = origin_node->GetOpDesc()->GetOutputDescPtr(origin_index); - if (origin_node_output_desc == nullptr) { - return; - } - - const ge::GeTensorDescPtr fusion_node_output_desc = fusion_node->GetOpDesc()->MutableOutputDesc(fusion_index); - if (fusion_node_output_desc == nullptr) { - return; - } - - SetOutputDescAttr(origin_node_output_desc, static_cast(origin_index), origin_op_desc, - fusion_node_output_desc); -} - -void GraphPassUtil::SetOutputDescAttr(ge::ConstGeTensorDescPtr &origin_tensor_desc, const int64_t origin_index, - const ge::OpDescPtr &origin_op_desc, - const ge::GeTensorDescPtr &target_tensor_desc) { - if (origin_tensor_desc == nullptr || target_tensor_desc == nullptr || origin_op_desc == nullptr) { - return; - } - - // set origin name - std::string original_name; - if (!ge::AttrUtils::GetStr(origin_tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, original_name) || - original_name.empty()) { - std::vector original_names; - if (ge::AttrUtils::GetListStr(origin_op_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names) && - !original_names.empty()) { - original_name = original_names[0]; - } else { - original_name = origin_op_desc->GetName(); - } - } - (void)ge::AttrUtils::SetStr(target_tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, original_name); - - // set origin output index - int64_t origin_output_index = 0; - if (ge::AttrUtils::GetInt(origin_tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index)) { - (void)ge::AttrUtils::SetInt(target_tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - } else { - (void)ge::AttrUtils::SetInt(target_tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_index); - } - - // set origin output data type - const ge::DataType origin_data_type = GetDataDumpOriginDataType(origin_tensor_desc); - if (origin_data_type != ge::DT_UNDEFINED) { - SetDataDumpOriginDataType(origin_data_type, target_tensor_desc); - } else { - SetDataDumpOriginDataType(origin_tensor_desc->GetOriginDataType(), target_tensor_desc); - } - - // set origin output format - const ge::Format origin_format = GetDataDumpOriginFormat(origin_tensor_desc); - if (origin_format != ge::FORMAT_RESERVED) { - SetDataDumpOriginFormat(origin_format, target_tensor_desc); - } else { - SetDataDumpOriginFormat(origin_tensor_desc->GetOriginFormat(), target_tensor_desc); - } -} - -/** get origin format for data dump - * - * @param tensor_desc,usually is output_desc - * - * @return format of this tensor_desc - */ -ge::Format GraphPassUtil::GetDataDumpOriginFormat(const ge::GeTensorDescPtr &tensor_desc) { - std::string origin_format_str; - if (!ge::AttrUtils::GetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format_str)) { - // Can not get the certificate and it's not set,return directly - return ge::FORMAT_RESERVED; - } - if (origin_format_str == "RESERVED") { - return ge::FORMAT_RESERVED; - } - return ge::TypeUtils::SerialStringToFormat(origin_format_str); -} - -ge::Format GraphPassUtil::GetDataDumpOriginFormat(ge::ConstGeTensorDescPtr &tensor_desc) { - std::string origin_format_str; - if (!ge::AttrUtils::GetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format_str)) { - // Can not get the certificate and it's not set,return directly - return ge::FORMAT_RESERVED; - } - if (origin_format_str == "RESERVED") { - return ge::FORMAT_RESERVED; - } - return ge::TypeUtils::SerialStringToFormat(origin_format_str); -} - -/** set origin format for data dump - * - * @param origin format - * - * @param tensor_desc,usually is output_desc - */ -void GraphPassUtil::SetDataDumpOriginFormat(const ge::Format &origin_format, - const ge::GeTensorDescPtr &tensor_desc) { - std::string origin_format_str = "RESERVED"; - if (origin_format != ge::FORMAT_RESERVED) { - origin_format_str = ge::TypeUtils::FormatToSerialString(origin_format); - } - (void)ge::AttrUtils::SetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format_str); -} - -/** set origin datatype for data dump - * - * @param origin datatype - * - * @param tensor_desc,usually is output_desc - */ -void GraphPassUtil::SetDataDumpOriginDataType(const ge::DataType origin_data_type, - const ge::GeTensorDescPtr &tensor_desc) { - std::string origin_data_type_str = "RESERVED"; - if (origin_data_type != ge::DT_UNDEFINED) { - origin_data_type_str = ge::TypeUtils::DataTypeToSerialString(origin_data_type); - } - (void)ge::AttrUtils::SetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_data_type_str); -} - -/** get origin datatype for data dump - * - * @param tensor_desc,usually is output_desc - * - * @return format of this tensor_desc - */ -ge::DataType GraphPassUtil::GetDataDumpOriginDataType(const ge::GeTensorDescPtr &tensor_desc) { - std::string origin_data_type_str; - if (!ge::AttrUtils::GetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_data_type_str)) { - return ge::DT_UNDEFINED; - } - if (origin_data_type_str == "RESERVED") { - return ge::DT_UNDEFINED; - } - return ge::TypeUtils::SerialStringToDataType(origin_data_type_str); -} - -ge::DataType GraphPassUtil::GetDataDumpOriginDataType(ge::ConstGeTensorDescPtr &tensor_desc) { - std::string origin_data_type_str; - if (!ge::AttrUtils::GetStr(tensor_desc, ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_data_type_str)) { - return ge::DT_UNDEFINED; - } - if (origin_data_type_str == "RESERVED") { - return ge::DT_UNDEFINED; - } - return ge::TypeUtils::SerialStringToDataType(origin_data_type_str); -} - -void GraphPassUtil::AddNodeFromOpTypeMap(const NodeMapInfoPtr &node_map_info, const ge::NodePtr &node_ptr) { - if ((node_map_info == nullptr) || (node_ptr == nullptr)) { - return; - } - const NodeTypeMapPtr node_type_map = node_map_info->node_type_map; - std::string real_op_type = ge::NodeUtils::GetNodeType(*node_ptr); - const auto iter = node_type_map->find(real_op_type); - if (iter != node_type_map->end()) { - iter->second[node_ptr->GetName()] = node_ptr; - } else { - (void)node_type_map->emplace(std::make_pair(real_op_type, - std::map{{node_ptr->GetName(), node_ptr}})); - } -} - -Status GraphPassUtil::GetOpTypeMapToGraph(NodeMapInfoPtr &node_map_info, const ge::ComputeGraph &graph) { - node_map_info = graph.TryGetExtAttr("NodeMapInfo", node_map_info); - if (node_map_info == nullptr) { - return FAILED; - } - return SUCCESS; -} - -void GraphPassUtil::RecordPassnameAndOriginalAttrs(const std::vector &original_nodes, - std::vector &fus_nodes, const string &pass_name, - const OriginOpAttrsVec &origin_op_attrs) { - for (auto &node : fus_nodes) { - (void)StoreAndUpdataOriginFusionPassName(node->GetOpDesc(), original_nodes, pass_name); - RecordOriginalOpAttrs(original_nodes, node->GetOpDesc(), pass_name, origin_op_attrs); - } -} - -Status GraphPassUtil::StoreAndUpdataOriginFusionPassName(const ge::OpDescPtr &op_desc, - const std::vector &original_nodes, - const std::string &pass_name) { - std::vector pass_names; - std::vector pass_names_tmp; - if (op_desc == nullptr) { - return FAILED; - } - for (const ge::NodePtr &original_node : original_nodes) { - if ((original_node == nullptr)) { - return FAILED; - } - const ge::OpDescPtr origin_op_desc_ptr = original_node->GetOpDesc(); - if (!ge::AttrUtils::GetListStr(origin_op_desc_ptr, kPassName, pass_names_tmp) || pass_names_tmp.empty()) { - continue; - } - (void)pass_names.insert(pass_names.cend(), pass_names_tmp.cbegin(), pass_names_tmp.cend()); - } - pass_names.push_back(pass_name); - if (!ge::AttrUtils::SetListStr(op_desc, kPassName, pass_names)) { - return FAILED; - } - return SUCCESS; -} - -void GraphPassUtil::GetBackWardAttr(const std::vector &original_nodes, - bool &backward, BackWardInheritMode inherit_mode) { - if (inherit_mode == BackWardInheritMode::kInheritTrue) { - backward = true; - return; - } - - if (inherit_mode != BackWardInheritMode::kDoNotInherit) { - for (const auto &origin_node : original_nodes) { - (void) ge::AttrUtils::GetBool(origin_node->GetOpDesc(), kBackWard, backward); - if (!backward) { - continue; - } - - if (inherit_mode != BackWardInheritMode::kFusedNode) { - break; - } - - bool has_in_node_backward = false; - for (const auto &in_node : origin_node->GetInNodes()) { - (void) ge::AttrUtils::GetBool(in_node->GetOpDesc(), kBackWard, has_in_node_backward); - if (has_in_node_backward) { - return; - } - } - - if (!has_in_node_backward) { - backward = false; - } - } - } -} - -void GraphPassUtil::InheritGraphRelatedAttr(const std::vector &original_nodes, - const std::vector &fusion_nodes, - BackWardInheritMode inherit_mode) { - vector bool_attrs(kBoolAttrNeedInherit.size(), false); - size_t i = 0; - for (const auto &attr : kBoolAttrNeedInherit) { - for (const auto &origin_node : original_nodes) { - bool value = false; - (void)ge::AttrUtils::GetBool(origin_node->GetOpDesc(), attr, value); - if (value) { - bool_attrs[i] = value; - break; - } - } - ++i; - } - - bool backward = false; - GetBackWardAttr(original_nodes, backward, inherit_mode); - - for (const auto &fusion_node : fusion_nodes) { - const ge::OpDescPtr fusion_op = fusion_node->GetOpDesc(); - if (backward && !ge::AttrUtils::HasAttr(fusion_op, kBackWard)) { - (void) ge::AttrUtils::SetBool(fusion_op, kBackWard, backward); - } - - if (bool_attrs.size() != kBoolAttrNeedInherit.size()) { - GELOGW("[Fusion][InheritAttr]Integer attribute size %zu is incorrect, should be %zu.", - bool_attrs.size(), kBoolAttrNeedInherit.size()); - return; - } - - i = 0; - for (const auto &attr : kBoolAttrNeedInherit) { - if (bool_attrs[i] != 0 && !ge::AttrUtils::HasAttr(fusion_op, attr)) { - (void) ge::AttrUtils::SetBool(fusion_op, attr, bool_attrs[i]); - } - ++i; - } - } -} - -void GraphPassUtil::GetOpCustomImplModeFromOriNode(const std::vector &original_nodes, - std::set &op_impl_mode_priority_set, - std::map &origin_node_impl_mode_map) { - for (const auto &origin_node : original_nodes) { - int64_t tmp_op_impl_mode = 0; - (void)ge::AttrUtils::GetInt(origin_node->GetOpDesc(), kOpCustomImplModeEnum, tmp_op_impl_mode); - if (tmp_op_impl_mode == 0) { - continue; - } - GELOGD("Node [%s, %s] has _op_custom_impl_mode_enum set to 0x%llx.", origin_node->GetName().c_str(), - origin_node->GetType().c_str(), tmp_op_impl_mode); - auto iter = kOpImplIntToPriorityMap.find(tmp_op_impl_mode); - if (iter != kOpImplIntToPriorityMap.end()) { - GELOGD("Node[%s, %s] has impl_mode priority %zu.", origin_node->GetName().c_str(), - origin_node->GetType().c_str(), iter->second); - (void)op_impl_mode_priority_set.emplace(iter->second); - origin_node_impl_mode_map[origin_node->GetName()] = tmp_op_impl_mode; - } - } -} - -void GraphPassUtil::SetOpCustomImplModeToFusNode(const ge::OpDescPtr &fusion_op, - const std::map &origin_node_impl_mode_map, - const std::set &op_impl_mode_priority_set) { - auto iter = origin_node_impl_mode_map.find(fusion_op->GetName()); - if (iter != origin_node_impl_mode_map.end()) { - (void)ge::AttrUtils::SetInt(fusion_op, kOpCustomImplModeEnum, iter->second); - GELOGD("Node[%s, %s] set _op_impl_mode_enum 0x%llx by op_name.", fusion_op->GetName().c_str(), - fusion_op->GetType().c_str(), iter->second); - } else { - if (op_impl_mode_priority_set.empty()) { - return; - } - for (auto iter1 = kOpImplIntToPriorityMap.begin(); iter1 != kOpImplIntToPriorityMap.end(); ++iter1) { - if (iter1->second == *op_impl_mode_priority_set.begin()) { - (void)ge::AttrUtils::SetInt(fusion_op, kOpCustomImplModeEnum, iter1->first); - GELOGD("Node[%s, %s] set _op_impl_mode_enum 0x%llx by priority.", fusion_op->GetName().c_str(), - fusion_op->GetType().c_str(), iter1->first); - } - } - } - return; -} - -void GraphPassUtil::GetOpCustomGroupIdFromOriginNodes(const std::vector &original_nodes, - uint32_t ¶llel_group_id) { - for (auto const &node : original_nodes) { - if (HeavyOpList.count(node->GetOpDesc()->GetType().c_str()) == 1) { - if (ge::AttrUtils::GetInt(node->GetOpDesc(), ge::ATTR_NAME_PARALLEL_GROUP_ID, parallel_group_id)) { - GELOGD("Node [%s, %s] has a _parallel_group_id of %u.", node->GetName().c_str(), - node->GetType().c_str(), parallel_group_id); - if (parallel_group_id == DefaultGroupId) { - continue; - } - return; - } - GELOGD("Node [%s, %s] is in the HeavyOpList, but lacks the _parallel_group_id attribute.", - node->GetName().c_str(), node->GetType().c_str()); - } else { - GELOGD("Node[%s, %s] not in HeavyOpList.", node->GetName().c_str(), node->GetType().c_str()); - } - } -} - -void GraphPassUtil::SetOpCustomGroupIdToFusNode(const ge::OpDescPtr &fusion_op, const uint32_t ¶llel_group_id) { - if (HeavyOpList.count(fusion_op->GetType().c_str()) == 1) { - uint32_t tmp_parallel_group_id = DefaultGroupId; - if (!ge::AttrUtils::GetInt(fusion_op, ge::ATTR_NAME_PARALLEL_GROUP_ID, tmp_parallel_group_id) || - tmp_parallel_group_id == DefaultGroupId) { - (void)ge::AttrUtils::SetInt(fusion_op, ge::ATTR_NAME_PARALLEL_GROUP_ID, static_cast(parallel_group_id)); - GELOGD("Fuse Node[%s, %s] in HeavyOpList, without _parallel_group_id; parallel_group_id is %u.", - fusion_op->GetName().c_str(), fusion_op->GetType().c_str(), parallel_group_id); - } else { - GELOGD("Fuse Node [%s, %s] is in the HeavyOpList, and this node already has a parallel_group_id of %u.", - fusion_op->GetNamePtr(), fusion_op->GetTypePtr(), tmp_parallel_group_id); - } - return; - } - GELOGD("Fuse Node[%s, %s] is not in the HeavyOpList; no need to set the parallel_group_id.", - fusion_op->GetName().c_str(), fusion_op->GetType().c_str()); -} - -void GraphPassUtil::InheritAttrFromOriNodes(const std::vector &original_nodes, - const std::vector &fusion_nodes, - BackWardInheritMode inherit_mode) { - std::string op_compile_strategy; - for (const auto &origin_node : original_nodes) { - if (ge::AttrUtils::GetStr(origin_node->GetOpDesc(), ge::ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy) && - !op_compile_strategy.empty()) { - break; - } - } - - int64_t keep_dtype = 0; - for (const auto &origin_node : original_nodes) { - if (ge::AttrUtils::GetInt(origin_node->GetOpDesc(), ge::ATTR_NAME_KEEP_DTYPE, keep_dtype) && - keep_dtype != 0) { - break; - } - } - - bool op_debug_compile = false; - for (const auto &origin_node : original_nodes) { - if (ge::AttrUtils::GetBool(origin_node->GetOpDesc(), kOpDebugCompile, op_debug_compile) && - op_debug_compile) { - break; - } - } - - std::set op_impl_mode_priority_set; - std::map origin_node_impl_mode_map; - GetOpCustomImplModeFromOriNode(original_nodes, op_impl_mode_priority_set, origin_node_impl_mode_map); - uint32_t parallel_group_id = DefaultGroupId; - GetOpCustomGroupIdFromOriginNodes(original_nodes, parallel_group_id); - - for (const auto &fusion_node : fusion_nodes) { - const ge::OpDescPtr fusion_op = fusion_node->GetOpDesc(); - if (!op_compile_strategy.empty() && !ge::AttrUtils::HasAttr(fusion_op, ge::ATTR_NAME_OP_COMPILE_STRATEGY)) { - (void) ge::AttrUtils::SetStr(fusion_op, ge::ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); - } - - if (keep_dtype != 0 && !ge::AttrUtils::HasAttr(fusion_op, ge::ATTR_NAME_KEEP_DTYPE)) { - (void) ge::AttrUtils::SetInt(fusion_op, ge::ATTR_NAME_KEEP_DTYPE, keep_dtype); - } - - if (op_debug_compile && !ge::AttrUtils::HasAttr(fusion_op, kOpDebugCompile)) { - (void) ge::AttrUtils::SetBool(fusion_op, kOpDebugCompile, op_debug_compile); - } - - if (parallel_group_id != DefaultGroupId) { - SetOpCustomGroupIdToFusNode(fusion_op, parallel_group_id); - } - - SetOpCustomImplModeToFusNode(fusion_op, origin_node_impl_mode_map, op_impl_mode_priority_set); - } - InheritGraphRelatedAttr(original_nodes, fusion_nodes, inherit_mode); -} - -void GraphPassUtil::RecordOriginalOpAttrs(const std::vector &original_nodes, - const ge::OpDescPtr &op_desc, const string &pass_name, - const OriginOpAttrsVec &origin_op_attrs) { - const char *dump_ge_graph = nullptr; - MM_SYS_GET_ENV(MM_ENV_DUMP_GE_GRAPH, dump_ge_graph); - FUSION_TURBO_NOTNULL(dump_ge_graph,); - if (op_desc == nullptr) { - GELOGD("op_desc is nullptr"); - return; - } - // 1. get the original_names - GELOGD("Start to record op[%s] origin op attrs after pass[%s]", op_desc->GetName().c_str(), pass_name.c_str()); - std::shared_ptr origin_op_attrs_map = nullptr; - REGISTER_MAKE_SHARED(origin_op_attrs_map = std::make_shared(), return); - OriginOpAttrsVec origin_op_attrs_vec; - size_t index = 0; - for (const ge::NodePtr &original_node : original_nodes) { - if (original_node == nullptr) { - return; - } - const ge::OpDescPtr origin_op_desc_ptr = original_node->GetOpDesc(); - if (origin_op_desc_ptr == nullptr) { - return; - } - std::shared_ptr op_attrs_maps_tmp = nullptr; - REGISTER_MAKE_SHARED(op_attrs_maps_tmp = std::make_shared(), return); - op_attrs_maps_tmp = origin_op_desc_ptr->TryGetExtAttr(ge::ATTR_NAME_ORIGIN_OP_ATTRS_MAP, op_attrs_maps_tmp); - if ((op_attrs_maps_tmp != nullptr) && (!op_attrs_maps_tmp->empty())) { - size_t op_attrs_index = 0; - std::vector pass_names; - if ((!ge::AttrUtils::GetListStr(origin_op_desc_ptr, kPassName, pass_names)) || pass_names.empty()) { - continue; - } - for (const auto &pass_name_tmp : pass_names) { - if (op_attrs_maps_tmp->find(pass_name_tmp) == op_attrs_maps_tmp->cend()) { - GELOGD("Not find pass_name[%s] in ATTR_NAME_ORIGIN_OP_ATTRS_MAP", pass_name_tmp.c_str()); - continue; - } - (void)origin_op_attrs_map->insert(std::pair(pass_name_tmp, - (*op_attrs_maps_tmp)[pass_name_tmp])); - // get last item of op_attrs_maps_tmp and push all origin_op_attrs into vector - if (op_attrs_index == (pass_names.size() - 1UL)) { - for (const auto &origin_op_attrs_tmp : (*op_attrs_maps_tmp)[pass_name_tmp]) { - origin_op_attrs_vec.push_back(origin_op_attrs_tmp); - } - } - ++op_attrs_index; - } - } else if (origin_op_attrs.empty()) { - std::vector origin_op_attrs_single_vec; - origin_op_attrs_single_vec.push_back(origin_op_desc_ptr->GetName().c_str()); - origin_op_attrs_single_vec.push_back(origin_op_desc_ptr->GetType().c_str()); - origin_op_attrs_vec.push_back(origin_op_attrs_single_vec); - } else if (index < origin_op_attrs.size()) { - origin_op_attrs_vec.push_back(origin_op_attrs.at(index)); - } - ++index; - } - (void)origin_op_attrs_map->insert(std::pair(pass_name, origin_op_attrs_vec)); - - // 2. set the dump attr - (void)op_desc->SetExtAttr(ge::ATTR_NAME_ORIGIN_OP_ATTRS_MAP, origin_op_attrs_map); -} - -void GraphPassUtil::RecordOriginalNames(const std::vector &original_nodes, - const ge::NodePtr &node) { - // 1. get the original_names - std::vector original_names; - std::vector original_types; - for (const ge::NodePtr &original_node : original_nodes) { - if ((original_node == nullptr) || (original_node->GetOpDesc() == nullptr)) { - return; - } - - const ge::OpDescPtr origin_op_desc_ptr = original_node->GetOpDesc(); - std::vector names_tmp; - std::vector types_tmp; - const bool is_has_attr = - ge::AttrUtils::GetListStr(origin_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names_tmp) && - !names_tmp.empty(); - (void)ge::AttrUtils::GetListStr(origin_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES, types_tmp); - if (is_has_attr) { - for (const auto &node_name : names_tmp) { - if (!node_name.empty()) { - original_names.push_back(node_name); - } - } - for (const auto &node_type : types_tmp) { - if (!node_type.empty()) { - original_types.push_back(node_type); - } - } - } else { - original_names.push_back(origin_op_desc_ptr->GetName()); - original_types.push_back(origin_op_desc_ptr->GetType()); - } - } - - // 2. set the dump attr - if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { - return; - } - const ge::OpDescPtr node_op_desc_ptr = node->GetOpDesc(); - (void)ge::AttrUtils::SetListStr(node_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); - (void)ge::AttrUtils::SetListStr(node_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES, original_types); -} - -void GraphPassUtil::AddNodeToNodeTypeMap(const NodeTypeMapPtr &node_type_map, const std::string &op_type, - const ge::NodePtr &node_ptr) { - if ((node_type_map == nullptr) || (node_ptr == nullptr)) { - return; - } - const auto iter = node_type_map->find(op_type); - if (iter == node_type_map->end()) { - (void)node_type_map->emplace(std::make_pair(op_type, - std::map{{node_ptr->GetName(), node_ptr}})); - } else { - (void)iter->second.emplace(node_ptr->GetName(), node_ptr); - } -} - -void GraphPassUtil::RemoveNodeFromNodeTypeMap(NodeTypeMapPtr &node_type_map, const std::string &op_type, - const ge::NodePtr &node_ptr) { - if ((node_type_map == nullptr) || (node_ptr == nullptr)) { - return; - } - const auto iter = node_type_map->find(op_type); - if (iter != node_type_map->end()) { - (void)iter->second.erase(node_ptr->GetName()); - } -} - -void GraphPassUtil::GetNodesFromNodeTypeMap(NodeTypeMapPtr &node_type_map, const std::string &op_type, - std::vector &nodes) { - if (node_type_map == nullptr) { - return; - } - - const auto iter = node_type_map->find(op_type); - if (iter == node_type_map->end()) { - return; - } - if (iter->second.empty()) { - return; - } - - for (auto node_iter = iter->second.cbegin(); node_iter != iter->second.cend(); node_iter++) { - nodes.push_back(node_iter->second); - } -} - -ge::OutDataAnchorPtr GraphPassUtil::GetPeerOutAnchorNotInDeleteList(const ge::NodePtr &node, size_t idx) { - if (node == nullptr) { - return nullptr; - } - auto anchor = node->GetInDataAnchor(static_cast(idx)); - if (anchor == nullptr) { - return nullptr; - } - auto peer_anchor = anchor->GetPeerOutAnchor(); - if (peer_anchor == nullptr) { - return nullptr; - } - auto peer_node = peer_anchor->GetOwnerNode(); - if (peer_node == nullptr) { - return nullptr; - } - if (kGeLocalOpType.count(peer_node->GetType()) != 0) { - if (!peer_node->GetInDataNodes().empty()) { - return GetPeerOutAnchorNotInDeleteList(peer_node, 0); - } - } - return peer_anchor; -} - -void GraphPassUtil::SetPairTensorIntAttr(const ge::NodePtr &node, size_t idx, - const std::map &attr_val) { - auto peer_anchor = GetPeerOutAnchorNotInDeleteList(node, idx); - if (peer_anchor == nullptr) { - return; - } - const auto peer_idx = peer_anchor->GetIdx(); - auto in_tensor = node->GetOpDesc()->MutableInputDesc(static_cast(idx)); - auto peer_tensor = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableOutputDesc(static_cast(peer_idx)); - if (peer_tensor == nullptr || in_tensor == nullptr) { - return; - } - for (auto &iter : attr_val) { - (void)ge::AttrUtils::SetInt(in_tensor, iter.first, iter.second); - (void)ge::AttrUtils::SetInt(peer_tensor, iter.first, iter.second); - for (auto &peer_peer_anchor : peer_anchor->GetPeerInDataAnchors()) { - if (peer_peer_anchor == nullptr) { - continue; - } - auto peer_peer_node = peer_peer_anchor->GetOwnerNode(); - if (peer_peer_node == nullptr) { - continue; - } - if (kGeLocalOpType.count(peer_peer_node->GetType()) != 0) { - continue; - } - const auto peer_peer_idx = peer_peer_anchor->GetIdx(); - auto peer_peer_tensor = peer_peer_node->GetOpDesc()->MutableInputDesc(static_cast(peer_peer_idx)); - if (peer_peer_tensor == nullptr) { - continue; - } - (void)ge::AttrUtils::SetInt(peer_peer_tensor, iter.first, iter.second); - } - } -} - -void GraphPassUtil::SetPairTensorAttr(const ge::NodePtr &node, size_t idx, - const std::map &attr_val, bool is_input) { - if (is_input) { - SetPairTensorIntAttr(node, idx, attr_val); - return; - } - auto anchor = node->GetOutDataAnchor(static_cast(idx)); - if (anchor == nullptr) { - return; - } - auto peer_anchors = anchor->GetPeerInDataAnchors(); - for (auto &in_anchor : peer_anchors) { - if (in_anchor == nullptr) { - continue; - } - auto peer_node = in_anchor->GetOwnerNode(); - if (peer_node == nullptr) { - continue; - } - const auto peer_idx = in_anchor->GetIdx(); - auto out_tensor = node->GetOpDesc()->MutableOutputDesc(static_cast(idx)); - auto peer_tensor = peer_node->GetOpDesc()->MutableInputDesc(static_cast(peer_idx)); - if (peer_tensor == nullptr || out_tensor == nullptr) { - return; - } - for (auto &iter : attr_val) { - (void)ge::AttrUtils::SetInt(out_tensor, iter.first, iter.second); - (void)ge::AttrUtils::SetInt(peer_tensor, iter.first, iter.second); - } - } -} -} diff --git a/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc b/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc deleted file mode 100644 index 9f7eaaf190e71e972fe557da8e5d80051267c09a..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass.cc +++ /dev/null @@ -1,619 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include -#include -#include -#include -#include -#include "mmpa/mmpa_api.h" -#include "graph/debug/ge_log.h" -#include "graph/utils/graph_utils.h" -#include "register/graph_optimizer/fusion_common/fusion_statistic_recorder.h" -#include "register/graph_optimizer/fusion_common/graph_pass_util.h" -#include "register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" -#include "register/graph_optimizer/fusion_common/fusion_config_info.h" - -namespace fe { -namespace { -void PrintAllNodes(const std::vector &scope_nodes) { - for (const auto &node : scope_nodes) { - if (node == nullptr) { - GELOGD("type: null, name: null"); - } else { - GELOGD("type: %s, name: %s", node->GetType().c_str(), node->GetName().c_str()); - } - } -} - -void StoreOriginNodes(const Mapping &mapping, - GraphPassUtil::OriginOpAttrsVec &origin_op_attrs, - std::vector &original_nodes) { - for (const auto &item : mapping) { - if (item.second.empty()) { - continue; - } - for (const auto &node : item.second) { - original_nodes.emplace_back(node); - std::vector origin_op_attrs_vec; - origin_op_attrs_vec.push_back(node->GetName()); - origin_op_attrs_vec.push_back(node->GetType()); - origin_op_attrs.emplace_back(origin_op_attrs_vec); - } - } -} -} -static const std::string STREAM_LABEL = "_stream_label"; -static const std::string ATTR_OP_COMPILE_STRATEGY = "_op_compile_strategy"; -static const std::string ATTR_KEEP_DTYPE = "_keep_dtype"; -PatternFusionBasePass::PatternFusionBasePass() { - pattern_fusion_base_pass_impl_ptr_ = std::make_shared(); -} - -PatternFusionBasePass::~PatternFusionBasePass() {} - -Status PatternFusionBasePass::Run(ge::ComputeGraph &graph, OpsKernelInfoStorePtr ops_kernel_info_store_ptr) { - // save the opskernelstoreptr which will be uesd while checking op support - pattern_fusion_base_pass_impl_ptr_->SetOpsKernelInfoStore(ops_kernel_info_store_ptr); - - Status ret = Run(graph); - if (ret != SUCCESS) { - // do not update cache when not success - int64_t run_count_attr; - if (ge::AttrUtils::GetInt(graph, "run_count", run_count_attr)) { - (void)ge::AttrUtils::SetInt(graph, "run_count", ++run_count_attr); - } - } - return ret; -} -/** - * @ingroup fe - * @brief execute pass - */ -Status PatternFusionBasePass::Run(ge::ComputeGraph &graph) { - bool is_patterns_ok = true; - // build Pattern - std::vector patterns; - pattern_fusion_base_pass_impl_ptr_->GetPatterns(patterns); - if (patterns.empty()) { - patterns = DefinePatterns(); - for (FusionPattern *pattern : patterns) { - if (pattern != nullptr) { - const bool ok = pattern->Build(); - if (!ok) { - GELOGW("[RunFusionPass][Check] pattern %s build failed", pattern->GetName().c_str()); - } - pattern->Dump(); - is_patterns_ok = is_patterns_ok && ok; - } - } - - pattern_fusion_base_pass_impl_ptr_->SetPatterns(patterns); - } - - if (!is_patterns_ok) { - return FAILED; - } - NodeMapInfoPtr node_map_info = nullptr; - if (GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph) == SUCCESS) { - if (node_map_info->run_count == std::numeric_limits::max()) { - GELOGE(ge::FAILED, "Run count is overflow."); - return FAILED; - } - node_map_info->run_count++; - } - - // do matching and fusion for each pattern - bool final_changed = false; - for (const FusionPattern * const pattern : patterns) { - if (pattern != nullptr) { - bool changed = false; - const Status ret = RunOnePattern(graph, *pattern, changed); - if (ret != SUCCESS) { - GELOGW("[RunFusionPass][Check] Running pattern %s failed; the graph was not altered by it.", pattern->GetName().c_str()); - return ret; - } - - final_changed = final_changed || changed; - } - } - return final_changed ? SUCCESS : NOT_CHANGED; -} - -static bool CheckStreamLabel(std::vector &fused_nodes) { - std::string stream_label = ""; - for (auto &n : fused_nodes) { - std::string stream_label_tmp = ""; - if (!ge::AttrUtils::GetStr(n->GetOpDesc(), STREAM_LABEL, stream_label_tmp)) { - stream_label_tmp = "null"; - } - if (stream_label == "") { - stream_label = stream_label_tmp; - } else if ((stream_label != "") && (stream_label != stream_label_tmp)) { - return false; - } - } - return true; -} - -static bool SetStreamLabelToFusedNodes(std::vector &fused_nodes, - const std::vector &original_nodes) { - if (original_nodes.empty() || original_nodes[0] == nullptr) { - return true; - } - - std::string stream_label = ""; - if (ge::AttrUtils::GetStr(original_nodes[0]->GetOpDesc(), STREAM_LABEL, stream_label)) { - for (ge::NodePtr &node : fused_nodes) { - if (!ge::AttrUtils::SetStr(node->GetOpDesc(), STREAM_LABEL, stream_label)) { - GELOGW("[Set][Attr] node %s set attr _stream_label failed", node->GetName().c_str()); - return false; - } - } - } - return true; -} - -void PatternFusionBasePass::DumpMapping(const FusionPattern &pattern, const Mapping &mapping) const { - std::ostringstream oss; - oss << std::endl << "Mapping of pattern "; - oss << pattern.GetName() << ":" << std::endl; - oss << " Mapping: " << std::endl; - for (const auto &item : mapping) { - const std::shared_ptr op_desc = item.first; - const ge::NodePtr node = item.second[0U]; - if ((op_desc != nullptr) && (node != nullptr)) { - oss << " " << op_desc->id << " -> " << node->GetName() << std::endl; - } - } - GELOGE(FAILED, "%s", oss.str().c_str()); -} - -/** - * @ingroup fe - * @brief do matching and fusion in graph based on the pattern - */ -Status PatternFusionBasePass::RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed) { - changed = false; - Mappings mappings; - int32_t effect_times = 0; - const uint32_t graph_id = graph.GetGraphID(); - FusionInfo fusion_info(graph.GetSessionID(), to_string(graph_id), GetName(), static_cast(mappings.size()), - effect_times); - origin_op_anchors_map_.clear(); - // match all patterns in graph, and save them to mappings - if (!MatchAll(graph, pattern, mappings)) { - GELOGD("GraphFusionPass[%s]: pattern=%s, matched_times=%zu, effected_times=%d.", GetName().c_str(), - pattern.GetName().c_str(), mappings.size(), effect_times); - return SUCCESS; - } - - GELOGD("This graph has been matched with pattern[%s]. The mappings are as follows.", pattern.GetName().c_str()); - - // print the results of matching - pattern_fusion_base_pass_impl_ptr_->DumpMappings(pattern, mappings); - NodeMapInfoPtr node_map_info = nullptr; - // get nodes by type from node - (void)GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph); - // do fusion for each mapping - for (Mapping &mapping : mappings) { - GraphPassUtil::OriginOpAttrsVec origin_op_attrs; - std::vector original_nodes; - StoreOriginNodes(mapping, origin_op_attrs, original_nodes); - bool backward = false; - GraphPassUtil::GetBackWardAttr(original_nodes, backward, BackWardInheritMode::kFusedNode); - std::vector fus_nodes; - const Status status = Fusion(graph, mapping, fus_nodes); - - const bool isGraphCycle = FusionConfigInfo::Instance().IsEnableNetworkAnalysis() && CheckGraphCycle(graph); - if (isGraphCycle) { - GELOGE(FAILED, "Failed to do topological sorting after graph fusion, graph is cyclic, graph name:%s", - graph.GetName().c_str()); - GELOGE(FAILED, "This graph is cyclic. The mapping and new nodes are as follows."); - DumpMapping(pattern, mapping); - - std::ostringstream oss; - for (const auto &node_ : fus_nodes) { - oss << "name:" << node_->GetName() << ", type:" << node_->GetType() << std::endl; - } - GELOGE(FAILED, "%s", oss.str().c_str()); - ge::GraphUtils::DumpGEGraphToOnnx(graph, "graph_cyclic_after " + pattern.GetName()); - return GRAPH_FUSION_CYCLE; - } - - if (!SetStreamLabelToFusedNodes(fus_nodes, original_nodes)) { - return FAILED; - } - - if ((status != SUCCESS) && (status != NOT_CHANGED)) { - GELOGE(status, "[Fuse][Graph]Fail with pattern[%s].", pattern.GetName().c_str()); - return status; - } - - if (status == SUCCESS) { - effect_times++; - SetDataDumpAttr(original_nodes, fus_nodes); - for (ge::NodePtr &node : fus_nodes) { - const ge::OpDescPtr fusion_op = node->GetOpDesc(); - GraphPassUtil::RecordOriginalOpAttrs(original_nodes, fusion_op, GetName(), origin_op_attrs); - (void)GraphPassUtil::StoreAndUpdataOriginFusionPassName(fusion_op, original_nodes, GetName()); - (void)GraphPassUtil::AddNodeFromOpTypeMap(node_map_info, node); - } - const BackWardInheritMode inherit_mode = backward ? BackWardInheritMode::kInheritTrue : - BackWardInheritMode::kDoNotInherit; - GraphPassUtil::InheritAttrFromOriNodes(original_nodes, fus_nodes, inherit_mode); - } - changed = (changed || (status == SUCCESS)); - } - - // get match times and effect times - FusionStatisticRecorder &fusion_statistic_inst = FusionStatisticRecorder::Instance(); - fusion_info.SetMatchTimes(static_cast(mappings.size())); - fusion_info.SetEffectTimes(effect_times); - fusion_statistic_inst.UpdateGraphFusionMatchTimes(fusion_info); - fusion_statistic_inst.UpdateGraphFusionEffectTimes(fusion_info); - GELOGI("GraphId[%d], GraphFusionPass[%s]: pattern=%s, matched_times=%zu, effected_times=%d.", graph_id, - GetName().c_str(), pattern.GetName().c_str(), mappings.size(), effect_times); - return SUCCESS; -} - -std::vector PatternFusionBasePass::DefineInnerPatterns() { - std::vector ret; - return ret; -} - -void PatternFusionBasePass::SetDataDumpAttr(const std::vector &fused_nodes, - const std::vector &fusion_nodes) { - // if pass do not specify fused nodes, all matched nodes will be handled as the fused nodes - const std::vector &actual_fused_nodes = pattern_fusion_base_pass_impl_ptr_->GetActualFusedNodes(); - if (actual_fused_nodes.empty()) { - SetOriginalOpDumpAttr(fused_nodes, fusion_nodes); - } else { - SetOriginalOpDumpAttr(actual_fused_nodes, fusion_nodes); - } - - if (fusion_nodes.size() > 1) { - const bool is_multi_op = true; - for (const ge::NodePtr &node : fusion_nodes) { - (void)ge::AttrUtils::SetBool(node->GetOpDesc(), ge::ATTR_NAME_DATA_DUMP_IS_MULTIOP, is_multi_op); - } - } - - SetOriginalOutputDumpAttr(fused_nodes, fusion_nodes); -} - -void PatternFusionBasePass::SetOriginalOutputDumpAttr(const std::vector &fused_nodes, - const std::vector &fusion_nodes) { - for (const ge::NodePtr &ori_node : fused_nodes) { - const auto iter = origin_op_anchors_map_.find(ori_node); - if (iter != origin_op_anchors_map_.end()) { - for (const auto &anchor_iter : iter->second) { - const auto next_node_in_anchor = anchor_iter.first; - const auto fusion_node_out_data_anchor = next_node_in_anchor->GetPeerOutAnchor(); - if (fusion_node_out_data_anchor == nullptr) { - GELOGW("[Set][Attr] peer_out_anchor is null"); - return; - } - - // owner_node of anchor should not be null - auto fusion_node = fusion_node_out_data_anchor->GetOwnerNode(); - if (fusion_node == nullptr) { - GELOGW("[Set][Attr] fusion_node is null."); - return; - } - if (pattern_fusion_base_pass_impl_ptr_->IsNodesExist(fusion_node, fusion_nodes)) { - const auto origin_node_out_anchor = anchor_iter.second; - if (origin_node_out_anchor == nullptr) { - GELOGW("[Set][Attr] The ori_out_anchor of node %s is null.", ori_node->GetName().c_str()); - return; - } - - // owner_node of anchor should not be null - auto origin_node = origin_node_out_anchor->GetOwnerNode(); - if (origin_node == nullptr) { - GELOGW("[Set][Attr] origin_node is null"); - return; - } - const uint32_t origin_index = static_cast(origin_node_out_anchor->GetIdx()); - const uint32_t fusion_index = static_cast(fusion_node_out_data_anchor->GetIdx()); - GraphPassUtil::SetOutputDescAttr(origin_index, fusion_index, origin_node, fusion_node); - } - } - } - } -} - -void PatternFusionBasePass::SetOriginalOpDumpAttr(const std::vector &fused_nodes, - const std::vector &fusion_nodes) { - for (const ge::NodePtr &node : fusion_nodes) { - GraphPassUtil::RecordOriginalNames(fused_nodes, node); - } -} - -void PatternFusionBasePass::SetActualFusedNodes(const std::vector &fused_nodes) { - pattern_fusion_base_pass_impl_ptr_->SetActualFusedNodes(fused_nodes); -} - -bool PatternFusionBasePass::CheckOpSupported(const ge::OpDescPtr &op_desc_ptr) const { - return pattern_fusion_base_pass_impl_ptr_->CheckOpSupported(op_desc_ptr); -} - -bool PatternFusionBasePass::CheckOpSupported(const ge::NodePtr &node) const { - return pattern_fusion_base_pass_impl_ptr_->CheckOpSupported(node); -} - -bool PatternFusionBasePass::CheckAccuracySupported(const ge::NodePtr &node) const { - return pattern_fusion_base_pass_impl_ptr_->CheckAccuracySupported(node); -} - -bool PatternFusionBasePass::CheckEachPeerOut(const ge::NodePtr &node, - const std::unordered_set &scope_nodes_set, - const std::vector &scope_nodes) const { - for (const auto &peer_out : node->GetOutAllNodes()) { - if (scope_nodes_set.count(peer_out) > 0) { - continue; - } - for (const auto &node_temp :scope_nodes) { - if ((node_temp == nullptr) || (node_temp == node)) { - continue; - } - GELOGD("Check %s and %s.", peer_out->GetName().c_str(), node_temp->GetName().c_str()); - - if (connectivity_->IsConnected(peer_out, node_temp)) { - GELOGD("There is a path between %s and %s after fusion:", - peer_out->GetName().c_str(), - node_temp->GetName().c_str()); - PrintAllNodes(scope_nodes); - return true; - } - } - } - return false; -} - -bool PatternFusionBasePass::DetectOneScope(const std::vector &scope_nodes) const { - /* Create a set for accelerating the searching. */ - const std::unordered_set scope_nodes_set(scope_nodes.begin(), scope_nodes.end()); - - for (const auto &node: scope_nodes) { - if (node == nullptr) { - continue; - } - if (CheckEachPeerOut(node, scope_nodes_set, scope_nodes)) { - return true; - } - } - return false; -} - -void PatternFusionBasePass::GetConnectionMatrix(std::unique_ptr &connection_matrix) { - connection_matrix = std::move(connectivity_); -} - -void PatternFusionBasePass::SetConnectionMatrix(std::unique_ptr &connection_matrix) { - connectivity_ = std::move(connection_matrix); -} - -bool PatternFusionBasePass::CycleDetection(const ge::ComputeGraph &graph, - const std::vector> &fusion_nodes) { - if (connectivity_ == nullptr) { - try { - connectivity_ = std::unique_ptr(new(std::nothrow) fe::ConnectionMatrix(graph)); - } catch (...) { - GELOGW("Make shared failed"); - return false; - } - connectivity_->Generate(graph); - } - - for (const auto &scope_nodes : fusion_nodes) { - if (DetectOneScope(scope_nodes)) { - return true; - } - } - return false; -} - -const std::vector &PatternFusionBasePass::GetPatterns() { - const auto &patterns = pattern_fusion_base_pass_impl_ptr_->GetPatterns(); - if (!patterns.empty()) { - return patterns; - } - - const auto new_defined_patterns = DefinePatterns(); - for (FusionPattern *pattern : new_defined_patterns) { - if (pattern != nullptr) { - const bool build_result = pattern->Build(); - if (!build_result) { - GELOGW("[GetPatterns][Check] Pattern %s build failed", pattern->GetName().c_str()); - return patterns; - } - pattern->Dump(); - } - } - - pattern_fusion_base_pass_impl_ptr_->SetPatterns(new_defined_patterns); - return pattern_fusion_base_pass_impl_ptr_->GetPatterns(); -} - -const std::vector &PatternFusionBasePass::GetInnerPatterns() { - const auto &inner_patterns = pattern_fusion_base_pass_impl_ptr_->GetInnerPatterns(); - if (!inner_patterns.empty()) { - return inner_patterns; - } - - const auto new_defined_inner_patterns = DefineInnerPatterns(); - for (FusionPattern *inner_pattern : new_defined_inner_patterns) { - if (inner_pattern != nullptr) { - const bool build_result = inner_pattern->Build(); - if (!build_result) { - GELOGW("[GetPatterns][Check] Pattern %s build failed", inner_pattern->GetName().c_str()); - return inner_patterns; - } - inner_pattern->Dump(); - } - } - - pattern_fusion_base_pass_impl_ptr_->SetInnerPatterns(new_defined_inner_patterns); - return pattern_fusion_base_pass_impl_ptr_->GetInnerPatterns(); -} - -bool PatternFusionBasePass::MatchFromOutput(const ge::NodePtr &output_node, - const std::shared_ptr &output_op_desc, Mapping &mapping) { - return pattern_fusion_base_pass_impl_ptr_->MatchFromOutput(output_node, output_op_desc, mapping); -} - -bool PatternFusionBasePass::CycleDetection(const ge::ComputeGraph &graph, - const std::vector &fusion_nodes) { - if (connectivity_ == nullptr) { - try { - connectivity_ = std::unique_ptr(new(std::nothrow) fe::ConnectionMatrix(graph)); - } catch (...) { - GELOGW("Make shared failed"); - return false; - } - connectivity_->Generate(graph); - } - - return DetectOneScope(fusion_nodes); -} - -bool PatternFusionBasePass::CheckGraphCycle(ge::ComputeGraph &graph) const { - const Status ret = graph.TopologicalSorting(); - if (ret != ge::GRAPH_SUCCESS) { - return true; - } - return false; -} - -/** - * @ingroup fe - * @brief match all nodes in graph according to pattern - * match nodes in graph according to pattern, the algorithm is shown as following: - * 1. get output node from pattern - * 2. Search for candidate nodes in Graph (network Graph generated after parsing) according to Op Type and - * (optional), and add the candidate node to the list of candidates - * 3. For each Node in the candidate list, check whether the type and the number - * of precursors are consistent with the description of corresponding Op in pattern. - * If they are consistent, add the precursor Node to the - * candidate list, and add "PatternOp-GraphNode" to the mapping; otherwise, return an empty mapping - * 4. repeat step 3 until all the Ops in pattern are matched - * 5. if all the Ops in pattern are matched successfully, return the mapping of PatternOp and GraphNode - */ -bool PatternFusionBasePass::MatchAll(const ge::ComputeGraph &graph, const FusionPattern &pattern, - Mappings &mappings) { - std::vector matched_output_nodes; - - // find all the output nodes of pattern in the graph based on Op type - std::shared_ptr output_op_desc = pattern.GetOutput(); - if (output_op_desc == nullptr) { - return false; - } - - if (!pattern_fusion_base_pass_impl_ptr_->GetMatchOutputNodes(graph, pattern, matched_output_nodes)) { - return false; - } - - // begin matching from every output node - for (ge::NodePtr &output_node : matched_output_nodes) { - Mapping mapping; - if (pattern_fusion_base_pass_impl_ptr_->MatchFromOutput(output_node, output_op_desc, mapping)) { - // node attr _stream_label must be equal - auto fusion_nodes = GetNodesFromMapping(mapping); - if (!CheckStreamLabel(fusion_nodes)) { - return false; - } - std::string reason_not_support; - const bool ret = graph.IsSupportFuse(fusion_nodes, reason_not_support); - if (!ret) { - GELOGD("IsSupportFuse did not succeed, reason is [%s].", reason_not_support.c_str()); - continue; - } - mappings.push_back(mapping); - - // Record output nodes anchor vs succeed node anchor map - RecordOutputAnchorMap(output_node); - } - } - // if matching is successful, return true; otherwise false - return !mappings.empty(); -} - -/* - * @brief: get all fusion nodes matched - * @param [in] mapping: fusion node group - * @return std::vector: all fusion nodes list - */ -std::vector PatternFusionBasePass::GetNodesFromMapping(const Mapping &mapping) const { - std::vector nodes; - for (auto &item : mapping) { - for (const auto &node : item.second) { - nodes.push_back(node); - } - } - return nodes; -} - -/** - * @ingroup fe - * @brief get an op from mapping according to ID - */ -ge::NodePtr PatternFusionBasePass::GetNodeFromMapping(const std::string &id, const Mapping &mapping) const { - for (auto &item : mapping) { - const std::shared_ptr op_desc = item.first; - if ((op_desc != nullptr) && (op_desc->id == id)) { - if (item.second.empty()) { - return nullptr; - } else { - return item.second[0]; - } - } - } - return nullptr; -} - -void PatternFusionBasePass::StoreOriginOpNames(const Mapping &mapping, - std::vector &origin_op_names) const { - for (const auto &item : mapping) { - if (item.second.empty()) { - continue; - } - for (const auto &node : item.second) { - origin_op_names.push_back(node->GetOpDesc()->GetName()); - } - } -} - -void PatternFusionBasePass::RecordOutputAnchorMap(ge::NodePtr output_node) { - for (const auto &output_anchor : output_node->GetAllOutDataAnchors()) { - if (output_anchor == nullptr) { - continue; - } - - for (const auto &peer_in_anchor : output_anchor->GetPeerInDataAnchors()) { - if (peer_in_anchor == nullptr) { - continue; - } - - // Record anchor map - const auto iter = origin_op_anchors_map_.find(output_node); - if (iter == origin_op_anchors_map_.end()) { - std::map anchorMap; - anchorMap[peer_in_anchor] = output_anchor; - (void)origin_op_anchors_map_.emplace(make_pair(output_node, anchorMap)); - } else { - (void)iter->second.emplace(make_pair(peer_in_anchor, output_anchor)); - } - } - } -} - -void PatternFusionBasePass::ClearOutputAnchorMap() { origin_op_anchors_map_.clear(); } -} // namespace fe diff --git a/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc b/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc deleted file mode 100644 index b8582585de121beeade6999ee94fd27a54287eef..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.cc +++ /dev/null @@ -1,469 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h" -#include "register/graph_optimizer/fusion_common/graph_pass_util.h" - -namespace fe { -namespace { -const std::string kAttrDumpAble = "_dump_able"; -} -PatternFusionBasePassImpl::PatternFusionBasePassImpl() {} - -PatternFusionBasePassImpl::~PatternFusionBasePassImpl() { - for (auto pattern : patterns_) { - if (pattern != nullptr) { - delete pattern; - pattern = nullptr; - } - } - for (auto inner_pattern : inner_patterns_) { - if (inner_pattern != nullptr) { - delete inner_pattern; - inner_pattern = nullptr; - } - } -} - -const std::vector &PatternFusionBasePassImpl::GetPatterns() { return patterns_; } - -void PatternFusionBasePassImpl::GetPatterns(std::vector &patterns) { patterns = patterns_; } - -const std::vector &PatternFusionBasePassImpl::GetInnerPatterns() { return inner_patterns_; } - -void PatternFusionBasePassImpl::GetInnerPatterns(std::vector &inner_patterns) { - inner_patterns = inner_patterns_; -} - -void PatternFusionBasePassImpl::SetPatterns(const std::vector &patterns) { patterns_ = patterns; } - -void PatternFusionBasePassImpl::SetInnerPatterns(const std::vector &inner_patterns) { - inner_patterns_ = inner_patterns; -} - -void PatternFusionBasePassImpl::SetOpsKernelInfoStore(const OpsKernelInfoStorePtr &ops_kernel_info_store_ptr) { - ops_kernel_info_store_ptr_ = ops_kernel_info_store_ptr; -} - -bool PatternFusionBasePassImpl::CheckOpSupported(const ge::OpDescPtr &op_desc_ptr) const { - std::string un_supported_reason; - - if (ops_kernel_info_store_ptr_ == nullptr) { - un_supported_reason = "opsKernelInfoStorePtr in PatternFusionBasePass is nullptr."; - return false; - } - - return ops_kernel_info_store_ptr_->CheckSupported(op_desc_ptr, un_supported_reason); -} - -bool PatternFusionBasePassImpl::CheckOpSupported(const ge::NodePtr &node) const { - std::string un_supported_reason; - - if (ops_kernel_info_store_ptr_ == nullptr) { - un_supported_reason = "opsKernelInfoStorePtr in PatternFusionBasePass is nullptr."; - return false; - } - - return ops_kernel_info_store_ptr_->CheckSupported(node, un_supported_reason); -} - -bool PatternFusionBasePassImpl::CheckAccuracySupported(const ge::NodePtr &node) const { - if (node == nullptr) { - GELOGD("Node is null."); - return false; - } - if (ops_kernel_info_store_ptr_ == nullptr) { - GELOGD("Ops kernel info store is null."); - return false; - } - std::string un_supported_reason; - const bool ret = ops_kernel_info_store_ptr_->CheckAccuracySupported(node, un_supported_reason, true); - GELOGD("Check result for op[%s, %s] is [%d], reason is [%s].", - node->GetName().c_str(), node->GetType().c_str(), ret, un_supported_reason.c_str()); - return ret; -} - -bool PatternFusionBasePassImpl::IsNodesExist(const ge::NodePtr ¤t_node, const std::vector &nodes) { - return find(nodes.begin(), nodes.end(), current_node) != nodes.end(); -} - -bool PatternFusionBasePassImpl::IsMatched(const std::shared_ptr op_desc, const ge::NodePtr node, - const Mapping &mapping) { - if ((op_desc == nullptr) || (node == nullptr)) { - GELOGD("opDesc or node could not be null"); - return false; - } - - const auto iter = mapping.find(op_desc); - - // check op_desc does not exist in mapping - return (iter != mapping.end()) && (find(iter->second.begin(), iter->second.end(), node) != iter->second.end()); -} - -void PatternFusionBasePassImpl::DumpMappings(const FusionPattern &pattern, const Mappings &mappings) const { - std::ostringstream oss; - oss << std::endl << "Mappings of pattern "; - oss << pattern.GetName() << ":" << std::endl; - for (size_t i = 0; i < mappings.size(); i++) { - const Mapping &mapping = mappings[i]; - oss << " Mapping " << (i + 1) << "/" << mappings.size() << ":" << std::endl; - for (const auto &item : mapping) { - const std::shared_ptr op_desc = item.first; - const ge::NodePtr node = item.second[0]; - if ((op_desc != nullptr) && (node != nullptr)) { - oss << " " << op_desc->id << " -> " << node->GetName() << std::endl; - } - } - } - GELOGD("%s", oss.str().c_str()); -} - -bool PatternFusionBasePassImpl::IsOpTypeExist(const std::string &type, const std::vector &types) { - return find(types.begin(), types.end(), type) != types.end(); -} - -bool PatternFusionBasePassImpl::GetSortedInAnchors(const ge::NodePtr &node, const std::string&op_id, - std::vector &in_anchors) const { - if (node->GetInDataNodes().empty()) { - GELOGW("[Match][Output] in data nodes of op %s is empty, pattern matching failed.", op_id.c_str()); - return false; - } - - /* Input anchors should have an order. */ - GetInDataAnchors(node, in_anchors); - if (in_anchors.empty()) { - GELOGW("[Match][Output] The data anchor for op %s is empty, leading to a failure in pattern matching.", op_id.c_str()); - return false; - } - - std::sort(in_anchors.begin(), in_anchors.end(), - [](const ge::InDataAnchorPtr &a, const ge::InDataAnchorPtr &b) { return a->GetIdx() < b->GetIdx(); }); - return true; -} - -bool PatternFusionBasePassImpl::MatchFromOutput(const ge::NodePtr output_node, - const std::shared_ptr output_op_desc, Mapping &mapping) const { - if ((output_node == nullptr) || (output_op_desc == nullptr)) { - GELOGW("[Match][Output] Output node or op_desc is null, pattern matching failed."); - return false; - } - CandidateAndMapping cand(mapping); - cand.candidate_nodes = {output_node}; - cand.candidate_op_descs = {output_op_desc}; - - // store the nodes matched - cand.mapping[output_op_desc].push_back(output_node); - - // match candidate node one by one - while ((!cand.candidate_nodes.empty()) && (!cand.candidate_op_descs.empty())) { - // get the first candidate node - bool result = MatchFromOutput(cand); - if (!result) { - return false; - } - - result = MatchOutputs(cand); - if (!result) { - return false; - } - // current op is matched successfully, thus remove it from candidate list - (void)cand.candidate_nodes.erase(cand.candidate_nodes.cbegin()); - (void)cand.candidate_op_descs.erase(cand.candidate_op_descs.cbegin()); - - // the sizes of candidate_nodes and candidate_op_descs should always keep the same - if (cand.candidate_nodes.size() != cand.candidate_op_descs.size()) { - GELOGW("[Match][Output] candidate_nodes_num != candidate_op_descs_num, pattern matching failed."); - return false; - } - } - - // if candidate_nodes(or candidate_op_descs) is empty, the matching is done - // successfully - return cand.candidate_op_descs.empty(); -} - -bool PatternFusionBasePassImpl::VerifyInputDescNodes(const ge::NodePtr &input_node, - const std::shared_ptr &input_desc, - const Mapping &mapping) { - if (input_node == nullptr) { - return true; - } - if (!input_desc->check_unique) { - return true; - } - - // if this input desc has been matched before, the current nodes should among the matched nodes - auto iter = mapping.find(input_desc); - if (iter == mapping.cend() || iter->second.empty()) { - return true; - } - return std::find(iter->second.begin(), iter->second.end(), input_node) != iter->second.end(); -} - -bool PatternFusionBasePassImpl::MatchFromOutput(CandidateAndMapping &cand) const { - if (cand.candidate_nodes.empty() || cand.candidate_op_descs.empty()) { - GELOGW("[Match][Output] Either candidate_nodes or candidate_op_descs is empty, resulting in pattern matching failure."); - return false; - } - const ge::NodePtr node = cand.candidate_nodes.front(); - std::shared_ptr op_desc = cand.candidate_op_descs.front(); - const std::string op_id = op_desc->id; - // add the input nodes into candidate list - const std::vector> * const inputs_desc = FusionPattern::GetInputs(op_desc); - if (inputs_desc == nullptr) { - GELOGW("[Match][Output] Failed to get input_desc for op %s, pattern matching failed.", op_id.c_str()); - return false; - } - - if (inputs_desc->empty()) { - return true; - } - std::vector in_anchors; - if (!GetSortedInAnchors(node, op_id, in_anchors)) { - return false; - } - // set flag for edge using - const std::unique_ptr usage_flags(new (std::nothrow) bool[inputs_desc->size()]{}); - for (const auto &in_anchor : in_anchors) { - if (in_anchor->GetPeerOutAnchor() == nullptr) { - GELOGE(ge::FAILED, "Peer anchor is null."); - return false; - } - const ge::NodePtr input_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode(); - for (uint32_t j = 0U; j < inputs_desc->size(); j++) { - const std::shared_ptr &input_desc = inputs_desc->at(static_cast(j)); - if (input_desc == nullptr) { - GELOGW("[Match][Output] input_desc %u for op %s is null, pattern matching failed.", j, op_id.c_str()); - return false; - } - - const bool matching_result = - (IsOpTypeExist(ge::NodeUtils::GetNodeType(*input_node), input_desc->types) || input_desc->types.empty()) && - ((!usage_flags[static_cast(j)]) || input_desc->repeatable) && - IsOpFusible(input_node->GetOpDesc(), input_desc) && - VerifyInputDescNodes(input_node, input_desc, cand.mapping); - if (!matching_result) { - continue; - } - - // some nodes might be the input of multiple nodes, we use - // IsMatched() to avoid repeat - AddCandidateQueue(input_desc, input_node, cand); - usage_flags[static_cast(j)] = true; - break; - } - } - - // return false if not all edges are matched - if (!MatchAllEdges(inputs_desc->size(), usage_flags)) { - GELOGD("[Match][Output] Not all inputs of op %s were matched; pattern matching did not succeed.", op_id.c_str()); - return false; - } - - return true; -} - -void PatternFusionBasePassImpl::AddCandidateQueue(const FusionPattern::OpDescPtr &op_desc, - const ge::NodePtr &node, - CandidateAndMapping &cand) const { - if (IsMatched(op_desc, node, cand.mapping)) { - return; - } - cand.candidate_nodes.emplace_back(node); - cand.candidate_op_descs.emplace_back(op_desc); - cand.mapping[op_desc].emplace_back(node); -} - -void PatternFusionBasePassImpl::MatchOneOutputNode(const ge::NodePtr &output_node, - const std::vector &outputs_desc, - size_t &out_idx, const std::unique_ptr &usage_flags, - CandidateAndMapping &cand) const { - if (output_node == nullptr) { - return; - } - for (size_t i = 0; i < outputs_desc.size(); i++) { - const FusionPattern::OpDescPtr &output_desc = outputs_desc.at(i); - const bool is_matched = - (IsOpTypeExist(ge::NodeUtils::GetNodeType(*output_node), output_desc->types) || output_desc->types.empty()) && - (!usage_flags[out_idx + i]) && IsOpFusible(output_node->GetOpDesc(), output_desc); - if (!is_matched) { - continue; - } - AddCandidateQueue(output_desc, output_node, cand); - usage_flags[out_idx + i] = true; - break; - } -} - -void PatternFusionBasePassImpl::MatchFuzzyOutputs(const ge::NodePtr &node, const FusionPattern::OpDescPtr &op_desc, - size_t &out_idx, const std::unique_ptr &usage_flags, - CandidateAndMapping &cand) const { - const FusionPattern::OutputMapDesc &outputs_desc_map = FusionPattern::GetOutputs(op_desc); - auto peer_in_nodes = node->GetOutDataNodes(); - for (const auto &outputs_desc_pair : outputs_desc_map) { - if (outputs_desc_pair.first != kFuzzyOutIndex) { - continue; - } - - for (const auto &peer_in_node : peer_in_nodes) { - MatchOneOutputNode(peer_in_node, outputs_desc_pair.second, out_idx, usage_flags, cand); - } - if (out_idx > (std::numeric_limits::max() - outputs_desc_pair.second.size())) { - GELOGE(ge::FAILED, "Out idx %zu is overflow.", out_idx); - return; - } - out_idx += outputs_desc_pair.second.size(); - } -} - -void PatternFusionBasePassImpl::UpdateCandidates( - const CandidateAndMapping &temp_cand, CandidateAndMapping &cand) const { - if (temp_cand.candidate_op_descs.size() != temp_cand.candidate_nodes.size()) { - return; - } - - for (size_t i = 0; i < temp_cand.candidate_nodes.size(); i++) { - AddCandidateQueue(temp_cand.candidate_op_descs[i], temp_cand.candidate_nodes[i], cand); - } -} - -bool PatternFusionBasePassImpl::MatchOutputs(CandidateAndMapping &cand) const { - const auto &node = cand.candidate_nodes.front(); - const FusionPattern::OpDescPtr &op_desc = cand.candidate_op_descs.front(); - const std::string op_id = op_desc->id; - const FusionPattern::OutputMapDesc &outputs_desc_map = FusionPattern::GetOutputs(op_desc); - if (outputs_desc_map.empty()) { - return true; - } - const size_t outputs_desc_size = FusionPattern::GetOutputSize(op_desc); - if (op_desc->is_output_fullmatch && node->GetOutDataNodesSize() != outputs_desc_size) { - GELOGW("[Match][Input] Full match mode: op %s description size (%zu) does not match output data node size (%u)", op_id.c_str(), - outputs_desc_size, node->GetOutDataNodesSize()); - return false; - } - - const std::unique_ptr usage_flags(new (std::nothrow) bool[outputs_desc_size] {}); - std::vector out_anchors; - GetOutDataAnchors(node, out_anchors); - - size_t out_idx = 0; - MatchFuzzyOutputs(node, op_desc, out_idx, usage_flags, cand); - for (const auto &out_anchor : out_anchors) { - if (outputs_desc_map.find(out_anchor->GetIdx()) == outputs_desc_map.end()) { - GELOGW("[Match][Input] op %s out anchor idx: %d not configured in pattern", op_id.c_str(), out_anchor->GetIdx()); - continue; - } - const std::vector &outputs_desc = outputs_desc_map.at(out_anchor->GetIdx()); - for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { - const ge::NodePtr output_node = peer_in_anchor->GetOwnerNode(); - MatchOneOutputNode(output_node, outputs_desc, out_idx, usage_flags, cand); - } - out_idx += outputs_desc.size(); - } - - if (!MatchAllEdges(outputs_desc_size, usage_flags)) { - GELOGW("[Match][Input] Not all outputs of op %s are matched; pattern matching failed.", op_id.c_str()); - return false; - } - return true; -} - -bool PatternFusionBasePassImpl::MatchAllEdges(const size_t &input_size, const std::unique_ptr &usage_flags) { - for (size_t i = 0; i != input_size; i++) { - if (!usage_flags[i]) { - return false; - } - } - return true; -} - -void PatternFusionBasePassImpl::GetInDataAnchors(const ge::NodePtr &node, - std::vector &in_anchor_vec) { - for (const auto in_anchor : node->GetAllInDataAnchors()) { - if ((in_anchor == nullptr) || (in_anchor->GetPeerOutAnchor() == nullptr) || - (in_anchor->GetPeerOutAnchor()->GetOwnerNode() == nullptr)) { - continue; - } - in_anchor_vec.push_back(in_anchor); - } -} - -void PatternFusionBasePassImpl::GetOutDataAnchors(const ge::NodePtr &node, - std::vector &out_anchor_vec) { - for (const auto out_anchor : node->GetAllOutDataAnchors()) { - if (out_anchor == nullptr || out_anchor->GetPeerInDataNodesSize() == 0) { - continue; - } - out_anchor_vec.emplace_back(out_anchor); - } -} - -bool PatternFusionBasePassImpl::GetMatchOutputNodes(const ge::ComputeGraph &graph, const FusionPattern &pattern, - std::vector &matched_output_nodes) const { - const FusionPattern::OpDescPtr output_op_desc = pattern.GetOutput(); - if (output_op_desc == nullptr) { - GELOGW("[Get][Output] output op_desc is null, pattern matching failed"); - return false; - } - - NodeMapInfoPtr node_map_info = nullptr; - // get nodes by type from node - if (GraphPassUtil::GetOpTypeMapToGraph(node_map_info, graph) == SUCCESS) { - for (auto &OutOpType : output_op_desc->types) { - const auto iter = node_map_info->node_type_map->find(OutOpType); - if (iter != node_map_info->node_type_map->end()) { - for (auto iter_node = iter->second.cbegin(); iter_node != iter->second.cend(); iter_node++) { - const ge::NodePtr node_ptr = iter_node->second; - - if (node_ptr->GetInDataNodes().empty() && node_ptr->GetOutAllNodes().empty()) { - continue; - } - if (ge::NodeUtils::GetNodeType(*node_ptr) == OutOpType && - IsOpFusible(node_ptr->GetOpDesc(), output_op_desc)) { - matched_output_nodes.push_back(node_ptr); - } - } - } - } - } else { // for each graph to find type - for (ge::NodePtr &n : graph.GetDirectNode()) { - if (IsOpTypeExist(ge::NodeUtils::GetNodeType(*n), output_op_desc->types) && - IsOpFusible(n->GetOpDesc(), output_op_desc)) { - matched_output_nodes.push_back(n); - } - } - } - - if (matched_output_nodes.empty()) { - return false; - } - return true; -} - -const std::vector& PatternFusionBasePassImpl::GetActualFusedNodes() const { - return actual_fused_nodes_; -} - -void PatternFusionBasePassImpl::SetActualFusedNodes(const std::vector &fused_nodes) { - actual_fused_nodes_ = fused_nodes; -} - -bool PatternFusionBasePassImpl::IsOpFusible(const ge::OpDescPtr &op_desc, const FusionPattern::OpDescPtr &pattern_desc) -{ - if (op_desc == nullptr || pattern_desc == nullptr) { - return false; - } - if (pattern_desc->allow_dumpable) { - return true; - } - bool is_dump_able = false; - (void)ge::AttrUtils::GetBool(op_desc, kAttrDumpAble, is_dump_able); - return !is_dump_able; -} -} diff --git a/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h b/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h deleted file mode 100644 index f09c2a67d2b045a1f4827d7248b6b3f48cfa9063..0000000000000000000000000000000000000000 --- a/register/graph_optimizer/graph_fusion/pattern_fusion_base_pass_impl.h +++ /dev/null @@ -1,134 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef FE_PATTERN_FUSION_BASE_PASS_IMPL_H -#define FE_PATTERN_FUSION_BASE_PASS_IMPL_H - -#include -#include -#include -#include -#include -#include -#include "graph/debug/ge_log.h" -#include "common/opskernel/ops_kernel_info_store.h" -#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" - -namespace fe { -using OpDesc = FusionPattern::OpDesc; -using Mapping = std::map, std::vector, CmpKey>; -using Mappings = std::vector; -using OpsKernelInfoStorePtr = std::shared_ptr; -struct CandidateAndMapping { - std::vector candidate_nodes; - std::vector candidate_op_descs; - Mapping &mapping; - CandidateAndMapping(Mapping &mapping_param) : mapping(mapping_param) {} -}; - -/** Base pattern impl - * @ingroup FUSION_PASS_GROUP - * @note New virtual methods should be append at the end of this class - */ -class PatternFusionBasePassImpl { - public: - PatternFusionBasePassImpl(); - - virtual ~PatternFusionBasePassImpl(); - - const std::vector &GetPatterns(); - - void GetPatterns(std::vector &patterns); - - const std::vector &GetInnerPatterns(); - - void GetInnerPatterns(std::vector &inner_patterns); - - void SetPatterns(const std::vector &patterns); - - void SetInnerPatterns(const std::vector &inner_patterns); - - void SetOpsKernelInfoStore(const OpsKernelInfoStorePtr &ops_kernel_info_store_ptr); - - PatternFusionBasePassImpl &operator=(const PatternFusionBasePassImpl &) = delete; - - PatternFusionBasePassImpl(const PatternFusionBasePassImpl &another_pattern_fusion) = delete; - - bool CheckOpSupported(const ge::OpDescPtr &op_desc_ptr) const; - - bool CheckOpSupported(const ge::NodePtr &node) const; - - bool CheckAccuracySupported(const ge::NodePtr &node) const; - - static bool IsNodesExist(const ge::NodePtr ¤t_node, const std::vector &nodes); - - static bool IsMatched(const std::shared_ptr op_desc, const ge::NodePtr node, const Mapping &mapping); - - void DumpMappings(const FusionPattern &pattern, const Mappings &mappings) const; - - static bool IsOpTypeExist(const std::string &type, const std::vector &types); - - bool MatchFromOutput(const ge::NodePtr output_node, const std::shared_ptr output_op_desc, - Mapping &mapping) const; - - bool GetMatchOutputNodes(const ge::ComputeGraph &graph, const FusionPattern &pattern, - std::vector &matched_output_nodes) const; - - const std::vector& GetActualFusedNodes() const; - - void SetActualFusedNodes(const std::vector &fused_nodes); - - private: - std::vector patterns_; - - std::vector inner_patterns_; - - OpsKernelInfoStorePtr ops_kernel_info_store_ptr_; - - std::vector actual_fused_nodes_; - - bool GetSortedInAnchors(const ge::NodePtr &node, const std::string &op_id, - std::vector &in_anchors) const; - - void MatchOneOutputNode(const ge::NodePtr &output_node, - const std::vector &outputs_desc, - size_t &out_idx, const std::unique_ptr &usage_flags, - CandidateAndMapping &cand) const; - - bool MatchFromOutput(CandidateAndMapping &cand) const; - - void MatchFuzzyOutputs(const ge::NodePtr &node, const FusionPattern::OpDescPtr &op_desc, - size_t &out_idx, const std::unique_ptr &usage_flags, - CandidateAndMapping &cand) const; - - bool MatchOutputs(CandidateAndMapping &cand) const; - - void UpdateCandidates(const CandidateAndMapping &temp_cand, CandidateAndMapping &cand) const; - - void AddCandidateQueue(const FusionPattern::OpDescPtr &op_desc, const ge::NodePtr &node, - CandidateAndMapping &cand) const; - - bool MatchAsInput(std::vector &candidate_nodes, - std::vector &candidate_op_descs, Mapping &mapping) const; - - static bool MatchAllEdges(const size_t &input_size, const std::unique_ptr &usage_flags); - - static void GetInDataAnchors(const ge::NodePtr &node, std::vector &in_anchor_vec); - - static void GetOutDataAnchors(const ge::NodePtr &node, std::vector &out_anchor_vec); - - static bool IsOpFusible(const ge::OpDescPtr &op_desc, const FusionPattern::OpDescPtr &pattern_desc); - - static bool VerifyInputDescNodes(const ge::NodePtr &input_node, const std::shared_ptr &input_desc, - const Mapping &mapping); -}; - -} // namespace fe - -#endif // FE_PATTERN_FUSION_BASE_PASS_H diff --git a/register/hidden_input_func_registry.cc b/register/hidden_input_func_registry.cc deleted file mode 100644 index f18ba7e02a05117d336474c08cac26585dc2dc55..0000000000000000000000000000000000000000 --- a/register/hidden_input_func_registry.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/hidden_input_func_registry.h" -#include -#include "common/ge_common/debug/ge_log.h" -namespace ge { -GetHiddenAddr HiddenInputFuncRegistry::FindHiddenInputFunc(const HiddenInputType input_type) { - const auto &iter = type_to_funcs_.find(input_type); - if (iter != type_to_funcs_.end()) { - return iter->second; - } - GELOGW("Hidden input func not found, type:[%d].", static_cast(input_type)); - return nullptr; -} - -void HiddenInputFuncRegistry::Register(const HiddenInputType input_type, const GetHiddenAddr func) { - type_to_funcs_[input_type] = func; -} -HiddenInputFuncRegistry &HiddenInputFuncRegistry::GetInstance() { - static HiddenInputFuncRegistry registry; - return registry; -} - -HiddenInputFuncRegister::HiddenInputFuncRegister(const HiddenInputType input_type, const GetHiddenAddr func) { - HiddenInputFuncRegistry::GetInstance().Register(input_type, func); -} -} // namespace ge diff --git a/register/hidden_inputs_func_registry.cc b/register/hidden_inputs_func_registry.cc deleted file mode 100644 index 9a47ec94bc796d9c4075130b45e527ed08e5d045..0000000000000000000000000000000000000000 --- a/register/hidden_inputs_func_registry.cc +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/hidden_inputs_func_registry.h" -#include -#include "common/ge_common/debug/ge_log.h" -namespace ge { -GetHiddenAddrs HiddenInputsFuncRegistry::FindHiddenInputsFunc(const HiddenInputsType input_type) { - const auto &iter = type_to_funcs_.find(input_type); - if (iter != type_to_funcs_.end()) { - return iter->second; - } - GELOGW("Hidden input func not found, type:[%d].", static_cast(input_type)); - return nullptr; -} - -void HiddenInputsFuncRegistry::Register(const HiddenInputsType input_type, const GetHiddenAddrs func) { - GELOGD("Hidden input func reg, type:[%d].", static_cast(input_type)); - type_to_funcs_[input_type] = func; -} -HiddenInputsFuncRegistry &HiddenInputsFuncRegistry::GetInstance() { - static HiddenInputsFuncRegistry registry; - return registry; -} - -HiddenInputsFuncRegister::HiddenInputsFuncRegister(const HiddenInputsType input_type, const GetHiddenAddrs func) { - HiddenInputsFuncRegistry::GetInstance().Register(input_type, func); -} -} // namespace ge diff --git a/register/host_cpu_context.cc b/register/host_cpu_context.cc deleted file mode 100644 index 7b12ff7b85cbf50b7c4c579b4f8541f666b891fe..0000000000000000000000000000000000000000 --- a/register/host_cpu_context.cc +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/host_cpu_context.h" - -namespace ge { -class HostCpuContext::Impl { - public: - Impl() = default; - ~Impl() = default; -}; -} // namespace ge diff --git a/register/infer_axis_slice_registry.cc b/register/infer_axis_slice_registry.cc deleted file mode 100644 index 1aa63ce561a279f397781b457a006a20db7a2ea5..0000000000000000000000000000000000000000 --- a/register/infer_axis_slice_registry.cc +++ /dev/null @@ -1,24 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/infer_axis_slice_registry.h" -#include "external/graph/types.h" -#include "graph/operator_factory_impl.h" - -namespace ge { -InferAxisTypeInfoFuncRegister::InferAxisTypeInfoFuncRegister(const char_t *const operator_type, - const InferAxisTypeInfoFunc &infer_axis_type_info_func) { - (void)OperatorFactoryImpl::RegisterInferAxisTypeInfoFunc(operator_type, infer_axis_type_info_func); -} - -InferAxisSliceFuncRegister::InferAxisSliceFuncRegister(const char_t *const operator_type, - const InferAxisSliceFunc &infer_axis_slice_func) { - (void)OperatorFactoryImpl::RegisterInferAxisSliceFunc(operator_type, infer_axis_slice_func); -} -} // namespace ge diff --git a/register/infer_data_slice_registry.cc b/register/infer_data_slice_registry.cc deleted file mode 100644 index 373f7d184724bb6d261892776b5368dbcfa3f36a..0000000000000000000000000000000000000000 --- a/register/infer_data_slice_registry.cc +++ /dev/null @@ -1,19 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/infer_data_slice_registry.h" -#include "external/graph/types.h" -#include "graph/operator_factory_impl.h" - -namespace ge { -InferDataSliceFuncRegister::InferDataSliceFuncRegister(const char_t *const operator_type, - const InferDataSliceFunc &infer_data_slice_func) { - (void)OperatorFactoryImpl::RegisterInferDataSliceFunc(operator_type, infer_data_slice_func); -} -} // namespace ge diff --git a/register/kernel_launch_info.cc b/register/kernel_launch_info.cc deleted file mode 100644 index bd4f83f9ac34ec9c5e6ddc0227b197863e9d541e..0000000000000000000000000000000000000000 --- a/register/kernel_launch_info.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/kernel_launch_info.h" -#include "graph/debug/ge_util.h" -#include "kernel_launch_info_impl.h" -#include "common/checker.h" - -namespace ge { -KernelLaunchInfo::~KernelLaunchInfo() {} -KernelLaunchInfo::KernelLaunchInfo(const KernelLaunchInfo &other) { - impl_ = ComGraphMakeUnique(); - if ((other.impl_ != nullptr) && (impl_ != nullptr)) { - *impl_ = *other.impl_; - } -} -KernelLaunchInfo::KernelLaunchInfo(KernelLaunchInfo &&other) noexcept { - impl_ = std::move(other.impl_); -} -KernelLaunchInfo &KernelLaunchInfo::operator=(const KernelLaunchInfo &other) { - if (&other != this) { - impl_ = ComGraphMakeUnique(); - if ((other.impl_ != nullptr) && (impl_ != nullptr)) { - *impl_ = *other.impl_; - } - } - return *this; -} -KernelLaunchInfo &KernelLaunchInfo::operator=(KernelLaunchInfo &&other) noexcept { - if (&other != this) { - impl_ = std::move(other.impl_); - } - return *this; -} - -KernelLaunchInfo::KernelLaunchInfo(KernelLaunchInfoImplPtr &&impl) : impl_(std::move(impl)) {} - -KernelLaunchInfo KernelLaunchInfo::LoadFromData(const gert::ExeResGenerationContext *context, - const std::vector &data) { - return KernelLaunchInfo(KernelLaunchInfoImpl::LoadFromData(context, data)); -} -KernelLaunchInfo KernelLaunchInfo::CreateAicpuKfcTask(const gert::ExeResGenerationContext *context, - const char *so_name, const char *kernel_name) { - return KernelLaunchInfo(KernelLaunchInfoImpl::CreateAicpuKfcTask(context, so_name, kernel_name)); -} - -KernelLaunchInfo KernelLaunchInfo::CreateHcomRecordTask(const gert::ExeResGenerationContext *context, - const char *group_name) { - return KernelLaunchInfo(KernelLaunchInfoImpl::CreateHcomRecordTask(context, group_name)); -} -KernelLaunchInfo KernelLaunchInfo::CreateHcomWaitTask(const gert::ExeResGenerationContext *context, - const char *group_name) { - return KernelLaunchInfo(KernelLaunchInfoImpl::CreateHcomWaitTask(context, group_name)); -} -std::vector KernelLaunchInfo::Serialize() { - if (impl_ != nullptr) { - return impl_->Serialize(); - } - return {}; -} -uint32_t KernelLaunchInfo::GetStreamId() const { - if (impl_ != nullptr) { - return impl_->GetStreamId(); - } - return std::numeric_limits::max(); -} - -void KernelLaunchInfo::SetStreamId(uint32_t stream_id) { - if (impl_ != nullptr) { - impl_->SetStreamId(stream_id); - } -} - -uint32_t KernelLaunchInfo::GetBlockDim() const { - if (impl_ != nullptr) { - return impl_->GetBlockDim(); - } - return std::numeric_limits::max(); -} - -graphStatus KernelLaunchInfo::SetBlockDim(uint32_t block_dim) { - GE_ASSERT_NOTNULL(impl_); - return impl_->SetBlockDim(block_dim); -} - -const char *KernelLaunchInfo::GetArgsFormat() const { - if (impl_ != nullptr) { - return impl_->GetArgsFormat(); - } - return nullptr; -} -graphStatus KernelLaunchInfo::SetArgsFormat(const char *args_format) { - GE_ASSERT_NOTNULL(impl_); - return impl_->SetArgsFormat(args_format); -} - -const char *KernelLaunchInfo::GetSoName() const { - if (impl_ != nullptr) { - return impl_->GetSoName(); - } - return nullptr; -} -const char *KernelLaunchInfo::GetKernelName() const { - if (impl_ != nullptr) { - return impl_->GetKernelName(); - } - return nullptr; -} -} \ No newline at end of file diff --git a/register/kernel_launch_info_impl.cc b/register/kernel_launch_info_impl.cc deleted file mode 100644 index 4c945bbe1f49e645bb09849d657a7575c69ce49d..0000000000000000000000000000000000000000 --- a/register/kernel_launch_info_impl.cc +++ /dev/null @@ -1,158 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "kernel_launch_info_impl.h" -#include "graph/debug/ge_util.h" -#include "common/checker.h" -#include "runtime/rt_model.h" -#include "ge/framework/common/taskdown_common.h" - -namespace ge { -namespace { -bool IsAllKernel(const domi::TaskDef &task_def) { - return (task_def.type() == RT_MODEL_TASK_ALL_KERNEL) || (task_def.type() == RT_MODEL_TASK_VECTOR_ALL_KERNEL); -} -} -KernelLaunchInfoImplPtr KernelLaunchInfoImpl::LoadFromData(const gert::ExeResGenerationContext *context, - const std::vector &data) { - GE_ASSERT_NOTNULL(context); - auto impl_ptr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(impl_ptr); - GE_ASSERT_TRUE(impl_ptr->task_def_.ParseFromArray(data.data(), data.size())); - impl_ptr->context_ = const_cast(context); - return impl_ptr; -} -KernelLaunchInfoImplPtr KernelLaunchInfoImpl::CreateAicpuKfcTask(const gert::ExeResGenerationContext *context, - const char *so_name, const char *kernel_name) { - GE_ASSERT_NOTNULL(context); - auto impl_ptr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(impl_ptr); - impl_ptr->context_ = const_cast(context); - impl_ptr->task_def_.set_type(RT_MODEL_TASK_KERNEL); - auto kernel_def = impl_ptr->task_def_.mutable_kernel(); - GE_ASSERT_NOTNULL(kernel_def); - kernel_def->set_so_name(so_name); - kernel_def->set_kernel_name(kernel_name); - auto kernel_context = kernel_def->mutable_context(); - GE_ASSERT_NOTNULL(kernel_context); - kernel_context->set_kernel_type(static_cast(ccKernelType::AI_CPU_KFC)); - kernel_context->set_op_index(context->GetOpId()); - return impl_ptr; -} - -KernelLaunchInfoImplPtr KernelLaunchInfoImpl::CreateHcomRecordTask(const gert::ExeResGenerationContext *context, - const char *group_name) { - GE_ASSERT_NOTNULL(context); - GE_ASSERT_NOTNULL(group_name); - auto impl_ptr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(impl_ptr); - impl_ptr->context_ = const_cast(context); - impl_ptr->task_def_.set_id(context->GetOpId()); - impl_ptr->task_def_.set_notify_id(UINT32_MAX); - impl_ptr->task_def_.set_type(RT_MODEL_TASK_NOTIFY_RECORD); - impl_ptr->task_def_.set_private_def(group_name); - return impl_ptr; -} - -KernelLaunchInfoImplPtr KernelLaunchInfoImpl::CreateHcomWaitTask(const gert::ExeResGenerationContext *context, - const char *group_name) { - GE_ASSERT_NOTNULL(context); - GE_ASSERT_NOTNULL(group_name); - auto impl_ptr = ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(impl_ptr); - impl_ptr->context_ = const_cast(context); - impl_ptr->task_def_.set_id(context->GetOpId()); - impl_ptr->task_def_.set_notify_id(UINT32_MAX); - impl_ptr->task_def_.set_type(RT_MODEL_TASK_NOTIFY_WAIT); - impl_ptr->task_def_.set_private_def(group_name); - return impl_ptr; -} - -std::vector KernelLaunchInfoImpl::Serialize() { - auto buffer_size = task_def_.ByteSizeLong(); - std::vector buffer(buffer_size, 0); - GE_ASSERT_TRUE(task_def_.SerializeToArray(buffer.data(), buffer_size)); - return buffer; -} -uint32_t KernelLaunchInfoImpl::GetStreamId() const { - return task_def_.stream_id(); -} -void KernelLaunchInfoImpl::SetStreamId(uint32_t stream_id) { - task_def_.set_stream_id(stream_id); -} - -uint32_t KernelLaunchInfoImpl::GetBlockDim() const { - uint32_t block_dim = 0; - if (task_def_.type() == RT_MODEL_TASK_KERNEL) { - block_dim = task_def_.kernel().block_dim(); - } else if (IsAllKernel(task_def_)) { - block_dim = task_def_.kernel_with_handle().block_dim(); - } else { - GELOGE(FAILED, "Only aicpu and aicore task has block_dim, but get[%d]", - task_def_.type()); - } - return block_dim; -} - -graphStatus KernelLaunchInfoImpl::SetBlockDim(uint32_t block_dim) { - if (task_def_.type() == RT_MODEL_TASK_KERNEL) { - auto kernel_def = task_def_.mutable_kernel(); - GE_ASSERT_NOTNULL(kernel_def); - kernel_def->set_block_dim(block_dim); - } else if (IsAllKernel(task_def_)) { - auto kernel_with_handle = task_def_.mutable_kernel_with_handle(); - GE_ASSERT_NOTNULL(kernel_with_handle); - kernel_with_handle->set_block_dim(block_dim); - } else { - // 报错 - GE_ASSERT_TRUE(false, "Only aicpu and aicore task can set args format, but get[%d]", - task_def_.type()); - } - return SUCCESS; -} - -const char *KernelLaunchInfoImpl::GetArgsFormat() const { - domi::KernelContext kernel_context; - if (task_def_.type() == RT_MODEL_TASK_KERNEL) { - return task_def_.kernel().context().args_format().c_str(); - } - if (IsAllKernel(task_def_)) { - return task_def_.kernel_with_handle().context().args_format().c_str(); - } - GELOGE(FAILED, "Only aicpu and aicore task has args format, but get[%d]", - task_def_.type()); - return nullptr; -} -graphStatus KernelLaunchInfoImpl::SetArgsFormat(const char *args_format) { - GE_ASSERT_NOTNULL(args_format); - domi::KernelContext *kernel_context = nullptr; - if (task_def_.type() == RT_MODEL_TASK_KERNEL) { - auto kernel_def = task_def_.mutable_kernel(); - GE_ASSERT_NOTNULL(kernel_def); - kernel_context = kernel_def->mutable_context(); - } else if (IsAllKernel(task_def_)) { - auto kernel_with_handle = task_def_.mutable_kernel_with_handle(); - GE_ASSERT_NOTNULL(kernel_with_handle); - kernel_context = kernel_with_handle->mutable_context(); - } else { - GELOGE(FAILED, "Only aicpu and aicore task can set args format, but get[%d]", - task_def_.type()); - } - GE_ASSERT_NOTNULL(kernel_context); - kernel_context->set_args_format(args_format); - return SUCCESS; -} - -const char *KernelLaunchInfoImpl::GetSoName() const { - return task_def_.kernel().so_name().c_str(); -} -const char *KernelLaunchInfoImpl::GetKernelName() const { - return task_def_.kernel().kernel_name().c_str(); -} -} \ No newline at end of file diff --git a/register/kernel_launch_info_impl.h b/register/kernel_launch_info_impl.h deleted file mode 100644 index 5f2ec58fedcae0dd40f96226b8f17afb0c21572d..0000000000000000000000000000000000000000 --- a/register/kernel_launch_info_impl.h +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_REGISTER_KERNEL_LAUNCH_INFO_IMPL_H -#define METADEF_REGISTER_KERNEL_LAUNCH_INFO_IMPL_H - -#include "graph/kernel_launch_info.h" -#include "proto/task.pb.h" - -namespace ge { -class KernelLaunchInfoImpl { - public: - ~KernelLaunchInfoImpl() = default; - KernelLaunchInfoImpl() = default; - static KernelLaunchInfoImplPtr LoadFromData(const gert::ExeResGenerationContext *context, - const std::vector &data); - static KernelLaunchInfoImplPtr CreateAicpuKfcTask(const gert::ExeResGenerationContext *context, - const char *so_name, const char *kernel_name); - static KernelLaunchInfoImplPtr CreateHcomRecordTask(const gert::ExeResGenerationContext *context, - const char *group_name); - static KernelLaunchInfoImplPtr CreateHcomWaitTask(const gert::ExeResGenerationContext *context, - const char *group_name); - std::vector Serialize(); - uint32_t GetStreamId() const; - void SetStreamId(uint32_t stream_id); - uint32_t GetBlockDim() const; - graphStatus SetBlockDim(uint32_t block_dim); - const char *GetArgsFormat() const; - graphStatus SetArgsFormat(const char *args_format); - const char *GetSoName() const; - const char *GetKernelName() const; - private: - domi::TaskDef task_def_; - gert::ExeResGenerationContext *context_; -}; -} -#endif // METADEF_REGISTER_KERNEL_LAUNCH_INFO_IMPL_H \ No newline at end of file diff --git a/register/kernel_register_data.cc b/register/kernel_register_data.cc deleted file mode 100644 index 6074819ee9ab4e4ae50c575ad0a3e8619f14e251..0000000000000000000000000000000000000000 --- a/register/kernel_register_data.cc +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "kernel_register_data.h" - -namespace gert { -namespace { -ge::graphStatus NullCreator(const ge::FastNode *node, KernelContext *context) { - (void) node; - (void) context; - return ge::GRAPH_SUCCESS; -} -} // namespace -KernelRegisterData::KernelRegisterData(const ge::char_t *kernel_type) : kernel_type_(kernel_type) { - funcs_.outputs_creator = NullCreator; - funcs_.trace_printer = nullptr; - critical_section_ = ""; - funcs_.profiling_info_filler = nullptr; - funcs_.data_dump_info_filler = nullptr; - funcs_.exception_dump_info_filler = nullptr; -} -} // namespace gert diff --git a/register/kernel_register_data.h b/register/kernel_register_data.h deleted file mode 100644 index 05217da573f7dc049bf723bd5dd024e4e06ab2e9..0000000000000000000000000000000000000000 --- a/register/kernel_register_data.h +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_REGISTER_KERNEL_REGISTER_DATA_H_ -#define METADEF_CXX_REGISTER_KERNEL_REGISTER_DATA_H_ -#include -#include "register/kernel_registry.h" -namespace gert { -class KernelRegisterData { - public: - explicit KernelRegisterData(const ge::char_t *kernel_type); - - KernelRegistry::KernelFuncs &GetFuncs() { - return funcs_; - } - - const std::string &GetKernelType() const { - return kernel_type_; - } - - std::string &GetCriticalSection() { - return critical_section_; - } - - private: - std::string critical_section_; - std::string kernel_type_; - KernelRegistry::KernelFuncs funcs_; -}; -} // namespace gert - -#endif // METADEF_CXX_REGISTER_KERNEL_REGISTER_DATA_H_ diff --git a/register/kernel_registry_impl.cc b/register/kernel_registry_impl.cc deleted file mode 100644 index 8b6226a1dbdff35e53f631bd0585726cce414a2a..0000000000000000000000000000000000000000 --- a/register/kernel_registry_impl.cc +++ /dev/null @@ -1,115 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/kernel_registry_impl.h" -#include -#include "graph/debug/ge_log.h" -#include "kernel_register_data.h" -namespace gert { -namespace { -std::shared_ptr g_user_defined_registry = nullptr; -} // namespace - -KernelRegistry &KernelRegistry::GetInstance() { - if (g_user_defined_registry != nullptr) { - return *g_user_defined_registry; - } else { - return KernelRegistryImpl::GetInstance(); - } -} -void KernelRegistry::ReplaceKernelRegistry(std::shared_ptr registry) { - g_user_defined_registry = std::move(registry); -} - -KernelRegistryImpl &KernelRegistryImpl::GetInstance() { - static KernelRegistryImpl registry; - return registry; -} -void KernelRegistryImpl::RegisterKernel(std::string kernel_type, KernelInfo kernel_infos) { - kernel_infos_[std::move(kernel_type)] = std::move(kernel_infos); -} - -const KernelRegistry::KernelFuncs *KernelRegistryImpl::FindKernelFuncs(const std::string &kernel_type) const { - const auto iter = kernel_infos_.find(kernel_type); - if (iter == kernel_infos_.end()) { - return nullptr; - } - return &iter->second.func; -} -const KernelRegistry::KernelInfo *KernelRegistryImpl::FindKernelInfo(const std::string &kernel_type) const { - const auto iter = kernel_infos_.find(kernel_type); - if (iter == kernel_infos_.end()) { - return nullptr; - } - return &iter->second; -} -const std::unordered_map &KernelRegistryImpl::GetAll() const { - return kernel_infos_; -} - -KernelRegisterV2::KernelRegisterV2(const char *kernel_type) - : register_data_(new(std::nothrow) KernelRegisterData(kernel_type)) {} -KernelRegisterV2::~KernelRegisterV2() = default; -KernelRegisterV2 &KernelRegisterV2::RunFunc(KernelRegistry::KernelFunc func) { - if (register_data_ != nullptr) { - register_data_->GetFuncs().run_func = func; - } - return *this; -} -KernelRegisterV2 &KernelRegisterV2::ConcurrentCriticalSectionKey(const std::string &critical_section_key) { - if (register_data_ != nullptr) { - register_data_->GetCriticalSection() = critical_section_key; - } - return *this; -} -KernelRegisterV2 &KernelRegisterV2::OutputsCreator(KernelRegistry::CreateOutputsFunc func) { - if (register_data_ != nullptr) { - register_data_->GetFuncs().outputs_creator = func; - } - return *this; -} -KernelRegisterV2 &KernelRegisterV2::TracePrinter(KernelRegistry::TracePrinter func) { - if (register_data_ != nullptr) { - register_data_->GetFuncs().trace_printer = func; - } - return *this; -} - -KernelRegisterV2 &KernelRegisterV2::ProfilingInfoFiller(KernelRegistry::ProfilingInfoFiller func) { - if (register_data_ != nullptr) { - register_data_->GetFuncs().profiling_info_filler = func; - } - return *this; -} - -KernelRegisterV2 &KernelRegisterV2::DataDumpInfoFiller(KernelRegistry::DataDumpInfoFiller func) { - if (register_data_ != nullptr) { - register_data_->GetFuncs().data_dump_info_filler = func; - } - return *this; -} - -KernelRegisterV2 &KernelRegisterV2::ExceptionDumpInfoFiller(KernelRegistry::ExceptionDumpInfoFiller func) { - if (register_data_ != nullptr) { - register_data_->GetFuncs().exception_dump_info_filler = func; - } - return *this; -} - -KernelRegisterV2::KernelRegisterV2(const KernelRegisterV2 &other) : register_data_(nullptr) { - const auto register_data = other.register_data_.get(); - if (register_data == nullptr) { - GE_LOGE("The register_data_ in register object is nullptr, failed to register funcs"); - return; - } - GELOGD("GERT kernel type %s registered", register_data->GetKernelType().c_str()); - KernelRegistry::GetInstance().RegisterKernel(register_data->GetKernelType(), - {register_data->GetFuncs(), register_data->GetCriticalSection()}); -} -} // namespace gert diff --git a/register/node_converter_registry.cc b/register/node_converter_registry.cc deleted file mode 100644 index 6e2b41cdb89c6d78ec79cbcfa2308d688f8ee25e..0000000000000000000000000000000000000000 --- a/register/node_converter_registry.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/node_converter_registry.h" -#include "common/hyper_status.h" - -namespace gert { -NodeConverterRegistry &NodeConverterRegistry::GetInstance() { - static NodeConverterRegistry registry; - return registry; -} - -NodeConverterRegistry::NodeConverter NodeConverterRegistry::FindNodeConverter(const string &func_name) { - auto data = FindRegisterData(func_name); - if (data == nullptr) { - return nullptr; - } - return data->converter; -} -void NodeConverterRegistry::RegisterNodeConverter(const std::string &func_name, NodeConverter func) { - names_to_register_data_[func_name] = {func, -1}; -} -const NodeConverterRegistry::ConverterRegisterData *NodeConverterRegistry::FindRegisterData( - const string &func_name) const { - auto iter = names_to_register_data_.find(func_name); - if (iter == names_to_register_data_.end()) { - return nullptr; - } - return &iter->second; -} -void NodeConverterRegistry::Register(const string &func_name, - const NodeConverterRegistry::ConverterRegisterData &data) { - names_to_register_data_[func_name] = data; -} -NodeConverterRegister::NodeConverterRegister(const char *lower_func_name, - NodeConverterRegistry::NodeConverter func) noexcept { - NodeConverterRegistry::GetInstance().Register(lower_func_name, {func, -1}); -} -NodeConverterRegister::NodeConverterRegister(const char *lower_func_name, int32_t require_placement, - NodeConverterRegistry::NodeConverter func) noexcept { - NodeConverterRegistry::GetInstance().Register(lower_func_name, {func, require_placement}); -} -} // namespace gert diff --git a/register/op_binary_resource_manager.cc b/register/op_binary_resource_manager.cc deleted file mode 100644 index 94129cf28a761fdcc57fb3176a04fc96bb1b6c97..0000000000000000000000000000000000000000 --- a/register/op_binary_resource_manager.cc +++ /dev/null @@ -1,204 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_binary_resource_manager.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/checker.h" - -namespace nnopbase { -namespace { -ge::graphStatus GetStr(const std::tuple &input, std::string &str) -{ - const uint8_t *start = std::get<0U>(input); - const uint8_t *end = std::get<1U>(input); - if ((end < start) || (start == nullptr) || (end == nullptr)) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Parse json failed, end is %p, start is %p!", end, start); - return ge::GRAPH_PARAM_INVALID; - } - const size_t len = end - start; - str = std::string((const char *)start, len); - GELOGD("Parse str addr is %p, len is %zu, %s.", start, len, str.c_str()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseJson(const std::tuple &input, nlohmann::json &res) -{ - std::string jsonStr; - GE_ASSERT_GRAPH_SUCCESS(GetStr(input, jsonStr)); - try { - res = nlohmann::json::parse(jsonStr); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Parse json failed, resion %s, json info %s.", e.what(), jsonStr.c_str()); - return ge::GRAPH_PARAM_INVALID; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseBinary(const std::tuple &input, Binary &binaryInfo) -{ - const uint8_t *start = std::get<0U>(input); - const uint8_t *end = std::get<1U>(input); - if ((end < start) || (start == nullptr) || (end == nullptr)) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Parse json failed, end is %p, start is %p!", end, start); - return ge::GRAPH_PARAM_INVALID; - } - binaryInfo.len = end - start; - binaryInfo.content = start; - GELOGD("Parse binary info, addr is %p, len is %us.", binaryInfo.content, binaryInfo.len); - return ge::GRAPH_SUCCESS; -} -} // namepsace - -void OpBinaryResourceManager::AddOpFuncHandle(const ge::AscendString &opType, - const std::vector &opResourceHandle) -{ - const std::lock_guard lk(mutex_); - const auto &it = resourceHandle_.find(opType); - if (it != resourceHandle_.end()) { - return; - } - GELOGI("Add op %s func handle, num is %zu!", opType.GetString(), opResourceHandle.size()); - for (auto func : opResourceHandle) { - (void)resourceHandle_[opType].emplace_back(func); - } -} - -// 首个信息一定存在,是算子描述json,后续是成对的二进制信息 -ge::graphStatus OpBinaryResourceManager::AddBinary(const ge::AscendString &opType, - const std::vector> &opBinary) -{ - const std::lock_guard lk(mutex_); - const auto &it = opBinaryDesc_.find(opType); - if (it != opBinaryDesc_.end()) { - return ge::GRAPH_SUCCESS; - } - - // 首个信息是op的描述信息 - if (opBinary.size() >= 1U) { - nlohmann::json opDesc; - GE_ASSERT_GRAPH_SUCCESS(ParseJson(opBinary[0], opDesc), "Parse op %s json failed!", opType.GetString()); - opBinaryDesc_[opType] = opDesc; - } - - for (size_t i = 1U; i + 1U < opBinary.size(); i += 2U) { // 2 for json & binary - nlohmann::json binaryDesc; - Binary binaryInfo; - GE_ASSERT_GRAPH_SUCCESS(ParseJson(opBinary[i], binaryDesc), "Parse op %s binary json file [%zu] failed!", - opType.GetString(), i / 2U); // 2 for idx - GE_ASSERT_GRAPH_SUCCESS(ParseBinary(opBinary[i + 1U], binaryInfo), "Parse op %s binary file [%zu] failed!", - opType.GetString(), i / 2U); // 2 for idx - - std::string filePath; - try { - filePath = binaryDesc["filePath"].get(); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_PARAM_INVALID, "Parse op %s json filePath from binary json failed, reason %s.", - opType.GetString(), e.what()); - return ge::GRAPH_PARAM_INVALID; - } - pathToBinary_[filePath.c_str()] = std::tuple(binaryDesc, binaryInfo); - GELOGI("Add op %s binary, filePath %s, bin addr is %p, bin len %u.", opType.GetString(), filePath.c_str(), - binaryInfo.content, binaryInfo.len); - - std::vector keys; - try { - auto supportInfo = binaryDesc["supportInfo"]; - keys = supportInfo["simplifiedKey"].get>(); - } catch (const nlohmann::json::exception &e) { - GELOGW("Get op %s json simplifiedKey from binary json failed, reason %s.", opType.GetString(), e.what()); - } - for (auto key : keys) { - GELOGI("Add op %s binary, simplifiedKey %s, filePath %s, bin addr %p, bin len %u.", opType.GetString(), - key.c_str(), filePath.c_str(), binaryInfo.content, binaryInfo.len); - keyToPath_[key.c_str()] = filePath.c_str(); - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus OpBinaryResourceManager::AddRuntimeKB(const ge::AscendString &opType, - const std::vector> &opRuntimeKb) -{ - const std::lock_guard lk(mutex_); - const auto &it = runtimeKb_.find(opType); - if (it != runtimeKb_.end()) { - return ge::GRAPH_SUCCESS; - } - for (const auto &kbInfo : opRuntimeKb) { - std::string kbStr; - GE_ASSERT_GRAPH_SUCCESS(GetStr(kbInfo, kbStr), "Parse op %s runtime kb json file!", opType.GetString()); - (void)runtimeKb_[opType].emplace_back(kbStr.c_str()); - } - GELOGI("Add op %s runtime kb num %zu!", opType.GetString(), runtimeKb_[opType].size()); - return ge::GRAPH_SUCCESS; -} - -const std::map &OpBinaryResourceManager::GetAllOpBinaryDesc() const -{ - const std::lock_guard lk(mutex_); - GELOGI("Get all op binary desc, num is %zu.", opBinaryDesc_.size()); - return opBinaryDesc_; -} - -ge::graphStatus OpBinaryResourceManager::GetOpBinaryDesc(const ge::AscendString &opType, nlohmann::json &binDesc) const -{ - const std::lock_guard lk(mutex_); - const auto &it = opBinaryDesc_.find(opType); - if (it == opBinaryDesc_.end()) { - // 返回错误码表示该optype不存在,但不打印error日志,可以在执行时调用,根据返回判断是否存在静态二进制 - GELOGW("Get op %s json info failed!", opType.GetString()); - return ge::GRAPH_PARAM_INVALID; - } - binDesc = it->second; - GELOGI("Get op %s binary desc.", opType.GetString()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus OpBinaryResourceManager::GetOpBinaryDescByPath(const ge::AscendString &jsonFilePath, - std::tuple &binInfo) const -{ - const std::lock_guard lk(mutex_); - const auto &it = pathToBinary_.find(jsonFilePath); - if (it == pathToBinary_.end()) { - GELOGW("Get binaryInfo by json path failed, path is %s.", jsonFilePath.GetString()); - return ge::GRAPH_PARAM_INVALID; - } - binInfo = it->second; - GELOGI("Get binary info, json path is %s.", jsonFilePath.GetString()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus OpBinaryResourceManager::GetOpBinaryDescByKey(const ge::AscendString &simplifiedKey, - std::tuple &binInfo) const -{ - const std::lock_guard lk(mutex_); - const auto &it = keyToPath_.find(simplifiedKey); - const auto &simplified = simplifiedKey; - if (it == keyToPath_.end()) { - GELOGW("Get binaryInfo by simplified failed, simplified is %s.", simplified.GetString()); - return ge::GRAPH_PARAM_INVALID; - } - GELOGI("Get binary info, simplified is %s.", simplified.GetString()); - return GetOpBinaryDescByPath(it->second, binInfo); -} - -ge::graphStatus OpBinaryResourceManager::GetOpRuntimeKB(const ge::AscendString &opType, - std::vector &kbList) const -{ - const std::lock_guard lk(mutex_); - const auto &it = runtimeKb_.find(opType); - if (it == runtimeKb_.end()) { - GELOGW("Get op %s RuntimeKB info failed.", opType.GetString()); - return ge::GRAPH_PARAM_INVALID; - } - kbList = it->second; - GELOGI("Get op %s RuntimeKB info, num is %zu.", opType.GetString(), kbList.size()); - return ge::GRAPH_SUCCESS; -} -} // nnopbase diff --git a/register/op_ext_calc_param_registry.cc b/register/op_ext_calc_param_registry.cc deleted file mode 100644 index 4665879b4a006486f89c493cac863ca1b874dbe7..0000000000000000000000000000000000000000 --- a/register/op_ext_calc_param_registry.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_ext_calc_param_registry.h" -#include "proto/task.pb.h" - -namespace fe { -OpExtCalcParamRegistry &OpExtCalcParamRegistry::GetInstance() { - static OpExtCalcParamRegistry registry; - return registry; -} - -OpExtCalcParamFunc OpExtCalcParamRegistry::FindRegisterFunc(const std::string &op_type) const { - auto iter = names_to_register_func_.find(op_type); - if (iter == names_to_register_func_.end()) { - return nullptr; - } - return iter->second; -} - -void OpExtCalcParamRegistry::Register(const std::string &op_type, OpExtCalcParamFunc const func) { - names_to_register_func_[op_type] = func; -} - -OpExtGenCalcParamRegister::OpExtGenCalcParamRegister(const char *op_type, OpExtCalcParamFunc func) noexcept { - OpExtCalcParamRegistry::GetInstance().Register(op_type, func); -} -} // namespace fe diff --git a/register/op_ext_gentask_registry.cc b/register/op_ext_gentask_registry.cc deleted file mode 100644 index 3a3b0eea7d5b504db37bd2a7007eacb0bc83f3c2..0000000000000000000000000000000000000000 --- a/register/op_ext_gentask_registry.cc +++ /dev/null @@ -1,66 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_ext_gentask_registry.h" - -namespace fe { -OpExtGenTaskRegistry &OpExtGenTaskRegistry::GetInstance() { - static OpExtGenTaskRegistry registry; - return registry; -} - -OpExtGenTaskFunc OpExtGenTaskRegistry::FindRegisterFunc(const std::string &op_type) const { - auto iter = names_to_register_func_.find(op_type); - if (iter == names_to_register_func_.end()) { - return nullptr; - } - return iter->second; -} - -SKExtGenTaskFunc OpExtGenTaskRegistry::FindSKRegisterFunc(const std::string &op_type) const { - auto iter = types_to_sk_register_func_.find(op_type); - if (iter == types_to_sk_register_func_.end()) { - return nullptr; - } - return iter->second; -} - -ExtTaskType OpExtGenTaskRegistry::GetExtTaskType(const std::string &op_type) const { - if (aicore_ext_task_ops_.count(op_type) > 0) { - return ExtTaskType::kAicoreTask; - } - return ExtTaskType::kFftsPlusTask; -} - -void OpExtGenTaskRegistry::Register(const std::string &op_type, OpExtGenTaskFunc const func) { - names_to_register_func_[op_type] = func; -} - -void OpExtGenTaskRegistry::RegisterSKFunc(const std::string &op_type, SKExtGenTaskFunc const func) { - types_to_sk_register_func_[op_type] = func; -} - -void OpExtGenTaskRegistry::RegisterAicoreExtTask(const std::string &op_type) { - aicore_ext_task_ops_.emplace(op_type); -} - -OpExtGenTaskRegister::OpExtGenTaskRegister(const char *op_type, OpExtGenTaskFunc func) noexcept { - OpExtGenTaskRegistry::GetInstance().Register(op_type, func); -} - -SKExtGenTaskRegister::SKExtGenTaskRegister(const char *op_type, SKExtGenTaskFunc func) noexcept { - OpExtGenTaskRegistry::GetInstance().RegisterSKFunc(op_type, func); -} - -ExtTaskTypeRegister::ExtTaskTypeRegister(const char *op_type, ExtTaskType type) noexcept { - if (type == ExtTaskType::kAicoreTask) { - OpExtGenTaskRegistry::GetInstance().RegisterAicoreExtTask(op_type); - } -} -} // namespace fe diff --git a/register/op_kernel_registry.cpp b/register/op_kernel_registry.cpp deleted file mode 100644 index ecefed433017ee30513dce6050a6efcc17ffe665..0000000000000000000000000000000000000000 --- a/register/op_kernel_registry.cpp +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_kernel_registry.h" -#include -#include -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" - -namespace ge { -class OpKernelRegistry::OpKernelRegistryImpl { - public: - void RegisterHostCpuOp(const std::string &op_type, const OpKernelRegistry::CreateFn create_fn) { - const std::lock_guard lock(mu_); - create_fns_[op_type] = create_fn; - } - - OpKernelRegistry::CreateFn GetCreateFn(const std::string &op_type) { - const std::lock_guard lock(mu_); - const auto it = create_fns_.find(op_type); - if (it == create_fns_.end()) { - return nullptr; - } - - return it->second; - } - - private: - std::mutex mu_; - std::map create_fns_; -}; - -OpKernelRegistry::OpKernelRegistry() { - impl_ = ge::ComGraphMakeUnique(); -} - -OpKernelRegistry::~OpKernelRegistry() = default; - -OpKernelRegistry& OpKernelRegistry::GetInstance() { - static OpKernelRegistry instance; - return instance; -} - -bool OpKernelRegistry::IsRegistered(const std::string &op_type) const { - if (impl_ == nullptr) { - GELOGE(MEMALLOC_FAILED, - "[Check][Param:impl_]Failed to invoke IsRegistered %s, OpKernelRegistry is not properly initialized", - op_type.c_str()); - return false; - } - - return impl_->GetCreateFn(op_type) != nullptr; -} - -void OpKernelRegistry::RegisterHostCpuOp(const std::string &op_type, const CreateFn create_fn) { - if (impl_ == nullptr) { - GELOGE(MEMALLOC_FAILED, - "[Check][Param:impl_]Failed to register %s, OpKernelRegistry is not properly initialized", - op_type.c_str()); - return; - } - - impl_->RegisterHostCpuOp(op_type, create_fn); -} -std::unique_ptr OpKernelRegistry::CreateHostCpuOp(const std::string &op_type) const { - if (impl_ == nullptr) { - GELOGE(MEMALLOC_FAILED, - "[Check][Param:impl_]Failed to create op for %s, OpKernelRegistry is not properly initialized", - op_type.c_str()); - return nullptr; - } - - const auto create_fn = impl_->GetCreateFn(op_type); - if (create_fn == nullptr) { - GELOGD("Host Cpu op is not registered. op type = %s", op_type.c_str()); - return nullptr; - } - - return std::unique_ptr(create_fn()); -} - -HostCpuOpRegistrar::HostCpuOpRegistrar(const char_t *const op_type, HostCpuOp *(*const create_fn)()) { - if (op_type == nullptr) { - GELOGE(PARAM_INVALID, "[Check][Param:op_type]is null,Failed to register host cpu op"); - return; - } - - OpKernelRegistry::GetInstance().RegisterHostCpuOp(op_type, create_fn); -} -} // namespace ge diff --git a/register/op_lib_register.cc b/register/op_lib_register.cc deleted file mode 100644 index 3f2c4df080c9ad3ada67dbe0703c30f1107d7cd6..0000000000000000000000000000000000000000 --- a/register/op_lib_register.cc +++ /dev/null @@ -1,179 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_lib_register_impl.h" -#include "graph/debug/ge_util.h" - -#include "mmpa/mmpa_api.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/ge_common/string_util.h" -#include "common/checker.h" -#include "common/plugin/plugin_manager.h" -#include "graph/utils/file_utils.h" - -namespace { - const std::string custom_so_name = "libcust_opapi.so"; -} - -namespace ge { -OpLibRegister::OpLibRegister(const char_t *vendor_name) : impl_(ComGraphMakeUnique()) { - if (impl_ != nullptr) { - impl_->SetVendorName(vendor_name); - } -} - -OpLibRegister::OpLibRegister(const OpLibRegister &other) { - if (other.impl_ != nullptr) { - OpLibRegistry::GetInstance().RegisterInitFunc(*other.impl_); - } -} - -OpLibRegister::OpLibRegister(OpLibRegister &&other) noexcept { - if (other.impl_ != nullptr) { - OpLibRegistry::GetInstance().RegisterInitFunc(*other.impl_); - } -} - -OpLibRegister::~OpLibRegister() = default; - -OpLibRegister &OpLibRegister::RegOpLibInit(OpLibRegister::OpLibInitFunc func) { - if (impl_ != nullptr) { - impl_->SetInitFunc(func); - } - return *this; -} - -OpLibRegistry &OpLibRegistry::GetInstance() { - static OpLibRegistry instance; - return instance; -} - -const char_t* OpLibRegistry::GetCustomOpLibPath() const { - GELOGI("get op lib path is %s", op_lib_paths_.c_str()); - return op_lib_paths_.c_str(); -} - -void OpLibRegistry::RegisterInitFunc(OpLibRegisterImpl ®ister_impl) { - const std::string vendor_name = register_impl.GetVendorName(); - auto func = register_impl.GetInitFunc(); - const std::lock_guard lk(mu_); - const auto it = vendor_names_set_.insert(vendor_name); - // ignore same vendor_name op lib when register secondly - if (it.second) { - if (func != nullptr) { - vendor_funcs_.emplace_back(vendor_name, func); - } - GELOGI("%s op lib register successfully", vendor_name.c_str()); - } else { - GELOGW("%s op lib has already registered", vendor_name.c_str()); - } -} - -/** - * @brief 对环境变量下ASCEND_CUSTOM_OPP_PATH新so交付的自定义算子目录作预处理,需要保证在获取自定义算子目录前调用, - * @brief 当前提供metadef接口, air仓各个流程初始化靠前的位置调用 - * - * 当前最新的自定义算子工程交付分为run包交付和so交付(新做的)两种形式: - * 新的so交付的形式下:export ASCEND_CUSTOM_OPP_PATH=/path/to/customize:/path/to/mdc:/path/to/lhisi - * 三个目录下都只有一个libcust_opapi.so - * - * 老的run包交付的形式下:export ASCEND_CUSTOM_OPP_PATH=/path/to/customize:/path/to/mdc:/path/to/lhisi - * 三个目录下都有完整的算子子目录,如op_proto,op_impl子目录等 - * - * 当前支持两种方式混用。混用优先级以新的so交付方式优先。 - * 例如export ASCEND_CUSTOM_OPP_PATH=/home/a:/home/b:/home/c,其中只有/home/b是新so交付的方式 - * 则最终优先级别顺序为b,a,c - * @return - */ -graphStatus OpLibRegistry::PreProcessForCustomOp() { - if (is_processed_) { - GELOGD("pre process for custom op has already been called"); - return GRAPH_SUCCESS; - } - std::string custom_opp_path; - const char_t *custom_opp_path_env = nullptr; - MM_SYS_GET_ENV(MM_ENV_ASCEND_CUSTOM_OPP_PATH, custom_opp_path_env); - if (custom_opp_path_env != nullptr) { - custom_opp_path = custom_opp_path_env; - } - std::vector so_real_paths; - GE_ASSERT_GRAPH_SUCCESS(GetAllCustomOpApiSoPaths(custom_opp_path, so_real_paths)); - GE_ASSERT_GRAPH_SUCCESS(CallInitFunc(custom_opp_path, so_real_paths)); - is_processed_ = true; - return GRAPH_SUCCESS; -} - -graphStatus OpLibRegistry::GetAllCustomOpApiSoPaths(const std::string &custom_opp_path, - std::vector &so_real_paths) const { - if (custom_opp_path.empty()) { - GELOGI("custom_opp_path is empty, no need to get custom op so"); - return GRAPH_SUCCESS; - } - GELOGI("value of env ASCEND_CUSTOM_OPP_PATH is %s.", custom_opp_path.c_str()); - std::vector current_custom_opp_path = StringUtils::Split(custom_opp_path, ':'); - - if (current_custom_opp_path.empty()) { - GELOGI("find no custom opp path, just return"); - return GRAPH_SUCCESS; - } - - for (const auto &path : current_custom_opp_path) { - if (path.empty()) { - continue; - } - const std::string so_path = path + "/" + custom_so_name; - std::string so_real_path = RealPath(so_path.c_str()); - if (!so_real_path.empty()) { - GELOGI("find so_real_path %s", so_real_path.c_str()); - so_real_paths.emplace_back(so_real_path); - } - } - return GRAPH_SUCCESS; -} - -graphStatus OpLibRegistry::CallInitFunc(const std::string &custom_opp_path, - const std::vector &so_real_paths) { - // dlopen so orderly - for (const auto &so_path : so_real_paths) { - GELOGI("begin dlopen %s", so_path.c_str()); - void* const handle = mmDlopen(so_path.c_str(), static_cast(static_cast(MMPA_RTLD_NOW))); - GE_ASSERT_NOTNULL(handle, "Failed to dlopen %s! errmsg:%s", so_path.c_str(), mmDlerror()); - handles_.emplace_back(handle); - } - - // call init func orderly - const std::lock_guard lk(mu_); - for (auto &vendor_func : vendor_funcs_) { - GELOGI("begin to call %s init func", vendor_func.first.c_str()); - AscendString tmp_dir(""); - GE_ASSERT_GRAPH_SUCCESS(vendor_func.second(tmp_dir)); - GELOGI("end to call %s init func, tmp_dir is %s", vendor_func.first.c_str(), tmp_dir.GetString()); - op_lib_paths_ += (std::string(tmp_dir.GetString()) + ":"); - } - if (custom_opp_path.empty()) { // ignore the end : - op_lib_paths_ = op_lib_paths_.substr(0, op_lib_paths_.find_last_of(':')); - } else { - op_lib_paths_ += custom_opp_path; // add origin env path to ensure priority(so mode first, runbag mode second) - } - PluginManager::SetCustomOpLibPath(op_lib_paths_); - GELOGI("CallInitFunc %zu successfully, op_lib_paths_ is %s", vendor_funcs_.size(), op_lib_paths_.c_str()); - return GRAPH_SUCCESS; -} - -void OpLibRegistry::ClearHandles() { - for (auto handle : handles_) { - (void)mmDlclose(handle); - } - handles_.clear(); -} - -OpLibRegistry::~OpLibRegistry() { - ClearHandles(); -} -} // namespace ge diff --git a/register/op_tiling/op_compile_info_manager.cc b/register/op_tiling/op_compile_info_manager.cc deleted file mode 100644 index cf4f518901e13892ebc89f2752469adab7126fa0..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_compile_info_manager.cc +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "op_tiling/op_compile_info_manager.h" - -namespace optiling { -CompileInfoManager::CompileInfoManager() {} -CompileInfoManager::~CompileInfoManager() {} - -CompileInfoManager& CompileInfoManager::Instance() { - static CompileInfoManager compile_info_manager_instance; - return compile_info_manager_instance; -} - -bool CompileInfoManager::HasCompileInfo(const std::string &key) { - return this->compile_info_map_.find(key) != this->compile_info_map_.end(); -} - -CompileInfoPtr CompileInfoManager::GetCompileInfo(const std::string &key) { - std::lock_guard lock_guard(compile_info_mutex_); - const auto iter = this->compile_info_map_.find(key); - if (iter == this->compile_info_map_.end()) { - return nullptr; - } - return iter->second; -} - -void CompileInfoManager::SetCompileInfo(const std::string &key, CompileInfoPtr compile_info_ptr) { - std::lock_guard lock_guard(compile_info_mutex_); - (void)this->compile_info_map_.emplace(key, compile_info_ptr); -} - -CompileInfoCache::CompileInfoCache() {} -CompileInfoCache::~CompileInfoCache() {} - -CompileInfoCache& CompileInfoCache::Instance() { - static CompileInfoCache compile_info_cache_instance; - return compile_info_cache_instance; -} - -bool CompileInfoCache::HasCompileInfo(const std::string &key) { - return this->compile_info_map_.find(key) != this->compile_info_map_.end(); -} - -void* CompileInfoCache::GetCompileInfo(const std::string &key) { - std::lock_guard lock_guard(compile_info_mutex_); - const auto iter = this->compile_info_map_.find(key); - if (iter == this->compile_info_map_.end()) { - return nullptr; - } - return iter->second; -} - -void CompileInfoCache::SetCompileInfo(const std::string &key, void *value) { - std::lock_guard lock_guard(compile_info_mutex_); - (void)this->compile_info_map_.emplace(key, value); -} -} // namespace optiling diff --git a/register/op_tiling/op_compile_info_manager.h b/register/op_tiling/op_compile_info_manager.h deleted file mode 100644 index 5d0ecf6e360c4e54acd4d8d448abadf522b9ea13..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_compile_info_manager.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_OP_TILING_COMPILE_INFO_MANAGER_H_ -#define REGISTER_OP_TILING_COMPILE_INFO_MANAGER_H_ - -#include -#include -#include - -#include "register/op_compile_info_base.h" - -namespace optiling { -class CompileInfoManager { -public: - CompileInfoManager(const CompileInfoManager &) = delete; - CompileInfoManager &operator=(const CompileInfoManager &) = delete; - static CompileInfoManager& Instance(); - bool HasCompileInfo(const std::string &key); - CompileInfoPtr GetCompileInfo(const std::string &key); - void SetCompileInfo(const std::string &key, CompileInfoPtr compile_info_ptr); - -private: - CompileInfoManager(); - ~CompileInfoManager(); - mutable std::mutex compile_info_mutex_; - std::unordered_map compile_info_map_; -}; - -class CompileInfoCache { -public: - CompileInfoCache(const CompileInfoCache &) = delete; - CompileInfoCache &operator=(const CompileInfoCache &) = delete; - static CompileInfoCache& Instance(); - bool HasCompileInfo(const std::string &key); - void* GetCompileInfo(const std::string &key); - void SetCompileInfo(const std::string &key, void* value); - -private: - CompileInfoCache(); - ~CompileInfoCache(); - mutable std::mutex compile_info_mutex_; - std::unordered_map compile_info_map_; -}; -} // namespace optiling -#endif // REGISTER_OP_TILING_COMPILE_INFO_MANAGER_H_ diff --git a/register/op_tiling/op_tiling.cc b/register/op_tiling/op_tiling.cc deleted file mode 100644 index 7d49971b293e8cc194aa9c14123551fdacf2cb6d..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling.cc +++ /dev/null @@ -1,1145 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_tiling.h" - -#include -#include "external/graph/operator.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/anchor_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_compile_info_manager.h" -#include "common/sgt_slice_type.h" -#include "graph/def_types.h" -#include "graph/utils/node_utils_ex.h" - -namespace optiling { -using Status = domi::Status; -using DataBuf = std::tuple; - -class AnyValueBase { -public: - virtual ~AnyValueBase() = default; - virtual DataBuf GetDataBuf() const = 0; -}; - -template -class AnyValue : public AnyValueBase { -public: - explicit AnyValue(const VT &value) : value_(value) {} - virtual ~AnyValue() override = default; - virtual DataBuf GetDataBuf() const override { - return DataBuf(reinterpret_cast(&value_), sizeof(value_)); - } - -private: - VT value_; -}; - -template -class AnyVecValue : public AnyValueBase { -public: - explicit AnyVecValue(const std::vector &value) : value_(std::move(value)) {} - virtual ~AnyVecValue() override = default; - virtual DataBuf GetDataBuf() const override { - return DataBuf(reinterpret_cast(value_.data()), sizeof(VT) * value_.size()); - } - -private: - std::vector value_; -}; - -template -struct Getter; - -template -struct Getter::value>::type> { - using ST = int64_t; - static constexpr bool (*func)(ge::AttrUtils::ConstAttrHolderAdapter &&, const string &, - int64_t &) = ge::AttrUtils::GetInt; - static constexpr bool (*list_func)(ge::AttrUtils::ConstAttrHolderAdapter &&, const string &, - vector &) = ge::AttrUtils::GetListInt; -}; -template -struct Getter::value>::type> { - using ST = float; - static constexpr bool (*func)(ge::AttrUtils::ConstAttrHolderAdapter &&, const string &, - float &) = ge::AttrUtils::GetFloat; - static constexpr bool (*list_func)(ge::AttrUtils::ConstAttrHolderAdapter &&, const string &, - vector &) = ge::AttrUtils::GetListFloat; -}; - -class TeOpVarAttrArgsImpl { -public: - explicit TeOpVarAttrArgsImpl(const ge::OpDescPtr &op_desc) : op_desc_(op_desc){}; - ~TeOpVarAttrArgsImpl() = default; - - Status GetDataByName(const string &name, const string &dtype, DataBuf &data); - -private: - template - Status GetNodeAttrDataIntListList(const std::string &name, DataBuf &data) { - std::vector> value; - const bool res = ge::AttrUtils::GetListListInt(op_desc_, name, value); - if (!res) { - GE_LOGE("Attribute not found: %s", name.c_str()); - return domi::FAILED; - } - - std::vector dest; - for (const auto &vec : value) { - for (const auto elem : vec) { - dest.emplace_back(static_cast(elem)); - } - } - const auto dest_ptr = std::make_shared>(dest); - data_map_.emplace(name + '_' + typeid(T).name(), dest_ptr); - data = dest_ptr->GetDataBuf(); - GELOGI("IntListList Attribute found: %s", name.c_str()); - return domi::SUCCESS; - } - - template::type = true> - Status GetNodeAttrDataTmpl(const std::string &name, DataBuf &data) { - const auto func = Getter::func; - typename Getter::ST value; - const bool res = func(op_desc_, name, value); - if (!res) { - GE_LOGE("Attribute not found: %s", name.c_str()); - return domi::FAILED; - } - - const auto dest_ptr = std::make_shared>(static_cast(value)); - (void)data_map_.emplace(name + '_' + typeid(T).name(), dest_ptr); - data = dest_ptr->GetDataBuf(); - GELOGI("Single Attribute found: %s", name.c_str()); - return domi::SUCCESS; - } - - template::type = true> - Status GetNodeAttrDataTmpl(const std::string &name, DataBuf &data) { - const auto func = Getter::list_func; - std::vector::ST> value; - const bool res = func(op_desc_, name, value); - if (!res) { - GE_LOGE("List Attribute missing: %s", name.c_str()); - return domi::FAILED; - } - - std::vector dest; - for (const auto elem : value) { - dest.emplace_back(static_cast(elem)); - } - const auto dest_ptr = std::make_shared>(dest); - (void)data_map_.emplace(name + '_' + typeid(T).name(), dest_ptr); - data = dest_ptr->GetDataBuf(); - GELOGI("Attribute found: %s", name.c_str()); - return domi::SUCCESS; - } - -private: - static std::map> - data_getter_; - ge::OpDescPtr op_desc_; - std::map> data_map_; -}; - -std::map> - TeOpVarAttrArgsImpl::data_getter_ = {{"Int8", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"Int16", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"Int32", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"Int64", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"UInt8", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"UInt16", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"UInt32", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"UInt64", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"Float", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListInt8", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListInt16", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListInt32", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListInt64", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListUInt8", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListUInt16", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListUInt32", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListUInt64", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}, - {"ListFloat", &TeOpVarAttrArgsImpl::GetNodeAttrDataTmpl}}; - -Status TeOpVarAttrArgsImpl::GetDataByName(const std::string &name, const std::string &dtype, DataBuf &data) { - const auto iter = data_getter_.find(dtype); - if (iter == data_getter_.end()) { - GE_LOGE("wrong dtype: %s", dtype.c_str()); - return domi::FAILED; - } else { - return iter->second(this, name, data); - } -} - -const uint8_t *TeOpVarAttrArgs::GetData(const std::string &name, const std::string &dtype, size_t &size) const { - DataBuf data(nullptr, 0); - const auto rc = impl_->GetDataByName(name, dtype, data); - if (rc == domi::SUCCESS) { - GELOGI("Attribute found: %s, %s, %p, %ld", name.c_str(), dtype.c_str(), std::get<0>(data), std::get<1>(data)); - } - size = std::get<1>(data); - return std::get<0>(data); -} - -class VarAttrHelper { -public: - static bool InitTeOpVarAttr(const ge::OpDescPtr &op_desc_ptr, TeOpVarAttrArgs &attr) { - OP_TILING_MAKE_SHARED(attr.impl_ = std::make_shared(op_desc_ptr), return false); - return true; - } -}; - -bool FeedTeOpTensorArg(ge::OpDesc::Vistor &tensor_desc_vec, - std::vector &tensor_arg, const ge::OpDescPtr &op_desc) { - size_t index = 0U; - for (ge::GeTensorDescPtr &tensor_desc_ptr : tensor_desc_vec) { - TeOpTensorArg arg_tensor; - TeOpTensor tensor; - arg_tensor.arg_type = TensorArgType::TA_SINGLE; - tensor.shape = tensor_desc_ptr->MutableShape().GetDims(); - if (tensor.shape.empty()) { - tensor.shape = {1}; - } - tensor.ori_shape = tensor_desc_ptr->GetOriginShape().GetDims(); - tensor.name = op_desc->GetInputNameByIndex(static_cast(index)); - - const ge::Format primary_format = static_cast(ge::GetPrimaryFormat(tensor_desc_ptr->GetFormat())); - tensor.format = ge::TypeUtils::FormatToSerialString(primary_format); - tensor.ori_format = ge::TypeUtils::FormatToSerialString(tensor_desc_ptr->GetOriginFormat()); - - const ge::DataType dtype = tensor_desc_ptr->GetDataType(); - const auto dataTypeIter = DATATYPE_STRING_MAP.find(dtype); - if (dataTypeIter == DATATYPE_STRING_MAP.end()) { - GE_LOGE("datatype error %d", static_cast(dtype)); - return false; - } - tensor.dtype = dataTypeIter->second; - if (IsLogEnable(GE_MODULE_NAME, DLOG_INFO)) { - std::stringstream shapestr; - shapestr << "shape:["; - for (auto &i : tensor.shape) { - shapestr << i << ","; - } - shapestr << "], ori_shape:["; - for (auto &i : tensor.ori_shape) { - shapestr << i << ","; - } - shapestr << "], format:" << tensor.format; - shapestr << ", ori_format:" << tensor.ori_format; - shapestr << ", dtype: " << tensor.dtype; - GELOGI("calling optiling shape info: %s", shapestr.str().c_str()); - } - - arg_tensor.tensor.emplace_back(tensor); - tensor_arg.emplace_back(arg_tensor); - index++; - } - return true; -} - -void FeedTeOpConstTensor(const ge::Operator &op, const ge::OpDescPtr &op_desc, - std::map &const_inputs) { - std::vector depend_names; - (void)ge::AttrUtils::GetListStr(op_desc, ATTR_NAME_OP_INFER_DEPENDS, depend_names); - for (const std::string &depend : depend_names) { - ge::Tensor data; - const ge::graphStatus rc = op.GetInputConstData(depend.c_str(), data); - GELOGI("GetInputConstData: %s, %d", depend.c_str(), rc); - if (rc != ge::GRAPH_SUCCESS) { - continue; - } - - const uint8_t * const pbuf = data.GetData(); - const size_t buflen = data.GetSize(); - - GELOGI("Const input tensor data: %s, %p %zu", depend.c_str(), pbuf, buflen); - (void)const_inputs.emplace(depend, TeConstTensorData{pbuf, buflen, data}); - } -} - -ge::graphStatus OpParaCalculate(const ge::Operator &op, OpRunInfo &run_info, const OpTilingFunc &tiling_func) { - ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - GELOGI("Do optiling, op_type: %s, op_name: %s", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - TeOpParas op_param; - op_param.op_type = op_desc->GetType(); - (void)VarAttrHelper::InitTeOpVarAttr(op_desc, op_param.var_attrs); - - ge::OpDesc::Vistor inputs = op_desc->GetAllInputsDescPtr(); - if (!FeedTeOpTensorArg(inputs, op_param.inputs, op_desc)) { - GE_LOGE("Do optiling, op_type: %s, op_name: %s", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - return ge::GRAPH_FAILED; - } - ge::OpDesc::Vistor outputs = op_desc->GetAllOutputsDescPtr(); - if (!FeedTeOpTensorArg(outputs, op_param.outputs, op_desc)) { - return ge::GRAPH_FAILED; - } - FeedTeOpConstTensor(op, op_desc, op_param.const_inputs); - - OpCompileInfo op_compile_info; - if (!ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_KEY, op_compile_info.key)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - if (!ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_JSON, op_compile_info.str)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - - const bool ret = (tiling_func)(op_param, op_compile_info, run_info); - if (ret) { - GELOGI("Do optiling succeed. op_type:%s, op_name:%s", - op_desc->GetType().c_str(), op_desc->GetName().c_str()); - } else { - GELOGW("Failed to call tiling function v1 of op [%s, %s].", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - } - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -ge::graphStatus TurnToOpParaCalculateV1(const ge::Operator &op, OpRunInfoV2 &run_info, - const OpTilingFunc &tiling_func) { - OpRunInfo run_info_struct; - run_info_struct.block_dim = run_info.GetBlockDim(); - run_info_struct.clear_atomic = run_info.GetClearAtomic(); - run_info_struct.tiling_key = run_info.GetTilingKey(); - if (OpParaCalculate(op, run_info_struct, tiling_func) != ge::GRAPH_SUCCESS) { - ge::AscendString op_type; - (void)op.GetOpType(op_type); - ge::AscendString op_name; - (void)op.GetName(op_name); - REPORT_INNER_ERR_MSG("E19999", "OpParaCalculate failed, op_type[%s], op_name[%s]", op_type.GetString(), - op_name.GetString()); - return ge::GRAPH_FAILED; - } - - run_info.SetBlockDim(run_info_struct.block_dim); - run_info.SetClearAtomic(run_info_struct.clear_atomic); - run_info.SetTilingKey(run_info_struct.tiling_key); - run_info.InternelSetTiling(run_info_struct.tiling_data); - if (!run_info_struct.workspaces.empty()) { - for (const int64_t &workspace : run_info_struct.workspaces) { - run_info.AddWorkspace(workspace); - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus TurnToOpParaCalculateV2(const ge::Operator &op_param, OpRunInfoV2 &run_info, - const OpTilingFuncV2 &tiling_func) { - const ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_param); - GELOGI("Do optiling, op_type: %s, op_name: %s", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - const std::string *op_compile_info_key = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_KEY); - if (op_compile_info_key == nullptr) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - const std::string *op_compile_info_json = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_JSON); - if (op_compile_info_json == nullptr) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - const OpCompileInfoV2 op_compile_info(*op_compile_info_key, *op_compile_info_json); - - std::vector indexes; - ReplaceEmptyShapeOfTensorDesc(op_desc, indexes); - - const bool ret = (tiling_func)(op_param, op_compile_info, run_info); - if (ret) { - GELOGI("Do optiling v2 succeed. op_type: %s, op_name: %s", - op_desc->GetType().c_str(), op_desc->GetName().c_str()); - } else { - GELOGW("Failed to call the tiling function v2 for op [%s, %s].", - op_desc->GetType().c_str(), op_desc->GetName().c_str()); - } - RecoveryEmptyShapeOfTensorDesc(op_desc, indexes); - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -ge::graphStatus TurnToOpParaCalculateV3(const ge::Operator &op_param, OpRunInfoV2 &run_info, - const OpTilingFuncV3 &tiling_func, const OpParseFuncV3 &parse_func) { - const ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_param); - GELOGI("Do optiling, op_type: %s, op_name: %s", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - const std::string *op_compile_info_key = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_KEY); - if (op_compile_info_key == nullptr) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - void* op_compile_json_ptr = CompileInfoCache::Instance().GetCompileInfo(*op_compile_info_key); - if (op_compile_json_ptr == nullptr) { - const std::string *op_compile_info_json = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_JSON); - if (op_compile_info_json == nullptr) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - const ge::AscendString compile_info_json_str = op_compile_info_json->c_str(); - op_compile_json_ptr = (parse_func)(op_param, compile_info_json_str); - if (op_compile_json_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E19999", "Failed to parse compile json[%s] for op [%s, %s].", op_compile_info_json->c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GE_LOGE("Failed to parse compile json[%s] for op [%s, %s].", op_compile_info_json->c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return ge::GRAPH_FAILED; - } - CompileInfoCache::Instance().SetCompileInfo(*op_compile_info_key, op_compile_json_ptr); - } - - std::vector indexes; - ReplaceEmptyShapeOfTensorDesc(op_desc, indexes); - - const bool ret = (tiling_func)(op_param, op_compile_json_ptr, run_info); - if (ret) { - GELOGI("Do optiling v3 succeed. op_type: %s, op_name: %s", - op_desc->GetType().c_str(), op_desc->GetName().c_str()); - } else { - GELOGW("Failed to call the tiling function v3 for op [%s, %s].", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - } - RecoveryEmptyShapeOfTensorDesc(op_desc, indexes); - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -ge::graphStatus TurnToOpParaCalculateV4(const ge::Operator &op_param, OpRunInfoV2 &run_info, - const OpTilingFuncV4 &tiling_func, const OpParseFuncV4 &parse_func) { - const ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op_param); - GELOGI("Do optiling, op_type: %s, op_name: %s", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - const std::string *op_compile_info_key = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_KEY); - if (op_compile_info_key == nullptr) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - CompileInfoPtr op_compile_info_ptr = CompileInfoManager::Instance().GetCompileInfo(*op_compile_info_key); - if (op_compile_info_ptr == nullptr) { - const std::string *op_compile_info_json = ge::AttrUtils::GetStr(op_desc, COMPILE_INFO_JSON); - if (op_compile_info_json == nullptr) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc->GetName().c_str(), COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - const ge::AscendString compile_info_json_str = op_compile_info_json->c_str(); - op_compile_info_ptr = (parse_func)(op_param, compile_info_json_str); - if (op_compile_info_ptr == nullptr) { - REPORT_INNER_ERR_MSG("E19999", "Failed to parse compile json[%s] for op [%s, %s].", op_compile_info_json->c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GE_LOGE("Failed to parse compile json[%s] for op [%s, %s].", op_compile_info_json->c_str(), - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return ge::GRAPH_FAILED; - } - CompileInfoManager::Instance().SetCompileInfo(*op_compile_info_key, op_compile_info_ptr); - } - - std::vector indexes; - ReplaceEmptyShapeOfTensorDesc(op_desc, indexes); - - const bool ret = (tiling_func)(op_param, op_compile_info_ptr, run_info); - if (ret) { - GELOGI("Do optiling v4 succeed. op_type:%s, op_name:%s", - op_desc->GetType().c_str(), op_desc->GetName().c_str()); - } else { - GELOGW("Failed to call the tiling function v4 of op [%s, %s].", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - } - RecoveryEmptyShapeOfTensorDesc(op_desc, indexes); - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -ge::graphStatus PostProcCalculateV2(const ge::Operator &op, OpRunInfoV2 &run_info) -{ - const ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_CHECK_NOTNULL(op_desc); - const std::vector all_workspaces = op_desc->GetWorkspaceBytes(); - std::vector op_workspaces; - run_info.GetAllWorkspaces(op_workspaces); - const size_t op_work_size = op_workspaces.size(); - if (op_work_size > all_workspaces.size()) { - GELOGW("Op name:%s tiling return workspace number(%zu) large than all workspace num(%zu).", - op_desc->GetName().c_str(), op_work_size, all_workspaces.size()); - return ge::GRAPH_SUCCESS; - } - - if (op_work_size == all_workspaces.size()) { - return ge::GRAPH_SUCCESS; - } - - GELOGD("Op name: %s post proc, op work num: %zu, all work num: %zu.", op_desc->GetName().c_str(), op_work_size, - all_workspaces.size()); - - // mixl2--pass will add additional works after op_workspaces - for (size_t i = op_work_size; i < all_workspaces.size(); ++i) { - op_workspaces.emplace_back(all_workspaces[i]); - } - for (size_t i = 0; i < op_workspaces.size(); ++i) { - GELOGD("Op's workspace: %zu, value: %ld.", i, op_workspaces[i]); - } - run_info.SetWorkspaces(op_workspaces); - return ge::GRAPH_SUCCESS; -} - -OpTilingFuncInfo *GetOpTilingInfo(const ge::OpDescPtr &op_desc) { - if (op_desc == nullptr) { - GE_LOGE("[Get][OpTilingInfo] failed, op_desc is nullptr."); - REPORT_INNER_ERR_MSG("EZ9999", "[Get][OpTilingInfo] failed, op_desc is nullptr."); - return nullptr; - } - if (op_desc->GetTilingFuncInfo() == nullptr) { - const std::string op_type = op_desc->GetType(); - auto &op_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - auto iter = op_func_map.find(op_type); - if (iter == op_func_map.end()) { - GELOGI("The optiling function is not found by op type [%s].", op_type.c_str()); - iter = op_func_map.find(OP_TYPE_AUTO_TILING); - if (iter == op_func_map.end()) { - GE_LOGE("Optiling function of op type[%s] is not found by Autotiling.", op_type.c_str()); - REPORT_INNER_ERR_MSG("EZ9999", "Optiling function not found. op_type[%s].", op_type.c_str()); - return nullptr; - } - } - op_desc->SetTilingFuncInfo(::ge::PtrToPtr(&(iter->second))); - return &(iter->second); - } - return ::ge::PtrToPtr(op_desc->GetTilingFuncInfo()); -} - -void parse_tiling_data(const void* base, const size_t max_size) { - std::stringstream result; - int32_t tmp = 0; - const char* base_addr = static_cast(base); - for (size_t i = 0U; i < max_size; i += sizeof(int32_t)) { - if ((max_size - i) < sizeof(tmp)) { - return; - } - if (memcpy_s(&tmp, sizeof(tmp), base_addr + i, sizeof(tmp)) != EOK) { - return; - } - result << std::to_string(tmp); - result << " "; - } - GELOGD("Parse tiling data %s.", result.str().c_str()); - return; -} - -ge::graphStatus PostProcMemoryCheck(const ge::Operator &op, OpRunInfoV2 &run_info) -{ - const ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - bool value = false; - if (!ge::AttrUtils::GetBool(op_desc, kMemoryCheck, value) || !value) { - return ge::GRAPH_SUCCESS; - } - uint64_t ori_op_para_size = 0; - if (ge::AttrUtils::GetInt(op_desc, kOriOpParaSize, ori_op_para_size)) { - GELOGD("The ori_op_para_size of node [%s] is %lu.", op_desc->GetName().c_str(), ori_op_para_size); - if (!run_info.SetMemCheckBaseOffset(ori_op_para_size)) { - REPORT_INNER_ERR_MSG("E19999", - "[register][op_tiling][PostProcMemoryCheck]Node:%s set mem check offset:%lu failed.", - op_desc->GetName().c_str(), ori_op_para_size); - return ge::GRAPH_FAILED; - } - } else { - run_info.AlignOffsetWith64(); - } - for (size_t i = 0U; i < op_desc->GetAllInputsSize(); ++i) { - const ge::GeTensorDescPtr tensor = op_desc->MutableInputDesc(static_cast(i)); - if (tensor == nullptr) { - continue; - } - int64_t clean_size = 0; - if (ge::TensorUtils::GetSize(*tensor, clean_size) != ge::GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E19999", "[register][op_tiling][PostProcMemoryCheck]Get op:%s tensor:%zu size failed.", - op_desc->GetName().c_str(), i); - return ge::GRAPH_FAILED; - } - GELOGD("Op input tensor: %zu has a size of %ld.", i, clean_size); - run_info.AddTilingData(clean_size); - } - for (size_t j = 0U; j < op_desc->GetOutputsSize(); ++j) { - const ge::GeTensorDescPtr tensor = op_desc->MutableOutputDesc(static_cast(j)); - if (tensor == nullptr) { - continue; - } - int64_t clean_size = 0; - if (ge::TensorUtils::GetSize(*tensor, clean_size) != ge::GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E19999", "[register][op_tiling][PostProcMemoryCheck]Get op:%s tensor:%zu size failed.", - op_desc->GetName().c_str(), j); - return ge::GRAPH_FAILED; - } - GELOGD("Op output tensor: %zu with size %ld.", j, clean_size); - run_info.AddTilingData(clean_size); - } - for (size_t k = 0U; k < run_info.GetWorkspaceNum(); ++k) { - int64_t workspace = 0; - (void)run_info.GetWorkspace(k, workspace); - GELOGD("Op workspace: %zu size is %ld bytes.", k, workspace); - run_info.AddTilingData(workspace); - } - const uint64_t cur_size = run_info.GetTilingDataSize(); - GELOGD("Adding tiling data; current size: %lu.", cur_size); - run_info.AddTilingData(cur_size); - - uint64_t max_size = 0U; - const void* base = run_info.GetAddrBase(max_size); - parse_tiling_data(base, static_cast(max_size)); - return ge::GRAPH_SUCCESS; -} - -extern "C" ge::graphStatus OpParaCalculateV2(const ge::Operator &op, OpRunInfoV2 &run_info) { - const ge::OpDescPtr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - OpTilingFuncInfo *op_func_info = GetOpTilingInfo(op_desc); - if (op_func_info == nullptr) { - GE_LOGE("Optiling function not found."); - REPORT_INNER_ERR_MSG("EZ9999", "Optiling function not found."); - return ge::GRAPH_FAILED; - } - ge::graphStatus ret = ge::GRAPH_FAILED; - if (op_func_info->IsFunctionV4()) { - const OpTilingFuncV4 &tiling_func = op_func_info->GetOpTilingFuncV4(); - const OpParseFuncV4 &parse_func = op_func_info->GetOpParseFuncV4(); - ret = TurnToOpParaCalculateV4(op, run_info, tiling_func, parse_func); - } else if (op_func_info->IsFunctionV3()) { - const OpTilingFuncV3 &tiling_func = op_func_info->GetOpTilingFuncV3(); - const OpParseFuncV3 &parse_func = op_func_info->GetOpParseFuncV3(); - ret = TurnToOpParaCalculateV3(op, run_info, tiling_func, parse_func); - } else if (op_func_info->IsFunctionV2()) { - const OpTilingFuncV2 &tiling_func = op_func_info->GetOpTilingFuncV2(); - ret = TurnToOpParaCalculateV2(op, run_info, tiling_func); - } else if (op_func_info->IsFunctionV1()) { - const OpTilingFunc &tiling_func = op_func_info->GetOpTilingFunc(); - ret = TurnToOpParaCalculateV1(op, run_info, tiling_func); - } else { - GE_LOGE("Optiling function for op type [%s] is entirely empty.", op_desc->GetType().c_str()); - } - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - ret = PostProcCalculateV2(op, run_info); - if (ret == ge::GRAPH_SUCCESS) { - return PostProcMemoryCheck(op, run_info); - } - return ret; -} - -void GenerateCompileInfoKey(const std::vector &workspace_size_list, std::string &op_compile_info_key) { - for (const int64_t &workspace_size : workspace_size_list) { - (void)op_compile_info_key.append(",").append(std::to_string(workspace_size)); - } -} - -ge::graphStatus AssembleCompileInfoJson(const ge::OpDescPtr &op_desc_ptr, - const std::vector &workspace_size_list, - std::string &op_compile_info_json) { - nlohmann::json compile_info_json; - try { - compile_info_json = nlohmann::json::parse(op_compile_info_json); - } catch (nlohmann::json::parse_error& ex) { - REPORT_INNER_ERR_MSG("E19999", - "Failed to set compile_info_value to the json format for op[%s]. op_compile_info_json: %s", - op_desc_ptr->GetName().c_str(), op_compile_info_json.c_str()); - GE_LOGE("Failed to set compile_info_value to the json format for op[%s]. op_compile_info_json: %s", - op_desc_ptr->GetName().c_str(), op_compile_info_json.c_str()); - return ge::GRAPH_FAILED; - } - for (const int64_t &workspace_size : workspace_size_list) { - compile_info_json[COMPILE_INFO_WORKSPACE_SIZE_LIST].push_back(workspace_size); - } - op_compile_info_json = compile_info_json.dump(); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus AssembleWorkspaceList(const ge::OpDescPtr &op_desc_ptr, - int64_t &first_clean_size, - std::vector &workspace_size_list) { - std::vector atomic_output_indices; - (void) ge::AttrUtils::GetListInt(op_desc_ptr, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - std::map> atomic_workspace_info; - atomic_workspace_info = op_desc_ptr->TryGetExtAttr(ge::EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_info); - const bool atomic_flag = atomic_output_indices.empty() && atomic_workspace_info.empty(); - if (atomic_flag) { - GE_LOGE("Do not find ATOMIC_ATTR_OUTPUT_INDEX and EXT_ATTR_ATOMIC_WORKSPACE_INFO, op_type:%s, op_name:%s", - OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str(), op_desc_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - - if (!atomic_output_indices.empty()) { - bool is_first_index = true; - for (const int64_t &atomic_output_indice : atomic_output_indices) { - const ge::ConstGeTensorDescPtr tensor = - op_desc_ptr->GetOutputDescPtr(static_cast(atomic_output_indice)); - if (tensor == nullptr) { - GE_LOGE("Failed to get atomic_output_indice. op_type: %s, op_name: %s", - OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str(), op_desc_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - - int64_t clean_size = 0; - if (ge::TensorUtils::GetSize(*tensor, clean_size) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to get size of tensor desc. op_type: %s, op_name: %s", - OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str(), op_desc_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - workspace_size_list.push_back(clean_size); - if (is_first_index) { - first_clean_size = clean_size; - is_first_index = false; - } - } - } - GELOGI("Atomic clean size: %ld, op_name:%s", first_clean_size, op_desc_ptr->GetName().c_str()); - - if (!atomic_workspace_info.empty()) { - const std::vector workspace_bytes = op_desc_ptr->GetWorkspaceBytes(); - const std::map workspace_bytes_map = atomic_workspace_info[op_desc_ptr->GetName()]; - for (auto &workspace_idxs : workspace_bytes_map) { - if (workspace_idxs.first < static_cast(workspace_bytes.size())) { - workspace_size_list.push_back(static_cast( - workspace_bytes[static_cast(workspace_idxs.first)])); - } - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus OpAtomicCalculateV1(const ge::OpDescPtr &op_desc_ptr, OpRunInfo &run_info, - const OpTilingFunc &tiling_func) { - GELOGI("Begin to perform atomic optiling. op_type: %s, op_name: %s", - OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str(), op_desc_ptr->GetName().c_str()); - OpCompileInfo op_compile_info; - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_KEY, op_compile_info.key)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_JSON, op_compile_info.str)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - - int64_t first_clean_size = 0; - std::vector workspace_size_list; - if (AssembleWorkspaceList(op_desc_ptr, first_clean_size, workspace_size_list) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to retrieve the workspace size list from op[%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - - TeOpParas op_param; - op_param.op_type = OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN; - (void)op_param.const_inputs.emplace("workspace_size", - TeConstTensorData(nullptr, static_cast(first_clean_size), ge::Tensor())); - - GenerateCompileInfoKey(workspace_size_list, op_compile_info.key); - if (AssembleCompileInfoJson(op_desc_ptr, workspace_size_list, op_compile_info.str) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to assemble compile info json for op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - - const bool ret = (tiling_func)(op_param, op_compile_info, run_info); - if (ret) { - GELOGI("Atomic optiling v1 operation succeeded. op_type: %s, op_name: %s.", - op_desc_ptr->GetType().c_str(), op_desc_ptr->GetName().c_str()); - } else { - GELOGW("Failed to call the tiling v1 function of atomic op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - } - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -ge::graphStatus TurnToOpAtomicCalculateV1(const ge::OpDescPtr &op_desc_ptr, OpRunInfoV2 &run_info, - const OpTilingFunc &tiling_func) { - OpRunInfo run_info_struct; - run_info_struct.block_dim = run_info.GetBlockDim(); - run_info_struct.clear_atomic = run_info.GetClearAtomic(); - run_info_struct.tiling_key = run_info.GetTilingKey(); - if (OpAtomicCalculateV1(op_desc_ptr, run_info_struct, tiling_func) != ge::GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E19999", "Do OpAtomicCalculateV1 failed, op_type[%s], op_name[%s]", - op_desc_ptr->GetType().c_str(), op_desc_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - run_info.InternelSetTiling(run_info_struct.tiling_data); - run_info.SetBlockDim(run_info_struct.block_dim); - run_info.SetClearAtomic(run_info_struct.clear_atomic); - run_info.SetTilingKey(run_info_struct.tiling_key); - if (!run_info_struct.workspaces.empty()) { - for (const int64_t &workspace : run_info_struct.workspaces) { - run_info.AddWorkspace(workspace); - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus AssembleWorkspaceList(const ge::OpDescPtr &op_desc_ptr, - std::vector &workspace_list, - std::vector &workspace_size_list) { - std::vector atomic_output_indices; - (void) ge::AttrUtils::GetListInt(op_desc_ptr, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - std::map> atomic_workspace_info; - atomic_workspace_info = op_desc_ptr->TryGetExtAttr(ge::EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_info); - const bool atomic_flag = atomic_output_indices.empty() && atomic_workspace_info.empty(); - if (atomic_flag) { - REPORT_INNER_ERR_MSG("E19999", - "No ATOMIC_ATTR_OUTPUT_INDEX and EXT_ATTR_ATOMIC_WORKSPACE_INFO found,op_type:%s, op_name:%s", - OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str(), op_desc_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - - if (!atomic_output_indices.empty()) { - bool is_first_index = true; - for (const int64_t &atomic_output_indice : atomic_output_indices) { - const ge::ConstGeTensorDescPtr tensor = - op_desc_ptr->GetOutputDescPtr(static_cast(atomic_output_indice)); - if (tensor == nullptr) { - REPORT_INNER_ERR_MSG("E19999", "Get MutableOutputDesc failed. op_type:%s, op_name:%s", - OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str(), op_desc_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - int64_t clean_size = 0; - if (ge::TensorUtils::GetSize(*tensor, clean_size) != ge::GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E19999", "Get size of tensor desc failed. op_type:%s, op_name:%s", - OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str(), op_desc_ptr->GetName().c_str()); - return ge::GRAPH_FAILED; - } - workspace_size_list.push_back(clean_size); - if (is_first_index) { - workspace_list.push_back(clean_size); - is_first_index = false; - } - } - } - - if (!atomic_workspace_info.empty()) { - const std::vector workspace_bytes = op_desc_ptr->GetWorkspaceBytes(); - const std::map workspace_bytes_map = atomic_workspace_info[op_desc_ptr->GetName()]; - for (auto &workspace_idxs : workspace_bytes_map) { - if (workspace_idxs.first < static_cast(workspace_bytes.size())) { - workspace_size_list.push_back(workspace_bytes[workspace_idxs.first]); - workspace_list.push_back(static_cast(workspace_bytes[static_cast(workspace_idxs.first)])); - } - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus TurnToOpAtomicCalculateV2(const ge::OpDescPtr &op_desc_ptr, OpRunInfoV2 &run_info, - const OpTilingFuncV2 &tiling_func) { - GELOGI("Begin atomic optiling V2 for op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - std::vector workspace_list; - std::vector workspace_size_list; - if (AssembleWorkspaceList(op_desc_ptr, workspace_list, workspace_size_list) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to retrieve the workspace size list from op[%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - ge::Operator op_param(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str()); - (void)op_param.SetAttr(ATTR_NAME_ATOMIC_CLEAN_WORKSPACE.c_str(), workspace_list); - - std::string op_compile_info_key; - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_KEY, op_compile_info_key)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - std::string op_compile_info_json; - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_JSON, op_compile_info_json)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - GenerateCompileInfoKey(workspace_size_list, op_compile_info_key); - if (AssembleCompileInfoJson(op_desc_ptr, workspace_size_list, op_compile_info_json) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to assemble compile info json for op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - const OpCompileInfoV2 op_compile_info(op_compile_info_key, op_compile_info_json); - const bool ret = (tiling_func)(op_param, op_compile_info, run_info); - if (ret) { - GELOGI("Atomic optiling v2 operation succeeded. op_type: %s, op_name: %s.", - op_desc_ptr->GetType().c_str(), op_desc_ptr->GetName().c_str()); - } else { - GELOGW("Failed to call the tiling v2 function of atomic op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - } - op_param.BreakConnect(); - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -ge::graphStatus TurnToOpAtomicCalculateV3(const ge::OpDescPtr &op_desc_ptr, OpRunInfoV2 &run_info, - const OpTilingFuncV3 &tiling_func, const OpParseFuncV3 &parse_func) { - GELOGI("Begin Atomic optiling V3 for op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - std::vector workspace_list; - std::vector workspace_size_list; - if (AssembleWorkspaceList(op_desc_ptr, workspace_list, workspace_size_list) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to retrieve workspace list from op[%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - ge::Operator op_param(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str()); - (void)op_param.SetAttr(ATTR_NAME_ATOMIC_CLEAN_WORKSPACE.c_str(), workspace_list); - - std::string op_compile_info_key; - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_KEY, op_compile_info_key)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - GenerateCompileInfoKey(workspace_size_list, op_compile_info_key); - void* op_compile_json_ptr = CompileInfoCache::Instance().GetCompileInfo(op_compile_info_key); - if (op_compile_json_ptr == nullptr) { - std::string op_compile_info_json; - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_JSON, op_compile_info_json)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - if (AssembleCompileInfoJson(op_desc_ptr, workspace_size_list, op_compile_info_json) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to assemble compile info json for op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - - const ge::AscendString compile_info_json_str = op_compile_info_json.c_str(); - op_compile_json_ptr = (parse_func)(op_param, compile_info_json_str); - if (op_compile_json_ptr == nullptr) { - GE_LOGE("Failed to parse compile json [%s] for op [%s, %s].", op_compile_info_json.c_str(), - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - CompileInfoCache::Instance().SetCompileInfo(op_compile_info_key, op_compile_json_ptr); - } - - const bool ret = (tiling_func)(op_param, op_compile_json_ptr, run_info); - if (ret) { - GELOGI("Atomic optiling v3 succeeded. op_type: %s, op_name: %s.", - op_desc_ptr->GetType().c_str(), op_desc_ptr->GetName().c_str()); - } else { - GELOGW("Failed to call the tiling v3 function for atomic op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - } - op_param.BreakConnect(); - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -ge::graphStatus TurnToOpAtomicCalculateV4(const ge::OpDescPtr &op_desc_ptr, OpRunInfoV2 &run_info, - const OpTilingFuncV4 &tiling_func, const OpParseFuncV4 &parse_func) { - GELOGI("Begin Atomic optiling V4 for op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - std::vector workspace_list; - std::vector workspace_size_list; - if (AssembleWorkspaceList(op_desc_ptr, workspace_list, workspace_size_list) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to retrieve workspace list from op[%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - ge::Operator op_param(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN.c_str()); - (void)op_param.SetAttr(ATTR_NAME_ATOMIC_CLEAN_WORKSPACE.c_str(), workspace_list); - - std::string op_compile_info_key; - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_KEY, op_compile_info_key)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_KEY.c_str()); - return ge::GRAPH_FAILED; - } - GenerateCompileInfoKey(workspace_size_list, op_compile_info_key); - CompileInfoPtr op_compile_info_ptr = CompileInfoManager::Instance().GetCompileInfo(op_compile_info_key); - if (op_compile_info_ptr == nullptr) { - std::string op_compile_info_json; - if (!ge::AttrUtils::GetStr(op_desc_ptr, ATOMIC_COMPILE_INFO_JSON, op_compile_info_json)) { - GE_LOGE("Op [%s] does not have attribute [%s].", op_desc_ptr->GetName().c_str(), ATOMIC_COMPILE_INFO_JSON.c_str()); - return ge::GRAPH_FAILED; - } - if (AssembleCompileInfoJson(op_desc_ptr, workspace_size_list, op_compile_info_json) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to assemble compile info json for op [%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - - const ge::AscendString compile_info_json_str = op_compile_info_json.c_str(); - op_compile_info_ptr = (parse_func)(op_param, compile_info_json_str); - if (op_compile_info_ptr == nullptr) { - GE_LOGE("Failed to parse compile json [%s] for op [%s, %s].", op_compile_info_json.c_str(), - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - return ge::GRAPH_FAILED; - } - CompileInfoManager::Instance().SetCompileInfo(op_compile_info_key, op_compile_info_ptr); - } - - const bool ret = (tiling_func)(op_param, op_compile_info_ptr, run_info); - if (ret) { - GELOGI("Atomic optiling v4 succeeded. op_type: %s, op_name: %s.", - op_desc_ptr->GetType().c_str(), op_desc_ptr->GetName().c_str()); - } else { - GELOGW("Failed to call the tiling v4 function of atomic op[%s, %s].", - op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str()); - } - op_param.BreakConnect(); - return ret ? ge::GRAPH_SUCCESS : ge::GRAPH_FAILED; -} - -OpTilingFuncInfo *GetOpAtomicTilingInfo(const ge::OpDescPtr &op_desc) { - if (op_desc == nullptr) { - return nullptr; - } - if (op_desc->GetAtomicTilingFuncInfo() == nullptr) { - auto &op_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - const auto iter = op_func_map.find(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - if (iter == op_func_map.end()) { - GE_LOGE("Atomic optiling func not found of op[%s, %s].", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return nullptr; - } - op_desc->SetAtomicTilingFuncInfo(::ge::PtrToPtr(&(iter->second))); - return &(iter->second); - } - return ::ge::PtrToPtr(op_desc->GetAtomicTilingFuncInfo()); -} - -extern "C" ge::graphStatus OpAtomicCalculateV2(const ge::Node &node, OpRunInfoV2 &run_info) { - const ge::OpDescPtr op_desc_ptr = node.GetOpDesc(); - OpTilingFuncInfo *op_func_info = GetOpAtomicTilingInfo(op_desc_ptr); - GE_CHECK_NOTNULL(op_func_info); - ge::graphStatus status = ge::GRAPH_FAILED; - if (op_func_info->IsFunctionV4()) { - const OpTilingFuncV4 &tiling_func = op_func_info->GetOpTilingFuncV4(); - const OpParseFuncV4 &parse_func = op_func_info->GetOpParseFuncV4(); - status = TurnToOpAtomicCalculateV4(op_desc_ptr, run_info, tiling_func, parse_func); - } else if (op_func_info->IsFunctionV3()) { - const OpTilingFuncV3 &tiling_func = op_func_info->GetOpTilingFuncV3(); - const OpParseFuncV3 &parse_func = op_func_info->GetOpParseFuncV3(); - status = TurnToOpAtomicCalculateV3(op_desc_ptr, run_info, tiling_func, parse_func); - } else if (op_func_info->IsFunctionV2()) { - const OpTilingFuncV2 &tiling_func = op_func_info->GetOpTilingFuncV2(); - status = TurnToOpAtomicCalculateV2(op_desc_ptr, run_info, tiling_func); - } else if (op_func_info->IsFunctionV1()) { - const OpTilingFunc &tiling_func = op_func_info->GetOpTilingFunc(); - status = TurnToOpAtomicCalculateV1(op_desc_ptr, run_info, tiling_func); - } else { - GE_LOGE("Optiling function for op type [%s] is entirely empty.", op_desc_ptr->GetType().c_str()); - } - return status; -} - -ge::graphStatus UpDateNodeShapeBySliceInfo(const ffts::ThreadSliceMapDyPtr slice_info_ptr, const ge::OpDescPtr op_desc, - const uint32_t thread_id, vector &ori_shape, - bool &same_shape) -{ - if ((thread_id >= slice_info_ptr->input_tensor_slice.size()) - || (thread_id >= slice_info_ptr->output_tensor_slice.size())) { - REPORT_INNER_ERR_MSG("E19999", "Update node shape thread id(%u) err.", thread_id); - return ge::GRAPH_FAILED; - } - ge::GeTensorDescPtr tensor_ptr = nullptr; - for (auto &index : slice_info_ptr->input_tensor_indexes) { - tensor_ptr = op_desc->MutableInputDesc(index); - GE_CHECK_NOTNULL(tensor_ptr); - ge::GeShape& shape = tensor_ptr->MutableShape(); - auto &tmp_dim = slice_info_ptr->input_tensor_slice[static_cast(thread_id)][index]; - if (tmp_dim.empty()) { - return ge::GRAPH_FAILED; - } - if (thread_id == 0U) { - ori_shape.emplace_back(shape.GetDim(0)); - auto &tail_dim = slice_info_ptr->input_tensor_slice[slice_info_ptr->slice_instance_num - 1][index]; - if (tail_dim.empty()) { - return ge::GRAPH_FAILED; - } - if (tail_dim[0] != tmp_dim[0]) { - same_shape = false; - } - } - (void)shape.SetDim(0, tmp_dim[0]); - } - for (auto &index : slice_info_ptr->output_tensor_indexes) { - tensor_ptr = op_desc->MutableOutputDesc(index); - GE_CHECK_NOTNULL(tensor_ptr); - ge::GeShape& shape = tensor_ptr->MutableShape(); - if (thread_id == 0U) { - ori_shape.emplace_back(shape.GetDim(0)); - } - auto &tmp_dim = slice_info_ptr->output_tensor_slice[static_cast(thread_id)][index]; - if (tmp_dim.empty()) { - return ge::GRAPH_FAILED; - } - (void)shape.SetDim(0, tmp_dim[0]); - GELOGD("Output anchor: %u set dim 0 to %ld", index, tmp_dim[0]); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus UpDateNodeShapeBack(const ge::OpDescPtr op_desc, const ffts::ThreadSliceMapDyPtr slice_info_ptr, - vector &ori_shape) -{ - if (slice_info_ptr == nullptr || - ori_shape.size() != (slice_info_ptr->input_tensor_indexes.size() + - slice_info_ptr->output_tensor_indexes.size())) { - REPORT_INNER_ERR_MSG("E19999", "Update back node shape size err."); - return ge::GRAPH_FAILED; - } - size_t idx = 0; - for (auto &index : slice_info_ptr->input_tensor_indexes) { - ge::GeTensorDescPtr tensor_ptr = op_desc->MutableInputDesc(index); - GE_CHECK_NOTNULL(tensor_ptr); - ge::GeShape& shape = tensor_ptr->MutableShape(); - (void)shape.SetDim(0, ori_shape[idx++]); - } - for (auto &index : slice_info_ptr->output_tensor_indexes) { - ge::GeTensorDescPtr tensor_ptr = op_desc->MutableOutputDesc(index); - GE_CHECK_NOTNULL(tensor_ptr); - ge::GeShape& shape = tensor_ptr->MutableShape(); - (void)shape.SetDim(0, ori_shape[idx++]); - } - GELOGD("Node shape update reverted successfully."); - return ge::GRAPH_SUCCESS; -} - -// For FFTS+ dynamic shape -extern "C" ge::graphStatus OpFftsPlusCalculate(const ge::Operator &op, std::vector &op_run_info) -{ - const auto node = ge::NodeUtilsEx::GetNodeFromOperator(op); - GE_CHECK_NOTNULL(node); - const auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GELOGD("[OpFftsPlusCalculate]Op_type:%s, op_name:%s", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - ffts::ThreadSliceMapDyPtr slice_info_ptr = nullptr; - slice_info_ptr = op_desc->TryGetExtAttr(ffts::kAttrSgtStructInfoDy, slice_info_ptr); - GE_CHECK_NOTNULL(slice_info_ptr); - if (slice_info_ptr->slice_instance_num != slice_info_ptr->input_tensor_slice.size() || - slice_info_ptr->slice_instance_num != slice_info_ptr->output_tensor_slice.size()) { - REPORT_INNER_ERR_MSG("E19999", "Slice num not equal."); - return ge::GRAPH_FAILED; - } - vector ori_shape; // save original shape - uint32_t thread_id = 0U; - op_run_info.resize(ffts::kSgtTillingNum); - bool same_shape = true; - for (size_t i = 0U; i < static_cast(ffts::kSgtTillingNum); i++) { - // update node shape by thread slice info - if (UpDateNodeShapeBySliceInfo(slice_info_ptr, op_desc, thread_id, ori_shape, same_shape) == ge::GRAPH_FAILED) { - REPORT_INNER_ERR_MSG("E19999", "Update shape failed."); - return ge::GRAPH_FAILED; - } - // call original interface - const ge::graphStatus rc = OpParaCalculateV2(op, op_run_info[i]); - if (rc != ge::GRAPH_SUCCESS) { - REPORT_INNER_ERR_MSG("E19999", "OpParaCalculateV2 failed, op_type:%s, op_name:%s", op_desc->GetType().c_str(), - op_desc->GetName().c_str()); - return rc; - } - if (same_shape) { - op_run_info[1] = op_run_info[0]; - break; - } - thread_id = slice_info_ptr->slice_instance_num - 1U; - } - // node shape write_back - (void)UpDateNodeShapeBack(op_desc, slice_info_ptr, ori_shape); - return ge::GRAPH_SUCCESS; -} -} // namespace optiling diff --git a/register/op_tiling/op_tiling_attr_utils.cc b/register/op_tiling/op_tiling_attr_utils.cc deleted file mode 100644 index 2cbc6853a16609ca8c8e1e43a6155b785fee7079..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling_attr_utils.cc +++ /dev/null @@ -1,404 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_tiling_attr_utils.h" -#include -#include -#include "graph/debug/ge_log.h" -#include "op_tiling/op_tiling_utils.h" -#include "common/util/tiling_utils.h" - -namespace optiling { -template -class AttrDataImpl : public AttrData { -public: - explicit AttrDataImpl(std::vector &value) : data_value_(std::move(value)) {} - ~AttrDataImpl() override {} - size_t GetSize() const override { - return data_value_.size() * sizeof(T); - } - const std::uint8_t *GetData() override { - if (data_value_.empty()) { - return nullptr; - } else { - return reinterpret_cast(data_value_.data()); - } - } - -private: - std::vector data_value_; -}; - -enum class AttrDataType { - BOOL, - STRING, - INT32, - UINT32, - FLOAT32, - FLOAT16, - BFLOAT16, - LIST_BOOL, - LIST_STRING, - LIST_INT32, - LIST_UINT32, - LIST_FLOAT32, - LIST_FLOAT16, - LIST_LIST_INT32 -}; - -static const std::map kAttrDataTypeMap { - {"bool", AttrDataType::BOOL}, - {"str", AttrDataType::STRING}, - {"string", AttrDataType::STRING}, - {"int", AttrDataType::INT32}, - {"int32", AttrDataType::INT32}, - {"uint", AttrDataType::UINT32}, - {"uint32", AttrDataType::UINT32}, - {"float", AttrDataType::FLOAT32}, - {"float32", AttrDataType::FLOAT32}, - {"float16", AttrDataType::FLOAT16}, - {"bfloat16", AttrDataType::BFLOAT16}, - {"list_int", AttrDataType::LIST_INT32}, - {"list_int32", AttrDataType::LIST_INT32}, - {"list_uint", AttrDataType::LIST_UINT32}, - {"list_uint32", AttrDataType::LIST_UINT32}, - {"list_float16", AttrDataType::LIST_FLOAT16}, - {"list_float", AttrDataType::LIST_FLOAT32}, - {"list_float32", AttrDataType::LIST_FLOAT32} -}; - -static const std::vector kValidSrcDTypeList { - AttrDataType::BOOL, - AttrDataType::STRING, - AttrDataType::INT32, - AttrDataType::FLOAT32, - AttrDataType::LIST_INT32, - AttrDataType::LIST_FLOAT32 -}; - -static const std::vector kValidDstDTypeList { - AttrDataType::UINT32, - AttrDataType::INT32, - AttrDataType::FLOAT32, - AttrDataType::FLOAT16, - AttrDataType::BFLOAT16, - AttrDataType::LIST_INT32, - AttrDataType::LIST_UINT32, - AttrDataType::LIST_FLOAT16, - AttrDataType::LIST_FLOAT32 -}; - -static const uint32_t kBitsOfByte = 8; - -inline uint32_t GenerateAttrFuncKey(const AttrDataType attr_dtype) { - return ((static_cast(attr_dtype) & 0xFFU) << kBitsOfByte) | (static_cast(attr_dtype) & 0xFFU); -} - -inline uint32_t GenerateAttrFuncKey(const AttrDataType src_dtype, const AttrDataType dest_dtype) { - return ((static_cast(src_dtype) & 0xFFU) << kBitsOfByte) | (static_cast(dest_dtype) & 0xFFU); -} - -class AttrDataManager; -using GetOpAttrValueFunc = std::function; - -class AttrDataManager { -public: - AttrDataManager(const AttrDataManager &) = delete; - AttrDataManager &operator=(const AttrDataManager &) = delete; - static AttrDataManager& Instance() { - static AttrDataManager attr_data_manager; - return attr_data_manager; - } - - bool VerifyAttrDtype(const AttrDataType src_dtype, const AttrDataType dst_dtype) { - const uint32_t func_key = GenerateAttrFuncKey(src_dtype, dst_dtype); - const auto iter = attr_func_.find(func_key); - return iter != attr_func_.end(); - } - - AttrDataPtr GetOpAttrValue(const ge::Operator &op, const char *attr_name, const AttrDataType src_dtype, - const AttrDataType dst_dtype) { - const uint32_t func_key = GenerateAttrFuncKey(src_dtype, dst_dtype); - const auto iter = attr_func_.find(func_key); - if (iter == attr_func_.end()) { - return nullptr; - } - return iter->second(this, op, attr_name); - } - -private: - AttrDataManager() {} - ~AttrDataManager() {} - template::type = true> - AttrDataPtr GetAttrValue(const ge::Operator &op, const char *attr_name) const { - T attr_value; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - std::vector attr_vec; - attr_vec.push_back(attr_value); - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - template::type = true> - AttrDataPtr GetAttrValue(const ge::Operator &op, const char *attr_name) const { - std::vector attr_vec; - if (op.GetAttr(attr_name, attr_vec) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - if (attr_vec.empty()) { - GELOGW("The vector value of attribute [%s] is empty.", attr_name); - return nullptr; - } - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetBoolAttrValue(const ge::Operator &op, const char *attr_name) const { - bool attr_value = false; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - std::vector attr_vec; - attr_vec.push_back(static_cast(attr_value)); - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetStrAttrValue(const ge::Operator &op, const char *attr_name) const { - std::string attr_value; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - std::vector attr_vec; - for (const char &c : attr_value) { - attr_vec.push_back(c); - } - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetIntAttrValueAndToUint(const ge::Operator &op, const char *attr_name) const { - int32_t attr_value = 0; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - std::vector attr_vec; - attr_vec.push_back(static_cast(attr_value)); - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetListIntAttrValueAndToListUint(const ge::Operator &op, const char *attr_name) const { - std::vector attr_value; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - if (attr_value.empty()) { - GELOGW("The vector value of attribute [%s] is empty.", attr_name); - return nullptr; - } - - std::vector attr_vec; - for (const int32_t &int_value : attr_value) { - attr_vec.push_back(static_cast(int_value)); - } - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetFloatAttrValueAndToFp16(const ge::Operator &op, const char *attr_name) const { - float attr_value = 0.0f; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - std::vector attr_vec; - attr_vec.push_back(Float32ToFloat16(attr_value)); - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetFloatAttrValueAndToBf16(const ge::Operator &op, const char *attr_name) const { - float attr_value = 0.0F; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - std::vector attr_vec; - attr_vec.push_back(Float32ToBfloat16(attr_value)); - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetFloatAttrValueAndToInt(const ge::Operator &op, const char *attr_name) const { - float attr_value = 0.0f; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - std::vector attr_vec; - attr_vec.push_back(static_cast(attr_value)); - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetListFloatAttrValueAndToListFp16(const ge::Operator &op, const char *attr_name) const { - std::vector attr_value; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - if (attr_value.empty()) { - GELOGW("The vector value of attribute [%s] is empty.", attr_name); - return nullptr; - } - - std::vector attr_vec; - for (const float &fp_value : attr_value) { - attr_vec.push_back(Float32ToFloat16(fp_value)); - } - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - AttrDataPtr GetListFloatAttrValueAndToListInt(const ge::Operator &op, const char *attr_name) const { - std::vector attr_value; - if (op.GetAttr(attr_name, attr_value) != ge::GRAPH_SUCCESS) { - GELOGW("Failed to retrieve attribute [%s] from op.", attr_name); - return nullptr; - } - - if (attr_value.empty()) { - GELOGW("The vector value of attribute [%s] is empty.", attr_name); - return nullptr; - } - - std::vector attr_vec; - for (const float &fp_value : attr_value) { - attr_vec.push_back(static_cast(fp_value)); - } - - AttrDataPtr attr_data_ptr = nullptr; - OP_TILING_MAKE_SHARED(attr_data_ptr = std::make_shared>(attr_vec), return nullptr); - return attr_data_ptr; - } - - static const std::map attr_func_; -}; - -const std::map AttrDataManager::attr_func_ = { - {GenerateAttrFuncKey(AttrDataType::BOOL), &AttrDataManager::GetBoolAttrValue}, - {GenerateAttrFuncKey(AttrDataType::INT32), &AttrDataManager::GetAttrValue}, - {GenerateAttrFuncKey(AttrDataType::FLOAT32), &AttrDataManager::GetAttrValue}, - {GenerateAttrFuncKey(AttrDataType::LIST_INT32), &AttrDataManager::GetAttrValue}, - {GenerateAttrFuncKey(AttrDataType::LIST_FLOAT32), &AttrDataManager::GetAttrValue}, - {GenerateAttrFuncKey(AttrDataType::STRING), &AttrDataManager::GetStrAttrValue}, - {GenerateAttrFuncKey(AttrDataType::INT32, AttrDataType::UINT32), - &AttrDataManager::GetIntAttrValueAndToUint}, - {GenerateAttrFuncKey(AttrDataType::LIST_INT32, AttrDataType::LIST_UINT32), - &AttrDataManager::GetListIntAttrValueAndToListUint}, - {GenerateAttrFuncKey(AttrDataType::FLOAT32, AttrDataType::FLOAT16), - &AttrDataManager::GetFloatAttrValueAndToFp16}, - {GenerateAttrFuncKey(AttrDataType::FLOAT32, AttrDataType::BFLOAT16), - &AttrDataManager::GetFloatAttrValueAndToBf16}, - {GenerateAttrFuncKey(AttrDataType::LIST_FLOAT32, AttrDataType::LIST_FLOAT16), - &AttrDataManager::GetListFloatAttrValueAndToListFp16}, - {GenerateAttrFuncKey(AttrDataType::FLOAT32, AttrDataType::INT32), - &AttrDataManager::GetFloatAttrValueAndToInt}, - {GenerateAttrFuncKey(AttrDataType::LIST_FLOAT32, AttrDataType::LIST_INT32), - &AttrDataManager::GetListFloatAttrValueAndToListInt} -}; - -ge::graphStatus GetOperatorAttrValue(const ge::Operator &op, const char *attr_name, const char *attr_dtype, - AttrDataPtr &attr_data_ptr, const char *target_dtype) { - ge::AscendString op_name; - ge::AscendString op_type; - (void)op.GetName(op_name); - (void)op.GetOpType(op_type); - if (attr_name == nullptr || attr_dtype == nullptr) { - GE_LOGE("Attribute name or attribute data type is null for op [%s].", op_name.GetString()); - return ge::GRAPH_FAILED; - } - - GELOGD("Begin to retrieve attribute [%s] of data type [%s] from op [%s, %s].", - attr_name, attr_dtype, op_name.GetString(), op_type.GetString()); - const std::string attr_dtype_str = attr_dtype; - auto iter = kAttrDataTypeMap.find(attr_dtype_str); - if (iter == kAttrDataTypeMap.end()) { - GELOGW("Attr data type [%s] for attribute [%s] is not supported.", attr_dtype, attr_name); - return ge::GRAPH_FAILED; - } - const AttrDataType src_dtype = iter->second; - if (std::find(kValidSrcDTypeList.begin(), kValidSrcDTypeList.end(), src_dtype) == kValidSrcDTypeList.end()) { - GELOGW("Attr data type [%s] for attribute [%s] is not supported.", attr_dtype, attr_name); - return ge::GRAPH_FAILED; - } - AttrDataType dst_dtype = src_dtype; - if (target_dtype != nullptr) { - GELOGD("Attempting to retrieve attribute [%s] and transform its value from [%s] to [%s].", - attr_name, attr_dtype, target_dtype); - const std::string target_dtype_str = target_dtype; - iter = kAttrDataTypeMap.find(target_dtype_str); - if (iter == kAttrDataTypeMap.end()) { - GELOGW("Target attr data type[%s] of attr[%s] is not supported.", target_dtype, attr_name); - return ge::GRAPH_FAILED; - } - dst_dtype = iter->second; - if (dst_dtype != src_dtype) { - if (std::find(kValidDstDTypeList.begin(), kValidDstDTypeList.end(), dst_dtype) == kValidDstDTypeList.end()) { - GELOGW("Target attr data type[%s] of attr[%s] is not supported.", target_dtype, attr_name); - return ge::GRAPH_FAILED; - } - if (!AttrDataManager::Instance().VerifyAttrDtype(src_dtype, dst_dtype)) { - GELOGW("Get attr[%s] and transform from [%s] to [%s] is not supported.", - attr_name, attr_dtype, target_dtype); - return ge::GRAPH_FAILED; - } - } - } - attr_data_ptr = AttrDataManager::Instance().GetOpAttrValue(op, attr_name, src_dtype, dst_dtype); - GELOGD("Finished getting attr [%s] of data type [%s] from op [%s, %s].", - attr_name, attr_dtype, op_name.GetString(), op_type.GetString()); - return attr_data_ptr == nullptr ? ge::GRAPH_FAILED : ge::GRAPH_SUCCESS; -} -} // namespace optiling diff --git a/register/op_tiling/op_tiling_constants.h b/register/op_tiling/op_tiling_constants.h deleted file mode 100644 index ec77ece828fcb5fc133a1e0a9aee17141c6e9fd9..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling_constants.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_OP_TILING_OP_TILING_CONSTANTS_H_ -#define REGISTER_OP_TILING_OP_TILING_CONSTANTS_H_ - -#include -#include -#include "graph/types.h" - -namespace optiling { -const std::string COMPILE_INFO_JSON = "compile_info_json"; -const std::string COMPILE_INFO_KEY = "compile_info_key"; -const std::string COMPILE_INFO_WORKSPACE_SIZE_LIST = "_workspace_size_list"; -const std::string ATOMIC_COMPILE_INFO_JSON = "_atomic_compile_info_json"; -const std::string ATOMIC_COMPILE_INFO_KEY = "_atomic_compile_info_key"; -const std::string ATTR_NAME_ATOMIC_CLEAN_WORKSPACE = "_optiling_atomic_add_mem_size"; -const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; -const std::string OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN = "DynamicAtomicAddrClean"; -const std::string OP_TYPE_AUTO_TILING = "AutoTiling"; -const std::string kMemoryCheck = "_memcheck"; -const std::string kOriOpParaSize = "ori_op_para_size"; -const std::map DATATYPE_STRING_MAP { - {ge::DT_FLOAT, "float32"}, - {ge::DT_FLOAT16, "float16"}, - {ge::DT_INT8, "int8"}, - {ge::DT_INT16, "int16"}, - {ge::DT_INT32, "int32"}, - {ge::DT_INT64, "int64"}, - {ge::DT_UINT8, "uint8"}, - {ge::DT_UINT16, "uint16"}, - {ge::DT_UINT32, "uint32"}, - {ge::DT_UINT64, "uint64"}, - {ge::DT_BOOL, "bool"}, - {ge::DT_DOUBLE, "double"}, - {ge::DT_DUAL, "dual"}, - {ge::DT_DUAL_SUB_INT8, "dual_sub_int8"}, - {ge::DT_DUAL_SUB_UINT8, "dual_sub_uint8"} -}; - -} // namespace optiling - -#endif // REGISTER_OP_TILING_OP_TILING_CONSTANTS_H_ diff --git a/register/op_tiling/op_tiling_info.cc b/register/op_tiling/op_tiling_info.cc deleted file mode 100644 index 2275fb7824c26d867a96775aafbd319f3ce85092..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling_info.cc +++ /dev/null @@ -1,404 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_tiling_info.h" -#include -#include "graph/debug/ge_log.h" -#include "graph/def_types.h" - -namespace optiling { -using std::make_shared; - -namespace utils { -class OpRunInfoImpl { -public: - OpRunInfoImpl() = default; - ~OpRunInfoImpl() = default; - - OpRunInfoImpl(const uint32_t &block_dim, const bool &clear_atomic, const uint64_t &tiling_key) - : block_dim_(block_dim), - clear_atomic_(clear_atomic), - tiling_key_(tiling_key), - addr_base_(nullptr), - max_size_(0), - offset_(0), - tiling_cond_(-1), - schedule_mode_(0U) {} - - void SetBlockDim(const uint32_t &block_dim) { block_dim_ = block_dim; } - - uint32_t GetBlockDim() const { return block_dim_; } - - void SetAicpuBlockDim(uint32_t block_dim) { aicpu_block_dim_ = block_dim; } - - uint32_t GetAicpuBlockDim() const { return aicpu_block_dim_; } - - void SetScheduleMode(const uint32_t schedule_mode) { - schedule_mode_ = schedule_mode; - } - - uint32_t GetScheduleMode() const { - return schedule_mode_; - } - - void AddWorkspace(const int64_t &workspace) { workspaces_.push_back(workspace); } - - size_t GetWorkspaceNum() const { return workspaces_.size(); } - - ge::graphStatus GetWorkspace(const size_t &idx, int64_t &workspace) const { - if ((!workspaces_.empty()) && (idx < workspaces_.size())) { - workspace = workspaces_[idx]; - return ge::GRAPH_SUCCESS; - } - return ge::GRAPH_FAILED; - } - - void GetAllWorkspaces(std::vector &workspaces) const { workspaces = workspaces_; } - - const std::vector &GetAllWorkspaces() const { return workspaces_; } - - void SetWorkspaces(const std::vector &workspaces) { workspaces_ = workspaces; } - - void AddTilingData(const char *value, const size_t size) { - if (addr_base_ == nullptr) { - (void)tiling_data_.write(value, static_cast(size)); - (void)tiling_data_.flush(); - } else { - auto addr = ::ge::ValueToPtr(::ge::PtrToValue(addr_base_) + offset_); - if (memcpy_s(addr, static_cast(max_size_ - offset_), value, size) != EOK) { - GELOGE(ge::GRAPH_FAILED, "[Add][TilingData] Memcpy tiling data failed, " - "dst size = %zu, src size = %zu.", static_cast(max_size_ - offset_), size); - REPORT_INNER_ERR_MSG("E19999", "[Add][TilingData] Memcpy tiling data failed, dst size = %zu, src size = %zu.", - static_cast(max_size_ - offset_), size); - return; - } - offset_ += size; - } - } - - void AlignOffsetWith64() { - const uint64_t offset = (offset_ + sizeof(uint64_t) - 1U) / sizeof(uint64_t); - offset_ = offset * sizeof(uint64_t); - } - - bool SetMemCheckBaseOffset(uint64_t offset) { - GELOGD("When max size is %lu, set a new offset [%lu] to replace the original offset [%lu].", max_size_, offset, offset_); - if (offset >= std::numeric_limits::max() - sizeof(uint64_t)) { - GELOGE(ge::GRAPH_FAILED, "Offset overflow."); - return false; - } - uint64_t new_offset = (offset + sizeof(uint64_t) - 1U) / sizeof(uint64_t); - new_offset = new_offset * sizeof(uint64_t); - if (new_offset < offset_ || new_offset >= max_size_) { - return false; - } - offset_ = new_offset; - return true; - } - - void* GetAddrBase(uint64_t& max_size) const { - max_size = max_size_; - return addr_base_; - } - - void SetAddrBaseOffset(const uint64_t size) { - offset_ = size; - } - - const ByteBuffer &GetAllTilingData() const { return tiling_data_; } - - ByteBuffer &GetAllTilingData() { return tiling_data_; } - - uint64_t GetTilingDataSize() const { return offset_; } - void SetAllTilingData(const ByteBuffer &value) { - tiling_data_.clear(); - offset_ = 0; - AddTilingData(value.str().c_str(), value.str().size()); - } - - void SetClearAtomic(const bool clear_atomic) { clear_atomic_ = clear_atomic; } - - bool GetClearAtomic() const { return clear_atomic_; } - - void SetTilingKey(const uint64_t &tiling_key) { tiling_key_ = tiling_key; } - - uint64_t GetTilingKey() const { return tiling_key_; } - - void ResetWorkspace() { - workspaces_.clear(); - } - - void ResetAddrBase(void *const addr_base, const uint64_t max_size) { - addr_base_ = addr_base; - max_size_ = max_size; - offset_ = 0; - } - - void SetTilingCond(const int32_t tiling_cond) { tiling_cond_ = tiling_cond; } - - int32_t GetTilingCond() const { return tiling_cond_; } - - void SetLocalMemorySize(const uint32_t local_memory_size) { - local_memory_size_ = local_memory_size; - } - - uint32_t GetLocalMemorySize() const { - return local_memory_size_; - } - -private: - uint32_t block_dim_; - bool clear_atomic_; - uint64_t tiling_key_; - ByteBuffer tiling_data_; - std::vector workspaces_; - void *addr_base_; - uint64_t max_size_; - uint64_t offset_; - int32_t tiling_cond_; - uint32_t schedule_mode_; - uint32_t local_memory_size_ = 0U; - uint32_t aicpu_block_dim_ = 0U; -}; - -OpRunInfo::OpRunInfo() { - impl_ = make_shared(); -} - -OpRunInfo::OpRunInfo(const uint32_t &block_dim, const bool &clear_atomic, const uint64_t &tiling_key) { - impl_ = make_shared(block_dim, clear_atomic, tiling_key); -} - -OpRunInfo::OpRunInfo(const OpRunInfo &runinfo) { - impl_ = make_shared(runinfo.GetBlockDim(), runinfo.GetClearAtomic(), runinfo.GetTilingKey()); - std::vector workspaces; - runinfo.GetAllWorkspaces(workspaces); - impl_->SetWorkspaces(workspaces); - impl_->SetAllTilingData(runinfo.GetAllTilingData()); - impl_->SetLocalMemorySize(runinfo.GetLocalMemorySize()); -} - -OpRunInfo::OpRunInfo(OpRunInfo &&runinfo) { - impl_ = std::move(runinfo.impl_); -} - -OpRunInfo &OpRunInfo::operator=(const OpRunInfo &runinfo) { - if (&runinfo != this) { - impl_ = make_shared(runinfo.GetBlockDim(), runinfo.GetClearAtomic(), runinfo.GetTilingKey()); - std::vector workspaces; - runinfo.GetAllWorkspaces(workspaces); - impl_->SetWorkspaces(workspaces); - impl_->SetAllTilingData(runinfo.GetAllTilingData()); - impl_->SetLocalMemorySize(runinfo.GetLocalMemorySize()); - } - return *this; -} - -OpRunInfo &OpRunInfo::operator=(OpRunInfo &&runinfo) { - if (&runinfo != this) { - impl_ = std::move(runinfo.impl_); - } - return *this; -} - -void OpRunInfo::SetBlockDim(const uint32_t &block_dim) { - impl_->SetBlockDim(block_dim); -} - -void OpRunInfo::SetAicpuBlockDim(uint32_t block_dim) { - impl_->SetAicpuBlockDim(block_dim); -} - -void OpRunInfo::SetScheduleMode(const uint32_t schedule_mode) { - impl_->SetScheduleMode(schedule_mode); -} - -uint32_t OpRunInfo::GetScheduleMode() const { - return impl_->GetScheduleMode(); -} - -uint32_t OpRunInfo::GetBlockDim() const { - return impl_->GetBlockDim(); -} - -uint32_t OpRunInfo::GetAicpuBlockDim() const { - return impl_->GetAicpuBlockDim(); -} - -void OpRunInfo::AddWorkspace(const int64_t &workspace) { - impl_->AddWorkspace(workspace); -} - -size_t OpRunInfo::GetWorkspaceNum() const { - return impl_->GetWorkspaceNum(); -} - -ge::graphStatus OpRunInfo::GetWorkspace(const size_t &idx, int64_t &workspace) const { - return impl_->GetWorkspace(idx, workspace); -} - -void OpRunInfo::GetAllWorkspaces(std::vector &workspaces) const { - impl_->GetAllWorkspaces(workspaces); -} - -const std::vector &OpRunInfo::GetAllWorkspaces() const { - return impl_->GetAllWorkspaces(); -} - -void OpRunInfo::SetWorkspaces(const std::vector &workspaces) { - impl_->SetWorkspaces(workspaces); -} - -void OpRunInfo::InternelSetTiling(const ByteBuffer &value) { - impl_->SetAllTilingData(value); -} - -void OpRunInfo::AddTilingData(const ge::char_t *value, const size_t size) { - impl_->AddTilingData(value, size); -} - -void OpRunInfo::AlignOffsetWith64() { - return impl_->AlignOffsetWith64(); -} - -bool OpRunInfo::SetMemCheckBaseOffset(const uint64_t &offset) { - return impl_->SetMemCheckBaseOffset(offset); -} - -void* OpRunInfo::GetAddrBase(uint64_t& max_size) const { - return impl_->GetAddrBase(max_size); -} - -void OpRunInfo::SetAddrBaseOffset(const uint64_t size) { - impl_->SetAddrBaseOffset(size); -} - -ByteBuffer &OpRunInfo::GetAllTilingData() { - return impl_->GetAllTilingData(); -} - -const ByteBuffer &OpRunInfo::GetAllTilingData() const { - return impl_->GetAllTilingData(); -} -uint64_t OpRunInfo::GetTilingDataSize() const { - return impl_->GetTilingDataSize(); -} -void OpRunInfo::SetClearAtomic(const bool clear_atomic) { - impl_->SetClearAtomic(clear_atomic); -} - -bool OpRunInfo::GetClearAtomic() const { - return impl_->GetClearAtomic(); -} - -void OpRunInfo::SetTilingKey(const uint64_t &new_tiling_key) { - impl_->SetTilingKey(new_tiling_key); -} - -uint64_t OpRunInfo::GetTilingKey() const { - return impl_->GetTilingKey(); -} - -void OpRunInfo::ResetWorkspace() { - impl_->ResetWorkspace(); -} - -void OpRunInfo::ResetAddrBase(void *const addr_base, const uint64_t max_size) { - impl_->ResetAddrBase(addr_base, max_size); -} - -void OpRunInfo::SetTilingCond(const int32_t tiling_cond) { - impl_->SetTilingCond(tiling_cond); -} - -int32_t OpRunInfo::GetTilingCond() const { - return impl_->GetTilingCond(); -} - -void OpRunInfo::SetLocalMemorySize(const uint32_t local_memory_size) { - impl_->SetLocalMemorySize(local_memory_size); -} - -uint32_t OpRunInfo::GetLocalMemorySize() const { - return impl_->GetLocalMemorySize(); -} - -class OpCompileInfoImpl { -public: - OpCompileInfoImpl() : key_(), value_() {} - ~OpCompileInfoImpl() = default; - OpCompileInfoImpl(const ge::AscendString &key, const ge::AscendString &value) : key_(key), value_(value) {} - OpCompileInfoImpl(const std::string &key, const std::string &value) : key_(key.c_str()), value_(value.c_str()) {} - - void SetKey(const ge::AscendString &key) { key_ = key; } - - void SetValue(const ge::AscendString &value) { value_ = value; } - - const ge::AscendString &GetKey() const { return key_; } - - const ge::AscendString &GetValue() const { return value_; } - -private: - ge::AscendString key_; - ge::AscendString value_; -}; - -OpCompileInfo::OpCompileInfo() { - impl_ = make_shared(); -} - -OpCompileInfo::OpCompileInfo(const ge::AscendString &key, const ge::AscendString &value) { - impl_ = make_shared(key, value); -} - -OpCompileInfo::OpCompileInfo(const std::string &key, const std::string &value) { - impl_ = make_shared(key, value); -} - -OpCompileInfo::OpCompileInfo(const OpCompileInfo &compileinfo) { - impl_ = make_shared(); - *impl_ = *compileinfo.impl_; -} - -OpCompileInfo::OpCompileInfo(OpCompileInfo &&compileinfo) { - impl_ = std::move(compileinfo.impl_); -} - -OpCompileInfo &OpCompileInfo::operator=(const OpCompileInfo &compileinfo) { - if (&compileinfo != this) { - impl_ = make_shared(); - *impl_ = *compileinfo.impl_; - } - return *this; -} - -OpCompileInfo &OpCompileInfo::operator=(OpCompileInfo &&compileinfo) { - if (&compileinfo != this) { - impl_ = std::move(compileinfo.impl_); - } - return *this; -} - -void OpCompileInfo::SetKey(const ge::AscendString &key) { - impl_->SetKey(key); -} - -void OpCompileInfo::SetValue(const ge::AscendString &value) { - impl_->SetValue(value); -} - -const ge::AscendString &OpCompileInfo::GetKey() const { - return impl_->GetKey(); -} - -const ge::AscendString &OpCompileInfo::GetValue() const { - return impl_->GetValue(); -} -} // namespace utils -} // namespace optiling diff --git a/register/op_tiling/op_tiling_py.cc b/register/op_tiling/op_tiling_py.cc deleted file mode 100644 index 643e335e3501121c95c76df8fdc58bfde9c5a337..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling_py.cc +++ /dev/null @@ -1,1985 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" -#include "register/op_tiling_info.h" -#include "register/op_tiling_registry.h" -#include "register/op_impl_space_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "common/util/tiling_utils.h" -#include "platform/platform_info.h" -#include "register/op_impl_registry.h" -#include "register/op_impl_registry_base.h" -#include "exe_graph/runtime/storage_shape.h" -#include "common/util/error_manager/error_manager.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "exe_graph/runtime/tiling_context.h" -#include "common/checker.h" -#include "graph/utils/math_util.h" -#include "external/hcom/hcom_topo_info.h" -#include "inc/external/ge_common/ge_api_types.h" -#include "common/ge_common/util.h" -namespace ge { -__attribute__((unused)) void to_json(nlohmann::json &j, const HcomTopoInfo::TopoLevelDesc &desc) { - j = nlohmann::json{ - {"comm_sets", desc.comm_sets}, - {"rank_size", desc.rank_size} - }; -} - -__attribute__((unused)) void from_json(const nlohmann::json &j, HcomTopoInfo::TopoLevelDesc &desc) { - (void) j.at("comm_sets").get_to(desc.comm_sets); - (void) j.at("rank_size").get_to(desc.rank_size); -} - -__attribute__((unused)) void to_json(nlohmann::json &j, const HcomTopoInfo::TopoInfo &info) { - j = nlohmann::json{ - {"rank_size", info.rank_size}, - {"topo_level_descs", nlohmann::json::array()} - }; - - for (const auto &topo_level_desc : info.topo_level_descs) { - j["topo_level_descs"].push_back(topo_level_desc); - } -} - -__attribute__((unused)) void from_json(const nlohmann::json &j, HcomTopoInfo::TopoInfo &info) { - (void) j.at("rank_size").get_to(info.rank_size); - - const auto &arr = j.at("topo_level_descs"); - if (arr.size() != static_cast(HcomTopoInfo::TopoLevel::MAX)) { - std::ostringstream oss; - oss << "Invalid topo_level_descs array length " << arr.size() << ", should be " - << static_cast(HcomTopoInfo::TopoLevel::MAX); - throw std::out_of_range(oss.str()); - } - - for (size_t i = 0; i < static_cast(HcomTopoInfo::TopoLevel::MAX); ++i) { - (void) arr[i].get_to(info.topo_level_descs[i]); - } -} -} -namespace optiling { -using ParseAttrFunc = std::function; -using CopyConstDataFunc = std::function &)>; - -class FuncTable { -public: - FuncTable() = default; - FuncTable &Init() { - funcs_.resize(ge::DT_MAX, nullptr); - return *this; - } - - FuncTable &Insert(ge::DataType index, CopyConstDataFunc func) { - funcs_[index] = func; - return *this; - } - - CopyConstDataFunc Find(ge::DataType index) const { - return funcs_[index]; - } - -private: - std::vector funcs_; -}; - -namespace { -constexpr uint32_t kRightShiftBits = 4; -constexpr uint32_t kAndBits = 15; -const std::string kHexDigits = "0123456789ABCDEF"; -constexpr size_t kSize = 4UL; -constexpr size_t kDeterministicOffset = 3UL; -const std::string kMaxTilingSize = "op_para_size"; -constexpr size_t kMaxTilingDataSize = 16UL * 1024UL; -constexpr size_t kWorkspaceHolerSize = 8UL; -const std::string kAttrGroup = "group"; - -struct ContextComponent { - std::vector storage_shapes; - std::vector>> index_to_tensors; - ge::OpDescPtr op_desc {nullptr}; - std::unique_ptr tiling_data; - std::unique_ptr workspace_size; - bool atomic_flag = true; - int32_t tiling_cond = 0; - uint32_t schedule_mode = 0; -}; - -bool FindImplFuncsOld(const ge::char_t *op_type, const gert::OpImplKernelRegistry::OpImplFunctions *&funcs) { - funcs = gert::OpImplRegistry::GetInstance().GetOpImpl(op_type); - if (funcs == nullptr || funcs->tiling == nullptr || funcs->tiling_parse == nullptr) { - funcs = gert::OpImplRegistry::GetInstance().GetOpImpl("DefaultImpl"); - if (funcs == nullptr || funcs->tiling == nullptr || funcs->tiling_parse == nullptr) { - GELOGE(ge::GRAPH_FAILED, - "failed to find implfuncs in 1.0 way, funcs/tiling/tiling_parse is null. op type is %s.", op_type); - REPORT_INNER_ERR_MSG("E19999", "old funcs/tiling/tiling_parse is null. op type is %s.", op_type); - return false; - } - GELOGD("Finding default implfuncs in 1.0 way, op type is %s.", op_type); - return true; - } - GELOGD("Finding implfuncs in 1.0 way, op type is %s.", op_type); - return true; -} - -bool FindImplFuncs(const ge::char_t *op_type, const gert::OpImplKernelRegistry::OpImplFunctions *&funcs) { - auto registry = gert::DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(); - if (registry == nullptr) { - GELOGW("Failed to find implfuncs in 2.0 way, registery is null. op type is %s.", op_type); - return FindImplFuncsOld(op_type, funcs); - } - const std::string op_type_str(op_type); - funcs = registry->GetOpImpl(op_type_str); - if (funcs == nullptr || funcs->tiling == nullptr || funcs->tiling_parse == nullptr) { - std::string default_impl_str("DefaultImpl"); - funcs = registry->GetOpImpl(default_impl_str); - if (funcs == nullptr || funcs->tiling == nullptr || funcs->tiling_parse == nullptr) { - GELOGW("failed to find implfuncs in 2.0 way, funcs/tiling/tiling_parse is null. op type is %s.", op_type); - return FindImplFuncsOld(op_type, funcs); - } - GELOGD("Finding default implfuncs in 2.0 way, op type is %s.", op_type); - return true; - } - GELOGD("Finding implfuncs using the 2.0 method, op type is %s.", op_type); - return true; -} - -template -bool ParseValueNullDesc(const nlohmann::json &value_null_desc, std::vector &data) { - GE_ASSERT_TRUE(!value_null_desc.is_null(), "value_null desc is null"); - std::string null_desc = value_null_desc.get(); - if (std::numeric_limits::has_infinity && std::numeric_limits::has_quiet_NaN) { - if (null_desc == "inf") { - data.emplace_back(std::numeric_limits::infinity()); - } else if (null_desc == "-inf") { - data.emplace_back(-std::numeric_limits::infinity()); - } else if (null_desc == "nan") { - data.emplace_back(std::numeric_limits::quiet_NaN()); - } else { - GELOGE(ge::GRAPH_PARAM_INVALID, "value_null desc: %s is not supported", null_desc.c_str()); - return false; - } - } else { - GELOGE(ge::GRAPH_PARAM_INVALID, "this type doesn't support infinity and nan"); - return false; - } - return true; -} - -bool ParseAndSetFloatAttr(ge::OpDescPtr &op_desc, const nlohmann::json &attr, const std::string &attr_name) { - const auto value = attr["value"]; - std::vector data; - const auto value_null_desc = attr.find("value_null_desc"); - if (value_null_desc == attr.end()) { - data.emplace_back(value.get()); - } else { - if (value.is_null()) { - GE_ASSERT_TRUE(ParseValueNullDesc(value_null_desc.value(), data)); - } else { - GELOGE(ge::GRAPH_PARAM_INVALID, "value_null_desc is set, but value is not null"); - return false; - } - } - op_desc->AppendIrAttrName(attr_name); - (void)op_desc->SetAttr(attr_name, ge::AnyValue::CreateFrom(data.front())); - return true; -} - -template -bool ParseAndSetAttr(ge::OpDescPtr &op_desc, const nlohmann::json &attr, const std::string &attr_name) { - const T attr_value = attr["value"].get(); - op_desc->AppendIrAttrName(attr_name); - (void)op_desc->SetAttr(attr_name, ge::AnyValue::CreateFrom(attr_value)); - return true; -} - -bool ParseAndSetFloatListAttr(ge::OpDescPtr &op_desc, const nlohmann::json &attr, const std::string &attr_name) { - const auto value = attr["value"]; - std::vector data; - const auto value_null_desc = attr.find("value_null_desc"); - if (value_null_desc == attr.end()) { - data = value.get>(); - } else { - GE_ASSERT_TRUE(value.size() == value_null_desc->size(), "value size is not equal to value_null_desc size"); - for (size_t i = 0U; i < value.size(); ++i) { - if (value.at(i).is_null()) { - GE_ASSERT_TRUE(ParseValueNullDesc(value_null_desc->at(i), data)); - } else { - data.emplace_back(value.at(i).get()); - } - } - } - op_desc->AppendIrAttrName(attr_name); - (void)op_desc->SetAttr(attr_name, ge::AnyValue::CreateFrom>(data)); - return true; -} - -template -bool ParseAndSetListAttr(ge::OpDescPtr &op_desc, const nlohmann::json &attr, const std::string &attr_name) { - const std::vector attr_value = attr["value"].get>(); - op_desc->AppendIrAttrName(attr_name); - (void)op_desc->SetAttr(attr_name, ge::AnyValue::CreateFrom>(attr_value)); - return true; -} - -bool ParseAndSetListListAttr(ge::OpDescPtr &op_desc, const nlohmann::json &attr, const std::string &attr_name) { - std::vector> attr_value_int32 = attr["value"].get>>(); - std::vector> attr_value_int64; - std::vector temp_int64_vec; - for (const auto &vec_int32 : attr_value_int32) { - for (const auto &item : vec_int32) { - int64_t tmp = static_cast(item); - temp_int64_vec.emplace_back(tmp); - } - attr_value_int64.emplace_back(temp_int64_vec); - temp_int64_vec.clear(); - } - op_desc->AppendIrAttrName(attr_name); - (void)op_desc->SetAttr(attr_name, ge::AnyValue::CreateFrom>>(attr_value_int64)); - return true; -} - -bool ParseAndSetListListInt64Attr(ge::OpDescPtr &op_desc, const nlohmann::json &attr, const std::string &attr_name) { - const std::vector> attr_value_int64 = attr["value"].get>>(); - op_desc->AppendIrAttrName(attr_name); - (void)op_desc->SetAttr(attr_name, ge::AnyValue::CreateFrom>>(attr_value_int64)); - return true; -} - -template -bool GetConstData(const nlohmann::json &json_array, const size_t total_size, - std::unique_ptr &tensor_holder) { - auto tensor = reinterpret_cast(tensor_holder.get()); - std::vector value; - const auto const_value = json_array["const_value"]; - const auto const_value_null_desc = json_array.find("const_value_null_desc"); - if (const_value_null_desc == json_array.end()) { - value = const_value.get>(); - } else { - GE_ASSERT_TRUE(const_value.size() == const_value_null_desc->size(), - "const_value size is not equal to const_value_null_desc size"); - for (size_t i = 0U; i < const_value.size(); ++i) { - if (const_value.at(i).is_null()) { - GE_ASSERT_TRUE(ParseValueNullDesc(const_value_null_desc->at(i), value)); - } else { - value.emplace_back(const_value.at(i).get()); - } - } - } - if (memcpy_s(tensor->GetData(), total_size - sizeof(gert::Tensor), value.data(), value.size() * sizeof(T)) != - EOK) { - GELOGE(ge::FAILED, "Call memcpy failed, total value size is %zu.", value.size() * sizeof(T)); - return false; - } - return true; -} - -bool GetConstDataWithFloat16(const nlohmann::json &json_array, const size_t total_size, - std::unique_ptr &tensor_holder) { - std::vector const_value = json_array["const_value"].get>(); - std::vector const_data_vec; - for (size_t i = 0UL; i < const_value.size(); ++i) { - uint16_t const_data_uint16 = Float32ToFloat16(const_value[i]); - const_data_vec.emplace_back(const_data_uint16); - } - auto tensor = reinterpret_cast(tensor_holder.get()); - if (memcpy_s(tensor->GetData(), total_size - sizeof(gert::Tensor), const_data_vec.data(), - const_data_vec.size() * sizeof(uint16_t)) != EOK) { - GELOGE(ge::FAILED, "Call memcpy failed, total value size is %zu.", const_data_vec.size() * sizeof(uint16_t)); - return false; - } - return true; -} - -bool GetConstDataWithBF16(const nlohmann::json &json_array, const size_t total_size, - std::unique_ptr &tensor_holder) { - std::vector const_value = json_array["const_value"].get>(); - std::vector const_data_vec; - for (size_t i = 0UL; i < const_value.size(); ++i) { - uint16_t const_data_uint16 = Float32ToBfloat16(const_value[i]); - const_data_vec.emplace_back(const_data_uint16); - } - auto tensor = reinterpret_cast(tensor_holder.get()); - GE_CHK_BOOL_RET_STATUS((memcpy_s(tensor->GetData(), total_size - sizeof(gert::Tensor), const_data_vec.data(), - const_data_vec.size() * sizeof(uint16_t)) == EOK), - false, "Call memcpy failed, total value size is %zu.", - const_data_vec.size() * sizeof(uint16_t)); - return true; -} - -const std::unordered_map kDtypeToAttrFunc = { - {"bool", ParseAndSetAttr}, - {"float", ParseAndSetFloatAttr}, - {"float32", ParseAndSetFloatAttr}, - {"int", ParseAndSetAttr}, - {"int32", ParseAndSetAttr}, - {"int64", ParseAndSetAttr}, - {"str", ParseAndSetAttr}, - {"list_bool", ParseAndSetListAttr}, - {"list_float", ParseAndSetFloatListAttr}, - {"list_float32", ParseAndSetFloatListAttr}, - {"list_int", ParseAndSetListAttr}, - {"list_int32", ParseAndSetListAttr}, - {"list_int64", ParseAndSetListAttr}, - {"list_str", ParseAndSetListAttr}, - {"list_list_int", ParseAndSetListListAttr}, - {"list_list_int32", ParseAndSetListListAttr}, - {"list_list_int64", ParseAndSetListListInt64Attr}}; - -const FuncTable kFuncTable = FuncTable() - .Init() - .Insert(ge::DT_INT8, GetConstData) - .Insert(ge::DT_UINT8, GetConstData) - .Insert(ge::DT_INT16, GetConstData) - .Insert(ge::DT_UINT16, GetConstData) - .Insert(ge::DT_INT32, GetConstData) - .Insert(ge::DT_UINT32, GetConstData) - .Insert(ge::DT_INT64, GetConstData) - .Insert(ge::DT_UINT64, GetConstData) - .Insert(ge::DT_FLOAT, GetConstData) - .Insert(ge::DT_DOUBLE, GetConstData) - .Insert(ge::DT_FLOAT16, GetConstDataWithFloat16) - .Insert(ge::DT_BF16, GetConstDataWithBF16) - .Insert(ge::DT_BOOL, GetConstData); - -void ParseDtype(const nlohmann::json &json, ge::GeTensorDesc &tensor_desc) { - if (json.contains("dtype")) { - std::string dtype_str = json["dtype"].get(); - (void)std::transform(dtype_str.begin(), dtype_str.end(), dtype_str.begin(), ::toupper); - dtype_str = "DT_" + dtype_str; - const ge::DataType ge_dtype = ge::TypeUtils::SerialStringToDataType(dtype_str); - tensor_desc.SetDataType(ge_dtype); - } -} - -void ParseStorageShape(const nlohmann::json &json, gert::StorageShape &storage_shape, - std::vector &storage_shapes) { - if (json.contains("shape")) { - gert::Shape shape; - const auto dims = json["shape"].get>(); - for (const int64_t &dim : dims) { - (void)shape.AppendDim(dim); - } - storage_shape.MutableStorageShape() = shape; - } - if (json.contains("ori_shape")) { - gert::Shape shape; - const auto dims = json["ori_shape"].get>(); - for (const int64_t dim : dims) { - (void)shape.AppendDim(dim); - } - storage_shape.MutableOriginShape() = shape; - } - storage_shapes.emplace_back(storage_shape); -} - -void ParseStorageFormat(const nlohmann::json &json, ge::GeTensorDesc &tensor_desc) { - if (json.contains("format")) { - std::string format_str = json["format"].get(); - (void)std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - if (json.contains("sub_format")) { - int32_t sub_format = json["sub_format"].get(); - GELOGD("Sub format: %d, Primary format: %d", sub_format, static_cast(ge_format)); - ge_format = static_cast(ge::GetFormatFromSub(static_cast(ge_format), sub_format)); - } - tensor_desc.SetFormat(ge_format); - } - if (json.contains("ori_format")) { - std::string format_str = json["ori_format"].get(); - (void)std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - const ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - tensor_desc.SetOriginFormat(ge_format); - } -} - -ge::graphStatus ParseConstValue(const nlohmann::json &input, const gert::StorageShape &storage_shape, - const ge::GeTensorDesc &tensor_desc, const uint32_t index, - std::vector>> &index_to_tensor) { - if (input.contains("const_value")) { - size_t total_size = 0UL; - const size_t tensor_size = static_cast(ge::GetSizeInBytes(storage_shape.GetStorageShape().GetShapeSize(), - tensor_desc.GetDataType())); - auto tensor_holder = gert::Tensor::CreateFollowing(tensor_desc.GetDataType(), tensor_size, total_size); - GE_CHECK_NOTNULL(tensor_holder); - - if (tensor_size != 0UL) { - auto func = kFuncTable.Find(tensor_desc.GetDataType()); - GE_CHECK_NOTNULL(func); - if (!func(input, total_size, tensor_holder)) { - GELOGE(ge::GRAPH_FAILED, "Make tensor failed."); - return ge::GRAPH_FAILED; - } - } - auto tensor = reinterpret_cast(tensor_holder.get()); - tensor->MutableOriginShape() = storage_shape.GetOriginShape(); - tensor->MutableStorageShape() = storage_shape.GetStorageShape(); - tensor->SetDataType(tensor_desc.GetDataType()); - tensor->SetStorageFormat(tensor_desc.GetFormat()); - tensor->SetOriginFormat(tensor_desc.GetOriginFormat()); - index_to_tensor.emplace_back(index, std::move(tensor_holder)); - } else { - auto tensor_holder = std::unique_ptr(new (std::nothrow) uint8_t[sizeof(gert::Tensor)]); - GE_ASSERT_NOTNULL(tensor_holder); - new (tensor_holder.get()) gert::Tensor({{}, {}}, {tensor_desc.GetOriginFormat(), tensor_desc.GetFormat(), {}}, - gert::kOnHost, tensor_desc.GetDataType(), nullptr); - reinterpret_cast(tensor_holder.get())->MutableStorageShape() = storage_shape.GetStorageShape(); - reinterpret_cast(tensor_holder.get())->MutableOriginShape() = storage_shape.GetOriginShape(); - index_to_tensor.emplace_back(index, std::move(tensor_holder)); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseInput(const nlohmann::json &input, const uint32_t index, const ge::IrInputType input_type, - ContextComponent &context_com) { - ge::GeTensorDesc tensor_desc; - gert::StorageShape storage_shape; - ParseDtype(input, tensor_desc); - ParseStorageShape(input, storage_shape, context_com.storage_shapes); - ParseStorageFormat(input, tensor_desc); - const auto ret = ParseConstValue(input, storage_shape, tensor_desc, index, context_com.index_to_tensors); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - if (input_type == ge::kIrInputRequired) { - (void) context_com.op_desc->AddInputDesc(std::to_string(index), tensor_desc); - } else if (input_type == ge::kIrInputDynamic) { - (void) context_com.op_desc->UpdateInputDesc(index, tensor_desc); - } else { - GELOGE(ge::GRAPH_FAILED, "Unsupported ir type."); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseInputs(const char* inputs, ContextComponent& context_com) { - nlohmann::json desc_list; - try { - desc_list = nlohmann::json::parse(inputs); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_FAILED, "Parse json exception. %s", inputs); - return ge::GRAPH_FAILED; - } - uint32_t index = 0U; - uint32_t optional_index = 0U; - for (const auto &desc : desc_list) { - if (desc.is_array()) { - const auto input_num = desc.size(); - (void)context_com.op_desc->AddDynamicInputDesc("dynamic_" + std::to_string(index) + "_" + - std::to_string(input_num), static_cast(input_num)); - context_com.op_desc->AppendIrInput("dynamic_" + std::to_string(index) + "_" + std::to_string(input_num), - ge::kIrInputDynamic); - for (const auto &ele : desc) { - if (ele.is_null()) { - GELOGW("Empty input at current index %u", index); - continue; - } - if (ParseInput(ele, index, ge::kIrInputDynamic, context_com) != ge::GRAPH_SUCCESS) { - return ge::GRAPH_FAILED; - } - ++index; - } - } else { - if (desc.is_null()) { - context_com.op_desc->AppendIrInput("optional" + std::to_string(optional_index), ge::kIrInputOptional); - (void)context_com.op_desc->AddOptionalInputDesc( - "optional" + std::to_string(optional_index), - ge::GeTensorDesc(ge::GeShape(), ge::FORMAT_RESERVED, ge::DT_UNDEFINED)); - GELOGI("Optional input at index %u is null.", optional_index); - ++optional_index; - continue; - } - context_com.op_desc->AppendIrInput(std::to_string(index), ge::kIrInputRequired); - if (ParseInput(desc, index, ge::kIrInputRequired, context_com) != ge::GRAPH_SUCCESS) { - return ge::GRAPH_FAILED; - } - ++index; - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseOutput(const nlohmann::json &output, ge::IrOutputType output_type, const uint32_t index, - ContextComponent &context_com) { - ge::GeTensorDesc tensor_desc; - gert::StorageShape storage_shape; - ParseDtype(output, tensor_desc); - ParseStorageShape(output, storage_shape, context_com.storage_shapes); - ParseStorageFormat(output, tensor_desc); - - if (output_type == ge::kIrOutputRequired) { - (void) context_com.op_desc->AddOutputDesc(std::to_string(index), tensor_desc); - } else if (output_type == ge::kIrOutputDynamic) { - (void) context_com.op_desc->UpdateOutputDesc(index, tensor_desc); - } else { - GELOGE(ge::GRAPH_FAILED, "Unsupported ir type."); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -void ParseTopoInfo(const ascend_nlohmann::json &extra_info, ge::OpDescPtr &op_desc) { - std::string group; - if (!ge::AttrUtils::GetStr(op_desc, kAttrGroup, group) || group.empty()) { - GELOGW("hcom_topo_info need bind valid %s, which is needed to set on op %s %s", kAttrGroup.c_str(), - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return; - } - if (ge::HcomTopoInfo::Instance().TopoInfoHasBeenSet(group.c_str())) { - return; - } - ge::HcomTopoInfo::TopoInfo topo_info_parsed{}; - // 兼容老的场景 - if (extra_info.contains("rank_size")) { - topo_info_parsed.rank_size = extra_info["rank_size"]; - GELOGD("Extra info contains rank size, rank size is %ld", topo_info_parsed.rank_size); - (void) ge::HcomTopoInfo::Instance().SetGroupTopoInfo(group.c_str(), topo_info_parsed); - } else if (extra_info.contains("hcom_topo_info")) { - try { - const auto &json_hcom_topo_info = extra_info.at("hcom_topo_info"); - GELOGD("Extra info contains topo info [%s]", json_hcom_topo_info.dump().c_str()); - (void) json_hcom_topo_info.get_to(topo_info_parsed); - } catch (const std::exception &e) { - GELOGE(ge::GRAPH_FAILED, "Parse error %s", e.what()); - return; - } - (void) ge::HcomTopoInfo::Instance().SetGroupTopoInfo(group.c_str(), topo_info_parsed); - } else { - return; - } - GELOGD("Set topo info for group %s successfully", group.c_str()); -} - -void ParseExtraInfo(const nlohmann::json &extra_info, ge::OpDescPtr &op_desc) { - if (extra_info.contains("op_name")) { - const std::string name = extra_info["op_name"]; - op_desc->SetName(name); - } - if (extra_info.contains("deterministic")) { - const int32_t deterministic = extra_info["deterministic"]; - (void)ge::AttrUtils::SetInt(op_desc, "deterministic", deterministic); - } - if (extra_info.contains(ge::public_attr::OP_AI_CORE_NUM)) { - const std::string op_aicore_num = extra_info[ge::public_attr::OP_AI_CORE_NUM]; - GELOGI("Set op_aicore_num from extra info: %s", op_aicore_num.c_str()); - (void)ge::AttrUtils::SetStr(op_desc, ge::public_attr::OP_AI_CORE_NUM, op_aicore_num); - } - if (extra_info.contains(ge::public_attr::OP_VECTOR_CORE_NUM)) { - const std::string op_vectorcore_num = extra_info[ge::public_attr::OP_VECTOR_CORE_NUM]; - GELOGI("Set op_vectorcore_num from extra info: %s", op_vectorcore_num.c_str()); - (void)ge::AttrUtils::SetStr(op_desc, ge::public_attr::OP_VECTOR_CORE_NUM, op_vectorcore_num); - } - ParseTopoInfo(extra_info, op_desc); -} - -ge::graphStatus ParseExtraInfos(const char *const extra_info, ge::OpDescPtr &op_desc) { - if (extra_info == nullptr) { - GELOGI("Extra info is nullptr."); - return ge::GRAPH_SUCCESS; - } - nlohmann::json desc; - try { - desc = nlohmann::json::parse(extra_info); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_FAILED, "Parse json exception. %s", extra_info); - return ge::GRAPH_FAILED; - } - if (desc.is_array()) { - for (const auto &ele : desc) { - ParseExtraInfo(ele, op_desc); - } - } else { - ParseExtraInfo(desc, op_desc); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseOutputs(const char *outputs, ContextComponent &context_com) { - nlohmann::json desc_list; - try { - desc_list = nlohmann::json::parse(outputs); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_FAILED, "Parse json exception. %s", outputs); - return ge::GRAPH_FAILED; - } - uint32_t index = 0; - for (const auto &desc : desc_list) { - if (desc.is_array()) { - const size_t output_num = desc.size(); - // 可能传过来的输入没有指定名字,所有用"dynamic_"+"index"+"_"+"num"拼一个统一的假名字,输出也是相同的处理 - (void)context_com.op_desc->AddDynamicOutputDesc("dynamic_" + std::to_string(index) + "_" + - std::to_string(output_num), static_cast(output_num)); - context_com.op_desc->AppendIrOutput("dynamic_" + std::to_string(index) + "_" + std::to_string(output_num), - ge::kIrOutputDynamic); - for (const auto &ele : desc) { - if (ele.is_null()) { - GELOGW("Empty output, cur index %u", index); - continue; - } - GE_ASSERT_GRAPH_SUCCESS(ParseOutput(ele, ge::kIrOutputDynamic, index, context_com)); - ++index; - } - } else { - if (desc.is_null()) { - GELOGW("Empty output, cur index %u", index); - continue; - } - context_com.op_desc->AppendIrOutput(std::to_string(index), ge::kIrOutputRequired); - GE_ASSERT_GRAPH_SUCCESS(ParseOutput(desc, ge::kIrOutputRequired, index, context_com)); - ++index; - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseAttrs(const char *attrs, ge::OpDescPtr &op_desc) { - if (attrs == nullptr) { - GELOGD("Attribute has not been set."); - } else { - nlohmann::json attrs_json; - try { - attrs_json = nlohmann::json::parse(attrs); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_FAILED, "Parse json exception. %s", attrs); - return ge::GRAPH_FAILED; - } - for (const auto &attr : attrs_json) { - if (!attr.contains("name") || !attr.contains("dtype") || !attr.contains("value")) { - GELOGE(ge::GRAPH_FAILED, "cur attr does not contain name or dtype or value."); - return ge::GRAPH_FAILED; - } - const std::string attr_name = attr["name"].get(); - const std::string dtype = attr["dtype"].get(); - const auto iter = kDtypeToAttrFunc.find(dtype); - if (iter == kDtypeToAttrFunc.end()) { - GELOGE(ge::GRAPH_FAILED, "Unknown dtype[%s], which is unsupported.", dtype.c_str()); - return ge::GRAPH_FAILED; - } - GE_ASSERT_TRUE((iter->second)(op_desc, attr, attr_name)); - GELOGD("Finished setting attribute [name: %s] for op.", attr_name.c_str()); - } - } - return ge::GRAPH_SUCCESS; -} - -std::string DumpTilingData(gert::TilingData *tiling_data) { - std::string output; - if (tiling_data == nullptr) { - return output; - } - if (tiling_data->GetDataSize() >= std::numeric_limits::max() / kSize) { - GELOGE(ge::GRAPH_FAILED, "Tiling data size overflow."); - return output; - } - output.reserve(tiling_data->GetDataSize() * kSize); - char *data = reinterpret_cast(tiling_data->GetData()); - for (size_t i = 0UL; i < tiling_data->GetDataSize(); ++i) { - const unsigned char ch = static_cast(data[i]); - output.push_back(kHexDigits[ch >> kRightShiftBits]); - output.push_back(kHexDigits[ch & kAndBits]); - } - return output; -} - -bool DumpRunInfo(gert::KernelContext *kernel_context, char *run_info_json, const size_t run_info_len) { - GE_ASSERT_NOTNULL(run_info_json); - nlohmann::json json_obj; - auto ws = kernel_context->GetOutputPointer(gert::TilingContext::kOutputWorkspace); - GE_ASSERT_NOTNULL(ws); - std::vector workspaces(reinterpret_cast(ws->GetData()), - reinterpret_cast(ws->GetData()) + ws->GetSize()); - GE_ASSERT_NOTNULL(kernel_context->GetOutputPointer(gert::TilingContext::kOutputBlockDim)); - GE_ASSERT_NOTNULL(kernel_context->GetOutputPointer(gert::TilingContext::kOutputAtomicCleanFlag)); - GE_ASSERT_NOTNULL(kernel_context->GetOutputPointer(gert::TilingContext::kOutputTilingKey)); - GE_ASSERT_NOTNULL(kernel_context->GetOutputPointer(gert::TilingContext::kOutputTilingCond)); - GE_ASSERT_NOTNULL(kernel_context->GetOutputPointer(gert::TilingContext::kOutputScheduleMode)); - json_obj["block_dim"] = *kernel_context->GetOutputPointer(gert::TilingContext::kOutputBlockDim); - json_obj["workspaces"] = workspaces; - json_obj["tiling_data"] = - DumpTilingData(kernel_context->GetOutputPointer(gert::TilingContext::kOutputTilingData)); - json_obj["clear_atomic"] = *kernel_context->GetOutputPointer(gert::TilingContext::kOutputAtomicCleanFlag); - json_obj["tiling_key"] = *kernel_context->GetOutputPointer(gert::TilingContext::kOutputTilingKey); - json_obj["tiling_cond"] = *kernel_context->GetOutputPointer(gert::TilingContext::kOutputTilingCond); - json_obj["schedule_mode"] = *kernel_context->GetOutputPointer(gert::TilingContext::kOutputScheduleMode); - - const auto local_mem_size = kernel_context->GetOutputPointer(gert::TilingContext::kOutputLocalMemorySize); - if (local_mem_size != nullptr) { - json_obj["local_memory_size"] = *local_mem_size; - } - - const auto aicpu_block_dim = kernel_context->GetOutputPointer(gert::TilingContext::kOutputAicpuBlockDim); - GE_ASSERT_NOTNULL(aicpu_block_dim); - json_obj["aicpu_block_dim"] = *aicpu_block_dim; - - const std::string str = json_obj.dump(); - return memcpy_s(run_info_json, run_info_len, str.c_str(), str.size() + 1UL) == EOK; -} -} // namespace - -using ParseAndSetAttrValueFunc = std::function; -using ParseAndSetAttrValuePtr = std::shared_ptr; - -thread_local int64_t last_op_tiling_perf = -1; - -template -void ParseAndSetAttrValue(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - const T attr_value = attr["value"].get(); - (void)op.SetAttr(attr_name.c_str(), attr_value); -} - -template -void ParseAndSetAttrListValue(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - const std::vector attr_value = attr["value"].get>(); - (void)op.SetAttr(attr_name.c_str(), attr_value); -} -namespace { -thread_local std::string error_string; -constexpr int64_t ret_success = 0; -constexpr int64_t ret_fail = 1; -constexpr int64_t outter_error_type = 1; -constexpr int64_t inner_error_type = 2; -std::string GetRawErrorMessage() { - try { - nlohmann::json ret_json; - const auto &error_messages = ErrorManager::GetInstance().GetRawErrorMessages(); - if (error_messages.empty()) { - ret_json["ret_code"] = ret_success; - return ret_json.dump(); - } - ret_json["ret_code"] = ret_fail; - nlohmann::json error_messages_json = {}; - for (const auto &item : error_messages) { - nlohmann::json item_json; - item_json["errorcode"] = item.error_id; - if (item.args_map.empty()) { - item_json["type"] = inner_error_type; - item_json["errormsg"] = item.error_message; - } else { - item_json["type"] = outter_error_type; - item_json["errormsg"] = item.args_map; - } - error_messages_json.push_back(item_json); - } - ret_json["error_messages"] = error_messages_json; - return ret_json.dump(); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_FAILED, "get failed when call json api, reason: %s", e.what()); - return ""; - } -} -void ParseAndSetAttrListListValue(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - std::vector> attr_value_int32 = attr["value"].get>>(); - std::vector> attr_value_int64; - std::vector temp_int64_vec; - for (const auto &vec_int32 : attr_value_int32) { - for (const auto &item : vec_int32) { - int64_t tmp = static_cast(item); - temp_int64_vec.emplace_back(tmp); - } - attr_value_int64.emplace_back(temp_int64_vec); - temp_int64_vec.clear(); - } - - (void)op.SetAttr(attr_name.c_str(), attr_value_int64); -} - -void ParseAndSetAttrListListInt64Value(ge::Operator &op, const nlohmann::json &attr, const std::string &attr_name) { - const std::vector> attr_value_int64 = attr["value"].get>>(); - (void)op.SetAttr(attr_name.c_str(), attr_value_int64); -} - -const std::map parse_attr_dtype_map = { - {"bool", std::make_shared(&ParseAndSetAttrValue)}, - {"float", std::make_shared(&ParseAndSetAttrValue)}, - {"float32", std::make_shared(&ParseAndSetAttrValue)}, - {"int", std::make_shared(&ParseAndSetAttrValue)}, - {"int32", std::make_shared(&ParseAndSetAttrValue)}, - {"int64", std::make_shared(&ParseAndSetAttrValue)}, - {"str", std::make_shared(&ParseAndSetAttrValue)}, - {"list_bool", std::make_shared(&ParseAndSetAttrListValue)}, - {"list_float", std::make_shared(&ParseAndSetAttrListValue)}, - {"list_float32", std::make_shared(&ParseAndSetAttrListValue)}, - {"list_int", std::make_shared(&ParseAndSetAttrListValue)}, - {"list_int32", std::make_shared(&ParseAndSetAttrListValue)}, - {"list_int64", std::make_shared(&ParseAndSetAttrListValue)}, - {"list_str", std::make_shared(&ParseAndSetAttrListValue)}, - {"list_list_int", std::make_shared(&ParseAndSetAttrListListValue)}, - {"list_list_int32", std::make_shared(&ParseAndSetAttrListListValue)}, - {"list_list_int64", std::make_shared(&ParseAndSetAttrListListInt64Value)}}; - -void ParseShapeDesc(const nlohmann::json &shape, std::vector &tensors) { - TeOpTensor tensor; - if (shape.contains("shape")) { - tensor.shape = shape["shape"].get>(); - } - if (shape.contains("ori_shape")) { - tensor.ori_shape = shape["ori_shape"].get>(); - } - if (shape.contains("format")) { - tensor.format = shape["format"].get(); - } - if (shape.contains("ori_format")) { - tensor.ori_format = shape["ori_format"].get(); - } - if (shape.contains("dtype")) { - tensor.dtype = shape["dtype"].get(); - } - tensors.emplace_back(tensor); -} - -void ParseShapeDescList(const nlohmann::json &shape_list, std::vector &op_args) { - for (const auto &elem : shape_list) { - TeOpTensorArg tensor_arg; - tensor_arg.arg_type = TensorArgType::TA_NONE; - - if (elem.is_array()) { - tensor_arg.arg_type = TensorArgType::TA_LIST; - for (const auto &shape : elem) { - ParseShapeDesc(shape, tensor_arg.tensor); - } - } else { - tensor_arg.arg_type = TensorArgType::TA_SINGLE; - ParseShapeDesc(elem, tensor_arg.tensor); - } - op_args.emplace_back(tensor_arg); - } -} - -void ParseShapeDescV2(const nlohmann::json &shape, ge::OpDescPtr &op_desc, const bool &is_input) { - ge::GeTensorDesc tensor; - std::string name; - if (shape.contains("shape")) { - tensor.SetShape(ge::GeShape(shape["shape"].get>())); - } - if (shape.contains("ori_shape")) { - tensor.SetOriginShape(ge::GeShape(shape["ori_shape"].get>())); - } - if (shape.contains("format")) { - std::string format_str = shape["format"].get(); - (void)std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - const ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - tensor.SetFormat(ge_format); - } - if (shape.contains("ori_format")) { - std::string format_str = shape["ori_format"].get(); - (void)std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - const ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - tensor.SetOriginFormat(ge_format); - } - if (shape.contains("dtype")) { - std::string dtype_str = shape["dtype"].get(); - (void)std::transform(dtype_str.begin(), dtype_str.end(), dtype_str.begin(), ::toupper); - dtype_str = "DT_" + dtype_str; - const ge::DataType ge_dtype = ge::TypeUtils::SerialStringToDataType(dtype_str); - tensor.SetDataType(ge_dtype); - } - if (shape.contains("name")) { - name = shape["name"]; - tensor.SetName(name); - is_input ? op_desc->AddInputDesc(name, tensor) : op_desc->AddOutputDesc(name, tensor); - } else { - is_input ? op_desc->AddInputDesc(tensor) : op_desc->AddOutputDesc(tensor); - } -} - -void ParseAndSetAttr(const nlohmann::json &attr, ge::Operator &op) { - if (!attr.contains("name") || !attr.contains("dtype") || !attr.contains("value")) { - REPORT_INNER_ERR_MSG("E19999", "cur attr does not contain name or dtype or value."); - return; - } - std::string attr_name; - std::string dtype; - attr_name = attr["name"].get(); - dtype = attr["dtype"].get(); - auto iter = parse_attr_dtype_map.find(dtype); - if (iter == parse_attr_dtype_map.end()) { - REPORT_INNER_ERR_MSG("E19999", "Unknown dtype[%s], which is unsupported.", dtype.c_str()); - return; - } - ParseAndSetAttrValuePtr func_ptr = iter->second; - if (func_ptr == nullptr) { - GE_LOGE("ParseAndSetAttrValueFunc ptr cannot be null!"); - return; - } - (*func_ptr)(op, attr, attr_name); - GELOGD("Finished setting attribute [name: %s] for op.", attr_name.c_str()); -} - -void ParseShapeDescListV2(const nlohmann::json &shape_list, ge::OpDescPtr &op_desc, const bool &is_input) { - for (const auto &elem : shape_list) { - if (elem.is_array()) { - for (const auto &shape : elem) { - if (shape.is_null()) { - GELOGW("Empty input."); - continue; - } - ParseShapeDescV2(shape, op_desc, is_input); - } - } else { - if (elem.is_null()) { - GELOGW("Empty input."); - continue; - } - ParseShapeDescV2(elem, op_desc, is_input); - } - } -} - -void ParseAndSetAttrsList(const nlohmann::json &attrs_list, ge::Operator &op) { - for (const auto &attr : attrs_list) { - ParseAndSetAttr(attr, op); - } -} - -template -void GetConstDataPointer(const nlohmann::json &json_array, std::vector &const_value) { - std::vector value = json_array.get>(); - uint8_t *pv_begin = reinterpret_cast(value.data()); - uint8_t *pv_end = pv_begin + (value.size() * sizeof(T)); - const_value = std::vector(pv_begin, pv_end); -} - -void CopyConstDataWithFloat16(const nlohmann::json &json_array, std::vector &value) { - std::vector const_value = json_array.get>(); - float *const_data_ptr = const_value.data(); - if (const_data_ptr == nullptr) { - GE_LOGE("Failed to get constant data pointer"); - return; - } - std::vector const_data_vec; - const size_t size = sizeof(const_value)/sizeof(float); - for (size_t i = 0; i < size; ++i) { - const float const_data = *(const_data_ptr + i); - uint16_t const_data_uint16 = optiling::Float32ToFloat16(const_data); - const_data_vec.emplace_back(const_data_uint16); - } - uint8_t *pv_begin = reinterpret_cast(const_data_vec.data()); - uint8_t *pv_end = pv_begin + (const_data_vec.size() * sizeof(uint16_t)); - value = std::vector(pv_begin, pv_end); -} - -bool CopyConstData(const std::string &dtype, const nlohmann::json &json_array, std::vector &value) { - if (dtype == "int8") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint8") { - GetConstDataPointer(json_array, value); - } else if (dtype == "int16") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint16") { - GetConstDataPointer(json_array, value); - } else if (dtype == "int32") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint32") { - GetConstDataPointer(json_array, value); - } else if (dtype == "int64") { - GetConstDataPointer(json_array, value); - } else if (dtype == "uint64") { - GetConstDataPointer(json_array, value); - } else if (dtype == "float32") { - GetConstDataPointer(json_array, value); - } else if (dtype == "double") { - GetConstDataPointer(json_array, value); - } else if (dtype == "float16") { - CopyConstDataWithFloat16(json_array, value); - } else { - GE_LOGE("Unknown dtype: %s", dtype.c_str()); - return false; - } - return true; -} - -void ParseConstShapeDesc(const nlohmann::json &shape_json, std::map &const_tensors, - std::map> &const_values) { - std::vector shape; - std::string format_str; - std::string dtype_str; - - if (!shape_json.contains("const_value")) { - GELOGI("Not constant tensor"); - return; - } - if (!shape_json.contains("name")) { - GE_LOGE("const tensor has no name"); - return; - } - std::string name = shape_json["name"]; - - if (shape_json.contains("shape")) { - shape = shape_json["shape"].get>(); - } - if (shape_json.contains("format")) { - format_str = shape_json["format"].get(); - } - if (shape_json.contains("dtype")) { - dtype_str = shape_json["dtype"].get(); - } - - std::vector value; - const bool bres = CopyConstData(dtype_str, shape_json["const_value"], value); - if (!bres) { - GE_LOGE("CopyConstData failed. Buffer is null"); - return; - } - auto res = const_values.emplace(name, std::move(value)); - if (res.first == const_values.end()) { - return; // CodeDEX complains 'CHECK_CONTAINER_EMPTY' - } - - ge::Shape ge_shape(shape); - (void)std::transform(dtype_str.begin(), dtype_str.end(), dtype_str.begin(), ::toupper); - dtype_str = "DT_" + dtype_str; - const ge::DataType ge_dtype = ge::TypeUtils::SerialStringToDataType(dtype_str); - (void)std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - const ge::Format ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - ge::Tensor const_tensor(ge::TensorDesc(ge_shape, ge_format, ge_dtype), res.first->second); - (void)const_tensors.emplace(name, std::make_tuple(const_tensor.GetData(), const_tensor.GetSize(), const_tensor)); - return; -} - -void ParseConstTensorList(const nlohmann::json &shape_list, std::map &const_tensors, - std::map> &const_values) { - for (const auto &elem : shape_list) { - if (elem.is_array()) { - for (const auto &shape : elem) { - ParseConstShapeDesc(shape, const_tensors, const_values); - } - } else { - ParseConstShapeDesc(elem, const_tensors, const_values); - } - } -} - -void ParseConstShapeDescV2(const nlohmann::json &shape_json, ge::Operator &op_para, - std::map> &const_values) { - std::vector shape; - std::string format_str; - std::string dtype_str; - - if (!shape_json.contains("const_value")) { - GELOGI("Not constant tensor"); - return; - } - if (!shape_json.contains("name")) { - REPORT_INNER_ERR_MSG("E19999", "const tensor has no name"); - return; - } - std::string name = shape_json["name"]; - - if (shape_json.contains("shape")) { - shape = shape_json["shape"].get>(); - } - if (shape_json.contains("format")) { - format_str = shape_json["format"].get(); - } - if (shape_json.contains("dtype")) { - dtype_str = shape_json["dtype"].get(); - } - - std::vector value; - const bool bres = CopyConstData(dtype_str, shape_json["const_value"], value); - if (!bres) { - REPORT_INNER_ERR_MSG("E19999", "CopyConstData faild. buffer is null"); - return; - } - auto res = const_values.emplace(name, std::move(value)); - if (res.first == const_values.end()) { - return; // CodeDEX complains 'CHECK_CONTAINER_EMPTY' - } - - const ge::GeShape ge_shape(shape); - ge::DataType ge_dtype = ge::DT_UNDEFINED; - if (!dtype_str.empty()) { - (void)std::transform(dtype_str.begin(), dtype_str.end(), dtype_str.begin(), ::toupper); - dtype_str = "DT_" + dtype_str; - ge_dtype = ge::TypeUtils::SerialStringToDataType(dtype_str); - } - ge::Format ge_format = ge::FORMAT_RESERVED; - if (!format_str.empty()) { - (void)std::transform(format_str.begin(), format_str.end(), format_str.begin(), ::toupper); - ge_format = ge::TypeUtils::SerialStringToFormat(format_str); - } - ge::GeTensorDesc ge_tensor(ge_shape, ge_format, ge_dtype); - ge_tensor.SetName(name); - ge::GeTensor const_tensor(ge_tensor, res.first->second); - ge::GeTensorPtr const_tensor_ptr = std::make_shared(const_tensor); - ge::OpDescPtr const_op_desc = ge::OpDescUtils::CreateConstOp(const_tensor_ptr); - ge::Operator const_op = ge::OpDescUtils::CreateOperatorFromOpDesc(const_op_desc); - (void)op_para.SetInput(name.c_str(), const_op); - return; -} - -void ParseConstTensorListV2(const nlohmann::json &shape_list, ge::Operator &operator_para, - std::map> &const_values) { - for (const auto &elem : shape_list) { - if (elem.is_array()) { - for (const auto &shape : elem) { - ParseConstShapeDescV2(shape, operator_para, const_values); - } - } else { - ParseConstShapeDescV2(elem, operator_para, const_values); - } - } -} - -std::string DumpByteBuffer(const ByteBuffer &buf) { - static const std::string hex_digits = "0123456789ABCDEF"; - std::string str = buf.str(); - std::string output; - const uint32_t num_two = 2; - const uint32_t num_four = 4; - const uint32_t num_fifteen = 15; - output.reserve(str.size() * num_two); - for (const unsigned char c : str) { - output.push_back(hex_digits[c >> num_four]); - output.push_back(hex_digits[c & num_fifteen]); - } - return output; -} - -bool DumpRunInfo(const OpRunInfo &run_info, char *run_info_json, const size_t &run_info_len) { - if (run_info_json == nullptr) { - GE_LOGE("run_info buffer is null"); - return false; - } - - nlohmann::json json_obj; - json_obj["block_dim"] = run_info.block_dim; - json_obj["workspaces"] = run_info.workspaces; - json_obj["tiling_data"] = DumpByteBuffer(run_info.tiling_data); - json_obj["clear_atomic"] = run_info.clear_atomic; - json_obj["tiling_key"] = run_info.tiling_key; - - const std::string str = json_obj.dump(); - if (str.size() >= run_info_len) { - GE_LOGE("runinfo too large. %zu/%zu", str.size(), run_info_len); - return false; - } - return memcpy_s(run_info_json, str.size() + 1, str.c_str(), str.size() + 1) == EOK; -} - -bool DumpRunInfoV2(const OpRunInfoV2 &run_info, char *run_info_json, const size_t &run_info_len) { - if (run_info_json == nullptr) { - REPORT_INNER_ERR_MSG("E19999", "run_info buffer is null"); - return false; - } - - nlohmann::json json_obj; - std::vector workspaces; - int64_t workspace; - for (size_t i = 0; i < run_info.GetWorkspaceNum(); ++i) { - (void) run_info.GetWorkspace(i, workspace); - workspaces.push_back(workspace); - } - json_obj["block_dim"] = run_info.GetBlockDim(); - json_obj["workspaces"] = workspaces; - json_obj["tiling_data"] = DumpByteBuffer(run_info.GetAllTilingData()); - json_obj["clear_atomic"] = run_info.GetClearAtomic(); - json_obj["tiling_key"] = run_info.GetTilingKey(); - - const std::string str = json_obj.dump(); - if (str.size() >= run_info_len) { - REPORT_INNER_ERR_MSG("E19999", "runinfo too large. %zu/%zu", str.size(), run_info_len); - return false; - } - return memcpy_s(run_info_json, str.size() + 1, str.c_str(), str.size() + 1) == EOK; -} - -int TbeOpTilingPyInterfaceEx2BackUpInner(const char *const optype, const char *const compile_info, - const char *const inputs, const char *const outputs, char *run_info_json, - size_t run_info_len, const char *const compile_info_hash, uint64_t *elapse, - const OpTilingFunc &tiling_func) { - if ((optype == nullptr) || (compile_info == nullptr) || (inputs == nullptr) || (outputs == nullptr)) { - REPORT_INNER_ERR_MSG("E19999", "optype/compile_info/inputs/outputs is null, %s, %s, %s, %s", optype, compile_info, - inputs, outputs); - return 0; - } - - std::chrono::time_point before_tiling; - std::chrono::time_point after_tiling; - TeOpParas op_params; - op_params.op_type = optype; - std::map> const_values; - try { - const nlohmann::json inputs_json = nlohmann::json::parse(inputs); - const nlohmann::json outputs_json = nlohmann::json::parse(outputs); - ParseShapeDescList(inputs_json, op_params.inputs); - ParseShapeDescList(outputs_json, op_params.outputs); - ParseConstTensorList(inputs_json, op_params.const_inputs, const_values); - } catch (...) { - REPORT_INNER_ERR_MSG("E19999", "Failed to parse json_str. %s, %s, %s", compile_info, inputs, outputs); - return 0; - } - GELOGI("Optiling func found, op_type:%s", optype); - - OpCompileInfo op_compile_info{compile_info, ""}; - if (compile_info_hash != nullptr) { - op_compile_info.key = compile_info_hash; - } - - OpRunInfo run_info; - if (elapse != nullptr) { - before_tiling = std::chrono::steady_clock::now(); - } - - const bool rc = (tiling_func)(op_params, op_compile_info, run_info); - - if (elapse != nullptr) { - after_tiling = std::chrono::steady_clock::now(); - } - if (!rc) { - GELOGW("Optiling failed. op_type: %s", optype); - return 0; - } - - if (elapse != nullptr) { - *elapse = static_cast(std::chrono::duration_cast(\ - after_tiling - before_tiling).count()); - *(elapse + 1) = static_cast(last_op_tiling_perf); - last_op_tiling_perf = -1; - } - - GELOGI("Optiling succeeded. op_type: %s", optype); - (void)DumpRunInfo(run_info, run_info_json, run_info_len); - return 1; -} - -void CheckAndSetAttr(const char *attrs, ge::Operator &operator_param) { - if (attrs != nullptr) { - GELOGD("Attrs set from pyAPI is: %s", attrs); - const nlohmann::json attrs_json = nlohmann::json::parse(attrs); - ParseAndSetAttrsList(attrs_json, operator_param); - } else { - GELOGD("Attribute has not been set."); - } - return; -} - -void ParseInputsAndOutputs(const char *inputs, const char *outputs, ge::OpDescPtr &op_desc, - ge::Operator &operator_param, std::map> &const_values) { - const nlohmann::json inputs_json = nlohmann::json::parse(inputs); - const nlohmann::json outputs_json = nlohmann::json::parse(outputs); - ParseShapeDescListV2(inputs_json, op_desc, true); - ParseShapeDescListV2(outputs_json, op_desc, false); - operator_param = ge::OpDescUtils::CreateOperatorFromOpDesc(op_desc); - ParseConstTensorListV2(inputs_json, operator_param, const_values); -} - -int TbeOpTilingPyInterfaceEx2NewInner(const char *const optype, const char *const compile_info, - const char *const inputs, const char *const outputs, char *run_info_json, - size_t run_info_len, const char *const compile_info_hash, uint64_t *elapse, - const OpTilingFuncV2 &tiling_func, const char *const attrs) { - if ((optype == nullptr) || (compile_info == nullptr) || (inputs == nullptr) || (outputs == nullptr)) { - REPORT_INNER_ERR_MSG("E19999", "optype/compile_info/inputs/outputs is null, %s, %s, %s, %s", optype, compile_info, - inputs, outputs); - return 0; - } - GELOGI("Optiling func v2 found, op_type: %s", optype); - - std::chrono::time_point before_tiling; - std::chrono::time_point after_tiling; - const std::string compile_info_str = compile_info; - std::string optype_str = optype; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - std::map> const_values; - ge::Operator operator_param; - try { - ParseInputsAndOutputs(inputs, outputs, op_desc, operator_param, const_values); - CheckAndSetAttr(attrs, operator_param); - } catch (...) { - REPORT_INNER_ERR_MSG("E19999", "Failed to parse json_str. %s, %s, %s", compile_info, inputs, outputs); - return 0; - } - - OpCompileInfoV2 op_compile_info{" ", compile_info_str}; - const ge::AscendString opCompileInfoHash(compile_info_hash); - if (compile_info_hash != nullptr) { - op_compile_info.SetKey(opCompileInfoHash); - } - - OpRunInfoV2 run_info(static_cast(0), false, static_cast(0)); - if (elapse != nullptr) { - before_tiling = std::chrono::steady_clock::now(); - } - - const bool rc = (tiling_func)(operator_param, op_compile_info, run_info); - - if (elapse != nullptr) { - after_tiling = std::chrono::steady_clock::now(); - } - if (!rc) { - GELOGW("Optiling failed. op_type: %s", optype); - return 0; - } - - if (elapse != nullptr) { - *elapse = static_cast(std::chrono::duration_cast(\ - after_tiling - before_tiling).count()); - *(elapse + 1) = static_cast(last_op_tiling_perf); - last_op_tiling_perf = -1; - } - - GELOGI("Op tiling v2 succeeded. op_type: %s", optype); - (void)DumpRunInfoV2(run_info, run_info_json, run_info_len); - return 1; -} - -int TbeOpTilingPyInterfaceEx3Inner(const char *const optype, const char *const compile_info, const char *const inputs, - const char *const outputs, char *run_info_json, size_t run_info_len, - const char *const compile_info_hash, uint64_t *elapse, - const OpTilingFuncV3 &tiling_func, const OpParseFuncV3 &parse_func, - const char *const attrs) { - if ((optype == nullptr) || (compile_info == nullptr) || (inputs == nullptr) || (outputs == nullptr)) { - REPORT_INNER_ERR_MSG("E19999", "optype/compile_info/inputs/outputs is null, %s, %s, %s, %s", optype, compile_info, - inputs, outputs); - return 0; - } - GELOGI("Optiling func v3 found, op_type: %s", optype); - - std::chrono::time_point before_tiling; - std::chrono::time_point after_tiling; - std::string optype_str = optype; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - std::map> const_values; - ge::Operator operator_param; - try { - ParseInputsAndOutputs(inputs, outputs, op_desc, operator_param, const_values); - CheckAndSetAttr(attrs, operator_param); - } catch (...) { - GELOGE(ge::FAILED, "Failed to parse json_str. %s, %s, %s", compile_info, inputs, outputs); - REPORT_INNER_ERR_MSG("E19999", "Failed to parse json_str. %s, %s, %s", compile_info, inputs, outputs); - return 0; - } - if (compile_info_hash == nullptr) { - return 0; - } - - const ge::AscendString compile_info_json_str = compile_info; - void* op_compile_json_ptr = (parse_func)(operator_param, compile_info_json_str); - - OpRunInfoV2 run_info(static_cast(0), false, static_cast(0)); - if (elapse != nullptr) { - before_tiling = std::chrono::steady_clock::now(); - } - const bool rc = (tiling_func)(operator_param, op_compile_json_ptr, run_info); - - if (elapse != nullptr) { - after_tiling = std::chrono::steady_clock::now(); - } - if (!rc) { - GELOGW("Optiling failed. op_type: %s", optype); - return 0; - } - - if (elapse != nullptr) { - *elapse = static_cast(std::chrono::duration_cast\ - (after_tiling - before_tiling).count()); - *(elapse + 1) = static_cast(last_op_tiling_perf); - last_op_tiling_perf = -1; - } - - GELOGI("Op tiling v3 succeeded. op_type: %s", optype); - (void)DumpRunInfoV2(run_info, run_info_json, run_info_len); - return 1; -} - -int TbeOpTilingPyInterfaceEx4Inner(const char *const optype, const char *const compile_info, const char *const inputs, - const char *const outputs, char *run_info_json, size_t run_info_len, - const char *const compile_info_hash, uint64_t *elapse, - const OpTilingFuncV4 &tiling_func, const OpParseFuncV4 &parse_func, - const char *const attrs) { - if ((optype == nullptr) || (compile_info == nullptr) || (inputs == nullptr) || (outputs == nullptr)) { - REPORT_INNER_ERR_MSG("E19999", "optype/compile_info/inputs/outputs is null, %s, %s, %s, %s", optype, compile_info, - inputs, outputs); - return 0; - } - GELOGI("Optiling func v4 found, op_type:%s", optype); - - std::chrono::time_point before_tiling; - std::chrono::time_point after_tiling; - std::string op_type_str = optype; - ge::OpDescPtr op_desc_ptr = std::make_shared("", op_type_str); - std::map> const_values; - ge::Operator operator_param; - try { - ParseInputsAndOutputs(inputs, outputs, op_desc_ptr, operator_param, const_values); - CheckAndSetAttr(attrs, operator_param); - } catch (...) { - REPORT_INNER_ERR_MSG("E19999", "Failed to parse json during tiling v4. %s, %s, %s", compile_info, inputs, outputs); - return 0; - } - if (compile_info_hash == nullptr) { - return 0; - } - - const ge::AscendString compile_info_json = compile_info; - const CompileInfoPtr op_compile_json_ptr = (parse_func)(operator_param, compile_info_json); - - OpRunInfoV2 run_info(static_cast(0), false, static_cast(0)); - if (elapse != nullptr) { - before_tiling = std::chrono::steady_clock::now(); - } - const bool rc = (tiling_func)(operator_param, op_compile_json_ptr, run_info); - - if (elapse != nullptr) { - after_tiling = std::chrono::steady_clock::now(); - } - if (!rc) { - GELOGW("Optiling failed. op_type: %s", optype); - return 0; - } - - if (elapse != nullptr) { - *elapse = static_cast( - std::chrono::duration_cast(after_tiling - before_tiling).count()); - *(elapse + 1) = static_cast(last_op_tiling_perf); - last_op_tiling_perf = -1; - } - - GELOGI("Op tiling v4 succeed. op_type:%s", optype); - (void) DumpRunInfoV2(run_info, run_info_json, run_info_len); - return 1; -} - -ge::graphStatus UpdateCoreCountWithOpDesc(std::map& res, const std::string& key_ini, const std::string& key_op, ge::OpDesc::OpDescPtr op_desc) { - auto it = res.find(key_ini); - if (it == res.end()) { - return ge::GRAPH_SUCCESS; - } - string core_num_str = ""; - const int32_t core_num_ini = std::stoi(it->second); - - if (!ge::AttrUtils::HasAttr(op_desc, key_op)) { - GELOGI("No attr: %s exist in op_desc", key_op.c_str()); - return ge::GRAPH_SUCCESS; - } - if (ge::AttrUtils::GetStr(op_desc, key_op, core_num_str)) { - GELOGI("Attr: %s exists in op_desc, value: %s", key_op.c_str(), core_num_str.c_str()); - int32_t core_num = -1; - try { - core_num = std::stoi(core_num_str); - } catch (...) { - GELOGE(ge::GRAPH_FAILED, "[Parse][Param]Failed, digit str:%s cannot change to int", core_num_str.c_str()); - return ge::GRAPH_FAILED; - } - if (core_num > 0) { - GELOGD("Change %s from platform %ld to the op_desc %ld.", key_ini.c_str(), core_num_ini, core_num); - res[key_ini] = core_num_str; - } - } - return ge::GRAPH_SUCCESS; -} - -gert::KernelContextHolder BuildTilingParseContextHolder(ge::OpDescPtr &op_desc, const char *compile_info, - const char *op_type, fe::PlatFormInfos &platform_info, - const gert::OpImplRegistry::OpImplFunctions *funcs) { - std::vector> tiling_parse_outputs(1, std::make_pair(nullptr, nullptr)); - if (op_desc->GetType() != OP_TYPE_AUTO_TILING) { - tiling_parse_outputs[0].first = funcs->compile_info_creator(); - tiling_parse_outputs[0].second = funcs->compile_info_deleter; - } - std::string socVersionStr; - (void)platform_info.GetPlatformResWithLock("version", "Short_SoC_version", socVersionStr); - GELOGI("Short_SoC_version in platform_info: %s", socVersionStr.c_str()); - - static fe::PlatFormInfos tmp_info = platform_info; - std::map res; - (void)tmp_info.GetPlatformResWithLock("SoCInfo", res); - - GELOGI("Begin to UpdateCoreCountWithOpDesc, opName: %s, opType: %s", op_desc->GetName().c_str(), - op_desc->GetTypePtr()); - bool needUpdate = false; - if (ge::AttrUtils::HasAttr(op_desc, ge::public_attr::OP_AI_CORE_NUM) || ge::AttrUtils::HasAttr(op_desc, ge::public_attr::OP_VECTOR_CORE_NUM)) { - needUpdate = true; - if (UpdateCoreCountWithOpDesc(res, ge::public_attr::AI_CORE_CNT, ge::public_attr::OP_AI_CORE_NUM, op_desc) != ge::GRAPH_SUCCESS) { - return gert::KernelContextHolder(); - } - res[ge::public_attr::CUBE_CORE_CNT] = res[ge::public_attr::AI_CORE_CNT]; - if (UpdateCoreCountWithOpDesc(res, ge::public_attr::VECTOR_CORE_CNT, ge::public_attr::OP_VECTOR_CORE_NUM, op_desc) != ge::GRAPH_SUCCESS) { - return gert::KernelContextHolder(); - } - tmp_info.SetPlatformResWithLock(ge::public_attr::SOC_INFO, res); - } - - if (needUpdate) { - GELOGI("Need to update platform info, use tmp_info"); - return gert::KernelRunContextBuilder() - .Inputs({std::make_pair(const_cast(compile_info), nullptr), - std::make_pair(reinterpret_cast(&tmp_info), nullptr), - std::make_pair(const_cast(op_type), nullptr)}) - .Outputs(tiling_parse_outputs) - .Build(op_desc); - } else { - GELOGI("No need to update platform info"); - return gert::KernelRunContextBuilder() - .Inputs({std::make_pair(const_cast(compile_info), nullptr), - std::make_pair(reinterpret_cast(&platform_info), nullptr), - std::make_pair(const_cast(op_type), nullptr)}) - .Outputs(tiling_parse_outputs) - .Build(op_desc); - } -} - -gert::KernelContextHolder BuildTilingContext(ContextComponent &context_com, gert::KernelContext *tiling_parse_context, - fe::PlatFormInfos &platform_info) { - if (context_com.storage_shapes.size() >= std::numeric_limits::max() - kSize) { - GELOGE(ge::GRAPH_FAILED, "Context storage size overflow."); - return gert::KernelContextHolder(); - } - std::vector tiling_context_inputs(context_com.storage_shapes.size() + kSize, nullptr); - for (size_t i = 0UL; i < context_com.index_to_tensors.size(); ++i) { - tiling_context_inputs[context_com.index_to_tensors[i].first] = - reinterpret_cast(context_com.index_to_tensors[i].second.get()); - } - for (size_t i = 0UL; i < context_com.storage_shapes.size(); ++i) { - if (tiling_context_inputs[i] == nullptr) { - tiling_context_inputs[i] = &context_com.storage_shapes[i]; - } - } - if (tiling_parse_context->GetOutputPointer(0) == nullptr) { - GELOGE(ge::GRAPH_FAILED, "Output Pointer is null."); - return gert::KernelContextHolder(); - } - std::string socVersionStr; - (void)platform_info.GetPlatformResWithLock("version", "Short_SoC_version", socVersionStr); - GELOGI("Short_SoC_version in platform_info: %s", socVersionStr.c_str()); - - static fe::PlatFormInfos tmp_info = platform_info; - std::map res; - (void)tmp_info.GetPlatformResWithLock("SoCInfo", res); - - bool needUpdate = false; - if (ge::AttrUtils::HasAttr(context_com.op_desc, ge::public_attr::OP_AI_CORE_NUM) - || ge::AttrUtils::HasAttr(context_com.op_desc, ge::public_attr::OP_VECTOR_CORE_NUM)) { - needUpdate = true; - GELOGI("Begin to UpdateCoreCountWithOpDesc, opName: %s, opType: %s", context_com.op_desc->GetName().c_str(), context_com.op_desc->GetTypePtr()); - GE_ASSERT_SUCCESS(UpdateCoreCountWithOpDesc(res, ge::public_attr::AI_CORE_CNT, ge::public_attr::OP_AI_CORE_NUM, context_com.op_desc)); - res[ge::public_attr::CUBE_CORE_CNT] = res[ge::public_attr::AI_CORE_CNT]; - GE_ASSERT_SUCCESS(UpdateCoreCountWithOpDesc(res, ge::public_attr::VECTOR_CORE_CNT, ge::public_attr::OP_VECTOR_CORE_NUM, context_com.op_desc)); - - tmp_info.SetPlatformResWithLock(ge::public_attr::SOC_INFO, res); - } - - tiling_context_inputs[context_com.storage_shapes.size()] = *tiling_parse_context->GetOutputPointer(0); - if (needUpdate) { - GELOGI("Need to update platform info, use tmp_info"); - tiling_context_inputs[context_com.storage_shapes.size() + 1UL] = reinterpret_cast(&tmp_info); - } else { - GELOGI("No need to update platform info"); - tiling_context_inputs[context_com.storage_shapes.size() + 1UL] = reinterpret_cast(&platform_info); - } - - int32_t deterministic = 0; - (void)ge::AttrUtils::GetInt(context_com.op_desc, "deterministic", deterministic); - GELOGI("Get deterministic: %d from node: %s", deterministic, context_com.op_desc->GetName().c_str()); - tiling_context_inputs[context_com.storage_shapes.size() + kDeterministicOffset] = - reinterpret_cast(deterministic); - return gert::KernelRunContextBuilder() - .Inputs(tiling_context_inputs) - .Outputs( - {nullptr, nullptr, &context_com.atomic_flag, context_com.tiling_data.get(), context_com.workspace_size.get(), - &context_com.tiling_cond, &context_com.schedule_mode, nullptr, nullptr}) - .Build(context_com.op_desc); -} - -ge::graphStatus DoTilingParse(const gert::OpImplRegistry::OpImplFunctions *funcs, - gert::KernelContextHolder &tiling_parse_context_holder) { - GE_CHECK_NOTNULL(tiling_parse_context_holder.context_); - return (funcs->tiling_parse)(tiling_parse_context_holder.context_); -} - -ge::graphStatus DoTilingWithTiming(const gert::OpImplRegistry::OpImplFunctions *funcs, uint64_t *elapse, - gert::KernelContextHolder &tiling_context_holder) { - GE_CHECK_NOTNULL(tiling_context_holder.context_); - // calcu tiling cost time - std::chrono::time_point before_tiling; - std::chrono::time_point after_tiling; - if (elapse != nullptr) { - before_tiling = std::chrono::steady_clock::now(); - } - const auto ret = (funcs->tiling)(reinterpret_cast(tiling_context_holder.context_)); - if (elapse != nullptr) { - after_tiling = std::chrono::steady_clock::now(); - } - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - - if (elapse != nullptr) { - *elapse = static_cast( - std::chrono::duration_cast(after_tiling - before_tiling).count()); - *(elapse + 1) = static_cast(last_op_tiling_perf); - last_op_tiling_perf = -1; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ParseJson(const char *const inputs, const char *const outputs, const char *const attrs, - const char *const extra_info, ContextComponent &context_com) { - if (ParseInputs(inputs, context_com) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "Parse inputs failed."); - REPORT_INNER_ERR_MSG("E19999", "Parse inputs failed."); - return ge::GRAPH_FAILED; - } - if (ParseOutputs(outputs, context_com) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "Parse outputs failed."); - REPORT_INNER_ERR_MSG("E19999", "Parse outputs failed."); - return ge::GRAPH_FAILED; - } - if (ParseAttrs(attrs, context_com.op_desc) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "Parse attrs failed."); - REPORT_INNER_ERR_MSG("E19999", "Parse attrs failed."); - return ge::GRAPH_FAILED; - } - if (ParseExtraInfos(extra_info, context_com.op_desc) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "Parse extra info failed."); - REPORT_INNER_ERR_MSG("E19999", "Parse extra info failed."); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -int32_t ParseDeviceIdAndCoreType(const char *compile_info, uint32_t &device_id, std::string &core_type) { - const std::string compile_str = compile_info; - if (compile_str.empty()) { - GELOGD("compile info is empty."); - return 1; - } - nlohmann::json info_list; - try { - info_list = nlohmann::json::parse(compile_info); - } catch (const nlohmann::json::exception &e) { - GELOGE(ge::GRAPH_FAILED, "Parse json exception. %s", compile_info); - return 0; - } - GELOGD("Parsing compile info: %s.", info_list.dump().c_str()); - - if (info_list.contains("device_id")) { - if (info_list["device_id"].is_null()) { - GELOGD("device_id is null."); - } else { - device_id = std::atoi(info_list["device_id"].get().c_str()); - GELOGI("Parse device id: %u.", device_id); - } - } - if (info_list.contains(ge::ATTR_NAME_SGT_CUBE_VECTOR_CORE_TYPE)) { - if (info_list[ge::ATTR_NAME_SGT_CUBE_VECTOR_CORE_TYPE].is_null()) { - GELOGD("Attribute %s is null.", ge::ATTR_NAME_SGT_CUBE_VECTOR_CORE_TYPE.c_str()); - } else { - core_type = info_list[ge::ATTR_NAME_SGT_CUBE_VECTOR_CORE_TYPE].get(); - GELOGI("Parsing core type: %s.", core_type.c_str()); - } - } else { - if (info_list.contains(ge::ATTR_NAME_CUBE_VECTOR_CORE_TYPE)) { - if (info_list[ge::ATTR_NAME_CUBE_VECTOR_CORE_TYPE].is_null()) { - GELOGD("Attribute %s is null.", ge::ATTR_NAME_CUBE_VECTOR_CORE_TYPE.c_str()); - } else { - core_type = info_list[ge::ATTR_NAME_CUBE_VECTOR_CORE_TYPE].get(); - GELOGI("Parsing core type: %s.", core_type.c_str()); - } - } - } - - return 1; -} - -int32_t GetPlatformInfo(const char *compile_info, fe::PlatFormInfos &platform_info) { - uint32_t device_id = 0U; - std::string core_type; - if (ParseDeviceIdAndCoreType(compile_info, device_id, core_type) == 0) { - return 0; - } - - if (fe::PlatformInfoManager::Instance().InitializePlatformInfo() != 0U) { - GELOGE(ge::GRAPH_FAILED, "InitializePlatformInfo failed."); - REPORT_INNER_ERR_MSG("E19999", "InitializePlatformInfo failed."); - return 0; - } - - if (fe::PlatformInfoManager::Instance().GetPlatformInstanceByDevice(device_id, platform_info) != 0) { - GELOGE(ge::GRAPH_FAILED, "GetPlatformInstanceByDevice failed."); - REPORT_INNER_ERR_MSG("E19999", "GetPlatformInstanceByDevice failed."); - return 0; - } - platform_info.SetCoreNumByCoreType(core_type); - GELOGD("device id: %u, core type: %s, core num: %u.", device_id, core_type.c_str(), platform_info.GetCoreNum()); - - return 1; -} - -int64_t GetNewMaxTilingSize(const char *const attrs) { - if (attrs == nullptr) { - return 0; - } - nlohmann::json attr_json = nlohmann::json::parse(attrs); - for (const auto &attr : attr_json) { - if (attr.contains("name") && attr.contains("value") && - attr["name"].get() == "ascendc_op_para_size") { // new max tiling size - return attr["value"].get(); - } - } - return 0; -} - -int TbeOptilingPyInterfaceNew(const char *const op_type, const char *const compile_info, const char *const inputs, - const char *const outputs, char *run_info_json, size_t run_info_len, uint64_t *elapse, - const char *const attrs, const char *const extra_info) { - if ((compile_info == nullptr) || (inputs == nullptr) || (outputs == nullptr)) { - GELOGE(ge::GRAPH_FAILED, "compile_info/inputs/outputs is null."); - REPORT_INNER_ERR_MSG("E19999", "compile_info/inputs/outputs is null."); - return 0; - } - - const gert::OpImplKernelRegistry::OpImplFunctions *funcs; - if (!FindImplFuncs(op_type, funcs)) { - return 0; - } - ContextComponent context_com {}; - context_com.op_desc = std::make_shared("", op_type); - if ((context_com.op_desc == nullptr) || - (ParseJson(inputs, outputs, attrs, extra_info, context_com) != ge::GRAPH_SUCCESS)) { - return 0; - } - - fe::PlatFormInfos platform_info; - if (GetPlatformInfo(compile_info, platform_info) == 0) { - return 0; - } - - // tiling parse - auto tiling_parse_context_holder = BuildTilingParseContextHolder(context_com.op_desc, compile_info, op_type, - platform_info, funcs); - if (DoTilingParse(funcs, tiling_parse_context_holder) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "Op %s tiling parse failed", op_type); - REPORT_INNER_ERR_MSG("E19999", "Op %s tiling parse failed", op_type); - return 0; - } - - // tiling - int64_t max_size = -1; - const int64_t new_max_tiling_size = GetNewMaxTilingSize(attrs); - if (!ge::AttrUtils::GetInt(context_com.op_desc, kMaxTilingSize, max_size)) { - GELOGI("Missing maximum tiling size in opdesc."); - if (new_max_tiling_size != 0) { - max_size = new_max_tiling_size; - } - } - if (max_size == -1) { - max_size = static_cast(kMaxTilingDataSize); - } - const auto aligned_max_size = ge::RoundUp(static_cast(max_size), sizeof(uintptr_t)); - context_com.tiling_data = gert::TilingData::CreateCap(aligned_max_size); - context_com.workspace_size = gert::ContinuousVector::Create(kWorkspaceHolerSize); - gert::KernelContextHolder tiling_context_holder = - BuildTilingContext(context_com, tiling_parse_context_holder.context_, platform_info); - if (tiling_context_holder.context_ == nullptr) { - GELOGE(ge::GRAPH_FAILED, "Output build tiling context failed."); - return 0; - } - if (tiling_context_holder.GetKernelContext()->GetOutputPointer( - gert::TilingContext::kOutputTilingCond) == nullptr) { - GELOGE(ge::GRAPH_FAILED, "Output tiling cond is null."); - return 0; - } - if (tiling_context_holder.GetKernelContext()->GetOutputPointer( - gert::TilingContext::kOutputScheduleMode) == nullptr) { - GELOGE(ge::GRAPH_FAILED, "Output tiling cond is null."); - return 0; - } - - // BuildTilingContext will not initialize schedule mode, initialize it here - *tiling_context_holder.GetKernelContext()->GetOutputPointer(gert::TilingContext::kOutputTilingCond) = 0; - *tiling_context_holder.GetKernelContext()->GetOutputPointer(gert::TilingContext::kOutputScheduleMode) = 0; - if (DoTilingWithTiming(funcs, elapse, tiling_context_holder) != ge::GRAPH_SUCCESS) { - GELOGE(ge::GRAPH_FAILED, "Op %s tiling failed", op_type); - REPORT_INNER_ERR_MSG("E19999", "Op %s tiling failed", op_type); - return 0; - } - - if (!DumpRunInfo(tiling_context_holder.context_, run_info_json, run_info_len)) { - GELOGE(ge::GRAPH_FAILED, "Dump op %s tiling result failed", op_type); - REPORT_INNER_ERR_MSG("E19999", "Dump op %s tiling result failed", op_type); - return 0; - } - GELOGI("Op tiling succeed. op_type:%s", op_type); - return 1; -} - -int TbeOpTilingPyInterfaceOld(const char *const optype, const char *const compile_info, - const char *const compile_info_hash, const char *const inputs, const char *const outputs, - const char *const attrs, char *run_info_json, size_t run_info_len, uint64_t *elapse, - const char *const extra_info) { - auto &op_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - auto iter = op_func_map.find(optype); - if (iter == op_func_map.end()) { - GELOGI("Op tiling function for op_type [%s] not found.", optype); - return TbeOptilingPyInterfaceNew(optype, compile_info, inputs, outputs, run_info_json, run_info_len, elapse, attrs, - extra_info); - } - OpTilingFuncInfo &op_func_info = iter->second; - int ret = 0; - if (op_func_info.IsFunctionV4()) { - const OpTilingFuncV4 &tiling_func = op_func_info.GetOpTilingFuncV4(); - const OpParseFuncV4 &parse_func = op_func_info.GetOpParseFuncV4(); - ret = TbeOpTilingPyInterfaceEx4Inner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func, parse_func, attrs); - } else if (op_func_info.IsFunctionV3()) { - const OpTilingFuncV3 &tiling_func = op_func_info.GetOpTilingFuncV3(); - const OpParseFuncV3 &parse_func = op_func_info.GetOpParseFuncV3(); - ret = TbeOpTilingPyInterfaceEx3Inner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func, parse_func, attrs); - } else if (op_func_info.IsFunctionV2()) { - const OpTilingFuncV2 &tiling_func = op_func_info.GetOpTilingFuncV2(); - ret = TbeOpTilingPyInterfaceEx2NewInner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func, attrs); - } else if (op_func_info.IsFunctionV1()) { - const OpTilingFunc &tiling_func = op_func_info.GetOpTilingFunc(); - ret = TbeOpTilingPyInterfaceEx2BackUpInner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func); - } else { - GE_LOGE("Optiling func of op type [%s] is completely empty.", optype); - } - return ret; -} - -extern "C" int OpTilingForCompile(const char *optype, const char *compile_info, const char *compile_info_hash, - const char *inputs, const char *outputs, const char *attrs, char *run_info_json, - size_t run_info_len, uint64_t *elapse, const char *extra_info) { - if (optype == nullptr) { - GELOGE(ge::GRAPH_FAILED, "op type is null."); - REPORT_INNER_ERR_MSG("E19999", "op type is null."); - return 0; - } - - if (strcmp(optype, OP_TYPE_AUTO_TILING.c_str()) == 0) { - GELOGI("The tiling function is automatically enabled for tiling on rt2."); - return TbeOptilingPyInterfaceNew(optype, compile_info, inputs, outputs, run_info_json, run_info_len, elapse, attrs, - extra_info); - } - return TbeOpTilingPyInterfaceOld(optype, compile_info, compile_info_hash, inputs, outputs, attrs, run_info_json, - run_info_len, elapse, extra_info); -} - -extern "C" const char *DoOpTilingForCompile(const char *optype, - const char *compile_info, - const char *compile_info_hash, - const char *inputs, - const char *outputs, - const char *attrs, - char *run_info_json, - size_t run_info_len, - uint64_t *elapse, - const char *extra_info) { - if (optype == nullptr) { - GELOGE(ge::GRAPH_FAILED, "op type is null."); - REPORT_INNER_ERR_MSG("E19999", "op type is null."); - error_string = GetRawErrorMessage(); - return error_string.data(); - } - - if (strcmp(optype, OP_TYPE_AUTO_TILING.c_str()) == 0) { - GELOGI("The tiling function is automatically enabled for tiling on rt2."); - if (TbeOptilingPyInterfaceNew(optype, compile_info, inputs, outputs, run_info_json, run_info_len, elapse, attrs, - extra_info) == 0) { - GELOGE(ge::GRAPH_FAILED, "TbeOptilingPyInterfaceNew failed."); - REPORT_INNER_ERR_MSG("E19999", "TbeOptilingPyInterfaceNew failed."); - } - error_string = GetRawErrorMessage(); - return error_string.data(); - } - if (TbeOpTilingPyInterfaceOld(optype, compile_info, compile_info_hash, inputs, outputs, attrs, run_info_json, - run_info_len, elapse, extra_info) == 0) { - GELOGE(ge::GRAPH_FAILED, "TbeOpTilingPyInterfaceOld failed."); - REPORT_INNER_ERR_MSG("E19999", "TbeOpTilingPyInterfaceOld failed."); - } - error_string = GetRawErrorMessage(); - return error_string.data(); -} - -extern "C" int TbeOpTilingPyInterface(const char *optype, const char *compile_info, const char *compile_info_hash, - const char *inputs, const char *outputs, const char *attrs, char *run_info_json, - size_t run_info_len, uint64_t *elapse) { - GELOGW("Deprecated api, use OpTilingForCompile instead."); - if (optype == nullptr) { - GELOGE(ge::GRAPH_FAILED, "op type is null."); - REPORT_INNER_ERR_MSG("E19999", "op type is null."); - return 0; - } - - if (strcmp(optype, OP_TYPE_AUTO_TILING.c_str()) == 0) { - GELOGI("The tiling function is automatically enabled for tiling on rt2."); - return TbeOptilingPyInterfaceNew(optype, compile_info, inputs, outputs, run_info_json, run_info_len, elapse, attrs, - nullptr); - } - - return TbeOpTilingPyInterfaceOld(optype, compile_info, compile_info_hash, inputs, outputs, attrs, run_info_json, - run_info_len, elapse, nullptr); -} - -extern "C" int TbeOpTilingPyInterfaceEx2(const char *optype, const char *compile_info, const char *inputs, - const char *outputs, char *run_info_json, size_t run_info_len, - const char *compile_info_hash, uint64_t *elapse) { - GELOGW("Deprecated api, use OpTilingForCompile instead."); - return TbeOpTilingPyInterface(optype, compile_info, compile_info_hash, inputs, outputs, nullptr, run_info_json, - run_info_len, elapse); -} - -extern "C" int TbeOpTilingPyInterfaceEx4(const char *optype, const char *compile_info, const char *inputs, - const char *outputs, char *run_info_json, size_t run_info_len, - const char *compile_info_hash, uint64_t *elapse, - const OpTilingFuncV4 &tiling_func, const OpParseFuncV4 &parse_func, - const char *attrs) { - GELOGW("Deprecated api, use OpTilingForCompile instead."); - return TbeOpTilingPyInterfaceEx4Inner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func, parse_func, attrs); -} - -extern "C" int TbeOpTilingPyInterfaceEx3(const char *optype, const char *compile_info, const char *inputs, - const char *outputs, char *run_info_json, size_t run_info_len, - const char *compile_info_hash, uint64_t *elapse, - const OpTilingFuncV3 &tiling_func, const OpParseFuncV3 &parse_func, - const char *attrs) { - GELOGW("Deprecated api, use OpTilingForCompile instead."); - return TbeOpTilingPyInterfaceEx3Inner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func, parse_func, attrs); -} - -extern "C" int TbeOpTilingPyInterfaceEx2New(const char *optype, const char *compile_info, const char *inputs, - const char *outputs, char *run_info_json, size_t run_info_len, - const char *compile_info_hash, uint64_t *elapse, - const OpTilingFuncV2 &tiling_func, const char *attrs) { - GELOGW("Deprecated api, use OpTilingForCompile instead."); - return TbeOpTilingPyInterfaceEx2NewInner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func, attrs); -} - -extern "C" int TbeOpTilingPyInterfaceEx2BackUp(const char *optype, const char *compile_info, const char *inputs, - const char *outputs, char *run_info_json, size_t run_info_len, - const char *compile_info_hash, uint64_t *elapse, - const OpTilingFunc &tiling_func) { - GELOGW("Deprecated api, use OpTilingForCompile instead."); - return TbeOpTilingPyInterfaceEx2BackUpInner(optype, compile_info, inputs, outputs, run_info_json, run_info_len, - compile_info_hash, elapse, tiling_func); -} - -extern "C" Status TbeLoadSoAndSaveToRegistry(const char *so_path) { - GE_ASSERT_NOTNULL(so_path); - GELOGD("start TbeLoadSoAndSaveToRegistry, so path: %s, pid is %d", so_path, getpid()); - return gert::OpImplSpaceRegistry::LoadSoAndSaveToRegistry(so_path); -} -} -} // namespace optiling diff --git a/register/op_tiling/op_tiling_registry.cc b/register/op_tiling/op_tiling_registry.cc deleted file mode 100644 index dd2705a1233966209df715da9e196cd2a2fee5e5..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling_registry.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_tiling_registry.h" -#include "common/ge_common/debug/ge_log.h" - -namespace optiling { -size_t ByteBufferGetAll(ByteBuffer &buf, ge::char_t *dest, size_t dest_len) { - size_t nread = 0; - size_t rn = 0; - do { - rn = static_cast(buf.readsome(dest + nread, static_cast(dest_len - nread))); - nread += rn; - } while ((rn > 0) && (dest_len > nread)); - - return nread; -} - -ByteBuffer &ByteBufferPut(ByteBuffer &buf, const uint8_t *data, size_t data_len) { - (void)buf.write(reinterpret_cast(data), static_cast(data_len)); - (void)buf.flush(); - return buf; -} - -std::unordered_map &OpTilingRegistryInterf::RegisteredOpInterf() { - static std::unordered_map interf; - return interf; -} - -OpTilingRegistryInterf::OpTilingRegistryInterf(std::string op_type, OpTilingFunc func) { - auto &interf = RegisteredOpInterf(); - (void)interf.emplace(op_type, func); - GELOGI("Register tiling function: op_type:%s, funcPointer:%p, registered count:%zu", op_type.c_str(), - func.target(), interf.size()); -} - -std::unordered_map &OpTilingRegistryInterf_V2::RegisteredOpInterf() { - static std::unordered_map interf; - GELOGI("Generated interface by new method, registered count: %zu", interf.size()); - return interf; -} - -OpTilingRegistryInterf_V2::OpTilingRegistryInterf_V2(const std::string &op_type, OpTilingFuncV2 func) { - auto &interf = RegisteredOpInterf(); - (void)interf.emplace(op_type, std::move(func)); - GELOGI("Registering tiling function with new method: op_type=%s, registered_count=%zu", op_type.c_str(), interf.size()); -} - -OpTilingFuncInfo::OpTilingFuncInfo(const std::string &op_type) - : op_type_(op_type), - tiling_func_(nullptr), - tiling_func_v2_(nullptr), - tiling_func_v3_(nullptr), - parse_func_v3_(nullptr) {} - -bool OpTilingFuncInfo::IsFunctionV4() { - return this->tiling_func_v4_ != nullptr && this->parse_func_v4_ != nullptr; -} -bool OpTilingFuncInfo::IsFunctionV3() { - return this->tiling_func_v3_ != nullptr && this->parse_func_v3_ != nullptr; -} -bool OpTilingFuncInfo::IsFunctionV2() { - return this->tiling_func_v2_ != nullptr; -} -bool OpTilingFuncInfo::IsFunctionV1() { - return this->tiling_func_ != nullptr; -} -void OpTilingFuncInfo::SetOpTilingFunc(OpTilingFunc &tiling_func) { - this->tiling_func_ = tiling_func; -} -void OpTilingFuncInfo::SetOpTilingFuncV2(OpTilingFuncV2 &tiling_func) { - this->tiling_func_v2_ = tiling_func; -} -void OpTilingFuncInfo::SetOpTilingFuncV3(OpTilingFuncV3 &tiling_func, OpParseFuncV3 &parse_func) { - this->tiling_func_v3_ = tiling_func; - this->parse_func_v3_ = parse_func; -} -void OpTilingFuncInfo::SetOpTilingFuncV4(OpTilingFuncV4 &tiling_func, OpParseFuncV4 &parse_func) { - this->tiling_func_v4_ = tiling_func; - this->parse_func_v4_ = parse_func; -} -const OpTilingFunc& OpTilingFuncInfo::GetOpTilingFunc() { - return this->tiling_func_; -} -const OpTilingFuncV2& OpTilingFuncInfo::GetOpTilingFuncV2() { - return this->tiling_func_v2_; -} -const OpTilingFuncV3& OpTilingFuncInfo::GetOpTilingFuncV3() { - return this->tiling_func_v3_; -} -const OpParseFuncV3& OpTilingFuncInfo::GetOpParseFuncV3() { - return this->parse_func_v3_; -} -const OpTilingFuncV4& OpTilingFuncInfo::GetOpTilingFuncV4() { - return this->tiling_func_v4_; -} -const OpParseFuncV4& OpTilingFuncInfo::GetOpParseFuncV4() { - return this->parse_func_v4_; -} - -std::unordered_map &OpTilingFuncRegistry::RegisteredOpFuncInfo() { - static std::unordered_map op_func_map; - return op_func_map; -} - -OpTilingFuncRegistry::OpTilingFuncRegistry(const std::string &op_type, OpTilingFunc tiling_func) { - auto &op_func_map = RegisteredOpFuncInfo(); - const auto iter = op_func_map.find(op_type); - if (iter == op_func_map.end()) { - OpTilingFuncInfo op_func_info(op_type); - op_func_info.SetOpTilingFunc(tiling_func); - (void)op_func_map.emplace(op_type, op_func_info); - } else { - iter->second.SetOpTilingFunc(tiling_func); - } - GELOGI("Register op tiling function V1 for op_type:%s", op_type.c_str()); -} -OpTilingFuncRegistry::OpTilingFuncRegistry(const std::string &op_type, OpTilingFuncV2 tiling_func) { - auto &op_func_map = RegisteredOpFuncInfo(); - const auto iter = op_func_map.find(op_type); - if (iter == op_func_map.end()) { - OpTilingFuncInfo op_func_info(op_type); - op_func_info.SetOpTilingFuncV2(tiling_func); - (void)op_func_map.emplace(op_type, op_func_info); - } else { - iter->second.SetOpTilingFuncV2(tiling_func); - } - GELOGI("Register op tiling function V2 for op_type:%s", op_type.c_str()); -} - -OpTilingFuncRegistry::OpTilingFuncRegistry(const std::string &op_type, - OpTilingFuncV3 tiling_func, OpParseFuncV3 parse_func) { - auto &op_func_map = RegisteredOpFuncInfo(); - const auto iter = op_func_map.find(op_type); - if (iter == op_func_map.end()) { - OpTilingFuncInfo op_func_info(op_type); - op_func_info.SetOpTilingFuncV3(tiling_func, parse_func); - (void)op_func_map.emplace(op_type, op_func_info); - } else { - iter->second.SetOpTilingFuncV3(tiling_func, parse_func); - } - GELOGI("Register op tiling and parse function V3 for op_type:%s", op_type.c_str()); -} - -OpTilingFuncRegistry::OpTilingFuncRegistry(const std::string &op_type, - OpTilingFuncV4 tiling_func, OpParseFuncV4 parse_func) { - auto &op_func_map = RegisteredOpFuncInfo(); - const auto iter = op_func_map.find(op_type); - if (iter == op_func_map.end()) { - OpTilingFuncInfo op_func_info(op_type); - op_func_info.SetOpTilingFuncV4(tiling_func, parse_func); - (void)op_func_map.emplace(op_type, op_func_info); - } else { - iter->second.SetOpTilingFuncV4(tiling_func, parse_func); - } - GELOGI("Registering tiling and parsing function V4 for op_type: %s", op_type.c_str()); -} -} // namespace optiling diff --git a/register/op_tiling/op_tiling_utils.cc b/register/op_tiling/op_tiling_utils.cc deleted file mode 100644 index 767581679135abe55beb806e5301c78aea1ce773..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling_utils.cc +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "op_tiling/op_tiling_utils.h" -#include -#include "graph/utils/attr_utils.h" - -namespace optiling { -void ReplaceEmptyShapeOfTensorDesc(const ge::OpDescPtr &op_desc, std::vector &indexes) { - const size_t input_size = op_desc->GetAllInputsSize(); - for (size_t i = 0; i < input_size; ++i) { - const ge::GeTensorDescPtr tensor_desc_ptr = op_desc->MutableInputDesc(static_cast(i)); - if (tensor_desc_ptr == nullptr) { - continue; - } - if (tensor_desc_ptr->MutableShape().IsScalar()) { - indexes.push_back(static_cast(i)); - tensor_desc_ptr->MutableShape().SetDimNum(1); - (void)tensor_desc_ptr->MutableShape().SetDim(0, 1); - } - } - - const size_t output_size = op_desc->GetOutputsSize(); - for (size_t i = 0; i < output_size; ++i) { - const ge::GeTensorDescPtr tensor_desc_ptr = op_desc->MutableOutputDesc(static_cast(i)); - if (tensor_desc_ptr == nullptr) { - continue; - } - if (tensor_desc_ptr->MutableShape().IsScalar()) { - indexes.push_back(static_cast(-1 - i)); - tensor_desc_ptr->MutableShape().SetDimNum(1); - (void)tensor_desc_ptr->MutableShape().SetDim(0, 1); - } - } -} - -void RecoveryEmptyShapeOfTensorDesc(const ge::OpDescPtr &op_desc, const std::vector &indexes) { - for (const int32_t &index : indexes) { - ge::GeTensorDescPtr tensor_desc_ptr; - if (index >= 0) { - tensor_desc_ptr = op_desc->MutableInputDesc(static_cast(index)); - } else { - tensor_desc_ptr = op_desc->MutableOutputDesc(static_cast(std::abs(index) - 1)); - } - if (tensor_desc_ptr == nullptr) { - continue; - } - tensor_desc_ptr->MutableShape().SetDimNum(0); - } -} -} // namespace optiling diff --git a/register/op_tiling/op_tiling_utils.h b/register/op_tiling/op_tiling_utils.h deleted file mode 100644 index 3383643493f664a0b17d3f3d4a1858ac90413256..0000000000000000000000000000000000000000 --- a/register/op_tiling/op_tiling_utils.h +++ /dev/null @@ -1,33 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef REGISTER_OP_TILING_OP_TILING_UTILS_H_ -#define REGISTER_OP_TILING_OP_TILING_UTILS_H_ - -#include -#include -#include "graph/op_desc.h" -#include "graph/debug/ge_log.h" - -namespace optiling { -void ReplaceEmptyShapeOfTensorDesc(const ge::OpDescPtr &op_desc, std::vector &indexes); -void RecoveryEmptyShapeOfTensorDesc(const ge::OpDescPtr &op_desc, const std::vector &indexes); - -#define OP_TILING_MAKE_SHARED(exec_expr0, exec_expr1) \ - do { \ - try { \ - exec_expr0; \ - } catch (...) { \ - GE_LOGE("Make shared failed"); \ - exec_expr1; \ - } \ - } while (0) - -} // namespace optiling -#endif // REGISTER_OP_TILING_OP_TILING_UTILS_H_ diff --git a/register/opdef/op_config_registry.cc b/register/opdef/op_config_registry.cc deleted file mode 100644 index d7b58beb97a6b4a24faddf24570ed1e051ccefdd..0000000000000000000000000000000000000000 --- a/register/opdef/op_config_registry.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "op_config_registry_impl.h" -#include "common/ge_common/debug/ge_log.h" -#include "op_def_impl.h" - -namespace ops { - -OpConfigRegistry::OpConfigRegistry() {} - -void OpConfigRegistry::RegisterOpAICoreConfig(const char* name, const char* socVersion, OpAICoreConfigFunc func = nullptr) { - if (name == nullptr) { - GELOGE(ge::PARAM_INVALID, "Register Op name is null."); - return; - } - - if (socVersion == nullptr) { - GELOGE(ge::PARAM_INVALID, "Register socVersion for op[%s] is null.", name); - return; - } - - GELOGD("Add aicore config for op[%s] at socVersion[%s] in OpConfigRegistry.", name, socVersion); - OpConfigRegistryImpl::GetInstance().AddAICoreConfig(name, socVersion, func); -} - -OpConfigRegistryImpl& OpConfigRegistryImpl::GetInstance() { - static OpConfigRegistryImpl instance; - return instance; -} - -void OpConfigRegistryImpl::AddAICoreConfig(const char* name, const char* socVersion, OpAICoreConfigFunc func) { - if (name == nullptr) { - GELOGE(ge::PARAM_INVALID, "AddAICoreConfig Op name is null."); - return; - } - - if (socVersion == nullptr) { - GELOGE(ge::PARAM_INVALID, "AddAICoreConfig socVersion for op is null."); - return; - } - - GELOGD("Add aicore config for op[%s] at socVersion[%s] in OpConfigRegistryImpl.", name, socVersion); - funcData_[ge::AscendString(name)][ge::AscendString(socVersion)] = func; -} - -std::map OpConfigRegistryImpl::GetOpAllAICoreConfig(const char* name) { - if (name == nullptr) { - GELOGE(ge::PARAM_INVALID, "GetOpAllAICoreConfig Op name is null."); - } - - GELOGD("Aicore config size: %zu", funcData_.size()); - - auto iter = funcData_.find(ge::AscendString(name)); - if (iter == funcData_.end()) { - GELOGD("Can not find aicore config for op[%s].", name); - return std::map(); - } - GELOGD("Found aicore config for op[%s].", name); - return iter->second; -} - -std::map GetOpAllAICoreConfig(const char* name) { - if (name == nullptr) { - GELOGE(ge::PARAM_INVALID, "GetOpAllAICoreConfig Op name is null."); - return {}; - } - return OpConfigRegistryImpl::GetInstance().GetOpAllAICoreConfig(name); -} - -} \ No newline at end of file diff --git a/register/opdef/op_config_registry_impl.h b/register/opdef/op_config_registry_impl.h deleted file mode 100644 index 8128cd17b073031e153514c083185e319128ce32..0000000000000000000000000000000000000000 --- a/register/opdef/op_config_registry_impl.h +++ /dev/null @@ -1,28 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef INC_REGISTER_OP_CONFIG_REGISTRY_IMPL_H_ -#define INC_REGISTER_OP_CONFIG_REGISTRY_IMPL_H_ - -#include -#include "external/register/op_config_registry.h" - -namespace ops { -class OpConfigRegistryImpl { -public: - static OpConfigRegistryImpl &GetInstance(); - void AddAICoreConfig(const char* name, const char* socVersion, OpAICoreConfigFunc func); - std::map GetOpAllAICoreConfig(const char* name); - -private: - std::map> funcData_; -}; -} - -#endif // INC_REGISTER_OP_CONFIG_REGISTRY_IMPL_H_ \ No newline at end of file diff --git a/register/opdef/op_def.cc b/register/opdef/op_def.cc deleted file mode 100644 index 7faa5e323217b8befe47d9a9b8472bfd255070f2..0000000000000000000000000000000000000000 --- a/register/opdef/op_def.cc +++ /dev/null @@ -1,549 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "op_def_impl.h" -#include "common/ge_common/debug/ge_log.h" -#include "register/op_def.h" -#include "register/op_config_registry.h" - -namespace ops { -OpDef::OpDef(const char *type) : impl_(new(std::nothrow) OpDefImpl) { - this->impl_->op_type = type; - auto regConfigs = GetOpAllAICoreConfig(type); - GELOGD("Aicore op[%s] configs size: %zu", type, regConfigs.size()); - for (auto it = regConfigs.cbegin(); it!= regConfigs.cend(); ++it) { - GELOGD("Found aicore op[%s] at socVersion[%s] registerd by REGISTER_OP_AICORE_CONFIG.", - type, it->first.GetString()); - if (it->second == nullptr) { - GELOGE(ge::PARAM_INVALID, "Aicore func of op[%s] at socVersion[%s] registerd by REGISTER_OP_AICORE_CONFIG is nullptr."); - return; - } - auto config = it->second(); - this->impl_->op_aicore.AddConfig(it->first.GetString(), config); - } -} - -OpDef::OpDef(const OpDef &op_def) : impl_(new(std::nothrow) OpDefImpl) { - this->impl_->op_type = op_def.impl_->op_type; - this->impl_->op_params = op_def.impl_->op_params; - this->impl_->attrs = op_def.impl_->attrs; - this->impl_->op_aicore = op_def.impl_->op_aicore; - this->impl_->has_workspace = op_def.impl_->has_workspace; - this->impl_->infer_shape = op_def.impl_->infer_shape; - this->impl_->infer_shape_range = op_def.impl_->infer_shape_range; - this->impl_->infer_data_type = op_def.impl_->infer_data_type; - this->impl_->op_mc2 = op_def.impl_->op_mc2; - this->impl_->non_list_len = op_def.impl_->non_list_len; - this->impl_->category = op_def.impl_->category; - this->impl_->comment_map = op_def.impl_->comment_map; - this->impl_->format_mode = op_def.impl_->format_mode; - this->impl_->enable_fall_back = op_def.impl_->enable_fall_back; -} - -OpDef::~OpDef() = default; - -OpDef &OpDef::operator=(const OpDef &op_def) { - if (this != &op_def) { - *this->impl_ = *op_def.impl_; - } - return *this; -} - -OpParamDef &OpDef::Input(const char *name) { - return this->impl_->op_params.Input(name); -} - -OpParamDef &OpDef::Output(const char *name) { - return this->impl_->op_params.Output(name); -} - -OpAttrDef &OpDef::Attr(const char *name) { - return this->GetOrCreateAttr(name); -} -OpDef &OpDef::Comment(CommentSection section, const char *comment) { - if (section >= CommentSection::SECTION_MAX) { - GELOGE(ge::PARAM_INVALID, "Ops %s : Comment Section is Invalid", this->GetOpType().GetString()); - return *this; - } - if (comment == nullptr || strlen(comment) == 0) { - GELOGE(ge::PARAM_INVALID, "Ops %s : Comment content cannot be empty", this->GetOpType().GetString()); - return *this; - } - if (section == CommentSection::CATEGORY) { - if (strchr(comment, ' ') != nullptr) { - GELOGE(ge::PARAM_INVALID, "Ops %s : category names cannot be split by spaces", this->GetOpType().GetString()); - return *this; - } - this->impl_->category = comment; - return *this; - } - this->impl_->comment_map[section].emplace_back(comment); - return *this; -} -ItemFindStatus OpDef::FindAttr(const char *name, OpAttrDef **attr) { - std::vector *attrList = &this->impl_->attrs; - for (auto it = attrList->begin(); it != attrList->end(); it++) { - if (ge::AscendString(it->GetName()) == ge::AscendString(name)) { - *attr = &(*it); - return ItemFindStatus::ITEM_FIND; - } - } - return ItemFindStatus::ITEM_NOEXIST; -} - -OpAttrDef &OpDef::AddAttr(OpAttrDef &attr) { - this->impl_->attrs.emplace_back(attr); - return this->impl_->attrs.back(); -} - -OpAttrDef &OpDef::GetOrCreateAttr(const char *name) { - OpAttrDef *pAttr; - if (this->FindAttr(name, &pAttr) == ItemFindStatus::ITEM_FIND) { - return *pAttr; - } else { - OpAttrDef attr(name); - return this->AddAttr(attr); - } -} - -std::vector &OpDef::GetAttrs(void) { - return this->impl_->attrs; -} - -OpDef &OpDef::SetInferShape(gert::OpImplRegisterV2::InferShapeKernelFunc func) { - this->impl_->infer_shape = func; - return *this; -} - -OpDef &OpDef::SetInferShapeRange(gert::OpImplRegisterV2::InferShapeRangeKernelFunc func) { - this->impl_->infer_shape_range = func; - return *this; -} - -OpDef &OpDef::SetInferDataType(gert::OpImplRegisterV2::InferDataTypeKernelFunc func) { - this->impl_->infer_data_type = func; - return *this; -} - -gert::OpImplRegisterV2::InferShapeKernelFunc &OpDef::GetInferShape(void) { - return this->impl_->infer_shape; -} -gert::OpImplRegisterV2::InferShapeRangeKernelFunc &OpDef::GetInferShapeRange(void) { - return this->impl_->infer_shape_range; -} -gert::OpImplRegisterV2::InferDataTypeKernelFunc &OpDef::GetInferDataType(void) { - return this->impl_->infer_data_type; -} -ge::AscendString &OpDef::GetOpType(void) { - return this->impl_->op_type; -} -ge::AscendString &OpDef::GetCateGory(void) const { - return this->impl_->category; -} -std::vector &OpDef::GetBrief(void) const { - return this->impl_->comment_map[ops::CommentSection::BRIEF]; -} -std::vector &OpDef::GetConstraints(void) const { - return this->impl_->comment_map[ops::CommentSection::CONSTRAINTS]; -} -std::vector &OpDef::GetRestrictions(void) const { - return this->impl_->comment_map[ops::CommentSection::RESTRICTIONS]; -} -std::vector &OpDef::GetSee(void) const { - return this->impl_->comment_map[ops::CommentSection::SEE]; -} -std::vector &OpDef::GetThirdPartyFwkCopat(void) const { - return this->impl_->comment_map[ops::CommentSection::THIRDPARTYFWKCOMPAT]; -} -std::vector &OpDef::GetInputs(void) { - return this->impl_->op_params.GetInputs(); -} - -std::vector &OpDef::GetOutputs(void) { - return this->impl_->op_params.GetOutputs(); -} - -void OpDef::MergeParam(std::vector &merge, std::vector &aicore_params) const { - for (auto &aicoreParam : aicore_params) { - bool find = false; - for (auto &mergeParam : merge) { - if (mergeParam == aicoreParam) { - mergeParam.MergeParam(aicoreParam); - find = true; - break; - } - } - if (!find) { - merge.emplace_back(aicoreParam); - } - } -} - -void OpDef::DfsDataType(DfsParam &dfs_param, const std::vector &all_param, - uint32_t list_idx, uint32_t non_list_idx) const { - constexpr uint32_t two = 2; - const OpParamDef &def = all_param[list_idx / two]; - if (def.IsScalarOrScalarList() && (def.IsScalarTypeSet() || def.IsScalarNameSet())) { - dfs_param.types.push_back(OpDef::ArrParam(static_cast(def.GetScalarType()), false)); - DfsFullPermutation(dfs_param, all_param, list_idx + 1, non_list_idx); - dfs_param.types.pop_back(); - } else if (def.IsDtypeList()) { - for (uint32_t i = 0; i < def.impl_->types_list.size(); ++i) { - dfs_param.types.push_back(OpDef::ArrParam(i, true)); - DfsFullPermutation(dfs_param, all_param, list_idx + 1, non_list_idx); - dfs_param.types.pop_back(); - } - } else { - dfs_param.types.push_back(OpDef::ArrParam(non_list_idx, true)); - DfsFullPermutation(dfs_param, all_param, list_idx + 1, non_list_idx); - dfs_param.types.pop_back(); - } -} - -void OpDef::DfsFormat(DfsParam &dfs_param, const std::vector &all_param, - uint32_t list_idx, uint32_t non_list_idx) const { - constexpr uint32_t two = 2; - const OpParamDef &def = all_param[list_idx / two]; - if ((def.IsScalarOrScalarList() || def.IsValueDepend())) { - dfs_param.formats.push_back(OpDef::ArrParam(static_cast(ge::FORMAT_ND), false)); - DfsFullPermutation(dfs_param, all_param, list_idx + 1, non_list_idx); - dfs_param.formats.pop_back(); - } else if (def.IsFormatList()) { - for (uint32_t i = 0; i < def.impl_->formats_list.size(); ++i) { - dfs_param.formats.push_back(OpDef::ArrParam(i, true)); - DfsFullPermutation(dfs_param, all_param, list_idx + 1, non_list_idx); - dfs_param.formats.pop_back(); - } - } else { - dfs_param.formats.push_back(OpDef::ArrParam(non_list_idx, true)); - DfsFullPermutation(dfs_param, all_param, list_idx + 1, non_list_idx); - dfs_param.formats.pop_back(); - } -} - -void OpDef::DfsFullPermutation(DfsParam &dfs_param, const std::vector &all_param, - uint32_t list_idx, uint32_t non_list_idx) const { - constexpr uint32_t two = 2; - if (list_idx == all_param.size() * two) { - dfs_param.full_types.push_back(dfs_param.types); - dfs_param.full_formats.push_back(dfs_param.formats); - return; - } - // process types while list_idx is even; process formats while list_idx is odd - if (list_idx % two == 0) { - DfsDataType(dfs_param, all_param, list_idx, non_list_idx); - } else { - DfsFormat(dfs_param, all_param, list_idx, non_list_idx); - } -} - -bool OpDef::IsNonListTypes(const OpParamDef &def) const { - return (!def.IsScalarOrScalarList() && def.IsDtype()) || - (def.IsScalarOrScalarList() && (!def.IsScalarTypeSet() && !def.IsScalarNameSet()) && def.IsDtype()); -} - -bool OpDef::IsNonListFormats(const OpParamDef &def) const { - return (!def.IsScalarOrScalarList() && !def.IsValueDepend() && def.IsFormat()); -} - -uint32_t OpDef::GetNonListLen(std::vector &input_param, std::vector &output_param) const { - std::unordered_set non_list_lens; - auto set_non_list_len = [this, &non_list_lens](const std::vector ¶ms) { - for (auto &def : params) { - if (this->IsNonListTypes(def)) { - non_list_lens.insert(def.impl_->types.size()); - } - if (this->IsNonListFormats(def)) { - non_list_lens.insert(def.impl_->formats.size()); - } - } - }; - set_non_list_len(input_param); - set_non_list_len(output_param); - - if (non_list_lens.empty()) { - return 1; - } - if (non_list_lens.size() > 1) { - GELOGE(ge::PARAM_INVALID, "Element num of DataType and Format is not aligned."); - return 0; - } - if (*non_list_lens.begin() == 0) { - GELOGE(ge::PARAM_INVALID, "DataType or Format cannot be empty."); - return 0; - } - return *non_list_lens.begin(); -} - -void OpDef::UpdateDtypeImpl(const DfsParam &dfs_param, OpParamDef ¶m, const uint32_t ¶m_idx) { - uint32_t param_type = dfs_param.full_types[0][param_idx].first; - bool have_scalar_param = !(dfs_param.full_types[0][param_idx].second); - if (have_scalar_param && static_cast(param_type) != ge::DT_UNDEFINED) { - if (param.IsSetDtypeForBin()) { - GELOGW("DataTypeForBinQuery is incompatible with To Type."); - param.impl_->set_type_for_bin = false; - } - param.impl_->types = std::vector(dfs_param.full_types.size(), static_cast(param_type)); - return; - } - if (have_scalar_param && static_cast(param_type) == ge::DT_UNDEFINED) { - return; - } - uint32_t num = 0; - bool is_idx = false; - std::vector data_types; - std::vector data_types_for_bin; - bool is_follow_list = - (param.impl_->follow_type == FollowType::ALL || param.impl_->follow_type == FollowType::DTYPE) && - param.IsDtypeList(); - for (uint32_t type_idx = 0; type_idx < dfs_param.full_types.size(); ++type_idx) { - std::tie(num, is_idx) = dfs_param.full_types[type_idx][param_idx]; - if (param.IsSetDtypeForBin() && is_idx && !is_follow_list) { - data_types_for_bin.emplace_back(param.impl_->types_for_bin[num]); - } - if (param.IsDtype()) { - data_types.emplace_back(param.impl_->types[num]); - } - if (param.IsDtypeList()) { - data_types.emplace_back(param.impl_->types_list[num]); - } - } - if (!data_types_for_bin.empty()) { - param.impl_->types_for_bin = data_types_for_bin; - } - param.impl_->types = data_types; -} - -void OpDef::UpdateFormatImpl(const DfsParam &dfs_param, OpParamDef ¶m, const uint32_t ¶m_idx) { - uint32_t param_format = dfs_param.full_formats[0][param_idx].first; - bool have_scalar_param = !(dfs_param.full_formats[0][param_idx].second); - if (have_scalar_param) { - if (param.IsSetFormatForBin()) { - GELOGW("FormatForBinQuery is incompatible with Scalar/ScalarList or ValueDepend."); - param.impl_->set_format_for_bin = false; - } - param.impl_->formats = - std::vector(dfs_param.full_formats.size(), static_cast(param_format)); - return; - } - uint32_t num = 0; - bool is_idx = false; - std::vector data_formats; - std::vector data_formats_for_bin; - bool is_follow_list = - (param.impl_->follow_type == FollowType::ALL || param.impl_->follow_type == FollowType::FORMAT) && - param.IsFormatList(); - for (uint32_t type_idx = 0; type_idx < dfs_param.full_formats.size(); ++type_idx) { - std::tie(num, is_idx) = dfs_param.full_formats[type_idx][param_idx]; - if (param.IsSetFormatForBin() && is_idx && !is_follow_list) { - data_formats_for_bin.emplace_back(param.impl_->formats_for_bin[num]); - } - if (param.IsFormat()) { - data_formats.emplace_back(param.impl_->formats[num]); - } - if (param.IsFormatList()) { - data_formats.emplace_back(param.impl_->formats_list[num]); - } - } - if (!data_formats_for_bin.empty()) { - param.impl_->formats_for_bin = data_formats_for_bin; - } - param.impl_->formats = data_formats; -} - -void OpDef::UpdateInput(const DfsParam &dfs_param, std::vector &input) { - std::vector> to_list; - for (uint32_t param_idx = 0; param_idx < input.size(); ++param_idx) { - if (input[param_idx].IsScalarNameSet()) { - to_list.emplace_back(param_idx, input[param_idx].GetScalarName()); - } - this->UpdateDtypeImpl(dfs_param, input[param_idx], param_idx); - this->UpdateFormatImpl(dfs_param, input[param_idx], param_idx); - } - auto follow_map = this->GetFollowMap(); - uint32_t input_idx = 0; - ge::AscendString to_name = ""; - for (const auto &to : to_list) { - std::tie(input_idx, to_name) = to; - if (follow_map.find(to_name) == follow_map.end()) { - GELOGE(ge::PARAM_INVALID, "Param %s : Cannot find param to be set To.", - input[input_idx].GetParamName().GetString()); - continue; - } - const PortFollowInfo &to_param = follow_map.at(to_name); - if (to_param.port_stat == OpDef::PortStat::OUT) { - GELOGE(ge::PARAM_INVALID, "Param %s : Cannot set To to output param.", - input[input_idx].GetParamName().GetString()); - continue; - } - if (input[to_param.index_in].IsScalarNameSet()) { - GELOGE(ge::PARAM_INVALID, "Param %s : Chained parameter setting is not supported in To with name.", - input[input_idx].GetParamName().GetString()); - continue; - } - input[input_idx].impl_->types = input[to_param.index_in].impl_->types; - if (input[input_idx].IsSetDtypeForBin()) { - std::vector data_types_for_bin; - for (uint32_t type_idx = 0; type_idx < dfs_param.full_types.size(); ++type_idx) { - uint32_t idx = dfs_param.full_types[type_idx][to_param.index_in].first; - data_types_for_bin.emplace_back(input[input_idx].impl_->types_for_bin[idx]); - } - input[input_idx].impl_->types_for_bin = data_types_for_bin; - } - } -} - -void OpDef::UpdateOutput(const DfsParam &dfs_param, std::vector &output) { - for (uint32_t param_idx = 0; param_idx < output.size(); ++param_idx) { - if (output[param_idx].IsScalarOrScalarList()) { - GELOGE(ge::PARAM_INVALID, "Output %s : output cannot be set to Scalar or ScalarList.", - output[param_idx].GetParamName().GetString()); - continue; - } - uint32_t dfs_full_idx = dfs_param.full_types[0].size() - output.size() + param_idx; - this->UpdateDtypeImpl(dfs_param, output[param_idx], dfs_full_idx); - this->UpdateFormatImpl(dfs_param, output[param_idx], dfs_full_idx); - } -} - -void OpDef::SetPermutedParam(const DfsParam &dfs_param, - std::vector &input, - std::vector &output) { - this->UpdateInput(dfs_param, input); - this->UpdateOutput(dfs_param, output); - this->FollowListImpl(dfs_param, input, output); -} - -void OpDef::CheckIncompatible(const std::vector& all) const { - bool is_unknown_shape_format = false; - for (auto &def : all) { - if (!def.impl_->unknown_shape_formats.empty()) { - is_unknown_shape_format = true; - break; - } - } - if (is_unknown_shape_format) { - for (auto &def : all) { - if (def.impl_->formats_list.size() > 1 || def.impl_->types_list.size() > 1) { - GELOGW("UnknownShapeFormat is incompatible with FormatList/DataTypeList."); - return; - } - } - } -} - -void OpDef::FullPermutation(std::vector &input_param, - std::vector &output_param) { - this->impl_->non_list_len = GetNonListLen(input_param, output_param); - std::vector all_param = input_param; - all_param.insert(all_param.end(), output_param.begin(), output_param.end()); - CheckIncompatible(all_param); - struct DfsParam dfs_param; - for (uint32_t i = 0; i < this->impl_->non_list_len; ++i) { - DfsFullPermutation(dfs_param, all_param, 0, i); - } - if (dfs_param.full_types.empty() || dfs_param.full_formats.empty()) { - for (auto &def : input_param) { - def.impl_->types.clear(); - def.impl_->formats.clear(); - } - for (auto &def : output_param) { - def.impl_->types.clear(); - def.impl_->formats.clear(); - } - return; - } - SetPermutedParam(dfs_param, input_param, output_param); -} - -void OpDef::SetDefaultND(std::vector &defs) const { - for (auto &def : defs) { - if (def.impl_->formats.empty() && def.impl_->formats_list.empty()) { - def.impl_->formats_status = LIST; - def.impl_->formats_list = {ge::FORMAT_ND}; - } - } -} - -std::vector> OpDef::GetMergeInputsOutputs(const OpAICoreConfig &aicore_config) { - this->FollowImpl(); - std::vector inputs = this->GetInputs(); - std::vector outputs = this->GetOutputs(); - MergeParam(inputs, aicore_config.GetInputs()); - MergeParam(outputs, aicore_config.GetOutputs()); - SetDefaultND(inputs); - SetDefaultND(outputs); - this->FullPermutation(inputs, outputs); - std::vector> inputs_outputs; - inputs_outputs.push_back(inputs); - inputs_outputs.push_back(outputs); - return inputs_outputs; -} - -std::vector OpDef::GetMergeInputs(OpAICoreConfig &aicore_config) { - std::vector> inputs_outputs = GetMergeInputsOutputs(aicore_config); - return inputs_outputs[0]; -} - -std::vector OpDef::GetMergeOutputs(OpAICoreConfig &aicore_config) { - std::vector> inputs_outputs = GetMergeInputsOutputs(aicore_config); - return inputs_outputs[1]; -} - -OpAICoreDef &OpDef::AICore(void) { - return this->impl_->op_aicore; -} - -OpMC2Def &OpDef::MC2(void) { - return this->impl_->op_mc2; -} - -void OpDef::FollowImpl(void) { - this->impl_->op_params.FollowDataImpl(); - return; -} - -void OpDef::FollowListImpl(const DfsParam &dfs_param, std::vector& input, std::vector& output) { - this->impl_->op_params.FollowListDataImpl(dfs_param, input, output); - return; -} - -std::map OpDef::GetFollowMap(void) { - return this->impl_->op_params.GetFollowMap(); -} -std::map>> OpDef::GetFollowShapeMap(void) { - return this->impl_->op_params.GetShapeMap(); -} -std::map>> OpDef::GetFollowTypeMap(void) { - return this->impl_->op_params.GetDtypeMap(); -} -OpParamDef OpDef::GetParamDef(const ge::AscendString& name, OpDef::PortStat stat) { - return this->impl_->op_params.GetParamDef(name, stat); -} - -OpDef &OpDef::FormatMatchMode(FormatCheckOption option) { - this->impl_->format_mode = option; - return *this; -} - -FormatCheckOption OpDef::GetFormatMatchMode(void) { - return this->impl_->format_mode; -} - -OpDef &OpDef::EnableFallBack(void) { - this->impl_->enable_fall_back = true; - return *this; -} - -bool OpDef::IsEnableFallBack(void) { - return this->impl_->enable_fall_back; -} - -} // namespace ops diff --git a/register/opdef/op_def_aicore.cc b/register/opdef/op_def_aicore.cc deleted file mode 100644 index fc1537146ee274eef9a72922535613891c4a35fe..0000000000000000000000000000000000000000 --- a/register/opdef/op_def_aicore.cc +++ /dev/null @@ -1,213 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_def.h" -#include "op_def_impl.h" -#include "common/ge_common/debug/ge_log.h" - -namespace ops { -OpAICoreConfig::OpAICoreConfig() : impl_(new(std::nothrow) OpAICoreConfigImpl) {} - -OpAICoreConfig::OpAICoreConfig(const char *soc) : impl_(new(std::nothrow) OpAICoreConfigImpl) { - (void)soc; - this->AddCfgItem("dynamicCompileStatic.flag", "true"); - this->AddCfgItem("dynamicFormat.flag", "true"); - this->AddCfgItem("dynamicRankSupport.flag", "true"); - this->AddCfgItem("dynamicShapeSupport.flag", "true"); - this->AddCfgItem("needCheckSupport.flag", "false"); - this->AddCfgItem("precision_reduce.flag", "true"); -} - -OpAICoreConfig::OpAICoreConfig(const OpAICoreConfig &aicore_config) : impl_(new(std::nothrow) OpAICoreConfigImpl) { - this->impl_->op_params = aicore_config.impl_->op_params; - this->impl_->cfg_keys = aicore_config.impl_->cfg_keys; - this->impl_->cfg_info = aicore_config.impl_->cfg_info; -} - -OpAICoreConfig::~OpAICoreConfig() = default; - -OpAICoreConfig &OpAICoreConfig::operator=(const OpAICoreConfig &aicore_config) { - if (this != &aicore_config) { - *this->impl_ = *aicore_config.impl_; - } - return *this; -} - -OpParamDef &OpAICoreConfig::Input(const char *name) { - return this->impl_->op_params.Input(name); -} - -OpParamDef &OpAICoreConfig::Output(const char *name) { - return this->impl_->op_params.Output(name); -} - -OpAICoreConfig &OpAICoreConfig::DynamicCompileStaticFlag(bool flag) { - this->AddCfgItem("dynamicCompileStatic.flag", flag ? "true" : "false"); - return *this; -} - -OpAICoreConfig &OpAICoreConfig::DynamicFormatFlag(bool flag) { - this->AddCfgItem("dynamicFormat.flag", flag ? "true" : "false"); - return *this; -} - -OpAICoreConfig &OpAICoreConfig::DynamicRankSupportFlag(bool flag) { - this->AddCfgItem("dynamicRankSupport.flag", flag ? "true" : "false"); - return *this; -} - -OpAICoreConfig &OpAICoreConfig::DynamicShapeSupportFlag(bool flag) { - this->AddCfgItem("dynamicShapeSupport.flag", flag ? "true" : "false"); - return *this; -} - -OpAICoreConfig &OpAICoreConfig::NeedCheckSupportFlag(bool flag) { - this->AddCfgItem("needCheckSupport.flag", flag ? "true" : "false"); - return *this; -} - -OpAICoreConfig &OpAICoreConfig::PrecisionReduceFlag(bool flag) { - this->AddCfgItem("precision_reduce.flag", flag ? "true" : "false"); - return *this; -} - -OpAICoreConfig &OpAICoreConfig::ExtendCfgInfo(const char *key, const char *value) { - this->AddCfgItem(key, value); - return *this; -} - -std::vector &OpAICoreConfig::GetInputs(void) const { - return this->impl_->op_params.GetInputs(); -} -std::vector &OpAICoreConfig::GetOutputs(void) const { - return this->impl_->op_params.GetOutputs(); -} -void OpAICoreConfig::AddCfgItem(const char *key, const char *value) { - auto it = this->impl_->cfg_info.find(key); - if (it == this->impl_->cfg_info.cend()) { - this->impl_->cfg_keys.emplace_back(key); - } else { - this->impl_->cfg_info.erase(key); - } - this->impl_->cfg_info.emplace(key, value); -} - -std::vector &OpAICoreConfig::GetCfgKeys(void) { - return this->impl_->cfg_keys; -} - -std::map &OpAICoreConfig::GetCfgInfo(void) { - return this->impl_->cfg_info; -} - -ge::AscendString &OpAICoreConfig::GetConfigValue(const char *key) { - return this->impl_->cfg_info[key]; -} - -OpAICoreDef::OpAICoreDef() : impl_(new(std::nothrow) OpAICoreDefImpl) {} - -OpAICoreDef::OpAICoreDef(const OpAICoreDef &aicore_def) : impl_(new(std::nothrow) OpAICoreDefImpl) { - this->impl_->tiling_func = aicore_def.impl_->tiling_func; - this->impl_->tiling_parse = aicore_def.impl_->tiling_parse; - this->impl_->ci_creator = aicore_def.impl_->ci_creator; - this->impl_->ci_deleter = aicore_def.impl_->ci_deleter; - this->impl_->op_chk_support = aicore_def.impl_->op_chk_support; - this->impl_->op_sel_format = aicore_def.impl_->op_sel_format; - this->impl_->op_get_support = aicore_def.impl_->op_get_support; - this->impl_->op_get_spec = aicore_def.impl_->op_get_spec; - this->impl_->op_generlize_func = aicore_def.impl_->op_generlize_func; - this->impl_->aicore_configs = aicore_def.impl_->aicore_configs; -} - -OpAICoreDef::~OpAICoreDef() = default; - -OpAICoreDef &OpAICoreDef::operator=(const OpAICoreDef &aicore_def) { - if (this != &aicore_def) { - *this->impl_ = *aicore_def.impl_; - } - return *this; -} - -ge::graphStatus TilingParsePlaceHolder(gert::TilingParseContext* context) -{ - (void)context; - return ge::GRAPH_SUCCESS; -} - -OpAICoreDef &OpAICoreDef::SetTiling(gert::OpImplRegisterV2::TilingKernelFunc func) { - this->impl_->tiling_func = func; - this->impl_->tiling_parse = TilingParsePlaceHolder; - return *this; -} - -OpAICoreDef &OpAICoreDef::SetCheckSupport(optiling::OP_CHECK_FUNC func) { - this->impl_->op_chk_support = func; - return *this; -} - -OpAICoreDef &OpAICoreDef::SetOpSelectFormat(optiling::OP_CHECK_FUNC func) { - this->impl_->op_sel_format = func; - return *this; -} - -OpAICoreDef &OpAICoreDef::SetOpSupportInfo(optiling::OP_CHECK_FUNC func) { - this->impl_->op_get_support = func; - return *this; -} - -OpAICoreDef &OpAICoreDef::SetOpSpecInfo(optiling::OP_CHECK_FUNC func) { - this->impl_->op_get_spec = func; - return *this; -} - -OpAICoreDef &OpAICoreDef::SetParamGeneralize(optiling::PARAM_GENERALIZE_FUNC func) { - this->impl_->op_generlize_func = func; - return *this; -} - -OpAICoreDef &OpAICoreDef::AddConfig(const char *soc) { - OpAICoreConfig aicore_config(soc); - this->AddConfig(soc, aicore_config); - return *this; -} - -OpAICoreDef &OpAICoreDef::AddConfig(const char *soc, OpAICoreConfig &aicore_config) { - GELOGD("Call AddConfig for soc[%s].", soc); - this->impl_->aicore_configs.erase(ge::AscendString(soc)); - this->impl_->aicore_configs.emplace(ge::AscendString(soc), aicore_config); - return *this; -} - -std::map &OpAICoreDef::GetAICoreConfigs(void) { - return this->impl_->aicore_configs; -} - -gert::OpImplRegisterV2::TilingKernelFunc &OpAICoreDef::GetTiling(void) { - return this->impl_->tiling_func; -} - -optiling::OP_CHECK_FUNC &OpAICoreDef::GetCheckSupport(void) { - return this->impl_->op_chk_support; -} -optiling::OP_CHECK_FUNC &OpAICoreDef::GetOpSelectFormat(void) { - return this->impl_->op_sel_format; -} -optiling::OP_CHECK_FUNC &OpAICoreDef::GetOpSupportInfo(void) { - return this->impl_->op_get_support; -} -optiling::OP_CHECK_FUNC &OpAICoreDef::GetOpSpecInfo(void) { - return this->impl_->op_get_spec; -} -optiling::PARAM_GENERALIZE_FUNC &OpAICoreDef::GetParamGeneralize(void) { - return this->impl_->op_generlize_func; -} -void OpAICoreDef::Log(const char *op_type, const char *info) const { - GELOGD("%s, op_type:%s.", info, op_type); -} -} // namespace ops diff --git a/register/opdef/op_def_attr.cc b/register/opdef/op_def_attr.cc deleted file mode 100644 index 7684d5c02da0bb686119162813f6a8636a04aaa4..0000000000000000000000000000000000000000 --- a/register/opdef/op_def_attr.cc +++ /dev/null @@ -1,239 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_def.h" -#include "op_def_impl.h" -#include "common/ge_common/debug/ge_log.h" - -namespace ops { -OpAttrDef::OpAttrDef(const char *name) : impl_(new(std::nothrow) OpAttrDefImpl) { - this->impl_->name = name; -} - -OpAttrDef::OpAttrDef(const OpAttrDef &attr_def) : impl_(new(std::nothrow) OpAttrDefImpl) { - this->impl_->name = attr_def.impl_->name; - this->impl_->data_type = attr_def.impl_->data_type; - this->impl_->required = attr_def.impl_->required; - this->impl_->bool_value = attr_def.impl_->bool_value; - this->impl_->float_value = attr_def.impl_->float_value; - this->impl_->int_value = attr_def.impl_->int_value; - this->impl_->str_value = attr_def.impl_->str_value; - this->impl_->list_bool = attr_def.impl_->list_bool; - this->impl_->list_float = attr_def.impl_->list_float; - this->impl_->list_int = attr_def.impl_->list_int; - this->impl_->list_list_int = attr_def.impl_->list_list_int; - this->impl_->version = attr_def.impl_->version; - this->impl_->comment = attr_def.impl_->comment; -} - -OpAttrDef::~OpAttrDef() = default; - -OpAttrDef &OpAttrDef::operator=(const OpAttrDef &attr_def) { - if (this != &attr_def) { - *this->impl_ = *attr_def.impl_; - } - return *this; -} - -bool OpAttrDef::operator==(const OpAttrDef &attr_def) const { - if (this->impl_->name == attr_def.impl_->name) { - return true; - } - return false; -} - -OpAttrDef &OpAttrDef::AttrType(Option attr_type) { - if (attr_type == Option::OPTIONAL) { - this->impl_->required = false; - } - return *this; -} - -OpAttrDef &OpAttrDef::Bool(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_BOOL; - return *this; -} - -OpAttrDef &OpAttrDef::Bool(bool value) { - this->impl_->bool_value = value; - return this->Bool(); -} - -OpAttrDef &OpAttrDef::Float(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_FLOAT; - return *this; -} - -OpAttrDef &OpAttrDef::Float(float value) { - this->impl_->float_value = value; - return this->Float(); -} - -OpAttrDef &OpAttrDef::Int(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_INT; - return *this; -} - -OpAttrDef &OpAttrDef::Int(int64_t value) { - this->impl_->int_value = value; - return this->Int(); -} - -OpAttrDef &OpAttrDef::String(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_STR; - return *this; -} - -OpAttrDef &OpAttrDef::String(const char *value) { - this->impl_->str_value = value; - return this->String(); -} - -OpAttrDef &OpAttrDef::ListBool(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_LIST_BOOL; - return *this; -} - -OpAttrDef &OpAttrDef::ListBool(std::vector value) { - this->impl_->list_bool = value; - return this->ListBool(); -} - -OpAttrDef &OpAttrDef::ListFloat(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_LIST_FLOAT; - return *this; -} - -OpAttrDef &OpAttrDef::ListFloat(std::vector value) { - this->impl_->list_float = value; - return this->ListFloat(); -} - -OpAttrDef &OpAttrDef::ListInt(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_LIST_INT; - return *this; -} - -OpAttrDef &OpAttrDef::ListInt(std::vector value) { - this->impl_->list_int = value; - return this->ListInt(); -} - -OpAttrDef &OpAttrDef::ListListInt(void) { - this->impl_->data_type = AttrDataType::ATTR_DT_LIST_LIST_INT; - return *this; -} - -OpAttrDef &OpAttrDef::ListListInt(std::vector> value) { - this->impl_->list_list_int = value; - return this->ListListInt(); -} - -OpAttrDef &OpAttrDef::Version(uint32_t version) { - this->impl_->version = version; - return *this; -} - -OpAttrDef &OpAttrDef::Comment(const char *comment) { - if (comment == nullptr || strlen(comment) == 0) { - GELOGE(ge::PARAM_INVALID, "Attr %s : Comment content cannot be empty", this->GetName().GetString()); - return *this; - } - this->impl_->comment = comment; - return *this; -} - -ge::AscendString &OpAttrDef::GetComment(void) const { - return this->impl_->comment; -} - -uint32_t OpAttrDef::GetVersion(void) { - return this->impl_->version; -} - -ge::AscendString &OpAttrDef::GetName(void) const { - return this->impl_->name; -} - -bool OpAttrDef::IsRequired(void) { - return this->impl_->required; -} - -ge::AscendString &OpAttrDef::GetCfgDataType(void) const { - static ge::AscendString dtype_names[] = {"bool", "float", "int", "str", - "listBool", "listFloat", "listInt", "listListInt"}; - return dtype_names[static_cast(this->impl_->data_type)]; -} - -ge::AscendString &OpAttrDef::GetProtoDataType(void) const { - static ge::AscendString dtype_names[] = {"Bool", "Float", "Int", "String", - "ListBool", "ListFloat", "ListInt", "ListListInt"}; - return dtype_names[static_cast(this->impl_->data_type)]; -} - -template -std::string GetListStr(std::vector list, const char *brac, void (*pfSout)(std::stringstream &s, T v)) { - std::string str = ""; - std::stringstream sstream; - if (brac == nullptr || brac[0] == '\0' || brac[1] == '\0') { - return str.c_str(); - } - sstream << brac[0]; - for (auto v : list) { - pfSout(sstream, v); - } - str += sstream.str(); - if (list.size() > 0) { - str.resize(str.size() - 1); - } - str += brac[1]; - return str; -} - -ge::AscendString &OpAttrDef::GetAttrDefaultVal(const char *brac) { - std::stringstream sstream; - std::vector strList; - - if (this->impl_->data_type == AttrDataType::ATTR_DT_BOOL) { - sstream << (this->impl_->bool_value ? "true" : "false"); - this->impl_->value = sstream.str().c_str(); - } else if (this->impl_->data_type == AttrDataType::ATTR_DT_FLOAT) { - sstream << this->impl_->float_value; - this->impl_->value = sstream.str().c_str(); - } else if (this->impl_->data_type == AttrDataType::ATTR_DT_INT) { - sstream << this->impl_->int_value; - this->impl_->value = sstream.str().c_str(); - } else if (this->impl_->data_type == AttrDataType::ATTR_DT_STR) { - this->impl_->value = this->impl_->str_value; - } else if (this->impl_->data_type == AttrDataType::ATTR_DT_LIST_BOOL) { - this->impl_->value = GetListStr(this->impl_->list_bool, brac, [](std::stringstream &s, bool v) { - s << (v ? "true" : "false") << ","; - }).c_str(); - } else if (this->impl_->data_type == AttrDataType::ATTR_DT_LIST_FLOAT) { - this->impl_->value = - GetListStr(this->impl_->list_float, brac, [](std::stringstream &s, float v) { s << v << ","; }).c_str(); - } else if (this->impl_->data_type == AttrDataType::ATTR_DT_LIST_INT) { - this->impl_->value = GetListStr(this->impl_->list_int, brac, [](std::stringstream &s, int64_t v) { - s << v << ","; - }).c_str(); - } else if (this->impl_->data_type == AttrDataType::ATTR_DT_LIST_LIST_INT) { - for (auto listInt : this->impl_->list_list_int) { - strList.emplace_back(GetListStr(listInt, brac, [](std::stringstream &s, int64_t v) { s << v << ","; })); - } - this->impl_->value = - GetListStr(strList, brac, [](std::stringstream &s, std::string v) { s << v << ","; }).c_str(); - } else { - this->impl_->value = ""; - } - return this->impl_->value; -} -} // namespace ops diff --git a/register/opdef/op_def_factory.cc b/register/opdef/op_def_factory.cc deleted file mode 100644 index b660a64ee0ca3766cd490073a4bc037e51de1560..0000000000000000000000000000000000000000 --- a/register/opdef/op_def_factory.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "register/op_def.h" -#include "op_def_impl.h" -#include "register/op_def_factory.h" - -namespace ops { -static std::map g_opsdef_creator; -static std::vector g_ops_list; -static std::set g_ops_sink_list; - -int OpDefFactory::OpDefRegister(const char *name, OpDefCreator creator) { - g_opsdef_creator.emplace(name, creator); - g_ops_list.emplace_back(name); - return 0; -} -OpDef OpDefFactory::OpDefCreate(const char *name) { - auto it = g_opsdef_creator.find(name); - if (it != g_opsdef_creator.cend()) { - return it->second(name); - } - return OpDef("default"); -} - -std::vector &OpDefFactory::GetAllOp(void) { - return g_ops_list; -} - -void OpDefFactory::OpTilingSinkRegister(const char *opType) { - g_ops_sink_list.emplace(opType); -} - -bool OpDefFactory::OpIsTilingSink(const char *opType) { - return g_ops_sink_list.find(opType) != g_ops_sink_list.end(); -} -} // namespace ops diff --git a/register/opdef/op_def_impl.h b/register/opdef/op_def_impl.h deleted file mode 100644 index ced0d90c2ebd6322b25c8b957d86f5ce65163ca7..0000000000000000000000000000000000000000 --- a/register/opdef/op_def_impl.h +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef OP_DEF_IMPL_H -#define OP_DEF_IMPL_H - -#include "register/op_def.h" -#include "register/op_impl_registry.h" -#include "register/op_check_register.h" -#include "graph/operator_reg.h" - -namespace ops { -enum ListParamStatus : int32_t { - UNSET = 0, - LIST = 1, - NON_LIST = 2, -}; -class OpParamDefImpl { -public: - ge::AscendString name; - Option param_type = Option::REQUIRED; - std::vector types; - std::vector origin_types; - std::vector types_list; - std::vector formats; - std::vector formats_list; - std::vector types_for_bin; - std::vector formats_for_bin; - ListParamStatus types_status = UNSET; - ListParamStatus formats_status = UNSET; - ge::AscendString need_compile = ""; - ge::AscendString reshape_type = ""; - ge::AscendString value_depend = ""; - DependScope depend_scope = DependScope::ALL; - std::vector unknown_shape_formats; - bool ignore_contiguous = false; - bool auto_contiguous = false; - bool is_scalar = false; - bool is_scalar_list = false; - bool set_type_for_bin = false; - bool set_format_for_bin = false; - ge::AscendString scalar_name = ""; - ge::DataType scalar_type = ge::DT_UNDEFINED; - uint32_t version{0}; - InitValueType init_value_type = InitValueType::INIT_VALUE_DEFAULT; - InitValueNum init_value; - std::vector init_value_list; - bool is_output_shape_depend_on_compute = false; - ge::AscendString follow_port_name = ""; - FollowType follow_type = FollowType::INVALID_TYPE; - ge::AscendString comment = ""; -}; - -class OpParamTrunk { -public: - OpParamDef &Input(const char *name); - OpParamDef &Output(const char *name); - std::vector &GetInputs(void); - std::vector &GetOutputs(void); - -private: - friend class OpDef; - friend class OpProtoGenerator; - - ItemFindStatus ParamFind(const char *name, bool is_output, OpParamDef **param); - OpParamDef &ParamAdd(OpParamDef ¶m, bool is_output); - OpParamDef &ParamGetOrCreate(const char *name, bool is_output); - OpParamDef &GetParamDef(const ge::AscendString& name, OpDef::PortStat stat); - void FollowMapUpdate(OpParamDef ¶m, bool is_output); - void FollowDataImpl(void); - void DfsFollow(OpParamDef& op_param_def, OpDef::PortStat stat); - void ParamFollow(OpParamDef &op_param_def, OpParamDef &target_param, OpDef::PortStat stat); - void FollowListDataImpl(const OpDef::DfsParam &dfs_param, std::vector &input, - std::vector &output); - std::map GetFollowMap(void); - std::map>> GetShapeMap(void); - std::map>> GetDtypeMap(void); - bool follow_isimpl = false; - std::vector inputs_; - std::vector outputs_; - std::map follow_map; - std::vector> follow_dtypelist; - std::vector> follow_formatlist; - std::map>> follow_shape_map; - std::map>> follow_dtype_map; -}; - -class OpAttrDefImpl { -public: - ge::AscendString name; - AttrDataType data_type = AttrDataType::ATTR_DT_BOOL; - bool required = true; - bool bool_value = false; - float float_value = 0; - int64_t int_value = 0; - ge::AscendString str_value = ""; - std::vector list_bool = {}; - std::vector list_float = {}; - std::vector list_int = {}; - std::vector> list_list_int = {}; - ge::AscendString value = ""; - uint32_t version = 0; - ge::AscendString comment = ""; -}; - -class OpAICoreConfigImpl { -public: - OpParamTrunk op_params; - std::vector cfg_keys; - std::map cfg_info; -}; - -class OpAICoreDefImpl { -public: - gert::OpImplRegisterV2::TilingKernelFunc tiling_func = nullptr; - gert::OpImplRegisterV2::TilingParseFunc tiling_parse = nullptr; - gert::OpImplRegisterV2::CompileInfoCreatorFunc ci_creator = nullptr; - gert::OpImplRegisterV2::CompileInfoDeleterFunc ci_deleter = nullptr; - optiling::OP_CHECK_FUNC op_chk_support = nullptr; - optiling::OP_CHECK_FUNC op_sel_format = nullptr; - optiling::OP_CHECK_FUNC op_get_support = nullptr; - optiling::OP_CHECK_FUNC op_get_spec = nullptr; - optiling::PARAM_GENERALIZE_FUNC op_generlize_func = nullptr; - std::map aicore_configs = {}; -}; - -class OpMC2DefImpl { -public: - std::vector group_list = {}; - std::map server_type_ = {}; -}; - -class OpDefImpl { -public: - gert::OpImplRegisterV2::InferShapeKernelFunc infer_shape = nullptr; - gert::OpImplRegisterV2::InferShapeRangeKernelFunc infer_shape_range = nullptr; - gert::OpImplRegisterV2::InferDataTypeKernelFunc infer_data_type = nullptr; - OpParamTrunk op_params; - std::vector attrs; - OpAICoreDef op_aicore; - ge::AscendString op_type; - ge::AscendString category = "op_proto"; - std::map> comment_map = { - {ops::CommentSection::BRIEF, {}}, - {ops::CommentSection::CONSTRAINTS, {}}, - {ops::CommentSection::RESTRICTIONS, {}}, - {ops::CommentSection::SEE, {}}, - {ops::CommentSection::THIRDPARTYFWKCOMPAT, {}} - }; - bool has_workspace = true; - uint32_t non_list_len = 0; - OpMC2Def op_mc2; - FormatCheckOption format_mode = FormatCheckOption::MAX; - bool enable_fall_back = false; -}; -} // namespace ops - -#endif diff --git a/register/opdef/op_def_mc2.cc b/register/opdef/op_def_mc2.cc deleted file mode 100644 index 97d1a3307d525494f8b86342ce80693d2a098cec..0000000000000000000000000000000000000000 --- a/register/opdef/op_def_mc2.cc +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. - * [graph-engine] is licensed under Mulan PSL v2. - * You can use this software according to the terms and conditions of the Mulan PSL v2. - * You may obtain a copy of Mulan PSL v2 at: - * http://license.coscl.org.cn/MulanPSL2 - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, - * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, - * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. - * See the Mulan PSL v2 for more details. - */ - -#include -#include "op_def_impl.h" - -namespace ops { -OpMC2Def::OpMC2Def() : impl_(new(std::nothrow) OpMC2DefImpl) {} - -OpMC2Def::OpMC2Def(const OpMC2Def &mc2_def) : impl_(new(std::nothrow) OpMC2DefImpl) { - this->impl_->group_list = mc2_def.impl_->group_list; -} - -OpMC2Def::~OpMC2Def() = default; - -OpMC2Def &OpMC2Def::operator=(const OpMC2Def &mc2_def) { - if (this != &mc2_def) { - *this->impl_ = *mc2_def.impl_; - } - return *this; -} - -OpMC2Def &OpMC2Def::HcclGroup(const char *value) { - if (std::find(this->impl_->group_list.begin(), this->impl_->group_list.end(), value) == - this->impl_->group_list.end()) { - this->impl_->group_list.emplace_back(value); - } - return *this; -} - -OpMC2Def &OpMC2Def::HcclGroup(std::vector value) { - for (const char *val : value) { - if (std::find(this->impl_->group_list.begin(), this->impl_->group_list.end(), val) == - this->impl_->group_list.end()) { - this->impl_->group_list.emplace_back(val); - } - } - return *this; -} - -std::vector &OpMC2Def::GetHcclGroups(void) const { - return this->impl_->group_list; -} - -void OpMC2Def::HcclServerType(enum HcclServerType type, const char* soc) { - ge::AscendString soc_version; - if (soc == nullptr || strlen(soc) == 0) { - soc_version = ""; - } else { - soc_version = soc; - } - this->impl_->server_type_[soc_version] = type; -} - -/** - * @brief get hccl server type by soc version - * @param soc_version "" means checking if any hccl server type has been set - * @return hccl server type corresponding to soc version. - For scenarios where soc version is empty, return MAX if not set, AICPU if set. - */ -enum HcclServerType OpMC2Def::GetHcclServerType(const ge::AscendString &soc_version) const { - if (this->impl_->server_type_.empty()) { - return HcclServerType::MAX; - } - if (soc_version.GetLength() == 0) { - return HcclServerType::AICPU; - } - if (this->impl_->server_type_.find(soc_version) != this->impl_->server_type_.end()) { - return this->impl_->server_type_[soc_version]; - } - if (this->impl_->server_type_.find("") != this->impl_->server_type_.end()) { - return this->impl_->server_type_[""]; - } - return HcclServerType::MAX; -} - -} // namespace ops diff --git a/register/opdef/op_def_param.cc b/register/opdef/op_def_param.cc deleted file mode 100644 index 09b21545bc635376dc88bdaee5553a1ff7153f9a..0000000000000000000000000000000000000000 --- a/register/opdef/op_def_param.cc +++ /dev/null @@ -1,773 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "register/op_def.h" -#include "op_def_impl.h" -#include "common/ge_common/debug/ge_log.h" - -namespace ops { -OpParamDef::OpParamDef(const char *name) : impl_(new(std::nothrow) OpParamDefImpl) { - this->impl_->name = name; -} - -OpParamDef::OpParamDef(const OpParamDef &def) : impl_(new(std::nothrow) OpParamDefImpl) { - this->impl_->name = def.impl_->name; - this->impl_->param_type = def.impl_->param_type; - this->impl_->types = def.impl_->types; - this->impl_->origin_types = def.impl_->origin_types; - this->impl_->formats = def.impl_->formats; - this->impl_->formats_list = def.impl_->formats_list; - this->impl_->types_list = def.impl_->types_list; - this->impl_->need_compile = def.impl_->need_compile; - this->impl_->reshape_type = def.impl_->reshape_type; - this->impl_->value_depend = def.impl_->value_depend; - this->impl_->depend_scope = def.impl_->depend_scope; - this->impl_->unknown_shape_formats = def.impl_->unknown_shape_formats; - this->impl_->ignore_contiguous = def.impl_->ignore_contiguous; - this->impl_->auto_contiguous = def.impl_->auto_contiguous; - this->impl_->is_scalar = def.impl_->is_scalar; - this->impl_->is_scalar_list = def.impl_->is_scalar_list; - this->impl_->types_status = def.impl_->types_status; - this->impl_->formats_status = def.impl_->formats_status; - this->impl_->scalar_name = def.impl_->scalar_name; - this->impl_->scalar_type = def.impl_->scalar_type; - this->impl_->version = def.impl_->version; - this->impl_->init_value_type = def.impl_->init_value_type; - this->impl_->init_value = def.impl_->init_value; - this->impl_->init_value_list = def.impl_->init_value_list; - this->impl_->is_output_shape_depend_on_compute = def.impl_->is_output_shape_depend_on_compute; - this->impl_->follow_port_name = def.impl_->follow_port_name; - this->impl_->follow_type = def.impl_->follow_type; - this->impl_->comment = def.impl_->comment; - this->impl_->types_for_bin = def.impl_->types_for_bin; - this->impl_->formats_for_bin = def.impl_->formats_for_bin; - this->impl_->set_type_for_bin = def.impl_->set_type_for_bin; - this->impl_->set_format_for_bin = def.impl_->set_format_for_bin; -} - - -OpParamDef &OpParamDef::operator=(const OpParamDef &def) { - if (this != &def) { - *this->impl_ = *def.impl_; - } - return *this; -} - -void OpParamDef::MergeParam(const OpParamDef &def) { - this->impl_->param_type = def.impl_->param_type; - if (!def.impl_->types.empty()) { - this->impl_->types = def.impl_->types; - this->impl_->origin_types = def.impl_->origin_types; - } - if (!def.impl_->types_list.empty()) { - this->impl_->types_list = def.impl_->types_list; - } - if (!def.impl_->formats.empty()) { - this->impl_->formats = def.impl_->formats; - } - if (!def.impl_->formats_list.empty()) { - this->impl_->formats_list = def.impl_->formats_list; - } - if (def.impl_->need_compile.GetLength() > 0) { - this->impl_->need_compile = def.impl_->need_compile; - } - if (def.impl_->reshape_type.GetLength() > 0) { - this->impl_->reshape_type = def.impl_->reshape_type; - } - if (def.impl_->value_depend.GetLength() > 0) { - this->impl_->value_depend = def.impl_->value_depend; - } - if (!def.impl_->unknown_shape_formats.empty()) { - this->impl_->unknown_shape_formats = def.impl_->unknown_shape_formats; - } - if (!def.impl_->types_for_bin.empty()) { - this->impl_->types_for_bin = def.impl_->types_for_bin; - this->impl_->set_type_for_bin = def.impl_->set_type_for_bin; - } - if (!def.impl_->formats_for_bin.empty()) { - this->impl_->formats_for_bin = def.impl_->formats_for_bin; - this->impl_->set_format_for_bin = def.impl_->set_format_for_bin; - } - this->impl_->init_value_type = def.impl_->init_value_type; - this->impl_->init_value = def.impl_->init_value; - this->impl_->init_value_list = def.impl_->init_value_list; - this->impl_->ignore_contiguous = def.impl_->ignore_contiguous; - this->impl_->auto_contiguous = def.impl_->auto_contiguous; - this->impl_->is_scalar = def.impl_->is_scalar; - this->impl_->is_scalar_list = def.impl_->is_scalar_list; - this->impl_->types_status = def.impl_->types_status; - this->impl_->formats_status = def.impl_->formats_status; - this->impl_->scalar_name = def.impl_->scalar_name; - this->impl_->scalar_type = def.impl_->scalar_type; - this->impl_->version = def.impl_->version; - this->impl_->is_output_shape_depend_on_compute = def.impl_->is_output_shape_depend_on_compute; - this->impl_->depend_scope = def.impl_->depend_scope; - this->impl_->follow_port_name = def.impl_->follow_port_name; - this->impl_->follow_type = def.impl_->follow_type; - this->impl_->comment = def.impl_->comment; -} - -OpParamDef::~OpParamDef() = default; - -bool OpParamDef::operator==(const OpParamDef &def) const { - if (this->impl_->name == def.impl_->name) { - return true; - } - return false; -} - -OpParamDef &OpParamDef::ParamType(Option param_type) { - this->impl_->param_type = param_type; - return *this; -} - -bool OpParamDef::IsDtype(void) const { - return this->impl_->types_status == NON_LIST; -} - -bool OpParamDef::IsDtypeList(void) const { - return this->impl_->types_status == LIST; -} - -bool OpParamDef::IsFormat(void) const { - return this->impl_->formats_status == NON_LIST; -} - -bool OpParamDef::IsFormatList(void) const { - return this->impl_->formats_status == LIST; -} - -bool OpParamDef::IsScalarOrScalarList(void) const { - return this->IsScalar() || this->IsScalarList(); -} - -bool OpParamDef::IsScalarTypeSet(void) const { - return this->impl_->scalar_type != ge::DT_UNDEFINED; -} - -bool OpParamDef::IsScalarNameSet(void) const { - return std::strcmp(this->impl_->scalar_name.GetString(), "") != 0; -} - -bool OpParamDef::IsValueDepend(void) const { - return std::strcmp(this->impl_->value_depend.GetString(), "") != 0; -} - -OpParamDef &OpParamDef::DataType(std::vector types) { - if (this->IsDtypeList()) { - GELOGE(ge::PARAM_INVALID, "DataTypeList and DataType can not be called at the same time!"); - return *this; - } - if (types.empty()) { - GELOGE(ge::PARAM_INVALID, "DataType can not be empty"); - return *this; - } - if (this->impl_->set_type_for_bin && types.size() != this->impl_->types_for_bin.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : DataType size is not equal to DataTypeForBinQuery size", - this->impl_->name.GetString()); - return *this; - } - this->impl_->types_status = NON_LIST; - this->impl_->types = types; - this->impl_->origin_types = types; - return *this; -} - -OpParamDef &OpParamDef::DataTypeList(std::vector types) { - if (this->IsDtype()) { - GELOGE(ge::PARAM_INVALID, "DataTypeList and DataType can not be called at the same time!"); - return *this; - } - if (types.empty()) { - GELOGE(ge::PARAM_INVALID, "DataTypeList can not be empty"); - return *this; - } - std::unordered_set dtype_set(types.begin(), types.end()); - if (dtype_set.size() < types.size()) { - GELOGE(ge::PARAM_INVALID, "Element of DataTypeList must be unique!"); - return *this; - } - if (this->impl_->set_type_for_bin && types.size() != this->impl_->types_for_bin.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : DataTypeList size is not equal to DataTypeForBinQuery size", - this->impl_->name.GetString()); - return *this; - } - this->impl_->types_status = LIST; - this->impl_->types_list = types; - return *this; -} - -OpParamDef &OpParamDef::Format(std::vector formats) { - if (this->IsFormatList()) { - GELOGE(ge::PARAM_INVALID, "FormatList and Format can not be called at the same time!"); - return *this; - } - if (formats.empty()) { - GELOGE(ge::PARAM_INVALID, "Format can not be empty"); - return *this; - } - if (this->impl_->set_format_for_bin && formats.size() != this->impl_->formats_for_bin.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : Format size is not equal to FormatForBinQuery size", - this->impl_->name.GetString()); - return *this; - } - this->impl_->formats_status = NON_LIST; - this->impl_->formats = formats; - return *this; -} - -OpParamDef &OpParamDef::FormatList(std::vector formats) { - if (this->IsFormat()) { - GELOGE(ge::PARAM_INVALID, "FormatList and Format can not be called at the same time!"); - return *this; - } - if (formats.empty()) { - GELOGE(ge::PARAM_INVALID, "Format can not be empty"); - return *this; - } - std::unordered_set format_set(formats.begin(), formats.end()); - if (format_set.size() < formats.size()) { - GELOGE(ge::PARAM_INVALID, "Element of FormatList must be unique!"); - return *this; - } - if (this->impl_->set_format_for_bin && formats.size() != this->impl_->formats_for_bin.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : FormatList size is not equal to FormatForBinQuery size", - this->impl_->name.GetString()); - return *this; - } - this->impl_->formats_status = LIST; - this->impl_->formats_list = formats; - return *this; -} - -OpParamDef &OpParamDef::DataTypeForBinQuery(std::vector types) { - if (types.empty()) { - GELOGE(ge::PARAM_INVALID, "DataTypeForBinList can not be empty!"); - return *this; - } - if (this->impl_->types_status == NON_LIST && this->impl_->types.size() != types.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : DataTypeForBinQuery size is not equal to DataType size", - this->impl_->name.GetString()); - return *this; - } - if (this->impl_->types_status == LIST && this->impl_->types_list.size() != types.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : DataTypeForBinQuery size is not equal to DataTypeList size", - this->impl_->name.GetString()); - return *this; - } - this->impl_->types_for_bin = types; - this->impl_->set_type_for_bin = true; - return *this; -} - -OpParamDef &OpParamDef::FormatForBinQuery(std::vector formats) { - if (formats.empty()) { - GELOGE(ge::PARAM_INVALID, "FormatForBinList can not be empty!"); - return *this; - } - if (this->impl_->formats_status == NON_LIST && this->impl_->formats.size() != formats.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : FormatForBinQuery size is not equal to Format size", - this->impl_->name.GetString()); - return *this; - } - if (this->impl_->formats_status == LIST && this->impl_->formats_list.size() != formats.size()) { - GELOGE(ge::PARAM_INVALID, "Param %s : FormatForBinQuery size is not equal to FormatList size", - this->impl_->name.GetString()); - return *this; - } - this->impl_->formats_for_bin = formats; - this->impl_->set_format_for_bin = true; - return *this; -} - -OpParamDef &OpParamDef::UnknownShapeFormat(std::vector formats) { - this->impl_->unknown_shape_formats = formats; - return *this; -} - -OpParamDef &OpParamDef::ValueDepend(Option value_depend) { - if (value_depend == Option::REQUIRED) { - this->impl_->value_depend = "required"; - } else if (value_depend == Option::OPTIONAL) { - this->impl_->value_depend = "optional"; - } else { - this->impl_->value_depend = ""; - GELOGW("Param %s : ValueDepend Option is Invalid", this->impl_->name.GetString()); - return *this; - } - this->impl_->depend_scope = DependScope::ALL; - return *this; -} - -OpParamDef &OpParamDef::ValueDepend(Option value_depend, DependScope scope) { - if (scope >= DependScope::INVALID_SCOPE) { - GELOGE(ge::PARAM_INVALID, "Param %s : ValueDepend DependScope is Invalid", this->impl_->name.GetString()); - return *this; - } - if (this->ValueDepend(value_depend).impl_->value_depend.GetLength() > 0) { - this->impl_->depend_scope = scope; - } - return *this; -} - -OpParamDef &OpParamDef::IgnoreContiguous(void) { - this->impl_->ignore_contiguous = true; - return *this; -} - -OpParamDef &OpParamDef::AutoContiguous() { - this->impl_->auto_contiguous = true; - return *this; -} - -OpParamDef &OpParamDef::Scalar() { - this->impl_->is_scalar = true; - return *this; -} - -OpParamDef &OpParamDef::ScalarList() { - this->impl_->is_scalar_list = true; - return *this; -} - -OpParamDef &OpParamDef::To(const ge::DataType type) { - if (!this->impl_->is_scalar && !this->impl_->is_scalar_list) { - GELOGE(ge::PARAM_INVALID, "Param %s : To must be set on the Scalar/ScalarList parameter.", - this->impl_->name.GetString()); - return *this; - } - if (this->impl_->follow_type != FollowType::INVALID_TYPE) { - GELOGE(ge::PARAM_INVALID, "Param %s : To is incompatible with Follow", this->impl_->name.GetString()); - return *this; - } - - this->impl_->scalar_type = type; - return *this; -} - -OpParamDef &OpParamDef::To(const char *name) { - if (!this->impl_->is_scalar && !this->impl_->is_scalar_list) { - GELOGE(ge::PARAM_INVALID, "Param %s : To must be set on the Scalar/ScalarList parameter.", - this->impl_->name.GetString()); - return *this; - } - if (this->impl_->follow_type != FollowType::INVALID_TYPE) { - GELOGE(ge::PARAM_INVALID, "Param %s : To is incompatible with Follow", this->impl_->name.GetString()); - return *this; - } - this->impl_->scalar_name = name; - return *this; -} - -OpParamDef &OpParamDef::Version(uint32_t version) { - this->impl_->version = version; - return *this; -} - -OpParamDef &OpParamDef::InitValue(uint64_t value) { - this->impl_->init_value.value_u64 = value; - this->impl_->init_value_type = InitValueType::INIT_VALUE_UINT64_T; - return *this; -} - -OpParamDef &OpParamDef::InitValue(const ScalarVar &value) { - if (!this->impl_->init_value_list.empty()) { - GELOGW("InitValue has been set, %s InitValue will be reset, please check whether it is correct.", - this->impl_->name.GetString()); - this->impl_->init_value_list.clear(); - } - this->impl_->init_value_list.emplace_back(value); - return *this; -} - -OpParamDef &OpParamDef::InitValue(const std::vector &value) { - if (!this->impl_->init_value_list.empty()) { - GELOGW("InitValue has been set, %s InitValue will be reset, please check whether it is correct.", - this->impl_->name.GetString()); - this->impl_->init_value_list.clear(); - } - this->impl_->init_value_list.assign(value.begin(), value.end()); - return *this; -} - -OpParamDef &OpParamDef::OutputShapeDependOnCompute() { - this->impl_->is_output_shape_depend_on_compute = true; - return *this; -} - -OpParamDef &OpParamDef::Follow(const char *paramName) -{ - if (this->IsScalarTypeSet() || this->IsScalarNameSet()) { - GELOGE(ge::PARAM_INVALID, "Param %s : Follow is incompatible with To", this->impl_->name.GetString()); - return *this; - } - this->impl_->follow_port_name = paramName; - this->impl_->follow_type = FollowType::ALL; - return *this; -} - -OpParamDef &OpParamDef::Follow(const char *paramName, FollowType ftype) -{ - if (this->IsScalarTypeSet() || this->IsScalarNameSet()) { - GELOGE(ge::PARAM_INVALID, "Param %s : Follow is incompatible with To", this->impl_->name.GetString()); - return *this; - } - if (ftype >= FollowType::INVALID_TYPE) { - GELOGE(ge::PARAM_INVALID, "Port %s : FollowType is Invalid", this->impl_->name.GetString()); - return *this; - } - this->impl_->follow_port_name = paramName; - this->impl_->follow_type = ftype; - return *this; -} - -OpParamDef &OpParamDef::Comment(const char *comment) { - if (comment == nullptr || strlen(comment) == 0) { - GELOGE(ge::PARAM_INVALID, "Param %s : Comment content cannot be empty", this->GetParamName().GetString()); - return *this; - } - this->impl_->comment = comment; - return *this; -} - -bool OpParamDef::IsOutputShapeDependOnCompute(void) const { - return this->impl_->is_output_shape_depend_on_compute; -} - -ge::AscendString &OpParamDef::GetParamName(void) const { - return this->impl_->name; -} -Option OpParamDef::GetParamType(void) { - return this->impl_->param_type; -} -std::vector &OpParamDef::GetDataTypes(void) { - if (this->impl_->types.empty()) { - GELOGW("GetDataTypes returns types_list because types is empty!"); - return this->impl_->types_list; - } - return this->impl_->types; -} - -std::vector &OpParamDef::GetOriginDataTypes(void) { - if (this->impl_->origin_types.empty()) { - GELOGE(ge::PARAM_INVALID, "origin types is empty, please check!"); - return this->impl_->origin_types; - } - return this->impl_->origin_types; -} - -std::vector &OpParamDef::GetDataTypesList(void) { - return this->impl_->types_list; -} -std::vector &OpParamDef::GetDataTypesForBin(void) const { - return this->impl_->types_for_bin; -} -bool OpParamDef::IsSetDtypeForBin(void) const { - return this->impl_->set_type_for_bin; -} -std::vector &OpParamDef::GetFormats(void) { - return this->impl_->formats; -} -std::vector &OpParamDef::GetFormatsList(void) { - return this->impl_->formats_list; -} -std::vector &OpParamDef::GetFormatsForBin(void) const { - return this->impl_->formats_for_bin; -} -bool OpParamDef::IsSetFormatForBin(void) const { - return this->impl_->set_format_for_bin; -} -std::vector &OpParamDef::GetUnknownShapeFormats(void) { - return this->impl_->unknown_shape_formats; -} -ge::AscendString &OpParamDef::GetValueDepend(void) const { - return this->impl_->value_depend; -} -DependScope &OpParamDef::GetDependScope(void) const { - return this->impl_->depend_scope; -} -ge::AscendString &OpParamDef::GetFollowName(void) const { - return this->impl_->follow_port_name; -} -FollowType &OpParamDef::GetFollowType(void) const { - return this->impl_->follow_type; -} -ge::AscendString &OpParamDef::GetComment(void) const { - return this->impl_->comment; -} -bool OpParamDef::GetIgnoreContiguous(void) { - return this->impl_->ignore_contiguous; -} -bool OpParamDef::GetAutoContiguous(void) { - return this->impl_->auto_contiguous; -} -bool OpParamDef::IsScalar(void) const { - return this->impl_->is_scalar; -} -bool OpParamDef::IsScalarList(void) const { - return this->impl_->is_scalar_list; -} -ge::AscendString &OpParamDef::GetScalarName(void) const { - return this->impl_->scalar_name; -} -ge::DataType OpParamDef::GetScalarType(void) const { - return this->impl_->scalar_type; -} - -uint32_t OpParamDef::GetVersion(void) { - return this->impl_->version; -} - -InitValueType &OpParamDef::GetInitValueType(void) { - return this->impl_->init_value_type; -} - -InitValueNum &OpParamDef::GetInitValue(void) { - return this->impl_->init_value; -} - -std::vector &OpParamDef::GetInitValueList(void) { - return this->impl_->init_value_list; -} - -OpParamDef &OpParamTrunk::Input(const char *name) { - return this->ParamGetOrCreate(name, false); -} - -OpParamDef &OpParamTrunk::Output(const char *name) { - return this->ParamGetOrCreate(name, true); -} - -OpParamDef &OpParamTrunk::ParamGetOrCreate(const char *name, bool is_output) { - OpParamDef *param; - if (this->ParamFind(name, is_output, ¶m) == ItemFindStatus::ITEM_FIND) { - return *param; - } else { - OpParamDef addParam(name); - return this->ParamAdd(addParam, is_output); - } -} - -ItemFindStatus OpParamTrunk::ParamFind(const char *name, bool is_output, OpParamDef **param) { - std::vector *paramList; - - if (is_output) { - paramList = &(this->outputs_); - } else { - paramList = &(this->inputs_); - } - for (auto it = paramList->begin(); it != paramList->end(); it++) { - if (it->GetParamName() == name) { - *param = &(*it); - return ItemFindStatus::ITEM_FIND; - } - } - return ItemFindStatus::ITEM_NOEXIST; -} - -OpParamDef &OpParamTrunk::ParamAdd(OpParamDef ¶m, bool is_output) { - FollowMapUpdate(param, is_output); - if (is_output) { - this->outputs_.emplace_back(param); - return this->outputs_.back(); - } else { - this->inputs_.emplace_back(param); - return this->inputs_.back(); - } -} - -std::vector &OpParamTrunk::GetInputs(void) { - return this->inputs_; -} - -std::vector &OpParamTrunk::GetOutputs(void) { - return this->outputs_; -} -void OpParamTrunk::FollowMapUpdate(OpParamDef ¶m, bool is_output) { - ge::AscendString& cur_name = param.GetParamName(); - if (this->follow_map.find(cur_name) != this->follow_map.end()) { - OpDef::PortFollowInfo& follow_info = this->follow_map[param.GetParamName()]; - follow_info.port_stat = OpDef::PortStat::INOUT; - if (is_output) { - follow_info.index_out = this->outputs_.size(); - } else { - follow_info.index_in = this->inputs_.size(); - } - return; - } - OpDef::PortFollowInfo follow_info; - if (is_output) { - follow_info.port_stat = OpDef::PortStat::OUT; - follow_info.index_out = this->outputs_.size(); - } else { - follow_info.port_stat = OpDef::PortStat::IN; - follow_info.index_in = this->inputs_.size(); - } - this->follow_map.emplace(cur_name, follow_info); - return; -} - -OpParamDef &OpParamTrunk::GetParamDef(const ge::AscendString& name, OpDef::PortStat stat) { - OpDef::PortFollowInfo& follow_info = this->follow_map[name]; - if (stat == OpDef::PortStat::OUT) { - return this->outputs_[follow_info.index_out]; - } else { - return this->inputs_[follow_info.index_in]; - } -} - -void OpParamTrunk::FollowDataImpl(void) { - if (this->follow_isimpl == true) { - return; - } - for (auto& op_param_def : this->inputs_) { - this->DfsFollow(op_param_def, OpDef::PortStat::IN); - } - for (auto& op_param_def : this->outputs_) { - this->DfsFollow(op_param_def, OpDef::PortStat::OUT); - } - this->follow_isimpl = true; - return; -} -void OpParamTrunk::ParamFollow(OpParamDef &op_param_def, OpParamDef &target_param, OpDef::PortStat stat) { - ge::AscendString cur_name = op_param_def.GetParamName(); - FollowType ftype = op_param_def.GetFollowType(); - ge::AscendString follow_name = target_param.GetParamName(); - if (ftype == FollowType::ALL || ftype == FollowType::DTYPE) { - this->follow_dtype_map[follow_name].emplace_back(std::make_pair(cur_name, stat)); - if (target_param.IsDtype()) { - op_param_def.impl_->types_status = target_param.impl_->types_status; - op_param_def.impl_->types = target_param.impl_->types; - op_param_def.impl_->origin_types = target_param.impl_->origin_types; - } - if (target_param.IsDtypeList()) { - op_param_def.impl_->types_status = target_param.impl_->types_status; - op_param_def.impl_->types_list = std::vector(1, target_param.impl_->types_list[0]); - this->follow_dtypelist.emplace_back(std::make_pair(cur_name, stat)); - } - } - if (ftype == FollowType::ALL || ftype == FollowType::FORMAT) { - if (target_param.IsFormat()) { - op_param_def.impl_->formats_status = target_param.impl_->formats_status; - op_param_def.impl_->formats = target_param.impl_->formats; - } - if (target_param.IsFormatList()) { - op_param_def.impl_->formats_status = target_param.impl_->formats_status; - op_param_def.impl_->formats_list = std::vector(1, target_param.impl_->formats_list[0]); - this->follow_formatlist.emplace_back(std::make_pair(cur_name, stat)); - } - } - if (ftype == FollowType::ALL || ftype == FollowType::SHAPE) { - this->follow_shape_map[follow_name].emplace_back(std::make_pair(cur_name, stat)); - } -} - -void OpParamTrunk::DfsFollow(OpParamDef& op_param_def, OpDef::PortStat stat) { - if (op_param_def.GetFollowType() >= FollowType::INVALID_TYPE || - op_param_def.GetFollowName() == ge::AscendString("")) { - return; - } - ge::AscendString cur_name = op_param_def.GetParamName(); - ge::AscendString follow_name = op_param_def.GetFollowName(); - FollowType ftype = op_param_def.GetFollowType(); - std::map& flw_mp = this->follow_map; - OpDef::PortFollowInfo& follow_info = flw_mp[cur_name]; - if (flw_mp.find(follow_name) == flw_mp.end()) { - GELOGE(ge::PARAM_INVALID, "PortName %s : FollowPort is Not Exist", cur_name.GetString()); - return; - } - if (cur_name == follow_name && flw_mp[cur_name].port_stat != OpDef::PortStat::INOUT) { - GELOGE(ge::PARAM_INVALID, "PortName %s : FollowPort ParamData is Not Found", cur_name.GetString()); - return; - } - if (cur_name != follow_name) { - std::map ring_check_map; - while (flw_mp.find(follow_name) != flw_mp.end()) { - if (ring_check_map.find(follow_name) != ring_check_map.end()) { - GELOGE(ge::PARAM_INVALID, "Port %s : FollowData not Found", cur_name.GetString()); - return; - } - if (flw_mp[follow_name].follow_port_name == ge::AscendString("")) { - break; - } - if (flw_mp[follow_name].follow_type != ftype) { - GELOGE(ge::PARAM_INVALID, "Port %s : FollowType cannot be changed.", cur_name.GetString()); - return; - } - ring_check_map.emplace(follow_name, 1); - follow_name = flw_mp[follow_name].follow_port_name; - } - } - OpDef::PortFollowInfo& target_follow_info = flw_mp[follow_name]; - if (target_follow_info.port_stat == OpDef::PortStat::OUT) { - GELOGE(ge::PARAM_INVALID, "Port %s : FollowData not Found", cur_name.GetString()); - return; - } - follow_info.follow_port_name = follow_name; - follow_info.follow_type = ftype; - op_param_def.impl_->follow_port_name = follow_name; - OpParamDef& target_param = this->inputs_[target_follow_info.index_in]; - this->ParamFollow(op_param_def, target_param, stat); -} - -void OpParamTrunk::FollowListDataImpl(const OpDef::DfsParam &dfs_param, std::vector &input, - std::vector &output) { - ge::AscendString name; - OpDef::PortStat stat; - uint32_t idx = 0; - auto get_param_ref = [this, &input, &output](const ge::AscendString &port_name, - OpDef::PortStat port_stat) -> OpParamDef & { - if (port_stat == OpDef::PortStat::OUT) { - return output[this->follow_map[port_name].index_out]; - } else { - return input[this->follow_map[port_name].index_in]; - } - }; - for (const auto& param_pair : this->follow_dtypelist) { - std::tie(name, stat) = param_pair; - uint32_t target_index = this->follow_map[this->follow_map[name].follow_port_name].index_in; - OpParamDef& target_param = input[target_index]; - OpParamDef& op_param_def = get_param_ref(name, stat); - op_param_def.impl_->types = target_param.impl_->types; - if (op_param_def.IsSetDtypeForBin()) { - std::vector data_types_for_bin; - for (uint32_t type_idx = 0; type_idx < dfs_param.full_types.size(); ++type_idx) { - idx = dfs_param.full_types[type_idx][target_index].first; - data_types_for_bin.emplace_back(op_param_def.impl_->types_for_bin[idx]); - } - op_param_def.impl_->types_for_bin = data_types_for_bin; - } - } - for (const auto& param_pair : this->follow_formatlist) { - std::tie(name, stat) = param_pair; - uint32_t target_index = this->follow_map[this->follow_map[name].follow_port_name].index_in; - OpParamDef& target_param = input[target_index]; - OpParamDef& op_param_def = get_param_ref(name, stat); - op_param_def.impl_->formats = target_param.impl_->formats; - if (op_param_def.IsSetFormatForBin()) { - std::vector data_formats_for_bin; - for (uint32_t format_idx = 0; format_idx < dfs_param.full_formats.size(); ++format_idx) { - idx = dfs_param.full_formats[format_idx][input.size() + target_index].first; - data_formats_for_bin.emplace_back(op_param_def.impl_->formats_for_bin[idx]); - } - op_param_def.impl_->formats_for_bin = data_formats_for_bin; - } - } -} -std::map OpParamTrunk::GetFollowMap(void) { - return this->follow_map; -} -std::map>> OpParamTrunk::GetShapeMap() { - return this->follow_shape_map; -} -std::map>> OpParamTrunk::GetDtypeMap() { - return this->follow_dtype_map; -} -} // namespace ops diff --git a/register/ops_kernel_builder_registry.cc b/register/ops_kernel_builder_registry.cc deleted file mode 100644 index d3bf9bd891a03bd6fab8f6fd83eda0288fd905de..0000000000000000000000000000000000000000 --- a/register/ops_kernel_builder_registry.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/ops_kernel_builder_registry.h" -#include "graph/debug/ge_log.h" -#include "inc/common/util/sanitizer_options.h" - -namespace ge { -OpsKernelBuilderRegistry::~OpsKernelBuilderRegistry() noexcept { - // disable memory leaks within the scope - DT_ALLOW_LEAKS_GUARD(OpsKernelBuilderRegistry); - for (auto &it : kernel_builders_) { - GELOGW("[Unregister][Destruct] %s was not unregistered", it.first.c_str()); - // to avoid core dump when unregister is not called when so was close - // this is called only when app is shutting down, so no release would be leaked - new (std::nothrow) std::shared_ptr(it.second); - } -} -void OpsKernelBuilderRegistry::Register(const string &lib_name, const OpsKernelBuilderPtr &instance) { - const auto it = kernel_builders_.emplace(lib_name, instance); - if (it.second) { - GELOGI("Register OpsKernelBuilder successfully, kernel lib name = %s", lib_name.c_str()); - } else { - GELOGW("[Register][Check] OpsKernelBuilder already registered. kernel lib name = %s", lib_name.c_str()); - } -} - -void OpsKernelBuilderRegistry::UnregisterAll() { - kernel_builders_.clear(); - GELOGI("All builders are unregistered"); -} - -void OpsKernelBuilderRegistry::Unregister(const string &lib_name) { - (void)kernel_builders_.erase(lib_name); - GELOGI("OpsKernelBuilder of %s is unregistered", lib_name.c_str()); -} - -const std::map &OpsKernelBuilderRegistry::GetAll() const { - return kernel_builders_; -} -OpsKernelBuilderRegistry &OpsKernelBuilderRegistry::GetInstance() { - static OpsKernelBuilderRegistry instance; - return instance; -} - -OpsKernelBuilderRegistrar::OpsKernelBuilderRegistrar(const string &kernel_lib_name, - const CreateFn fn) - : kernel_lib_name_(kernel_lib_name) { - GELOGI("Register kernel lib name = %s", kernel_lib_name.c_str()); - std::shared_ptr builder; - if (fn != nullptr) { - builder.reset(fn()); - if (builder == nullptr) { - GELOGE(INTERNAL_ERROR, "[Create][OpsKernelBuilder]kernel lib name = %s", kernel_lib_name.c_str()); - } - } else { - GELOGE(INTERNAL_ERROR, "[Check][Param:fn]Creator is nullptr, kernel lib name = %s", kernel_lib_name.c_str()); - } - - // May add empty ptr, so that error can be found afterward - OpsKernelBuilderRegistry::GetInstance().Register(kernel_lib_name, builder); -} - -OpsKernelBuilderRegistrar::~OpsKernelBuilderRegistrar() noexcept { - GELOGI("Unregister kernel lib name = %s", kernel_lib_name_.c_str()); - OpsKernelBuilderRegistry::GetInstance().Unregister(kernel_lib_name_); -} -} // namespace ge diff --git a/register/optimization_option_registry.cc b/register/optimization_option_registry.cc deleted file mode 100644 index ecf1ae19ae981edb1893ace3965c76217a3e8bc2..0000000000000000000000000000000000000000 --- a/register/optimization_option_registry.cc +++ /dev/null @@ -1,231 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/optimization_option_registry.h" -#include "graph/debug/ge_log.h" - -namespace { -bool IsPassOptionValid(const std::string& pass_name, const std::map &options) { - // 一级选项必须有且仅有一个,二级选项有且仅有一个或零个 - if (options.count(ge::OoHierarchy::kH1) == 0UL) { - GELOGW("The pass [%s] has no primary switch option", pass_name.c_str()); - return false; - } - - for (const auto &opt : options) { - if (opt.second.empty()) { - GELOGW("The pass [%s] has an empty option, pass option is not registered", - pass_name.c_str()); - return false; - } - if (opt.first >= ge::OoHierarchy::kEnd) { - GELOGW("The hierarchy [%u] of option [%s] is InValid, pass name is [%s]", static_cast(opt.first), - opt.second.c_str(), pass_name.c_str()); - return false; - } - } - return true; -} -} -namespace ge { -OptionRegistry &OptionRegistry::GetInstance() { - static OptionRegistry instance; - return instance; -} - -void OptionRegistry::Register(const OoInfo &option) { - if (registered_opt_table_.count(option.name) > 0UL) { - GELOGW("Repeatedly register option [%s]", option.name.c_str()); - return; - } - (void) registered_opt_table_.emplace(option.name, option); - GELOGD("Add optimization option [%s], OoLevel is [%s]", option.name.c_str(), - OoInfoUtils::GenOoLevelStr(option.levels).c_str()); -} - -const OoInfo *OptionRegistry::FindOptInfo(const std::string &opt_name) const { - const auto iter = registered_opt_table_.find(opt_name); - if (iter != registered_opt_table_.cend()) { - return &iter->second; - } - return nullptr; -} - -std::unordered_map OptionRegistry::GetVisibleOptions(OoEntryPoint entry_point) const { - std::unordered_map visible_options; - for (const auto &opt : registered_opt_table_) { - const auto &option_info = opt.second; - if (OoInfoUtils::IsBitSet(option_info.visibility, static_cast(entry_point))) { - const auto iter = option_info.show_infos.find(entry_point); - if (iter != option_info.show_infos.end()) { - (void) visible_options.emplace(iter->second.show_name, option_info); - } - } - } - return visible_options; -} - -PassOptionRegistry &PassOptionRegistry::GetInstance() { - static PassOptionRegistry instance; - return instance; -} - -void PassOptionRegistry::Register(const std::string &pass_name, - const std::map &option_names) { - if (pass_names_to_options_.count(pass_name) > 0UL) { - GELOGW("Repeatedly register optimization option for Pass [%s]", pass_name.c_str()); - return; - } - auto &opt_array = pass_names_to_options_[pass_name]; - for (const auto &opt : option_names) { - opt_array[static_cast(opt.first)] = opt.second; - GELOGD("Add optimization option [%s] for Pass [%s]", opt.second.c_str(), pass_name.c_str()); - } -} - -graphStatus PassOptionRegistry::FindOptionNamesByPassName(const std::string &pass_name, - std::vector &option_names) const { - const auto iter = pass_names_to_options_.find(pass_name); - if (iter == pass_names_to_options_.end()) { - return GRAPH_FAILED; - } - for (const auto &opt_name : iter->second) { - if (!opt_name.empty()) { - option_names.emplace_back(opt_name); - } - } - return GRAPH_SUCCESS; -} - -OptionRegister::OptionRegister(const OptionRegister &other) : opt_reg_data_(nullptr) { - const auto reg_data_ptr = other.opt_reg_data_.get(); - if (reg_data_ptr == nullptr) { - GELOGW("The opt_reg_data_ is null, failed to register optimization option"); - return; - } - - if (reg_data_ptr->name.empty()) { - GELOGW("The option name is empty, failed to register optimization option"); - return; - } - - if (reg_data_ptr->levels == 0UL) { - GELOGW("The option level is not set or invalid, failed to register optimization option"); - return; - } - - if (reg_data_ptr->hierarchy >= OoHierarchy::kEnd) { - GELOGW("The option hierarchy is not set or invalid, failed to register optimization option"); - return; - } - - OptionRegistry::GetInstance().Register(*reg_data_ptr); -} - -OptionRegister &OptionRegister::SetDefaultValues(const std::map &opt_values) { - if (opt_reg_data_ != nullptr) { - if (opt_reg_data_->default_values.empty()) { - opt_reg_data_->default_values = opt_values; - } - } - return *this; -} - -OptionRegister &OptionRegister::SetOptLevel(const std::vector &levels) { - if (opt_reg_data_ != nullptr) { - if (opt_reg_data_->levels == 0UL) { - opt_reg_data_->levels = OoInfoUtils::GenOptLevelBits(levels); - } - } - return *this; -} - -OptionRegister &OptionRegister::SetVisibility(const std::vector &entry_points) { - if (opt_reg_data_ != nullptr) { - if (opt_reg_data_->visibility == 0UL) { - opt_reg_data_->visibility = OoInfoUtils::GenOptVisibilityBits(entry_points); - } - } - return *this; -} - -OptionRegister &OptionRegister::SetOptValueChecker(OoInfo::ValueChecker opt_checker) { - if (opt_reg_data_ != nullptr) { - if (opt_reg_data_->checker == nullptr) { - opt_reg_data_->checker = opt_checker; - } - } - return *this; -} - -OptionRegister &OptionRegister::SetHelpText(std::string opt_help) { - if (opt_reg_data_ != nullptr) { - if (opt_reg_data_->help_text.empty()) { - opt_reg_data_->help_text = std::move(opt_help); - } - } - return *this; -} - -OptionRegister &OptionRegister::SetShowName(OoEntryPoint entry_point, std::string show_name, ge::OoCategory category) { - if (opt_reg_data_ != nullptr) { - if (opt_reg_data_->show_infos.count(entry_point) == 0UL) { - opt_reg_data_->show_infos.emplace(entry_point, OoShowInfo{category, std::move(show_name)}); - } - } - return *this; -} - -PassOptionRegister::PassOptionRegister(const PassOptionRegister &other) : pass_reg_data_(nullptr) { - const auto pass_reg_data_ptr = other.pass_reg_data_.get(); - if (pass_reg_data_ptr == nullptr) { - GELOGW("The pass_reg_data_ is null, failed to bind optimization option to pass"); - return; - } - if (pass_reg_data_ptr->pass_name.empty()) { - GELOGW("The pass name is empty, failed to bind optimization option for pass"); - return; - } - - if (pass_reg_data_ptr->options.empty()) { - for (const auto level : pass_reg_data_ptr->levels) { - if (level >= OoLevel::kEnd) { - GELOGW("The option level [%u] of pass [%s] is invalid", static_cast(level), - pass_reg_data_ptr->pass_name.c_str()); - return; - } - } - // 功能性 Pass 没有对外开放的 Option, 所以使用 PassName 作为 OptionName - OoInfo default_option{pass_reg_data_ptr->pass_name, OoHierarchy::kH1, - OoInfoUtils::GenOptLevelBits(pass_reg_data_ptr->levels)}; - OptionRegistry::GetInstance().Register(default_option); - PassOptionRegistry::GetInstance().Register(pass_reg_data_ptr->pass_name, {{OoHierarchy::kH1, default_option.name}}); - } else { - if (IsPassOptionValid(pass_reg_data_ptr->pass_name, pass_reg_data_ptr->options)) { - PassOptionRegistry::GetInstance().Register(pass_reg_data_ptr->pass_name, pass_reg_data_ptr->options); - } - } -} - -PassOptionRegister &PassOptionRegister::SetOptLevel(const std::vector &levels) { - if (pass_reg_data_ != nullptr) { - if (pass_reg_data_->levels.empty()) { - pass_reg_data_->levels = levels; - } - } - return *this; -} - -PassOptionRegister &PassOptionRegister::BindSwitchOption(const std::string &opt_name, OoHierarchy hierarchy) { - if (pass_reg_data_ != nullptr) { - (void) pass_reg_data_->options.emplace(hierarchy, opt_name); - } - return *this; -} -} // namespace ge \ No newline at end of file diff --git a/register/pass_option_utils.cc b/register/pass_option_utils.cc deleted file mode 100644 index cc508909337572a7a7faac104a0c37e72098f1cb..0000000000000000000000000000000000000000 --- a/register/pass_option_utils.cc +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "register/pass_option_utils.h" -#include "register/optimization_option_registry.h" -#include "ge_common/debug/ge_log.h" -#include "common/checker.h" -#include "graph/ge_local_context.h" - -namespace ge { -graphStatus PassOptionUtils::CheckIsPassEnabled(const std::string &pass_name, bool &is_enabled) { - // 查看是否通过optimization_swtich配置过pass开关,配置过的话working_opt_names_to_value_中有pass_name这个option,且值为"on"或"off" - const auto res = CheckIsPassEnabledByOption(pass_name, is_enabled); - if (res == GRAPH_SUCCESS) { - return GRAPH_SUCCESS; - } - GELOGD("The pass [%s] is not enabled by graph options", pass_name.c_str()); - /** - * 原有逻辑分几种情况: - * 1.PassOptionRegistry::pass_names_to_options_中没有pass关联的option,即pass未被注册,返回非零错误码 - * 2.PassOptionRegistry::pass_names_to_options_中有pass关联的option,但是其关联的option都没有注册,返回非零错误码 - * 3.PassOptionRegistry::pass_names_to_options_中有pass关联的option且有关联的option注册,根据passname到其关联的option中获取开关选项值, - * 1) 若为空或者"true",则is_enabled返回true(需要注册关联的option在注册时设置了该级别的默认值) - * 2) 若为"false",则is_enabled返回false - */ - std::vector opt_names; - const auto ret = PassOptionRegistry::GetInstance().FindOptionNamesByPassName(pass_name, opt_names); - if (ret != SUCCESS) { - // 若Pass未被注册,返回非零错误码,由调用方判断如何处理 - GELOGI("The pass [%s] is not registered", pass_name.c_str()); - return ret; - } - - // 当前最多支持两级开关,opt_names.size() <= 2 - std::vector opt_infos; - for (const auto &opt_name : opt_names) { - const auto info_ptr = OptionRegistry::GetInstance().FindOptInfo(opt_name); - if (info_ptr == nullptr) { - GELOGW("Option [%s] of pass [%s] is not registered", opt_name.c_str(), pass_name.c_str()); - continue; - } - opt_infos.emplace_back(info_ptr); - } - // Pass关联的选项均未注册,说明注册阶段遗漏了选项 - if (opt_infos.empty()) { - GELOGW("the pass [%s] has no registered option", pass_name.c_str()); - return GRAPH_FAILED; - } - - is_enabled = false; - const auto &oo = GetThreadLocalContext().GetOo(); - for (auto it = opt_infos.crbegin(); it != opt_infos.crend(); ++it) { - const auto opt = *it; - std::string opt_value; - if (oo.GetValue(opt->name, opt_value) == GRAPH_SUCCESS) { - if (opt_value.empty() || (opt_value == "true")) { - GELOGD("the pass [%s] is enabled, option [%s] is [%s]", pass_name.c_str(), opt->name.c_str(), - opt_value.c_str()); - is_enabled = true; - return GRAPH_SUCCESS; - } else { - GELOGD("the pass [%s] is disabled, option [%s] is [%s]", pass_name.c_str(), opt->name.c_str(), - opt_value.c_str()); - return GRAPH_SUCCESS; - } - } - } - // OoTable中没有配置该Pass的开关选项,说明不使能 - GELOGD("the pass [%s] is disabled, option is not in working option table", pass_name.c_str()); - return GRAPH_SUCCESS; -} - -graphStatus PassOptionUtils::CheckIsPassEnabledByOption(const std::string &pass_name, bool &is_enabled) { - std::string opt_value; - const auto res = GetThreadLocalContext().GetOo().GetValue(pass_name, opt_value); - if (res != SUCCESS) { - return GRAPH_FAILED; - } - - if (opt_value == "on" || opt_value == "off") { - GELOGD("The pass [%s] is configured by graph options, switch is [%s]", pass_name.c_str(), opt_value.c_str()); - is_enabled = (opt_value == "on"); - return GRAPH_SUCCESS; - } - - return GRAPH_FAILED; -} -} // namespace ge \ No newline at end of file diff --git a/register/prototype_pass_registry.cc b/register/prototype_pass_registry.cc deleted file mode 100644 index fb5e4b18366d8f8901fac55f1251fecb0dad4031..0000000000000000000000000000000000000000 --- a/register/prototype_pass_registry.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "mutex" -#include "register/prototype_pass_registry.h" - -#include "graph/debug/ge_log.h" -#include "graph/types.h" -#include "graph/debug/ge_util.h" - -namespace ge { -class ProtoTypePassRegistry::ProtoTypePassRegistryImpl { - public: - void RegisterProtoTypePass(const std::string &pass_name, ProtoTypePassRegistry::CreateFn create_fn, - const domi::FrameworkType fmk_type) { - const std::lock_guard lock(mu_); - if (std::find(pass_names_.begin(), pass_names_.end(), pass_name) != pass_names_.end()) { - GELOGW("[Register][Check] The prototype pass %s has been registered and will not overwrite the previous one", - pass_name.c_str()); - return; - } - pass_names_.push_back(pass_name); - - const auto iter = create_fns_.find(fmk_type); - if (iter != create_fns_.end()) { - create_fns_[fmk_type].push_back(std::make_pair(pass_name, create_fn)); - GELOGD("Register prototype pass, pass name = %s", pass_name.c_str()); - return; - } - - std::vector> create_fn_vector; - create_fn_vector.push_back(std::make_pair(pass_name, create_fn)); - create_fns_[fmk_type] = create_fn_vector; - GELOGD("Register prototype pass, pass name = %s", pass_name.c_str()); - } - - std::vector> GetCreateFnByType( - domi::FrameworkType fmk_type) { - const std::lock_guard lock(mu_); - const auto iter = create_fns_.find(fmk_type); - if (iter == create_fns_.end()) { - return std::vector>{}; - } - return iter->second; - } - - private: - std::mutex mu_; - std::vector pass_names_; - std::map>> create_fns_; -}; - -ProtoTypePassRegistry::ProtoTypePassRegistry() { - impl_ = ge::ComGraphMakeUnique(); -} - -ProtoTypePassRegistry::~ProtoTypePassRegistry() = default; - -ProtoTypePassRegistry &ProtoTypePassRegistry::GetInstance() { - static ProtoTypePassRegistry instance; - return instance; -} - -void ProtoTypePassRegistry::RegisterProtoTypePass(const char_t *const pass_name, const CreateFn &create_fn, - const domi::FrameworkType fmk_type) { - if (impl_ == nullptr) { - GELOGE(MEMALLOC_FAILED, "ProtoTypePassRegistry is not properly initialized."); - return; - } - std::string str_pass_name; - if (pass_name != nullptr) { - str_pass_name = pass_name; - } - impl_->RegisterProtoTypePass(str_pass_name, create_fn, fmk_type); -} - -std::vector> ProtoTypePassRegistry::GetCreateFnByType( - const domi::FrameworkType fmk_type) const { - if (impl_ == nullptr) { - GELOGE(MEMALLOC_FAILED, "ProtoTypePassRegistry is not properly initialized."); - return std::vector>{}; - } - return impl_->GetCreateFnByType(fmk_type); -} - -ProtoTypePassRegistrar::ProtoTypePassRegistrar(const char_t *const pass_name, ProtoTypeBasePass *(*const create_fn)(), - const domi::FrameworkType fmk_type) { - if (pass_name == nullptr) { - GELOGE(PARAM_INVALID, "Failed to register ProtoType pass, pass name is null."); - return; - } - ProtoTypePassRegistry::GetInstance().RegisterProtoTypePass(pass_name, create_fn, fmk_type); -} -} // namespace ge diff --git a/register/register.cpp b/register/register.cpp deleted file mode 100644 index 85bc1327122c251bb61d6c0088d58abe414fb095..0000000000000000000000000000000000000000 --- a/register/register.cpp +++ /dev/null @@ -1,1192 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/register/register.h" -#include -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_op_types.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "proto/tensorflow/attr_value.pb.h" -#include "proto/tensorflow/node_def.pb.h" -#include "register/auto_mapping_util.h" -#include "register/op_registry.h" -#include "register/register_utils.h" -#include "graph/graph.h" -#include "graph/debug/ge_util.h" -#include "graph/ascend_limits.h" -#include "graph/def_types.h" -#include "graph/utils/graph_utils_ex.h" - -namespace domi { -/*lint -e1073*/ -namespace { -const std::string kDefaultFormat = "ND"; -const std::string kSrcFormat = "src_format"; -const std::string kDstFormat = "dst_format"; -const std::string kDataFormat = "data_format"; -const std::string kTfInputDesc = "input_tensor_desc"; -const std::string kTfOutputDesc = "output_tensor_desc"; -const std::string kFuncNameKey = "name"; - -struct DynamicInfo { - DynamicInfo(const DynamicType dynamic_type, const uint32_t index, const uint32_t num) - : type(dynamic_type), - inset_index(index), - tensor_num(num) {} - explicit DynamicInfo() : DynamicInfo(kInvalid, 0U, 0U) {} - - DynamicType GetType() const {return type;} - uint32_t GetInsetIndex() const {return inset_index;} - uint32_t GetTensorNum() const {return tensor_num;} - void SetInsetIndex(const uint32_t insetIndex) {inset_index = insetIndex;} -private: - DynamicType type; - uint32_t inset_index; - uint32_t tensor_num; -}; - -std::set GetSubgraphAttrNames(const ge::Operator &op) { - if (op.GetSubgraphNamesCount() == 0U) { - return std::set(); - } - std::vector subgraph_names; - (void) op.GetSubgraphNames(subgraph_names); - std::vector subgraph_name_strings; - for (const auto &subgraph_name : subgraph_names) { - subgraph_name_strings.emplace_back(subgraph_name.GetString()); - } - return std::set(subgraph_name_strings.begin(), subgraph_name_strings.end()); -} - -/// there are two forms to represent functions in TF: -/// case 1(subgraph of a `if` node) normal subgraph: -/// attr { -/// key: "else_branch" -/// value { -/// func { -/// name: "cond_false_9" -/// } -/// } -/// } - -/// case 2(subgraph of a `case` node) dynamic subgraph: -/// attr { -/// key: "branches" -/// value { -/// list { -/// func { -/// name: "two_J6Sc96RZs5g" -/// } -/// func { -/// name: "three_3pYv7KFNs2M" -/// } -/// func { -/// name: "four_MdtG6T4LHxA" -/// } -/// } -/// } -/// } -/// \param func_attr -/// \param op_desc -/// \return -Status AutoMappingFunction(const std::pair &func_attr, - std::shared_ptr &op_desc) { - if (func_attr.second.value_case() == domi::tensorflow::AttrValue::kFunc) { - const auto &func_signature = func_attr.second.func().name(); - if (ge::OpDescUtils::SetSubgraphInstanceName(func_attr.first, func_signature, op_desc) != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to set subgraph instance %s for node %s type %s, instance name %s", - func_attr.first.c_str(), op_desc->GetName().c_str(), - op_desc->GetType().c_str(), func_signature.c_str()); - return FAILED; - } - } else if (func_attr.second.value_case() == domi::tensorflow::AttrValue::kList) { - uint32_t i = 0U; - for (auto &dyn_func_attr : func_attr.second.list().func()) { - const auto &func_signature = dyn_func_attr.name(); - const auto subgraph_name = func_attr.first + std::to_string(i++); - auto ret = op_desc->AddSubgraphName(subgraph_name); - if (ret != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to add subgraph name %s to node %s type %s", - subgraph_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return FAILED; - } - ret = ge::OpDescUtils::SetSubgraphInstanceName(subgraph_name, func_signature, op_desc); - if (ret != ge::GRAPH_SUCCESS) { - GE_LOGE("Failed to set dynamic subgraph instance %s for node %s type %s, instance name %s", - func_attr.first.c_str(), op_desc->GetName().c_str(), - op_desc->GetType().c_str(), func_signature.c_str()); - return FAILED; - } - } - } else { - GE_LOGE("Unexpected attr value type %d for func", static_cast(func_attr.second.value_case())); - return FAILED; - } - return SUCCESS; -} - -Status CheckDynamicInfo(const vector &dynamic_name_attr_value) { - for (const auto &dynamic_info : dynamic_name_attr_value) { - if ((dynamic_info.port_name_len == 0) || (dynamic_info.port_name_len > kMaxNameLength) || - (dynamic_info.attr_name_len == 0) || (dynamic_info.attr_name_len > kMaxNameLength)) { - GELOGE(PARAM_INVALID, "[Check][Param]port_name_len:%ld, attr_name_len:%ld", - dynamic_info.port_name_len, dynamic_info.attr_name_len); - return PARAM_INVALID; - } - - const int64_t port_name_len = static_cast(strnlen(dynamic_info.port_name, ge::kMaxNameLen)); - if ((dynamic_info.port_name == nullptr) || (port_name_len != dynamic_info.port_name_len)) { - GELOGE(PARAM_INVALID, "[Check][Param]port_name:%s, port_name_len:%ld", - dynamic_info.port_name, dynamic_info.port_name_len); - return PARAM_INVALID; - } - - const int64_t attr_name_len = static_cast(strnlen(dynamic_info.attr_name, ge::kMaxNameLen)); - if ((dynamic_info.attr_name == nullptr) || (attr_name_len != dynamic_info.attr_name_len)) { - GELOGE(PARAM_INVALID, "[Check][Param]attr_name:%s, attr_name_len:%ld", - dynamic_info.attr_name, dynamic_info.attr_name_len); - return PARAM_INVALID; - } - } - - return SUCCESS; -} - -Status GetDynamicTensorNum(const std::shared_ptr &op_desc, const string &attr_name, uint32_t &tensor_num) { - GE_CHECK_NOTNULL(op_desc); - - ge::GeAttrValue attr_value; - const ge::graphStatus ret = op_desc->GetAttr(attr_name, attr_value); - if (ret != SUCCESS) { - GELOGE(FAILED, "[Get][Attr:%s]op name:%s", attr_name.c_str(), op_desc->GetName().c_str()); - return FAILED; - } - - const ge::GeAttrValue::ValueType value_type = attr_value.GetValueType(); - switch (value_type) { - case ge::GeAttrValue::VT_LIST_DATA_TYPE: { - vector vec_d; - (void)ge::AttrUtils::GetListDataType(op_desc, attr_name, vec_d); - tensor_num = static_cast(vec_d.size()); - break; - } - case ge::GeAttrValue::VT_INT: { - (void)ge::AttrUtils::GetInt(op_desc, attr_name, tensor_num); - break; - } - default: - GELOGI("Default other value type: %d", static_cast(value_type)); - break; - } - - return SUCCESS; -} - -Status GetDynamicAttrNum(const std::shared_ptr &op_desc, - const vector &dynamic_name_attrs, std::map &port_dynamic_info) { - GE_CHECK_NOTNULL(op_desc); - for (const auto &dynamic_name_attr : dynamic_name_attrs) { - const std::string attr_name = dynamic_name_attr.attr_name; - uint32_t dynamic_tensor_num = 0U; - if (op_desc->HasAttr(attr_name)) { - if (GetDynamicTensorNum(op_desc, attr_name, dynamic_tensor_num) != SUCCESS) { - GELOGE(FAILED, "[Get][DynamicTensorNum]op_name:%s, attr_name:%s", - op_desc->GetName().c_str(), attr_name.c_str()); - return FAILED; - } - } else { - GELOGW("[UpdateDynamic][GetAttr] Dynamic attr %s does not exist in op %s", attr_name.c_str(), - op_desc->GetName().c_str()); - continue; - } - GELOGI("In Op %s dynamic attr [%s] is exist, tensor num: %u.", op_desc->GetName().c_str(), attr_name.c_str(), - dynamic_tensor_num); - port_dynamic_info[dynamic_name_attr.port_name] = DynamicInfo(dynamic_name_attr.type, 0U, dynamic_tensor_num); - } - return SUCCESS; -} - -Status UpdateDynamicInputOutPutIndex(const std::shared_ptr &op_desc, - const vector &dynamic_name_attrs, std::map &port_dynamic_info) { - GE_CHECK_NOTNULL(op_desc); - if (GetDynamicAttrNum(op_desc, dynamic_name_attrs, port_dynamic_info) != SUCCESS) { - GELOGE(FAILED, "[Get][DynamicAttrNum] fail, op_name:%s", op_desc->GetName().c_str()); - return FAILED; - } - - const vector register_input_names = op_desc->GetRegisterInputName(); - uint32_t input_index = 0U; - uint32_t input_increment = 0U; - for (const auto &input_name : register_input_names) { - const auto input_iter = port_dynamic_info.find(input_name); - if (input_iter != port_dynamic_info.end()) { - port_dynamic_info[input_name].SetInsetIndex(input_index + input_increment); - const uint32_t tensor_num = port_dynamic_info[input_name].GetTensorNum(); - if (tensor_num == 0U) { - (void)port_dynamic_info.erase(input_iter); - continue; - } - input_increment += (tensor_num > 0U) ? (tensor_num - 1U) : 0U; - GELOGI("Dynamic input name[%s] insert index: %u, tensor num: %u, op proto index: %u", input_name.c_str(), - port_dynamic_info[input_name].GetInsetIndex(), tensor_num, input_index); - } - input_index++; - } - const vector register_output_names = op_desc->GetRegisterOutputName(); - uint32_t output_index = 0U; - uint32_t out_increment = 0U; - for (const auto &output_name : register_output_names) { - const auto output_iter = port_dynamic_info.find(output_name); - if (output_iter != port_dynamic_info.end()) { - port_dynamic_info[output_name].SetInsetIndex(output_index + out_increment); - const uint32_t tensor_num = port_dynamic_info[output_name].GetTensorNum(); - if (tensor_num == 0U) { - (void)port_dynamic_info.erase(output_iter); - continue; - } - out_increment += (tensor_num > 0U) ? (tensor_num - 1U) : 0U; - GELOGI("Dynamic output name[%s] insert index: %u, tensor num: %u, op proto index: %u", output_name.c_str(), - port_dynamic_info[output_name].GetInsetIndex(), tensor_num, output_index); - } - output_index++; - } - return SUCCESS; -} - -Status SetOpdescInputOutputFormat(std::shared_ptr &op_desc) { - GE_CHECK_NOTNULL(op_desc); - - const auto inputDescsPtr = op_desc->GetAllInputsDescPtr(); - const auto outputDescsPtr = op_desc->GetAllOutputsDescPtr(); - - string src_data_format = kDefaultFormat; - string dst_data_format = kDefaultFormat; - if (op_desc->HasAttr(kSrcFormat)) { - (void)ge::AttrUtils::GetStr(op_desc, kSrcFormat, src_data_format); - } - if (op_desc->HasAttr(kDstFormat)) { - (void)ge::AttrUtils::GetStr(op_desc, kDstFormat, dst_data_format); - } - if (op_desc->HasAttr(kDataFormat)) { - (void)ge::AttrUtils::GetStr(op_desc, kDataFormat, src_data_format); - dst_data_format = src_data_format; - } - ge::Format format = ge::TypeUtils::DataFormatToFormat(src_data_format); - for (const auto &inputDescPtr : inputDescsPtr) { - inputDescPtr->SetOriginFormat(format); - inputDescPtr->SetFormat(format); - } - format = ge::TypeUtils::DataFormatToFormat(dst_data_format); - for (const auto &outputDescPtr : outputDescsPtr) { - outputDescPtr->SetOriginFormat(format); - outputDescPtr->SetFormat(format); - } - return SUCCESS; -} -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingFnDynamic( - const google::protobuf::Message *op_src, ge::Operator &op, - std::map> dynamic_name_attr_value, - int32_t in_pos, int32_t out_pos) { - // 1. automapping for parser - const std::shared_ptr op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(op_src); - const Status ret = OperatorAutoMapping(op_src, op); - if (ret != SUCCESS) { - GE_LOGE("Op: %s call auto mapping function failed.", op_desc->GetName().c_str()); - return FAILED; - } - - GELOGI("op[%s] call auto mapping function success.", op_desc->GetName().c_str()); - - if (dynamic_name_attr_value.size() > 2U) { // attr value size should be less than 2 - GE_LOGE("attr set size [%zu] should be less than 2.", dynamic_name_attr_value.size()); - return FAILED; - } - - // add dynamic input and output - const domi::tensorflow::NodeDef *const node = ge::PtrToPtr(op_src); - for (const auto &it : dynamic_name_attr_value) { - const std::string flag = it.first; - const std::pair name_value = it.second; - const std::string dynamic_name = name_value.first; - const std::string attr_name = name_value.second; - - tensorflow::AttrValue attr_num; - int32_t dynamic_tensor_num = 0; - if (!(ge::AutoMappingUtil::FindAttrValue(node, attr_name, attr_num))) { - GELOGW("[AutoMappingFn][GetAttr] Dynamic attr %s in node %s does not exist.", attr_name.c_str(), node->name().c_str()); - } - - dynamic_tensor_num = (attr_num.has_list()) ? attr_num.list().type_size() : static_cast(attr_num.i()); - if (dynamic_tensor_num <= 0) { - GELOGW("[AutoMappingFn][Check] Dynamic num %d in node %s is less than 0.", dynamic_tensor_num, - node->name().c_str()); - continue; - } - - GELOGI("In NodeDef %s dynamic attr [%s] is exist: %d.", node->name().c_str(), attr_name.c_str(), - dynamic_tensor_num); - - if (flag == "in") { - const bool is_pushback = (in_pos == -1); - (void)op_desc->AddDynamicInputDesc(dynamic_name, static_cast(dynamic_tensor_num), is_pushback); - (void)ge::AttrUtils::SetInt(op_desc, DYNAMIC_INPUT_TD_NUM(dynamic_name), dynamic_tensor_num); - GELOGI("In NodeDef %s add dynamic input[%d]", node->name().c_str(), dynamic_tensor_num); - } else if (flag == "out") { - const bool is_pushback = (out_pos == -1); - (void)op_desc->AddDynamicOutputDesc(dynamic_name, static_cast(dynamic_tensor_num), is_pushback); - (void)ge::AttrUtils::SetInt(op_desc, DYNAMIC_OUTPUT_TD_NUM(dynamic_name), dynamic_tensor_num); - GELOGI("In NodeDef %s add dynamic output[%d]", node->name().c_str(), dynamic_tensor_num); - } else { - // no operation - } - } - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, - ge::Operator &op, const vector &dynamic_name_attr_value) { - // 1. auto mapping for parser - const std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_CHECK_NOTNULL(op_desc_dst); - - const Status ret = AutoMappingByOpFn(op_src, op); - if (ret != SUCCESS) { - GELOGE(ret, "[Mapping][Operator]op_name:%s", op_desc_dst->GetName().c_str()); - return FAILED; - } - - GELOGI("Op[%s] call auto mapping function success.", op_desc_dst->GetName().c_str()); - // 2. check dynamic input output info; - if (CheckDynamicInfo(dynamic_name_attr_value) != SUCCESS) { - GELOGE(FAILED, "[Check][DynamicInfo]op_name:%s", op_desc_dst->GetName().c_str()); - return FAILED; - } - // 3. update dynamic input output index by tensor num; - std::map port_dynamic_info; - if (UpdateDynamicInputOutPutIndex(op_desc_dst, dynamic_name_attr_value, port_dynamic_info) != SUCCESS) { - GELOGE(FAILED, "[Update][DynamicIndex]op_name:%s", op_desc_dst->GetName().c_str()); - return FAILED; - } - // 4. sort map by port name insert index. - vector> port_dynamic_info_vec(port_dynamic_info.cbegin(), port_dynamic_info.cend()); - std::sort(port_dynamic_info_vec.begin(), port_dynamic_info_vec.end(), - [](const pair &p1, const pair &p2) - { return p1.second.GetInsetIndex() < p2.second.GetInsetIndex(); }); - // 5. add dynamic input and output - for (const auto &dynamic_info : port_dynamic_info_vec) { - const string port_name = dynamic_info.first; - const DynamicType dynamic_type = dynamic_info.second.GetType(); - const uint32_t insert_index = dynamic_info.second.GetInsetIndex(); - const uint32_t tensor_num = dynamic_info.second.GetTensorNum(); - - if (dynamic_type == kInput) { - (void)op_desc_dst->AddInputDescMiddle(port_name, tensor_num, static_cast(insert_index)); - (void)ge::AttrUtils::SetInt(op_desc_dst, DYNAMIC_INPUT_TD_NUM(port_name), static_cast(tensor_num)); - GELOGI("Op[%s] add dynamic input[%u]", op_desc_dst->GetName().c_str(), tensor_num); - } else if (dynamic_type == kOutput) { - (void)op_desc_dst->AddOutputDescMiddle(port_name, tensor_num, static_cast(insert_index)); - (void)ge::AttrUtils::SetInt(op_desc_dst, DYNAMIC_OUTPUT_TD_NUM(port_name), - static_cast(tensor_num)); - GELOGI("Op[%s] add dynamic output[%u]", op_desc_dst->GetName().c_str(), tensor_num); - } else { - GELOGW("Do not add input or output desc with dynamic type :[%d].", static_cast(dynamic_type)); - continue; - } - } - - return SUCCESS; -} - -// Convert tensorflow property to ge property -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OperatorAutoMapping(const Message *op_src, ge::Operator &op) { - std::shared_ptr op_dst = ge::OpDescUtils::GetOpDescFromOperator(op); - // Analysis of tensorflow operator parameters based on key value - GE_CHECK_NOTNULL(op_src); - GE_CHECK_NOTNULL(op_dst); - - const auto subgraph_attr_names = GetSubgraphAttrNames(op); - const domi::tensorflow::NodeDef *const node_src = ge::PtrToPtr(op_src); - GE_CHECK_NOTNULL(node_src); - op_dst->SetName(node_src->name()); - for (const auto &attr_pair : node_src->attr()) { - if ((attr_pair.first == kTfInputDesc) || (attr_pair.first == kTfOutputDesc)) { - continue; - } - if (subgraph_attr_names.count(attr_pair.first) > 0U) { - const Status ret = AutoMappingFunction(attr_pair, op_dst); - if (ret != SUCCESS) { - return ret; - } - } else { - ge::AutoMappingUtil::ConvertValue(attr_pair.first, attr_pair.second, op_dst); - } - } - - const Status ret = SetOpdescInputOutputFormat(op_dst); - if (ret != SUCCESS) { - GELOGE(FAILED, "[Set][Format]Set op[%s] desc input output format failed.", op_dst->GetName().c_str()); - return FAILED; - } - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingFn(const Message *op_src, ge::Operator &op) { - return OperatorAutoMapping(op_src, op); -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status AutoMappingByOpFn(const ge::Operator &op_src, - ge::Operator &op) { - const std::shared_ptr op_desc_src = ge::OpDescUtils::GetOpDescFromOperator(op_src); - std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op); - GE_CHECK_NOTNULL(op_desc_src); - GE_CHECK_NOTNULL(op_desc_dst); - - op_desc_dst->SetName(op_desc_src->GetName()); - const auto subgraph_name_indexs = op_desc_src->GetSubgraphNameIndexes(); - for (const auto &subgraph_name_index : subgraph_name_indexs) { - const auto ret = op_desc_dst->AddSubgraphName(subgraph_name_index.first); - if (ret != ge::GRAPH_SUCCESS) { - GELOGW("[AutoMappingFn][Check] %s subgraph of node %s, type %s already exist.", - subgraph_name_index.first.c_str(), op_desc_dst->GetName().c_str(), op_desc_dst->GetType().c_str()); - } - } - - const auto subgraph_instance_names = op_desc_src->GetSubgraphInstanceNames(); - uint32_t index = 0U; - for (const auto &subgraph_instance_name : subgraph_instance_names) { - const auto ret = op_desc_dst->SetSubgraphInstanceName(index, subgraph_instance_name); - if (ret != ge::GRAPH_SUCCESS) { - GELOGE(FAILED, "[Add][SubGraphInstance] subgraph_name: %s, index: %u, for node %s type %s.", - subgraph_instance_name.c_str(), index, op_desc_dst->GetType().c_str(), op_desc_dst->GetName().c_str()); - return FAILED; - } - index++; - } - - for (const auto &iter : op_desc_src->GetAllAttrs()) { - (void) op_desc_dst->SetAttr(iter.first, iter.second); - } - - const Status ret = SetOpdescInputOutputFormat(op_desc_dst); - if (ret != SUCCESS) { - GELOGE(FAILED, "[Set][Format]op_name:%s", op_desc_dst->GetName().c_str()); - return FAILED; - } - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY -Status AutoMappingSubgraphIndex(const ge::Graph &graph, - const std::function &input, - const std::function &output) { - GE_CHECK_NOTNULL(input); - GE_CHECK_NOTNULL(output); - return AutoMappingSubgraphIndex(graph, - [&input](const int32_t i, int32_t &o) -> Status { - o = input(i); - return SUCCESS; - }, - [&output](const int32_t i, int32_t &o) -> Status { - o = output(i); - return SUCCESS; - }); -} - -namespace { - std::vector> FindNodesByType(const ge::ComputeGraphPtr &graph, const std::string &type) { - std::vector> nodes; - for (const auto &node : graph->GetDirectNode()) { - GELOGI("Find node %s, node type is %s.", type.c_str(), node->GetOpDesc()->GetType().c_str()); - if (node->GetOpDesc()->GetType() == type) { - nodes.push_back(node); - continue; - } - if (node->GetOpDesc()->GetType() == "FrameworkOp") { - std::string original_type; - if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, original_type)) { - // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. - continue; - } - if (original_type == type) { - nodes.push_back(node); - } - } - } - return nodes; - } -} - -static Status AutoMappingSubgraphOutput(const ge::ComputeGraphPtr &graph, - const std::function &output) { - GE_CHECK_NOTNULL(graph); - GE_CHECK_NOTNULL(output); - const auto &output_node = graph->FindFirstNodeMatchType(ge::NETOUTPUT); - if (output_node == nullptr) { // Graph from parser no NetOutput. - return SUCCESS; - } - - const auto &op_desc = output_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - for (size_t index = 0U; index < op_desc->GetInputsSize(); ++index) { - int32_t parent_index = -1; - const auto ret = output(static_cast(index), parent_index); - if (ret != SUCCESS) { - GELOGE(FAILED, "[Get][ParentIndex:output]net output index %ld, error code %u", index, ret); - return FAILED; - } - - GELOGI("Generate subgraph output map for subgraph %s, index %ld, parent node index %d", - graph->GetName().c_str(), index, parent_index); - if (parent_index == -1) { - continue; - } - - const ge::GeTensorDescPtr tensor = op_desc->MutableInputDesc(static_cast(index)); - GE_CHECK_NOTNULL(tensor); - if (!ge::AttrUtils::SetInt(tensor, ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(FAILED, "[Set][Attr:%s]Failed for graph %s, op_name:%s, parent_index:%d", - ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), graph->GetName().c_str(), - op_desc->GetName().c_str(), parent_index); - return FAILED; - } - } - - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY -static Status AutoMappingSubgraphIndexByDataNode(const ge::ComputeGraphPtr &compute_graph, - const std::function &input) { - const auto nodes = FindNodesByType(compute_graph, "Data"); - for (size_t i = 0U; i < nodes.size(); ++i) { - int32_t parent_index = -1; - int32_t index = -1; - if (!ge::AttrUtils::GetInt(nodes[i]->GetOpDesc(), "index", index)) { - GELOGE(FAILED, "[Get][Attr:index]data_index:%zu, op_name:%s", i, nodes[i]->GetOpDesc()->GetName().c_str()); - return FAILED; - } - GELOGI("Get index %d from data[%zu]", index, i); - const auto ret = input(index, parent_index); - if (ret != SUCCESS) { - GELOGE(FAILED, "[Get][ParentIndex:input]data index %zu, error code %u", i, ret); - return FAILED; - } - if (!ge::AttrUtils::SetInt(nodes[i]->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(FAILED, "[Set][Attr:%s]data_index:%zu, op_name:%s, ", - ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), i, nodes[i]->GetName().c_str()); - return FAILED; - } - GELOGI("Generate subgraph input map for subgraph %s, data index %zu, parent node index %d", - compute_graph->GetName().c_str(), i, parent_index); - } - return SUCCESS; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY -Status AutoMappingSubgraphIndex(const ge::Graph &graph, - const std::function &input, - const std::function &output) { - GE_CHECK_NOTNULL(input); - GE_CHECK_NOTNULL(output); - const auto compute_graph = ge::GraphUtilsEx::GetComputeGraph(graph); - GE_CHECK_NOTNULL(compute_graph); - ge::AscendString graph_name; - (void) graph.GetName(graph_name); - auto ret = AutoMappingSubgraphIndexByDataNode(compute_graph, input); - if (ret != SUCCESS) { - GELOGE(ret, "[Mapping][Index] auto mapping graph:%s input index failed,", graph_name.GetString()); - return ret; - } - - const auto nodes = FindNodesByType(compute_graph, "_Retval"); - for (auto &retval : nodes) { - int64_t index = -1; - if (!ge::AttrUtils::GetInt(retval->GetOpDesc(), "retval_index", index)) { - GELOGE(FAILED, "[Get][Attr:retval_index]retval index %ld, op_name:%s", - index, retval->GetOpDesc()->GetName().c_str()); - return FAILED; - } - int32_t parent_index = -1; - ret = output(static_cast(index), parent_index); - if (ret != SUCCESS) { - GELOGE(FAILED, "[Get][ParentIndex:output]retval index %ld, error code %u", index, ret); - return FAILED; - } - if (!ge::AttrUtils::SetInt(retval->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(FAILED, "[Set][Attr:%s]op_name:%s, parent_index:%d", - ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), retval->GetName().c_str(), parent_index); - return FAILED; - } - GELOGI("Generate subgraph output map for subgraph %s, retval index %ld, parent node index %d", - graph_name.GetString(), index, parent_index); - } - - return nodes.empty() ? AutoMappingSubgraphOutput(compute_graph, output) : SUCCESS; -} - -class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkRegistryImpl { - public: - void AddAutoMappingSubgraphIOIndexFunc(const domi::FrameworkType framework, AutoMappingSubgraphIOIndexFunc fun); - AutoMappingSubgraphIOIndexFunc GetAutoMappingSubgraphIOIndexFunc(const domi::FrameworkType framework); - private: - std::map fmk_type_to_auto_mapping_subgraph_index_fun_; -}; - -void FrameworkRegistryImpl::AddAutoMappingSubgraphIOIndexFunc( - const domi::FrameworkType framework, AutoMappingSubgraphIOIndexFunc fun) { - GELOGD("Regitser auto mapping function: framework type:%d.", framework); - fmk_type_to_auto_mapping_subgraph_index_fun_[framework] = std::move(fun); -} - -AutoMappingSubgraphIOIndexFunc FrameworkRegistryImpl::GetAutoMappingSubgraphIOIndexFunc( - const domi::FrameworkType framework) { - const auto itr = fmk_type_to_auto_mapping_subgraph_index_fun_.find(framework); - if (itr != fmk_type_to_auto_mapping_subgraph_index_fun_.end()) { - return itr->second; - } - return nullptr; -} - -FrameworkRegistry::FrameworkRegistry() { - impl_ = ge::ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGW("[Check][Param] make impl failed"); - } -} - -FrameworkRegistry::~FrameworkRegistry() = default; - -FrameworkRegistry& FrameworkRegistry::Instance() { - static FrameworkRegistry instance; - return instance; -} - -void FrameworkRegistry::AddAutoMappingSubgraphIOIndexFunc( - domi::FrameworkType framework, AutoMappingSubgraphIOIndexFunc fun) { - if (impl_ != nullptr) { - impl_->AddAutoMappingSubgraphIOIndexFunc(framework, fun); - } -} - -AutoMappingSubgraphIOIndexFunc FrameworkRegistry::GetAutoMappingSubgraphIOIndexFunc( - domi::FrameworkType framework) { - if (impl_ != nullptr) { - return impl_->GetAutoMappingSubgraphIOIndexFunc(framework); - } - return nullptr; -} - -AutoMappingSubgraphIOIndexFuncRegister::AutoMappingSubgraphIOIndexFuncRegister( - domi::FrameworkType framework, AutoMappingSubgraphIOIndexFunc fun) { - FrameworkRegistry::Instance().AddAutoMappingSubgraphIOIndexFunc(framework, fun); -} - -OpReceiver::OpReceiver(OpRegistrationData ®_data) { OpRegistry::Instance()->registrationDatas.push_back(reg_data); } - -class OpRegistrationDataImpl { - public: - OpRegistrationDataImpl() = default; - ~OpRegistrationDataImpl() = default; - explicit OpRegistrationDataImpl(const std::string &om_optype); -private: - friend class OpRegistrationData; - friend class OpRegistry; - domi::FrameworkType fmk_type_; - std::set ori_optype_set_; // OP type in the original model, there may be multiple - std::string om_optype_; // OP type in OM model - domi::ImplyType imply_type_; // execution type - ParseParamFunc parseParamFn_; // parseParam function - ParseParamByOpFunc parse_param_by_op_fn_; // parse param by op function - FusionParseParamFunc fusionParseParamFn_; // fusion parseParam function - FusionParseParamByOpFunc fusion_parse_param_by_op_fn_; // fusion parseParam by op function - ParseSubgraphFunc parse_subgraph_post_fn_; // a function called after the subgraph was generated - ParseSubgraphFuncV2 parse_subgraph_post_fn_v2_; // a function called after the subgraph was generated - std::vector remove_input_configure_vec_; - ParseOpToGraphFunc parse_op_to_graph_fn_; -}; - -OpRegistrationDataImpl::OpRegistrationDataImpl(const std::string &om_optype) - : fmk_type_(FRAMEWORK_RESERVED), - om_optype_(om_optype), - imply_type_(domi::ImplyType::BUILDIN), - parseParamFn_(nullptr), - parse_param_by_op_fn_(nullptr), - fusionParseParamFn_(nullptr), - fusion_parse_param_by_op_fn_(nullptr), - parse_subgraph_post_fn_(nullptr), - parse_subgraph_post_fn_v2_(nullptr), - parse_op_to_graph_fn_(nullptr) {} - -OpRegistrationData::~OpRegistrationData() = default; - -OpRegistrationData::OpRegistrationData(const std::string &om_optype) { - impl_ = ge::ComGraphMakeShared(om_optype); - if (impl_ == nullptr) { - GELOGW("[Check][Param] make impl failed!"); - } -} - -OpRegistrationData::OpRegistrationData(const char_t *om_optype) { - std::string op_type; - if (om_optype != nullptr) { - op_type = om_optype; - } - impl_ = ge::ComGraphMakeShared(op_type); - if (impl_ == nullptr) { - GELOGW("[Check][Param] make impl failed!"); - } -} - -std::string OpRegistrationData::GetOmOptype() const { - if (impl_ != nullptr) { - return impl_->om_optype_; - } - return ""; -} - -Status OpRegistrationData::GetOmOptype(ge::AscendString &om_op_type) const { - if (impl_ != nullptr) { - om_op_type = ge::AscendString(impl_->om_optype_.c_str()); - } - return SUCCESS; -} - -OpRegistrationData &OpRegistrationData::FrameworkType(const domi::FrameworkType &fmk_type) { - if (impl_ != nullptr) { - impl_->fmk_type_ = fmk_type; - } - return *this; -} - -domi::FrameworkType OpRegistrationData::GetFrameworkType() const { - if (impl_ != nullptr) { - return impl_->fmk_type_; - } - return FRAMEWORK_RESERVED; -} - -OpRegistrationData &OpRegistrationData::OriginOpType(const std::initializer_list &ori_optype_list) { - if (impl_ != nullptr) { - for (const auto &ori_optype : ori_optype_list) { - (void)impl_->ori_optype_set_.insert(ori_optype); - } - } - return *this; -} - -OpRegistrationData &OpRegistrationData::OriginOpType(const std::vector &ori_op_type_list) { - if (impl_ != nullptr) { - for (auto &ori_op_type : ori_op_type_list) { - std::string tmp_ori_op_type; - if (ori_op_type.GetString() != nullptr) { - tmp_ori_op_type = ori_op_type.GetString(); - } - (void)impl_->ori_optype_set_.insert(tmp_ori_op_type); - } - } - return *this; -} - -OpRegistrationData &OpRegistrationData::OriginOpType(const std::string &ori_optype) { - if (impl_ != nullptr) { - (void)impl_->ori_optype_set_.insert(ori_optype); - } - return *this; -} - -OpRegistrationData &OpRegistrationData::OriginOpType(const char_t *ori_op_type) { - if (impl_ != nullptr) { - std::string tmp_ori_op_type; - if (ori_op_type != nullptr) { - tmp_ori_op_type = ori_op_type; - } - (void)impl_->ori_optype_set_.insert(tmp_ori_op_type); - } - return *this; -} - -std::set OpRegistrationData::GetOriginOpTypeSet() const { - const std::set ori_optype_set; - if (impl_ != nullptr) { - return impl_->ori_optype_set_; - } - return ori_optype_set; -} - -Status OpRegistrationData::GetOriginOpTypeSet(std::set &ori_op_type) const { - std::set ori_op_type_set; - if (impl_ != nullptr) { - ori_op_type_set = impl_->ori_optype_set_; - } - for (auto &op_type : ori_op_type_set) { - (void)ori_op_type.insert(ge::AscendString(op_type.c_str())); - } - return SUCCESS; -} - -OpRegistrationData &OpRegistrationData::ParseParamsFn(const ParseParamFunc &parseParamFn) { - if (impl_ != nullptr) { - impl_->parseParamFn_ = parseParamFn; - } - return *this; -} - -ParseParamFunc OpRegistrationData::GetParseParamFn() const { - if (impl_ != nullptr) { - return impl_->parseParamFn_; - } - return nullptr; -} - -OpRegistrationData &OpRegistrationData::ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn) { - if (impl_ != nullptr) { - impl_->parse_param_by_op_fn_ = parse_param_by_op_fn; - } - return *this; -} - -ParseParamByOpFunc OpRegistrationData::GetParseParamByOperatorFn() const { - if (impl_ != nullptr) { - return impl_->parse_param_by_op_fn_; - } - return nullptr; -} - -OpRegistrationData &OpRegistrationData::FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn) { - if (impl_ != nullptr) { - impl_->fusionParseParamFn_ = fusionParseParamFn; - } - return *this; -} - -FusionParseParamFunc OpRegistrationData::GetFusionParseParamFn() const { - if (impl_ != nullptr) { - return impl_->fusionParseParamFn_; - } - return nullptr; -} - -OpRegistrationData &OpRegistrationData::FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn) { - if (impl_ != nullptr) { - impl_->fusion_parse_param_by_op_fn_ = fusion_parse_param_fn; - } - return *this; -} - -FusionParseParamByOpFunc OpRegistrationData::GetFusionParseParamByOpFn() const { - if (impl_ != nullptr) { - return impl_->fusion_parse_param_by_op_fn_; - } - return nullptr; -} - -OpRegistrationData &OpRegistrationData::ImplyType(const domi::ImplyType &imply_type) { - if (impl_ != nullptr) { - impl_->imply_type_ = imply_type; - } - return *this; -} - -domi::ImplyType OpRegistrationData::GetImplyType() const { - constexpr domi::ImplyType imply_type = domi::ImplyType::BUILDIN; - if (impl_ != nullptr) { - return impl_->imply_type_; - } - return imply_type; -} - -OpRegistrationData &OpRegistrationData::DelInputWithCond(int32_t inputIdx, const std::string &attrName, - bool attrValue) { - if (impl_ != nullptr) { - struct RemoveInputConfigure registerStu; - registerStu.inputIdx = inputIdx; - registerStu.attrName = attrName; - registerStu.moveType = RemoveInputType::OMG_REMOVE_TYPE_WITH_COND; - registerStu.attrValue = attrValue; - impl_->remove_input_configure_vec_.push_back(registerStu); - } - return *this; -} - -OpRegistrationData &OpRegistrationData::DelInputWithCond(int32_t input_idx, const char_t *attr_name, bool attr_value) { - std::string tmp_attr_name; - if (attr_name != nullptr) { - tmp_attr_name = attr_name; - } - if (impl_ != nullptr) { - struct RemoveInputConfigure registerStu; - registerStu.inputIdx = input_idx; - registerStu.attrName = tmp_attr_name; - registerStu.moveType = RemoveInputType::OMG_REMOVE_TYPE_WITH_COND; - registerStu.attrValue = attr_value; - impl_->remove_input_configure_vec_.push_back(registerStu); - } - return *this; -} - -OpRegistrationData &OpRegistrationData::InputReorderVector(const vector &input_order) { - if (impl_ != nullptr) { - struct RemoveInputConfigure register_input; - register_input.inputIdx = 0; - register_input.input_order = input_order; - register_input.moveType = RemoveInputType::OMG_INPUT_REORDER; - impl_->remove_input_configure_vec_.push_back(register_input); - } - return *this; -} - -OpRegistrationData &OpRegistrationData::DelInputWithOriginalType(int32_t input_idx, const std::string &ori_type) { - if (impl_ != nullptr) { - struct RemoveInputConfigure register_input; - register_input.inputIdx = input_idx; - register_input.originalType = ori_type; - register_input.moveType = RemoveInputType::OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE; - impl_->remove_input_configure_vec_.push_back(register_input); - } - return *this; -} - -OpRegistrationData &OpRegistrationData::DelInputWithOriginalType(int32_t input_idx, const char_t *ori_type) { - std::string tmp_ori_type; - if (ori_type != nullptr) { - tmp_ori_type = ori_type; - } - if (impl_ != nullptr) { - struct RemoveInputConfigure register_input; - register_input.inputIdx = input_idx; - register_input.originalType = tmp_ori_type; - register_input.moveType = RemoveInputType::OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE; - impl_->remove_input_configure_vec_.push_back(register_input); - } - return *this; -} - -OpRegistrationData &OpRegistrationData::ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn) { - if (impl_ != nullptr) { - impl_->parse_subgraph_post_fn_ = subgraph_post_fn; - } - return *this; -} - -ParseSubgraphFunc OpRegistrationData::GetParseSubgraphPostFn() const { - if (impl_ == nullptr) { - return nullptr; - } - return impl_->parse_subgraph_post_fn_; -} - -OpRegistrationData &OpRegistrationData::ParseOpToGraphFn(const ParseOpToGraphFunc &parse_op_to_graph_fn) { - if (impl_ != nullptr) { - impl_->parse_op_to_graph_fn_ = parse_op_to_graph_fn; - } - return *this; -} - -OpRegistrationData &OpRegistrationData::ParseSubgraphPostFn(const ParseSubgraphFuncV2 &subgraph_post_fn) { - if (impl_ != nullptr) { - impl_->parse_subgraph_post_fn_v2_ = subgraph_post_fn; - } - return *this; -} - -ParseOpToGraphFunc OpRegistrationData::GetParseOpToGraphFn() const { - if (impl_ == nullptr) { - return nullptr; - } - return impl_->parse_op_to_graph_fn_; -} - -Status OpRegistrationData::GetParseSubgraphPostFn(ParseSubgraphFuncV2 &func) const { - if (impl_ == nullptr) { - return FAILED; - } - func = impl_->parse_subgraph_post_fn_v2_; - return SUCCESS; -} - -OpRegistry *OpRegistry::Instance() { - static OpRegistry g_instance; - return &g_instance; -} - -namespace { -std::string GetParserKey(const std::string &om_type, const std::string &ori_type) { - return om_type + "_" + ori_type; -} -} // namespace - -bool OpRegistry::Register(const OpRegistrationData ®_data) { - if (reg_data.impl_ == nullptr) { - return false; - } - for (const auto &ori_type : reg_data.impl_->ori_optype_set_) { - const std::string om_ori_type = GetParserKey(reg_data.impl_->om_optype_, ori_type); - if (op_parse_params_fn_map_.find(om_ori_type) != op_parse_params_fn_map_.end()) { - GELOGI("[Register][Check] Plugin of op type:%s, original type:%s already registered, skip", - reg_data.impl_->om_optype_.c_str(), ori_type.c_str()); - continue; - } - - GELOGD("The plugin of type:%s will be registered.", om_ori_type.c_str()); - op_parse_params_fn_map_[om_ori_type] = reg_data.impl_->parseParamFn_; - fusion_op_parse_params_fn_map_[om_ori_type] = reg_data.impl_->fusionParseParamFn_; - fusion_parse_params_by_op_fn_map_[om_ori_type] = reg_data.impl_->fusion_parse_param_by_op_fn_; - parse_params_by_op_func_map_[om_ori_type] = reg_data.impl_->parse_param_by_op_fn_; - remove_input_configure_map_[om_ori_type] = reg_data.impl_->remove_input_configure_vec_; - parse_op_to_graph_fn_map_[om_ori_type] = reg_data.impl_->parse_op_to_graph_fn_; - - if (origin_type_to_om_type_.find(ori_type) == origin_type_to_om_type_.end()) { - origin_type_to_om_type_[ori_type] = reg_data.impl_->om_optype_; - } - } - - if (op_run_mode_map_.find(reg_data.impl_->om_optype_) != op_run_mode_map_.end()) { - GELOGI("[Register][Check] Plugin of %s already registered, skip", reg_data.impl_->om_optype_.c_str()); - return true; - } - op_run_mode_map_[reg_data.impl_->om_optype_] = reg_data.impl_->imply_type_; - op_types_to_parse_subgraph_post_func_[reg_data.impl_->om_optype_] = reg_data.impl_->parse_subgraph_post_fn_; - op_types_to_parse_subgraph_post_func_v2_[reg_data.impl_->om_optype_] = reg_data.impl_->parse_subgraph_post_fn_v2_; - return true; -} - -domi::ImplyType OpRegistry::GetImplyTypeByOriOpType(const std::string &ori_optype) { - domi::ImplyType result = domi::ImplyType::BUILDIN; - const auto iter = origin_type_to_om_type_.find(ori_optype); - if (iter != origin_type_to_om_type_.end()) { - result = GetImplyType(iter->second); - } - return result; -} - -domi::ImplyType OpRegistry::GetImplyType(const std::string &op_type) { - const auto it_find = op_run_mode_map_.find(op_type); - if (it_find == op_run_mode_map_.end()) { - return domi::ImplyType::BUILDIN; - } - return it_find->second; -} - -domi::ParseParamByOpFunc OpRegistry::GetParseParamByOperatorFunc(const std::string &ori_type) { - std::string om_type; - const auto iter = origin_type_to_om_type_.find(ori_type); - if (iter != origin_type_to_om_type_.end()) { - om_type = iter->second; - } - const std::string type = GetParserKey(om_type, ori_type); - const auto it_find = parse_params_by_op_func_map_.find(type); - if (it_find == parse_params_by_op_func_map_.end()) { - return nullptr; - } - return it_find->second; -} - -domi::ParseParamFunc OpRegistry::GetParseParamFunc(const std::string &op_type, const std::string &ori_type) { - const std::string type = GetParserKey(op_type, ori_type); - const auto it_find = op_parse_params_fn_map_.find(type); - if (it_find == op_parse_params_fn_map_.end()) { - return nullptr; - } - return it_find->second; -} - -domi::FusionParseParamFunc OpRegistry::GetFusionParseParamFunc(const std::string &op_type, - const std::string &ori_type) { - const std::string type = GetParserKey(op_type, ori_type); - const auto it_find = fusion_op_parse_params_fn_map_.find(type); - if (it_find == fusion_op_parse_params_fn_map_.end()) { - return nullptr; - } - return it_find->second; -} - -domi::FusionParseParamByOpFunc OpRegistry::GetFusionParseParamByOpFunc(const std::string &op_type, - const std::string &ori_type) { - const std::string type = GetParserKey(op_type, ori_type); - const auto it_find = fusion_parse_params_by_op_fn_map_.find(type); - if (it_find == fusion_parse_params_by_op_fn_map_.end()) { - return nullptr; - } - return it_find->second; -} - -domi::ParseSubgraphFunc OpRegistry::GetParseSubgraphPostFunc(const std::string &op_type) { - const auto it_find = op_types_to_parse_subgraph_post_func_.find(op_type); - if (it_find == op_types_to_parse_subgraph_post_func_.end()) { - return nullptr; - } - return it_find->second; -} - -Status OpRegistry::GetParseSubgraphPostFunc(const std::string &op_type, - domi::ParseSubgraphFuncV2 &parse_subgraph_func) { - const auto it_find = op_types_to_parse_subgraph_post_func_v2_.find(op_type); - if (it_find == op_types_to_parse_subgraph_post_func_v2_.end()) { - return FAILED; - } - parse_subgraph_func = it_find->second; - return SUCCESS; -} - -void OpRegistry::GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType imply_type) const { - for (const auto &iter : op_run_mode_map_) { - if (iter.second == imply_type) { - vec_op_type.push_back(iter.first); - } - } - return; -} - -const std::vector &OpRegistry::GetRemoveInputConfigure(const std::string &ori_optype) const { - static const std::vector empty_ = {}; - const auto iter = origin_type_to_om_type_.find(ori_optype); - if (iter != origin_type_to_om_type_.end()) { - const std::string type = GetParserKey(iter->second, ori_optype); - const auto it = remove_input_configure_map_.find(type); - if (it != remove_input_configure_map_.end()) { - return it->second; - } - } - return empty_; -} - -bool OpRegistry::GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type) { - const auto iter = origin_type_to_om_type_.find(ori_optype); - if (iter != origin_type_to_om_type_.end()) { - om_type = iter->second; - return true; - } - return false; -} - -ParseOpToGraphFunc OpRegistry::GetParseOpToGraphFunc(const std::string &op_type, const std::string &ori_type) { - const std::string type = GetParserKey(op_type, ori_type); - const auto iter = parse_op_to_graph_fn_map_.find(type); - if (iter == parse_op_to_graph_fn_map_.end()) { - return nullptr; - } - return iter->second; -} -/*lint +e1073*/ -} // namespace domi diff --git a/register/register_base.cc b/register/register_base.cc deleted file mode 100644 index c625499075f47cb550208a4682a16e621718ddfa..0000000000000000000000000000000000000000 --- a/register/register_base.cc +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/register_base.h" -#include "op_lib_register_impl.h" - -extern "C" const char *aclGetCustomOpLibPath() { - return ge::OpLibRegistry::GetInstance().GetCustomOpLibPath(); -} - diff --git a/register/register_custom_pass.cpp b/register/register_custom_pass.cpp deleted file mode 100644 index 4a872da96bb775697b8a9837fc5d06604768a762..0000000000000000000000000000000000000000 --- a/register/register_custom_pass.cpp +++ /dev/null @@ -1,301 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/register/register_custom_pass.h" - -#include "common/checker.h" - -#include "register/custom_pass_helper.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "common/plugin/plugin_manager.h" -#include "register/custom_pass_context_impl.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/graph_utils.h" - -namespace ge { -const std::set kConstGraphStages = {CustomPassStage::kAfterAssignLogicStream}; -const std::map kCustomPassStageToStringMap = { - {CustomPassStage::kBeforeInferShape, "BeforeInferShape"}, - {CustomPassStage::kAfterInferShape, "AfterInferShape"}, - {CustomPassStage::kAfterAssignLogicStream, "AfterAssignLogicStream"}, - {CustomPassStage::kAfterBuiltinFusionPass, "AfterBuiltinFusionPass"}, - {CustomPassStage::kInvalid, "InvalidStage"} -}; - -namespace { -std::string CustomPassStageToString(CustomPassStage stage) { - GE_ASSERT_TRUE(stage <= CustomPassStage::kInvalid); - return kCustomPassStageToStringMap.find(stage)->second; -} - -Status RunAllocateStreamPass(const PassRegistrationData ®_data, const GraphPtr &graph, - CustomPassContext &custom_pass_context) { - GE_ASSERT_NOTNULL(graph); - const auto allocate_stream_pass_func = reg_data.GetCustomAllocateStreamPass(); - if (allocate_stream_pass_func == nullptr) { - GE_LOGE("[Check][Param] It is required CustomAllocateStreamPassFunc of [%s] at stage[%s] but got nullptr.", - reg_data.GetPassName().c_str(), CustomPassStageToString(reg_data.GetStage()).c_str()); - std::stringstream reason; - reason << "It is required CustomAllocateStreamPassFunc in stage " << CustomPassStageToString(reg_data.GetStage()) << - ", but got nullptr"; - REPORT_PREDEFINED_ERR_MSG("E19030", std::vector({"passname", "reason"}), - std::vector({reg_data.GetPassName().c_str(), reason.str().c_str()})); - return FAILED; - } - - const auto compute_graph = GraphUtilsEx::GetComputeGraph(*graph); - GE_ASSERT_NOTNULL(compute_graph); - const auto root_graph = GraphUtils::FindRootGraph(compute_graph); - GE_ASSERT_NOTNULL(root_graph); - - GE_DUMP(root_graph, "RunCustomPass_BeforeAssignLogicStream" + reg_data.GetPassName()); - // 此处框架保证传入的custom_pass_context实例为stream_pass_context - auto *stream_pass_context = dynamic_cast(&custom_pass_context); - GE_ASSERT_NOTNULL(stream_pass_context, "Failed to transfer CustomPassContext to StreamPassContext"); - const auto ret = allocate_stream_pass_func(graph, *stream_pass_context); - GE_DUMP(root_graph, "RunCustomPass_AfterAssignLogicStream" + reg_data.GetPassName()); - if (ret != SUCCESS) { - GE_LOGE("Execution of custom pass [%s] failed! Reason: %s.", reg_data.GetPassName().c_str(), - custom_pass_context.GetErrorMessage().GetString()); - REPORT_PREDEFINED_ERR_MSG( - "E19028", std::vector({"passname", "retcode", "reason"}), - std::vector({reg_data.GetPassName().c_str(), std::to_string(ret).c_str(), - std::string(custom_pass_context.GetErrorMessage().GetString()).c_str()})); - return FAILED; - } - return SUCCESS; -} - -Status RunCustomPass(const PassRegistrationData ®_data, GraphPtr &graph, CustomPassContext &custom_pass_context) { - const auto custom_pass_fn = reg_data.GetCustomPassFn(); - if (custom_pass_fn == nullptr) { - GELOGW("[Check][Param] Failed to retrieve custom_pass_fn for custom pass %s failed", - reg_data.GetPassName().c_str()); - return SUCCESS; - } - const auto ret = custom_pass_fn(graph, custom_pass_context); - if (ret != SUCCESS) { - GE_LOGE("Execution of custom pass [%s] failed! Reason: %s.", reg_data.GetPassName().c_str(), - custom_pass_context.GetErrorMessage().GetString()); - REPORT_PREDEFINED_ERR_MSG( - "E19028", std::vector({"passname", "retcode", "reason"}), - std::vector({reg_data.GetPassName().c_str(), std::to_string(ret).c_str(), - std::string(custom_pass_context.GetErrorMessage().GetString()).c_str()})); - return FAILED; - } - return SUCCESS; -} -} // namespace -PassReceiver::PassReceiver(PassRegistrationData ®_data) { - CustomPassHelper::Instance().Insert(reg_data); -} - -class PassRegistrationDataImpl { - public: - PassRegistrationDataImpl() = default; - ~PassRegistrationDataImpl() = default; - - explicit PassRegistrationDataImpl(const std::string &pass_name); - -private: - friend class PassRegistrationData; - std::string pass_name_; - CustomPassFunc custom_pass_; - CustomAllocateStreamPassFunc allocate_stream_pass_; - CustomPassStage stage_ = CustomPassStage::kBeforeInferShape; -}; - -PassRegistrationDataImpl::PassRegistrationDataImpl(const std::string &pass_name) - : pass_name_(pass_name), custom_pass_(nullptr) {} - -PassRegistrationData::PassRegistrationData(std::string pass_name) { - impl_ = ge::ComGraphMakeShared(pass_name); - if (impl_ == nullptr) { - GELOGW("[Check][Param] make impl failed, pass_name:%s", pass_name.c_str()); - } -} - -std::string PassRegistrationData::GetPassName() const { - if (impl_ == nullptr) { - return ""; - } - return impl_->pass_name_; -} - -PassRegistrationData &PassRegistrationData::CustomPassFn(const CustomPassFunc &custom_pass_fn) { - if (impl_ != nullptr) { - impl_->custom_pass_ = custom_pass_fn; - } - return *this; -} - -PassRegistrationData &PassRegistrationData::CustomAllocateStreamPassFn( - const CustomAllocateStreamPassFunc &allocate_stream_pass_fn) { - if (impl_ != nullptr) { - impl_->allocate_stream_pass_ = allocate_stream_pass_fn; - impl_->stage_ = CustomPassStage::kAfterAssignLogicStream; - } - return *this; -} - -CustomPassFunc PassRegistrationData::GetCustomPassFn() const { - if (impl_ == nullptr) { - return nullptr; - } - return impl_->custom_pass_; -} - -CustomAllocateStreamPassFunc PassRegistrationData::GetCustomAllocateStreamPass() const { - GE_ASSERT_NOTNULL(impl_); - return impl_->allocate_stream_pass_; -} - -PassRegistrationData &PassRegistrationData::Stage(const CustomPassStage stage) { - if ((impl_ != nullptr) && (stage < CustomPassStage::kInvalid)) { - impl_->stage_ = stage; - GELOGD("Setting pass [%s] stage to [%s]", impl_->pass_name_.c_str(), CustomPassStageToString(stage).c_str()); - } - return *this; -} - -CustomPassStage PassRegistrationData::GetStage() const { - if (impl_ == nullptr) { - return CustomPassStage::kInvalid; - } - return impl_->stage_; -} - -CustomPassContext::CustomPassContext() { - impl_ = ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGW("[Check][Param] make impl failed"); - } -} - -void CustomPassContext::SetErrorMessage(const AscendString &error_message) { - if (impl_ != nullptr) { - impl_->SetErrorMessage(error_message); - } -} - -AscendString CustomPassContext::GetErrorMessage() const { - if (impl_ != nullptr) { - return impl_->GetErrorMessage(); - } - return ""; -} - -StreamPassContext::StreamPassContext(int64_t current_max_stream_id) { - impl_ = ge::ComGraphMakeUnique(current_max_stream_id); - if (impl_ == nullptr) { - GELOGW("[Check][Param] make StreamPassContextImpl failed"); - } -} - -graphStatus StreamPassContext::SetStreamId(const GNode &node, int64_t stream_id) { - GE_ASSERT_NOTNULL(impl_); - return impl_->SetStreamId(node, stream_id); -} - -int64_t StreamPassContext::GetStreamId(const GNode &node) const { - int64_t stream_id; - if ((impl_ == nullptr) || impl_->GetStreamId(node, stream_id) != SUCCESS) { - return INVALID_STREAM_ID; - } - return stream_id; -} - -int64_t StreamPassContext::GetCurrMaxStreamId() const { - GE_ASSERT_NOTNULL(impl_); - return impl_->GetCurrentMaxStreamId(); -} - -int64_t StreamPassContext::AllocateNextStreamId() { - if (impl_ != nullptr) { - return impl_->AllocateNextStreamId(); - } - return INT64_MAX; -} - -CustomPassHelper &CustomPassHelper::Instance() { - static CustomPassHelper instance; - return instance; -} - -void CustomPassHelper::Insert(const PassRegistrationData ®_data) { - (void)registration_datas_.emplace_back(reg_data); -} - -Status CustomPassHelper::Load() { - GELOGD("[Load][CustomPassLibs] Start to load custom pass libs"); - std::string opp_path; - GE_ASSERT_SUCCESS(ge::PluginManager::GetOppPath(opp_path)); - std::string vendors_path = opp_path + "/vendors"; - - // 存储所有的 .so 文件路径 - std::vector so_files; - ge::PluginManager::FindSoFilesInCustomPassDirs(vendors_path, so_files); - - if (so_files.empty()) { - GELOGD("No custom pass libs found in %s, skip loading custom pass libs", vendors_path.c_str()); - return ge::SUCCESS; - } - - // 逐个 dlopen - for (const auto &so_file : so_files) { - void *handle = dlopen(so_file.c_str(), RTLD_NOW | RTLD_LOCAL); - if (handle == nullptr) { - const char* error = dlerror(); - REPORT_PREDEFINED_ERR_MSG( - "E19029", std::vector({"passlibname", "reason"}), - std::vector({so_file.c_str(), error})); - GELOGE(ge::FAILED, "Failed to load %s: %s", so_file.c_str(), error); - return ge::FAILED; - } - handles_.emplace_back(handle); - GELOGI("Load custom pass lib %s success", so_file.c_str()); - } - return ge::SUCCESS; -} - -Status CustomPassHelper::Unload() { - registration_datas_.clear(); - for (auto &handle : handles_) { - if (handle != nullptr && dlclose(handle) != 0) { - GELOGE(ge::FAILED, "[Unload][CustomPassLibs] Failed to unload custom pass lib: %s", dlerror()); - return ge::FAILED; - } - GELOGI("Unload custom pass lib success"); - } - handles_.clear(); - return ge::SUCCESS; -} - -Status CustomPassHelper::Run(GraphPtr &graph, CustomPassContext &custom_pass_context) const { - return Run(graph, custom_pass_context, CustomPassStage::kBeforeInferShape); -} - -Status CustomPassHelper::Run(GraphPtr &graph, CustomPassContext &custom_pass_context, - const CustomPassStage stage) const { - for (auto &item : registration_datas_) { - if (item.GetStage() != stage) { - continue; - } - GELOGD("Starting custom pass [%s] in stage [%s]!", item.GetPassName().c_str(), CustomPassStageToString(stage).c_str()); - if (stage == CustomPassStage::kAfterAssignLogicStream) { - GE_ASSERT_SUCCESS(RunAllocateStreamPass(item, graph, custom_pass_context)); - } else { - GE_ASSERT_SUCCESS(RunCustomPass(item, graph, custom_pass_context)); - } - GELOGD("Run custom pass [%s] successfully!", item.GetPassName().c_str()); - } - return SUCCESS; -} -} // namespace ge diff --git a/register/scope/scope_graph.cc b/register/scope/scope_graph.cc deleted file mode 100644 index bd65ac2caa693bf7abd5747868523aef166deb6c..0000000000000000000000000000000000000000 --- a/register/scope/scope_graph.cc +++ /dev/null @@ -1,1373 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "register/scope/scope_graph_impl.h" -#include "register/register_utils.h" -#include "external/register/register.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/ge_common/string_util.h" -#include "graph/debug/ge_util.h" -#include "graph/ge_tensor.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/types.h" -#include "graph/debug/ge_util.h" - -namespace ge { -namespace { -using NodeEdges = std::map>>; -using GraphNodesInOut = std::unordered_map>; -constexpr const char_t *const kTfIdentityType = "Identity"; -constexpr const char_t *const kTfConstType = "Const"; -constexpr const char_t *const kNumerics = "0123456789"; -constexpr size_t kInputPartsSize = 2U; -constexpr size_t kInputNodeName = 0U; -constexpr size_t kPeerOutIndex = 1U; -constexpr int32_t kControlSlot = -1; - -Status DecomposeInputName(const std::string &input_name, std::string &node_name, int32_t &index, bool &is_control) { - if (StringUtils::StartWith(input_name, "^")) { - is_control = true; - node_name = input_name.substr(1U); - index = kControlSlot; - return SUCCESS; - } - is_control = false; - if (input_name.find(":") == std::string::npos) { - node_name = input_name; - index = 0; - return SUCCESS; - } - const std::vector parts = StringUtils::Split(input_name, ':'); - if (parts.size() != kInputPartsSize) { - GELOGE(PARAM_INVALID, "Input name [%s] is invalid.", input_name.c_str()); - return PARAM_INVALID; - } - try { - index = static_cast(std::stoi(parts[kPeerOutIndex])); - } catch (std::invalid_argument &) { - GELOGE(PARAM_INVALID, "Peer out index [%s] is invalid.", parts[kPeerOutIndex].c_str()); - return PARAM_INVALID; - } catch (...) { - GELOGE(PARAM_INVALID, "Peer out index [%s] is out of range.", parts[kPeerOutIndex].c_str()); - return PARAM_INVALID; - } - if (index < 0) { - GELOGE(PARAM_INVALID, "Peer out index [%d] is invalid.", index); - return PARAM_INVALID; - } - node_name = parts[kInputNodeName]; - return SUCCESS; -} - -void AddEdgeCtx(const std::string &peer_node_name, const int32_t peer_index, - const int32_t curr_index, NodeEdges &node_edges) { - (void) node_edges[curr_index].insert({peer_node_name, peer_index}); -} - -Status GetGraphDefInOutMap(domi::tensorflow::GraphDef *graph_def, GraphNodesInOut &in_out_map) { - GE_CHECK_NOTNULL(graph_def); - for (int32_t i = 0; i < graph_def->node_size(); i++) { - const domi::tensorflow::NodeDef *node = graph_def->mutable_node(i); - const std::string &node_name = node->name(); - int32_t input_index = 0; - for (const auto &input : node->input()) { - std::string peer_node_name; - int32_t peer_out_index = 0; - bool is_control = false; - const auto ret = DecomposeInputName(input, peer_node_name, peer_out_index, is_control); - if (ret != SUCCESS) { - GELOGE(PARAM_INVALID, "Input[%s] of node[%s] is invalid.", input.c_str(), node_name.c_str()); - return PARAM_INVALID; - } - const int32_t curr_input_index = is_control ? kControlSlot : input_index++; - AddEdgeCtx(peer_node_name, peer_out_index, curr_input_index, in_out_map[node_name].first); - AddEdgeCtx(node_name, curr_input_index, peer_out_index, in_out_map[peer_node_name].second); - } - } - return SUCCESS; -} - -void GetInOutStr(const GraphNodesInOut &in_out_map, const string &node_name, - std::vector &inputs, std::vector &outputs) { - const auto in_out_iter = in_out_map.find(node_name); - if (in_out_iter == in_out_map.end()) { - GELOGI("Not find input or output info for node:%s.", node_name.c_str()); - return; - } - const auto inputs_data = in_out_iter->second.first; - for (const auto &input_data : inputs_data) { - for (const auto &name_index : input_data.second) { - const std::string item = std::to_string(input_data.first) + ":" + name_index.first + - ":" + std::to_string(name_index.second); - inputs.push_back(item); - } - } - - const auto outputs_data = in_out_iter->second.second; - for (const auto &output_data : outputs_data) { - for (const auto &name_index : output_data.second) { - const std::string item = std::to_string(output_data.first) + ":" + name_index.first + - ":" + std::to_string(name_index.second); - outputs.push_back(item); - } - } -} - -Status SetNodeInputOutputAttr(const GraphNodesInOut &in_out_map, OperatorPtr &op) { - GE_CHECK_NOTNULL(op); - std::vector inputs; - std::vector outputs; - ge::AscendString op_name; - (void) op->GetName(op_name); - GetInOutStr(in_out_map, op_name.GetString(), inputs, outputs); - (void)op->SetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_INPUTS, inputs); - (void)op->SetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_OUTPUTS, outputs); - return SUCCESS; -} -} // namespace - -Status Scope::ScopeImpl::Init(const std::string &name, const std::string &sub_type, Scope *const father_scope) { - name_ = name; - sub_type_ = sub_type; - father_scope_ = father_scope; - return SUCCESS; -} - -Scope::ScopeImpl::~ScopeImpl() { - for (auto &scope : sub_scopes_) { - if (scope.second != nullptr) { - delete scope.second; - scope.second = nullptr; - } - } -} - -void Scope::ScopeImpl::ClearTypeAndSubType() { - sub_type_ = ""; - const std::vector &sub_scopes = GetAllSubScopes(); - for (auto &sub_scope : sub_scopes) { - auto &impl = sub_scope->impl_; - impl->SetSubType(""); - } -} - -void Scope::ScopeImpl::AddNode(ge::OperatorPtr &node_def) { - if (node_def == nullptr) { - GELOGE(PARAM_INVALID, "Input node_def is nullptr."); - return; - } - - nodes_.push_back(node_def); -} - -const std::unordered_map &Scope::ScopeImpl::AllNodesMap() { - if (!all_nodes_map_.empty()) { - return all_nodes_map_; - } - - if (!nodes_.empty()) { - AscendString name; - for (const auto &node : nodes_) { - (void)node->GetName(name); - (void)all_nodes_map_.insert(std::pair(name.GetString(), node)); - } - } - const std::vector &scopes = GetAllSubScopes(); - for (auto &scope : scopes) { - auto &impl = scope->impl_; - const std::vector &sub_nodes = impl->Nodes(); - if (!sub_nodes.empty()) { - AscendString name; - for (const auto &sub_node : sub_nodes) { - (void) sub_node->GetName(name); - (void) all_nodes_map_.insert(std::pair(name.GetString(), sub_node)); - } - } - } - return all_nodes_map_; -} - -void Scope::ScopeImpl::AddSubScope(Scope *const scope) { - AscendString name; - (void)scope->Name(name); - sub_scopes_[name.GetString()] = scope; -} - -Scope *Scope::ScopeImpl::GetSubScope(const std::string &scope_name) const { - const auto iter = sub_scopes_.find(scope_name); - if (iter != sub_scopes_.end()) { - return iter->second; - } - return nullptr; -} - -const std::vector &Scope::ScopeImpl::GetAllSubScopes() { - if (!all_sub_scopes_.empty()) { - return all_sub_scopes_; - } - - for (auto &iter : sub_scopes_) { - Scope *const scope = iter.second; - all_sub_scopes_.push_back(scope); - - std::stack scopes; - scopes.push(scope); - while (!scopes.empty()) { - Scope *const sub_scope = scopes.top(); - scopes.pop(); - auto &impl = sub_scope->impl_; - const std::unordered_map &sub_scopes = impl->GetSubScopes(); - for (auto &iter_sub : sub_scopes) { - all_sub_scopes_.push_back(iter_sub.second); - scopes.push(iter_sub.second); - } - } - } - return all_sub_scopes_; -} - -int32_t Scope::ScopeImpl::GetOpTypeNum(const std::string &op_type) const { - const auto iter = op_nums_.find(op_type); - if (iter != op_nums_.end()) { - return iter->second; - } else { - return -1; - } -} - -void Scope::ScopeImpl::OpsNumInc(const std::string &op_type) { - const auto iter = op_nums_.find(op_type); - if (iter != op_nums_.end()) { - op_nums_[op_type] = iter->second + 1; - } else { - op_nums_[op_type] = 1; - } -} - -const std::string Scope::ScopeImpl::LastName() const { - const std::vector names = ge::StringUtils::Split(name_, '/'); - // if vector size is less than 2, there is no multilevel directory, return origin name. - if (names.size() < 2U) { - GELOGI("Input name is already the last name, input name:%s.", name_.c_str()); - return name_; - } - const std::string last_name = names[names.size() - 2U]; // minus 2 to get the last name - return ScopeImpl::TrimScopeIndex(last_name); -} - -std::string Scope::ScopeImpl::TrimScopeIndex(const std::string &scope_name) { - std::string scope_name_new = scope_name; - // deal D_index, only keep name D - const auto index = scope_name.find_last_of("_"); - if (index != std::string::npos) { - // index_str after "_" is integer - const std::string index_str = scope_name.substr(index + 1U, scope_name.length()); - if (index_str.find_first_not_of(kNumerics) != std::string::npos) { - return scope_name; - } - try { - if (std::stoi(index_str.c_str()) > 0) { - scope_name_new = scope_name.substr(0U, index); - } - } catch (std::invalid_argument &) { - scope_name_new = scope_name; - } catch (std::out_of_range &) { - scope_name_new = scope_name; - } - } - return scope_name_new; -} - -Scope::Scope() {} - -Status Scope::Init(const std::string &name, const std::string &sub_type, Scope *father_scope) { - impl_ = ge::ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of ScopeImpl failed."); - return ge::MEMALLOC_FAILED; - } - - return impl_->Init(name, sub_type, father_scope); -} - -Status Scope::Init(const char_t *name, const char_t *sub_type, Scope *father_scope) { - std::string scope_name; - std::string scope_sub_type; - if (name != nullptr) { - scope_name = name; - } - if (sub_type != nullptr) { - scope_sub_type = sub_type; - } - impl_ = ge::ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of ScopeImpl failed."); - return ge::MEMALLOC_FAILED; - } - - return impl_->Init(scope_name, scope_sub_type, father_scope); -} - -Scope::~Scope() = default; - -const std::string &Scope::Name() const { - return impl_->Name(); -} - -Status Scope::Name(AscendString &name) const { - name = AscendString(impl_->Name().c_str()); - return SUCCESS; -} - -const std::string &Scope::SubType() const { - return impl_->SubType(); -} - -Status Scope::SubType(AscendString &sub_type) const { - sub_type = AscendString(impl_->SubType().c_str()); - return SUCCESS; -} - -const std::unordered_map &Scope::AllNodesMap() const { - return impl_->AllNodesMap(); -} - -Status Scope::AllNodesMap(std::unordered_map &node_map) const { - const std::unordered_map nodes = impl_->AllNodesMap(); - for (auto &node : nodes) { - const AscendString tmp(node.first.c_str()); - node_map[tmp] = node.second; - } - return SUCCESS; -} - -Scope *Scope::GetSubScope(const std::string &scope_name) const { - return impl_->GetSubScope(scope_name); -} - -Scope *Scope::GetSubScope(const char_t *scope_name) const { - std::string str_scope_name; - if (scope_name != nullptr) { - str_scope_name = scope_name; - } - return impl_->GetSubScope(str_scope_name); -} - -const std::string Scope::LastName() const { - return impl_->LastName(); -} - -Status Scope::LastName(AscendString &name) const { - name = AscendString(impl_->LastName().c_str()); - return SUCCESS; -} - -const Scope *Scope::GetFatherScope() const { - return impl_->GetFatherScope(); -} - -const std::vector &Scope::GetAllSubScopes() const { - return impl_->GetAllSubScopes(); -} - -FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::~InnerNodeInfoImpl() noexcept { - operator_.BreakConnect(); -} - -std::string FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::GetFullNodeName(const std::string &relative_name) { - if (fusion_node_name_.empty()) { - return relative_name; - } - return (fusion_node_name_.at(fusion_node_name_.size() - 1U) == '/') ? (fusion_node_name_ + relative_name) - : (fusion_node_name_ + "/" + relative_name); -} - -void FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::InsertInput(const std::string &input_node, - int32_t peer_out_idx) { - std::string input_name = (input_node != kInputFromFusionScope) ? GetFullNodeName(input_node) : input_node; - inner_node_inputs_.emplace_back(std::make_pair(input_name, peer_out_idx)); -} - -void FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::InsertOutput(const std::string &output_node, - int32_t peer_in_idx) { - std::string output_name = (output_node != kOutputToFusionScope) ? GetFullNodeName(output_node) : output_node; - inner_node_outputs_.emplace_back(std::make_pair(output_name, peer_in_idx)); -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::BuildOperator() { - operator_ = ge::OperatorFactory::CreateOperator(name_.c_str(), type_.c_str()); - ge::AscendString operator_name; - (void) operator_.GetName(operator_name); - if (operator_name.GetString() != name_) { - GELOGE(ge::GRAPH_FAILED, "IR for op is not registered, op name:%s, op type:%s", name_.c_str(), type_.c_str()); - return ge::GRAPH_FAILED; - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetInputFormat(const std::string &input_name, - const std::string &format) { - ge::TensorDesc input_tesor_desc = operator_.GetInputDescByName(input_name.c_str()); - const auto ge_format = ge::TypeUtils::SerialStringToFormat(format); - input_tesor_desc.SetOriginFormat(ge_format); - input_tesor_desc.SetFormat(ge_format); - return operator_.UpdateInputDesc(input_name.c_str(), input_tesor_desc); -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetOutputFormat(const std::string &output_name, - const std::string &format) { - ge::TensorDesc output_tesor_desc = operator_.GetOutputDescByName(output_name.c_str()); - const auto ge_format = ge::TypeUtils::SerialStringToFormat(format); - output_tesor_desc.SetOriginFormat(ge_format); - output_tesor_desc.SetFormat(ge_format); - return operator_.UpdateOutputDesc(output_name.c_str(), output_tesor_desc); -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetDynamicInputFormat( - const std::string &input_name, const uint32_t index, const std::string &format) { - ge::TensorDesc input_tesor_desc = operator_.GetDynamicInputDesc(input_name.c_str(), index); - const auto ge_format = ge::TypeUtils::SerialStringToFormat(format); - input_tesor_desc.SetOriginFormat(ge_format); - input_tesor_desc.SetFormat(ge_format); - return operator_.UpdateDynamicInputDesc(input_name.c_str(), index, input_tesor_desc); -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl::SetDynamicOutputFormat( - const std::string &output_name, const uint32_t index, const std::string &format) { - ge::TensorDesc output_tesor_desc = operator_.GetDynamicOutputDesc(output_name.c_str(), index); - const auto ge_format = ge::TypeUtils::SerialStringToFormat(format); - output_tesor_desc.SetOriginFormat(ge_format); - output_tesor_desc.SetFormat(ge_format); - return operator_.UpdateDynamicOutputDesc(output_name.c_str(), index, output_tesor_desc); -} - -FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const std::string &fusion_node_name) { - impl_ = ge::ComGraphMakeUnique(fusion_node_name); -} - -FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const char_t *fusion_node_name) { - std::string str_fusion_node_name; - if (fusion_node_name != nullptr) { - str_fusion_node_name = fusion_node_name; - } - impl_ = ge::ComGraphMakeUnique(str_fusion_node_name); -} - -FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, - const std::string &type) { - impl_ = ge::ComGraphMakeUnique(fusion_node_name, name, type); -} - -FusionScopesResult::InnerNodeInfo::InnerNodeInfo(const char_t *fusion_node_name, const char_t *name, - const char_t *type) { - impl_ = ge::ComGraphMakeUnique(fusion_node_name, name, type); -} - -FusionScopesResult::InnerNodeInfo::InnerNodeInfo(FusionScopesResult::InnerNodeInfo &&other) noexcept - : impl_(std::move(other.impl_)) {} - -FusionScopesResult::InnerNodeInfo::~InnerNodeInfo() = default; - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::operator=( - FusionScopesResult::InnerNodeInfo &&other) noexcept { - if (&other != this) { - impl_ = std::move(other.impl_); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetName(const std::string &name) { - if (impl_ != nullptr) { - impl_->SetName(name); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetName(const char_t *name) { - if ((impl_ != nullptr) && (name != nullptr)) { - const std::string str_name = name; - impl_->SetName(str_name); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetType(const std::string &type) { - if (impl_ != nullptr) { - impl_->SetType(type); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::SetType(const char_t *type) { - if ((impl_ != nullptr) && (type != nullptr)) { - const std::string str_type = type; - impl_->SetType(str_type); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertInput(const std::string &input_node, - int32_t peer_out_idx) { - if (impl_ != nullptr) { - impl_->InsertInput(input_node, peer_out_idx); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertInput(const char_t *input_node, - int32_t peer_out_idx) { - if ((impl_ != nullptr) && (input_node != nullptr)) { - const std::string str_input_node = input_node; - impl_->InsertInput(str_input_node, peer_out_idx); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertOutput(const std::string &output_node, - int32_t peer_in_idx) { - if (impl_ != nullptr) { - impl_->InsertOutput(output_node, peer_in_idx); - } - return *this; -} - -FusionScopesResult::InnerNodeInfo &FusionScopesResult::InnerNodeInfo::InsertOutput(const char_t *output_node, - int32_t peer_in_idx) { - if ((impl_ != nullptr) && (output_node != nullptr)) { - const std::string str_output_node = output_node; - impl_->InsertOutput(str_output_node, peer_in_idx); - } - return *this; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::BuildInnerNode() { - if (impl_ != nullptr) { - return impl_->BuildOperator(); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::Operator *FusionScopesResult::InnerNodeInfo::MutableOperator() { - if (impl_ != nullptr) { - return impl_->MutableOperator(); - } - return nullptr; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetInputFormat(const std::string &input_name, - const std::string &format) { - if (impl_ != nullptr) { - return impl_->SetInputFormat(input_name, format); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetInputFormat(const char_t *input_name, - const char_t *format) { - if ((impl_ != nullptr) && (input_name != nullptr) && (format != nullptr)) { - const std::string str_input_name = input_name; - const std::string str_format = format; - return impl_->SetInputFormat(str_input_name, str_format); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetOutputFormat(const std::string &output_name, - const std::string &format) { - if (impl_ != nullptr) { - return impl_->SetOutputFormat(output_name, format); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetOutputFormat(const char_t *output_name, - const char_t *format) { - if ((impl_ != nullptr) && (output_name != nullptr) && (format != nullptr)) { - const std::string str_output_name = output_name; - const std::string str_format = format; - return impl_->SetOutputFormat(str_output_name, str_format); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicInputFormat(const std::string &input_name, uint32_t index, - const std::string &format) { - if (impl_ != nullptr) { - return impl_->SetDynamicInputFormat(input_name, index, format); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicInputFormat(const char_t *input_name, uint32_t index, - const char_t *format) { - if ((impl_ != nullptr) && (input_name != nullptr) && (format != nullptr)) { - const std::string str_input_name = input_name; - const std::string str_format = format; - return impl_->SetDynamicInputFormat(str_input_name, index, str_format); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicOutputFormat(const std::string &output_name, - uint32_t index, const std::string &format) { - if (impl_ != nullptr) { - return impl_->SetDynamicOutputFormat(output_name, index, format); - } - return ge::GRAPH_PARAM_INVALID; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::SetDynamicOutputFormat(const char_t *output_name, - uint32_t index, const char_t *format) { - if ((impl_ != nullptr) && (output_name != nullptr) && (format != nullptr)) { - const std::string str_output_name = output_name; - const std::string str_format = format; - return impl_->SetDynamicOutputFormat(str_output_name, index, str_format); - } - return ge::GRAPH_PARAM_INVALID; -} - -std::string FusionScopesResult::InnerNodeInfo::GetName() const { - if (impl_ != nullptr) { - return impl_->GetName(); - } - return ""; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::GetName(AscendString &name) const { - if (impl_ != nullptr) { - name = AscendString(impl_->GetName().c_str()); - } - return GRAPH_SUCCESS; -} - -std::string FusionScopesResult::InnerNodeInfo::GetType() const { - if (impl_ != nullptr) { - return impl_->GetType(); - } - return ""; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::GetType(AscendString &type) const { - if (impl_ != nullptr) { - type = AscendString(impl_->GetType().c_str()); - } - return GRAPH_SUCCESS; -} - -std::vector> FusionScopesResult::InnerNodeInfo::GetInputs() const { - const std::vector> tmp; - if (impl_ != nullptr) { - return impl_->GetInputs(); - } - return tmp; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::GetInputs( - std::vector> &inputs) const { - std::vector> tmps; - if (impl_ != nullptr) { - tmps = impl_->GetInputs(); - } - for (auto &tmp : tmps) { - inputs.emplace_back(std::pair(AscendString(tmp.first.c_str()), tmp.second)); - } - return GRAPH_SUCCESS; -} - -std::vector> FusionScopesResult::InnerNodeInfo::GetOutputs() const { - const std::vector> tmp; - if (impl_ != nullptr) { - return impl_->GetOutputs(); - } - return tmp; -} - -ge::graphStatus FusionScopesResult::InnerNodeInfo::GetOutputs( - std::vector> &outputs) const { - std::vector> tmps; - if (impl_ != nullptr) { - tmps = impl_->GetOutputs(); - } - for (auto &tmp : tmps) { - outputs.emplace_back(std::pair(tmp.first.c_str(), tmp.second)); - } - return GRAPH_SUCCESS; -} - -void FusionScopesResult::FusionScopesResultImpl::AddNodes(const std::vector &nodes) { - (void)nodes_.insert(nodes_.cend(), nodes.cbegin(), nodes.cend()); -} - -void FusionScopesResult::FusionScopesResultImpl::InsertInputs(const std::string &inner_op_name, - const std::vector &index_map) { - (void)inputs_.insert(make_pair(inner_op_name, index_map)); -} -void FusionScopesResult::FusionScopesResultImpl::InsertOutputs(const std::string &inner_op_name, - const std::vector &index_map) { - (void)outputs_.insert(make_pair(inner_op_name, index_map)); -} - -bool FusionScopesResult::FusionScopesResultImpl::FindNodes(const std::string &node_name) const { - for (auto &node : nodes_) { - ge::AscendString name; - (void) node->GetName(name); - if (name.GetString() == node_name) { - return true; - } - } - return false; -} - -bool FusionScopesResult::FusionScopesResultImpl::FindScopes(const std::string &scope_name) const { - for (auto &scope : scopes_) { - AscendString name; - (void)scope->Name(name); - if ((std::string(name.GetString()).length() < scope_name.length()) && - (scope_name.find(std::string(name.GetString())) == 0U)) { - return true; - } - } - return false; -} - -FusionScopesResult::InnerNodeInfo *FusionScopesResult::FusionScopesResultImpl::AddInnerNode(const std::string &name, - const std::string &type) { - inner_node_infos_.emplace_back(InnerNodeInfo(name_.c_str(), name.c_str(), type.c_str())); - return &(inner_node_infos_[inner_node_infos_.size() - 1U]); -} - -FusionScopesResult::InnerNodeInfo *FusionScopesResult::FusionScopesResultImpl::MutableRecentInnerNode() { - const size_t size = inner_node_infos_.size(); - if (size >= 1U) { - return &(inner_node_infos_[size - 1U]); - } - return nullptr; -} - -FusionScopesResult::InnerNodeInfo *FusionScopesResult::FusionScopesResultImpl::MutableInnerNode(uint32_t index) { - if (static_cast(index) < inner_node_infos_.size()) { - return &(inner_node_infos_[static_cast(index)]); - } - return nullptr; -} - -FusionInnerNodesInfo FusionScopesResult::FusionScopesResultImpl::GetInnerNodesInfo() { - FusionInnerNodesInfo nodes_info; - for (auto &info : inner_node_infos_) { - ge::AscendString name; - (void) info.GetName(name); - ge::AscendString type; - (void) info.GetType(type); - std::vector> inputs; - (void) info.GetInputs(inputs); - std::vector> input_strings; - for (const auto &input : inputs) { - input_strings.emplace_back(input.first.GetString(), input.second); - } - std::vector> outputs; - (void) info.GetOutputs(outputs); - std::vector> output_strings; - for (const auto &output : outputs) { - output_strings.emplace_back(output.first.GetString(), output.second); - } - nodes_info.emplace_back( - std::make_tuple(name.GetString(), type.GetString(), input_strings, output_strings, info.MutableOperator())); - } - return nodes_info; -} - -ge::graphStatus FusionScopesResult::FusionScopesResultImpl::CheckInnerNodesInfo() { - size_t input_from_scope = 0U; - size_t output_to_scope = 0U; - std::set name_set; - for (const auto &info : inner_node_infos_) { - ge::AscendString name; - (void) info.GetName(name); - if (!(name_set.insert(name.GetString()).second)) { - GELOGE(ge::GRAPH_PARAM_INVALID, "There are duplicate internal node name, please check."); - return ge::GRAPH_PARAM_INVALID; - } - std::vector> inputs; - (void) info.GetInputs(inputs); - for (const auto &input : inputs) { - input_from_scope += - static_cast((std::string(input.first.GetString()) == kInputFromFusionScope) ? 1UL : 0UL); - } - std::vector> outputs; - (void) info.GetOutputs(outputs); - for (const auto &output : outputs) { - output_to_scope += - static_cast((std::string(output.first.GetString()) == kOutputToFusionScope) ? 1UL : 0UL); - } - } - size_t scope_input = 0U; - size_t scope_output = 0U; - for (const auto &input : inputs_) { - for (const auto &idx : input.second) { - scope_input += static_cast((idx != kFusionDisableIndex) ? 1UL : 0UL); - } - } - for (const auto &output : outputs_) { - for (const auto &idx : output.second) { - scope_output += static_cast((idx != kFusionDisableIndex) ? 1UL : 0UL); - } - } - if ((input_from_scope != scope_input) || (output_to_scope != scope_output)) { - GELOGE(ge::GRAPH_PARAM_INVALID, - "Input or Output mismatched, please check. " - "Inner input_from_scope:%zu, scope input:%zu, " - "inner output_to_scope:%zu, scope output:%zu.", - input_from_scope, scope_input, output_to_scope, scope_output); - return ge::GRAPH_PARAM_INVALID; - } - return ge::GRAPH_SUCCESS; -} - -FusionScopesResult::FusionScopesResult() {} - -Status FusionScopesResult::Init() { - impl_ = ge::ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of FusionScopesResultImpl failed."); - return ge::MEMALLOC_FAILED; - } - - return SUCCESS; -} - -FusionScopesResult::~FusionScopesResult() = default; - -void FusionScopesResult::SetName(const std::string &name) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - impl_->SetName(name); -} - -void FusionScopesResult::SetName(const char_t *name) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - std::string str_name; - if (name != nullptr) { - str_name = name; - } - impl_->SetName(str_name); -} - -void FusionScopesResult::SetType(const std::string &type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - impl_->SetType(type); -} - -void FusionScopesResult::SetType(const char_t *type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - std::string str_type; - if (type != nullptr) { - str_type = type; - } - impl_->SetType(str_type); -} - -void FusionScopesResult::SetDescription(const std::string &description) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - impl_->SetDescription(description); -} - -void FusionScopesResult::SetDescription(const char_t *description) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - std::string str_desc; - if (description != nullptr) { - str_desc = description; - } - impl_->SetDescription(str_desc); -} - -const std::string &FusionScopesResult::Name() const { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - static std::string name; - return name; - } - return impl_->Name(); -} - -Status FusionScopesResult::Name(AscendString &name) const { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return ge::GRAPH_PARAM_INVALID; - } - name = AscendString(impl_->Name().c_str()); - return SUCCESS; -} - -const std::vector &FusionScopesResult::Nodes() const { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - static std::vector nodes; - return nodes; - } - return impl_->Nodes(); -} - -void FusionScopesResult::InsertInputs(const std::string &inner_op_name, const std::vector &index_map) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - impl_->InsertInputs(inner_op_name, index_map); -} - -void FusionScopesResult::InsertInputs(const char_t *inner_op_name, const std::vector &index_map) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - std::string op_name; - if (inner_op_name != nullptr) { - op_name = inner_op_name; - } - impl_->InsertInputs(op_name, index_map); -} - -void FusionScopesResult::InsertOutputs(const std::string &inner_op_name, const std::vector &index_map) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return; - } - impl_->InsertOutputs(inner_op_name, index_map); -} - -void FusionScopesResult::InsertOutputs(const char_t *inner_op_name, const std::vector &index_map) { - std::string op_name; - if (inner_op_name != nullptr) { - op_name = inner_op_name; - } - impl_->InsertOutputs(op_name, index_map); -} - -FusionScopesResult::InnerNodeInfo *FusionScopesResult::AddInnerNode(const std::string &name, const std::string &type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return nullptr; - } - return impl_->AddInnerNode(name, type); -} - -FusionScopesResult::InnerNodeInfo *FusionScopesResult::AddInnerNode(const char_t *name, const char_t *type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return nullptr; - } - std::string str_name; - if (name != nullptr) { - str_name = name; - } - std::string str_type; - if (type != nullptr) { - str_type = type; - } - return impl_->AddInnerNode(str_name, str_type); -} - -FusionScopesResult::InnerNodeInfo *FusionScopesResult::MutableRecentInnerNode() { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return nullptr; - } - return impl_->MutableRecentInnerNode(); -} - -FusionScopesResult::InnerNodeInfo *FusionScopesResult::MutableInnerNode(uint32_t index) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return nullptr; - } - return impl_->MutableInnerNode(index); -} - -ge::graphStatus FusionScopesResult::CheckInnerNodesInfo() { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "FusionScopesResult is not properly initialized."); - return ge::GRAPH_PARAM_INVALID; - } - return impl_->CheckInnerNodesInfo(); -} - -Status ScopeTree::ScopeTreeImpl::Init() { - root_ = new (std::nothrow) Scope(); - if (root_ == nullptr) { - GELOGE(FAILED, "Alloc root scope failed."); - return FAILED; - } - const std::string name("root"); - if (root_->Init(name.c_str(), nullptr) != SUCCESS) { - GELOGE(FAILED, "Init root scope failed."); - return FAILED; - } - scopes_.push_back(root_); - return SUCCESS; -} - -ScopeTree::ScopeTreeImpl::~ScopeTreeImpl() { - if (root_ != nullptr) { - delete root_; - root_ = nullptr; - } -} - -void ScopeTree::ScopeTreeImpl::AddNodeToScope(ge::OperatorPtr &node_def) { - if (node_def == nullptr) { - GELOGE(PARAM_INVALID, "Input node_def is nullptr."); - return; - } - ge::AscendString node_name; - (void) node_def->GetName(node_name); - const std::vector scopes = SplitNodeName(node_name.GetString(), '/'); - Scope *super_scope = root_; - for (size_t i = 0U; i < scopes.size(); ++i) { - auto &impl = super_scope->impl_; - ge::AscendString node_type; - (void) node_def->GetOpType(node_type); - impl->OpsNumInc(node_type.GetString()); - - if (i == (scopes.size() - 1U)) { - impl->AddNode(node_def); - } else { - Scope *sub_scope = impl->GetSubScope(scopes[i]); - if (sub_scope == nullptr) { - sub_scope = new (std::nothrow) Scope(); - if (sub_scope == nullptr) { - GELOGE(FAILED, "Alloc Scope failed."); - return; - } - const auto ret = sub_scope->Init(scopes[i].c_str(), nullptr, super_scope); - if (ret != SUCCESS) { - GELOGE(FAILED, "Init Scope failed."); - delete sub_scope; - sub_scope = nullptr; - return; - } - scopes_.push_back(sub_scope); - impl->AddSubScope(sub_scope); - } - super_scope = sub_scope; - } - } -} - -std::vector ScopeTree::ScopeTreeImpl::SplitNodeName(const std::string &node_name, - const char_t delim) const { - std::vector items; - std::vector scopes; - if (node_name == "") { - return items; - } - - items = ge::StringUtils::Split(node_name, delim); - std::string scope; - for (uint32_t i = 0U; i < items.size(); ++i) { - if (items[static_cast(i)].length() == 0U) { - continue; - } - - if (i == 0U) { - scope = items[static_cast(i)]; - } else { - scope = scope + items[static_cast(i)]; - } - - if (i != (items.size() - 1U)) { - scope = scope + delim; - } - - scopes.push_back(scope); - } - - return scopes; -} - -ScopeTree::ScopeTree() {} - -ScopeTree::~ScopeTree() = default; - -Status ScopeTree::Init() { - impl_ = ge::ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of FusionScopesResultImpl failed."); - return ge::MEMALLOC_FAILED; - } - return impl_->Init(); -} - -const std::vector &ScopeTree::GetAllScopes() const { - return impl_->GetAllScopes(); -} - -Status ScopeGraph::ScopeGraphImpl::Init() { - scope_tree_ = new (std::nothrow) ScopeTree(); - if (scope_tree_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Alloc scope tree failed."); - return ge::MEMALLOC_FAILED; - } - const Status ret = scope_tree_->Init(); - if (ret != SUCCESS) { - GELOGE(FAILED, "Scope tree init failed."); - return FAILED; - } - return SUCCESS; -} - -ScopeGraph::ScopeGraphImpl::~ScopeGraphImpl() noexcept { - if (scope_tree_ != nullptr) { - delete scope_tree_; - scope_tree_ = nullptr; - } - - for (auto &fusion_result : fusion_results_) { - if (fusion_result.second != nullptr) { - delete fusion_result.second; - fusion_result.second = nullptr; - } - } - - for (const auto &item : nodes_map_) { - item.second->BreakConnect(); - } -} - -void ScopeGraph::ScopeGraphImpl::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { - if (graph_def == nullptr) { - GELOGE(PARAM_INVALID, "Input graph_def is nullptr."); - return; - } - GraphNodesInOut graph_nodes_in_out; - const auto status = GetGraphDefInOutMap(graph_def, graph_nodes_in_out); - if (status != SUCCESS) { - GELOGE(FAILED, "Failed to get node input output map for graph."); - return; - } - - for (int32_t i = 0; i < graph_def->node_size(); ++i) { - const domi::tensorflow::NodeDef *const node_def = graph_def->mutable_node(i); - ge::OperatorPtr op(new (std::nothrow) ge::Operator(node_def->name().c_str(), node_def->op().c_str())); - if (op == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Make shared_ptr falied."); - return; - } - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(*op); - Status ret = domi::OperatorAutoMapping(node_def, *op); - if (ret != SUCCESS) { - GELOGE(FAILED, "Op: %s call auto mapping function failed.", op_desc->GetName().c_str()); - return; - } - - for (int32_t j = 0; j < node_def->input_size(); j++) { - ge::GeTensorDesc tensor_desc; - tensor_desc.SetName(node_def->input(j)); - (void)op_desc->AddInputDesc(tensor_desc); - } - ret = SetNodeInputOutputAttr(graph_nodes_in_out, op); - if (ret != SUCCESS) { - ge::AscendString op_name; - (void) op->GetName(op_name); - GELOGE(FAILED, "Failed to set input output attr, op:%s.", op_name.GetString()); - return; - } - AscendString name; - (void)op->GetName(name); - (void)nodes_map_.emplace(std::string(name.GetString()), op); - AscendString type; - (void)op->GetOpType(type); - if ((type.GetString() != kTfIdentityType) || (type.GetString() != kTfConstType)) { - auto &impl = scope_tree_->impl_; - impl->AddNodeToScope(op); - } - } -} - -void ScopeGraph::ScopeGraphImpl::AddFusionScopesResult(FusionScopesResult *result) { - if (result == nullptr) { - GELOGE(PARAM_INVALID, "Input params invalid, result is nullptr."); - return; - } - ge::AscendString result_name; - (void) result->Name(result_name); - fusion_results_[result_name.GetString()] = result; -} - -bool ScopeGraph::ScopeGraphImpl::IsFusionOpChild(const std::string &node_name, - std::vector &info_list) { - bool find = false; - for (auto &fusion_result : fusion_results_) { - const FusionScopesResult *const fusion_node = fusion_result.second; - auto &impl = fusion_node->impl_; - - if (impl->FindNodes(node_name) || impl->FindScopes(node_name)) { - ScopeFusionOpInfo info; - ge::AscendString name; - (void) fusion_node->Name(name); - info.fusion_node_name = name.GetString(); - info.fusion_op_type = impl->Type(); - info.node_name = node_name; - info.description = impl->Description(); - info.scope_pass = true; - info_list.push_back(info); - - find = true; - } - } - - return find; -} - -bool ScopeGraph::ScopeGraphImpl::FusionOpChildIgnore(const ScopeFusionOpInfo &info) { - if ((!(GetFusionResultInputOrOutput(info, true).empty())) || - (!(GetFusionResultInputOrOutput(info, false).empty()))) { - return false; - } - return true; -} - -std::vector ScopeGraph::ScopeGraphImpl::GetFusionResultInputOrOutput(const ScopeFusionOpInfo &info, - const bool input) { - std::vector indexs; - const auto fusion_iter = fusion_results_.find(info.fusion_node_name); - if (fusion_iter == fusion_results_.end()) { - GELOGE(FAILED, "Get fusion result failed, not found node:%s", info.fusion_node_name.c_str()); - return indexs; - } - - const FusionScopesResult *const fusion_node = fusion_iter->second; - std::map> inout_map; - auto &impl = fusion_node->impl_; - if (input) { - inout_map = impl->GetInputs(); - } else { - inout_map = impl->GetOutputs(); - } - - for (auto &iter : inout_map) { - const std::string input_name = iter.first; - const std::string op_name = (info.node_name.length() > input_name.length()) - ? info.node_name.substr(info.node_name.length() - input_name.length()) - : info.node_name; - if (input_name == op_name) { - (void)indexs.insert(indexs.cend(), iter.second.cbegin(), iter.second.cend()); - break; - } - } - - return indexs; -} - -bool ScopeGraph::ScopeGraphImpl::IsFusionOp(const domi::tensorflow::NodeDef *const node_def) { - if (node_def == nullptr) { - GELOGE(PARAM_INVALID, "Input node_def is nullptr."); - return false; - } - for (auto &fusion_result : fusion_results_) { - const FusionScopesResult *const fusion_node = fusion_result.second; - auto &impl = fusion_node->impl_; - AscendString name; - (void)fusion_node->Name(name); - if ((impl->Type() == node_def->op()) && (name.GetString() == node_def->name())) { - return true; - } - } - return false; -} - -Status ScopeGraph::ScopeGraphImpl::GetInputOrOutputIndex(const ScopeFusionOpInfo &info, const int32_t old_index, - const bool input, int32_t &new_index) { - if (old_index == -1) { - new_index = -1; - return SUCCESS; - } - - const std::vector indexs = GetFusionResultInputOrOutput(info, input); - GELOGD("GetNodeindex, node_name:%s, fusion_node_name:%s, fusion_op_type:%s, old_index:%d, size:%zu.", - info.node_name.c_str(), info.fusion_node_name.c_str(), info.fusion_op_type.c_str(), old_index, indexs.size()); - if (static_cast(indexs.size()) < (old_index + 1)) { - GELOGD("GetNodeindex fusionDisableIndex, node_name:%s, fusion_node_name:%s, fusion_op_type:%s, old_index:%d .", - info.node_name.c_str(), info.fusion_node_name.c_str(), info.fusion_op_type.c_str(), old_index); - new_index = kFusionDisableIndex; - } else { - new_index = indexs[static_cast(old_index)]; - } - GELOGD("RESULT: new index:%d.", new_index); - return SUCCESS; -} - -FusionScopesResult *ScopeGraph::ScopeGraphImpl::GetFusionScopesResults( - const domi::tensorflow::NodeDef *const node_def) const { - if (node_def == nullptr) { - return nullptr; - } - return GetFusionScopesResults(node_def->name()); -} - -FusionScopesResult *ScopeGraph::ScopeGraphImpl::GetFusionScopesResults(const string &node_name) const { - const auto iter = fusion_results_.find(node_name); - if (iter != fusion_results_.end()) { - return iter->second; - } else { - return nullptr; - } -} - -ScopeGraph::ScopeGraph() {} - -ScopeGraph::~ScopeGraph() = default; - -Status ScopeGraph::Init() { - impl_ = ge::ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Make unique_ptr of ScopeGraphImpl failed."); - return ge::MEMALLOC_FAILED; - } - return impl_->Init(); -} - -const ScopeTree *ScopeGraph::GetScopeTree() const { - return impl_->GetScopeTree(); -} - -const std::unordered_map &ScopeGraph::GetNodesMap() const { - return impl_->GetNodesMap(); -} - -Status ScopeGraph::GetNodesMap(std::unordered_map &nodes_map) const { - std::unordered_map tmps; - if (impl_ != nullptr) { - tmps = impl_->GetNodesMap(); - } - for (auto &tmp : tmps) { - const AscendString node(tmp.first.c_str()); - nodes_map[node] = tmp.second; - } - return SUCCESS; -} -} // namespace ge diff --git a/register/scope/scope_pass.cc b/register/scope/scope_pass.cc deleted file mode 100644 index d4d2f0ac0a7921989286c4e295f3db24b4317471..0000000000000000000000000000000000000000 --- a/register/scope/scope_pass.cc +++ /dev/null @@ -1,331 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "register/scope/scope_pass_impl.h" -#include "register/scope/scope_graph_impl.h" -#include "register/scope/scope_pattern_impl.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/debug/ge_util.h" - -namespace ge { -ScopesResult::ScopesResult() { - impl_ = ge::ComGraphMakeUnique(); -} - -ScopesResult::ScopesResult(ScopesResult const &result) { - impl_ = ge::ComGraphMakeUnique(); - if ((impl_ == nullptr) || (result.impl_ == nullptr)) { - GELOGE(ge::MEMALLOC_FAILED, "ScopesResult is not properly initialized."); - return; - } - const std::vector &scopes = result.impl_->GetScopes(); - const std::vector &nodes = result.impl_->GetNodes(); - impl_->SetScopes(scopes); - impl_->SetNodes(nodes); -} -ScopesResult &ScopesResult::operator=(ScopesResult const &result) { - if (&result == this) { - return *this; - } - if ((impl_ == nullptr) || (result.impl_ == nullptr)) { - GELOGE(ge::MEMALLOC_FAILED, "ScopesResult is not properly initialized."); - return *this; - } - const std::vector &scopes = result.impl_->GetScopes(); - const std::vector &nodes = result.impl_->GetNodes(); - impl_->SetScopes(scopes); - impl_->SetNodes(nodes); - return *this; -} - -ScopesResult::~ScopesResult() = default; - -void ScopesResult::SetScopes(std::vector &scopes) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetScopes(), ScopesResult is not properly initialized."); - return; - } - - impl_->SetScopes(scopes); -} - -void ScopesResult::SetNodes(std::vector &nodes) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetNodes(), ScopesResult is not properly initialized."); - return; - } - - impl_->SetNodes(nodes); -} - -ScopeBasePass::ScopeBasePassImpl::~ScopeBasePassImpl() { - for (auto &scope_patterns : patterns_) { - for (auto &batch_patterns : scope_patterns) { - for (auto &pattern : batch_patterns) { - if (pattern != nullptr) { - delete pattern; - pattern = nullptr; - } - } - } - } -} - -Status ScopeBasePass::ScopeBasePassImpl::AddFusionScopesResultToScopeGraph( - const std::shared_ptr &scope_graph, std::vector &scope_results) const { - for (auto &rlt : scope_results) { - std::unique_ptr fusion_rlt = ComGraphMakeUnique(); - if (fusion_rlt == nullptr) { - GELOGE(FAILED, "Alloc fusion_rlt failed."); - return FAILED; - } - if (fusion_rlt->Init() != SUCCESS) { - GELOGE(FAILED, "Init fusion_rlt failed."); - return FAILED; - } - auto &impl_fusion_rlt = fusion_rlt->impl_; - auto &impl_scope_rlt = rlt.impl_; - if (impl_scope_rlt == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "ScopesResult is not properly initialized."); - continue; - } - - impl_fusion_rlt->AddNodes(impl_scope_rlt->GetNodes()); - impl_fusion_rlt->AddScopes(impl_scope_rlt->GetScopes()); - parent_->GenerateFusionResult(impl_scope_rlt->GetScopes(), fusion_rlt.get()); - if (impl_fusion_rlt->Type() == kScopeInvalidType) { - GELOGE(FAILED, "Failed to set inner node for fusion op %s.", impl_fusion_rlt->Type().c_str()); - return FAILED; - } - auto &impl_scope_graph = scope_graph->impl_; - impl_scope_graph->AddFusionScopesResult(fusion_rlt.release()); - } - - return SUCCESS; -} - -Status ScopeBasePass::ScopeBasePassImpl::Run(std::shared_ptr &scope_graph) { - GE_CHECK_NOTNULL(scope_graph); - const ScopeTree *const scope_tree = scope_graph->GetScopeTree(); - GE_CHECK_NOTNULL(scope_tree); - GE_CHECK_NOTNULL(parent_); - patterns_ = parent_->DefinePatterns(); - std::vector results; - if (!MatchAllBatches(scope_tree, results)) { - GELOGI("[scope_fusion] Scope pass %s's patterns is not matched and ignored.", parent_->PassName().c_str()); - return domi::SCOPE_NOT_CHANGED; - } - GELOGI("[scope_fusion] Scope pass %s's patterns is matched.", parent_->PassName().c_str()); - - std::vector scope_results; - Status ret = parent_->LastMatchScopesAndOPs(scope_graph, scope_results); - if (ret != SUCCESS) { - for (auto &result : results) { - GE_CHECK_NOTNULL(result); - auto &impl_scope = result->impl_; - impl_scope->ClearTypeAndSubType(); - } - GELOGW("[ScopeFusion][RunPass] Scope pass %s's patterns is ignored, because LastMatchScopesAndOPs failed.", - parent_->PassName().c_str()); - return domi::SCOPE_NOT_CHANGED; - } - - if (!results.empty()) { - ret = AddFusionScopesResultToScopeGraph(scope_graph, scope_results); - if (ret != SUCCESS) { - GELOGE(FAILED, "Scope pass %s add fusion scopes result to scope graph failed.", parent_->PassName().c_str()); - return domi::SCOPE_NOT_CHANGED; - } - } else { - GELOGI("[scope_fusion] Scope pass %s not match any scope.", parent_->PassName().c_str()); - } - - ret = PrintFusionScopeInfo(scope_graph); - if (ret != SUCCESS) { - GELOGI("[scope_fusion] Can not print scope pass %s fusion info.", parent_->PassName().c_str()); - return FAILED; - } - - return SUCCESS; -} - -bool ScopeBasePass::ScopeBasePassImpl::MatchAllBatches(const ScopeTree *scope_tree, std::vector &results) { - if (scope_tree == nullptr) { - GELOGE(PARAM_INVALID, "Input param [scope_tree] is nullptr."); - return false; - } - - for (auto &scope_patterns : patterns_) { - std::vector tmp_results; - std::vector last_results; - uint32_t batch_num = 0U; - for (auto &batch_patterns : scope_patterns) { - ++batch_num; - std::vector one_results; - const bool is_matched = MatchOneBatch(scope_tree, batch_patterns, one_results); - if (!is_matched) { - break; - } - if (batch_num == scope_patterns.size()) { - (void)last_results.insert(last_results.cend(), one_results.cbegin(), one_results.cend()); - } else { - (void)tmp_results.insert(tmp_results.cend(), one_results.cbegin(), one_results.cend()); - } - } - for (auto &tmp : tmp_results) { - bool rollback = true; - for (auto &result : last_results) { - AscendString result_name; - AscendString tmp_name; - (void) result->Name(result_name); - (void) tmp->Name(tmp_name); - if ((result_name.GetLength() <= tmp_name.GetLength()) && (tmp_name.Find(result_name) == 0U)) { - rollback = false; - break; - } - } - if (rollback) { - auto &impl = tmp->impl_; - impl->SetSubType(""); - } - } - (void)results.insert(results.cend(), last_results.cbegin(), last_results.cend()); - } - - return !(results.empty()); -} - -bool ScopeBasePass::ScopeBasePassImpl::MatchOneBatch(const ScopeTree *const scope_tree, - const std::vector &patternlist, - std::vector &results) const { - if (scope_tree == nullptr) { - GELOGE(PARAM_INVALID, "Input param [scope_tree] is nullptr"); - return false; - } - - int32_t find = 0; - auto &impl_scope_tree = scope_tree->impl_; - const Scope *const root = impl_scope_tree->Root(); - if (root != nullptr) { - auto &impl_scope = root->impl_; - const std::unordered_map &sub_scopes = impl_scope->GetSubScopes(); - for (auto &pattern : patternlist) { - for (auto &scope : sub_scopes) { - if (MatchOneScope(pattern, scope.second, results)) { - ++find; - } - } - } - } - - return (find > 0) ? true : false; -} - -bool ScopeBasePass::ScopeBasePassImpl::MatchOneScope(const ScopePattern *pattern, Scope *scope, - std::vector &results) const { - if ((pattern == nullptr) || (scope == nullptr)) { - GELOGE(PARAM_INVALID, "Input param is nullptr"); - return false; - } - auto &impl_scope_pattern = pattern->impl_; - if (impl_scope_pattern == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "ScopePattern is not properly initialized."); - return false; - } - if (impl_scope_pattern->Match(scope)) { - auto &scope_impl = scope->impl_; - scope_impl->SetSubType(impl_scope_pattern->SubType()); - results.push_back(scope); - return true; - } - int32_t find = 0; - std::stack scopes; - scopes.push(scope); - while (!scopes.empty()) { - const Scope *const current_scope = scopes.top(); - scopes.pop(); - auto ¤t_scope_impl = current_scope->impl_; - const std::unordered_map &sub_scopes = current_scope_impl->GetSubScopes(); - for (auto &sub_scope : sub_scopes) { - if (impl_scope_pattern->Match(sub_scope.second)) { - auto &sub_scope_impl = sub_scope.second->impl_; - sub_scope_impl->SetSubType(impl_scope_pattern->SubType()); - results.push_back(sub_scope.second); - ++find; - } else { - scopes.push(sub_scope.second); - } - } - } - return (find > 0) ? true : false; -} - -Status ScopeBasePass::ScopeBasePassImpl::PrintFusionScopeInfo(std::shared_ptr &scope_graph) const { - if (scope_graph == nullptr) { - GELOGE(PARAM_INVALID, "Input param scope_graph is nullptr."); - return PARAM_INVALID; - } - auto &impl_scope_graph = scope_graph->impl_; - const std::unordered_map &final_results = impl_scope_graph->FusionScopesResults(); - for (auto &result : final_results) { - if (result.second == nullptr) { - GELOGE(PARAM_INVALID, "Fusion scope is nullptr."); - return PARAM_INVALID; - } - AscendString name; - (void) result.second->Name(name); - GELOGI("FusionScope:%s", name.GetString()); - auto &impl = result.second->impl_; - const std::map> &inputs = impl->GetInputs(); - for (auto &input : inputs) { - const std::vector indexs = input.second; - for (const int32_t index : indexs) { - GELOGI("FusionScope input node:%s,%d", input.first.c_str(), index); - } - } - - const std::map> &outputs = impl->GetOutputs(); - for (auto &output : outputs) { - const std::vector indexs = output.second; - for (const int32_t index : indexs) { - GELOGI("FusionScope output node:%s,%d", output.first.c_str(), index); - } - } - - for (auto &scope : impl->Scopes()) { - if (scope == nullptr) { - GELOGE(PARAM_INVALID, "Scope in fusion scope is nullptr."); - return PARAM_INVALID; - } - AscendString scope_name; - (void) scope->Name(scope_name); - GELOGI("FusionScope GetScope:%s", scope_name.GetString()); - } - - for (auto &node : result.second->Nodes()) { - if (node == nullptr) { - GELOGE(PARAM_INVALID, "Node in scope is nullptr."); - return PARAM_INVALID; - } - AscendString node_name; - (void) node->GetName(node_name); - GELOGI("FusionScope Node:%s", node_name.GetString()); - } - } - return SUCCESS; -} - -ScopeBasePass::ScopeBasePass() { - impl_ = ge::ComGraphMakeUnique(this); -} - -ScopeBasePass::~ScopeBasePass() = default; -} // namespace ge diff --git a/register/scope/scope_pass_registry.cc b/register/scope/scope_pass_registry.cc deleted file mode 100644 index fc54fc55f1d960f83be58e6662165052929da4ed..0000000000000000000000000000000000000000 --- a/register/scope/scope_pass_registry.cc +++ /dev/null @@ -1,142 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/scope/scope_pass_registry_impl.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" - -namespace ge { -struct CreatePassFnPack { - bool is_enable; - ScopeFusionPassRegistry::CreateFn create_fn; -}; - -void ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::RegisterScopeFusionPass( - const std::string &pass_name, ScopeFusionPassRegistry::CreateFn create_fn, bool is_general) { - const std::lock_guard lock(mu_); - const auto iter = std::find(pass_names_.begin(), pass_names_.end(), pass_name); - if (iter != pass_names_.end()) { - GELOGW("[Register][Check] ScopeFusionPass %s already exists and will not be overwritten", - pass_name.c_str()); - return; - } - - CreatePassFnPack create_fn_pack; - create_fn_pack.is_enable = is_general; - create_fn_pack.create_fn = create_fn; - create_fn_packs_[pass_name] = create_fn_pack; - pass_names_.push_back(pass_name); - GELOGI("Register pass name = %s, is_enable = %s.", pass_name.c_str(), is_general ? "true" : "false"); -} - -ScopeFusionPassRegistry::CreateFn ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::GetCreateFn( - const std::string &pass_name) { - const std::lock_guard lock(mu_); - const auto it = create_fn_packs_.find(pass_name); - if (it == create_fn_packs_.end()) { - GELOGW("[Get][CreateFun] ScopeFusionPass %s not registered", pass_name.c_str()); - return nullptr; - } - - CreatePassFnPack &create_fn_pack = it->second; - if (create_fn_pack.is_enable) { - return create_fn_pack.create_fn; - } else { - GELOGW("[Get][CreateFun] ScopeFusionPass %s is disabled", pass_name.c_str()); - return nullptr; - } -} - -std::vector ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::GetAllRegisteredPasses() { - const std::lock_guard lock(mu_); - std::vector all_passes; - for (size_t i = 0U; i < pass_names_.size(); ++i) { - if (create_fn_packs_[pass_names_[i]].is_enable) { - all_passes.push_back(pass_names_[i]); - } - } - - return all_passes; -} - -bool ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::SetPassEnableFlag( - const std::string pass_name, const bool flag) { - const std::lock_guard lock(mu_); - const auto it = create_fn_packs_.find(pass_name); - if (it == create_fn_packs_.end()) { - GELOGW("[Set][EnableFlag] ScopeFusionPass %s not registered", pass_name.c_str()); - return false; - } - - CreatePassFnPack &create_fn_pack = it->second; - create_fn_pack.is_enable = flag; - GELOGI("enable flag of scope fusion pass:%s is set with %s.", pass_name.c_str(), flag ? "true" : "false"); - - return true; -} - -std::unique_ptr ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl::CreateScopeFusionPass( - const std::string &pass_name) { - const auto create_fn = GetCreateFn(pass_name); - if (create_fn == nullptr) { - GELOGD("Create scope fusion pass failed, pass name = %s.", pass_name.c_str()); - return nullptr; - } - GELOGI("Create scope fusion pass, pass name = %s.", pass_name.c_str()); - return std::unique_ptr(create_fn()); -} - -ScopeFusionPassRegistry::ScopeFusionPassRegistry() { - impl_ = ge::ComGraphMakeUnique(); -} - -ScopeFusionPassRegistry::~ScopeFusionPassRegistry() = default; - -ScopeFusionPassRegistry& ScopeFusionPassRegistry::GetInstance() { - static ScopeFusionPassRegistry instance; - return instance; -} - -void ScopeFusionPassRegistry::RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, - bool is_general) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to register %s, ScopeFusionPassRegistry is not properly initialized.", - pass_name.c_str()); - return; - } - impl_->RegisterScopeFusionPass(pass_name, create_fn, is_general); -} - -void ScopeFusionPassRegistry::RegisterScopeFusionPass(const char_t *pass_name, CreateFn create_fn, - bool is_general) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to register %s, ScopeFusionPassRegistry is not properly initialized.", - pass_name); - return; - } - std::string str_pass_name; - if (pass_name != nullptr) { - str_pass_name = pass_name; - } - impl_->RegisterScopeFusionPass(str_pass_name, create_fn, is_general); -} - -ScopeFusionPassRegistrar::ScopeFusionPassRegistrar(const char_t *pass_name, ScopeBasePass *(*create_fn)(), - bool is_general) { - if (pass_name == nullptr) { - GELOGE(PARAM_INVALID, "Failed to register scope fusion pass, pass name is null."); - return; - } - - ScopeFusionPassRegistry::GetInstance().RegisterScopeFusionPass(pass_name, create_fn, is_general); -} -} // namespace ge diff --git a/register/scope/scope_pattern.cc b/register/scope/scope_pattern.cc deleted file mode 100644 index 727ec2105aa6e8429d0e94e1ffeeeb8eb7ef4eb8..0000000000000000000000000000000000000000 --- a/register/scope/scope_pattern.cc +++ /dev/null @@ -1,542 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/scope/scope_pattern_impl.h" -#include "register/scope/scope_graph_impl.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/attr_utils.h" -#include "graph/types.h" - -namespace ge { -ScopeAttrValue::ScopeAttrValue() { - impl_ = ge::ComGraphMakeUnique(); -} - -ScopeAttrValue::ScopeAttrValue(ScopeAttrValue const &attr_value) { - impl_ = ge::ComGraphMakeUnique(); - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "ScopeAttrValue is not properly initialized."); - return; - } - impl_->SetIntValue(attr_value.impl_->GetIntValue()); - impl_->SetFloatValue(attr_value.impl_->GetFloatValue()); - impl_->SetStringValue(attr_value.impl_->GetStrValue()); - impl_->SetBoolValue(attr_value.impl_->GetBoolValue()); -} - -ScopeAttrValue &ScopeAttrValue::operator=(ScopeAttrValue const &attr_value) { - if (&attr_value == this) { - return *this; - } - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "ScopeAttrValue is not properly initialized."); - return *this; - } - impl_->SetIntValue(attr_value.impl_->GetIntValue()); - impl_->SetFloatValue(attr_value.impl_->GetFloatValue()); - impl_->SetStringValue(attr_value.impl_->GetStrValue()); - impl_->SetBoolValue(attr_value.impl_->GetBoolValue()); - return *this; -} - -ScopeAttrValue::~ScopeAttrValue() = default; - -void ScopeAttrValue::SetIntValue(int64_t value) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetIntValue(), ScopeAttrValue is not properly initialized."); - return; - } - impl_->SetIntValue(value); -} - -void ScopeAttrValue::SetFloatValue(float32_t value) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetFloatValue(), ScopeAttrValue is not properly initialized."); - return; - } - impl_->SetFloatValue(value); -} - -void ScopeAttrValue::SetStringValue(std::string value) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetStringValue(), ScopeAttrValue is not properly initialized."); - return; - } - impl_->SetStringValue(value); -} - -void ScopeAttrValue::SetStringValue(const char_t *value) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetStringValue(), ScopeAttrValue is not properly initialized."); - return; - } - std::string str_value; - if (value != nullptr) { - str_value = value; - } - impl_->SetStringValue(str_value); -} - -void ScopeAttrValue::SetBoolValue(bool value) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetBoolValue(), ScopeAttrValue is not properly initialized."); - return; - } - impl_->SetBoolValue(value); -} - -bool NodeOpTypeFeature::NodeOpTypeFeatureImpl::Match(const Scope *const scope) { - if (scope == nullptr) { - GELOGE(PARAM_INVALID, "Input scope is nullptr."); - return false; - } - auto &impl = scope->impl_; - AscendString scope_name; - (void) scope->Name(scope_name); - if (step_ == 0) { - if (impl->GetOpTypeNum(node_type_) == num_) { - GELOGI("NodeOpTypeFeature, node type:%s, num:%ld, match scope:%s", node_type_.c_str(), num_, - scope_name.GetString()); - return true; - } - } else { - if ((impl->GetOpTypeNum(node_type_) != -1) && ((impl->GetOpTypeNum(node_type_) % step_) == num_)) { - GELOGI("NodeOpTypeFeature, node type:%s, num:%ld, match scope:%s", node_type_.c_str(), num_, - scope_name.GetString()); - return true; - } - } - - return false; -} - -NodeOpTypeFeature::NodeOpTypeFeature(std::string nodeType, int32_t num, int32_t step) - : ScopeBaseFeature() { - impl_ = ge::ComGraphMakeUnique(nodeType, num, step); -} - -NodeOpTypeFeature::NodeOpTypeFeature(const char_t *node_type, int32_t num, int32_t step) - : ScopeBaseFeature() { - std::string op_type; - if (node_type != nullptr) { - op_type = node_type; - } - impl_ = ge::ComGraphMakeUnique(op_type, num, step); -} - -NodeOpTypeFeature::NodeOpTypeFeature(NodeOpTypeFeature const &feature) : ScopeBaseFeature() { - impl_ = ge::ComGraphMakeUnique(feature.impl_->node_type_, feature.impl_->num_, - feature.impl_->step_); -} - -NodeOpTypeFeature &NodeOpTypeFeature::operator=(NodeOpTypeFeature const &feature) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "NodeOpTypeFeature is not properly initialized."); - return *this; - } - - if (&feature == this) { - return *this; - } - - impl_->node_type_ = feature.impl_->node_type_; - impl_->num_ = feature.impl_->num_; - impl_->step_ = feature.impl_->step_; - return *this; -} - -NodeOpTypeFeature::~NodeOpTypeFeature() = default; - -bool NodeOpTypeFeature::Match(const Scope *scope) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke Match(), NodeOpTypeFeature is not properly initialized."); - return false; - } - - return impl_->Match(scope); -} - -bool NodeAttrFeature::NodeAttrFeatureImpl::Match(const Scope *scope) { - if ((scope == nullptr) || (scope->impl_ == nullptr)) { - GELOGE(ge::PARAM_INVALID, "Input scope is nullptr."); - return false; - } - auto &impl = scope->impl_; - const std::vector &nodes = impl->Nodes(); - for (auto &node_op : nodes) { - ge::AscendString op_type; - (void) node_op->GetOpType(op_type); - if (op_type.GetString() != node_type_) { - continue; - } - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(*node_op); - if (op_desc == nullptr) { - GELOGE(ge::PARAM_INVALID, "Op desc is nullptr."); - return false; - } - - Status result = SUCCESS; - switch (datatype_) { - case ge::DT_FLOAT: - result = CheckNodeAttrFeatureData(0.0F, op_desc, scope); - break; - case ge::DT_INT32: - result = CheckNodeAttrFeatureData(static_cast(0), op_desc, scope); - break; - case ge::DT_STRING: - result = CheckNodeAttrFeatureData("", op_desc, scope); - break; - case ge::DT_BOOL: - result = CheckNodeAttrFeatureData(false, op_desc, scope); - break; - default: - break; - } - if (result != FAILED) { - return (result == PARAM_INVALID) ? false : true; - } - } - return false; -} - -Status NodeAttrFeature::NodeAttrFeatureImpl::CheckNodeAttrFeatureData(const bool init_value, - const ge::OpDescPtr &op_desc, - const Scope *const scope) { - bool value = init_value; - if (!ge::AttrUtils::GetBool(op_desc, attr_name_, value)) { - GELOGE(ge::PARAM_INVALID, "op:%s %s attr is null", op_desc->GetName().c_str(), attr_name_.c_str()); - return PARAM_INVALID; - } - if (attr_value_.impl_->GetBoolValue() == value) { - AscendString scope_name; - (void) scope->Name(scope_name); - GELOGI("NodeAttrFeature, match scope:%s", scope_name.GetString()); - return SUCCESS; - } - return FAILED; -} - -Status NodeAttrFeature::NodeAttrFeatureImpl::CheckNodeAttrFeatureData(const std::string &init_value, - const ge::OpDescPtr &op_desc, - const Scope *const scope) { - std::string value = init_value; - if (!ge::AttrUtils::GetStr(op_desc, attr_name_, value)) { - GELOGE(ge::PARAM_INVALID, "op:%s %s attr is null", op_desc->GetName().c_str(), attr_name_.c_str()); - return PARAM_INVALID; - } - if (attr_value_.impl_->GetStrValue() == value) { - AscendString scope_name; - (void) scope->Name(scope_name); - GELOGI("NodeAttrFeature, match scope:%s", scope_name.GetString()); - return SUCCESS; - } - return FAILED; -} - -Status NodeAttrFeature::NodeAttrFeatureImpl::CheckNodeAttrFeatureData(const int64_t init_value, - const ge::OpDescPtr &op_desc, - const Scope *const scope) { - int64_t value = init_value; - if (!ge::AttrUtils::GetInt(op_desc, attr_name_, value)) { - GELOGE(ge::PARAM_INVALID, "op:%s %s attr is null", op_desc->GetName().c_str(), attr_name_.c_str()); - return PARAM_INVALID; - } - if (attr_value_.impl_->GetIntValue() == value) { - AscendString scope_name; - (void) scope->Name(scope_name); - GELOGI("NodeAttrFeature, match scope:%s", scope_name.GetString()); - return SUCCESS; - } - return FAILED; -} - -Status NodeAttrFeature::NodeAttrFeatureImpl::CheckNodeAttrFeatureData(const float32_t init_value, - const ge::OpDescPtr &op_desc, - const Scope *const scope) { - float32_t value = init_value; - if (!ge::AttrUtils::GetFloat(op_desc, attr_name_, value)) { - GELOGE(ge::PARAM_INVALID, "op:%s %s attr is null", op_desc->GetName().c_str(), attr_name_.c_str()); - return PARAM_INVALID; - } - - if (FloatIsEqual(attr_value_.impl_->GetFloatValue(), value)) { - AscendString scope_name; - (void) scope->Name(scope_name); - GELOGI("NodeAttrFeature, match scope:%s", scope_name.GetString()); - return SUCCESS; - } - return FAILED; -} - -NodeAttrFeature::NodeAttrFeature(std::string nodeType, std::string attr_name, - ge::DataType datatype, ScopeAttrValue &attr_value) - : ScopeBaseFeature() { - impl_ = ge::ComGraphMakeUnique(nodeType, attr_name, datatype, attr_value); -} - -NodeAttrFeature::NodeAttrFeature(const char_t *node_type, const char_t *attr_name, - ge::DataType data_type, ScopeAttrValue &attr_value) - : ScopeBaseFeature() { - std::string str_node_type; - if (node_type != nullptr) { - str_node_type = node_type; - } - std::string str_attr_name; - if (attr_name != nullptr) { - str_attr_name = attr_name; - } - impl_ = ge::ComGraphMakeUnique(str_node_type, str_attr_name, data_type, attr_value); -} - -NodeAttrFeature::NodeAttrFeature(NodeAttrFeature const &feature) : ScopeBaseFeature() { - impl_ = ge::ComGraphMakeUnique(feature.impl_->node_type_, feature.impl_->attr_name_, - feature.impl_->datatype_, feature.impl_->attr_value_); -} - -NodeAttrFeature &NodeAttrFeature::operator=(NodeAttrFeature const &feature) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "NodeAttrFeature is not properly initialized."); - return *this; - } - if (&feature == this) { - return *this; - } - impl_->node_type_ = feature.impl_->node_type_; - impl_->attr_name_ = feature.impl_->attr_name_; - impl_->datatype_ = feature.impl_->datatype_; - impl_->attr_value_ = feature.impl_->attr_value_; - return *this; -} - -NodeAttrFeature::~NodeAttrFeature() = default; - -bool NodeAttrFeature::Match(const Scope *scope) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke Match(), NodeAttrFeature is not properly initialized."); - return false; - } - - return impl_->Match(scope); -} - -bool ScopeFeature::ScopeFeatureImpl::SubScopesMatch(const std::vector &scopes) { - int32_t count = 0; - bool sub_scope_name_matched = false; - for (auto &scp : scopes) { - AscendString scp_sub_type; - (void) scp->SubType(scp_sub_type); - if ((sub_type_.length() > 0UL) && (sub_type_ == scp_sub_type.GetString())) { - ++count; - } - if (sub_scope_name_matched) { - continue; - } - auto &sub_impl = scp->impl_; - AscendString name; - (void) scp->Name(name); - sub_scope_name_matched = (sub_scope_mask_.length() > 0UL) && (sub_scope_mask_.length() < name.GetLength()) && - (sub_impl->LastName().find(sub_scope_mask_) != std::string::npos); - } - - if ((sub_type_.length() > 0UL) && (step_ == 0) && (count != num_)) { - return false; - } - if ((sub_scope_mask_.length() > 0UL) && (!sub_scope_name_matched)) { - return false; - } - - return true; -} - -bool ScopeFeature::ScopeFeatureImpl::Match(const Scope *const scope) { - auto &impl = scope->impl_; - AscendString scope_name; - (void) scope->Name(scope_name); - GELOGD("NodeAttrFeature, match scope:%s", scope_name.GetString()); - if (suffix_.length() > scope_name.GetLength()) { - return false; - } - if (suffix_.length() > 0UL) { - const std::string &last_name = impl->LastName(); - if (suffix_ != last_name) { - return false; - } - } - - const std::vector &scopes = impl->GetAllSubScopes(); - if (SubScopesMatch(scopes)) { - GELOGI("ScopeFeature, match scope:%s", scope_name.GetString()); - return true; - } - - return false; -} - -ScopeFeature::ScopeFeature(std::string sub_type, int32_t num, std::string suffix, - std::string sub_scope_mask, int32_t step) - : ScopeBaseFeature() { - impl_ = ge::ComGraphMakeUnique(sub_type, num, suffix, sub_scope_mask, step); -} - -ScopeFeature::ScopeFeature(const char_t *sub_type, int32_t num, const char_t *suffix, - const char_t *sub_scope_mask, int32_t step) - : ScopeBaseFeature() { - std::string str_sub_type; - if (sub_type != nullptr) { - str_sub_type = sub_type; - } - std::string str_suffix; - if (suffix != nullptr) { - str_suffix = suffix; - } - std::string str_sub_scope_mask; - if (sub_scope_mask != nullptr) { - str_sub_scope_mask = sub_scope_mask; - } - impl_ = ge::ComGraphMakeUnique(str_sub_type, num, str_suffix, str_sub_scope_mask, step); -} - -ScopeFeature::ScopeFeature(ScopeFeature const &feature) : ScopeBaseFeature() { - impl_ = ge::ComGraphMakeUnique(feature.impl_->sub_type_, feature.impl_->num_, - feature.impl_->suffix_, feature.impl_->sub_scope_mask_, - feature.impl_->step_); -} - -ScopeFeature &ScopeFeature::operator=(ScopeFeature const &feature) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "ScopeFeature is not properly initialized."); - return *this; - } - if (&feature == this) { - return *this; - } - impl_->sub_type_ = feature.impl_->sub_type_; - impl_->num_ = feature.impl_->num_; - impl_->suffix_ = feature.impl_->suffix_; - impl_->sub_scope_mask_ = feature.impl_->sub_scope_mask_; - impl_->step_ = feature.impl_->step_; - return *this; -} - -ScopeFeature::~ScopeFeature() = default; - -bool ScopeFeature::Match(const Scope *scope) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke Match(), ScopeFeature is not properly initialized."); - return false; - } - - return impl_->Match(scope); -} - -bool ScopePattern::ScopePatternImpl::Match(const Scope *scope) const { - if (scope == nullptr) { - GELOGE(PARAM_INVALID, "Input scope is nullptr."); - return false; - } - for (auto feature : node_optype_features_) { - if (!feature.Match(scope)) { - return false; - } - } - - for (auto feature : node_attr_features_) { - if (!feature.Match(scope)) { - return false; - } - } - - for (auto feature : scopes_features_) { - if (!feature.Match(scope)) { - return false; - } - } - - // If there is a _Retval node in the scope, the scope will not be fused. - NodeOpTypeFeature comm_node_feature = NodeOpTypeFeature("_Retval", -1, 0); - if (!comm_node_feature.Match(scope)) { - return false; - } - - return true; -} - -void ScopePattern::ScopePatternImpl::SetSubType(const std::string &sub_type) { - sub_type_ = sub_type; -} - -void ScopePattern::ScopePatternImpl::AddNodeOpTypeFeature(NodeOpTypeFeature &feature) { - node_optype_features_.push_back(feature); -} - -void ScopePattern::ScopePatternImpl::AddNodeAttrFeature(NodeAttrFeature &feature) { - node_attr_features_.push_back(feature); -} - -void ScopePattern::ScopePatternImpl::AddScopeFeature(ScopeFeature &feature) { - scopes_features_.push_back(feature); -} - -ScopePattern::ScopePattern() { - impl_ = ge::ComGraphMakeUnique(); -} - -ScopePattern::~ScopePattern() = default; - -ScopePattern &ScopePattern::SetSubType(const std::string &sub_type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetSubType(), ScopePattern is not properly initialized."); - return *this; - } - impl_->SetSubType(sub_type); - return *this; -} - -ScopePattern &ScopePattern::SetSubType(const char_t *sub_type) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke SetSubType(), ScopePattern is not properly initialized."); - return *this; - } - std::string str_sub_type; - if (sub_type != nullptr) { - str_sub_type = sub_type; - } - impl_->SetSubType(str_sub_type); - return *this; -} - -ScopePattern &ScopePattern::AddNodeOpTypeFeature(NodeOpTypeFeature feature) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke AddNodeOpTypeFeature(), ScopePattern is not properly initialized."); - return *this; - } - impl_->AddNodeOpTypeFeature(feature); - return *this; -} - -ScopePattern &ScopePattern::AddNodeAttrFeature(NodeAttrFeature feature) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke AddNodeAttrFeature(), ScopePattern is not properly initialized."); - return *this; - } - impl_->AddNodeAttrFeature(feature); - return *this; -} - -ScopePattern &ScopePattern::AddScopeFeature(ScopeFeature feature) { - if (impl_ == nullptr) { - GELOGE(ge::MEMALLOC_FAILED, "Failed to invoke AddScopeFeature(), ScopePattern is not properly initialized."); - return *this; - } - impl_->AddScopeFeature(feature); - return *this; -} -} // namespace ge diff --git a/register/scope/scope_util.cc b/register/scope/scope_util.cc deleted file mode 100644 index fb8e03b04aadc1f1f5c1ab485290d553e7b83ca5..0000000000000000000000000000000000000000 --- a/register/scope/scope_util.cc +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/register/scope/scope_fusion_pass_register.h" -#include "common/ge_common/string_util.h" - -namespace ge { -std::string ScopeUtil::StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value) { - return ge::StringUtils::ReplaceAll(str, old_value, new_value); -} - -AscendString ScopeUtil::StringReplaceAll(const char_t *str, const char_t *old_value, const char_t *new_value) { - std::string tmp_str; - if (str != nullptr) { - tmp_str = str; - } - std::string tmp_old_value; - if (old_value != nullptr) { - tmp_old_value = old_value; - } - std::string tmp_new_value; - if (new_value != nullptr) { - tmp_new_value = new_value; - } - const std::string ret = ge::StringUtils::ReplaceAll(tmp_str, tmp_old_value, tmp_new_value); - return AscendString(ret.c_str()); -} - -void ScopeUtil::FreeScopePatterns(ScopeFusionPatterns &patterns) { - for (auto &batch_pattern : patterns) { - FreeOneBatchPattern(batch_pattern); - } - patterns.clear(); -} - -void ScopeUtil::FreeOneBatchPattern(std::vector &one_batch_pattern) { - for (auto &one_pattern : one_batch_pattern) { - if (one_pattern != nullptr) { - delete one_pattern; - one_pattern = nullptr; - } - } - one_batch_pattern.clear(); -} -} // namespace ge diff --git a/register/shape_inference.cc b/register/shape_inference.cc deleted file mode 100644 index 59d6fa31d56848b5959779e6e8a8c8c51f85dc12..0000000000000000000000000000000000000000 --- a/register/shape_inference.cc +++ /dev/null @@ -1,774 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/shape_inference.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/operator_factory_impl.h" -#include "graph/compiler_def.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/transformer_utils.h" -#include "register/op_impl_space_registry.h" -#include "common/checker.h" -#include "graph/utils/inference_rule.h" - -namespace gert { -namespace { -using Index = struct InputIndex { - size_t input_index; - size_t invalid_index_num; -}; -bool IsInputDescValid(const ge::GeTensorDesc &input_desc, size_t &invalid_index_num) { - if (input_desc.IsValid() != ge::GRAPH_SUCCESS) { - if (invalid_index_num < std::numeric_limits::max()) { - invalid_index_num++; - } - return false; - } - return true; -} - -void GetStorageShape(const ge::GeTensorDesc &input_desc, gert::StorageShape &storage_shape) { - const auto &dims = input_desc.GetOriginShape().GetDims(); - for (const auto &dim : dims) { - (void)storage_shape.MutableOriginShape().AppendDim(dim); - (void)storage_shape.MutableStorageShape().AppendDim(dim); - } -} - -void GetMinMaxStorageShape(const ge::GeTensorDesc &input_desc, gert::StorageShape &min_storage_shape, - gert::StorageShape &max_storage_shape) { - auto ge_shape = input_desc.GetShape(); - if (ge_shape.IsUnknownShape()) { - std::vector> shape_range; - (void)input_desc.GetShapeRange(shape_range); - for (size_t j = 0UL; j < shape_range.size(); ++j) { - (void)min_storage_shape.MutableOriginShape().AppendDim(shape_range[j].first); - (void)min_storage_shape.MutableStorageShape().AppendDim(shape_range[j].first); - (void)max_storage_shape.MutableOriginShape().AppendDim(shape_range[j].second); - (void)max_storage_shape.MutableStorageShape().AppendDim(shape_range[j].second); - } - } else { - const auto &dims = input_desc.GetOriginShape().GetDims(); - for (const auto &dim : dims) { - (void)min_storage_shape.MutableOriginShape().AppendDim(dim); - (void)min_storage_shape.MutableStorageShape().AppendDim(dim); - (void)max_storage_shape.MutableOriginShape().AppendDim(dim); - (void)max_storage_shape.MutableStorageShape().AppendDim(dim); - } - } -} - -ge::graphStatus GetTensorAddress(const ge::Operator &op, const ge::OpDescPtr &op_desc, - const Index &index, TensorAddress &address, - std::vector> &ge_tensors_holder) { - const auto *const space_registry = DefaultOpImplSpaceRegistry::GetInstance() - .GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); - GE_ASSERT_NOTNULL(space_registry); - - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - const size_t instance_index = index.input_index - index.invalid_index_num; - // check valid map - const auto valid_op_ir_map = ge::OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(op_desc); - if (valid_op_ir_map.empty()) { - return ge::GRAPH_PARAM_INVALID; - } - size_t ir_index; - GE_ASSERT_GRAPH_SUCCESS(ge::OpDescUtils::GetInputIrIndexByInstanceIndex(op_desc, instance_index, ir_index), - "[Get][InputIrIndexByInstanceIndex] failed, op[%s], instance index[%zu], input_index[%zu]", - op_desc->GetName().c_str(), instance_index, index.input_index); - GE_ASSERT_NOTNULL(functions); - if (functions->IsInputDataDependency(ir_index)) { - ge_tensors_holder[index.input_index] = ge::ComGraphMakeUnique(); - GE_ASSERT_NOTNULL(ge_tensors_holder[index.input_index], "Create ge tensor holder inputs failed."); - const auto index_name = op_desc->GetInputNameByIndex(static_cast(index.input_index)); - if (op.GetInputConstData(index_name.c_str(), *(ge_tensors_holder[index.input_index].get())) == ge::GRAPH_SUCCESS) { - address = ge_tensors_holder[index.input_index]->GetData(); - } - } - return ge::GRAPH_SUCCESS; -} - -bool IsTensorDependencyValid(const ge::Operator &op, const ge::OpDescPtr &op_desc, - const size_t input_index, const size_t invalid_index_num) { - const auto *const space_registry = DefaultOpImplSpaceRegistry::GetInstance() - .GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); - GE_ASSERT_NOTNULL(space_registry); - - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - GE_ASSERT_NOTNULL(functions); - const size_t instance_index = input_index - invalid_index_num; - size_t ir_index; - GE_ASSERT_GRAPH_SUCCESS(ge::OpDescUtils::GetInputIrIndexByInstanceIndex(op_desc, instance_index, ir_index), - "[Get][InputIrIndexByInstanceIndex] failed, op[%s], instance index[%zu], input_index[%zu]", - op_desc->GetName().c_str(), instance_index, input_index); - if (functions->IsInputDataDependency(ir_index)) { - ge::Tensor data; - const auto index_name = op_desc->GetInputNameByIndex(static_cast(input_index)); - if (op.GetInputConstData(index_name.c_str(), data) == ge::GRAPH_SUCCESS) { - return true; - } else { - return false; - } - } - return true; -} - -ge::graphStatus GetTensorHolder(const ge::GeTensorDesc &input_desc, const gert::StorageShape &storage_shape, - TensorAddress address, std::unique_ptr &tensor_holder) { - tensor_holder = ge::ComGraphMakeUnique(sizeof(gert::Tensor)); - GE_ASSERT_NOTNULL(tensor_holder, "Create context holder inputs failed."); - if (address == nullptr) { - new (tensor_holder.get()) - gert::Tensor(storage_shape, - {input_desc.GetOriginFormat(), input_desc.GetFormat(), {}}, - input_desc.GetDataType()); - } else { - new (tensor_holder.get()) - gert::Tensor(storage_shape, - {input_desc.GetOriginFormat(), input_desc.GetFormat(), {}}, - gert::kOnHost, input_desc.GetDataType(), address); - } - return ge::GRAPH_SUCCESS; -} - - -ge::graphStatus ConstructCompileKernelContextInputs(const ge::Operator &op, const ge::OpDescPtr &op_desc, - std::vector> &inputs, - std::vector> &ge_tensors_holder) { - size_t invalid_index_num = 0UL; - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); i++) { - if (!IsInputDescValid(op_desc->GetInputDesc(static_cast(i)), invalid_index_num)) { - GELOGD("input desc is not valid, skip add input[%zu] into context inputs.", i); - continue; - } - gert::StorageShape storage_shape; - GetStorageShape(op_desc->GetInputDesc(static_cast(i)), storage_shape); - // init tensor address, if can not get const tensor input, set it to nullptr - TensorAddress address = nullptr; - Index index; - index.input_index = i; - index.invalid_index_num = invalid_index_num; - auto status = GetTensorAddress(op, op_desc, index, address, ge_tensors_holder); - if (status != ge::GRAPH_SUCCESS) { - return status; - } - std::unique_ptr tensor_holder; - status = GetTensorHolder(op_desc->GetInputDesc(static_cast(i)), storage_shape, address, tensor_holder); - if (status != ge::GRAPH_SUCCESS) { - return status; - } - inputs.emplace_back(std::move(tensor_holder)); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ConstructInferShapeContextInputs(const ge::Operator &op, const ge::OpDescPtr &op_desc, - std::vector> &inputs, - std::vector> &ge_tensors_holder) { - GE_ASSERT_GRAPH_SUCCESS(ConstructCompileKernelContextInputs(op, op_desc, inputs, ge_tensors_holder)); - // set infer shape_func to NULL - inputs.emplace_back(nullptr); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ConstructInferShapeRangeContextInputs( - const ge::Operator &op, const ge::OpDescPtr &op_desc, std::vector> &inputs, - std::vector> &ge_tensors_holder, - std::vector> &input_tensor_ranges_holder) { - size_t invalid_index_num = 0UL; - GE_ASSERT(input_tensor_ranges_holder.size() == op_desc->GetAllInputsSize()); - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); i++) { - const auto &input_desc = op_desc->GetInputDesc(static_cast(i)); - if (!IsInputDescValid(input_desc, invalid_index_num)) { - GELOGD("input desc is not valid, skip add input[%zu] into context inputs.", i); - continue; - } - - GetMinMaxStorageShape(op_desc->GetInputDesc(static_cast(i)), - input_tensor_ranges_holder[i].first.GetShape(), - input_tensor_ranges_holder[i].second.GetShape()); - input_tensor_ranges_holder[i].first.SetOriginFormat(input_desc.GetOriginFormat()); - input_tensor_ranges_holder[i].first.SetStorageFormat(input_desc.GetFormat()); - input_tensor_ranges_holder[i].first.SetDataType(input_desc.GetDataType()); - input_tensor_ranges_holder[i].second.SetOriginFormat(input_desc.GetOriginFormat()); - input_tensor_ranges_holder[i].second.SetStorageFormat(input_desc.GetFormat()); - input_tensor_ranges_holder[i].second.SetDataType(input_desc.GetDataType()); - - // init tensor address, if can not get const tensor input, set it to nullptr - TensorAddress address = nullptr; - Index index; - index.input_index = i; - index.invalid_index_num = invalid_index_num; - const auto status = GetTensorAddress(op, op_desc, index, address, ge_tensors_holder); - if (status != ge::GRAPH_SUCCESS) { - return status; - } - std::unique_ptr tensor_range_holder = ge::ComGraphMakeUnique(sizeof(gert::TensorRange)); - GE_ASSERT_NOTNULL(tensor_range_holder, "Create context holder inputs failed."); - - if (address != nullptr) { - (void) input_tensor_ranges_holder[i].first.MutableTensorData().SetAddr(address, nullptr); - (void) input_tensor_ranges_holder[i].first.MutableTensorData().SetPlacement(gert::kOnHost); - (void) input_tensor_ranges_holder[i].second.MutableTensorData().SetAddr(address, nullptr); - input_tensor_ranges_holder[i].second.MutableTensorData().SetPlacement(gert::kOnHost); - } - new (tensor_range_holder.get()) - gert::TensorRange(&input_tensor_ranges_holder[i].first, &input_tensor_ranges_holder[i].second); - inputs.emplace_back(std::move(tensor_range_holder)); - } - // set infer shape_func to NULL - inputs.emplace_back(nullptr); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ConstructInferShapeContextInputs(const ge::Operator &op, const ge::OpDescPtr &op_desc, - std::vector> &min_inputs, - std::vector> &max_inputs, - std::vector> &ge_tensors_holder) { - size_t invalid_index_num = 0UL; - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); i++) { - if (!IsInputDescValid(op_desc->GetInputDesc(static_cast(i)), invalid_index_num)) { - GELOGD("input desc is not valid, skip add input[%zu] into context inputs.", i); - continue; - } - gert::StorageShape min_storage_shape; - gert::StorageShape max_storage_shape; - GetMinMaxStorageShape(op_desc->GetInputDesc(static_cast(i)), min_storage_shape, max_storage_shape); - - // init tensor address, if can not get const tensor input, set it to nullptr - TensorAddress address = nullptr; - Index index; - index.input_index = i; - index.invalid_index_num = invalid_index_num; - auto status = GetTensorAddress(op, op_desc, index, address, ge_tensors_holder); - if (status != ge::GRAPH_SUCCESS) { - return status; - } - std::unique_ptr min_tensor_holder; - status = GetTensorHolder(op_desc->GetInputDesc(static_cast(i)), min_storage_shape, address, min_tensor_holder); - if (status != ge::GRAPH_SUCCESS) { - return status; - } - std::unique_ptr max_tensor_holder; - status = GetTensorHolder(op_desc->GetInputDesc(static_cast(i)), max_storage_shape, address, max_tensor_holder); - if (status != ge::GRAPH_SUCCESS) { - return status; - } - min_inputs.emplace_back(std::move(min_tensor_holder)); - max_inputs.emplace_back(std::move(max_tensor_holder)); - } - // set infer shape_func to NULL - min_inputs.emplace_back(nullptr); - max_inputs.emplace_back(nullptr); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ConstructCompileKernelContextOutputs(const ge::OpDescPtr &op_desc, - std::vector> &outputs) { - auto size = op_desc->GetAllOutputsDescSize(); - while (size-- > 0) { - auto tensor_holder = ge::ComGraphMakeUnique(sizeof(gert::Tensor)); - GE_ASSERT_NOTNULL(tensor_holder, "Create context holder outputs failed, op[%s]", op_desc->GetName().c_str()); - outputs.emplace_back(std::move(tensor_holder)); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus ConstructInferShapeRangeContextOutputs( - const ge::OpDescPtr &op_desc, std::vector> &outputs, - std::vector> &output_range_holder) { - for (size_t i = 0UL; i < op_desc->GetAllOutputsDescSize(); i++) { - auto tensor_holder = ge::ComGraphMakeUnique(sizeof(Range)); - GE_ASSERT_NOTNULL(tensor_holder, "Create context holder outputs failed, op[%s]", op_desc->GetName().c_str()); - reinterpret_cast *>(tensor_holder.get())->SetMin(&(output_range_holder[i].first.MutableOriginShape())); - reinterpret_cast *>(tensor_holder.get()) - ->SetMax(&(output_range_holder[i].second.MutableOriginShape())); - outputs.emplace_back(std::move(tensor_holder)); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus UpdateOpDescOutShape(const ge::OpDescPtr &op_desc, gert::InferShapeContext *infer_shape_ctx) { - for (size_t index = 0UL; index < op_desc->GetOutputsSize(); index++) { - auto &dst_out_shape = op_desc->MutableOutputDesc(static_cast(index))->MutableShape(); - const auto *shape = infer_shape_ctx->GetOutputShape(index); - GE_ASSERT_NOTNULL(shape); - dst_out_shape.SetDimNum(shape->GetDimNum()); - for (size_t dim = 0UL; dim < shape->GetDimNum(); dim++) { - (void)dst_out_shape.SetDim(dim, shape->GetDim(dim)); - } - op_desc->MutableOutputDesc(static_cast(index))->SetOriginShape(dst_out_shape); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus AddRange(std::vector> &shape_range, - const gert::Shape * const min_shape, const gert::Shape * const max_shape) { - for (size_t i = 0UL; i < min_shape->GetDimNum(); ++i) { - GELOGD("min dim:%ld, max dim:%ld", min_shape->GetDim(i), max_shape->GetDim(i)); - if (max_shape->GetDim(i) != -1) { - GE_CHECK_LE(min_shape->GetDim(i), max_shape->GetDim(i)); - } - shape_range.emplace_back(std::make_pair(min_shape->GetDim(i), max_shape->GetDim(i))); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus UpdateOpDescOutShapeRange(const ge::OpDescPtr &op_desc, gert::InferShapeContext *min_ctx, - gert::InferShapeContext *max_ctx) { - for (size_t index = 0UL; index < op_desc->GetOutputsSize(); index++) { - auto output_desc = op_desc->MutableOutputDesc(static_cast(index)); - auto ge_shape = output_desc->GetShape(); - if (ge_shape.IsUnknownShape()) { - std::vector> shape_range; - const auto *min_shape = min_ctx->GetOutputShape(index); - const auto *max_shape = max_ctx->GetOutputShape(index); - GE_ASSERT_NOTNULL(min_shape); - GE_ASSERT_NOTNULL(max_shape); - GELOGD("min dim num:%zu, max dim num:%zu", min_shape->GetDimNum(), max_shape->GetDimNum()); - GE_RETURN_WITH_LOG_IF_TRUE((min_shape->GetDimNum()) != (max_shape->GetDimNum())); - const auto ret = AddRange(shape_range, min_shape, max_shape); - if (ret != ge::GRAPH_SUCCESS) { - return ret; - } - (void)output_desc->SetShapeRange(shape_range); - } - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus UpdateOpDescOutShapeRange(const ge::OpDescPtr &op_desc, - gert::InferShapeRangeContext *infer_shape_range_ctx) { - for (size_t i = 0UL; i < op_desc->GetOutputsSize(); ++i) { - const auto &output_tensor = op_desc->MutableOutputDesc(static_cast(i)); - std::vector> shape_range; - const auto out_range = infer_shape_range_ctx->GetOutputShapeRange(i); - GE_ASSERT_NOTNULL(out_range, "out range is nullptr."); - GE_ASSERT_NOTNULL(out_range->GetMax(), "out range max is nullptr."); - GE_ASSERT_NOTNULL(out_range->GetMin(), "out range min is nullptr."); - for (size_t j = 0UL; j < out_range->GetMax()->GetDimNum(); ++j) { - shape_range.emplace_back(std::make_pair(out_range->GetMin()->GetDim(j), out_range->GetMax()->GetDim(j))); - } - (void)output_tensor->SetShapeRange(shape_range); - } - return ge::GRAPH_SUCCESS; -} - -void ConstructDataTypeContextInputs(const ge::OpDescPtr &op_desc, std::vector &inputs) { - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); ++i) { - const auto &compile_tensor = op_desc->MutableInputDesc(static_cast(i)); - if (compile_tensor == nullptr) { - GELOGD("OpDesc[%s]type[%s], input desc[%zu] is nullptr, skip constructing rt2 ctx for it.", op_desc->GetNamePtr(), - op_desc->GetTypePtr(), i); - continue; - } - inputs.emplace_back(reinterpret_cast(compile_tensor->GetDataType())); - } -} - -void ConstructDataTypeContextOutputs(const ge::OpDescPtr &op_desc, std::vector &outputs) { - for (size_t i = 0UL; i < op_desc->GetAllOutputsDescSize(); i++) { - const auto &compile_tensor = op_desc->GetOutputDesc(static_cast(i)); - outputs.emplace_back(reinterpret_cast(compile_tensor.GetDataType())); - } -} - -// inputs layout is input tensors -std::vector GetInputs(const std::vector> &inputs_holders) { - std::vector inputs; - inputs.reserve(inputs_holders.size()); - for (const auto &input_holder : inputs_holders) { - inputs.emplace_back(input_holder.get()); - } - return inputs; -} - -std::vector GetInputs(const ge::Operator &op, const std::vector> &inputs_holders) { - std::vector inputs; - inputs.reserve(inputs_holders.size() + 1UL); - for (const auto &input_holder : inputs_holders) { - inputs.emplace_back(input_holder.get()); - } - // inputs layout is input tensors + infer func + inference context ptr - inputs.emplace_back(op.GetInferenceContext().get()); - return inputs; -} - -std::vector GetOutputs(const std::vector> &outputs_holders) { - std::vector outputs; - outputs.reserve(outputs_holders.size()); - for (const auto &output_holder : outputs_holders) { - outputs.emplace_back(output_holder.get()); - } - return outputs; -} - -bool NeedInferShapeRange(const ge::Operator &op, const ge::OpDescPtr &op_desc) { - bool need_infer = false; - size_t invalid_index_num = 0UL; - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); ++i) { - const auto &input_desc = op_desc->GetInputDesc(static_cast(i)); - if (!IsInputDescValid(input_desc, invalid_index_num)) { - GELOGD("input desc is not valid, skip add input[%zu] into context inputs.", i); - continue; - } - auto ge_shape = input_desc.GetShape(); - if (ge_shape.IsUnknownShape()) { - std::vector> shape_range; - need_infer = true; - (void)input_desc.GetShapeRange(shape_range); - if (shape_range.size() == 0UL) { - GELOGD("No need to infer shape range, because shape is unknown shape but no shape range, input[%zu].", i); - return false; - } - if (!IsTensorDependencyValid(op, op_desc, i, invalid_index_num)) { - GELOGD("No need to infer shape range, because dependency tensor is not const, input[%zu].", i); - return false; - } - } - } - return need_infer; -} - -ge::graphStatus InferShapeRangeCustom(const ge::Operator &op, const ge::OpDescPtr &op_desc, - OpImplRegisterV2::InferShapeRangeKernelFunc const infer_shape_range) { - std::vector> inputs_holder; - std::vector> outputs_holder; - std::vector> ge_tensors_holder; - std::vector> input_tensor_range_holder; - std::vector> output_range_holder; - - ge_tensors_holder.resize(op_desc->GetAllInputsSize()); - input_tensor_range_holder.resize(static_cast(op_desc->GetAllInputsSize())); - output_range_holder.resize(static_cast(op_desc->GetAllOutputsDescSize())); - GE_ASSERT_GRAPH_SUCCESS( - ConstructInferShapeRangeContextInputs(op, op_desc, inputs_holder, ge_tensors_holder, input_tensor_range_holder), - "[Construct][InferShapeContextInputs] failed, op_desc[%s]", op_desc->GetName().c_str()); - GE_ASSERT_GRAPH_SUCCESS(ConstructInferShapeRangeContextOutputs(op_desc, outputs_holder, output_range_holder), - "[Construct][InferShapeContextOutputs] failed, op_desc[%s]", op_desc->GetName().c_str()); - const auto kernel_context_holder = gert::KernelRunContextBuilder() - .Inputs(GetInputs(op, inputs_holder)).Outputs(GetOutputs(outputs_holder)).Build(op_desc); - auto infer_shape_range_ctx = reinterpret_cast(kernel_context_holder.context_); - const auto ret = infer_shape_range(infer_shape_range_ctx); - GE_CHK_STATUS_RET(ret, "[Call][InferShapeRange] failed, ret[%d]", ret); - (void)UpdateOpDescOutShapeRange(op_desc, infer_shape_range_ctx); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus InferShapeRangeAutomaticly(const ge::Operator &op, const ge::OpDescPtr &op_desc, - OpImplRegisterV2::InferShapeKernelFunc const infer_shape) { - GELOGD("Need to infer shape range, op[%s]", op_desc->GetName().c_str()); - std::vector> min_inputs_holder; - std::vector> max_inputs_holder; - std::vector> min_outputs_holder; - std::vector> max_outputs_holder; - std::vector> ge_tensors_holder; - ge_tensors_holder.resize(op_desc->GetAllInputsSize()); - GE_ASSERT_GRAPH_SUCCESS( - ConstructInferShapeContextInputs(op, op_desc, min_inputs_holder, max_inputs_holder, ge_tensors_holder), - "[Construct][InferShapeRangeAutomaticly] failed, op_desc[%s]", op_desc->GetName().c_str()); - GE_ASSERT_GRAPH_SUCCESS(ConstructCompileKernelContextOutputs(op_desc, min_outputs_holder), - "[Construct][InferShapeRangeAutomaticly] failed, op_desc[%s]", op_desc->GetName().c_str()); - GE_ASSERT_GRAPH_SUCCESS(ConstructCompileKernelContextOutputs(op_desc, max_outputs_holder), - "[Construct][InferShapeRangeAutomaticly] failed, op_desc[%s]", op_desc->GetName().c_str()); - // min output - const auto min_kernel_context_holder = gert::KernelRunContextBuilder() - .Inputs(GetInputs(op, min_inputs_holder)).Outputs(GetOutputs(min_outputs_holder)).Build(op_desc); - auto min_infer_shape_ctx = reinterpret_cast(min_kernel_context_holder.context_); - auto ret = infer_shape(min_infer_shape_ctx); - GE_CHK_STATUS_RET(ret, "[InferV2][MinShape] failed, op_desc[%s], ret[%d]", op_desc->GetName().c_str(), ret); - // max output - const auto max_kernel_context_holder = gert::KernelRunContextBuilder() - .Inputs(GetInputs(op, max_inputs_holder)).Outputs(GetOutputs(max_outputs_holder)).Build(op_desc); - auto max_infer_shape_ctx = reinterpret_cast(max_kernel_context_holder.context_); - ret = infer_shape(max_infer_shape_ctx); - GE_CHK_STATUS_RET(ret, "[InferV2][MaxShape] failed, op_desc[%s], ret[%d]", op_desc->GetName().c_str(), ret); - ret = UpdateOpDescOutShapeRange(op_desc, min_infer_shape_ctx, max_infer_shape_ctx); - return ret; -} - -ge::graphStatus UpdateOpDescOutFormat(const ge::OpDescPtr &op_desc, gert::InferFormatContext *infer_format_ctx) { - size_t in_index = 0UL; - for (size_t index = 0UL; index < op_desc->GetInputsSize(); index++) { - const auto desc = op_desc->MutableInputDesc(static_cast(index)); - if (desc == nullptr) { - continue; - } - const auto format = infer_format_ctx->GetInputFormat(in_index++); - GE_ASSERT_NOTNULL(format); - desc->SetOriginFormat(format->GetOriginFormat()); - desc->SetFormat(format->GetStorageFormat()); - } - - size_t out_index = 0UL; - for (size_t index = 0UL; index < op_desc->GetOutputsSize(); index++) { - const auto desc = op_desc->MutableOutputDesc(static_cast(index)); - if (desc == nullptr) { - continue; - } - const auto format = infer_format_ctx->GetOutputFormat(out_index++); - GE_ASSERT_NOTNULL(format); - desc->SetOriginFormat(format->GetOriginFormat()); - desc->SetFormat(format->GetStorageFormat()); - } - return ge::GRAPH_SUCCESS; -} - - -ge::graphStatus InferShapeByRegisteredFuncOrRule(const OpImplKernelRegistry::OpImplFunctionsV2 *functions, - const ge::OpDescPtr &op_desc, - gert::InferShapeContext *infer_shape_ctx) { - if (functions && functions->infer_shape) { - if (functions->IsOutputShapeDependOnCompute()) { - GELOGD("OpDesc %s(%s) is third class operator", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - (void) ge::AttrUtils::SetInt(op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, - static_cast(ge::DEPEND_SHAPE_RANGE)); - } - GELOGD("Infer shape for %s[%s] by registered func", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return functions->infer_shape(infer_shape_ctx); - } - const auto shape_infer_rule = ge::ShapeInferenceRule::FromOpDesc(op_desc); - if (shape_infer_rule == nullptr) { - REPORT_INNER_ERR_MSG("EZ9999", - "Can not find infer_shape func of node %s[%s]. Please confirm whether the op_proto shared " - "library (.so) has been loaded " - "successfully, and that you have already developed the infer_shape func.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - GELOGE(ge::GRAPH_FAILED, - "Can not find infer_shape func of node %s[%s]. Please confirm whether the op_proto shared library (.so) " - "has been loaded " - "successfully, and that you have already developed the infer_shape func.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return ge::GRAPH_FAILED; - } - if (!shape_infer_rule->IsValid()) { - REPORT_INNER_ERR_MSG( - "EZ9999", - "No infer shape func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), ge::InferenceRule::GetInferenceRule(op_desc).c_str(), - shape_infer_rule->Error().c_str()); - GELOGE(ge::GRAPH_FAILED, - "No infer shape func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), ge::InferenceRule::GetInferenceRule(op_desc).c_str(), - shape_infer_rule->Error().c_str()); - return ge::GRAPH_FAILED; - } - GELOGD("Infer shape for %s[%s] by inference rule", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return shape_infer_rule->InferOnCompile(infer_shape_ctx); -} - -ge::graphStatus InferDtypeByRegisteredFuncOrRule(const OpImplKernelRegistry::OpImplFunctionsV2 *functions, - const ge::OpDescPtr &op_desc, - gert::InferDataTypeContext *infer_dtype_ctx) { - if (functions && functions->infer_datatype) { - GELOGD("Infer dtype for %s[%s] by registered func", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return functions->infer_datatype(infer_dtype_ctx); - } - const auto dtype_infer_rule = ge::DtypeInferenceRule::FromOpDesc(op_desc); - if (dtype_infer_rule == nullptr) { - REPORT_INNER_ERR_MSG("EZ9999", - "Can not find Node %s[%s] custom infer_datatype func. Please confirm whether the op_proto " - "shared library (.so) has been " - "loaded successfully, and that you have already developed the infer_datatype func or marked " - "the T-derivation rules on the IR.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - GELOGE(ge::GRAPH_FAILED, - "Can not find Node %s[%s] custom infer_datatype func. Please confirm whether the op_proto shared library " - "(.so) has been " - "loaded successfully, and that you have already developed the infer_datatype func or marked " - "the T-derivation rules on the IR.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return ge::GRAPH_FAILED; - } - if (!dtype_infer_rule->IsValid()) { - REPORT_INNER_ERR_MSG( - "EZ9999", - "No infer dtype func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), ge::InferenceRule::GetInferenceRule(op_desc).c_str(), - dtype_infer_rule->Error().c_str()); - GELOGE(ge::GRAPH_FAILED, - "No infer dtype func registered for node %s[%s], and inference rule: %s is set but failed to parse: %s.", - op_desc->GetNamePtr(), op_desc->GetTypePtr(), ge::InferenceRule::GetInferenceRule(op_desc).c_str(), - dtype_infer_rule->Error().c_str()); - return ge::GRAPH_FAILED; - } - GELOGD("Infer dtype for %s[%s] by inference rule", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return dtype_infer_rule->InferDtype(infer_dtype_ctx); -} -} - -ge::graphStatus InferShapeRangeOnCompile(const ge::Operator &op, const ge::OpDescPtr &op_desc) { - if (!NeedInferShapeRange(op, op_desc)) { - GELOGD("No need to infer shape range, op[%s]", op_desc->GetName().c_str()); - return ge::GRAPH_SUCCESS; - } - - const auto *const space_registry = DefaultOpImplSpaceRegistry::GetInstance() - .GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); - GE_ASSERT_NOTNULL(space_registry); - - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - GE_ASSERT_NOTNULL(functions); - if (functions->infer_shape_range != nullptr) { - GELOGD("Op[%s], type[%s] use custom derivation strategy.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return InferShapeRangeCustom(op, op_desc, functions->infer_shape_range); - } else if (functions->infer_shape != nullptr) { - GELOGD("Can not get infer shape range func op[%s], type[%s], will use an automatic derivation strategy.", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - return InferShapeRangeAutomaticly(op, op_desc, functions->infer_shape); - } else { - GELOGE(ge::PARAM_INVALID, "infer_shape_range and infer_shape is nullptr."); - return ge::PARAM_INVALID; - } -} - -ge::graphStatus InferShapeOnCompile(const ge::Operator &op, const ge::OpDescPtr &op_desc) { - const auto *const space_registry = - DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); - GE_ASSERT_NOTNULL(space_registry); - - ge::NodeShapeTransUtils transformer(op_desc); - GE_CHK_BOOL_RET_STATUS(transformer.Init(), ge::GRAPH_FAILED, "Failed to init transformer for %s", - op_desc->GetNamePtr()); - GE_CHK_BOOL_RET_STATUS(transformer.CatchFormatAndShape(), ge::GRAPH_FAILED, "Failed to catch format and shape for %s", - op_desc->GetNamePtr()); - std::vector> inputs_holder; - std::vector> outputs_holder; - std::vector> ge_tensors_holder; - ge_tensors_holder.resize(op_desc->GetAllInputsSize()); - auto ret = ConstructInferShapeContextInputs(op, op_desc, inputs_holder, ge_tensors_holder); - if (ret == ge::GRAPH_PARAM_INVALID) { - return ret; - } - GE_ASSERT_GRAPH_SUCCESS(ret, "[Construct][InferShapeContextInputs] failed, op_desc[%s]", op_desc->GetName().c_str()); - GE_ASSERT_GRAPH_SUCCESS(ConstructCompileKernelContextOutputs(op_desc, outputs_holder), - "[Construct][InferShapeContextOutputs] failed, op_desc[%s]", op_desc->GetName().c_str()); - const auto kernel_context_holder = gert::KernelRunContextBuilder() - .Inputs(GetInputs(op, inputs_holder)) - .Outputs(GetOutputs(outputs_holder)) - .Build(op_desc); - auto infer_shape_ctx = reinterpret_cast(kernel_context_holder.context_); - - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - - ret = InferShapeByRegisteredFuncOrRule(functions, op_desc, infer_shape_ctx); - GE_CHK_STATUS_RET(ret, "[Call][InferShapeV2Func] failed, op_desc[%s], ret[%d]", op_desc->GetName().c_str(), ret); - - GE_ASSERT_GRAPH_SUCCESS(UpdateOpDescOutShape(op_desc, infer_shape_ctx), - "UpdateOpDescOutShape failed, OutputShape is nullptr. op_desc[%s]", - op_desc->GetName().c_str()); - GE_CHK_BOOL_RET_STATUS(transformer.UpdateFormatAndShape(), ge::GRAPH_FAILED, - "Failed to update format and shape for %s", op_desc->GetNamePtr()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus InferDataTypeOnCompile(const ge::OpDescPtr &op_desc) { - const auto *const space_registry = - DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); - if (space_registry == nullptr) { - GELOGW("Default space registry has not been initialized!"); - if (op_desc->IsSupportSymbolicInferDataType()) { - return op_desc->SymbolicInferDataType(); - } - GELOGW("Space_registry is null, neither Node %s[%s] not support symbolic infer datatype. Please declare symbol T " - "on IR or check Space_registry.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return ge::GRAPH_FAILED; - } - - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - if ((!functions || !functions->infer_datatype) && op_desc->IsSupportSymbolicInferDataType()) { - GELOGD("Infer dtype for %s[%s] by ir symbol", op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return op_desc->SymbolicInferDataType(); - } - - std::vector inputs; - std::vector outputs; - ConstructDataTypeContextInputs(op_desc, inputs); - ConstructDataTypeContextOutputs(op_desc, outputs); - const auto kernel_context_holder = gert::KernelRunContextBuilder().Inputs(inputs).Outputs(outputs).Build(op_desc); - const auto kernel_context = reinterpret_cast(kernel_context_holder.context_); - - ge::graphStatus ret = InferDtypeByRegisteredFuncOrRule(functions, op_desc, kernel_context); - GE_CHK_STATUS_RET(ret, "[Check][InferDataType] result failed, op_desc[%s], ret[%d]", op_desc->GetName().c_str(), ret); - for (size_t i = 0UL; i < op_desc->GetOutputsSize(); i++) { - const auto &out_desc = op_desc->MutableOutputDesc(static_cast(i)); - out_desc->SetDataType(kernel_context->GetOutputDataType(i)); - } - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus InferFormatOnCompile(const ge::Operator &op, const ge::OpDescPtr &op_desc) { - const auto *const space_registry = - DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); - GE_ASSERT_NOTNULL(space_registry); - - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - if ((functions == nullptr) || (functions->infer_format_func == nullptr)) { - REPORT_INNER_ERR_MSG("EZ9999", - "Can not find infer_format func of node %s[%s]. Please confirm whether the op_proto shared " - "library (.so) has been loaded " - "successfully, and that you have already developed the infer_format func.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - GELOGE( - ge::GRAPH_FAILED, - "Can not find infer_format func of node %s[%s]. Please confirm whether the op_proto shared library (.so) has been loaded " - "successfully, and that you have already developed the infer_format func.", - op_desc->GetNamePtr(), op_desc->GetTypePtr()); - return ge::GRAPH_FAILED; - } - - std::vector> inputs_holder; - std::vector> outputs_holder; - std::vector> ge_tensors_holder; - ge_tensors_holder.resize(op_desc->GetAllInputsSize()); - GE_ASSERT_GRAPH_SUCCESS(ConstructCompileKernelContextInputs(op, op_desc, inputs_holder, ge_tensors_holder), - "[Construct][InferFormatContextInputs] failed, op_desc[%s]", op_desc->GetName().c_str()); - GE_ASSERT_GRAPH_SUCCESS(ConstructCompileKernelContextOutputs(op_desc, outputs_holder), - "[Construct][InferShapeContextOutputs] failed, op_desc[%s]", op_desc->GetName().c_str()); - const auto kernel_context_holder = gert::KernelRunContextBuilder() - .Inputs(GetInputs(inputs_holder)) - .Outputs(GetOutputs(outputs_holder)) - .Build(op_desc); - const auto infer_format_ctx = reinterpret_cast(kernel_context_holder.context_); - const auto ret = functions->infer_format_func(infer_format_ctx); - GE_CHK_STATUS_RET(ret, "[Call][InferFormatV2Func] failed, op_desc[%s], ret[%d]", op_desc->GetName().c_str(), ret); - GE_ASSERT_GRAPH_SUCCESS(UpdateOpDescOutFormat(op_desc, infer_format_ctx), "UpdateOpDescOutFormat failed for op[%s]", - op_desc->GetName().c_str()); - return ge::GRAPH_SUCCESS; -} - -bool IsInferFormatV2Registered(const ge::OpDescPtr &op_desc) { - const auto *const space_registry = - gert::DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(op_desc->GetOppImplVersion()).get(); - if (space_registry != nullptr) { - const auto &functions = space_registry->GetOpImpl(op_desc->GetType()); - if ((functions != nullptr) && (functions->infer_format_func != nullptr)) { - return true; - } - } - return false; -} - -class CompileAdaptFunctionsRegister { - public: - CompileAdaptFunctionsRegister() { - // only infer shape is necessary, as register all infer func in infer shape - (void) ge::OperatorFactoryImpl::RegisterInferShapeV2Func(&gert::InferShapeOnCompile); - (void) ge::OperatorFactoryImpl::RegisterInferShapeRangeFunc(&gert::InferShapeRangeOnCompile); - (void) ge::OperatorFactoryImpl::RegisterInferDataTypeFunc(&gert::InferDataTypeOnCompile); - (void) ge::OperatorFactoryImpl::RegisterInferFormatV2Func(&gert::InferFormatOnCompile); - (void) ge::OperatorFactoryImpl::RegisterIsInferFormatV2RegisteredFunc(&gert::IsInferFormatV2Registered); - } -}; -static CompileAdaptFunctionsRegister VAR_UNUSED g_register_adapt_funcs; -} // namespace gert diff --git a/register/stream_manage_func_registry.cc b/register/stream_manage_func_registry.cc deleted file mode 100644 index 1487dac1ad3950947c5b93917a6005144ad6ca00..0000000000000000000000000000000000000000 --- a/register/stream_manage_func_registry.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/stream_manage_func_registry.h" - -namespace ge { -StreamMngFuncRegistry &StreamMngFuncRegistry::GetInstance() { - static StreamMngFuncRegistry registry; - return registry; -} - -Status StreamMngFuncRegistry::TryCallStreamMngFunc(const StreamMngFuncType func_type, MngActionType action_type, - MngResourceHandle handle) { - StreamMngFunc callback_func = LookUpStreamMngFunc(func_type); - if (callback_func == nullptr) { - GELOGI("Stream manage func is not found, FuncType is [%u]", static_cast(func_type)); - return SUCCESS; - } - const uint32_t ret = callback_func(action_type, handle); - GELOGI("Call stream manage func, ret = %u!", ret); - return SUCCESS; -} - -void StreamMngFuncRegistry::Register(const StreamMngFuncType func_type, StreamMngFunc const manage_func) { - std::lock_guard lock(mutex_); - type_to_func_[func_type] = manage_func; -} - -StreamMngFunc StreamMngFuncRegistry::LookUpStreamMngFunc(const StreamMngFuncType func_type) { - std::lock_guard lock(mutex_); - const auto iter = type_to_func_.find(func_type); - if (iter == type_to_func_.end()) { - return nullptr; - } - return iter->second; -} - -StreamMngFuncRegister::StreamMngFuncRegister(const StreamMngFuncType func_type, StreamMngFunc const manage_func) { - StreamMngFuncRegistry::GetInstance().Register(func_type, manage_func); -} -} // namespace ge diff --git a/register/stub/Makefile b/register/stub/Makefile deleted file mode 100755 index 587864d15fa3006ae567f611b30c52124f562c32..0000000000000000000000000000000000000000 --- a/register/stub/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -inc_path := $(shell pwd)/metadef/inc/external/ -out_path := $(shell pwd)/out/register/lib64/stub/ -stub_path := $(shell pwd)/metadef/register/stub/ - -mkdir_stub := $(shell mkdir -p $(out_path)) -register_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/register/tensor_assign.cpp b/register/tensor_assign.cpp deleted file mode 100644 index 5a219638a24b9f3cc20309e606d4c7da8cd7a258..0000000000000000000000000000000000000000 --- a/register/tensor_assign.cpp +++ /dev/null @@ -1,559 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "securec.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/utils/type_utils_inner.h" -#include "graph/utils/attr_utils.h" -#include "register/register_error_codes.h" -#include "graph/types.h" -#include "graph/def_types.h" -#include "graph/debug/ge_util.h" -#include "register/tensor_assign.h" - -namespace domi { -namespace { -using GeTensorDesc = ge::GeTensorDesc; -using GeShape = ge::GeShape; -using domi::tensorflow::TensorProto; -using google::protobuf::int32; -using google::protobuf::int64; -const char_t *const kOriginElementNumAttrName = "origin_element_num"; -const std::map data_type_map = { - {domi::tensorflow::DataType::DT_FLOAT, ge::DataType::DT_FLOAT}, - {domi::tensorflow::DataType::DT_HALF, ge::DataType::DT_FLOAT16}, - {domi::tensorflow::DataType::DT_INT8, ge::DataType::DT_INT8}, - {domi::tensorflow::DataType::DT_INT16, ge::DataType::DT_INT16}, - {domi::tensorflow::DataType::DT_UINT16, ge::DataType::DT_UINT16}, - {domi::tensorflow::DataType::DT_UINT8, ge::DataType::DT_UINT8}, - {domi::tensorflow::DataType::DT_INT32, ge::DataType::DT_INT32}, - {domi::tensorflow::DataType::DT_INT64, ge::DataType::DT_INT64}, - {domi::tensorflow::DataType::DT_UINT32, ge::DataType::DT_UINT32}, - {domi::tensorflow::DataType::DT_UINT64, ge::DataType::DT_UINT64}, - {domi::tensorflow::DataType::DT_BOOL, ge::DataType::DT_BOOL}, - {domi::tensorflow::DataType::DT_DOUBLE, ge::DataType::DT_DOUBLE}, - {domi::tensorflow::DataType::DT_COMPLEX32, ge::DataType::DT_COMPLEX32}, - {domi::tensorflow::DataType::DT_COMPLEX64, ge::DataType::DT_COMPLEX64}, - {domi::tensorflow::DataType::DT_QINT8, ge::DataType::DT_INT8}, - {domi::tensorflow::DataType::DT_QUINT8, ge::DataType::DT_UINT8}, - {domi::tensorflow::DataType::DT_QINT32, ge::DataType::DT_INT32}, - {domi::tensorflow::DataType::DT_QINT16, ge::DataType::DT_INT16}, - {domi::tensorflow::DataType::DT_QUINT16, ge::DataType::DT_UINT16}, - {domi::tensorflow::DataType::DT_COMPLEX128, ge::DataType::DT_COMPLEX128}, - {domi::tensorflow::DataType::DT_RESOURCE, ge::DataType::DT_RESOURCE}, - {domi::tensorflow::DataType::DT_BFLOAT16, ge::DataType::DT_BF16}, - {domi::tensorflow::DataType::DT_STRING, ge::DataType::DT_STRING}, - {domi::tensorflow::DataType::DT_FLOAT_REF, ge::DataType::DT_FLOAT}, - {domi::tensorflow::DataType::DT_DOUBLE_REF, ge::DataType::DT_DOUBLE}, - {domi::tensorflow::DataType::DT_INT32_REF, ge::DataType::DT_INT32}, - {domi::tensorflow::DataType::DT_INT8_REF, ge::DataType::DT_INT8}, - {domi::tensorflow::DataType::DT_UINT8_REF, ge::DataType::DT_UINT8}, - {domi::tensorflow::DataType::DT_INT16_REF, ge::DataType::DT_INT16}, - {domi::tensorflow::DataType::DT_UINT16_REF, ge::DataType::DT_UINT16}, - {domi::tensorflow::DataType::DT_COMPLEX32_REF, ge::DataType::DT_COMPLEX32}, - {domi::tensorflow::DataType::DT_COMPLEX64_REF, ge::DataType::DT_COMPLEX64}, - {domi::tensorflow::DataType::DT_QINT8_REF, ge::DataType::DT_INT8}, - {domi::tensorflow::DataType::DT_QUINT8_REF, ge::DataType::DT_UINT8}, - {domi::tensorflow::DataType::DT_QINT32_REF, ge::DataType::DT_INT32}, - {domi::tensorflow::DataType::DT_QINT16_REF, ge::DataType::DT_INT16}, - {domi::tensorflow::DataType::DT_QUINT16_REF, ge::DataType::DT_UINT16}, - {domi::tensorflow::DataType::DT_COMPLEX128_REF, ge::DataType::DT_COMPLEX128}, - {domi::tensorflow::DataType::DT_RESOURCE_REF, ge::DataType::DT_RESOURCE}, - {domi::tensorflow::DataType::DT_BFLOAT16_REF, ge::DataType::DT_FLOAT16}, - {domi::tensorflow::DataType::DT_UINT32_REF, ge::DataType::DT_UINT32}, - {domi::tensorflow::DataType::DT_UINT64_REF, ge::DataType::DT_UINT64}, - {domi::tensorflow::DataType::DT_INT64_REF, ge::DataType::DT_INT64}, - {domi::tensorflow::DataType::DT_BOOL_REF, ge::DataType::DT_BOOL}, - {domi::tensorflow::DataType::DT_HALF_REF, ge::DataType::DT_FLOAT16}, - {domi::tensorflow::DataType::DT_STRING_REF, ge::DataType::DT_STRING}, - {domi::tensorflow::DataType::DT_VARIANT, ge::DataType::DT_VARIANT}, -}; -} // namespace - -ge::DataType TensorAssign::ConvertTensorflowDataType(const uint32_t tf_data_type) { - const auto search = data_type_map.find(tf_data_type); - if (search != data_type_map.end()) { - return search->second; - } else { - return ge::DataType::DT_UNDEFINED; - } -} - -bool TensorAssign::CheckBoolVal(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_BOOL) || (data_type == tensorflow::DT_BOOL_REF)); -} - -bool TensorAssign::CheckHalfVal(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_HALF) || (data_type == tensorflow::DT_BFLOAT16) || - (data_type == tensorflow::DT_HALF_REF) || (data_type == tensorflow::DT_BFLOAT16_REF)); -} - -bool TensorAssign::CheckFloatVal(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_FLOAT) || (data_type == tensorflow::DT_FLOAT_REF)); -} - -bool TensorAssign::CheckDoubleVal(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_DOUBLE) || (data_type == tensorflow::DT_DOUBLE_REF)); -} - -bool TensorAssign::CheckComplex32Val(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_COMPLEX32) || (data_type == tensorflow::DT_COMPLEX32_REF)); -} - -bool TensorAssign::CheckComplex64Val(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_COMPLEX64) || (data_type == tensorflow::DT_COMPLEX64_REF)); -} - -bool TensorAssign::CheckComplex128Val(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_COMPLEX128) || (data_type == tensorflow::DT_COMPLEX128_REF)); -} - -bool TensorAssign::CheckStringVal(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_STRING) || (data_type == tensorflow::DT_STRING_REF)); -} - -bool TensorAssign::CheckByte(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_UINT8) || (data_type == tensorflow::DT_INT8) || - (data_type == tensorflow::DT_QINT8) || (data_type == tensorflow::DT_QUINT8) || - (data_type == tensorflow::DT_UINT8_REF) || (data_type == tensorflow::DT_INT8_REF) || - (data_type == tensorflow::DT_QINT8_REF) || (data_type == tensorflow::DT_QUINT8_REF)); -} - -bool TensorAssign::CheckDoubleByte(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_INT16) || (data_type == tensorflow::DT_UINT16) || - (data_type == tensorflow::DT_QINT16) || (data_type == tensorflow::DT_QUINT16) || - (data_type == tensorflow::DT_INT16_REF) || (data_type == tensorflow::DT_UINT16_REF) || - (data_type == tensorflow::DT_QINT16_REF) || (data_type == tensorflow::DT_QUINT16_REF)); -} - -bool TensorAssign::CheckSignedFourByte(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_INT32) || (data_type == tensorflow::DT_QINT32) || - (data_type == tensorflow::DT_INT32_REF) || (data_type == tensorflow::DT_QINT32_REF)); -} - -bool TensorAssign::CheckUnsignedFourByte(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_UINT32) || (data_type == tensorflow::DT_UINT32_REF)); -} - -bool TensorAssign::CheckSignedEightByte(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_INT64) || (data_type == tensorflow::DT_INT64_REF)); -} - -bool TensorAssign::CheckUnsignedEightByte(const tensorflow::DataType data_type) { - return ((data_type == tensorflow::DT_UINT64) || (data_type == tensorflow::DT_UINT64_REF)); -} - -Status TensorAssign::GetDoubleByteVal(const int64_t val_size, const google::protobuf::RepeatedField &val_vector, - const int64_t count, GeTensorPtr &weight) { - GE_CHECK_NOTNULL(weight); - const bool zerosLike = ((count != val_size) && (val_size == 1)); - std::vector addr(static_cast(count)); - if (val_size == 0) { // addr has been zero initialized - (void)weight->SetData(ge::PtrToPtr(addr.data()), static_cast(count) * sizeof(uint16_t)); - return SUCCESS; - } - if (!zerosLike) { - const int64_t minCount = (count > val_size) ? val_size : count; - for (int64_t i = 0; i < minCount; i++) { - GE_ASSERT_EQ(ge::IntegerChecker::Compat(i), true); - addr[static_cast(i)] = static_cast(val_vector.Get(static_cast(i))); - } - const int64_t value_index = minCount - 1; - GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_index), true); - for (int64_t i = minCount; i < count; i++) { - addr[static_cast(i)] = static_cast(val_vector.Get(static_cast(value_index))); - } - } else { - for (int64_t i = 0; i < count; i++) { - addr[static_cast(i)] = static_cast(val_vector.Get(0)); - } - } - (void)weight->SetData(ge::PtrToPtr(addr.data()), static_cast(count) * sizeof(uint16_t)); - return SUCCESS; -} - -Status TensorAssign::GetByteVal(const int64_t val_size, const google::protobuf::RepeatedField &val_vector, - const int64_t count, GeTensorPtr &weight) { - GE_CHECK_NOTNULL(weight); - const bool zerosLike = ((count != val_size) && (val_size == 1)); - std::vector addr(static_cast(count)); - if (val_size == 0) { // addr has been zero initialized - (void)weight->SetData(addr.data(), static_cast(count) * sizeof(uint8_t)); - return SUCCESS; - } - if (!zerosLike) { - const int64_t minCount = (count > val_size) ? val_size : count; - for (int64_t i = 0; i < minCount; i++) { - GE_ASSERT_EQ(ge::IntegerChecker::Compat(i), true); - addr[static_cast(i)] = static_cast(val_vector.Get(static_cast(i))); - } - const int64_t value_index = minCount - 1; - GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_index), true); - for (int64_t i = minCount; i < count; i++) { - addr[static_cast(i)] = static_cast(val_vector.Get(static_cast(value_index))); - } - } else { - for (int64_t i = 0; i < count; i++) { - addr[static_cast(i)] = static_cast(val_vector.Get(0)); - } - } - (void)weight->SetData(addr.data(), static_cast(count) * sizeof(uint8_t)); - return SUCCESS; -} - -Status TensorAssign::GetStringVal(const int64_t val_size, - const google::protobuf::RepeatedPtrField &val_vector, - const int64_t count, GeTensorPtr &weight) { - GE_CHECK_NOTNULL(weight); - const bool flag = ((count != val_size) && (val_size == 1)); - size_t total_size = 0U; - if (!flag) { - const int64_t min_count = (count > val_size) ? val_size : count; - for (int64_t i = 0; i < min_count; i++) { - // extra 16 bytes store head of string - // extra 1 byte store '\0' - GE_ASSERT_EQ(ge::IntegerChecker::Compat(i), true); - total_size += (val_vector[static_cast(i)].size() + sizeof(ge::StringHead) + 1U); - } - total_size += (static_cast(count) - static_cast(min_count)) * (sizeof(ge::StringHead) + 1U); - std::vector addr(total_size); - ge::StringHead *const string_head = ge::PtrToPtr(addr.data()); - // front 16 bytes store head of each string - auto raw_data = ge::PtrAdd(addr.data(), total_size + 1U, - static_cast(count) * sizeof(ge::StringHead)); - GE_ASSERT_TRUE(count > 0); - GE_ASSERT_EQ(ge::TypeUtilsInner::CheckUint64MulOverflow(static_cast(count), - static_cast(sizeof(ge::StringHead))), - false); - uint64_t ptr_size = static_cast(count) * sizeof(ge::StringHead); - for (int64_t i = 0; i < count; ++i) { - ge::PtrAdd(string_head, static_cast(count) + 1U, - static_cast(i))->addr = static_cast(ptr_size); - if (i < val_size) { - GE_ASSERT_EQ(ge::IntegerChecker::Compat(i), true); - const string &str = val_vector.Get(static_cast(i)); - ge::PtrAdd(string_head, static_cast(count) + 1U, - static_cast(i))->len = static_cast(str.size()); - CHECK_FALSE_EXEC(memcpy_s(raw_data, str.size() + 1U, str.c_str(), str.size() + 1U) == EOK, - GELOGW("[GetStringVal][Copy] memcpy failed")); - raw_data = ge::PtrAdd(raw_data, total_size + 1U, str.size() + 1U); - ptr_size += (str.size() + 1U); - } else { - ge::PtrAdd(string_head, static_cast(count) + 1U, static_cast(i))->len = 0; - raw_data = ge::PtrAdd(raw_data, total_size + 1U, 1U); - ptr_size += 1U; - } - } - (void)weight->SetData(ge::PtrToPtr(addr.data()), total_size); - } else { - const string &str = val_vector.Get(0); - // extra 16 bytes store head of string - // extra 1 byte store '\0' - total_size = (str.size() + sizeof(ge::StringHead) + 1U) * static_cast(count); - std::vector addr(total_size); - // front 16 bytes store head of each string - ge::StringHead *const string_head = ge::PtrToPtr(addr.data()); - auto raw_data = ge::PtrAdd(addr.data(), total_size + 1U, - static_cast(count) * sizeof(ge::StringHead)); - GE_ASSERT_TRUE(count > 0); - GE_ASSERT_EQ(ge::TypeUtilsInner::CheckUint64MulOverflow(static_cast(count), - static_cast(sizeof(ge::StringHead))), - false); - uint64_t ptr_size = static_cast(count) * sizeof(ge::StringHead); - for (int64_t i = 0; i < count; ++i) { - ge::PtrAdd(string_head, static_cast(count) + 1U, - static_cast(i))->addr = static_cast(ptr_size); - ge::PtrAdd(string_head, static_cast(count) + 1U, - static_cast(i))->len = static_cast(str.size()); - const bool b = memcpy_s(raw_data, str.size() + 1U, str.c_str(), str.size() + 1U) == EOK; - if (!b) { - GELOGW("[GetStringVal][Copy] memcpy failed"); - } - raw_data = ge::PtrAdd(raw_data, total_size + 1U, str.size() + 1U); - ptr_size += (str.size() + 1U); - } - (void)weight->SetData(ge::PtrToPtr(addr.data()), total_size); - } - return SUCCESS; -} - -static Status GetComplex32Val(const int64_t val_size, const google::protobuf::RepeatedField &val_vector, - const int64_t count, GeTensorPtr &weight) { - // val_size must be even, and complex value should be an integer multiple of 2 - GE_ASSERT_TRUE((val_size % kComplexWidth) == 0, "complex value should be an integer multiple of 2."); - const std::unique_ptr addr = ge::ComGraphMakeUnique(static_cast(count)); - GE_CHECK_NOTNULL(addr); - // Complex numbers are made up of real and imaginary numbers - const bool zerosLike = ((count != val_size) && (val_size == 2)); - if (!zerosLike) { - GE_ASSERT_TRUE(val_size <= count); - for (size_t i = 0UL; i < static_cast(val_size); i++) { - addr[i] = static_cast(val_vector.Get(static_cast(i))); - } - const int64_t value_r = val_size - 1; - GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_r), true); - // val_vector format is real value, complex value..., here is getting the corresponding value. - // real value and complex value are stored spaced apart, so use 2 and 1 to store in the correct addr. - const int64_t value_l = val_size - kComplexWidth; - GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_l), true); - for (int64_t i = val_size; i < count; i += kComplexWidth) { - addr[static_cast(i)] = static_cast(val_vector.Get(static_cast(value_l))); - addr[static_cast(i) + 1UL] = static_cast(val_vector.Get(static_cast(value_r))); - } - } else { - for (int64_t i = 0; i < count; i += kComplexWidth) { - addr[static_cast(i)] = static_cast(val_vector.Get(0)); - addr[static_cast(i) + 1UL] = static_cast(val_vector.Get(1)); - } - } - (void)weight->SetData(ge::PtrToPtr(addr.get()), static_cast(count) * sizeof(uint16_t)); - return SUCCESS; -} - -void TensorAssign::SetGeTensorWeightData(const TensorProto &tensor, const int64_t val_size, - const int64_t count, GeTensorPtr &weight) { - const tensorflow::DataType data_type = tensor.dtype(); - constexpr int64_t kNumElementOfComplex = 2; - if (CheckFloatVal(data_type)) { - (void)GetVal(val_size, tensor.float_val(), count, weight); - } else if (CheckComplex32Val(data_type)) { - (void)GetComplex32Val(val_size, tensor.icomplex_val(), count * kNumElementOfComplex, weight); - } else if (CheckComplex64Val(data_type)) { - (void)GetVal(val_size, tensor.scomplex_val(), count * kNumElementOfComplex, weight, true); - } else if (CheckSignedFourByte(data_type)) { - (void)GetVal(val_size, tensor.int_val(), count, weight); - } else if (CheckUnsignedFourByte(data_type)) { - (void)GetVal(val_size, tensor.uint32_val(), count, weight); - } else if (CheckSignedEightByte(data_type)) { - (void)GetVal(val_size, tensor.int64_val(), count, weight); - } else if (CheckUnsignedEightByte(data_type)) { - (void)GetVal(val_size, tensor.uint64_val(), count, weight); - } else if (CheckBoolVal(data_type)) { - (void)GetVal(val_size, tensor.bool_val(), count, weight); - } else if (CheckStringVal(data_type)) { - (void)GetStringVal(val_size, tensor.string_val(), count, weight); - } else if (CheckHalfVal(data_type)) { - (void)GetDoubleByteVal(val_size, tensor.half_val(), count, weight); - } else if (CheckDoubleByte(data_type)) { - (void)GetDoubleByteVal(val_size, tensor.int_val(), count, weight); - } else if (CheckByte(data_type)) { - (void)GetByteVal(val_size, tensor.int_val(), count, weight); - } else if (CheckDoubleVal(data_type)) { - (void)GetVal(val_size, tensor.double_val(), count, weight); - } else if (CheckComplex128Val(data_type)) { - (void)GetVal(val_size, tensor.dcomplex_val(), count * kNumElementOfComplex, weight, true); - } else { - GELOGI("data_type:%s.", DataType_Name(data_type).c_str()); - } -} - -void TensorAssign::SetWeightData(const tensorflow::DataType data_type, const int64_t count, - const std::string &tensor_content, GeTensorPtr &weight) { - if (weight == nullptr) { - GE_LOGE("weight is nullptr."); - return; - } - GELOGD("Set data from tensor_content, count = %ld, data_type = %s.", - count, DataType_Name(data_type).c_str()); - const auto tensor_content_data = tensor_content.data(); - const bool is_four_byte = - CheckSignedFourByte(data_type) || CheckUnsignedFourByte(data_type) || CheckComplex32Val(data_type); - const bool is_double_byte = CheckHalfVal(data_type) || CheckDoubleByte(data_type); - const bool is_eight_byte = CheckSignedEightByte(data_type) || CheckUnsignedEightByte(data_type); - if (CheckByte(data_type)) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(uint8_t)); - } else if (CheckBoolVal(data_type)) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(bool)); - } else if (is_double_byte) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(uint16_t)); - } else if (is_four_byte) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(uint32_t)); - } else if (is_eight_byte) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(uint64_t)); - } else if (CheckDoubleVal(data_type)) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(double)); - } else if (CheckComplex128Val(data_type)) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(std::complex)); - } else if (CheckComplex64Val(data_type)) { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(std::complex)); - } else if (CheckStringVal(data_type)) { - if (ge::TypeUtilsInner::CheckUint64MulOverflow(static_cast(count), - static_cast(sizeof(ge::StringHead)))) { - GELOGE(ge::FAILED, "count multiply StringHead is overflow uint64, count: %u", static_cast(count)); - return; - } - std::string weight_content; - if (count > 0) { - // each byte of top count bytes is each string length - weight_content = tensor_content.substr(static_cast(count)); - } - const size_t total_size = weight_content.size() + static_cast(count) * (sizeof(ge::StringHead) + 1U); - std::vector addr(total_size); - ge::StringHead *const string_head = ge::PtrToPtr(addr.data()); - auto raw_data = - ge::PtrAdd(addr.data(), total_size + 1U, static_cast(count) * sizeof(ge::StringHead)); - uint64_t ptr_size = static_cast(count) * sizeof(ge::StringHead); - size_t str_start_index = 0U; - for (int64_t i = 0; i < count; ++i) { - ge::PtrAdd(string_head, static_cast(count) + 1U, static_cast(i))->addr = - static_cast(ptr_size); - const size_t str_len = static_cast(tensor_content.at(static_cast(i))); - const string &str = weight_content.substr(str_start_index, str_len); - str_start_index += str_len; - ge::PtrAdd(string_head, static_cast(count) + 1U, static_cast(i))->len = - static_cast(str.size()); - CHECK_FALSE_EXEC(memcpy_s(raw_data, str.size() + 1U, str.c_str(), str.size() + 1U) == EOK, - GELOGW("[SetWeight][Copy] memcpy failed")); - raw_data = ge::PtrAdd(raw_data, total_size + 1U, str.size() + 1U); - ptr_size += static_cast(str.size()) + 1U; - } - (void)weight->SetData(ge::PtrToPtr(addr.data()), total_size); - } else { - (void)weight->SetData(ge::PtrToPtr(tensor_content_data), - static_cast(count) * sizeof(float)); - } -} - -Status TensorAssign::SetGeTensor(const TensorProto &tensor, GeTensorPtr &weight) { - GE_CHECK_NOTNULL(weight); - std::map datatype_val_size_map = { - {tensorflow::DT_FLOAT, tensor.float_val().size()}, - {tensorflow::DT_INT32, tensor.int_val().size()}, - {tensorflow::DT_INT64, tensor.int64_val().size()}, - {tensorflow::DT_BOOL, tensor.bool_val().size()}, - {tensorflow::DT_HALF, tensor.half_val().size()}, - {tensorflow::DT_INT8, tensor.int_val().size()}, - {tensorflow::DT_UINT8, tensor.int_val().size()}, - {tensorflow::DT_INT16, tensor.int_val().size()}, - {tensorflow::DT_UINT16, tensor.int_val().size()}, - {tensorflow::DT_DOUBLE, tensor.double_val().size()}, - {tensorflow::DT_STRING, tensor.string_val().size()}, - {tensorflow::DT_QINT8, tensor.int_val().size()}, - {tensorflow::DT_QINT16, tensor.int_val().size()}, - {tensorflow::DT_QINT32, tensor.int_val().size()}, - {tensorflow::DT_QUINT8, tensor.int_val().size()}, - {tensorflow::DT_QUINT16, tensor.int_val().size()}, - {tensorflow::DT_COMPLEX32, tensor.icomplex_val().size()}, - {tensorflow::DT_COMPLEX64, tensor.scomplex_val().size()}, - {tensorflow::DT_COMPLEX128, tensor.dcomplex_val().size()}, - {tensorflow::DT_BFLOAT16, tensor.half_val().size()}, - {tensorflow::DT_UINT32, tensor.uint32_val().size()}, - {tensorflow::DT_UINT64, tensor.uint64_val().size()}, - {tensorflow::DT_RESOURCE, tensor.resource_handle_val().size()}, - {tensorflow::DT_VARIANT, tensor.variant_val().size()}, - {tensorflow::DT_FLOAT_REF, tensor.float_val().size()}, - {tensorflow::DT_INT32_REF, tensor.int_val().size()}, - {tensorflow::DT_INT64_REF, tensor.int64_val().size()}, - {tensorflow::DT_BOOL_REF, tensor.bool_val().size()}, - {tensorflow::DT_HALF_REF, tensor.half_val().size()}, - {tensorflow::DT_INT8_REF, tensor.int_val().size()}, - {tensorflow::DT_UINT8_REF, tensor.int_val().size()}, - {tensorflow::DT_INT16_REF, tensor.int_val().size()}, - {tensorflow::DT_UINT16_REF, tensor.int_val().size()}, - {tensorflow::DT_DOUBLE_REF, tensor.double_val().size()}, - {tensorflow::DT_STRING_REF, tensor.string_val().size()}, - {tensorflow::DT_QINT8_REF, tensor.int_val().size()}, - {tensorflow::DT_QINT16_REF, tensor.int_val().size()}, - {tensorflow::DT_QINT32_REF, tensor.int_val().size()}, - {tensorflow::DT_QUINT8_REF, tensor.int_val().size()}, - {tensorflow::DT_QUINT16_REF, tensor.int_val().size()}, - {tensorflow::DT_COMPLEX32_REF, tensor.icomplex_val().size()}, - {tensorflow::DT_COMPLEX64_REF, tensor.scomplex_val().size()}, - {tensorflow::DT_COMPLEX128_REF, tensor.dcomplex_val().size()}, - {tensorflow::DT_BFLOAT16_REF, tensor.half_val().size()}, - {tensorflow::DT_UINT32_REF, tensor.uint32_val().size()}, - {tensorflow::DT_UINT64_REF, tensor.uint64_val().size()}, - {tensorflow::DT_RESOURCE_REF, tensor.resource_handle_val().size()}, - {tensorflow::DT_VARIANT_REF, tensor.variant_val().size()}, - }; - const tensorflow::DataType data_type = tensor.dtype(); - int64_t datatype_val_size = 0; - - const auto iter = datatype_val_size_map.find(data_type); - if (iter != datatype_val_size_map.cend()) { - datatype_val_size = iter->second; - } else { - GE_CHECK_GE(data_type, 0); - GE_LOGE("datatype:%s not support.", DataType_Name(data_type).c_str()); - return FAILED; - } - - std::vector shape_vec; - // There is tensor shape, get the dimension - int64_t count = 1; - GE_IF_BOOL_EXEC( - tensor.has_tensor_shape(), const tensorflow::TensorShapeProto &tensor_shape = tensor.tensor_shape(); - for (int32_t i = 0; i < tensor_shape.dim_size(); i++) { - const tensorflow::TensorShapeProto_Dim &shape_dim = tensor_shape.dim(i); - shape_vec.push_back(shape_dim.size()); - const int64_t dim = shape_vec[static_cast(i)]; - // tensorflow support weights shape [0],have no weights - if (dim < 0) { - GELOGE(FAILED, "Dim size invalid"); - return FAILED; - } - if ((count != 0) && (dim >= (std::numeric_limits::max() / count))) { - GELOGE(FAILED, "Dim size exceeds INT64_MAX"); - return FAILED; - } - count *= dim; - }); - const GeShape shape(shape_vec); - GeTensorDesc tmp_desc = weight->GetTensorDesc(); - tmp_desc.SetShape(shape); - - // Fixed input ND - tmp_desc.SetFormat(ge::Format::FORMAT_ND); - tmp_desc.SetOriginFormat(ge::Format::FORMAT_ND); - - weight->SetTensorDesc(tmp_desc); - - if (datatype_val_size > 0 || ((datatype_val_size == 0) && (count > 0) && tensor.tensor_content().empty())) { - SetGeTensorWeightData(tensor, datatype_val_size, count, weight); - const int64_t origin_element_num = static_cast(datatype_val_size); - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(weight->MutableTensorDesc(), kOriginElementNumAttrName, origin_element_num), - return FAILED, "Set origin element num failed."); - } else if (!tensor.tensor_content().empty()) { - const auto &tensor_content = tensor.tensor_content(); - SetWeightData(data_type, count, tensor_content, weight); - } else { - if (count == 0) { - GELOGI("Empty tensor, has no data."); - return SUCCESS; - } - GE_LOGE("value Attr tensor should have val() or tensor_content"); - return FAILED; - } - - return SUCCESS; -} - -Status TensorAssign::SetGeTensorDataType(const int64_t data_type, GeTensorPtr &weight) { - GE_CHECK_NOTNULL(weight); - GeTensorDesc tmp_desc = weight->GetTensorDesc(); - tmp_desc.SetDataType(static_cast(data_type)); - weight->SetTensorDesc(tmp_desc); - return SUCCESS; -} -} // namespace domi diff --git a/register/tuning_bank_key_registry.cc b/register/tuning_bank_key_registry.cc deleted file mode 100644 index 1f41cdb429c16678fab6e3a98f08762cb97f713e..0000000000000000000000000000000000000000 --- a/register/tuning_bank_key_registry.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/tuning_bank_key_registry.h" -#include "common/ge_common/debug/ge_log.h" - -namespace tuningtiling { -OpBankKeyFuncInfo::OpBankKeyFuncInfo(const ge::AscendString &optype) : optype_(optype) {} - -OpBankKeyFuncInfoV2::OpBankKeyFuncInfoV2(const ge::AscendString &optypeV2) : optypeV2_(optypeV2) {} - -// v1兼容老版本om -void OpBankKeyFuncInfo::SetOpConvertFunc(const OpBankKeyConvertFun &convert_func) { - convert_func_ = convert_func; -} - -void OpBankKeyFuncInfo::SetOpParseFunc(const OpBankParseFun &parse_func) { - parse_func_ = parse_func; -} - -void OpBankKeyFuncInfo::SetOpLoadFunc(const OpBankLoadFun &load_func) { - load_func_ = load_func; -} - - -// v2 -void OpBankKeyFuncInfoV2::SetOpConvertFuncV2(const OpBankKeyConvertFunV2 &convert_funcV2) { - convert_funcV2_ = convert_funcV2; -} - -void OpBankKeyFuncInfoV2::SetOpParseFuncV2(const OpBankParseFunV2 &parse_funcV2) { - parse_funcV2_ = parse_funcV2; -} - -void OpBankKeyFuncInfoV2::SetOpLoadFuncV2(const OpBankLoadFunV2 &load_funcV2) { - load_funcV2_ = load_funcV2; -} - -// v1兼容老版本om -const OpBankKeyConvertFun& OpBankKeyFuncInfo::OpBankKeyFuncInfo::GetBankKeyConvertFunc() const { - return convert_func_; -} - -const OpBankParseFun& OpBankKeyFuncInfo::GetBankKeyParseFunc() const { - return parse_func_; -} - -const OpBankLoadFun& OpBankKeyFuncInfo::GetBankKeyLoadFunc() const { - return load_func_; -} - -// v2 -const OpBankKeyConvertFunV2& OpBankKeyFuncInfoV2::OpBankKeyFuncInfoV2::GetBankKeyConvertFuncV2() const { - return convert_funcV2_; -} - -const OpBankParseFunV2& OpBankKeyFuncInfoV2::GetBankKeyParseFuncV2() const { - return parse_funcV2_; -} - -const OpBankLoadFunV2& OpBankKeyFuncInfoV2::GetBankKeyLoadFuncV2() const { - return load_funcV2_; -} - -// v1兼容老版本om -std::unordered_map &OpBankKeyFuncRegistry::RegisteredOpFuncInfo() { - static std::unordered_map op_func_map; - return op_func_map; -} - -// v2 -std::unordered_map &OpBankKeyFuncRegistryV2::RegisteredOpFuncInfoV2() { - static std::unordered_map op_func_mapV2; - return op_func_mapV2; -} - -extern "C" void _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann10basic_jsonISt3mapSt6vectorSsblmdSaNSA_14adl_serializerESD_IhSaIhEEEEEERKS5_IFbRS7_RmRKSH_EE() {} - -extern "C" void _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann16json_abi_v3_11_210basic_jsonISt3mapSt6vectorSsblmdSaNSB_14adl_serializerESE_IhSaIhEEEEEERKS5_IFbRS7_RmRKSI_EE() {} -// v1兼容老版本om -OpBankKeyFuncRegistry::OpBankKeyFuncRegistry(const ge::AscendString &optype, const OpBankKeyConvertFun &convert_func) { - auto &op_func_map = RegisteredOpFuncInfo(); - const auto iter = op_func_map.find(optype); - if (iter == op_func_map.cend()) { - OpBankKeyFuncInfo op_func_info(optype); - op_func_info.SetOpConvertFunc(convert_func); - (void)op_func_map.emplace(optype, op_func_info); - } else { - iter->second.SetOpConvertFunc(convert_func); - } -} - -OpBankKeyFuncRegistry::OpBankKeyFuncRegistry(const ge::AscendString &optype, - const OpBankParseFun &parse_func, const OpBankLoadFun &load_func) { - auto &op_func_map = RegisteredOpFuncInfo(); - const auto iter = op_func_map.find(optype); - if (iter == op_func_map.cend()) { - OpBankKeyFuncInfo op_func_info(optype); - op_func_info.SetOpParseFunc(parse_func); - op_func_info.SetOpLoadFunc(load_func); - (void)op_func_map.emplace(optype, op_func_info); - } else { - iter->second.SetOpParseFunc(parse_func); - iter->second.SetOpLoadFunc(load_func); - } -} - -// v2接口 -OpBankKeyFuncRegistryV2::OpBankKeyFuncRegistryV2(const ge::AscendString &optype, const OpBankKeyConvertFunV2 &convert_funcV2) { - auto &op_func_mapV2 = RegisteredOpFuncInfoV2(); - const auto iter = op_func_mapV2.find(optype); - if (iter == op_func_mapV2.cend()) { - OpBankKeyFuncInfoV2 op_func_info(optype); - op_func_info.SetOpConvertFuncV2(convert_funcV2); - (void)op_func_mapV2.emplace(optype, op_func_info); - } else { - iter->second.SetOpConvertFuncV2(convert_funcV2); - } -} - -OpBankKeyFuncRegistryV2::OpBankKeyFuncRegistryV2(const ge::AscendString &optype, const OpBankParseFunV2 &parse_funcV2, - const OpBankLoadFunV2 &load_funcV2) { - auto &op_func_map = RegisteredOpFuncInfoV2(); - const auto iter = op_func_map.find(optype); - if (iter == op_func_map.cend()) { - OpBankKeyFuncInfoV2 op_func_info(optype); - op_func_info.SetOpParseFuncV2(parse_funcV2); - op_func_info.SetOpLoadFuncV2(load_funcV2); - (void)op_func_map.emplace(optype, op_func_info); - } else { - iter->second.SetOpParseFuncV2(parse_funcV2); - iter->second.SetOpLoadFuncV2(load_funcV2); - } -} -} // namespace tuningtiling diff --git a/register/tuning_tiling_registry.cc b/register/tuning_tiling_registry.cc deleted file mode 100644 index 2482cb64da5ddd4ab51862fdd465d7c543c2c734..0000000000000000000000000000000000000000 --- a/register/tuning_tiling_registry.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/tuning_tiling_registry.h" -#include "common/ge_common/debug/ge_log.h" - -namespace tuningtiling { -ge::AscendString TuningTilingDef::GetClassName() const { - return class_name_; -} - -std::map &TuningTilingClassFactory::RegisterInfo() { - static std::map instance; - return instance; -} - -void TuningTilingClassFactory::RegisterTilingData(const ge::AscendString &optype, - TuningTilingDefConstructor const constructor) { - if (constructor == nullptr) { - return; - } - auto &instance = TuningTilingClassFactory::RegisterInfo(); - instance[optype] = constructor; - GELOGI("optype: %s, registered count: %zu", optype.GetString(), instance.size()); -} - -std::shared_ptr TuningTilingClassFactory::CreateTilingDataInstance(const ge::AscendString &optype) { - const auto &instance = TuningTilingClassFactory::RegisterInfo(); - const auto it = instance.find(optype); - if (it == instance.cend()) { - GELOGW("can not find optype: %s", optype.GetString()); - return nullptr; - } - - TuningTilingDefConstructor const constructor = it->second; - - if (constructor == nullptr) { - GELOGW("CreateTilingDataInstance: constructor is nullptr"); - return nullptr; - } - - return (*constructor)(); -} -} // namespace tuningtiling diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cd061a066f1248029d63f9b7a9e65bb55462e5e4..f963318931ceaed24c76e94f8f9c560795cce46c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -12,8 +12,6 @@ if (ENABLE_METADEF_UT OR ENABLE_METADEF_ST OR ENABLE_METADEF_COV) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fsanitize=address -fsanitize=leak -fsanitize-recover=address") endif() -target_compile_definitions(graph PRIVATE FUNC_VISIBILITY) - add_subdirectory(depends/slog) add_subdirectory(depends/mmpa) add_subdirectory(depends/platform) @@ -25,4 +23,3 @@ stub_module(platform platform_stub) stub_module(runtime runtime_stub) add_subdirectory(ut) -add_subdirectory(benchmark) diff --git a/tests/benchmark/CMakeLists.txt b/tests/benchmark/CMakeLists.txt deleted file mode 100644 index 9f204d92dfaf049fe274cc9817f129b5005b7d17..0000000000000000000000000000000000000000 --- a/tests/benchmark/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -add_subdirectory(exe_graph) -add_subdirectory(fast_graph) diff --git a/tests/benchmark/exe_graph/CMakeLists.txt b/tests/benchmark/exe_graph/CMakeLists.txt deleted file mode 100644 index c85da946827feac1e9e673dec102847672fe5e75..0000000000000000000000000000000000000000 --- a/tests/benchmark/exe_graph/CMakeLists.txt +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -file(GLOB_RECURSE EXE_GRAPH_SRCS CONFIGURE_DEPENDS "*.cc") - -add_executable(exec_graph_benchmark ${EXE_GRAPH_SRCS}) - -target_link_libraries(exec_graph_benchmark PRIVATE intf_pub) - -target_include_directories(exec_graph_benchmark PRIVATE - ${AIR_CODE_DIR}/tests/ut/ge/runtime/fast_v2 - ${AIR_CODE_DIR}/runtime/v2 - ${AIR_CODE_DIR}/inc/framework - ${METADEF_DIR}/exe_graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos - ) - -target_link_libraries(exec_graph_benchmark PRIVATE benchmark::benchmark exe_graph lowering error_manager - slog_stub ascend_protobuf c_sec mmpa_stub -lrt -ldl - metadef_headers - graph - ) - -set_target_properties(exec_graph_benchmark PROPERTIES CXX_STANDARD 17) -target_compile_options(exec_graph_benchmark PRIVATE -O2) diff --git a/tests/benchmark/exe_graph/tiling_data_benchmark.cc b/tests/benchmark/exe_graph/tiling_data_benchmark.cc deleted file mode 100644 index 8d04183a02012bba22050c16ff297901843fa7b1..0000000000000000000000000000000000000000 --- a/tests/benchmark/exe_graph/tiling_data_benchmark.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/tiling_data.h" -#include - -namespace gert { -namespace { -struct TestData { - int64_t a; - int32_t b; - int16_t c; - int16_t d; -}; -} -static void TilingData_AppendBasicType(benchmark::State &state) { - auto data = TilingData::CreateCap(2 * 1024); - auto tiling_data = reinterpret_cast(data.get()); - - for (auto _ : state) { - tiling_data->Append(10); - tiling_data->SetDataSize(0); - } -} -BENCHMARK(TilingData_AppendBasicType); - -static void TilingData_AppendStruct(benchmark::State &state) { - auto data = TilingData::CreateCap(2048); - auto tiling_data = reinterpret_cast(data.get()); - TestData td { - .a = 1024, - .b = 512, - .c = 256, - .d = 128 - }; - - for (auto _ : state) { - tiling_data->Append(td); - tiling_data->SetDataSize(0); - } -} -BENCHMARK(TilingData_AppendStruct); - -} - -BENCHMARK_MAIN(); diff --git a/tests/benchmark/fast_graph/CMakeLists.txt b/tests/benchmark/fast_graph/CMakeLists.txt deleted file mode 100644 index 7bb51bd251b20d3d3f5a72ff3cd2f5e11816b9cd..0000000000000000000000000000000000000000 --- a/tests/benchmark/fast_graph/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -set(BENCHMARK_TEST "compute_graph_base_unittest.cc") - -add_executable(fast_graph_benchmark ${BENCHMARK_TEST}) - -target_include_directories(fast_graph_benchmark PRIVATE - ${AIR_CODE_DIR}/tests/ut/ge/runtime/fast_v2 - ${AIR_CODE_DIR}/runtime/v2 - ${AIR_CODE_DIR}/inc/framework - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/third_party/inc - ${METADEF_DIR}/third_party/inc/external - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos - ${CMAKE_BINARY_DIR}/proto/ge - ${CMAKE_BINARY_DIR}/proto/ge/proto - ) - -target_compile_definitions(fast_graph_benchmark PRIVATE - google=ascend_private - FUNC_VISIBILITY - ) - -set_target_properties(fast_graph_benchmark PROPERTIES CXX_STANDARD 17) - -target_compile_options(fast_graph_benchmark PRIVATE -O2 -std=c++17) - -target_link_libraries(fast_graph_benchmark PRIVATE benchmark::benchmark - intf_pub - metadef_headers - -Wl,--no-as-needed - graph - graph_base - c_sec - error_manager - slog - ascend_protobuf - $<$>:-lrt> - -ldl -) diff --git a/tests/benchmark/fast_graph/compute_graph_base_unittest.cc b/tests/benchmark/fast_graph/compute_graph_base_unittest.cc deleted file mode 100644 index 2787524f3c19cd65ff339c89d7a3c3344add0ea3..0000000000000000000000000000000000000000 --- a/tests/benchmark/fast_graph/compute_graph_base_unittest.cc +++ /dev/null @@ -1,1543 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include "graph/node.h" -#include "graph/compute_graph.h" -#include "graph/fast_graph/execute_graph.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "fast_graph/fast_graph_impl.h" -#include "fast_node_utils.h" -#include "node_utils.h" -#include "graph/op_desc.h" - -using namespace ge; - -#define GRAPH_CHECK_RET true - -class NodeBuilder { - public: - NodeBuilder(string name, string type, const std::shared_ptr &owner_graph) - : name_(std::move(name)), type_(std::move(type)), owner_graph_(owner_graph) {} - NodeBuilder &InputNum(int64_t num) { - input_num_ = num; - return *this; - } - NodeBuilder &OutputNum(int64_t num) { - output_num_ = num; - return *this; - } - - NodeBuilder &IoNum(int64_t input_num, int64_t output_num) { - return InputNum(input_num).OutputNum(output_num); - } - - FastNode *Build() const { - auto op_desc = std::make_shared(name_, type_); - auto td = GeTensorDesc(); - for (int64_t i = 0; i < input_num_; ++i) { - op_desc->AddInputDesc(td); - } - for (int64_t i = 0; i < output_num_; ++i) { - op_desc->AddOutputDesc(td); - } - return owner_graph_->AddNode(op_desc); - } - - private: - std::string name_; - std::string type_; - std::shared_ptr owner_graph_; - int64_t input_num_ = 0; - int64_t output_num_ = 0; -}; - -void OpDescCreate(int64_t node_num, std::shared_ptr *op_desc, int64_t io_num) { - for (int64_t j = 0; j < node_num; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddOutputDesc(td); - } - } -} - -static void OLD_Graph_Creation(benchmark::State &state) { - for (auto _ : state) { - auto compute_graph = std::make_shared("graph"); - benchmark::DoNotOptimize(compute_graph); - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_Creation); - -static void NEW_Graph_Creation(benchmark::State &state) { - for (auto _ : state) { - auto compute_graph = std::make_shared("graph"); - benchmark::DoNotOptimize(compute_graph); - benchmark::ClobberMemory(); - } -} -BENCHMARK(NEW_Graph_Creation); - -static void TEST_HASH_TIME(benchmark::State &state) { - std::string test = "hello, world."; - int loop_num = 1000; - for (int i = 0; i < loop_num; ++i) { - test += "a"; - } - for (auto _ : state) { - auto size = std::hash{}(test); - benchmark::DoNotOptimize(size); - benchmark::ClobberMemory(); - } -} -BENCHMARK(TEST_HASH_TIME); - -static void OLD_Graph_AddAndRemoveSingleNode(benchmark::State &state) { - auto compute_graph = std::make_shared("graph0"); - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - for (int j = 0; j < node_num; j++) { - if (j != 0) compute_graph->AddNode(op_desc[j]); - } - - ge::NodePtr node[node_num]; - for (auto _ : state) { - node[0] = compute_graph->AddNode(op_desc[0]); -#if GRAPH_CHECK_RET - if (node[0] == nullptr) { - std::cout << "Graph_AddNode Error" << std::endl; - return; - } -#endif - GraphUtils::RemoveJustNode(compute_graph, node[0]); - benchmark::DoNotOptimize(node[0]); - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_AddAndRemoveSingleNode)->Args({20, 10})->Args({20, 100})->Args({20, 1000})->Args({20, 10000}); - -static void NEW_Graph_AddAndRemoveSingleNode(benchmark::State &state) { - auto compute_graph = std::make_shared("graph1"); - int64_t node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - - int64_t io_num = state.range(0); - OpDescCreate(node_num, op_desc, io_num); - for (int j = 0; j < node_num; j++) { - if (j != 0) compute_graph->AddNode(op_desc[j]); - } - - FastNode *node[node_num] = {}; - for (auto _ : state) { - node[0] = compute_graph->AddNode(op_desc[0]); -#if GRAPH_CHECK_RET - if (node[0] == nullptr) { - std::cout << "Graph_AddNode Error" << std::endl; - return; - } -#endif - compute_graph->RemoveJustNode(node[0]); - - benchmark::DoNotOptimize(node[0]); - benchmark::ClobberMemory(); - } -} -BENCHMARK(NEW_Graph_AddAndRemoveSingleNode)->Args({20, 10})->Args({20, 100})->Args({20, 1000})->Args({20, 10000}); - -static void NEW_Graph_AddAndRemoveMultiNode(benchmark::State &state) { - auto compute_graph = std::make_shared("graph2"); - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - - FastNode *node[node_num] = {}; - for (auto _ : state) { - for (int64_t j = 0; j < node_num; j++) { - node[j] = compute_graph->AddNode(op_desc[j]); - } - - for (int64_t j = 0; j < node_num; j++) { - compute_graph->RemoveJustNode(node[j]); - } - benchmark::DoNotOptimize(node[0]); - benchmark::ClobberMemory(); - } -} -BENCHMARK(NEW_Graph_AddAndRemoveMultiNode)->Args({20, 10000})->Args({20, 100000}); - -static void OLD_Graph_AddAndRemoveMultiNode(benchmark::State &state) { - auto compute_graph = std::make_shared("graph01"); - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - - ge::NodePtr node[node_num]; - for (auto _ : state) { - for (int64_t j = 0; j < node_num; j++) { - node[j] = compute_graph->AddNode(op_desc[j]); - } - - for (int64_t j = 0; j < node_num; j++) { - GraphUtils::RemoveJustNode(compute_graph, node[j]); - } - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_AddAndRemoveMultiNode)->Args({20, 10000})->Args({20, 100000}); - -static void OLD_Graph_ADD_NODE_WITH_NODE(benchmark::State &state) { - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - auto root_graph2 = std::make_shared("root_graph2"); - auto root_graph = std::make_shared("root_graph"); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - - NodePtr node[node_num] = {}; - for (auto _ : state) { - for (int i = 0; i < node_num; i++) { - node[i] = root_graph2->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE AddNode Error" << std::endl; - return; - } -#endif - auto ret = GraphUtils::RemoveJustNode(root_graph2, node[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE RemoveJustNode Error" << std::endl; - return; - } -#endif - - node[i] = root_graph->AddNode(node[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE AddNode Error" << std::endl; - return; - } -#endif - - ret = GraphUtils::RemoveJustNode(root_graph, node[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE RemoveJustNode Error" << std::endl; - return; - } -#endif - } - } -} -BENCHMARK(OLD_Graph_ADD_NODE_WITH_NODE)->Args({20, 100}); -BENCHMARK(OLD_Graph_ADD_NODE_WITH_NODE)->Args({20, 1000}); -BENCHMARK(OLD_Graph_ADD_NODE_WITH_NODE)->Args({20, 10000}); -BENCHMARK(OLD_Graph_ADD_NODE_WITH_NODE)->Args({20, 50000}); - -static void NEW_Graph_ADD_NODE_WITH_NODE(benchmark::State &state) { - auto new_root_graph2 = std::make_shared("new_graph2"); - auto new_root_graph = std::make_shared("new_graph"); - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - FastNode *node[node_num] = {}; - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - - for (auto _ : state) { - for (int i = 0; i < node_num; i++) { - node[i] = new_root_graph2->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE Add Node Error" << std::endl; - return; - } -#endif - auto ret = new_root_graph2->RemoveJustNode(node[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE Remove Node Error" << std::endl; - return; - } -#endif - - node[i] = new_root_graph->AddNode(node[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE Add Node Error" << std::endl; - return; - } -#endif - - ret = new_root_graph->RemoveJustNode(node[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "NEW_Graph_ADD_NODE_WITH_NODE Remove Node Error" << std::endl; - return; - } -#endif - } - } -} -BENCHMARK(NEW_Graph_ADD_NODE_WITH_NODE)->Args({20, 100}); -BENCHMARK(NEW_Graph_ADD_NODE_WITH_NODE)->Args({20, 1000}); -BENCHMARK(NEW_Graph_ADD_NODE_WITH_NODE)->Args({20, 10000}); -BENCHMARK(NEW_Graph_ADD_NODE_WITH_NODE)->Args({20, 50000}); - -static void OLD_Graph_GetDirectNode(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - - for (int j = 0; j < node_num; j++) { - compute_graph->AddNode(op_desc[j]); - } - - for (auto _ : state) { - auto ret = compute_graph->GetDirectNode(); - if (ret.size() == 0) { - std::cout << "OLD GetDirectNode Error " << std::endl; - return; - } - - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_GetDirectNode)->Args({20, 10})->Args({20, 100})->Args({20, 1000})->Args({20, 10000})->Iterations(1); - -static void New_Graph_GetDirectNode(benchmark::State &state) { - auto compute_graph = std::make_shared(std::string("graph")); - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - for (int j = 0; j < node_num; j++) { - (void)compute_graph->AddNode(op_desc[j]); - } - - for (auto _ : state) { - auto ret = compute_graph->GetDirectNode(); - if (ret.size() == 0) { - std::cout << "OLD GetDirectNode Error " << std::endl; - return; - } - - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(New_Graph_GetDirectNode)->Args({20, 10})->Args({20, 100})->Args({20, 1000})->Args({20, 10000})->Iterations(1); - -static void Graph_AddAndRemoveEdge(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - int num = state.range(1); - int vec_size = state.range(0); - std::vector vec; - vec.resize(vec_size); - - for (int i = 0; i < vec_size; i++) { - vec[i] = NodeBuilder("Node" + std::to_string(i), "Node", compute_graph).IoNum(num, num).Build(); -#if GRAPH_CHECK_RET - if (vec[i] == nullptr) { - std::cout << "Graph_AddEdge vec[i] Error " << i << std::endl; - return; - } -#endif - } - - FastEdge *edge[num] = {}; - for (auto _ : state) { - for (int i = 0; i < num; i++) { - edge[i] = compute_graph->AddEdge(vec[1], i, vec[0], i); -#if GRAPH_CHECK_RET - if (edge[i] == nullptr) { - std::cout << "Graph_AddEdge Error " << std::endl; - return; - } -#endif - } - - for (int i = 0; i < num; i++) { - auto ret = compute_graph->RemoveEdge(edge[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_AddAndRemoveEdge RemoveEdge Error " << i << std::endl; - return; - } -#endif - } - benchmark::ClobberMemory(); - } -} -BENCHMARK(Graph_AddAndRemoveEdge)->Args({20, 20})->Args({20, 100})->Args({20, 1000})->Args({20, 10000})->Iterations(1); - -static void OLD_Graph_AddAndRemoveEdge(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - int64_t io_num = state.range(0); - int64_t node_num = state.range(1); - std::vector nodes; - nodes.resize(node_num); - - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - for (int j = 0; j < node_num; j++) { - auto tmp = compute_graph->AddNode(op_desc[j]); - nodes[j] = tmp.get(); - } - - for (auto _ : state) { - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::AddEdge(nodes[i]->GetOutDataAnchor(0), nodes[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_AddAndRemoveEdge AddEdge Error " << i << std::endl; - return; - } -#endif - } - - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::RemoveEdge(nodes[i]->GetOutDataAnchor(0), nodes[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_AddAndRemoveEdge RemoveEdge Error " << i << std::endl; - return; - } -#endif - } - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_AddAndRemoveEdge) - ->Args({20, 20}) - ->Args({20, 100}) - ->Args({20, 1000}) - ->Args({20, 10000}) - ->Iterations(1); - -static void Graph_GetAllEdge(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - std::vector vec; - int node_num = 2; - vec.resize(node_num); - int edge_num = state.range(0); - - for (int i = 0; i < node_num; i++) { - vec[i] = NodeBuilder("Node" + std::to_string(i), "Node", compute_graph).IoNum(edge_num, edge_num).Build(); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; i++) { - edge[i] = compute_graph->AddEdge(vec[1], i, vec[0], i); - } - - for (auto _ : state) { - auto ret = compute_graph->GetAllEdges(); -#if GRAPH_CHECK_RET - if (ret.size() == 0) { - std::cout << "Graph_GetAllOutEdge Error " << std::endl; - return; - } -#endif - - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(Graph_GetAllEdge)->Arg(20)->Arg(100)->Arg(1000)->Arg(10000); - -static void Graph_AddAndRemoveSubgraph(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - auto subgraph_num = state.range(); - int edge_num = 5; - FastNode *node[subgraph_num] = {}; - std::shared_ptr op_desc[subgraph_num] = {nullptr}; - OpDescCreate(subgraph_num, op_desc, edge_num); - - for (int i = 0; i < subgraph_num; ++i) { - node[i] = root_graph->AddNode(op_desc[i]); - } - - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(node[i]); - } - - for (int i = 0; i < subgraph_num - 1; i++) { - auto ret = root_graph->AddSubGraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (ret == nullptr) { - std::cout << "Graph_RemoveHeadEdge Error" << std::endl; - return; - } -#endif - } - - for (auto _ : state) { - auto ret = root_graph->AddSubGraph(sub_graph[subgraph_num - 1]); -#if GRAPH_CHECK_RET - if (ret == nullptr) { - std::cout << "Graph_RemoveHeadEdge Error" << std::endl; - return; - } -#endif - root_graph->RemoveSubGraph(ret); - - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(Graph_AddAndRemoveSubgraph)->Arg(10)->Arg(100)->Arg(1000)->Arg(10000); - -static void OLD_Graph_AddAndRemoveSubgraph(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - auto subgraph_num = state.range(); - int edge_num = 5; - NodePtr node[subgraph_num] = {}; - std::shared_ptr op_desc[subgraph_num] = {nullptr}; - ComputeGraphPtr sub_graph[subgraph_num] = {nullptr}; - - OpDescCreate(subgraph_num, op_desc, edge_num); - for (int i = 0; i < subgraph_num; ++i) { - node[i] = root_graph->AddNode(op_desc[i]); - } - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - sub_graph[i]->SetParentGraph(root_graph); - sub_graph[i]->SetParentNode(node[i]); - } - for (int i = 0; i < subgraph_num - 1; i++) { - auto ret = root_graph->AddSubgraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_AddAndRemoveSubgraph 0 Error" << std::endl; - return; - } -#endif - } - - for (auto _ : state) { - auto ret = root_graph->AddSubgraph(sub_graph[subgraph_num - 1]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_AddAndRemoveSubgraph 1 Error" << std::endl; - return; - } -#endif - root_graph->RemoveSubGraph(sub_graph[subgraph_num - 1]); - - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_AddAndRemoveSubgraph)->Arg(10)->Arg(100)->Arg(1000)->Arg(10000); - -static void Graph_Sort(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - std::vector vec; - int io_num = state.range(0); - int node_num = state.range(1); - vec.resize(node_num); - - for (int i = 0; i < node_num; i++) { - vec[i] = NodeBuilder("Node" + std::to_string(i), "Node", compute_graph).IoNum(io_num, io_num).Build(); -#if GRAPH_CHECK_RET - if (vec[i] == nullptr) { - std::cout << "Graph_Sort Error." << std::endl; - return; - } -#endif - } - - FastEdge *edge[node_num] = {}; - for (int j = 1; j < node_num; j++) { - edge[j] = compute_graph->AddEdge(vec[j - 1], 1, vec[j], 0); -#if GRAPH_CHECK_RET - if (edge[j] == nullptr) { - std::cout << "Graph_Sort Error." << std::endl; - return; - } -#endif - } - - for (auto _ : state) { - auto ret = compute_graph->TopologicalSortingGraph(compute_graph.get(), true); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "Graph_Sort Error: " << ret << std::endl; - return; - } -#endif - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(Graph_Sort)->Args({20, 10000})->Args({20, 50000})->Args({20, 100000})->Args({20, 200000})->Iterations(1); - -static void OLD_Graph_Sort(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - int io_num = state.range(0); - int node_num = state.range(1); - std::vector vec; - vec.resize(node_num); - - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - for (int i = 0; i < node_num; i++) { - vec[i] = compute_graph->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (vec[i] == nullptr) { - std::cout << "OLD_Graph_Sort Error: 0" << std::endl; - return; - } -#endif - } - - for (int j = 1; j < node_num; j++) { - auto ret = GraphUtils::AddEdge(vec[j - 1]->GetOutDataAnchor(1), vec[j]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_Sort Error: 1" << std::endl; - return; - } -#endif - } - - for (auto _ : state) { - auto ret = compute_graph->TopologicalSortingGraph(true); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_Sort Error: 2" << std::endl; - return; - } -#endif - - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_Sort)->Args({20, 10000})->Args({20, 50000})->Args({20, 100000})->Args({20, 200000})->Iterations(1); - -static void Graph_ALL_RUN(benchmark::State &state) { - int node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - int edge_num = state.range(0); - OpDescCreate(node_num, op_desc, edge_num); - - auto subgraph_num = state.range(2); - auto subgraph_node_num = state.range(3); - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - - FastNode *node[node_num] = {}; - FastEdge *edge[node_num] = {}; - ExecuteGraph *quick_graph[subgraph_num] = {nullptr}; - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - OpDescCreate(subgraph_node_num, sub_op_desc[i], edge_num); - } - - auto root_graph = std::make_shared("root_graph"); - for (auto _ : state) { - for (int i = 0; i < node_num; i++) { - node[i] = root_graph->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "0 Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - for (int i = 1; i < node_num; i++) { - edge[i] = root_graph->AddEdge(node[i], 1, node[i - 1], 0); -#if GRAPH_CHECK_RET - if (edge[i] == nullptr) { - std::cout << "1 Graph_ALL_RUN Add Edge Error " << i << std::endl; - return; - } -#endif - } - - FastNode *sub_node[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - for (int j = 0; j < subgraph_node_num; j++) { - sub_node[i][j] = sub_graph[i]->AddNode(sub_op_desc[i][j]); -#if GRAPH_CHECK_RET - if (sub_node[i][j] == nullptr) { - std::cout << "Graph_ALL_RUN add subgraph node error." << std::endl; - return; - } -#endif - } - } - - for (int i = 0; i < subgraph_num; i++) { - for (int j = 1; j < subgraph_node_num; j++) { - auto ret = sub_graph[i]->AddEdge(sub_node[i][j], 1, sub_node[i][j - 1], 0); -#if GRAPH_CHECK_RET - if (ret == nullptr) { - std::cout << "1 Graph_ALL_RUN sub graph edge Error " << j << std::endl; - return; - } -#endif - } - } - - for (int i = 0; i < subgraph_num; ++i) { - quick_graph[i] = root_graph->AddSubGraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (quick_graph[i] == nullptr) { - std::cout << "2 Graph_ALL_RUN add subgraph Error" << std::endl; - return; - } -#endif - } - - root_graph->TopologicalSortingGraph(root_graph.get(), true); - - for (int i = 1; i < node_num; i++) { - root_graph->RemoveEdge(edge[i]); - } - - for (int i = 0; i < node_num; i++) { - root_graph->RemoveJustNode(node[i]); - } - - for (int i = 0; i < subgraph_num; ++i) { - root_graph->RemoveSubGraph(quick_graph[i]); - } - - benchmark::ClobberMemory(); - } -} -BENCHMARK(Graph_ALL_RUN)->Args({20, 2000, 1000, 10})->Iterations(1); -BENCHMARK(Graph_ALL_RUN)->Args({20, 4000, 1000, 10})->Iterations(1); -BENCHMARK(Graph_ALL_RUN)->Args({20, 6000, 1000, 10})->Iterations(1); -BENCHMARK(Graph_ALL_RUN)->Args({20, 8000, 1000, 10})->Iterations(1); - -static void OLD_Graph_ALL_RUN(benchmark::State &state) { - auto subgraph_num = state.range(2); - int node_num = state.range(1); - int edge_num = state.range(0); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, edge_num); - - auto subgraph_node_num = state.range(3); - ComputeGraphPtr sub_graph[subgraph_num] = {nullptr}; - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - OpDescCreate(subgraph_node_num, sub_op_desc[i], edge_num); - } - - NodePtr node[node_num] = {}; - ComputeGraphPtr quick_graph[subgraph_num] = {nullptr}; - auto old_root_graph = std::make_shared("root_graph"); - for (auto _ : state) { - for (int i = 0; i < node_num; i++) { - node[i] = old_root_graph->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "OLD_Graph_ALL_RUN add node error." << std::endl; - return; - } -#endif - } - - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::AddEdge(node[i]->GetOutDataAnchor(1), node[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_ALL_RUN add edge error" << std::endl; - return; - } -#endif - } - - NodePtr sub_graph_node[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - for (int j = 0; j < subgraph_node_num; j++) { - sub_graph_node[i][j] = sub_graph[i]->AddNode(sub_op_desc[i][j]); -#if GRAPH_CHECK_RET - if (sub_graph_node[i][j] == nullptr) { - std::cout << "OLD_Graph_ALL_RUN add node error." << std::endl; - return; - } -#endif - } - } - - for (int i = 0; i < subgraph_num; i++) { - for (int j = 1; j < subgraph_node_num; j++) { - GraphUtils::AddEdge(sub_graph_node[i][j]->GetOutDataAnchor(1), sub_graph_node[i][j - 1]->GetInDataAnchor(0)); - } - } - - for (int64_t i = 0; i < subgraph_num; ++i) { - quick_graph[i] = old_root_graph->AddSubGraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (quick_graph[i] == nullptr) { - std::cout << "2 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - old_root_graph->TopologicalSortingGraph(true); - - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::RemoveEdge(node[i]->GetOutDataAnchor(1), node[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "0 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - for (int i = 0; i < node_num; i++) { - auto ret = GraphUtils::RemoveJustNode(old_root_graph, node[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_ALL_RUN remove node error." << std::endl; - return; - } -#endif - } - - for (int64_t i = 0; i < subgraph_num; ++i) { - old_root_graph->RemoveSubGraph(quick_graph[i]); - } - - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_ALL_RUN)->Args({20, 2000, 1000, 10})->Iterations(1); -BENCHMARK(OLD_Graph_ALL_RUN)->Args({20, 4000, 1000, 10})->Iterations(1); -BENCHMARK(OLD_Graph_ALL_RUN)->Args({20, 6000, 1000, 10})->Iterations(1); -BENCHMARK(OLD_Graph_ALL_RUN)->Args({20, 8000, 1000, 10})->Iterations(1); - -static void Graph_AddAndRemoveSubgraph_Multi(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - auto subgraph_num = state.range(); - - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - } - - int edge_num = 5; - FastNode *node[subgraph_num] = {}; - std::shared_ptr op_desc[subgraph_num] = {nullptr}; - OpDescCreate(subgraph_num, op_desc, edge_num); - for (int i = 0; i < subgraph_num; ++i) { - node[i] = root_graph->AddNode(op_desc[i]); - } - - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(node[i]); - } - - ExecuteGraph *new_sub_graph[subgraph_num] = {nullptr}; - for (auto _ : state) { - for (int64_t i = 0; i < subgraph_num; ++i) { - new_sub_graph[i] = root_graph->AddSubGraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (new_sub_graph[i] == nullptr) { - std::cout << "Graph_RemoveHeadEdge Error" << std::endl; - return; - } -#endif - } - for (int64_t i = 0; i < subgraph_num; ++i) { - root_graph->RemoveSubGraph(new_sub_graph[i]); - } - - benchmark::ClobberMemory(); - } -} -BENCHMARK(Graph_AddAndRemoveSubgraph_Multi)->Arg(10)->Arg(100)->Arg(1000)->Arg(10000); - -static void Graph_AddAndRemoveSubgraph_Multi_OLD(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - auto subgraph_num = state.range(); - int edge_num = 5; - NodePtr node[subgraph_num] = {}; - std::shared_ptr op_desc[subgraph_num] = {nullptr}; - OpDescCreate(subgraph_num, op_desc, edge_num); - for (int i = 0; i < subgraph_num; ++i) { - node[i] = root_graph->AddNode(op_desc[i]); - } - - ComputeGraphPtr sub_graph[subgraph_num] = {nullptr}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - sub_graph[i]->SetParentGraph(root_graph); - sub_graph[i]->SetParentNode(node[i]); - } - - for (auto _ : state) { - for (int64_t i = 0; i < subgraph_num; ++i) { - auto ret = root_graph->AddSubGraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (ret == nullptr) { - std::cout << "Graph_RemoveHeadEdge Error" << std::endl; - return; - } -#endif - } - for (int64_t i = 0; i < subgraph_num; ++i) { - root_graph->RemoveSubGraph(sub_graph[i]); - } - - benchmark::ClobberMemory(); - } -} -BENCHMARK(Graph_AddAndRemoveSubgraph_Multi_OLD)->Arg(10)->Arg(100)->Arg(1000)->Arg(10000); - -static void TEST_ANCHOR(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - auto edge_num = state.range(); - int node_num = 2; - NodePtr node[node_num] = {}; - - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, edge_num); - for (int i = 0; i < node_num; ++i) { - node[i] = root_graph->AddNode(op_desc[i]); - } - - InDataAnchorPtr ptr[edge_num] = {}; - OutDataAnchorPtr out_ptr[edge_num] = {}; - - for (auto _ : state) { - node[0]->GetAllInAnchors(); - node[0]->GetAllOutAnchors(); - benchmark::ClobberMemory(); - } -} -BENCHMARK(TEST_ANCHOR)->Arg(10)->Arg(100)->Arg(1000)->Iterations(1); - -static void TEST_ANCHOR_PEER_GET(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - auto anchor_num = state.range(); - int node_num = 2; - NodePtr node[node_num] = {}; - InDataAnchorPtr ptr[anchor_num] = {}; - OutDataAnchorPtr out_ptr[anchor_num] = {}; - std::shared_ptr op_desc[node_num] = {nullptr}; - - OpDescCreate(node_num, op_desc, anchor_num); - for (int i = 0; i < node_num; ++i) { - node[i] = root_graph->AddNode(op_desc[i]); - } - - for (int j = 0; j < anchor_num; j++) { - GraphUtils::AddEdge(node[0]->GetOutDataAnchor(0), node[1]->GetInDataAnchor(j)); - } - - for (auto _ : state) { - auto ret = node[0]->GetOutDataAnchor(0)->GetPeerAnchors(); - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(TEST_ANCHOR_PEER_GET)->Arg(10)->Arg(100)->Arg(1000); - -static void TEST_GET_IN_NODES(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - auto anchor_num = state.range(); - int node_num = 1001; - NodePtr node[node_num] = {}; - - for (int j = 0; j < node_num; j++) { - OpDescPtr op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int64_t i = 0; i < anchor_num; ++i) { - op_desc->AddInputDesc(td); - } - for (int64_t i = 0; i < anchor_num; ++i) { - op_desc->AddOutputDesc(td); - } - node[j] = root_graph->AddNode(op_desc); - } - - for (int j = 0; j < anchor_num; j++) { - GraphUtils::AddEdge(node[0]->GetOutDataAnchor(j), node[j]->GetInDataAnchor(j)); - } - - InDataAnchorPtr ptr[anchor_num] = {}; - OutDataAnchorPtr out_ptr[anchor_num] = {}; - - for (auto _ : state) { - auto ret = node[0]->GetOutNodes(); - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(TEST_GET_IN_NODES)->Arg(10)->Arg(100)->Arg(1000); - -static void TEST_OLD_Graph_ALL_RUN(benchmark::State &state) { - int edge_num = state.range(0); - int node_num = state.range(1); - auto subgraph_num = state.range(2); - auto subgraph_node_num = state.range(3); - - std::shared_ptr op_desc[node_num] = {nullptr}; - ComputeGraphPtr sub_graph[subgraph_num] = {nullptr}; - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - NodePtr node[node_num] = {}; - ComputeGraphPtr quick_graph[subgraph_num] = {nullptr}; - - OpDescCreate(node_num, op_desc, edge_num); - for (int i = 0; i < subgraph_num; i++) { - OpDescCreate(subgraph_num, sub_op_desc[i], edge_num); - } - - auto root_graph = std::make_shared("root_graph"); - for (auto _ : state) { - for (int i = 0; i < node_num; i++) { - node[i] = root_graph->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "0 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::AddEdge(node[i]->GetOutDataAnchor(0), node[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "0 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - ret = GraphUtils::AddEdge(node[i]->GetOutDataAnchor(1), node[i - 1]->GetInDataAnchor(1)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "0 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - root_graph->TopologicalSortingGraph(true); - - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::RemoveEdge(node[i]->GetOutDataAnchor(0), node[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "0 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - ret = GraphUtils::RemoveEdge(node[i]->GetOutDataAnchor(1), node[i - 1]->GetInDataAnchor(1)); - } - - for (int i = 0; i < node_num; i++) { - auto ret = GraphUtils::RemoveJustNode(root_graph, node[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "TEST_OLD_Graph_ALL_RUN remove node error." << std::endl; - return; - } -#endif - } - - benchmark::ClobberMemory(); - } -} -BENCHMARK(TEST_OLD_Graph_ALL_RUN)->Args({20, 500, 1, 0})->Iterations(100); - -static void OLD_GRAPH_DEEPCOPY(benchmark::State &state) { - auto subgraph_num = state.range(2); - int node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - for (int j = 0; j < node_num; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int64_t i = 0; i < state.range(0); ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < state.range(0); ++i) { - op_desc[j]->AddOutputDesc(td); - } - } - - auto subgraph_node_num = state.range(3); - ComputeGraphPtr sub_graph[subgraph_num] = {nullptr}; - - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - for (int j = 0; j < subgraph_node_num; j++) { - sub_op_desc[i][j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t x = 0; x < state.range(0); ++x) { - sub_op_desc[i][j]->AddInputDesc(td); - } - for (int64_t x = 0; x < state.range(0); ++x) { - sub_op_desc[i][j]->AddOutputDesc(td); - } - } - } - - NodePtr node[node_num] = {}; - ComputeGraphPtr quick_graph[subgraph_num] = {nullptr}; - auto root_graph = std::make_shared("root_graph"); - - for (int i = 0; i < node_num; i++) { - node[i] = root_graph->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "0 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::AddEdge(node[i]->GetOutDataAnchor(1), node[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "0 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - NodePtr sub_graph_node[subgraph_node_num] = {}; - for (int j = 0; j < subgraph_node_num; j++) { - sub_graph_node[j] = sub_graph[i]->AddNode(sub_op_desc[i][j]); -#if GRAPH_CHECK_RET - if (sub_graph_node[j] == nullptr) { - std::cout << "1111 Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - for (int j = 1; j < subgraph_node_num; j++) { - GraphUtils::AddEdge(sub_graph_node[j]->GetOutDataAnchor(1), sub_graph_node[j - 1]->GetInDataAnchor(0)); - } - } - - for (int64_t i = 0; i < subgraph_num; ++i) { - quick_graph[i] = root_graph->AddSubGraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (quick_graph[i] == nullptr) { - std::cout << "2 OLD_Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - auto test_graph = std::make_shared("test"); - - for (auto _ : state) { - GraphUtils::CopyComputeGraph(root_graph, test_graph); - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_GRAPH_DEEPCOPY)->Args({20, 10000, 1, 10})->Iterations(1); -BENCHMARK(OLD_GRAPH_DEEPCOPY)->Args({20, 50000, 10, 10})->Iterations(1); -BENCHMARK(OLD_GRAPH_DEEPCOPY)->Args({20, 50000, 100, 10})->Iterations(1); -BENCHMARK(OLD_GRAPH_DEEPCOPY)->Args({20, 50000, 1000, 10})->Iterations(1); - -static void NEW_GRAPH_DEEPCOPY(benchmark::State &state) { - int node_num = state.range(1); - std::shared_ptr op_desc[node_num] = {nullptr}; - for (int j = 0; j < node_num; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t i = 0; i < state.range(0); ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < state.range(0); ++i) { - op_desc[j]->AddOutputDesc(td); - } - } - - auto subgraph_num = state.range(2); - auto subgraph_node_num = state.range(3); - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - - FastNode *node[node_num] = {}; - FastEdge *edge[node_num] = {}; - ExecuteGraph *quick_graph[subgraph_num] = {nullptr}; - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - for (int j = 0; j < subgraph_node_num; j++) { - sub_op_desc[i][j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t x = 0; x < state.range(0); ++x) { - sub_op_desc[i][j]->AddInputDesc(td); - } - for (int64_t x = 0; x < state.range(0); ++x) { - sub_op_desc[i][j]->AddOutputDesc(td); - } - } - } - - auto root_graph = std::make_shared("root_graph"); - for (int i = 0; i < node_num; i++) { - node[i] = root_graph->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "0 Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - for (int i = 1; i < node_num; i++) { - edge[i] = root_graph->AddEdge(node[i], 1, node[i - 1], 0); -#if GRAPH_CHECK_RET - if (edge[i] == nullptr) { - std::cout << "1 Graph_ALL_RUN Add Edge Error " << i << std::endl; - return; - } -#endif - } - - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - FastNode *sub_graph_node[subgraph_node_num] = {}; - for (int j = 0; j < subgraph_node_num; j++) { - sub_graph_node[j] = sub_graph[i]->AddNode(sub_op_desc[i][j]); -#if GRAPH_CHECK_RET - if (sub_graph_node[j] == nullptr) { - std::cout << "1111 Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - for (int j = 1; j < subgraph_node_num; j++) { - auto ret = sub_graph[i]->AddEdge(sub_graph_node[j], 1, sub_graph_node[j - 1], 0); -#if GRAPH_CHECK_RET - if (ret == nullptr) { - std::cout << "1 Graph_ALL_RUN sub graph Error" << j << std::endl; - return; - } -#endif - } - } - - for (int i = 0; i < subgraph_num; ++i) { - quick_graph[i] = root_graph->AddSubGraph(sub_graph[i]); -#if GRAPH_CHECK_RET - if (quick_graph[i] == nullptr) { - std::cout << "2 Graph_ALL_RUN Error" << std::endl; - return; - } -#endif - } - - auto test1_graph = std::make_shared("root_graph"); - for (auto _ : state) { - test1_graph->CompleteCopy(*(root_graph.get())); - benchmark::ClobberMemory(); - } -} -BENCHMARK(NEW_GRAPH_DEEPCOPY)->Args({20, 10000, 1, 10})->Iterations(1); -BENCHMARK(NEW_GRAPH_DEEPCOPY)->Args({20, 50000, 10, 10})->Iterations(1); -BENCHMARK(NEW_GRAPH_DEEPCOPY)->Args({20, 50000, 100, 10})->Iterations(1); -BENCHMARK(NEW_GRAPH_DEEPCOPY)->Args({20, 50000, 1000, 10})->Iterations(1); - -static void TEST_REMOVE_NODE(benchmark::State &state) { - int node_num = state.range(1); - int edge_num = state.range(0); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, edge_num); - - NodePtr node[node_num] = {}; - auto old_root_graph = std::make_shared("root_graph"); - - for (int i = 0; i < node_num; i++) { - node[i] = old_root_graph->AddNode(op_desc[i]); -#if GRAPH_CHECK_RET - if (node[i] == nullptr) { - std::cout << "OLD_Graph_ALL_RUN add node error." << std::endl; - return; - } -#endif - } - - for (int i = 1; i < node_num; i++) { - auto ret = GraphUtils::AddEdge(node[i]->GetOutDataAnchor(1), node[i - 1]->GetInDataAnchor(0)); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_ALL_RUN add edge error" << std::endl; - return; - } -#endif - } - - for (auto _ : state) { - for (int i = 0; i < node_num; i++) { - auto ret = old_root_graph->RemoveNode(node[i]); -#if GRAPH_CHECK_RET - if (ret != GRAPH_SUCCESS) { - std::cout << "OLD_Graph_ALL_RUN remove node error." << std::endl; - return; - } -#endif - } - - benchmark::ClobberMemory(); - } -} -BENCHMARK(TEST_REMOVE_NODE)->Args({20, 10000, 2000, 10})->Iterations(1); -BENCHMARK(TEST_REMOVE_NODE)->Args({20, 10000, 4000, 10})->Iterations(1); -BENCHMARK(TEST_REMOVE_NODE)->Args({20, 10000, 6000, 10})->Iterations(1); -BENCHMARK(TEST_REMOVE_NODE)->Args({20, 10000, 10000, 10})->Iterations(1); - -static void TEST_GetSubGraph(benchmark::State &state) { - auto root_graph = std::make_shared("root_graph"); - int node_num = state.range(0); - size_t subgraph_num = state.range(1); - int edge_num = 5; - NodePtr node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = root_graph->AddNode(op_desc); - } - - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - for (size_t i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - sub_graph[i]->SetParentGraph(root_graph); - sub_graph[i]->SetParentNode(node[i]); - } - - for (size_t i = 0; i < subgraph_num; i++) { - std::string name = "subgraph_" + std::to_string(i); - root_graph->AddSubgraph(name, sub_graph[i]); - } - - for (auto _ : state) { - auto subgraphs = root_graph->GetAllSubgraphs(); -#if GRAPH_CHECK_RET - if (subgraphs.size() != subgraph_num) { - std::cout << "0 TEST_GetSubGraph Error" << std::endl; - exit(1); - } -#endif - benchmark::DoNotOptimize(subgraphs); - benchmark::ClobberMemory(); - } -} -BENCHMARK(TEST_GetSubGraph)->Args({10000, 2000}); -BENCHMARK(TEST_GetSubGraph)->Args({10000, 4000}); -BENCHMARK(TEST_GetSubGraph)->Args({10000, 6000}); -BENCHMARK(TEST_GetSubGraph)->Args({10000, 8000}); - -static void New_Graph_AddNodeAndUpdateIo(benchmark::State &state) { - auto compute_graph = std::make_shared("graph0"); - int node_num = 1; - int io_num = 10; - int new_num = state.range(0); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - FastNode *node = compute_graph->AddNode(op_desc[0]); - - for (auto _ : state) { - auto ret = FastNodeUtils::AppendInputEdgeInfo(node, new_num); - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(New_Graph_AddNodeAndUpdateIo)->Arg(11)->Arg(21)->Arg(110)->Arg(1010)->Iterations(1); - -static void New_Graph_AddNodeAndUpdateOutput_Step(benchmark::State &state) { - auto compute_graph = std::make_shared("graph0"); - int node_num = 1; - int io_num = 10; - int new_num = state.range(0); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - FastNode *node = compute_graph->AddNode(op_desc[0]); - - const GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - for (int i = op_desc[0]->GetOutputsSize(); i < new_num; ++i) { - op_desc[0]->AddOutputDesc(data_desc); - } - - for (auto _ : state) { - for (int i = io_num; i < new_num; ++i) { - node->UpdateDataOutNum(i); - } - benchmark::ClobberMemory(); - } -} -BENCHMARK(New_Graph_AddNodeAndUpdateOutput_Step)->Arg(20)->Arg(110)->Arg(1010)->Iterations(1); - -static void New_Graph_AddNodeAndUpdateInput_Step(benchmark::State &state) { - auto compute_graph = std::make_shared("graph0"); - int node_num = 1; - int io_num = 10; - int new_num = state.range(0); - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - FastNode *node = compute_graph->AddNode(op_desc[0]); - - const GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); - for (int i = op_desc[0]->GetOutputsSize(); i < new_num; ++i) { - op_desc[0]->AddOutputDesc(data_desc); - } - - for (auto _ : state) { - for (int i = io_num; i < new_num; ++i) { - node->UpdateDataInNum(i); - } - benchmark::ClobberMemory(); - } -} -BENCHMARK(New_Graph_AddNodeAndUpdateInput_Step)->Arg(20)->Arg(110)->Arg(1010)->Iterations(1); - -static void OLD_Graph_AddNodeAndUpdateIo(benchmark::State &state) { - auto compute_graph = std::make_shared("graph0"); - int node_num = 1; - int io_num = 10; - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, io_num); - NodePtr node = compute_graph->AddNode(op_desc[0]); - int new_num = state.range(0); - - for (auto _ : state) { - auto ret = NodeUtils::AppendInputAnchor(node, new_num); - benchmark::DoNotOptimize(ret); - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_Graph_AddNodeAndUpdateIo)->Arg(11)->Arg(21)->Arg(110)->Arg(1010)->Iterations(1); - -static void OLD_ChangeEdgeAndNodeOwner(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - int node_num = state.range(0); - int edge_num = 5; - NodePtr node[node_num] = {}; - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, edge_num); - - for (int i = 0; i < node_num; ++i) { - node[i] = compute_graph->AddNode(op_desc[i]); - if (node[i] == nullptr) { - return; - } - } - - for (int i = 1; i < node_num; ++i) { - GraphUtils::AddEdge(node[i - 1]->GetOutDataAnchor(0), node[i]->GetInDataAnchor(0)); - GraphUtils::AddEdge(node[i - 1]->GetOutControlAnchor(), node[i]->GetInControlAnchor()); - } - - auto graph1 = std::make_shared("graph1"); - for (auto _ : state) { - for (int i = 0; i < node_num; ++i) { - graph1->AddNode(node[i]); - GraphUtils::RemoveJustNode(compute_graph, node[i]); - } - - for (int i = 0; i < node_num; ++i) { - compute_graph->AddNode(node[i]); - GraphUtils::RemoveJustNode(graph1, node[i]); - } - - benchmark::ClobberMemory(); - } -} -BENCHMARK(OLD_ChangeEdgeAndNodeOwner)->Arg(10)->Arg(100)->Arg(1000); - -static void NEW_ChangeEdgeAndNodeOwner(benchmark::State &state) { - auto compute_graph = std::make_shared("graph"); - int node_num = state.range(0); - int edge_num = 5; - FastNode *node[node_num] = {}; - std::shared_ptr op_desc[node_num] = {nullptr}; - OpDescCreate(node_num, op_desc, edge_num); - - for (int i = 0; i < node_num; ++i) { - node[i] = compute_graph->AddNode(op_desc[i]); - if (node[i] == nullptr) { - return; - } - } - - FastEdge *edge[node_num] = {}; - FastEdge *ctrl_edge[node_num] = {}; - for (int i = 1; i < node_num; ++i) { - edge[i] = compute_graph->AddEdge(node[i - 1], 0, node[i], 0); - ctrl_edge[i] = compute_graph->AddEdge(node[i - 1], kControlEdgeIndex, node[i], kControlEdgeIndex); - } - - auto graph1 = std::make_shared("graph1"); - for (auto _ : state) { - for (int i = 0; i < node_num; ++i) { - graph1->AddNode(node[i]); - } - - for (int i = 0; i < node_num; ++i) { - auto &edges = node[i]->GetAllInDataEdgesRef(); - for (auto edge : edges) { - if (edge != nullptr) { - graph1->MoveEdgeToGraph(edge); - } - } - } - - for (int i = 0; i < node_num; ++i) { - compute_graph->AddNode(node[i]); - } - - for (int i = 0; i < node_num; ++i) { - auto &edges = node[i]->GetAllInDataEdgesRef(); - for (auto edge : edges) { - if (edge != nullptr) { - compute_graph->MoveEdgeToGraph(edge); - } - } - } - benchmark::ClobberMemory(); - } -} -BENCHMARK(NEW_ChangeEdgeAndNodeOwner)->Arg(10)->Arg(100)->Arg(1000); - -BENCHMARK_MAIN(); diff --git a/tests/depends/faker/kernel_run_context_faker.cc b/tests/depends/faker/kernel_run_context_faker.cc index 96b90b975122253a32541e0287dfd61df5857623..6bd14daf6d81f9412a5b943d3689461469712d81 100644 --- a/tests/depends/faker/kernel_run_context_faker.cc +++ b/tests/depends/faker/kernel_run_context_faker.cc @@ -9,8 +9,13 @@ #include "kernel_run_context_faker.h" #include "graph/compute_graph.h" -#include "exe_graph/lowering/bg_kernel_context_extend.h" #include "exe_graph/runtime/tiling_context.h" +#include "graph/operator_factory_impl.h" +#include "base/context_builder/op_kernel_run_context_builder.h" + +namespace ge { +void ge::OperatorFactoryImpl::ReleaseRegInfo() {} +} namespace gert { FakeKernelContextHolder BuildKernelRunContext(size_t input_num, size_t output_num) { @@ -44,70 +49,17 @@ KernelRunContextFaker &KernelRunContextFaker::IrOutputInstanceNum(std::vector("node", "node"); - size_t input_index = 0; - for (size_t ir_index = 0; ir_index < ir_instance_num_.size(); ++ir_index) { - auto ir_ins_num = ir_instance_num_[ir_index]; - auto prefix = "x_" + std::to_string(ir_index) + "_"; - op_desc->AppendIrInput(prefix, ge::kIrInputDynamic); - for (size_t i = 0; i < ir_ins_num; ++i, ++input_index) { - auto td = ge::GeTensorDesc(); - if (node_input_tds_.size() > input_index) { - td.SetOriginFormat(node_input_tds_[input_index].GetOriginFormat()); - td.SetFormat(node_input_tds_[input_index].GetStorageFormat()); - td.SetDataType(node_input_tds_[input_index].GetDataType()); - td.SetOriginDataType(node_input_tds_[input_index].GetDataType()); - } - op_desc->AddInputDesc(prefix + std::to_string(i), td); - } - } - // fill it when not set - std::vector ir_output_instance_num; - if (ir_output_instance_num_.empty()) { - for (size_t i = 0; i < node_output_num_; ++i) { - ir_output_instance_num.emplace_back(1U); - } - } else { - ir_output_instance_num = ir_output_instance_num_; - } - size_t output_index = 0; - for (size_t ir_index = 0; ir_index < ir_output_instance_num.size(); ++ir_index) { - auto ir_ins_num = ir_output_instance_num[ir_index]; - auto prefix = "y_" + std::to_string(ir_index) + "_"; - op_desc->AppendIrOutput(prefix, ge::kIrOutputDynamic); - for (size_t i = 0; i < ir_ins_num; ++i, ++output_index) { - auto td = ge::GeTensorDesc(); - if (node_output_tds_.size() > output_index) { - td.SetOriginFormat(node_output_tds_[output_index].GetOriginFormat()); - td.SetFormat(node_output_tds_[output_index].GetStorageFormat()); - td.SetDataType(node_output_tds_[output_index].GetDataType()); - td.SetOriginDataType(node_output_tds_[output_index].GetDataType()); - } - op_desc->AddOutputDesc(prefix + std::to_string(i), td); - } - } - - for (const auto &attr : attrs_) { - op_desc->AppendIrAttrName(attr.first); - op_desc->SetAttr(attr.first, attr.second); - } - return op_desc; -} FakeKernelContextHolder KernelRunContextFaker::Build() const { FakeKernelContextHolder fake_holder; fake_holder.kernel_input_num = kernel_input_num_; fake_holder.kernel_output_num = kernel_output_num_; - KernelRunContextBuilder kernel_context_builder; - auto op_desc = FakeOp(); + OpKernelContextBuilder kernel_context_builder; if (inputs_.size() != kernel_input_num_ || outputs_.size() != kernel_output_num_) { std::vector inputs(kernel_input_num_, nullptr); std::vector outputs(kernel_output_num_, nullptr); - fake_holder.holder = kernel_context_builder.Inputs(inputs).Outputs(outputs).Build(op_desc); return fake_holder; } - fake_holder.holder = kernel_context_builder.Inputs(inputs_).Outputs(outputs_).Build(op_desc); return fake_holder; } KernelRunContextFaker &KernelRunContextFaker::NodeInputTd(int32_t index, ge::DataType dt, ge::Format origin_format, diff --git a/tests/depends/faker/kernel_run_context_faker.h b/tests/depends/faker/kernel_run_context_faker.h index e4f7787408d83a1cc5150317436768c7a9587942..c2b771c6508f780da272c9497e9ea36d05921d98 100644 --- a/tests/depends/faker/kernel_run_context_faker.h +++ b/tests/depends/faker/kernel_run_context_faker.h @@ -16,24 +16,21 @@ #include "exe_graph/runtime/context_extend.h" #include "exe_graph/runtime/storage_shape.h" #include "exe_graph/runtime/tiling_context.h" -#include "exe_graph/lowering/buffer_pool.h" #include "graph/any_value.h" #include "graph/node.h" -#include "lowering/kernel_run_context_builder.h" #include "exe_graph/runtime/gert_mem_allocator.h" +#include "base/context_builder/context_holder.h" namespace gert { struct FakeKernelContextHolder { template T *GetContext() { - return reinterpret_cast(holder.context_); - } - ComputeNodeInfo *MutableComputeNodeInfo() { - return reinterpret_cast(holder.compute_node_extend_holder_.get()); + return reinterpret_cast(holder.GetContext()); } + size_t kernel_input_num; size_t kernel_output_num; - KernelContextHolder holder; + ContextHolderVoid holder; }; FakeKernelContextHolder BuildKernelRunContext(size_t input_num, size_t output_num); diff --git a/tests/run_test.sh b/tests/run_test.sh index 5a6c0ac2e6aefbcd87b84eb5f1d274510197ac4b..c4bd2199bcfcf1423e4a158ec18218e9e35d4ba8 100755 --- a/tests/run_test.sh +++ b/tests/run_test.sh @@ -154,9 +154,7 @@ build_metadef() { cmake_generate_make "${BUILD_PATH}" "${CMAKE_ARGS}" if [[ "X$ENABLE_METADEF_UT" = "Xon" || "X$ENABLE_METADEF_COV" = "Xon" ]]; then - make ut_metadef ut_graph ut_register ut_error_manager ut_exe_graph ut_exe_meta_device ut_ascendc_ir ut_expression ut_sc_check ${VERBOSE} -j${THREAD_NUM} - elif [ "X$ENABLE_BENCHMARK" = "Xon" ]; then - make exec_graph_benchmark fast_graph_benchmark ut_ascendc_ir ut_expression ${VERBOSE} -j${THREAD_NUM} + make ut_metadef ut_register ut_error_manager ut_exe_meta_device ut_sc_check ${VERBOSE} -j${THREAD_NUM} fi if [ 0 -ne $? ]; then @@ -180,28 +178,15 @@ main() { build_metadef || { echo "Metadef llt build failed."; exit 1; } echo "---------------- Metadef llt build finished ----------------" - if [ "X$ENABLE_BENCHMARK" = "Xon" ]; then - RUN_TEST_CASE=${BUILD_PATH}/tests/benchmark/exe_graph/exec_graph_benchmark && ${RUN_TEST_CASE} - RUN_TEST_CASE=${BUILD_PATH}/tests/benchmark/fast_graph/fast_graph_benchmark && ${RUN_TEST_CASE} - fi - if [[ "X$ENABLE_METADEF_UT" = "Xon" || "X$ENABLE_METADEF_COV" = "Xon" ]]; then cp ${BUILD_PATH}/tests/ut/base/ut_metadef ${OUTPUT_PATH} - cp ${BUILD_PATH}/tests/ut/graph/ut_graph ${OUTPUT_PATH} cp ${BUILD_PATH}/tests/ut/register/ut_register ${OUTPUT_PATH} cp ${BUILD_PATH}/tests/ut/error_manager/ut_error_manager ${OUTPUT_PATH} - cp ${BUILD_PATH}/tests/ut/exe_graph/ut_exe_graph ${OUTPUT_PATH} - cp ${BUILD_PATH}/tests/ut/ascendc_ir/ut_ascendc_ir ${OUTPUT_PATH} - cp ${BUILD_PATH}/tests/ut/expression/ut_expression ${OUTPUT_PATH} cp ${BUILD_PATH}/tests/ut/sc_check/ut_sc_check ${OUTPUT_PATH} export ASAN_OPTIONS=detect_container_overflow=0 RUN_TEST_CASE=${OUTPUT_PATH}/ut_metadef && ${RUN_TEST_CASE} && - RUN_TEST_CASE=${OUTPUT_PATH}/ut_graph && ${RUN_TEST_CASE} && RUN_TEST_CASE=${OUTPUT_PATH}/ut_register && ${RUN_TEST_CASE} && RUN_TEST_CASE=${OUTPUT_PATH}/ut_error_manager && ${RUN_TEST_CASE} && - RUN_TEST_CASE=${OUTPUT_PATH}/ut_exe_graph && ${RUN_TEST_CASE} && - RUN_TEST_CASE=${OUTPUT_PATH}/ut_ascendc_ir && ${RUN_TEST_CASE} && - RUN_TEST_CASE=${OUTPUT_PATH}/ut_expression && ${RUN_TEST_CASE} && RUN_TEST_CASE=${OUTPUT_PATH}/ut_sc_check && ${RUN_TEST_CASE} if [[ "$?" -ne 0 ]]; then echo "!!! UT FAILED, PLEASE CHECK YOUR CHANGES !!!" @@ -218,19 +203,9 @@ main() { lcov -c \ -d ${BUILD_RELATIVE_PATH}/base/CMakeFiles/metadef.dir \ -d ${BUILD_RELATIVE_PATH}/base/CMakeFiles/opp_registry.dir \ - -d ${BUILD_RELATIVE_PATH}/graph/CMakeFiles/graph.dir \ - -d ${BUILD_RELATIVE_PATH}/graph/CMakeFiles/graph_base.dir \ - -d ${BUILD_RELATIVE_PATH}/graph/ascendc_ir/CMakeFiles/aihac_ir.dir \ - -d ${BUILD_RELATIVE_PATH}/graph/expression/CMakeFiles/aihac_symbolizer.dir \ - -d ${BUILD_RELATIVE_PATH}/graph/ascendc_ir/generator/CMakeFiles/ascir_generate.dir \ - -d ${BUILD_RELATIVE_PATH}/graph/ascendc_ir/generator/CMakeFiles/aihac_ir_register.dir \ - -d ${BUILD_RELATIVE_PATH}/graph/ascendc_ir/generator/CMakeFiles/ascir_ops_header_generator.dir \ - -d ${BUILD_RELATIVE_PATH}/register/CMakeFiles/register.dir \ - -d ${BUILD_RELATIVE_PATH}/register/CMakeFiles/rt2_registry_objects.dir \ -d ${BUILD_RELATIVE_PATH}/error_manager/CMakeFiles/error_manager.dir \ -d ${BUILD_RELATIVE_PATH}/base/CMakeFiles/exe_graph.dir \ - -d ${BUILD_RELATIVE_PATH}/exe_graph/CMakeFiles/lowering.dir \ - -d ${BUILD_RELATIVE_PATH}/tests/ut/exe_graph/CMakeFiles/ut_exe_graph.dir \ + -d ${BUILD_RELATIVE_PATH}/base/CMakeFiles/rt2_registry_objects.dir \ -o cov/tmp.info lcov -r cov/tmp.info '*/output/*' "*/${BUILD_RELATIVE_PATH}/opensrc/*" "*/${BUILD_RELATIVE_PATH}/proto/*" \ '*/third_party/*' '*/tests/*' '/usr/*' \ diff --git a/tests/ut/CMakeLists.txt b/tests/ut/CMakeLists.txt index e57132ca589018e0196efaf18f734322a688f139..8c9f340caac59616cc32de6a6efc9ae64bcdef46 100644 --- a/tests/ut/CMakeLists.txt +++ b/tests/ut/CMakeLists.txt @@ -8,13 +8,7 @@ # ====================================================================================================================== add_subdirectory(base) -add_subdirectory(graph) -add_subdirectory(register) add_subdirectory(error_manager) -add_subdirectory(exe_graph) add_subdirectory(exe_meta_device) add_subdirectory(sc_check) -add_subdirectory(expression) -if (ENABLE_OPEN_SRC) -add_subdirectory(ascendc_ir) -endif () +add_subdirectory(register) diff --git a/tests/ut/ascendc_ir/CMakeLists.txt b/tests/ut/ascendc_ir/CMakeLists.txt deleted file mode 100644 index 0cb2291b989a020e033831c1c18de9582a047775..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/CMakeLists.txt +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -set(CMAKE_CXX_STANDARD 17) -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${METADEF_DIR}/inc/common/util/trace_manager) -include_directories(${CMAKE_BINARY_DIR}/proto/metadef_protos) -include_directories(${CMAKE_BINARY_DIR}/proto/metadef_protos/proto) -include_directories(${METADEF_DIR}) -include_directories(${METADEF_DIR}/graph) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) -include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) - -include_directories(${ASCEND_INSTALL_PATH}/include/experiment) -include_directories(${ASCEND_INSTALL_PATH}/include/experiment/runtime) -include_directories(${ASCEND_INSTALL_PATH}/include/experiment/msprof) - -add_subdirectory(stub) -add_compile_definitions(CMAKE_BINARY_DIR=\"${CMAKE_BINARY_DIR}\") -file(GLOB_RECURSE UT_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/ascendc_ir/testcase/*.cc") -file(GLOB_RECURSE FAKE_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/cache_desc_stub/runtime_cache_desc.cc") -add_executable(ut_ascendc_ir ${UT_FILES} ${FAKE_FILES} ${UTILS_FILES}) - -target_compile_options(ut_ascendc_ir PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Wno-deprecated-declarations - -Wno-error=unused-variable - -Wall -Wfloat-equal -Werror - -D_GLIBCXX_USE_CXX11_ABI=0 - -fno-access-control -) - -target_compile_definitions(ut_ascendc_ir PRIVATE - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - FUNC_VISIBILITY -) - -# intf_pub包含stdc++11 当前需要按照更高版本来编译ut -target_link_libraries(ut_ascendc_ir PRIVATE - -lgcov - -Wl,--no-as-needed - platform_stub - runtime_stub - slog - slog_headers - metadef_headers - ascir_generate #正常不应该依赖这个,当前因为ut要校验generate能力 - aihac_ir aihac_symbolizer aihac_ir_register graph graph_base error_manager mmpa ascir_stub_ops_headers - GTest::gtest GTest::gtest_main ascend_protobuf slog_stub slog c_sec json mmpa_stub -lrt -ldl -) - -target_include_directories(ut_ascendc_ir PRIVATE - ${METADEF_DIR}/tests/depends - ${METADEF_DIR}/tests/ut/ascendc_ir -) diff --git a/tests/ut/ascendc_ir/code_extractor.h b/tests/ut/ascendc_ir/code_extractor.h deleted file mode 100644 index 96104c1dec7cdc155e935e3ae08455ac3ca3cd35..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/code_extractor.h +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#ifndef METADEF_CXX_TESTS_UT_ASCENDC_IR_CODE_EXTRACTOR_H_ -#define METADEF_CXX_TESTS_UT_ASCENDC_IR_CODE_EXTRACTOR_H_ -#include -#include -#include -#include - -class CodeExtractor { - public: - struct FunctionCode { - std::string signature; - std::string body; - }; - - static FunctionCode ExtractFunction(const std::string &header_path, - const std::string &class_name, - const std::string &target_func) { - std::ifstream file(header_path); - std::vector lines; - std::string line; - - while (std::getline(file, line)) { - lines.push_back(line); - } - - return AnalyzeLines(lines, class_name, target_func); - } - - private: - static FunctionCode AnalyzeLines(const std::vector &lines, - const std::string &class_name, - const std::string &target_func) { - enum State { SEARCHING, IN_CLASS, IN_FUNCTION }; - State state = SEARCHING; - int brace_level = 0; - FunctionCode result; - size_t start_line = 0; - - for (size_t i = 0; i < lines.size(); ++i) { - std::string trimmed = Trim(lines[i]); - - if (state == SEARCHING) { - if (class_name.empty() || IsClassStart(trimmed, class_name)) { - state = IN_CLASS; - continue; - } - } else if (state == IN_CLASS) { - if (IsFunctionStart(trimmed, target_func)) { - state = IN_FUNCTION; - start_line = i; - result.signature = trimmed; - brace_level += CountBraces(trimmed); - } - } else { - brace_level += CountBraces(lines[i]); - if (brace_level == 0) { - // 收集函数体代码 - for (size_t j = start_line; j <= i; ++j) { - result.body += (lines[j] + "\n"); - } - break; - } - } - } - return result; - } - - static std::string Trim(const std::string &s) { - size_t start = s.find_first_not_of(" \t"); - size_t end = s.find_last_not_of(" \t"); - return (start == std::string::npos) ? "" : s.substr(start, end - start + 1); - } - - static int CountBraces(const std::string &s) { - return std::count(s.begin(), s.end(), '{') - - std::count(s.begin(), s.end(), '}'); - } - - static bool IsFunctionStart(const std::string &line, - const std::string &func_name) { - return line.find(func_name + "(") != std::string::npos && - (line.find('{') != std::string::npos || - line.find(';') == std::string::npos); - } - - static bool IsClassStart(const std::string &line, - const std::string &class_name) { - return line.find(class_name + " :") != std::string::npos && - (line.find('{') != std::string::npos || - line.find(';') == std::string::npos); - } -}; -#endif //METADEF_CXX_TESTS_UT_ASCENDC_IR_CODE_EXTRACTOR_H_ diff --git a/tests/ut/ascendc_ir/stub/CMakeLists.txt b/tests/ut/ascendc_ir/stub/CMakeLists.txt deleted file mode 100644 index d3d0f83f740e87e14e79a81cb35da9270cc01af7..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/stub/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -include(${METADEF_DIR}/graph/ascendc_ir/generator/generator.cmake) -add_library(ascir_stub_builtin_ops SHARED ascir_stub_builtin_ops.cc ascir_stub_builtin_ops_v2.cc) -target_compile_options(ascir_stub_builtin_ops PRIVATE - -D_GLIBCXX_USE_CXX11_ABI=0) -target_link_options(ascir_stub_builtin_ops PRIVATE - -rdynamic - -Wl,-Bsymbolic - -Wl,--exclude-libs,All - -L ${INSTALL_BASE_DIR}/lib - ) -target_link_libraries(ascir_stub_builtin_ops PRIVATE aihac_ir_register) -target_include_directories(ascir_stub_builtin_ops PRIVATE - ${METADEF_DIR} - ${METADEF_DIR}/inc - ${METADEF_DIR}/inc/external - ${METADEF_DIR}/graph/ascendc_ir/generator - ) -message(${CMAKE_BINARY_DIR}) -set(ops_header_dir ${CMAKE_BINARY_DIR}/ascir_stub_builtin_ops) -set_target_properties(ascir_ops_header_generator PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${ops_header_dir}) -message("ops_header_dir is " ${ops_header_dir}/ascir_ops.h) -ascir_generate(ascir_stub_builtin_ops - ${ops_header_dir} - ${CMAKE_BINARY_DIR}/tests/ut/ascendc_ir/stub/libascir_stub_builtin_ops.so - ${ops_header_dir}/ascir_ops.h) - -add_custom_target(ascir_stub_builtin_ops_header ALL DEPENDS ${ops_header_dir}/ascir_ops.h) -add_library(ascir_stub_ops_headers INTERFACE) -target_include_directories(ascir_stub_ops_headers INTERFACE ${ops_header_dir}) -add_dependencies(ascir_stub_ops_headers ascir_stub_builtin_ops_header) diff --git a/tests/ut/ascendc_ir/stub/ascir_stub_builtin_ops.cc b/tests/ut/ascendc_ir/stub/ascir_stub_builtin_ops.cc deleted file mode 100644 index 7e8f96764639b4a0af6bad526029a120d0046c5d..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/stub/ascir_stub_builtin_ops.cc +++ /dev/null @@ -1,333 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include "graph/ascendc_ir/ascir_register.h" -#include "graph/types.h" -namespace ge { -namespace ascir { -EXPORT_GENERATOR() -REG_ASC_IR_START_NODE_WITH_ATTR(Data); -REG_ASC_IR_START_NODE(Constant).Attr("value").Attr("dtype"); -REG_ASC_IR_START_NODE(IndexExpr).Attr("expr"); -REG_ASC_IR_START_NODE(Workspace); -REG_ASC_IR_START_NODE(TbufData); -REG_ASC_IR_1IO(Output); - -REG_ASC_IR_1IO(Load).UseFirstInputView(); -REG_ASC_IR_1IO(Broadcast); -REG_ASC_IR_1IO(Store).UseFirstInputView().Attr("offset"); -//这里先打桩用来测试 -REG_ASC_IR_1IO(WorkspaceWithInput).UseFirstInputView(); - -/* - * todo nop比较特别,不确定是不是缺陷,原定义中,GEIR与ASCIR是不同的,GEIR多了个必选属性 -namespace ge { -REG_OP(Nop) - .REQUIRED_ATTR(dst_type, Int) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) -.OP_END_FACTORY_REG(Nop) -} - -namespace ascir::ops { -REG_OPS(Nop) -OPS_INPUT(0, x) -OPS_OUTPUT(0, y) -END_OPS(Nop) -} - */ -REG_ASC_IR_1IO(Nop); -REG_ASC_IR_1IO(Cast).Attr("dst_type"); - -REG_ASC_IR_1IO(Abs).UseFirstInputView(); -REG_ASC_IR_1IO(Exp).UseFirstInputView(); - -REG_ASC_IR_1IO(Max); -REG_ASC_IR_1IO(Sum); - -REG_ASC_IR_2I1O(Add).UseFirstInputView(); -REG_ASC_IR_2I1O(Sub).UseFirstInputView(); -REG_ASC_IR_2I1O(Div).UseFirstInputView(); -REG_ASC_IR_2I1O(Mul).UseFirstInputView(); - -REG_ASC_IR_2I1O(GT).UseFirstInputView(); -REG_ASC_IR_2I1O(Muls).UseFirstInputView(); - -// REG_ASC_IR_2I1O(MatMul) -REG_ASC_IR(MatMul) - .Inputs({"x1", "x2"}) - .OptionalInput("x3") - .Outputs({"y"}); - -REG_ASC_IR(FlashSoftmax) - .Inputs({"x1", "x2", "x3"}) - .Outputs({"y1", "y2", "y3"}); -REG_ASC_IR_2I1O(Dropout); -REG_ASC_IR_2I1O(Select); -// 适配add_layer_norm新增的api -REG_ASC_IR(CalcMean).Inputs({"x1", "x2", "x3"}).Outputs({"y1", "y2", "y3"}); -REG_ASC_IR(CalcMeanSlice).Inputs({"x1", "x2", "x3"}).Outputs({"y1", "y2", "y3"}); -REG_ASC_IR(CalcRstd).Inputs({"x1", "x2", "x3"}).Outputs({"y1", "y2"}); -REG_ASC_IR(CalcRstdSlice).Inputs({"x1", "x2"}).Outputs({"y1", "y2"}); -REG_ASC_IR(VFWelfordPart1Update) - .Inputs({"x1", "x2", "x3"}) - .Outputs({"y1", "y2", "y3", "y4"}) - .UseFirstInputView(); -REG_ASC_IR(VFWelfordPart1Finalize).Inputs({"x1", "x2"}).Outputs({"y1", "y2"}); -REG_ASC_IR(VFCalcYWelford).Inputs({"x1", "x2", "x3"}).Outputs({"y1"}).UseSecondInputDataType().UseFirstInputView(); -REG_ASC_IR(Concat).DynamicInput("x").Outputs({"y"}); -REG_ASC_IR(VectorFunction).DynamicInput("x").DynamicOutput("y", "T").DataType("T", TensorType::ALL()); -REG_ASC_IR(FakeOpA).DynamicInput("dx").OptionalInput("x2").Inputs({"x3", "x4"}).Outputs({"y"}); -REG_ASC_IR(CalcY).Inputs({"x1", "x2", "x3", "x4"}).Outputs({"y1"}).UseSecondInputDataType().UseFirstInputView(); -REG_ASC_IR(CalcMeanStub) - .Inputs({"x1", "x2", "x3"}) - .Outputs({"y1", "y2", "y3", "y4"}).Attr("reduce_axis_dim") - .DataTypes({PromptDtype(0U), 0U, PromptDtype(0U), - ge::DT_DOUBLE}) - .Views({ReduceView(0U, "reduce_axis_dim"), 0U, 0U, 0U}); -// 打桩测试专用op -REG_ASC_IR_WITH_COMMENT(StubOp1, - .Input("x", "T") - .Output("y", "T") - .DataType("T", TensorType::ALL()) - .Attr("my_int") - .Attr("my_string") - .Attr("my_float") - .Attr("offset") -); -/* codgen生成的类如下 -namespace ge { -namespace ascir_op { -struct StubOp1 : public ge::op::StubOp1 { - static constexpr const char *Type = "StubOp1"; - AscNodeAttr &attr; - struct AscStubOp1IrAttrDef : public AscIrAttrDefBase { - ~AscStubOp1IrAttrDef() override = default; - graphStatus GetMy_int(int64_t &my_int) const { - auto attr_value = attr_store_.GetAnyValue("my_int"); - GE_WARN_ASSERT(attr_value != nullptr); - return attr_value->GetValue(my_int); - } - graphStatus SetMy_int(int64_t my_int) { - auto attr_value = attr_store_.GetOrCreateAnyValue("my_int"); - ASCIR_ASSERT_NOTNULL(attr_value); - return attr_value->SetValue(my_int); - } - graphStatus GetMy_string(std::string &my_string) const { - auto attr_value = attr_store_.GetAnyValue("my_string"); - GE_WARN_ASSERT(attr_value != nullptr); - return attr_value->GetValue(my_string); - } - graphStatus SetMy_string(std::string my_string) { - auto attr_value = attr_store_.GetOrCreateAnyValue("my_string"); - ASCIR_ASSERT_NOTNULL(attr_value); - return attr_value->SetValue(my_string); - } - graphStatus GetMy_float(float &my_float) const { - auto attr_value = attr_store_.GetAnyValue("my_float"); - GE_WARN_ASSERT(attr_value != nullptr); - return attr_value->GetValue(my_float); - } - graphStatus SetMy_float(float my_float) { - auto attr_value = attr_store_.GetOrCreateAnyValue("my_float"); - ASCIR_ASSERT_NOTNULL(attr_value); - return attr_value->SetValue(my_float); - } - }; - AscStubOp1IrAttrDef &ir_attr; - AscOpInput<0> x; - AscOpOutput y; - inline StubOp1(const char *name) - : ge::op::StubOp1(name), - attr(AscNodeAttr::Create(*this)), - ir_attr(dynamic_cast(*(attr.ir_attr))), - x(this), - y(this, 0) {} -}; -} -} - */ -REG_ASC_IR_WITH_COMMENT(StubOp2, - .Input("x1", "T") - .Input("x2", "T") - .Output("y", "T") - .DataType("T", TensorType{DT_INT32, DT_INT64}) -); - -REG_ASC_IR_WITH_COMMENT(StubOp2New, - .Input("x1", "T") - .Input("x2", "T") - .Output("y", "T") - .Impl({"socv1"}, - {nullptr, nullptr, - {{"T", TensorType{DT_INT32, DT_INT64}}}}) -); - -REG_ASC_IR(StubOp3) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T1") - .Output("y2", "T2") - .DataType("T1", TensorType{DT_INT32, DT_INT64}) - .DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}); - -REG_ASC_IR(StubOp3New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T1") - .Output("y2", "T2") - .Impl({"socv1"}, - {nullptr, nullptr, {{"T1", TensorType{DT_INT32, DT_INT64}}, {"T2", TensorType{DT_FLOAT16, DT_FLOAT}}}}); - -REG_ASC_IR(StubOp4) - .Input("x1", "T1") - .Input("x2", "T2") - .Output("y1", "T3") - .Output("y2", "T3") - .Output("y3", "T2") - .DataType("T1", TensorType{DT_INT32, DT_INT64}) - .DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}) - .DataType("T3", TensorType{DT_DOUBLE, DT_BOOL}); - -REG_ASC_IR(StubOp4New) - .Input("x1", "T1") - .Input("x2", "T2") - .Output("y1", "T3") - .Output("y2", "T3") - .Output("y3", "T2") - .Impl({"socv1"}, - {nullptr, - nullptr, - {{"T1", TensorType{DT_INT32, DT_INT64}}, - {"T2", TensorType{DT_FLOAT16, DT_FLOAT}}, - {"T3", TensorType{DT_DOUBLE, DT_BOOL}}}}); - -REG_ASC_IR(StubOp5) - .Input("x1", "T1") - .DynamicInput("x2", "T2") - .Output("y1", "T1") - .Output("y2", "T2") - .DataType("T1", TensorType{DT_INT32, DT_INT64}) - .DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}); - -REG_ASC_IR(StubOp5New) - .Input("x1", "T1") - .DynamicInput("x2", "T2") - .Output("y1", "T1") - .Output("y2", "T2") - .Impl({"socv1"}, - {nullptr, nullptr, {{"T1", TensorType{DT_INT32, DT_INT64}}, {"T2", TensorType{DT_FLOAT16, DT_FLOAT}}}}); - -REG_ASC_IR(StubOp6) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T3") - .DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT64}) - .DataType("T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT}) - .DataType("T3", OrderedTensorTypeList{DT_BOOL, DT_INT8}); - -REG_ASC_IR(StubOp6New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T3") - .Impl({"socv1"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT64}}, - {"T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT}}, - {"T3", OrderedTensorTypeList{DT_BOOL, DT_INT8}}}}); - -REG_ASC_IR(StubOp7) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T3") - .Output("y2", "T2") - .DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) - .DataType("T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT16, DT_FLOAT}) - .DataType("T3", OrderedTensorTypeList{DT_BOOL, DT_INT4, DT_INT8}); - -REG_ASC_IR(StubOp7New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T3") - .Output("y2", "T2") - .Impl({"socv1"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, - {"T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT16, DT_FLOAT}}, - {"T3", OrderedTensorTypeList{DT_BOOL, DT_INT4, DT_INT8}}}}); - -REG_ASC_IR(StubOp8) - .Input("x", "T1") - .Output("y", "T2") - .DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) - .DataType("T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}); - -REG_ASC_IR(StubOp8New) - .Input("x", "T1") - .Output("y", "T2") - .Impl({"socv1"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, {"T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}}}}); - -REG_ASC_IR(StubOp9) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T3") - .Output("y1", "T2") - .Output("y2", "T1") - .Output("y3", "T4") - .Output("y4", "T5") - .DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) - .DataType("T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}) - .DataType("T3", OrderedTensorTypeList{DT_INT8, DT_INT8, DT_FLOAT}) - .DataType("T4", OrderedTensorTypeList{DT_BOOL, DT_DOUBLE, DT_FLOAT}) - .DataType("T5", OrderedTensorTypeList{DT_BOOL, DT_COMPLEX128, DT_DUAL}); - -REG_ASC_IR(StubOp9New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T3") - .Output("y1", "T2") - .Output("y2", "T1") - .Output("y3", "T4") - .Output("y4", "T5") - .Impl({"socv1"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, - {"T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}}, - {"T3", OrderedTensorTypeList{DT_INT8, DT_INT8, DT_FLOAT}}, - {"T4", OrderedTensorTypeList{DT_BOOL, DT_DOUBLE, DT_FLOAT}}, - {"T5", OrderedTensorTypeList{DT_BOOL, DT_COMPLEX128, DT_DUAL}}}}); - -REG_ASC_IR_1IO(StubOp10) - .SameTmpBufSizeFromFirstInput() - .CalcTmpBufSize("CalcTmpSizeForStubOp11"); - -REG_ASC_IR_1IO(StubOp11) - .CalcTmpBufSize("CalcTmpSizeForStubOp11") - .SameTmpBufSizeFromFirstInput(); - -REG_ASC_IR(StubRemovePad) - .Input("x", "T") - .Output("y", "T") - .Impl({"socv1"}, - {nullptr, - nullptr, - {{"T", OrderedTensorTypeList{DT_INT16, DT_UINT16, DT_INT32}}}}); -} // namespace ascir -} // namespace ge diff --git a/tests/ut/ascendc_ir/stub/ascir_stub_builtin_ops_v2.cc b/tests/ut/ascendc_ir/stub/ascir_stub_builtin_ops_v2.cc deleted file mode 100644 index 577b142ab6a41ac54959dede3f3e4df0bab8c4d9..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/stub/ascir_stub_builtin_ops_v2.cc +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include "graph/ascendc_ir/ascir_register.h" -#include "graph/types.h" -namespace ge { -namespace ascir { -REG_ASC_IR_WITH_COMMENT(StubOp2New, - .Input("x1", "T") - .Input("x2", "T") - .Output("y", "T") - .Impl({"socv2", "socv3"}, - {nullptr, nullptr, - {{"T", TensorType{DT_INT32, DT_INT64}}}}) -); - -REG_ASC_IR(StubOp3New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T1") - .Output("y2", "T2") - .Impl({"socv2", "socv3"}, - {nullptr, nullptr, {{"T1", TensorType{DT_FLOAT16, DT_INT32, DT_INT64}}, {"T2", TensorType{DT_FLOAT16, DT_FLOAT16, DT_FLOAT}}}}); - -REG_ASC_IR(StubOp4New) - .Input("x1", "T1") - .Input("x2", "T2") - .Output("y1", "T3") - .Output("y2", "T3") - .Output("y3", "T2") - .Impl({"socv2", "socv3"}, - {nullptr, - nullptr, - {{"T1", TensorType{DT_INT32, DT_INT64, DT_UINT16}}, - {"T2", TensorType{DT_FLOAT16, DT_FLOAT, DT_UINT16}}, - {"T3", TensorType{DT_DOUBLE, DT_BOOL, DT_UINT16}}}}); - -REG_ASC_IR(StubOp5New) - .Input("x1", "T1") - .DynamicInput("x2", "T2") - .Output("y1", "T1") - .Output("y2", "T2") - .Impl({"socv2", "socv3"}, - {nullptr, nullptr, {{"T1", TensorType{DT_INT32, DT_INT64,DT_UINT16}}, {"T2", TensorType{DT_FLOAT16, DT_FLOAT,DT_UINT16}}}}); - -REG_ASC_IR(StubOp6New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T3") - .Impl({"socv2", "socv3"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT64,DT_UINT16,DT_UINT16}}, - {"T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT,DT_UINT16,DT_UINT16}}, - {"T3", OrderedTensorTypeList{DT_BOOL, DT_INT8,DT_UINT16,DT_UINT64}}}}); - -REG_ASC_IR(StubOp7New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T1") - .Output("y1", "T3") - .Output("y2", "T2") - .Impl({"socv2", "socv3"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, - {"T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT16, DT_FLOAT}}, - {"T3", OrderedTensorTypeList{DT_BOOL, DT_INT4, DT_INT8}}}}); - -REG_ASC_IR(StubOp8New) - .Input("x", "T1") - .Output("y", "T2") - .Impl({"socv2", "socv3"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, {"T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}}}}); - -REG_ASC_IR(StubOp9New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T3") - .Output("y1", "T2") - .Output("y2", "T1") - .Output("y3", "T4") - .Output("y4", "T5") - .Impl({"socv2", "socv3"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, - {"T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}}, - {"T3", OrderedTensorTypeList{DT_INT8, DT_INT8, DT_FLOAT}}, - {"T4", OrderedTensorTypeList{DT_BOOL, DT_DOUBLE, DT_FLOAT}}, - {"T5", OrderedTensorTypeList{DT_BOOL, DT_COMPLEX128, DT_DUAL}}}}); - -} // namespace ascir -} // namespace ge diff --git a/tests/ut/ascendc_ir/testcase/asc_graph_utils_unittest.cc b/tests/ut/ascendc_ir/testcase/asc_graph_utils_unittest.cc deleted file mode 100644 index c387997c27d348fabf951980d381fb81410f825a..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/asc_graph_utils_unittest.cc +++ /dev/null @@ -1,607 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include -#include "graph/ascendc_ir/utils/asc_graph_utils.h" -#include "graph/ascendc_ir/utils/asc_tensor_utils.h" -#include "graph/ascendc_ir/core/ascendc_ir_impl.h" -#include "testcase/ascendc_ir_dump_test/stub_graph.h" -#include "utils/graph_utils.h" -#include "mmpa/mmpa_api.h" -#include "graph/ge_context.h" -#include -using namespace ge; -class UtestAscirGraphUtils : public testing::Test { - protected: - void SetUp() { - dlog_setlevel(0, 3, 0); - } - - void TearDown() {} -}; -REG_OP(Constant) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .OP_END_FACTORY_REG(Constant); - -REG_OP(Abs) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .OP_END_FACTORY_REG(Abs); -namespace { -std::stringstream GetFilePathWhenDumpPathSet(const string &ascend_work_path) { - std::stringstream dump_file_path; - dump_file_path << ascend_work_path << "/pid_" << mmGetPid() << "_deviceid_" << GetContext().DeviceId() << "/"; - return dump_file_path; -} -struct AscNodeInfo { - std::string name; - std::string type; - size_t input_num; - size_t output_num; - std::vector axis_ids; -}; -NodePtr BuildNode(const ComputeGraphPtr &asc_graph, - const AscNodeInfo &node, bool make_asc_node_directly = true) { - OpDescBuilder op_desc_builder(node.name, node.type); - for (size_t input_index = 0U; input_index < node.input_num; ++input_index) { - op_desc_builder.AddInput("x_" + std::to_string(input_index)); - } - for (size_t output_index = 0U; output_index < node.output_num; ++output_index) { - op_desc_builder.AddOutput("y_" + std::to_string(output_index)); - } - const auto &op_desc = op_desc_builder.Build(); - auto node_attr_group = op_desc->GetOrCreateAttrsGroup(); - - node_attr_group->sched.exec_condition = ExecuteCondition::kNoCache; - node_attr_group->sched.axis = node.axis_ids; - if (!node_attr_group->sched.axis.empty()) { - node_attr_group->sched.loop_axis = node_attr_group->sched.axis.back(); - } - if (make_asc_node_directly) { - const auto &asc_node = ComGraphMakeShared(op_desc, asc_graph); - asc_node->Init(); - return asc_graph->AddNode(asc_node); - } - return asc_graph->AddNode(op_desc); -} -std::string GetSpecificFilePath(const std::string &file_path, const string &suffix) { - DIR *dir; - struct dirent *ent; - dir = opendir(file_path.c_str()); - if (dir == nullptr) { - return ""; - } - while ((ent = readdir(dir)) != nullptr) { - if (strstr(ent->d_name, suffix.c_str()) != nullptr) { - std::string d_name(ent->d_name); - closedir(dir); - return file_path + "/" + d_name; - } - } - closedir(dir); - return ""; -} -} -TEST_F(UtestAscirGraphUtils, SampleGraphSerializeDeserializeReadableSuccess) { - AscGraph g("graph"); - auto op = ascir_op::Constant("abc"); - auto node = g.AddNode(op); - node->inputs(); - node->outputs(); - g.SetTilingKey(0x5a5a); - std::string output; - EXPECT_EQ(AscGraphUtils::SerializeToReadable(g, output), GRAPH_SUCCESS); - AscGraph out_asc_graph(""); - EXPECT_EQ(AscGraphUtils::DeserializeFromReadable(output, out_asc_graph), GRAPH_SUCCESS); - EXPECT_EQ(out_asc_graph.GetName(), "graph"); - EXPECT_EQ(out_asc_graph.GetTilingKey(), 0x5a5a); -} - -TEST_F(UtestAscirGraphUtils, SampleGraphSerializeDeserializeBinarySuccess) { - AscGraph g("graph"); - auto op = ascir_op::Constant("abc"); - auto node = g.AddNode(op); - node->inputs(); - node->outputs(); - g.SetTilingKey(0x5a5a); - std::string output; - EXPECT_EQ(AscGraphUtils::SerializeToBinary(g, output), GRAPH_SUCCESS); - AscGraph out_asc_graph(""); - EXPECT_EQ(AscGraphUtils::DeserializeFromBinary(output, out_asc_graph), GRAPH_SUCCESS); - EXPECT_EQ(out_asc_graph.GetName(), "graph"); - EXPECT_EQ(out_asc_graph.GetTilingKey(), 0x5a5a); -} - -// constant->cast1 -// expr->cast2 -TEST_F(UtestAscirGraphUtils, ConstantCastIndexExprGraphSerializeDeserializeTestSuccess) { - AscGraph g("graph"); - ascir_op::Constant constant("constant", g); - constant.ir_attr.SetValue(0); - ascir_op::IndexExpr expr("expr", g); - expr.ir_attr.SetExpr(0x5a); - ascir_op::Cast cast1("cast1"); - cast1.x = constant.y; - ascir_op::Cast cast2("cast2"); - cast2.x = expr.y; - cast2.ir_attr.SetDst_type(ge::DT_FLOAT16); - g.SetTilingKey(0x5a5a); - std::string output; - EXPECT_EQ(AscGraphUtils::SerializeToReadable(g, output), GRAPH_SUCCESS); - AscGraph out_asc_graph(""); - EXPECT_EQ(AscGraphUtils::DeserializeFromReadable(output, out_asc_graph), GRAPH_SUCCESS); - EXPECT_EQ(out_asc_graph.GetName(), "graph"); - EXPECT_EQ(out_asc_graph.GetTilingKey(), 0x5a5a); -} - -TEST_F(UtestAscirGraphUtils, FaGraphSerializeDeserializeBinarySuccess) { - std::string graph_name("test_graph"); - AscGraph graph(graph_name.c_str()); - FaBeforeAutoFuse(graph); - FaAfterScheduler(graph); - FaAfterQueBufAlloc(graph); - std::string output; - ASSERT_EQ(AscGraphUtils::SerializeToBinary(graph, output), GRAPH_SUCCESS); - AscGraph out_asc_graph(""); - ASSERT_EQ(AscGraphUtils::DeserializeFromBinary(output, out_asc_graph), GRAPH_SUCCESS); - AscGraphAttr *out_asc_graph_attr = out_asc_graph.impl_->GetOrCreateGraphAttrsGroup(); - ASSERT_NE(out_asc_graph_attr, nullptr); - AscGraphAttr *graph_attr = graph.impl_->GetOrCreateGraphAttrsGroup(); - ASSERT_NE(out_asc_graph_attr, nullptr); - ASSERT_NE(graph_attr, nullptr); - EXPECT_EQ(out_asc_graph_attr->tiling_key, graph_attr->tiling_key); - ASSERT_EQ(out_asc_graph_attr->axis.size(), graph_attr->axis.size()); - for (size_t id = 0UL; id < out_asc_graph_attr->axis.size(); id++) { - EXPECT_EQ(out_asc_graph_attr->axis[id]->name, graph_attr->axis[id]->name); - EXPECT_EQ(out_asc_graph_attr->axis[id]->type, graph_attr->axis[id]->type); - EXPECT_EQ(out_asc_graph_attr->axis[id]->allow_unaligned_tail, graph_attr->axis[id]->allow_unaligned_tail); - EXPECT_EQ(out_asc_graph_attr->axis[id]->allow_oversize_axis, graph_attr->axis[id]->allow_oversize_axis); - EXPECT_EQ(out_asc_graph_attr->axis[id]->bind_block, graph_attr->axis[id]->bind_block); - EXPECT_EQ(out_asc_graph_attr->axis[id]->align, graph_attr->axis[id]->align); - EXPECT_EQ(out_asc_graph_attr->axis[id]->split_pair_other_id, graph_attr->axis[id]->split_pair_other_id); - EXPECT_EQ(out_asc_graph_attr->axis[id]->from, graph_attr->axis[id]->from); - EXPECT_EQ(string(out_asc_graph_attr->axis[id]->size.Str().get()), string(graph_attr->axis[id]->size.Str().get())); - } - std::vector asc_nodes; - for (const auto &n : graph.GetAllNodes()) { - asc_nodes.emplace_back(n); - } - EXPECT_EQ(graph_name, out_asc_graph.GetName()); - std::vector out_asc_nodes; - const auto out_graph = out_asc_graph.impl_->compute_graph_; - for (const auto &n : out_asc_graph.GetAllNodes()) { - out_asc_nodes.emplace_back(n); - } - ASSERT_EQ(asc_nodes.size(), out_asc_nodes.size()); - for (size_t i = 0UL; i < asc_nodes.size(); i++) { - // check ascend node - auto &asc_node = asc_nodes[i]; - auto &out_asc_node = out_asc_nodes[i]; - ASSERT_EQ(asc_node->GetName(), out_asc_node->GetName()) << " node name = " << asc_node->GetName(); - if (asc_node->GetType() == "Data") { - int64_t index_src{0}; - int64_t index_dst{-1}; - auto src_ir_attr = asc_node->attr.ir_attr.get(); - auto dst_ir_attr = out_asc_node->attr.ir_attr.get(); - if (src_ir_attr->GetAttrValue("index", index_src) == GRAPH_SUCCESS) { - EXPECT_EQ(dst_ir_attr->GetAttrValue("index", index_dst), GRAPH_SUCCESS) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(index_src, index_dst); - } - } - EXPECT_EQ(asc_node->attr.type, out_asc_node->attr.type) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.name, out_asc_node->attr.name) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.api.type, out_asc_node->attr.api.type) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.api.unit, out_asc_node->attr.api.unit) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.api.compute_type, out_asc_node->attr.api.compute_type) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.exec_order, out_asc_node->attr.sched.exec_order) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.loop_axis, out_asc_node->attr.sched.loop_axis) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.exec_condition, ExecuteCondition::kNoCache) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.axis, out_asc_node->attr.sched.axis) << " node name = " << asc_node->GetName(); - // check input tensor - auto inputs = asc_node->inputs.tensors_; - auto new_inputs = out_asc_node->inputs.tensors_; - for (size_t id = 0UL; id < inputs.size(); id++) { - // check anchor ref - ASSERT_NE(inputs[id].anchor.GetOwnerNode(), nullptr); - ASSERT_NE(new_inputs[id].anchor.GetOwnerNode(), nullptr); - EXPECT_EQ(new_inputs[id].anchor.GetOwnerNode()->GetName(), inputs[id].anchor.GetOwnerNode()->GetName()); - EXPECT_EQ(new_inputs[id].anchor.GetIdx(), inputs[id].anchor.GetIdx()); - // check attr - const auto owner_node = ge::ascir::AscTensorUtils::GetOwner(inputs[id]); - EXPECT_EQ(inputs[id].attr.axis, new_inputs[id].attr.axis); - EXPECT_EQ(inputs[id].attr.dtype, new_inputs[id].attr.dtype) - << "node name=" << asc_node->GetNamePtr() << ", in id=" << id - << ", from=" << ((owner_node != nullptr) ? "null" : owner_node->GetName()); - // mem attr - EXPECT_EQ(inputs[id].attr.mem.tensor_id, new_inputs[id].attr.mem.tensor_id); - EXPECT_EQ(inputs[id].attr.mem.name, new_inputs[id].attr.mem.name); - EXPECT_EQ(inputs[id].attr.mem.buf_ids, new_inputs[id].attr.mem.buf_ids); - EXPECT_EQ(inputs[id].attr.mem.position, new_inputs[id].attr.mem.position); - EXPECT_EQ(inputs[id].attr.mem.hardware, new_inputs[id].attr.mem.hardware); - EXPECT_EQ(inputs[id].attr.mem.alloc_type, new_inputs[id].attr.mem.alloc_type); - // que/buf attr - EXPECT_EQ(inputs[id].attr.que.id, new_inputs[id].attr.que.id); - EXPECT_EQ(inputs[id].attr.que.buf_num, new_inputs[id].attr.que.buf_num); - EXPECT_EQ(inputs[id].attr.que.depth, new_inputs[id].attr.que.depth); - EXPECT_EQ(inputs[id].attr.buf.id, new_inputs[id].attr.buf.id); - // opt attr - EXPECT_EQ(inputs[id].attr.opt.reuse_id, new_inputs[id].attr.opt.reuse_id); - EXPECT_EQ(inputs[id].attr.opt.ref_tensor, new_inputs[id].attr.opt.ref_tensor); - EXPECT_EQ(inputs[id].attr.opt.merge_scope, new_inputs[id].attr.opt.merge_scope); - ASSERT_EQ(inputs[id].attr.strides.size(), new_inputs[id].attr.strides.size()); - for (size_t stride_id = 0UL; stride_id < inputs[id].attr.strides.size(); stride_id++) { - EXPECT_EQ(std::string(inputs[id].attr.strides[stride_id].Str().get()), - std::string(new_inputs[id].attr.strides[stride_id].Str().get())); - } - ASSERT_EQ(inputs[id].attr.repeats.size(), new_inputs[id].attr.repeats.size()); - for (size_t repeat_id = 0UL; repeat_id < inputs[id].attr.repeats.size(); repeat_id++) { - EXPECT_EQ(std::string(inputs[id].attr.repeats[repeat_id].Str().get()), - std::string(new_inputs[id].attr.repeats[repeat_id].Str().get())); - } - EXPECT_EQ(inputs[id].attr.vectorized_axis, new_inputs[id].attr.vectorized_axis); - ASSERT_EQ(inputs[id].attr.vectorized_strides.size(), new_inputs[id].attr.vectorized_strides.size()); - for (size_t vectorized_stride_id = 0UL; vectorized_stride_id < inputs[id].attr.vectorized_strides.size(); - vectorized_stride_id++) { - EXPECT_EQ(std::string(inputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get()), - std::string(new_inputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get())) - << " node name=" << asc_node->GetName() << " input index=" << id; - } - } - - auto outputs = asc_node->outputs.tensors_; - auto new_outputs = out_asc_node->outputs.tensors_; - for (size_t id = 0UL; id < outputs.size(); id++) { - // check anchor ref - ASSERT_NE(outputs[id].anchor.GetOwnerNode(), nullptr); - ASSERT_NE(new_outputs[id].anchor.GetOwnerNode(), nullptr); - EXPECT_EQ(new_outputs[id].anchor.GetOwnerNode()->GetName(), outputs[id].anchor.GetOwnerNode()->GetName()); - EXPECT_EQ(new_outputs[id].anchor.GetIdx(), outputs[id].anchor.GetIdx()); - // check attr - EXPECT_EQ(outputs[id].attr.axis, new_outputs[id].attr.axis); - EXPECT_EQ(outputs[id].attr.dtype, new_outputs[id].attr.dtype) - << "node name=" << asc_node->GetNamePtr() << ",out id=" << id; - // mem attr - EXPECT_EQ(outputs[id].attr.mem.tensor_id, new_outputs[id].attr.mem.tensor_id); - EXPECT_EQ(outputs[id].attr.mem.name, new_outputs[id].attr.mem.name); - EXPECT_EQ(outputs[id].attr.mem.buf_ids, new_outputs[id].attr.mem.buf_ids); - EXPECT_EQ(outputs[id].attr.mem.position, new_outputs[id].attr.mem.position); - EXPECT_EQ(outputs[id].attr.mem.hardware, new_outputs[id].attr.mem.hardware); - EXPECT_EQ(outputs[id].attr.mem.alloc_type, new_outputs[id].attr.mem.alloc_type); - // que/buf attr - EXPECT_EQ(outputs[id].attr.que.id, new_outputs[id].attr.que.id); - EXPECT_EQ(outputs[id].attr.que.buf_num, new_outputs[id].attr.que.buf_num); - EXPECT_EQ(outputs[id].attr.que.depth, new_outputs[id].attr.que.depth); - EXPECT_EQ(outputs[id].attr.buf.id, new_outputs[id].attr.buf.id); - // opt attr - EXPECT_EQ(outputs[id].attr.opt.reuse_id, new_outputs[id].attr.opt.reuse_id); - EXPECT_EQ(outputs[id].attr.opt.ref_tensor, new_outputs[id].attr.opt.ref_tensor); - EXPECT_EQ(outputs[id].attr.opt.merge_scope, new_outputs[id].attr.opt.merge_scope); - ASSERT_EQ(outputs[id].attr.strides.size(), new_outputs[id].attr.strides.size()); - for (size_t stride_id = 0UL; stride_id < outputs[id].attr.strides.size(); stride_id++) { - EXPECT_EQ(std::string(outputs[id].attr.strides[stride_id].Str().get()), - std::string(new_outputs[id].attr.strides[stride_id].Str().get())); - } - ASSERT_EQ(outputs[id].attr.repeats.size(), new_outputs[id].attr.repeats.size()); - for (size_t repeat_id = 0UL; repeat_id < outputs[id].attr.repeats.size(); repeat_id++) { - EXPECT_EQ(std::string(outputs[id].attr.repeats[repeat_id].Str().get()), - std::string(new_outputs[id].attr.repeats[repeat_id].Str().get())); - } - EXPECT_EQ(outputs[id].attr.vectorized_axis, new_outputs[id].attr.vectorized_axis); - ASSERT_EQ(outputs[id].attr.vectorized_strides.size(), new_outputs[id].attr.vectorized_strides.size()); - for (size_t vectorized_stride_id = 0UL; vectorized_stride_id < outputs[id].attr.vectorized_strides.size(); - vectorized_stride_id++) { - EXPECT_EQ(std::string(outputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get()), - std::string(new_outputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get())) - << " node name=" << asc_node->GetName() << ", output index=" << id; - } - } - } -} - -TEST_F(UtestAscirGraphUtils, FaGraphSerializeDeserializeReadableSuccess) { - std::string graph_name("test_graph"); - AscGraph graph(graph_name.c_str()); - FaBeforeAutoFuse(graph); - FaAfterScheduler(graph); - FaAfterQueBufAlloc(graph); - std::string output; - ASSERT_EQ(AscGraphUtils::SerializeToReadable(graph, output), GRAPH_SUCCESS); - AscGraph out_asc_graph(""); - ASSERT_EQ(AscGraphUtils::DeserializeFromReadable(output, out_asc_graph), GRAPH_SUCCESS); - AscGraphAttr *out_asc_graph_attr = out_asc_graph.impl_->GetOrCreateGraphAttrsGroup(); - ASSERT_NE(out_asc_graph_attr, nullptr); - AscGraphAttr *graph_attr = graph.impl_->GetOrCreateGraphAttrsGroup(); - ASSERT_NE(out_asc_graph_attr, nullptr); - ASSERT_NE(graph_attr, nullptr); - EXPECT_EQ(out_asc_graph_attr->tiling_key, graph_attr->tiling_key); - ASSERT_EQ(out_asc_graph_attr->axis.size(), graph_attr->axis.size()); - for (size_t id = 0UL; id < out_asc_graph_attr->axis.size(); id++) { - EXPECT_EQ(out_asc_graph_attr->axis[id]->name, graph_attr->axis[id]->name); - EXPECT_EQ(out_asc_graph_attr->axis[id]->type, graph_attr->axis[id]->type); - EXPECT_EQ(out_asc_graph_attr->axis[id]->allow_unaligned_tail, graph_attr->axis[id]->allow_unaligned_tail); - EXPECT_EQ(out_asc_graph_attr->axis[id]->allow_oversize_axis, graph_attr->axis[id]->allow_oversize_axis); - EXPECT_EQ(out_asc_graph_attr->axis[id]->bind_block, graph_attr->axis[id]->bind_block); - EXPECT_EQ(out_asc_graph_attr->axis[id]->align, graph_attr->axis[id]->align); - EXPECT_EQ(out_asc_graph_attr->axis[id]->split_pair_other_id, graph_attr->axis[id]->split_pair_other_id); - EXPECT_EQ(out_asc_graph_attr->axis[id]->from, graph_attr->axis[id]->from); - EXPECT_EQ(string(out_asc_graph_attr->axis[id]->size.Str().get()), string(graph_attr->axis[id]->size.Str().get())); - } - std::vector asc_nodes; - for (const auto &n : graph.GetAllNodes()) { - asc_nodes.emplace_back(n); - } - EXPECT_EQ(graph_name, out_asc_graph.GetName()); - std::vector out_asc_nodes; - const auto out_graph = out_asc_graph.impl_->compute_graph_; - for (const auto &n : out_asc_graph.GetAllNodes()) { - out_asc_nodes.emplace_back(n); - } - ASSERT_EQ(asc_nodes.size(), out_asc_nodes.size()); - for (size_t i = 0UL; i < asc_nodes.size(); i++) { - // check ascend node - auto &asc_node = asc_nodes[i]; - auto &out_asc_node = out_asc_nodes[i]; - ASSERT_EQ(asc_node->GetName(), out_asc_node->GetName()) << " node name = " << asc_node->GetName(); - if (asc_node->GetType() == "Data") { - int64_t index_src{0}; - int64_t index_dst{-1}; - auto src_ir_attr = asc_node->attr.ir_attr.get(); - auto dst_ir_attr = out_asc_node->attr.ir_attr.get(); - if (src_ir_attr->GetAttrValue("index", index_src) == GRAPH_SUCCESS) { - EXPECT_EQ(dst_ir_attr->GetAttrValue("index", index_dst), GRAPH_SUCCESS) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(index_src, index_dst); - } - } - EXPECT_EQ(asc_node->attr.type, out_asc_node->attr.type) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.name, out_asc_node->attr.name) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.api.type, out_asc_node->attr.api.type) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.api.unit, out_asc_node->attr.api.unit) << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.api.compute_type, out_asc_node->attr.api.compute_type) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.exec_order, out_asc_node->attr.sched.exec_order) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.loop_axis, out_asc_node->attr.sched.loop_axis) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.exec_condition, ExecuteCondition::kNoCache) - << " node name = " << asc_node->GetName(); - EXPECT_EQ(asc_node->attr.sched.axis, out_asc_node->attr.sched.axis) << " node name = " << asc_node->GetName(); - // check input tensor - auto inputs = asc_node->inputs.tensors_; - auto new_inputs = out_asc_node->inputs.tensors_; - for (size_t id = 0UL; id < inputs.size(); id++) { - // check anchor ref - ASSERT_NE(inputs[id].anchor.GetOwnerNode(), nullptr); - ASSERT_NE(new_inputs[id].anchor.GetOwnerNode(), nullptr); - EXPECT_EQ(new_inputs[id].anchor.GetOwnerNode()->GetName(), inputs[id].anchor.GetOwnerNode()->GetName()); - EXPECT_EQ(new_inputs[id].anchor.GetIdx(), inputs[id].anchor.GetIdx()); - // check attr - const auto owner_node = ge::ascir::AscTensorUtils::GetOwner(inputs[id]); - EXPECT_EQ(inputs[id].attr.axis, new_inputs[id].attr.axis); - EXPECT_EQ(inputs[id].attr.dtype, new_inputs[id].attr.dtype) - << "node name=" << asc_node->GetNamePtr() << ", in id=" << id - << ", from=" << ((owner_node != nullptr) ? "null" : owner_node->GetName()); - // mem attr - EXPECT_EQ(inputs[id].attr.mem.tensor_id, new_inputs[id].attr.mem.tensor_id); - EXPECT_EQ(inputs[id].attr.mem.name, new_inputs[id].attr.mem.name); - EXPECT_EQ(inputs[id].attr.mem.buf_ids, new_inputs[id].attr.mem.buf_ids); - EXPECT_EQ(inputs[id].attr.mem.position, new_inputs[id].attr.mem.position); - EXPECT_EQ(inputs[id].attr.mem.hardware, new_inputs[id].attr.mem.hardware); - EXPECT_EQ(inputs[id].attr.mem.alloc_type, new_inputs[id].attr.mem.alloc_type); - // que/buf attr - EXPECT_EQ(inputs[id].attr.que.id, new_inputs[id].attr.que.id); - EXPECT_EQ(inputs[id].attr.que.buf_num, new_inputs[id].attr.que.buf_num); - EXPECT_EQ(inputs[id].attr.que.depth, new_inputs[id].attr.que.depth); - EXPECT_EQ(inputs[id].attr.buf.id, new_inputs[id].attr.buf.id); - // opt attr - EXPECT_EQ(inputs[id].attr.opt.reuse_id, new_inputs[id].attr.opt.reuse_id); - EXPECT_EQ(inputs[id].attr.opt.ref_tensor, new_inputs[id].attr.opt.ref_tensor); - EXPECT_EQ(inputs[id].attr.opt.merge_scope, new_inputs[id].attr.opt.merge_scope); - ASSERT_EQ(inputs[id].attr.strides.size(), new_inputs[id].attr.strides.size()); - for (size_t stride_id = 0UL; stride_id < inputs[id].attr.strides.size(); stride_id++) { - EXPECT_EQ(std::string(inputs[id].attr.strides[stride_id].Str().get()), - std::string(new_inputs[id].attr.strides[stride_id].Str().get())); - } - ASSERT_EQ(inputs[id].attr.repeats.size(), new_inputs[id].attr.repeats.size()); - for (size_t repeat_id = 0UL; repeat_id < inputs[id].attr.repeats.size(); repeat_id++) { - EXPECT_EQ(std::string(inputs[id].attr.repeats[repeat_id].Str().get()), - std::string(new_inputs[id].attr.repeats[repeat_id].Str().get())); - } - EXPECT_EQ(inputs[id].attr.vectorized_axis, new_inputs[id].attr.vectorized_axis); - ASSERT_EQ(inputs[id].attr.vectorized_strides.size(), new_inputs[id].attr.vectorized_strides.size()); - for (size_t vectorized_stride_id = 0UL; vectorized_stride_id < inputs[id].attr.vectorized_strides.size(); - vectorized_stride_id++) { - EXPECT_EQ(std::string(inputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get()), - std::string(new_inputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get())) - << " node name=" << asc_node->GetName() << " input index=" << id; - } - } - - auto outputs = asc_node->outputs.tensors_; - auto new_outputs = out_asc_node->outputs.tensors_; - for (size_t id = 0UL; id < outputs.size(); id++) { - // check anchor ref - ASSERT_NE(outputs[id].anchor.GetOwnerNode(), nullptr); - ASSERT_NE(new_outputs[id].anchor.GetOwnerNode(), nullptr); - EXPECT_EQ(new_outputs[id].anchor.GetOwnerNode()->GetName(), outputs[id].anchor.GetOwnerNode()->GetName()); - EXPECT_EQ(new_outputs[id].anchor.GetIdx(), outputs[id].anchor.GetIdx()); - // check attr - EXPECT_EQ(outputs[id].attr.axis, new_outputs[id].attr.axis); - EXPECT_EQ(outputs[id].attr.dtype, new_outputs[id].attr.dtype) - << "node name=" << asc_node->GetNamePtr() << ",out id=" << id; - // mem attr - EXPECT_EQ(outputs[id].attr.mem.tensor_id, new_outputs[id].attr.mem.tensor_id); - EXPECT_EQ(outputs[id].attr.mem.name, new_outputs[id].attr.mem.name); - EXPECT_EQ(outputs[id].attr.mem.buf_ids, new_outputs[id].attr.mem.buf_ids); - EXPECT_EQ(outputs[id].attr.mem.position, new_outputs[id].attr.mem.position); - EXPECT_EQ(outputs[id].attr.mem.hardware, new_outputs[id].attr.mem.hardware); - EXPECT_EQ(outputs[id].attr.mem.alloc_type, new_outputs[id].attr.mem.alloc_type); - // que/buf attr - EXPECT_EQ(outputs[id].attr.que.id, new_outputs[id].attr.que.id); - EXPECT_EQ(outputs[id].attr.que.buf_num, new_outputs[id].attr.que.buf_num); - EXPECT_EQ(outputs[id].attr.que.depth, new_outputs[id].attr.que.depth); - EXPECT_EQ(outputs[id].attr.buf.id, new_outputs[id].attr.buf.id); - // opt attr - EXPECT_EQ(outputs[id].attr.opt.reuse_id, new_outputs[id].attr.opt.reuse_id); - EXPECT_EQ(outputs[id].attr.opt.ref_tensor, new_outputs[id].attr.opt.ref_tensor); - EXPECT_EQ(outputs[id].attr.opt.merge_scope, new_outputs[id].attr.opt.merge_scope); - ASSERT_EQ(outputs[id].attr.strides.size(), new_outputs[id].attr.strides.size()); - for (size_t stride_id = 0UL; stride_id < outputs[id].attr.strides.size(); stride_id++) { - EXPECT_EQ(std::string(outputs[id].attr.strides[stride_id].Str().get()), - std::string(new_outputs[id].attr.strides[stride_id].Str().get())); - } - ASSERT_EQ(outputs[id].attr.repeats.size(), new_outputs[id].attr.repeats.size()); - for (size_t repeat_id = 0UL; repeat_id < outputs[id].attr.repeats.size(); repeat_id++) { - EXPECT_EQ(std::string(outputs[id].attr.repeats[repeat_id].Str().get()), - std::string(new_outputs[id].attr.repeats[repeat_id].Str().get())); - } - EXPECT_EQ(outputs[id].attr.vectorized_axis, new_outputs[id].attr.vectorized_axis); - ASSERT_EQ(outputs[id].attr.vectorized_strides.size(), new_outputs[id].attr.vectorized_strides.size()); - for (size_t vectorized_stride_id = 0UL; vectorized_stride_id < outputs[id].attr.vectorized_strides.size(); - vectorized_stride_id++) { - EXPECT_EQ(std::string(outputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get()), - std::string(new_outputs[id].attr.vectorized_strides[vectorized_stride_id].Str().get())) - << " node name=" << asc_node->GetName() << ", output index=" << id; - } - } - } -} - -TEST_F(UtestAscirGraphUtils, ConvertComputeGraphToAscGraph_Success) { - auto compute_graph = ComGraphMakeShared("test"); - const auto graph_attr_group_ptr = compute_graph->GetOrCreateAttrsGroup(); - EXPECT_TRUE(graph_attr_group_ptr != nullptr); - std::vector axes{1, 2, 3}; - std::vector axis_ptrs; - auto axis_ptr = ComGraphMakeShared(); - axis_ptr->name = "axis1"; - axis_ptr->size = sym::kSymbolOne; - axis_ptr->id = 0; - axis_ptrs.push_back(axis_ptr); - graph_attr_group_ptr->axis = axis_ptrs; - auto data = BuildNode(compute_graph, {"data0", "Data", 0, 1, axes}, false); - auto load = BuildNode(compute_graph, {"load0", "Load", 1, 1, axes}, false); - GraphUtils::AddEdge(data->GetOutDataAnchor(0), load->GetInDataAnchor(0)); - EXPECT_TRUE(dynamic_cast(data.get()) == nullptr); - AscGraph asc_graph(""); - EXPECT_EQ(AscGraphUtils::ConvertComputeGraphToAscGraph(compute_graph, asc_graph), GRAPH_SUCCESS); - EXPECT_EQ(asc_graph.GetName(), "test"); - - auto asc_data = asc_graph.FindNode("data0"); - EXPECT_TRUE(asc_data != nullptr); - EXPECT_EQ(asc_data->attr.sched.axis, axes); - EXPECT_EQ(asc_data->attr.sched.loop_axis, axes.back()); - EXPECT_EQ(asc_data->attr.sched.exec_condition, ExecuteCondition::kNoCache); - EXPECT_EQ(asc_data->outputs().size(), 1U); - auto &tensor_attr = asc_data->outputs[0U]; - tensor_attr.attr.axis = axes; - EXPECT_EQ(asc_data->GetOpDesc()->GetOutputDescPtr(0U)->GetAttrsGroup()->axis, axes); - auto asc_load = asc_graph.FindNode("load0"); - EXPECT_EQ(asc_load->inputs.Size(), 1U); - EXPECT_EQ(asc_load->GetInDataNodesSize(), 1U); - EXPECT_EQ(asc_load->GetInNodesPtr()[0U], asc_data.get()); - auto axis_to_find = asc_graph.FindAxis(0); - EXPECT_TRUE(axis_to_find != nullptr); - EXPECT_EQ(axis_to_find->name, "axis1"); - EXPECT_EQ(axis_to_find->id, 0); - EXPECT_EQ(axis_to_find->size, Symbol(1)); -} - -TEST_F(UtestAscirGraphUtils, ConcatGraphSerializeDeserializeReadableSuccess) { - std::string graph_name("concat_graph"); - AscGraph graph(graph_name.c_str()); - CreatConcatAscGraph(graph); - std::string output; - ASSERT_EQ(AscGraphUtils::SerializeToReadable(graph, output), GRAPH_SUCCESS); - AscGraph out_asc_graph(""); - ASSERT_EQ(AscGraphUtils::DeserializeFromReadable(output, out_asc_graph), GRAPH_SUCCESS); - AscGraphAttr *out_asc_graph_attr = out_asc_graph.impl_->GetOrCreateGraphAttrsGroup(); - ASSERT_NE(out_asc_graph_attr, nullptr); - AscGraphAttr *graph_attr = graph.impl_->GetOrCreateGraphAttrsGroup(); - ASSERT_NE(out_asc_graph_attr, nullptr); - ASSERT_NE(graph_attr, nullptr); - EXPECT_EQ(out_asc_graph_attr->tiling_key, graph_attr->tiling_key); - EXPECT_EQ(out_asc_graph_attr->tiling_key, graph_attr->tiling_key); - ASSERT_EQ(out_asc_graph_attr->axis.size(), graph_attr->axis.size()); - for (size_t id = 0UL; id < out_asc_graph_attr->axis.size(); id++) { - EXPECT_EQ(out_asc_graph_attr->axis[id]->name, graph_attr->axis[id]->name); - EXPECT_EQ(out_asc_graph_attr->axis[id]->type, graph_attr->axis[id]->type); - EXPECT_EQ(out_asc_graph_attr->axis[id]->allow_unaligned_tail, graph_attr->axis[id]->allow_unaligned_tail); - EXPECT_EQ(out_asc_graph_attr->axis[id]->allow_oversize_axis, graph_attr->axis[id]->allow_oversize_axis); - EXPECT_EQ(out_asc_graph_attr->axis[id]->bind_block, graph_attr->axis[id]->bind_block); - EXPECT_EQ(out_asc_graph_attr->axis[id]->align, graph_attr->axis[id]->align); - EXPECT_EQ(out_asc_graph_attr->axis[id]->split_pair_other_id, graph_attr->axis[id]->split_pair_other_id); - EXPECT_EQ(out_asc_graph_attr->axis[id]->from, graph_attr->axis[id]->from); - EXPECT_EQ(string(out_asc_graph_attr->axis[id]->size.Str().get()), string(graph_attr->axis[id]->size.Str().get())); - } -} - -TEST_F(UtestAscirGraphUtils, FaGraphDumpWithAttrGroupSuccess) { - std::string graph_name("test_graph"); - AscGraph graph(graph_name.c_str()); - FaBeforeAutoFuse(graph); - FaAfterScheduler(graph); - FaAfterQueBufAlloc(graph); - std::string ascend_work_path = "./test_ge_graph_path"; - setenv("DUMP_GRAPH_PATH", ascend_work_path.c_str(), 1); - EXPECT_NO_THROW(GraphUtils::DumpGEGraph(AscGraphUtils::GetComputeGraph(graph), "attr_group_test", true);); - EXPECT_NO_THROW(GraphUtils::DumpGEGraphToOnnx(*AscGraphUtils::GetComputeGraph(graph), "attr_group_test", true)); - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(ascend_work_path); - std::string dump_graph_path = ge::RealPath(dump_file_path.str().c_str()); - std::string - dump_txt_graph_path = GetSpecificFilePath(ge::RealPath(dump_file_path.str().c_str()), "attr_group_test.txt"); - ComputeGraphPtr com_graph = std::make_shared("load_test_graph"); - // 测试反序列化之后的图 - auto state = GraphUtils::LoadGEGraph(dump_txt_graph_path.c_str(), *com_graph); - ASSERT_EQ(state, true); - auto data = com_graph->FindNode("query"); - EXPECT_NE(data, nullptr); - auto data_op = data->GetOpDesc(); - EXPECT_NE(data_op, nullptr); - auto data_attr_group = data_op->GetAttrsGroup(); - EXPECT_NE(data_attr_group, nullptr); - ascendc_ir::proto::AscNodeAttrGroupsDef asc_node_group; - EXPECT_EQ(data_attr_group->SerializeAttr(asc_node_group), GRAPH_SUCCESS); - EXPECT_EQ(asc_node_group.DebugString(), R"PROTO(name: "query" -type: "Data" -sched { - axis: 10 - axis: 11 - axis: 12 - axis: 8 - axis: 13 - axis: 5 - axis: 6 - loop_axis: 12 -} -api { - type: 2 - compute_type: 11 - unit: 7 -} -ir_attr_def { - attr { - key: "index" - value { - i: 0 - } - } -} -)PROTO"); - auto data_desc_attr_group = data_op->GetOutputDescPtr(0U)->GetAttrsGroup(); - EXPECT_NE(data_desc_attr_group, nullptr); - EXPECT_EQ(data_desc_attr_group->dtype, DT_FLOAT16); - unsetenv("DUMP_GRAPH_PATH"); - system(("rm -rf " + ascend_work_path).c_str()); -} diff --git a/tests/ut/ascendc_ir/testcase/asc_tensor_utils_unittest.cc b/tests/ut/ascendc_ir/testcase/asc_tensor_utils_unittest.cc deleted file mode 100644 index 7deea2094a92ba6cb61ccd6a04176f5672560c42..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/asc_tensor_utils_unittest.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include -#include "graph/ascendc_ir/utils/asc_tensor_utils.h" -#include "testcase/ascendc_ir_dump_test/stub_graph.h" -#include -using namespace ge; -class UtestAscirTensorUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; -REG_OP(Constant) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .OP_END_FACTORY_REG(Constant); - -REG_OP(Abs) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .OP_END_FACTORY_REG(Abs); - -TEST_F(UtestAscirTensorUtils, IsConstTensorTrue) { - AscGraph g("graph"); - auto op = ascir_op::Constant("abc"); - auto node = g.AddNode(op); - node->inputs(); - node->outputs(); - EXPECT_EQ(ge::ascir::AscTensorUtils::IsConstTensor(node->outputs[0]), true); -} - -TEST_F(UtestAscirTensorUtils, IsConstTensorFalse) { - AscGraph g("graph"); - auto op = ascir_op::Abs("abs"); - auto node = g.AddNode(op); - node->inputs(); - node->outputs(); - EXPECT_EQ(ge::ascir::AscTensorUtils::IsConstTensor(node->outputs[0]), false); -} diff --git a/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/ascendc_ir_dump_utils_unittest.cc b/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/ascendc_ir_dump_utils_unittest.cc deleted file mode 100644 index 22dc1262def874cb89772c6734d2b4316a4433c0..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/ascendc_ir_dump_utils_unittest.cc +++ /dev/null @@ -1,2684 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include -#include "inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "inc/graph/ascendc_ir/utils/ascendc_ir_dump_utils.h" -#include "stub_graph.h" -#include -#include -class UtestAscirDump : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; -using namespace ge; - -TEST_F(UtestAscirDump, DumpAscirGraphTest) { - AscGraph graph("test_graph"); - FaBeforeAutoFuse(graph); - FaAfterScheduler(graph); - FaAfterQueBufAlloc(graph); - std::string res = R"(TilingKey: 1 -Graph Name: test_graph -Axis: - axis1: - name: b - id: 0 - type: ORIGINAL - bind_block: false - size: B - align: 1 - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis2: - name: n - id: 1 - type: ORIGINAL - bind_block: false - size: N - align: 1 - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis3: - name: g - id: 2 - type: ORIGINAL - bind_block: false - size: G - align: 1 - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis4: - name: s1 - id: 3 - type: ORIGINAL - bind_block: false - size: S1 - align: 1 - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis5: - name: s2 - id: 4 - type: ORIGINAL - bind_block: false - size: S2 - align: 1 - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis6: - name: d - id: 5 - type: ORIGINAL - bind_block: false - size: D - align: 1 - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis7: - name: l - id: 6 - type: ORIGINAL - bind_block: false - size: 8 - align: 1 - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis8: - name: s1T - id: 7 - type: TILE_OUTER - bind_block: false - size: Ceiling((S1 / (s1t_size))) - align: 1 - from: {3, } - split_pair_other_id: 8 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis9: - name: s1t - id: 8 - type: TILE_INNER - bind_block: false - size: s1t_size - align: 128 - from: {3, } - split_pair_other_id: 7 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis10: - name: bngs1T - id: 9 - type: MERGED - bind_block: false - size: (B * Ceiling((S1 / (s1t_size))) * G * N) - align: 1 - from: {0, 1, 2, 7, } - split_pair_other_id: 0 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis11: - name: bngs1TB - id: 10 - type: BLOCK_OUTER - bind_block: false - size: Ceiling((B * Ceiling((S1 / (s1t_size))) * G * N / (bngs1Tb_size))) - align: 1 - from: {9, } - split_pair_other_id: 11 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis12: - name: bngs1Tb - id: 11 - type: BLOCK_INNER - bind_block: false - size: bngs1Tb_size - align: 1 - from: {9, } - split_pair_other_id: 10 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis13: - name: s2T - id: 12 - type: TILE_OUTER - bind_block: false - size: Ceiling((S2 / (s2t_size))) - align: 1 - from: {4, } - split_pair_other_id: 13 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis14: - name: s2t - id: 13 - type: TILE_INNER - bind_block: false - size: s2t_size - align: 256 - from: {4, } - split_pair_other_id: 12 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis15: - name: s1tT - id: 14 - type: TILE_OUTER - bind_block: false - size: Ceiling((s1t_size / (s1tt_size))) - align: 1 - from: {8, } - split_pair_other_id: 15 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis16: - name: s1tt - id: 15 - type: TILE_INNER - bind_block: false - size: s1tt_size - align: 1 - from: {8, } - split_pair_other_id: 14 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis17: - name: s1tT2 - id: 16 - type: TILE_OUTER - bind_block: false - size: Ceiling((s1t_size / (s1tt2_size))) - align: 1 - from: {8, } - split_pair_other_id: 17 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 - axis18: - name: s1tt2 - id: 17 - type: TILE_INNER - bind_block: false - size: s1tt2_size - align: 1 - from: {8, } - split_pair_other_id: 16 - allow_oversize_axis: 0 - allow_unaligned_tail: 0 -nodes: - node1 info: - node name: query - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 8, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, s1t_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, D, 0, 1, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: X,Y,Z, - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 0 - axis: 10, 11, 12, 8, 13, 5, 6, - loop_axis: 12 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node2 info: - node name: key - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 12, 8, 5, 13, 6, - repeats: (S2 / (s2t_size)), 1, D, s2t_size, 1, - strides: (D * s2t_size), 0, 1, D, 0, - vectorized_axis: 8, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 1 - axis: 10, 11, 12, 8, 5, 13, 6, - loop_axis: 12 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node3 info: - node name: bmm1 - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 8, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, s1t_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, D, 0, 1, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: X,Y,Z, - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT16 - axis: 12, 8, 5, 13, 6, - repeats: (S2 / (s2t_size)), 1, D, s2t_size, 1, - strides: (D * s2t_size), 0, 1, D, 0, - vectorized_axis: 8, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 8, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), s1t_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, S2, 1, 0, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 0 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 0 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 2 - axis: 10, 11, 12, 8, 13, 5, 6, - loop_axis: 12 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node4 info: - node name: load1 - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 8, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), s1t_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, S2, 1, 0, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 0 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 0 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 1 - alloc_type: QUEUE - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 3 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node5 info: - node name: pse - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, N, G, S1, S2, 1, 1, - strides: (G * N * S1 * S2), (G * S1 * S2), (S1 * S2), S2, 1, 0, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 4 - axis: - loop_axis: -1 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node6 info: - node name: loadPse - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, N, G, S1, S2, 1, 1, - strides: (G * N * S1 * S2), (G * S1 * S2), (S1 * S2), S2, 1, 0, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 2 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 0 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 5 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node7 info: - node name: castPse - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 2 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 0 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 3 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 6 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node8 info: - node name: add1 - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 1 - alloc_type: QUEUE - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 3 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 4 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 7 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node9 info: - node name: scaleValue - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT - axis: - repeats: - strides: - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 8 - axis: - loop_axis: -1 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node10 info: - node name: mul1 - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 4 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: - repeats: - strides: - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 5 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 9 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node11 info: - node name: attenMask - inputs: - outputs: - AscTensor: - DataType: DT_UINT8 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, 1, 1, S1, S2, 1, 1, - strides: (S1 * S2), (S1 * S2), (S1 * S2), S2, 1, 0, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 10 - axis: - loop_axis: -1 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node12 info: - node name: loadAttenMask - inputs: - AscTensor: - DataType: DT_UINT8 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, 1, 1, S1, S2, 1, 1, - strides: (S1 * S2), (S1 * S2), (S1 * S2), S2, 1, 0, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_UINT8 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 12 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 2 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 11 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node13 info: - node name: select - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 5 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_UINT8 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 12 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 2 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 6 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 12 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node14 info: - node name: softmaxExp - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 12, 14, 15, 13, 5, 6, - repeats: 1, (s1t_size / (s1tt_size)), s1tt_size, 1, 1, 8, - strides: 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 7 - alloc_type: QUEUE - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 3 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 13 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node15 info: - node name: softmaxApiTmpBuf - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 12, 14, 15, 13, 5, 6, - repeats: (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 8 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 14 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node16 info: - node name: flashSoftmax - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 6 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 12, 14, 15, 13, 5, 6, - repeats: 1, (s1t_size / (s1tt_size)), s1tt_size, 1, 1, 8, - strides: 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 7 - alloc_type: QUEUE - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 3 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 12, 14, 15, 13, 5, 6, - repeats: (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 8 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 9 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 8, - strides: (8 * bngs1Tb_size * s1t_size), (8 * s1t_size), 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 10 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 4 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 8, - strides: (8 * bngs1Tb_size * s1t_size), (8 * s1t_size), 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 11 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 2 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 15 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node17 info: - node name: storeSoftmaxMax - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 8, - strides: (8 * bngs1Tb_size * s1t_size), (8 * s1t_size), 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 11 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 2 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 8, - strides: (8 * bngs1Tb_size * s1t_size), (8 * s1t_size), 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 26 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: 0 - merge_scope: 0 - attr: - AscNode: - sched: - exec_order: 16 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node18 info: - node name: softmaxMax - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 8, - strides: (8 * bngs1Tb_size * s1t_size), (8 * s1t_size), 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 26 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: 0 - merge_scope: 0 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: - repeats: - strides: - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 17 - axis: - loop_axis: -1 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node19 info: - node name: dropMask - inputs: - outputs: - AscTensor: - DataType: DT_UINT8 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, N, G, S1, S2, 1, 1, - strides: (G * N * S1 * S2), (G * S1 * S2), (S1 * S2), S2, 1, 0, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 18 - axis: - loop_axis: -1 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node20 info: - node name: loadDropMask - inputs: - AscTensor: - DataType: DT_UINT8 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, N, G, S1, S2, 1, 1, - strides: (G * N * S1 * S2), (G * S1 * S2), (S1 * S2), S2, 1, 0, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_UINT8 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 13 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 3 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 19 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node21 info: - node name: dropout - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 9 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_UINT8 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 13 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 3 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 14 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 20 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node22 info: - node name: castVec1Res - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 14 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 1 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 15 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 0 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 21 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node23 info: - node name: storeVec1Res - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 15 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 0 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 14, 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 16 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 4 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 22 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - loop_axis: 14 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node24 info: - node name: value - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 12, 8, 13, 5, 6, - repeats: (S2 / (s2t_size)), 1, s2t_size, D, 1, - strides: (D * s2t_size), 0, D, 1, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 23 - axis: 10, 11, 12, 8, 13, 5, 6, - loop_axis: 12 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node25 info: - node name: bmm2 - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 1, - strides: (S2 * bngs1Tb_size * s1t_size), (S2 * s1t_size), s2t_size, (S2 * s1tt_size), S2, 1, 0, 0, - vectorized_axis: 14, 15, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 16 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 4 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT16 - axis: 12, 8, 13, 5, 6, - repeats: (S2 / (s2t_size)), 1, s2t_size, D, 1, - strides: (D * s2t_size), 0, D, 1, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 8, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, s1t_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, D, 0, 1, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 17 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 5 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 24 - axis: 10, 11, 12, 8, 13, 5, 6, - loop_axis: 12 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node26 info: - node name: load2 - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 8, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, s1t_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, D, 0, 1, 0, - vectorized_axis: 8, 13, 5, - vectorized_strides: - MemAttr: - tensor_id: 17 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 5 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 18 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 25 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node27 info: - node name: addResOut - inputs: - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 19 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 6 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 26 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node28 info: - node name: loadAddResOut - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 19 - alloc_type: QUEUE - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: 6 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 20 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 27 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node29 info: - node name: mulRes - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 20 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 12, 14, 15, 13, 5, 6, - repeats: 1, (s1t_size / (s1tt_size)), s1tt_size, 1, 1, 8, - strides: 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 7 - alloc_type: QUEUE - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 3 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 21 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 28 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node30 info: - node name: addRes - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 18 - alloc_type: BUFFER - position: VECIN - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 21 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 22 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 29 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node31 info: - node name: div - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 22 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 14, 15, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, (S2 / (s2t_size)), (s1t_size / (s1tt_size)), s1tt_size, s2t_size, 1, 8, - strides: (8 * bngs1Tb_size * s1t_size), (8 * s1t_size), 0, (8 * s1tt_size), 8, 0, 0, 1, - vectorized_axis: 14, 15, 13, 5, 6, - vectorized_strides: - MemAttr: - tensor_id: 11 - alloc_type: QUEUE - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: 2 - depth: 2 - buf_num: 2 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 23 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 30 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node32 info: - node name: castBmm2Res - inputs: - AscTensor: - DataType: DT_FLOAT - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 23 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 24 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: 24 - attr: - AscNode: - sched: - exec_order: 31 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node33 info: - node name: store - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 24 - alloc_type: BUFFER - position: VECOUT - hardware: UB - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: 5 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: 24 - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 25 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: 0 - merge_scope: 0 - attr: - AscNode: - sched: - exec_order: 32 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - loop_axis: 16 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node34 info: - node name: buf - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 25 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: 0 - merge_scope: 0 - outputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, N, G, S1, 1, D, 1, - strides: (D * G * N * S1), (D * G * S1), (D * S1), D, 0, 1, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: -1 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 33 - axis: - loop_axis: -1 - - Api: - Api type: INVALID - Compute unit: INVALID - Compute type: INVALID - - node35 info: - node name: buf_ - inputs: - AscTensor: - DataType: DT_FLOAT16 - axis: 10, 11, 12, 16, 17, 13, 5, 6, - repeats: (B * G * N * S1 / (bngs1Tb_size * s1t_size)), bngs1Tb_size, 1, (s1t_size / (s1tt2_size)), s1tt2_size, 1, D, 1, - strides: (D * bngs1Tb_size * s1t_size), (D * s1t_size), 0, (D * s1tt2_size), D, 0, 1, 0, - vectorized_axis: 17, 5, 13, - vectorized_strides: - MemAttr: - tensor_id: 25 - alloc_type: GLOBAL - position: GM - hardware: GM - buf_ids: - name: - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: 0 - merge_scope: 0 - outputs: - AscTensor: - DataType: DT_DUAL_SUB_UINT8 - axis: 0, 1, 2, 3, 4, 5, 6, - repeats: B, N, G, S1, 1, D, 1, - strides: (D * G * N * S1), (D * G * S1), (D * S1), D, 0, 1, 0, - vectorized_axis: - vectorized_strides: - MemAttr: - tensor_id: 1 - alloc_type: L1 - position: GM - hardware: UB - buf_ids: 1, 2, 3, 4, 5, - name: Mem_ - MemQueAttr: - id: -1 - depth: -1 - buf_num: -1 - name: - MemBufAttr: - id: -1 - name: - MemOptAttr: - reuse_id: -1 - ref_tensor: -1 - merge_scope: -1 - attr: - AscNode: - sched: - exec_order: 34 - axis: 1, 2, 3, 4, 5, - loop_axis: 3 - - Api: - Api type: BUFFER - Compute unit: MTE1 - Compute type: REDUCE - -)"; - //EXPECT_EQ(res, ge::DumpAscirGraph::DumpGraph(graph)); - ge::DumpAscirGraph::WriteOutToFile("../ascendc_ir_dump_test/dump_graph.txt", graph); -} diff --git a/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/stub_graph.cc b/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/stub_graph.cc deleted file mode 100644 index 4022f9bdfaf70d151046818908af879f3ac6f935..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/stub_graph.cc +++ /dev/null @@ -1,1109 +0,0 @@ -#include -#include -#include "stub_graph.h" -#include "graph/symbolizer/symbolic.h" - -namespace ge{ -using namespace ge::ascir_op; -void FaBeforeAutoFuse(ge::AscGraph &graph) { - graph.SetTilingKey(1); - using Expr = ge::Expression; - - auto B = Symbol("B"); - auto N = Symbol("N"); - auto G = Symbol("G"); - auto S1 = Symbol("S1"); - auto S2 = Symbol("S2"); - auto D = Symbol("D"); - auto BL = Symbol(8, "BL"); - auto ONE = Symbol(1, "ONE"); - auto ZERO = Symbol(0, "ZERO"); - - auto b = graph.CreateAxis("b", B); - auto n = graph.CreateAxis("n", N); - auto g = graph.CreateAxis("g", G); - auto s1 = graph.CreateAxis("s1", S1); - auto s2 = graph.CreateAxis("s2", S2); - auto d = graph.CreateAxis("d", D); - auto bl = graph.CreateAxis("l", BL); - - auto bmm1ResAxis = {b.id, n.id, g.id, s1.id, s2.id, d.id, bl.id}; - std::initializer_list bmm1ResRepeat = {B, N, G, S1, S2, ONE, ONE}; - std::initializer_list bmm1ResStride = {N*G*S1*S2, G*S1*S2, S1*S2, S2, ONE, ZERO, ZERO}; - - std::initializer_list vec1ResRepeat = {B, N, G, S1, S2, ONE, ONE}; - std::initializer_list vec1ResStride = {N*G*S1*S2, G*S1*S2, S1*S2, S2, ONE, ZERO, ZERO}; - - auto bmm2ResAxis = {b.id, n.id, g.id, s1.id, s2.id, d.id, bl.id}; - std::initializer_list bmm2ResRepeat = {B, N, G, S1, ONE, D, ONE}; - std::initializer_list bmm2ResStride = {N*G*S1*D, G*S1*D, S1*D, D, ZERO, ONE, ZERO}; - - std::initializer_list vec2ResRepeat = {B, N, G, S1, ONE, D, ONE}; - std::initializer_list vec2ResStride = {N*G*S1*D, G*S1*D, S1*D, D, ZERO, ONE, ZERO}; - - std::initializer_list reduceResRepeat = {ONE, ONE, ONE, S1, ONE, ONE, BL}; - std::initializer_list reduceResStride = {ZERO, ZERO, ZERO, BL, ZERO, ZERO, ONE}; - - int32_t exec_order = 0; - Data query("query", graph); - query.attr.sched.exec_order = exec_order++; - query.attr.sched.axis = bmm1ResAxis; - query.y.dtype = ge::DT_FLOAT16; - *query.y.axis = bmm1ResAxis; - *query.y.repeats = {B, N, G, S1, ONE, D, ONE}; - *query.y.strides = {N*G*S1*D, G*S1*D, S1*D, D, ZERO, ONE, ZERO}; - query.ir_attr.SetIndex(0); - - Data key("key", graph); - key.attr.sched.exec_order = exec_order++; - key.attr.sched.axis = bmm1ResAxis; - key.y.dtype = ge::DT_FLOAT16; - *key.y.axis = bmm1ResAxis; - *key.y.repeats = {B, N, G, ONE, S2, D, ONE}; - *key.y.strides = {N*S1*D, S2*D, S2*D, ZERO, D, ONE, ZERO}; - key.ir_attr.SetIndex(1); - - MatMul bmm1("bmm1"); - bmm1.x1 = query.y; - bmm1.x2 = key.y; - bmm1.attr.sched.exec_order = exec_order++; - bmm1.attr.sched.axis = bmm1ResAxis; - bmm1.y.dtype = ge::DT_FLOAT; - *bmm1.y.axis = bmm1ResAxis; - *bmm1.y.repeats = bmm1ResRepeat; - *bmm1.y.strides = bmm1ResStride; - - Load load1("load1"); - load1.x = bmm1.y; - load1.attr.sched.exec_order = exec_order++; - load1.attr.sched.axis = bmm1ResAxis; - load1.y.dtype = ge::DT_FLOAT; - *load1.y.axis = bmm1ResAxis; - *load1.y.repeats = vec1ResRepeat; - *load1.y.strides = vec1ResStride; - - Data pse("pse", graph); - pse.attr.sched.exec_order = exec_order++; - pse.y.dtype = ge::DT_FLOAT16; - *pse.y.axis = bmm1ResAxis; - *pse.y.repeats = vec1ResRepeat; - *pse.y.strides = vec1ResStride; - - Load loadPse("loadPse"); - loadPse.x = pse.y; - loadPse.attr.sched.exec_order = exec_order++; - loadPse.attr.sched.axis = bmm1ResAxis; - loadPse.y.dtype = ge::DT_FLOAT16; - *loadPse.y.axis = bmm1ResAxis; - *loadPse.y.repeats = vec1ResRepeat; - *loadPse.y.strides = vec1ResStride; - - Cast castPse("castPse"); - castPse.x = loadPse.y; - castPse.attr.sched.exec_order = exec_order++; - castPse.attr.sched.axis = bmm1ResAxis; - castPse.y.dtype = ge::DT_FLOAT; - *castPse.y.axis = bmm1ResAxis; - *castPse.y.repeats = vec1ResRepeat; - *castPse.y.strides = vec1ResStride; - - ge::ascir_op::Add add1("add1"); - add1.x1 = load1.y; - add1.x2 = castPse.y; - add1.attr.sched.exec_order = exec_order++; - add1.attr.sched.axis = bmm1ResAxis; - add1.y.dtype = ge::DT_FLOAT; - *add1.y.axis = bmm1ResAxis; - *add1.y.repeats = vec1ResRepeat; - *add1.y.strides = vec1ResStride; - - Data scaleValue("scaleValue", graph); - scaleValue.attr.sched.exec_order = exec_order++; - scaleValue.y.dtype = ge::DT_FLOAT; - - ge::ascir_op::Muls mul1("mul1"); - mul1.x1 = add1.y; - mul1.x2 = scaleValue.y; - mul1.attr.sched.exec_order = exec_order++; - mul1.attr.sched.axis = bmm1ResAxis; - mul1.y.dtype = ge::DT_FLOAT; - *mul1.y.axis = bmm1ResAxis; - *mul1.y.repeats = vec1ResRepeat; - *mul1.y.strides = vec1ResStride; - - Data attenMask("attenMask", graph); - attenMask.attr.sched.exec_order = exec_order++; - attenMask.y.dtype = ge::DT_UINT8; - *attenMask.y.axis = bmm1ResAxis; - *attenMask.y.repeats = {B, ONE, ONE, S1, S2, ONE, ONE}; - *attenMask.y.strides = {S1*S2, S1*S2, S1*S2, S2, ONE, ZERO, ZERO}; - - Load loadAttenMask("loadAttenMask"); - loadAttenMask.x = attenMask.y; - loadAttenMask.attr.sched.exec_order = exec_order++; - loadAttenMask.attr.sched.axis = bmm1ResAxis; - loadAttenMask.y.dtype = ge::DT_UINT8; - *loadAttenMask.y.axis = bmm1ResAxis; - *loadAttenMask.y.repeats = {B, ONE, ONE, S1, S2, ONE, ONE}; - *loadAttenMask.y.strides = {S1*S2, S1*S2, S1*S2, S2, ONE, ZERO, ZERO}; - - Select select("select"); - select.x1 = mul1.y; - select.x2 = loadAttenMask.y; - select.attr.sched.exec_order = exec_order++; - select.attr.sched.axis = bmm1ResAxis; - select.y.dtype = ge::DT_FLOAT; - *select.y.axis = bmm1ResAxis; - *select.y.repeats = vec1ResRepeat; - *select.y.strides = vec1ResStride; - - TbufData softmaxExp("softmaxExp", graph); - softmaxExp.attr.sched.exec_order = exec_order++; - softmaxExp.attr.sched.axis = bmm1ResAxis; - softmaxExp.y.dtype = ge::DT_FLOAT; - *softmaxExp.y.axis = bmm1ResAxis; - *softmaxExp.y.repeats = reduceResRepeat; - *softmaxExp.y.strides = reduceResStride; - - TbufData softmaxApiTmpBuf("softmaxApiTmpBuf", graph); - softmaxApiTmpBuf.attr.sched.exec_order = exec_order++; - softmaxApiTmpBuf.attr.sched.axis = bmm1ResAxis; - softmaxApiTmpBuf.y.dtype = ge::DT_FLOAT; - *softmaxApiTmpBuf.y.axis = bmm1ResAxis; - *softmaxApiTmpBuf.y.repeats = {ONE, ONE, ONE, S1, S2, ONE, ONE}; - *softmaxApiTmpBuf.y.strides = {ZERO, ZERO, ZERO, S2, ONE, ZERO, ZERO}; - - FlashSoftmax flashSoftmax("flashSoftmax"); - flashSoftmax.x1 = select.y; - flashSoftmax.x2 = softmaxExp.y; - flashSoftmax.x3 = softmaxApiTmpBuf.y; - flashSoftmax.attr.sched.exec_order = exec_order++; - flashSoftmax.attr.sched.axis = bmm1ResAxis; - flashSoftmax.y1.dtype = ge::DT_FLOAT; - *flashSoftmax.y1.axis = bmm1ResAxis; - *flashSoftmax.y1.repeats = vec1ResRepeat; - *flashSoftmax.y1.strides = vec1ResStride; - - flashSoftmax.y2.dtype = ge::DT_FLOAT; - *flashSoftmax.y2.axis = bmm1ResAxis; - *flashSoftmax.y2.repeats = {B, N, G, S1, S2, ONE, BL}; - *flashSoftmax.y2.strides = {N*G*S1*BL, G*S1*BL, S1*BL, BL, ZERO, ZERO, ONE}; - - flashSoftmax.y3.dtype = ge::DT_FLOAT; - *flashSoftmax.y3.axis = bmm1ResAxis; - *flashSoftmax.y3.repeats = {B, N, G, S1, S2, ONE, BL}; - *flashSoftmax.y3.strides = {N*G*S1*BL, G*S1*BL, S1*BL, BL, ZERO, ZERO, ONE}; - - Store storeSoftmaxMax("storeSoftmaxMax"); - storeSoftmaxMax.x = flashSoftmax.y3; - storeSoftmaxMax.attr.sched.exec_order = exec_order++; - storeSoftmaxMax.attr.sched.axis = bmm1ResAxis; - storeSoftmaxMax.y.dtype = ge::DT_FLOAT; - *storeSoftmaxMax.y.axis = bmm1ResAxis; - *storeSoftmaxMax.y.repeats = {B, N, G, S1, S2, ONE, BL}; - *storeSoftmaxMax.y.strides = {N*G*S1*BL, G*S1*BL, S1*BL, BL, ZERO, ZERO, ONE}; - - Output softmaxMax("softmaxMax"); - softmaxMax.x = storeSoftmaxMax.y; - softmaxMax.attr.sched.exec_order = exec_order++; - - Data dropMask("dropMask", graph); - dropMask.attr.sched.exec_order = exec_order++; - dropMask.y.dtype = ge::DT_UINT8; - *dropMask.y.axis = bmm1ResAxis; - *dropMask.y.repeats = vec1ResRepeat; - *dropMask.y.strides = vec1ResStride; - - Load loadDropMask("loadDropMask"); - loadDropMask.x = dropMask.y; - loadDropMask.attr.sched.exec_order = exec_order++; - loadDropMask.attr.sched.axis = bmm1ResAxis; - loadDropMask.y.dtype = ge::DT_UINT8; - *loadDropMask.y.axis = bmm1ResAxis; - *loadDropMask.y.repeats = vec1ResRepeat; - *loadDropMask.y.strides = vec1ResStride; - - Dropout dropout("dropout"); - dropout.x1 = flashSoftmax.y1; - dropout.x2 = loadDropMask.y; - dropout.attr.sched.exec_order = exec_order++; - dropout.attr.sched.axis = bmm1ResAxis; - dropout.y.dtype = ge::DT_FLOAT; - *dropout.y.axis = bmm1ResAxis; - *dropout.y.repeats = vec1ResRepeat; - *dropout.y.strides = vec1ResStride; - - Cast castVec1Res("castVec1Res"); - castVec1Res.x = dropout.y; - castVec1Res.attr.sched.exec_order = exec_order++; - castVec1Res.attr.sched.axis = bmm1ResAxis; - castVec1Res.y.dtype = ge::DT_FLOAT16; - *castVec1Res.y.axis = bmm1ResAxis; - *castVec1Res.y.repeats = vec1ResRepeat; - *castVec1Res.y.strides = vec1ResStride; - - Store storeVec1Res("storeVec1Res"); - storeVec1Res.x = castVec1Res.y; - storeVec1Res.attr.sched.exec_order = exec_order++; - storeVec1Res.attr.sched.axis = bmm1ResAxis; - storeVec1Res.y.dtype = ge::DT_FLOAT16; - *storeVec1Res.y.axis = bmm1ResAxis; - *storeVec1Res.y.repeats = vec1ResRepeat; - *storeVec1Res.y.strides = vec1ResStride; - - Data value("value", graph); - value.attr.sched.exec_order = exec_order++; - value.attr.sched.axis = bmm2ResAxis; - value.y.dtype = ge::DT_FLOAT16; - *value.y.axis = bmm2ResAxis; - *value.y.repeats = {B, N, G, ONE, S2, D, ONE}; - *value.y.strides = {N*S2*D, S2*D, S2*D, ZERO, D, ONE, ZERO}; - - MatMul bmm2("bmm2"); - bmm2.x1 = storeVec1Res.y; - bmm2.x2 = value.y; - bmm2.attr.sched.exec_order = exec_order++; - bmm2.attr.sched.axis = bmm2ResAxis; - bmm2.y.dtype = ge::DT_FLOAT; - *bmm2.y.axis = bmm2ResAxis; - *bmm2.y.repeats = {B, N, G, S1, ONE, D, ONE}; - *bmm2.y.strides = {N*G*S1*D, G*S1*D, S1*D, D, ZERO, ONE, ZERO}; - - Load load2("load2"); - load2.x = bmm2.y; - load2.attr.sched.exec_order = exec_order++; - load2.attr.sched.axis = bmm2ResAxis; - load2.y.dtype = ge::DT_FLOAT; - *load2.y.axis = bmm2ResAxis; - *load2.y.repeats = vec2ResRepeat; - *load2.y.strides = vec2ResStride; - - Workspace addResOut("addResOut", graph); - addResOut.attr.sched.exec_order = exec_order++; - addResOut.attr.sched.axis = bmm2ResAxis; - addResOut.y.dtype = ge::DT_FLOAT; - *addResOut.y.axis = bmm2ResAxis; - *addResOut.y.repeats = vec2ResRepeat; - *addResOut.y.strides = vec2ResStride; - - Load loadAddResOut("loadAddResOut"); - loadAddResOut.x = addResOut.y; - loadAddResOut.attr.sched.exec_order = exec_order++; - loadAddResOut.attr.sched.axis = bmm2ResAxis; - loadAddResOut.y.dtype = ge::DT_FLOAT; - *loadAddResOut.y.axis = bmm2ResAxis; - *loadAddResOut.y.repeats = vec2ResRepeat; - *loadAddResOut.y.strides = vec2ResStride; - - ge::ascir_op::Mul mulRes("mulRes"); - mulRes.x1 = loadAddResOut.y; - mulRes.x2 = softmaxExp.y; - mulRes.attr.sched.exec_order = exec_order++; - mulRes.attr.sched.axis = bmm2ResAxis; - mulRes.y.dtype = ge::DT_FLOAT; - *mulRes.y.axis = bmm2ResAxis; - *mulRes.y.repeats = vec2ResRepeat; - *mulRes.y.strides = vec2ResStride; - - ge::ascir_op::Add addRes("addRes"); - addRes.x1 = load2.y; - addRes.x2 = mulRes.y; - addRes.attr.sched.exec_order = exec_order++; - addRes.attr.sched.axis = bmm2ResAxis; - addRes.y.dtype = ge::DT_FLOAT; - *addRes.y.axis = bmm2ResAxis; - *addRes.y.repeats = vec2ResRepeat; - *addRes.y.strides = vec2ResStride; - - ge::ascir_op::Div div("div"); - div.x1 = addRes.y; - div.x2 = flashSoftmax.y3; - div.attr.sched.exec_order = exec_order++; - div.attr.sched.axis = bmm2ResAxis; - div.y.dtype = ge::DT_FLOAT; - *div.y.axis = bmm2ResAxis; - *div.y.repeats = vec2ResRepeat; - *div.y.strides = vec2ResStride; - - Cast castBmm2Res("castBmm2Res"); - castBmm2Res.x = div.y; - castBmm2Res.attr.sched.exec_order = exec_order++; - castBmm2Res.attr.sched.axis = bmm2ResAxis; - castBmm2Res.y.dtype = ge::DT_FLOAT16; - *castBmm2Res.y.axis = bmm2ResAxis; - *castBmm2Res.y.repeats = vec2ResRepeat; - *castBmm2Res.y.strides = vec2ResStride; - - Store store("store"); - store.x = castBmm2Res.y; - store.attr.sched.exec_order = exec_order++; - store.attr.sched.axis = bmm2ResAxis; - store.y.dtype = ge::DT_FLOAT16; - *store.y.axis = bmm2ResAxis; - *store.y.repeats = vec2ResRepeat; - *store.y.strides = vec2ResStride; - - Output buf("buf"); - buf.x = store.y; - buf.attr.sched.exec_order = exec_order++; - buf.y.dtype = ge::DT_FLOAT16; - *buf.y.axis = bmm2ResAxis; - *buf.y.repeats = vec2ResRepeat; - *buf.y.strides = vec2ResStride; - - Output buf_("buf_"); - buf_.x = store.y; - buf_.attr.sched.exec_order = exec_order++; - buf_.attr.sched.axis = {1, 2, 3, 4, 5}; - buf_.attr.sched.loop_axis = {3}; - buf_.attr.api.type = ge::ApiType::kAPITypeBuffer; - buf_.attr.api.unit = ge::ComputeUnit::kUnitMTE1; - buf_.attr.api.compute_type = ge::ComputeType::kComputeReduce; - buf_.y.dtype = ge::DT_DUAL_SUB_UINT8; - buf_.y.format = ge::FORMAT_C1HWC0; - *buf_.y.axis = bmm2ResAxis; - *buf_.y.repeats = vec2ResRepeat; - *buf_.y.strides = vec2ResStride; -} - -void FaAfterScheduler(ge::AscGraph &graph) { - auto b = graph.GetAllAxis()[0]->id; - auto n = graph.GetAllAxis()[1]->id; - auto g = graph.GetAllAxis()[2]->id; - auto s1 = graph.GetAllAxis()[3]->id; - auto s2 = graph.GetAllAxis()[4]->id; - auto d = graph.GetAllAxis()[5]->id; - auto bl = graph.GetAllAxis()[6]->id; - - std::tuple split = graph.TileSplit(s1); - auto s1T = *(std::get<0>(split)); - auto s1t = *(std::get<1>(split)); - graph.FindAxis(s1t.id)->align = 128; - auto mcAxis = *graph.MergeAxis({b, n, g, s1T.id}); - split = graph.BlockSplit(mcAxis.id); - auto mcAxisB = *(std::get<0>(split)); - auto mcAxisb = *(std::get<1>(split)); - - split = graph.TileSplit(s2); - auto s2T = *(std::get<0>(split)); - auto s2t = *(std::get<1>(split)); - graph.FindAxis(s2t.id)->align = 256; - - split = graph.TileSplit(s1t.id); - auto s1tT = *(std::get<0>(split)); - auto s1tt = *(std::get<1>(split)); - vector bmm1VectorizedAxis{s1t.id, s2t.id, d}; - vector vec1VectorizedAxis{s1tt.id, s2t.id, d}; - vector bmm2VectorizedAxis{s1t.id, d, s2t.id}; - - auto X = Symbol("X"); - auto Y = Symbol("Y"); - auto Z = Symbol("Z"); - - auto query = graph.FindNode("query"); - graph.ApplySplit(query, s1T.id, s1t.id); - graph.ApplyMerge(query, mcAxis.id); - graph.ApplySplit(query, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(query, s2T.id, s2t.id); - graph.ApplyReorder(query, {mcAxisB.id, mcAxisb.id, s2T.id, s1t.id, s2t.id, d, bl}); - query->attr.sched.loop_axis = s2T.id; - query->outputs[0].attr.vectorized_axis = {s1t.id, s2t.id, d}; - query->outputs[0].attr.vectorized_strides = {X, Y, Z}; - - auto key = graph.FindNode("key"); - graph.ApplySplit(key, s1T.id, s1t.id); - graph.ApplyMerge(key, mcAxis.id); - graph.ApplySplit(key, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(key, s2T.id, s2t.id); - graph.ApplyReorder(key, {mcAxisB.id, mcAxisb.id, s2T.id, s1t.id, d, s2t.id, bl}); - key->attr.sched.loop_axis = s2T.id; - key->outputs[0].attr.vectorized_axis = {s1t.id, d, s2t.id}; - - auto bmmReorderedAxis = {mcAxisB.id, mcAxisb.id, s2T.id, s1t.id, s2t.id, d, bl}; - auto vecReorderedAxis = {mcAxisB.id, mcAxisb.id, s2T.id, s1tT.id, s1tt.id, s2t.id, d, bl}; - - auto bmm1 = graph.FindNode("bmm1"); - graph.ApplySplit(bmm1, s1T.id, s1t.id); - graph.ApplyMerge(bmm1, mcAxis.id); - graph.ApplySplit(bmm1, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(bmm1, s2T.id, s2t.id); - graph.ApplyReorder(bmm1, bmmReorderedAxis); - bmm1->attr.sched.loop_axis = s2T.id; - bmm1->outputs[0].attr.vectorized_axis = bmm1VectorizedAxis; - - auto load1 = graph.FindNode("load1"); - graph.ApplySplit(load1, s1T.id, s1t.id); - graph.ApplyMerge(load1, mcAxis.id); - graph.ApplySplit(load1, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(load1, s2T.id, s2t.id); - graph.ApplySplit(load1, s1tT.id, s1tt.id); - graph.ApplyReorder(load1, vecReorderedAxis); - load1->attr.sched.loop_axis = s1tT.id; - load1->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto loadPse = graph.FindNode("loadPse"); - graph.ApplySplit(loadPse, s1T.id, s1t.id); - graph.ApplyMerge(loadPse, mcAxis.id); - graph.ApplySplit(loadPse, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(loadPse, s2T.id, s2t.id); - graph.ApplySplit(loadPse, s1tT.id, s1tt.id); - graph.ApplyReorder(loadPse, vecReorderedAxis); - loadPse->attr.sched.loop_axis = s1tT.id; - loadPse->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto castPse = graph.FindNode("castPse"); - graph.ApplySplit(castPse, s1T.id, s1t.id); - graph.ApplyMerge(castPse, mcAxis.id); - graph.ApplySplit(castPse, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(castPse, s2T.id, s2t.id); - graph.ApplySplit(castPse, s1tT.id, s1tt.id); - graph.ApplyReorder(castPse, vecReorderedAxis); - castPse->attr.sched.loop_axis = s1tT.id; - castPse->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto add1 = graph.FindNode("add1"); - graph.ApplySplit(add1, s1T.id, s1t.id); - graph.ApplyMerge(add1, mcAxis.id); - graph.ApplySplit(add1, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(add1, s2T.id, s2t.id); - graph.ApplySplit(add1, s1tT.id, s1tt.id); - graph.ApplyReorder(add1, vecReorderedAxis); - add1->attr.sched.loop_axis = s1tT.id; - add1->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto mul1 = graph.FindNode("mul1"); - graph.ApplySplit(mul1, s1T.id, s1t.id); - graph.ApplyMerge(mul1, mcAxis.id); - graph.ApplySplit(mul1, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(mul1, s2T.id, s2t.id); - graph.ApplySplit(mul1, s1tT.id, s1tt.id); - graph.ApplyReorder(mul1, vecReorderedAxis); - mul1->attr.sched.loop_axis = s1tT.id; - mul1->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto loadAttenMask = graph.FindNode("loadAttenMask"); - graph.ApplySplit(loadAttenMask, s1T.id, s1t.id); - graph.ApplyMerge(loadAttenMask, mcAxis.id); - graph.ApplySplit(loadAttenMask, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(loadAttenMask, s2T.id, s2t.id); - graph.ApplySplit(loadAttenMask, s1tT.id, s1tt.id); - graph.ApplyReorder(loadAttenMask, vecReorderedAxis); - loadAttenMask->attr.sched.loop_axis = s1tT.id; - loadAttenMask->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto select = graph.FindNode("select"); - graph.ApplySplit(select, s1T.id, s1t.id); - graph.ApplyMerge(select, mcAxis.id); - graph.ApplySplit(select, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(select, s2T.id, s2t.id); - graph.ApplySplit(select, s1tT.id, s1tt.id); - graph.ApplyReorder(select, vecReorderedAxis); - select->attr.sched.loop_axis = s1tT.id; - select->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto loadDropMask = graph.FindNode("loadDropMask"); - graph.ApplySplit(loadDropMask, s1T.id, s1t.id); - graph.ApplyMerge(loadDropMask, mcAxis.id); - graph.ApplySplit(loadDropMask, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(loadDropMask, s2T.id, s2t.id); - graph.ApplySplit(loadDropMask, s1tT.id, s1tt.id); - graph.ApplyReorder(loadDropMask, vecReorderedAxis); - loadDropMask->attr.sched.loop_axis = s1tT.id; - loadDropMask->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto dropout = graph.FindNode("dropout"); - graph.ApplySplit(dropout, s1T.id, s1t.id); - graph.ApplyMerge(dropout, mcAxis.id); - graph.ApplySplit(dropout, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(dropout, s2T.id, s2t.id); - graph.ApplySplit(dropout, s1tT.id, s1tt.id); - graph.ApplyReorder(dropout, vecReorderedAxis); - dropout->attr.sched.loop_axis = s1tT.id; - dropout->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto castVec1Res = graph.FindNode("castVec1Res"); - graph.ApplySplit(castVec1Res, s1T.id, s1t.id); - graph.ApplyMerge(castVec1Res, mcAxis.id); - graph.ApplySplit(castVec1Res, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(castVec1Res, s2T.id, s2t.id); - graph.ApplySplit(castVec1Res, s1tT.id, s1tt.id); - graph.ApplyReorder(castVec1Res, vecReorderedAxis); - castVec1Res->attr.sched.loop_axis = s1tT.id; - castVec1Res->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto storeVec1Res = graph.FindNode("storeVec1Res"); - graph.ApplySplit(storeVec1Res, s1T.id, s1t.id); - graph.ApplyMerge(storeVec1Res, mcAxis.id); - graph.ApplySplit(storeVec1Res, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(storeVec1Res, s2T.id, s2t.id); - graph.ApplySplit(storeVec1Res, s1tT.id, s1tt.id); - graph.ApplyReorder(storeVec1Res, vecReorderedAxis); - storeVec1Res->attr.sched.loop_axis = s1tT.id; - storeVec1Res->outputs[0].attr.vectorized_axis = {s1tT.id, s1tt.id, s2t.id, d}; - - auto softmaxExp = graph.FindNode("softmaxExp"); - graph.ApplySplit(softmaxExp, s1T.id, s1t.id); - graph.ApplyMerge(softmaxExp, mcAxis.id); - graph.ApplySplit(softmaxExp, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(softmaxExp, s2T.id, s2t.id); - graph.ApplySplit(softmaxExp, s1tT.id, s1tt.id); - graph.ApplyReorder(softmaxExp, vecReorderedAxis); - softmaxExp->attr.sched.loop_axis = s1tT.id; - softmaxExp->outputs[0].attr.vectorized_axis = {s1tT.id, s1tt.id, s2t.id, d, bl}; - - auto softmaxApiTmpBuf = graph.FindNode("softmaxApiTmpBuf"); - graph.ApplySplit(softmaxApiTmpBuf, s1T.id, s1t.id); - graph.ApplyMerge(softmaxApiTmpBuf, mcAxis.id); - graph.ApplySplit(softmaxApiTmpBuf, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(softmaxApiTmpBuf, s2T.id, s2t.id); - graph.ApplySplit(softmaxApiTmpBuf, s1tT.id, s1tt.id); - graph.ApplyReorder(softmaxApiTmpBuf, vecReorderedAxis); - softmaxApiTmpBuf->attr.sched.loop_axis = s1tT.id; - softmaxApiTmpBuf->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - - auto flashSoftmax = graph.FindNode("flashSoftmax"); - graph.ApplySplit(flashSoftmax, s1T.id, s1t.id); - graph.ApplyMerge(flashSoftmax, mcAxis.id); - graph.ApplySplit(flashSoftmax, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(flashSoftmax, s2T.id, s2t.id); - graph.ApplySplit(flashSoftmax, s1tT.id, s1tt.id); - graph.ApplyReorder(flashSoftmax, vecReorderedAxis); - flashSoftmax->attr.sched.loop_axis = s1tT.id; - flashSoftmax->outputs[0].attr.vectorized_axis = vec1VectorizedAxis; - flashSoftmax->outputs[1].attr.vectorized_axis = {s1tT.id, s1tt.id, s2t.id, d, bl}; - flashSoftmax->outputs[2].attr.vectorized_axis = {s1tT.id, s1tt.id, s2t.id, d, bl}; - - auto storeSoftmaxMax = graph.FindNode("storeSoftmaxMax"); - graph.ApplySplit(storeSoftmaxMax, s1T.id, s1t.id); - graph.ApplyMerge(storeSoftmaxMax, mcAxis.id); - graph.ApplySplit(storeSoftmaxMax, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(storeSoftmaxMax, s2T.id, s2t.id); - graph.ApplySplit(storeSoftmaxMax, s1tT.id, s1tt.id); - graph.ApplyReorder(storeSoftmaxMax, vecReorderedAxis); - storeSoftmaxMax->attr.sched.loop_axis = s1tT.id; - storeSoftmaxMax->outputs[0].attr.vectorized_axis = {s1tT.id, s1tt.id, s2t.id, d, bl}; - - auto value = graph.FindNode("value"); - graph.ApplySplit(value, s1T.id, s1t.id); - graph.ApplyMerge(value, mcAxis.id); - graph.ApplySplit(value, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(value, s2T.id, s2t.id); - graph.ApplyReorder(value, bmmReorderedAxis); - value->attr.sched.loop_axis = s2T.id; - value->outputs[0].attr.vectorized_axis = {s1t.id, s2t.id, d}; - - auto bmm2 = graph.FindNode("bmm2"); - graph.ApplySplit(bmm2, s1T.id, s1t.id); - graph.ApplyMerge(bmm2, mcAxis.id); - graph.ApplySplit(bmm2, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(bmm2, s2T.id, s2t.id); - graph.ApplyReorder(bmm2, bmmReorderedAxis); - bmm2->attr.sched.loop_axis = s2T.id; - bmm2->outputs[0].attr.vectorized_axis = {s1t.id, s2t.id, d}; - - split = graph.TileSplit(s1t.id, "s1tT2", "s1tt2"); - auto s1Vec2tT = *(std::get<0>(split)); - auto s1Vec2tt = *(std::get<1>(split)); - vector vec2VectorizedAxis{s1Vec2tt.id, d, s2t.id}; - auto vec2ReorderedAxis = {mcAxisB.id, mcAxisb.id, s2T.id, s1Vec2tT.id, s1Vec2tt.id, s2t.id, d, bl}; - - auto load2 = graph.FindNode("load2"); - graph.ApplySplit(load2, s1T.id, s1t.id); - graph.ApplyMerge(load2, mcAxis.id); - graph.ApplySplit(load2, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(load2, s2T.id, s2t.id); - graph.ApplySplit(load2, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(load2, vec2ReorderedAxis); - load2->attr.sched.loop_axis = s1Vec2tT.id; - load2->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; - - auto addResOut = graph.FindNode("addResOut"); - graph.ApplySplit(addResOut, s1T.id, s1t.id); - graph.ApplyMerge(addResOut, mcAxis.id); - graph.ApplySplit(addResOut, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(addResOut, s2T.id, s2t.id); - graph.ApplySplit(addResOut, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(addResOut, vec2ReorderedAxis); - addResOut->attr.sched.loop_axis = s1Vec2tT.id; - addResOut->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; - - auto loadAddResOut = graph.FindNode("loadAddResOut"); - graph.ApplySplit(loadAddResOut, s1T.id, s1t.id); - graph.ApplyMerge(loadAddResOut, mcAxis.id); - graph.ApplySplit(loadAddResOut, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(loadAddResOut, s2T.id, s2t.id); - graph.ApplySplit(loadAddResOut, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(loadAddResOut, vec2ReorderedAxis); - loadAddResOut->attr.sched.loop_axis = s1Vec2tT.id; - loadAddResOut->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; - - auto mulRes = graph.FindNode("mulRes"); - graph.ApplySplit(mulRes, s1T.id, s1t.id); - graph.ApplyMerge(mulRes, mcAxis.id); - graph.ApplySplit(mulRes, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(mulRes, s2T.id, s2t.id); - graph.ApplySplit(mulRes, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(mulRes, vec2ReorderedAxis); - mulRes->attr.sched.loop_axis = s1Vec2tT.id; - mulRes->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; - - auto addRes = graph.FindNode("addRes"); - graph.ApplySplit(addRes, s1T.id, s1t.id); - graph.ApplyMerge(addRes, mcAxis.id); - graph.ApplySplit(addRes, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(addRes, s2T.id, s2t.id); - graph.ApplySplit(addRes, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(addRes, vec2ReorderedAxis); - addRes->attr.sched.loop_axis = s1Vec2tT.id; - addRes->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; - - auto div = graph.FindNode("div"); - graph.ApplySplit(div, s1T.id, s1t.id); - graph.ApplyMerge(div, mcAxis.id); - graph.ApplySplit(div, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(div, s2T.id, s2t.id); - graph.ApplySplit(div, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(div, vec2ReorderedAxis); - div->attr.sched.loop_axis = s1Vec2tT.id; - div->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; - - auto castBmm2Res = graph.FindNode("castBmm2Res"); - graph.ApplySplit(castBmm2Res, s1T.id, s1t.id); - graph.ApplyMerge(castBmm2Res, mcAxis.id); - graph.ApplySplit(castBmm2Res, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(castBmm2Res, s2T.id, s2t.id); - graph.ApplySplit(castBmm2Res, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(castBmm2Res, vec2ReorderedAxis); - castBmm2Res->attr.sched.loop_axis = s1Vec2tT.id; - castBmm2Res->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; - - auto store = graph.FindNode("store"); - graph.ApplySplit(store, s1T.id, s1t.id); - graph.ApplyMerge(store, mcAxis.id); - graph.ApplySplit(store, mcAxisB.id, mcAxisb.id); - graph.ApplySplit(store, s2T.id, s2t.id); - graph.ApplySplit(store, s1Vec2tT.id, s1Vec2tt.id); - graph.ApplyReorder(store, vec2ReorderedAxis); - store->attr.sched.loop_axis = s1Vec2tT.id; - store->outputs[0].attr.vectorized_axis = vec2VectorizedAxis; -} - -void FaAfterQueBufAlloc(ge::AscGraph &graph) { - int32_t tensorID = 0; - int32_t queID = 0; - int32_t bufID = 0; - int32_t mmRes1Que = queID++; - int32_t stage1Que = queID++; - int32_t pseTBuf = bufID++; - int32_t commonTBuf = bufID++; - int32_t maskTbufPing = bufID++; - int32_t maskTbufPong = bufID++; - int32_t softmaxMaxBuf = bufID++; - int32_t softmaxSumQueue = queID++; - int32_t softmaxExpQueue = queID++; - int32_t stage2Buf = bufID++; - int32_t stage1ResQueue = queID++; - int32_t mm2ResQueue = queID++; - int32_t vec2ResQueue = queID++; - - auto query = graph.FindNode("query"); - query->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - query->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto key = graph.FindNode("key"); - key->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - key->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto value = graph.FindNode("value"); - value->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - value->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto bmm1 = graph.FindNode("bmm1"); - bmm1->outputs[0].attr.mem.tensor_id = tensorID++; - bmm1->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - bmm1->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - bmm1->outputs[0].attr.mem.position = ge::Position::kPositionGM; - bmm1->outputs[0].attr.buf.id = ge::kIdNone; - bmm1->outputs[0].attr.que.id = mmRes1Que; - bmm1->outputs[0].attr.que.depth = 2; - bmm1->outputs[0].attr.que.buf_num = 2; - bmm1->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - bmm1->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto load1 = graph.FindNode("load1"); - load1->outputs[0].attr.mem.tensor_id = tensorID++; - load1->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - load1->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - load1->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - load1->outputs[0].attr.buf.id = ge::kIdNone; - load1->outputs[0].attr.que.id = stage1Que; - load1->outputs[0].attr.que.depth = 2; - load1->outputs[0].attr.que.buf_num = 2; - load1->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - load1->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto loadPse = graph.FindNode("loadPse"); - loadPse->outputs[0].attr.mem.tensor_id = tensorID++; - loadPse->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - loadPse->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - loadPse->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - loadPse->outputs[0].attr.buf.id = pseTBuf; - loadPse->outputs[0].attr.que.id = ge::kIdNone; - loadPse->outputs[0].attr.que.depth = ge::kIdNone; - loadPse->outputs[0].attr.que.buf_num = ge::kIdNone; - loadPse->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - loadPse->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto castPse = graph.FindNode("castPse"); - castPse->outputs[0].attr.mem.tensor_id = tensorID++; - castPse->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - castPse->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - castPse->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - castPse->outputs[0].attr.buf.id = commonTBuf; - castPse->outputs[0].attr.que.id = ge::kIdNone; - castPse->outputs[0].attr.que.depth = ge::kIdNone; - castPse->outputs[0].attr.que.buf_num = ge::kIdNone; - castPse->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - castPse->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto add1 = graph.FindNode("add1"); - add1->outputs[0].attr.mem.tensor_id = tensorID++; - add1->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - add1->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - add1->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - add1->outputs[0].attr.buf.id = ge::kIdNone; - add1->outputs[0].attr.que.id = stage1Que; - add1->outputs[0].attr.que.depth = 2; - add1->outputs[0].attr.que.buf_num = 2; - add1->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - add1->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto mul1 = graph.FindNode("mul1"); - mul1->outputs[0].attr.mem.tensor_id = tensorID++; - mul1->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - mul1->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - mul1->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - mul1->outputs[0].attr.buf.id = ge::kIdNone; - mul1->outputs[0].attr.que.id = stage1Que; - mul1->outputs[0].attr.que.depth = 2; - mul1->outputs[0].attr.que.buf_num = 2; - mul1->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - mul1->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto select = graph.FindNode("select"); - select->outputs[0].attr.mem.tensor_id = tensorID++; - select->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - select->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - select->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - select->outputs[0].attr.buf.id = ge::kIdNone; - select->outputs[0].attr.que.id = stage1Que; - select->outputs[0].attr.que.depth = 2; - select->outputs[0].attr.que.buf_num = 2; - select->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - select->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto softmaxExp = graph.FindNode("softmaxExp"); - softmaxExp->outputs[0].attr.mem.tensor_id = tensorID++; - softmaxExp->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - softmaxExp->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - softmaxExp->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - softmaxExp->outputs[0].attr.buf.id = ge::kIdNone; - softmaxExp->outputs[0].attr.que.id = softmaxExpQueue; - softmaxExp->outputs[0].attr.que.depth = 2; - softmaxExp->outputs[0].attr.que.buf_num = 2; - softmaxExp->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - softmaxExp->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto softmaxApiTmpBuf = graph.FindNode("softmaxApiTmpBuf"); - softmaxApiTmpBuf->outputs[0].attr.mem.tensor_id = tensorID++; - softmaxApiTmpBuf->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - softmaxApiTmpBuf->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - softmaxApiTmpBuf->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - softmaxApiTmpBuf->outputs[0].attr.buf.id = commonTBuf; - softmaxApiTmpBuf->outputs[0].attr.que.id = ge::kIdNone; - softmaxApiTmpBuf->outputs[0].attr.que.depth = ge::kIdNone; - softmaxApiTmpBuf->outputs[0].attr.que.buf_num = ge::kIdNone; - softmaxApiTmpBuf->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - softmaxApiTmpBuf->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto flashSoftmax = graph.FindNode("flashSoftmax"); - flashSoftmax->outputs[0].attr.mem.tensor_id = tensorID++; - flashSoftmax->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - flashSoftmax->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - flashSoftmax->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - flashSoftmax->outputs[0].attr.buf.id = ge::kIdNone; - flashSoftmax->outputs[0].attr.que.id = stage1Que; - flashSoftmax->outputs[0].attr.que.depth = 2; - flashSoftmax->outputs[0].attr.que.buf_num = 2; - flashSoftmax->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - flashSoftmax->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - flashSoftmax->outputs[1].attr.mem.tensor_id = tensorID++; - flashSoftmax->outputs[1].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - flashSoftmax->outputs[1].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - flashSoftmax->outputs[1].attr.mem.position = ge::Position::kPositionVecOut; - flashSoftmax->outputs[1].attr.buf.id = softmaxMaxBuf; - flashSoftmax->outputs[1].attr.que.id = ge::kIdNone; - flashSoftmax->outputs[1].attr.que.depth = ge::kIdNone; - flashSoftmax->outputs[1].attr.que.buf_num = ge::kIdNone; - flashSoftmax->outputs[1].attr.opt.ref_tensor = ge::kIdNone; - flashSoftmax->outputs[1].attr.opt.merge_scope = ge::kIdNone; - - flashSoftmax->outputs[2].attr.mem.tensor_id = tensorID++; - flashSoftmax->outputs[2].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - flashSoftmax->outputs[2].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - flashSoftmax->outputs[2].attr.mem.position = ge::Position::kPositionVecOut; - flashSoftmax->outputs[2].attr.buf.id = ge::kIdNone; - flashSoftmax->outputs[2].attr.que.id = softmaxSumQueue; - flashSoftmax->outputs[2].attr.que.depth = 2; - flashSoftmax->outputs[2].attr.que.buf_num = 2; - flashSoftmax->outputs[2].attr.opt.ref_tensor = ge::kIdNone; - flashSoftmax->outputs[2].attr.opt.merge_scope = ge::kIdNone; - - auto loadAttenMask = graph.FindNode("loadAttenMask"); - loadAttenMask->outputs[0].attr.mem.tensor_id = tensorID++; - loadAttenMask->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - loadAttenMask->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - loadAttenMask->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - loadAttenMask->outputs[0].attr.buf.id = maskTbufPing; - loadAttenMask->outputs[0].attr.que.id = ge::kIdNone; - loadAttenMask->outputs[0].attr.que.depth = ge::kIdNone; - loadAttenMask->outputs[0].attr.que.buf_num = ge::kIdNone; - loadAttenMask->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - loadAttenMask->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto loadDropMask = graph.FindNode("loadDropMask"); - loadDropMask->outputs[0].attr.mem.tensor_id = tensorID++; - loadDropMask->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - loadDropMask->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - loadDropMask->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - loadDropMask->outputs[0].attr.buf.id = maskTbufPong; - loadDropMask->outputs[0].attr.que.id = ge::kIdNone; - loadDropMask->outputs[0].attr.que.depth = ge::kIdNone; - loadDropMask->outputs[0].attr.que.buf_num = ge::kIdNone; - loadDropMask->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - loadDropMask->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto dropout = graph.FindNode("dropout"); - dropout->outputs[0].attr.mem.tensor_id = tensorID++; - dropout->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - dropout->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - dropout->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - dropout->outputs[0].attr.buf.id = ge::kIdNone; - dropout->outputs[0].attr.que.id = stage1Que; - dropout->outputs[0].attr.que.depth = 2; - dropout->outputs[0].attr.que.buf_num = 2; - dropout->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - dropout->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto castVec1Res = graph.FindNode("castVec1Res"); - castVec1Res->outputs[0].attr.mem.tensor_id = tensorID++; - castVec1Res->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - castVec1Res->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - castVec1Res->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - castVec1Res->outputs[0].attr.buf.id = pseTBuf; - castVec1Res->outputs[0].attr.que.id = ge::kIdNone; - castVec1Res->outputs[0].attr.que.depth = ge::kIdNone; - castVec1Res->outputs[0].attr.que.buf_num = ge::kIdNone; - castVec1Res->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - castVec1Res->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto storeVec1Res = graph.FindNode("storeVec1Res"); - storeVec1Res->outputs[0].attr.mem.tensor_id = tensorID++; - storeVec1Res->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - storeVec1Res->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - storeVec1Res->outputs[0].attr.mem.position = ge::Position::kPositionGM; - storeVec1Res->outputs[0].attr.buf.id = ge::kIdNone; - storeVec1Res->outputs[0].attr.que.id = stage1ResQueue; - storeVec1Res->outputs[0].attr.que.depth = 2; - storeVec1Res->outputs[0].attr.que.buf_num = 2; - storeVec1Res->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - storeVec1Res->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto bmm2 = graph.FindNode("bmm2"); - bmm2->outputs[0].attr.mem.tensor_id = tensorID++; - bmm2->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - bmm2->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - bmm2->outputs[0].attr.mem.position = ge::Position::kPositionGM; - bmm2->outputs[0].attr.buf.id = ge::kIdNone; - bmm2->outputs[0].attr.que.id = mm2ResQueue; - bmm2->outputs[0].attr.que.depth = 2; - bmm2->outputs[0].attr.que.buf_num = 2; - bmm2->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - bmm2->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto load2 = graph.FindNode("load2"); - load2->outputs[0].attr.mem.tensor_id = tensorID++; - load2->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - load2->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - load2->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - load2->outputs[0].attr.buf.id = commonTBuf; - load2->outputs[0].attr.que.id = ge::kIdNone; - load2->outputs[0].attr.que.depth = ge::kIdNone; - load2->outputs[0].attr.que.buf_num = ge::kIdNone; - load2->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - load2->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto addResOut = graph.FindNode("addResOut"); - addResOut->outputs[0].attr.mem.tensor_id = tensorID++; - addResOut->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - addResOut->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - addResOut->outputs[0].attr.mem.position = ge::Position::kPositionGM; - addResOut->outputs[0].attr.buf.id = ge::kIdNone; - addResOut->outputs[0].attr.que.id = vec2ResQueue; - addResOut->outputs[0].attr.que.depth = 2; - addResOut->outputs[0].attr.que.buf_num = 2; - addResOut->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - addResOut->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto loadAddResOut = graph.FindNode("loadAddResOut"); - loadAddResOut->outputs[0].attr.mem.tensor_id = tensorID++; - loadAddResOut->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - loadAddResOut->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - loadAddResOut->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - loadAddResOut->outputs[0].attr.buf.id = stage2Buf; - loadAddResOut->outputs[0].attr.que.id = ge::kIdNone; - loadAddResOut->outputs[0].attr.que.depth = ge::kIdNone; - loadAddResOut->outputs[0].attr.que.buf_num = ge::kIdNone; - loadAddResOut->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - loadAddResOut->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto mulRes = graph.FindNode("mulRes"); - mulRes->outputs[0].attr.mem.tensor_id = tensorID++; - mulRes->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - mulRes->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - mulRes->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - mulRes->outputs[0].attr.buf.id = stage2Buf; - mulRes->outputs[0].attr.que.id = ge::kIdNone; - mulRes->outputs[0].attr.que.depth = ge::kIdNone; - mulRes->outputs[0].attr.que.buf_num = ge::kIdNone; - mulRes->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - mulRes->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto addRes = graph.FindNode("addRes"); - addRes->outputs[0].attr.mem.tensor_id = tensorID++; - addRes->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - addRes->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - addRes->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - addRes->outputs[0].attr.buf.id = stage2Buf; - addRes->outputs[0].attr.que.id = ge::kIdNone; - addRes->outputs[0].attr.que.depth = ge::kIdNone; - addRes->outputs[0].attr.que.buf_num = ge::kIdNone; - addRes->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - addRes->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto div = graph.FindNode("div"); - div->outputs[0].attr.mem.tensor_id = tensorID++; - div->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - div->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - div->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - div->outputs[0].attr.buf.id = stage2Buf; - div->outputs[0].attr.que.id = ge::kIdNone; - div->outputs[0].attr.que.depth = ge::kIdNone; - div->outputs[0].attr.que.buf_num = ge::kIdNone; - div->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - div->outputs[0].attr.opt.merge_scope = ge::kIdNone; - - auto castBmm2Res = graph.FindNode("castBmm2Res"); - castBmm2Res->outputs[0].attr.mem.tensor_id = tensorID++; - castBmm2Res->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - castBmm2Res->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - castBmm2Res->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - castBmm2Res->outputs[0].attr.buf.id = stage2Buf; - castBmm2Res->outputs[0].attr.que.id = ge::kIdNone; - castBmm2Res->outputs[0].attr.que.depth = ge::kIdNone; - castBmm2Res->outputs[0].attr.que.buf_num = ge::kIdNone; - castBmm2Res->outputs[0].attr.opt.ref_tensor = ge::kIdNone; - castBmm2Res->outputs[0].attr.opt.merge_scope = castBmm2Res->outputs[0].attr.mem.tensor_id; - - auto store = graph.FindNode("store"); - store->outputs[0].attr.mem.tensor_id = tensorID++; - store->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeGlobal; - store->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - store->outputs[0].attr.mem.position = ge::Position::kPositionGM; - store->outputs[0].attr.opt.ref_tensor = 0; - store->outputs[0].attr.opt.merge_scope = 0; - - auto storeSoftmaxMax = graph.FindNode("storeSoftmaxMax"); - storeSoftmaxMax->outputs[0].attr.mem.tensor_id = tensorID++; - storeSoftmaxMax->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeGlobal; - storeSoftmaxMax->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - storeSoftmaxMax->outputs[0].attr.mem.position = ge::Position::kPositionGM; - storeSoftmaxMax->outputs[0].attr.opt.ref_tensor = 0; - storeSoftmaxMax->outputs[0].attr.opt.merge_scope = 0; - - auto buf_ = graph.FindNode("buf_"); - buf_->outputs[0].attr.mem.tensor_id = 1; - buf_->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeL1; - buf_->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - buf_->outputs[0].attr.mem.buf_ids = {1, 2, 3, 4, 5}; - buf_->outputs[0].attr.mem.name = "Mem_"; -} - -void CreatConcatAscGraph(ge::AscGraph &graph) { - auto ONE = Symbol(1); - const Expression s0 = graph.CreateSizeVar("s0"); - const Expression s1 = graph.CreateSizeVar("s1"); - const Expression s2 = graph.CreateSizeVar("s2"); - - auto z0 = graph.CreateAxis("z0", s0); - auto z1 = graph.CreateAxis("z1", s1 + s2 + s2); - - ge::ascir_op::Data x1("concat_data0", graph); - x1.attr.sched.axis = {z0.id, z1.id}; - *x1.y.axis = {z0.id, z1.id}; - *x1.y.repeats = {s0, s1}; - *x1.y.strides = {s1, ONE}; - - ge::ascir_op::Load load0("concat_load0"); - load0.x = x1.y; - load0.attr.sched.axis = {z0.id, z1.id}; - *load0.y.axis = {z0.id, z1.id}; - *load0.y.repeats = {s0, s1}; - *load0.y.strides = {s1, ONE}; - - ge::ascir_op::Data x2("concat_data1", graph); - x2.attr.sched.axis = {z0.id, z1.id}; - *x2.y.axis = {z0.id, z1.id}; - *x2.y.repeats = {s0, s2}; - *x2.y.strides = {s2, ONE}; - - ge::ascir_op::Load load1("concat_load1"); - load1.x = x2.y; - load1.attr.sched.axis = {z0.id, z1.id}; - *load1.y.axis = {z0.id, z1.id}; - *load1.y.repeats = {s0, s2}; - *load1.y.strides = {s2, ONE}; - - ge::ascir_op::Data x3("x3", graph); - x3.attr.sched.axis = {z0.id, z1.id}; - *x3.y.axis = {z0.id, z1.id}; - *x3.y.repeats = {s0, s2}; - *x3.y.strides = {s2, ONE}; - - ge::ascir_op::Load load2("load2"); - load2.x = x3.y; - load2.attr.sched.axis = {z0.id, z1.id}; - *load2.y.axis = {z0.id, z1.id}; - *load2.y.repeats = {s0, s2}; - *load2.y.strides = {s2, ONE}; - - ge::ascir_op::Concat concat("concat"); - concat.x = {load0.y, load1.y, load2.y}; - concat.attr.sched.axis = {z0.id, z1.id}; - *concat.y.axis = {z0.id, z1.id}; - *concat.y.repeats = {s0, s1 + s2 + s2}; - *concat.y.strides = {s1 + s2 + s2, ONE}; - - ge::ascir_op::Store x_out("concat_store"); - x_out.x = concat.y; - x_out.attr.sched.axis = {z0.id, z1.id}; - *x_out.y.axis = {z0.id, z1.id}; - *x_out.y.repeats = {s0, s1 + s2 + s2}; - *x_out.y.strides = {s1 + s2 + s2, ONE}; - - ge::ascir_op::Output y("concat_out"); - y.x = x_out.y; - y.y.dtype = ge::DT_FLOAT16; -} -} //namespace ge diff --git a/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/stub_graph.h b/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/stub_graph.h deleted file mode 100644 index 4ab7b0f83aee9bf3d7be31eddd998ab5559862f5..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/ascendc_ir_dump_test/stub_graph.h +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef STUB_GRAPH_H_ -#define STUB_GRAPH_H_ -#include "ascir_ops.h" -#include "inc/graph/ascendc_ir/ascendc_ir_core/ascendc_ir.h" - -namespace ge { -void FaBeforeAutoFuse(ge::AscGraph &graph); -void FaAfterScheduler(ge::AscGraph &graph); -void FaAfterQueBufAlloc(ge::AscGraph &graph); -void CreatConcatAscGraph(ge::AscGraph &graph); -} // namespace ge -#endif \ No newline at end of file diff --git a/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc b/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc deleted file mode 100644 index 3aef9048109966925113bc52a3071a51f6c8386e..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/ascendc_ir_unittest.cc +++ /dev/null @@ -1,4090 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include -#include -#include "ascir_ops.h" -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "inc/graph/symbolizer/symbolic.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/ascendc_ir/ascir_registry.h" -#include "graph/ascendc_ir/ascir_register.h" -#include "graph/utils/graph_utils.h" -#include "slog.h" -#include "expression/const_values.h" -#include "code_extractor.h" -#include "ascendc_ir/utils/asc_graph_utils.h" - -using namespace ge::ascir_op; -namespace { -constexpr int64_t ID_NONE = -1; -} -class UtestAscendCIR : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; -using namespace ge; -using ge::Expression; -using ge::Symbol; -namespace { -struct AscNodeInfo { - std::string name; - std::string type; - size_t input_num; - size_t output_num; - std::vector axis_ids; -}; -template -class OpDtypeInfer { - public: - OpDtypeInfer &Input(ge::DataType input_type) { - input_dtypes_.push_back(input_type); - return *this; - } - OpDtypeInfer &Expect(const DataType &type) { - expected_dtypes_.emplace_back(type); - return *this; - } - void AssertSucceed() { - if (expected_dtypes_.empty()) { - std::vector dtypes; - ASSERT_EQ(T::InferDataType(input_dtypes_, dtypes), GRAPH_SUCCESS); - std::vector dtypes2; - ASSERT_EQ(ascir::CommonInferDtype(T::Type, input_dtypes_, dtypes2), GRAPH_SUCCESS); - } else { - ASSERT_EQ(T::InferDataType(input_dtypes_, expected_dtypes_), GRAPH_SUCCESS); - ASSERT_EQ(ascir::CommonInferDtype(T::Type, input_dtypes_, expected_dtypes_), GRAPH_SUCCESS); - } - } - - void AssertFailed() { - if (expected_dtypes_.empty()) { - std::vector dtypes; - ASSERT_NE(T::InferDataType(input_dtypes_, dtypes), GRAPH_SUCCESS); - std::vector dtypes2; - ASSERT_NE(ascir::CommonInferDtype(T::Type, input_dtypes_, dtypes), GRAPH_SUCCESS); - } else { - ASSERT_NE(T::InferDataType(input_dtypes_, expected_dtypes_), GRAPH_SUCCESS); - ASSERT_NE(ascir::CommonInferDtype(T::Type, input_dtypes_, expected_dtypes_), GRAPH_SUCCESS); - } - } - private: - std::vector input_dtypes_; - std::vector expected_dtypes_; -}; -} -TEST_F(UtestAscendCIR, TilingKey_OK) { - AscGraph graph("test_graph"); - graph.SetTilingKey(10); - EXPECT_EQ(graph.GetTilingKey(), 10); -} - -TEST_F(UtestAscendCIR, CreateSizeVar_OK) { - AscGraph graph("test_graph"); - const auto &s0 = graph.CreateSizeVar("s0"); - const auto &s1 = graph.CreateSizeVar("s1"); - const auto &s2 = graph.CreateSizeVar("s2"); - const auto &const_10 = graph.CreateSizeVar(10); - Symbol symbol1(1, "MyOne"); - graph.CreateSizeVar(symbol1); - Symbol symbol2("s3"); - graph.CreateSizeVar(symbol2); - auto all_size_var = graph.GetAllSizeVar(); - EXPECT_EQ(all_size_var.size(), 6u); - EXPECT_EQ(all_size_var[4]->expr.IsConstExpr(), true); - int64_t i_get(-1); - EXPECT_EQ(all_size_var[4]->expr.GetConstValue<>(i_get), true); - EXPECT_EQ(i_get, 1); - EXPECT_EQ(all_size_var[5]->expr.IsConstExpr(), false); - EXPECT_EQ(all_size_var[5]->expr.Str().get(), std::string("s3")); -} - -TEST_F(UtestAscendCIR, CreateAxis) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - s0_axis.align = 10; - auto s0_axis_find = graph.FindAxis(s0_axis.id); - EXPECT_NE(s0_axis_find, nullptr); - EXPECT_EQ(s0_axis_find->name, "S0"); - EXPECT_EQ(s0_axis_find->align, 10); - - auto axis_invalid = graph.FindAxis(-1); - EXPECT_EQ(axis_invalid, nullptr); -} - -TEST_F(UtestAscendCIR, BlockSplit) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - auto split_axis = graph.BlockSplit(s0_axis.id); - EXPECT_NE(split_axis.first, nullptr); - EXPECT_NE(split_axis.second, nullptr); - auto &outer_axis = *split_axis.first; - auto &inner_axis = *split_axis.second; - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeBlockInner); - EXPECT_EQ(outer_axis.type, Axis::kAxisTypeBlockOuter); -} - -TEST_F(UtestAscendCIR, TileSplit) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - s0_axis.align = 10; - auto split_axis = graph.TileSplit(s0_axis.id); - EXPECT_NE(split_axis.first, nullptr); - EXPECT_NE(split_axis.second, nullptr); - auto &outer_axis = *split_axis.first; - auto &inner_axis = *split_axis.second; - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeTileInner); - EXPECT_EQ(outer_axis.type, Axis::kAxisTypeTileOuter); -} - -TEST_F(UtestAscendCIR, TileSplitSizeOneAxis) { - AscGraph graph("test_graph"); - Axis &s0_axis = graph.CreateAxis("S0", ge::sym::kSymbolOne); - auto split_axis = graph.TileSplit(s0_axis.id); - EXPECT_NE(split_axis.first, nullptr); - EXPECT_NE(split_axis.second, nullptr); - auto &outer_axis = *split_axis.first; - auto &inner_axis = *split_axis.second; - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeTileInner); - EXPECT_EQ(inner_axis.size, ge::sym::kSymbolOne); - EXPECT_EQ(outer_axis.type, Axis::kAxisTypeTileOuter); - EXPECT_EQ(outer_axis.size, ge::sym::kSymbolOne); -} - -TEST_F(UtestAscendCIR, MergeAxis) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); -} - -TEST_F(UtestAscendCIR, BindBlock) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - - s0_axis.align = 10; - auto split_axis = graph.TileSplit(s0_axis.id); - EXPECT_NE(split_axis.first, nullptr); - EXPECT_NE(split_axis.second, nullptr); - auto &outer_axis = *split_axis.first; - auto &inner_axis = *split_axis.second; - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeTileInner); - EXPECT_EQ(outer_axis.type, Axis::kAxisTypeTileOuter); - - auto merge_axis = graph.MergeAxis({outer_axis.id, s1_axis.id}); - EXPECT_TRUE(graph.BindBlock(merge_axis->id, inner_axis.id)); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeBlockOuter); - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeBlockInner); -} - -TEST_F(UtestAscendCIR, GetAllAxisTransInfo) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto [aBO, aBI] = graph.BlockSplit(a.id); - EXPECT_EQ(graph.GetAllAxisTransInfo().size(), 1U); - EXPECT_EQ(graph.GetAllAxisTransInfo().front().trans_type, TransType::kSplit); - EXPECT_EQ(graph.GetAllAxisTransInfo().front().src_axis.size(), 1U); - EXPECT_EQ(graph.GetAllAxisTransInfo().front().src_axis.front()->id, a.id); - EXPECT_EQ(graph.GetAllAxisTransInfo().front().dst_axis.size(), 2U); - EXPECT_EQ(graph.GetAllAxisTransInfo().front().dst_axis, std::vector({aBO, aBI})); - auto aBIb = graph.MergeAxis({aBI->id, b.id}); - EXPECT_EQ(graph.GetAllAxisTransInfo().size(), 2U); - EXPECT_EQ(graph.GetAllAxisTransInfo()[1U].trans_type, TransType::kMerge); - EXPECT_EQ(graph.GetAllAxis().size(), 6U); -} - -TEST_F(UtestAscendCIR, ApplySplit) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id}; - *data.y.repeats = {s0}; - *data.y.strides = {sym::kSymbolOne}; - - auto split_axis = graph.TileSplit(s0_axis.id); - EXPECT_NE(split_axis.first, nullptr); - EXPECT_NE(split_axis.second, nullptr); - auto &outer_axis = *split_axis.first; - auto &inner_axis = *split_axis.second; - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeTileInner); - EXPECT_EQ(outer_axis.type, Axis::kAxisTypeTileOuter); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySplit(data_node_find, outer_axis.id, inner_axis.id); -} - -TEST_F(UtestAscendCIR, ApplySplit_BroadCast) { - AscGraph graph("test_graph"); - auto A = graph.CreateSizeVar("A"); - auto R = graph.CreateSizeVar("R"); - auto BL = graph.CreateSizeVar("BL"); - // 定义轴 - auto a = graph.CreateAxis("A", A); - auto r = graph.CreateAxis("R", R); - auto bl = graph.CreateAxis("BL", BL); - - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {a.id, r.id, bl.id}; - data.y.dtype = ge::DT_FLOAT16; - *data.y.axis = {a.id, r.id, bl.id}; - *data.y.repeats = {A, R, sym::kSymbolOne}; - *data.y.strides = {R, sym::kSymbolOne, sym::kSymbolZero}; - - auto split_axis = graph.TileSplit(bl.id); - EXPECT_NE(split_axis.first, nullptr); - EXPECT_NE(split_axis.second, nullptr); - auto &outer_axis = *split_axis.first; - auto &inner_axis = *split_axis.second; - EXPECT_EQ(inner_axis.type, Axis::kAxisTypeTileInner); - EXPECT_EQ(outer_axis.type, Axis::kAxisTypeTileOuter); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySplit(data_node_find, outer_axis.id, inner_axis.id); -} - -TEST_F(UtestAscendCIR, ApplyMerge) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - data_node_find->attr.sched.exec_order = 1; - data_node_find->attr.sched.axis = {s0_axis.id, s1_axis.id}; - EXPECT_NE(data_node_find, nullptr); - graph.ApplyMerge(data_node_find, merge_axis->id); - EXPECT_EQ(data_node_find->attr.sched.axis.size(), 1U); - EXPECT_EQ(data_node_find->outputs[0].attr.axis.size(), 1U); -} - -TEST_F(UtestAscendCIR, ApplyMerge_0_not_merge_tensor_but_merge_node) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {sym::kSymbolOne, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - data_node_find->attr.sched.exec_order = 1; - data_node_find->attr.sched.axis = {s0_axis.id, s1_axis.id}; - graph.ApplyMerge(data_node_find, merge_axis->id); - EXPECT_EQ(data_node_find->attr.sched.axis.size(), 1U); - EXPECT_EQ(data_node_find->outputs[0].attr.axis.size(), 2U); -} - -TEST_F(UtestAscendCIR, ApplyMerge_1) { - AscGraph graph("test_graph"); - const auto s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const auto s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyMerge(data_node_find, merge_axis->id); - EXPECT_EQ((data_node_find->outputs[0].attr.repeats[0U]), (s0*s1)); - EXPECT_EQ((data_node_find->outputs[0].attr.repeats[0U]), (s0*s1)); - EXPECT_EQ((data_node_find->outputs[0].attr.strides[0U]), 1UL); -} - -TEST_F(UtestAscendCIR, ApplySchedAxisMerge) { - { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySchedAxisMerge(data_node_find, merge_axis->id); - EXPECT_EQ(data_node_find->attr.sched.axis.size(), 1U); - auto sched_axis = data_node_find->attr.sched.axis[0]; - EXPECT_EQ(sched_axis, merge_axis->id); - } - - { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySchedAxisMerge(data_node_find, merge_axis->id, {s0_axis.id, s1_axis.id}); - auto sched_axis = data_node_find->attr.sched.axis[0]; - EXPECT_EQ(sched_axis, merge_axis->id); - } - - { - // 非连续,正序场景 - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - const Expression s2 = graph.CreateSizeVar("s2"); - Axis &s2_axis = graph.CreateAxis("S2", s2); - const Expression s3 = graph.CreateSizeVar("s3"); - Axis &s3_axis = graph.CreateAxis("S3", s3); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - *data.y.repeats = {s0, s1, s2, s3}; - *data.y.strides = {s1*s2*s3, s2*s3, s3, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s3_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySchedAxisMerge(data_node_find, merge_axis->id, {s0_axis.id, s3_axis.id}); - EXPECT_EQ(data_node_find->attr.sched.axis[0], merge_axis->id); - EXPECT_EQ(data_node_find->attr.sched.axis[1], s1_axis.id); - EXPECT_EQ(data_node_find->attr.sched.axis[2], s2_axis.id); - } - - { - // 连续,倒序场景 - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - const Expression s2 = graph.CreateSizeVar("s2"); - Axis &s2_axis = graph.CreateAxis("S2", s2); - const Expression s3 = graph.CreateSizeVar("s3"); - Axis &s3_axis = graph.CreateAxis("S3", s3); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - *data.y.repeats = {s0, s1, s2, s3}; - *data.y.strides = {s1*s2*s3, s2*s3, s3, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s1_axis.id, s0_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySchedAxisMerge(data_node_find, merge_axis->id, {s1_axis.id, s0_axis.id}); - EXPECT_EQ(data_node_find->attr.sched.axis[0], merge_axis->id); - EXPECT_EQ(data_node_find->attr.sched.axis[1], s2_axis.id); - EXPECT_EQ(data_node_find->attr.sched.axis[2], s3_axis.id); - } - - { - // 非合并轴场景 - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - const Expression s2 = graph.CreateSizeVar("s2"); - Axis &s2_axis = graph.CreateAxis("S2", s2); - const Expression s3 = graph.CreateSizeVar("s3"); - Axis &s3_axis = graph.CreateAxis("S3", s3); - const Expression s4 = graph.CreateSizeVar("s4"); - Axis &s4_axis = graph.CreateAxis("S4", s4); - const Expression s5 = graph.CreateSizeVar("s5"); - Axis &s5_axis = graph.CreateAxis("S5", s5); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - *data.y.repeats = {s0, s1, s2, s3}; - *data.y.strides = {s1*s2*s3, s2*s3, s3, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s4_axis.id, s5_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySchedAxisMerge(data_node_find, merge_axis->id, {s4_axis.id, s5_axis.id}); - EXPECT_EQ(data_node_find->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(data_node_find->attr.sched.axis[1], s1_axis.id); - EXPECT_EQ(data_node_find->attr.sched.axis[2], s2_axis.id); - EXPECT_EQ(data_node_find->attr.sched.axis[3], s3_axis.id); - } -} - -TEST_F(UtestAscendCIR, ApplyTensorAxisMerge) { - { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyTensorAxisMerge(data_node_find, merge_axis->id); - auto tensor_axis = data_node_find->outputs[0].attr.axis[0]; - EXPECT_EQ(tensor_axis, merge_axis->id); - auto tensor_stride = data_node_find->outputs[0].attr.strides[0]; - EXPECT_TRUE(tensor_stride == 1); - auto tensor_repeat = data_node_find->outputs[0].attr.repeats[0]; - EXPECT_EQ(tensor_repeat, merge_axis->size); - } - - { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s0_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyTensorAxisMerge(data_node_find, merge_axis->id, {s0_axis.id, s1_axis.id}); - auto tensor_axis = data_node_find->outputs[0].attr.axis[0]; - EXPECT_EQ(tensor_axis, merge_axis->id); - auto tensor_stride = data_node_find->outputs[0].attr.strides[0]; - EXPECT_TRUE(tensor_stride == 1); - auto tensor_repeat = data_node_find->outputs[0].attr.repeats[0]; - EXPECT_EQ(tensor_repeat, merge_axis->size); - } - - { - // 连续,正序场景 - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - const Expression s2 = graph.CreateSizeVar("s2"); - Axis &s2_axis = graph.CreateAxis("S2", s2); - const Expression s3 = graph.CreateSizeVar("s3"); - Axis &s3_axis = graph.CreateAxis("S3", s3); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - *data.y.repeats = {s0, s1, s2, s3}; - *data.y.strides = {s1*s2*s3, s2*s3, s3, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s1_axis.id, s2_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyTensorAxisMerge(data_node_find, merge_axis->id, {s1_axis.id, s2_axis.id}); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[0], s0_axis.id); - EXPECT_TRUE(data_node_find->outputs[0].attr.strides[0] == s1*s2*s3); - EXPECT_EQ(data_node_find->outputs[0].attr.repeats[0], s0_axis.size); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[1], merge_axis->id); - EXPECT_TRUE(data_node_find->outputs[0].attr.strides[1] == s3); - EXPECT_EQ(data_node_find->outputs[0].attr.repeats[1], merge_axis->size); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[2], s3_axis.id); - EXPECT_TRUE(data_node_find->outputs[0].attr.strides[2] == sym::kSymbolOne); - EXPECT_EQ(data_node_find->outputs[0].attr.repeats[2], s3_axis.size); - } - - { - // 连续,倒序场景 - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - const Expression s2 = graph.CreateSizeVar("s2"); - Axis &s2_axis = graph.CreateAxis("S2", s2); - const Expression s3 = graph.CreateSizeVar("s3"); - Axis &s3_axis = graph.CreateAxis("S3", s3); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - *data.y.repeats = {s0, s1, s2, s3}; - *data.y.strides = {s1*s2*s3, s2*s3, s3, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s2_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyTensorAxisMerge(data_node_find, merge_axis->id, {s2_axis.id, s1_axis.id}); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[0], s0_axis.id); - EXPECT_TRUE(data_node_find->outputs[0].attr.strides[0] == s1*s2*s3); - EXPECT_EQ(data_node_find->outputs[0].attr.repeats[0], s0_axis.size); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[1], merge_axis->id); - EXPECT_TRUE(data_node_find->outputs[0].attr.strides[1] == s3); - EXPECT_EQ(data_node_find->outputs[0].attr.repeats[1], merge_axis->size); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[2], s3_axis.id); - EXPECT_TRUE(data_node_find->outputs[0].attr.strides[2] == sym::kSymbolOne); - EXPECT_EQ(data_node_find->outputs[0].attr.repeats[2], s3_axis.size); - } - - { - // 非连续,倒序场景 - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - const Expression s2 = graph.CreateSizeVar("s2"); - Axis &s2_axis = graph.CreateAxis("S2", s2); - const Expression s3 = graph.CreateSizeVar("s3"); - Axis &s3_axis = graph.CreateAxis("S3", s3); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - *data.y.repeats = {s0, s1, s2, s3}; - *data.y.strides = {s1*s2*s3, s2*s3, s3, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s3_axis.id, s1_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyTensorAxisMerge(data_node_find, merge_axis->id, {s3_axis.id, s1_axis.id}); - EXPECT_EQ(data_node_find->outputs[0].attr.axis.size(), 4); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[0], s0_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[1], s1_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[2], s2_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[3], s3_axis.id); - } - - { - // 未触发合轴 - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - const Expression s2 = graph.CreateSizeVar("s2"); - Axis &s2_axis = graph.CreateAxis("S2", s2); - const Expression s3 = graph.CreateSizeVar("s3"); - Axis &s3_axis = graph.CreateAxis("S3", s3); - const Expression s4 = graph.CreateSizeVar("s4"); - Axis &s4_axis = graph.CreateAxis("S4", s4); - const Expression s5 = graph.CreateSizeVar("s5"); - Axis &s5_axis = graph.CreateAxis("S5", s5); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id, s2_axis.id, s3_axis.id}; - *data.y.repeats = {s0, s1, s2, s3}; - *data.y.strides = {s1*s2*s3, s2*s3, s3, sym::kSymbolOne}; - - auto merge_axis = graph.MergeAxis({s4_axis.id, s5_axis.id}); - EXPECT_NE(merge_axis, nullptr); - EXPECT_EQ(merge_axis->type, Axis::kAxisTypeMerged); - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyTensorAxisMerge(data_node_find, merge_axis->id, {s4_axis.id, s5_axis.id}); - EXPECT_EQ(data_node_find->outputs[0].attr.axis.size(), 4); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[0], s0_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[1], s1_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[2], s2_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis[3], s3_axis.id); - } -} - -TEST_F(UtestAscendCIR, ApplyReorder) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyReorder(data_node_find, {s1_axis.id, s0_axis.id}); -} - -TEST_F(UtestAscendCIR, ApplyReorder_Sched_Tensor) { - { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplySchedAxisReorder(data_node_find, {s1_axis.id, s0_axis.id}); - } - - { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - const Expression s1 = graph.CreateSizeVar("s1"); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - graph.ApplyTensorAxisReorder(data_node_find, {s1_axis.id, s0_axis.id}); - } -} - -TEST_F(UtestAscendCIR, TryApplyReplace) { - AscGraph graph("test_graph"); - const Expression s0 = graph.CreateSizeVar("s0"); - const Expression s1 = graph.CreateSizeVar("s1"); - - Axis &s0_axis = graph.CreateAxis("S0", s0); - Axis &s1_axis = graph.CreateAxis("S1", s1); - Axis &s1_new_axis = graph.CreateAxis("s1_new", s1); - - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - auto data_node_find = graph.FindNode("data"); - ASSERT_NE(data_node_find, nullptr); - EXPECT_EQ(graph.TryApplyAxisReplace(data_node_find, s1_new_axis, s1_axis), false); - EXPECT_EQ(data.attr.sched.axis.size(), 2UL); - EXPECT_EQ(data.attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(data.attr.sched.axis[1], s1_axis.id); - EXPECT_EQ(data.y.axis->size(), 2UL); - EXPECT_EQ((*data.y.axis)[0], s0_axis.id); - EXPECT_EQ((*data.y.axis)[1], s1_axis.id); - - EXPECT_EQ(data.attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(graph.TryApplyAxisReplace(data_node_find, s1_axis, s1_new_axis), true); - EXPECT_EQ(data.attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(data.attr.sched.axis[1], s1_new_axis.id); - EXPECT_EQ(data.y.axis->size(), 2UL); - EXPECT_EQ((*data.y.axis)[0], s0_axis.id); - EXPECT_EQ((*data.y.axis)[1], s1_new_axis.id); -} - -TEST_F(UtestAscendCIR, Operator_OK) { - AscGraph graph("test_graph"); - Expression s0 = graph.CreateSizeVar("s0"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id}; - *data.y.repeats = {s0}; - *data.y.strides = {sym::kSymbolOne}; - - ascir_op::Abs abs("abs"); - auto abs_node = ge::NodeUtilsEx::GetNodeFromOperator(abs); - EXPECT_EQ(abs_node, nullptr); - abs.x = data.y; - // invalid case - abs.x = AscOpOutput(); - abs_node = ge::NodeUtilsEx::GetNodeFromOperator(abs); - EXPECT_NE(abs_node, nullptr); - - // find Node - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - EXPECT_EQ(data_node_find->attr.sched.exec_order, 1); - EXPECT_EQ(data_node_find->attr.sched.axis.size(), 1U); - EXPECT_EQ(data_node_find->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis.size(), 1); - EXPECT_EQ(ge::DataType(data_node_find->outputs[0].attr.dtype), ge::DT_FLOAT16); - data_node_find->outputs[0].attr.dtype = ge::DT_FLOAT; - EXPECT_EQ(ge::DataType(data_node_find->outputs[0].attr.dtype), ge::DT_FLOAT); - auto abs_node_find = graph.FindNode("abs"); - EXPECT_NE(abs_node_find, nullptr); - - // GetAllNodes - int num = 0; - for (const auto &node : graph.GetAllNodes()) { - if (num == 0) { - EXPECT_EQ(node->GetName(), "data"); - EXPECT_EQ(node->attr.sched.exec_order, 1); - EXPECT_EQ(node->attr.sched.axis.size(), 1U); - EXPECT_EQ(node->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(node->outputs[0].attr.axis.size(), 1); - const auto outputs = node->outputs(); - EXPECT_EQ(outputs.size(), 1U); - EXPECT_NE(outputs[0], nullptr); - EXPECT_EQ(outputs[0]->attr.axis.size(), 1); - } - if (num == 1) { - EXPECT_EQ(node->inputs.Size(), 1U); - EXPECT_EQ(node->inputs[0].attr.axis.size(), 1); - } - num++; - } - EXPECT_EQ(num, 2); - - // GetAllNodes - int input_nodes_num = 0; - for (auto node : graph.GetInputNodes()) { - if (input_nodes_num == 0) { - EXPECT_EQ(node->GetName(), "data"); - EXPECT_EQ(node->attr.sched.exec_order, 1); - EXPECT_EQ(node->attr.sched.axis.size(), 1U); - EXPECT_EQ(node->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(node->outputs[0].attr.axis.size(), 1); - } - input_nodes_num++; - } - EXPECT_EQ(input_nodes_num, 1); - EXPECT_EQ(graph.GetName(), "test_graph"); - - // GetAllAxis - const AscGraph &const_graph = graph; - const auto all_axis = const_graph.GetAllAxis(); - EXPECT_EQ(all_axis.size(), 1U); -} - -TEST_F(UtestAscendCIR, Operator_Fail) { - ascir_op::Abs abs("abs"); - ascir_op::Output output("output"); - output.x = abs.y; - EXPECT_TRUE(ge::NodeUtilsEx::GetNodeFromOperator(abs) == nullptr); - EXPECT_TRUE(ge::NodeUtilsEx::GetNodeFromOperator(output) == nullptr); -} - -void Add_Layer_Norm_Normal_BeforeAutofuse(AscGraph &graph) { - auto ONE = sym::kSymbolOne; - auto ZERO = sym::kSymbolZero; - // 定义轴的大小 - auto A = Symbol("A"); - auto R = Symbol("R"); - auto BL = Symbol(8, "BL"); - - // 定义轴 - auto a = graph.CreateAxis("A", A); - auto r = graph.CreateAxis("R", R); - auto bl = graph.CreateAxis("BL", BL); - - // 定义节点 - int exec_order = 0; - Data x1("x1", graph); - x1.attr.sched.exec_order = exec_order++; - x1.attr.sched.axis = {a.id, r.id, bl.id}; - x1.y.dtype = ge::DT_FLOAT16; - *x1.y.axis = {a.id, r.id, bl.id}; - *x1.y.repeats = {A, R, ONE}; - *x1.y.strides = {R, ONE, ZERO}; - - Load x1Local("x1Local"); - x1Local.x = x1.y; - x1Local.attr.sched.exec_order = exec_order++; - x1Local.attr.sched.axis = {a.id, r.id, bl.id}; - x1Local.y.dtype = ge::DT_FLOAT16; - *x1Local.y.axis = {a.id, r.id, bl.id}; - *x1Local.y.repeats = {A, R, ONE}; - *x1Local.y.strides = {R, ONE, ZERO}; - - Data x2("x2", graph); - x2.attr.sched.exec_order = exec_order++; - x2.attr.sched.axis = {a.id, r.id, bl.id}; - x2.y.dtype = ge::DT_FLOAT16; - *x2.y.axis = {a.id, r.id, bl.id}; - *x2.y.repeats = {A, R, ONE}; - *x2.y.strides = {R, ONE, ZERO}; - - Load x2Local("x2Local"); - x2Local.x = x2.y; - x2Local.attr.sched.exec_order = exec_order++; - x2Local.attr.sched.axis = {a.id, r.id, bl.id}; - x2Local.y.dtype = ge::DT_FLOAT16; - *x2Local.y.axis = {a.id, r.id, bl.id}; - *x2Local.y.repeats = {A, R, ONE}; - *x2Local.y.strides = {R, ONE, ZERO}; - - Data bias("bias", graph); - bias.attr.sched.exec_order = exec_order++; - bias.attr.sched.axis = {a.id, r.id, bl.id}; - bias.y.dtype = ge::DT_FLOAT16; - *bias.y.axis = {a.id, r.id, bl.id}; - *bias.y.repeats = {A, R, ONE}; - *bias.y.strides = {R, ONE, ZERO}; - - Load biasLocal("biasLocal"); - biasLocal.x = bias.y; - biasLocal.attr.sched.exec_order = exec_order++; - biasLocal.attr.sched.axis = {a.id, r.id, bl.id}; - biasLocal.y.dtype = ge::DT_FLOAT16; - *biasLocal.y.axis = {a.id, r.id, bl.id}; - *biasLocal.y.repeats = {A, R, ONE}; - *biasLocal.y.strides = {R, ONE, ZERO}; - - CalcMean mean("mean"); - mean.x1 = x1Local.y; - mean.x2 = x2Local.y; - mean.x3 = biasLocal.y; - mean.attr.sched.exec_order = exec_order++; - mean.attr.sched.axis = {a.id, r.id, bl.id}; - mean.y1.dtype = ge::DT_FLOAT; // mean - *mean.y1.axis = {a.id, r.id, bl.id}; - *mean.y1.repeats = {A, ONE, ONE}; - *mean.y1.strides = {ONE, ZERO, ZERO}; - mean.y2.dtype = ge::DT_FLOAT16; // x out - *mean.y2.axis = {a.id, r.id, bl.id}; - *mean.y2.repeats = {A, R, ONE}; - *mean.y2.strides = {R, ONE, ZERO}; - mean.y3.dtype = ge::DT_FLOAT; // x fp32 - *mean.y3.axis = {a.id, r.id, bl.id}; - *mean.y3.repeats = {A, R, ONE}; - *mean.y3.strides = {R, ONE, ZERO}; - Store x_out("x_out"); - x_out.attr.sched.exec_order = exec_order++; - x_out.attr.sched.axis = {a.id, r.id, bl.id}; - x_out.x = mean.y2; - x_out.y.dtype = ge::DT_FLOAT16; - *x_out.y.axis = {a.id, r.id, bl.id}; - *x_out.y.repeats = {A, R, ONE}; - *x_out.y.strides = {R, ONE, ZERO}; - - Store mean_out("mean_out"); - mean_out.attr.sched.exec_order = exec_order++; - mean_out.attr.sched.axis = {a.id, r.id, bl.id}; - mean_out.x = mean.y1; - mean_out.y.dtype = ge::DT_FLOAT; - *mean_out.y.axis = {a.id, r.id, bl.id}; - *mean_out.y.repeats = {A, ONE, ONE}; - *mean_out.y.strides = {ONE, ZERO, ZERO}; - - TbufData one("one", graph); - one.attr.sched.exec_order = exec_order++; - one.attr.sched.axis = {a.id, r.id, bl.id}; - one.y.dtype = ge::DT_FLOAT; - *one.y.axis = {a.id, r.id, bl.id}; - *one.y.repeats = {ONE, ONE, BL}; - *one.y.strides = {ZERO, ZERO, ONE}; - - CalcRstd rstd("rstd"); - rstd.attr.sched.exec_order = exec_order++; - rstd.attr.sched.axis = {a.id, r.id, bl.id}; - rstd.x1 = mean.y3; - rstd.x2 = mean.y1; - rstd.x3 = one.y; - rstd.y1.dtype = ge::DT_FLOAT; // x-mean - *rstd.y1.axis = {a.id, r.id, bl.id}; - *rstd.y1.repeats = {A, R, ONE}; - *rstd.y1.strides = {R, ONE, ZERO}; - rstd.y2.dtype = ge::DT_FLOAT; // rstd - *rstd.y2.axis = {a.id, r.id, bl.id}; - *rstd.y2.repeats = {A, ONE, ONE}; - *rstd.y2.strides = {ONE, ZERO, ZERO}; - - Store rstd_out("rstd_out"); - rstd_out.attr.sched.exec_order = exec_order++; - rstd_out.attr.sched.axis = {a.id, r.id, bl.id}; - rstd_out.x = rstd.y2; - rstd_out.y.dtype = ge::DT_FLOAT; - *rstd_out.y.axis = {a.id, r.id, bl.id}; - *rstd_out.y.repeats = {A, ONE, ONE}; - *rstd_out.y.strides = {ONE, ZERO, ZERO}; - - Data beta("beta", graph); - beta.attr.sched.exec_order = exec_order++; - beta.attr.sched.axis = {a.id, r.id, bl.id}; - beta.y.dtype = ge::DT_FLOAT16; - *beta.y.axis = {a.id, r.id, bl.id}; - *beta.y.repeats = {ONE, R, ONE}; - *beta.y.strides = {ZERO, ONE, ZERO}; - - Load betaLocal("betaLocal"); - betaLocal.x = beta.y; - betaLocal.attr.sched.exec_order = exec_order++; - betaLocal.attr.sched.axis = {a.id, r.id, bl.id}; - betaLocal.y.dtype = ge::DT_FLOAT16; - *betaLocal.y.axis = {a.id, r.id, bl.id}; - *betaLocal.y.repeats = {ONE, R, ONE}; - *betaLocal.y.strides = {ZERO, ONE, ZERO}; - - Data gamma("gamma", graph); - gamma.attr.sched.exec_order = exec_order++; - gamma.attr.sched.axis = {a.id, r.id, bl.id}; - gamma.y.dtype = ge::DT_FLOAT16; - *gamma.y.axis = {a.id, r.id, bl.id}; - *gamma.y.repeats = {ONE, R, ONE}; - *gamma.y.strides = {ZERO, ONE, ZERO}; - - Load gammaLocal("gammaLocal"); - gammaLocal.x = gamma.y; - gammaLocal.attr.sched.exec_order = exec_order++; - gammaLocal.attr.sched.axis = {a.id, r.id, bl.id}; - gammaLocal.y.dtype = ge::DT_FLOAT16; - *gammaLocal.y.axis = {a.id, r.id, bl.id}; - *gammaLocal.y.repeats = {ONE, R, ONE}; - *gammaLocal.y.strides = {ZERO, ONE, ZERO}; - - CalcY y("y"); - y.attr.sched.exec_order = exec_order++; - y.attr.sched.axis = {a.id, r.id, bl.id}; - y.x1 = rstd.y1; // x-mean - y.x2 = betaLocal.y; - y.x3 = gammaLocal.y; - y.x4 = rstd.y2; // rstd - y.y1.dtype = ge::DT_FLOAT16; - *y.y1.axis = {a.id, r.id, bl.id}; - *y.y1.repeats = {A, R, ONE}; - *y.y1.strides = {R, ONE, ZERO}; - - Store y_out("y_out"); - y_out.attr.sched.exec_order = exec_order++; - y_out.attr.sched.axis = {a.id, r.id, bl.id}; - y_out.x = y.y1; - y_out.y.dtype = ge::DT_FLOAT16; - *y_out.y.axis = {a.id, r.id, bl.id}; - *y_out.y.repeats = {A, R, ONE}; - *y_out.y.strides = {R, ONE, ZERO}; - - Output buf1("buf1"); - buf1.x = x_out.y; - buf1.attr.sched.exec_order = exec_order++; - buf1.y.dtype = ge::DT_FLOAT16; - *buf1.y.axis = {a.id, r.id, bl.id}; - *buf1.y.repeats = {A, R, ONE}; - *buf1.y.strides = {R, ONE, ZERO}; - - Output buf2("buf2"); - buf2.x = mean_out.y; - buf2.attr.sched.exec_order = exec_order++; - buf2.y.dtype = ge::DT_FLOAT; - *buf2.y.axis = {a.id, r.id, bl.id}; - *buf2.y.repeats = {A, ONE, ONE}; - *buf2.y.strides = {ONE, ZERO, ZERO}; - - Output buf3("buf3"); - buf3.x = rstd_out.y; - buf3.attr.sched.exec_order = exec_order++; - buf3.y.dtype = ge::DT_FLOAT; - *buf3.y.axis = {a.id, r.id, bl.id}; - *buf3.y.repeats = {A, ONE, ONE}; - *buf3.y.strides = {ONE, ZERO, ZERO}; - - Output buf("buf"); - buf.x = y_out.y; - buf.attr.sched.exec_order = exec_order++; - buf.y.dtype = ge::DT_FLOAT16; - *buf.y.axis = {a.id, r.id, bl.id}; - *buf.y.repeats = {A, R, ONE}; - *buf.y.strides = {R, ONE, ZERO}; -} - -/* -for aBO - for aBIO - for aBII - for r - load x1 - load x2 - load bias - CalcMean - CalcRstd - Store X - Store mean - Load beta - Load gamma - CalcRstd - Store rstd - CalcY - Store y -*/ - -void Add_Layer_Norm_Normal_AfterScheduler(AscGraph &graph) { - auto a = graph.FindAxis(0)->id; - auto r = graph.FindAxis(1)->id; - - auto [aBO, aBI] = graph.BlockSplit(a, "nbi", "nbo"); // AB Ab - auto [aBIO, aBII] = graph.TileSplit(aBI->id, "nii", "nio"); // AbT Abt - // graph.UpdateAxisAlign(aBI.id, 1u); - // graph.UpdateAxisAlign(aBII.id, 8u); - auto x1 = graph.FindNode("x1"); - graph.ApplySplit(x1, aBO->id, aBI->id); - graph.ApplySplit(x1, aBIO->id, aBII->id); - x1->attr.sched.loop_axis = aBIO->id; - x1->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto x2 = graph.FindNode("x2"); - graph.ApplySplit(x2, aBO->id, aBI->id); - graph.ApplySplit(x2, aBIO->id, aBII->id); - x2->attr.sched.loop_axis = aBIO->id; - x2->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto bias = graph.FindNode("bias"); - graph.ApplySplit(bias, aBO->id, aBI->id); - graph.ApplySplit(bias, aBIO->id, aBII->id); - bias->attr.sched.loop_axis = aBIO->id; - bias->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto x1Local = graph.FindNode("x1Local"); - graph.ApplySplit(x1Local, aBO->id, aBI->id); - graph.ApplySplit(x1Local, aBIO->id, aBII->id); - x1Local->attr.sched.loop_axis = aBIO->id; - x1Local->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto x2Local = graph.FindNode("x2Local"); - graph.ApplySplit(x2Local, aBO->id, aBI->id); - graph.ApplySplit(x2Local, aBIO->id, aBII->id); - x2Local->attr.sched.loop_axis = aBIO->id; - x2Local->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto biasLocal = graph.FindNode("biasLocal"); - graph.ApplySplit(biasLocal,aBO->id, aBI->id); - graph.ApplySplit(biasLocal, aBIO->id, aBII->id); - biasLocal->attr.sched.loop_axis = aBIO->id; - biasLocal->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto mean = graph.FindNode("mean"); - graph.ApplySplit(mean,aBO->id, aBI->id); - graph.ApplySplit(mean,aBIO->id, aBII->id); - mean->attr.sched.loop_axis = aBIO->id; - mean->outputs[0].attr.vectorized_axis = {aBII->id, r}; - mean->outputs[1].attr.vectorized_axis = {aBII->id, r}; - mean->outputs[2].attr.vectorized_axis = {aBII->id, r}; - - auto x_out = graph.FindNode("x_out"); - graph.ApplySplit(x_out, aBO->id, aBI->id); - graph.ApplySplit(x_out, aBIO->id, aBII->id); - x_out->attr.sched.loop_axis = aBIO->id; - x_out->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto mean_out = graph.FindNode("mean_out"); - graph.ApplySplit(mean_out, aBO->id, aBI->id); - graph.ApplySplit(mean_out, aBIO->id, aBII->id); - mean_out->attr.sched.loop_axis = aBIO->id; - mean_out->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto rstd = graph.FindNode("rstd"); - graph.ApplySplit(rstd,aBO->id, aBI->id); - graph.ApplySplit(rstd,aBIO->id, aBII->id); - rstd->attr.sched.loop_axis = aBIO->id; - rstd->outputs[0].attr.vectorized_axis = {aBII->id, r}; - rstd->outputs[1].attr.vectorized_axis = {aBII->id, r}; - - auto rstd_out = graph.FindNode("rstd_out"); - graph.ApplySplit(rstd_out,aBO->id, aBI->id); - graph.ApplySplit(rstd_out,aBIO->id, aBII->id); - rstd_out->attr.sched.loop_axis = aBIO->id; - rstd_out->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto betaLocal = graph.FindNode("betaLocal"); - graph.ApplySplit(betaLocal,aBO->id, aBI->id); - graph.ApplySplit(betaLocal,aBIO->id, aBII->id); - betaLocal->attr.sched.loop_axis = aBIO->id; - betaLocal->outputs[0].attr.vectorized_axis = {r}; - - auto gammaLocal = graph.FindNode("gammaLocal"); - graph.ApplySplit(gammaLocal,aBO->id, aBI->id); - graph.ApplySplit(gammaLocal,aBIO->id, aBII->id); - gammaLocal->attr.sched.loop_axis = aBIO->id; - gammaLocal->outputs[0].attr.vectorized_axis = {r}; - - auto y = graph.FindNode("y"); - graph.ApplySplit(y,aBO->id, aBI->id); - graph.ApplySplit(y,aBIO->id, aBII->id); - y->attr.sched.loop_axis = aBIO->id; - y->outputs[0].attr.vectorized_axis = {aBII->id, r}; - - auto y_out = graph.FindNode("y_out"); - graph.ApplySplit(y_out,aBO->id, aBI->id); - graph.ApplySplit(y_out,aBIO->id, aBII->id); - y_out->attr.sched.loop_axis = aBIO->id; - y_out->outputs[0].attr.vectorized_axis = {aBII->id, r}; -} - -void Add_Layer_Norm_Normal_AfterQueBufAlloc(AscGraph &graph) { - int tensorID = 0; - int queID = 0; - int bufID = 0; - int x1Que = queID++; - int x2Que = queID++; - int biasQue = queID++; - int gammaQue = queID++; - int betaQue = queID++; - int meanQue = queID++; - int rstdQue = queID++; - int yQue = queID++; - int xQue = queID++; - int x32Queue = queID++; - int oneTBuf = bufID++; - - auto x1 = graph.FindNode("x1"); - x1->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - x1->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto x2 = graph.FindNode("x2"); - x2->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - x2->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto bias = graph.FindNode("bias"); - bias->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - bias->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto x1Local = graph.FindNode("x1Local"); - x1Local->outputs[0].attr.mem.tensor_id = tensorID++; - x1Local->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - x1Local->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - x1Local->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - x1Local->outputs[0].attr.buf.id = ID_NONE; - x1Local->outputs[0].attr.que.id = x1Que; - x1Local->outputs[0].attr.que.depth = 1; - x1Local->outputs[0].attr.que.buf_num = 1; - x1Local->outputs[0].attr.opt.ref_tensor = ID_NONE; - x1Local->outputs[0].attr.opt.merge_scope = ID_NONE; - - auto x2Local = graph.FindNode("x2Local"); - x2Local->outputs[0].attr.mem.tensor_id = tensorID++; - x2Local->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - x2Local->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - x2Local->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - x2Local->outputs[0].attr.buf.id = ID_NONE; - x2Local->outputs[0].attr.que.id = x2Que; - x2Local->outputs[0].attr.que.depth = 1; - x2Local->outputs[0].attr.que.buf_num = 1; - x2Local->outputs[0].attr.opt.ref_tensor = ID_NONE; - x2Local->outputs[0].attr.opt.merge_scope = ID_NONE; - - auto biasLocal = graph.FindNode("biasLocal"); - biasLocal->outputs[0].attr.mem.tensor_id = tensorID++; - biasLocal->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - biasLocal->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - biasLocal->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - biasLocal->outputs[0].attr.buf.id = ID_NONE; - biasLocal->outputs[0].attr.que.id = biasQue; - biasLocal->outputs[0].attr.que.depth = 1; - biasLocal->outputs[0].attr.que.buf_num = 1; - biasLocal->outputs[0].attr.opt.ref_tensor = ID_NONE; - biasLocal->outputs[0].attr.opt.merge_scope = ID_NONE; - - auto mean = graph.FindNode("mean"); - mean->outputs[0].attr.mem.tensor_id = tensorID++; - mean->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - mean->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - mean->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - mean->outputs[0].attr.buf.id = ID_NONE; - mean->outputs[0].attr.que.id = meanQue; - mean->outputs[0].attr.que.depth = 1; - mean->outputs[0].attr.que.buf_num = 1; - mean->outputs[0].attr.opt.ref_tensor = ID_NONE; - mean->outputs[0].attr.opt.merge_scope = ID_NONE; - mean->outputs[1].attr.mem.tensor_id = tensorID++; - mean->outputs[1].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - mean->outputs[1].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - mean->outputs[1].attr.mem.position = ge::Position::kPositionVecOut; - mean->outputs[1].attr.buf.id = ID_NONE; - mean->outputs[1].attr.que.id = xQue; - mean->outputs[1].attr.que.depth = 1; - mean->outputs[1].attr.que.buf_num = 1; - mean->outputs[1].attr.opt.ref_tensor = ID_NONE; - mean->outputs[1].attr.opt.merge_scope = ID_NONE; - mean->outputs[2].attr.mem.tensor_id = tensorID++; - mean->outputs[2].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - mean->outputs[2].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - mean->outputs[2].attr.mem.position = ge::Position::kPositionVecOut; - mean->outputs[2].attr.buf.id = ID_NONE; - mean->outputs[2].attr.que.id = x32Queue; - mean->outputs[2].attr.que.depth = 1; - mean->outputs[2].attr.que.buf_num = 1; - mean->outputs[2].attr.opt.ref_tensor = ID_NONE; - mean->outputs[2].attr.opt.merge_scope = ID_NONE; - - auto x_out = graph.FindNode("x_out"); - x_out->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - x_out->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto mean_out = graph.FindNode("mean_out"); - mean_out->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - mean_out->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto one = graph.FindNode("one"); - one->outputs[0].attr.mem.tensor_id = tensorID++; - one->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeBuffer; - one->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - one->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - one->outputs[0].attr.buf.id = oneTBuf; - one->outputs[0].attr.que.id = ID_NONE; - one->outputs[0].attr.que.depth = ID_NONE; - one->outputs[0].attr.que.buf_num = ID_NONE; - one->outputs[0].attr.opt.ref_tensor = ID_NONE; - one->outputs[0].attr.opt.merge_scope = ID_NONE; - - auto rstd = graph.FindNode("rstd"); - rstd->outputs[0].attr.mem.tensor_id = tensorID++; - rstd->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - rstd->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - rstd->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - rstd->outputs[0].attr.buf.id =ID_NONE; - rstd->outputs[0].attr.que.id = yQue; - rstd->outputs[0].attr.que.depth = 1; - rstd->outputs[0].attr.que.buf_num = 1; - rstd->outputs[0].attr.opt.ref_tensor = ID_NONE; - rstd->outputs[0].attr.opt.merge_scope = ID_NONE; - rstd->outputs[1].attr.mem.tensor_id = tensorID++; - rstd->outputs[1].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - rstd->outputs[1].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - rstd->outputs[1].attr.mem.position = ge::Position::kPositionVecOut; - rstd->outputs[1].attr.buf.id = ID_NONE; - rstd->outputs[1].attr.que.id = rstdQue; - rstd->outputs[1].attr.que.depth = 1; - rstd->outputs[1].attr.que.buf_num = 1; - rstd->outputs[1].attr.opt.ref_tensor = ID_NONE; - rstd->outputs[1].attr.opt.merge_scope = ID_NONE; - - auto rstd_out = graph.FindNode("rstd_out"); - rstd_out->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - rstd_out->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto beta = graph.FindNode("beta"); - beta->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - beta->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto betaLocal = graph.FindNode("betaLocal"); - betaLocal->outputs[0].attr.mem.tensor_id = tensorID++; - betaLocal->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - betaLocal->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - betaLocal->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - betaLocal->outputs[0].attr.buf.id = ID_NONE; - betaLocal->outputs[0].attr.que.id = betaQue; - betaLocal->outputs[0].attr.que.depth = 1; - betaLocal->outputs[0].attr.que.buf_num = 1; - betaLocal->outputs[0].attr.opt.ref_tensor = ID_NONE; - betaLocal->outputs[0].attr.opt.merge_scope = ID_NONE; - - auto gamma = graph.FindNode("gamma"); - gamma->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - gamma->outputs[0].attr.mem.position = ge::Position::kPositionGM; - - auto gammaLocal = graph.FindNode("gammaLocal"); - gammaLocal->outputs[0].attr.mem.tensor_id = tensorID++; - gammaLocal->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - gammaLocal->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - gammaLocal->outputs[0].attr.mem.position = ge::Position::kPositionVecIn; - gammaLocal->outputs[0].attr.buf.id = ID_NONE; - gammaLocal->outputs[0].attr.que.id = gammaQue; - gammaLocal->outputs[0].attr.que.depth = 1; - gammaLocal->outputs[0].attr.que.buf_num = 1; - gammaLocal->outputs[0].attr.opt.ref_tensor = ID_NONE; - gammaLocal->outputs[0].attr.opt.merge_scope = ID_NONE; - - auto y = graph.FindNode("y"); - y->outputs[0].attr.mem.tensor_id = tensorID++; - y->outputs[0].attr.mem.alloc_type = ge::AllocType::kAllocTypeQueue; - y->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareUB; - y->outputs[0].attr.mem.position = ge::Position::kPositionVecOut; - y->outputs[0].attr.buf.id = ID_NONE; - y->outputs[0].attr.que.id = yQue; - y->outputs[0].attr.que.depth = 1; - y->outputs[0].attr.que.buf_num = 1; - y->outputs[0].attr.opt.ref_tensor = ID_NONE; - y->outputs[0].attr.opt.merge_scope = ID_NONE; - - auto y_out = graph.FindNode("y_out"); - y_out->outputs[0].attr.mem.hardware = ge::MemHardware::kMemHardwareGM; - y_out->outputs[0].attr.mem.position = ge::Position::kPositionGM; -} - -TEST_F(UtestAscendCIR, CheckValid) { - AscGraph graph_normal("graph_normal"); - graph_normal.SetTilingKey(1101u); - Add_Layer_Norm_Normal_BeforeAutofuse(graph_normal); - Add_Layer_Norm_Normal_AfterScheduler(graph_normal); - Add_Layer_Norm_Normal_AfterQueBufAlloc(graph_normal); - EXPECT_EQ(graph_normal.CheckValid(), true); - auto betaLocal = graph_normal.FindNode("betaLocal"); - betaLocal->outputs[0].attr.que.id = ID_NONE; - EXPECT_EQ(graph_normal.CheckValid(), false); -} -TEST_F(UtestAscendCIR, CreateContiguousData) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto data = graph.CreateContiguousData("data0", ge::DT_INT8, {a}); - EXPECT_EQ(data.output_index, 0U); - EXPECT_EQ(static_cast(data.dtype), ge::DT_INT8); - EXPECT_EQ(data.format, ge::FORMAT_ND); - EXPECT_EQ(*data.axis, std::vector({a.id})); - const auto &attr = graph.FindNode("data0")->attr; - EXPECT_EQ(attr.sched.exec_order, 0U); - EXPECT_EQ(attr.sched.loop_axis, -1); - EXPECT_TRUE(attr.ir_attr != nullptr); - int64_t value_get{-1}; - EXPECT_EQ(attr.ir_attr->GetAttrValue("index", value_get), GRAPH_SUCCESS); - EXPECT_EQ(value_get, 0U); - auto data1 = graph.CreateContiguousData("data1", ge::DT_FLOAT, {a, b}, ge::FORMAT_DHWCN); - EXPECT_EQ(data1.output_index, 0U); - EXPECT_EQ(static_cast(data1.dtype), ge::DT_FLOAT); - EXPECT_EQ(data1.format, ge::FORMAT_DHWCN); - EXPECT_EQ(*data1.axis, std::vector({a.id, b.id})); - const auto &attr1 = graph.FindNode("data1")->attr; - EXPECT_EQ(attr1.sched.exec_order, 1U); - EXPECT_EQ(attr1.sched.loop_axis, -1); - EXPECT_TRUE(attr1.ir_attr != nullptr); - const auto &ir_attr1 = dynamic_cast(*attr1.ir_attr); - EXPECT_EQ(ir_attr1.GetIndex(value_get), GRAPH_SUCCESS); - EXPECT_EQ(value_get, 1U); -} - -TEST_F(UtestAscendCIR, CreateContiguousOut) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto output1 = graph.CreateContiguousOut("out1", ge::DT_FLOAT16, {a, b}, ge::FORMAT_ND); - - EXPECT_EQ(output1.output_index, 0U); - EXPECT_EQ(static_cast(output1.dtype), ge::DT_FLOAT16); - EXPECT_EQ(output1.format, ge::FORMAT_ND); - const auto *attr1 = dynamic_cast< AscNodeAttr *>(&(graph.FindNode("out1")->attr)); - EXPECT_EQ(attr1->sched.exec_order, -1); - EXPECT_EQ(attr1->sched.loop_axis, -1); -} - -TEST_F(UtestAscendCIR, StoreToOut) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto data1 = graph.CreateContiguousData("data1", ge::DT_FLOAT, {a, b}, ge::FORMAT_DHWCN); - auto output1 = graph.CreateContiguousOut("out1", ge::DT_FLOAT16, {a, b, c}, ge::FORMAT_ND); - auto load1 = ascir::cg::Load("load1", data1); - ascir::cg::Store("StoreLoad1ToOutput1", load1, output1); - EXPECT_EQ(output1.GetOwnerOp().GetInputsSize(), 1U); - EXPECT_EQ(output1.output_index, 0U); - EXPECT_EQ(static_cast(output1.dtype), ge::DT_FLOAT16); - EXPECT_EQ(output1.format, ge::FORMAT_ND); - EXPECT_EQ(*output1.axis, std::vector({a.id, b.id, c.id})); - const auto *attr1 = dynamic_cast< AscNodeAttr *>(&(graph.FindNode("out1")->attr)); - EXPECT_EQ(attr1->sched.exec_order, 3); - EXPECT_EQ(attr1->sched.loop_axis, -1); -} - -TEST_F(UtestAscendCIR, StoreWithOffset) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto data1 = graph.CreateContiguousData("data1", ge::DT_FLOAT, {a, b}, ge::FORMAT_DHWCN); - auto output1 = graph.CreateContiguousOut("out1", ge::DT_FLOAT16, {a, b, c}, ge::FORMAT_ND); - auto load1 = ascir::cg::Load("load1", data1); - int64_t offset1 = 1024; - ascir::cg::Store("StoreLoad1ToOutput1", load1, output1, offset1); - EXPECT_EQ(output1.GetOwnerOp().GetInputsSize(), 1U); - auto store_node = graph.FindNode("StoreLoad1ToOutput1"); - EXPECT_NE(store_node->attr.ir_attr, nullptr); - int64_t offset_get{-1}; - EXPECT_EQ(store_node->attr.ir_attr->GetAttrValue("offset", offset_get), GRAPH_SUCCESS); - EXPECT_EQ(offset_get, offset1); - auto data2 = graph.CreateContiguousData("data1", ge::DT_FLOAT, {a, b}, ge::FORMAT_DHWCN); - auto load2 = ascir::cg::Load("load2", data2); - auto store2 = ascir_op::Store("Store2"); - store2.x = load2; - int64_t offset2 = 256; - // 设置属性 - store2.ir_attr.SetOffset(offset2); - auto store2_node = graph.FindNode("Store2"); - EXPECT_NE(store2_node->attr.ir_attr, nullptr); - // 获取的方式1,调用子类的函数 - EXPECT_EQ(dynamic_cast(store2_node->attr.ir_attr.get())->GetOffset(offset_get), - GRAPH_SUCCESS); - EXPECT_EQ(offset_get, offset2); - // 获取方式2,调用基类的函数 - EXPECT_EQ(store2_node->attr.ir_attr->GetAttrValue("offset", offset_get), GRAPH_SUCCESS); - EXPECT_EQ(offset_get, offset2); -} - -TEST_F(UtestAscendCIR, IrAttrTest) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("(1 + D)"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto data1 = graph.CreateContiguousData("data1", ge::DT_FLOAT, {a, b}, ge::FORMAT_DHWCN); - auto stub_op1 = ascir_op::StubOp1("stub_op1"); - // input - stub_op1.x = data1; - // 通过op的方式设置attr - stub_op1.ir_attr.SetMy_float(0.1); - stub_op1.ir_attr.SetMy_int(1); - stub_op1.ir_attr.SetMy_string("stub_test"); - stub_op1.ir_attr.SetOffset(D); - auto node = graph.FindNode("stub_op1"); - EXPECT_NE(node, nullptr); - EXPECT_NE(node->attr.ir_attr, nullptr); - // 通过node的方式获取属性 - auto my_ir_attrs = dynamic_cast(node->attr.ir_attr.get()); - EXPECT_NE(my_ir_attrs, nullptr); - int64_t get_valuei; - float get_valuef; - std::string get_values; - Expression get_expression; - EXPECT_EQ(my_ir_attrs->GetMy_int(get_valuei), GRAPH_SUCCESS); - EXPECT_FLOAT_EQ(my_ir_attrs->GetMy_float(get_valuef), GRAPH_SUCCESS); - EXPECT_EQ(my_ir_attrs->GetMy_string(get_values), GRAPH_SUCCESS); - EXPECT_EQ(my_ir_attrs->GetOffset(get_expression), GRAPH_SUCCESS); - EXPECT_EQ(get_valuei, 1); - EXPECT_FLOAT_EQ(get_valuef, 0.1); - EXPECT_EQ(get_values, "stub_test"); - EXPECT_EQ(get_expression, D); - // 成员函数测试 - ascendc_ir::proto::AscIrAttrDef asc_ir_attr_def; - my_ir_attrs->Serialize(asc_ir_attr_def); - const std::string kExpected = R"PROTO(attr { - key: "my_float" - value { - f: 0.1 - } -} -attr { - key: "my_int" - value { - i: 1 - } -} -attr { - key: "my_string" - value { - s: "stub_test" - } -} -attr { - key: "offset" - value { - expression: "(1 + D)" - } -} -)PROTO"; - EXPECT_EQ(asc_ir_attr_def.DebugString(), kExpected); - ascir_op::StubOp1::AscStubOp1IrAttrDef ir_attr_obj2; - EXPECT_EQ(ir_attr_obj2.Deserialize(asc_ir_attr_def), GRAPH_SUCCESS); - EXPECT_EQ(ir_attr_obj2.GetMy_int(get_valuei), GRAPH_SUCCESS); - EXPECT_FLOAT_EQ(ir_attr_obj2.GetMy_float(get_valuef), GRAPH_SUCCESS); - EXPECT_EQ(ir_attr_obj2.GetMy_string(get_values), GRAPH_SUCCESS); - EXPECT_EQ(ir_attr_obj2.GetOffset(get_expression), GRAPH_SUCCESS); - EXPECT_EQ(get_valuei, 1); - EXPECT_FLOAT_EQ(get_valuef, 0.1); - EXPECT_EQ(get_values, "stub_test"); - EXPECT_TRUE(get_expression.IsValid()); - EXPECT_TRUE(D.IsValid()); - EXPECT_EQ(std::string(get_expression.Str().get()), std::string(D.Str().get())); - - auto ir_attr_obj_base = ir_attr_obj2.Clone(); - EXPECT_NE(ir_attr_obj_base, nullptr); - EXPECT_EQ(ir_attr_obj_base->GetAttrValue("my_int", get_valuei), GRAPH_SUCCESS); - EXPECT_NE(ir_attr_obj_base->GetAttrValue("others_int", get_valuei), GRAPH_SUCCESS); - EXPECT_NE(ir_attr_obj_base->GetAttrValue("my_int", get_valuef), GRAPH_SUCCESS); - EXPECT_EQ(ir_attr_obj_base->GetAttrValue("my_float", get_valuef), GRAPH_SUCCESS); - EXPECT_EQ(ir_attr_obj_base->GetAttrValue("my_string", get_values), GRAPH_SUCCESS); - EXPECT_EQ(ir_attr_obj_base->GetAttrValue("offset", get_expression), GRAPH_SUCCESS); - EXPECT_EQ(get_valuei, 1); - EXPECT_FLOAT_EQ(get_valuef, 0.1); - EXPECT_EQ(get_values, "stub_test"); - EXPECT_TRUE(get_expression.IsValid()); - EXPECT_TRUE(D.IsValid()); - EXPECT_EQ(std::string(get_expression.Str().get()), std::string(D.Str().get())); -} - -//REG_ASC_IR(StubOp2) -//.Input("x1", "T") -//.Input("x2", "T") -//.Output("y", "T") -//.DataType("T", TensorType{DT_INT32, DT_INT64}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp2_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp2"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 1U); - - // 校验同sym的输入的dtype是否在注册范围内并且一致 - GE_WARN_ASSERT(input_dtypes[0] == input_dtypes[1]); - const static std::set support_dtypes_of_sym_T = {DT_INT32, DT_INT64}; - GE_WARN_ASSERT(support_dtypes_of_sym_T.find(input_dtypes[0]) != support_dtypes_of_sym_T.end()); - - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - expect_output_dtypes.push_back(input_dtypes[0]); - return SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(input_dtypes[0] == expect_output_dtypes[0]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - //REG_ASC_IR(StubOp2) - //.Input("x1", "T") - //.Input("x2", "T") - //.Output("y", "T") - //.DataType("T", TensorType{DT_INT32, DT_INT64}); - OpDtypeInfer().Input(DT_INT32).Input(DT_INT32).Expect(DT_INT32).AssertSucceed(); - // check input and output num - OpDtypeInfer().Input(DT_INT32).Input(DT_INT32).Input(DT_INT32).Expect(DT_INT32).AssertFailed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_INT32).Expect(DT_INT32).Expect(DT_INT32).AssertFailed(); - // check input same dtype of same sym - OpDtypeInfer().Input(DT_INT32).Input(DT_INT64).Expect(DT_INT32).AssertFailed(); - // check output same dtype of same sym of input - OpDtypeInfer().Input(DT_INT32).Input(DT_INT32).Expect(DT_INT64).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp2_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp2"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty()); - - // 校验同sym的输入的dtype是否一致 - GE_WARN_ASSERT(input_dtypes[0] == input_dtypes[1]); - - expect_output_dtypes.push_back(input_dtypes[0]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -//REG_ASC_IR(StubOp3) -//.Input("x1", "T1") -//.Input("x2", "T2") -//.Input("x3", "T1") -//.Output("y1", "T1") -//.Output("y2", "T2") -//.DataType("T1", TensorType{DT_INT32, DT_INT64}) -//.DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp3_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp3"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 3U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 2U); - - // 校验同sym的输入的dtype是否在注册范围内并且一致 - GE_WARN_ASSERT(input_dtypes[0] == input_dtypes[2]); - const static std::set support_dtypes_of_sym_T1 = {DT_INT32, DT_INT64}; - GE_WARN_ASSERT(support_dtypes_of_sym_T1.find(input_dtypes[0]) != support_dtypes_of_sym_T1.end()); - const static std::set support_dtypes_of_sym_T2 = {DT_FLOAT, DT_FLOAT16}; - GE_WARN_ASSERT(support_dtypes_of_sym_T2.find(input_dtypes[1]) != support_dtypes_of_sym_T2.end()); - - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - expect_output_dtypes.push_back(input_dtypes[0]); - expect_output_dtypes.push_back(input_dtypes[1]); - return SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(input_dtypes[0] == expect_output_dtypes[0]); - GE_WARN_ASSERT(input_dtypes[1] == expect_output_dtypes[1]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - //REG_ASC_IR(StubOp3) - //.Input("x1", "T1") - //.Input("x2", "T2") - //.Input("x3", "T1") - //.Output("y1", "T1") - //.Output("y2", "T2") - //.DataType("T1", TensorType{DT_INT32, DT_INT64}) - //.DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).Input(DT_INT32).Expect(DT_INT32).Expect(DT_FLOAT).AssertSucceed(); - // check input and output num - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).Input(DT_INT32).Input(DT_INT32).Expect(DT_INT32).Expect( - DT_FLOAT).AssertFailed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).Input(DT_INT32).Expect(DT_INT32).Expect(DT_INT32).Expect( - DT_FLOAT).AssertFailed(); - // check input same dtype of same sym - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).Input(DT_INT64).Expect(DT_INT32).Expect(DT_FLOAT).AssertFailed(); - // check output same dtype of same sym of input - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).Input(DT_INT32).Expect(DT_INT32).Expect(DT_INT64).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp3_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp3"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 3U); - GE_ASSERT_TRUE(expect_output_dtypes.empty()); - - // 校验同sym的输入的dtype是否一致 - GE_WARN_ASSERT(input_dtypes[0] == input_dtypes[2]); - - expect_output_dtypes.push_back(input_dtypes[0]); - expect_output_dtypes.push_back(input_dtypes[1]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -//REG_ASC_IR(StubOp4) -//.Input("x1", "T1") -//.Input("x2", "T2") -//.Output("y1", "T3") -//.Output("y2", "T3") -//.Output("y3", "T2") -//.DataType("T1", TensorType{DT_INT32, DT_INT64}) -//.DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}) -//.DataType("T3", TensorType{DT_DOUBLE, DT_BOOL}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp4_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp4"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 3U); - - // 校验同sym的输入的dtype是否在注册范围内并且一致 - const static std::set support_dtypes_of_sym_T1 = {DT_INT32, DT_INT64}; - GE_WARN_ASSERT(support_dtypes_of_sym_T1.find(input_dtypes[0]) != support_dtypes_of_sym_T1.end()); - const static std::set support_dtypes_of_sym_T2 = {DT_FLOAT, DT_FLOAT16}; - GE_WARN_ASSERT(support_dtypes_of_sym_T2.find(input_dtypes[1]) != support_dtypes_of_sym_T2.end()); - - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - GELOGW("Output ir_index [0] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - GELOGW("Output ir_index [1] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - return FAILED; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(expect_output_dtypes[0] == expect_output_dtypes[1]); - static std::set support_dtypes_of_sym_T3 = {DT_DOUBLE, DT_BOOL}; - GE_WARN_ASSERT(support_dtypes_of_sym_T3.find(expect_output_dtypes[0]) != support_dtypes_of_sym_T3.end()); - GE_WARN_ASSERT(input_dtypes[1] == expect_output_dtypes[2]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - //REG_ASC_IR(StubOp4) - //.Input("x1", "T1") - //.Input("x2", "T2") - //.Output("y1", "T3") - //.Output("y2", "T3") - //.Output("y3", "T2") - //.DataType("T1", TensorType{DT_INT32, DT_INT64}) - //.DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}) - //.DataType("T3", TensorType{DT_DOUBLE, DT_BOOL}); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Expect(DT_DOUBLE).Expect(DT_DOUBLE).Expect( - DT_FLOAT16).AssertSucceed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).Expect(DT_DOUBLE).Expect(DT_DOUBLE).Expect( - DT_FLOAT).AssertSucceed(); - // check input and output num - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT32).Expect(DT_DOUBLE).Expect(DT_DOUBLE).Expect( - DT_FLOAT16).AssertFailed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Expect(DT_DOUBLE).Expect(DT_DOUBLE).Expect( - DT_FLOAT16).Expect(DT_FLOAT16).AssertFailed(); - // check out dtype of same sym - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Expect(DT_DOUBLE).Expect(DT_FLOAT16).Expect( - DT_FLOAT16).AssertFailed(); - // infer out failed - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp4_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp4"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty()); - - // 校验同sym的输入的dtype是否一致 - - GELOGW("Output ir_index [0] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - GELOGW("Output ir_index [1] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - return FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -//REG_ASC_IR(StubOp4New) -// .Input("x1", "T1") -// .Input("x2", "T2") -// .Output("y1", "T3") -// .Output("y2", "T3") -// .Output("y3", "T2") -// .Impl({"socv1"}, -// {nullptr, -// nullptr, -// {{"T1", TensorType{DT_INT32, DT_INT64}}, -// {"T2", TensorType{DT_FLOAT16, DT_FLOAT}}, -// {"T3", TensorType{DT_DOUBLE, DT_BOOL}}}}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp4New_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp4New"; - const std::string target_func = "InferDataType"; - - auto [sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 3U); - - // 校验同sym的输入的dtype是否在注册范围内并且一致 - char soc_version[128] = {}; - auto res = rtGetSocVersion(soc_version, 128U); - GE_ASSERT_TRUE(res == RT_ERROR_NONE, "Failed to get soc version str."); - auto soc_str = std::string(soc_version); - std::set support_dtypes_of_sym_T1; - if (soc_str == "socv1") { - support_dtypes_of_sym_T1 = {DT_INT32, DT_INT64}; - } else if (soc_str == "socv2") { - support_dtypes_of_sym_T1 = {DT_INT32, DT_UINT16, DT_INT64}; - } else if (soc_str == "socv3") { - support_dtypes_of_sym_T1 = {DT_INT32, DT_UINT16, DT_INT64}; - } else { - GELOGE(ge::FAILED, "Failed to get soc version, res:%s", soc_str.c_str()); - return ge::FAILED; - } - GE_WARN_ASSERT(support_dtypes_of_sym_T1.find(input_dtypes[0]) != support_dtypes_of_sym_T1.end()); - std::set support_dtypes_of_sym_T2; - if (soc_str == "socv1") { - support_dtypes_of_sym_T2 = {DT_FLOAT, DT_FLOAT16}; - } else if (soc_str == "socv2") { - support_dtypes_of_sym_T2 = {DT_FLOAT, DT_FLOAT16, DT_UINT16}; - } else if (soc_str == "socv3") { - support_dtypes_of_sym_T2 = {DT_FLOAT, DT_FLOAT16, DT_UINT16}; - } else { - GELOGE(ge::FAILED, "Failed to get soc version, res:%s", soc_str.c_str()); - return ge::FAILED; - } - GE_WARN_ASSERT(support_dtypes_of_sym_T2.find(input_dtypes[1]) != support_dtypes_of_sym_T2.end()); - - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - GELOGW("Output ir_index [0] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - GELOGW("Output ir_index [1] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - return FAILED; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(expect_output_dtypes[0] == expect_output_dtypes[1]); - static std::set support_dtypes_of_sym_T3 = {DT_DOUBLE, DT_BOOL}; - GE_WARN_ASSERT(support_dtypes_of_sym_T3.find(expect_output_dtypes[0]) != support_dtypes_of_sym_T3.end()); - GE_WARN_ASSERT(input_dtypes[1] == expect_output_dtypes[2]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) << "Actual code:\n" - << actual_code << "\nExpected:\n" - << expected_code; - - OpDtypeInfer() - .Input(DT_INT32) - .Input(DT_FLOAT16) - .Expect(DT_DOUBLE) - .Expect(DT_DOUBLE) - .Expect(DT_FLOAT16) - .AssertSucceed(); - OpDtypeInfer() - .Input(DT_INT32) - .Input(DT_FLOAT) - .Expect(DT_DOUBLE) - .Expect(DT_DOUBLE) - .Expect(DT_FLOAT) - .AssertSucceed(); - // check input and output num - OpDtypeInfer() - .Input(DT_INT32) - .Input(DT_FLOAT16) - .Input(DT_INT32) - .Expect(DT_DOUBLE) - .Expect(DT_DOUBLE) - .Expect(DT_FLOAT16) - .AssertFailed(); - OpDtypeInfer() - .Input(DT_INT32) - .Input(DT_FLOAT16) - .Expect(DT_DOUBLE) - .Expect(DT_DOUBLE) - .Expect(DT_FLOAT16) - .Expect(DT_FLOAT16) - .AssertFailed(); - // check out dtype of same sym - OpDtypeInfer() - .Input(DT_INT32) - .Input(DT_FLOAT16) - .Expect(DT_DOUBLE) - .Expect(DT_FLOAT16) - .Expect(DT_FLOAT16) - .AssertFailed(); - // infer out failed - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp4New_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp4New"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto [sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty()); - - // 校验同sym的输入的dtype是否一致 - - GELOGW("Output ir_index [0] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - GELOGW("Output ir_index [1] has multi result {DT_DOUBLE, DT_BOOL}, can not infer."); - return FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) << "Actual code:\n" - << actual_code << "\nExpected:\n" - << expected_code; -} - -//REG_ASC_IR(StubOp5) -//.Input("x1", "T1") -//.DynamicInput("x2", "T2") -//.Output("y1", "T1") -//.Output("y2", "T2") -//.DataType("T1", TensorType{DT_INT32, DT_INT64}) -//.DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp5_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp5"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 2U); - - // 校验同sym的输入的dtype是否在注册范围内并且一致 - const static std::set support_dtypes_of_sym_T1 = {DT_INT32, DT_INT64}; - GE_WARN_ASSERT(support_dtypes_of_sym_T1.find(input_dtypes[0]) != support_dtypes_of_sym_T1.end()); - const static std::set support_dtypes_of_sym_T2 = {DT_FLOAT, DT_FLOAT16}; - GE_WARN_ASSERT(support_dtypes_of_sym_T2.find(input_dtypes[1]) != support_dtypes_of_sym_T2.end()); - - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - expect_output_dtypes.push_back(input_dtypes[0]); - expect_output_dtypes.push_back(input_dtypes[1]); - return SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(input_dtypes[0] == expect_output_dtypes[0]); - GE_WARN_ASSERT(input_dtypes[1] == expect_output_dtypes[1]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - //REG_ASC_IR(StubOp5) - //.Input("x1", "T1") - //.DynamicInput("x2", "T2") - //.Output("y1", "T1") - //.Output("y2", "T2") - //.DataType("T1", TensorType{DT_INT32, DT_INT64}) - //.DataType("T2", TensorType{DT_FLOAT16, DT_FLOAT}); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Expect(DT_INT32).Expect(DT_FLOAT16).AssertSucceed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).Expect(DT_INT32).Expect(DT_FLOAT).AssertSucceed(); - // infer out successfully - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT).AssertSucceed(); - // check input and output num - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT32).Expect(DT_DOUBLE).Expect(DT_DOUBLE).AssertFailed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_FLOAT16).Expect(DT_DOUBLE).Expect( - DT_DOUBLE).AssertFailed(); - // check out dtype of same sym - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Expect(DT_INT32).Expect(DT_INT32).AssertFailed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Expect(DT_FLOAT16).Expect(DT_FLOAT16).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp5_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp5"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 2U); - GE_ASSERT_TRUE(expect_output_dtypes.empty()); - - // 校验同sym的输入的dtype是否一致 - - expect_output_dtypes.push_back(input_dtypes[0]); - expect_output_dtypes.push_back(input_dtypes[1]); - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -//REG_ASC_IR(StubOp6) -//.Input("x1", "T1") -//.Input("x2", "T2") -//.Input("x3", "T1") -//.Output("y1", "T3") -//.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT64}) -//.DataType("T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT}) -//.DataType("T3", OrderedTensorTypeList{DT_BOOL, DT_INT8}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp6_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp6"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 3U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 1U); - - GE_WARN_ASSERT(input_dtypes[0] == input_dtypes[2]); - const static std::map, ge::DataType> results = { - {{DT_INT32, DT_FLOAT16}, DT_BOOL}, - {{DT_INT64, DT_FLOAT}, DT_INT8} - }; - auto iter = results.find(std::vector{input_dtypes[0], input_dtypes[1]}); - GE_WARN_ASSERT(iter != results.end()); - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - expect_output_dtypes.push_back(iter->second); - return SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(iter->second == expect_output_dtypes[0]); - - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - -// REG_ASC_IR(StubOp6) -//.Input("x1", "T1") -//.Input("x2", "T2") -//.Input("x3", "T1") -//.Output("y1", "T3") -//.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT64}) -//.DataType("T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT}) -//.DataType("T3", OrderedTensorTypeList{DT_BOOL, DT_INT8}); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT32).Expect(DT_BOOL).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_INT64).Expect(DT_INT8).AssertSucceed(); - // infer out successfully - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT32).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_INT64).AssertSucceed(); - // check input dtype of same sym - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_INT32).AssertFailed(); - // check inputs indicies not match - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT16).Input(DT_INT64).Expect(DT_BOOL).AssertFailed(); - // check output input indicies not match - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_INT64).Expect(DT_BOOL).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp6_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp6"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 输入输出存在关联, 无法进行推导 - GELOGW("Node type %s is not supported to infernocheck for dtype.", Type); - return ge::FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -// 输出存在多个解的情况 -//REG_ASC_IR(StubOp7) -// .Input("x1", "T1") -// .Input("x2", "T2") -// .Input("x3", "T1") -// .Output("y1", "T3") -// .Output("y2", "T2") -//.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) -//.DataType("T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT16, DT_FLOAT}) -//.DataType("T3", OrderedTensorTypeList{DT_BOOL, DT_INT4, DT_INT8}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp7_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp7"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 3U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 2U); - - GE_WARN_ASSERT(input_dtypes[0] == input_dtypes[2]); - const static std::map, std::set> results = { - {{DT_INT32, DT_FLOAT16}, {DT_BOOL, DT_INT4}}, - {{DT_INT64, DT_FLOAT}, {DT_INT8}} - }; - auto iter = results.find(std::vector{input_dtypes[0], input_dtypes[1]}); - GE_WARN_ASSERT(iter != results.end()); - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - GE_WARN_ASSERT(iter->second.size() == 1U); - expect_output_dtypes.push_back(*(iter->second.begin())); - expect_output_dtypes.push_back(input_dtypes[1]); - return SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(iter->second.find(expect_output_dtypes[0]) != iter->second.end()); - GE_WARN_ASSERT(input_dtypes[1] == expect_output_dtypes[1]); - - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - //REG_ASC_IR(StubOp7) - // .Input("x1", "T1") - // .Input("x2", "T2") - // .Input("x3", "T1") - // .Output("y1", "T3") - // .Output("y2", "T2") - //.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) - //.DataType("T2", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT16, DT_FLOAT}) - //.DataType("T3", OrderedTensorTypeList{DT_BOOL, DT_INT4, DT_INT8}); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT32).Expect(DT_BOOL).Expect(DT_FLOAT16).AssertSucceed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT32).Expect(DT_INT4).Expect(DT_FLOAT16).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_INT64).Expect(DT_INT8).Expect(DT_FLOAT).AssertSucceed(); - // infer out successfully - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_INT64).AssertSucceed(); - // infer out failed by multi result - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT32).AssertFailed(); - // check input dtype of same sym - OpDtypeInfer().Input(DT_INT32).Input(DT_FLOAT16).Input(DT_INT64).Expect(DT_BOOL).Expect(DT_FLOAT16).AssertFailed(); - // check output input indicies not match - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_INT64).Expect(DT_INT4).Expect(DT_FLOAT).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp7_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp7"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 输入输出存在关联, 无法进行推导 - GELOGW("Node type %s is not supported to infernocheck for dtype.", Type); - return ge::FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -// 单输入单输出完全一对一,带有重复解优化测试 -//REG_ASC_IR(StubOp8) -//.Input("x", "T1") -//.Output("y", "T2") -//.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) -//.DataType("T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp8_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp8"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 1U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 1U); - - const static std::map results = { - {DT_INT32, DT_BF16}, - {DT_INT64, DT_FLOAT} - }; - auto iter = results.find(input_dtypes[0]); - GE_WARN_ASSERT(iter != results.end()); - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - expect_output_dtypes.push_back(iter->second); - return SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(iter->second == expect_output_dtypes[0]); - - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - //REG_ASC_IR(StubOp8) - //.Input("x", "T1") - //.Output("y", "T2") - //.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) - //.DataType("T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}); - OpDtypeInfer().Input(DT_INT32).Expect(DT_BF16).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).Expect(DT_FLOAT).AssertSucceed(); - // infer out successfully - OpDtypeInfer().Input(DT_INT32).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).AssertSucceed(); - // check output input indicies not match - OpDtypeInfer().Input(DT_INT64).Expect(DT_BF16).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp8_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp8"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 输入输出存在关联, 无法进行推导 - GELOGW("Node type %s is not supported to infernocheck for dtype.", Type); - return ge::FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -// 单输入单输出完全一对一,带有重复解优化测试 -//REG_ASC_IR(StubOp8New) -// .Input("x", "T1") -// .Output("y", "T2") -// .Impl({"socv1"}, -// {nullptr, -// nullptr, -// {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, {"T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}}}}); - -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp8New_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp8New"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 1U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 1U); - - char soc_version[128] = {}; - auto res = rtGetSocVersion(soc_version, 128U); - GE_ASSERT_TRUE(res == RT_ERROR_NONE, "Failed to get soc version str."); - auto soc_str = std::string(soc_version); - std::map> results; - if (soc_str == "socv1") { - results = { - {DT_INT32, {DT_BF16}}, - {DT_INT64, {DT_FLOAT}} - }; - } else if (soc_str == "socv2") { - results = { - {DT_INT32, {DT_BF16}}, - {DT_INT64, {DT_FLOAT}} - }; - } else if (soc_str == "socv3") { - results = { - {DT_INT32, {DT_BF16}}, - {DT_INT64, {DT_FLOAT}} - }; - } else { - GELOGE(ge::FAILED, "Failed to get soc version, res:%s", soc_str.c_str()); - return ge::FAILED; - } - - auto iter = results.find(input_dtypes[0]); - GE_WARN_ASSERT(iter != results.end()); - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - GE_WARN_ASSERT(iter->second.size() == 1U); - expect_output_dtypes.push_back(*(iter->second.begin())); - return ge::SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(iter->second.find(expect_output_dtypes[0]) != iter->second.end()); - - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - OpDtypeInfer().Input(DT_INT32).Expect(DT_BF16).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).Expect(DT_FLOAT).AssertSucceed(); - // infer out successfully - OpDtypeInfer().Input(DT_INT32).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).AssertSucceed(); - // check output input indicies not match - OpDtypeInfer().Input(DT_INT64).Expect(DT_BF16).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp8New_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp8New"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 输入输出存在关联, 无法进行推导 - GELOGW("Node type %s is not supported to infernocheck for dtype.", Type); - return ge::FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -// 多输入,多输出, 输出唯一解和多个解混合的复杂场景 -//REG_ASC_IR(StubOp9) -//.Input("x1", "T1") -//.Input("x2", "T2") -//.Input("x3", "T3") -//.Output("y1", "T2") -//.Output("y2", "T1") -//.Output("y3", "T4") -//.Output("y4", "T5") -//.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) -//.DataType("T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}) -//.DataType("T3", OrderedTensorTypeList{DT_INT8, DT_INT8, DT_FLOAT}) -//.DataType("T4", OrderedTensorTypeList{DT_BOOL, DT_DOUBLE, DT_FLOAT}) -//.DataType("T5", OrderedTensorTypeList{DT_BOOL, DT_COMPLEX128, DT_DUAL}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp9_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp9"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 3U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 4U); - - const static std::map, std::vector>> results = { - {{DT_INT32, DT_BF16, DT_INT8}, {{DT_DOUBLE, DT_BOOL}, {DT_BOOL, DT_COMPLEX128}}}, - {{DT_INT64, DT_FLOAT, DT_FLOAT}, {{DT_FLOAT}, {DT_DUAL}}} - }; - auto iter = results.find(std::vector{input_dtypes[0], input_dtypes[1], input_dtypes[2]}); - GE_WARN_ASSERT(iter != results.end()); - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - expect_output_dtypes.push_back(input_dtypes[1]); - expect_output_dtypes.push_back(input_dtypes[0]); - GE_WARN_ASSERT(iter->second[0].size() == 1U); - expect_output_dtypes.push_back(*(iter->second[0].begin())); - GE_WARN_ASSERT(iter->second[1].size() == 1U); - expect_output_dtypes.push_back(*(iter->second[1].begin())); - return SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(input_dtypes[1] == expect_output_dtypes[0]); - GE_WARN_ASSERT(input_dtypes[0] == expect_output_dtypes[1]); - GE_WARN_ASSERT(iter->second[0].find(expect_output_dtypes[2]) != iter->second[0].end()); - GE_WARN_ASSERT(iter->second[1].find(expect_output_dtypes[3]) != iter->second[1].end()); - - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - // 多输入,多输出, 输出唯一解和多个解混合的复杂场景 - //REG_ASC_IR(StubOp9) - //.Input("x1", "T1") - //.Input("x2", "T2") - //.Input("x3", "T3") - //.Output("y1", "T2") - //.Output("y2", "T1") - //.Output("y3", "T4") - //.Output("y4", "T5") - //.DataType("T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}) - //.DataType("T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}) - //.DataType("T3", OrderedTensorTypeList{DT_INT8, DT_INT8, DT_FLOAT}) - //.DataType("T4", OrderedTensorTypeList{DT_BOOL, DT_DOUBLE, DT_FLOAT}) - //.DataType("T5", OrderedTensorTypeList{DT_BOOL, DT_COMPLEX128, DT_DUAL}); - OpDtypeInfer().Input(DT_INT32).Input(DT_BF16).Input(DT_INT8).Expect(DT_BF16).Expect(DT_INT32).Expect( - DT_BOOL).Expect(DT_BOOL).AssertSucceed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_BF16).Input(DT_INT8).Expect(DT_BF16).Expect(DT_INT32).Expect( - DT_DOUBLE).Expect(DT_COMPLEX128).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_FLOAT).Expect(DT_FLOAT).Expect(DT_INT64).Expect( - DT_FLOAT).Expect(DT_DUAL).AssertSucceed(); - // infer out successfully - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_FLOAT).AssertSucceed(); - // infer out failed of multi result - OpDtypeInfer().Input(DT_INT32).Input(DT_BF16).Input(DT_INT8).AssertFailed(); - // check failed of error indicies - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_FLOAT).Expect(DT_FLOAT).Expect(DT_INT64).Expect( - DT_DOUBLE).Expect(DT_DUAL).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp9_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp9"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 输入输出存在关联, 无法进行推导 - GELOGW("Node type %s is not supported to infernocheck for dtype.", Type); - return ge::FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -// 多输入,多输出, 输出唯一解和多个解混合的复杂场景 -REG_ASC_IR(StubOp9New) - .Input("x1", "T1") - .Input("x2", "T2") - .Input("x3", "T3") - .Output("y1", "T2") - .Output("y2", "T1") - .Output("y3", "T4") - .Output("y4", "T5") - .Impl({"socv1"}, - {nullptr, - nullptr, - {{"T1", OrderedTensorTypeList{DT_INT32, DT_INT32, DT_INT64}}, - {"T2", OrderedTensorTypeList{DT_BF16, DT_BF16, DT_FLOAT}}, - {"T3", OrderedTensorTypeList{DT_INT8, DT_INT8, DT_FLOAT}}, - {"T4", OrderedTensorTypeList{DT_BOOL, DT_DOUBLE, DT_FLOAT}}, - {"T5", OrderedTensorTypeList{DT_BOOL, DT_COMPLEX128, DT_DUAL}}}}); -TEST_F(UtestAscendCIR, CheckInferDtypeImplementation_StubOp9New_InferDataType) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp9New"; - const std::string target_func = "InferDataType"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataType(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 校验入参容器的元素个数是否合法 - GE_ASSERT_EQ(input_dtypes.size(), 3U); - GE_ASSERT_TRUE(expect_output_dtypes.empty() || expect_output_dtypes.size() == 4U); - - char soc_version[128] = {}; - auto res = rtGetSocVersion(soc_version, 128U); - GE_ASSERT_TRUE(res == RT_ERROR_NONE, "Failed to get soc version str."); - auto soc_str = std::string(soc_version); - std::map, std::vector>> results; - if (soc_str == "socv1") { - results = { - {{DT_INT32, DT_BF16, DT_INT8}, {{DT_DOUBLE, DT_BOOL}, {DT_BOOL, DT_COMPLEX128}}}, - {{DT_INT64, DT_FLOAT, DT_FLOAT}, {{DT_FLOAT}, {DT_DUAL}}} - }; - } else if (soc_str == "socv2") { - results = { - {{DT_INT32, DT_BF16, DT_INT8}, {{DT_DOUBLE, DT_BOOL}, {DT_BOOL, DT_COMPLEX128}}}, - {{DT_INT64, DT_FLOAT, DT_FLOAT}, {{DT_FLOAT}, {DT_DUAL}}} - }; - } else if (soc_str == "socv3") { - results = { - {{DT_INT32, DT_BF16, DT_INT8}, {{DT_DOUBLE, DT_BOOL}, {DT_BOOL, DT_COMPLEX128}}}, - {{DT_INT64, DT_FLOAT, DT_FLOAT}, {{DT_FLOAT}, {DT_DUAL}}} - }; - } else { - GELOGE(ge::FAILED, "Failed to get soc version, res:%s", soc_str.c_str()); - return ge::FAILED; - } - - auto iter = results.find(std::vector{input_dtypes[0], input_dtypes[1], input_dtypes[2]}); - GE_WARN_ASSERT(iter != results.end()); - // 输出外部不指定的时候,生成推导的代码 - if (expect_output_dtypes.empty()) { - expect_output_dtypes.push_back(input_dtypes[1]); - expect_output_dtypes.push_back(input_dtypes[0]); - GE_WARN_ASSERT(iter->second[0].size() == 1U); - expect_output_dtypes.push_back(*(iter->second[0].begin())); - GE_WARN_ASSERT(iter->second[1].size() == 1U); - expect_output_dtypes.push_back(*(iter->second[1].begin())); - return ge::SUCCESS; - } - // 输出外部指定,生成校验的代码 - GE_WARN_ASSERT(input_dtypes[1] == expect_output_dtypes[0]); - GE_WARN_ASSERT(input_dtypes[0] == expect_output_dtypes[1]); - GE_WARN_ASSERT(iter->second[0].find(expect_output_dtypes[2]) != iter->second[0].end()); - GE_WARN_ASSERT(iter->second[1].find(expect_output_dtypes[3]) != iter->second[1].end()); - - return SUCCESS; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; - - OpDtypeInfer().Input(DT_INT32).Input(DT_BF16).Input(DT_INT8).Expect(DT_BF16).Expect(DT_INT32).Expect( - DT_BOOL).Expect(DT_BOOL).AssertSucceed(); - OpDtypeInfer().Input(DT_INT32).Input(DT_BF16).Input(DT_INT8).Expect(DT_BF16).Expect(DT_INT32).Expect( - DT_DOUBLE).Expect(DT_COMPLEX128).AssertSucceed(); - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_FLOAT).Expect(DT_FLOAT).Expect(DT_INT64).Expect( - DT_FLOAT).Expect(DT_DUAL).AssertSucceed(); - // infer out successfully - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_FLOAT).AssertSucceed(); - // infer out failed of multi result - OpDtypeInfer().Input(DT_INT32).Input(DT_BF16).Input(DT_INT8).AssertFailed(); - // check failed of error indicies - OpDtypeInfer().Input(DT_INT64).Input(DT_FLOAT).Input(DT_FLOAT).Expect(DT_FLOAT).Expect(DT_INT64).Expect( - DT_DOUBLE).Expect(DT_DUAL).AssertFailed(); -} - -TEST_F(UtestAscendCIR, CheckInferDataTypeWithNoCheckImplementation_StubOp9New_InferDataTypeWithNoCheck) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class = "StubOp9New"; - const std::string target_func = "InferDataTypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( - inline static Status InferDataTypeWithNoCheck(const std::vector& input_dtypes, - std::vector& expect_output_dtypes) { - // 输入输出存在关联, 无法进行推导 - GELOGW("Node type %s is not supported to infernocheck for dtype.", Type); - return ge::FAILED; - }; -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -TEST_F(UtestAscendCIR, AscNodeAttr_copy_constrcut_for_ir_attr) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto data1 = graph.CreateContiguousData("data1", ge::DT_FLOAT, {a, b}, ge::FORMAT_DHWCN); - auto stub_op1 = ascir_op::StubOp1("stub_op1"); - // input - stub_op1.x = data1; - // 通过op的方式设置attr - stub_op1.ir_attr.SetMy_float(0.1); - stub_op1.ir_attr.SetMy_int(1); - stub_op1.ir_attr.SetMy_string("stub_test"); - auto node = graph.FindNode("stub_op1"); - EXPECT_NE(node, nullptr); - - auto node_attr2 = node->attr; - // 测试拷贝构造时,如果有ir_attr,则ir_attr的clone方法会被调用,进行ir_attr的拷贝 - EXPECT_NE(node_attr2.ir_attr, nullptr); - auto my_ir_attrs = node_attr2.ir_attr->DownCastTo(); - EXPECT_NE(my_ir_attrs, nullptr); - int64_t get_valuei; - float get_valuef; - std::string get_values; - EXPECT_EQ(my_ir_attrs->GetMy_int(get_valuei), GRAPH_SUCCESS); - EXPECT_FLOAT_EQ(my_ir_attrs->GetMy_float(get_valuef), GRAPH_SUCCESS); - EXPECT_EQ(my_ir_attrs->GetMy_string(get_values), GRAPH_SUCCESS); - EXPECT_EQ(get_valuei, 1); - EXPECT_FLOAT_EQ(get_valuef, 0.1); - EXPECT_EQ(get_values, "stub_test"); -} - -TEST_F(UtestAscendCIR, Concat_OK) { - AscGraph graph("test_graph"); - Expression s0 = graph.CreateSizeVar("s0"); - Expression s1 = graph.CreateSizeVar("s1"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - Axis &s1_axis = graph.CreateAxis("S1", s1); - - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - ascir_op::Abs abs("abs"); - auto abs_node = ge::NodeUtilsEx::GetNodeFromOperator(abs); - EXPECT_EQ(abs_node, nullptr); - abs.x = data.y; - abs.attr.sched.exec_order = 2; - abs.attr.sched.axis = {s0_axis.id, s1_axis.id}; - abs.y.dtype = ge::DT_FLOAT16; - abs.y.format = ge::FORMAT_ND; - *abs.y.axis = {s0_axis.id, s1_axis.id}; - *abs.y.repeats = {s0, s1}; - *abs.y.strides = {s1, sym::kSymbolOne}; - - ascir_op::Exp exp("exp"); - auto exp_node = ge::NodeUtilsEx::GetNodeFromOperator(exp); - EXPECT_EQ(exp_node, nullptr); - exp.x = data.y; - exp.attr.sched.exec_order = 3; - exp.attr.sched.axis = {s0_axis.id, s1_axis.id}; - exp.y.dtype = ge::DT_FLOAT16; - exp.y.format = ge::FORMAT_ND; - *exp.y.axis = {s0_axis.id, s1_axis.id}; - *exp.y.repeats = {s0, s1}; - *exp.y.strides = {s1, sym::kSymbolOne}; - - ascir_op::Concat concat("concat"); - auto concat_node = ge::NodeUtilsEx::GetNodeFromOperator(concat); - EXPECT_EQ(concat_node, nullptr); - concat.x = {abs.y, exp.y}; - concat.attr.sched.exec_order = 4; - concat.attr.sched.axis = {s0_axis.id, s1_axis.id}; - concat.y.dtype = ge::DT_FLOAT16; - concat.y.format = ge::FORMAT_ND; - *concat.y.axis = {s0_axis.id, s1_axis.id}; - *concat.y.repeats = {s0, s1 * ge::Symbol(2)}; - *concat.y.strides = {s1* ge::Symbol(2), sym::kSymbolOne}; - - - // find Node - auto data_node_find = graph.FindNode("data"); - EXPECT_NE(data_node_find, nullptr); - EXPECT_EQ(data_node_find->attr.sched.exec_order, 1); - EXPECT_EQ(data_node_find->attr.sched.axis.size(), 2U); - EXPECT_EQ(data_node_find->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(data_node_find->outputs[0].attr.axis.size(), 2U); - EXPECT_EQ(ge::DataType(data_node_find->outputs[0].attr.dtype), ge::DT_FLOAT16); - auto abs_node_find = graph.FindNode("abs"); - EXPECT_NE(abs_node_find, nullptr); - - // GetAllNodes - int num = 0; - for (const auto &node : graph.GetAllNodes()) { - if (num == 0) { - EXPECT_EQ(node->GetName(), "data"); - EXPECT_EQ(node->attr.sched.exec_order, 1); - EXPECT_EQ(node->attr.sched.axis.size(), 2U); - EXPECT_EQ(node->attr.sched.axis[0], s0_axis.id); - const auto outputs = node->outputs(); - EXPECT_EQ(outputs.size(), 1U); - EXPECT_NE(outputs[0], nullptr); - EXPECT_EQ(outputs[0]->attr.axis.size(), 2); - } - if (node->GetName() == "concat") { - EXPECT_EQ(node->attr.sched.axis.size(), 2U); - EXPECT_EQ(node->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(node->outputs[0].attr.axis.size(), 2); - const auto outputs = node->outputs(); - EXPECT_EQ(outputs.size(), 1U); - EXPECT_NE(outputs[0], nullptr); - EXPECT_EQ(outputs[0]->attr.axis.size(), 2); - EXPECT_EQ(outputs[0]->attr.axis[0], s0_axis.id); - EXPECT_EQ(outputs[0]->attr.axis[1], s1_axis.id); - } - num++; - } - EXPECT_EQ(num, 4); - - // GetAllNodes - int input_nodes_num = 0; - for (auto node : graph.GetInputNodes()) { - if (input_nodes_num == 0) { - EXPECT_EQ(node->GetName(), "data"); - EXPECT_EQ(node->attr.sched.exec_order, 1); - EXPECT_EQ(node->attr.sched.axis.size(), 2U); - EXPECT_EQ(node->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(node->attr.sched.axis[1], s1_axis.id); - EXPECT_EQ(node->outputs[0].attr.axis.size(), 2); - } - input_nodes_num++; - } - EXPECT_EQ(input_nodes_num, 1); - EXPECT_EQ(graph.GetName(), "test_graph"); - - // GetAllAxis - const AscGraph &const_graph = graph; - const auto all_axis = const_graph.GetAllAxis(); - EXPECT_EQ(all_axis.size(), 2U); -} - -TEST_F(UtestAscendCIR, CreateStartNodesWithoutGraph) { - AscGraph graph("test_graph"); - - ascir_op::Data data("data"); - ascir_op::Constant constant("constant"); - ascir_op::Workspace ws("workspace"); - ascir_op::TbufData t_buf("t_buf"); - graph.AddNode(data); - graph.AddNode(constant); - graph.AddNode(ws); - graph.AddNode(t_buf); - - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - auto const_node = ge::NodeUtilsEx::GetNodeFromOperator(constant); - EXPECT_NE(const_node, nullptr); - auto ws_node = ge::NodeUtilsEx::GetNodeFromOperator(ws); - EXPECT_NE(ws_node, nullptr); - auto t_buf_node = ge::NodeUtilsEx::GetNodeFromOperator(t_buf); - EXPECT_NE(t_buf_node, nullptr); -} - -TEST_F(UtestAscendCIR, CopyFrom) { - AscGraph sub_graph("Sub1"); - ascir_op::Data sub_data("sub_data", sub_graph); - sub_data.attr.api.type = ApiType::kAPITypeBuffer; - sub_data.attr.api.unit = ComputeUnit::kUnitMTE2; - - ascir_op::Abs sub_abs("sub_abs"); - sub_abs.x = sub_data.y; - sub_abs.attr.sched.exec_order = 2; - sub_abs.y.dtype = ge::DT_FLOAT16; - sub_abs.y.format = ge::FORMAT_ND; - - ascir_op::Output sub_out("sub_out"); - sub_out.x = sub_abs.y; - sub_out.attr.api.type = ApiType::kAPITypeBuffer; - sub_out.attr.api.unit = ComputeUnit::kUnitMTE2; - - AscGraph sub_graph2("Sub2"); - ascir_op::Data sub_data2("sub_data", sub_graph2); - sub_data2.attr.api.type = ApiType::kAPITypeBuffer; - sub_data2.attr.api.unit = ComputeUnit::kUnitMTE2; - - ascir_op::Abs sub_abs2("sub_abs"); - sub_abs2.x = sub_data2.y; - sub_abs2.attr.sched.exec_order = 2; - sub_abs2.y.dtype = ge::DT_FLOAT16; - sub_abs2.y.format = ge::FORMAT_ND; - - ascir_op::Output sub_out2("sub_out"); - sub_out2.x = sub_abs2.y; - sub_out2.attr.api.type = ApiType::kAPITypeBuffer; - sub_out2.attr.api.unit = ComputeUnit::kUnitMTE2; - - AscGraph graph("test_graph"); - Expression s0 = graph.CreateSizeVar("s0"); - Expression s1 = graph.CreateSizeVar("s1"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - Axis &s1_axis = graph.CreateAxis("S1", s1); - - ascir_op::Data data("data", graph); - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - EXPECT_NE(data_node, nullptr); - data.attr.api.type = ApiType::kAPITypeBuffer; - data.attr.api.unit = ComputeUnit::kUnitMTE1; - data.attr.api.compute_type = ComputeType::kComputeLoad; // fake to check - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - ascir_op::Abs abs("abs"); - auto abs_node = ge::NodeUtilsEx::GetNodeFromOperator(abs); - EXPECT_EQ(abs_node, nullptr); - abs.x = data.y; - abs.attr.sched.exec_order = 2; - abs.attr.sched.axis = {s0_axis.id, s1_axis.id}; - abs.y.dtype = ge::DT_FLOAT16; - abs.y.format = ge::FORMAT_ND; - *abs.y.axis = {s0_axis.id, s1_axis.id}; - *abs.y.repeats = {s0, s1}; - *abs.y.strides = {s1, sym::kSymbolOne}; - - ascir_op::Exp exp("exp"); - auto exp_node = ge::NodeUtilsEx::GetNodeFromOperator(exp); - EXPECT_EQ(exp_node, nullptr); - exp.x = data.y; - exp.attr.sched.exec_order = 3; - exp.attr.sched.axis = {s0_axis.id, s1_axis.id}; - exp.y.dtype = ge::DT_FLOAT16; - exp.y.format = ge::FORMAT_ND; - *exp.y.axis = {s0_axis.id, s1_axis.id}; - *exp.y.repeats = {s0, s1}; - *exp.y.strides = {s1, sym::kSymbolOne}; - - ascir_op::Concat concat("concat"); - auto concat_node = ge::NodeUtilsEx::GetNodeFromOperator(concat); - EXPECT_EQ(concat_node, nullptr); - concat.x = {abs.y, exp.y}; - concat.x = {abs.y, exp.y}; // reinit - concat.attr.sched.exec_order = 4; - concat.attr.sched.axis = {s0_axis.id, s1_axis.id}; - concat.y.dtype = ge::DT_FLOAT16; - concat.y.format = ge::FORMAT_ND; - *concat.y.axis = {s0_axis.id, s1_axis.id}; - *concat.y.repeats = {s0, s1 * ge::Symbol(2)}; - *concat.y.strides = {s1* ge::Symbol(2), sym::kSymbolOne}; - EXPECT_EQ(graph.AddSubGraph(sub_graph), ge::SUCCESS); - EXPECT_EQ(graph.AddSubGraph(sub_graph2), ge::SUCCESS); - - auto cg = ge::AscGraphUtils::GetComputeGraph(graph); - auto attr = cg->GetOrCreateAttrsGroup(); - ASSERT_NE(attr, nullptr); - attr->type = ge::AscGraphType::kImplGraph; - - AscGraph copy_graph(graph.GetName().c_str()); - copy_graph.CopyFrom(graph); - ge::AscGraph sub1("tmp"); - EXPECT_EQ(copy_graph.FindSubGraph("Sub1", sub1), ge::SUCCESS); - EXPECT_NE(copy_graph.FindSubGraph("Sub3", sub1), ge::SUCCESS); - - auto new_cg = ge::AscGraphUtils::GetComputeGraph(copy_graph); - auto new_attr = new_cg->GetOrCreateAttrsGroup(); - ASSERT_NE(new_attr, nullptr); - EXPECT_EQ(new_attr->type, ge::AscGraphType::kImplGraph); - - std::vector copied_subs; - EXPECT_EQ(copy_graph.GetAllSubGraphs(copied_subs), ge::SUCCESS); - ASSERT_EQ(copied_subs.size(), 2UL); - - auto tmp_node = graph.FindNode("concat"); - ASSERT_NE(tmp_node, nullptr); - ge::AscGraph owner_graph("onwer"); - ASSERT_EQ(ge::AscGraphUtils::FromComputeGraph(tmp_node->GetOwnerComputeGraph(), owner_graph), ge::SUCCESS); - - std::vector copied_subs_new; - EXPECT_EQ(owner_graph.GetAllSubGraphs(copied_subs_new), ge::SUCCESS); - ASSERT_EQ(copied_subs_new.size(), 2UL); - - // check graph attr - auto all_axis = copy_graph.GetAllAxis(); - EXPECT_EQ(all_axis.size(), 2); - EXPECT_EQ(all_axis[0]->name, "S0"); - EXPECT_EQ(all_axis[0]->size, s0); - EXPECT_EQ(all_axis[1]->name, "S1"); - EXPECT_EQ(all_axis[1]->size, s1); - auto all_sizevar = copy_graph.GetAllSizeVar(); - EXPECT_EQ(all_sizevar.size(), 2); - EXPECT_EQ(all_sizevar[0]->expr, Symbol("s0")); - EXPECT_EQ(all_sizevar[1]->expr, Symbol("s1")); - - // check node tensor attr - // data - auto data_node_find = graph.FindNode("data"); - ASSERT_NE(data_node_find, nullptr); - EXPECT_EQ(data_node_find->attr.sched.exec_order, 1); - EXPECT_EQ(data_node_find->attr.sched.axis.size(), 2U); - EXPECT_EQ(data_node_find->attr.sched.axis[0], s0_axis.id); - EXPECT_EQ(data_node_find->attr.api.unit, ComputeUnit::kUnitMTE1); - EXPECT_EQ(data_node_find->attr.api.type, ApiType::kAPITypeBuffer); - EXPECT_EQ(data_node_find->attr.api.compute_type, ComputeType::kComputeLoad); - EXPECT_EQ(data_node_find->outputs[0].attr.axis.size(), 2U); - EXPECT_EQ(ge::DataType(data_node_find->outputs[0].attr.dtype), ge::DT_FLOAT16); - - auto data_node_copy = copy_graph.FindNode("data"); - ASSERT_NE(data_node_copy, nullptr); - EXPECT_EQ(data_node_copy->attr.sched.exec_order, 1); - ASSERT_EQ(data_node_copy->attr.sched.axis.size(), 2U); - EXPECT_EQ(data_node_copy->attr.api.unit, data_node_find->attr.api.unit); - EXPECT_EQ(data_node_copy->attr.api.type, data_node_find->attr.api.type); - EXPECT_EQ(data_node_copy->attr.api.compute_type, data_node_find->attr.api.compute_type); - EXPECT_EQ(data_node_copy->outputs[0].attr.axis.size(), 2U); - EXPECT_EQ(ge::DataType(data_node_copy->outputs[0].attr.dtype), ge::DT_FLOAT16); - - // 测试深拷贝 - data_node_find->outputs[0].attr.dtype = ge::DT_INT8; - EXPECT_EQ(ge::DataType(data_node_copy->outputs[0].attr.dtype), ge::DT_FLOAT16); - auto data_type_copy = data_node_find->outputs[0].attr.dtype; - EXPECT_TRUE(data_type_copy == ge::DT_INT8); - AscTensorDataType data_type_assign; - // 异常测试 - data_type_assign = ge::DT_INT8; - EXPECT_EQ(data_type_assign, ge::DT_UNDEFINED); - // 测试浅拷贝 - data_type_assign = data_node_find->outputs[0].attr.dtype; - EXPECT_TRUE(data_type_assign == ge::DT_INT8); - - // concat - auto concat_node_find = graph.FindNode("concat"); - EXPECT_NE(concat_node_find, nullptr); - EXPECT_EQ(concat_node_find->attr.sched.exec_order, 4); - EXPECT_EQ(concat_node_find->attr.sched.axis.size(), 2U); - EXPECT_EQ(concat_node_find->attr.sched.axis[1], s1_axis.id); - EXPECT_EQ(concat_node_find->outputs[0].attr.axis.size(), 2U); - EXPECT_EQ(ge::DataType(concat_node_find->outputs[0].attr.dtype), ge::DT_FLOAT16); - - auto concat_node_copy = copy_graph.FindNode("concat"); - EXPECT_NE(concat_node_copy, nullptr); - EXPECT_EQ(concat_node_copy->attr.sched.exec_order, 4); - EXPECT_EQ(concat_node_copy->attr.sched.axis.size(), 2U); - EXPECT_EQ(concat_node_copy->attr.sched.axis[1], s1_axis.id); - EXPECT_EQ(concat_node_copy->outputs[0].attr.axis.size(), 2U); - EXPECT_EQ(ge::DataType(concat_node_copy->outputs[0].attr.dtype), ge::DT_FLOAT16); - - // check link (concat) - auto in_node = concat_node_find->GetInDataNodes(); - EXPECT_EQ(in_node.size(), 2); - EXPECT_EQ(in_node.at(0)->GetName(), "abs"); - EXPECT_EQ(in_node.at(1)->GetName(), "exp"); - - auto in_node_copy = concat_node_copy->GetInDataNodes(); - EXPECT_EQ(in_node_copy.size(), 2); - EXPECT_EQ(in_node_copy.at(0)->GetName(), "abs"); - EXPECT_EQ(in_node_copy.at(1)->GetName(), "exp"); -} - -TEST_F(UtestAscendCIR, CopyAttrFrom) { - AscGraph graph("graph"); - graph.SetTilingKey(160); - graph.SetGraphType(ge::AscGraphType::kImplGraph); - auto s0 = graph.CreateSizeVar("s0"); - auto &z0 = graph.CreateAxis("z0", s0); - z0.type = ge::Axis::Type::kAxisTypeBlockOuter; - z0.bind_block = true; - auto &z1 = graph.CreateAxis("z1", ge::sym::kSymbolOne); - z1.type = ge::Axis::Type::kAxisTypeTileInner; - z1.from = {z0.id}; - - AscGraph target_graph("target"); - ASSERT_TRUE(target_graph.CopyAttrFrom(graph)); - - EXPECT_EQ(target_graph.GetTilingKey(), 160); - EXPECT_EQ(target_graph.GetGraphType(), ge::AscGraphType::kImplGraph); - auto all_size_var = target_graph.GetAllSizeVar(); - ASSERT_EQ(all_size_var.size(), 1UL); - EXPECT_EQ(all_size_var[0]->expr, s0); - - auto all_axes = target_graph.GetAllAxis(); - ASSERT_EQ(all_axes.size(), 2UL); - EXPECT_EQ(all_axes[0]->name, z0.name); - EXPECT_EQ(all_axes[0]->type, ge::Axis::Type::kAxisTypeBlockOuter); - EXPECT_EQ(all_axes[0]->bind_block, true); - EXPECT_EQ(all_axes[0]->size, s0); - - EXPECT_EQ(all_axes[1]->name, z1.name); - EXPECT_EQ(all_axes[1]->size, ge::sym::kSymbolOne); - EXPECT_EQ(all_axes[1]->type, ge::Axis::Type::kAxisTypeTileInner); - EXPECT_EQ(all_axes[1]->from, std::vector{z0.id}); -} - -TEST_F(UtestAscendCIR, CopyNodeAttr) { - AscGraph graph("graph"); - Expression s0 = graph.CreateSizeVar("s0"); - Expression s1 = graph.CreateSizeVar("s1"); - Axis &z0 = graph.CreateAxis("S0", s0); - Axis &z1 = graph.CreateAxis("S1", s1); - - ascir_op::Data data("data", graph); - ascir_op::Abs abs("abs"); - abs.x = data.y; - abs.attr.sched.axis = {z0.id, z1.id}; - abs.attr.sched.loop_axis = 1; - abs.attr.api.type = ge::ApiType::kAPITypeCompute; - abs.attr.api.compute_type = ge::ComputeType::kComputeElewise; - abs.y.dtype = ge::DT_FLOAT16; - abs.y.format = ge::FORMAT_ND; - *abs.y.axis = {z0.id, z1.id}; - *abs.y.repeats = {s0, s1}; - *abs.y.strides = {s1, sym::kSymbolOne}; - *abs.y.vectorized_axis = {z0.id, z1.id}; - *abs.y.vectorized_strides = {s1, sym::kSymbolOne}; - ascir_op::Abs abs1("abs1"); - abs1.x = data.y; - - auto abs_node = graph.FindNode("abs"); - auto abs1_node = graph.FindNode("abs1"); - ASSERT_NE(abs_node, nullptr); - ASSERT_NE(abs1_node, nullptr); - ASSERT_TRUE(AscGraph::CopyAscNodeTensorAttr(abs_node, abs1_node)); - - std::vector golden_axis{z0.id, z1.id}; - std::vector golden_repeats{s0, s1}; - std::vector golden_strides{s1, sym::kSymbolOne}; - - EXPECT_EQ(abs1_node->attr.sched.axis, golden_axis); - EXPECT_EQ(abs1_node->attr.sched.loop_axis, 1); - EXPECT_EQ(abs1_node->attr.api.type, ge::ApiType::kAPITypeCompute); - EXPECT_EQ(abs1_node->attr.api.compute_type, ge::ComputeType::kComputeElewise); - - EXPECT_EQ(abs1_node->outputs[0].attr.dtype, ge::DT_FLOAT16); - EXPECT_EQ(abs1_node->outputs[0].attr.axis, golden_axis); - EXPECT_EQ(abs1_node->outputs[0].attr.repeats, golden_repeats); - EXPECT_EQ(abs1_node->outputs[0].attr.strides, golden_strides); - EXPECT_EQ(abs1_node->outputs[0].attr.vectorized_axis, golden_axis); - EXPECT_EQ(abs1_node->outputs[0].attr.vectorized_strides, golden_strides); -} - - -TEST_F(UtestAscendCIR, AscGraphAttr_Clone_Success) { - AscGraphAttr asc_graph_attr; - constexpr uint32_t kMagicNum = 0x5a5a; - asc_graph_attr.tiling_key = kMagicNum; - EXPECT_EQ(asc_graph_attr.type, ge::AscGraphType::kHintGraph); - asc_graph_attr.type = ge::AscGraphType::kImplGraph; - auto clone_attr = asc_graph_attr.Clone(); - ASSERT_NE(clone_attr, nullptr); - auto clone_graph_attr = dynamic_cast(clone_attr.get()); - ASSERT_NE(clone_graph_attr, nullptr); - EXPECT_EQ(clone_graph_attr->tiling_key, kMagicNum); - EXPECT_EQ(clone_graph_attr->type, ge::AscGraphType::kImplGraph); -} - -TEST_F(UtestAscendCIR, AscGraphAttr_Ser_And_Des_Success) { - AscGraphAttr asc_graph_attr; - constexpr uint32_t kMagicNum = 0x5a5a; - asc_graph_attr.tiling_key = kMagicNum; - EXPECT_EQ(asc_graph_attr.type, ge::AscGraphType::kHintGraph); - asc_graph_attr.type = ge::AscGraphType::kImplGraph; - ascendc_ir::proto::AscGraphAttrGroupsDef asc_graph_group; - EXPECT_EQ(asc_graph_attr.SerializeAttr(asc_graph_group), GRAPH_SUCCESS); - EXPECT_EQ(asc_graph_group.tiling_key(), asc_graph_attr.tiling_key); - EXPECT_EQ(asc_graph_group.type(), static_cast(asc_graph_attr.type)); - AscGraphAttr asc_graph_attr2; - asc_graph_attr2.DeserializeAttr(asc_graph_group); - EXPECT_EQ(asc_graph_attr2.tiling_key, asc_graph_attr.tiling_key); - EXPECT_EQ(asc_graph_attr2.type, asc_graph_attr.type); -} - -TEST_F(UtestAscendCIR, AscNodeAttr_Clone_Success) { - AscNodeAttr asc_node_attr; - auto data_ir_attr = ComGraphMakeUnique(); - data_ir_attr->SetIndex(10); - asc_node_attr.ir_attr = std::move(data_ir_attr); - asc_node_attr.api.type = ApiType::kAPITypeCompute; - MemAttr mem_attr{1, AllocType::kAllocTypeGlobal, Position::kPositionGM, MemHardware::kMemHardwareGM, {1}, "mem_name", 2}; - asc_node_attr.tmp_buffers = {TmpBuffer{TmpBufDesc{Expression(), 1}, mem_attr}}; - auto clone_attr = asc_node_attr.Clone(); - ASSERT_NE(clone_attr, nullptr); - auto clone_node_attr = dynamic_cast(clone_attr.get()); - ASSERT_NE(clone_node_attr, nullptr); - EXPECT_EQ(clone_node_attr->api.type, ApiType::kAPITypeCompute); - EXPECT_NE(clone_node_attr->ir_attr, nullptr); - int64_t value_get{-1}; - EXPECT_EQ(clone_node_attr->ir_attr->GetAttrValue("index", value_get), GRAPH_SUCCESS); - EXPECT_EQ(value_get, 10); - EXPECT_EQ(clone_node_attr->tmp_buffers[0].buf_desc.life_time_axis_id, 1); - EXPECT_EQ(clone_node_attr->tmp_buffers[0].mem.name, "mem_name"); - EXPECT_EQ(clone_node_attr->tmp_buffers[0].mem.reuse_id, 2); -} - -TEST_F(UtestAscendCIR, AscTensorAttr_Clone_Success) { - AscTensorAttr asc_tensor_attr; - asc_tensor_attr.mem.alloc_type = AllocType::kAllocTypeL1; - auto clone_attr = asc_tensor_attr.Clone(); - ASSERT_NE(clone_attr, nullptr); - auto clone_tensor_attr = dynamic_cast(clone_attr.get()); - ASSERT_NE(clone_tensor_attr, nullptr); - EXPECT_EQ(clone_tensor_attr->mem.alloc_type, AllocType::kAllocTypeL1); -} - -TEST_F(UtestAscendCIR, AscTensorAttr_Create_Success) { - ascir_op::Data data("data0"); - AscTensorAttr asc_tensor_attr = AscTensorAttr::GetTensorAttr(&data, 0); - asc_tensor_attr.dtype = DT_INT8; - EXPECT_EQ(data.y.dtype, DT_INT8); -} - -namespace ge { -namespace ascir { -inline std::vector> CalcTmpSizeForStubOp11(const ge::AscNode &node) { - std::vector> tmp_buf_descs; - return tmp_buf_descs; -} -} -} - -TEST_F(UtestAscendCIR, CalcAscNodeTmpSize) { - AscGraph graph("test_graph"); - Expression s0 = graph.CreateSizeVar("s0"); - Expression s1 = graph.CreateSizeVar("s1"); - Axis &s0_axis = graph.CreateAxis("S0", s0); - Axis &s1_axis = graph.CreateAxis("S1", s1); - ascir_op::Data data("data", graph); - data.attr.api.type = ApiType::kAPITypeBuffer; - data.attr.api.unit = ComputeUnit::kUnitMTE1; - data.attr.api.compute_type = ComputeType::kComputeLoad; // fake to check - data.attr.sched.exec_order = 1; - data.attr.sched.axis = {s0_axis.id, s1_axis.id}; - auto data_node = ge::NodeUtilsEx::GetNodeFromOperator(data); - data.y.dtype = ge::DT_FLOAT16; - data.y.format = ge::FORMAT_ND; - *data.y.axis = {s0_axis.id, s1_axis.id}; - *data.y.repeats = {s0, s1}; - *data.y.strides = {s1, sym::kSymbolOne}; - - ascir_op::StubOp10 stubOp10("StubOp10"); - stubOp10.x = data.y; - auto stubOp10_node = std::static_pointer_cast(::NodeUtilsEx::GetNodeFromOperator(stubOp10)); - auto tmp_buf_desc_stubOp10 = ge::ascir::CalcAscNodeTmpSize(*stubOp10_node); - EXPECT_EQ(tmp_buf_desc_stubOp10.size(), 1); - EXPECT_EQ(tmp_buf_desc_stubOp10[0]->life_time_axis_id, -1); - EXPECT_EQ(tmp_buf_desc_stubOp10[0]->size, ge::sym::Mul(ge::sym::Mul(Expression(Symbol(2)), s0), s1)); - - ascir_op::StubOp11 stubOp11("StubOp11"); - stubOp11.x = data.y; - auto stubOp11_node = std::static_pointer_cast(::NodeUtilsEx::GetNodeFromOperator(stubOp11)); - auto tmp_buf_desc_stubOp11 = ge::ascir::CalcAscNodeTmpSize(*stubOp11_node); - EXPECT_EQ(tmp_buf_desc_stubOp11.size(), 0); -} - -TEST_F(UtestAscendCIR, AscNodeAttr_TmpBuffer_Serialize) { - AscNodeAttr asc_node_attr; - const TmpBufDesc tmp_buf_desc{Expression(Symbol(1)), 1}; - const MemAttr mem_attr{1, AllocType::kAllocTypeGlobal, Position::kPositionGM, MemHardware::kMemHardwareGM, {1}, "mem_name", 2}; - asc_node_attr.tmp_buffers.emplace_back(TmpBuffer{tmp_buf_desc, mem_attr}); - ascendc_ir::proto::AscNodeAttrGroupsDef asc_node_attr_def; - asc_node_attr.SerializeAttr(asc_node_attr_def); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).buf_desc().life_time_axis_id(), 1); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).buf_desc().size(), Expression(Symbol(1)).Serialize().get()); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).mem().alloc_type(), static_cast(AllocType::kAllocTypeGlobal)); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).mem().position(), static_cast(Position::kPositionGM)); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).mem().hardware(), static_cast(MemHardware::kMemHardwareGM)); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).mem().name(), "mem_name"); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).mem().tensor_id(), 1); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).mem().buf_ids(0), 1); - EXPECT_EQ(asc_node_attr_def.tmp_buffers(0).mem().reuse_id(), 2); -} - -TEST_F(UtestAscendCIR, AscNodeAttr_TmpBuffer_Deserialize) { - ascendc_ir::proto::AscNodeAttrGroupsDef asc_node_attr_def; - auto tmp_buffer = asc_node_attr_def.add_tmp_buffers(); - auto buf_desc = tmp_buffer->mutable_buf_desc(); - buf_desc->set_life_time_axis_id(1); - buf_desc->set_size(Expression(Symbol(1)).Serialize().get()); - auto mem_attr = tmp_buffer->mutable_mem(); - mem_attr->set_alloc_type(static_cast(AllocType::kAllocTypeGlobal)); - mem_attr->set_position(static_cast(Position::kPositionVecCalc)); - mem_attr->set_hardware(static_cast(MemHardware::kMemHardwareGM)); - mem_attr->set_name("mem_name"); - mem_attr->set_tensor_id(2); - mem_attr->add_buf_ids(1); - mem_attr->add_buf_ids(2); - AscNodeAttr asc_node_attr; - asc_node_attr.DeserializeAttr(asc_node_attr_def); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].buf_desc.life_time_axis_id, 1); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].buf_desc.size, Expression(Symbol(1))); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].mem.alloc_type, AllocType::kAllocTypeGlobal); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].mem.position, Position::kPositionVecCalc); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].mem.hardware, MemHardware::kMemHardwareGM); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].mem.name, "mem_name"); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].mem.tensor_id, 2); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].mem.buf_ids[0], 1); - EXPECT_EQ(asc_node_attr.tmp_buffers[0].mem.buf_ids[1], 2); -} - -TEST_F(UtestAscendCIR, AscNodeAttr_Create_invalid) { - auto invalid_op = Operator(); - auto attr = AscNodeAttr::Create(invalid_op); - EXPECT_TRUE(attr == nullptr); -} - -TEST_F(UtestAscendCIR, AscTensorAttr_Create_invalid) { - auto invalid_op = Operator(); - auto attr = AscTensorAttr::GetTensorAttr(&invalid_op, 0); - EXPECT_EQ(attr.dtype, DT_UNDEFINED); - OutDataAnchor out_data_anchor(nullptr, -1); - auto attr2 = AscTensorAttr::GetTensorAttr(&invalid_op, 0); - EXPECT_EQ(attr2.dtype, DT_UNDEFINED); -} - -TEST_F(UtestAscendCIR, AscOutputAttrFormat_invalid) { - AscOutputAttrFormat asc_output_attr_format(nullptr, UINT32_MAX); - EXPECT_EQ(asc_output_attr_format, FORMAT_RESERVED); - asc_output_attr_format = FORMAT_ND; - EXPECT_EQ(asc_output_attr_format, FORMAT_RESERVED); - auto op = Operator(); - AscOutputAttrFormat asc_output_attr_format1(&op, UINT32_MAX); - EXPECT_EQ(asc_output_attr_format1, FORMAT_RESERVED); - asc_output_attr_format1 = FORMAT_ND; - EXPECT_EQ(asc_output_attr_format1, FORMAT_RESERVED); - auto op2 = Operator("stub", "stub"); - AscOutputAttrFormat asc_output_attr_format2(&op2, 1); - EXPECT_EQ(asc_output_attr_format2, FORMAT_RESERVED); - asc_output_attr_format2 = FORMAT_ND; - EXPECT_EQ(asc_output_attr_format2, FORMAT_RESERVED); -} - -TEST_F(UtestAscendCIR, AscOutputAttrDataType_invalid) { - AscOutputAttrDataType asc_output_attr_data_type(nullptr, UINT32_MAX); - EXPECT_EQ(asc_output_attr_data_type, DT_UNDEFINED); - asc_output_attr_data_type = DT_INT32; - EXPECT_EQ(asc_output_attr_data_type, DT_UNDEFINED); - auto op = Operator(); - AscOutputAttrDataType asc_output_attr_data_type1(&op, UINT32_MAX); - EXPECT_EQ(asc_output_attr_data_type1, DT_UNDEFINED); - asc_output_attr_data_type1 = DT_INT32; - EXPECT_EQ(asc_output_attr_data_type1, DT_UNDEFINED); - auto op2 = Operator("stub", "stub"); - AscOutputAttrDataType asc_output_attr_data_type2(&op2, 1); - EXPECT_EQ(asc_output_attr_data_type2, DT_UNDEFINED); - asc_output_attr_data_type2 = DT_INT32; - EXPECT_EQ(asc_output_attr_data_type2, DT_UNDEFINED); -} - -TEST_F(UtestAscendCIR, CalcAscNodeTmpSizeFunc) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class; - const std::string target_func = "CalcAscNodeTmpSize"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( -inline std::vector> CalcAscNodeTmpSize(const ge::AscNode &node) { - typedef std::vector> (*calc_func_ptr) (const AscNode &node); - static const std::unordered_map node_calc_tmp_buff_map = { - {"StubOp10", &SameTmpBufSizeWithFirstInput}, - {"StubOp11", &CalcTmpSizeForStubOp11}, - }; - ge::AscNodeAttr attr = node.attr; - if (node_calc_tmp_buff_map.find(attr.type) != node_calc_tmp_buff_map.end()) { - return node_calc_tmp_buff_map.at(node.attr.type)(node); - } - return std::vector>(); -} -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -TEST_F(UtestAscendCIR, CommonInferDtypeFuncGen) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class; - const std::string target_func = "CommonInferDtype"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( -inline ge::Status CommonInferDtype(const std::string &type, const std::vector &input_dtypes, - std::vector &expect_output_dtypes) { - using func = ge::Status (*)(const std::vector &input_dtypes, - std::vector &expect_output_dtypes); - static const std::unordered_map func_table = { - {"Data", ::ge::ascir_op::Data::InferDataType}, - {"Constant", ::ge::ascir_op::Constant::InferDataType}, - {"IndexExpr", ::ge::ascir_op::IndexExpr::InferDataType}, - {"Workspace", ::ge::ascir_op::Workspace::InferDataType}, - {"TbufData", ::ge::ascir_op::TbufData::InferDataType}, - {"Output", ::ge::ascir_op::Output::InferDataType}, - {"Load", ::ge::ascir_op::Load::InferDataType}, - {"Broadcast", ::ge::ascir_op::Broadcast::InferDataType}, - {"Store", ::ge::ascir_op::Store::InferDataType}, - {"WorkspaceWithInput", ::ge::ascir_op::WorkspaceWithInput::InferDataType}, - {"Nop", ::ge::ascir_op::Nop::InferDataType}, - {"Cast", ::ge::ascir_op::Cast::InferDataType}, - {"Abs", ::ge::ascir_op::Abs::InferDataType}, - {"Exp", ::ge::ascir_op::Exp::InferDataType}, - {"Max", ::ge::ascir_op::Max::InferDataType}, - {"Sum", ::ge::ascir_op::Sum::InferDataType}, - {"Add", ::ge::ascir_op::Add::InferDataType}, - {"Sub", ::ge::ascir_op::Sub::InferDataType}, - {"Div", ::ge::ascir_op::Div::InferDataType}, - {"Mul", ::ge::ascir_op::Mul::InferDataType}, - {"GT", ::ge::ascir_op::GT::InferDataType}, - {"Muls", ::ge::ascir_op::Muls::InferDataType}, - {"MatMul", ::ge::ascir_op::MatMul::InferDataType}, - {"FlashSoftmax", ::ge::ascir_op::FlashSoftmax::InferDataType}, - {"Dropout", ::ge::ascir_op::Dropout::InferDataType}, - {"Select", ::ge::ascir_op::Select::InferDataType}, - {"CalcMean", ::ge::ascir_op::CalcMean::InferDataType}, - {"CalcMeanSlice", ::ge::ascir_op::CalcMeanSlice::InferDataType}, - {"CalcRstd", ::ge::ascir_op::CalcRstd::InferDataType}, - {"CalcRstdSlice", ::ge::ascir_op::CalcRstdSlice::InferDataType}, - {"VFWelfordPart1Update", ::ge::ascir_op::VFWelfordPart1Update::InferDataType}, - {"VFWelfordPart1Finalize", ::ge::ascir_op::VFWelfordPart1Finalize::InferDataType}, - {"VFCalcYWelford", ::ge::ascir_op::VFCalcYWelford::InferDataType}, - {"Concat", ::ge::ascir_op::Concat::InferDataType}, - {"VectorFunction", ::ge::ascir_op::VectorFunction::InferDataType}, - {"FakeOpA", ::ge::ascir_op::FakeOpA::InferDataType}, - {"CalcY", ::ge::ascir_op::CalcY::InferDataType}, - {"CalcMeanStub", ::ge::ascir_op::CalcMeanStub::InferDataType}, - {"StubOp1", ::ge::ascir_op::StubOp1::InferDataType}, - {"StubOp2", ::ge::ascir_op::StubOp2::InferDataType}, - {"StubOp2New", ::ge::ascir_op::StubOp2New::InferDataType}, - {"StubOp3", ::ge::ascir_op::StubOp3::InferDataType}, - {"StubOp3New", ::ge::ascir_op::StubOp3New::InferDataType}, - {"StubOp4", ::ge::ascir_op::StubOp4::InferDataType}, - {"StubOp4New", ::ge::ascir_op::StubOp4New::InferDataType}, - {"StubOp5", ::ge::ascir_op::StubOp5::InferDataType}, - {"StubOp5New", ::ge::ascir_op::StubOp5New::InferDataType}, - {"StubOp6", ::ge::ascir_op::StubOp6::InferDataType}, - {"StubOp6New", ::ge::ascir_op::StubOp6New::InferDataType}, - {"StubOp7", ::ge::ascir_op::StubOp7::InferDataType}, - {"StubOp7New", ::ge::ascir_op::StubOp7New::InferDataType}, - {"StubOp8", ::ge::ascir_op::StubOp8::InferDataType}, - {"StubOp8New", ::ge::ascir_op::StubOp8New::InferDataType}, - {"StubOp9", ::ge::ascir_op::StubOp9::InferDataType}, - {"StubOp9New", ::ge::ascir_op::StubOp9New::InferDataType}, - {"StubOp10", ::ge::ascir_op::StubOp10::InferDataType}, - {"StubOp11", ::ge::ascir_op::StubOp11::InferDataType}, - {"StubRemovePad", ::ge::ascir_op::StubRemovePad::InferDataType}, - }; - const auto &iter = func_table.find(type); - if (iter != func_table.end()) { - return iter->second(input_dtypes, expect_output_dtypes); - } - GELOGW("Node type %s is not supported to infer for now!", type.c_str()); - return ge::FAILED; -} -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -TEST_F(UtestAscendCIR, CommonInferDtypeWithNoCheckFuncGen) { - const std::string file_path = std::string(CMAKE_BINARY_DIR) + "/ascir_stub_builtin_ops/ascir_ops.h"; - const std::string target_class; - const std::string target_func = "CommonInferDtypeWithNoCheck"; - - auto[sig, actual_code] = CodeExtractor::ExtractFunction(file_path, target_class, target_func); - - const std::string expected_code = R"EXPECT( -inline ge::Status CommonInferDtypeWithNoCheck(const std::string &type, const std::vector &input_dtypes, - std::vector &expect_output_dtypes) { - using func = ge::Status (*)(const std::vector &input_dtypes, - std::vector &expect_output_dtypes); - static const std::unordered_map func_table = { - {"Data", ::ge::ascir_op::Data::InferDataTypeWithNoCheck}, - {"Constant", ::ge::ascir_op::Constant::InferDataTypeWithNoCheck}, - {"IndexExpr", ::ge::ascir_op::IndexExpr::InferDataTypeWithNoCheck}, - {"Workspace", ::ge::ascir_op::Workspace::InferDataTypeWithNoCheck}, - {"TbufData", ::ge::ascir_op::TbufData::InferDataTypeWithNoCheck}, - {"Output", ::ge::ascir_op::Output::InferDataTypeWithNoCheck}, - {"Load", ::ge::ascir_op::Load::InferDataTypeWithNoCheck}, - {"Broadcast", ::ge::ascir_op::Broadcast::InferDataTypeWithNoCheck}, - {"Store", ::ge::ascir_op::Store::InferDataTypeWithNoCheck}, - {"WorkspaceWithInput", ::ge::ascir_op::WorkspaceWithInput::InferDataTypeWithNoCheck}, - {"Nop", ::ge::ascir_op::Nop::InferDataTypeWithNoCheck}, - {"Cast", ::ge::ascir_op::Cast::InferDataTypeWithNoCheck}, - {"Abs", ::ge::ascir_op::Abs::InferDataTypeWithNoCheck}, - {"Exp", ::ge::ascir_op::Exp::InferDataTypeWithNoCheck}, - {"Max", ::ge::ascir_op::Max::InferDataTypeWithNoCheck}, - {"Sum", ::ge::ascir_op::Sum::InferDataTypeWithNoCheck}, - {"Add", ::ge::ascir_op::Add::InferDataTypeWithNoCheck}, - {"Sub", ::ge::ascir_op::Sub::InferDataTypeWithNoCheck}, - {"Div", ::ge::ascir_op::Div::InferDataTypeWithNoCheck}, - {"Mul", ::ge::ascir_op::Mul::InferDataTypeWithNoCheck}, - {"GT", ::ge::ascir_op::GT::InferDataTypeWithNoCheck}, - {"Muls", ::ge::ascir_op::Muls::InferDataTypeWithNoCheck}, - {"MatMul", ::ge::ascir_op::MatMul::InferDataTypeWithNoCheck}, - {"FlashSoftmax", ::ge::ascir_op::FlashSoftmax::InferDataTypeWithNoCheck}, - {"Dropout", ::ge::ascir_op::Dropout::InferDataTypeWithNoCheck}, - {"Select", ::ge::ascir_op::Select::InferDataTypeWithNoCheck}, - {"CalcMean", ::ge::ascir_op::CalcMean::InferDataTypeWithNoCheck}, - {"CalcMeanSlice", ::ge::ascir_op::CalcMeanSlice::InferDataTypeWithNoCheck}, - {"CalcRstd", ::ge::ascir_op::CalcRstd::InferDataTypeWithNoCheck}, - {"CalcRstdSlice", ::ge::ascir_op::CalcRstdSlice::InferDataTypeWithNoCheck}, - {"VFWelfordPart1Update", ::ge::ascir_op::VFWelfordPart1Update::InferDataTypeWithNoCheck}, - {"VFWelfordPart1Finalize", ::ge::ascir_op::VFWelfordPart1Finalize::InferDataTypeWithNoCheck}, - {"VFCalcYWelford", ::ge::ascir_op::VFCalcYWelford::InferDataTypeWithNoCheck}, - {"Concat", ::ge::ascir_op::Concat::InferDataTypeWithNoCheck}, - {"VectorFunction", ::ge::ascir_op::VectorFunction::InferDataTypeWithNoCheck}, - {"FakeOpA", ::ge::ascir_op::FakeOpA::InferDataTypeWithNoCheck}, - {"CalcY", ::ge::ascir_op::CalcY::InferDataTypeWithNoCheck}, - {"CalcMeanStub", ::ge::ascir_op::CalcMeanStub::InferDataTypeWithNoCheck}, - {"StubOp1", ::ge::ascir_op::StubOp1::InferDataTypeWithNoCheck}, - {"StubOp2", ::ge::ascir_op::StubOp2::InferDataTypeWithNoCheck}, - {"StubOp2New", ::ge::ascir_op::StubOp2New::InferDataTypeWithNoCheck}, - {"StubOp3", ::ge::ascir_op::StubOp3::InferDataTypeWithNoCheck}, - {"StubOp3New", ::ge::ascir_op::StubOp3New::InferDataTypeWithNoCheck}, - {"StubOp4", ::ge::ascir_op::StubOp4::InferDataTypeWithNoCheck}, - {"StubOp4New", ::ge::ascir_op::StubOp4New::InferDataTypeWithNoCheck}, - {"StubOp5", ::ge::ascir_op::StubOp5::InferDataTypeWithNoCheck}, - {"StubOp5New", ::ge::ascir_op::StubOp5New::InferDataTypeWithNoCheck}, - {"StubOp6", ::ge::ascir_op::StubOp6::InferDataTypeWithNoCheck}, - {"StubOp6New", ::ge::ascir_op::StubOp6New::InferDataTypeWithNoCheck}, - {"StubOp7", ::ge::ascir_op::StubOp7::InferDataTypeWithNoCheck}, - {"StubOp7New", ::ge::ascir_op::StubOp7New::InferDataTypeWithNoCheck}, - {"StubOp8", ::ge::ascir_op::StubOp8::InferDataTypeWithNoCheck}, - {"StubOp8New", ::ge::ascir_op::StubOp8New::InferDataTypeWithNoCheck}, - {"StubOp9", ::ge::ascir_op::StubOp9::InferDataTypeWithNoCheck}, - {"StubOp9New", ::ge::ascir_op::StubOp9New::InferDataTypeWithNoCheck}, - {"StubOp10", ::ge::ascir_op::StubOp10::InferDataTypeWithNoCheck}, - {"StubOp11", ::ge::ascir_op::StubOp11::InferDataTypeWithNoCheck}, - {"StubRemovePad", ::ge::ascir_op::StubRemovePad::InferDataTypeWithNoCheck}, - }; - const auto &iter = func_table.find(type); - if (iter != func_table.end()) { - return iter->second(input_dtypes, expect_output_dtypes); - } - GELOGW("Node type %s is not supported to infer for now!", type.c_str()); - return ge::FAILED; -} -)EXPECT"; - - auto Normalize = [](const std::string &code) { - std::string str = code; - str.erase(std::remove_if(str.begin(), str.end(), ::isspace), str.end()); - return str; - }; - - EXPECT_EQ(Normalize(actual_code), Normalize(expected_code)) - << "Actual code:\n" << actual_code << "\nExpected:\n" << expected_code; -} - -// 正常场景已经在OpDtypeInfer类中校验,这个用例校验异常场景 -TEST_F(UtestAscendCIR, CommonInferDtypeFunc_invalid_case) { - std::vector outputs; - EXPECT_EQ(ascir::CommonInferDtype("not_support_op", {}, outputs), ge::FAILED); -} -TEST_F(UtestAscendCIR, DataCopyConstructor) { - AscGraph graph("test_graph"); - std::vector data_ops; - data_ops.reserve(2); - std::vector outputs; - for (size_t i = 0; i < 2; ++i) { - std::string name = "x" + std::to_string(i); - auto x_op = Data(name.c_str(), graph); - x_op.y.dtype = ge::DT_FLOAT; - x_op.ir_attr.SetIndex(static_cast(i)); - data_ops.push_back(x_op); - outputs.push_back(data_ops[i].y); - } - int64_t data_ops_0_index, data_ops_1_index; - data_ops[0].ir_attr.GetIndex(data_ops_0_index); - data_ops[1].ir_attr.GetIndex(data_ops_1_index); - EXPECT_EQ(data_ops[0].y.dtype, ge::DT_FLOAT); - EXPECT_EQ(data_ops[1].y.dtype, ge::DT_FLOAT); - EXPECT_EQ(data_ops_0_index, 0); - EXPECT_EQ(data_ops_1_index, 1); - - data_ops[0].y.dtype = ge::DT_FLOAT16; - data_ops[0].ir_attr.SetIndex(1); - data_ops[1].y.dtype = ge::DT_INT16; - data_ops[1].ir_attr.SetIndex(0); - - data_ops[0].ir_attr.GetIndex(data_ops_0_index); - data_ops[1].ir_attr.GetIndex(data_ops_1_index); - EXPECT_EQ(data_ops[0].y.dtype, ge::DT_FLOAT16); - EXPECT_EQ(data_ops[1].y.dtype, ge::DT_INT16); - EXPECT_EQ(data_ops_0_index, 1); - EXPECT_EQ(data_ops_1_index, 0); -} - -TEST_F(UtestAscendCIR, OutputCopyConstructor) { - AscGraph graph("test_graph"); - std::vector output_ops; - output_ops.reserve(2); - std::vector outputs; - for (size_t i = 0; i < 2; ++i) { - std::string name = "y" + std::to_string(i); - auto y_op = Output(name.c_str()); - y_op.y.dtype = ge::DT_FLOAT; - output_ops.push_back(y_op); - outputs.push_back(output_ops[i].y); - } - EXPECT_EQ(output_ops[0].y.dtype, ge::DT_FLOAT); - EXPECT_EQ(output_ops[1].y.dtype, ge::DT_FLOAT); - - output_ops[0].y.dtype = ge::DT_FLOAT16; - output_ops[1].y.dtype = ge::DT_INT16; - - EXPECT_EQ(output_ops[0].y.dtype, ge::DT_FLOAT16); - EXPECT_EQ(output_ops[1].y.dtype, ge::DT_INT16); -} - -TEST_F(UtestAscendCIR, AscOpDynamicInputVectorConstructor) { - AscGraph graph("test_graph"); - std::vector data_ops; - data_ops.reserve(2); - std::vector outputs; - for (size_t i = 0; i < 2; ++i) { - std::string name = "x" + std::to_string(i); - auto x_op = Data(name.c_str(), graph); - x_op.y.dtype = ge::DT_FLOAT; - x_op.ir_attr.SetIndex(static_cast(i)); - data_ops.push_back(x_op); - outputs.push_back(data_ops[i].y); - } - Concat concat_op("concat"); - concat_op.x = outputs; - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(concat_op); - const std::string name = "x"; - std::vector indexes; - op_desc->GetDynamicInputIndexesByName(name, indexes); - EXPECT_EQ(indexes.size(), 2); - EXPECT_EQ(indexes[0], static_cast(0)); - EXPECT_EQ(indexes[1], static_cast(1)); - auto concat_node = ge::NodeUtilsEx::GetNodeFromOperator(concat_op); - auto input_nodes = concat_node->GetInNodesPtr(); - EXPECT_EQ(input_nodes.size(), 2); - EXPECT_EQ(input_nodes[0]->GetName(), "x0"); - EXPECT_EQ(input_nodes[1]->GetName(), "x1"); -} - -TEST_F(UtestAscendCIR, AscOpDynamicInputAndOutputToDynamicInput) { - AscGraph graph("test_graph"); - std::vector data_ops; - data_ops.reserve(2); - std::vector outputs; - for (size_t i = 0; i < 2; ++i) { - std::string name = "x" + std::to_string(i); - auto x_op = Data(name.c_str(), graph); - x_op.y.dtype = ge::DT_FLOAT; - x_op.ir_attr.SetIndex(static_cast(i)); - data_ops.push_back(x_op); - outputs.push_back(data_ops[i].y); - } - - VectorFunction vf_op("vf"); - - // 指明有两个输出 - vf_op.InstanceOutputy(2); - - vf_op.x = outputs; - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(vf_op); - const std::string name = "x"; - std::vector indexes; - op_desc->GetDynamicInputIndexesByName(name, indexes); - EXPECT_EQ(indexes.size(), 2); - EXPECT_EQ(indexes[0], static_cast(0)); - EXPECT_EQ(indexes[1], static_cast(1)); - auto concat_node = ge::NodeUtilsEx::GetNodeFromOperator(vf_op); - auto input_nodes = concat_node->GetInNodesPtr(); - EXPECT_EQ(input_nodes.size(), 2); - EXPECT_EQ(input_nodes[0]->GetName(), "x0"); - EXPECT_EQ(input_nodes[1]->GetName(), "x1"); - - { - Concat concat_op("concat"); - concat_op.x = vf_op.y; - const auto op_desc2 = ge::OpDescUtils::GetOpDescFromOperator(concat_op); - const std::string name2 = "x"; - std::vector indexes; - op_desc2->GetDynamicInputIndexesByName(name2, indexes); - EXPECT_EQ(indexes.size(), 2); - EXPECT_EQ(indexes[0], static_cast(0)); - EXPECT_EQ(indexes[1], static_cast(1)); - auto concat_node = ge::NodeUtilsEx::GetNodeFromOperator(concat_op); - auto input_nodes = concat_node->GetInNodesPtr(); - EXPECT_EQ(input_nodes.size(), 2); - EXPECT_EQ(input_nodes[0]->GetName(), "vf"); - EXPECT_EQ(input_nodes[1]->GetName(), "vf"); - } -} - -TEST_F(UtestAscendCIR, AscOpDynamicInputAndOutputToNonDynamicInput) { - AscGraph graph("test_graph"); - std::vector data_ops; - data_ops.reserve(2); - std::vector outputs; - for (size_t i = 0; i < 2; ++i) { - std::string name = "x" + std::to_string(i); - auto x_op = Data(name.c_str(), graph); - x_op.y.dtype = ge::DT_FLOAT; - x_op.ir_attr.SetIndex(static_cast(i)); - data_ops.push_back(x_op); - outputs.push_back(data_ops[i].y); - } - - VectorFunction vf_op("vf"); - // 指明有两个输出 - vf_op.InstanceOutputy(2); - - vf_op.x = outputs; - const auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(vf_op); - const std::string name = "x"; - std::vector indexes; - op_desc->GetDynamicInputIndexesByName(name, indexes); - EXPECT_EQ(indexes.size(), 2); - EXPECT_EQ(indexes[0], static_cast(0)); - EXPECT_EQ(indexes[1], static_cast(1)); - auto concat_node = ge::NodeUtilsEx::GetNodeFromOperator(vf_op); - auto input_nodes = concat_node->GetInNodesPtr(); - EXPECT_EQ(input_nodes.size(), 2); - EXPECT_EQ(input_nodes[0]->GetName(), "x0"); - EXPECT_EQ(input_nodes[1]->GetName(), "x1"); - - { - Add add1_op("add1"); - add1_op.x1 = vf_op.y[0]; - add1_op.x2 = vf_op.y[1]; - const auto op_desc2 = ge::OpDescUtils::GetOpDescFromOperator(add1_op); - auto add_node = ge::NodeUtilsEx::GetNodeFromOperator(add1_op); - auto input_nodes = add_node->GetInNodesPtr(); - EXPECT_EQ(input_nodes.size(), 2); - EXPECT_EQ(input_nodes[0]->GetName(), "vf"); - EXPECT_EQ(input_nodes[1]->GetName(), "vf"); - } - - { - Add add2_op("add2"); - add2_op.x1 = vf_op.y[0]; - const auto op_desc2 = ge::OpDescUtils::GetOpDescFromOperator(add2_op); - auto add_node = ge::NodeUtilsEx::GetNodeFromOperator(add2_op); - auto input_nodes = add_node->GetInNodesPtr(); - EXPECT_EQ(input_nodes.size(), 1); - EXPECT_EQ(input_nodes[0]->GetName(), "vf"); - } -} - -TEST_F(UtestAscendCIR, RegisterTilingData) { - auto ir_defs = ge::ascir::AscirRegistry::GetInstance().GetAll(); - EXPECT_NE(ir_defs.find("StubTilingData"), ir_defs.end()); - EXPECT_EQ(ir_defs["StubTilingData"].GetApiTilingDataName(), "StubTilingData"); -} - -TEST_F(UtestAscendCIR, AscirRegisterImpTest) { - class AscIrAttStub : public ge::ascir::AscIrAtt { - virtual void *GetApiPerf() const { - return nullptr; - } - - virtual void *GetMicroApiPerf() const { - return nullptr; - } - virtual void *GetAscendCApiPerfTable() const { - return nullptr; - } - }; - class AscIrCodegenStub : public ge::ascir::AscIrCodegen { - public: - virtual bool IsVectorFunctionSupported(const ge::AscNode &node) const { - return true; - } - bool IsScalarInputSupported(const std::vector &is_scalar_list) const override { - return false; - } - bool IsScalarInputSupportedIfExchangeInputs(const std::vector &is_scalar_list) const override { - return true; - } - - bool IsInplaceSupported(const ge::AscNode &node) const override { - return true; - } - - bool IsBrcInlineSupported(const ge::AscNode &node) const override { - return true; - } - }; - ge::ascir::AscirRegister reg_test; - reg_test.Impl( - {"v1", "v2", "v3"}, - {ge::ascir::AscIrImplCreator(), - ge::ascir::AscIrImplCreator(), - {{"T1", OrderedTensorTypeList{DT_INT8, DT_INT16}}, {"T2", OrderedTensorTypeList{DT_UINT8, DT_INT16}}}}); - - EXPECT_EQ(reg_test.GetSocImplSize(), 3); - - { - REG_ASC_IR(StubAbs).Input("x", "T").Output("y", "T").Impl({"910v1"}, - {ge::ascir::AscIrImplCreator(), - ge::ascir::AscIrImplCreator(), - {{"T", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT}}}}); - - auto codegen_impl = ge::ascir::AscirRegistry::GetInstance().GetIrCodegenImpl("910v1", "StubAbs"); - auto att_impl = ge::ascir::AscirRegistry::GetInstance().GetIrAttImpl("910v1", "StubAbs"); - EXPECT_NE(att_impl, nullptr); - EXPECT_NE(codegen_impl, nullptr); - AscGraph graph_normal("graph_normal"); - Data x1("x1", graph_normal); - for (const auto &node : graph_normal.GetAllNodes()) { - EXPECT_EQ(codegen_impl->IsVectorFunctionSupported(*node), false); - } - } - - { - REG_ASC_IR(StubAbs2).Input("x", "T").Output("y", "T").Impl({"910v1"}, - {ge::ascir::AscIrImplCreator(), - ge::ascir::AscIrImplCreator(), - {{"T", TensorType{DT_FLOAT16, DT_FLOAT}}}}); - auto codegen_impl = ge::ascir::AscirRegistry::GetInstance().GetIrCodegenImpl("910v1", "StubAbs2"); - auto att_impl = ge::ascir::AscirRegistry::GetInstance().GetIrAttImpl("910v1", "StubAbs2"); - EXPECT_NE(att_impl, nullptr); - EXPECT_NE(codegen_impl, nullptr); - AscGraph graph_normal("graph_normal"); - Data x1("x1", graph_normal); - for (const auto &node : graph_normal.GetAllNodes()) { - EXPECT_EQ(codegen_impl->IsVectorFunctionSupported(*node), true); - } - } - { - REG_ASC_IR(StubAbs3).Input("x", "T").Output("y", "T"); - auto codegen_impl = ge::ascir::AscirRegistry::GetInstance().GetIrCodegenImpl("910v1", "StubAbs3"); - auto att_impl = ge::ascir::AscirRegistry::GetInstance().GetIrAttImpl("910v1", "StubAbs3"); - EXPECT_EQ(att_impl, nullptr); - EXPECT_EQ(codegen_impl, nullptr); - } - { - REG_ASC_IR(StubAdd2).Input("x1", "T").Input("x2", "T").Output("y", "T").Impl({"910v1"}, - {ge::ascir::AscIrImplCreator(), - ge::ascir::AscIrImplCreator(), - {{"T", OrderedTensorTypeList{DT_FLOAT16, DT_FLOAT}}}}); - auto codegen_impl = ge::ascir::AscirRegistry::GetInstance().GetIrCodegenImpl("910v1", "StubAdd2"); - EXPECT_NE(codegen_impl, nullptr); - std::vector is_scalar_list = {false, true}; - EXPECT_EQ(codegen_impl->IsScalarInputSupported(is_scalar_list), false); - EXPECT_EQ(codegen_impl->IsScalarInputSupportedIfExchangeInputs(is_scalar_list), true); - AscGraph graph_normal("graph_normal"); - Data x1("x1", graph_normal); - for (const auto &node : graph_normal.GetAllNodes()) { - EXPECT_EQ(codegen_impl->IsInplaceSupported(*node), true); - EXPECT_EQ(codegen_impl->IsBrcInlineSupported(*node), true); - } - } -} \ No newline at end of file diff --git a/tests/ut/ascendc_ir/testcase/ascir_ops_generator_unittest.cc b/tests/ut/ascendc_ir/testcase/ascir_ops_generator_unittest.cc deleted file mode 100644 index b4944f4d2003a3f878934940591ba4df548c1c7a..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/ascir_ops_generator_unittest.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ -#include -#include - -#include "graph/ascendc_ir/ascir_register.h" -#include "graph/types.h" -namespace ge { -namespace ascir { -REG_ASC_IR_1IO(StubData).StartNode(); -REG_ASC_IR_START_NODE(StubConstant).Attr("value"); -REG_ASC_IR_1IO(StubOutput); - -REG_ASC_IR_1IO(StubLoad).UseFirstInputDataType().UseFirstInputView(); -REG_ASC_IR_1IO(StubStore).UseFirstInputDataType().UseFirstInputView(); - -REG_ASC_IR_1IO(StubCast) - .Attr("dst_type").Attr("stub_attr") - .InferDataType([](const AscIrDef &def, std::stringstream &ss) { - ss << " op.y.dtype = dst_type;" << std::endl; - }); - -REG_ASC_IR_2I1O(StubAdd).UseFirstInputDataType().UseFirstInputView(); -REG_ASC_IR(StubFlashSoftmax) - .Inputs({"x1", "x2", "x3"}) - .Outputs({"y1", "y2", "y3"}) - .UseFirstInputDataType(); - -REG_ASC_IR(StubConcat).DynamicInput("x").Outputs({"y"}).UseFirstInputDataType(); - -REG_ASC_IR(StubVectorFunction).DynamicInput("x").DynamicOutput({"y"}); - -REG_ASC_IR_1IO(StubTilingData).ApiTilingDataType("StubTilingData").ApiTilingDataType("TilingData"); - -} // namespace ascir -namespace ascir { -void GenHeaderFileToStream(const char *, std::stringstream &ss); -class GeneratorUT : public testing::Test {}; -TEST(GeneratorUT, Gnerate_Ops_Ok) { - EXPECT_NO_THROW( - std::stringstream ss; - GenHeaderFileToStream("/path/to/hello.h", ss); - std::cout << "===================:" << std::endl; - std::cout << ss.str() << std::endl; - std::cout << "===================:" << std::endl; - ); -} -} // namespace ascir -} diff --git a/tests/ut/ascendc_ir/testcase/axis_utils_unittest.cc b/tests/ut/ascendc_ir/testcase/axis_utils_unittest.cc deleted file mode 100644 index 5fadac3d572f9d7f76bb9ab8fe9af47e6d917033..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/axis_utils_unittest.cc +++ /dev/null @@ -1,455 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include -#include - -#include "ascendc_ir/ascend_reg_ops.h" -#include "ascendc_ir/core/ascendc_ir_impl.h" -#include "ascir_ops.h" -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "graph/symbolizer/symbolic.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/utils/axis_utils.h" -#include "graph/expression/const_values.h" - -class UtestAxisUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; -namespace ge { -namespace ascir { -namespace cg { -using ge::Expression; -TEST_F(UtestAxisUtils, ReduceView_ok) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto[aBO, aBI] = graph.BlockSplit(a.id, "nbi", "nbo"); - (void) aBO; - auto[aBIO, aBII] = graph.TileSplit(aBI->id, "nii", "nio"); - auto data = graph.CreateContiguousData("data0", DT_FLOAT, {a, b}); - - LOOP(*aBO) { - LOOP(*aBIO) { - LOOP(*aBII) { - auto load = Load("load0", data); - auto before_view = View{*load.axis, *load.repeats, *load.strides}; - auto after_view = AxisUtils::ReduceView(before_view, b.id); - EXPECT_EQ(ge::ViewToString(before_view), - "{ axis: [2, 4, 5, 1], repeats: [(A / (nbo_size)), (nbo_size / (nio_size)), nio_size, B], strides: [(B * nbo_size), (B * nio_size), B, 1] }"); - EXPECT_EQ(ge::ViewToString(after_view), - "{ axis: [2, 4, 5, 1], repeats: [(A / (nbo_size)), (nbo_size / (nio_size)), nio_size, B], strides: [nbo_size, nio_size, 1, 0] }"); - } - } - } -} - -TEST_F(UtestAxisUtils, GetDefaultVectorizedAxis_ok) { - std::vector axis = {0, 1, 2, 3}; - EXPECT_EQ(AxisUtils::GetDefaultVectorizedAxis(axis, 0), std::vector({1, 2, 3})); - EXPECT_EQ(AxisUtils::GetDefaultVectorizedAxis(axis, 1), std::vector({2, 3})); - EXPECT_EQ(AxisUtils::GetDefaultVectorizedAxis(axis, 2), std::vector({3})); - EXPECT_EQ(AxisUtils::GetDefaultVectorizedAxis(axis, 3), std::vector({})); - EXPECT_EQ(AxisUtils::GetDefaultVectorizedAxis(axis, 4), std::vector({0, 1, 2, 3})); -} - -TEST_F(UtestAxisUtils, UpdateViewIfCrossLoop_No_need_update) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto[aBO, aBI] = graph.BlockSplit(a.id, "nbi", "nbo"); - (void) aBO; - auto[aBIO, aBII] = graph.TileSplit(aBI->id, "nii", "nio"); - auto trans_infos = graph.GetAllAxisTransInfo(); - LOOP(*aBO) { - LOOP(*aBIO) { - LOOP(*aBII) { - auto data = ContiguousData("data0", graph, DT_FLOAT, {a, b, c}); - auto load = Load("load0", data); - auto data_attr = - OpDescUtils::GetOpDescFromOperator(data.GetOwnerOp())->GetOrCreateAttrsGroup(); - auto load_attr = CodeGenUtils::GetOwnerOpAscAttr(load.GetOwnerOp()); - EXPECT_EQ(data_attr->sched.axis, std::vector({aBO->id, aBIO->id, aBII->id})); - EXPECT_EQ(data_attr->sched.axis, load_attr->sched.axis); - EXPECT_FALSE(AxisUtils::UpdateViewIfCrossLoop(trans_infos, - data_attr->sched.axis, - load_attr->sched.axis, - {*load.axis, *load.repeats, *load.strides}).first); - } - } - } -} - -TEST_F(UtestAxisUtils, UpdateViewIfCrossLoop_Update_success1) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto[aBO, aBI] = graph.BlockSplit(a.id, "nbi", "nbo"); - (void) aBO; - auto[aBIO, aBII] = graph.TileSplit(aBI->id, "nii", "nio"); - auto trans_infos = graph.GetAllAxisTransInfo(); - auto data = ContiguousData("data0", graph, DT_FLOAT, {a, b, c}); - auto data_sched_axis = graph.FindNode("data0")->attr.sched.axis; - EXPECT_TRUE(data_sched_axis.empty()); - LOOP(*aBO) { - LOOP(*aBIO) { - LOOP(*aBII) { - // Load接口内部会调用UpdateViewIfCrossLoop - auto load = Load("load0", data); - auto load_attr = OpDescUtils::GetOpDescFromOperator(load.GetOwnerOp())->GetOrCreateAttrsGroup(); - EXPECT_EQ(load_attr->sched.axis, std::vector({aBO->id, aBIO->id, aBII->id})); - EXPECT_NE(data_sched_axis, load_attr->sched.axis); - EXPECT_EQ(*load.axis, std::vector({aBO->id, aBIO->id, aBII->id, b.id, c.id})); - std::vector repeats_expect; - repeats_expect.emplace_back(A / aBI->size); - repeats_expect.emplace_back(aBI->size / aBII->size); - repeats_expect.emplace_back(aBII->size); - repeats_expect.emplace_back(B); - repeats_expect.emplace_back(C); - EXPECT_EQ(load.repeats->size(), repeats_expect.size()); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ((*load.repeats)[index++], re); - } - std::vector strides_expect; - strides_expect.emplace_back(B * C * aBI->size); - strides_expect.emplace_back(B * C * aBII->size); - strides_expect.emplace_back(B * C); - strides_expect.emplace_back(C); - strides_expect.emplace_back(sym::kSymbolOne); - EXPECT_EQ(load.strides->size(), strides_expect.size()); - index = 0U; - for (const auto &se : strides_expect) { - EXPECT_EQ((*load.strides)[index++], se); - } - } - } - } -} - -TEST_F(UtestAxisUtils, UpdateViewIfCrossLoop_Update_success2) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto[aBO, aBI] = graph.BlockSplit(a.id, "nbi", "nbo"); - (void) aBO; - auto aBIB = graph.MergeAxis({aBI->id, b.id}); - auto trans_infos = graph.GetAllAxisTransInfo(); - EXPECT_EQ(trans_infos.size(), 2U); - auto data = ContiguousData("data0", graph, DT_FLOAT, {a, b, c}); - auto data_sched_axis = graph.FindNode("data0")->attr.sched.axis; - EXPECT_TRUE(data_sched_axis.empty()); - LOOP(*aBO) { - LOOP(*aBIB) { - // Load接口内部会调用UpdateViewIfCrossLoop - auto load = Load("load0", data); - auto load_attr = OpDescUtils::GetOpDescFromOperator(load.GetOwnerOp())->GetOrCreateAttrsGroup(); - EXPECT_EQ(load_attr->sched.axis, std::vector({aBO->id, aBIB->id})); - EXPECT_NE(data_sched_axis, load_attr->sched.axis); - // 测试多次调用UpdateViewIfCrossLoop - auto pair = AxisUtils::UpdateViewIfCrossLoop(trans_infos, - data_sched_axis, - load_attr->sched.axis, - {*load.axis, *load.repeats, *load.strides}); - EXPECT_TRUE(pair.first); - View view{*load.axis, *load.repeats, *load.strides}; - view = pair.second; - EXPECT_EQ(*load.axis, std::vector({aBO->id, aBIB->id, c.id})); - std::vector repeats_expect; - repeats_expect.emplace_back(A / aBI->size); - repeats_expect.emplace_back(aBI->size * B); - repeats_expect.emplace_back(C); - EXPECT_EQ(load.repeats->size(), repeats_expect.size()); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ((*load.repeats)[index++], re); - } - std::vector strides_expect; - strides_expect.emplace_back(aBI->size * B * C); - strides_expect.emplace_back(C); - strides_expect.emplace_back(sym::kSymbolOne); - EXPECT_EQ(load.strides->size(), strides_expect.size()); - index = 0U; - for (const auto &se : strides_expect) { - EXPECT_EQ((*load.strides)[index++], se); - } - } - } -} - -TEST_F(UtestAxisUtils, UpdateViewIfCrossLoop_DelSceduleAxes_success3) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto L = Symbol("L"); - auto R = Symbol("R"); - auto axis = graph.CreateAxis("axes", A); // 0 - auto loop = graph.CreateAxis("loop", L); // 1 - auto r = graph.CreateAxis("r", R); // 2 - auto[axisB, axisb] = graph.BlockSplit(axis.id, "axisb", "axisB"); // 3,4 - auto[loopT, loopt] = graph.TileSplit(loop.id, "loopt", "loopT"); // 5,6 - auto data0 = ContiguousData("data0", graph, DT_FLOAT, {axis, loop, r}); // 0, 1, 2 - auto data1 = ContiguousData("data1", graph, DT_FLOAT, {axis, loop, r}); - auto data2 = ContiguousData("data2", graph, DT_FLOAT, {axis, loop, r}); - auto data_sched_axis = graph.FindNode("data0")->attr.sched.axis; - LOOP(*axisB) { // 3 - LOOP(*axisb) { // 4 - AscOpOutput y1({loop.id, r.id}); - auto x1 = Load("load1", data0); - auto x2 = Load("load2", data1); - auto x3 = Load("load3", data2); - LOOP(*loopT) { // 5 - auto out1 = CalcY("calc_y", x1, x2, x3, x3); - EXPECT_EQ(*out1.axis, std::vector({axisB->id, axisb->id, loopT->id, loopt->id, r.id})); - std::vector repeats_expect; - repeats_expect.emplace_back(axis.size / axisb->size); - repeats_expect.emplace_back(axisb->size); - repeats_expect.emplace_back(loop.size / loopt->size); - repeats_expect.emplace_back(loopt->size); - repeats_expect.emplace_back(r.size); - EXPECT_EQ(out1.repeats->size(), repeats_expect.size()); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ((*out1.repeats)[index++], re) << " index=" << index; - } - std::vector strides_expect; - strides_expect.emplace_back(axisb->size * loop.size * r.size); - strides_expect.emplace_back(loop.size * r.size); - strides_expect.emplace_back(r.size * loopt->size); - strides_expect.emplace_back(r.size); - strides_expect.emplace_back(sym::kSymbolOne); - EXPECT_EQ(out1.strides->size(), strides_expect.size()); - index = 0U; - for (const auto &se : strides_expect) { - EXPECT_EQ((*out1.strides)[index++], se) << " index=" << index; - } - y1.AutoOffset() = out1; - EXPECT_EQ(*y1.vectorized_axis, std::vector({loop.id, r.id})); - } - auto output = Store("store", y1); - EXPECT_EQ(*output.axis, std::vector({axisB->id, axisb->id, loop.id, r.id})); - std::vector repeats_expect; - repeats_expect.emplace_back(axis.size / axisb->size); - repeats_expect.emplace_back(axisb->size); - repeats_expect.emplace_back(loop.size); - repeats_expect.emplace_back(r.size); - EXPECT_EQ(output.repeats->size(), repeats_expect.size()); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ((*output.repeats)[index++], re) << " index=" << index; - } - std::vector strides_expect; - strides_expect.emplace_back(axisb->size * loop.size * r.size); - strides_expect.emplace_back(loop.size * r.size); - strides_expect.emplace_back(r.size); - strides_expect.emplace_back(sym::kSymbolOne); - EXPECT_EQ(output.strides->size(), strides_expect.size()); - index = 0U; - for (const auto &se:strides_expect) { - EXPECT_EQ((*output.strides)[index++], se) << " index=" << index; - } - } - } -} - -TEST_F(UtestAxisUtils, UpdateViewIfCrossLoop_AddDelSceduleAxes_success3) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto L = Symbol("L"); - auto R = Symbol("R"); - auto axis = graph.CreateAxis("axes", A); // 0 - auto loop = graph.CreateAxis("loop", L); // 1 - auto r = graph.CreateAxis("r", R); // 2 - auto [axisB, axisb] = graph.BlockSplit(axis.id); // 3,4 - auto [loopT, loopt] = graph.TileSplit(loop.id); // 5,6 - auto [rT, rt] = graph.TileSplit(r.id); // 7,8 - auto data0 = ContiguousData("data0", graph, DT_FLOAT, {axis, loop, r}); - auto data1 = ContiguousData("data1", graph, DT_FLOAT, {axis, loop, r}); - auto data2 = ContiguousData("data2", graph, DT_FLOAT, {axis, loop, r}); - auto data_sched_axis = graph.FindNode("data0")->attr.sched.axis; - LOOP(*axisB) { // 3 - LOOP(*axisb) { // 4 - AscOpOutput y1({loop.id, r.id}); - auto x1 = Load("load1", data0); - auto x2 = Load("load2", data1); - auto x3 = Load("load3", data2); - LOOP(*loopT) { // 5 - auto out1 = CalcY("calc_y", x1, x2, x3, x3); - EXPECT_EQ(*out1.axis, std::vector({axisB->id, axisb->id, loopT->id, loopt->id, r.id})); - std::vector repeats_expect; - repeats_expect.emplace_back(axis.size / axisb->size); - repeats_expect.emplace_back(axisb->size); - repeats_expect.emplace_back(loop.size / loopt->size); - repeats_expect.emplace_back(loopt->size); - repeats_expect.emplace_back(r.size); - EXPECT_EQ(out1.repeats->size(), repeats_expect.size()); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ((*out1.repeats)[index++], re); - } - std::vector strides_expect; - strides_expect.emplace_back(axisb->size * loop.size * r.size); - strides_expect.emplace_back(loop.size * r.size); - strides_expect.emplace_back(r.size * loopt->size); - strides_expect.emplace_back(r.size); - strides_expect.emplace_back(sym::kSymbolOne); - EXPECT_EQ(out1.strides->size(), strides_expect.size()); - index = 0U; - for (const auto &se : strides_expect) { - EXPECT_EQ((*out1.strides)[index++], se); - } - y1.AutoOffset() = out1; - EXPECT_EQ(*y1.vectorized_axis, std::vector({loop.id, r.id})); - } - LOOP(*rT) { // 7 - auto output = Store("store", y1); - // 3,4,1,7,8 - EXPECT_EQ(*output.axis, std::vector({axisB->id, axisb->id, rT->id, loop.id, rt->id})); - std::vector repeats_expect; - repeats_expect.emplace_back(axis.size / axisb->size); - repeats_expect.emplace_back(axisb->size); - repeats_expect.emplace_back(r.size / rt->size); - repeats_expect.emplace_back(loop.size); - repeats_expect.emplace_back(rt->size); - EXPECT_EQ(output.repeats->size(), repeats_expect.size()); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ((*output.repeats)[index++], re) << " index=" << index; - } - std::vector strides_expect; - strides_expect.emplace_back(axisb->size * loop.size * r.size); - strides_expect.emplace_back(loop.size * r.size); - strides_expect.emplace_back(rt->size); - strides_expect.emplace_back(r.size); - strides_expect.emplace_back(sym::kSymbolOne); - EXPECT_EQ(output.strides->size(), strides_expect.size()); - index = 0U; - for (const auto &se:strides_expect) { - EXPECT_EQ((*output.strides)[index++], se) << " index=" << index; - } - } - } - } -} - -TEST_F(UtestAxisUtils, UpdateViewIfCrossLoop_ReorderViewSuccess) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto L = Symbol("L"); - auto R = Symbol("R"); - auto axis = graph.CreateAxis("axes", A); // 0 - auto loop = graph.CreateAxis("loop", L); // 1 - auto r = graph.CreateAxis("r", R); // 2 - auto [axisB, axisb] = graph.BlockSplit(axis.id); // 3,4 - auto [loopT, loopt] = graph.TileSplit(loop.id); // 5,6 - auto [rT, rt] = graph.TileSplit(r.id); // 7,8 - std::vector axes = {3, 4, 7, 8, 1}; - std::vector repeats_expect; - repeats_expect.emplace_back(axis.size / axisb->size); // 3 - repeats_expect.emplace_back(axisb->size); // 4 - repeats_expect.emplace_back(r.size / rt->size); // 7 - repeats_expect.emplace_back(loop.size); // 8 - repeats_expect.emplace_back(rt->size); // 1 - - std::vector strides_expect; - strides_expect.emplace_back(axisb->size * loop.size * r.size); - strides_expect.emplace_back(loop.size * r.size); - strides_expect.emplace_back(rt->size); - strides_expect.emplace_back(r.size); - strides_expect.emplace_back(sym::kSymbolOne); - - View src_view{axes, repeats_expect, strides_expect}; - std::vector my_api_sched_axes = {3, 4, 7}; - auto dst_view = AxisUtils::ReorderView(src_view, my_api_sched_axes); - auto [axes_res, repeats, strides] = dst_view; - std::vector expect_axes = {3, 4, 7, 8, 1}; - EXPECT_EQ(axes_res, expect_axes); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ(repeats[index++], re) << " index=" << index; - } - index = 0U; - for (const auto &re : strides_expect) { - EXPECT_EQ(strides[index++], re) << " index=" << index; - } -} - -TEST_F(UtestAxisUtils, UpdateViewIfCrossLoop_ReorderViewSuccess2) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto L = Symbol("L"); - auto R = Symbol("R"); - auto axis = graph.CreateAxis("axes", A); // 0 - auto loop = graph.CreateAxis("loop", L); // 1 - auto r = graph.CreateAxis("r", R); // 2 - auto [axisB, axisb] = graph.BlockSplit(axis.id); // 3,4 - auto [loopT, loopt] = graph.TileSplit(loop.id); // 5,6 - auto [rT, rt] = graph.TileSplit(r.id); // 7,8 - std::vector axes = {8, 7, 1, 3, 4}; - std::vector repeats; - repeats.emplace_back(loop.size); // 8 - repeats.emplace_back(r.size / rt->size); // 7 - repeats.emplace_back(rt->size); // 1 - repeats.emplace_back(axis.size / axisb->size); // 3 - repeats.emplace_back(axisb->size); // 4 - - std::vector strides; - strides.emplace_back(r.size); // 8 - strides.emplace_back(rt->size); // 7 - strides.emplace_back(sym::kSymbolOne); // 1 - strides.emplace_back(axisb->size * loop.size * r.size); // 3 - strides.emplace_back(loop.size * r.size); // 4 - - View src_view{axes, repeats, strides}; - std::vector my_api_sched_axes = {3, 4, 7}; - auto dst_view = AxisUtils::ReorderView(src_view, my_api_sched_axes); - auto [axes_res, repeats_res, strides_res] = dst_view; - std::vector expect_axes = {3, 4, 7, 8, 1}; - - std::vector repeats_expect; - repeats_expect.emplace_back(axis.size / axisb->size); // 3 - repeats_expect.emplace_back(axisb->size); // 4 - repeats_expect.emplace_back(r.size / rt->size); // 7 - repeats_expect.emplace_back(loop.size); // 8 - repeats_expect.emplace_back(rt->size); // 1 - - std::vector strides_expect; - strides_expect.emplace_back(axisb->size * loop.size * r.size); // 3 - strides_expect.emplace_back(loop.size * r.size); // 4 - strides_expect.emplace_back(rt->size); // 7 - strides_expect.emplace_back(r.size); // 8 - strides_expect.emplace_back(sym::kSymbolOne); // 1 - EXPECT_EQ(axes_res, expect_axes); - size_t index = 0U; - for (const auto &re : repeats_expect) { - EXPECT_EQ(repeats_res[index++], re) << " index=" << index; - } - index = 0U; - for (const auto &re : strides_expect) { - EXPECT_EQ(strides_res[index++], re) << " index=" << index; - } -} -} -} -} \ No newline at end of file diff --git a/tests/ut/ascendc_ir/testcase/code_gen_utils_unittest.cc b/tests/ut/ascendc_ir/testcase/code_gen_utils_unittest.cc deleted file mode 100644 index b6488ed74598aabcb2fc9f4bd76e791a5794ff65..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/code_gen_utils_unittest.cc +++ /dev/null @@ -1,1077 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include -#include -#include "ascir_ops.h" -#include "graph/utils/cg_utils.h" -#include "inc/graph/symbolizer/symbolic.h" -#include "expression/const_values.h" - -#define EXPECT_VIEW_PTR_EQ(tensor0, tensor1) \ - EXPECT_EQ(*tensor0.axis, *tensor1.axis);\ - EXPECT_EQ(*tensor0.strides, *tensor1.strides);\ - EXPECT_EQ(*tensor0.repeats, *tensor1.repeats); - -#define EXPECT_VIEW_EQ(tensor0, tensor1) \ - EXPECT_EQ(tensor0.axis, tensor1.axis); \ - EXPECT_EQ(tensor0.strides, tensor1.strides); \ - EXPECT_EQ(tensor0.repeats, tensor1.repeats); - -#define EXPECT_VIEW_AND_DTYPE_EQ(tensor0, tensor1) \ - EXPECT_VIEW_EQ(tensor0, tensor1) \ - EXPECT_EQ(tensor0.dtype, tensor1.dtype) - -namespace ge { -namespace ascir { -namespace cg { -using Graph = ge::AscGraph; -using ge::Expression; -using ge::Symbol; -Graph ConstructTestGraph(const std::string &graph_name) { - Graph graph(graph_name.c_str()); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - // data0(GM1)-------------------------------| - // |->load0(TQue1)-->mm(TQue3)--------->y(TQue4) - // data1(GM2)--->load1(TQue2)-| | - // |_______________________________| - // data2(TBuf1)________________________________| - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT16, {a, c, d}); - AscendString name; - data1.GetOwnerOp().GetName(name); - EXPECT_EQ("data1", std::string(name.GetString())); - auto load0 = Load("load0", data0).TQue(Position::kPositionVecIn, 1, 2); - auto load1 = Load("load1", data1).TQue(Position::kPositionVecIn, 1, 2); - auto mm = MatMul("mm", load0, load1).TQue(Position::kPositionVecOut, 1, 1); - auto data2 = ContiguousData("data2", graph, ge::DT_FLOAT, {a, c, d}).TBuf(Position::kPositionVecOut); - auto y = CalcY("y", data0, data2, data1, mm).TQue(Position::kPositionVecOut, 1, 1); - EXPECT_EQ(y.dtype, ge::DT_FLOAT); - } - } - } - } - return graph; -} - -TEST(CgUtils, SetGetContextOk) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto ctx = CgContext::GetSharedThreadLocalContext(); - ASSERT_EQ(ctx, nullptr); - auto ctx_obj = std::make_shared(); - CgContext::SetThreadLocalContext(ctx_obj); - ctx = CgContext::GetSharedThreadLocalContext(); - ASSERT_NE(ctx, nullptr); - ctx->SetLoopAxes({a, b, c}); - ASSERT_EQ(ctx->GetLoopAxes().size(), 3); - ctx->SetBlockLoopEnd(a.id); - ASSERT_EQ(ctx->GetBlockLoopEnd(), a.id); - ctx->SetVectorizedLoopEnd(c.id); - ASSERT_EQ(ctx->GetVectorizedLoopEnd(), c.id); - ctx->SetLoopEnd(c.id); - ASSERT_EQ(ctx->GetLoopEnd(), c.id); -} - -TEST(CgUtils, LoopGuardContextOk) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - - int64_t count = 0; - ASSERT_EQ(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetSharedThreadLocalContext(), nullptr); - LOOP(a) { - LOOP(b) { - LOOP(c) { - ++count; - ASSERT_NE(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_NE(CgContext::GetSharedThreadLocalContext(), nullptr); - } - } - } - ASSERT_EQ(count, 1); - ASSERT_EQ(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetSharedThreadLocalContext(), nullptr); -} -TEST(CgUtils, OptionLoopGuardContextOk) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - - int64_t count = 0; - ASSERT_EQ(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetSharedThreadLocalContext(), nullptr); - LOOP(a) { - LOOP(b) { - OPTION_LOOP(c, LoopOption{}) { - ++count; - ASSERT_NE(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetOption().pad_tensor_axes_to_loop, false); - ASSERT_NE(CgContext::GetSharedThreadLocalContext(), nullptr); - } - } - } - ASSERT_EQ(count, 1); - ASSERT_EQ(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetSharedThreadLocalContext(), nullptr); - - OPTION_LOOP(a, LoopOption{.pad_tensor_axes_to_loop = true}) { - ASSERT_NE(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetOption().pad_tensor_axes_to_loop, true); - ASSERT_NE(CgContext::GetSharedThreadLocalContext(), nullptr); - LOOP(b) { - LOOP(c) { - ++count; - ASSERT_NE(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetOption().pad_tensor_axes_to_loop, false); - ASSERT_NE(CgContext::GetSharedThreadLocalContext(), nullptr); - } - } - } - ASSERT_EQ(count, 2); - ASSERT_EQ(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetSharedThreadLocalContext(), nullptr); -} -TEST(CgUtils, NestedLoopGuardContextOk) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("D"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ASSERT_EQ(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetSharedThreadLocalContext(), nullptr); - LOOP(a) { - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes().size(), 1UL); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].name, a.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].id, a.id); - - LOOP(b) { - LOOP(c) { - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes().size(), 3UL); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].name, a.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].id, a.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].name, b.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].id, b.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].name, c.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].id, c.id); - - LOOP(d) { - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes().size(), 4UL); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].name, a.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].id, a.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].name, b.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].id, b.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].name, c.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].id, c.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[3].name, d.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[3].id, d.id); - } - - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes().size(), 3UL); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].name, a.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].id, a.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].name, b.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].id, b.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].name, c.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].id, c.id); - } - } - - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes().size(), 1UL); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].name, a.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].id, a.id); - } - ASSERT_EQ(CgContext::GetThreadLocalContext(), nullptr); - ASSERT_EQ(CgContext::GetSharedThreadLocalContext(), nullptr); -} -TEST(CgUtils, LoopGuardAxisOk) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - ASSERT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes().size(), 3UL); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].name, a.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[0].id, a.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].name, b.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[1].id, b.id); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].name, c.name); - EXPECT_EQ(CgContext::GetThreadLocalContext()->GetLoopAxes()[2].id, c.id); - } - } - } -} -TEST(CgUtils, LoopGuard_SchedAxis_Ok) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT, {a, b}); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT, {b, c}); - auto mm = MatMul("mm", data0, data1); - (void) mm; // -Werror=unused-but-set-variable - } - } - } - - auto data0 = graph.FindNode("data0"); - auto data1 = graph.FindNode("data1"); - auto mm = graph.FindNode("mm"); - ASSERT_EQ(std::vector(data0->attr.sched.axis), std::vector({a.id, b.id, c.id})); - ASSERT_EQ(std::vector(data1->attr.sched.axis), std::vector({a.id, b.id, c.id})); - ASSERT_EQ(std::vector(mm->attr.sched.axis), std::vector({a.id, b.id, c.id})); -} - -TEST(PadTensorAxisToSched, NoContext_DoNotPad) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - - ge::ascir_op::Data data("data", graph); - *data.y.axis = {a.id}; - *data.y.repeats = {A}; - *data.y.strides = {sym::kSymbolOne}; - - ASSERT_TRUE(PadOutputViewToSched(data.y)); - EXPECT_EQ(*data.y.axis, std::vector({a.id})); - EXPECT_TRUE((*data.y.repeats)[0] == A); - EXPECT_TRUE((*data.y.strides)[0] == 1); -} - -TEST(PadTensorAxisToSched, NotConfigPad_DoNotPad) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - - ge::ascir_op::Data data("data", graph); - *data.y.axis = {a.id}; - *data.y.repeats = {A}; - *data.y.strides = {sym::kSymbolOne}; - - LOOP(a) { - LOOP(b) { - LOOP(c) { - ASSERT_TRUE(PadOutputViewToSched(data.y)); - } - } - } - EXPECT_EQ(*data.y.axis, std::vector({a.id})); - EXPECT_TRUE((*data.y.repeats)[0] == A); - EXPECT_TRUE((*data.y.strides)[0] == 1); -} - -TEST(PadTensorAxisToSched, NoNeedPad_Ok) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - - ge::ascir_op::Data data("data", graph); - *data.y.axis = {a.id, b.id, c.id}; - *data.y.repeats = {A, B, C}; - *data.y.strides = {C, sym::kSymbolZero, sym::kSymbolOne}; - LOOP(a) { - LOOP(b) { - OPTION_LOOP(c, LoopOption{true}) { - ASSERT_TRUE(PadOutputViewToSched(data.y)); - } - } - } - - EXPECT_EQ(*data.y.axis, std::vector({a.id, b.id, c.id})); - EXPECT_TRUE((*data.y.repeats)[0] == A); - EXPECT_TRUE((*data.y.repeats)[0] == A); - EXPECT_TRUE((*data.y.repeats)[1] == B); - EXPECT_TRUE((*data.y.repeats)[2] == C); - EXPECT_TRUE((*data.y.strides)[0] == C); - EXPECT_TRUE((*data.y.strides)[1] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[2] == sym::kSymbolOne); -} - -TEST(PadTensorAxisToSched, PadHead) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - data.y.SetContiguousView({c, d}); - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - ASSERT_TRUE(PadOutputViewToSched(data.y)); - } - } - } - } - EXPECT_EQ(*data.y.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE((*data.y.repeats)[0] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[1] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[2] == C); - EXPECT_TRUE((*data.y.repeats)[3] == D); - std::cout << "strides 0:" << (*data.y.strides)[0] << std::endl; - std::cout << "strides 1:" << (*data.y.strides)[1] << std::endl; - EXPECT_TRUE((*data.y.strides)[0] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[1] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[2] == D); - EXPECT_TRUE((*data.y.strides)[3] == sym::kSymbolOne); -} -TEST(PadTensorAxisToSched, PadTail) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - data.y.SetContiguousView({a, b}); - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - ASSERT_TRUE(PadOutputViewToSched(data.y)); - } - } - } - } - - EXPECT_EQ(*data.y.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE((*data.y.repeats)[0] == A); - EXPECT_TRUE((*data.y.repeats)[1] == B); - EXPECT_TRUE((*data.y.repeats)[2] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[3] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.strides)[0] == B); - EXPECT_TRUE((*data.y.strides)[1] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.strides)[2] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[3] == sym::kSymbolZero); -} - -TEST(PadTensorAxisToSched, PadTail_NotContiguous) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - *data.y.axis = {a.id, b.id, c.id}; - *data.y.repeats = {A, sym::kSymbolOne, C}; - *data.y.strides = {C, sym::kSymbolZero, sym::kSymbolOne}; - - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - ASSERT_TRUE(PadOutputViewToSched(data.y)); - } - } - } - } - EXPECT_EQ(*data.y.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE((*data.y.repeats)[0] == A); - EXPECT_TRUE((*data.y.repeats)[1] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[2] == C); - EXPECT_TRUE((*data.y.repeats)[3] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.strides)[0] == C); - EXPECT_TRUE((*data.y.strides)[1] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[2] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.strides)[3] == sym::kSymbolZero); -} -TEST(PadTensorAxisToSched, PadMiddle) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - data.y.SetContiguousView({a, d}); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - ASSERT_TRUE(PadOutputViewToSched(data.y)); - } - } - } - } - EXPECT_EQ(*data.y.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE((*data.y.repeats)[0] == A); - EXPECT_TRUE((*data.y.repeats)[1] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[2] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[3] == D); - EXPECT_TRUE((*data.y.strides)[0] == D); - EXPECT_TRUE((*data.y.strides)[1] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[2] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[3] == sym::kSymbolOne); -} -TEST(PadTensorAxisToSched, PadMultiple) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - data.y.SetContiguousView({b, d}); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - ASSERT_TRUE(PadOutputViewToSched(data.y)); - } - } - } - } - EXPECT_EQ(*data.y.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_EQ(*data.y.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE((*data.y.repeats)[0] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[1] == B); - EXPECT_TRUE((*data.y.repeats)[2] == sym::kSymbolOne); - EXPECT_TRUE((*data.y.repeats)[3] == D); - EXPECT_TRUE((*data.y.strides)[0] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[1] == D); - EXPECT_TRUE((*data.y.strides)[2] == sym::kSymbolZero); - EXPECT_TRUE((*data.y.strides)[3] == sym::kSymbolOne); -} - -TEST(PadTensorAxisToSched, SameAxisNumButNotMatch_Failed) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - data.y.SetContiguousView({b, a, c, d}); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - ASSERT_FALSE(PadOutputViewToSched(data.y)); - } - } - } - } -} -TEST(PadTensorAxisToSched, DiffAxisNumAndNotMatch1_Failed) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - data.y.SetContiguousView({a, b, c, d}); - - LOOP(a) { - LOOP(b) { - OPTION_LOOP(c, LoopOption{true}) { - ASSERT_FALSE(PadOutputViewToSched(data.y)); - } - } - } -} -TEST(PadTensorAxisToSched, DiffAxisNumAndNotMatch2_Failed) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - ge::ascir_op::Data data("data", graph); - data.y.SetContiguousView({a, c, b}); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - ASSERT_FALSE(PadOutputViewToSched(data.y)); - } - } - } - } -} -TEST(AutoPadAxis, Ok) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT16, {a, c, d}); - AscendString name; - data1.GetOwnerOp().GetName(name); - EXPECT_EQ("data1", std::string(name.GetString())); - auto load0 = Load("load0", data0); - auto load1 = Load("load1", data0); - auto mm = MatMul("mm", load0, load1); - mm.SetContiguousView({a, b, c}); - PadOutputViewToSched(mm); - auto data2 = ContiguousData("data1", graph, ge::DT_FLOAT, {a, c, d}); - auto y = CalcY("y", data0, data2, data1, mm); - EXPECT_EQ(y.dtype, ge::DT_FLOAT); - } - } - } - } - - auto d0 = graph.FindNode("data0"); - EXPECT_EQ(d0->outputs[0].attr.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE(d0->outputs[0].attr.repeats[0] == A); - EXPECT_TRUE(d0->outputs[0].attr.repeats[1] == B); - EXPECT_TRUE(d0->outputs[0].attr.repeats[2] == sym::kSymbolOne); - EXPECT_TRUE(d0->outputs[0].attr.repeats[3] == D); - EXPECT_TRUE(d0->outputs[0].attr.strides[0] == (B*D)); - EXPECT_TRUE(d0->outputs[0].attr.strides[1] == D); - EXPECT_TRUE(d0->outputs[0].attr.strides[2] == sym::kSymbolZero); - EXPECT_TRUE(d0->outputs[0].attr.strides[3] == sym::kSymbolOne); - - auto d1 = graph.FindNode("data1"); - - EXPECT_EQ(d0->outputs[0].attr.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE(d1->outputs[0].attr.repeats[0] == A); - EXPECT_TRUE(d1->outputs[0].attr.repeats[1] == sym::kSymbolOne); - EXPECT_TRUE(d1->outputs[0].attr.repeats[2] == C); - EXPECT_TRUE(d1->outputs[0].attr.repeats[3] == D); - EXPECT_TRUE(d1->outputs[0].attr.strides[0] == (C*D)); - EXPECT_TRUE(d1->outputs[0].attr.strides[1] == sym::kSymbolZero); - EXPECT_TRUE(d1->outputs[0].attr.strides[2] == D); - EXPECT_TRUE(d1->outputs[0].attr.strides[3] == sym::kSymbolOne); - - auto mm = graph.FindNode("mm"); - EXPECT_EQ(mm->outputs[0].attr.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_TRUE(mm->outputs[0].attr.repeats[0] == A); - EXPECT_TRUE(mm->outputs[0].attr.repeats[1] == B); - EXPECT_TRUE(mm->outputs[0].attr.repeats[2] == C); - EXPECT_TRUE(mm->outputs[0].attr.repeats[3] == sym::kSymbolOne); - EXPECT_TRUE(mm->outputs[0].attr.strides[0] == (B*C)); - EXPECT_TRUE(mm->outputs[0].attr.strides[1] == C); - EXPECT_TRUE(mm->outputs[0].attr.strides[2] == sym::kSymbolOne); - EXPECT_TRUE(mm->outputs[0].attr.strides[3] == sym::kSymbolZero); -} - - -TEST(CgApi, VectorizedTensor_move_assign) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("D"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - LOOP(a) { - AscOpOutput v({b.id, c.id, d.id}); - LOOP(b) { - LOOP(c) { - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, c, d}); - auto load0 = Load("load0", data0); - v.AutoOffset() = Load("load1", data0); - // make sure compile error -// v = Load("load1", data0); // 异常场景 - auto abs0 = Abs("abs0", load0); - auto abs1 = Abs("abs1", static_cast(v)); - (void) abs0; - (void) abs1; - } - } - } - EXPECT_EQ(graph.FindNode("load1")->outputs[0U].attr.vectorized_axis, std::vector({b.id, c.id, d.id})); - // dtype推导api接口还没切换到新的dtype注册机制,暂时不校验dtype - EXPECT_VIEW_EQ(graph.FindNode("load1")->outputs[0U].attr, graph.FindNode("data0")->outputs[0U].attr); - EXPECT_VIEW_EQ(graph.FindNode("abs1")->inputs[0U].attr, graph.FindNode("load1")->outputs[0U].attr); - EXPECT_VIEW_EQ(graph.FindNode("abs1")->outputs[0U].attr, graph.FindNode("load1")->outputs[0U].attr); - EXPECT_VIEW_EQ(graph.FindNode("load0")->outputs[0U].attr, graph.FindNode("data0")->outputs[0U].attr); -} - -TEST(CgApi, ViewInfer_ok) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("D"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - LOOP(a) { - LOOP(b) { - LOOP(c) { - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, c, d}); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT16, {a, b, c, d}); - auto data2 = ContiguousData("data2", graph, ge::DT_FLOAT16, {d}); - auto load0 = Load("load0", data0); - auto load1 = Load("load1", data1); - auto load2 = Load("load2", data2); - auto[out0, out1, out2, out3] = CalcMeanStub("CalcMeanStub", load0, load1, load2, d.id); - // out0 is reduced axis_d - EXPECT_EQ(*out0.axis, *load0.axis); - EXPECT_EQ(*out0.repeats, *load0.repeats); - EXPECT_NE(*out0.strides, *load0.strides); - std::vector strides_expect; - strides_expect.emplace_back(B * C * D / D); - strides_expect.emplace_back(C * D / D); - strides_expect.emplace_back(D / D); - strides_expect.emplace_back(sym::kSymbolZero); - EXPECT_EQ(out0.strides->size(), strides_expect.size()); - size_t index = 0U; - for (const auto &se : strides_expect) { - EXPECT_EQ((*out0.strides)[index++], se); - } - EXPECT_VIEW_PTR_EQ(out1, load0); - EXPECT_VIEW_PTR_EQ(out2, load0); - EXPECT_VIEW_PTR_EQ(out3, load0); - } - } - } -} - -TEST(CgApi, VectorizedAxisInfer_ok) { - AscGraph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("D"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - LOOP(a) { - AscOpOutput v({b.id, c.id, d.id}); - LOOP(b) { - LOOP(c) { - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, c, d}); - auto load0 = Load("load0", data0); - v.AutoOffset() = Load("load1", data0); - auto abs0 = Abs("abs0", load0); - auto abs1 = Abs("abs1", static_cast(v)); - (void) abs0; - (void) abs1; - } - } - } - EXPECT_EQ(graph.FindNode("load1")->outputs[0U].attr.vectorized_axis, std::vector({b.id, c.id, d.id})); - EXPECT_EQ(graph.FindNode("abs1")->inputs[0U].attr.vectorized_axis, graph.FindNode("load1")->outputs[0U].attr.vectorized_axis); - EXPECT_EQ(graph.FindNode("abs1")->outputs[0U].attr.vectorized_axis, std::vector({d.id})); - EXPECT_EQ(graph.FindNode("load0")->attr.sched.loop_axis, c.id); - EXPECT_EQ(graph.FindNode("load0")->outputs[0U].attr.axis, std::vector({a.id, b.id, c.id, d.id})); - EXPECT_EQ(graph.FindNode("load0")->outputs[0U].attr.vectorized_axis, std::vector({d.id})); -} - -TEST(SetDataNodeAttr, Ok) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - std::vector> vec; - vec.emplace_back(1); - vec.emplace_back(2); - vec.emplace_back(3); - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}, 0); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT16, {a, c, d}, 1); - AscendString name; - data0.GetOwnerOp().GetName(name); - EXPECT_EQ("data0", std::string(name.GetString())); - data1.GetOwnerOp().GetName(name); - EXPECT_EQ("data1", std::string(name.GetString())); - } - } - } - } - - auto d0 = graph.FindNode("data0"); - ge::GeAttrValue attr_value; - int64_t index_value = -1; - auto d1 = graph.FindNode("data1"); - index_value = -1; - (void) d1->GetOpDesc()->GetAttr("index", attr_value); - attr_value.GetValue(index_value); - EXPECT_TRUE(index_value == -1); -} - -TEST(TBufTQue, CreatTQueFailed){ - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto d = graph.CreateAxis("d", D); - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}); - EXPECT_EQ(Load("load0", data0).TQue(Position::kPositionVecIn, -1, 1).que->id, kIdNone); - EXPECT_EQ(Load("load0", data0).TQue(Position::kPositionVecIn, 0, 1).que->id, kIdNone); - EXPECT_EQ(Load("load0", data0).TQue(Position::kPositionVecIn, 1, -1).que->id, kIdNone); - EXPECT_EQ(Load("load0", data0).TQue(Position::kPositionVecIn, 1, 0).que->id, kIdNone); -} - -TEST(TBufTQue, CreateOk) { - Graph graph = ConstructTestGraph("test_graph1"); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeGlobal); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.mem.position, Position::kPositionGM); - - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeGlobal); - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.mem.position, Position::kPositionGM); - - EXPECT_EQ(graph.FindNode("data2")->outputs[0].attr.que.id, kIdNone); - EXPECT_NE(graph.FindNode("data2")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("data2")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeBuffer); - EXPECT_EQ(graph.FindNode("data2")->outputs[0].attr.mem.position, Position::kPositionVecOut); - - EXPECT_NE(graph.FindNode("load0")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.que.depth, 1); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.que.buf_num, 2); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeQueue); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.mem.position, Position::kPositionVecIn); - - EXPECT_NE(graph.FindNode("load1")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.que.depth, 1); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.que.buf_num, 2); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeQueue); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.mem.position, Position::kPositionVecIn); - - EXPECT_NE(graph.FindNode("mm")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.que.depth, 1); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.que.buf_num, 1); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeQueue); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.mem.position, Position::kPositionVecOut); - Graph graph2 = ConstructTestGraph("test_graph2"); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.mem.tensor_id, 0); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.mem.tensor_id, graph2.FindNode("data0")->outputs[0].attr.mem.tensor_id); - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.mem.tensor_id, graph2.FindNode("data1")->outputs[0].attr.mem.tensor_id); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.mem.tensor_id, graph2.FindNode("load0")->outputs[0].attr.mem.tensor_id); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.mem.tensor_id, graph2.FindNode("load1")->outputs[0].attr.mem.tensor_id); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.mem.tensor_id, graph2.FindNode("mm")->outputs[0].attr.mem.tensor_id); - EXPECT_EQ(graph.FindNode("data2")->outputs[0].attr.mem.tensor_id, graph2.FindNode("data2")->outputs[0].attr.mem.tensor_id); -} - -TEST(TBufTQue, RepeatBindingFailed){ - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto d = graph.CreateAxis("d", D); - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}); - auto test1 = Load("load0", data0).TQue(Position::kPositionVecIn, 1, 2); - EXPECT_EQ(test1.mem->position, Position::kPositionVecIn); - EXPECT_EQ(test1.TBuf(Position::kPositionVecOut).mem->position, Position::kPositionVecIn); - - auto test2 = Load("load0", data0).TBuf(Position::kPositionVecIn); - EXPECT_EQ(test2.TQue(Position::kPositionVecIn, 1, 1).mem->position, Position::kPositionVecIn); - auto test3 = Load("load0", data0).TBuf(Position::kPositionVecIn); - EXPECT_EQ(test3.TBuf(Position::kPositionVecIn).mem->position, Position::kPositionVecIn); - auto test4 = Load("load0", data0).TQue(Position::kPositionVecIn, 1, 2); - EXPECT_EQ(test4.TQue(Position::kPositionVecIn, 1, 2).mem->position, Position::kPositionVecIn); -} - -TEST(ScopeUse, Ok) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - - LOOP(a) { - LOOP(b) { - LOOP(c) { - OPTION_LOOP(d, LoopOption{true}) { - // data0(GM1)-------------------------------| - // |->load0(TQue1)-->mm(TQue3)--------->y(ScopeUse(data2)) - // data1(GM2)--->load1(TQue2)-| | - // |_______________________________| - // data2(TBuf1)________________________________| - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT16, {a, c, d}); - AscendString name; - data1.GetOwnerOp().GetName(name); - EXPECT_EQ("data1", std::string(name.GetString())); - auto load0 = Load("load0", data0).TQue(Position::kPositionVecIn, 1, 2); - auto load1 = Load("load1", data1).TQue(Position::kPositionVecIn, 1, 2); - auto mm = MatMul("mm", load0, load1).TQue(Position::kPositionVecOut, 1, 1); - auto data2 = ContiguousData("data2", graph, ge::DT_FLOAT, {a, c, d}).TBuf(Position::kPositionVecOut); - auto [rstd0, rstd1] = CalcRstd("rstd", data2, data1, mm); - EXPECT_EQ(rstd0.dtype, ge::DT_FLOAT); - EXPECT_EQ(rstd1.dtype, ge::DT_FLOAT); - rstd0.Use(load1); - rstd1.Use(mm); - } - } - } - } - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeGlobal); - EXPECT_EQ(graph.FindNode("data0")->outputs[0].attr.mem.position, Position::kPositionGM); - - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeGlobal); - EXPECT_EQ(graph.FindNode("data1")->outputs[0].attr.mem.position, Position::kPositionGM); - - EXPECT_EQ(graph.FindNode("data2")->outputs[0].attr.que.id, kIdNone); - EXPECT_NE(graph.FindNode("data2")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("data2")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeBuffer); - EXPECT_EQ(graph.FindNode("data2")->outputs[0].attr.mem.position, Position::kPositionVecOut); - - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.que.id, 0); - EXPECT_NE(graph.FindNode("load0")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.que.depth, 1); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.que.buf_num, 2); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeQueue); - EXPECT_EQ(graph.FindNode("load0")->outputs[0].attr.mem.position, Position::kPositionVecIn); - - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.que.id, 1); - EXPECT_NE(graph.FindNode("load1")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.que.depth, 1); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.que.buf_num, 2); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeQueue); - EXPECT_EQ(graph.FindNode("load1")->outputs[0].attr.mem.position, Position::kPositionVecIn); - - EXPECT_NE(graph.FindNode("mm")->outputs[0].attr.que.id, kIdNone); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.buf.id, kIdNone); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.que.depth, 1); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.que.buf_num, 1); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.mem.alloc_type, AllocType::kAllocTypeQueue); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.mem.position, Position::kPositionVecOut); - - EXPECT_EQ(graph.FindNode("rstd")->outputs[0].attr.que.id, graph.FindNode("load1")->outputs[0].attr.que.id); - EXPECT_NE(graph.FindNode("rstd")->outputs[0].attr.buf.id, graph.FindNode("load1")->outputs[0].attr.que.id); - EXPECT_EQ(graph.FindNode("rstd")->outputs[0].attr.mem.alloc_type, graph.FindNode("load1")->outputs[0].attr.mem.alloc_type); - EXPECT_EQ(graph.FindNode("rstd")->outputs[0].attr.mem.position, graph.FindNode("load1")->outputs[0].attr.mem.position); - - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.que.id, graph.FindNode("rstd")->outputs[1].attr.que.id); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.buf.id, graph.FindNode("rstd")->outputs[1].attr.buf.id); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.que.depth, graph.FindNode("rstd")->outputs[1].attr.que.depth); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.que.buf_num, graph.FindNode("rstd")->outputs[1].attr.que.buf_num); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.mem.alloc_type, graph.FindNode("rstd")->outputs[1].attr.mem.alloc_type); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.mem.position, graph.FindNode("rstd")->outputs[1].attr.mem.position); - EXPECT_EQ(graph.FindNode("mm")->outputs[0].attr.opt.merge_scope, graph.FindNode("rstd")->outputs[1].attr.opt.merge_scope); -} - -TEST(ScopeUse, AlreadyBindFailed) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT16, {a, c, d}); - auto load0 = Load("load0", data0).TQue(Position::kPositionVecIn, 1, 2); - auto load0_id = load0.mem->reuse_id; - auto load1 = Load("load1", data1).TQue(Position::kPositionVecIn, 1, 2); - auto load1_id = load1.mem->reuse_id; - EXPECT_NE(load0_id, load1_id); - EXPECT_EQ(load1.Use(load0).mem->reuse_id, load1_id); - EXPECT_NE(load0_id, load1_id); -} - -TEST(ScopeUse, ReuseIdSame) { - Graph graph("test_graph"); - auto ND = ge::Symbol("ND"); - auto nd = graph.CreateAxis("nd", ND); - auto [ndB, ndb] = graph.BlockSplit(nd.id); - auto [ndbT, ndbt] = graph.TileSplit(ndb->id); - auto data1 = graph.CreateContiguousData("input1", DT_FLOAT, {nd}); - auto data2 = graph.CreateContiguousData("input2", DT_FLOAT, {nd}); - auto data3 = graph.CreateContiguousData("input3", DT_FLOAT, {nd}); - LOOP(*ndB) { - LOOP(*ndbT) { - auto load1 = Load("load1", data1).TQue(Position::kPositionVecIn, 1, 2); - auto load2 = Load("load2", data2).TQue(Position::kPositionVecIn, 1, 2); - auto load3 = Load("load3", data3).TQue(Position::kPositionVecIn, 1, 2); - auto relu = Cast("FakeRelu", load1).TBuf(Position::kPositionVecOut); - auto mul1 = Mul("Mul1", relu, load2).Use(relu); - auto sig_mod = Abs("FakeSigmod", load3).Use(mul1); - auto mul2 = Mul("Mul2", mul1, sig_mod).TQue(Position::kPositionVecOut, 1, 2); - auto store1 = Store("store1", relu); - auto output1 = Output("output1", store1); - EXPECT_EQ(relu.mem->reuse_id, sig_mod.mem->reuse_id); - EXPECT_EQ(relu.mem->reuse_id, mul1.mem->reuse_id); - } - } -} - -TEST(ScopeUse, UsedNotBindFailed) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto D = Symbol("d"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto d = graph.CreateAxis("d", D); - auto data0 = ContiguousData("data0", graph, ge::DT_FLOAT16, {a, b, d}); - auto data1 = ContiguousData("data1", graph, ge::DT_FLOAT16, {a, c, d}); - auto load0 = Load("load0", data0); - auto load1 = Load("load1", data1); - EXPECT_EQ(load1.Use(load0).mem->reuse_id, kIdNone); - AscOpOutput asc_op_output; - EXPECT_EQ(asc_op_output.Use(load0).output_index, UINT32_MAX); -} - -TEST(CodeGenUtils, GenNextExecIdOk) { - ge::AscGraph graph("test"); - EXPECT_EQ(CodeGenUtils::GenNextExecId(graph), 0L); -} - -TEST(CodeGenUtils, PopBackLoopAxisFailed) { - Graph graph("test_graph"); - auto A = Symbol("A"); - auto B = Symbol("B"); - auto C = Symbol("C"); - auto a = graph.CreateAxis("a", A); - auto b = graph.CreateAxis("b", B); - auto c = graph.CreateAxis("c", C); - auto ctx = CgContext::GetSharedThreadLocalContext(); - ASSERT_EQ(ctx, nullptr); - auto ctx_obj = std::make_shared(); - CgContext::SetThreadLocalContext(ctx_obj); - ctx = CgContext::GetSharedThreadLocalContext(); - ASSERT_NE(ctx, nullptr); - // pop empty - ctx->PopBackLoopAxis(a); - ctx->PushLoopAxis(a); - // pop order unmatch - ctx->PopBackLoopAxis(b); -} - -} // namespace cg -} // namespace ascir -} diff --git a/tests/ut/ascendc_ir/testcase/mem_utils_unittest.cc b/tests/ut/ascendc_ir/testcase/mem_utils_unittest.cc deleted file mode 100644 index 98fb473c1b161e4bf01ff3d1f6614f19f1c35a84..0000000000000000000000000000000000000000 --- a/tests/ut/ascendc_ir/testcase/mem_utils_unittest.cc +++ /dev/null @@ -1,134 +0,0 @@ -/** - * Copyright (c) Huawei Technologies Co., Ltd. 2024 All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include -#include -#include "ascendc_ir/ascendc_ir_core/ascendc_ir.h" -#include "graph/utils/mem_utils.h" - -namespace ge { -class UtestMemUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; -TEST(UtestMemUtils, CreateTQueConfigSuccess) { - auto tque = MemUtils::CreateTQueConfig(Position::kPositionVecIn, 1, 2); - EXPECT_EQ(tque.pos_, Position::kPositionVecIn); - EXPECT_EQ(tque.queue_attr_.buf_num, 2); - EXPECT_EQ(tque.queue_attr_.depth, 1); -} - -TEST(UtestMemUtils, CreateTQueConfigFailed) { - EXPECT_EQ(MemUtils::CreateTQueConfig(Position::kPositionGM, 1, 2).pos_, Position::kPositionInvalid); -} - -TEST(UtestMemUtils, CreateTBufConfigIncSuccess) { - auto tbuf = MemUtils::CreateTBufConfig(Position::kPositionVecIn); - EXPECT_EQ(tbuf.pos_, Position::kPositionVecIn); - auto old = tbuf.buf_attr_.id; - auto tbuf2 = MemUtils::CreateTBufConfig(Position::kPositionVecOut); - EXPECT_EQ(tbuf2.buf_attr_.id - old, 1); -} - -TEST(UtestMemUtils, CreateTQueConfigIncSuccess) { - auto tque1 = MemUtils::CreateTQueConfig(Position::kPositionVecIn, 1, 2); - auto old = tque1.queue_attr_.id; - auto tque2 = MemUtils::CreateTQueConfig(Position::kPositionVecIn, 2, 4); - EXPECT_EQ(tque1.queue_attr_.depth, 1); - EXPECT_EQ(tque1.queue_attr_.buf_num, 2); - EXPECT_EQ(tque2.queue_attr_.depth, 2); - EXPECT_EQ(tque2.queue_attr_.buf_num, 4); - EXPECT_EQ(tque2.queue_attr_.id - old, 1); -} - -TEST(UtestMemUtils, CreateTQueConfigBindTensorsSuccess) { - auto tque = MemUtils::CreateTQueConfig(Position::kPositionVecIn, 10, 20); - auto tque1 = MemUtils::CreateTQueConfig(Position::kPositionVecIn, 10, 20); - auto old = tque1.queue_attr_.id; - EXPECT_EQ(tque1.queue_attr_.depth, 10); - EXPECT_EQ(tque1.queue_attr_.buf_num, 20); - AscTensorAttr output1; - AscTensorAttr output2; - AscTensorAttr output3; - AscTensorAttr output4; - AscTensorAttr output5; - AscTensorAttr output6; - tque1.BindTensors(output1, output2); - tque1.BindTensors(output3, output4, output5); - tque1.BindTensors(output6); - EXPECT_EQ(output1.que.id, old); - EXPECT_EQ(output1.buf.id, kIdNone); - EXPECT_EQ(output1.que.depth, 10); - EXPECT_EQ(output1.que.buf_num, 20); - EXPECT_EQ(output2.que.id, old); - EXPECT_EQ(output2.buf.id, kIdNone); - EXPECT_EQ(output3.que.id, old); - EXPECT_EQ(output3.buf.id, kIdNone); - EXPECT_EQ(output4.que.id, old); - EXPECT_EQ(output4.buf.id, kIdNone); - EXPECT_EQ(output5.que.id, old); - EXPECT_EQ(output5.buf.id, kIdNone); - EXPECT_EQ(output6.que.id, old); - EXPECT_EQ(output6.buf.id, kIdNone); - EXPECT_EQ(output6.que.depth, 10); - EXPECT_EQ(output6.que.buf_num, 20); -} - -TEST(UtestMemUtils, CreateTBufConfigBindTensorsSuccess) { - auto tbuf = MemUtils::CreateTBufConfig(Position::kPositionVecIn); - auto tbuf1 = MemUtils::CreateTBufConfig(Position::kPositionVecIn); - auto old = tbuf1.buf_attr_.id; - EXPECT_EQ(tbuf1.pos_, Position::kPositionVecIn); - AscTensorAttr output1; - AscTensorAttr output2; - AscTensorAttr output3; - AscTensorAttr output4; - AscTensorAttr output5; - AscTensorAttr output6; - tbuf1.BindTensors(output1, output2); - tbuf1.BindTensors(output3, output4, output5); - tbuf1.BindTensors(output6); - EXPECT_EQ(output1.buf.id, old); - EXPECT_EQ(output2.buf.id, old); - EXPECT_EQ(output3.buf.id, old); - EXPECT_EQ(output4.buf.id, old); - EXPECT_EQ(output5.buf.id, old); - EXPECT_EQ(output6.buf.id, old); -} - -TEST(UtestMemUtils, MergeScopeSuccess) { - auto tbuf = MemUtils::CreateTBufConfig(Position::kPositionVecIn); - auto tbuf1 = MemUtils::CreateTBufConfig(Position::kPositionVecIn); - auto tbuf2 = MemUtils::CreateTBufConfig(Position::kPositionVecIn); - auto tbuf3 = MemUtils::CreateTBufConfig(Position::kPositionVecIn); - AscTensorAttr output1; - output1.opt.merge_scope = 2; - AscTensorAttr output2; - output2.opt.merge_scope = 3; - AscTensorAttr output3; - output3.opt.merge_scope = 4; - AscTensorAttr output4; - output4.opt.merge_scope = 5; - AscTensorAttr output5; - output5.opt.merge_scope = 6; - MemUtils::MergeScope(output1, output2, output3, output4, output5); - EXPECT_EQ(output1.opt.merge_scope, output2.opt.merge_scope); - EXPECT_EQ(output1.opt.merge_scope, output3.opt.merge_scope); - EXPECT_EQ(output1.opt.merge_scope, output4.opt.merge_scope); - EXPECT_EQ(output1.opt.merge_scope, output5.opt.merge_scope); -} -} diff --git a/tests/ut/base/CMakeLists.txt b/tests/ut/base/CMakeLists.txt index f1b21a2c31ec2269ebd860798dadd9861cc6050f..28a175c721ad9448f3946a10ff91d0ee6fd7a4e7 100644 --- a/tests/ut/base/CMakeLists.txt +++ b/tests/ut/base/CMakeLists.txt @@ -23,12 +23,17 @@ file(GLOB_RECURSE UT_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/base/testcase/*.cc" ) -add_executable(ut_metadef ${UT_FILES}) +file(GLOB_RECURSE FAKER_SRCS CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/faker/kernel_run_context_faker*.cc" + "${METADEF_DIR}/tests/depends/faker/allocator_*.cc" + ) + +add_executable(ut_metadef ${UT_FILES} ${FAKER_SRCS}) target_compile_options(ut_metadef PRIVATE -g --coverage -fprofile-arcs -ftest-coverage -Wno-deprecated-declarations -Wall -Wfloat-equal -Werror + -fno-access-control ) target_compile_definitions(ut_metadef PRIVATE @@ -40,9 +45,8 @@ target_compile_definitions(ut_metadef PRIVATE target_link_libraries(ut_metadef PRIVATE intf_pub -lgcov -Wl,--no-as-needed - platform_stub slog_headers - metadef_headers metadef error_manager + metadef_headers metadef opp_registry error_manager GTest::gtest GTest::gtest_main ascend_protobuf slog_stub c_sec json mmpa_stub -lrt -ldl ) diff --git a/tests/ut/graph/testcase/aligned_ptr_unittest.cc b/tests/ut/base/testcase/aligned_ptr_unittest.cc similarity index 100% rename from tests/ut/graph/testcase/aligned_ptr_unittest.cc rename to tests/ut/base/testcase/aligned_ptr_unittest.cc diff --git a/tests/ut/graph/testcase/any_value_ut.cc b/tests/ut/base/testcase/any_value_ut.cc similarity index 91% rename from tests/ut/graph/testcase/any_value_ut.cc rename to tests/ut/base/testcase/any_value_ut.cc index b78da83f47a65f8b73d57ed2edbcdbd281e7a956..0421c555055d07fa924ad226e2dd6f01bac04421 100644 --- a/tests/ut/graph/testcase/any_value_ut.cc +++ b/tests/ut/base/testcase/any_value_ut.cc @@ -15,7 +15,7 @@ #include "graph/ge_attr_value.h" #include "graph/compute_graph.h" #include "graph/ge_tensor.h" -#include "graph/buffer.h" + namespace ge { namespace { struct InlineFuncCounter : public FuncCounter { @@ -530,28 +530,6 @@ TEST_F(AnyValueUt, GetTypeOk) { EXPECT_EQ(av.GetValueTypeId(), GetTypeId()); EXPECT_EQ(av.GetValueType(), AnyValue::VT_INT); - av.SetValue(GeTensorDesc()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_TENSOR_DESC); - - av.SetValue(GeTensor()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_TENSOR); - - av.SetValue(Buffer()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_BYTES); - -// auto graph = proto::GraphDef(nullptr); -// av.SetValue(graph); -// EXPECT_EQ(av.GetValueTypeId(), GetTypeId()); -// EXPECT_EQ(av.GetValueType(), AnyValue::VT_GRAPH); - - av.SetValue(NamedAttrs()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_NAMED_ATTRS); - - av.SetValue(std::vector>({{1,2,3}, {1,2,3}})); EXPECT_EQ(av.GetValueTypeId(), GetTypeId>>()); EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_LIST_INT); @@ -580,26 +558,6 @@ TEST_F(AnyValueUt, GetTypeOk) { EXPECT_EQ(av.GetValueTypeId(), GetTypeId>()); EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_INT); - av.SetValue(std::vector()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId>()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_TENSOR_DESC); - - av.SetValue(std::vector()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId>()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_TENSOR); - - av.SetValue(std::vector()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId>()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_BYTES); - -// av.SetValue(std::vector()); -// EXPECT_EQ(av.GetValueTypeId(), GetTypeId>()); -// EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_GRAPH); - - av.SetValue(std::vector()); - EXPECT_EQ(av.GetValueTypeId(), GetTypeId>()); - EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_NAMED_ATTRS); - av.SetValue(std::vector()); EXPECT_EQ(av.GetValueTypeId(), GetTypeId>()); EXPECT_EQ(av.GetValueType(), AnyValue::VT_LIST_DATA_TYPE); diff --git a/tests/ut/graph/testcase/attr_store_ut.cc b/tests/ut/base/testcase/attr_store_ut.cc similarity index 74% rename from tests/ut/graph/testcase/attr_store_ut.cc rename to tests/ut/base/testcase/attr_store_ut.cc index 0ff4b8683a0def835035ae7b26a3005d804a42d3..d7b996102a79b5b710e697bc9bd2a217bd196b0c 100644 --- a/tests/ut/graph/testcase/attr_store_ut.cc +++ b/tests/ut/base/testcase/attr_store_ut.cc @@ -11,11 +11,6 @@ #include "graph/attr_store.h" #include "graph/op_desc.h" #include "ge_common/debug/ge_log.h" -#include "graph/attribute_group/attr_group_serialize.h" -#include "graph/attribute_group/attr_group_base.h" -#include "graph/attribute_group/attr_group_serializer_registry.h" -#include "graph/compute_graph.h" -#include "graph/attribute_group/attr_group_shape_env.h" #include "graph/debug/ge_util.h" #include "common/checker.h" #include "test_structs.h" @@ -35,8 +30,12 @@ struct TestAscendCIROpAttrGroups : public AttrGroupsBase { std::string name; std::string type; - graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override; - graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override; + graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override { + return ge::SUCCESS; + } + graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override { + return ge::SUCCESS; + }; std::unique_ptr Clone() override; bool operator==(const TestAscendCIROpAttrGroups &other) const { @@ -48,30 +47,6 @@ std::unique_ptr TestAscendCIROpAttrGroups::Clone() { return std::unique_ptr(new (std::nothrow) TestAscendCIROpAttrGroups(*this)); } -graphStatus TestAscendCIROpAttrGroups::Serialize(proto::AttrGroupDef &attr_group_def) { - auto op_attr_groups = attr_group_def.mutable_op_attr_group(); - if (op_attr_groups == nullptr) { - return GRAPH_FAILED; - } - - op_attr_groups->set_name(name); - op_attr_groups->set_type(type); - return GRAPH_SUCCESS; -} - -graphStatus TestAscendCIROpAttrGroups::Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) { - (void) attr_holder; - auto &op_attr_groups = attr_group_def.op_attr_group(); - - name = op_attr_groups.name(); - type = op_attr_groups.type(); - return GRAPH_SUCCESS; -} - -REG_ATTR_GROUP_SERIALIZER(TestAscendCIROpAttrGroups, - TestAscendCIROpAttrGroups, - GetTypeId(), - proto::AttrGroupDef::kOpAttrGroup); struct TestAscendCIROpAttrGroupsFailed : public AttrGroupsBase { std::string name; @@ -584,141 +559,4 @@ TEST_F(AttrStoreUt, ErrorTest2) { ASSERT_EQ(ret, GRAPH_FAILED); } -TEST_F(AttrStoreUt, AttrGroupSerializeAndDeSeralize) { - auto s = AttrStore::Create(1); - std::string attr_name = "Max memory"; - auto flag = s.CheckAttrIsExistInOtherGroup(attr_name); - ASSERT_EQ(flag, false); - AnyValue any_value; - any_value.SetValue(1); - - auto ret = s.SetAttrToOtherGroup(attr_name, any_value); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto m = s.GetAllAttrsFromOtherGroup(); - ASSERT_EQ(m.size(), 1); - - auto ptr = s.GetOrCreateAttrsGroup(); - ASSERT_NE(ptr, nullptr); - ptr->name = "test attr group"; - ptr->type = "test type"; - - proto::AttrGroups attr_group; - ret = AttrGroupSerialize::SerializeAllAttr(attr_group, s); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto op_desc = std::make_shared("Stub", "Stub"); - ASSERT_NE(op_desc, nullptr); - ret = AttrGroupSerialize::DeserializeAllAttr(attr_group, op_desc.get()); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto ptr_new = op_desc->GetAttrsGroup(); - ASSERT_NE(ptr_new, nullptr); - ASSERT_EQ(*ptr_new, *ptr); -} - -TEST_F(AttrStoreUt, AttrGroupSerializer_invalid) { - EXPECT_NE(AttrGroupSerializerRegistry::GetInstance().GetSerializer(GetTypeId()), nullptr); - // invalid builder - AttrGroupSerializerRegistry::GetInstance().RegisterAttrGroupSerialize([]() -> std::unique_ptr { return nullptr; }, - GetTypeId(), - proto::AttrGroupDef::kOpAttrGroup); - // repeat reg - AttrGroupSerializerRegistry::GetInstance().RegisterAttrGroupSerialize([]() -> std::unique_ptr { - return std::unique_ptr(new(std::nothrow)TestAscendCIROpAttrGroups()); - }, GetTypeId(), - proto::AttrGroupDef::kOpAttrGroup); - // builder is null - AttrGroupSerializerRegister attr_group_serializer_registrar - (nullptr, GetTypeId(), proto::AttrGroupDef::kOpAttrGroup); -} - -TEST_F(AttrStoreUt, GetOrCreateAttrGroupWith0Args) { - auto s = AttrStore::Create(1); - auto ptr = s.CreateAttrsGroup(); - ASSERT_NE(ptr, nullptr); - ASSERT_EQ(ptr->a, 0); - ASSERT_EQ(ptr->b, 0); - - auto ptr_1 = s.CreateAttrsGroup(); - ASSERT_EQ(ptr_1, nullptr); - - auto ptr_2 = s.CreateAttrsGroup(1); - ASSERT_EQ(ptr_2, nullptr); - - auto ptr_3 = s.CreateAttrsGroup(1, 2); - ASSERT_EQ(ptr_3, nullptr); - - auto ptr_4 = s.GetAttrsGroup(); - ASSERT_EQ(ptr_4, ptr); - - auto ptr_5 = s.GetOrCreateAttrsGroup(); - ASSERT_EQ(ptr_5, ptr); -} - -TEST_F(AttrStoreUt, GetOrCreateAttrGroupWith1Args) { - auto s = AttrStore::Create(1); - auto ptr = s.CreateAttrsGroup(1); - ASSERT_NE(ptr, nullptr); - ASSERT_EQ(ptr->a, 1); - ASSERT_EQ(ptr->b, 0); - - auto ptr_1 = s.CreateAttrsGroup(); - ASSERT_EQ(ptr_1, nullptr); - - auto ptr_2 = s.CreateAttrsGroup(1); - ASSERT_EQ(ptr_2, nullptr); - - auto ptr_3 = s.CreateAttrsGroup(1, 2); - ASSERT_EQ(ptr_3, nullptr); - - auto ptr_4 = s.GetAttrsGroup(); - ASSERT_EQ(ptr_4, ptr); - - auto ptr_5 = s.GetOrCreateAttrsGroup(); - ASSERT_EQ(ptr_5, ptr); -} - -TEST_F(AttrStoreUt, DeleteAttrsGroup) { - auto s = AttrStore::Create(1); - ASSERT_FALSE(s.DeleteAttrsGroup()); - s.CreateAttrsGroup(); - ASSERT_TRUE(s.DeleteAttrsGroup()); - ASSERT_FALSE(s.DeleteAttrsGroup()); - ASSERT_EQ(s.GetAttrsGroup(), nullptr); - ASSERT_FALSE(s.CheckAttrGroupIsExist()); - s.CreateAttrsGroup(); - ASSERT_TRUE(s.DeleteAttrsGroup()); - - ge::ComputeGraph cg("simple"); - auto graph_attr_group_ptr = cg.GetOrCreateAttrsGroup(); - ASSERT_TRUE(cg.DeleteAttrsGroup()); - ASSERT_EQ(cg.GetAttrsGroup(), nullptr); - ASSERT_FALSE(cg.DeleteAttrsGroup()); - graph_attr_group_ptr = cg.CreateAttrsGroup(); - EXPECT_TRUE(graph_attr_group_ptr != nullptr); -} - -TEST_F(AttrStoreUt, GetOrCreateAttrGroupWith2Args) { - auto s = AttrStore::Create(1); - auto ptr = s.CreateAttrsGroup(1, 2); - ASSERT_NE(ptr, nullptr); - ASSERT_EQ(ptr->a, 1); - ASSERT_EQ(ptr->b, 2); - - auto ptr_1 = s.CreateAttrsGroup(); - ASSERT_EQ(ptr_1, nullptr); - - auto ptr_2 = s.CreateAttrsGroup(1); - ASSERT_EQ(ptr_2, nullptr); - - auto ptr_3 = s.CreateAttrsGroup(1, 2); - ASSERT_EQ(ptr_3, nullptr); - - auto ptr_4 = s.GetAttrsGroup(); - ASSERT_EQ(ptr_4, ptr); - - auto ptr_5 = s.GetOrCreateAttrsGroup(); - ASSERT_EQ(ptr_5, ptr); -} } diff --git a/tests/ut/graph/testcase/attributes_holder_unittest.cc b/tests/ut/base/testcase/attributes_holder_unittest.cc similarity index 54% rename from tests/ut/graph/testcase/attributes_holder_unittest.cc rename to tests/ut/base/testcase/attributes_holder_unittest.cc index f9dc15db80cc1bd0c07bbce0f730eaaac9ecbbaa..3110b0eec46be2d3be6329e22ccb45201651fd62 100644 --- a/tests/ut/graph/testcase/attributes_holder_unittest.cc +++ b/tests/ut/base/testcase/attributes_holder_unittest.cc @@ -13,7 +13,6 @@ #include "graph/detail/attributes_holder.h" #include "graph/ge_attr_value.h" #include "graph/any_value.h" -#include "ge_ir.pb.h" namespace ge { namespace { @@ -65,39 +64,6 @@ void oper(AnyValue::OperateType ot, const AnyValue *av, void *out){ class AttrHolderUt : public testing::Test {}; -TEST_F(AttrHolderUt, All) { - EXPECT_NO_THROW( - GeIrProtoHelper helper1; - helper1.InitDefault(); - - GeIrProtoHelper helper2; - helper2.InitDefault(); - - GeIrProtoHelper helper3; - helper3.InitDefault(); - - GeIrProtoHelper helper4; - helper4.InitDefault(); - - GeIrProtoHelper helper5; - helper5.InitDefault(); - - GeIrProtoHelper helper6; - helper6.InitDefault(); - ); -} - -TEST_F(AttrHolderUt, Plus) { - - SubAttrHolder sub_attr_hodler = SubAttrHolder(); - AnyValue av = AnyValue::CreateFrom(1); - av.operate_ = oper; - EXPECT_EQ(sub_attr_hodler.SetAttr("name", av), GRAPH_SUCCESS); - av.operate_ = nullptr; - EXPECT_EQ(sub_attr_hodler.TrySetAttr("name", av), GRAPH_FAILED); - EXPECT_EQ(sub_attr_hodler.AddRequiredAttr("name"), GRAPH_FAILED); -} - TEST_F(AttrHolderUt, ExtAttrGetSuccess) { SubAttrHolder holder; EXPECT_EQ(holder.GetExtAttr("TestName"), nullptr); @@ -169,92 +135,4 @@ TEST_F(AttrHolderUt, ExtAttrEraseFailedWhenAttrNotExsit) { SubAttrHolder holder; EXPECT_FALSE(holder.DelExtAttr("TestName")); } -TEST_F(AttrHolderUt, GetOrCreateAttrsGroup_AutoCreate_Ok) { - SubAttrHolder holder; - ASSERT_NE(holder.GetOrCreateAttrsGroup(), nullptr); -} -TEST_F(AttrHolderUt, GetOrCreateAttrsGroup_MultiTimes_SameRet_Ok) { - SubAttrHolder holder; - auto ret1 = holder.GetOrCreateAttrsGroup(); - auto ret2 = holder.GetOrCreateAttrsGroup(); - ASSERT_EQ(ret1, ret2); - ASSERT_NE(ret1, nullptr); -} -TEST_F(AttrHolderUt, GetAttrsGroup_NotExists_ReturnNull) { - SubAttrHolder holder; - ASSERT_EQ(holder.GetAttrsGroup(), nullptr); -} -TEST_F(AttrHolderUt, GetAttrsGroup_Ok) { - SubAttrHolder holder; - auto ret1 = holder.GetOrCreateAttrsGroup(); - ASSERT_NE(ret1, nullptr); - auto ret2 = holder.GetAttrsGroup(); - ASSERT_EQ(ret1, ret2); -} -TEST_F(AttrHolderUt, CreateAttrGroupWith0Args) { - SubAttrHolder s; - auto ptr = s.CreateAttrsGroup(); - ASSERT_NE(ptr, nullptr); - ASSERT_EQ(ptr->a, 0); - ASSERT_EQ(ptr->b, 0); - - auto ptr_1 = s.CreateAttrsGroup(); - ASSERT_EQ(ptr_1, nullptr); - - auto ptr_2 = s.CreateAttrsGroup(1); - ASSERT_EQ(ptr_2, nullptr); - - auto ptr_3 = s.CreateAttrsGroup(1, 2); - ASSERT_EQ(ptr_3, nullptr); - - auto ptr_4 = s.GetAttrsGroup(); - ASSERT_EQ(ptr_4, ptr); - - auto ptr_5 = s.GetOrCreateAttrsGroup(); - ASSERT_EQ(ptr_5, ptr); -} -TEST_F(AttrHolderUt, CreateAttrGroupWith1Args) { - SubAttrHolder s; - auto ptr = s.CreateAttrsGroup(1); - ASSERT_NE(ptr, nullptr); - ASSERT_EQ(ptr->a, 1); - ASSERT_EQ(ptr->b, 0); - - auto ptr_1 = s.CreateAttrsGroup(); - ASSERT_EQ(ptr_1, nullptr); - - auto ptr_2 = s.CreateAttrsGroup(1); - ASSERT_EQ(ptr_2, nullptr); - - auto ptr_3 = s.CreateAttrsGroup(1, 2); - ASSERT_EQ(ptr_3, nullptr); - - auto ptr_4 = s.GetAttrsGroup(); - ASSERT_EQ(ptr_4, ptr); - - auto ptr_5 = s.GetOrCreateAttrsGroup(); - ASSERT_EQ(ptr_5, ptr); -} -TEST_F(AttrHolderUt, CreateAttrGroupWith2Args) { - SubAttrHolder s; - auto ptr = s.CreateAttrsGroup(1, 2); - ASSERT_NE(ptr, nullptr); - ASSERT_EQ(ptr->a, 1); - ASSERT_EQ(ptr->b, 2); - - auto ptr_1 = s.CreateAttrsGroup(); - ASSERT_EQ(ptr_1, nullptr); - - auto ptr_2 = s.CreateAttrsGroup(1); - ASSERT_EQ(ptr_2, nullptr); - - auto ptr_3 = s.CreateAttrsGroup(1, 2); - ASSERT_EQ(ptr_3, nullptr); - - auto ptr_4 = s.GetAttrsGroup(); - ASSERT_EQ(ptr_4, ptr); - - auto ptr_5 = s.GetOrCreateAttrsGroup(); - ASSERT_EQ(ptr_5, ptr); -} } // namespace ge diff --git a/tests/ut/graph/testcase/file_utils_unittest.cc b/tests/ut/base/testcase/file_utils_unittest.cc similarity index 80% rename from tests/ut/graph/testcase/file_utils_unittest.cc rename to tests/ut/base/testcase/file_utils_unittest.cc index e1a85079c6f613d8da6a669681b4a9b77fb87345..b7cf53090cd8357da6fa8081bc1e93969343d5aa 100644 --- a/tests/ut/graph/testcase/file_utils_unittest.cc +++ b/tests/ut/base/testcase/file_utils_unittest.cc @@ -82,6 +82,28 @@ TEST_F(UtestFileUtils, GetBinFileFromFileSuccess) { system(("rm -f " + so_bin).c_str()); } +TEST_F(UtestFileUtils, GetBinFileFromFileSuccess_offset) { + std::string so_bin = "./opsptoro.so"; + system(("touch " + so_bin).c_str()); + system(("echo '123' > " + so_bin).c_str()); + size_t data_len = 4; + size_t offset = 0; + std::unique_ptr so_data = GetBinFromFile(so_bin, offset, data_len); + ASSERT_NE(so_data, nullptr); + ASSERT_EQ(data_len, 4); + ASSERT_EQ(so_data.get()[0], '1'); + ASSERT_EQ(so_data.get()[1], '2'); + ASSERT_EQ(so_data.get()[2], '3'); + + ASSERT_EQ(GetBinFromFile(so_bin, static_cast(so_data.get()), data_len), GRAPH_SUCCESS); + ASSERT_NE(so_data, nullptr); + ASSERT_EQ(data_len, 4); + ASSERT_EQ(so_data.get()[0], '1'); + ASSERT_EQ(so_data.get()[1], '2'); + ASSERT_EQ(so_data.get()[2], '3'); + system(("rm -f " + so_bin).c_str()); +} + TEST_F(UtestFileUtils, GetBinFilePathNullFail) { std::string so_bin = ""; uint32_t data_len; @@ -100,7 +122,7 @@ TEST_F(UtestFileUtils, WriteBinToFileSuccess) { uint32_t data_len = 4; char so_data[4] = {'1', '2', '3'}; ASSERT_EQ(WriteBinToFile(so_bin, so_data, data_len), GRAPH_SUCCESS); - + ASSERT_EQ(SaveBinToFile(so_data, data_len, so_bin), GRAPH_SUCCESS); system(("rm -f " + so_bin).c_str()); } diff --git a/tests/ut/graph/testcase/func_counter.cc b/tests/ut/base/testcase/func_counter.cc similarity index 100% rename from tests/ut/graph/testcase/func_counter.cc rename to tests/ut/base/testcase/func_counter.cc diff --git a/tests/ut/graph/testcase/func_counter.h b/tests/ut/base/testcase/func_counter.h similarity index 100% rename from tests/ut/graph/testcase/func_counter.h rename to tests/ut/base/testcase/func_counter.h diff --git a/tests/ut/base/testcase/ge_attr_value_unittest.cc b/tests/ut/base/testcase/ge_attr_value_unittest.cc new file mode 100644 index 0000000000000000000000000000000000000000..f38804add2fd5eccea85a8b003af4dc349090e63 --- /dev/null +++ b/tests/ut/base/testcase/ge_attr_value_unittest.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#include +#include "graph/op_desc.h" +#include "graph/ge_attr_value.h" +#include "graph/utils/attr_utils.h" +#include "external/graph/attr_value.h" +#include "external/graph/tensor.h" +#include "external/graph/ascend_string.h" +#include "external/graph/types.h" +#include +#include +#include + +namespace ge { +graphStatus AttrValue::SetAttrValue(const AscendString &attr_value) const { + return impl->geAttrValue_.SetValue(std::string(attr_value.GetString())); +} + +graphStatus AttrValue::SetAttrValue(const AttrValue::FLOAT &attr_value) const { + return impl->geAttrValue_.SetValue(attr_value); +} + +graphStatus AttrValue::SetAttrValue(const AttrValue::INT &attr_value) const { + return impl->geAttrValue_.SetValue(attr_value); +} + +class UtestGeAttrValue : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; + +// extern "C" wrapper for AttrValue SetAttrValue methods to avoid C++ name mangling +extern "C" { + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_AttrValue_SetAttrValue_Int64(void *attr_value_ptr, + int64_t value) { + if (attr_value_ptr == nullptr) { + return GRAPH_FAILED; + } + auto *attr_value = static_cast(attr_value_ptr); + return attr_value->SetAttrValue(value); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_AttrValue_SetAttrValue_String(void *attr_value_ptr, + const char_t *value) { + if (attr_value_ptr == nullptr || value == nullptr) { + return GRAPH_FAILED; + } + auto *attr_value = static_cast(attr_value_ptr); + return attr_value->SetAttrValue(ge::AscendString(value)); +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_AttrValue_SetAttrValue_Float(void *attr_value_ptr, + AttrValue::FLOAT value) { + if (attr_value_ptr == nullptr) { + return GRAPH_FAILED; + } + auto *attr_value = static_cast(attr_value_ptr); + return attr_value->SetAttrValue(value); +} +} + +TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Int64_Success) { + AttrValue attr_value; + int64_t int64_value = 12345; + + // 测试成功情况 + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, int64_value), GRAPH_SUCCESS); + + // 验证设置的值 + int64_t get_value = 0; + EXPECT_EQ(attr_value.GetValue(get_value), GRAPH_SUCCESS); + EXPECT_EQ(get_value, int64_value); + + // 测试nullptr参数 + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(nullptr, int64_value), GRAPH_FAILED); +} + +TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Float_Success) { + AttrValue attr_value; + AttrValue::FLOAT float_value = 12345.0; + + // 测试成功情况 + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Float(&attr_value, float_value), GRAPH_SUCCESS); + + // 验证设置的值 + AttrValue::FLOAT get_value = 0; + EXPECT_EQ(attr_value.GetValue(get_value), GRAPH_SUCCESS); + EXPECT_EQ(get_value, float_value); + + // 测试nullptr参数 + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Float(nullptr, float_value), GRAPH_FAILED); +} + +TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_String_Success) { + AttrValue attr_value; + const char_t *char_value = "12345"; + + // 测试成功情况 + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_String(&attr_value, char_value), GRAPH_SUCCESS); + + // 验证设置的值 + ge::AscendString get_value = "0"; + EXPECT_EQ(attr_value.GetValue(get_value), GRAPH_SUCCESS); + EXPECT_EQ(get_value, char_value); + + // 测试nullptr参数 + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_String(nullptr, char_value), GRAPH_FAILED); +} + +TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Int64_EdgeCases) { + AttrValue attr_value; + + // 测试边界值 + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, std::numeric_limits::max()), GRAPH_SUCCESS); + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, std::numeric_limits::min()), GRAPH_SUCCESS); + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, 0), GRAPH_SUCCESS); + EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, -1), GRAPH_SUCCESS); +} +} diff --git a/tests/ut/exe_graph/kernel_context_unittest.cc b/tests/ut/base/testcase/kernel_context_unittest.cc similarity index 45% rename from tests/ut/exe_graph/kernel_context_unittest.cc rename to tests/ut/base/testcase/kernel_context_unittest.cc index 880874debb42e58471640fe6314d99401271cb6c..292c7c6ae2943ab81ad6a41f5d61a8df1525d4db 100644 --- a/tests/ut/exe_graph/kernel_context_unittest.cc +++ b/tests/ut/base/testcase/kernel_context_unittest.cc @@ -84,78 +84,4 @@ TEST_F(KernelContextUT, ChainSetAndUseStructOk) { av.deleter(av.data.pointer); } - -TEST_F(KernelContextUT, GetInputNumOk) { - auto context_holder = KernelRunContextFaker().KernelIONum(5, 6).Build(); - auto context = context_holder.GetContext(); - EXPECT_EQ(context->GetInputNum(), 5); - EXPECT_EQ(context->GetOutputNum(), 6); -} - -TEST_F(KernelContextUT, GetInnerInputOk) { - auto context_holder = KernelRunContextFaker().KernelIONum(5, 6).Build(); - auto context = context_holder.GetContext(); - - EXPECT_EQ(context->GetInputPointer(0), context_holder.holder.value_holder_[0].GetPointer()); - EXPECT_EQ(context->GetInputPointer(1), context_holder.holder.value_holder_[1].GetPointer()); - EXPECT_EQ(context->GetInputPointer(2), context_holder.holder.value_holder_[2].GetPointer()); - - EXPECT_EQ(context->GetInputValue(4), context_holder.holder.value_holder_[4].GetValue()); - - EXPECT_EQ(context->GetInput(0), &context_holder.holder.value_holder_[0]); - EXPECT_EQ(context->GetInput(4), &context_holder.holder.value_holder_[4]); -} - -TEST_F(KernelContextUT, GetAllocInputOk) { - auto context_holder = KernelRunContextFaker().KernelIONum(5, 6).Build(); - std::vector t_holder(11); - for (size_t i = 0; i < 11; ++i) { - context_holder.holder.value_holder_[i].any_value_.data.pointer = &t_holder[i]; - } - auto context = context_holder.GetContext(); - - EXPECT_EQ(context->GetInputPointer(0), - reinterpret_cast((context_holder.holder.value_holder_[0].any_value_.data.pointer))); - EXPECT_EQ(context->GetInputPointer(1), - reinterpret_cast((context_holder.holder.value_holder_[1].any_value_.data.pointer))); - EXPECT_EQ(context->GetInputPointer(4), - reinterpret_cast((context_holder.holder.value_holder_[4].any_value_.data.pointer))); - - EXPECT_EQ(context->GetInput(0), &context_holder.holder.value_holder_[0]); - EXPECT_EQ(context->GetInput(4), &context_holder.holder.value_holder_[4]); -} - -TEST_F(KernelContextUT, GetInnerOutputOk) { - auto context_holder = KernelRunContextFaker().KernelIONum(5, 6).Build(); - auto context = context_holder.GetContext(); - - EXPECT_EQ(context->GetOutputPointer(0), - reinterpret_cast(&(context_holder.holder.value_holder_[5].any_value_.data))); - EXPECT_EQ(context->GetOutputPointer(1), - reinterpret_cast(&(context_holder.holder.value_holder_[6].any_value_.data))); - EXPECT_EQ(context->GetOutputPointer(2), - reinterpret_cast(&(context_holder.holder.value_holder_[7].any_value_.data))); - - EXPECT_EQ(context->GetOutput(0), &context_holder.holder.value_holder_[5]); - EXPECT_EQ(context->GetOutput(5), &context_holder.holder.value_holder_[10]); -} - -TEST_F(KernelContextUT, GetAllocOutputOk) { - auto context_holder = KernelRunContextFaker().KernelIONum(5, 6).Build(); - std::vector t_holder(11); - for (size_t i = 0; i < 11; ++i) { - context_holder.holder.value_holder_[i].any_value_.data.pointer = &t_holder[i]; - } - auto context = context_holder.GetContext(); - - EXPECT_EQ(context->GetOutputPointer(0), - reinterpret_cast((context_holder.holder.value_holder_[5].any_value_.data.pointer))); - EXPECT_EQ(context->GetOutputPointer(1), - reinterpret_cast((context_holder.holder.value_holder_[6].any_value_.data.pointer))); - EXPECT_EQ(context->GetOutputPointer(4), - reinterpret_cast((context_holder.holder.value_holder_[9].any_value_.data.pointer))); - - EXPECT_EQ(context->GetOutput(0), &context_holder.holder.value_holder_[5]); - EXPECT_EQ(context->GetOutput(4), &context_holder.holder.value_holder_[9]); -} } // namespace gert diff --git a/register/ffts_plus_update_manager.cc b/tests/ut/base/testcase/op_execute_context_unittest.cc similarity index 32% rename from register/ffts_plus_update_manager.cc rename to tests/ut/base/testcase/op_execute_context_unittest.cc index 361c2212b700aa46f6ce2e3ce474527f71b95279..8c6357c945299281fb13720604b246cdaafaf6ad 100644 --- a/register/ffts_plus_update_manager.cc +++ b/tests/ut/base/testcase/op_execute_context_unittest.cc @@ -7,48 +7,47 @@ * See LICENSE in the root of the software repository for the full text of the License. * ===================================================================================================================*/ -#include "register/ffts_plus_update_manager.h" -#include "graph/debug/ge_util.h" -#include "common/plugin/plugin_manager.h" - -namespace ge { -FftsPlusUpdateManager &FftsPlusUpdateManager::Instance() { - static FftsPlusUpdateManager instance; - return instance; +#include "exe_graph/runtime/op_execute_context.h" +#include "graph/ge_error_codes.h" +#include +#include "faker/kernel_run_context_faker.h" +#include "faker/allocator_faker.h" +#include "exe_graph/runtime/storage_shape.h" +#include "base/context_builder/op_kernel_run_context_builder.h" + +namespace gert { +class OpExecuteContextUT : public testing::Test {}; + + +TEST_F(OpExecuteContextUT, MallocFreeWorkSpaceOk) { + OpKernelContextBuilder ctx_builder; + gert::StorageShape shape0 = {{10, 20}, {10, 20}}; + AllocatorFaker gert_allocator; + auto output_block_memory = std::make_shared>(); + ASSERT_NE(output_block_memory, nullptr); + output_block_memory->reserve(1UL); + + auto holder = ctx_builder.OpType("Add") + .OpName("add_1") + .IONum(2, 1) + .InputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .InputTensorDesc(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .OutputTensorDesc(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) + .Inputs({&shape0, &shape0, &gert_allocator, &gert_allocator}) + .Outputs({output_block_memory.get()}) + .Build(); + + auto context = reinterpret_cast(holder.GetContext()); + ASSERT_NE(context, nullptr); + + auto block = context->MallocWorkspace(1024); + ASSERT_NE(block, nullptr); + + auto kernel_context = reinterpret_cast(context); + auto memory_vec = kernel_context->GetOutputPointer>(0UL); + ASSERT_NE(memory_vec, nullptr); + EXPECT_EQ(memory_vec->size(), 1UL); + context->FreeWorkspace(); + EXPECT_EQ(memory_vec->size(), 0UL); } - -FftsCtxUpdatePtr FftsPlusUpdateManager::GetUpdater(const std::string &core_type) const { - const std::map::const_iterator it = creators_.find(core_type); - if (it == creators_.cend()) { - GELOGW("Cannot find creator for core type: %s.", core_type.c_str()); - return nullptr; - } - - return it->second(); -} - -void FftsPlusUpdateManager::RegisterCreator(const std::string &core_type, const FftsCtxUpdateCreatorFun &creator) { - if (creator == nullptr) { - GELOGW("Register null creator for core type: %s", core_type.c_str()); - return; - } - - const auto it = creators_.find(core_type); - if (it != creators_.end()) { - GELOGW("Creator already exist for core type: %s", core_type.c_str()); - return; - } - - GELOGI("Register creator for core type: %s", core_type.c_str()); - creators_[core_type] = creator; -} - -Status FftsPlusUpdateManager::Initialize() { - return SUCCESS; -} - -FftsPlusUpdateManager::~FftsPlusUpdateManager() { - creators_.clear(); // clear must be called before `plugin_manager_.reset` which would close so - plugin_manager_.reset(); } -} // namespace ge diff --git a/tests/ut/graph/testcase/plugin_manager_unittest.cc b/tests/ut/base/testcase/plugin_manager_unittest.cc similarity index 100% rename from tests/ut/graph/testcase/plugin_manager_unittest.cc rename to tests/ut/base/testcase/plugin_manager_unittest.cc diff --git a/tests/ut/graph/testcase/test_structs.h b/tests/ut/base/testcase/test_structs.h similarity index 73% rename from tests/ut/graph/testcase/test_structs.h rename to tests/ut/base/testcase/test_structs.h index 06c754eb5e87eb07a543bd6bc7f80903699247a5..0171a80897c5a605cd94a6b53cea8254949b0809 100644 --- a/tests/ut/graph/testcase/test_structs.h +++ b/tests/ut/base/testcase/test_structs.h @@ -9,7 +9,7 @@ #ifndef METADEF_CXX_TEST_STRUCTS_H #define METADEF_CXX_TEST_STRUCTS_H -#include "graph/attribute_group/attr_group_base.h" + #include "debug/ge_util.h" namespace ge { struct TestStructA { @@ -65,26 +65,5 @@ struct InlineStructB { private: int32_t *a; }; -struct TestAttrGroup : public AttrGroupsBase { - TestAttrGroup(int32_t a, int32_t b) : a(a), b(b) {} - TestAttrGroup(int32_t a) : a(a), b(0) {} - TestAttrGroup() : a(0), b(0) {} - int32_t a; - int32_t b; - graphStatus status{GRAPH_SUCCESS}; - graphStatus Serialize(proto::AttrGroupDef &attr_group_def) override { - (void) attr_group_def; - return status; - } - - graphStatus Deserialize(const proto::AttrGroupDef &attr_group_def, AttrHolder *attr_holder) override { - (void) attr_holder; - (void) attr_group_def; - return status; - } - std::unique_ptr Clone() override { - return ComGraphMakeUnique(*this); // UT,由使用者判空 - } -}; } #endif //METADEF_CXX_TEST_STRUCTS_H diff --git a/tests/ut/exe_graph/tiling_data_unittest.cc b/tests/ut/base/testcase/tiling_data_unittest.cc similarity index 82% rename from tests/ut/exe_graph/tiling_data_unittest.cc rename to tests/ut/base/testcase/tiling_data_unittest.cc index e148124c3e89199e9d96c5736bd49e7a1746ae9d..0c4f1bb3dc21a67957ecdbe48a01a62ff2ae65be 100644 --- a/tests/ut/exe_graph/tiling_data_unittest.cc +++ b/tests/ut/base/testcase/tiling_data_unittest.cc @@ -10,6 +10,8 @@ #include "exe_graph/runtime/tiling_data.h" #include "common/util/tiling_utils.h" #include "faker/kernel_run_context_faker.h" +#include "base/context_builder/op_kernel_run_context_builder.h" +#include "base/context_builder/op_tiling_context_builder.h" #include "common/ge_common/debug/ge_log.h" #include namespace gert { @@ -22,23 +24,25 @@ struct TestData { int16_t d; }; -FakeKernelContextHolder BuildTestContext() { - auto holder = gert::KernelRunContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeAttrs({{"str", ge::AnyValue::CreateFrom("Hello!")}, - {"float", ge::AnyValue::CreateFrom(10.101)}, - {"list_float", ge::AnyValue::CreateFrom>({1.2, 2.3, 3.4})}, - {"list_list_float", ge::AnyValue::CreateFrom>>( - {{1.2, 2.3, 3.4}, {4.5, 5.6, 6.7}})}, - {"int", ge::AnyValue::CreateFrom(0x7fUL)}, - {"list_int", ge::AnyValue::CreateFrom>({1, 2, 3})}, - {"list_list_int", ge::AnyValue::CreateFrom>>( - {{1, 2, 3}, {4, 5, 6}})}, - {"bool", ge::AnyValue::CreateFrom(true)}, - {"list_bool", ge::AnyValue::CreateFrom>({true, false, true})}}) +gert::ContextHolder BuildTestContext() { + uint8_t tmp_compile_info[] = {1, 2, 3, 4, 5, 6, 7}; + uint8_t tmp_platform_info[] = {1, 2, 3, 4, 5, 6, 7}; + OpTilingContextBuilder ctx_builder; + auto holder = ctx_builder + .IONum(1, 1) + .OpName("tmp") + .OpType("DIY") + .CompileInfo(tmp_compile_info) + .PlatformInfo(tmp_platform_info) + .AppendAttr(ge::AscendString("Hello!")) + .AppendAttr(float(10.101)) + .AppendAttr(std::vector({1.2, 2.3, 3.4})) + .AppendAttr(std::vector>({{1, 2, 3}, {4, 5, 6}})) + .AppendAttr(int64_t(0x7fUL)) + .AppendAttr(std::vector({1, 2, 3})) + .AppendAttr(std::vector>({{1, 2, 3}, {4, 5, 6}})) + .AppendAttr(bool(true)) + .AppendAttr(std::vector({true, false, true})) .Build(); return holder; } @@ -139,7 +143,7 @@ TEST_F(TilingDataUT, AppendAttrStrOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 0, AttrDataType::kString, AttrDataType::kString), ge::GRAPH_SUCCESS); @@ -154,7 +158,7 @@ TEST_F(TilingDataUT, AppendAttrBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kBool), ge::GRAPH_SUCCESS); @@ -166,7 +170,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListBool), ge::GRAPH_SUCCESS); @@ -182,7 +186,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kFloat16), ge::GRAPH_SUCCESS); @@ -194,7 +198,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListFloat16), ge::GRAPH_SUCCESS); @@ -210,7 +214,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kBfloat16), ge::GRAPH_SUCCESS); @@ -222,7 +226,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListBfloat16), ge::GRAPH_SUCCESS); @@ -238,7 +242,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kFloat32), ge::GRAPH_SUCCESS); @@ -250,7 +254,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListFloat32), ge::GRAPH_SUCCESS); @@ -266,7 +270,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kInt8), ge::GRAPH_SUCCESS); @@ -278,7 +282,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListInt8), ge::GRAPH_SUCCESS); @@ -294,7 +298,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kInt16), ge::GRAPH_SUCCESS); @@ -306,7 +310,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListInt16), ge::GRAPH_SUCCESS); @@ -322,7 +326,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kInt32), ge::GRAPH_SUCCESS); @@ -334,7 +338,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListInt32), ge::GRAPH_SUCCESS); @@ -350,7 +354,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToInt64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kInt64), ge::GRAPH_SUCCESS); @@ -362,7 +366,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListInt64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListInt64), ge::GRAPH_SUCCESS); @@ -378,7 +382,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kUint8), ge::GRAPH_SUCCESS); @@ -390,7 +394,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListUint8), ge::GRAPH_SUCCESS); @@ -406,7 +410,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kUint16), ge::GRAPH_SUCCESS); @@ -418,7 +422,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListUint16), ge::GRAPH_SUCCESS); @@ -434,7 +438,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kUint32), ge::GRAPH_SUCCESS); @@ -446,7 +450,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListUint32), ge::GRAPH_SUCCESS); @@ -462,7 +466,7 @@ TEST_F(TilingDataUT, AppendAttrBoolToUint64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kUint64), ge::GRAPH_SUCCESS); @@ -474,7 +478,7 @@ TEST_F(TilingDataUT, AppendAttrListBoolToListUint64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 8, AttrDataType::kListBool, AttrDataType::kListUint64), ge::GRAPH_SUCCESS); @@ -490,7 +494,7 @@ TEST_F(TilingDataUT, AppendAttrFloat32ToBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kBool), ge::GRAPH_SUCCESS); @@ -502,7 +506,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListBool), ge::GRAPH_SUCCESS); @@ -513,26 +517,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListBoolOk) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListBoolOk) { - auto data = TilingData::CreateCap(20); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListBool), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(bool) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], true); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kFloat16), ge::GRAPH_SUCCESS); @@ -544,7 +533,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListFloat16), ge::GRAPH_SUCCESS); @@ -556,27 +545,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListFloat16Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListFloat16Ok) { - auto data = TilingData::CreateCap(20); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListFloat16), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint16_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1.2, 2.3, 3.4, 4.5, 5.6, 6.7}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], optiling::Float32ToFloat16(expect_data[i])); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kBfloat16), ge::GRAPH_SUCCESS); @@ -588,7 +561,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListBfloat16), ge::GRAPH_SUCCESS); @@ -600,28 +573,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListBfloat16Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListBfloat16Ok) { - auto data = TilingData::CreateCap(20); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListBfloat16), - ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint16_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1.2, 2.3, 3.4, 4.5, 5.6, 6.7}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], optiling::Float32ToBfloat16(expect_data[i])); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kFloat32), ge::GRAPH_SUCCESS); @@ -633,7 +589,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListFloat32), ge::GRAPH_SUCCESS); @@ -645,28 +601,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32Ok) { - auto data = TilingData::CreateCap(30); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListFloat32), - ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(float) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1.2, 2.3, 3.4, 4.5, 5.6, 6.7}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kInt8), ge::GRAPH_SUCCESS); @@ -678,7 +617,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListInt8), ge::GRAPH_SUCCESS); @@ -690,27 +629,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt8Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListInt8Ok) { - auto data = TilingData::CreateCap(20); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListInt8), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(int8_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kInt16), ge::GRAPH_SUCCESS); @@ -722,7 +645,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListInt16), ge::GRAPH_SUCCESS); @@ -734,27 +657,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt16Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListInt16Ok) { - auto data = TilingData::CreateCap(20); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListInt16), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(int16_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kInt32), ge::GRAPH_SUCCESS); @@ -766,7 +673,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListInt32), ge::GRAPH_SUCCESS); @@ -778,27 +685,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt32Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListInt32Ok) { - auto data = TilingData::CreateCap(30); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListInt32), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(int32_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToInt64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kInt64), ge::GRAPH_SUCCESS); @@ -810,7 +701,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListInt64), ge::GRAPH_SUCCESS); @@ -822,27 +713,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListInt64Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListInt64Ok) { - auto data = TilingData::CreateCap(50); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListInt64), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(int64_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kUint8), ge::GRAPH_SUCCESS); @@ -854,7 +729,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListUint8), ge::GRAPH_SUCCESS); @@ -866,27 +741,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint8Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListUint8Ok) { - auto data = TilingData::CreateCap(20); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListUint8), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint8_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kUint16), ge::GRAPH_SUCCESS); @@ -898,7 +757,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListUint16), ge::GRAPH_SUCCESS); @@ -910,27 +769,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint16Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListUint16Ok) { - auto data = TilingData::CreateCap(20); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListUint16), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint16_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kUint32), ge::GRAPH_SUCCESS); @@ -942,7 +785,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListUint32), ge::GRAPH_SUCCESS); @@ -954,27 +797,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint32Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListUint32Ok) { - auto data = TilingData::CreateCap(30); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListUint32), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint32_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrFloat32ToUint64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 1, AttrDataType::kFloat32, AttrDataType::kUint64), ge::GRAPH_SUCCESS); @@ -986,7 +813,7 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 2, AttrDataType::kListFloat32, AttrDataType::kListUint64), ge::GRAPH_SUCCESS); @@ -998,27 +825,11 @@ TEST_F(TilingDataUT, AppendAttrListFloat32ToListUint64Ok) { } } -TEST_F(TilingDataUT, AppendAttrListListFloat32ToListListUint64Ok) { - auto data = TilingData::CreateCap(50); - auto tiling_data = reinterpret_cast(data.get()); - auto holder = BuildTestContext(); - auto context = holder.GetContext(); - EXPECT_NE(context, nullptr); - EXPECT_EQ(tiling_data->AppendConvertedAttrVal( - context->GetAttrs(), 3, AttrDataType::kListListFloat32, AttrDataType::kListListUint64), ge::GRAPH_SUCCESS); - EXPECT_EQ(tiling_data->GetDataSize(), sizeof(uint64_t) * 6); - auto ele = reinterpret_cast(tiling_data->GetData()); - std::vector expect_data{1, 2, 3, 4, 5, 6}; - for (size_t i = 0UL; i < 6UL; ++i) { - EXPECT_EQ(ele[i], expect_data[i]); - } -} - TEST_F(TilingDataUT, AppendAttrInt32ToBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kBool), ge::GRAPH_SUCCESS); @@ -1030,7 +841,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListBool), ge::GRAPH_SUCCESS); @@ -1046,7 +857,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListBool), ge::GRAPH_SUCCESS); @@ -1062,7 +873,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kFloat16), ge::GRAPH_SUCCESS); @@ -1074,7 +885,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListFloat16), ge::GRAPH_SUCCESS); @@ -1090,7 +901,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListFloat16), ge::GRAPH_SUCCESS); @@ -1106,7 +917,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kBfloat16), ge::GRAPH_SUCCESS); @@ -1118,7 +929,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListBfloat16), ge::GRAPH_SUCCESS); @@ -1134,7 +945,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListBfloat16), @@ -1151,7 +962,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kFloat32), ge::GRAPH_SUCCESS); @@ -1163,7 +974,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListFloat32), ge::GRAPH_SUCCESS); @@ -1179,7 +990,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListFloat32Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListFloat32), @@ -1196,7 +1007,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kInt8), ge::GRAPH_SUCCESS); @@ -1208,7 +1019,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListInt8), ge::GRAPH_SUCCESS); @@ -1224,7 +1035,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListInt8), ge::GRAPH_SUCCESS); @@ -1240,7 +1051,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kInt16), ge::GRAPH_SUCCESS); @@ -1252,7 +1063,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListInt16), ge::GRAPH_SUCCESS); @@ -1268,7 +1079,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListInt16), ge::GRAPH_SUCCESS); @@ -1284,7 +1095,7 @@ TEST_F(TilingDataUT, AppendAttrInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kInt32), ge::GRAPH_SUCCESS); @@ -1296,7 +1107,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListInt32), ge::GRAPH_SUCCESS); @@ -1312,7 +1123,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListInt32), ge::GRAPH_SUCCESS); @@ -1328,7 +1139,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToInt64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kInt64), ge::GRAPH_SUCCESS); @@ -1340,7 +1151,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListInt64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListInt64), ge::GRAPH_SUCCESS); @@ -1356,7 +1167,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListInt64Ok) { auto data = TilingData::CreateCap(50); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListInt64), ge::GRAPH_SUCCESS); @@ -1372,7 +1183,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kUint8), ge::GRAPH_SUCCESS); @@ -1384,7 +1195,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListUint8), ge::GRAPH_SUCCESS); @@ -1400,7 +1211,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListUint8), ge::GRAPH_SUCCESS); @@ -1416,7 +1227,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kUint16), ge::GRAPH_SUCCESS); @@ -1428,7 +1239,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListUint16), ge::GRAPH_SUCCESS); @@ -1444,7 +1255,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListUint16), ge::GRAPH_SUCCESS); @@ -1460,7 +1271,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kUint32), ge::GRAPH_SUCCESS); @@ -1472,7 +1283,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListUint32), ge::GRAPH_SUCCESS); @@ -1488,7 +1299,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListUint32Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListUint32), ge::GRAPH_SUCCESS); @@ -1504,7 +1315,7 @@ TEST_F(TilingDataUT, AppendAttrInt32ToUint64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt32, AttrDataType::kUint64), ge::GRAPH_SUCCESS); @@ -1516,7 +1327,7 @@ TEST_F(TilingDataUT, AppendAttrListInt32ToListUint64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt32, AttrDataType::kListUint64), ge::GRAPH_SUCCESS); @@ -1532,7 +1343,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt32ToListListUint64Ok) { auto data = TilingData::CreateCap(50); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt32, AttrDataType::kListListUint64), ge::GRAPH_SUCCESS); @@ -1548,7 +1359,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kBool), ge::GRAPH_SUCCESS); @@ -1560,7 +1371,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListBool), ge::GRAPH_SUCCESS); @@ -1575,7 +1386,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListBoolOk) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListBool), ge::GRAPH_SUCCESS); @@ -1590,7 +1401,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kFloat16), ge::GRAPH_SUCCESS); @@ -1602,7 +1413,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListFloat16), ge::GRAPH_SUCCESS); @@ -1618,7 +1429,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListFloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListFloat16), ge::GRAPH_SUCCESS); @@ -1634,7 +1445,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kBfloat16), ge::GRAPH_SUCCESS); @@ -1646,7 +1457,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListBfloat16), ge::GRAPH_SUCCESS); @@ -1662,7 +1473,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListBfloat16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListBfloat16), @@ -1679,7 +1490,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kFloat32), ge::GRAPH_SUCCESS); @@ -1691,7 +1502,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListFloat32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListFloat32), ge::GRAPH_SUCCESS); @@ -1707,7 +1518,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListFloat32Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListFloat32), @@ -1724,7 +1535,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kInt8), ge::GRAPH_SUCCESS); @@ -1736,7 +1547,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListInt8), ge::GRAPH_SUCCESS); @@ -1752,7 +1563,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListInt8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListInt8), ge::GRAPH_SUCCESS); @@ -1768,7 +1579,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kInt16), ge::GRAPH_SUCCESS); @@ -1780,7 +1591,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListInt16), ge::GRAPH_SUCCESS); @@ -1796,7 +1607,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListInt16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListInt16), ge::GRAPH_SUCCESS); @@ -1812,7 +1623,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kInt32), ge::GRAPH_SUCCESS); @@ -1824,7 +1635,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListInt32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListInt32), ge::GRAPH_SUCCESS); @@ -1840,7 +1651,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ListListInt32Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListInt32), ge::GRAPH_SUCCESS); @@ -1856,7 +1667,7 @@ TEST_F(TilingDataUT, AppendAttrInt64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kInt64), ge::GRAPH_SUCCESS); @@ -1868,7 +1679,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListInt64), ge::GRAPH_SUCCESS); @@ -1884,7 +1695,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64Ok) { auto data = TilingData::CreateCap(50); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListInt64), ge::GRAPH_SUCCESS); @@ -1900,7 +1711,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kUint8), ge::GRAPH_SUCCESS); @@ -1912,7 +1723,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListUint8), ge::GRAPH_SUCCESS); @@ -1928,7 +1739,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListUint8Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListUint8), ge::GRAPH_SUCCESS); @@ -1944,7 +1755,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kUint16), ge::GRAPH_SUCCESS); @@ -1956,7 +1767,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListUint16), ge::GRAPH_SUCCESS); @@ -1972,7 +1783,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListUint16Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListUint16), ge::GRAPH_SUCCESS); @@ -1988,7 +1799,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kUint32), ge::GRAPH_SUCCESS); @@ -2000,7 +1811,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListUint32Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListUint32), ge::GRAPH_SUCCESS); @@ -2016,7 +1827,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListUint32Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListUint32), ge::GRAPH_SUCCESS); @@ -2032,7 +1843,7 @@ TEST_F(TilingDataUT, AppendAttrInt64ToUint64Ok) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 4, AttrDataType::kInt64, AttrDataType::kUint64), ge::GRAPH_SUCCESS); @@ -2044,7 +1855,7 @@ TEST_F(TilingDataUT, AppendAttrListInt64ToListUint64Ok) { auto data = TilingData::CreateCap(30); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 5, AttrDataType::kListInt64, AttrDataType::kListUint64), ge::GRAPH_SUCCESS); @@ -2060,7 +1871,7 @@ TEST_F(TilingDataUT, AppendAttrListListInt64ToListListUint64Ok) { auto data = TilingData::CreateCap(50); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 6, AttrDataType::kListListInt64, AttrDataType::kListListUint64), ge::GRAPH_SUCCESS); @@ -2076,7 +1887,7 @@ TEST_F(TilingDataUT, AppendAttrIndexInvalid) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 13, AttrDataType::kInt64, AttrDataType::kInt32), ge::GRAPH_FAILED); @@ -2086,7 +1897,7 @@ TEST_F(TilingDataUT, AppendAttrSrcTypeInvalid) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal(context->GetAttrs(), 7, AttrDataType::kString, AttrDataType::kInt32), ge::GRAPH_FAILED); @@ -2098,7 +1909,7 @@ TEST_F(TilingDataUT, AppendAttrDstTypeInvalid) { auto data = TilingData::CreateCap(20); auto tiling_data = reinterpret_cast(data.get()); auto holder = BuildTestContext(); - auto context = holder.GetContext(); + auto context = holder.GetContext(); EXPECT_NE(context, nullptr); EXPECT_EQ(tiling_data->AppendConvertedAttrVal( context->GetAttrs(), 7, AttrDataType::kBool, AttrDataType::kListListInt32), ge::GRAPH_FAILED); diff --git a/tests/ut/exe_graph/CMakeLists.txt b/tests/ut/exe_graph/CMakeLists.txt deleted file mode 100644 index 86c06ee67d16352f5315626a04f6d2041a41e250..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/CMakeLists.txt +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -file(GLOB_RECURSE TEST_SRCS CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/exe_graph/*.cc") -file(GLOB_RECURSE FAKER_SRCS CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/faker/*.cc") -file(GLOB_RECURSE GE_RUNTIME_STUB_SRCS CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/ge_runtime_stub/*.cc") - -add_executable(ut_exe_graph ${TEST_SRCS} ${FAKER_SRCS} ${GE_RUNTIME_STUB_SRCS}) - -target_include_directories(ut_exe_graph PRIVATE - ${METADEF_DIR} - ${METADEF_DIR}/tests/depends - ${METADEF_DIR}/exe_graph - ${CMAKE_BINARY_DIR} - ${CMAKE_BINARY_DIR}/proto/metadef_protos - ) - -target_compile_options(ut_exe_graph PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Wno-deprecated-declarations - -Wall -Wfloat-equal -Werror - -fno-access-control - ) - -target_compile_definitions(ut_exe_graph PRIVATE - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - ) - -target_link_libraries(ut_exe_graph PRIVATE - intf_pub - -Wl,--no-as-needed - slog_headers - platform_stub - msprof_headers - runtime_headers - exe_graph lowering register opp_registry rt2_registry_static error_manager - mmpa - GTest::gtest GTest::gtest_main slog_stub ascend_protobuf c_sec mmpa_stub -lrt -ldl -lgcov - metadef_headers - graph - graph_base - aihac_symbolizer - -Wl,--as-needed - ) diff --git a/tests/ut/exe_graph/abi_compatibility_for_exe_graph_unittest.cc b/tests/ut/exe_graph/abi_compatibility_for_exe_graph_unittest.cc deleted file mode 100644 index 74434fe1fcb46cdd5156868ba81a0972ec182455..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/abi_compatibility_for_exe_graph_unittest.cc +++ /dev/null @@ -1,334 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "exe_graph/runtime/tensor.h" -#include "exe_graph/runtime/compute_node_info.h" -#include "base/runtime/runtime_attrs_def.h" -#include "exe_graph/runtime/infer_shape_range_context.h" -#include "exe_graph/runtime/infer_shape_context.h" -#include "exe_graph/runtime/infer_datatype_context.h" -#include "exe_graph/runtime/atomic_clean_tiling_context.h" -#include "exe_graph/runtime/tiling_parse_context.h" -#include "exe_graph/runtime/tiling_data.h" - -namespace gert { -namespace { -constexpr const size_t kPointerSize = 8U; -constexpr const size_t kReservedFieldSize = 40U; -constexpr const size_t kComputeNodeInfoReservedFieldSize = 24U; -constexpr const size_t kExtendInfoReservedFieldSize = 56U; - -constexpr const size_t kShapeSize = 248U; -constexpr const size_t kExpandDimsTypeSize = 48U; -constexpr const size_t kStorageShapeSize = 536U; -constexpr const size_t kStorageFormatSize = 96U; -constexpr const size_t kTensorDataSize = 72U; -constexpr const size_t kTensorSize = 752U; -constexpr const size_t kCompileTimeTensorDescSize = 144U; -constexpr const size_t kAnchorInstanceInfoSize = 48U; -constexpr const size_t kComputeNodeInfoSize = 88U; -constexpr const size_t kRuntimeAttrsSize = 8U; -constexpr const size_t kRuntimeAttrsDefSize = 48U; -constexpr const size_t kContinuousVectorSize = 64U; -constexpr const size_t kContinuousVectorVectorSize = 64U; -constexpr const size_t kChainSize = 16U; -constexpr const size_t kRangeSize = 56U; -constexpr const size_t kKernelExtendInfoSize = 72U; -constexpr const size_t kKernelRunContextSize = 48U; -constexpr const size_t kTilingDataSize = 64U; -} // namespace -constexpr size_t Shape::kMaxDimNum; -class AbiCompatibilityForExeGraphUT : public testing::Test {}; - -TEST_F(AbiCompatibilityForExeGraphUT, Shape_CheckMemLayoutNotChanged) { - Shape s; - ASSERT_EQ(sizeof(s), kShapeSize); - ASSERT_EQ(static_cast(&s), static_cast(&s.dim_num_)); - - EXPECT_EQ(reinterpret_cast(&s.dims_) - reinterpret_cast(&s.dim_num_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&s.reserved_) - reinterpret_cast(&s.dims_), 25 * sizeof(int64_t)); - EXPECT_EQ(sizeof(s.reserved_), kReservedFieldSize); - - EXPECT_EQ(Shape::kMaxDimNum, 25); -} - -TEST_F(AbiCompatibilityForExeGraphUT, StorageShape_CheckMemLayoutNotChanged) { - StorageShape s; - ASSERT_EQ(sizeof(s), kStorageShapeSize); - ASSERT_EQ(static_cast(&s), static_cast(&s.origin_shape_)); - - EXPECT_EQ(reinterpret_cast(&s.storage_shape_) - reinterpret_cast(&s.origin_shape_), kShapeSize); - EXPECT_EQ(reinterpret_cast(&s.reserved_) - reinterpret_cast(&s.storage_shape_), kShapeSize); - EXPECT_EQ(sizeof(s.reserved_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, ExpandDimsType_CheckMemLayoutNotChanged) { - ExpandDimsType e; - ASSERT_EQ(sizeof(e), kExpandDimsTypeSize); - - EXPECT_EQ(reinterpret_cast(&e.reserved_) - reinterpret_cast(&e), 8); - EXPECT_EQ(sizeof(e.reserved_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, StorageFormat_CheckMemLayoutNotChanged) { - StorageFormat s; - ASSERT_EQ(sizeof(s), kStorageFormatSize); - ASSERT_EQ(static_cast(&s), static_cast(&s.origin_format_)); - - EXPECT_EQ(reinterpret_cast(&s.storage_format_) - reinterpret_cast(&s.origin_format_), - sizeof(ge::Format)); - EXPECT_EQ(reinterpret_cast(&s.expand_dims_type_) - reinterpret_cast(&s.storage_format_), - sizeof(ge::Format)); - EXPECT_EQ(reinterpret_cast(&s.reserved_) - reinterpret_cast(&s.expand_dims_type_), - kExpandDimsTypeSize); - EXPECT_EQ(sizeof(s.reserved_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, TensorData_CheckMemLayoutNotChanged) { - TensorData t; - ASSERT_EQ(sizeof(t), kTensorDataSize); - ASSERT_EQ(static_cast(&t), static_cast(&t.addr_)); - - EXPECT_EQ(reinterpret_cast(&t.manager_) - reinterpret_cast(&t.addr_), kPointerSize); - EXPECT_EQ(reinterpret_cast(&t.size_) - reinterpret_cast(&t.manager_), kPointerSize); - EXPECT_EQ(reinterpret_cast(&t.placement_) - reinterpret_cast(&t.size_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&t.reserved_0_) - reinterpret_cast(&t.placement_), - sizeof(gert::TensorPlacement)); - EXPECT_EQ(reinterpret_cast(&t.reserved_1_) - reinterpret_cast(&t.reserved_0_), - sizeof(uint32_t)); - EXPECT_EQ(sizeof(t.reserved_1_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, Tensor_CheckMemLayoutNotChanged) { - Tensor t; - ASSERT_EQ(sizeof(t), kTensorSize); - ASSERT_EQ(static_cast(&t), static_cast(&t.storage_shape_)); - - EXPECT_EQ(reinterpret_cast(&t.storage_format_) - reinterpret_cast(&t.storage_shape_), - kStorageShapeSize); - EXPECT_EQ(reinterpret_cast(&t.reserved_) - reinterpret_cast(&t.storage_format_), - kStorageFormatSize); - EXPECT_EQ(reinterpret_cast(&t.data_type_) - reinterpret_cast(&t.reserved_), 4); - EXPECT_EQ(reinterpret_cast(&t.tensor_data_) - reinterpret_cast(&t.data_type_), - sizeof(ge::DataType)); - EXPECT_EQ(reinterpret_cast(&t.reserved_field_) - reinterpret_cast(&t.tensor_data_), - kTensorDataSize); - EXPECT_EQ(sizeof(t.reserved_field_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, CompileTimeTensorDesc_CheckMemLayoutNotChanged) { - CompileTimeTensorDesc t; - ASSERT_EQ(sizeof(t), kCompileTimeTensorDescSize); - ASSERT_EQ(static_cast(&t), static_cast(&t.data_type_)); - - EXPECT_EQ(reinterpret_cast(&t.storage_format_) - reinterpret_cast(&t.data_type_), 8); - EXPECT_EQ(reinterpret_cast(&t.reserved_) - reinterpret_cast(&t.storage_format_), - kStorageFormatSize); - EXPECT_EQ(sizeof(t.reserved_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, AnchorInstanceInfo_CheckMemLayoutNotChanged) { - AnchorInstanceInfo a; - ASSERT_EQ(sizeof(a), kAnchorInstanceInfoSize); - ASSERT_EQ(static_cast(&a), static_cast(&a.instance_start_)); - - EXPECT_EQ(reinterpret_cast(&a.instantiation_num_) - reinterpret_cast(&a.instance_start_), 4); - EXPECT_EQ(reinterpret_cast(&a.reserved_) - reinterpret_cast(&a.instantiation_num_), 4); - EXPECT_EQ(sizeof(a.reserved_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, ComputeNodeInfo_CheckMemLayoutNotChanged) { - auto holder = malloc(kComputeNodeInfoSize); - auto c = reinterpret_cast(holder); - - ASSERT_EQ(static_cast(c), static_cast(&c->node_type_)); - ASSERT_EQ(reinterpret_cast(&c->place_holder) - reinterpret_cast(c), kComputeNodeInfoSize - 8); - - EXPECT_EQ(reinterpret_cast(&c->node_name_) - reinterpret_cast(&c->node_type_), sizeof(char *)); - EXPECT_EQ(reinterpret_cast(&c->ir_inputs_num_) - reinterpret_cast(&c->node_name_), - sizeof(char *)); - EXPECT_EQ(reinterpret_cast(&c->inputs_num_) - reinterpret_cast(&c->ir_inputs_num_), - sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&c->outputs_num_) - reinterpret_cast(&c->inputs_num_), - sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&c->reserved_) - reinterpret_cast(&c->outputs_num_), sizeof(size_t) * 3); - EXPECT_EQ(reinterpret_cast(&c->place_holder) - reinterpret_cast(&c->reserved_), - kComputeNodeInfoReservedFieldSize); - - free(holder); -} - -TEST_F(AbiCompatibilityForExeGraphUT, RuntimeAttrs_CheckMemLayoutNotChanged) { - auto holder = malloc(kRuntimeAttrsSize); - auto c = reinterpret_cast(holder); - - ASSERT_EQ(static_cast(c), static_cast(&c->placeholder_)); - EXPECT_EQ(sizeof(c->placeholder_), sizeof(uint64_t)); - - free(holder); -} - -TEST_F(AbiCompatibilityForExeGraphUT, RuntimeAttrsDef_CheckMemLayoutNotChanged) { - RuntimeAttrsDef r; - ASSERT_EQ(sizeof(r), kRuntimeAttrsDefSize); - ASSERT_EQ(static_cast(&r), static_cast(&r.attr_num)); - - EXPECT_EQ(reinterpret_cast(&r.reserved_) - reinterpret_cast(&r.attr_num), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&r.offset) - reinterpret_cast(&r.reserved_), kReservedFieldSize); - EXPECT_EQ(sizeof(r.offset), 0); -} - -TEST_F(AbiCompatibilityForExeGraphUT, ContinuousVector_CheckMemLayoutNotChanged) { - ContinuousVector c; - ASSERT_EQ(sizeof(c), kContinuousVectorSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.capacity_)); - - EXPECT_EQ(reinterpret_cast(&c.size_) - reinterpret_cast(&c.capacity_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&c.reserved_) - reinterpret_cast(&c.size_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&c.elements) - reinterpret_cast(&c.reserved_), kReservedFieldSize); - EXPECT_EQ(sizeof(c.elements), 8); -} - -TEST_F(AbiCompatibilityForExeGraphUT, TypedContinuousVector_CheckMemLayoutNotChanged) { - TypedContinuousVector c; - ASSERT_EQ(sizeof(c), kContinuousVectorSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.capacity_)); -} - -TEST_F(AbiCompatibilityForExeGraphUT, ContinuousVectorVector_CheckMemLayoutNotChanged) { - ContinuousVectorVector c; - ASSERT_EQ(sizeof(c), kContinuousVectorVectorSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.capacity_)); - - EXPECT_EQ(reinterpret_cast(&c.size_) - reinterpret_cast(&c.capacity_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&c.reserved_) - reinterpret_cast(&c.size_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&c.offset_) - reinterpret_cast(&c.reserved_), kReservedFieldSize); - EXPECT_EQ(sizeof(c.offset_), 8); -} - -TEST_F(AbiCompatibilityForExeGraphUT, Chain_CheckMemLayoutNotChanged) { - Chain c; - ASSERT_EQ(sizeof(c), kChainSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.any_value_)); - - EXPECT_EQ(sizeof(c.any_value_), kChainSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, Range_CheckMemLayoutNotChanged) { - Range r; - ASSERT_EQ(sizeof(r), kRangeSize); - ASSERT_EQ(static_cast(&r), static_cast(&r.min_)); - - EXPECT_EQ(reinterpret_cast(&r.max_) - reinterpret_cast(&r.min_), kPointerSize); - EXPECT_EQ(reinterpret_cast(&r.reserved_) - reinterpret_cast(&r.max_), kPointerSize); - EXPECT_EQ(sizeof(r.reserved_), kReservedFieldSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, KernelExtendInfo_CheckMemLayoutNotChanged) { - auto holder = malloc(kKernelExtendInfoSize); - auto k = reinterpret_cast(holder); - - ASSERT_EQ(reinterpret_cast(&k->reserved_) - reinterpret_cast(k), - kKernelExtendInfoSize - kExtendInfoReservedFieldSize); - ASSERT_EQ(static_cast(k), static_cast(&k->kernel_name_)); - - EXPECT_EQ(reinterpret_cast(&k->kernel_type_) - reinterpret_cast(&k->kernel_name_), - kPointerSize); - free(holder); -} - -TEST_F(AbiCompatibilityForExeGraphUT, KernelRunContext_CheckMemLayoutNotChanged) { - KernelRunContext k; - ASSERT_EQ(sizeof(KernelRunContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&k), static_cast(&k.input_size)); - - EXPECT_EQ(reinterpret_cast(&k.output_size) - reinterpret_cast(&k.input_size), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&k.compute_node_info) - reinterpret_cast(&k.output_size), - sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&k.kernel_extend_info) - reinterpret_cast(&k.compute_node_info), - kPointerSize); - EXPECT_EQ(reinterpret_cast(&k.output_start) - reinterpret_cast(&k.kernel_extend_info), - kPointerSize); - EXPECT_EQ(reinterpret_cast(&k.values) - reinterpret_cast(&k.output_start), kPointerSize); - EXPECT_EQ(sizeof(k.values), kPointerSize); -} - - -TEST_F(AbiCompatibilityForExeGraphUT, KernelContext_CheckMemLayoutNotChanged) { - KernelContext k; - ASSERT_EQ(sizeof(KernelContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&k), static_cast(&k.context_)); - EXPECT_EQ(sizeof(k.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, ExtendedKernelContext_CheckMemLayoutNotChanged) { - ExtendedKernelContext k; - ASSERT_EQ(sizeof(ExtendedKernelContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&k), static_cast(&k.context_)); - EXPECT_EQ(sizeof(k.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, InferShapeContext_CheckMemLayoutNotChanged) { - InferShapeContext c; - ASSERT_EQ(sizeof(InferShapeContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.context_)); - EXPECT_EQ(sizeof(c.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, InferShapeRangeContext_CheckMemLayoutNotChanged) { - InferShapeRangeContext c; - ASSERT_EQ(sizeof(InferShapeRangeContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.context_)); - EXPECT_EQ(sizeof(c.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, InferDataTypeContext_CheckMemLayoutNotChanged) { - InferDataTypeContext c; - ASSERT_EQ(sizeof(InferDataTypeContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.context_)); - EXPECT_EQ(sizeof(c.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, TilingContext_CheckMemLayoutNotChanged) { - TilingContext c; - ASSERT_EQ(sizeof(TilingContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.context_)); - EXPECT_EQ(sizeof(c.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, TilingParseContext_CheckMemLayoutNotChanged) { - TilingParseContext c; - ASSERT_EQ(sizeof(TilingParseContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.context_)); - EXPECT_EQ(sizeof(c.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, AtomicCleanTilingContext_CheckMemLayoutNotChanged) { - AtomicCleanTilingContext c; - ASSERT_EQ(sizeof(AtomicCleanTilingContext), kKernelRunContextSize); - ASSERT_EQ(static_cast(&c), static_cast(&c.context_)); - EXPECT_EQ(sizeof(c.context_), kKernelRunContextSize); -} - -TEST_F(AbiCompatibilityForExeGraphUT, TilingData_CheckMemLayoutNotChanged) { - auto holder = malloc(kTilingDataSize); - auto t = reinterpret_cast(holder); - ASSERT_EQ(reinterpret_cast(&t->reserved_) - reinterpret_cast(t), - kTilingDataSize - kReservedFieldSize); - ASSERT_EQ(static_cast(t), static_cast(&t->capacity_)); - - EXPECT_EQ(reinterpret_cast(&t->data_size_) - reinterpret_cast(&t->capacity_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&t->data_) - reinterpret_cast(&t->data_size_), sizeof(size_t)); - EXPECT_EQ(reinterpret_cast(&t->reserved_) - reinterpret_cast(&t->data_), kPointerSize); - EXPECT_EQ(sizeof(t->reserved_), kReservedFieldSize); - - free(holder); -} -} // namespace gert diff --git a/tests/ut/exe_graph/bg_ir_attrs_unittest.cc b/tests/ut/exe_graph/bg_ir_attrs_unittest.cc deleted file mode 100644 index 5b072e3b8174e041e490f9f98552b8f988d7ac90..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/bg_ir_attrs_unittest.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include -#include -#include -#include "graph/compute_graph.h" -#include "graph/utils/node_utils.h" -#include "exe_graph/runtime/context_extend.h" -#include "exe_graph/runtime/continuous_vector.h" -#include "runtime/runtime_attrs_def.h" -#include "exe_graph/runtime/tensor.h" -#include "exe_graph/lowering/bg_ir_attrs.h" -#include "graph/debug/ge_attr_define.h" -#include "expand_dimension.h" - -namespace gert { -class BgIrAttrsUT : public testing::Test {}; -// 构造tensorAttr,其shape小于size,测试AppendTensorAtrr函数能够正常内存拷贝 -TEST_F(BgIrAttrsUT, ShapeSmallerThanSizeOfTensorAttr) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc ge_td; - ge_td.SetOriginFormat(ge::FORMAT_NHWC); - ge_td.SetFormat(ge::FORMAT_NHWC); - ge_td.SetDataType(ge::DT_FLOAT16); - ge_td.SetOriginShape(ge::GeShape({10, 10})); - ge_td.SetShape(ge::GeShape({10, 10})); - ge::GeTensor ge_tensor(ge_td); - std::vector fake_data(12 * 12); - for (size_t i = 0; i < fake_data.size(); ++i) { - fake_data[i] = static_cast(i % std::numeric_limits::max()); - } - ge_tensor.SetData(reinterpret_cast(fake_data.data()), fake_data.size() * 2); - ge::AttrUtils::SetTensor(op_desc, "a1", ge_tensor); - op_desc->AppendIrAttrName("a1"); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - auto attrs = compute_node_info->GetAttrs(); - ASSERT_NE(attrs, nullptr); - EXPECT_EQ(attrs->GetAttrNum(), 1); - - auto gert_tensor = attrs->GetAttrPointer(0); - EXPECT_EQ(attrs->GetTensor(0), gert_tensor); - ASSERT_NE(gert_tensor, nullptr); - EXPECT_EQ(gert_tensor->GetOriginShape(), gert::Shape({10, 10})); - EXPECT_EQ(gert_tensor->GetStorageShape(), gert::Shape({10, 10})); - EXPECT_EQ(gert_tensor->GetOriginFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(gert_tensor->GetStorageFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(gert_tensor->GetDataType(), ge::DT_FLOAT16); - auto gert_tensor_ptr = gert_tensor->GetData(); - EXPECT_NE(gert_tensor_ptr, nullptr); - for (size_t i = 0; i < 10 * 10; ++i) { - EXPECT_EQ(gert_tensor_ptr[i], static_cast(i % std::numeric_limits::max())); - } -} -TEST_F(BgIrAttrsUT, CreateDataTypeAttrBuffer) { - auto op_desc = std::make_shared("foo", "Foo"); - op_desc->AppendIrAttrName("dtype"); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), 1U); - EXPECT_EQ(op_desc->GetIrAttrNames().at(0), "dtype"); - ge::AttrUtils::SetDataType(op_desc, "dtype", ge::DT_INT32); - auto node = ge::NodeUtils::CreatNodeWithoutGraph(op_desc); - size_t attr_size; - auto attr_buffer = bg::CreateAttrBuffer(node, attr_size); - auto rt_attr_def = reinterpret_cast(attr_buffer.get()); - ASSERT_NE(rt_attr_def, nullptr); - EXPECT_EQ(rt_attr_def->attr_num, 1U); - for (size_t i = 0U; i < 40U; ++i) { - EXPECT_EQ(rt_attr_def->reserved_[i], 0); - } - EXPECT_EQ(rt_attr_def->offset[0], 2 * sizeof(size_t) + sizeof(rt_attr_def->reserved_)); - auto base = reinterpret_cast(attr_buffer.get()); - EXPECT_EQ(*reinterpret_cast(base + rt_attr_def->offset[0]), ge::DT_INT32); -} - -TEST_F(BgIrAttrsUT, CreateAttrBufferSuccessOpLossAttr) { - auto op_desc = std::make_shared("foo", "Foo"); - op_desc->AppendIrAttrName("dtype"); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), 1U); - EXPECT_EQ(op_desc->GetIrAttrNames().at(0), "dtype"); - auto node = ge::NodeUtils::CreatNodeWithoutGraph(op_desc); - size_t attr_size; - auto attr_buffer = bg::CreateAttrBuffer(node, attr_size); - size_t gt_attr_size = sizeof(RuntimeAttrsDef); - EXPECT_EQ(attr_size, gt_attr_size); - auto base = reinterpret_cast(attr_buffer.get()); - EXPECT_EQ(base[0], 0U); - EXPECT_EQ(base[1], 0U); // todo 原始用例,没加预留字段之前,base[1]为啥能取到值 -} - -TEST_F(BgIrAttrsUT, CreateListListIntAttrBuffer_Int64Ok) { - auto op_desc = std::make_shared("foo", "Foo"); - op_desc->AppendIrAttrName("axes"); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), 1U); - EXPECT_EQ(op_desc->GetIrAttrNames().at(0), "axes"); - std::vector> value{{}, {1}, {1, 2}, {1, 2, 3}}; - ge::AttrUtils::SetListListInt(op_desc, "axes", value); - auto node = ge::NodeUtils::CreatNodeWithoutGraph(op_desc); - size_t attr_size; - auto attr_buffer = bg::CreateAttrBuffer(node, attr_size); - auto rt_attr_def = reinterpret_cast(attr_buffer.get()); - ASSERT_NE(rt_attr_def, nullptr); - EXPECT_EQ(rt_attr_def->attr_num, 1U); - for (size_t i = 0U; i < 40U; ++i) { - EXPECT_EQ(rt_attr_def->reserved_[i], 0); - } - EXPECT_EQ(rt_attr_def->offset[0], 2 * sizeof(size_t) + sizeof(rt_attr_def->reserved_)); - - auto base = reinterpret_cast(attr_buffer.get()); - auto cvv = reinterpret_cast(base + rt_attr_def->offset[0]); - ASSERT_NE(cvv, nullptr); - ASSERT_EQ(cvv->GetSize(), value.size()); - for (size_t i = 0U; i < value.size(); ++i) { - auto cv = cvv->Get(i); - ASSERT_NE(cv, nullptr); - ASSERT_EQ(cv->GetSize(), value[i].size()); - ASSERT_EQ(cv->GetSize(), cv->GetCapacity()); - auto data = reinterpret_cast(cv->GetData()); - for (size_t j = 0U; j < value[i].size(); ++j) { - EXPECT_EQ(data[j], value[i][j]); - } - } -} - -TEST_F(BgIrAttrsUT, CreateListListIntAttrBuffer_Float64) { - auto op_desc = std::make_shared("foo", "Foo"); - op_desc->AppendIrAttrName("axes"); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), 1U); - EXPECT_EQ(op_desc->GetIrAttrNames().at(0), "axes"); - std::vector> value{{}, {1}, {1, 2}, {1, 2, 3}}; - ge::AttrUtils::SetListListFloat(op_desc, "axes", value); - auto node = ge::NodeUtils::CreatNodeWithoutGraph(op_desc); - size_t attr_size; - auto attr_buffer = bg::CreateAttrBuffer(node, attr_size); - auto rt_attr_def = reinterpret_cast(attr_buffer.get()); - ASSERT_NE(rt_attr_def, nullptr); - EXPECT_EQ(rt_attr_def->attr_num, 1U); - for (size_t i = 0U; i < 40U; ++i) { - EXPECT_EQ(rt_attr_def->reserved_[i], 0); - } - EXPECT_EQ(rt_attr_def->offset[0], 2 * sizeof(size_t) + sizeof(rt_attr_def->reserved_)); - - auto base = reinterpret_cast(attr_buffer.get()); - auto cvv = reinterpret_cast(base + rt_attr_def->offset[0]); - ASSERT_NE(cvv, nullptr); - ASSERT_EQ(cvv->GetSize(), value.size()); - for (size_t i = 0U; i < value.size(); ++i) { - auto cv = cvv->Get(i); - ASSERT_NE(cv, nullptr); - ASSERT_EQ(cv->GetSize(), value[i].size()); - ASSERT_EQ(cv->GetSize(), cv->GetCapacity()); - auto data = reinterpret_cast(cv->GetData()); - for (size_t j = 0U; j < value[i].size(); ++j) { - EXPECT_EQ(data[j], value[i][j]); - } - } -} -TEST_F(BgIrAttrsUT, CreateStringAttrBuffer) { - auto op_desc = std::make_shared("foo", "Foo"); - op_desc->AppendIrAttrName("demo_str"); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), 1U); - EXPECT_EQ(op_desc->GetIrAttrNames().at(0), "demo_str"); - std::string str_attr = "hello"; - ge::AttrUtils::SetStr(op_desc, "demo_str", str_attr); - auto node = ge::NodeUtils::CreatNodeWithoutGraph(op_desc); - size_t attr_size; - auto attr_buffer = bg::CreateAttrBuffer(node, attr_size); - - auto base = reinterpret_cast(reinterpret_cast(attr_buffer.get())->offset + 1); - EXPECT_STREQ(base, "hello"); -} - -TEST_F(BgIrAttrsUT, CreateListStringAttrBuffer) { - auto op_desc = std::make_shared("foo", "Foo"); - op_desc->AppendIrAttrName("demo_str"); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), 1U); - EXPECT_EQ(op_desc->GetIrAttrNames().at(0), "demo_str"); - std::string str_attr1 = "hello"; - std::string str_attr2 = "world"; - std::string str_attr3 = "good"; - std::string str_attr4 = "job"; - std::vector str_atts = {str_attr1, str_attr2, str_attr3, str_attr4}; - ge::AttrUtils::SetListStr(op_desc, "demo_str", str_atts); - auto node = ge::NodeUtils::CreatNodeWithoutGraph(op_desc); - size_t attr_size; - auto attr_buffer = bg::CreateAttrBuffer(node, attr_size); - auto attr_def = reinterpret_cast(attr_buffer.get()); - auto base = - reinterpret_cast(ge::PtrToPtr(attr_def) - + attr_def->offset[0]); - ASSERT_NE(base, nullptr); - size_t str_attrs_len = 0U; - for (const auto &str_attr : str_atts) { - str_attrs_len += strlen(str_attr.c_str()) + 1U; - } - size_t gt_attr_size = sizeof(ContinuousVector) + sizeof(RuntimeAttrsDef) + 1 * sizeof(size_t) + +str_attrs_len; - EXPECT_EQ(attr_size, gt_attr_size); - ASSERT_EQ(base->GetSize(), 4); - EXPECT_STREQ(reinterpret_cast(base->GetData()), "hello"); - EXPECT_STREQ(reinterpret_cast(base->GetData()) + 6, "world"); - EXPECT_STREQ(reinterpret_cast(base->GetData()) + 12, "good"); - EXPECT_STREQ(reinterpret_cast(base->GetData()) + 17, "job"); -} - -TEST_F(BgIrAttrsUT, CreateAttrBufferWithoutIrAttr) { - auto op_desc = std::make_shared("foo", "Foo"); - op_desc->AppendIrAttrName("dtype"); - ge::AttrUtils::SetDataType(op_desc, "dtype", ge::DT_INT32); - - auto node = ge::NodeUtils::CreatNodeWithoutGraph(op_desc); - size_t attr_size; - ge::AnyValue value = ge::AnyValue::CreateFrom(2); - auto attr_buffer = bg::CreateAttrBufferWithoutIr(node, {value}, attr_size); - size_t gt_attr_size = sizeof(RuntimeAttrsDef) + sizeof(size_t) + sizeof(int64_t); - EXPECT_EQ(attr_size, gt_attr_size); - auto rt_attr_def = reinterpret_cast(attr_buffer.get()); - ASSERT_NE(rt_attr_def, nullptr); - EXPECT_EQ(rt_attr_def->attr_num, 1U); - EXPECT_EQ(rt_attr_def->offset[0], sizeof(size_t) + sizeof(RuntimeAttrsDef)); - auto base = reinterpret_cast(rt_attr_def); - EXPECT_EQ(*reinterpret_cast(base + rt_attr_def->offset[0]), 2); -} -} // namespace gert diff --git a/tests/ut/exe_graph/bg_kernel_context_extend_unittest.cc b/tests/ut/exe_graph/bg_kernel_context_extend_unittest.cc deleted file mode 100644 index 116b73bc394909d1d81b20a4082a51cfa188008f..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/bg_kernel_context_extend_unittest.cc +++ /dev/null @@ -1,1247 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include -#include -#include -#include -#include "graph/compute_graph.h" -#include "exe_graph/runtime/context_extend.h" -#include "exe_graph/runtime/continuous_vector.h" -#include "exe_graph/runtime/tensor.h" -#include "graph/debug/ge_attr_define.h" -#include "expand_dimension.h" -#include "graph/utils/graph_utils.h" -#include "runtime/runtime_attrs_def.h" -namespace { -bool isMemoryCleared(const uint8_t *ptr, size_t size) { - if (ptr == nullptr) { - return false; - } - return std::all_of(ptr, ptr + size, [](uint8_t byte) { return byte == 0; }); -} -} -namespace gert { -class BgKernelContextExtendUT : public testing::Test {}; -TEST_F(BgKernelContextExtendUT, BuildRequiredInput) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputRequired); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - auto expand_dims_type = td->GetExpandDimsType(); - Shape origin_shape({8, 3, 224, 224}); - Shape storage_shape; - expand_dims_type.Expand(origin_shape, storage_shape); - EXPECT_EQ(storage_shape, origin_shape); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildEmptyRequiredInput) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputRequired); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - EXPECT_NE(ret, nullptr); -} -TEST_F(BgKernelContextExtendUT, UknownInputFailed) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x2", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - EXPECT_NE(ret, nullptr); - - auto compute_node_info = reinterpret_cast(ret.get()); - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 0); -} -TEST_F(BgKernelContextExtendUT, BuildWithOptionalInputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithOptionalInputsNotExists) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("y", ge::kIrInputOptional); - op_desc->AppendIrInput("x", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 2); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithMultipleOptionalInputsIns) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputOptional); - op_desc->AppendIrInput("y", ge::kIrInputOptional); - op_desc->AppendIrInput("z", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - data_op_desc->AddOutputDesc("y", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node->GetInDataAnchor(1)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 2); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 3); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 1); - ins_info = compute_node_info->GetInputInstanceInfo(2); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithDynamicInputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x0", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputDynamic); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithMultiInstanceDynamicInputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x0", tensor_desc); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AddInputDesc("y0", tensor_desc); - op_desc->AddInputDesc("y1", tensor_desc); - op_desc->AddInputDesc("y2", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputDynamic); - op_desc->AppendIrInput("y", ge::kIrInputDynamic); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x0", tensor_desc); - data_op_desc->AddOutputDesc("x1", tensor_desc); - data_op_desc->AddOutputDesc("y0", tensor_desc); - data_op_desc->AddOutputDesc("y1", tensor_desc); - data_op_desc->AddOutputDesc("y2", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(2), node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(3), node->GetInDataAnchor(3)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(4), node->GetInDataAnchor(4)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 5); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 2); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 2); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 3); - EXPECT_EQ(ins_info->GetInstanceStart(), 2); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithEmptyDynamicInputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x0", tensor_desc); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("y", ge::kIrInputDynamic); - op_desc->AppendIrInput("x", ge::kIrInputDynamic); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x0", tensor_desc); - data_op_desc->AddOutputDesc("x1", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node->GetInDataAnchor(1)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 2); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 2); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 2); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithOneAttr) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputRequired); - op_desc->AddOutputDesc("y", tensor_desc); - - ge::AttrUtils::SetStr(op_desc, "a1", "Hello"); - op_desc->AppendIrAttrName("a1"); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 1); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 1); - ASSERT_NE(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); - EXPECT_STREQ(compute_node_info->GetAttrs()->GetAttrPointer(0), "Hello"); -} -TEST_F(BgKernelContextExtendUT, BuildWithAttrs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputRequired); - op_desc->AddOutputDesc("y", tensor_desc); - - ge::GeTensorDesc ge_td; - ge_td.SetOriginFormat(ge::FORMAT_NHWC); - ge_td.SetFormat(ge::FORMAT_NC1HWC0); - ge_td.SetDataType(ge::DT_FLOAT16); - ge_td.SetOriginShape(ge::GeShape({8, 224, 224, 3})); - ge_td.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - ge::GeTensor ge_tensor(ge_td); - std::vector fake_data(8 * 224 * 224 * 16); - for (size_t i = 0; i < fake_data.size(); ++i) { - fake_data[i] = static_cast(rand() % std::numeric_limits::max()); - } - ge_tensor.SetData(reinterpret_cast(fake_data.data()), fake_data.size() * 2); - - ge::AttrUtils::SetStr(op_desc, "a1", "Hello"); - ge::AttrUtils::SetInt(op_desc, "a2", 10240); - ge::AttrUtils::SetBool(op_desc, "a3", true); - ge::AttrUtils::SetFloat(op_desc, "a4", 1024.0021); - ge::AttrUtils::SetListInt(op_desc, "a5", std::vector({10, 200, 3000, 4096})); - ge::AttrUtils::SetTensor(op_desc, "a6", ge_tensor); - ge::AttrUtils::SetStr(op_desc, "b1", "World"); - ge::AttrUtils::SetInt(op_desc, "b2", 1024000); - ge::AttrUtils::SetBool(op_desc, "b3", false); - ge::AttrUtils::SetFloat(op_desc, "b4", 1024.1); - ge::AttrUtils::SetListInt(op_desc, "b5", std::vector({10, 400, 3000, 8192})); - ge::AttrUtils::SetTensor(op_desc, "b6", ge_tensor); - ge::AttrUtils::SetListStr(op_desc, "c1", std::vector({"hello", "world", "world1", "hello1"})); - ge::AttrUtils::SetListDataType(op_desc, "c2", std::vector({ge::DT_FLOAT, ge::DT_STRING, ge::DT_UINT16, ge::DT_BOOL})); - - op_desc->AppendIrAttrName("b1"); - op_desc->AppendIrAttrName("b2"); - op_desc->AppendIrAttrName("b3"); - op_desc->AppendIrAttrName("b4"); - op_desc->AppendIrAttrName("b5"); - op_desc->AppendIrAttrName("b6"); - op_desc->AppendIrAttrName("a1"); - op_desc->AppendIrAttrName("a2"); - op_desc->AppendIrAttrName("a3"); - op_desc->AppendIrAttrName("a4"); - op_desc->AppendIrAttrName("a5"); - op_desc->AppendIrAttrName("a6"); - op_desc->AppendIrAttrName("c1"); - op_desc->AppendIrAttrName("c2"); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 1); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - auto attrs = compute_node_info->GetAttrs(); - EXPECT_EQ(attrs->GetAttrNum(), 14); - - EXPECT_STREQ(attrs->GetAttrPointer(0), "World"); - EXPECT_STREQ(attrs->GetStr(0), "World"); - EXPECT_EQ(*attrs->GetAttrPointer(1), 1024000); - EXPECT_EQ(*attrs->GetInt(1), 1024000); - EXPECT_EQ(*attrs->GetAttrPointer(2), false); - EXPECT_EQ(*attrs->GetBool(2), false); - EXPECT_FLOAT_EQ(*attrs->GetAttrPointer(3), 1024.1); - EXPECT_FLOAT_EQ(*attrs->GetFloat(3), 1024.1); - auto list_int = attrs->GetAttrPointer(4); - ASSERT_NE(list_int, nullptr); - ASSERT_EQ(list_int->GetSize(), 4); - EXPECT_EQ(memcmp(list_int->GetData(), std::vector({10, 400, 3000, 8192}).data(), 4 * sizeof(int64_t)), 0); - auto typed_list_int = attrs->GetListInt(4); - ASSERT_NE(typed_list_int, nullptr); - ASSERT_EQ(typed_list_int->GetSize(), 4); - EXPECT_EQ(typed_list_int->GetData()[0], 10); - EXPECT_EQ(typed_list_int->GetData()[1], 400); - EXPECT_EQ(typed_list_int->GetData()[2], 3000); - EXPECT_EQ(typed_list_int->GetData()[3], 8192); - - auto gert_tensor = attrs->GetAttrPointer(5); - EXPECT_EQ(attrs->GetTensor(5), gert_tensor); - ASSERT_NE(gert_tensor, nullptr); - EXPECT_EQ(gert_tensor->GetOriginShape(), gert::Shape({8, 224, 224, 3})); - EXPECT_EQ(gert_tensor->GetStorageShape(), gert::Shape({8, 1, 224, 224, 16})); - EXPECT_EQ(gert_tensor->GetOriginFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(gert_tensor->GetStorageFormat(), ge::FORMAT_NC1HWC0); - EXPECT_EQ(gert_tensor->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(memcmp(gert_tensor->GetData(), fake_data.data(), fake_data.size() * sizeof(uint16_t)), 0); - - EXPECT_STREQ(attrs->GetAttrPointer(6), "Hello"); - EXPECT_EQ(*attrs->GetAttrPointer(7), 10240); - EXPECT_EQ(*attrs->GetAttrPointer(8), true); - EXPECT_FLOAT_EQ(*attrs->GetAttrPointer(9), 1024.0021); - list_int = attrs->GetAttrPointer(10); - ASSERT_NE(list_int, nullptr); - ASSERT_EQ(list_int->GetSize(), 4); - EXPECT_EQ(memcmp(list_int->GetData(), std::vector({10, 200, 3000, 4096}).data(), 4 * sizeof(int64_t)), 0); - gert_tensor = attrs->GetAttrPointer(11); - ASSERT_NE(gert_tensor, nullptr); - EXPECT_EQ(gert_tensor->GetOriginShape(), gert::Shape({8, 224, 224, 3})); - EXPECT_EQ(gert_tensor->GetStorageShape(), gert::Shape({8, 1, 224, 224, 16})); - EXPECT_EQ(gert_tensor->GetOriginFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(gert_tensor->GetStorageFormat(), ge::FORMAT_NC1HWC0); - EXPECT_EQ(gert_tensor->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(memcmp(gert_tensor->GetData(), fake_data.data(), fake_data.size() * sizeof(uint16_t)), 0); - - auto list_str = attrs->GetAttrPointer(12); - ASSERT_NE(list_str, nullptr); - ASSERT_EQ(list_str->GetSize(), 4); - EXPECT_STREQ(reinterpret_cast(list_str->GetData()), "hello"); - EXPECT_STREQ(reinterpret_cast(list_str->GetData()) + 6, "world"); - EXPECT_STREQ(reinterpret_cast(list_str->GetData()) + 12, "world1"); - EXPECT_STREQ(reinterpret_cast(list_str->GetData()) + 19, "hello1"); - - auto list_datatype = attrs->GetAttrPointer(13); - ASSERT_NE(list_datatype, nullptr); - ASSERT_EQ(list_datatype->GetSize(), 4); - EXPECT_EQ(memcmp(list_datatype->GetData(), std::vector({ge::DT_FLOAT, ge::DT_STRING, ge::DT_UINT16, ge::DT_BOOL}).data(), 4 * sizeof(ge::DataType)), 0); - auto typed_list_datatype = attrs->GetAttrPointer>(13); - ASSERT_NE(typed_list_datatype, nullptr); - ASSERT_EQ(typed_list_datatype->GetSize(), 4); - EXPECT_EQ((ge::DataType)(typed_list_datatype->GetData()[0]), ge::DT_FLOAT); - EXPECT_EQ((ge::DataType)typed_list_datatype->GetData()[1], ge::DT_STRING); - EXPECT_EQ((ge::DataType)typed_list_datatype->GetData()[2], ge::DT_UINT16); - EXPECT_EQ((ge::DataType)typed_list_datatype->GetData()[3], ge::DT_BOOL); - -} -TEST_F(BgKernelContextExtendUT, IgnoreNoneIrAttr) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputRequired); - op_desc->AddOutputDesc("y", tensor_desc); - - ge::AttrUtils::SetStr(op_desc, "a1", "Hello"); - ge::AttrUtils::SetInt(op_desc, "a2", 10240); - ge::AttrUtils::SetBool(op_desc, "a3", true); - ge::AttrUtils::SetFloat(op_desc, "a4", 1024.0021); - ge::AttrUtils::SetStr(op_desc, "b1", "World"); - ge::AttrUtils::SetInt(op_desc, "b2", 1024000); - ge::AttrUtils::SetBool(op_desc, "b3", false); - ge::AttrUtils::SetFloat(op_desc, "b4", 1024.1); - - op_desc->AppendIrAttrName("b1"); - op_desc->AppendIrAttrName("b3"); - op_desc->AppendIrAttrName("a2"); - op_desc->AppendIrAttrName("a4"); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 1); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 4); - EXPECT_STREQ(compute_node_info->GetAttrs()->GetAttrPointer(0), "World"); - EXPECT_EQ(*compute_node_info->GetAttrs()->GetAttrPointer(1), false); - EXPECT_EQ(*compute_node_info->GetAttrs()->GetAttrPointer(2), 10240); - EXPECT_FLOAT_EQ(*compute_node_info->GetAttrs()->GetAttrPointer(3), 1024.0021); -} - -// 测试构造kernel context的时候从tensor desc上获取ATTR_NAME_RESHAPE_TYPE_MASK并设置到compile time tensor desc 上 -// 同时测试调用Expand是否能够得到正确的扩维shape -TEST_F(BgKernelContextExtendUT, BuildRequiredInputWithExpandDimsType) { - vector origin_shape = {5, 6, 7}; - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({5, 6, 7, 1})); - tensor_desc.SetOriginShape(ge::GeShape(origin_shape)); - // get reshape type 此处模拟FE调用transformer中方法获取int类型的reshape type - int64_t int_reshape_type = transformer::ExpandDimension::GenerateReshapeType(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, - origin_shape.size(), "NCH"); - (void) ge::AttrUtils::SetInt(tensor_desc, ge::ATTR_NAME_RESHAPE_TYPE_MASK, int_reshape_type); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputRequired); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto expand_dims_type = td->GetExpandDimsType(); - auto shape = Shape{5, 6, 7}; - Shape out_shape; - expand_dims_type.Expand(shape, out_shape); - ASSERT_EQ(4, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({5, 6, 7, 1})); -} - -TEST_F(BgKernelContextExtendUT, BuildWithMultiInstanceDynamicInputsWithNoMatchingNameBetweenIrAndNode) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("y1", tensor_desc); - op_desc->AddInputDesc("y2", tensor_desc); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AddInputDesc("x2", tensor_desc); - - op_desc->AppendIrInput("x", ge::kIrInputDynamic); - op_desc->AppendIrInput("y", ge::kIrInputDynamic); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("y0", tensor_desc); - data_op_desc->AddOutputDesc("y1", tensor_desc); - data_op_desc->AddOutputDesc("x0", tensor_desc); - data_op_desc->AddOutputDesc("x1", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(2), node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(3), node->GetInDataAnchor(3)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 4); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 2); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 0); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 0); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} - -TEST_F(BgKernelContextExtendUT, BuildWithRequiredInputWithNoMatchingNameBetweenIrAndNode) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputRequired); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x1", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); -} - -/* - * net_output - * | - * conv - * / \ - * x w - */ -ge::ComputeGraphPtr GetNoPeerOutputAnchorGraph() { - ge::ComputeGraphPtr graph = std::make_shared("default"); - // add x node - ge::OpDescPtr x_op_desc = std::make_shared("x", "Data"); - ge::GeTensorDesc x_td(ge::GeShape({1, 256, 256, 3}), ge::FORMAT_NHWC, ge::DT_FLOAT); - x_td.SetFormat(ge::FORMAT_NHWC); - x_td.SetOriginFormat(ge::FORMAT_NHWC); - x_op_desc->AddOutputDesc(x_td); - auto x_node = graph->AddNode(x_op_desc); - x_node->Init(); - // add w node - ge::OpDescPtr w_op_desc = std::make_shared("w", "Data"); - ge::GeTensorDesc w_td(ge::GeShape({1, 1, 1, 1}), ge::FORMAT_HWCN, ge::DT_FLOAT); - w_op_desc->AddOutputDesc(w_td); - auto w_node = graph->AddNode(w_op_desc); - w_node->Init(); - // add conv node - ge::OpDescPtr conv_op_desc = std::make_shared("conv", "Conv2D"); - conv_op_desc->AddInputDesc(x_td); - conv_op_desc->AddInputDesc(w_td); - ge::GeTensorDesc bias_td(ge::GeShape({1})); - conv_op_desc->AddInputDesc("bias", bias_td); - conv_op_desc->AppendIrInput("x", ge::kIrInputRequired); - conv_op_desc->AppendIrInput("w", ge::kIrInputRequired); - conv_op_desc->AppendIrInput("bias", ge::kIrInputOptional); - auto conv_node = graph->AddNode(conv_op_desc); - conv_node->Init(); - // add edge - ge::GraphUtils::AddEdge(x_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(w_node->GetOutDataAnchor(0), conv_node->GetInDataAnchor(1)); - return graph; -} - -TEST_F(BgKernelContextExtendUT, BuildOptionalInputWithNoPeerOutputAnchor) { - auto graph = GetNoPeerOutputAnchorGraph(); - auto conv_node = graph->FindNode("conv"); - bg::BufferPool buffer_pool; - size_t total_size = 0; - auto ret = bg::CreateComputeNodeInfo(conv_node, buffer_pool, total_size); - ASSERT_NE(ret, nullptr); - size_t buf_size = sizeof(ComputeNodeInfo) + sizeof(AnchorInstanceInfo) * 3 + - sizeof(CompileTimeTensorDesc) * 2 + sizeof(RuntimeAttrsDef); - ASSERT_EQ(total_size, buf_size); - - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 2); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 3); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NHWC); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 1); - ins_info = compute_node_info->GetInputInstanceInfo(2); -} - -TEST_F(BgKernelContextExtendUT, CheckCompileTimeTensorDescReservedMemClean) { - auto graph = GetNoPeerOutputAnchorGraph(); - auto conv_node = graph->FindNode("conv"); - bg::BufferPool buffer_pool; - size_t total_size = 0; - auto ret = bg::CreateComputeNodeInfo(conv_node, buffer_pool, total_size); - ASSERT_NE(ret, nullptr); - size_t buf_size = sizeof(ComputeNodeInfo) + sizeof(AnchorInstanceInfo) * 3 + - sizeof(CompileTimeTensorDesc) * 2 + sizeof(RuntimeAttrsDef); - ASSERT_EQ(total_size, buf_size); - - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 2); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 0); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 3); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NHWC); - EXPECT_FALSE(isMemoryCleared(reinterpret_cast(td), sizeof(CompileTimeTensorDesc))); - auto reserved_size = sizeof(CompileTimeTensorDesc) - sizeof(ge::DataType) - sizeof(StorageFormat); - EXPECT_TRUE(isMemoryCleared(reinterpret_cast(td) + sizeof(ge::DataType) + sizeof(StorageFormat), - reserved_size)); -} - -TEST_F(BgKernelContextExtendUT, GetPrivateAttrInComputeNodeInfoOK) { - ge::OpDescPtr op_desc = std::make_shared("test0", "Test"); - const char *attr_name_1 = "private_attr1"; - const char *attr_name_2 = "private_attr2"; - constexpr int64_t attr_value_1 = 10; - const std::string attr_value_2 = "20"; - ge::AnyValue av1 = ge::AnyValue::CreateFrom(attr_value_1); - ge::AnyValue av2 = ge::AnyValue::CreateFrom(attr_value_2); - op_desc->AppendIrAttrName("ir_attr_1"); - (void)op_desc->SetAttr("ir_attr_1", av2); - (void)op_desc->SetAttr(attr_name_1, av1); - (void)op_desc->SetAttr(attr_name_2, av2); - std::vector> private_attrs; - private_attrs.emplace_back(std::make_pair(attr_name_1, av1)); - private_attrs.emplace_back(std::make_pair(attr_name_2, av2)); - bg::BufferPool buffer_pool; - size_t attr_size; - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto compute_node_info_holder = bg::CreateComputeNodeInfo(node, buffer_pool, private_attrs, attr_size); - ASSERT_NE(compute_node_info_holder, nullptr); - auto compute_node_info = reinterpret_cast(compute_node_info_holder.get()); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 3); - EXPECT_STREQ(compute_node_info->GetAttrs()->GetStr(0), "20"); - EXPECT_EQ(*(compute_node_info->GetAttrs()->GetInt(1)), 10); - EXPECT_STREQ(compute_node_info->GetAttrs()->GetStr(2), "20"); -} - -TEST_F(BgKernelContextExtendUT, GetPrivateAttrInComputeNodeInfoByDefault) { - ge::OpDescPtr op_desc = std::make_shared("test0", "Test"); - const char *attr_name_1 = "private_attr1"; - const char *attr_name_2 = "private_attr2"; - constexpr int64_t attr_value_1 = 10; - const std::string attr_value_2 = "20"; - ge::AnyValue av1 = ge::AnyValue::CreateFrom(attr_value_1); - ge::AnyValue av2 = ge::AnyValue::CreateFrom(attr_value_2); - op_desc->AppendIrAttrName("ir_attr_1"); - (void)op_desc->SetAttr("ir_attr_1", av2); - std::vector> private_attrs; - private_attrs.emplace_back(std::make_pair(attr_name_1, av1)); - private_attrs.emplace_back(std::make_pair(attr_name_2, av2)); - bg::BufferPool buffer_pool; - size_t attr_size; - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto compute_node_info_holder = bg::CreateComputeNodeInfo(node, buffer_pool, private_attrs, attr_size); - ASSERT_NE(compute_node_info_holder, nullptr); - auto compute_node_info = reinterpret_cast(compute_node_info_holder.get()); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 3); - EXPECT_STREQ(compute_node_info->GetAttrs()->GetStr(0), "20"); - EXPECT_EQ(*(compute_node_info->GetAttrs()->GetInt(1)), 10); - EXPECT_STREQ(compute_node_info->GetAttrs()->GetStr(2), "20"); -} - -TEST_F(BgKernelContextExtendUT, CreateComputeNodeInfoFailedWhenNotRegisteringPrivateAttr) { - ge::OpDescPtr op_desc = std::make_shared("test0", "Test"); - const char *attr_name_1 = "private_attr1"; - const char *attr_name_2 = "private_attr2"; - constexpr int64_t attr_value_1 = 10; - const std::string attr_value_2 = "20"; - ge::AnyValue av1 = ge::AnyValue::CreateFrom(attr_value_1); - ge::AnyValue av2 = ge::AnyValue::CreateFrom(attr_value_2); - op_desc->AppendIrAttrName("ir_attr_1"); - (void)op_desc->SetAttr("ir_attr_1", av2); - std::vector> private_attrs; - private_attrs.emplace_back(std::make_pair(attr_name_1, av1)); - private_attrs.emplace_back(std::make_pair(attr_name_2, ge::AnyValue())); - bg::BufferPool buffer_pool; - size_t attr_size; - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto compute_node_info_holder = bg::CreateComputeNodeInfo(node, buffer_pool, private_attrs, attr_size); - EXPECT_EQ(compute_node_info_holder, nullptr); -} - -TEST_F(BgKernelContextExtendUT, CreateComputeNodeInfoWithoutIrAttr) { - ge::OpDescPtr op_desc = std::make_shared("test0", "Test"); - const char *attr_name_1 = "private_attr1"; - const char *attr_name_2 = "private_attr2"; - constexpr int64_t attr_value_1 = 10; - const std::string attr_value_2 = "20"; - ge::AnyValue av1 = ge::AnyValue::CreateFrom(attr_value_1); - ge::AnyValue av2 = ge::AnyValue::CreateFrom(attr_value_2); - op_desc->AppendIrAttrName("ir_attr_1"); - (void)op_desc->SetAttr("ir_attr_1", av2); - std::vector> private_attrs; - private_attrs.emplace_back(std::make_pair(attr_name_1, av1)); - private_attrs.emplace_back(std::make_pair(attr_name_2, av2)); - bg::BufferPool buffer_pool; - size_t attr_size; - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto compute_node_info_holder = bg::CreateComputeNodeInfoWithoutIrAttr(node, buffer_pool, private_attrs, attr_size); - EXPECT_NE(compute_node_info_holder, nullptr); - auto compute_node_info = reinterpret_cast(compute_node_info_holder.get()); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 2); - EXPECT_EQ(*(compute_node_info->GetAttrs()->GetInt(0)), 10); - EXPECT_STREQ(compute_node_info->GetAttrs()->GetStr(1), "20"); -} - -TEST_F(BgKernelContextExtendUT, BuildRequiredOutput) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputRequired); - op_desc->AddOutputDesc("y", tensor_desc); - op_desc->AppendIrOutput("y", ge::kIrOutputRequired); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 1); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetIrOutputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - auto expand_dims_type = td->GetExpandDimsType(); - Shape origin_shape({8, 3, 224, 224}); - Shape storage_shape; - expand_dims_type.Expand(origin_shape, storage_shape); - EXPECT_EQ(storage_shape, origin_shape); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - auto out_ins_info = compute_node_info->GetOutputInstanceInfo(0); - ASSERT_NE(out_ins_info, nullptr); - EXPECT_EQ(out_ins_info->GetInstanceNum(), 1); - EXPECT_EQ(out_ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithDynamicOutputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x0", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputDynamic); - op_desc->AddOutputDesc("y0", tensor_desc); - op_desc->AppendIrOutput("y", ge::kIrOutputDynamic); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 1); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 1); - ASSERT_EQ(compute_node_info->GetIrOutputsNum(), 1); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - auto td_out = compute_node_info->GetOutputTdInfo(0); - ASSERT_NE(td_out, nullptr); - EXPECT_EQ(td_out->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td_out->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td_out->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 1); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - auto out_ins_info = compute_node_info->GetOutputInstanceInfo(0); - ASSERT_NE(out_ins_info, nullptr); - EXPECT_EQ(out_ins_info->GetInstanceNum(), 1); - EXPECT_EQ(out_ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithMultiInstanceDynamicOutputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x0", tensor_desc); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AddInputDesc("y0", tensor_desc); - op_desc->AddInputDesc("y1", tensor_desc); - op_desc->AddInputDesc("y2", tensor_desc); - op_desc->AppendIrInput("x", ge::kIrInputDynamic); - op_desc->AppendIrInput("y", ge::kIrInputDynamic); - - op_desc->AddOutputDesc("xx0", tensor_desc); - op_desc->AddOutputDesc("xx1", tensor_desc); - op_desc->AddOutputDesc("yy0", tensor_desc); - op_desc->AddOutputDesc("yy1", tensor_desc); - op_desc->AddOutputDesc("yy2", tensor_desc); - op_desc->AppendIrOutput("xx", ge::kIrOutputDynamic); - op_desc->AppendIrOutput("yy", ge::kIrOutputDynamic); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x0", tensor_desc); - data_op_desc->AddOutputDesc("x1", tensor_desc); - data_op_desc->AddOutputDesc("y0", tensor_desc); - data_op_desc->AddOutputDesc("y1", tensor_desc); - data_op_desc->AddOutputDesc("y2", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(2), node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(3), node->GetInDataAnchor(3)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(4), node->GetInDataAnchor(4)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 5); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 5); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 2); - ASSERT_EQ(compute_node_info->GetIrOutputsNum(), 2); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 2); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 3); - EXPECT_EQ(ins_info->GetInstanceStart(), 2); - - auto out_ins_info = compute_node_info->GetOutputInstanceInfo(0); - ASSERT_NE(out_ins_info, nullptr); - EXPECT_EQ(out_ins_info->GetInstanceNum(), 2); - EXPECT_EQ(out_ins_info->GetInstanceStart(), 0); - out_ins_info = compute_node_info->GetOutputInstanceInfo(1); - ASSERT_NE(out_ins_info, nullptr); - EXPECT_EQ(out_ins_info->GetInstanceNum(), 3); - EXPECT_EQ(out_ins_info->GetInstanceStart(), 2); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -TEST_F(BgKernelContextExtendUT, BuildWithEmptyDynamicOutputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x0", tensor_desc); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("y", ge::kIrInputDynamic); - op_desc->AppendIrInput("x", ge::kIrInputDynamic); - op_desc->AddOutputDesc("xx0", tensor_desc); - op_desc->AddOutputDesc("xx1", tensor_desc); - op_desc->AppendIrOutput("yy", ge::kIrOutputDynamic); - op_desc->AppendIrOutput("xx", ge::kIrOutputDynamic); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - auto data_op_desc = std::make_shared("data", "Data"); - data_op_desc->AddOutputDesc("x0", tensor_desc); - data_op_desc->AddOutputDesc("x1", tensor_desc); - auto data_node = graph->AddNode(data_op_desc); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node->GetInDataAnchor(1)); - - bg::BufferPool buffer_pool; - auto ret = bg::CreateComputeNodeInfo(node, buffer_pool); - ASSERT_NE(ret, nullptr); - auto compute_node_info = reinterpret_cast(ret.get()); - ASSERT_EQ(compute_node_info->GetInputsNum(), 2); - ASSERT_EQ(compute_node_info->GetOutputsNum(), 2); - ASSERT_EQ(compute_node_info->GetIrInputsNum(), 2); - ASSERT_EQ(compute_node_info->GetIrOutputsNum(), 2); - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto td_out = compute_node_info->GetOutputTdInfo(0); - ASSERT_NE(td_out, nullptr); - EXPECT_EQ(td_out->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(td_out->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td_out->GetStorageFormat(), ge::FORMAT_NC1HWC0); - - auto ins_info = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 0); - ins_info = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 2); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - auto out_ins_info = compute_node_info->GetOutputInstanceInfo(0); - ASSERT_NE(out_ins_info, nullptr); - EXPECT_EQ(out_ins_info->GetInstanceNum(), 0); - out_ins_info = compute_node_info->GetOutputInstanceInfo(1); - ASSERT_NE(ins_info, nullptr); - EXPECT_EQ(ins_info->GetInstanceNum(), 2); - EXPECT_EQ(ins_info->GetInstanceStart(), 0); - - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrNum(), 0); - EXPECT_EQ(compute_node_info->GetAttrs()->GetAttrPointer(0), nullptr); -} -// todo lowering时,不需要构造attr -// todo infershape、tiling utils重新看一下输入是否正确 -// todo kernel中获取attr的方式变化 -} // namespace gert diff --git a/tests/ut/exe_graph/buffer_pool_unittest.cc b/tests/ut/exe_graph/buffer_pool_unittest.cc deleted file mode 100644 index c2342e758c3d6bc4f5f6e6644a886639abfb231b..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/buffer_pool_unittest.cc +++ /dev/null @@ -1,171 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/buffer_pool.h" -#include -#include "exe_graph/runtime/continuous_buffer.h" -namespace gert { -namespace { -constexpr size_t kLargeBufSizeThreshold = 1024U * 1024U; // 1M -} -using namespace bg; -class BufferPoolUT : public testing::Test {}; -TEST_F(BufferPoolUT, IdContinuous) { - BufferPool tp; - std::string large_str(kLargeBufSizeThreshold, 'a'); - EXPECT_EQ(tp.AddStr("Hello"), 0); - EXPECT_EQ(tp.AddStr(large_str.c_str()), 1); - EXPECT_EQ(tp.AddStr("World"), 2); - - auto text_holder = tp.Serialize(); - ASSERT_NE(text_holder, nullptr); - auto text = reinterpret_cast(text_holder.get()); - EXPECT_EQ(text->GetNum(), 3); - EXPECT_STREQ(text->Get(0), "Hello"); - EXPECT_STREQ(text->Get(1), large_str.c_str()); - EXPECT_STREQ(text->Get(2), "World"); -} -TEST_F(BufferPoolUT, Deduplication) { - BufferPool tp; - std::string large_str(kLargeBufSizeThreshold, 'a'); - EXPECT_EQ(tp.AddStr("Hello"), 0); - EXPECT_EQ(tp.AddStr("World"), 1); - EXPECT_EQ(tp.AddStr(large_str.c_str()), 2); - EXPECT_EQ(tp.AddStr("Hello"), 0); - EXPECT_EQ(tp.AddStr("Zero"), 3); - EXPECT_EQ(tp.AddStr(large_str.c_str()), 4); - - auto text_holder = tp.Serialize(); - ASSERT_NE(text_holder, nullptr); - auto text = reinterpret_cast(text_holder.get()); - EXPECT_EQ(text->GetNum(), 5); - EXPECT_STREQ(text->Get(0), "Hello"); - EXPECT_STREQ(text->Get(1), "World"); - EXPECT_STREQ(text->Get(2), large_str.c_str()); - EXPECT_STREQ(text->Get(3), "Zero"); - EXPECT_STREQ(text->Get(4), large_str.c_str()); - EXPECT_EQ(text->Get(5), nullptr); -} -TEST_F(BufferPoolUT, NonString) { - BufferPool tp; - char buf[] = "Hello\0World\0Zero"; - std::string large_str(kLargeBufSizeThreshold, 'a'); - large_str[1] = '\0'; - EXPECT_EQ(tp.AddBuf(reinterpret_cast(buf), 16), 0); - EXPECT_EQ(tp.AddStr("World"), 1); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str.c_str()), kLargeBufSizeThreshold), 2); - EXPECT_EQ(tp.AddStr("Hello"), 3); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str.c_str()), kLargeBufSizeThreshold), 4); - EXPECT_EQ(tp.AddStr("Zero"), 5); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(buf), 16), 0); - - auto text_holder = tp.Serialize(); - ASSERT_NE(text_holder, nullptr); - auto text = reinterpret_cast(text_holder.get()); - EXPECT_EQ(text->GetNum(), 6); - size_t size; - EXPECT_EQ(memcmp(text->Get(0, size), buf, 16), 0); - EXPECT_EQ(size, 16); - EXPECT_STREQ(text->Get(1, size), "World"); - EXPECT_EQ(size, 6); - EXPECT_EQ(memcmp(text->Get(2, size), large_str.c_str(), kLargeBufSizeThreshold), 0); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_STREQ(text->Get(3, size), "Hello"); - EXPECT_EQ(size, 6); - EXPECT_EQ(memcmp(text->Get(4, size), large_str.c_str(), kLargeBufSizeThreshold), 0); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_STREQ(text->Get(5, size), "Zero"); - EXPECT_EQ(size, 5); -} -TEST_F(BufferPoolUT, CorrectLength) { - BufferPool tp; - char buf[] = "Hello\0World\0Zero"; - std::string large_str(kLargeBufSizeThreshold, 'a'); - large_str[1] = '\0'; - EXPECT_EQ(tp.GetSize(), 0); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(buf), 16), 0); - EXPECT_EQ(tp.AddStr("World"), 1); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str.c_str()), kLargeBufSizeThreshold), 2); - EXPECT_EQ(tp.AddStr("Hello"), 3); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str.c_str()), kLargeBufSizeThreshold), 4); - EXPECT_EQ(tp.AddStr("Zero"), 5); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(buf), 16), 0); - - size_t total_size; - auto text_holder = tp.Serialize(total_size); - ASSERT_NE(text_holder, nullptr); - auto text = reinterpret_cast(text_holder.get()); - - auto length = text->GetTotalLength(); - EXPECT_EQ(length, total_size); - auto another = std::unique_ptr(new uint8_t[length]); - ASSERT_NE(another, nullptr); - memcpy(another.get(), text_holder.get(), length); - text = reinterpret_cast(another.get()); - - EXPECT_EQ(text->GetNum(), 6); - size_t size; - EXPECT_EQ(memcmp(text->Get(0, size), buf, 16), 0); - EXPECT_EQ(size, 16); - EXPECT_STREQ(text->Get(1, size), "World"); - EXPECT_EQ(size, 6); - EXPECT_EQ(memcmp(text->Get(2, size), large_str.c_str(), kLargeBufSizeThreshold), 0); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_STREQ(text->Get(3, size), "Hello"); - EXPECT_EQ(size, 6); - EXPECT_EQ(memcmp(text->Get(4, size), large_str.c_str(), kLargeBufSizeThreshold), 0); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_STREQ(text->Get(5, size), "Zero"); - EXPECT_EQ(size, 5); - EXPECT_EQ(text->Get(6, size), nullptr); -} - -TEST_F(BufferPoolUT, BoundaryValue) { - BufferPool tp; - std::string large_str1(kLargeBufSizeThreshold, 'a'); - large_str1[1] = '\0'; - std::string large_str2(kLargeBufSizeThreshold - 1, 'b'); - std::string large_str3(kLargeBufSizeThreshold - 2, 'c'); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str1.c_str()), kLargeBufSizeThreshold), 0); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str1.c_str()), kLargeBufSizeThreshold), 1); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str1.c_str()), kLargeBufSizeThreshold - 1), 2); - EXPECT_EQ(tp.AddBuf(reinterpret_cast(large_str1.c_str()), kLargeBufSizeThreshold - 1), 2); - EXPECT_EQ(tp.AddStr(large_str2.c_str()), 3); - EXPECT_EQ(tp.AddStr(large_str2.c_str()), 4); - EXPECT_EQ(tp.AddStr(large_str3.c_str()), 5); - EXPECT_EQ(tp.AddStr(large_str3.c_str()), 5); - - EXPECT_NE(tp.GetBufById(0), nullptr); - EXPECT_NE(tp.GetBufById(1), nullptr); - EXPECT_NE(tp.GetBufById(2), nullptr); - EXPECT_NE(tp.GetBufById(3), nullptr); - EXPECT_NE(tp.GetBufById(4), nullptr); - EXPECT_NE(tp.GetBufById(5), nullptr); - EXPECT_EQ(tp.GetBufById(6), nullptr); - - auto text_holder = tp.Serialize(); - ASSERT_NE(text_holder, nullptr); - auto text = reinterpret_cast(text_holder.get()); - EXPECT_EQ(text->GetNum(), 6); - size_t size; - EXPECT_EQ(memcmp(text->Get(0, size), large_str1.c_str(), kLargeBufSizeThreshold), 0); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_EQ(memcmp(text->Get(1, size), large_str1.c_str(), kLargeBufSizeThreshold), 0); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_EQ(memcmp(text->Get(2, size), large_str1.c_str(), kLargeBufSizeThreshold - 1), 0); - EXPECT_EQ(size, kLargeBufSizeThreshold - 1); - EXPECT_STREQ(text->Get(3, size), large_str2.c_str()); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_STREQ(text->Get(4, size), large_str2.c_str()); - EXPECT_EQ(size, kLargeBufSizeThreshold); - EXPECT_STREQ(text->Get(5, size), large_str3.c_str()); - EXPECT_EQ(size, kLargeBufSizeThreshold - 1); - EXPECT_EQ(text->Get(6, size), nullptr); -} -} // namespace gert diff --git a/tests/ut/exe_graph/builtin_node_types_unittest.cc b/tests/ut/exe_graph/builtin_node_types_unittest.cc deleted file mode 100644 index 148d7b9052fbd8c8e8e552f347ad2ecc2abb4346..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/builtin_node_types_unittest.cc +++ /dev/null @@ -1,18 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/builtin_node_types.h" -#include -namespace gert { -class BuiltInNodeTypesUT : public testing::Test {}; -TEST_F(BuiltInNodeTypesUT, IsTypeOutputData) { - ASSERT_TRUE(IsTypeOutputData("OutputData")); - ASSERT_FALSE(IsTypeOutputData("Data")); -} -} // namespace gert diff --git a/tests/ut/exe_graph/checker_unittest.cc b/tests/ut/exe_graph/checker_unittest.cc deleted file mode 100644 index 26f7a9f60acabc14884a81609149661debc4a35f..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/checker_unittest.cc +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/checker.h" -#include -#include "runtime/base.h" -namespace { -template -T JustReturn(T val) { - return val; -} - -ge::graphStatus StatusFuncUseStatusFunc(ge::graphStatus val) { - GE_ASSERT_SUCCESS(JustReturn(val)); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseBoolFunc(bool val) { - GE_ASSERT_TRUE(JustReturn(val)); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUsePointerFunc(void *val) { - GE_ASSERT_NOTNULL(JustReturn(val)); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseUniquePtrFunc(std::unique_ptr val) { - GE_ASSERT_NOTNULL(JustReturn(std::move(val))); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseSharedPtrFunc(const std::shared_ptr &val) { - GE_ASSERT_NOTNULL(JustReturn(val)); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseEOKFunc(int val) { - GE_ASSERT_EOK(JustReturn(val)); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseRtFunc(int32_t val) { - GE_ASSERT_RT_OK(JustReturn(val)); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseHyperStatusFunc(gert::HyperStatus val) { - GE_ASSERT_HYPER_SUCCESS(JustReturn(std::move(val))); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus StatusFuncUseStatus(ge::graphStatus val) { - GE_ASSERT_SUCCESS(val); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseBool(bool val) { - GE_ASSERT_TRUE(val); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUsePointer(void *val) { - GE_ASSERT_NOTNULL(val); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseUniquePtr(const std::unique_ptr &val) { - GE_ASSERT_NOTNULL(val); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseSharedPtr(const std::shared_ptr &val) { - GE_ASSERT_NOTNULL(val); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseEOK(int val) { - GE_ASSERT_EOK(val); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseRt(int32_t val) { - GE_ASSERT_RT_OK(val); - return ge::GRAPH_SUCCESS; -} -ge::graphStatus StatusFuncUseHyperStatus(gert::HyperStatus val) { - GE_ASSERT_HYPER_SUCCESS(val); - return ge::GRAPH_SUCCESS; -} - -bool BoolFuncUseStatus(ge::graphStatus val) { - GE_ASSERT_SUCCESS(val); - return true; -} -bool BoolFuncUseBool(bool val) { - GE_ASSERT_TRUE(val); - return true; -} -bool BoolFuncUsePointer(void *val) { - GE_ASSERT_NOTNULL(val); - return true; -} -bool BoolFuncUseUniquePtr(const std::unique_ptr &val) { - GE_ASSERT_NOTNULL(val); - return true; -} -bool BoolFuncUseSharedPtr(const std::shared_ptr &val) { - GE_ASSERT_NOTNULL(val); - return true; -} -bool BoolFuncUseEOK(int val) { - GE_ASSERT_EOK(val); - return true; -} -bool BoolFuncUseRt(int32_t val) { - GE_ASSERT_RT_OK(val); - return true; -} -bool BoolFuncUseHyperStatus(gert::HyperStatus val) { - GE_ASSERT_HYPER_SUCCESS(val); - return true; -} - -int64_t g_a = 0xff; -void *PointerFuncUseStatus(ge::graphStatus val) { - GE_ASSERT_SUCCESS(val); - return (void*)&g_a; -} -void *PointerFuncUseBool(bool val) { - GE_ASSERT_TRUE(val); - return (void*)&g_a; -} -void *PointerFuncUsePointer(void *val) { - GE_ASSERT_NOTNULL(val); - return (void*)&g_a; -} -void *PointerFuncUseUniquePtr(const std::unique_ptr &val) { - GE_ASSERT_NOTNULL(val); - return (void*)&g_a; -} -void *PointerFuncUseSharedPtr(const std::shared_ptr &val) { - GE_ASSERT_NOTNULL(val); - return (void*)&g_a; -} -void *PointerFuncUseEOK(int val) { - GE_ASSERT_EOK(val); - return (void*)&g_a; -} -void *PointerFuncUseRt(int32_t val) { - GE_ASSERT_RT_OK(val); - return (void*)&g_a; -} -void *PointerFuncUseHyperStatus(gert::HyperStatus val) { - GE_ASSERT_HYPER_SUCCESS(val); - return (void*)&g_a; -} -} // namespace -class CheckerUT : public testing::Test {}; -TEST_F(CheckerUT, ReturnStatusOk) { - ASSERT_NE(StatusFuncUseStatus(ge::FAILED), ge::GRAPH_SUCCESS); - ASSERT_NE(StatusFuncUseStatus(ge::GRAPH_FAILED), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseStatus(ge::GRAPH_SUCCESS), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseBool(false), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseBool(true), ge::GRAPH_SUCCESS); - - int64_t a; - ASSERT_NE(StatusFuncUsePointer(nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUsePointer(&a), ge::GRAPH_SUCCESS); - - std::unique_ptr b = std::unique_ptr(new uint8_t[100]); - ASSERT_NE(StatusFuncUseUniquePtr(nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseUniquePtr(b), ge::GRAPH_SUCCESS); - - auto c = std::shared_ptr(new uint8_t[100]); - ASSERT_NE(StatusFuncUseSharedPtr(nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseSharedPtr(c), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseEOK(EINVAL), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseEOK(EOK), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseHyperStatus(gert::HyperStatus::ErrorStatus("hello")), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseHyperStatus(gert::HyperStatus::Success()), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseRt(RT_EXCEPTION_DEV_RUNNING_DOWN), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseRt(RT_ERROR_NONE), ge::GRAPH_SUCCESS); -} - -TEST_F(CheckerUT, ReturnBoolOk) { - ASSERT_NE(BoolFuncUseStatus(ge::FAILED), true); - ASSERT_NE(BoolFuncUseStatus(ge::GRAPH_FAILED), true); - ASSERT_EQ(BoolFuncUseStatus(ge::GRAPH_SUCCESS), true); - - ASSERT_NE(BoolFuncUseBool(false), true); - ASSERT_EQ(BoolFuncUseBool(true), true); - - int64_t a; - ASSERT_NE(BoolFuncUsePointer(nullptr), true); - ASSERT_EQ(BoolFuncUsePointer(&a), true); - - std::unique_ptr b = std::unique_ptr(new uint8_t[100]); - ASSERT_NE(BoolFuncUseUniquePtr(nullptr), true); - ASSERT_EQ(BoolFuncUseUniquePtr(b), true); - - auto c = std::shared_ptr(new uint8_t[100]); - ASSERT_NE(BoolFuncUseSharedPtr(nullptr), true); - ASSERT_EQ(BoolFuncUseSharedPtr(c), true); - - ASSERT_NE(BoolFuncUseEOK(EINVAL), true); - ASSERT_EQ(BoolFuncUseEOK(EOK), true); - - ASSERT_NE(BoolFuncUseHyperStatus(gert::HyperStatus::ErrorStatus("hello")), true); - ASSERT_EQ(BoolFuncUseHyperStatus(gert::HyperStatus::Success()), true); - - ASSERT_NE(BoolFuncUseRt(RT_EXCEPTION_DEV_RUNNING_DOWN), true); - ASSERT_EQ(BoolFuncUseRt(RT_ERROR_NONE), true); -} - -TEST_F(CheckerUT, ReturnPointerOk) { - ASSERT_EQ(PointerFuncUseStatus(ge::FAILED), nullptr); - ASSERT_EQ(PointerFuncUseStatus(ge::GRAPH_FAILED), nullptr); - ASSERT_NE(PointerFuncUseStatus(ge::GRAPH_SUCCESS), nullptr); - - ASSERT_EQ(PointerFuncUseBool(false), nullptr); - ASSERT_NE(PointerFuncUseBool(true), nullptr); - - int64_t a; - ASSERT_EQ(PointerFuncUsePointer(nullptr), nullptr); - ASSERT_NE(PointerFuncUsePointer(&a), nullptr); - - std::unique_ptr b = std::unique_ptr(new uint8_t[100]); - ASSERT_EQ(PointerFuncUseUniquePtr(nullptr), nullptr); - ASSERT_NE(PointerFuncUseUniquePtr(b), nullptr); - - auto c = std::shared_ptr(new uint8_t[100]); - ASSERT_EQ(PointerFuncUseSharedPtr(nullptr), nullptr); - ASSERT_NE(PointerFuncUseSharedPtr(c), nullptr); - - ASSERT_EQ(PointerFuncUseEOK(EINVAL), nullptr); - ASSERT_NE(PointerFuncUseEOK(EOK), nullptr); - - ASSERT_EQ(PointerFuncUseHyperStatus(gert::HyperStatus::ErrorStatus("hello")), nullptr); - ASSERT_NE(PointerFuncUseHyperStatus(gert::HyperStatus::Success()), nullptr); - - ASSERT_EQ(PointerFuncUseRt(RT_EXCEPTION_DEV_RUNNING_DOWN), nullptr); - ASSERT_NE(PointerFuncUseRt(RT_ERROR_NONE), nullptr); -} - -TEST_F(CheckerUT, ReturnInFunc) { - ASSERT_NE(StatusFuncUseStatusFunc(ge::FAILED), ge::GRAPH_SUCCESS); - ASSERT_NE(StatusFuncUseStatusFunc(ge::GRAPH_FAILED), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseStatusFunc(ge::GRAPH_SUCCESS), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseBoolFunc(false), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseBoolFunc(true), ge::GRAPH_SUCCESS); - - int64_t a; - ASSERT_NE(StatusFuncUsePointerFunc(nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUsePointerFunc(&a), ge::GRAPH_SUCCESS); - - std::unique_ptr b = std::unique_ptr(new uint8_t[100]); - ASSERT_NE(StatusFuncUseUniquePtrFunc(nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseUniquePtrFunc(std::move(b)), ge::GRAPH_SUCCESS); - - auto c = std::shared_ptr(new uint8_t[100]); - ASSERT_NE(StatusFuncUseSharedPtrFunc(nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseSharedPtrFunc(c), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseEOKFunc(EINVAL), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseEOKFunc(EOK), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseHyperStatusFunc(gert::HyperStatus::ErrorStatus("hello")), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseHyperStatusFunc(gert::HyperStatus::Success()), ge::GRAPH_SUCCESS); - - ASSERT_NE(StatusFuncUseRtFunc(RT_EXCEPTION_DEV_RUNNING_DOWN), ge::GRAPH_SUCCESS); - ASSERT_EQ(StatusFuncUseRtFunc(RT_ERROR_NONE), ge::GRAPH_SUCCESS); -} - -TEST_F(CheckerUT, MicroTest) { // Keep this the last case!!! - std::string error_msg; -#ifdef GELOGE -#undef GELOGE -#endif -#define GELOGE(v, ...) error_msg = std::string(CreateErrorMsg(__VA_ARGS__).data()) - [&error_msg]() { GE_ASSERT(false); }(); - EXPECT_EQ(error_msg, "Assert false failed"); - [&error_msg]() { GE_ASSERT(false, "Something error"); }(); - EXPECT_EQ(error_msg, "Something error"); - [&error_msg]() { GE_ASSERT(false, "%s error", "Many things"); }(); - EXPECT_EQ(error_msg, "Many things error"); - - [&error_msg]() { GE_ASSERT_NOTNULL(nullptr); }(); - EXPECT_EQ(error_msg, "Assert ((nullptr) != nullptr) failed"); - [&error_msg]() { GE_ASSERT_NOTNULL(nullptr, "%s error", "Nullptr"); }(); - EXPECT_EQ(error_msg, "Nullptr error"); - - [&error_msg]()->bool { GE_ASSERT_EQ(0, 1); return true;}(); - EXPECT_EQ(error_msg, "Assert (0 == 1) failed, expect 1 actual 0"); -#undef GELOGE -} diff --git a/tests/ut/exe_graph/compute_node_info_unittest.cc b/tests/ut/exe_graph/compute_node_info_unittest.cc deleted file mode 100644 index 02fc96f9e65ed894388833dccb882b330eae4c5e..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/compute_node_info_unittest.cc +++ /dev/null @@ -1,280 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/compute_node_info.h" -#include -#include "faker/kernel_run_context_faker.h" -#include "graph/debug/ge_util.h" -namespace gert { -class ComputeNodeInfoUT : public testing::Test {}; -TEST_F(ComputeNodeInfoUT, GetInputFormatOk) { - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - auto td = compute_node_info->GetInputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetFormat().GetStorageFormat(), ge::FORMAT_NC1HWC0); - EXPECT_EQ(td->GetFormat().GetOriginFormat(), ge::FORMAT_NCHW); - - td = compute_node_info->GetInputTdInfo(1); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_FRACTAL_Z); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_HWCN); - EXPECT_EQ(td->GetFormat().GetStorageFormat(), ge::FORMAT_FRACTAL_Z); - EXPECT_EQ(td->GetFormat().GetOriginFormat(), ge::FORMAT_HWCN); - - EXPECT_EQ(compute_node_info->GetInputTdInfo(2), nullptr); -} -TEST_F(ComputeNodeInfoUT, GetOutputFormatOk) { - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - auto td = compute_node_info->GetOutputTdInfo(0); - ASSERT_NE(td, nullptr); - EXPECT_EQ(td->GetStorageFormat(), ge::FORMAT_NC1HWC0); - EXPECT_EQ(td->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(td->GetFormat().GetStorageFormat(), ge::FORMAT_NC1HWC0); - EXPECT_EQ(td->GetFormat().GetOriginFormat(), ge::FORMAT_NCHW); - - EXPECT_EQ(compute_node_info->GetOutputTdInfo(1), nullptr); -} -TEST_F(ComputeNodeInfoUT, GetNodeNameTypeOk) { - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - EXPECT_STREQ(compute_node_info->GetNodeName(), "node"); - EXPECT_STREQ(compute_node_info->GetNodeType(), "node"); -} -TEST_F(ComputeNodeInfoUT, GetInputInfoOk) { - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - EXPECT_EQ(compute_node_info->GetIrInputsNum(), 2); - EXPECT_EQ(compute_node_info->GetInputsNum(), 2); - EXPECT_EQ(compute_node_info->GetOutputsNum(), 1); -} -TEST_F(ComputeNodeInfoUT, GetInputInstanceOk) { - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - auto ins = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 1); - EXPECT_EQ(ins->GetInstanceStart(), 0); - - ins = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 1); - EXPECT_EQ(ins->GetInstanceStart(), 1); - - EXPECT_EQ(compute_node_info->GetInputInstanceInfo(2), nullptr); -} -TEST_F(ComputeNodeInfoUT, GetDynamicInputInstanceOk) { - auto context_holder = KernelRunContextFaker() - .IrInstanceNum({2, 0, 1}) - .NodeIoNum(3, 1) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - auto ins = compute_node_info->GetInputInstanceInfo(0); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 2); - EXPECT_EQ(ins->GetInstanceStart(), 0); - - ins = compute_node_info->GetInputInstanceInfo(1); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 0); - EXPECT_EQ(ins->GetInstanceStart(), 2); - - ins = compute_node_info->GetInputInstanceInfo(2); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 1); - EXPECT_EQ(ins->GetInstanceStart(), 2); - - EXPECT_EQ(compute_node_info->GetInputInstanceInfo(3), nullptr); -} -TEST_F(ComputeNodeInfoUT, GetOutputInstanceOk) { - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .IrOutputNum(2) - .NodeIoNum(2, 2) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - auto ins = compute_node_info->GetOutputInstanceInfo(0); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 1); - EXPECT_EQ(ins->GetInstanceStart(), 0); - - ins = compute_node_info->GetOutputInstanceInfo(1); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 1); - EXPECT_EQ(ins->GetInstanceStart(), 1); - - EXPECT_EQ(compute_node_info->GetOutputInstanceInfo(2), nullptr); -} -TEST_F(ComputeNodeInfoUT, GetDynamicOutputInstanceOk) { - auto context_holder = KernelRunContextFaker() - .IrInstanceNum({2, 0, 1}) - .IrOutputInstanceNum({2, 1}) - .NodeIoNum(3, 3) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - auto ins = compute_node_info->GetOutputInstanceInfo(0); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 2); - EXPECT_EQ(ins->GetInstanceStart(), 0); - - ins = compute_node_info->GetOutputInstanceInfo(1); - ASSERT_NE(ins, nullptr); - EXPECT_EQ(ins->GetInstanceNum(), 1); - EXPECT_EQ(ins->GetInstanceStart(), 2); - - EXPECT_EQ(compute_node_info->GetOutputInstanceInfo(2), nullptr); -} -TEST_F(ComputeNodeInfoUT, GetAttrsOk) { - auto context_holder = KernelRunContextFaker() - .IrInstanceNum({2, 0, 1}) - .NodeIoNum(3, 1) - .NodeAttrs({ - {"i", ge::AnyValue::CreateFrom(static_cast(10))}, - {"li", ge::AnyValue::CreateFrom(std::vector({10,20,30}))} - }) - .Build(); - - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compute_node_info = context->GetComputeNodeInfo(); - ASSERT_NE(compute_node_info, nullptr); - - auto attrs = compute_node_info->GetAttrs(); - ASSERT_NE(attrs, nullptr); - - EXPECT_EQ(attrs->GetAttrNum(), 2); - ASSERT_NE(attrs->GetAttrPointer(0), nullptr); - EXPECT_EQ(*attrs->GetAttrPointer(0), 10); - - ASSERT_NE(attrs->GetAttrPointer(1), nullptr); - auto vec = attrs->GetAttrPointer(1); - EXPECT_EQ(reinterpret_cast(vec->GetData())[0], 10); - EXPECT_EQ(reinterpret_cast(vec->GetData())[1], 20); - EXPECT_EQ(reinterpret_cast(vec->GetData())[2], 30); -} - -TEST_F(ComputeNodeInfoUT, GetAttrsEmptyAxes) { -auto context_holder = KernelRunContextFaker() - .IrInstanceNum({2, 0, 1}) - .NodeIoNum(3, 1) - .NodeAttrs({ - {"i", ge::AnyValue::CreateFrom(static_cast(10))}, - {"axes", ge::AnyValue::CreateFrom(std::vector({}))} - }) - .Build(); - -auto context = context_holder.GetContext(); -ASSERT_NE(context, nullptr); -auto compute_node_info = context->GetComputeNodeInfo(); -ASSERT_NE(compute_node_info, nullptr); - -auto attrs = compute_node_info->GetAttrs(); -ASSERT_NE(attrs, nullptr); -EXPECT_EQ(attrs->GetAttrNum(), 2); -EXPECT_EQ(*attrs->GetAttrPointer(0), 10); -auto vec = attrs->GetAttrPointer(1); -EXPECT_NE(vec, nullptr); -EXPECT_EQ(vec->GetSize(), 0); -} - -TEST_F(ComputeNodeInfoUT, InitAndCalcSizeDefault) { - const size_t ir_input_num = 2U; - const size_t inputs_num = 2U; - const size_t outputs_num = 2U; - const char * node_name = "test"; - const char * node_type = "Test"; - - size_t total_size = 0U; - EXPECT_EQ(ComputeNodeInfo::CalcSize(ir_input_num, inputs_num, outputs_num, total_size), ge::SUCCESS); - - auto compute_node_info_holder = ge::ComGraphMakeUnique(total_size); - EXPECT_NE(compute_node_info_holder, nullptr); - auto compute_node_info = ge::PtrToPtr(compute_node_info_holder.get()); - compute_node_info->Init(ir_input_num, inputs_num, outputs_num, node_name, node_type); - - EXPECT_EQ(compute_node_info->GetIrInputsNum(), ir_input_num); - EXPECT_EQ(compute_node_info->GetIrOutputsNum(), 0U); - EXPECT_EQ(compute_node_info->GetInputsNum(), inputs_num); - EXPECT_EQ(compute_node_info->GetOutputsNum(), outputs_num); - EXPECT_STREQ(compute_node_info->GetNodeName(), node_name); - EXPECT_STREQ(compute_node_info->GetNodeType(), node_type); -} -} // namespace gert diff --git a/tests/ut/exe_graph/continuous_vector_unittest.cc b/tests/ut/exe_graph/continuous_vector_unittest.cc deleted file mode 100644 index 19b96b9512b003de6f91abf59da5c6368d3a60cd..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/continuous_vector_unittest.cc +++ /dev/null @@ -1,157 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/continuous_vector.h" -#include "securec.h" -#include -namespace gert { -class ContinuousVectorUT : public testing::Test {}; -TEST_F(ContinuousVectorUT, CreateOk) { - auto vec_holder = ContinuousVector::Create(16); - auto vec = reinterpret_cast(vec_holder.get()); - auto c_vec = reinterpret_cast(vec_holder.get()); - ASSERT_NE(vec, nullptr); - EXPECT_EQ(vec->GetSize(), 0); - EXPECT_EQ(vec->GetCapacity(), 16); - EXPECT_EQ(c_vec->GetSize(), 0); - EXPECT_EQ(c_vec->GetCapacity(), 16); -} -TEST_F(ContinuousVectorUT, SetSizeOk) { - auto vec_holder = ContinuousVector::Create(16); - auto vec = reinterpret_cast(vec_holder.get()); - ASSERT_NE(vec, nullptr); - EXPECT_EQ(vec->GetSize(), 0); - EXPECT_EQ(vec->SetSize(8), ge::GRAPH_SUCCESS); - EXPECT_EQ(vec->GetSize(), 8); - EXPECT_EQ(vec->SetSize(16), ge::GRAPH_SUCCESS); - EXPECT_EQ(vec->GetSize(), 16); - EXPECT_EQ(vec->SetSize(0), ge::GRAPH_SUCCESS); - EXPECT_EQ(vec->GetSize(), 0); -} -TEST_F(ContinuousVectorUT, SetSizeFailedOutOfBounds) { - auto vec_holder = ContinuousVector::Create(16); - auto vec = reinterpret_cast(vec_holder.get()); - ASSERT_NE(vec, nullptr); - EXPECT_EQ(vec->GetSize(), 0); - EXPECT_NE(vec->SetSize(17), ge::GRAPH_SUCCESS); -} -TEST_F(ContinuousVectorUT, CreateNone) { - auto vec_holder = ContinuousVector::Create(0); - auto vec = reinterpret_cast(vec_holder.get()); - ASSERT_NE(vec, nullptr); - EXPECT_EQ(vec->GetSize(), 0); - EXPECT_EQ(vec->GetCapacity(), 0); -} -TEST_F(ContinuousVectorUT, WriteOk) { - auto vec_holder = ContinuousVector::Create(2); - auto vec = reinterpret_cast(vec_holder.get()); - ASSERT_NE(vec, nullptr); - EXPECT_EQ(vec->GetSize(), 0); - EXPECT_EQ(vec->GetCapacity(), 2); - - EXPECT_EQ(vec->SetSize(2), ge::GRAPH_SUCCESS); - auto data = reinterpret_cast(vec->MutableData()); - data[0] = 1024; - data[1] = 2048; - EXPECT_EQ(vec->GetSize(), 2); - EXPECT_EQ(reinterpret_cast(vec->GetData())[0], 1024); - EXPECT_EQ(reinterpret_cast(vec->GetData())[1], 2048); -} -TEST_F(ContinuousVectorUT, TypedOk) { - auto vec_holder = ContinuousVector::Create(16); - auto vec = reinterpret_cast(vec_holder.get()); - ASSERT_NE(vec, nullptr); - EXPECT_EQ(vec->SetSize(4), ge::GRAPH_SUCCESS); - auto data = reinterpret_cast(vec->MutableData()); - data[0] = 1024; - data[1] = 2048; - data[2] = 4096; - data[3] = 8192; - - auto t_vec = reinterpret_cast *>(vec); - EXPECT_EQ(t_vec->GetSize(), 4); - EXPECT_EQ(t_vec->GetCapacity(), 16); - EXPECT_EQ(t_vec->GetData()[0], 1024); - EXPECT_EQ(t_vec->GetData()[1], 2048); - EXPECT_EQ(t_vec->GetData()[2], 4096); - EXPECT_EQ(t_vec->GetData()[3], 8192); - auto mt_vec = reinterpret_cast *>(vec); - EXPECT_EQ(mt_vec->MutableData()[0], 1024); - EXPECT_EQ(mt_vec->MutableData()[1], 2048); - EXPECT_EQ(mt_vec->MutableData()[2], 4096); - EXPECT_EQ(mt_vec->MutableData()[3], 8192); -} - -TEST_F(ContinuousVectorUT, GetOverHeadLengthOk) { - EXPECT_EQ(ContinuousVectorVector::GetOverHeadLength(0), sizeof(ContinuousVectorVector)); - EXPECT_EQ(ContinuousVectorVector::GetOverHeadLength(1), sizeof(ContinuousVectorVector)); - EXPECT_EQ(ContinuousVectorVector::GetOverHeadLength(3), sizeof(ContinuousVectorVector) + (3U - 1U) * sizeof(size_t)); -} - -TEST_F(ContinuousVectorUT, AddOk) { - std::vector> vector_vector_list{{0, 2}, {1, 1}, {2, 4, 3}, {}, {1}}; - - size_t inner_vector_num = vector_vector_list.size(); - size_t total_length = ContinuousVectorVector::GetOverHeadLength(inner_vector_num); - for (const auto &inner_vector : vector_vector_list) { - size_t inner_vector_length = sizeof(ContinuousVector) + sizeof(int64_t) * inner_vector.size(); - total_length += inner_vector_length; - } - - std::vector buf(total_length); - auto cvv = new (buf.data()) ContinuousVectorVector(); - ASSERT_NE(cvv, nullptr); - cvv->Init(inner_vector_num); - - for (const auto &inner_vector : vector_vector_list) { - auto cv = cvv->Add(inner_vector.size()); - cv->Init(inner_vector.size()); - cv->SetSize(inner_vector.size()); - if (inner_vector.empty()) { - continue; - } - auto ret = memcpy_s(cv->MutableData(), cv->GetCapacity() * sizeof(int64_t), - inner_vector.data(), inner_vector.size() * sizeof(int64_t)); - ASSERT_EQ(ret, EOK); - } - - auto new_cvv = reinterpret_cast(buf.data()); - ASSERT_EQ(new_cvv->GetSize(), vector_vector_list.size()); - for (size_t i = 0U; i < vector_vector_list.size(); ++i) { - auto new_cv = new_cvv->Get(i); - ASSERT_EQ(new_cv->GetSize(), vector_vector_list[i].size()); - ASSERT_EQ(new_cv->GetSize(), new_cv->GetCapacity()); - const int64_t *data = reinterpret_cast(new_cv->GetData()); - for (size_t j = 0U; j < new_cv->GetSize(); ++j) { - EXPECT_EQ(data[j], vector_vector_list[i][j]); - } - } -} - -TEST_F(ContinuousVectorUT, AddFailed_WithoutInit) { - std::vector buf(100); - auto cvv = new (buf.data()) ContinuousVectorVector(); - ASSERT_NE(cvv, nullptr); - auto cv = cvv->Add(2); - EXPECT_EQ(cv, nullptr); -} - -TEST_F(ContinuousVectorUT, AddFailed_OutBounds) { - std::vector buf(500); - auto cvv = new (buf.data()) ContinuousVectorVector(); - ASSERT_NE(cvv, nullptr); - cvv->Init(2); - auto cv = cvv->Add(2); - EXPECT_NE(cv, nullptr); - cv = cvv->Add(2); - EXPECT_NE(cv, nullptr); - cv = cvv->Add(2); - EXPECT_EQ(cv, nullptr); -} -} // namespace gert diff --git a/tests/ut/exe_graph/data_dependent_interpreter_unittest.cc b/tests/ut/exe_graph/data_dependent_interpreter_unittest.cc deleted file mode 100644 index 709d90c5a412ed70a9231527ab21522edaf9d2ca..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/data_dependent_interpreter_unittest.cc +++ /dev/null @@ -1,545 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/data_dependent_interpreter.h" -#include -#include "graph/node.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "register/op_impl_registry.h" -#include "faker/node_faker.h" -#include "faker/space_registry_faker.h" -#include "common/checker.h" - -namespace gert { -namespace { -// todo 把注册做成stub的庄能力,不影响其他流程 -IMPL_OP(DDIT02).InputsDataDependency({0, 2}); -IMPL_OP(DDIT1).InputsDataDependency({1}); -IMPL_OP(DDIT3).TilingInputsDataDependency({1, 2}); -IMPL_OP(DDIT4); -IMPL_OP(DDIT5).TilingInputsDataDependency({1, 2}, {TilingPlacement::TILING_ON_AICPU}); - -/* - * ub graph: - * - * NetOutput - * | - * Foo - * | - * ddit02 - * / | \ - * data0 data1 data2 - */ -ge::NodePtr FakeUbNode02() { - auto node = ComputeNodeFaker().NameAndType("UbNode", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - auto ub_graph = std::make_shared("ub_graph"); - auto data0 = ComputeNodeFaker(ub_graph) - .NameAndType("Data0", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 0) - .IoNum(0, 1) - .Build(); - auto data1 = ComputeNodeFaker(ub_graph) - .NameAndType("Data1", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 1) - .IoNum(0, 1) - .Build(); - auto data2 = ComputeNodeFaker(ub_graph) - .NameAndType("Data2", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 2) - .IoNum(0, 1) - .Build(); - - auto ddit02 = - ComputeNodeFaker(ub_graph).NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ddit02->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data0->GetOutDataAnchor(0), ddit02->GetInDataAnchor(0))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), ddit02->GetInDataAnchor(1))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), ddit02->GetInDataAnchor(2))); - - auto foo = ComputeNodeFaker(ub_graph).NameAndType("Foo", "Foo").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(ddit02->GetOutDataAnchor(0), foo->GetInDataAnchor(0))); - - auto netoutput = ComputeNodeFaker(ub_graph).NameAndType("NetOutput", "NetOutput").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(foo->GetOutDataAnchor(0), netoutput->GetInDataAnchor(0))); - - GE_DUMP(ub_graph, "TestUbGraph"); - - GE_ASSERT_TRUE(ge::AttrUtils::SetGraph(node->GetOpDesc(), "_original_fusion_graph", ub_graph)); - - return node; -} - -/* - * ub graph: - * - * NetOutput - * | - * Foo - * | - * ddit02 - * / | \ - * const0 data1 data2 - */ -ge::NodePtr FakeUbNode02OnlyHasTwoData() { - auto node = ComputeNodeFaker().NameAndType("UbNode", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - auto ub_graph = std::make_shared("ub_graph"); - auto data2 = ComputeNodeFaker(ub_graph) - .NameAndType("const0", "Const") - .IoNum(0, 1) - .Build(); - auto data0 = ComputeNodeFaker(ub_graph) - .NameAndType("Data1", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 0) - .IoNum(0, 1) - .Build(); - auto data1 = ComputeNodeFaker(ub_graph) - .NameAndType("Data2", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 1) - .IoNum(0, 1) - .Build(); - - auto ddit02 = - ComputeNodeFaker(ub_graph).NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ddit02->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data0->GetOutDataAnchor(0), ddit02->GetInDataAnchor(0))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), ddit02->GetInDataAnchor(1))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), ddit02->GetInDataAnchor(2))); - - auto foo = ComputeNodeFaker(ub_graph).NameAndType("Foo", "Foo").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(ddit02->GetOutDataAnchor(0), foo->GetInDataAnchor(0))); - - auto netoutput = ComputeNodeFaker(ub_graph).NameAndType("NetOutput", "NetOutput").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(foo->GetOutDataAnchor(0), netoutput->GetInDataAnchor(0))); - - GE_DUMP(ub_graph, "TestUbGraph"); - - GE_ASSERT_TRUE(ge::AttrUtils::SetGraph(node->GetOpDesc(), "_original_fusion_graph", ub_graph)); - - return node; -} -/* - * ub graph: - * - * NetOutput - * | - * Foo - * | - * ddit02 - * / | \ - * data0 data1 data2 - */ -ge::NodePtr FakeUbNode02DataDoesNotHasIndex() { - auto node = ComputeNodeFaker().NameAndType("UbNode", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - auto ub_graph = std::make_shared("ub_graph"); - auto data0 = ComputeNodeFaker(ub_graph) - .NameAndType("Data0", "Data") - .IoNum(0, 1) - .Build(); - auto data1 = ComputeNodeFaker(ub_graph) - .NameAndType("Data1", "Data") - .IoNum(0, 1) - .Build(); - auto data2 = ComputeNodeFaker(ub_graph) - .NameAndType("Data2", "Data") - .IoNum(0, 1) - .Build(); - - auto ddit02 = - ComputeNodeFaker(ub_graph).NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ddit02->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data0->GetOutDataAnchor(0), ddit02->GetInDataAnchor(0))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), ddit02->GetInDataAnchor(1))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), ddit02->GetInDataAnchor(2))); - - auto foo = ComputeNodeFaker(ub_graph).NameAndType("Foo", "Foo").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(ddit02->GetOutDataAnchor(0), foo->GetInDataAnchor(0))); - - auto netoutput = ComputeNodeFaker(ub_graph).NameAndType("NetOutput", "NetOutput").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(foo->GetOutDataAnchor(0), netoutput->GetInDataAnchor(0))); - - GE_DUMP(ub_graph, "TestUbGraph"); - - GE_ASSERT_TRUE(ge::AttrUtils::SetGraph(node->GetOpDesc(), "_original_fusion_graph", ub_graph)); - - return node; -} -/* - * ub graph: - * - * NetOutput - * | - * Foo - * | - * ddit02 - * / | \ - * data0 data1 data2 - */ -ge::NodePtr FakeUbNode02TypeFoo() { - auto node = ComputeNodeFaker().NameAndType("UbNode", "Foo").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - - auto ub_graph = std::make_shared("ub_graph"); - auto data0 = ComputeNodeFaker(ub_graph) - .NameAndType("Data0", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 0) - .IoNum(0, 1) - .Build(); - auto data1 = ComputeNodeFaker(ub_graph) - .NameAndType("Data1", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 1) - .IoNum(0, 1) - .Build(); - auto data2 = ComputeNodeFaker(ub_graph) - .NameAndType("Data2", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 2) - .IoNum(0, 1) - .Build(); - - auto ddit02 = - ComputeNodeFaker(ub_graph).NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ddit02->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data0->GetOutDataAnchor(0), ddit02->GetInDataAnchor(0))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), ddit02->GetInDataAnchor(1))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), ddit02->GetInDataAnchor(2))); - - auto foo = ComputeNodeFaker(ub_graph).NameAndType("Foo", "Foo").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(ddit02->GetOutDataAnchor(0), foo->GetInDataAnchor(0))); - - auto netoutput = ComputeNodeFaker(ub_graph).NameAndType("NetOutput", "NetOutput").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(foo->GetOutDataAnchor(0), netoutput->GetInDataAnchor(0))); - - GE_DUMP(ub_graph, "TestUbGraph"); - - GE_ASSERT_TRUE(ge::AttrUtils::SetGraph(node->GetOpDesc(), "_original_fusion_graph", ub_graph)); - - return node; -} -/* - * ub graph: - * - * NetOutput - * | - * Foo - * | - * ddit1 - * / | \ - * data0 data2 data1 - */ -ge::NodePtr FakeUbNode1() { - auto node = ComputeNodeFaker().NameAndType("UbNode", "DDIT1").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - node->GetOpDesc()->SetOpInferDepends({"y"}); - - auto ub_graph = std::make_shared("ub_graph"); - auto data0 = ComputeNodeFaker(ub_graph) - .NameAndType("Data0", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 0) - .IoNum(0, 1) - .Build(); - auto data1 = ComputeNodeFaker(ub_graph) - .NameAndType("Data1", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 1) - .IoNum(0, 1) - .Build(); - auto data2 = ComputeNodeFaker(ub_graph) - .NameAndType("Data2", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 2) - .IoNum(0, 1) - .Build(); - - auto ddit02 = ComputeNodeFaker(ub_graph).NameAndType("Test", "DDIT1").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ddit02->GetOpDesc()->SetOpInferDepends({"y"}); - - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data0->GetOutDataAnchor(0), ddit02->GetInDataAnchor(0))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), ddit02->GetInDataAnchor(2))); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), ddit02->GetInDataAnchor(1))); - - auto foo = ComputeNodeFaker(ub_graph).NameAndType("Foo", "Foo").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(ddit02->GetOutDataAnchor(0), foo->GetInDataAnchor(0))); - - auto netoutput = ComputeNodeFaker(ub_graph).NameAndType("NetOutput", "NetOutput").IoNum(1, 1).Build(); - GE_ASSERT_SUCCESS(ge::GraphUtils::AddEdge(foo->GetOutDataAnchor(0), netoutput->GetInDataAnchor(0))); - - GE_DUMP(ub_graph, "TestUbGraph"); - - GE_ASSERT_TRUE(ge::AttrUtils::SetGraph(node->GetOpDesc(), "_original_fusion_graph", ub_graph)); - - return node; -} -} // namespace -class DataDependentInterpreterUT : public testing::Test {}; -TEST_F(DataDependentInterpreterUT, SimpleNode_ReturnTrue_V2V1True) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - auto space_registry = SpaceRegistryFaker().Build(); - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - - ASSERT_EQ(DataDependentInterpreter(node, space_registry).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); // 兼容 - ASSERT_TRUE(ret); -} - -TEST_F(DataDependentInterpreterUT, SimpleNode_ReturnTrue_V2V1True_UseRegistryV2) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - auto space_registry = SpaceRegistryFaker().BuildV2(); - OpImplSpaceRegistryV2Array space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); // 兼容 - ASSERT_TRUE(ret); -} - -TEST_F(DataDependentInterpreterUT, SimpleNode_TilingDepend_ReturnTrue_V2V1False_UseRegistryV2) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT_error").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - auto space_registry = SpaceRegistryFaker().BuildV2(); - OpImplSpaceRegistryV2Array space_registries; - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - space_registries[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - (void)ge::AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_BINARY_SOURCE, - static_cast(ge::OppImplVersion::kVersionEnd)); - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - auto node2 = ComputeNodeFaker().NameAndType("Test", "DDIT4").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node2->GetOpDesc(), space_registries).IsTilingInputDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); -} - -TEST_F(DataDependentInterpreterUT, SimpleNode_TilingDepend_ReturnTrue_V2V1False) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT_error").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - auto space_registry = SpaceRegistryFaker().Build(); - OpImplSpaceRegistryArray space_registries; - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - space_registries[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - (void)ge::AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_BINARY_SOURCE, - static_cast(ge::OppImplVersion::kVersionEnd)); - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - auto node2 = ComputeNodeFaker().NameAndType("Test", "DDIT4").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node2->GetOpDesc(), space_registries).IsTilingInputDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); -} - -TEST_F(DataDependentInterpreterUT, SimpleNode_TilingDepend_ReturnTrue_V2V1) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT3").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(1UL, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(0UL, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsTilingInputDataDependent(2UL, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} - -TEST_F(DataDependentInterpreterUT, SimpleNode_TilingDependPlacement_ReturnTrue_V2V1False) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT1").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - auto space_registry = SpaceRegistryFaker().Build(); - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsSupportTilingDependPlacement(0, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - (void)ge::AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_BINARY_SOURCE, - static_cast(ge::OppImplVersion::kOppKernel)); - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsSupportTilingDependPlacement(0, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - auto node2 = ComputeNodeFaker().NameAndType("Test", "DDIT5").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node2, nullptr); - // space_registries.fill() - space_registries[static_cast(ge::OppImplVersion::kOpp)] = nullptr; - ASSERT_EQ(DataDependentInterpreter(node2->GetOpDesc(), space_registries).IsSupportTilingDependPlacement(0, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - auto node3 = ComputeNodeFaker().NameAndType("Test", "DDIT_error").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node3, nullptr); - space_registries[static_cast(ge::OppImplVersion::kOpp)] = space_registry; - ASSERT_EQ(DataDependentInterpreter(node3->GetOpDesc(), space_registries).IsSupportTilingDependPlacement(0, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); -} - -TEST_F(DataDependentInterpreterUT, SimpleNode_TilingDependPlacement_ReturnTrue_V2V1) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT3").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsSupportTilingDependPlacement(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - auto node2 = ComputeNodeFaker().NameAndType("Test", "DDIT5").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node2, nullptr); - ASSERT_EQ(DataDependentInterpreter(node2->GetOpDesc(), space_registries).IsSupportTilingDependPlacement(0, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); - ASSERT_EQ(DataDependentInterpreter(node2->GetOpDesc(), space_registries).IsSupportTilingDependPlacement(1, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} - -TEST_F(DataDependentInterpreterUT, SimpleNode_ReturnFalse_V2V1False) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - bool ret = true; - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); -} -TEST_F(DataDependentInterpreterUT, SimpleNode_ReturnTrueAndLogWarning_V2FalseV1True) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"y"}); - bool ret = false; - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} -TEST_F(DataDependentInterpreterUT, SimpleNode_ReturnTrue_V2TrueV1False) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} -TEST_F(DataDependentInterpreterUT, UbGraphNode_ReturnTrue_V2V1UbGraphTrue) { - auto node = FakeUbNode02(); - ASSERT_NE(node, nullptr); - - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} -TEST_F(DataDependentInterpreterUT, UbGraphNode_ReturnFalse_V2V1UbGraphFalse) { - auto node = FakeUbNode02(); - ASSERT_NE(node, nullptr); - - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - bool ret = true; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_FALSE(ret); -} -TEST_F(DataDependentInterpreterUT, UbGraphNode_ReturnTrueAndLogWarning_V2V1TrueUbGraphFalse) { - auto node = FakeUbNode1(); - ASSERT_NE(node, nullptr); - - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - bool ret = true; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(1, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} -TEST_F(DataDependentInterpreterUT, UbGraphNode_ReturnTrue_V2V1FalseUbGraphTrue) { - auto node = FakeUbNode02TypeFoo(); - ASSERT_NE(node, nullptr); - - auto space_registry = SpaceRegistryFaker().Build(); - DataDependentInterpreter ddi(node, space_registry); - - bool ret = true; - ASSERT_EQ(ddi.IsDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - ret = true; - ASSERT_EQ(ddi.IsDataDependent(2, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} -TEST_F(DataDependentInterpreterUT, UbGraphNode_Failed_InvalidDataInUbGraph) { - auto node = FakeUbNode02DataDoesNotHasIndex(); - ASSERT_NE(node, nullptr); - bool ret; - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - ASSERT_NE(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(0, ret), ge::GRAPH_SUCCESS); -} -TEST_F(DataDependentInterpreterUT, UbGraphNode_Failed_DataIndexMissmatch) { - auto node = FakeUbNode02OnlyHasTwoData(); - ASSERT_NE(node, nullptr); - bool ret; - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - ASSERT_NE(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); -} -TEST_F(DataDependentInterpreterUT, SimpleNode_With_EmptyRegistry) { - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), OpImplSpaceRegistryArray()).IsDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} -TEST_F(DataDependentInterpreterUT, OnlyV1Node_ReturnTrueAndLogWarning_V1True) { - auto node = ComputeNodeFaker().NameAndType("Test", "FooNotRegister").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - OpImplSpaceRegistryArray space_registries; - space_registries[static_cast(ge::OppImplVersion::kOpp)] = SpaceRegistryFaker().Build(); - bool ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(0, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); - ret = false; - ASSERT_EQ(DataDependentInterpreter(node->GetOpDesc(), space_registries).IsDataDependent(2, ret), ge::GRAPH_SUCCESS); - ASSERT_TRUE(ret); -} -} // namespace gert diff --git a/tests/ut/exe_graph/execute_graph_types_unittest.cc b/tests/ut/exe_graph/execute_graph_types_unittest.cc deleted file mode 100644 index 8173f350db4e1d0d42e683eb8f1851579da9f65d..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/execute_graph_types_unittest.cc +++ /dev/null @@ -1,22 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/execute_graph_types.h" -#include -namespace gert { -class ExecuteGraphTypesUT : public testing::Test {}; -TEST_F(ExecuteGraphTypesUT, GetStr_Ok) { - EXPECT_STREQ(GetExecuteGraphTypeStr(ExecuteGraphType::kInit), "Init"); - EXPECT_STREQ(GetExecuteGraphTypeStr(ExecuteGraphType::kMain), "Main"); - EXPECT_STREQ(GetExecuteGraphTypeStr(ExecuteGraphType::kDeInit), "DeInit"); -} -TEST_F(ExecuteGraphTypesUT, GetStr_Nullptr_OutOfRange) { - EXPECT_EQ(GetExecuteGraphTypeStr(ExecuteGraphType::kNum), nullptr); -} -} // namespace gert diff --git a/tests/ut/exe_graph/expand_dims_type_unittest.cc b/tests/ut/exe_graph/expand_dims_type_unittest.cc deleted file mode 100644 index 797d5c22dfba73f39c2b124f4f45250757b769bc..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/expand_dims_type_unittest.cc +++ /dev/null @@ -1,166 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/expand_dims_type.h" -#include -namespace gert { -class ExpandDimsTypeUT : public testing::Test {}; -TEST_F(ExpandDimsTypeUT, TestDoNotExpandSpecifyAll) { - auto shape = Shape{2, 3, 5}; - ExpandDimsType edt("000"); - Shape out_shape; - edt.Expand(shape, out_shape); - ASSERT_EQ(out_shape, shape); -} -TEST_F(ExpandDimsTypeUT, TestDoNotExpandSpecifyPart1) { - auto shape = Shape{2, 3, 5}; - ExpandDimsType edt("0"); - Shape out_shape; - edt.Expand(shape, out_shape); - ASSERT_EQ(out_shape, shape); -} -TEST_F(ExpandDimsTypeUT, TestDoNotExpandSpecifyPart2) { - auto shape = Shape{2, 3, 5}; - ExpandDimsType edt("0"); - Shape out_shape; - edt.Expand(shape, out_shape); - ASSERT_EQ(out_shape, shape); -} -TEST_F(ExpandDimsTypeUT, TestDoNotExpandSpecifyNone) { - auto shape = Shape{2, 3, 5}; - ExpandDimsType edt(""); - Shape out_shape; - edt.Expand(shape, out_shape); - ASSERT_EQ(out_shape, shape); -} - -TEST_F(ExpandDimsTypeUT, ExpandAtHead) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt("11000"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(5, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({1, 1, 2, 16, 16})); -} -TEST_F(ExpandDimsTypeUT, ExpandAtHeadSpecifyPart) { - auto shape = Shape{2, 16}; - ExpandDimsType edt("110"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(4, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({1, 1, 2, 16})); -} -TEST_F(ExpandDimsTypeUT, ExpandAtHeadSpecifyNone) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt("11"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(5, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({1, 1, 2, 16, 16})); -} -TEST_F(ExpandDimsTypeUT, ExpandAtHeadSpecifyNone1) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt("1"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(4, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({1, 2, 16, 16})); -} -TEST_F(ExpandDimsTypeUT, ExpandAtLast) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt("00011"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(5, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({2, 16, 16, 1, 1})); -} -TEST_F(ExpandDimsTypeUT, ExpandAtLast3Dim) { - auto shape = Shape{2}; - ExpandDimsType edt("0111"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(4, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({2, 1, 1, 1})); -} -TEST_F(ExpandDimsTypeUT, ExpandHeadAndLast) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt("10001"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(5, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({1, 2, 16, 16, 1})); -} -TEST_F(ExpandDimsTypeUT, ExpandMiddle) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt("01010"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(5, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({2, 1, 16, 1, 16})); -} -TEST_F(ExpandDimsTypeUT, ExpandMiddleSpecifyPart) { - auto shape = Shape{2, 16}; - ExpandDimsType edt("011"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(4, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({2, 1, 1, 16})); -} -TEST_F(ExpandDimsTypeUT, ExpandDimsMoreThanShape) { - auto shape = Shape{2, 16}; - ExpandDimsType edt("1000"); - Shape out_shape; - EXPECT_NE(edt.Expand(shape, out_shape), ge::GRAPH_SUCCESS); -} -TEST_F(ExpandDimsTypeUT, NullInput) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt(nullptr); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(out_shape, Shape({2, 16, 16})); -} -TEST_F(ExpandDimsTypeUT, Over56Limits) { - auto shape = Shape{2, 16, 16}; - std::string s; - for (size_t i = 0; i <= 56; ++i) { - s.push_back('1'); - } - ExpandDimsType edt(s.c_str()); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(out_shape, Shape({2, 16, 16})); -} -TEST_F(ExpandDimsTypeUT, ExpandSpecifyPart) { - auto shape = Shape{2, 16, 16}; - ExpandDimsType edt("100"); - Shape out_shape; - edt.Expand(shape, out_shape); - - ASSERT_EQ(3, out_shape.GetDimNum()); - ASSERT_EQ(out_shape, Shape({2, 16, 16})); -} -TEST_F(ExpandDimsTypeUT, GetExpandDimMask) { - ExpandDimsType edt("011"); - ASSERT_EQ(edt.GetFullSize(), 3); - ASSERT_FALSE(edt.IsExpandIndex(0)); - ASSERT_TRUE(edt.IsExpandIndex(1)); - ASSERT_TRUE(edt.IsExpandIndex(2)); -} -} // namespace gert diff --git a/tests/ut/exe_graph/fast_dev_mem_value_holder_unittest.cc b/tests/ut/exe_graph/fast_dev_mem_value_holder_unittest.cc deleted file mode 100644 index 8168699d8cf5b004764b08b793940cd7c9b4661f..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/fast_dev_mem_value_holder_unittest.cc +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/dev_mem_value_holder.h" -#include "checker/topo_checker.h" -#include "checker/bg_test.h" - -namespace gert { -namespace bg { -class FastDevMemValueHolderUt : public BgTest { -}; - -TEST_F(FastDevMemValueHolderUt, SetGetLogicStreamOk) { - int64_t logic_stream_id = 2; - auto output0 = DevMemValueHolder::CreateSingleDataOutput("test", {}, logic_stream_id); - output0->SetPlacement(1); - EXPECT_EQ(output0->GetPlacement(), 1); - EXPECT_EQ(output0->GetLogicStream(), 2); -} - -TEST_F(FastDevMemValueHolderUt, DevMemCreateErrorOk) { - auto holder = DevMemValueHolder::CreateError(0, "This is a test error information, int %d, %s", 10240, "Test msg"); - ASSERT_NE(holder, nullptr); - EXPECT_FALSE(holder->IsOk()); -} - -TEST_F(FastDevMemValueHolderUt, DevMemCreateDataOutputOk) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data1 = ValueHolder::CreateFeed(0); - ASSERT_NE(const1, nullptr); - ASSERT_NE(data1, nullptr); - - std::vector inputs = {data1, const1}; - auto holders = DevMemValueHolder::CreateDataOutput("TestNode", inputs, 3, 0); - - ASSERT_EQ(holders.size(), 3); - ASSERT_TRUE(holders[0]->IsOk()); - ASSERT_TRUE(holders[1]->IsOk()); - ASSERT_TRUE(holders[2]->IsOk()); - EXPECT_EQ(holders[0]->GetType(), ValueHolder::ValueHolderType::kOutput); - EXPECT_EQ(holders[1]->GetType(), ValueHolder::ValueHolderType::kOutput); - EXPECT_EQ(holders[2]->GetType(), ValueHolder::ValueHolderType::kOutput); -} - -TEST_F(FastDevMemValueHolderUt, DevMemCreateConstOk) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto holder = DevMemValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1), 0); - ASSERT_NE(holder, nullptr); - ASSERT_TRUE(holder->IsOk()); -} - -TEST_F(FastDevMemValueHolderUt, DevMemCreateMateFromNodeOk) { - ge::ExecuteGraphPtr graph = std::make_shared("graph"); - auto op_desc = std::make_shared("FakeNode", "FakeNode"); - auto node = graph->AddNode(op_desc); - auto holder = std::make_shared(2); - ASSERT_NE(holder, nullptr); - auto value_holder = holder->CreateMateFromNode(node, 0, ValueHolder::ValueHolderType::kOutput); - ASSERT_NE(value_holder, nullptr); - ASSERT_TRUE(value_holder->IsOk()); - EXPECT_EQ(value_holder->GetType(), ValueHolder::ValueHolderType::kOutput); - auto mem_holder = std::dynamic_pointer_cast(value_holder); - ASSERT_NE(mem_holder, nullptr); - EXPECT_EQ(mem_holder->GetLogicStream(), 2); -} -} // namespace bg -} // namespace gert diff --git a/tests/ut/exe_graph/fast_frame_selector_unittest.cc b/tests/ut/exe_graph/fast_frame_selector_unittest.cc deleted file mode 100644 index 0dbc67a62e2b75bf70b25d0f953ffcf2b8107054..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/fast_frame_selector_unittest.cc +++ /dev/null @@ -1,1057 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "checker/bg_test.h" -#include "checker/summary_checker.h" -#include "checker/topo_checker.h" -#include "exe_graph/lowering/frame_selector.h" -#include "exe_graph/runtime/execute_graph_types.h" -#include "graph/utils/execute_graph_utils.h" -#include -#include -namespace gert { -namespace bg { -class FastFrameSelectorUT : public BgTest { - public: - GraphFrame *root_frame = nullptr; - std::unique_ptr init_frame; - std::unique_ptr de_init_frame; - void InitTestFrames() { - root_frame = ValueHolder::GetCurrentFrame(); - auto init_node = ValueHolder::CreateVoid("Init", {}); - ValueHolder::PushGraphFrame(init_node, "Init"); - init_frame = ValueHolder::PopGraphFrame(); - - auto de_init_node = ValueHolder::CreateVoid("DeInit", {}); - ValueHolder::PushGraphFrame(de_init_node, "DeInit"); - de_init_frame = ValueHolder::PopGraphFrame(); - - auto main_node = ValueHolder::CreateVoid(GetExecuteGraphTypeStr(ExecuteGraphType::kMain), {}); - ValueHolder::PushGraphFrame(main_node, "Main"); - } -}; - -/* - * +-----------------------+ - * |FooGraph | - * | | - * | InnerNetOutput | - * | | | - * | foo2 <--+ | - * | / \ | - * | c1 \ | - * +---+-----+------+------+ - * | | | - * data0 data1 bar1 - * | - * c0 - */ -TEST_F(FastFrameSelectorUT, SelectMainRoot_CreateOnRoot_NoMainGraph) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {data0, data1}); - EXPECT_NE(ValueHolder::PushGraphFrame(foo1, "FooGraph"), nullptr); - - // on FooGraph - auto c1 = ValueHolder::CreateConst("ConstData", 10, true); - - // on RootGraph - auto bars = FrameSelector::OnMainRoot([&]() -> std::vector { - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0}); - return {bar1}; - }); - ASSERT_EQ(bars.size(), 1); - ASSERT_NE(bars[0], nullptr); - - // on FooGraph - auto foo2 = ValueHolder::CreateSingleDataOutput("Foo2", {c1, bars[0]}); - auto foo2_graph = ValueHolder::PopGraphFrame({foo2}, {}); - ASSERT_NE(foo2_graph, nullptr); - - auto frame = ValueHolder::PopGraphFrame(); - - ASSERT_EQ( - ExeGraphSummaryChecker(frame->GetExecuteGraph().get()).StrictDirectNodeTypes( - {{"Data", 2}, {"Const", 1}, {"Foo1", 1}, {"Bar1", 1}}), "success"); - ASSERT_EQ(ExeGraphSummaryChecker(foo2_graph->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"InnerData", 1}, {"Const", 1}, {"Foo2", 1}, {"InnerNetOutput", 1}}), - "success"); - - ASSERT_EQ(FastNodeTopoChecker(bars[0]).StrictConnectTo(0, {{"Foo1", 2}}), "success"); - ASSERT_EQ(FastNodeTopoChecker(bars[0]).StrictConnectFrom({{"Const"}}), "success"); -} -TEST_F(FastFrameSelectorUT, SelectMainRoot_CreateOnMainRoot_CurrentFrameIsMainRoot) { - InitTestFrames(); - - auto data0 = ValueHolder::CreateFeed(0); - - // on RootGraph - auto bars = FrameSelector::OnMainRoot([&]() -> std::vector { - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0, data0}); - return {bar1}; - }); - ASSERT_EQ(bars.size(), 1); - ASSERT_NE(bars[0], nullptr); - - auto main_frame = ValueHolder::PopGraphFrame(); - auto root_frame = ValueHolder::PopGraphFrame(); - ASSERT_NE(main_frame, nullptr); - ASSERT_NE(root_frame, nullptr); - - ASSERT_EQ(ExeGraphSummaryChecker(root_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Init", 1}, {"Main", 1}, {"DeInit", 1}}), - "success"); - ASSERT_EQ(ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 1}, {"Const", 1}, {"Bar1", 1}}), - "success"); - - ASSERT_EQ(FastNodeTopoChecker(bars[0]).StrictConnectFrom({{"Const"}, {"Data"}}), "success"); - ASSERT_TRUE(bars[0]->GetFastNode()->GetAllOutNodes().empty()); -} - -/* - * +-----------------------+ - * |FooGraph | - * | | - * | InnerNetOutput | - * | | | - * | foo2 <--+ | - * | / \ | - * | c1 \ | - * +---+-----+------+------+ - * | | | - * data0 data1 bar1 - * | - * c0 - */ -TEST_F(FastFrameSelectorUT, SelectMainRoot_CreateOnMainRoot_CurrentFrameIsMainSubgraphs) { - InitTestFrames(); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {data0, data1}); - ValueHolder::PushGraphFrame(foo1, "FooGraph"); - - // on FooGraph - auto c1 = ValueHolder::CreateConst("ConstData", 10, true); - - // on MainRootGraph - auto bars = FrameSelector::OnMainRoot([&]() -> std::vector { - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0}); - return {bar1}; - }); - ASSERT_EQ(bars.size(), 1); - ASSERT_NE(bars[0], nullptr); - - // on FooGraph - auto foo2 = ValueHolder::CreateSingleDataOutput("Foo2", {c1, bars[0]}); - auto foo2_graph = ValueHolder::PopGraphFrame({foo2}, {}); - ASSERT_NE(foo2_graph, nullptr); - - auto frame = ValueHolder::PopGraphFrame(); - - ASSERT_EQ( - ExeGraphSummaryChecker(frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 2}, {"Const", 1}, {"Foo1", 1}, {"Bar1", 1}}), - "success"); - ASSERT_EQ(ExeGraphSummaryChecker(foo2_graph->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"InnerData", 1}, {"Const", 1}, {"Foo2", 1}, {"InnerNetOutput", 1}}), - "success"); - - ASSERT_EQ(FastNodeTopoChecker(bars[0]).StrictConnectTo(0, {{"Foo1", 2}}), "success"); - ASSERT_EQ(FastNodeTopoChecker(bars[0]).StrictConnectFrom({{"Const"}}), "success"); -} -TEST_F(FastFrameSelectorUT, SelectMainRoot_Failed_BuilderIsNullptr) { - ASSERT_EQ(FrameSelector::OnMainRoot(nullptr).size(), 0); -} -TEST_F(FastFrameSelectorUT, SelectMainRoot_Failed_ConnectFromSubgraph) { - InitTestFrames(); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {data0, data1}); - ValueHolder::PushGraphFrame(foo1, "FooGraph"); - - // on FooGraph - auto c1 = ValueHolder::CreateConst("ConstData", 10, true); - - // on RootGraph - auto bars = FrameSelector::OnMainRoot([&]() -> std::vector { - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c1}); - return {bar1}; - }); - ASSERT_EQ(bars.size(), 1); - ASSERT_EQ(bars[0], nullptr); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_Success_OnlyInitNodeOut) { - InitTestFrames(); - auto ret = FrameSelector::OnInitRoot([]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1, c2}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {foo1, c1}, 2); - return {foo1, foo2[0], foo2[1]}; - }); - - ASSERT_EQ(ret.size(), 3); - ASSERT_EQ(ret[0]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(ret[1]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(ret[2]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(ret[0]->GetOutIndex(), 0); - ASSERT_EQ(ret[1]->GetOutIndex(), 1); - ASSERT_EQ(ret[2]->GetOutIndex(), 2); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_Success_ReEnter) { - InitTestFrames(); - auto ret = FrameSelector::OnInitRoot([]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1, c2}); - - auto out_graph_outs = FrameSelector::OnInitRoot([]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - return {ValueHolder::CreateSingleDataOutput("Bar1", {c1, c2})}; - }); - - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {HolderOnInit(out_graph_outs[0]), foo1}, 2); - return {foo1, foo2[0], foo2[1]}; - }); - ASSERT_EQ(ret.size(), 3); - - EXPECT_EQ(ExeGraphSummaryChecker(init_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Const", 4}, {"Foo1", 1}, {"Bar1", 1}, {"Foo2", 1}, {"InnerNetOutput", 1}}), - "success"); - - auto inner_netoutput_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "InnerNetOutput"); - ASSERT_NE(inner_netoutput_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(inner_netoutput_node).StrictConnectFrom({{"Bar1", 0}, {"Foo1", 0}, {"Foo2", 0}, {"Foo2", 1}}), - "success"); - - auto foo2_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "Foo2"); - ASSERT_NE(foo2_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(foo2_node).StrictConnectFrom({{"Bar1", 0}, {"Foo1", 0}}), - "success"); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_TheSameWithInput_ReturnInput) { - InitTestFrames(); - auto ret = FrameSelector::OnInitRoot([]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1, c2}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {foo1, c1}, 2); - return {foo1, foo2[0], foo2[1]}; - }); - - ASSERT_EQ(ret.size(), 3); - - auto ret2 = FrameSelector::OnInitRoot([&]() -> std::vector { - return {HolderOnInit(ret[0]), - ValueHolder::CreateSingleDataOutput("Foo3", {HolderOnInit(ret[0]), ValueHolder::CreateConst("Hello", 5)}), - HolderOnInit(ret[1])}; - }); - ASSERT_EQ(ret2.size(), 3); - - ASSERT_EQ(ret2[0]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(ret2[1]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(ret2[2]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(ret2[0]->GetOutIndex(), 0); - ASSERT_EQ(ret2[1]->GetOutIndex(), 3); - ASSERT_EQ(ret2[2]->GetOutIndex(), 1); - - auto netoutput = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "InnerNetOutput"); - ASSERT_NE(netoutput, nullptr); - EXPECT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom({{"Foo1", 0}, {"Foo2", 0}, {"Foo2", 1}, {"Foo3", 0}}), - "success"); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_Success_InitNodeAndInitGraphOut) { - InitTestFrames(); - std::vector init_node_outputs; - std::vector init_graph_outputs; - auto ret = FrameSelector::OnInitRoot( - []() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1, c2}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {foo1, c1}, 2); - return {foo1, foo2[0], foo2[1]}; - }, - init_graph_outputs, init_node_outputs); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - - ASSERT_EQ(init_node_outputs.size(), 3); - ASSERT_EQ(init_node_outputs[0]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(init_node_outputs[1]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(init_node_outputs[2]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(init_node_outputs[0]->GetOutIndex(), 0); - ASSERT_EQ(init_node_outputs[1]->GetOutIndex(), 1); - ASSERT_EQ(init_node_outputs[2]->GetOutIndex(), 2); - - ASSERT_EQ(init_graph_outputs.size(), 3); - ASSERT_EQ(init_graph_outputs[0]->GetFastNode()->GetType(), "Foo1"); - ASSERT_EQ(init_graph_outputs[1]->GetFastNode()->GetType(), "Foo2"); - ASSERT_EQ(init_graph_outputs[2]->GetFastNode()->GetType(), "Foo2"); - ASSERT_EQ(init_graph_outputs[0]->GetOutIndex(), 0); - ASSERT_EQ(init_graph_outputs[1]->GetOutIndex(), 0); - ASSERT_EQ(init_graph_outputs[2]->GetOutIndex(), 1); - ASSERT_EQ(init_graph_outputs[1]->GetFastNode(), init_graph_outputs[2]->GetFastNode()); - - ASSERT_EQ(FastNodeTopoChecker(init_graph_outputs[0]).StrictConnectTo(0, {{"InnerNetOutput", 0}, {"Foo2", 0}}), "success"); - ASSERT_EQ(FastNodeTopoChecker(init_graph_outputs[1]).StrictConnectTo(0, {{"InnerNetOutput", 1}}), "success"); - ASSERT_EQ(FastNodeTopoChecker(init_graph_outputs[1]).StrictConnectTo(1, {{"InnerNetOutput", 2}}), "success"); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_AllConnectToInnerNetOutput_CallTwoTimes) { - InitTestFrames(); - auto outputs1 = - FrameSelector::OnInitRoot([]() -> std::vector { return {ValueHolder::CreateConst("Hello", 5)}; }); - ASSERT_EQ(outputs1.size(), 1); - - std::vector init_node_outputs; - std::vector init_graph_outputs; - auto ret = FrameSelector::OnInitRoot( - []() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - return ValueHolder::CreateDataOutput("Foo2", {c1, c2}, 2); - }, - init_graph_outputs, init_node_outputs); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - - ASSERT_EQ(init_graph_outputs.size(), 2); - ASSERT_EQ(ExeGraphSummaryChecker(init_graph_outputs[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()) - .StrictAllNodeTypes({{"Const", 3}, {"Foo2", 1}, {"InnerNetOutput", 1}}), - "success"); - - ASSERT_EQ(init_node_outputs[1]->GetFastNode(), init_node_outputs[1]->GetFastNode()); - ASSERT_EQ(init_node_outputs[0]->GetFastNode()->GetType(), "Init"); - ASSERT_EQ(init_node_outputs[0]->GetOutIndex(), 1); - ASSERT_EQ(init_node_outputs[1]->GetOutIndex(), 2); - - ASSERT_EQ(init_graph_outputs.size(), 2); - ASSERT_EQ(init_graph_outputs[0]->GetFastNode()->GetType(), "Foo2"); - ASSERT_EQ(init_graph_outputs[1]->GetFastNode()->GetType(), "Foo2"); - ASSERT_EQ(init_graph_outputs[0]->GetOutIndex(), 0); - ASSERT_EQ(init_graph_outputs[1]->GetOutIndex(), 1); - ASSERT_EQ(init_graph_outputs[1]->GetFastNode(), init_graph_outputs[1]->GetFastNode()); - - ASSERT_EQ(FastNodeTopoChecker(init_graph_outputs[0]).StrictConnectTo(0, {{"InnerNetOutput", 1}}), "success"); - ASSERT_EQ(FastNodeTopoChecker(init_graph_outputs[1]).StrictConnectTo(1, {{"InnerNetOutput", 2}}), "success"); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_FrameCorrect_AfterSelection) { - InitTestFrames(); - - auto d0 = ValueHolder::CreateFeed(0); - auto d1 = ValueHolder::CreateFeed(1); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {d0, d1}); - ValueHolder::PushGraphFrame(bar1, "Graph"); - - auto init_outputs = FrameSelector::OnInitRoot([]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - return ValueHolder::CreateDataOutput("Foo2", {c1, c2}, 2); - }); - ASSERT_EQ(init_outputs.size(), 2); - - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {d0, d1}); - ValueHolder::PopGraphFrame({foo1}, {}); - - ASSERT_NE(bar1, nullptr); - auto bar1_graph = ge::FastNodeUtils::GetSubgraphFromNode(bar1->GetFastNode(), 0); - ASSERT_NE(bar1_graph, nullptr); - ASSERT_EQ(ExeGraphSummaryChecker(bar1_graph).StrictAllNodeTypes({{"InnerData", 2}, {"Foo1", 1}, {"InnerNetOutput", 1}}), - "success"); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_ConnectFromResultOk) { - InitTestFrames(); - - auto d0 = ValueHolder::CreateFeed(0); - auto d1 = ValueHolder::CreateFeed(1); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {d0, d1}); - - const ge::FastNode *foo2_node; - auto init_outputs = FrameSelector::OnInitRoot([&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2}, 2); - foo2_node = foo2[0]->GetFastNode(); - return foo2; - }); - ASSERT_EQ(init_outputs.size(), 2); - - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {bar1, init_outputs[0]}); - - ASSERT_NE(foo1, nullptr); - ASSERT_EQ(foo1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetParentNodeBarePtr()->GetType(), "Main"); - ASSERT_EQ(FastNodeTopoChecker(foo1).StrictConnectFrom({{"Bar1"}, {"InnerData"}}), "success"); - ConnectFromInitToMain(foo2_node, 0, foo1->GetFastNode(), 1); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_PlacementCorrect) { - InitTestFrames(); - - auto init_outputs = FrameSelector::OnInitRoot([&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2}, 2); - foo2[0]->SetPlacement(kOnDeviceHbm); - foo2[1]->SetPlacement(kOnDeviceHbm); - auto foo3 = ValueHolder::CreateSingleDataOutput("Foo3", {c1, c2}); - foo3->SetPlacement(kOnHost); - return {foo2[0], foo2[1], foo3}; - }); - ASSERT_EQ(init_outputs.size(), 3); - EXPECT_EQ(init_outputs[0]->GetPlacement(), kOnDeviceHbm); - EXPECT_EQ(init_outputs[1]->GetPlacement(), kOnDeviceHbm); - EXPECT_EQ(init_outputs[2]->GetPlacement(), kOnHost); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_GuarderInDeInit_OutputHasGuarder) { - InitTestFrames(); - - auto init_outputs = FrameSelector::OnInitRoot([&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1}); - auto foo1_guarder = ValueHolder::CreateVoidGuarder("FreeFoo1", foo1, {}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2, foo1}, 2); - auto foo2_guarder = ValueHolder::CreateVoidGuarder("FreeFoo2", foo2[0], {}); - return foo2; - }); - ASSERT_EQ(init_outputs.size(), 2); - - ASSERT_EQ(ExeGraphSummaryChecker(init_frame->GetExecuteGraph().get()) - .StrictAllNodeTypes({{"Const", 2}, {"Foo1", 1}, {"FreeFoo1", 1}, {"Foo2", 1}, {"InnerNetOutput", 1}}), - "success"); - auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "InnerNetOutput"); - ASSERT_NE(netoutput, nullptr); - ASSERT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom({{"Foo2", 0}, {"Foo2", 1}}), "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(de_init_frame->GetExecuteGraph().get()) - .StrictAllNodeTypes({{"InnerData", 1}, {"FreeFoo2", 1}}), - "success"); - - auto foo2 = ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "Foo2"); - ASSERT_NE(foo2, nullptr); - auto free_foo2 = ge::ExecuteGraphUtils::FindFirstNodeMatchType(de_init_frame->GetExecuteGraph().get(), "FreeFoo2"); - ASSERT_NE(free_foo2, nullptr); - ConnectFromInitToDeInit(foo2, 0, free_foo2, 0); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_GuarderInDeInit_MultipleOutputHasGuarder) { - InitTestFrames(); - - auto init_outputs = FrameSelector::OnInitRoot([&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1}); - auto foo1_guarder = ValueHolder::CreateVoidGuarder("FreeFoo1", foo1, {}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2, foo1}, 2); - auto foo2_guarder = ValueHolder::CreateVoidGuarder("FreeFoo2", foo2[1], {}); - return {foo1, foo2[0], foo2[1]}; - }); - ASSERT_EQ(init_outputs.size(), 3); - - ASSERT_EQ(ExeGraphSummaryChecker(init_frame->GetExecuteGraph().get()) - .StrictAllNodeTypes({{"Const", 2}, {"Foo1", 1}, {"Foo2", 1}, {"InnerNetOutput", 1}}), - "success"); - auto netoutput = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "InnerNetOutput"); - ASSERT_NE(netoutput, nullptr); - ASSERT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom({{"Foo1", 0}, {"Foo2", 0}, {"Foo2", 1}}), "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(de_init_frame->GetExecuteGraph().get()) - .StrictAllNodeTypes({{"InnerData", 2}, {"FreeFoo2", 1}, {"FreeFoo1", 1}}), - "success"); - - auto foo2 = ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "Foo2"); - ASSERT_NE(foo2, nullptr); - auto free_foo2 = ge::ExecuteGraphUtils::FindFirstNodeMatchType(de_init_frame->GetExecuteGraph().get(), "FreeFoo2"); - ASSERT_NE(free_foo2, nullptr); - ConnectFromInitToDeInit(foo2, 1, free_foo2, 0); - - auto foo1 = ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "Foo1"); - ASSERT_NE(foo1, nullptr); - auto free_foo1 = ge::ExecuteGraphUtils::FindFirstNodeMatchType(de_init_frame->GetExecuteGraph().get(), "FreeFoo1"); - ASSERT_NE(free_foo1, nullptr); - ConnectFromInitToDeInit(foo1, 0, free_foo1, 0); -} -TEST_F(FastFrameSelectorUT, SelectInitRoot_AllGuardersInDeInit_MultipleTimes) { - InitTestFrames(); - - ge::FastNode *g1_foo2_node; - ge::FastNode *g1_free_foo2_node; - ge::FastNode *g1_foo1_node; - ge::FastNode *g1_free_foo1_node; - auto init_outputs_1 = FrameSelector::OnInitRoot([&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1}); - auto foo1_guarder = ValueHolder::CreateVoidGuarder("FreeFoo1", foo1, {}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2, foo1}, 2); - auto foo2_guarder = ValueHolder::CreateVoidGuarder("FreeFoo2", foo2[1], {}); - g1_foo1_node = foo1->GetFastNode(); - g1_free_foo1_node = foo1_guarder->GetFastNode(); - g1_foo2_node = foo2[0]->GetFastNode(); - g1_free_foo2_node = foo2_guarder->GetFastNode(); - return {foo1, foo2[0], foo2[1]}; - }); - ASSERT_EQ(init_outputs_1.size(), 3); - - ge::FastNode *g2_foo2_node; - ge::FastNode *g2_free_foo2_node; - ge::FastNode *g2_foo1_node; - ge::FastNode *g2_free_foo1_node; - auto init_outputs_2 = FrameSelector::OnInitRoot([&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1}); - auto foo1_guarder = ValueHolder::CreateVoidGuarder("FreeFoo1", foo1, {}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2, foo1}, 2); - auto foo2_guarder = ValueHolder::CreateVoidGuarder("FreeFoo2", foo2[1], {}); - g2_foo1_node = foo1->GetFastNode(); - g2_free_foo1_node = foo1_guarder->GetFastNode(); - g2_foo2_node = foo2[0]->GetFastNode(); - g2_free_foo2_node = foo2_guarder->GetFastNode(); - return {foo1, foo2[0], foo2[1]}; - }); - ASSERT_EQ(init_outputs_2.size(), 3); - - ASSERT_EQ(ExeGraphSummaryChecker(init_frame->GetExecuteGraph().get()) - .StrictAllNodeTypes({{"Const", 4}, {"Foo1", 2}, {"Foo2", 2}, {"InnerNetOutput", 1}}), - "success"); - auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "InnerNetOutput"); - ASSERT_NE(netoutput, nullptr); - ASSERT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom( - {{"Foo1", 0}, {"Foo2", 0}, {"Foo2", 1}, {"Foo1", 0}, {"Foo2", 0}, {"Foo2", 1}}), - "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(de_init_frame->GetExecuteGraph().get()) - .StrictAllNodeTypes({{"InnerData", 4}, {"FreeFoo2", 2}, {"FreeFoo1", 2}}), - "success"); - - // check 1 - auto foo2 = init_frame->GetExecuteGraph()->FindNode(g1_foo2_node->GetNodeToken()); - ASSERT_NE(foo2, nullptr); - auto free_foo2 = de_init_frame->GetExecuteGraph()->FindNode(g1_free_foo2_node->GetNodeToken()); - ASSERT_NE(free_foo2, nullptr); - ConnectFromInitToDeInit(foo2, 1, free_foo2, 0); - - auto foo1 = init_frame->GetExecuteGraph()->FindNode(g1_foo1_node->GetNodeToken()); - ASSERT_NE(foo1, nullptr); - auto free_foo1 = de_init_frame->GetExecuteGraph()->FindNode(g1_free_foo1_node->GetNodeToken()); - ASSERT_NE(free_foo1, nullptr); - ConnectFromInitToDeInit(foo1, 0, free_foo1, 0); - - // check 2 - foo2 = init_frame->GetExecuteGraph()->FindNode(g2_foo2_node->GetNodeToken()); - ASSERT_NE(foo2, nullptr); - free_foo2 = de_init_frame->GetExecuteGraph()->FindNode(g2_free_foo2_node->GetNodeToken()); - ASSERT_NE(free_foo2, nullptr); - ConnectFromInitToDeInit(foo2, 1, free_foo2, 0); - - foo1 = init_frame->GetExecuteGraph()->FindNode(g2_foo1_node->GetNodeToken()); - ASSERT_NE(foo1, nullptr); - free_foo1 = de_init_frame->GetExecuteGraph()->FindNode(g2_free_foo1_node->GetNodeToken()); - ASSERT_NE(free_foo1, nullptr); - ConnectFromInitToDeInit(foo1, 0, free_foo1, 0); -} -TEST_F(FastFrameSelectorUT, HolderOnInit_GetInput_WhenInputInInit) { - InitTestFrames(); - - std::vector graph_outputs; - std::vector node_outputs; - ValueHolderPtr foo1; - auto ret = FrameSelector::OnInitRoot( - [&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1}); - auto foo1_guarder = ValueHolder::CreateVoidGuarder("FreeFoo1", foo1, {}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2, foo1}, 3); - auto foo2_guarder = ValueHolder::CreateVoidGuarder("FreeFoo2", foo2[1], {}); - return foo2; - }, - graph_outputs, node_outputs); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - - ASSERT_NE(foo1, nullptr); - ASSERT_EQ(HolderOnInit(foo1), foo1); - ASSERT_EQ(HolderOnInit(graph_outputs[0]), graph_outputs[0]); - ASSERT_EQ(HolderOnInit(graph_outputs[1]), graph_outputs[1]); - ASSERT_EQ(HolderOnInit(graph_outputs[2]), graph_outputs[2]); -} -TEST_F(FastFrameSelectorUT, HolderOnInit_GetInitOutput_WhenInputIsInitNode) { - InitTestFrames(); - - std::vector graph_outputs; - std::vector node_outputs; - ValueHolderPtr foo1; - auto ret = FrameSelector::OnInitRoot( - [&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1}); - auto foo1_guarder = ValueHolder::CreateVoidGuarder("FreeFoo1", foo1, {}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2, foo1}, 3); - auto foo2_guarder = ValueHolder::CreateVoidGuarder("FreeFoo2", foo2[1], {}); - return foo2; - }, - graph_outputs, node_outputs); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - - ASSERT_NE(foo1, nullptr); - ASSERT_EQ(HolderOnInit(foo1), foo1); - ASSERT_EQ(HolderOnInit(node_outputs[0])->GetFastNode(), graph_outputs[0]->GetFastNode()); - ASSERT_EQ(HolderOnInit(node_outputs[0])->GetOutIndex(), graph_outputs[0]->GetOutIndex()); - ASSERT_EQ(HolderOnInit(node_outputs[1])->GetFastNode(), graph_outputs[1]->GetFastNode()); - ASSERT_EQ(HolderOnInit(node_outputs[1])->GetOutIndex(), graph_outputs[1]->GetOutIndex()); - ASSERT_EQ(HolderOnInit(node_outputs[2])->GetFastNode(), graph_outputs[2]->GetFastNode()); - ASSERT_EQ(HolderOnInit(node_outputs[2])->GetOutIndex(), graph_outputs[2]->GetOutIndex()); -} - -TEST_F(FastFrameSelectorUT, OnMainRootLast_GetInitOutput_WhenInputIsInitNode) { - InitTestFrames(); - - std::vector graph_outputs; - std::vector node_outputs; - ValueHolderPtr foo1; - auto ret = FrameSelector::OnInitRoot( - [&]() -> std::vector { - auto c1 = ValueHolder::CreateConst("Hello", 5); - auto c2 = ValueHolder::CreateConst("Hello", 5); - foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {c1}); - auto foo1_guarder = ValueHolder::CreateVoidGuarder("FreeFoo1", foo1, {}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {c1, c2, foo1}, 3); - auto foo2_guarder = ValueHolder::CreateVoidGuarder("FreeFoo2", foo2[1], {}); - return foo2; - }, - graph_outputs, node_outputs); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - - ASSERT_NE(foo1, nullptr); - ASSERT_EQ(HolderOnInit(foo1), foo1); - ASSERT_EQ(HolderOnInit(node_outputs[0])->GetFastNode(), graph_outputs[0]->GetFastNode()); - ASSERT_EQ(HolderOnInit(node_outputs[0])->GetOutIndex(), graph_outputs[0]->GetOutIndex()); - ASSERT_EQ(HolderOnInit(node_outputs[1])->GetFastNode(), graph_outputs[1]->GetFastNode()); - ASSERT_EQ(HolderOnInit(node_outputs[1])->GetOutIndex(), graph_outputs[1]->GetOutIndex()); - ASSERT_EQ(HolderOnInit(node_outputs[2])->GetFastNode(), graph_outputs[2]->GetFastNode()); - ASSERT_EQ(HolderOnInit(node_outputs[2])->GetOutIndex(), graph_outputs[2]->GetOutIndex()); -} -/* - * data0 c0 - * \ / - * bar1 - */ -TEST_F(FastFrameSelectorUT, OnMainRootLast_SetLastExecNode_NoNetoutput_Fail) { - InitTestFrames(); - // build exe graph on main frame - auto data0 = ValueHolder::CreateFeed(0); - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0, data0}); - - auto last_node_builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateVoid("LastExecNode", {c0}); - }; - auto last_holder = bg::FrameSelector::OnMainRootLast(last_node_builder); - EXPECT_NE(last_holder, nullptr); - auto last_exec_nodes = bg::ValueHolder::GetLastExecNodes(); - EXPECT_FALSE(last_exec_nodes.empty()); - - auto main_frame = ValueHolder::PopGraphFrame(); - ASSERT_EQ(ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 1}, {"Const", 1}, {"Bar1", 1}, {"LastExecNode", 1}}), - "success"); -} - -/* - * data0 c0 c0 - * \ / | - * bar1 last_node: LastExecNode - * | - * netoutput - * - */ -TEST_F(FastFrameSelectorUT, OnMainRootLast_SetLastExecNode_OneNodeInsideOneLayer_DefaultPriority) { - InitTestFrames(); - // build exe graph on main frame - auto data0 = ValueHolder::CreateFeed(0); - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0, data0}); - auto output = ValueHolder::CreateSingleDataOutput("NetOutput", {bar1}); - - auto last_node_builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateVoid("LastExecNode", {c0}); - }; - auto last_holder = bg::FrameSelector::OnMainRootLast(last_node_builder); - EXPECT_NE(last_holder, nullptr); - auto last_exec_nodes = bg::ValueHolder::GetLastExecNodes(); - EXPECT_FALSE(last_exec_nodes.empty()); - - auto main_frame = ValueHolder::PopGraphFrame(); - ASSERT_EQ(ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 1}, {"Const", 1}, {"Bar1", 1}, {"LastExecNode", 1}, {"NetOutput", 1}}), - "success"); - ASSERT_EQ(FastNodeTopoChecker(output).StrictConnectFrom({{"Bar1"}}), "success"); -} -/* - * data0 c0 c0 data0 c0 - * \ / / \ \ / - * bar1 last_node: L0 L1 ====> bar1 - * | / \c - * netoutput | noop - * | /c \c - * | L0 L1 - * | \c /c - * | noop - * \ /c - * netoutput - * 这里L0和L1分别有来自c0的数据输入。示意图中忽略 - */ -TEST_F(FastFrameSelectorUT, OnMainRootLast_SetLastExecNode_TwoNodeInsideOneLayer_DefaultPriority) { - InitTestFrames(); - // build exe graph on main frame - auto data0 = ValueHolder::CreateFeed(0); - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0, data0}); - auto output = ValueHolder::CreateSingleDataOutput("NetOutput", {bar1}); - - auto last0_node_builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateVoid("L0", {c0}); - }; - auto last0_holder = bg::FrameSelector::OnMainRootLast(last0_node_builder); - EXPECT_NE(last0_holder, nullptr); - auto last1_node_builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateVoid("L1", {c0}); - }; - auto last1_holder = bg::FrameSelector::OnMainRootLast(last1_node_builder); - EXPECT_NE(last1_holder, nullptr); - auto last_exec_nodes = bg::ValueHolder::GetLastExecNodes(); - EXPECT_EQ(last_exec_nodes.size(), 2U); - - auto main_frame = ValueHolder::PopGraphFrame(); - ASSERT_EQ( - ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 1}, {"Const", 1}, {"Bar1", 1}, {"L0", 1}, {"L1", 1}, {"NetOutput", 1}}), - "success"); -} -/* - * data0 c0 c0 c0 - * \ / / \ / \ - * bar1 last_priority0: L0 L1 last_priority1: L2 L3 - * | - * netoutput - * - * | - * V subgraph of PartitionCall - * +----------------------------------+ +-------------------------+ - * | data0 c0 | | InnerData | - * | \ / \ | | / \ \ \ | - * | bar1 PartitionedCall | | L0 L1 L2 L3 | - * | | | | \ / / / | - * | netoutput | | InnerNetoutput | - * +----------------------------------+ +-------------------------+ - * 这里L0和L1分别有来自c0的数据输入。示意图中忽略 - */ -TEST_F(FastFrameSelectorUT, OnMainRootLast_SetLastExecNode_TwoNodeInsideEachLayer_TwoPriority) { - InitTestFrames(); - // build exe graph on main frame - auto data0 = ValueHolder::CreateFeed(0); - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0, data0}); - auto bar2 = ValueHolder::CreateSingleDataOutput("Bar2", {c0}); - auto output = ValueHolder::CreateSingleDataOutput("NetOutput", {bar1, bar2}); - - for (size_t node_size = 0U; node_size < 2; ++node_size) { - auto last_node_builder = [&]() -> bg::ValueHolderPtr { return bg::ValueHolder::CreateVoid("LastExec", {c0}); }; - auto last_holder = - bg::FrameSelector::OnMainRootLast(last_node_builder); - EXPECT_NE(last_holder, nullptr); - } - for (size_t node_size = 0U; node_size < 2; ++node_size) { - auto last_node_builder = [&]() -> std::vector { return {bg::ValueHolder::CreateVoid("LastEventExec", {data0})}; }; - auto last_holders = - bg::FrameSelector::OnMainRootLastEventSync(last_node_builder); - EXPECT_FALSE(last_holders.empty()); - } - - auto main_frame = ValueHolder::PopGraphFrame(); - std::vector last_exe_nodes = main_frame->GetLastExecNodes(); - EXPECT_EQ(last_exe_nodes.size(), 2); - - auto stage_ids_2_pcalls = main_frame->GetExecuteGraph()->GetExtAttr>(kStageIdsToLastPartitionedCall); - EXPECT_NE(stage_ids_2_pcalls, nullptr); - auto last_event_sync_exe_node = stage_ids_2_pcalls->at(static_cast(OnMainRootLastExecStage::kLastEventSyncStage)); - EXPECT_NE(last_event_sync_exe_node, nullptr); - - auto sub_exe_graph = ge::FastNodeUtils::GetSubgraphFromNode(last_event_sync_exe_node->GetFastNode(), 0U); - EXPECT_NE(sub_exe_graph, nullptr); - auto last_exec_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(sub_exe_graph, "LastEventExec"); - EXPECT_NE(last_exec_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(last_exec_node).StrictConnectFrom({{"InnerData", 0}}), "success"); - ASSERT_NE(FastNodeTopoChecker(last_exec_node).StrictConnectTo(-1, {{"InnerNetoutput", 0}}), "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(sub_exe_graph) - .StrictDirectNodeTypes({{"LastEventExec", 2}, {"InnerData", 1}, {"InnerNetOutput", 1}}), - "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 1}, - {"Const", 1}, - {"Bar1", 1}, - {"Bar2", 1}, - {"LastExec", 2}, - {"PartitionedCall", 1}, - {"NetOutput", 1}}), - "success"); - ASSERT_EQ(FastNodeTopoChecker(output).StrictConnectFrom({{"Bar1"}, {"Bar2"}}), "success"); -} - -/* - * data0 c0 c0 data0 c0 - * \ / / \ \ / - * bar1 last_node: L0 L1 ====> bar1 - * | / \c - * netoutput | noop - * | /c \c - * | L0 L1 - * | \c /c - * | noop - * \ /c - * netoutput - * 这里L0和L1分别有来自c0的数据输入。示意图中忽略 - */ -TEST_F(FastFrameSelectorUT, OnMainRootLast_SetLastExecNode_TwoNodeInsideOneLayer_LastResourceCleanStage) { - InitTestFrames(); - // build exe graph on main frame - auto data0 = ValueHolder::CreateFeed(0); - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0, data0}); - auto output = ValueHolder::CreateSingleDataOutput("NetOutput", {bar1}); - - auto last0_node_builder = [&]() -> std::vector { - return {bg::ValueHolder::CreateVoid("L0", {c0})}; - }; - auto last0_holder = bg::FrameSelector::OnMainRootLastResourceClean(last0_node_builder); - EXPECT_FALSE(last0_holder.empty()); - auto last1_node_builder = [&]() -> std::vector { - return {bg::ValueHolder::CreateVoid("L1", {c0})}; - }; - auto last1_holder = bg::FrameSelector::OnMainRootLastResourceClean(last1_node_builder); - EXPECT_FALSE(last1_holder.empty()); - auto last_exec_nodes = bg::ValueHolder::GetLastExecNodes(); - EXPECT_EQ(last_exec_nodes.size(), 0U); - - auto main_frame = ValueHolder::PopGraphFrame(); - ASSERT_EQ( - ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 1}, {"Const", 1}, {"Bar1", 1}, {"PartitionedCall", 1}, {"NetOutput", 1}}), - "success"); - - auto stage_ids_2_pcalls = main_frame->GetExecuteGraph()->GetExtAttr>(kStageIdsToLastPartitionedCall); - EXPECT_NE(stage_ids_2_pcalls, nullptr); - auto last_resource_clean_exe_node = stage_ids_2_pcalls->at(static_cast(OnMainRootLastExecStage::kLastResourceClean)); - EXPECT_NE(last_resource_clean_exe_node, nullptr); - - auto sub_exe_graph = ge::FastNodeUtils::GetSubgraphFromNode(last_resource_clean_exe_node->GetFastNode(), 0U); - EXPECT_NE(sub_exe_graph, nullptr); - auto last_exec_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(sub_exe_graph, "L0"); - EXPECT_NE(last_exec_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(last_exec_node).StrictConnectFrom({{"InnerData", 0}}), "success"); - ASSERT_NE(FastNodeTopoChecker(last_exec_node).StrictConnectTo(-1, {{"InnerNetoutput", 0}}), "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(sub_exe_graph) - .StrictDirectNodeTypes({{"L0", 1}, {"L1", 1}, {"InnerData", 1}, {"InnerNetOutput", 1}}), - "success"); - ASSERT_EQ(FastNodeTopoChecker(output).StrictConnectFrom({{"Bar1"}}), "success"); -} - -/* - * data0 c0 c0 - * \ / | - * bar1 first_node: FirstExecNode - * | - * netoutput - * | - * V - * - * +----------------------------------+ +-------------------------+ - * | data0 c0 | | InnerData | - * | \ / \ | | | | - * | bar1 PartitionedCall | | FirstExecNode | - * | | | | | | - * | netoutput | | InnerNetoutput | - * +----------------------------------+ +-------------------------+ - */ -TEST_F(FastFrameSelectorUT, OnMainRootLast_SetFirstExecNode_OneNodeInsideOneLayer_OnePriority) { - dlog_setlevel(GE_MODULE_NAME, DLOG_DEBUG, 1); - InitTestFrames(); - // build exe graph on main frame - auto data0 = ValueHolder::CreateFeed(0); - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {c0, data0}); - auto output = ValueHolder::CreateSingleDataOutput("NetOutput", {bar1}); - - auto first_node_builder = [&]() -> std::vector { - return {bg::ValueHolder::CreateVoid("FirstExecNode", {c0})}; - }; - auto first_holders = bg::FrameSelector::OnMainRootFirst(first_node_builder); - EXPECT_FALSE(first_holders.empty()); - - auto main_frame = ValueHolder::PopGraphFrame(); - auto stage_ids_2_pcalls = main_frame->GetExecuteGraph()->GetExtAttr>(kStageIdsToFirstPartitionedCall); - EXPECT_NE(stage_ids_2_pcalls, nullptr); - auto first_exec_partitioncall = stage_ids_2_pcalls->at(static_cast(OnMainRootFirstExecStage::kFirstEventSyncStage)); - EXPECT_NE(first_exec_partitioncall, nullptr); - - auto sub_exe_graph = ge::FastNodeUtils::GetSubgraphFromNode(first_exec_partitioncall->GetFastNode(), 0U); - EXPECT_NE(sub_exe_graph, nullptr); - auto first_exec_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(sub_exe_graph, "FirstExecNode"); - EXPECT_NE(first_exec_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(first_exec_node).StrictConnectFrom({{"InnerData", 0}}), "success"); - ASSERT_NE(FastNodeTopoChecker(first_exec_node).StrictConnectTo(-1, {{"InnerNetoutput", 0}}), "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(sub_exe_graph) - .StrictDirectNodeTypes({{"FirstExecNode", 1}, {"InnerData", 1}, {"InnerNetOutput", 1}}), - "success"); - - ASSERT_EQ(ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes( - {{"Data", 1}, {"Const", 1}, {"Bar1", 1}, {"PartitionedCall", 1}, {"NetOutput", 1}}), - "success"); - ASSERT_EQ(FastNodeTopoChecker(data0).StrictConnectTo(0, {{"Bar1"}}), "success"); - ASSERT_NE(FastNodeTopoChecker(data0).StrictConnectTo(-1, {{"NoOp", -1}}), "success"); -} - -/* - * data0 c0 c0 - * / \ / / \ - * bar1 bar2 first_priority0: L0 L1 - * \ | - * netoutput - * - * | - * V - * +----------------------------------+ +-------------------------+ - * | data0 c0 | | InnerData | - * | / \ /\ | | / \ | - * | bar1 bar2 PartitionedCall | | L0 L1 | - * | \ | | | \ / | - * | netoutput | | InnerNetoutput | - * +----------------------------------+ +-------------------------+ - */ -TEST_F(FastFrameSelectorUT, OnMainRootLast_SetFirstExecNode_TwoNodeInsideEachLayer_TwoPriority) { - InitTestFrames(); - // build exe graph on main frame - auto data0 = ValueHolder::CreateFeed(0); - auto c0 = ValueHolder::CreateConst("ConstData", 10, true); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {data0}); - auto bar2 = ValueHolder::CreateSingleDataOutput("Bar2", {data0, c0}); - auto output = ValueHolder::CreateSingleDataOutput("NetOutput", {bar1, bar2}); - - for (size_t node_size=0U; node_size < 2; ++node_size) { - auto first_node_builder = [&]() -> std::vector { - return {bg::ValueHolder::CreateVoid("FirstExec", {c0})}; - }; - auto first_holders = bg::FrameSelector::OnMainRootFirst(first_node_builder); - EXPECT_FALSE(first_holders.empty()); - } - - auto main_frame = ValueHolder::PopGraphFrame(); - - auto stage_ids_2_pcalls = main_frame->GetExecuteGraph()->GetExtAttr>(kStageIdsToFirstPartitionedCall); - EXPECT_NE(stage_ids_2_pcalls, nullptr); - auto first_exec_partitioncall = stage_ids_2_pcalls->at(static_cast(OnMainRootFirstExecStage::kFirstEventSyncStage)); - EXPECT_NE(first_exec_partitioncall, nullptr); - - auto sub_exe_graph = ge::FastNodeUtils::GetSubgraphFromNode(first_exec_partitioncall->GetFastNode(), 0U); - EXPECT_NE(sub_exe_graph, nullptr); - ASSERT_EQ(ExeGraphSummaryChecker(sub_exe_graph) - .StrictDirectNodeTypes({{"InnerData", 1}, - {"FirstExec", 2}, - {"InnerNetOutput", 1}}), - "success"); - auto first_exec_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(sub_exe_graph, "FirstExec"); - EXPECT_NE(first_exec_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(first_exec_node).StrictConnectFrom({{"InnerData", 0}}), "success"); - ASSERT_NE(FastNodeTopoChecker(first_exec_node).StrictConnectTo(-1, {{"InnerNetoutput", 0}}), "success"); - - - ASSERT_EQ(ExeGraphSummaryChecker(main_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"Data", 1}, - {"Const", 1}, - {"Bar1", 1}, - {"Bar2", 1}, - {"PartitionedCall", 1}, - {"NetOutput", 1}}), - "success"); - ASSERT_EQ(FastNodeTopoChecker(output).StrictConnectFrom({{"Bar1"}, {"Bar2"}}), "success"); -} - -TEST_F(FastFrameSelectorUT, SelectDeInitRoot_CreateOnDeInitRoot_CurrentFrameIsRoot) { - InitTestFrames(); - auto frame = ValueHolder::PopGraphFrame(); - - // on RootGraph - const auto &bars = FrameSelector::OnDeInitRoot([]() -> std::vector { - auto bar1 = ValueHolder::CreateVoid("DavinviModelFinalize", {}); - return {bar1}; - }); - ASSERT_EQ(bars.size(), 1); - ASSERT_NE(bars[0], nullptr); - - ASSERT_EQ(ExeGraphSummaryChecker(de_init_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"DavinviModelFinalize", 1}}), - "success"); - - ASSERT_TRUE(bars[0]->GetFastNode()->GetAllOutNodes().empty()); - - const auto &c1 = ValueHolder::CreateConst("ConstData", 10, true); - const auto &c2 = ValueHolder::CreateConst("ConstData", 10, true); - ValueHolder::PushGraphFrame(frame.release()); - ASSERT_EQ(bars[0]->AppendInputs({c1, c2}), 0); - ASSERT_EQ(ExeGraphSummaryChecker(de_init_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"DavinviModelFinalize", 1}, {"InnerData", 2}}), - "success"); - const auto &c3 = ValueHolder::CreateConst("ConstData", 10, true); - ASSERT_EQ(c3->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetParentNodeBarePtr()->GetType(), "Main"); -} - -TEST_F(FastFrameSelectorUT, SelectDeInitRoot_CreateOnDeInitRoot_CurrentFrameIsMain) { - InitTestFrames(); - - // on RootGraph - const auto &bars = FrameSelector::OnDeInitRoot([]() -> std::vector { - auto bar1 = ValueHolder::CreateVoid("DavinviModelFinalize", {}); - return {bar1}; - }); - ASSERT_EQ(bars.size(), 1); - ASSERT_NE(bars[0], nullptr); - - ASSERT_EQ(ExeGraphSummaryChecker(de_init_frame->GetExecuteGraph().get()) - .StrictDirectNodeTypes({{"DavinviModelFinalize", 1}}), - "success"); - - ASSERT_TRUE(bars[0]->GetFastNode()->GetAllOutNodes().empty()); - - const auto &c1 = ValueHolder::CreateConst("ConstData", 10, true); - ASSERT_EQ(c1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetParentNodeBarePtr()->GetType(), "Main"); - - ASSERT_NE(bars[0]->AppendInputs({c1}), 0); -} - -TEST_F(FastFrameSelectorUT, SelectDeInitRoot_Failed_BuilderIsNullptr) { - ASSERT_EQ(FrameSelector::OnDeInitRoot(nullptr).size(), 0); -} -} // namespace bg -} // namespace gert diff --git a/tests/ut/exe_graph/fast_generate_exe_graph_unittest.cc b/tests/ut/exe_graph/fast_generate_exe_graph_unittest.cc deleted file mode 100644 index dc6b7ff50421cf9cf67ea21e0277e5910b7f35fb..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/fast_generate_exe_graph_unittest.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/generate_exe_graph.h" -#include -#include "exe_graph/lowering/dev_mem_value_holder.h" -#include "checker/bg_test.h" -#include "checker/topo_checker.h" -namespace gert { -using namespace bg; -namespace { -std::vector StubInferShape(const ge::NodePtr &node, - const std::vector &shapes, - LoweringGlobalData &globalData) { - return ValueHolder::CreateDataOutput("InferShape", shapes, 10); -} -std::vector StubAllocOutputMemory(TensorPlacement placement, const ge::NodePtr &node, - const std::vector &output_sizes, - LoweringGlobalData &global_data) { - return DevMemValueHolder::CreateDataOutput("AllocOutputMemory", output_sizes, output_sizes.size(), 0); -} -std::vector StubCalcTensorSize(const ge::NodePtr &node, - const std::vector &output_shapes) { - return ValueHolder::CreateDataOutput("CalcTensorSize", output_shapes, output_shapes.size()); -} -ge::NodePtr FakeNode() { - static size_t counter = 0; - static ge::ComputeGraphPtr graph = std::make_shared("graph"); - auto op_desc = std::make_shared("FakeNode_" + std::to_string(counter++), "FakeNode"); - return graph->AddNode(op_desc); -} -} // namespace -class FastGenerateExeGraphUT : public BgTest { - protected: - void SetUp() override { - BgTest::SetUp(); - bg::GenerateExeGraph::AddBuilderImplement({nullptr, nullptr, nullptr}); - } - void InitTestFramesWithStream(LoweringGlobalData &global_data) { - root_frame = bg::ValueHolder::GetCurrentFrame(); - auto init_node = bg::ValueHolder::CreateVoid("Init", {}); - bg::ValueHolder::PushGraphFrame(init_node, "Init"); - global_data.LoweringAndSplitRtStreams(1); - init_frame = bg::ValueHolder::PopGraphFrame(); - - auto de_init_node = bg::ValueHolder::CreateVoid("DeInit", {}); - bg::ValueHolder::PushGraphFrame(de_init_node, "DeInit"); - de_init_frame = bg::ValueHolder::PopGraphFrame(); - - auto main_node = bg::ValueHolder::CreateVoid(GetExecuteGraphTypeStr(ExecuteGraphType::kMain), {}); - bg::ValueHolder::PushGraphFrame(main_node, "Main"); - global_data.LoweringAndSplitRtStreams(1); - } - bg::GraphFrame *root_frame; - std::unique_ptr init_frame; - std::unique_ptr de_init_frame; -}; - -TEST_F(FastGenerateExeGraphUT, NoImpl_Failed_InferShape) { - LoweringGlobalData gd; - ASSERT_TRUE(bg::GenerateExeGraph::InferShape(nullptr, {bg::ValueHolder::CreateFeed(0)}, gd).empty()); -} -TEST_F(FastGenerateExeGraphUT, NoImpl_Failed_AllocOutputMemory) { - LoweringGlobalData gd; - ASSERT_TRUE( - bg::GenerateExeGraph::AllocOutputMemory(kOnDeviceHbm, nullptr, {bg::ValueHolder::CreateFeed(0)}, gd).empty()); -} -TEST_F(FastGenerateExeGraphUT, NoImpl_Failed_CalcTensorSize) { - LoweringGlobalData gd; - ASSERT_TRUE(bg::GenerateExeGraph::CalcTensorSize(nullptr, {bg::ValueHolder::CreateFeed(0)}).empty()); -} - -TEST_F(FastGenerateExeGraphUT, StubImpl_GraphCorrect_InferShape) { - bg::GenerateExeGraph::AddBuilderImplement({StubInferShape, nullptr, nullptr}); - auto input_shape = bg::ValueHolder::CreateFeed(0); - LoweringGlobalData gd; - auto shapes = bg::GenerateExeGraph::InferShape(nullptr, {input_shape}, gd); - ASSERT_EQ(shapes.size(), 10); - ASSERT_EQ(shapes[0]->GetFastNode()->GetType(), "InferShape"); - ASSERT_EQ(FastNodeTopoChecker(shapes[0]).StrictConnectFrom({{input_shape}}), "success"); -} -TEST_F(FastGenerateExeGraphUT, StubImpl_GraphCorrect_AllocOutputMemory) { - bg::GenerateExeGraph::AddBuilderImplement({nullptr, StubAllocOutputMemory, nullptr}); - auto input_shape0 = bg::ValueHolder::CreateFeed(0); - auto input_shape1 = bg::ValueHolder::CreateFeed(1); - - LoweringGlobalData gd; - auto shapes = bg::GenerateExeGraph::AllocOutputMemory(kOnDeviceHbm, nullptr, {input_shape0, input_shape1}, gd); - - ASSERT_EQ(shapes.size(), 2); - ASSERT_EQ(shapes[0]->GetFastNode()->GetType(), "AllocOutputMemory"); - ASSERT_EQ(FastNodeTopoChecker(shapes[0]).StrictConnectFrom({{input_shape0, input_shape1}}), "success"); -} -TEST_F(FastGenerateExeGraphUT, StubImpl_GraphCorrect_CalcTensorSize) { - bg::GenerateExeGraph::AddBuilderImplement({nullptr, nullptr, StubCalcTensorSize}); - auto input_shape0 = bg::ValueHolder::CreateFeed(0); - auto input_shape1 = bg::ValueHolder::CreateFeed(1); - - LoweringGlobalData gd; - auto shapes = bg::GenerateExeGraph::CalcTensorSize(nullptr, {input_shape0, input_shape1}); - - ASSERT_EQ(shapes.size(), 2); - ASSERT_EQ(shapes[0]->GetFastNode()->GetType(), "CalcTensorSize"); - ASSERT_EQ(FastNodeTopoChecker(shapes[0]).StrictConnectFrom({{input_shape0, input_shape1}}), "success"); -} -TEST_F(FastGenerateExeGraphUT, MakeSureTensorAtHost_Success) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data); - auto src_addr = (void *) (0x1); - auto addr = bg::ValueHolder::CreateConst(&src_addr, sizeof(void *)); - size_t src_size = 8U; - auto size = bg::ValueHolder::CreateConst(&src_size, sizeof(size_t)); - auto node = FakeNode().get(); - node->GetOpDesc()->SetStreamId(1); - ASSERT_NE(bg::GenerateExeGraph::MakeSureTensorAtHost(node, global_data, addr, size), nullptr); -} -TEST_F(FastGenerateExeGraphUT, CalcTensorSizeFromShape_Success) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data); - auto shape = ge::Shape(std::vector({1, 2, 3, 4})); - auto shape_holder = bg::ValueHolder::CreateConst(&shape, sizeof(ge::Shape)); - ASSERT_NE(bg::GenerateExeGraph::CalcTensorSizeFromShape(ge::DT_UINT8, shape_holder), nullptr); -} -TEST_F(FastGenerateExeGraphUT, FreeMemoryGuarder_Success) { - auto data_addr = (void *) (0x1); - auto addr = bg::ValueHolder::CreateConst(&data_addr, sizeof(void *)); - size_t data_size = 8; - auto size = ValueHolder::CreateConst(&data_size, sizeof(size_t)); - auto addr_holder = ValueHolder::CreateSingleDataOutput("CopyD2H", {addr, size}); - ASSERT_NE(bg::GenerateExeGraph::FreeMemoryGuarder(addr_holder), nullptr); -} -} // namespace gert diff --git a/tests/ut/exe_graph/fast_lowering_global_data_unittest.cc b/tests/ut/exe_graph/fast_lowering_global_data_unittest.cc deleted file mode 100644 index c97cdd6d60913f063ae5220c2d12df0dcc28c6fc..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/fast_lowering_global_data_unittest.cc +++ /dev/null @@ -1,1355 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/lowering_global_data.h" -#include "exe_graph/lowering/frame_selector.h" -#include -#include "checker/bg_test.h" -#include "exe_graph/lowering/value_holder.h" -#include "exe_graph/runtime/execute_graph_types.h" -#include "checker/summary_checker.h" -#include "checker/topo_checker.h" -#include "exe_graph/lowering/lowering_opt.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/graph_dump_utils.h" - -namespace gert { -namespace { -ge::NodePtr BuildTestNode() { - auto graph = std::make_shared("graph"); - auto op_desc = std::make_shared("node", "node"); - return graph->AddNode(op_desc); -} -} // namespace -class FastLoweringGlobalDataUT : public BgTest { - protected: - void SetUp() override { - BgTest::SetUp(); - } - - void InitTestFrames() { - root_frame = bg::ValueHolder::GetCurrentFrame(); - auto init_node = bg::ValueHolder::CreateVoid("Init", {}); - bg::ValueHolder::PushGraphFrame(init_node, "Init"); - init_frame = bg::ValueHolder::PopGraphFrame(); - - auto de_init_node = bg::ValueHolder::CreateVoid("DeInit", {}); - bg::ValueHolder::PushGraphFrame(de_init_node, "DeInit"); - de_init_frame = bg::ValueHolder::PopGraphFrame(); - - auto main_node = bg::ValueHolder::CreateVoid(GetExecuteGraphTypeStr(ExecuteGraphType::kMain), {}); - bg::ValueHolder::PushGraphFrame(main_node, "Main"); - } - void InitTestFramesWithStream(LoweringGlobalData &global_data, int64_t stream_num = 1) { - root_frame = bg::ValueHolder::GetCurrentFrame(); - auto init_node = bg::ValueHolder::CreateVoid("Init", {}); - bg::ValueHolder::PushGraphFrame(init_node, "Init"); - global_data.LoweringAndSplitRtStreams(1); - // prepare stream num in init - auto init_out = bg::FrameSelector::OnInitRoot([&stream_num, &global_data]()-> std::vector { - auto stream_num_holder = bg::ValueHolder::CreateConst(&stream_num, sizeof(stream_num)); - global_data.SetUniqueValueHolder(kGlobalDataModelStreamNum, stream_num_holder); - return {}; - }); - init_frame = bg::ValueHolder::PopGraphFrame(); - - auto de_init_node = bg::ValueHolder::CreateVoid("DeInit", {}); - bg::ValueHolder::PushGraphFrame(de_init_node, "DeInit"); - de_init_frame = bg::ValueHolder::PopGraphFrame(); - - auto main_node = bg::ValueHolder::CreateVoid(GetExecuteGraphTypeStr(ExecuteGraphType::kMain), {}); - bg::ValueHolder::PushGraphFrame(main_node, "Main"); - global_data.LoweringAndSplitRtStreams(stream_num); - } - bg::GraphFrame *root_frame; - std::unique_ptr init_frame; - std::unique_ptr de_init_frame; -}; - -TEST_F(FastLoweringGlobalDataUT, SetGetCompileResultOk) { - LoweringGlobalData gd; - - auto node = BuildTestNode(); - ASSERT_NE(node, nullptr); - - EXPECT_EQ(gd.FindCompiledResult(node), nullptr); - - gd.AddCompiledResult(node, {}); - ASSERT_NE(gd.FindCompiledResult(node), nullptr); - EXPECT_TRUE(gd.FindCompiledResult(node)->GetTaskDefs().empty()); -} - -TEST_F(FastLoweringGlobalDataUT, SetGetKnownSubgraphModel) { - LoweringGlobalData gd; - - std::string graph_name = "graph"; - - EXPECT_EQ(gd.GetGraphStaticCompiledModel(graph_name), nullptr); - - gd.AddStaticCompiledGraphModel(graph_name, reinterpret_cast(0x123)); - EXPECT_EQ(gd.GetGraphStaticCompiledModel(graph_name), reinterpret_cast(0x123)); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateAllocatorOk) { - LoweringGlobalData gd; - InitTestFramesWithStream(gd); - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - EXPECT_EQ(allocator1, gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput})); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_InitRootCreateSync1) { - InitTestFrames(); - LoweringGlobalData gd; - auto holder = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(holder, nullptr); - - ASSERT_EQ(gd.GetL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}), holder); - - std::vector on_init; - std::vector on_root; - auto ret = bg::FrameSelector::OnInitRoot( - [&]() -> std::vector { - auto allocator = gd.GetL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - return {allocator}; - }, - on_init, on_root); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - ASSERT_EQ(on_init.size(), 1U); - ASSERT_EQ(on_root.size(), 1U); - ASSERT_NE(on_init[0], nullptr); - ASSERT_NE(on_root[0], nullptr); -} -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_InitRootCreateSync2) { - InitTestFrames(); - LoweringGlobalData gd; - std::vector on_init; - std::vector on_root; - auto ret = bg::FrameSelector::OnInitRoot( - [&]() -> std::vector { - return {gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput})}; - }, - on_init, on_root); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - ASSERT_NE(on_init[0], nullptr); - - ASSERT_NE(gd.GetL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}), nullptr); -} -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_CreateSelectAllocator_MainExternalAllocatorSet) { - LoweringGlobalData gd; - InitTestFramesWithStream(gd); - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2)); - - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - EXPECT_EQ(allocator1->GetFastNode()->GetType(), "SelectL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(allocator1).StrictConnectFrom( - {{"InnerData"}, {"Data"}, {"InnerData"}, {"SplitRtStreams"}}), - "success"); - auto create_allocator_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "CreateL1Allocator"); - ASSERT_NE(create_allocator_node, nullptr); - ConnectFromInitToMain(create_allocator_node, 0, allocator1->GetFastNode(), 2); - - bg::ValueHolderPtr init_allocator = nullptr; - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - init_allocator = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - return {}; - }); - ASSERT_NE(init_allocator, nullptr); - EXPECT_EQ(init_allocator->GetFastNode()->GetType(), "CreateL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(init_allocator).StrictConnectFrom({{"Const"}}), "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_MainExternalAllocatorSet_P2pNotUseExternal) { - LoweringGlobalData gd; - InitTestFramesWithStream(gd); - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2)); - - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceP2p, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - EXPECT_EQ(allocator1->GetFastNode()->GetType(), "Init"); - - bg::ValueHolderPtr init_allocator = nullptr; - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - init_allocator = gd.GetOrCreateL1Allocator({kOnDeviceP2p, AllocatorUsage::kAllocNodeOutput}); - return {}; - }); - ASSERT_NE(init_allocator, nullptr); - EXPECT_EQ(init_allocator->GetFastNode()->GetType(), "CreateL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(init_allocator).StrictConnectFrom({{"Const"}}), "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_CreateSelectAllocator_ExternalAllocatorSet) { - LoweringGlobalData gd; - InitTestFramesWithStream(gd); - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2)); - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2), ExecuteGraphType::kInit); - return {}; - }); - - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - EXPECT_EQ(allocator1->GetFastNode()->GetType(), "SelectL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(allocator1).StrictConnectFrom( - {{"InnerData"}, {"Data"}, {"InnerData"}, {"SplitRtStreams"}}), - "success"); - auto create_allocator_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "CreateL1Allocator"); - ASSERT_NE(create_allocator_node, nullptr); - ConnectFromInitToMain(create_allocator_node, 0, allocator1->GetFastNode(), 2); - - bg::ValueHolderPtr init_allocator = nullptr; - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - init_allocator = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - return {}; - }); - ASSERT_NE(init_allocator, nullptr); - EXPECT_EQ(init_allocator->GetFastNode()->GetType(), "SelectL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(init_allocator).StrictConnectFrom( - {{"Const"}, {"Data"}, {"CreateL1Allocator"}, {"SplitRtStreams"}}), - "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_ExternalAllocatorSet_P2pNotUseExternal) { - LoweringGlobalData gd; - InitTestFramesWithStream(gd); - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2)); - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2), ExecuteGraphType::kInit); - return {}; - }); - - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceP2p, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - EXPECT_EQ(allocator1->GetFastNode()->GetType(), "Init"); - - bg::ValueHolderPtr init_allocator = nullptr; - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - init_allocator = gd.GetOrCreateL1Allocator({kOnDeviceP2p, AllocatorUsage::kAllocNodeOutput}); - return {}; - }); - ASSERT_NE(init_allocator, nullptr); - EXPECT_EQ(init_allocator->GetFastNode()->GetType(), "CreateL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(init_allocator).StrictConnectFrom( - {{"Const"}}), "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_ExternalAllocatorSet_UseAlwaysExternalAllocatorOption) { - LoweringGlobalData gd; - InitTestFramesWithStream(gd); - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2)); - LoweringOption opt; - opt.always_external_allocator = true; - gd.SetLoweringOption(opt); - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2), ExecuteGraphType::kInit); - return {}; - }); - - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - EXPECT_EQ(allocator1->GetFastNode()->GetType(), "SelectL1Allocator"); - - auto create_allocator_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "CreateL1Allocator"); - // 外置allocator后,图中就不存在CreateAllocator节点了 - ASSERT_EQ(create_allocator_node, nullptr); - - auto get_allocator_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "GetExternalL1Allocator"); - // 外置allocator后,init - ASSERT_NE(get_allocator_node, nullptr); - - bg::ValueHolderPtr init_allocator = nullptr; - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - init_allocator = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - return {}; - }); - ASSERT_NE(init_allocator, nullptr); - EXPECT_EQ(init_allocator->GetFastNode()->GetType(), "GetExternalL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(init_allocator).StrictConnectFrom( - {{"Const"}, {"Data"}}), - "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_ExternalAllocatorSet_UseAlwaysExternalAllocatorOption_P2pNotUseExternal) { - LoweringGlobalData gd; - InitTestFramesWithStream(gd); - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2)); - LoweringOption opt; - opt.always_external_allocator = true; - gd.SetLoweringOption(opt); - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - gd.SetExternalAllocator(bg::ValueHolder::CreateFeed(-2), ExecuteGraphType::kInit); - return {}; - }); - - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceP2p, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - EXPECT_EQ(allocator1->GetFastNode()->GetType(), "Init"); - - auto create_allocator_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(init_frame->GetExecuteGraph().get(), "CreateL1Allocator"); - ASSERT_NE(create_allocator_node, nullptr); - - bg::ValueHolderPtr init_allocator = nullptr; - bg::FrameSelector::OnInitRoot([&]() -> std::vector { - init_allocator = gd.GetOrCreateL1Allocator({kOnDeviceP2p, AllocatorUsage::kAllocNodeOutput}); - return {}; - }); - ASSERT_NE(init_allocator, nullptr); - EXPECT_EQ(init_allocator->GetFastNode()->GetType(), "CreateL1Allocator"); - EXPECT_EQ(FastNodeTopoChecker(init_allocator).StrictConnectFrom( - {{"Const"}}), "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_AlwaysReturnOnRootFrame_CallInSubgraph) { - InitTestFrames(); - LoweringGlobalData gd; - - auto data0 = bg::ValueHolder::CreateFeed(0); - auto foo1 = bg::ValueHolder::CreateSingleDataOutput("Foo", {data0}); - - bg::ValueHolder::PushGraphFrame(foo1, "FooGraph"); - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - ASSERT_NE(bg::ValueHolder::PopGraphFrame(), nullptr); - - ASSERT_EQ(allocator1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), root_frame->GetExecuteGraph().get()); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_AlwaysCreateOnInitFrame_CallInSubgraph) { - InitTestFrames(); - LoweringGlobalData gd; - - auto data0 = bg::ValueHolder::CreateFeed(0); - auto foo1 = bg::ValueHolder::CreateSingleDataOutput("Foo", {data0}); - - bg::ValueHolder::PushGraphFrame(foo1, "FooGraph"); - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - - ASSERT_NE(bg::ValueHolder::PopGraphFrame(), nullptr); - - ASSERT_EQ(ExeGraphSummaryChecker(init_frame->GetExecuteGraph().get()) - .StrictAllNodeTypes({{"CreateL1Allocator", 1}, {"Const", 1}, {"InnerNetOutput", 1}}), - "success"); -} -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL1Allocator_ReturnOnInit_WhenGetOnInit) { - InitTestFrames(); - LoweringGlobalData gd; - - auto data0 = bg::ValueHolder::CreateFeed(0); - auto foo1 = bg::ValueHolder::CreateSingleDataOutput("Foo", {data0}); - - bg::ValueHolder::PushGraphFrame(foo1, "FooGraph"); - auto allocator1 = gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - ASSERT_NE(allocator1, nullptr); - ASSERT_EQ(allocator1->GetFastNode()->GetType(), "Init"); - - std::vector graph_out; - std::vector node_out; - auto ret = bg::FrameSelector::OnInitRoot( - [&]() -> std::vector { - return {gd.GetOrCreateL1Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput})}; - }, - graph_out, node_out); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - ASSERT_EQ(graph_out.size(), 1); - auto init_node = graph_out[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetParentNodeBarePtr(); - ASSERT_NE(init_node, nullptr); - ASSERT_EQ(init_node->GetType(), "Init"); -} -TEST_F(FastLoweringGlobalDataUT, GetOrCreateUniqueValueHolderOk) { - LoweringGlobalData gd; - auto builder = [&]() -> bg::ValueHolderPtr { - auto resource_holder = bg::FrameSelector::OnMainRoot([&]() -> std::vector { - std::string name = "aicpu_resource"; - auto name_holder = bg::ValueHolder::CreateConst(name.c_str(), name.size(), true); - auto create_container_holder = bg::ValueHolder::CreateSingleDataOutput("CreateStepContainer", {name_holder}); - bg::ValueHolder::CreateVoidGuarder("DestroyStepContainer", create_container_holder, {}); - return {create_container_holder}; - }); - return resource_holder[0]; - }; - auto holder_0 = gd.GetOrCreateUniqueValueHolder("aicpu_container_0", builder); - EXPECT_NE(holder_0, nullptr); - - auto clear_builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateVoid("ClearStepContainer", {holder_0}); - }; - auto clear_holder = bg::FrameSelector::OnMainRootLast(clear_builder); - EXPECT_NE(clear_holder, nullptr); - std::string create_resource_name = holder_0->GetFastNode()->GetOpDescBarePtr()->GetName(); - EXPECT_EQ(create_resource_name.find("CreateStepContainer"), 0); - - auto last_exec_nodes = bg::ValueHolder::GetLastExecNodes(); - EXPECT_EQ(last_exec_nodes.size(), 1); - EXPECT_NE(last_exec_nodes[0], nullptr); - std::string clear_resource_name = last_exec_nodes[0]->GetFastNode()->GetOpDescBarePtr()->GetName(); - EXPECT_EQ(clear_resource_name.find("ClearStepContainer"), 0); - - // use same key: aicpu_container_0, check unique - auto holder_1 = gd.GetOrCreateUniqueValueHolder("aicpu_container_0", builder); - EXPECT_EQ(last_exec_nodes.size(), 1); - last_exec_nodes.clear(); -} - -TEST_F(FastLoweringGlobalDataUT, OnMainRootLastOk) { - LoweringGlobalData gd; - uint64_t global_container_id = 0; - auto builder = [&]() -> bg::ValueHolderPtr { - uint64_t container_id = global_container_id++; - auto container_id_holder = bg::ValueHolder::CreateConst(&container_id, sizeof(uint64_t)); - uint64_t session_id = 0; - auto session_id_holder = bg::ValueHolder::CreateConst(&session_id, sizeof(uint64_t)); - auto resource_holder = bg::FrameSelector::OnMainRoot([&]() -> std::vector { - auto create_session_holder = bg::ValueHolder::CreateSingleDataOutput("CreateSession", {session_id_holder}); - bg::ValueHolder::CreateVoidGuarder("DestroySession", create_session_holder, {}); - auto clear_builder = [&]() -> bg::ValueHolderPtr { - return bg::ValueHolder::CreateVoid("ClearStepContainer", {session_id_holder, container_id_holder}); - }; - auto clear_holder = bg::FrameSelector::OnMainRootLast(clear_builder); - EXPECT_NE(clear_holder, nullptr); - return {container_id_holder}; - }); - return resource_holder[0]; - }; - auto holder_0 = gd.GetOrCreateUniqueValueHolder("aicpu_container_0", builder); - EXPECT_NE(holder_0, nullptr); - - auto last_exec_nodes = bg::ValueHolder::GetLastExecNodes(); - EXPECT_EQ(last_exec_nodes.size(), 1); - EXPECT_NE(last_exec_nodes[0], nullptr); - std::string clear_resource_name = last_exec_nodes[0]->GetFastNode()->GetOpDescBarePtr()->GetName(); - EXPECT_EQ(clear_resource_name.find("ClearStepContainer"), 0); - - // use same key: aicpu_container_0, check unique - auto holder_1 = gd.GetOrCreateUniqueValueHolder("aicpu_container_0", builder); - EXPECT_EQ(last_exec_nodes.size(), 1); - last_exec_nodes.clear(); -} - -TEST_F(FastLoweringGlobalDataUT, SinkWeightInfoTest) { - LoweringGlobalData gd; - size_t weight_info = 1; - gd.SetModelWeightSize(weight_info); - auto result = gd.GetModelWeightSize(); - EXPECT_EQ(result, weight_info); -} - -TEST_F(FastLoweringGlobalDataUT, GetValueHolersSizeTest) { - LoweringGlobalData gd; - gd.SetValueHolders("test1", nullptr); - EXPECT_EQ(gd.GetValueHoldersSize("test1"), 1); - EXPECT_EQ(gd.GetValueHoldersSize("test2"), 0); - gd.SetValueHolders("test1", nullptr); - EXPECT_EQ(gd.GetValueHoldersSize("test1"), 2); - - gd.SetUniqueValueHolder("test3", nullptr); - EXPECT_EQ(gd.GetValueHoldersSize("test3"), 1); - gd.SetUniqueValueHolder("test3", nullptr); - EXPECT_EQ(gd.GetValueHoldersSize("test3"), 1); -} - -TEST_F(FastLoweringGlobalDataUT, SetGetUniqueValueHoler) { - LoweringGlobalData gd; - gd.SetUniqueValueHolder("test1", nullptr); - EXPECT_EQ(gd.GetValueHoldersSize("test1"), 1); - EXPECT_EQ(gd.GetValueHoldersSize("test2"), 0); - EXPECT_EQ(gd.GetUniqueValueHolder("test1"), nullptr); - - gd.SetUniqueValueHolder("test1", bg::ValueHolder::CreateVoid("TEST", {})); - EXPECT_EQ(gd.GetValueHoldersSize("test1"), 1); - EXPECT_NE(gd.GetUniqueValueHolder("test1"), nullptr); -} - -TEST_F(FastLoweringGlobalDataUT, StaticModelWsSizeTest) { - LoweringGlobalData gd; - int64_t require_ws_size = 1; - gd.SetStaicModelWsSize(require_ws_size); - auto result = gd.GetStaticModelWsSize(); - EXPECT_EQ(result, require_ws_size); -} - -TEST_F(FastLoweringGlobalDataUT, FixedFeatureMemoryBaseTest) { - LoweringGlobalData gd; - gd.SetFixedFeatureMemoryBase((void *)0x355, 4); - auto fixed_feature_mem = gd.GetFixedFeatureMemoryBase(); - EXPECT_EQ((size_t)fixed_feature_mem.first, 0x355); - EXPECT_EQ(fixed_feature_mem.second, 4); -} - -/* - * init_graph: - * +--------------------------+ - * | | | - * Const(placement) Const(stream_num) Const(placement) Const(usage) | - * \ / \ / | - * CreateL2Allocator CreateL1Allocator | - * \ / | - * \ / | - * InnerNetOutput <--------------------------+ - * - * - * - * main_graph: - * * InnerData(l1 allocators) - * Data(rt_streams) - * | / - * SplitRtStreams / InnerData(l2 allocators) - * \ / / - * SelectL2Allocator - * - * 测试无外置allocator场景下的L2 allocator - */ -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL2AllocatorInMain_Device_WithoutExternalAllocator) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - auto l2_allocator = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator, nullptr); - - auto get_l2_allocator = - global_data.GetMainL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(get_l2_allocator, nullptr); - EXPECT_EQ(l2_allocator->GetFastNode(), get_l2_allocator->GetFastNode()); - - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Data", 1}, - {"SplitRtStreams", 1}, - {"Const",4}, - {"CreateL1Allocator", 1}, - {"CreateL2Allocators", 1}, - {"InnerNetOutput", 1}}), - "success"); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - ge::DumpGraph(main_exe_graph->GetParentGraphBarePtr(), "TestL2Allocator"); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{ - {"Const", 2}, {"Data", 1}, {"InnerData", 2}, {"SplitRtStreams", 1}, {"SelectL2Allocator", 1}}), - "success"); - FastNodeTopoChecker checker(l2_allocator); - // Const(logic_stream_id), GetStreamById, InnerData(L1 allocator), InnerData(L2 allocators) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Const"}, {"SplitRtStreams"}, {"InnerData"}, {"InnerData"}})), "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL2AllocatorInMain_FollowingPlacement_GetHostAllocator) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - auto l2_allocator = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kFollowing, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator, nullptr); - - auto get_l2_allocator = - global_data.GetMainL2Allocator(0, {TensorPlacement::kFollowing, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(get_l2_allocator, nullptr); - EXPECT_EQ(l2_allocator->GetFastNode(), get_l2_allocator->GetFastNode()); - auto get_l2_allocator_init = - global_data.GetInitL2Allocator({TensorPlacement::kFollowing, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(get_l2_allocator_init, nullptr); - - auto init_out = bg::FrameSelector::OnInitRoot([&]()-> std::vector { - return {global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kFollowing, AllocatorUsage::kAllocNodeOutput})}; - }); - EXPECT_EQ(init_out.size(), 1); - EXPECT_NE(init_out[0], nullptr); - auto host_allocator_init = bg::HolderOnInit(init_out[0]); - EXPECT_EQ(host_allocator_init->GetFastNode(), get_l2_allocator_init->GetFastNode()); - EXPECT_EQ(host_allocator_init->GetFastNode()->GetType(), "CreateHostL2Allocator"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL2AllocatorOnInit_UnsupportedPlacement_Failed) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - auto init_out = bg::FrameSelector::OnInitRoot([&]()-> std::vector { - return {global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kTensorPlacementEnd, AllocatorUsage::kAllocNodeOutput})}; - }); - EXPECT_EQ(init_out.size(), 0); - - auto get_l2_allocator_init = - global_data.GetInitL2Allocator({TensorPlacement::kTensorPlacementEnd, AllocatorUsage::kAllocNodeOutput}); - EXPECT_EQ(get_l2_allocator_init, nullptr); -} -/* - * init_graph: - * - * - * Const(placement) - * / | - * | CreateL1Allocator - * | / - * \ / - * InnerNetOutput - * - * - * - * main_graph: - * * InnerData(l1 allocators) - * | - * CreateHostL2Allocator - * - * - * 测试无外置allocator场景下的host L2 allocator - */ -TEST_F(FastLoweringGlobalDataUT, GetOrCreateL2AllocatorInMain_Host_WithoutExternalAllocator) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - auto l2_allocator = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnHost, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator, nullptr); - - auto get_l2_allocator = - global_data.GetMainL2Allocator(0, {TensorPlacement::kOnHost, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(get_l2_allocator, nullptr); - EXPECT_EQ(l2_allocator->GetFastNode(), get_l2_allocator->GetFastNode()); - auto get_l2_allocator_init = - global_data.GetInitL2Allocator({TensorPlacement::kOnHost, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(get_l2_allocator_init, nullptr); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Data", 1}, - {"SplitRtStreams", 1}, - {"Const", 3}, - {"CreateL1Allocator", 1}, - {"InnerNetOutput", 1}, - {"CreateHostL2Allocator", 1}}), - "success"); - - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - ge::DumpGraph(main_exe_graph->GetParentGraphBarePtr(), "TestHostL2Allocator"); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 1}, // stream num - {"Data", 1}, // stream - {"SplitRtStreams", 1}, - {"InnerData", 1}, - {"CreateHostL2Allocator", 1}}), - "success"); - FastNodeTopoChecker checker(l2_allocator); - // InnerData(L1 allocator) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"InnerData"}})), "success"); -} -/* - * init_graph: - * +--------------------------+ - * | | | - * Const(placement) Const(stream_num) Const(placement) Const(usage) | - * \ / \ / | Data(exteranl_allocator) - * CreateL2Allocator CreateL1Allocator ----------+-----+ | Data(external_stream) - * \ / | \ | / - * \ / | SelectAllocator - * InnerNetOutput <-------------------------+ - * - * - * main_graph: - * SelectL1Allocator - * Data(rt_streams) / - * | / - * GetStreamById / InnerData(l2 allocators) - * \ / / - * SelectL2Allocator - * - * 测试有外置allocator场景下的L2 allocator - */ -TEST_F(FastLoweringGlobalDataUT, GetOrCreateAllocatorInMain_Device_WithExternalAllocator) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - // prepare external allocator on init and main - bg::FrameSelector::OnInitRoot([&]()-> std::vector { - auto external_allocator_init = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator_init), ExecuteGraphType::kInit); - return {}; - }); - auto external_allocator = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator), ExecuteGraphType::kMain); - - auto l2_allocator = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator, nullptr); - - auto get_l2_allocator = - global_data.GetMainL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(get_l2_allocator, nullptr); - EXPECT_EQ(l2_allocator->GetFastNode(), get_l2_allocator->GetFastNode()); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - ge::DumpGraph(init_exe_graph, "l2_allocator_init"); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 4}, - {"Data", 2}, - {"SplitRtStreams", 1}, - {"CreateL1Allocator", 1}, - {"CreateL2Allocators", 1}, - {"SelectL1Allocator", 1}, - {"InnerNetOutput", 1}}), - "success"); - - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - // ge::DumpGraph(main_exe_graph->GetParentGraphBarePtr(), "TestDeviceL2AllocatorWithExternalL1Allocator"); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 2}, - {"Data", 2}, - {"InnerData", 3}, - {"SplitRtStreams", 1}, - {"SelectL1Allocator", 1}, - {"SelectL2Allocator", 1}}), - "success"); - FastNodeTopoChecker checker(l2_allocator); - // Const(logic_stream_id), GetStreamById, SelectL1Allocator(L1 allocator), InnerData(L2 allocators) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Const"}, {"SplitRtStreams"}, {"SelectL1Allocator"}, {"InnerData"}})), "success"); -} - -/* - * init_graph: - * +--------------------------+ - * | | | - * Const(placement) Const(stream_num) Const(placement) Const(usage) | - * \ / \ / | Data(exteranl_allocator) - * CreateL2Allocator CreateL1Allocator ----------+-----+ | Data(external_stream) - * \ / | \ | / - * \ / | SelectAllocator - * InnerNetOutput <-------------------------+ - * - * - * main_graph: - * SelectL1Allocator - * Const(stream_id) Data(rt_streams) / - * \ / / - * GetStreamById / InnerData(l2 allocators) - * \ / / - * SelectL2Allocator - * / \ - * consumer00 consumer01 - * 测试多次调用有外置allocator场景下的L2 allocator - * 预期结果:同一条流上多次调用 select l2 allocator只生成1个kerenel - */ -TEST_F(FastLoweringGlobalDataUT, GetOrCreateAllocatorInMain_WithExternalAllocator_CallMultiTimes) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - // prepare external allocator on init and main - bg::FrameSelector::OnInitRoot([&]()-> std::vector { - auto external_allocator_init = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator_init), ExecuteGraphType::kInit); - return {}; - }); - auto external_allocator = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator), ExecuteGraphType::kMain); - - // prepare rtStreams - - auto l2_allocator_00 = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator_00, nullptr); - auto consumer00 = bg::ValueHolder::CreateVoid("consumer00", {l2_allocator_00}); - - auto l2_allocator_01 = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator_01, nullptr); - EXPECT_EQ(l2_allocator_00, l2_allocator_01); - auto consumer01 = bg::ValueHolder::CreateVoid("consumer01", {l2_allocator_01}); - - auto l2_allocator_10 = - global_data.GetOrCreateL2Allocator(1, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator_01, nullptr); - auto consumer10 = bg::ValueHolder::CreateVoid("consumer10", {l2_allocator_10}); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - // ge::DumpGraph(init_exe_graph, "l2_allocator_init"); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 4}, - {"Data", 2}, - {"SplitRtStreams", 1}, - {"CreateL1Allocator", 1}, - {"CreateL2Allocators", 1}, - {"SelectL1Allocator", 1}, - {"InnerNetOutput", 1}}), - "success"); - - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - // ge::DumpGraph(main_exe_graph, "l2_allocator_main"); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 3}, - {"Data", 2}, - {"InnerData", 3}, - {"SplitRtStreams", 1}, - {"SelectL1Allocator", 1}, - {"SelectL2Allocator", 2}, - {"consumer00", 1}, - {"consumer01", 1}, - {"consumer10", 1}}), - "success"); - FastNodeTopoChecker checker(l2_allocator_00); - // Const(logic_stream_id), GetStreamById, SelectL1Allocator(L1 allocator), InnerData(L2 allocators) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Const"}, {"SplitRtStreams"}, {"SelectL1Allocator"}, {"InnerData"}})), "success"); - EXPECT_EQ(checker.StrictConnectTo(0, std::vector({{"consumer00"},{"consumer01"}})), "success"); -} - -/* - * init_graph: - * +--------------------------+ - * | | | - * Const(placement) Const(stream_num) Const(placement) Const(usage) | - * \ / \ / | Data(exteranl_allocator) - * CreateL2Allocator CreateL1Allocator ----------+-----+ | Data(external_stream) - * \ / | \ | / - * \ / | SelectAllocator - * InnerNetOutput <-------------------------+ - * - * - * main_graph: - * SelectL1Allocator - * Const(stream_id) Data(rt_streams) / - * \ / / - * GetStreamById / InnerData(l2 allocators) - * \ / / ......... - * SelectL2Allocator CreateHostL2Allocator - * / \ / \ - * consumer00 consumer02 consumer01 consumer03 - * - * 测试多次调用有外置allocator场景下的L2 allocator - * 预期结果:同一条流上多次调用 select l2 allocator,但是placement不同,生成2个kerenel - */ -TEST_F(FastLoweringGlobalDataUT, GetOrCreateAllocatorInMain_WithExternalAllocator_CallMultiTimesWithDiffPlacement) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - // prepare external allocator on init and main - bg::FrameSelector::OnInitRoot([&global_data]()-> std::vector { - auto external_allocator_init = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator_init), ExecuteGraphType::kInit); - return {}; - }); - auto external_allocator = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator), ExecuteGraphType::kMain); - - auto l2_allocator_00 = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator_00, nullptr); - EXPECT_EQ(l2_allocator_00->GetFastNode()->GetType(), "SelectL2Allocator"); - auto consumer00 = bg::ValueHolder::CreateVoid("consumer00", {l2_allocator_00}); - - auto l2_allocator_01 = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnHost, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator_01, nullptr); - EXPECT_EQ(l2_allocator_01->GetFastNode()->GetType(), "CreateHostL2Allocator"); - EXPECT_NE(l2_allocator_00, l2_allocator_01); - auto consumer01 = bg::ValueHolder::CreateVoid("consumer01", {l2_allocator_01}); - - auto l2_allocator_02 = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator_02, nullptr); - EXPECT_EQ(l2_allocator_02->GetFastNode()->GetType(), "SelectL2Allocator"); - auto consumer02 = bg::ValueHolder::CreateVoid("consumer02", {l2_allocator_02}); - - auto l2_allocator_03 = - global_data.GetOrCreateL2Allocator(0, {TensorPlacement::kOnHost, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator_03, nullptr); - EXPECT_EQ(l2_allocator_03->GetFastNode()->GetType(), "CreateHostL2Allocator"); - auto consumer03 = bg::ValueHolder::CreateVoid("consumer03", {l2_allocator_03}); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - // ge::DumpGraph(init_exe_graph, "l2_allocator_init"); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 5}, - {"Data", 2}, - {"SplitRtStreams", 1}, - {"CreateL1Allocator", 2}, - {"CreateL2Allocators", 1}, - {"SelectL1Allocator", 2}, - {"InnerNetOutput", 1}, - {"CreateHostL2Allocator", 1}}), - "success"); - - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - ge::DumpGraph(main_exe_graph->GetParentGraphBarePtr(), "CallMultiTimesWithDiffPlacement"); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 2}, - {"Data", 2}, - {"InnerData", 5}, - {"SplitRtStreams", 1}, - {"SelectL1Allocator", 2}, - {"SelectL2Allocator", 1}, - {"CreateHostL2Allocator", 1}, - {"consumer00", 1}, - {"consumer01", 1}, - {"consumer02", 1}, - {"consumer03", 1}}), - "success"); - FastNodeTopoChecker checker(l2_allocator_00); - // Const(logic_stream_id), GetStreamById, SelectL1Allocator(L1 allocator), InnerData(L2 allocators) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Const"}, {"SplitRtStreams"}, {"SelectL1Allocator"}, {"InnerData"}})), "success"); -} - -/* - * Data Const Const(out) - * \ / | - * SplitRtStreams CreateL1Allocator(out) Const - * \ / / - * SelectL1Allocator------> CreateL2Allocators(out) - * (out) \ / - * CreateInitL2Allocator (out) - * / \ - * / consumer - * \ / - * InnerNetOutput - */ -TEST_F(FastLoweringGlobalDataUT, GetOrCreateInitL2AllocatorOnInit_Device) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - // prepare external allocator on init and main - auto init_out1 = bg::FrameSelector::OnInitRoot([&global_data]()-> std::vector { - auto external_allocator_init = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator_init), ExecuteGraphType::kInit); - - AllocatorDesc desc = {kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}; - auto init_l2_allocator = global_data.GetOrCreateL2Allocator(0, desc); - EXPECT_NE(init_l2_allocator, nullptr); - auto consumer = bg::ValueHolder::CreateSingleDataOutput("consumer", {init_l2_allocator}); - return {init_l2_allocator, consumer}; - }); - EXPECT_EQ(init_out1.size(), 2); - - auto device_l2_allocator_init = bg::HolderOnInit(init_out1[0]); - EXPECT_NE(device_l2_allocator_init, nullptr); - auto get_device_l2_allocator_init = - global_data.GetInitL2Allocator({kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(device_l2_allocator_init, nullptr); - EXPECT_EQ(device_l2_allocator_init->GetFastNode(), get_device_l2_allocator_init->GetFastNode()); - - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - ge::DumpGraph(init_exe_graph, "l2_allocator_init"); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 4}, - {"Data", 2}, - {"SplitRtStreams", 1}, - {"CreateL1Allocator", 1}, - {"CreateL2Allocators", 1}, - {"CreateInitL2Allocator", 1}, - {"SelectL1Allocator", 1}, - {"consumer", 1}, - {"InnerNetOutput", 1}}), - "success"); - auto consumer = HolderOnInit(init_out1[1]); - FastNodeTopoChecker checker(consumer); - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"CreateInitL2Allocator"}})), "success"); -} - -/* - * Data Const Const(out) - * \ / | - * SplitRtStreams CreateL1Allocator(out) Const - * \ / / - * SelectL1Allocator------> CreateL2Allocators(out) - * (out) \ / - * CreateHostL2Allocator (out) - * / - * / - * InnerNetOutput - */ - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateInitL2AllocatorOnInit_Host) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - // prepare external allocator on init and main - auto init_out = bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - auto external_allocator_init = bg::ValueHolder::CreateFeed(-2); - global_data.SetExternalAllocator(static_cast(external_allocator_init), ExecuteGraphType::kInit); - - AllocatorDesc desc = {kOnHost, AllocatorUsage::kAllocNodeOutput}; - auto init_l2_allocator = global_data.GetOrCreateL2Allocator(0, desc); - EXPECT_NE(init_l2_allocator, nullptr); - auto consumer = bg::ValueHolder::CreateSingleDataOutput("consumer", {init_l2_allocator}); - return {init_l2_allocator, consumer}; - }); - EXPECT_EQ(init_out.size(), 2); - - auto host_l2_allocator_init = bg::HolderOnInit(init_out[0]); - EXPECT_NE(host_l2_allocator_init, nullptr); - auto get_host_l2_allocator_init = - global_data.GetInitL2Allocator({kOnHost, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(get_host_l2_allocator_init, nullptr); - EXPECT_EQ(host_l2_allocator_init->GetFastNode(), get_host_l2_allocator_init->GetFastNode()); - - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - ge::DumpGraph(init_exe_graph, "l2_allocator_init"); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 3}, - {"Data", 2}, - {"SplitRtStreams", 1}, - {"CreateL1Allocator", 1}, - {"SelectL1Allocator", 1}, - {"CreateHostL2Allocator", 1}, - {"consumer", 1}, - {"InnerNetOutput", 1}}), - "success"); - - auto consumer = HolderOnInit(init_out[1]); - FastNodeTopoChecker checker(consumer); - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"CreateHostL2Allocator"}})), "success"); - - // get host l2 allocator in another init lowering - auto init_out1 = bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - AllocatorDesc desc = {kOnHost, AllocatorUsage::kAllocNodeOutput}; - auto init_l2_allocator = global_data.GetOrCreateL2Allocator(0, desc); - EXPECT_NE(init_l2_allocator, nullptr); - auto consumer1 = bg::ValueHolder::CreateSingleDataOutput("consumer1", {init_l2_allocator}); - return {init_l2_allocator, consumer1}; - }); - EXPECT_FALSE(init_out1.empty()); - - auto consumer1 = HolderOnInit(init_out[1]); - FastNodeTopoChecker checker1(consumer1); - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"CreateHostL2Allocator"}})), "success"); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 3}, - {"Data", 2}, - {"SplitRtStreams", 1}, - {"CreateL1Allocator", 1}, - {"SelectL1Allocator", 1}, - {"CreateHostL2Allocator", 1}, - {"consumer", 1}, - {"consumer1", 1}, - {"InnerNetOutput", 1}}), - "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetOrCreateAllL2Allocators_success) { - LoweringGlobalData global_data; - InitTestFramesWithStream(global_data, 3); - - global_data.LoweringAndSplitRtStreams(3); - auto l2_allocator = global_data.GetOrCreateL2Allocator(1, {kOnDeviceHbm, AllocatorUsage::kAllocNodeOutput}); - EXPECT_NE(l2_allocator, nullptr); - auto all_l2_allocators = global_data.GetOrCreateAllL2Allocators(); - EXPECT_NE(all_l2_allocators, nullptr); - EXPECT_EQ(all_l2_allocators->GetFastNode()->GetType(), "Init"); - auto consumer = bg::ValueHolder::CreateSingleDataOutput("consumer", {all_l2_allocators}); - - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - ge::DumpGraph(init_exe_graph, "l2_allocator_init"); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Const", 4}, - {"Data", 1}, - {"SplitRtStreams", 1}, - {"CreateL1Allocator", 1}, - {"CreateL2Allocators", 1}, - {"InnerNetOutput", 1}}), - "success"); - - auto all_l2_allocators_in_init = HolderOnInit(all_l2_allocators); - EXPECT_EQ(all_l2_allocators_in_init->GetFastNode()->GetType(), "CreateL2Allocators"); - - FastNodeTopoChecker checker(consumer); - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"InnerData", 0}})), "success"); -} -/* - * main_graph: - * Data(rt_streams) - * | - * SplitRtStreams - * | - * consumer - * - */ -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Main_Once) { - InitTestFrames(); - LoweringGlobalData global_data; - - // prepare rtStreams - auto all_rt_streams = global_data.LoweringAndSplitRtStreams(1); - EXPECT_EQ(all_rt_streams.size(), 1); - - auto rt_stream = global_data.GetStreamById(0); - EXPECT_NE(rt_stream, nullptr); - auto consumer = bg::ValueHolder::CreateVoid("consumer", {rt_stream}); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{ - {"Data", 1}, {"Const", 1}, {"SplitRtStreams", 1}, {"consumer", 1}}), - "success"); - FastNodeTopoChecker checker(rt_stream); - // Const(logic_stream_id), Data(rt_streams) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Data"}, {"Const"}})), "success"); - FastNodeTopoChecker consumer_checker(consumer); - EXPECT_EQ(consumer_checker.StrictConnectFrom(std::vector({{"SplitRtStreams", 0}})), "success"); -} -/* - * main_graph: - * Data(rt_streams) - * | - * SplitRtStreams - * / \ - * consumer0 consumer1 - * - */ -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Main_SameStreamCallTwice) { - InitTestFrames(); - LoweringGlobalData global_data; - - // prepare rtStreams - auto all_rt_streams = global_data.LoweringAndSplitRtStreams(1); - EXPECT_EQ(all_rt_streams.size(), 1); - EXPECT_TRUE(global_data.IsSingleStreamScene()); - - auto rt_stream = global_data.GetStreamById(0); - EXPECT_NE(rt_stream, nullptr); - auto consumer = bg::ValueHolder::CreateVoid("consumer0", {rt_stream}); - - auto rt_stream1 = global_data.GetStreamById(0); - EXPECT_NE(rt_stream1, nullptr); - auto consumer1 = bg::ValueHolder::CreateVoid("consumer1", {rt_stream1}); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{ - {"Data", 1},{"Const", 1}, {"SplitRtStreams", 1}, {"consumer0", 1}, {"consumer1", 1}}), - "success"); - FastNodeTopoChecker checker(rt_stream); - // Const(logic_stream_id), Data(rt_streams) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Data"}, {"Const"}})), "success"); - EXPECT_EQ(checker.StrictConnectTo(0, std::vector({{"consumer0"}, {"consumer1"}})), "success"); -} - -/* - * main_graph: - * Data(rt_streams) - * | - * SplitRtStreams - * / \ - * consumer0 consumer1 - * - */ -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Main_DiffStreamCallTwice) { - InitTestFrames(); - LoweringGlobalData global_data; - - // prepare rtStreams - auto all_rt_streams = global_data.LoweringAndSplitRtStreams(2); - EXPECT_EQ(all_rt_streams.size(), 2); - EXPECT_TRUE(!global_data.IsSingleStreamScene()); - - auto rt_stream = global_data.GetStreamById(0); - EXPECT_NE(rt_stream, nullptr); - auto consumer = bg::ValueHolder::CreateVoid("consumer0", {rt_stream}); - - auto rt_stream1 = global_data.GetStreamById(1); - EXPECT_NE(rt_stream1, nullptr); - auto consumer1 = bg::ValueHolder::CreateVoid("consumer1", {rt_stream1}); - - auto main_frame = bg::ValueHolder::PopGraphFrame(); - auto main_exe_graph = main_frame->GetExecuteGraph().get(); - EXPECT_EQ(ExeGraphSummaryChecker(main_exe_graph) - .StrictDirectNodeTypes(std::map{ - {"Data", 1},{"Const", 1}, {"SplitRtStreams", 1}, {"consumer0", 1}, {"consumer1", 1}}), - "success"); - FastNodeTopoChecker checker(rt_stream); - // Const(logic_stream_id), Data(rt_streams) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Data"}, {"Const"}})), "success"); - EXPECT_EQ(checker.StrictConnectTo(0, std::vector({{"consumer0"}})), "success"); - EXPECT_EQ(checker.StrictConnectTo(1, std::vector({{"consumer1"}})), "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Main_WithOutSplitRtStreams) { - InitTestFrames(); - LoweringGlobalData global_data; - - auto rt_stream = global_data.GetStreamById(0); - EXPECT_EQ(rt_stream, nullptr); -} - -/* - * init_graph: - * Data(rt_streams) - * | - * SplitRtStreams - * | - * consumer - * - */ -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Init_Once) { - InitTestFrames(); - LoweringGlobalData global_data; - auto consumer_and_stream = bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - // prepare rtStreams - auto all_rt_streams = global_data.LoweringAndSplitRtStreams(1); - EXPECT_EQ(all_rt_streams.size(), 1); - - auto rt_stream = global_data.GetStreamById(0); - EXPECT_NE(rt_stream, nullptr); - auto consumer = bg::ValueHolder::CreateSingleDataOutput("consumer", {rt_stream}); - return {consumer, rt_stream}; - }); - EXPECT_EQ(consumer_and_stream.size(), 2); - - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{ - {"Data", 1}, {"Const", 1}, {"SplitRtStreams", 1}, {"consumer", 1}, {"InnerNetOutput", 1}}), - "success"); - auto split_rt_streams = HolderOnInit(consumer_and_stream[1]); - FastNodeTopoChecker checker(split_rt_streams); - // Const(logic_stream_id), Data(rt_streams) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Data"}, {"Const"}})), "success"); - auto consumer = HolderOnInit(consumer_and_stream[0]); - FastNodeTopoChecker consumer_checker(consumer); - EXPECT_EQ(consumer_checker.StrictConnectFrom(std::vector({{"SplitRtStreams", 0}})), "success"); -} - -/* - * init_graph: - * Data(rt_streams) - * | - * SplitRtStreams - * / \ - * consumer0 consumer1 - * - */ -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Init_SameStreamCallTwice) { - InitTestFrames(); - LoweringGlobalData global_data; - - auto consumer_and_stream = bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - // prepare rtStreams - auto all_rt_streams = global_data.LoweringAndSplitRtStreams(1); - EXPECT_EQ(all_rt_streams.size(), 1); - - auto rt_stream = global_data.GetStreamById(0); - EXPECT_NE(rt_stream, nullptr); - auto consumer = bg::ValueHolder::CreateSingleDataOutput("consumer0", {rt_stream}); - - auto rt_stream1 = global_data.GetStreamById(0); - EXPECT_NE(rt_stream1, nullptr); - auto consumer1 = bg::ValueHolder::CreateSingleDataOutput("consumer1", {rt_stream1}); - return {rt_stream, consumer, consumer1}; - }); - EXPECT_EQ(consumer_and_stream.size(), 3); - - auto init_exe_graph = init_frame->GetExecuteGraph().get(); - EXPECT_EQ(ExeGraphSummaryChecker(init_exe_graph) - .StrictDirectNodeTypes(std::map{{"Data", 1}, - {"Const", 1}, - {"SplitRtStreams", 1}, - {"consumer0", 1}, - {"consumer1", 1}, - {"InnerNetOutput", 1}}), - "success"); - auto rt_stream = HolderOnInit(consumer_and_stream[0]); - FastNodeTopoChecker checker(rt_stream); - // Const(logic_stream_id), Data(rt_streams) - EXPECT_EQ(checker.StrictConnectFrom(std::vector({{"Data"}, {"Const"}})), "success"); - EXPECT_EQ(checker.StrictConnectTo(0, std::vector({{"consumer0"}, {"consumer1"}, {"InnerNetOutput"}})), - "success"); -} - -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Init_WithOutSplitRtStreams) { - InitTestFrames(); - LoweringGlobalData global_data; - auto consumer_and_stream = bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - auto rt_stream = global_data.GetStreamById(0); - EXPECT_EQ(rt_stream, nullptr); - return {rt_stream}; - }); - EXPECT_EQ(consumer_and_stream.size(), 0); -} - -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Init_StreamIdOutOfRange) { - InitTestFrames(); - LoweringGlobalData global_data; - auto consumer_and_stream = bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - // prepare rtStreams - auto all_rt_streams = global_data.LoweringAndSplitRtStreams(1); - EXPECT_EQ(all_rt_streams.size(), 1); - - auto rt_stream = global_data.GetStreamById(2); // stream id out of range - EXPECT_EQ(rt_stream, nullptr); - return {rt_stream}; - }); - EXPECT_EQ(consumer_and_stream.size(), 0); -} - -TEST_F(FastLoweringGlobalDataUT, GetStreamById_Init_StreamNumOutOfRange) { - InitTestFrames(); - LoweringGlobalData global_data; - auto consumer_and_stream = bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - // prepare rtStreams - auto all_rt_streams = global_data.LoweringAndSplitRtStreams(2); // stream num out of range - EXPECT_EQ(all_rt_streams.size(), 0); - - auto rt_stream = global_data.GetStreamById(0); - EXPECT_EQ(rt_stream, nullptr); - return {rt_stream}; - }); - EXPECT_EQ(consumer_and_stream.size(), 0); -} - -TEST_F(FastLoweringGlobalDataUT, GetNotifyById_Main) { - InitTestFrames(); - LoweringGlobalData global_data; - auto notify_0 = bg::ValueHolder::CreateFeed(0); - auto notify_1 = bg::ValueHolder::CreateFeed(1); - std::vector notifies{notify_0, notify_1}; - (void) bg::FrameSelector::OnMainRoot([&global_data, ¬ifies]() -> std::vector { - global_data.SetRtNotifies(notifies); - - auto rt_notify0 = global_data.GetNotifyById(0); - EXPECT_NE(rt_notify0, nullptr); - auto rt_notify1 = global_data.GetNotifyById(1); - EXPECT_NE(rt_notify1, nullptr); - auto rt_notify2 = global_data.GetNotifyById(2); - EXPECT_EQ(rt_notify2, nullptr); - return {}; - }); - - (void) bg::FrameSelector::OnInitRoot([&global_data]() -> std::vector { - auto rt_notify0 = global_data.GetNotifyById(0); - EXPECT_EQ(rt_notify0, nullptr); - return {}; - }); -} - -TEST_F(FastLoweringGlobalDataUT, GetNotifyById_Init) { - InitTestFrames(); - LoweringGlobalData global_data; - auto notify_0 = bg::ValueHolder::CreateFeed(0); - auto notify_1 = bg::ValueHolder::CreateFeed(1); - std::vector notifies{notify_0, notify_1}; - (void) bg::FrameSelector::OnInitRoot([&global_data, ¬ifies]() -> std::vector { - global_data.SetRtNotifies(notifies); - - auto rt_notify0 = global_data.GetNotifyById(0); - EXPECT_NE(rt_notify0, nullptr); - auto rt_notify1 = global_data.GetNotifyById(1); - EXPECT_NE(rt_notify1, nullptr); - auto rt_notify2 = global_data.GetNotifyById(2); - EXPECT_EQ(rt_notify2, nullptr); - return {}; - }); -} -} // namespace gert diff --git a/tests/ut/exe_graph/fast_value_holder_unittest.cc b/tests/ut/exe_graph/fast_value_holder_unittest.cc deleted file mode 100644 index b8d9db541e00d99f00e4d37fb7cbc7baf7352f98..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/fast_value_holder_unittest.cc +++ /dev/null @@ -1,1883 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/value_holder.h" -#include "exe_graph/lowering/value_holder_utils.h" -#include "exe_graph/runtime/context_extend.h" -#include -#include -#include -#include "exe_graph/lowering/exe_graph_attrs.h" -#include "checker/bg_test.h" -#include "graph/utils/graph_utils.h" -#include "checker/topo_checker.h" -#include "checker/summary_checker.h" -#include "exe_graph/lowering/extend_exe_graph.h" -#include "graph/fast_graph/execute_graph.h" -#include "graph/utils/execute_graph_utils.h" - -namespace gert { -namespace bg { -namespace { -ge::NodePtr FakeNode() { - static size_t counter = 0; - static ge::ComputeGraphPtr graph = std::make_shared("graph"); - auto op_desc = std::make_shared("FakeNode_" + std::to_string(counter++), "FakeNode"); - return graph->AddNode(op_desc); -} -size_t GetComputeNodeIndex(const ge::FastNode *node) { - int64_t index; - if (!ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), kComputeNodeIndex, index)) { - return std::numeric_limits::max(); - } - return static_cast(index); -} -} -class FastValueHolderUt : public BgTest { - public: - ge::ExecuteGraph *FindFirstSubgraphForNodeType(const ge::ExecuteGraph *root_graph, - const std::string &node_type) { - for (const auto &subgraph : root_graph->GetAllSubgraphs()) { - auto parent_node = subgraph->GetParentNodeBarePtr(); - if (parent_node->GetType() == node_type) { - return subgraph; - } - } - return nullptr; - } - ge::FastNode *FindData(const ge::ExecuteGraph *graph, int32_t index) { - for (const auto &node : graph->GetDirectNode()) { - if (node->GetType() != "Data") { - continue; - } - int32_t data_index; - if (!ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), "index", data_index)) { - continue; - } - if (data_index != index) { - continue; - } - return node; - } - return nullptr; - } - - void ConnectFromInnerData(const ge::FastNode *node, const std::vector &indexes) { - ASSERT_EQ(node->GetInDataNodes().size(), indexes.size()); - size_t i = 0; - for (const auto &in_node : node->GetInDataNodes()) { - ASSERT_EQ(in_node->GetType(), "InnerData"); - int32_t index; - ASSERT_TRUE(ge::AttrUtils::GetInt(in_node->GetOpDescBarePtr(), "index", index)); - ASSERT_EQ(index, indexes[i++]); - } - } - void ConnectFromOuter(ge::FastNode *node, int32_t dst_index, const ge::FastNode *outer, int32_t src_index) { - while (true) { - auto edge = node->GetInDataEdgeByIndex(dst_index); - ASSERT_NE(edge, nullptr); - auto src_node = edge->src; - ASSERT_NE(src_node, nullptr); - if (src_node == outer) { - return; - } - if (src_node->GetType() == "InnerData" || src_node->GetType() == "Data") { - int32_t parent_index; - ASSERT_TRUE(ge::AttrUtils::GetInt(src_node->GetOpDescBarePtr(), "index", parent_index)); - auto parent_graph = src_node->GetExtendInfo()->GetOwnerGraphBarePtr(); - ASSERT_NE(parent_graph, nullptr); - auto parent_node = parent_graph->GetParentNodeBarePtr(); - ASSERT_NE(parent_node, nullptr); - node = const_cast(parent_node); - dst_index = parent_index; - } - } - } - void StrictSubgraphs(const ge::FastNode *node, const std::vector &names) { - const auto &subgraph_names_to_index = node->GetOpDescBarePtr()->GetSubgraphNameIndexes(); - - ASSERT_EQ(subgraph_names_to_index.size(), names.size()); - for (const auto &name : names) { - auto iter = subgraph_names_to_index.find(name); - ASSERT_NE(iter, subgraph_names_to_index.end()); - auto ins_name = node->GetOpDescBarePtr()->GetSubgraphInstanceName(iter->second); - auto root_graph = ge::ExecuteGraphUtils::FindRootGraph(node->GetExtendInfo()->GetOwnerGraphBarePtr()); - ASSERT_NE(root_graph->GetSubGraph(ins_name), nullptr); - } - } -}; - -TEST_F(FastValueHolderUt, CreateConstOk) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto c = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - EXPECT_NE(c, nullptr); - ASSERT_TRUE(c->IsOk()); - ASSERT_NE(c->GetFastNode(), nullptr); - EXPECT_EQ(c->GetType(), ValueHolder::ValueHolderType::kConst); - EXPECT_EQ(c->GetOutIndex(), 0); - auto node = c->GetFastNode(); - EXPECT_EQ(node->GetType(), "Const"); - EXPECT_EQ(node->GetDataOutNum(), 1); - EXPECT_EQ(node->GetDataInNum(), 0); - ge::Buffer buffer; - ASSERT_TRUE(ge::AttrUtils::GetZeroCopyBytes(node->GetOpDescBarePtr(), "value", buffer)); - ASSERT_EQ(buffer.GetSize(), sizeof(ge::FORMAT_NC1HWC0)); - EXPECT_EQ(*reinterpret_cast(buffer.GetData()), ge::FORMAT_NC1HWC0); -} - -TEST_F(FastValueHolderUt, CreateInnerOk) { - auto inner_data_holder = bg::ValueHolder::CreateSingleDataOutput("InnerData", {}); - EXPECT_NE(inner_data_holder, nullptr); - ASSERT_TRUE(inner_data_holder->IsOk()); - ASSERT_NE(inner_data_holder->GetFastNode(), nullptr); - EXPECT_EQ(inner_data_holder->GetType(), ValueHolder::ValueHolderType::kOutput); - EXPECT_EQ(inner_data_holder->GetOutIndex(), 0); - auto node = inner_data_holder->GetFastNode(); - EXPECT_EQ(node->GetType(), "InnerData"); - EXPECT_EQ(node->GetDataOutNum(), 1); - EXPECT_EQ(node->GetDataInNum(), 0); - EXPECT_EQ(inner_data_holder->AddInnerDataToKVMap(0).IsSuccess(), true); - ge::FastNode *data = nullptr; - bool ret = FindValFromMapExtAttr(node->GetExtendInfo()->GetOwnerGraphBarePtr(), - "_inner_data_nodes", 0, data); - EXPECT_EQ(ret, true); - EXPECT_EQ(data, node); -} - -TEST_F(FastValueHolderUt, CreateInnerFailed) { - auto inner_data_holder = bg::ValueHolder::CreateSingleDataOutput("InnerData1", {}); - EXPECT_NE(inner_data_holder, nullptr); - ASSERT_TRUE(inner_data_holder->IsOk()); - ASSERT_NE(inner_data_holder->GetFastNode(), nullptr); - EXPECT_EQ(inner_data_holder->AddInnerDataToKVMap(0).IsSuccess(), false); -} - -TEST_F(FastValueHolderUt, CreateVectorConstOk) { - int64_t shape[5] = {32, 1, 224, 224, 16}; - auto c = ValueHolder::CreateConst(reinterpret_cast(shape), sizeof(shape)); - EXPECT_NE(c, nullptr); - ASSERT_TRUE(c->IsOk()); - ASSERT_NE(c->GetFastNode(), nullptr); - EXPECT_EQ(c->GetType(), ValueHolder::ValueHolderType::kConst); - EXPECT_EQ(c->GetOutIndex(), 0); - auto node = c->GetFastNode(); - EXPECT_EQ(node->GetType(), "Const"); - ge::Buffer buffer; - ASSERT_TRUE(ge::AttrUtils::GetZeroCopyBytes(node->GetOpDescBarePtr(), "value", buffer)); - ASSERT_EQ(buffer.GetSize(), sizeof(shape)); - EXPECT_EQ(memcmp(buffer.GetData(), shape, sizeof(shape)), 0); -} -TEST_F(FastValueHolderUt, CreateFeedOk) { - auto c = ValueHolder::CreateFeed(1); - EXPECT_NE(c, nullptr); - ASSERT_TRUE(c->IsOk()); - ASSERT_NE(c->GetFastNode(), nullptr); - EXPECT_EQ(c->GetType(), ValueHolder::ValueHolderType::kFeed); - EXPECT_EQ(c->GetOutIndex(), 0); - auto node = c->GetFastNode(); - EXPECT_EQ(node->GetType(), "Data"); - EXPECT_EQ(node->GetDataOutNum(), 1); - EXPECT_EQ(node->GetDataInNum(), 0); - int32_t index; - ASSERT_TRUE(ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), "index", index)); - EXPECT_EQ(index, 1); -} -TEST_F(FastValueHolderUt, CreateErrorOk) { - auto holder = ValueHolder::CreateError("This is a test error information, int %d, %s", 10240, "Test msg"); - ASSERT_NE(holder, nullptr); - EXPECT_FALSE(holder->IsOk()); -} -TEST_F(FastValueHolderUt, CreateDataOutOk) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data1 = ValueHolder::CreateFeed(0); - ASSERT_NE(const1, nullptr); - ASSERT_NE(data1, nullptr); - - std::vector inputs = {data1, const1}; - auto holders = ValueHolder::CreateDataOutput("TestNode", inputs, 3); - - ASSERT_EQ(holders.size(), 3); - ASSERT_TRUE(holders[0]->IsOk()); - ASSERT_TRUE(holders[1]->IsOk()); - ASSERT_TRUE(holders[2]->IsOk()); - EXPECT_EQ(holders[0]->GetType(), ValueHolder::ValueHolderType::kOutput); - EXPECT_EQ(holders[1]->GetType(), ValueHolder::ValueHolderType::kOutput); - EXPECT_EQ(holders[2]->GetType(), ValueHolder::ValueHolderType::kOutput); - - ASSERT_NE(const1->GetExecuteGraph(), nullptr); - ASSERT_NE(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(data1->GetExecuteGraph(), nullptr); - ASSERT_NE(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(holders[0]->GetFastNode(), nullptr); - ASSERT_NE(holders[0]->GetExecuteGraph(), nullptr); - ASSERT_NE(holders[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(holders[1]->GetFastNode(), nullptr); - ASSERT_NE(holders[1]->GetExecuteGraph(), nullptr); - ASSERT_NE(holders[1]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(holders[2]->GetFastNode(), nullptr); - ASSERT_NE(holders[2]->GetExecuteGraph(), nullptr); - ASSERT_NE(holders[2]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - // check node is ok - auto node = holders[0]->GetFastNode(); - ASSERT_EQ(node->GetType(), "TestNode"); - ASSERT_EQ(node->GetDataInNum(), 2); - ASSERT_EQ(node->GetDataOutNum(), 3); - - // all holders point to the same node - ASSERT_EQ(holders[0]->GetFastNode(), holders[1]->GetFastNode()); - ASSERT_EQ(holders[0]->GetFastNode(), holders[2]->GetFastNode()); - - // all holders have correct out-index - EXPECT_EQ(holders[0]->GetOutIndex(), 0); - EXPECT_EQ(holders[1]->GetOutIndex(), 1); - EXPECT_EQ(holders[2]->GetOutIndex(), 2); - - // all nodes(contains data and const) in the same graph - EXPECT_EQ(holders[0]->GetExecuteGraph(), const1->GetExecuteGraph()); - EXPECT_EQ(holders[0]->GetExecuteGraph(), data1->GetExecuteGraph()); - - // all holders holds the same graph - EXPECT_EQ(holders[0]->GetExecuteGraph(), holders[1]->GetExecuteGraph()); - EXPECT_EQ(holders[0]->GetExecuteGraph(), holders[2]->GetExecuteGraph()); - EXPECT_EQ(holders[0]->GetExecuteGraph(), const1->GetExecuteGraph()); - EXPECT_EQ(holders[0]->GetExecuteGraph(), data1->GetExecuteGraph()); - - // check graph is ok - auto graph = holders[0]->GetExecuteGraph(); - ASSERT_EQ(graph->GetAllNodes().size(), 3); - CheckGraphGenerally(graph); - auto const1_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Const"); - auto data1_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Data"); - auto node_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "TestNode"); - ASSERT_NE(const1_g, nullptr); - ASSERT_NE(data1_g, nullptr); - ASSERT_NE(node_g, nullptr); - - EXPECT_EQ(node_g->GetInDataEdgeByIndex(0)->src, data1_g); - EXPECT_EQ(node_g->GetInDataEdgeByIndex(1)->src, const1_g); -} - -/* - * --------> node3 - * / / \ - * node1 node2 const3 - * / \ / \ - * data1 const1 data2 const2 - */ -TEST_F(FastValueHolderUt, CreateDataOutOk2) { - ge::Format fmt = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(&fmt, sizeof(fmt)); - auto data1 = ValueHolder::CreateFeed(0); - auto node1 = ValueHolder::CreateSingleDataOutput("Node1", {const1, data1}); - - auto const2 = ValueHolder::CreateConst(&fmt, sizeof(fmt)); - auto data2 = ValueHolder::CreateFeed(0); - auto node2 = ValueHolder::CreateSingleDataOutput("Node1", {const2, data2}); - - auto const3 = ValueHolder::CreateConst(&fmt, sizeof(fmt)); - auto n2_holder = ValueHolder::CreateVoid("Node2", {node1, node2, const3}); - - for (const auto &holder : {const1, data1, node1, const2, data2, node2, const3, n2_holder}) { - ASSERT_NE(holder, nullptr); - ASSERT_TRUE(holder->IsOk()); - ASSERT_NE(holder->GetFastNode(), nullptr); - ASSERT_NE(holder->GetExecuteGraph(), nullptr); - } - EXPECT_EQ(node1->GetFastNode()->GetType(), "Node1"); - EXPECT_EQ(node2->GetFastNode()->GetType(), "Node1"); - EXPECT_EQ(n2_holder->GetFastNode()->GetType(), "Node2"); - - ASSERT_EQ(node1->GetFastNode()->GetDataOutNum(), 1); - ASSERT_EQ(node1->GetFastNode()->GetDataInNum(), 2); - EXPECT_EQ(node1->GetFastNode()->GetInDataEdgeByIndex(0)->src, const1->GetFastNode()); - EXPECT_EQ(node1->GetFastNode()->GetInDataEdgeByIndex(1)->src, data1->GetFastNode()); - - ASSERT_EQ(n2_holder->GetFastNode()->GetDataOutNum(), 0); - ASSERT_EQ(n2_holder->GetFastNode()->GetDataInNum(), 3); - EXPECT_EQ(n2_holder->GetFastNode()->GetInDataEdgeByIndex(0)->src, node1->GetFastNode()); - EXPECT_EQ(n2_holder->GetFastNode()->GetInDataEdgeByIndex(1)->src, node2->GetFastNode()); - EXPECT_EQ(n2_holder->GetFastNode()->GetInDataEdgeByIndex(2)->src, const3->GetFastNode()); -} -TEST_F(FastValueHolderUt, MergeIsolateNodeToGraphOk) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data1 = ValueHolder::CreateFeed(0); - auto node1 = ValueHolder::CreateDataOutput("Node1", {data1, const1}, 2); - ASSERT_NE(const1, nullptr); - ASSERT_NE(data1, nullptr); - ASSERT_EQ(node1.size(), 2); - ASSERT_NE(node1[0], nullptr); - ASSERT_NE(node1[1], nullptr); - - ASSERT_NE(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_EQ(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - EXPECT_EQ(node1[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - EXPECT_EQ(node1[1]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - - ASSERT_NE(const1->GetExecuteGraph(), nullptr); - EXPECT_EQ(const1->GetExecuteGraph(), const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - EXPECT_EQ(data1->GetExecuteGraph(), const1->GetExecuteGraph()); - EXPECT_EQ(node1[0]->GetExecuteGraph(), const1->GetExecuteGraph()); - EXPECT_EQ(node1[1]->GetExecuteGraph(), const1->GetExecuteGraph()); -} - -/* - * - * node3 - * / \ | - * node1 node2 - * / \ / \ - * data1 const1 data2 const2 - */ -TEST_F(FastValueHolderUt, MergeTwoGraphOk1) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data1 = ValueHolder::CreateFeed(0); - auto node1 = ValueHolder::CreateDataOutput("Node1", {data1, const1}, 1); - - auto const2 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data2 = ValueHolder::CreateFeed(0); - auto node2 = ValueHolder::CreateDataOutput("Node2", {data2, const2}, 2); - - auto node3 = ValueHolder::CreateSingleDataOutput("Node3", {node1[0], node2[0]}); - ASSERT_NE(node3, nullptr); - - EXPECT_NE(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node1[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - EXPECT_NE(data2->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(const2->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node2[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node2[1]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - EXPECT_NE(node3->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - ASSERT_NE(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_EQ(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetAllNodes().size(), 7); - - EXPECT_EQ(const1->GetExecuteGraph(), data1->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node1[0]->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), data2->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), const2->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node2[0]->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node2[1]->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node3->GetExecuteGraph()); -} -/* - * - * node4 - * / \ - * node3 \ - * / \ | - * node1 node2 - * / \ / \ - * data1 const1 data2 const2 - */ -TEST_F(FastValueHolderUt, MergeTwoGraphOk) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data1 = ValueHolder::CreateFeed(0); - auto node1 = ValueHolder::CreateDataOutput("Node1", {data1, const1}, 1); - ASSERT_NE(const1, nullptr); - ASSERT_NE(data1, nullptr); - ASSERT_EQ(node1.size(), 1); - ASSERT_NE(node1[0], nullptr); - ASSERT_NE(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - auto const2 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data2 = ValueHolder::CreateFeed(0); - auto node2 = ValueHolder::CreateDataOutput("Node2", {data2, const2}, 2); - ASSERT_NE(const2, nullptr); - ASSERT_NE(data2, nullptr); - ASSERT_EQ(node2.size(), 2); - ASSERT_NE(node2[0], nullptr); - ASSERT_NE(node2[1], nullptr); - - EXPECT_NE(const2->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node2[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node2[1]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - auto node3 = ValueHolder::CreateSingleDataOutput("Node3", {node1[0], node2[0]}); - ASSERT_NE(node3, nullptr); - - auto node4 = ValueHolder::CreateVoid("Node4", {node3, node2[1]}); - ASSERT_NE(node4, nullptr); - - EXPECT_NE(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node1[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(data2->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(const2->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node2[0]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node2[1]->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node3->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_NE(node4->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - ASSERT_NE(const1->GetFastNode(), nullptr); - ASSERT_NE(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - EXPECT_EQ(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetAllNodes().size(), 8); - - EXPECT_EQ(const1->GetExecuteGraph(), data1->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node1[0]->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), data2->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), const2->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node2[0]->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node2[1]->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node3->GetExecuteGraph()); - EXPECT_EQ(const1->GetExecuteGraph(), node4->GetExecuteGraph()); -} -TEST_F(FastValueHolderUt, CreateVoidOk) { - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - auto data1 = ValueHolder::CreateFeed(0); - ASSERT_NE(const1, nullptr); - ASSERT_NE(data1, nullptr); - - std::vector inputs = {data1, const1}; - auto holder = ValueHolder::CreateVoid("TestNode", inputs); - - ASSERT_NE(holder, nullptr); - - ASSERT_NE(const1->GetExecuteGraph(), nullptr); - ASSERT_NE(const1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(data1->GetExecuteGraph(), nullptr); - ASSERT_NE(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(holder->GetFastNode(), nullptr); - ASSERT_NE(holder->GetExecuteGraph(), nullptr); - ASSERT_NE(holder->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - // check graph is ok - auto graph = holder->GetExecuteGraph(); - ASSERT_EQ(graph->GetAllNodes().size(), 3); - CheckGraphGenerally(graph); - - auto const1_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Const"); - auto data1_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Data"); - auto node_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "TestNode"); - ASSERT_NE(const1_g, nullptr); - ASSERT_NE(data1_g, nullptr); - ASSERT_NE(node_g, nullptr); - - EXPECT_EQ(node_g->GetInDataEdgeByIndex(0)->src, data1_g); - EXPECT_EQ(node_g->GetInDataEdgeByIndex(1)->src, const1_g); -} - -TEST_F(FastValueHolderUt, AddDependencyOk) { - auto data1 = ValueHolder::CreateFeed(0); - auto data2 = ValueHolder::CreateFeed(1); - ValueHolder::AddDependency(data1, data2); - - auto node1 = ValueHolder::CreateSingleDataOutput("Node1", {data1}); - auto node2 = ValueHolder::CreateSingleDataOutput("Node1", {data1}); - ValueHolder::AddDependency(node1, node2); - - ASSERT_NE(data1, nullptr); - ASSERT_NE(data2, nullptr); - ASSERT_NE(node1, nullptr); - ASSERT_NE(node2, nullptr); - - ASSERT_NE(data1->GetFastNode(), nullptr); - ASSERT_NE(data2->GetFastNode(), nullptr); - ASSERT_NE(node1->GetFastNode(), nullptr); - ASSERT_NE(node2->GetFastNode(), nullptr); - - ASSERT_EQ(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - data2->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - ASSERT_EQ(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - node1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - ASSERT_EQ(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - node2->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - - ASSERT_EQ(data1->GetFastNode()->GetOutControlNodes().size(), 1); - ASSERT_EQ(data2->GetFastNode()->GetInControlNodes().size(), 1); - EXPECT_EQ(data1->GetFastNode()->GetOutControlNodes()[0], - data2->GetFastNode()); - - ASSERT_EQ(node1->GetFastNode()->GetOutControlNodes().size(), 1); - ASSERT_EQ(node2->GetFastNode()->GetInControlNodes().size(), 1); - EXPECT_EQ(node1->GetFastNode()->GetOutControlNodes()[0], - node2->GetFastNode()); - - auto data3 = ValueHolder::CreateFeed(2); - ASSERT_NE(data3, nullptr); - ValueHolder::AddDependency(data3, data3); - ASSERT_EQ(data3->GetFastNode()->GetOutControlNodes().size(), 0); - ASSERT_EQ(data3->GetFastNode()->GetInControlNodes().size(), 0); -} - -/* - * KernelLaunch - * | - * Tiling - * / \ - * InferShape CompileInfo - * / \ | - * shape1 shape2 json - */ -TEST_F(FastValueHolderUt, CurrentNodeOk) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto shape1 = ValueHolder::CreateFeed(0); - auto shape2 = ValueHolder::CreateFeed(1); - auto json1 = ValueHolder::CreateConst("{}", 3); - - ValueHolder::SetCurrentComputeNode(node); - auto frame = ValueHolder::GetCurrentFrame(); - ASSERT_NE(frame, nullptr); - ASSERT_EQ(frame->GetCurrentComputeNode(), node); - auto shape = ValueHolder::CreateSingleDataOutput("InferShape", {shape1, shape2}); - auto compile_info = ValueHolder::CreateSingleDataOutput("TilingParse", {json1}); - auto tiling_ret = ValueHolder::CreateSingleDataOutput("Tiling", {shape, compile_info}); - auto holder = ValueHolder::CreateVoid("KernelLaunch", {tiling_ret}); - - ASSERT_NE(shape1, nullptr); - ASSERT_NE(shape2, nullptr); - ASSERT_NE(json1, nullptr); - ASSERT_NE(shape, nullptr); - ASSERT_NE(compile_info, nullptr); - ASSERT_NE(tiling_ret, nullptr); - ASSERT_NE(holder, nullptr); - - int64_t compute_node_index_none; - ASSERT_FALSE(ge::AttrUtils::GetInt(shape1->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_none)); - ASSERT_FALSE(ge::AttrUtils::GetInt(shape2->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_none)); - ASSERT_FALSE(ge::AttrUtils::GetInt(json1->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_none)); - - int64_t compute_node_index_shape, compute_node_index_compile_ifo, compute_node_index_tiling_ret, - compute_node_index_holder; - ASSERT_TRUE(ge::AttrUtils::GetInt(shape->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_shape)); - ASSERT_TRUE( - ge::AttrUtils::GetInt(compile_info->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_compile_ifo)); - ASSERT_TRUE( - ge::AttrUtils::GetInt(tiling_ret->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_tiling_ret)); - ASSERT_TRUE(ge::AttrUtils::GetInt(holder->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_holder)); - EXPECT_EQ(compute_node_index_shape, compute_node_index_compile_ifo); - EXPECT_EQ(compute_node_index_shape, compute_node_index_tiling_ret); - EXPECT_EQ(compute_node_index_shape, compute_node_index_holder); - - size_t frame_current_node_index; - frame->GetCurrentNodeIndex(frame_current_node_index); - EXPECT_EQ(compute_node_index_shape, frame_current_node_index); -} -/* - * hello - * / \ - * data0 data1 - */ -TEST_F(FastValueHolderUt, CreateExeGraphOk) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateSingleDataOutput("hello", {data0, data1}); - - ValueHolder::AddRelevantInputNode(node); - ASSERT_NE(graph, nullptr); -} -/* - * hello - * / \ - * data0 data1 - */ -TEST_F(FastValueHolderUt, CreateExeGraphWithTargetsOk) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateVoid("hello", {data0, data1}); - ASSERT_NE(graph, nullptr); -} -/* - * c - * Atomic-LaunchKernel ----> LaunchKernel - * | / - * Atomic-tiling Tiling - * / \ / \ - * TilingParse InferShape TilingParse - * | / \ | - * json1 shape1 shape2 json2 - */ -TEST_F(FastValueHolderUt, ScopedCurrentNodeOk) { - auto graph = std::make_shared("graph"); - - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - auto node = graph->AddNode(op_desc); - - auto clean_op_desc = std::make_shared("node-AtomicClean", "DynamicAtomicAddrClean"); - clean_op_desc->AddInputDesc("workspace", tensor_desc); - clean_op_desc->AddInputDesc("clean1", tensor_desc); - clean_op_desc->AddInputDesc("clean2", tensor_desc); - clean_op_desc->AppendIrInput("workspace", ge::kIrInputRequired); - clean_op_desc->AppendIrInput("clean", ge::kIrInputDynamic); - auto clean_node = graph->AddNode(clean_op_desc); - - auto shape1 = ValueHolder::CreateFeed(0); - auto shape2 = ValueHolder::CreateFeed(1); - auto json1 = ValueHolder::CreateConst("{}", 2); - auto json2 = ValueHolder::CreateConst("{}", 3); - - ValueHolder::SetCurrentComputeNode(node); - auto frame = ValueHolder::GetCurrentFrame(); - ASSERT_NE(frame, nullptr); - ASSERT_EQ(frame->GetCurrentComputeNode(), node); - auto shape = ValueHolder::CreateSingleDataOutput("InferShape", {shape1, shape2}); - - size_t node1_index; - ValueHolderPtr compile_info1, tiling_ret1, holder1; - { - auto guarder = ValueHolder::SetScopedCurrentComputeNode(clean_node); - compile_info1 = ValueHolder::CreateSingleDataOutput("TilingParse", {json1}); - tiling_ret1 = ValueHolder::CreateSingleDataOutput("Tiling", {shape, compile_info1}); - holder1 = ValueHolder::CreateVoid("AtomicKernelLaunch", {tiling_ret1}); - EXPECT_TRUE(frame->GetCurrentNodeIndex(node1_index)); - } - - auto compile_info2 = ValueHolder::CreateSingleDataOutput("TilingParse", {json2}); - auto tiling_ret2 = ValueHolder::CreateSingleDataOutput("Tiling", {shape, compile_info2}); - auto holder2 = ValueHolder::CreateVoid("KernelLaunch", {tiling_ret2}); - - ValueHolder::AddDependency(holder1, holder2); - - ASSERT_NE(shape1, nullptr); - ASSERT_NE(shape2, nullptr); - ASSERT_NE(json1, nullptr); - ASSERT_NE(shape, nullptr); - ASSERT_NE(compile_info1, nullptr); - ASSERT_NE(tiling_ret1, nullptr); - ASSERT_NE(holder1, nullptr); - ASSERT_NE(compile_info2, nullptr); - ASSERT_NE(tiling_ret2, nullptr); - ASSERT_NE(holder2, nullptr); - - int64_t compute_node_index_none; - ASSERT_FALSE(ge::AttrUtils::GetInt(shape1->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_none)); - ASSERT_FALSE(ge::AttrUtils::GetInt(shape2->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_none)); - ASSERT_FALSE(ge::AttrUtils::GetInt(json1->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compute_node_index_none)); - - int64_t shape_index, compile_info1_index, tiling_ret1_index, holder1_index, compile_info2_index, tiling_ret2_index, - holder2_index; - ASSERT_TRUE(ge::AttrUtils::GetInt(shape->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", shape_index)); - - ASSERT_TRUE(ge::AttrUtils::GetInt(compile_info1->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compile_info1_index)); - ASSERT_TRUE(ge::AttrUtils::GetInt(tiling_ret1->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - tiling_ret1_index)); - ASSERT_TRUE(ge::AttrUtils::GetInt(holder1->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - holder1_index)); - - ASSERT_TRUE(ge::AttrUtils::GetInt(compile_info2->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - compile_info2_index)); - ASSERT_TRUE(ge::AttrUtils::GetInt(tiling_ret2->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - tiling_ret2_index)); - ASSERT_TRUE(ge::AttrUtils::GetInt(holder2->GetFastNode()->GetOpDescBarePtr(), "ComputeNodeIndex", - holder2_index)); - - EXPECT_EQ(shape_index, compile_info2_index); - EXPECT_EQ(shape_index, tiling_ret2_index); - EXPECT_EQ(shape_index, holder2_index); - - EXPECT_NE(shape_index, compile_info1_index); - EXPECT_EQ(compile_info1_index, tiling_ret1_index); - EXPECT_EQ(compile_info1_index, holder1_index); -} - -TEST_F(FastValueHolderUt, CreateExeGraphNoOutpus) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateVoid("hello", {data0, data1}); - ASSERT_NE(hello, nullptr); - EXPECT_NE(hello->GetCurrentExecuteGraph(), nullptr); -} - -TEST_F(FastValueHolderUt, CreateExeGraphNoFrame) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateVoid("hello", {data0, data1}); - ASSERT_NE(hello, nullptr); - EXPECT_NE(hello->GetCurrentExecuteGraph(), nullptr); -} -/* - * hello - * / \ - * data0 data1 - */ -TEST_F(FastValueHolderUt, GetCurrentGraphOk) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateVoid("hello", {data0, data1}); - - EXPECT_NE(hello->GetCurrentFrame(), nullptr); - EXPECT_NE(hello->GetCurrentExecuteGraph(), nullptr); -} -/* - * ref - * +------+ - * | | - * launch | - * | | - * tiling | - * | | - * alloc---- - * / \ - * data0 data1 - */ -TEST_F(FastValueHolderUt, RefFromOk) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto alloc_outs = ValueHolder::CreateDataOutput("alloc", {data0, data1}, 3); - auto tiling_outs = ValueHolder::CreateDataOutput("tiling", {data0, data1}, 2); - tiling_outs[1]->RefFrom(alloc_outs[1]); - - auto launch = ValueHolder::CreateSingleDataOutput("launch", {tiling_outs[0], tiling_outs[1]}); - ASSERT_NE(launch, nullptr); - launch->RefFrom(alloc_outs[2]); - EXPECT_NE(launch->GetCurrentExecuteGraph(), nullptr); -} -/* - * hello - * / \ - * data0 data1 - */ -TEST_F(FastValueHolderUt, AddNullOutputs) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateSingleDataOutput("hello", {data0, data1}); - - EXPECT_NE(hello->GetCurrentFrame(), nullptr); - EXPECT_NE(hello->GetCurrentExecuteGraph(), nullptr); -} -/* - * hello - * / \ - * data0 data1 - */ -TEST_F(FastValueHolderUt, AddNullTargets) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateSingleDataOutput("hello", {data0, data1}); - - EXPECT_NE(hello->GetCurrentFrame(), nullptr); - EXPECT_NE(hello->GetCurrentExecuteGraph(), nullptr); -} -TEST_F(FastValueHolderUt, Guard_AddFlagToNode) { - auto data0 = ValueHolder::CreateFeed(0); - auto allocator0 = ValueHolder::CreateSingleDataOutput("CreateAllocator", {data0}); - auto allocator_destroyer = ValueHolder::CreateVoidGuarder("DestroyAllocator", allocator0, {}); - ASSERT_NE(allocator_destroyer, nullptr); - auto tmp_frame = ValueHolder::PopGraphFrame(); - auto graph = tmp_frame->GetExecuteGraph().get(); - - auto node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "DestroyAllocator"); - ASSERT_NE(node, nullptr); - int64_t index; - EXPECT_TRUE(ge::AttrUtils::GetInt(node->GetOpDescBarePtr(), kReleaseResourceIndex, index)); - EXPECT_EQ(index, 0); - - const auto &allocator_node = allocator0->GetFastNode(); - ASSERT_NE(allocator_node, nullptr); - const auto &allocator_desc = allocator_node->GetOpDescBarePtr(); - ASSERT_NE(allocator_desc, nullptr); - string guarder_type; - EXPECT_TRUE(ge::AttrUtils::GetStr(allocator_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "DestroyAllocator"); -} -TEST_F(FastValueHolderUt, Guarder_AddDependencyAutomately_ConnectDataEdgeToResource) { - auto data0 = ValueHolder::CreateFeed(0); - auto allocator0 = ValueHolder::CreateSingleDataOutput("CreateAllocator", {data0}); - auto allocator_destroyer = ValueHolder::CreateVoidGuarder("DestroyAllocator", allocator0, {}); - ASSERT_NE(allocator_destroyer, nullptr); - - size_t alloc_size = 1024; - auto size = ValueHolder::CreateConst(&alloc_size, sizeof(alloc_size)); - auto alloc_mem0 = ValueHolder::CreateSingleDataOutput("AllocMemory", {allocator0, size}); - auto alloc_mem1 = ValueHolder::CreateSingleDataOutput("AllocMemory", {allocator0, size}); - auto tmp_frame = ValueHolder::PopGraphFrame(); - auto graph = tmp_frame->GetExecuteGraph().get(); - - CheckGraphGenerally(graph); - - ASSERT_NE(alloc_mem0, nullptr); - ASSERT_NE(alloc_mem1, nullptr); - HasControlEdge(graph, alloc_mem0->GetFastNode(), allocator_destroyer->GetFastNode()); - HasControlEdge(graph, alloc_mem1->GetFastNode(), allocator_destroyer->GetFastNode()); - - const auto &allocator_node = allocator0->GetFastNode(); - ASSERT_NE(allocator_node, nullptr); - const auto &allocator_desc = allocator_node->GetOpDescBarePtr(); - ASSERT_NE(allocator_desc, nullptr); - string guarder_type; - EXPECT_TRUE(ge::AttrUtils::GetStr(allocator_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "DestroyAllocator"); -} -/* -* NetOutput -* | -* Bar -c-> foo0_guarder -* / \ / -* data1 foo0 -* | -* data0 -*/ -TEST_F(FastValueHolderUt, Guarder_AddDependencyFromTheSameLevelNode_ConnectFromSrcToSubgraphNodes) { - auto data0 = ValueHolder::CreateFeed(0); - auto foo0 = ValueHolder::CreateSingleDataOutput("Foo", {data0}); - auto guarder = ValueHolder::CreateVoidGuarder("FooGuarder", foo0, {}); - ASSERT_NE(guarder, nullptr); - auto data1 = ValueHolder::CreateFeed(1); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar", {data1}); - - ValueHolder::PushGraphFrame(bar1, "BarGraph"); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo", {foo0, data1}); - auto bar_frame = ValueHolder::PopGraphFrame({foo1}, {}); - - auto frame = ValueHolder::PopGraphFrame({bar1}, {}); - ASSERT_NE(frame, nullptr); - ASSERT_NE(frame->GetExecuteGraph(), nullptr); - - EXPECT_EQ(frame->GetExecuteGraph()->TopologicalSorting(), ge::GRAPH_SUCCESS); - EXPECT_TRUE(FastNodeTopoChecker(bar1).OutChecker().CtrlToByType("FooGuarder").IsOk()); - EXPECT_EQ(FastNodeTopoChecker(bar1).StrictConnectFrom({{"Data"}, {"Foo"}}), "success"); - - const auto &foo_node = foo0->GetFastNode(); - ASSERT_NE(foo_node, nullptr); - const auto &foo_desc = foo_node->GetOpDescBarePtr(); - ASSERT_NE(foo_desc, nullptr); - string guarder_type; - EXPECT_TRUE(ge::AttrUtils::GetStr(foo_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "FooGuarder"); -} -TEST_F(FastValueHolderUt, Guarder_DoNotAddDependency_ConnectDataEdgeToNetOutput) { - auto data0 = ValueHolder::CreateFeed(0); - auto foo0 = ValueHolder::CreateSingleDataOutput("Foo", {data0}); - auto guarder = ValueHolder::CreateVoidGuarder("FooGuarder", foo0, {}); - ASSERT_NE(guarder, nullptr); - - auto bar0 = ValueHolder::CreateSingleDataOutput("Bar", {foo0}); - - auto frame = ValueHolder::PopGraphFrame({foo0}, {}); - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - - EXPECT_EQ(FastNodeTopoChecker(foo0).StrictConnectTo(0, {{"NetOutput"}, {"FooGuarder"}, {"Bar"}}), "success"); - HasControlEdge(graph, bar0->GetFastNode(), guarder->GetFastNode()); - ASSERT_EQ(guarder->GetFastNode()->GetInControlNodes().size(), 1); - - const auto &foo_node = foo0->GetFastNode(); - ASSERT_NE(foo_node, nullptr); - const auto &foo_desc = foo_node->GetOpDescBarePtr(); - ASSERT_NE(foo_desc, nullptr); - string guarder_type; - EXPECT_TRUE(ge::AttrUtils::GetStr(foo_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "FooGuarder"); -} -TEST_F(FastValueHolderUt, AddDependencyForGuard_RleaseBy) { - auto data0 = ValueHolder::CreateFeed(0); - auto allocator0 = ValueHolder::CreateSingleDataOutput("CreateAllocator", {data0}); - auto allocator_destroyer = ValueHolder::CreateVoidGuarder("DestroyAllocator", allocator0, {}); - ASSERT_NE(allocator_destroyer, nullptr); - - size_t alloc_size = 1024; - auto size = ValueHolder::CreateConst(&alloc_size, sizeof(alloc_size)); - auto alloc_mem0 = ValueHolder::CreateSingleDataOutput("AllocMemory", {allocator0, size}); - auto free_mem0 = ValueHolder::CreateVoidGuarder("FreeMemory", {alloc_mem0}, {}); - auto tmp_graph = ValueHolder::PopGraphFrame(); - auto graph = tmp_graph->GetExecuteGraph().get(); - CheckGraphGenerally(graph); - - ASSERT_NE(free_mem0, nullptr); - ASSERT_NE(alloc_mem0, nullptr); - HasControlEdge(graph, alloc_mem0->GetFastNode(), allocator_destroyer->GetFastNode()); - - allocator0->ReleaseAfter(free_mem0); - HasControlEdge(graph, free_mem0->GetFastNode(), allocator_destroyer->GetFastNode()); - - const auto &allocator_node = allocator0->GetFastNode(); - ASSERT_NE(allocator_node, nullptr); - const auto &allocator_desc = allocator_node->GetOpDescBarePtr(); - ASSERT_NE(allocator_desc, nullptr); - string guarder_type; - EXPECT_TRUE(ge::AttrUtils::GetStr(allocator_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "DestroyAllocator"); - - const auto &alloc_node = alloc_mem0->GetFastNode(); - ASSERT_NE(alloc_node, nullptr); - const auto &alloc_desc = alloc_node->GetOpDescBarePtr(); - ASSERT_NE(alloc_desc, nullptr); - EXPECT_TRUE(ge::AttrUtils::GetStr(alloc_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "FreeMemory"); -} -TEST_F(FastValueHolderUt, RleaseBy_NoGuarder) { - auto data0 = ValueHolder::CreateFeed(0); - auto allocator0 = ValueHolder::CreateSingleDataOutput("CreateAllocator", {data0}); - - size_t alloc_size = 1024; - auto size = ValueHolder::CreateConst(&alloc_size, sizeof(alloc_size)); - auto alloc_mem0 = ValueHolder::CreateSingleDataOutput("AllocMemory", {allocator0, size}); - auto tmp_graph = ValueHolder::PopGraphFrame(); - auto graph = tmp_graph->GetExecuteGraph().get(); - - CheckGraphGenerally(graph); - - ASSERT_NE(alloc_mem0, nullptr); - - allocator0->ReleaseAfter(alloc_mem0); - - EXPECT_EQ(allocator0->GetFastNode()->GetAllOutNodes().size(), 1); - EXPECT_EQ(allocator0->GetFastNode()->GetAllInNodes().size(), 1); - - EXPECT_EQ(alloc_mem0->GetFastNode()->GetAllOutNodes().size(), 0); - EXPECT_EQ(alloc_mem0->GetFastNode()->GetAllInNodes().size(), 2); -} -TEST_F(FastValueHolderUt, PushFrame_ChildFrameIsNotRoot) { - ValueHolder::PopGraphFrame(); - auto root_frame = ValueHolder::PushGraphFrame(); - EXPECT_TRUE(root_frame->IsRootFrame()); - auto feed0 = ValueHolder::CreateFeed(0); - auto child_frame = ValueHolder::PushGraphFrame(feed0, "subgraph_name0"); - ASSERT_NE(child_frame, nullptr); - EXPECT_FALSE(child_frame->IsRootFrame()); -} -TEST_F(FastValueHolderUt, PushFrame_ComputeNodeIndexTheSame) { - auto compute_node1 = FakeNode(); - auto compute_node2 = FakeNode(); - auto compute_node3 = FakeNode(); - - ValueHolder::PopGraphFrame(); - auto root_frame = ValueHolder::PushGraphFrame(); - EXPECT_TRUE(root_frame->IsRootFrame()); - ValueHolder::SetCurrentComputeNode(compute_node1); - auto feed0 = ValueHolder::CreateFeed(0); - auto feed1 = ValueHolder::CreateFeed(1); - auto foo0 = ValueHolder::CreateSingleDataOutput("Foo", {feed0, feed1}); - - ValueHolder::SetCurrentComputeNode(compute_node2); - auto bar0 = ValueHolder::CreateSingleDataOutput("Bar", {foo0}); - - ValueHolder::PushGraphFrame(foo0, "subgraph_name0"); - auto sub_bar1 = ValueHolder::CreateSingleDataOutput("Bar", {feed0}); - - EXPECT_EQ(GetComputeNodeIndex(sub_bar1->GetFastNode()), GetComputeNodeIndex(foo0->GetFastNode())); - EXPECT_NE(GetComputeNodeIndex(bar0->GetFastNode()), GetComputeNodeIndex(foo0->GetFastNode())); - - ValueHolder::PopGraphFrame(); - ValueHolder::PopGraphFrame(); -} - -TEST_F(FastValueHolderUt, PlacementDefault0) { - auto data0 = ValueHolder::CreateFeed(0); - EXPECT_EQ(data0->GetPlacement(), 0); -} -TEST_F(FastValueHolderUt, SetGetPlacementOk) { - auto data0 = ValueHolder::CreateFeed(0); - data0->SetPlacement(1); - EXPECT_EQ(data0->GetPlacement(), 1); - data0->SetPlacement(2); - EXPECT_EQ(data0->GetPlacement(), 2); -} -TEST_F(FastValueHolderUt, BuildGraphWithDataOutput) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo", {data0, data1}); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar", {data0, data1}); - auto frame = ValueHolder::PopGraphFrame({foo1, bar1}, {}); - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - - CheckGraphGenerally(graph); - - EXPECT_EQ(graph->GetAllNodes().size(), 5); - - auto nodes = ge::ExecuteGraphUtils::FindNodesByTypeFromAllNodes(graph, "NetOutput"); - ASSERT_EQ(nodes.size(), 1); - auto netoutput = nodes[0]; - ASSERT_NE(netoutput, nullptr); - EXPECT_EQ(netoutput->GetAllInNodes().size(), 2); - ASSERT_EQ(netoutput->GetInDataNodes().size(), 2); - EXPECT_EQ((*netoutput->GetInDataNodes().begin())->GetType(), "Foo"); - EXPECT_EQ((*(netoutput->GetInDataNodes().begin() + 1))->GetType(), "Bar"); -} -TEST_F(FastValueHolderUt, BuildGraphWithCtrlOutput) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo", {data0, data1}); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar", {data0, data1}); - auto launch = ValueHolder::CreateVoid("Launch", {foo1, bar1}); - auto frame = ValueHolder::PopGraphFrame({}, {launch}); - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - - CheckGraphGenerally(graph); - - EXPECT_EQ(graph->GetAllNodes().size(), 6); - - auto nodes = ge::ExecuteGraphUtils::FindNodesByTypeFromAllNodes(graph, "NetOutput"); - ASSERT_EQ(nodes.size(), 1); - auto netoutput = nodes[0]; - ASSERT_NE(netoutput, nullptr); - EXPECT_EQ(netoutput->GetAllInNodes().size(), 1); - ASSERT_EQ(netoutput->GetInControlNodes().size(), 1); - EXPECT_EQ((*netoutput->GetInControlNodes().begin())->GetType(), "Launch"); -} -TEST_F(FastValueHolderUt, AppendOutputOk) { - auto foo = ValueHolder::CreateVoid("Foo", {}); - EXPECT_EQ(foo->GetFastNode()->GetDataOutNum(), 0); - - auto outputs = foo->AppendOutputs(5); - EXPECT_EQ(outputs.size(), 5); - EXPECT_EQ(foo->GetFastNode()->GetDataOutNum(), 5); - - auto bar = ValueHolder::CreateSingleDataOutput("Bar", outputs); - EXPECT_NE(bar, nullptr); - EXPECT_EQ(bar->GetFastNode()->GetDataInNum(), 5); - for (int i = 0; i < 5; ++i) { - EXPECT_EQ(bar->GetFastNode()->GetInDataEdgeByIndex(i)->src_output, i); - EXPECT_EQ(bar->GetFastNode()->GetInDataEdgeByIndex(i)->src, - foo->GetFastNode()); - } -} -TEST_F(FastValueHolderUt, AppendOutputToNodeWithOutputs) { - auto foo = ValueHolder::CreateDataOutput("Foo", {}, 3)[0]; - EXPECT_EQ(foo->GetFastNode()->GetDataOutNum(), 3); - - auto outputs = foo->AppendOutputs(5); - ASSERT_EQ(outputs.size(), 5); - EXPECT_EQ(foo->GetFastNode()->GetDataOutNum(), 8); - - auto bar = ValueHolder::CreateSingleDataOutput("Bar", outputs); - EXPECT_NE(bar, nullptr); - EXPECT_EQ(bar->GetFastNode()->GetDataInNum(), 5); - for (int i = 0; i < 5; ++i) { - EXPECT_EQ(bar->GetFastNode()->GetInDataEdgeByIndex(i)->src_output, i + 3); - EXPECT_EQ(bar->GetFastNode()->GetInDataEdgeByIndex(i)->src, - foo->GetFastNode()); - } -} -TEST_F(FastValueHolderUt, AppendInputInOneGraphOk) { - const auto &foo = ValueHolder::CreateVoid("Foo", {}); - const auto &input1 = ValueHolder::CreateConst("{}", 2); - const auto &input2 = ValueHolder::CreateFeed(0); - EXPECT_EQ(foo->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - input1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - EXPECT_EQ(foo->AppendInputs({input1, input2}), 0); - const auto &input_nodes = foo->GetFastNode()->GetInDataNodes(); - EXPECT_EQ(input_nodes.size(), 2); - EXPECT_EQ(input_nodes.at(0)->GetType(), "Const"); - EXPECT_EQ(input_nodes.at(1)->GetType(), "Data"); -} -TEST_F(FastValueHolderUt, AppendInputInParentGraphOk) { - const auto &input1 = ValueHolder::CreateConst("{}", 2); - auto main_node = ValueHolder::CreateVoid("Main", {}); - ValueHolder::PushGraphFrame(main_node, "Main"); - const auto &foo = ValueHolder::CreateVoid("Foo", {}); - EXPECT_NE(foo->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - input1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - EXPECT_EQ(foo->AppendInputs({input1}), 0); - const auto &input_nodes = foo->GetFastNode()->GetInDataNodes(); - EXPECT_EQ(input_nodes.size(), 1); - EXPECT_EQ(input_nodes.at(0)->GetType(), "InnerData"); -} -TEST_F(FastValueHolderUt, AppendInputInChildGraph_noOk) { - const auto &foo = ValueHolder::CreateConst("{}", 2); - auto main_node = ValueHolder::CreateVoid("Main", {}); - ValueHolder::PushGraphFrame(main_node, "Main"); - const auto &input1 = ValueHolder::CreateVoid("Foo", {}); - EXPECT_NE(foo->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - input1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - EXPECT_NE(foo->AppendInputs({input1}), 0); - const auto &input_nodes = foo->GetFastNode()->GetInDataNodes(); - EXPECT_EQ(input_nodes.size(), 0); - const uint32_t in_data_anchor_size = foo->GetFastNode()->GetDataInNum(); - EXPECT_EQ(in_data_anchor_size, 1); -} -TEST_F(FastValueHolderUt, AppendInputInDifferentGraph_noOk) { - auto main_node = ValueHolder::CreateVoid("Main", {}); - ValueHolder::PushGraphFrame(main_node, "Main"); - const auto &input1 = ValueHolder::CreateConst("{}", 2); - const auto &main_frame = ValueHolder::PopGraphFrame(); - const auto &de_init_node = ValueHolder::CreateVoid("DeInit", {}); - ValueHolder::PushGraphFrame(de_init_node, "DeInit"); - const auto &foo = ValueHolder::CreateVoid("Foo", {}); - EXPECT_NE(foo->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), - input1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()); - EXPECT_NE(foo->AppendInputs({input1}), 0); - const auto &input_nodes = foo->GetFastNode()->GetInDataNodes(); - EXPECT_EQ(input_nodes.size(), 0); - const uint32_t in_data_anchor_size = foo->GetFastNode()->GetDataInNum(); - EXPECT_EQ(in_data_anchor_size, 1); -} -TEST_F(FastValueHolderUt, AppendInputWithMulitSubgraph_Ok) { - ValueHolder::PopGraphFrame(); - auto root_frame = ValueHolder::PushGraphFrame(); - EXPECT_TRUE(root_frame->IsRootFrame()); - const auto root_graph = root_frame->GetExecuteGraph().get(); - ASSERT_NE(root_graph, nullptr); - const auto &input1 = ValueHolder::CreateConst("{}", 2); - - const auto &main_node = ValueHolder::CreateVoid("Main", {}); - ValueHolder::PushGraphFrame(main_node, "Main"); - const auto &if_holder = ValueHolder::CreateVoid("If", {}); - ASSERT_NE(if_holder, nullptr); - const auto &if_node = - ge::ExecuteGraphUtils::FindFirstNodeMatchType(ValueHolder::GetCurrentFrame()->GetExecuteGraph().get(), "If"); - ASSERT_NE(if_node, nullptr); - - ValueHolder::PushGraphFrame(if_holder, "then"); - const auto &then_foo = ValueHolder::CreateVoid("Foo", {}); - EXPECT_EQ(then_foo->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetParentNodeBarePtr()->GetType(), "If"); - ValueHolder::PopGraphFrame(); - - ValueHolder::PushGraphFrame(if_holder, "else"); - const auto &else_foo = ValueHolder::CreateVoid("Foo", {}); - EXPECT_EQ(else_foo->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr()->GetParentNodeBarePtr()->GetType(), "If"); - ValueHolder::PopGraphFrame(); - - EXPECT_EQ(then_foo->AppendInputs({input1}), 0); - const auto &then_inputs = then_foo->GetFastNode()->GetInDataNodes(); - EXPECT_EQ(then_inputs.size(), 1); - const auto &if_inputs1 = if_holder->GetFastNode()->GetAllInNodes(); - EXPECT_EQ(if_inputs1.size(), 1); - - EXPECT_EQ(else_foo->AppendInputs({input1}), 0); - const auto &else_inputs = else_foo->GetFastNode()->GetInDataNodes(); - EXPECT_EQ(else_inputs.size(), 1); - const auto &if_inputs2 = if_holder->GetFastNode()->GetAllInNodes(); - EXPECT_EQ(if_inputs2.size(), 1); - - EXPECT_EQ(root_graph->GetAllSubgraphs().size(), 3); - EXPECT_EQ(if_node->GetOpDescBarePtr()->GetSubgraphInstanceNames().size(), 2); -} -TEST_F(FastValueHolderUt, ConnectFromAncestor_CreateInnerData_ParentGraph) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto data2 = ValueHolder::CreateFeed(2); - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0}); - ValueHolder::PushGraphFrame(foo, "Foo"); - auto sub_foo = ValueHolder::CreateSingleDataOutput("SubFoo", {data1, data2}); - ValueHolder::PopGraphFrame({sub_foo}, {}); - - auto frame = ValueHolder::PopGraphFrame(); - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - CheckGraphGenerally(graph); - - EXPECT_EQ(ExeGraphSummaryChecker(graph).StrictDirectNodeTypes({{"Data", 3}, {"Foo", 1}}), "success"); - EXPECT_EQ(graph->GetAllSubgraphs().size(), 1); - - auto foo_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Foo"); - ASSERT_NE(foo_node, nullptr); - EXPECT_EQ(FastNodeTopoChecker(foo).StrictConnectFrom({data0, data1, data2}), "success"); - StrictSubgraphs(foo_node, {"Foo"}); - auto subgraph_name = foo_node->GetOpDescBarePtr()->GetSubgraphInstanceName(0); - auto subgraph = graph->GetSubGraph(subgraph_name); - ASSERT_NE(subgraph, nullptr); - auto ret = gert::ExeGraphSummaryChecker(subgraph).StrictAllNodeTypes({{"InnerData", 2}, {"SubFoo", 1}, {"InnerNetOutput", 1}}); - EXPECT_EQ(ret, "success") << ret; - - auto sub_foo_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(subgraph, "SubFoo"); - ASSERT_NE(sub_foo_node, nullptr); - ConnectFromInnerData(sub_foo_node, {1, 2}); - EXPECT_EQ(FastNodeTopoChecker(sub_foo_node).StrictConnectTo(0, {{"InnerNetOutput"}}), "success"); -} - -TEST_F(FastValueHolderUt, ConnectFromAncestor_InnerDataWithGuarderOutside) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto data2 = ValueHolder::CreateFeed(2); - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0}); - auto data1_guarder = ValueHolder::CreateVoidGuarder("FreeMemory", data1, {}); - ValueHolder::PushGraphFrame(foo, "Foo"); - auto sub_foo = ValueHolder::CreateSingleDataOutput("SubFoo", {data1, data2}); - - auto sub_frame = ValueHolder::PopGraphFrame({sub_foo}, {}); - auto subgraph = sub_frame->GetExecuteGraph().get(); - ASSERT_NE(subgraph, nullptr); - auto innerdata_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(subgraph, "InnerData"); - std::string guarder_type_outside; - (void) ge::AttrUtils::GetStr(innerdata_node->GetOpDescBarePtr(), kNodeWithGuarderOutside, guarder_type_outside); - EXPECT_EQ(!guarder_type_outside.empty(), true); - EXPECT_EQ(guarder_type_outside, "FreeMemory"); - - const auto &data_node = data1->GetFastNode(); - ASSERT_NE(data_node, nullptr); - const auto &data_desc = data_node->GetOpDescBarePtr(); - ASSERT_NE(data_desc, nullptr); - string guarder_type; - EXPECT_TRUE(ge::AttrUtils::GetStr(data_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "FreeMemory"); -} - -TEST_F(FastValueHolderUt, ConnectFromAncestor_InnerDataWithGuarderOutside_In_Subgraph_Nesting) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto data2 = ValueHolder::CreateFeed(2); - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0}); - auto data2_guarder = ValueHolder::CreateVoidGuarder("FreeFftsMem", data2, {}); - - ValueHolder::PushGraphFrame(foo, "Foo"); - auto sub_foo = ValueHolder::CreateSingleDataOutput("SubFoo", {data1}); - - ValueHolder::PushGraphFrame(sub_foo, "SubFoo"); - auto sub_sub_foo = ValueHolder::CreateSingleDataOutput("SubFoo", {data2}); - - auto sub_sub_frame = ValueHolder::PopGraphFrame({sub_sub_foo}, {}); - auto sub_sub_graph = sub_sub_frame->GetExecuteGraph().get(); - - auto innerdata_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(sub_sub_graph, "InnerData"); - std::string guarder_type_outside; - (void) ge::AttrUtils::GetStr(innerdata_node->GetOpDescBarePtr(), kNodeWithGuarderOutside, guarder_type_outside); - EXPECT_EQ(!guarder_type_outside.empty(), true); - EXPECT_EQ(guarder_type_outside, "FreeFftsMem"); - - const auto &data_node = data2->GetFastNode(); - ASSERT_NE(data_node, nullptr); - const auto &data_desc = data_node->GetOpDescBarePtr(); - ASSERT_NE(data_desc, nullptr); - string guarder_type; - EXPECT_TRUE(ge::AttrUtils::GetStr(data_desc, kGuarderNodeType, guarder_type)); - EXPECT_EQ(guarder_type, "FreeFftsMem"); -} - -/* - * +-----------------------------+ - * |Foo | - * | +---------------------+ | - * | |SubFoo | | - * | | NetOutput | | - * | | | | | - * | | Sub2Foo3 | | - * | | / \ | | - * | | Sub2Foo1 Sub2Foo2 | | - * | | | / | | | - * | +---+-----+-----+-----+ | - * | | | | | - * +-------+-----+-----+---------+ - * / | | | - * data0 data1 data2 data3 - */ -TEST_F(FastValueHolderUt, ConnectFromAncestor_CreateInnerDataRecursively_AncestorGraph) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto data2 = ValueHolder::CreateFeed(2); - auto data3 = ValueHolder::CreateFeed(3); - - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0}); - ValueHolder::PushGraphFrame(foo, "Foo"); - - auto sub_foo = ValueHolder::CreateSingleDataOutput("SubFoo", {data1}); - ValueHolder::PushGraphFrame(sub_foo, "Foo"); - - auto sub2_foo1 = ValueHolder::CreateSingleDataOutput("Sub2Foo1", {data1}); - auto sub2_foo2 = ValueHolder::CreateSingleDataOutput("Sub2Foo2", {data3, data2}); - auto sub2_foo3 = ValueHolder::CreateSingleDataOutput("Sub2Foo3", {sub2_foo1, sub2_foo2}); - - ValueHolder::PopGraphFrame({sub2_foo3}, {}); - ValueHolder::PopGraphFrame({sub_foo}, {}); - auto frame = ValueHolder::PopGraphFrame(); - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - CheckGraphGenerally(graph); - - // Check elements on root graph - EXPECT_EQ(ExeGraphSummaryChecker(graph).StrictDirectNodeTypes({{"Data", 4}, {"Foo", 1}}), "success"); - EXPECT_EQ(graph->GetAllSubgraphs().size(), 2); - auto foo_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Foo"); - - ASSERT_NE(foo_node, nullptr); - EXPECT_EQ(FastNodeTopoChecker(foo_node).StrictConnectFrom({data0, data1, data3, data2}), "success"); - StrictSubgraphs(foo_node, {"Foo"}); - ASSERT_EQ(foo_node->GetOpDescBarePtr()->GetSubgraphInstanceNames().size(), 1); - auto foo_graph = graph->GetSubGraph(foo_node->GetOpDescBarePtr()->GetSubgraphInstanceName(0)); - ASSERT_NE(foo_graph, nullptr); - - // Check elements on foo graph - ASSERT_EQ(ExeGraphSummaryChecker(foo_graph).StrictDirectNodeTypes({{"InnerData", 3}, {"SubFoo", 1}, {"InnerNetOutput", 1}}), - "success"); - auto sub_foo_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(foo_graph, "SubFoo"); - ASSERT_NE(sub_foo_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(sub_foo_node).StrictConnectFrom({{"InnerData"}, {"InnerData"}, {"InnerData"}}), - "success"); - ASSERT_EQ(FastNodeTopoChecker(sub_foo_node).StrictConnectTo(0, {{"InnerNetOutput"}}), "success"); - StrictSubgraphs(sub_foo_node, {"Foo"}); - auto subfoo_graph = graph->GetSubGraph(sub_foo_node->GetOpDescBarePtr()->GetSubgraphInstanceName(0)); - ASSERT_NE(subfoo_graph, nullptr); - - // Check elements on SubFoo graph - auto ret = gert::ExeGraphSummaryChecker(subfoo_graph) - .StrictAllNodeTypes( - {{"InnerData", 3}, {"Sub2Foo1", 1}, {"Sub2Foo2", 1}, {"Sub2Foo3", 1}, {"InnerNetOutput", 1}}); - auto sub2_foo1_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(subfoo_graph, "Sub2Foo1"); - - ASSERT_NE(sub2_foo1_node, nullptr); - EXPECT_EQ(FastNodeTopoChecker(sub2_foo1_node).StrictConnectFrom({{"InnerData"}}), "success"); - ConnectFromOuter(sub2_foo1_node, 0, FindData(graph, 1), 0); - - auto sub2_foo2_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(subfoo_graph, "Sub2Foo2"); - ASSERT_NE(sub2_foo2_node, nullptr); - EXPECT_EQ(FastNodeTopoChecker(sub2_foo2_node).StrictConnectFrom({{"InnerData"}, {"InnerData"}}), "success"); - ConnectFromOuter(sub2_foo2_node, 0, FindData(graph, 3), 0); - ConnectFromOuter(sub2_foo2_node, 1, FindData(graph, 2), 0); - - auto sub2_foo3_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(subfoo_graph, "Sub2Foo3"); - ASSERT_NE(sub2_foo3_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(sub2_foo3_node).StrictConnectFrom({sub2_foo1_node, sub2_foo2_node}), "success"); -} - -/* - * +---------------------------------------+ - * |Foo | - * | +-------------------------------+ | - * | |SubFoo | | - * | | NetOutput | | - * | | | | | - * | | Sub2Foo3 | | - * | | / | \ | | - * | | Sub2Foo1 Sub2Foo2 Sub2Foo4 | | - * | | | / \ / | | - * | +---+-----+---------+-----------+ | - * | | | | | - * +-------+-----+---------+---------------+ - * / | | | - * data0 data1 data2 data3 - */ -TEST_F(FastValueHolderUt, ConnectFromAncestor_DeDuplicate_SameSrc) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto data2 = ValueHolder::CreateFeed(2); - auto data3 = ValueHolder::CreateFeed(3); - - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0}); - ValueHolder::PushGraphFrame(foo, "Foo"); - - auto sub_foo = ValueHolder::CreateSingleDataOutput("SubFoo", {data1}); - ValueHolder::PushGraphFrame(sub_foo, "Foo"); - - auto sub2_foo1 = ValueHolder::CreateSingleDataOutput("Sub2Foo1", {data1}); - auto sub2_foo2 = ValueHolder::CreateSingleDataOutput("Sub2Foo2", {data3, data2}); - auto sub2_foo4 = ValueHolder::CreateSingleDataOutput("Sub2Foo4", {data3}); - auto sub2_foo3 = ValueHolder::CreateSingleDataOutput("Sub2Foo3", {sub2_foo1, sub2_foo2, sub2_foo4}); - - ValueHolder::PopGraphFrame({sub2_foo3}, {}); - ValueHolder::PopGraphFrame({sub_foo}, {}); - auto frame = ValueHolder::PopGraphFrame(); - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - CheckGraphGenerally(graph); - - auto foo_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Foo"); - ASSERT_NE(foo_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(foo_node).StrictConnectFrom({{"Data"}, {"Data"}, {"Data"}, {"Data"}}), "success"); - - auto sub_foo_node = ge::ExecuteGraphUtils::FindNodesByTypeFromAllNodes(graph, "SubFoo")[0]; - ASSERT_NE(sub_foo_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(sub_foo_node).StrictConnectFrom({{"InnerData"}, {"InnerData"}, {"InnerData"}}), - "success"); - auto sub_foo_graph = FindFirstSubgraphForNodeType(graph, "SubFoo"); - ASSERT_NE(sub_foo_graph, nullptr); - auto sub2_foo2_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(sub_foo_graph, "Sub2Foo2"); - ASSERT_NE(sub2_foo2_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(sub2_foo2_node).StrictConnectFrom({{"InnerData"}, {"InnerData"}}), "success"); - ConnectFromOuter(sub2_foo2_node, 0, FindData(graph, 3), 0); - ConnectFromOuter(sub2_foo2_node, 1, FindData(graph, 2), 0); - - auto sub2_foo4_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(sub_foo_graph, "Sub2Foo4"); - ASSERT_EQ(FastNodeTopoChecker(sub2_foo4_node).StrictConnectFrom({{"InnerData"}}), "success"); - ConnectFromOuter(sub2_foo4_node, 0, FindData(graph, 3), 0); - - auto inner_data_from_3 = sub2_foo2_node->GetInDataEdgeByIndex(0)->src; - ASSERT_EQ(FastNodeTopoChecker(inner_data_from_3).StrictConnectTo(0, {sub2_foo2_node, sub2_foo4_node}), - "success"); -} -TEST_F(FastValueHolderUt, PopFrame_CreateControlEdge_Targets) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0, data1}); - auto frame = ValueHolder::PopGraphFrame({}, {data0, foo}); - - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - - ASSERT_EQ(ExeGraphSummaryChecker(graph).StrictAllNodeTypes({{"Data", 2}, {"Foo", 1}, {"NetOutput", 1}}), "success"); - - auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "NetOutput"); - ASSERT_NE(netoutput, nullptr); - ASSERT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom({{"Data", -1}, {"Foo", -1}}), "success"); - EXPECT_EQ(netoutput->GetInDataNodes().size(), 0); - EXPECT_EQ(netoutput->GetInControlNodes().size(), 2); -} -TEST_F(FastValueHolderUt, PopFrame_CreateNetOuptut_PopRootGraph) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0, data1}); - auto frame = ValueHolder::PopGraphFrame({data0, foo}, {}); - - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - - ASSERT_EQ(ExeGraphSummaryChecker(graph).StrictAllNodeTypes({{"Data", 2}, {"Foo", 1}, {"NetOutput", 1}}), "success"); - - auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "NetOutput"); - ASSERT_NE(netoutput, nullptr); - EXPECT_EQ(netoutput->GetName(), "NetOutput"); - ASSERT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom({data0, foo}), "success"); - - auto foo_node = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Foo"); - ASSERT_NE(foo_node, nullptr); - ASSERT_EQ(FastNodeTopoChecker(foo_node).StrictConnectFrom({data0, data1}), "success"); - ASSERT_EQ(FastNodeTopoChecker(foo_node).StrictConnectTo(0, {netoutput}), "success"); -} -TEST_F(FastValueHolderUt, PopFrame_CreateInnerNetOuptut_PopSubgraph) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto foo = ValueHolder::CreateSingleDataOutput("Foo", {data0, data1}); - ValueHolder::PushGraphFrame(foo, "Foo"); - auto bar1 = ValueHolder::CreateSingleDataOutput("Bar1", {data0}); - auto bar2 = ValueHolder::CreateSingleDataOutput("Bar2", {data1}); - auto frame = ValueHolder::PopGraphFrame({bar1}, {bar2}); - - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - ASSERT_EQ(ExeGraphSummaryChecker(graph).StrictAllNodeTypes( - {{"InnerData", 2}, {"Bar1", 1}, {"Bar2", 1}, {"InnerNetOutput", 1}}), "success"); - auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "InnerNetOutput"); - - ASSERT_NE(netoutput, nullptr); - ASSERT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom({{"Bar1"}, {"Bar2"}}), "success"); -} -/* - * +-----------------------------+ - * |Foo | - * | +---------------------+ | - * | |SubFoo | | - * | | NetOutput | | - * | | / | \ | | - * | | | foo5 | | | - * | | | | | | | - * | +-+--+--------+-------+ | - * | | | | | - * | / Foo2 Foo3 | - * | | / / \ | - * +---+---------+--------+------+ - * | | | - * data0 data1 data2 - */ -TEST_F(FastValueHolderUt, PopFrame_CraeteInnerData_OutputsUseParentHolder) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto data2 = ValueHolder::CreateFeed(2); - - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo", {data0, data1, data2}); - ValueHolder::PushGraphFrame(foo1, "Foo"); - - auto foo2 = ValueHolder::CreateSingleDataOutput("Foo2", {data0, data1}); - auto foo3 = ValueHolder::CreateDataOutput("Foo3", {data1, data2}, 3); - - auto foo4 = ValueHolder::CreateSingleDataOutput("Foo4", {foo2, foo3[0]}); - ValueHolder::PushGraphFrame(foo4, "Foo"); - auto foo5 = ValueHolder::CreateSingleDataOutput("Foo5", {foo2}); - - ValueHolder::PopGraphFrame({data0, foo3[1], foo5}, {}); - ValueHolder::PopGraphFrame({}, {}); - - auto frame = ValueHolder::PopGraphFrame(); - ASSERT_NE(frame, nullptr); - auto graph = frame->GetExecuteGraph().get(); - ASSERT_NE(graph, nullptr); - - auto foo3_nodes = ge::ExecuteGraphUtils::FindNodesByTypeFromAllNodes(graph, "Foo3"); - ASSERT_EQ(foo3_nodes.size(), 1); - - auto foo4_graph = FindFirstSubgraphForNodeType(graph, "Foo4"); - ASSERT_NE(foo4_graph, nullptr); - ASSERT_EQ(ExeGraphSummaryChecker(foo4_graph).StrictAllNodeTypes({{"InnerData", 3}, {"Foo5", 1}, {"InnerNetOutput", 1}}), - "success"); - auto netoutput = ge::ExecuteGraphUtils::FindFirstNodeMatchType(foo4_graph, "InnerNetOutput"); - ASSERT_EQ(FastNodeTopoChecker(netoutput).StrictConnectFrom({{"InnerData"}, {"InnerData"}, {"Foo5"}}), "success"); - ConnectFromOuter(netoutput, 0, FindData(graph, 0), 0); - ConnectFromOuter(netoutput, 1, foo3_nodes[0], 1); -} - -/* - * +--------------------------------------------------------+ - * |Foo-Node | - * | +---------------------+ +---------------------+ | - * | |Foo-Subgraph1 | |Foo-Subgraph2 | | - * | | NetOutput | | NetOutput | | - * | | | | | ERROR | | | - * | | Bar1 | | Bar1 ---> Bar2 | | - * | | / \ | | / \ | | - * | +--0------1-----------+ +---------0------1----+ | - * +------0------1--------2---------------------------------+ - * | | | - * data0 data1 data2 - */ -TEST_F(FastValueHolderUt, PopFrame_Failed_OutpusUseGraphNotAncestor) { - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - auto data2 = ValueHolder::CreateFeed(2); - - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo", {data0, data1, data2}); - ValueHolder::PushGraphFrame(foo1, "Foo1"); - auto bar1 = ValueHolder::CreateDataOutput("Bar1", {data0, data1}, 3); - ValueHolder::PopGraphFrame({bar1[0]}, {}); - - ValueHolder::PushGraphFrame(foo1, "Foo2"); - auto bar2 = ValueHolder::CreateSingleDataOutput("Bar2", {bar1[0], data0, data1}); - ASSERT_EQ(bar2, nullptr); - bar2 = ValueHolder::CreateSingleDataOutput("Bar2", {bar1[1], data0, data1}); - ASSERT_EQ(bar2, nullptr); - ValueHolder::PopGraphFrame(); -} - -TEST_F(FastValueHolderUt, CreateConstDataOk) { - auto const_data1 = ValueHolder::CreateConstData(0); - auto data1 = ValueHolder::CreateFeed(0); - ASSERT_NE(const_data1, nullptr); - ASSERT_NE(data1, nullptr); - - std::vector inputs = {data1, const_data1}; - auto holder = ValueHolder::CreateVoid("TestNode", inputs); - - ASSERT_NE(holder, nullptr); - - ASSERT_NE(const_data1->GetExecuteGraph(), nullptr); - ASSERT_NE(const_data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(data1->GetExecuteGraph(), nullptr); - ASSERT_NE(data1->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - ASSERT_NE(holder->GetFastNode(), nullptr); - ASSERT_NE(holder->GetExecuteGraph(), nullptr); - ASSERT_NE(holder->GetFastNode()->GetExtendInfo()->GetOwnerGraphBarePtr(), nullptr); - - // check graph is ok - auto graph = holder->GetExecuteGraph(); - ASSERT_EQ(graph->GetAllNodes().size(), 3); - CheckGraphGenerally(graph); - - auto const1_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "ConstData"); - auto data1_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "Data"); - auto node_g = ge::ExecuteGraphUtils::FindFirstNodeMatchType(graph, "TestNode"); - ASSERT_NE(const1_g, nullptr); - ASSERT_NE(data1_g, nullptr); - ASSERT_NE(node_g, nullptr); - - EXPECT_EQ(node_g->GetInDataEdgeByIndex(0)->src, data1_g); - EXPECT_EQ(node_g->GetInDataEdgeByIndex(1)->src, const1_g); -} - -TEST_F(FastValueHolderUt, ValueHolderUtils_TEST_CheckNode_OK) { - auto data1 = ValueHolder::CreateFeed(0); - ASSERT_NE(data1, nullptr); - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - ASSERT_NE(const1, nullptr); - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {data1, const1}); - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {foo1, const1, data1}, 2); - ASSERT_EQ(foo2.size(), 2); - - ASSERT_FALSE(ValueHolderUtils::IsNodeValid(nullptr)); - ASSERT_TRUE(ValueHolderUtils::IsNodeValid(foo1)); - ASSERT_TRUE(ValueHolderUtils::IsNodeValid(foo2[0])); - ASSERT_TRUE(ValueHolderUtils::IsNodeValid(foo2[1])); - - ASSERT_TRUE(ValueHolderUtils::IsNodeEqual(nullptr, nullptr)); - ASSERT_FALSE(ValueHolderUtils::IsNodeEqual(foo1, nullptr)); - ASSERT_FALSE(ValueHolderUtils::IsNodeEqual(nullptr, foo1)); - ASSERT_TRUE(ValueHolderUtils::IsNodeEqual(foo2[0], foo2[1])); -} - -TEST_F(FastValueHolderUt, ValueHolderUtils_Test_BaseInfo_OK) { - ValueHolderPtr holder = nullptr; - EXPECT_STREQ(ValueHolderUtils::GetNodeName(holder).c_str(), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeNameBarePtr(holder), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeType(holder).c_str(), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeTypeBarePtr(holder), ""); - auto desc = ValueHolderUtils::GetNodeOpDesc(holder); - EXPECT_EQ(desc, nullptr); - EXPECT_EQ(desc.get(), ValueHolderUtils::GetNodeOpDescBarePtr(holder)); - - auto data1 = ValueHolder::CreateFeed(0); - ASSERT_NE(data1, nullptr); - EXPECT_STRNE(ValueHolderUtils::GetNodeName(data1).c_str(), ""); - EXPECT_STRNE(ValueHolderUtils::GetNodeNameBarePtr(data1), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeType(data1).c_str(), "Data"); - EXPECT_STREQ(ValueHolderUtils::GetNodeTypeBarePtr(data1), "Data"); - desc = ValueHolderUtils::GetNodeOpDesc(data1); - EXPECT_NE(desc, nullptr); - EXPECT_EQ(desc.get(), ValueHolderUtils::GetNodeOpDescBarePtr(data1)); - EXPECT_EQ(desc->GetAllInputsSize(), 0); - EXPECT_EQ(desc->GetAllOutputsDescSize(), 1); - - ge::Format f1 = ge::FORMAT_NC1HWC0; - auto const1 = ValueHolder::CreateConst(reinterpret_cast(&f1), sizeof(f1)); - ASSERT_NE(const1, nullptr); - EXPECT_STRNE(ValueHolderUtils::GetNodeName(const1).c_str(), ""); - EXPECT_STRNE(ValueHolderUtils::GetNodeNameBarePtr(const1), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeType(const1).c_str(), "Const"); - EXPECT_STREQ(ValueHolderUtils::GetNodeTypeBarePtr(const1), "Const"); - desc = ValueHolderUtils::GetNodeOpDesc(const1); - EXPECT_NE(desc, nullptr); - EXPECT_EQ(desc.get(), ValueHolderUtils::GetNodeOpDescBarePtr(const1)); - EXPECT_EQ(desc->GetAllInputsSize(), 0); - EXPECT_EQ(desc->GetAllOutputsDescSize(), 1); - - auto foo1 = ValueHolder::CreateSingleDataOutput("Foo1", {data1, const1}); - ASSERT_NE(foo1, nullptr); - EXPECT_STRNE(ValueHolderUtils::GetNodeName(foo1).c_str(), ""); - EXPECT_STRNE(ValueHolderUtils::GetNodeNameBarePtr(foo1), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeType(foo1).c_str(), "Foo1"); - EXPECT_STREQ(ValueHolderUtils::GetNodeTypeBarePtr(foo1), "Foo1"); - desc = ValueHolderUtils::GetNodeOpDesc(foo1); - EXPECT_NE(desc, nullptr); - EXPECT_EQ(desc.get(), ValueHolderUtils::GetNodeOpDescBarePtr(foo1)); - EXPECT_EQ(desc->GetAllInputsSize(), 2); - EXPECT_EQ(desc->GetAllOutputsDescSize(), 1); - - auto foo2 = ValueHolder::CreateDataOutput("Foo2", {foo1, const1, data1}, 2); - ASSERT_EQ(foo2.size(), 2); - ASSERT_NE(foo2[0], nullptr); - EXPECT_STRNE(ValueHolderUtils::GetNodeName(foo2[0]).c_str(), ""); - EXPECT_STRNE(ValueHolderUtils::GetNodeNameBarePtr(foo2[0]), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeType(foo2[0]).c_str(), "Foo2"); - EXPECT_STREQ(ValueHolderUtils::GetNodeTypeBarePtr(foo2[0]), "Foo2"); - desc = ValueHolderUtils::GetNodeOpDesc(foo2[0]); - EXPECT_NE(desc, nullptr); - EXPECT_EQ(desc.get(), ValueHolderUtils::GetNodeOpDescBarePtr(foo2[0])); - EXPECT_EQ(desc->GetAllInputsSize(), 3); - EXPECT_EQ(desc->GetAllOutputsDescSize(), 2); - - ASSERT_NE(foo2[1], nullptr); - EXPECT_STRNE(ValueHolderUtils::GetNodeName(foo2[1]).c_str(), ""); - EXPECT_STRNE(ValueHolderUtils::GetNodeNameBarePtr(foo2[1]), ""); - EXPECT_STREQ(ValueHolderUtils::GetNodeType(foo2[1]).c_str(), "Foo2"); - EXPECT_STREQ(ValueHolderUtils::GetNodeTypeBarePtr(foo2[1]), "Foo2"); - desc = ValueHolderUtils::GetNodeOpDesc(foo2[1]); - EXPECT_NE(desc, nullptr); - EXPECT_EQ(desc.get(), ValueHolderUtils::GetNodeOpDescBarePtr(foo2[1])); - EXPECT_EQ(desc->GetAllInputsSize(), 3); - EXPECT_EQ(desc->GetAllOutputsDescSize(), 2); -} - -/* - * hello - * / \ - * data0 data1 - */ -TEST_F(FastValueHolderUt, IsDirectlyControlled) { - auto op_desc = std::make_shared("node", "node"); - ge::GeTensorDesc tensor_desc; - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetFormat(ge::FORMAT_NC1HWC0); - tensor_desc.SetDataType(ge::DT_FLOAT16); - tensor_desc.SetOriginDataType(ge::DT_FLOAT); - tensor_desc.SetShape(ge::GeShape({8, 1, 224, 224, 16})); - tensor_desc.SetOriginShape(ge::GeShape({8, 3, 224, 224})); - op_desc->AddInputDesc("x1", tensor_desc); - op_desc->AppendIrInput("x1", ge::kIrInputRequired); - op_desc->AppendIrInput("x2", ge::kIrInputOptional); - - auto graph = std::make_shared("graph"); - auto node = graph->AddNode(op_desc); - - auto data0 = ValueHolder::CreateFeed(0); - auto data1 = ValueHolder::CreateFeed(1); - - ValueHolder::SetCurrentComputeNode(node); - auto hello = ValueHolder::CreateSingleDataOutput("hello", {data0, data1}); - - (void)bg::ValueHolder::AddDependency(hello, data0); - EXPECT_EQ(ValueHolderUtils::IsDirectlyControlled(hello, data0), true); - EXPECT_EQ(ValueHolderUtils::IsDirectlyControlled(data0, data1), false); -} - -TEST_F(FastValueHolderUt, ClearGraphFrameSucc) { - EXPECT_NE(ValueHolder::GetCurrentFrame(), nullptr); - ValueHolder::ClearGraphFrameResource(); - EXPECT_EQ(ValueHolder::GetCurrentFrame(), nullptr); - EXPECT_EQ(ValueHolder::PopGraphFrame(), nullptr); -} -} // namespace bg -} // namespace gert diff --git a/tests/ut/exe_graph/gert_tensor_data_unittest.cc b/tests/ut/exe_graph/gert_tensor_data_unittest.cc deleted file mode 100644 index 42c6a79c56d6626037b2c775f1c014879799db65..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/gert_tensor_data_unittest.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "inc/exe_graph/runtime/gert_tensor_data.h" - -namespace gert { -class GertTensorDataUT : public testing::Test {}; - -TEST_F(GertTensorDataUT, Init_success) { - GertTensorData gert_tensor_data = GertTensorData(); - ASSERT_EQ(gert_tensor_data.GetSize(), 0U); - ASSERT_EQ(gert_tensor_data.GetPlacement(), kTensorPlacementEnd); - ASSERT_EQ(gert_tensor_data.GetStreamId(), -1); -} - -TEST_F(GertTensorDataUT, SetSize_SetPlacement_SetAddr_success) { - GertTensorData gert_tensor_data = GertTensorData(); - ASSERT_EQ(gert_tensor_data.GetSize(), 0U); - ASSERT_EQ(gert_tensor_data.GetPlacement(), kTensorPlacementEnd); - ASSERT_EQ(gert_tensor_data.GetStreamId(), -1); - - gert_tensor_data.SetSize(100U); - ASSERT_EQ(gert_tensor_data.GetSize(), 100U); - - gert_tensor_data.SetPlacement(kOnDeviceHbm); - ASSERT_EQ(gert_tensor_data.GetPlacement(), kOnDeviceHbm); - - gert_tensor_data.MutableTensorData().SetAddr((void *)1, nullptr); - ASSERT_EQ(gert_tensor_data.GetAddr(), (void *)1); -} - -TEST_F(GertTensorDataUT, TensorDataWithoutManager_After_Free_GetAddr_success) { - GertTensorData gert_tensor_data = GertTensorData(); - void *addr = (void*)1; - const_cast(&gert_tensor_data.GetTensorData())->SetAddr(addr, nullptr); - ASSERT_EQ(gert_tensor_data.FreeHoldAddr(), ge::GRAPH_SUCCESS); - ASSERT_EQ(gert_tensor_data.GetAddr(), addr); -} -} // namespace gert diff --git a/tests/ut/exe_graph/getcdim_unittest.cc b/tests/ut/exe_graph/getcdim_unittest.cc deleted file mode 100644 index 49d654ef231b00ee18e53262009c422ad7995df7..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/getcdim_unittest.cc +++ /dev/null @@ -1,383 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include "exe_graph/lowering/getcdim.h" -#include -#include -#include "graph/compute_graph.h" -#include "exe_graph/runtime/context_extend.h" -#include "exe_graph/runtime/continuous_vector.h" -#include "exe_graph/runtime/tensor.h" -#include "graph/debug/ge_attr_define.h" -#include "expand_dimension.h" -#include "graph/utils/graph_utils.h" -#include "runtime/runtime_attrs_def.h" -#include "graph/compute_graph.h" -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include "exe_graph/runtime/tiling_context.h" - -#include "graph/compute_graph.h" -#include "exe_graph/lowering/bg_kernel_context_extend.h" -#include "exe_graph/runtime/tiling_context.h" -#include "lowering/kernel_run_context_builder.h" -namespace gert { - -struct CDimFakeKernelContextHolder { - template - T *GetContext() { - return reinterpret_cast(holder.context_); - } - ComputeNodeInfo *MutableComputeNodeInfo() { - return reinterpret_cast(holder.compute_node_extend_holder_.get()); - } - size_t kernel_input_num; - size_t kernel_output_num; - KernelContextHolder holder; -}; -CDimFakeKernelContextHolder CDimBuildKernelRunContext(size_t input_num, size_t output_num, int64_t reshape_type); - -class CDimKernelRunContextFaker { - public: - CDimKernelRunContextFaker() = default; - CDimKernelRunContextFaker &KernelIONum(size_t input_num, size_t output_num); - CDimKernelRunContextFaker &NodeIoNum(size_t input_num, size_t output_num); - CDimKernelRunContextFaker &IrInputNum(size_t input_num); - CDimKernelRunContextFaker &IrInstanceNum(std::vector instance_num); - CDimKernelRunContextFaker &NodeInputTd(int32_t index, ge::DataType dt, ge::Format origin_format, - ge::Format storage_format); - CDimKernelRunContextFaker &NodeOutputTd(int32_t index, ge::DataType dt, ge::Format origin_format, - ge::Format storage_format); - CDimKernelRunContextFaker &NodeAttrs(std::vector> keys_to_value); - CDimKernelRunContextFaker &Inputs(std::vector inputs); - CDimKernelRunContextFaker &Outputs(std::vector outputs); - - CDimFakeKernelContextHolder Build(int64_t reshape_type) const; - - private: - ge::OpDescPtr FakeOp(int64_t reshape_type) const; - - private: - size_t kernel_input_num_; - size_t kernel_output_num_; - size_t node_input_num_; - size_t node_output_num_; - std::vector ir_instance_num_; - std::vector node_input_tds_; - std::vector node_output_tds_; - std::vector inputs_; - std::vector outputs_; - std::vector> attrs_; -}; - - -class CDimTilingContextFaker { - public: - CDimTilingContextFaker &NodeIoNum(size_t input_num, size_t output_num); - CDimTilingContextFaker &IrInputNum(size_t input_num) { - base_faker_.IrInputNum(input_num); - return *this; - } - CDimTilingContextFaker &IrInstanceNum(std::vector instance_num) { - base_faker_.IrInstanceNum(std::move(instance_num)); - return *this; - } - CDimTilingContextFaker &NodeInputTd(int32_t index, ge::DataType dt, ge::Format origin_format, ge::Format storage_format) { - base_faker_.NodeInputTd(index, dt, origin_format, storage_format); - return *this; - } - CDimTilingContextFaker &NodeOutputTd(int32_t index, ge::DataType dt, ge::Format origin_format, - ge::Format storage_format) { - base_faker_.NodeOutputTd(index, dt, origin_format, storage_format); - return *this; - } - CDimTilingContextFaker &NodeAttrs(std::vector> keys_to_value) { - base_faker_.NodeAttrs(std::move(keys_to_value)); - return *this; - } - CDimTilingContextFaker &InputShapes(std::vector input_shapes); - CDimTilingContextFaker &OutputShapes(std::vector output_shapes); - CDimTilingContextFaker &CompileInfo(void *compile_info); - CDimTilingContextFaker &TilingData(void *tiling_data); - CDimTilingContextFaker &Workspace(ContinuousVector *workspace); - - CDimFakeKernelContextHolder Build(int64_t reshape_type) const; - - private: - void UpdateInputs(); - void UpdateOutputs(); - private: - enum InputsAppend { kInputsCompileInfo, kInputsTilingFunc, kInputsAppendEnd }; - - CDimKernelRunContextFaker base_faker_; - std::vector input_shapes_; - std::vector output_shapes_; - - void *compile_info_; -}; - - -CDimFakeKernelContextHolder CDimBuildKernelRunContext(size_t input_num, size_t output_num, int64_t reshape_type) { - return CDimKernelRunContextFaker().KernelIONum(input_num, output_num).Build(reshape_type); -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::KernelIONum(size_t input_num, size_t output_num) { - kernel_input_num_ = input_num; - kernel_output_num_ = output_num; - return *this; -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::NodeIoNum(size_t input_num, size_t output_num) { - node_input_num_ = input_num; - node_output_num_ = output_num; - node_input_tds_.resize(input_num); - node_output_tds_.resize(output_num); - return *this; -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::IrInputNum(size_t input_num) { - ir_instance_num_ = std::vector(input_num, 1); - return *this; -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::IrInstanceNum(std::vector instance_num) { - ir_instance_num_ = std::move(instance_num); - return *this; -} - -ge::OpDescPtr CDimKernelRunContextFaker::FakeOp(int64_t reshape_type) const { - auto op_desc = std::make_shared("node", "node"); - for (size_t i = 0; i < node_input_num_; ++i) { - auto prefix = "x_" + std::to_string(i) + "_"; - op_desc->AppendIrInput(prefix, ge::kIrInputRequired); - auto td = ge::GeTensorDesc(); - if (reshape_type != 0) { - (void) ge::AttrUtils::SetInt(td, ge::ATTR_NAME_RESHAPE_TYPE_MASK, reshape_type); - //td.SetExpandDimsType(ExpandDimsType(reshape_type)); - } - - td.SetOriginFormat(node_input_tds_[i].GetOriginFormat()); - td.SetFormat(node_input_tds_[i].GetStorageFormat()); - td.SetDataType(node_input_tds_[i].GetDataType()); - td.SetOriginDataType(node_input_tds_[i].GetDataType()); - op_desc->AddInputDesc(prefix, td); - } - for (size_t i = 0; i < node_output_num_; ++i) { - auto prefix = "y_" + std::to_string(i) + "_"; - auto td = ge::GeTensorDesc(); - if (reshape_type != 0) { - (void) ge::AttrUtils::SetInt(td, ge::ATTR_NAME_RESHAPE_TYPE_MASK, reshape_type); - //td.SetExpandDimsType(ExpandDimsType(reshape_type)); - } - - td.SetOriginFormat(node_output_tds_[i].GetOriginFormat()); - td.SetFormat(node_output_tds_[i].GetStorageFormat()); - td.SetDataType(node_output_tds_[i].GetDataType()); - td.SetOriginDataType(node_output_tds_[i].GetDataType()); - - op_desc->AddOutputDesc(prefix, td); - } - return op_desc; -} - -CDimFakeKernelContextHolder CDimKernelRunContextFaker::Build(int64_t reshape_type) const { - CDimFakeKernelContextHolder fake_holder; - fake_holder.kernel_input_num = kernel_input_num_; - fake_holder.kernel_output_num = kernel_output_num_; - KernelRunContextBuilder kernel_context_builder; - auto op_desc = FakeOp(reshape_type); - std::cout << "kernel_input_num = " << kernel_input_num_ << "kernel_output_num = " << kernel_output_num_ - << " input.size = " << inputs_.size() << " outputs.size = " << outputs_.size() << std::endl; - if (inputs_.size() != kernel_input_num_ || outputs_.size() != kernel_output_num_) { - std::vector inputs(kernel_input_num_, nullptr); - std::vector outputs(kernel_output_num_, nullptr); - fake_holder.holder = kernel_context_builder.Inputs(inputs).Outputs(outputs).Build(op_desc); - return fake_holder; - } - fake_holder.holder = kernel_context_builder.Inputs(inputs_).Outputs(outputs_).Build(op_desc); - return fake_holder; -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::NodeInputTd(int32_t index, ge::DataType dt, ge::Format origin_format, - ge::Format storage_format) { - node_input_tds_[index].SetDataType(dt); - node_input_tds_[index].SetOriginFormat(origin_format); - node_input_tds_[index].SetStorageFormat(storage_format); - return *this; -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::NodeOutputTd(int32_t index, ge::DataType dt, ge::Format origin_format, - ge::Format storage_format) { - node_output_tds_[index].SetDataType(dt); - node_output_tds_[index].SetOriginFormat(origin_format); - node_output_tds_[index].SetStorageFormat(storage_format); - return *this; -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::Inputs(std::vector inputs) { - inputs_ = std::move(inputs); - return *this; -} -CDimKernelRunContextFaker &CDimKernelRunContextFaker::Outputs(std::vector outputs) { - outputs_ = std::move(outputs); - return *this; -} -CDimKernelRunContextFaker & -CDimKernelRunContextFaker::NodeAttrs(std::vector> keys_to_value) { - attrs_ = std::move(keys_to_value); - return *this; -} - -CDimTilingContextFaker &CDimTilingContextFaker::NodeIoNum(size_t input_num, size_t output_num) { - base_faker_.KernelIONum(input_num, output_num); - base_faker_.NodeIoNum(input_num, output_num); - return *this; -} -CDimTilingContextFaker &CDimTilingContextFaker::InputShapes(std::vector input_shapes) { - input_shapes_ = std::move(input_shapes); - UpdateInputs(); - return *this; -} -CDimTilingContextFaker &CDimTilingContextFaker::OutputShapes(std::vector output_shapes) { - output_shapes_ = std::move(output_shapes); - UpdateOutputs(); - return *this; -} -CDimTilingContextFaker &CDimTilingContextFaker::CompileInfo(void *compile_info) { - compile_info_ = compile_info; - UpdateInputs(); - UpdateOutputs(); - return *this; -} -CDimTilingContextFaker &CDimTilingContextFaker::TilingData(void *tiling_data) { - return *this; -} -CDimTilingContextFaker &CDimTilingContextFaker::Workspace(ContinuousVector *workspace) { - return *this; -} -CDimFakeKernelContextHolder CDimTilingContextFaker::Build(int64_t reshape_type) const { - return base_faker_.Build(reshape_type); -} -void CDimTilingContextFaker::UpdateInputs() { - std::vector inputs; - for (const auto input_shape : input_shapes_) { - inputs.push_back(input_shape); - } - inputs.push_back(nullptr); // kInputsTilingFunc - base_faker_.Inputs(std::move(inputs)); -} - -void CDimTilingContextFaker::UpdateOutputs() { - std::vector outputs; - for (const auto output_shape : output_shapes_) { - outputs.push_back(output_shape); - } - base_faker_.Outputs(std::move(outputs)); -} - - -namespace { -struct CDimTestTilingData { - int64_t a; -}; -struct CDimTestCompileInfo { - int64_t a; - int64_t b; - std::vector c; -}; -} -class GetCDimTestUT : public testing::Test {}; - -// 测试构造kernel context的时候从tensor desc上获取ATTR_NAME_RESHAPE_TYPE_MASK并设置到compile time tensor desc 上 -// 同时测试调用Expand是否能够得到正确的扩维shape -TEST_F(GetCDimTestUT, BuildRequiredInputWithExpandDimsType01) { - gert::StorageShape in_shape = {{5,2,3,4}, {5, 1, 1, 1, 1}}; - gert::StorageShape out_shape = {{5,2,3,4}, {5, 1,1, 1, 1}}; - vector origin_shape = {5,2,3,4}; - int64_t int_reshape_type = 0; - // tiling data - CDimTestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::CDimTilingContextFaker() - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .Build(int_reshape_type); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - int64_t cdim = GetInputCDim(context, 0); - EXPECT_EQ(2, cdim); - cdim = GetOutputCDim(context, 0); - EXPECT_EQ(2, cdim); -} - -TEST_F(GetCDimTestUT, BuildRequiredInputWithExpandDimsType02) { - gert::StorageShape in_shape = {{5,2,3,4, 1}, {5, 1, 1, 1, 1}}; - gert::StorageShape out_shape = {{5,2,3,4, 1}, {5, 1,1, 1, 1}}; - vector origin_shape = {5,2,3,4}; - int64_t int_reshape_type = transformer::ExpandDimension::GenerateReshapeType(ge::FORMAT_NC1HWC0, ge::FORMAT_NC1HWC0, - origin_shape.size(), "CHW"); - int_reshape_type = 0; - // tiling data - CDimTestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::CDimTilingContextFaker() - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::FORMAT_NC1HWC0) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_NC1HWC0, ge::FORMAT_NC1HWC0) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .Build(int_reshape_type); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - int64_t cdim = GetInputCDim(context, 0); - EXPECT_EQ(-1, cdim); - cdim = GetOutputCDim(context, 0); - EXPECT_EQ(-1, cdim); -} - -TEST_F(GetCDimTestUT, BuildRequiredInputWithExpandDimsType03) { - gert::StorageShape in_shape = {{5, 6, 7}, {5, 6, 7, 1}}; - gert::StorageShape out_shape = {{5, 6, 7}, {5, 6, 7, 1}}; - vector origin_shape = {5, 6, 7}; - int64_t int_reshape_type = transformer::ExpandDimension::GenerateReshapeType(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, - origin_shape.size(), "NCH"); - - // tiling data - CDimTestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::CDimTilingContextFaker() - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .Build(int_reshape_type); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - int64_t cdim = GetInputCDim(context, 0); - EXPECT_EQ(6, cdim); - cdim = GetOutputCDim(context, 0); - EXPECT_EQ(6, cdim); -} - -TEST_F(GetCDimTestUT, GetAxisIndexByName01) { - int32_t c_ax_index = transformer::ExpandDimension::GetAxisIndexByName('C', ge::FORMAT_NCHW); - EXPECT_EQ(1, c_ax_index); -} - -TEST_F(GetCDimTestUT, GetReshapeAxicValueByName01) { - vector origin_shape = {5, 6, 7}; - vector storage_shape = {5, 6, 7, 1}; - ge::GeShape inshape{storage_shape}; - int64_t int_reshape_type = transformer::ExpandDimension::GenerateReshapeType(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, - origin_shape.size(), "NCH"); - int64_t c_ax_value = transformer::ExpandDimension::GetReshapeAxicValueByName(int_reshape_type, 'W', inshape, ge::FORMAT_NCHW); - EXPECT_EQ(1, c_ax_value); -} - -} // namespace gert diff --git a/tests/ut/exe_graph/hyper_status_unittest.cc b/tests/ut/exe_graph/hyper_status_unittest.cc deleted file mode 100644 index fdedc1f90117cf54f15cc7975449ca0c7a19b7cd..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/hyper_status_unittest.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/hyper_status.h" -#include - -namespace gert { -class HyperStatusUnittest : public testing::Test {}; - -TEST_F(HyperStatusUnittest, CreateMessageNullPtr) { - va_list arg; - EXPECT_EQ(CreateMessage(nullptr, arg), nullptr); -} - -TEST_F(HyperStatusUnittest, CreateErrorStatusOk) { - auto status = HyperStatus::ErrorStatus("This is a error message %s, int %d", "Hello", 10); - ASSERT_FALSE(status.IsSuccess()); - EXPECT_EQ(strcmp(status.GetErrorMessage(), "This is a error message Hello, int 10"), 0); -} - -TEST_F(HyperStatusUnittest, CreateSuccessStatusOk) { - auto status = HyperStatus::Success(); - EXPECT_TRUE(status.IsSuccess()); -} - -TEST_F(HyperStatusUnittest, CopyAssignOk1) { - auto error = HyperStatus::ErrorStatus("This is a error message %s, int %d", "Hello", 10); - auto success = HyperStatus::Success(); - success = error; - EXPECT_FALSE(success.IsSuccess()); - EXPECT_FALSE(error.IsSuccess()); - EXPECT_EQ(strcmp(success.GetErrorMessage(), "This is a error message Hello, int 10"), 0); - EXPECT_NE(success.GetErrorMessage(), error.GetErrorMessage()); -} - -TEST_F(HyperStatusUnittest, CopyAssginOk2) { - auto error = HyperStatus::ErrorStatus("This is a error message %s, int %d", "Hello", 10); - auto success = HyperStatus::Success(); - error = success; - EXPECT_TRUE(success.IsSuccess()); - EXPECT_TRUE(error.IsSuccess()); -} - -TEST_F(HyperStatusUnittest, CopyConstructOk) { - auto error = HyperStatus::ErrorStatus("This is a error message %s, int %d", "Hello", 10); - auto success = HyperStatus::Success(); - HyperStatus e1(error); - HyperStatus s1(success); - EXPECT_FALSE(e1.IsSuccess()); - EXPECT_TRUE(s1.IsSuccess()); - EXPECT_EQ(strcmp(e1.GetErrorMessage(), "This is a error message Hello, int 10"), 0); -} - -TEST_F(HyperStatusUnittest, MoveConstructOk) { - auto error = HyperStatus::ErrorStatus("This is a error message %s, int %d", "Hello", 10); - auto success = HyperStatus::Success(); - HyperStatus e1(std::move(error)); - HyperStatus s1(std::move(success)); - EXPECT_FALSE(e1.IsSuccess()); - EXPECT_TRUE(s1.IsSuccess()); - EXPECT_EQ(strcmp(e1.GetErrorMessage(), "This is a error message Hello, int 10"), 0); -} - -TEST_F(HyperStatusUnittest, MoveAssginOk1) { - auto error = HyperStatus::ErrorStatus("This is a error message %s, int %d", "Hello", 10); - auto success = HyperStatus::Success(); - error = std::move(success); - EXPECT_TRUE(error.IsSuccess()); -} - -TEST_F(HyperStatusUnittest, MoveAssginOk2) { - auto error = HyperStatus::ErrorStatus("This is a error message %s, int %d", "Hello", 10); - auto success = HyperStatus::Success(); - success = std::move(error); - ASSERT_FALSE(success.IsSuccess()); - EXPECT_EQ(strcmp(success.GetErrorMessage(), "This is a error message Hello, int 10"), 0); -} -} diff --git a/tests/ut/exe_graph/infer_datatype_context_unittest.cc b/tests/ut/exe_graph/infer_datatype_context_unittest.cc deleted file mode 100644 index 026dc3303a4e9185d1a965f5b1033c6d4d5c81f6..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/infer_datatype_context_unittest.cc +++ /dev/null @@ -1,190 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/infer_datatype_context.h" -#include -#include "register/op_impl_registry_base.h" -#include "register/op_impl_registry.h" -#include "faker/kernel_run_context_faker.h" -#include "exe_graph/runtime/storage_shape.h" -namespace gert { -class InferDataTypeContextUT : public testing::Test {}; -TEST_F(InferDataTypeContextUT, GetInputDataTypeOk) { - ge::DataType in_datatype1 = ge::DT_INT8; - ge::DataType in_datatype2 = ge::DT_INT8; - ge::DataType out_datatype = ge::DT_FLOAT16; - auto context_holder = InferDataTypeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputDataTypes({&in_datatype1, &in_datatype2}) - .OutputDataTypes({&out_datatype}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - EXPECT_EQ(context->GetInputDataType(0), in_datatype1); - EXPECT_EQ(context->GetInputDataType(1), in_datatype2); -} - -TEST_F(InferDataTypeContextUT, GetDynamicInputDataTypeOk) { - ge::DataType in_datatype1 = ge::DT_INT8; - ge::DataType in_datatype2 = ge::DT_INT4; - ge::DataType in_datatype3 = ge::DT_INT8; - ge::DataType in_datatype4 = ge::DT_INT4; - ge::DataType out_datatype = ge::DT_FLOAT16; - auto context_holder = InferDataTypeContextFaker() - .IrInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputDataTypes({&in_datatype1, &in_datatype2, &in_datatype3, &in_datatype4}) - .OutputDataTypes({&out_datatype}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - EXPECT_EQ(context->GetOptionalInputDataType(0), in_datatype1); - EXPECT_EQ(context->GetDynamicInputDataType(1, 0), in_datatype2); - EXPECT_EQ(context->GetDynamicInputDataType(1, 1), in_datatype3); - - EXPECT_EQ(context->GetOptionalInputDataType(2), ge::DataType::DT_UNDEFINED); - - EXPECT_EQ(context->GetOptionalInputDataType(3), in_datatype4); -} - -TEST_F(InferDataTypeContextUT, GetOutDataTypeOk) { - ge::DataType in_datatype1 = ge::DT_INT4; - ge::DataType in_datatype2 = ge::DT_INT8; - ge::DataType out_datatype = ge::DT_FLOAT16; - auto context_holder = InferDataTypeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputDataTypes({&in_datatype1, &in_datatype2}) - .OutputDataTypes({&out_datatype}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - EXPECT_EQ(context->GetOutputDataType(0), out_datatype); - - EXPECT_EQ(context->GetOutputDataType(1), ge::DataType::DT_UNDEFINED); -} - -TEST_F(InferDataTypeContextUT, SetOutputDataTypeOk) { - ge::DataType in_datatype1 = ge::DT_INT4; - ge::DataType in_datatype2 = ge::DT_INT8; - ge::DataType origin_out_datatype = ge::DT_FLOAT16; - auto context_holder = InferDataTypeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputDataTypes({&in_datatype1, &in_datatype2}) - .OutputDataTypes({&origin_out_datatype}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - EXPECT_EQ(context->GetOutputDataType(0), origin_out_datatype); - EXPECT_EQ(context->SetOutputDataType(0, ge::DT_INT32), ge::GRAPH_SUCCESS); - EXPECT_EQ(context->GetOutputDataType(0), ge::DT_INT32); -} - -TEST_F(InferDataTypeContextUT, Retpeat_register_InferDataType_InferOutDataTypeByFirstInput_success) { - auto funcs = gert::OpImplRegistry::GetInstance().GetOpImpl("TestFoo1"); - ASSERT_EQ(funcs, nullptr); - - IMPL_OP(TestFoo1) - .InferOutDataTypeSameWithFirstInput(); - funcs = gert::OpImplRegistry::GetInstance().GetOpImpl("TestFoo1"); - ASSERT_NE(funcs, nullptr); - auto infer_func = funcs->infer_datatype; - EXPECT_NE(infer_func, nullptr); - - ge::DataType in_datatype1 = ge::DT_INT4; - ge::DataType in_datatype2 = ge::DT_INT8; - ge::DataType origin_out_datatype = ge::DT_FLOAT16; - auto context_holder = InferDataTypeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputDataTypes({&in_datatype1, &in_datatype2}) - .OutputDataTypes({&origin_out_datatype}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - EXPECT_EQ(context->GetOutputDataType(0), origin_out_datatype); - EXPECT_EQ(infer_func(context), ge::GRAPH_SUCCESS); - EXPECT_EQ(context->GetOutputDataType(0), in_datatype1); -} - -TEST_F(InferDataTypeContextUT, Retpeat_register_InferDataType_InferOutDataTypeByFirstInput_failed) { - auto funcs = gert::OpImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_EQ(funcs, nullptr); - - IMPL_OP(TestFoo) - .InferOutDataTypeSameWithFirstInput(); - funcs = gert::OpImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_NE(funcs, nullptr); - auto infer_func = funcs->infer_datatype; - EXPECT_NE(infer_func, nullptr); - - ge::DataType in_datatype1 = ge::DT_UNDEFINED; - ge::DataType in_datatype2 = ge::DT_INT8; - ge::DataType origin_out_datatype = ge::DT_FLOAT16; - auto context_holder = InferDataTypeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputDataTypes({&in_datatype1, &in_datatype2}) - .OutputDataTypes({&origin_out_datatype}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - EXPECT_EQ(context->GetOutputDataType(0), origin_out_datatype); - EXPECT_NE(infer_func(context), ge::GRAPH_SUCCESS); -} -TEST_F(InferDataTypeContextUT, GetRequiredInputDataTypeOk) { - ge::DataType in_datatype1 = ge::DT_INT8; - ge::DataType in_datatype2 = ge::DT_INT4; - ge::DataType in_datatype3 = ge::DT_INT8; - ge::DataType in_datatype4 = ge::DT_INT4; - ge::DataType out_datatype = ge::DT_FLOAT16; - auto context_holder = InferDataTypeContextFaker() - .IrInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputDataTypes({&in_datatype1, &in_datatype2, &in_datatype3, &in_datatype4}) - .OutputDataTypes({&out_datatype}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - EXPECT_EQ(context->GetRequiredInputDataType(0), in_datatype1); - EXPECT_EQ(context->GetDynamicInputDataType(1, 0), in_datatype2); - EXPECT_EQ(context->GetDynamicInputDataType(1, 1), in_datatype3); - - EXPECT_EQ(context->GetOptionalInputDataType(2), ge::DataType::DT_UNDEFINED); - - EXPECT_EQ(context->GetRequiredInputDataType(3), in_datatype4); -} -} // namespace gert diff --git a/tests/ut/exe_graph/infer_shape_context_unittest.cc b/tests/ut/exe_graph/infer_shape_context_unittest.cc deleted file mode 100644 index a6c4d0bc8d56c51ddaa572535f9f811e8e3cc788..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/infer_shape_context_unittest.cc +++ /dev/null @@ -1,211 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/infer_shape_context.h" -#include "graph/ge_error_codes.h" -#include -#include "faker/kernel_run_context_faker.h" -#include "exe_graph/runtime/storage_shape.h" -#include "register/kernel_registry.h" -namespace gert { -class InferShapeContextUT : public testing::Test {}; -TEST_F(InferShapeContextUT, GetInputShapeOk) { - gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; - gert::StorageShape in_shape2 = {{2, 2, 3, 8}, {8, 1, 2, 2, 16}}; - gert::StorageShape out_shape = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; - auto context_holder = InferShapeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputShapes({&in_shape1, &in_shape2}) - .OutputShapes({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputShape(0), nullptr); - EXPECT_EQ(*context->GetInputShape(0), in_shape1.GetOriginShape()); - - ASSERT_NE(context->GetInputShape(1), nullptr); - EXPECT_EQ(*context->GetInputShape(1), in_shape2.GetOriginShape()); - - EXPECT_EQ(context->GetInputShape(2), nullptr); -} - -TEST_F(InferShapeContextUT, GetDynamicInputShapeOk) { - gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; - gert::StorageShape in_shape2 = {{2, 2, 3, 8}, {2, 2, 3, 8}}; - gert::StorageShape in_shape3 = {{3, 2, 3, 8}, {3, 2, 3, 8}}; - gert::StorageShape in_shape4 = {{4, 2, 3, 8}, {4, 2, 3, 8}}; - gert::StorageShape out_shape = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; - auto context_holder = InferShapeContextFaker() - .IrInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape1, &in_shape2, &in_shape3, &in_shape4}) - .OutputShapes({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetOptionalInputShape(0), nullptr); - EXPECT_EQ(*context->GetOptionalInputShape(0), in_shape1.GetOriginShape()); - - ASSERT_NE(context->GetDynamicInputShape(1, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(1, 0), in_shape2.GetOriginShape()); - - ASSERT_NE(context->GetDynamicInputShape(1, 1), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(1, 1), in_shape3.GetOriginShape()); - - EXPECT_EQ(context->GetOptionalInputShape(2), nullptr); - - ASSERT_NE(context->GetOptionalInputShape(3), nullptr); - EXPECT_EQ(*context->GetOptionalInputShape(3), in_shape4.GetOriginShape()); -} - -TEST_F(InferShapeContextUT, GetOutShapeOk) { - gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; - gert::StorageShape in_shape2 = {{2, 2, 3, 8}, {8, 1, 2, 2, 16}}; - gert::StorageShape out_shape = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; - auto context_holder = InferShapeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputShapes({&in_shape1, &in_shape2}) - .OutputShapes({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetOutputShape(0), nullptr); - EXPECT_EQ(*context->GetOutputShape(0), out_shape.GetOriginShape()); - - EXPECT_EQ(context->GetOutputShape(1), nullptr); -} - -TEST_F(InferShapeContextUT, GetOptionalInputTensorFailed_NotSetOptionalInput) { - gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; - auto infer_shape_func_addr = reinterpret_cast(0x11); - - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .IrInstanceNum({1, 0}) - .KernelIONum(2, 0) - .NodeIoNum(1, 0) - .Inputs({&in_shape1, infer_shape_func_addr}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputShape(0), nullptr); - EXPECT_EQ(*context->GetInputShape(0), in_shape1.GetOriginShape()); - - EXPECT_EQ(context->GetOptionalInputTensor(1), nullptr); - EXPECT_NE(context->GetInputTensor(1), nullptr); -} - -TEST_F(InferShapeContextUT, GetOptionalInputTensorOK_SetOptionalInput) { -gert::StorageShape in_shape1 = {{8, 3, 224, 224}, {8, 1, 224, 224, 16}}; -gert::Tensor in_tensor_2 = {{{1, 16, 256}, {1, 16, 256}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0xabc}; -auto infer_shape_func_addr = reinterpret_cast(0x11); - -auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .IrInstanceNum({1, 1}) - .KernelIONum(3, 0) - .NodeIoNum(2, 0) - .Inputs({&in_shape1, &in_tensor_2, infer_shape_func_addr}) - .Build(); -auto context = context_holder.GetContext(); -ASSERT_NE(context, nullptr); - -ASSERT_NE(context->GetInputShape(0), nullptr); -EXPECT_EQ(*context->GetInputShape(0), in_shape1.GetOriginShape()); - -EXPECT_NE(context->GetOptionalInputTensor(1), nullptr); -EXPECT_NE(context->GetInputTensor(2), nullptr); -} - - -TEST_F(InferShapeContextUT, GetRequiredInputTensorOk) { - gert::Tensor in_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Tensor in_tensor_2 = {{{2, 2, 3, 8}, {2, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Tensor in_tensor_3 = {{{3, 2, 3, 8}, {3, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Tensor in_tensor_4 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor out_tensor = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - - auto context_holder = InferShapeContextFaker() - .IrInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_tensor_1, &in_tensor_2, &in_tensor_3, &in_tensor_4}) - .OutputShapes({&out_tensor}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetRequiredInputShape(0), nullptr); - EXPECT_EQ(*context->GetRequiredInputShape(0), in_tensor_1.GetOriginShape()); - ASSERT_NE(context->GetRequiredInputTensor(0), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetAddr(), nullptr); - - ASSERT_NE(context->GetDynamicInputShape(1, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(1, 0), in_tensor_2.GetOriginShape()); - ASSERT_NE(context->GetDynamicInputTensor(1, 0), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetOriginShape(), in_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetAddr(), nullptr); - - ASSERT_NE(context->GetDynamicInputShape(1, 1), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(1, 1), in_tensor_3.GetOriginShape()); - ASSERT_NE(context->GetDynamicInputTensor(1, 1), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetOriginShape(), in_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetAddr(), nullptr); - - EXPECT_EQ(context->GetOptionalInputShape(2), nullptr); - - ASSERT_NE(context->GetRequiredInputShape(3), nullptr); - EXPECT_EQ(*context->GetRequiredInputShape(3), in_tensor_4.GetOriginShape()); - ASSERT_NE(context->GetRequiredInputTensor(3), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetOriginShape(), in_tensor_4.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetAddr(), in_tensor_4.GetAddr()); -} -} // namespace gert diff --git a/tests/ut/exe_graph/infer_shape_range_context_unittest.cc b/tests/ut/exe_graph/infer_shape_range_context_unittest.cc deleted file mode 100644 index 7cd6df9e2536924a90ed8eebaaf2c6ea5af4e5a3..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/infer_shape_range_context_unittest.cc +++ /dev/null @@ -1,226 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/infer_shape_range_context.h" -#include -#include "faker/kernel_run_context_faker.h" -#include "exe_graph/runtime/storage_shape.h" -namespace gert { -class InferShapeRangeContextUT : public testing::Test {}; -TEST_F(InferShapeRangeContextUT, GetInputShapeRangeOk) { - Shape same_ele{8, 3, 224, 224}; - gert::Range in_shape_range1(&same_ele); - Shape min1{2, 2, 3, 8}; - Shape max1{2, -1, 3, 8}; - gert::Range in_shape_range2(&min1, &max1); - Shape out_shape1{8, 3, 224, 224}; - Shape out_shape2{8, 224, 224, 224}; - gert::Range out_shape_range(&out_shape1, &out_shape2); - auto context_holder = InferShapeRangeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputShapeRanges({&in_shape_range1, &in_shape_range2}) - .OutputShapeRanges({&out_shape_range}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputShapeRange(0), nullptr); - EXPECT_EQ(*context->GetInputShapeRange(0), in_shape_range1); - - ASSERT_NE(context->GetInputShapeRange(1), nullptr); - EXPECT_EQ(*context->GetInputShapeRange(1), in_shape_range2); - - EXPECT_EQ(context->GetInputShapeRange(2), nullptr); -} - -TEST_F(InferShapeRangeContextUT, GetDynamicInputShapeRangeOk) { - Shape min1{8, 3, 224, 224}; - Shape max1{-1, 3, 224, 224}; - gert::Range in_shape_range1(&min1, &max1); - Shape min2{2, 2, 3, 8}; - Shape max2{2, -1, 3, 8}; - gert::Range in_shape_range2(&min2, &max2); - Shape min3{3, 2, 3, 8}; - Shape max3{3, 2, 9, 8}; - gert::Range in_shape_range3(&min3, &max3); - Shape min4{4, 2, 3, 8}; - Shape max4{4, 2, 3, 16}; - gert::Range in_shape_range4(&min4, &max4); - Shape min5{8, 3, 224, 224}; - Shape max5{-1, 3, 224, 224}; - gert::Range out_shape_range(&min5, &max5); - auto context_holder = InferShapeRangeContextFaker() - .IrInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapeRanges({&in_shape_range1, &in_shape_range2, &in_shape_range3, &in_shape_range4}) - .OutputShapeRanges({&out_shape_range}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetOptionalInputShapeRange(0), nullptr); - EXPECT_EQ(*context->GetOptionalInputShapeRange(0), in_shape_range1); - - ASSERT_NE(context->GetDynamicInputShapeRange(1, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShapeRange(1, 0), in_shape_range2); - - ASSERT_NE(context->GetDynamicInputShapeRange(1, 1), nullptr); - EXPECT_EQ(*context->GetDynamicInputShapeRange(1, 1), in_shape_range3); - - EXPECT_EQ(context->GetOptionalInputShapeRange(2), nullptr); - - ASSERT_NE(context->GetOptionalInputShapeRange(3), nullptr); - EXPECT_EQ(*context->GetOptionalInputShapeRange(3), in_shape_range4); -} - -TEST_F(InferShapeRangeContextUT, GetOutShapeOk) { - Shape same_ele{8, 3, 224, 224}; - gert::Range in_shape_range1(&same_ele); - Shape min2{2, 2, 3, 8}; - Shape max2{2, -1, 3, 8}; - gert::Range in_shape_range2(&min2, &max2); - Shape out_min{8, 3, 224, 224}; - Shape out_max{8, 224, 224, 224}; - gert::Range out_shape_range(&out_min, &out_max); - auto context_holder = InferShapeRangeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .InputShapeRanges({&in_shape_range1, &in_shape_range2}) - .OutputShapeRanges({&out_shape_range}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetOutputShapeRange(0), nullptr); - EXPECT_EQ(*context->GetOutputShapeRange(0), out_shape_range); - - EXPECT_EQ(context->GetOutputShapeRange(1), nullptr); -} - -TEST_F(InferShapeRangeContextUT, GetRequiredInputShapeRangeOk) { - gert::Tensor min_tensor_1 = {{{8, 3, 224, 224}, {8, 3, 224, 224}}, // shape - {ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Tensor max_tensor_1 = {{{-1, 3, 224, 224}, {-1, 3, 224, 224}}, // shape - {ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Range in_shape_range1(&min_tensor_1, &max_tensor_1); - gert::Tensor min_tensor_2 = {{{2, 2, 3, 8}, {2, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Tensor max_tensor_2 = {{{2, -1, 3, 8}, {2, -1, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Range in_shape_range2(&min_tensor_2, &max_tensor_2); - gert::Tensor min_tensor_3 = {{{3, 2, 3, 8}, {3, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Tensor max_tensor_3 = {{{3, 2, 9, 8}, {3, 2, 9, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Range in_shape_range3(&min_tensor_3, &max_tensor_3); - gert::Tensor min_tensor_4 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor max_tensor_4 = {{{4, 2, 3, 16}, {4, 2, 3, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Range in_shape_range4(&min_tensor_4, &max_tensor_4); - gert::Tensor min_tensor_5 = {{{8, 3, 224, 224}, {8, 3, 224, 224}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Tensor max_tensor_5 = {{{-1, 3, 224, 224}, {-1, 3, 224, 224}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x0}; - gert::Range out_shape_range(&min_tensor_5, &max_tensor_5); - auto context_holder = InferShapeRangeContextFaker() - .IrInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapeRanges({&in_shape_range1, &in_shape_range2, &in_shape_range3, &in_shape_range4}) - .OutputShapeRanges({&out_shape_range}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetRequiredInputShapeRange(0), nullptr); - EXPECT_EQ(*context->GetRequiredInputShapeRange(0)->GetMin(), in_shape_range1.GetMin()->GetStorageShape()); - EXPECT_EQ(*context->GetRequiredInputShapeRange(0)->GetMax(), in_shape_range1.GetMax()->GetStorageShape()); - ASSERT_NE(context->GetRequiredInputTensorRange(0), nullptr); - EXPECT_EQ(context->GetRequiredInputTensorRange(0)->GetMin()->GetOriginShape(), in_shape_range1.GetMin()->GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensorRange(0)->GetMin()->GetAddr(), in_shape_range1.GetMin()->GetAddr()); - EXPECT_EQ(context->GetRequiredInputTensorRange(0)->GetMax()->GetOriginShape(), in_shape_range1.GetMax()->GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensorRange(0)->GetMax()->GetAddr(), in_shape_range1.GetMax()->GetAddr()); - - ASSERT_NE(context->GetDynamicInputShapeRange(1, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShapeRange(1, 0)->GetMin(), in_shape_range2.GetMin()->GetStorageShape()); - EXPECT_EQ(*context->GetDynamicInputShapeRange(1, 0)->GetMax(), in_shape_range2.GetMax()->GetStorageShape()); - ASSERT_NE(context->GetDynamicInputTensorRange(1, 0), nullptr); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 0)->GetMin()->GetOriginShape(), in_shape_range2.GetMin()->GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 0)->GetMin()->GetAddr(), in_shape_range2.GetMin()->GetAddr()); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 0)->GetMax()->GetOriginShape(), in_shape_range2.GetMax()->GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 0)->GetMax()->GetAddr(), in_shape_range2.GetMax()->GetAddr()); - - ASSERT_NE(context->GetDynamicInputShapeRange(1, 1), nullptr); - EXPECT_EQ(*context->GetDynamicInputShapeRange(1, 1)->GetMin(), in_shape_range3.GetMin()->GetStorageShape()); - EXPECT_EQ(*context->GetDynamicInputShapeRange(1, 1)->GetMax(), in_shape_range3.GetMax()->GetStorageShape()); - ASSERT_NE(context->GetDynamicInputTensorRange(1, 1), nullptr); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 1)->GetMin()->GetOriginShape(), in_shape_range3.GetMin()->GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 1)->GetMin()->GetAddr(), in_shape_range3.GetMin()->GetAddr()); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 1)->GetMax()->GetOriginShape(), in_shape_range3.GetMax()->GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensorRange(1, 1)->GetMax()->GetAddr(), in_shape_range3.GetMax()->GetAddr()); - - EXPECT_EQ(context->GetOptionalInputShapeRange(2), nullptr); - - ASSERT_NE(context->GetRequiredInputShapeRange(3), nullptr); - EXPECT_EQ(*context->GetRequiredInputShapeRange(3)->GetMin(), in_shape_range4.GetMin()->GetStorageShape()); - EXPECT_EQ(*context->GetRequiredInputShapeRange(3)->GetMax(), in_shape_range4.GetMax()->GetStorageShape()); - ASSERT_NE(context->GetRequiredInputTensorRange(3), nullptr); - EXPECT_EQ(context->GetRequiredInputTensorRange(3)->GetMin()->GetOriginShape(), in_shape_range4.GetMin()->GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensorRange(3)->GetMin()->GetAddr(), in_shape_range4.GetMin()->GetAddr()); - EXPECT_EQ(context->GetRequiredInputTensorRange(3)->GetMax()->GetOriginShape(), in_shape_range4.GetMax()->GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensorRange(3)->GetMax()->GetAddr(), in_shape_range4.GetMax()->GetAddr()); - - ASSERT_NE(context->GetOutputShapeRange(0), nullptr); - EXPECT_EQ(*context->GetOutputShapeRange(0)->GetMin(), out_shape_range.GetMin()->GetStorageShape()); - EXPECT_EQ(*context->GetOutputShapeRange(0)->GetMax(), out_shape_range.GetMax()->GetStorageShape()); -} -} // namespace gert diff --git a/tests/ut/exe_graph/infer_symbol_shape_context_unittest.cc b/tests/ut/exe_graph/infer_symbol_shape_context_unittest.cc deleted file mode 100644 index 34967592d02bb61c9ee388667b185c5e83a9986e..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/infer_symbol_shape_context_unittest.cc +++ /dev/null @@ -1,262 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "exe_graph/runtime/infer_symbol_shape_context.h" -#include "faker/kernel_run_context_faker.h" -#include "exe_graph/runtime/symbolic_tensor.h" - -namespace gert { -class InferSymbolShapeContextUT : public testing::Test {}; -TEST_F(InferSymbolShapeContextUT, GetInputShapeOk) { - auto in_0 = ge::Symbol(8); - auto in_1 = ge::Symbol(3); - auto in_2 = ge::Symbol(224); - auto in_3 = ge::Symbol(224); - gert::SymbolTensor in_tensor1({in_0, in_1, in_2, in_3}, {}); - gert::SymbolTensor in_tensor2({in_0, in_1, in_2, in_3}, {}); - gert::SymbolShape out_shape({in_0, in_1, in_2, in_3}); - - auto context_holder = InferSymbolShapeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Inputs({&in_tensor1, &in_tensor2}) - .Outputs({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputSymbolShape(0), nullptr); - EXPECT_EQ(*context->GetInputSymbolShape(0), in_tensor1.GetOriginSymbolShape()); - - ASSERT_NE(context->GetInputSymbolTensor(1), nullptr); - EXPECT_EQ(*context->GetInputSymbolShape(1), in_tensor2.GetOriginSymbolShape()); - - EXPECT_EQ(context->GetInputSymbolTensor(2), nullptr); - - ASSERT_NE(context->GetOutputSymbolShape(0), nullptr); - EXPECT_EQ(*context->GetOutputSymbolShape(0), out_shape); -} - -TEST_F(InferSymbolShapeContextUT, GetDynamicInputShapeOk) { - auto sym_8 = ge::Symbol(8); - auto sym_3 = ge::Symbol(3); - auto sym_2 = ge::Symbol(2); - auto sym_4 = ge::Symbol(4); - auto sym_224 = ge::Symbol(224); - auto sym_128 = ge::Symbol(128); - auto sym_16 = ge::Symbol(16); - gert::SymbolTensor in_tensor1({sym_8, sym_3, sym_224, sym_224}); - gert::SymbolTensor in_tensor2({sym_2, sym_2, sym_3, sym_8}); - gert::SymbolTensor in_tensor3({sym_3, sym_2, sym_3, sym_8}); - gert::SymbolTensor in_tensor4({sym_4, sym_2, sym_3, sym_8}); - gert::SymbolShape out_shape({sym_8, sym_3, sym_224, sym_224}); - - auto context_holder = InferSymbolShapeContextFaker() - .IrInputInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .Inputs({&in_tensor1, &in_tensor2, &in_tensor3, &in_tensor4}) - .Outputs({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetOptionalInputSymbolShape(0), nullptr); - EXPECT_EQ(*context->GetOptionalInputSymbolShape(0), in_tensor1.GetOriginSymbolShape()); - - ASSERT_NE(context->GetDynamicInputSymbolShape(1, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputSymbolShape(1, 0), in_tensor2.GetOriginSymbolShape()); - - ASSERT_NE(context->GetDynamicInputSymbolShape(1, 1), nullptr); - EXPECT_EQ(*context->GetDynamicInputSymbolShape(1, 1), in_tensor3.GetOriginSymbolShape()); - - EXPECT_EQ(context->GetOptionalInputSymbolShape(2), nullptr); - - ASSERT_NE(context->GetOptionalInputSymbolShape(3), nullptr); - EXPECT_EQ(*context->GetOptionalInputSymbolShape(3), in_tensor4.GetOriginSymbolShape()); -} - -TEST_F(InferSymbolShapeContextUT, GetOutShapeOk) { - auto sym_8 = ge::Symbol(8); - auto sym_3 = ge::Symbol(3); - auto sym_2 = ge::Symbol(2); - auto sym_4 = ge::Symbol(4); - auto sym_224 = ge::Symbol(224); - auto sym_16 = ge::Symbol(16); - gert::SymbolShape in_shape1({sym_8, sym_3, sym_224, sym_224}); - gert::SymbolShape in_shape2({sym_2, sym_2, sym_3, sym_8}); - gert::SymbolShape out_shape({sym_8, sym_3, sym_224, sym_224}); - - auto context_holder = InferSymbolShapeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Inputs({&in_shape1, &in_shape2}) - .Outputs({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetOutputSymbolShape(0), nullptr); - EXPECT_EQ(*context->GetOutputSymbolShape(0), out_shape); - - EXPECT_EQ(context->GetOutputSymbolShape(1), nullptr); -} - -TEST_F(InferSymbolShapeContextUT, GetOptionalInputTensorFailed_NotSetOptionalInput) { - auto sym_8 = ge::Symbol(8); - auto sym_3 = ge::Symbol(3); - auto sym_224 = ge::Symbol(224); - gert::SymbolShape in_shape1({sym_8, sym_3, sym_224, sym_224}); - - auto infer_shape_func_addr = reinterpret_cast(0x11); - - auto context_holder = KernelRunContextFaker() - .IrInputNum(2) - .IrInstanceNum({1, 0}) - .KernelIONum(2, 0) - .NodeIoNum(1, 0) - .Inputs({&in_shape1, infer_shape_func_addr}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputSymbolTensor(0), nullptr); - EXPECT_EQ(*context->GetInputSymbolShape(0), in_shape1); - - EXPECT_EQ(context->GetOptionalInputSymbolTensor(1), nullptr); - EXPECT_NE(context->GetInputSymbolTensor(1), nullptr); -} - -TEST_F(InferSymbolShapeContextUT, GetInputSymbolTensorValueOK) { - auto sym_8 = ge::Symbol(8); - auto sym_3 = ge::Symbol(3); - auto sym_4 = ge::Symbol(4); - auto sym_224 = ge::Symbol(224); - auto sym_16 = ge::Symbol(16); - gert::SymbolTensor in_tensor1({sym_3, sym_8, sym_16, sym_16}); - gert::SymbolTensor in_tensor2({sym_4}, {sym_3, sym_16, sym_224, sym_16}); - - auto context_holder = InferSymbolShapeContextFaker() - .IrInputNum(2) - .IrInputInstanceNum({1, 1}) - .NodeIoNum(2, 0) - .Inputs({reinterpret_cast(&in_tensor1), reinterpret_cast(&in_tensor2)}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputSymbolShape(0), nullptr); - EXPECT_EQ(*context->GetInputSymbolShape(0), in_tensor1.GetOriginSymbolShape()); - - EXPECT_NE(context->GetInputSymbolTensor(1), nullptr); - EXPECT_EQ(context->GetInputSymbolTensor(1)->GetSymbolicValue()->size(), 4); - EXPECT_EQ(context->GetInputSymbolTensor(1)->GetSymbolicValue()->at(0), sym_3); - EXPECT_EQ(context->GetInputSymbolTensor(1)->GetSymbolicValue()->at(1), sym_16); - EXPECT_EQ(context->GetInputSymbolTensor(1)->GetSymbolicValue()->at(2), sym_224); - EXPECT_EQ(context->GetInputSymbolTensor(1)->GetSymbolicValue()->at(3), sym_16); -} - -TEST_F(InferSymbolShapeContextUT, GetOptionalInputTensorOK_SetOptionalInput) { - auto sym_8 = ge::Symbol(8); - auto sym_3 = ge::Symbol(3); - auto sym_4 = ge::Symbol(4); - auto sym_224 = ge::Symbol(224); - auto sym_16 = ge::Symbol(16); - gert::SymbolTensor in_tensor1({sym_8, sym_3, sym_224, sym_224}); - gert::SymbolTensor in_tensor2({sym_4}, {sym_3, sym_16, sym_224, sym_16}); - - auto context_holder = - InferSymbolShapeContextFaker() - .IrInputNum(2) - .IrInputInstanceNum({1, 1}) - .NodeIoNum(2, 0) - .Inputs({reinterpret_cast(&in_tensor1), - reinterpret_cast(&in_tensor2)}) // 接口内部会在末尾填充一个infershape的函数地址 - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputSymbolShape(0), nullptr); - EXPECT_EQ(*context->GetInputSymbolShape(0), in_tensor1.GetOriginSymbolShape()); - - EXPECT_NE(context->GetOptionalInputSymbolTensor(1), nullptr); - ASSERT_EQ(context->GetOptionalInputSymbolTensor(1)->GetSymbolicValue()->size(), 4); - EXPECT_EQ(context->GetOptionalInputSymbolTensor(1)->GetSymbolicValue()->at(0), sym_3); - EXPECT_EQ(context->GetOptionalInputSymbolTensor(1)->GetSymbolicValue()->at(1), sym_16); - EXPECT_EQ(context->GetOptionalInputSymbolTensor(1)->GetSymbolicValue()->at(2), sym_224); - EXPECT_EQ(context->GetOptionalInputSymbolTensor(1)->GetSymbolicValue()->at(3), sym_16); - - EXPECT_EQ(context->GetOptionalInputSymbolTensor(2), nullptr); -} - -TEST_F(InferSymbolShapeContextUT, GetRequiredInputTensorOk) { - auto sym_8 = ge::Symbol(8); - auto sym_3 = ge::Symbol(3); - auto sym_4 = ge::Symbol(4); - auto sym_224 = ge::Symbol(224); - auto sym_16 = ge::Symbol(16); - gert::SymbolTensor in_tensor1({sym_8, sym_3, sym_224, sym_224}); - gert::SymbolTensor in_tensor2({sym_3, sym_16, sym_224, sym_16}); - gert::SymbolTensor in_tensor3({sym_4}, {sym_8, sym_3, sym_224, sym_224}); - gert::SymbolTensor in_tensor4({sym_4}, {sym_3, sym_8, sym_16, sym_16}); - - gert::SymbolShape symbol_shape; - - - auto context_holder = - InferSymbolShapeContextFaker() - .IrInputInstanceNum({1, 2, 0, 1}) - .NodeIoNum(4, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .Inputs({reinterpret_cast(&in_tensor1), reinterpret_cast(&in_tensor2), - reinterpret_cast(&in_tensor3), reinterpret_cast(&in_tensor4)}) - .Outputs({&symbol_shape}) - .Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - - // 第一个输入是必选输入,获取到shape - ASSERT_NE(context->GetRequiredInputSymbolShape(0), nullptr); - EXPECT_EQ(*context->GetRequiredInputSymbolShape(0), in_tensor1.GetOriginSymbolShape()); - - // 第二个输入是动态输入,有两个实例,第一个实例获取到shape - ASSERT_NE(context->GetDynamicInputSymbolShape(1, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputSymbolShape(1, 0), in_tensor2.GetOriginSymbolShape()); - - // 第二个输入是动态输入,有两个实例,第一个实例获取到value - ASSERT_NE(context->GetDynamicInputSymbolTensor(1, 1), nullptr); - ASSERT_EQ(context->GetDynamicInputSymbolTensor(1, 1)->GetSymbolicValue()->size(), 4); - EXPECT_EQ(context->GetDynamicInputSymbolTensor(1, 1)->GetSymbolicValue()->at(0), sym_8); - EXPECT_EQ(context->GetDynamicInputSymbolTensor(1, 1)->GetSymbolicValue()->at(1), sym_3); - EXPECT_EQ(context->GetDynamicInputSymbolTensor(1, 1)->GetSymbolicValue()->at(2), sym_224); - EXPECT_EQ(context->GetDynamicInputSymbolTensor(1, 1)->GetSymbolicValue()->at(3), sym_224); - - // 第三个输入是可选输入,由于可选输入未被设置,获取到shape为nullptr - EXPECT_EQ(context->GetOptionalInputSymbolShape(2), nullptr); - - // 第四个输入是必选输入 - ASSERT_NE(context->GetRequiredInputSymbolTensor(3), nullptr); - ASSERT_EQ(context->GetRequiredInputSymbolTensor(3)->GetSymbolicValue()->size(), 4); - EXPECT_EQ(context->GetRequiredInputSymbolTensor(3)->GetSymbolicValue()->at(0), sym_3); - EXPECT_EQ(context->GetRequiredInputSymbolTensor(3)->GetSymbolicValue()->at(1), sym_8); - EXPECT_EQ(context->GetRequiredInputSymbolTensor(3)->GetSymbolicValue()->at(2), sym_16); - EXPECT_EQ(context->GetRequiredInputSymbolTensor(3)->GetSymbolicValue()->at(3), sym_16); -} -} // namespace gert \ No newline at end of file diff --git a/tests/ut/exe_graph/kernel_run_context_builder_unittest.cc b/tests/ut/exe_graph/kernel_run_context_builder_unittest.cc deleted file mode 100644 index 6bed0e2ad69b5daa06988c84b5b8416a47a7e26f..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/kernel_run_context_builder_unittest.cc +++ /dev/null @@ -1,129 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "lowering/kernel_run_context_builder.h" -#include "register/op_impl_registry.h" -#include "faker/space_registry_faker.h" -namespace gert { -class KernelRunContextBuilderUT : public testing::Test {}; - -TEST_F(KernelRunContextBuilderUT, SetBufferPoolOk) { - ge::OpDescPtr op_desc = std::make_shared("test0", "test1"); - KernelRunContextBuilder builder; - ge::graphStatus ret = ge::GRAPH_FAILED; - auto holder = builder.Build(op_desc, ret); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - auto compute_node_info = reinterpret_cast( - holder.context_->GetComputeNodeExtend()); - EXPECT_EQ(std::string(compute_node_info->GetNodeName()), "test0"); - EXPECT_EQ(std::string(compute_node_info->GetNodeType()), "test1"); -} - -TEST_F(KernelRunContextBuilderUT, SetInputsOutputsOk) { - ge::OpDescPtr op_desc = std::make_shared("test0", "test1"); - KernelRunContextBuilder builder; - gert::StorageShape shape1({1,2,3,4}, {1,2,3,4}); - gert::StorageShape shape2({2,2,3,4}, {2,2,3,4}); - gert::StorageShape shape3({3,2,3,4}, {3,2,3,4}); - ge::graphStatus ret = ge::GRAPH_FAILED; - auto holder = builder.Inputs({{&shape1, nullptr}, {&shape2, nullptr}}).Outputs({&shape3}).Build(op_desc, ret); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - auto context = holder.context_; - EXPECT_EQ(context->GetInputNum(), 2); - EXPECT_EQ(context->GetOutputNum(), 1); - EXPECT_TRUE(context->GetInputPointer(0) == &shape1); - EXPECT_TRUE(context->GetInputPointer(1) == &shape2); - EXPECT_TRUE(context->GetOutputPointer(0) == &shape3); -} - -TEST_F(KernelRunContextBuilderUT, SetInputsOutputsDataTypeOk) { -ge::OpDescPtr op_desc = std::make_shared("test0", "test1"); -KernelRunContextBuilder builder; - -ge::DataType in_datatype_1 = ge::DT_INT4; -ge::DataType in_datatype_2 = ge::DT_INT8; -ge::DataType out_datatype = ge::DT_INT8; -ge::graphStatus ret = ge::GRAPH_FAILED; -auto holder = builder - .Inputs({{reinterpret_cast(in_datatype_1), nullptr}, - {reinterpret_cast(in_datatype_2), nullptr}}) - .Outputs({reinterpret_cast(out_datatype)}) - .Build(op_desc, ret); -EXPECT_EQ(ret, ge::GRAPH_SUCCESS); -auto context = holder.context_; -EXPECT_EQ(context->GetInputNum(), 2); -EXPECT_EQ(context->GetOutputNum(), 1); -EXPECT_TRUE(*context->GetInputPointer(0) == in_datatype_1); -EXPECT_TRUE(*context->GetInputPointer(1) == in_datatype_2); -EXPECT_TRUE(*context->GetOutputPointer(0) == out_datatype); -} - -TEST_F(KernelRunContextBuilderUT, BuildContextHolderSuccessWhenOpLossAttrs) { - ge::OpDescPtr op_desc = std::make_shared("test0", "test1"); - op_desc->AppendIrAttrName("attr1"); - KernelRunContextBuilder builder; - ge::graphStatus ret = ge::GRAPH_FAILED; - auto holder = builder.Build(op_desc, ret); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - ASSERT_NE(holder.context_holder_, nullptr); - EXPECT_NE(holder.compute_node_extend_holder_, nullptr); -} - -TEST_F(KernelRunContextBuilderUT, Get2AttrsFromCtx_OpBothHas1IrAttrAnd1PrivateAttr) { - IMPL_OP(TestOpWithPrivateAttr1).PrivateAttr("test_private_attr", static_cast(100)); - SpaceRegistryFaker::UpdateOpImplToDefaultSpaceRegistry(); - ge::OpDescPtr op_desc = std::make_shared("testop1", "TestOpWithPrivateAttr1"); - op_desc->AppendIrAttrName("test_ir_attr"); - ge::AttrUtils::SetInt(op_desc, "test_ir_attr", 10); - ge::AttrUtils::SetInt(op_desc, "test_private_attr", 1000); - KernelRunContextBuilder builder; - ge::graphStatus ret = ge::GRAPH_FAILED; - auto holder = builder.Build(op_desc, ret); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - ASSERT_NE(holder.context_holder_, nullptr); - ASSERT_NE(holder.compute_node_extend_holder_, nullptr); - - auto kernel_ctx = holder.GetKernelContext(); - ASSERT_NE(kernel_ctx, nullptr); - auto tiling_ctx = reinterpret_cast(kernel_ctx); - auto runtime_attrs = tiling_ctx->GetAttrs(); - ASSERT_NE(runtime_attrs, nullptr); - ASSERT_EQ(runtime_attrs->GetAttrNum(), 2U); - auto attr0 = runtime_attrs->GetAttrPointer(0); - EXPECT_EQ(*attr0, 10); - auto attr1 = runtime_attrs->GetAttrPointer(1); - EXPECT_EQ(*attr1, 1000); - SpaceRegistryFaker::SetefaultSpaceRegistryNull(); -} - -TEST_F(KernelRunContextBuilderUT, OnlyGetIrAttrsFromCtx_OpNotRegisterPrivateAttr) { - SpaceRegistryFaker::UpdateOpImplToDefaultSpaceRegistry(); - ge::OpDescPtr op_desc = std::make_shared("testop1", "TestOpWithPrivateAttr2"); - op_desc->AppendIrAttrName("test_ir_attr"); - ge::AttrUtils::SetInt(op_desc, "test_ir_attr", 10); - ge::AttrUtils::SetInt(op_desc, "test_private_attr", 1000); - KernelRunContextBuilder builder; - ge::graphStatus ret = ge::GRAPH_FAILED; - auto holder = builder.Build(op_desc, ret); - ASSERT_EQ(ret, ge::GRAPH_SUCCESS); - ASSERT_NE(holder.context_holder_, nullptr); - ASSERT_NE(holder.compute_node_extend_holder_, nullptr); - - auto kernel_ctx = holder.GetKernelContext(); - ASSERT_NE(kernel_ctx, nullptr); - auto tiling_ctx = reinterpret_cast(kernel_ctx); - auto runtime_attrs = tiling_ctx->GetAttrs(); - ASSERT_NE(runtime_attrs, nullptr); - ASSERT_EQ(runtime_attrs->GetAttrNum(), 1U); - auto attr0 = runtime_attrs->GetAttrPointer(0); - EXPECT_EQ(*attr0, 10); - SpaceRegistryFaker::SetefaultSpaceRegistryNull(); -} -} // namespace gert diff --git a/tests/ut/exe_graph/op_execute_context_unittest.cc b/tests/ut/exe_graph/op_execute_context_unittest.cc deleted file mode 100644 index ffe657cb5a7c27a1b8b60c88e4427f4d18df359c..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/op_execute_context_unittest.cc +++ /dev/null @@ -1,196 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/op_execute_context.h" -#include "graph/ge_error_codes.h" -#include -#include "faker/kernel_run_context_faker.h" -#include "faker/allocator_faker.h" -#include "exe_graph/runtime/storage_shape.h" -#include "register/kernel_registry.h" - -namespace gert { -class OpExecuteContextUT : public testing::Test {}; -TEST_F(OpExecuteContextUT, GetInputTest) { - gert::Tensor in_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor in_tensor_2 = {{{2, 2, 3, 8}, {2, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x234565}; - gert::Tensor in_tensor_3 = {{{3, 2, 3, 8}, {3, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x45678}; - gert::Tensor in_tensor_4 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnHost, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor in_tensor_5 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnHost, // placement - ge::DT_FLOAT16, // data type - (void *)0x123345}; - gert::Tensor out_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x34586}; - gert::Tensor out_tensor_2 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x342354}; - gert::Tensor out_tensor_3 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x642354}; - int32_t stream_int = 1; - rtStream rt_stream = static_cast(&stream_int); - OpExecuteOptions execute_option; - execute_option.allow_hf32[0] = '1'; - execute_option.deterministic = 1; - execute_option.precision_mode = 1; - AllocatorFaker gert_allocator; - auto op_execute_holder = OpExecuteContextFaker() - .IrInstanceNum({1, 2, 1, 0, 1}) - .IrOutputInstanceNum({2, 1}) - .NodeIoNum(5, 3) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(4, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputTensor({&in_tensor_1, &in_tensor_2, &in_tensor_3, &in_tensor_4, &in_tensor_5}) - .OutputTensor({&out_tensor_1, &out_tensor_2, &out_tensor_3}) - .Allocate(&gert_allocator) - .Stream(rt_stream) - .ExecuteOption(&execute_option) - .Build(); - auto context = op_execute_holder.GetContext(); - - ASSERT_NE(context, nullptr); - ASSERT_NE(context->GetInputTensor(0), nullptr); - EXPECT_EQ(context->GetInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(0)->GetAddr(), in_tensor_1.GetAddr()); - ASSERT_NE(context->GetRequiredInputTensor(0), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetAddr(), in_tensor_1.GetAddr()); - - ASSERT_NE(context->GetInputTensor(1), nullptr); - EXPECT_EQ(context->GetInputTensor(1)->GetOriginShape(), in_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(1)->GetAddr(), in_tensor_2.GetAddr()); - ASSERT_NE(context->GetDynamicInputTensor(1, 0), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetOriginShape(), in_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetAddr(), in_tensor_2.GetAddr()); - - ASSERT_NE(context->GetInputTensor(2), nullptr); - EXPECT_EQ(context->GetInputTensor(2)->GetOriginShape(), in_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(2)->GetAddr(), in_tensor_3.GetAddr()); - ASSERT_NE(context->GetDynamicInputTensor(1, 1), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetOriginShape(), in_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetAddr(), in_tensor_3.GetAddr()); - - ASSERT_NE(context->GetInputTensor(3), nullptr); - EXPECT_EQ(context->GetInputTensor(3)->GetOriginShape(), in_tensor_4.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(3)->GetAddr(), in_tensor_4.GetAddr()); - ASSERT_NE(context->GetOptionalInputTensor(2), nullptr); - EXPECT_EQ(context->GetOptionalInputTensor(2)->GetOriginShape(), in_tensor_4.GetOriginShape()); - EXPECT_EQ(context->GetOptionalInputTensor(2)->GetAddr(), in_tensor_4.GetAddr()); - - EXPECT_EQ(context->GetOptionalInputTensor(3), nullptr); - - ASSERT_NE(context->GetInputTensor(4), nullptr); - EXPECT_EQ(context->GetInputTensor(4)->GetOriginShape(), in_tensor_5.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(4)->GetAddr(), in_tensor_5.GetAddr()); - ASSERT_NE(context->GetRequiredInputTensor(4), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(4)->GetOriginShape(), in_tensor_5.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(4)->GetAddr(), in_tensor_5.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(0), nullptr); - EXPECT_EQ(context->GetOutputTensor(0)->GetOriginShape(), out_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(0)->GetAddr(), out_tensor_1.GetAddr()); - ASSERT_NE(context->GetDynamicOutputTensor(0, 0), nullptr); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 0)->GetOriginShape(), out_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 0)->GetAddr(), out_tensor_1.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(1), nullptr); - EXPECT_EQ(context->GetOutputTensor(1)->GetOriginShape(), out_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(1)->GetAddr(), out_tensor_2.GetAddr()); - ASSERT_NE(context->GetDynamicOutputTensor(0, 1), nullptr); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 1)->GetOriginShape(), out_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 1)->GetAddr(), out_tensor_2.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(2), nullptr); - EXPECT_EQ(context->GetOutputTensor(2)->GetOriginShape(), out_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(2)->GetAddr(), out_tensor_3.GetAddr()); - ASSERT_NE(context->GetRequiredOutputTensor(1), nullptr); - EXPECT_EQ(context->GetRequiredOutputTensor(1)->GetOriginShape(), out_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetRequiredOutputTensor(1)->GetAddr(), out_tensor_3.GetAddr()); - EXPECT_EQ(context->GetStream(), rt_stream); - EXPECT_EQ(context->GetAllowHf32(), execute_option.allow_hf32); - EXPECT_EQ(context->GetDeterministic(), true); - EXPECT_EQ(context->GetPrecisionMode(), execute_option.precision_mode); -} - -TEST_F(OpExecuteContextUT, MallocFreeWorkSpaceOk) { - gert::Tensor in_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor in_tensor_2 = {{{2, 2, 3, 8}, {2, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x234565}; - gert::Tensor out_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x34586}; - AllocatorFaker gert_allocator; - auto output_block_memory = std::make_shared>(); - ASSERT_NE(output_block_memory, nullptr); - output_block_memory->reserve(1UL); - auto op_execute_holder = OpExecuteContextFaker() - .IrInstanceNum({1, 1}) - .IrOutputInstanceNum({1}) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputTensor({&in_tensor_1, &in_tensor_2}) - .OutputTensor({&out_tensor_1}) - .OutputMem(output_block_memory) - .Allocate(&gert_allocator) - .Build(); - auto context = op_execute_holder.GetContext(); - ASSERT_NE(context, nullptr); - auto block = context->MallocWorkspace(1024); - ASSERT_NE(block, nullptr); - - auto kernel_context = reinterpret_cast(context); - auto memory_vec = kernel_context->GetOutputPointer>(0UL); - ASSERT_NE(memory_vec, nullptr); - EXPECT_EQ(memory_vec->size(), 1UL); - context->FreeWorkspace(); - EXPECT_EQ(memory_vec->size(), 0UL); -} -} diff --git a/tests/ut/exe_graph/op_execute_launch_unittest.cc b/tests/ut/exe_graph/op_execute_launch_unittest.cc deleted file mode 100644 index ba53e5f82ce51184756ac7a9f55d83aba454d9d7..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/op_execute_launch_unittest.cc +++ /dev/null @@ -1,159 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ -#include "exe_graph/runtime/op_execute_launch_context.h" -#include -#include "faker/kernel_run_context_faker.h" - -namespace gert { -class OpExecuteLaunchContextUT : public testing::Test {}; -TEST_F(OpExecuteLaunchContextUT, GetInputOutputTest) { - gert::Tensor in_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor in_tensor_2 = {{{2, 2, 3, 8}, {2, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x234565}; - gert::Tensor in_tensor_3 = {{{3, 2, 3, 8}, {3, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x45678}; - gert::Tensor in_tensor_4 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnHost, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor in_tensor_5 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnHost, // placement - ge::DT_FLOAT16, // data type - (void *)0x123345}; - gert::Tensor out_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x34586}; - gert::Tensor out_tensor_2 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x342354}; - gert::Tensor out_tensor_3 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x642354}; - - void *op_api_params = (void *)0x12; - auto ws_addr_vector = ContinuousVector::Create(3); - auto *ws_addr_cont_vec = reinterpret_cast *>(ws_addr_vector.get()); - ws_addr_cont_vec->SetSize(3); - auto tensor_data_addr = static_cast(ws_addr_cont_vec->MutableData()); - auto td1 = TensorData((void *)0x34, nullptr); - auto td2 = TensorData((void *)0x35, nullptr); - auto td3 = TensorData((void *)0x36, nullptr); - tensor_data_addr[0] = &td1; - tensor_data_addr[1] = &td2; - tensor_data_addr[2] = &td3; - auto ws_size_vector = ContinuousVector::Create(1); - auto *ws_size_cont_vec = reinterpret_cast *>(ws_size_vector.get()); - ws_size_cont_vec->MutableData()[0] = 32U; - ws_size_cont_vec->SetSize(1); - rtStream stream = (void *)0x56; - auto op_execute_holder = OpExecuteLaunchContextFaker() - .IrInstanceNum({1, 2, 1, 0, 1}) - .IrOutputInstanceNum({2, 1}) - .NodeIoNum(5, 3) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(4, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputTensor({&in_tensor_1, &in_tensor_2, &in_tensor_3, &in_tensor_4, &in_tensor_5}) - .OutputTensor({&out_tensor_1, &out_tensor_2, &out_tensor_3}) - .OpApiParams(op_api_params) - .WorkspaceSize(ws_size_vector.get()) - .WorkspaceAddr(ws_addr_vector.get()) - .Stream(stream) - .Build(); - auto context = op_execute_holder.GetContext(); - ASSERT_NE(context, nullptr); - ASSERT_NE(context->GetInputTensor(0), nullptr); - EXPECT_EQ(context->GetInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(0)->GetAddr(), in_tensor_1.GetAddr()); - ASSERT_NE(context->GetRequiredInputTensor(0), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetAddr(), in_tensor_1.GetAddr()); - - ASSERT_NE(context->GetInputTensor(1), nullptr); - EXPECT_EQ(context->GetInputTensor(1)->GetOriginShape(), in_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(1)->GetAddr(), in_tensor_2.GetAddr()); - ASSERT_NE(context->GetDynamicInputTensor(1, 0), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetOriginShape(), in_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetAddr(), in_tensor_2.GetAddr()); - - ASSERT_NE(context->GetInputTensor(2), nullptr); - EXPECT_EQ(context->GetInputTensor(2)->GetOriginShape(), in_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(2)->GetAddr(), in_tensor_3.GetAddr()); - ASSERT_NE(context->GetDynamicInputTensor(1, 1), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetOriginShape(), in_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetAddr(), in_tensor_3.GetAddr()); - - ASSERT_NE(context->GetInputTensor(3), nullptr); - EXPECT_EQ(context->GetInputTensor(3)->GetOriginShape(), in_tensor_4.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(3)->GetAddr(), in_tensor_4.GetAddr()); - ASSERT_NE(context->GetOptionalInputTensor(2), nullptr); - EXPECT_EQ(context->GetOptionalInputTensor(2)->GetOriginShape(), in_tensor_4.GetOriginShape()); - EXPECT_EQ(context->GetOptionalInputTensor(2)->GetAddr(), in_tensor_4.GetAddr()); - - EXPECT_EQ(context->GetOptionalInputTensor(3), nullptr); - - ASSERT_NE(context->GetInputTensor(4), nullptr); - EXPECT_EQ(context->GetInputTensor(4)->GetOriginShape(), in_tensor_5.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(4)->GetAddr(), in_tensor_5.GetAddr()); - ASSERT_NE(context->GetRequiredInputTensor(4), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(4)->GetOriginShape(), in_tensor_5.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(4)->GetAddr(), in_tensor_5.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(0), nullptr); - EXPECT_EQ(context->GetOutputTensor(0)->GetOriginShape(), out_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(0)->GetAddr(), out_tensor_1.GetAddr()); - ASSERT_NE(context->GetDynamicOutputTensor(0, 0), nullptr); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 0)->GetOriginShape(), out_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 0)->GetAddr(), out_tensor_1.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(1), nullptr); - EXPECT_EQ(context->GetOutputTensor(1)->GetOriginShape(), out_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(1)->GetAddr(), out_tensor_2.GetAddr()); - ASSERT_NE(context->GetDynamicOutputTensor(0, 1), nullptr); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 1)->GetOriginShape(), out_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 1)->GetAddr(), out_tensor_2.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(2), nullptr); - EXPECT_EQ(context->GetOutputTensor(2)->GetOriginShape(), out_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(2)->GetAddr(), out_tensor_3.GetAddr()); - ASSERT_NE(context->GetRequiredOutputTensor(1), nullptr); - EXPECT_EQ(context->GetRequiredOutputTensor(1)->GetOriginShape(), out_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetRequiredOutputTensor(1)->GetAddr(), out_tensor_3.GetAddr()); - - ASSERT_EQ(context->GetOpApiParams(), op_api_params); - auto workspace_addrs = context->GetWorkspaceAddrs(); - ASSERT_EQ(workspace_addrs->GetData()[0]->GetAddr(), (void *)0x34); - auto workspace_sizes = context->GetWorkspaceSizes(); - ASSERT_EQ(workspace_sizes->GetData()[0], 32U); - ASSERT_EQ(context->GetStream(), stream); -} -} \ No newline at end of file diff --git a/tests/ut/exe_graph/op_execute_prepare_context_unittest.cc b/tests/ut/exe_graph/op_execute_prepare_context_unittest.cc deleted file mode 100644 index 858f60150186b3bf48f2da6bf6576c1ec05fc455..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/op_execute_prepare_context_unittest.cc +++ /dev/null @@ -1,167 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include "exe_graph/runtime/op_execute_prepare_context.h" -#include -#include "faker/kernel_run_context_faker.h" - -namespace gert { -class OpExecutePrepareContextUT : public testing::Test {}; -TEST_F(OpExecutePrepareContextUT, GetInputOutputTest) { - gert::Tensor in_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor in_tensor_2 = {{{2, 2, 3, 8}, {2, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x234565}; - gert::Tensor in_tensor_3 = {{{3, 2, 3, 8}, {3, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x45678}; - gert::Tensor in_tensor_4 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnHost, // placement - ge::DT_FLOAT16, // data type - (void *)0x12345}; - gert::Tensor in_tensor_5 = {{{4, 2, 3, 8}, {4, 2, 3, 8}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnHost, // placement - ge::DT_FLOAT16, // data type - (void *)0x123345}; - gert::Tensor out_tensor_1 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x34586}; - gert::Tensor out_tensor_2 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x342354}; - gert::Tensor out_tensor_3 = {{{8, 3, 224, 224}, {8, 1, 224, 224, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *)0x642354}; - OpExecuteOptions execute_option; - execute_option.allow_hf32[0] = '1'; - execute_option.deterministic = 1; - execute_option.precision_mode = 1; - - auto param = new DummyOpApiParams(); - auto param2 = new DummyOpApiParams(); - auto op_execute_holder = OpExecutePrepareContextFaker() - .IrInstanceNum({1, 2, 1, 0, 1}) - .IrOutputInstanceNum({2, 1}) - .NodeIoNum(5, 3) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeInputTd(4, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .NodeOutputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputTensor({&in_tensor_1, &in_tensor_2, &in_tensor_3, &in_tensor_4, &in_tensor_5}) - .OutputTensor({&out_tensor_1, &out_tensor_2, &out_tensor_3}) - .ExecuteOption(&execute_option) - .Build(); - auto context = op_execute_holder.GetContext(); - ASSERT_NE(context, nullptr); - ASSERT_NE(context->GetInputTensor(0), nullptr); - EXPECT_EQ(context->GetInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(0)->GetAddr(), in_tensor_1.GetAddr()); - ASSERT_NE(context->GetRequiredInputTensor(0), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetOriginShape(), in_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetAddr(), in_tensor_1.GetAddr()); - - ASSERT_NE(context->GetInputTensor(1), nullptr); - EXPECT_EQ(context->GetInputTensor(1)->GetOriginShape(), in_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(1)->GetAddr(), in_tensor_2.GetAddr()); - ASSERT_NE(context->GetDynamicInputTensor(1, 0), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetOriginShape(), in_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 0)->GetAddr(), in_tensor_2.GetAddr()); - - ASSERT_NE(context->GetInputTensor(2), nullptr); - EXPECT_EQ(context->GetInputTensor(2)->GetOriginShape(), in_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(2)->GetAddr(), in_tensor_3.GetAddr()); - ASSERT_NE(context->GetDynamicInputTensor(1, 1), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetOriginShape(), in_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(1, 1)->GetAddr(), in_tensor_3.GetAddr()); - - ASSERT_NE(context->GetInputTensor(3), nullptr); - EXPECT_EQ(context->GetInputTensor(3)->GetOriginShape(), in_tensor_4.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(3)->GetAddr(), in_tensor_4.GetAddr()); - ASSERT_NE(context->GetOptionalInputTensor(2), nullptr); - EXPECT_EQ(context->GetOptionalInputTensor(2)->GetOriginShape(), in_tensor_4.GetOriginShape()); - EXPECT_EQ(context->GetOptionalInputTensor(2)->GetAddr(), in_tensor_4.GetAddr()); - - EXPECT_EQ(context->GetOptionalInputTensor(3), nullptr); - - ASSERT_NE(context->GetInputTensor(4), nullptr); - EXPECT_EQ(context->GetInputTensor(4)->GetOriginShape(), in_tensor_5.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(4)->GetAddr(), in_tensor_5.GetAddr()); - ASSERT_NE(context->GetRequiredInputTensor(4), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(4)->GetOriginShape(), in_tensor_5.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(4)->GetAddr(), in_tensor_5.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(0), nullptr); - EXPECT_EQ(context->GetOutputTensor(0)->GetOriginShape(), out_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(0)->GetAddr(), out_tensor_1.GetAddr()); - ASSERT_NE(context->GetDynamicOutputTensor(0, 0), nullptr); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 0)->GetOriginShape(), out_tensor_1.GetOriginShape()); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 0)->GetAddr(), out_tensor_1.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(1), nullptr); - EXPECT_EQ(context->GetOutputTensor(1)->GetOriginShape(), out_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(1)->GetAddr(), out_tensor_2.GetAddr()); - ASSERT_NE(context->GetDynamicOutputTensor(0, 1), nullptr); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 1)->GetOriginShape(), out_tensor_2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicOutputTensor(0, 1)->GetAddr(), out_tensor_2.GetAddr()); - - ASSERT_NE(context->GetOutputTensor(2), nullptr); - EXPECT_EQ(context->GetOutputTensor(2)->GetOriginShape(), out_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetOutputTensor(2)->GetAddr(), out_tensor_3.GetAddr()); - ASSERT_NE(context->GetRequiredOutputTensor(1), nullptr); - EXPECT_EQ(context->GetRequiredOutputTensor(1)->GetOriginShape(), out_tensor_3.GetOriginShape()); - EXPECT_EQ(context->GetRequiredOutputTensor(1)->GetAddr(), out_tensor_3.GetAddr()); - EXPECT_EQ(context->GetAllowHf32(), execute_option.allow_hf32); - EXPECT_EQ(context->GetDeterministic(), true); - EXPECT_EQ(context->GetPrecisionMode(), execute_option.precision_mode); - - size_t kernel_out_start_num = 5 + 3 + 2; // 5 input tensor, 3 output tensor, 2 append input - auto set_status = context->SetOpApiParamsWithDefaultDeleter(param); - ASSERT_EQ(set_status, ge::GRAPH_SUCCESS); - set_status = context->SetOpApiParams(param2, nullptr); - ASSERT_EQ(set_status, ge::GRAPH_FAILED); - set_status = context->SetOpApiParams(param2, [](void * const data) { - delete reinterpret_cast(data); - }); - ASSERT_EQ(set_status, ge::GRAPH_SUCCESS); - - context->SetWorkspaceSizes({64U, 128U, 256U}); - Chain *out1 = reinterpret_cast( - (reinterpret_cast(op_execute_holder.holder.context_))->values[kernel_out_start_num + 1]); - auto ws_size_vec = out1->GetPointer>(); - ASSERT_NE(ws_size_vec, nullptr); - EXPECT_EQ(ws_size_vec->GetSize(), 3); - EXPECT_EQ(ws_size_vec->GetData()[0], 64U); - context->SetWorkspaceSizes({0U}); - EXPECT_EQ(ws_size_vec->GetSize(), 1); - EXPECT_EQ(ws_size_vec->GetData()[0], 0U); - EXPECT_NE(context->SetWorkspaceSizes({10U, 10U, 10U, 10U}), ge::GRAPH_SUCCESS); - EXPECT_EQ(ws_size_vec->GetSize(), 1); - EXPECT_EQ(ws_size_vec->GetData()[0], 0U); -} -} \ No newline at end of file diff --git a/tests/ut/exe_graph/shape_range_unittest.cc b/tests/ut/exe_graph/shape_range_unittest.cc deleted file mode 100644 index e1172c90484c81daf1ae204a873c05e112e44a1a..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/shape_range_unittest.cc +++ /dev/null @@ -1,114 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/range.h" -#include -namespace gert { -class ShapeRangeUT : public testing::Test {}; -TEST_F(ShapeRangeUT, DefaultConstructOk) { - Range sr; - EXPECT_EQ(sr.GetMax(), nullptr); - EXPECT_EQ(sr.GetMin(), nullptr); -} - -TEST_F(ShapeRangeUT, ConstructFromListOk) { - Shape min{8, 3, 224, 224}; - Shape max{-1, 10, 2240, 2240}; - Range sr(&min, &max); - auto s = sr.GetMin(); - EXPECT_EQ(s->GetDimNum(), 4); - EXPECT_EQ(s->GetDim(0), 8); - EXPECT_EQ(s->GetDim(1), 3); - EXPECT_EQ(s->GetDim(2), 224); - EXPECT_EQ(s->GetDim(3), 224); - s = sr.GetMax(); - EXPECT_EQ(s->GetDimNum(), 4); - EXPECT_EQ(s->GetDim(0), -1); - EXPECT_EQ(s->GetDim(1), 10); - EXPECT_EQ(s->GetDim(2), 2240); - EXPECT_EQ(s->GetDim(3), 2240); -} - -TEST_F(ShapeRangeUT, ConstructFromShapesOk) { - Shape s1{8, 3, 224, 224}; - Shape s2{-1, 10, 2240, 2240}; - Range sr(&s1, &s2); - auto s = sr.GetMin(); - EXPECT_EQ(s->GetDimNum(), 4); - EXPECT_EQ(s->GetDim(0), 8); - EXPECT_EQ(s->GetDim(1), 3); - EXPECT_EQ(s->GetDim(2), 224); - EXPECT_EQ(s->GetDim(3), 224); - s = sr.GetMax(); - EXPECT_EQ(s->GetDimNum(), 4); - EXPECT_EQ(s->GetDim(0), -1); - EXPECT_EQ(s->GetDim(1), 10); - EXPECT_EQ(s->GetDim(2), 2240); - EXPECT_EQ(s->GetDim(3), 2240); -} - -TEST_F(ShapeRangeUT, ConstructFromSameShapeOk) { - Shape s{8, 3, 224, 224}; - Range sr(&s); - auto min = sr.GetMin(); - EXPECT_EQ(s.GetDimNum(), 4); - EXPECT_EQ(s.GetDim(0), 8); - EXPECT_EQ(s.GetDim(1), 3); - EXPECT_EQ(s.GetDim(2), 224); - EXPECT_EQ(s.GetDim(3), 224); - auto max = sr.GetMax(); - EXPECT_EQ(min, max); -} - -TEST_F(ShapeRangeUT, EqualOk) { - Shape s1{8, 3, 224, 224}; - Shape s2{8, 3, 224, 224}; - Range sr1(&s1, &s2); - Range sr2(&s1, &s2); - EXPECT_TRUE(sr1 == sr2); -} - -TEST_F(ShapeRangeUT, SetMaxOk) { - Range sr; - Shape max{7, 3, 224, 224}; - sr.SetMax(&max); - EXPECT_EQ(*sr.GetMax(), max); -} - -TEST_F(ShapeRangeUT, SetMinOk) { - Range sr; - Shape min{7, 4, -1, -1}; - sr.SetMin(&min); - EXPECT_EQ(sr.GetMin(), &min); -} - -TEST_F(ShapeRangeUT, ModifyMaxOk) { - Shape max{7, 4, -1, -1}; - Range sr; - sr.SetMax(&max); - - auto max_pr = sr.GetMax(); - (*max_pr)[0] = 8; - (*max_pr)[1] = -1; - EXPECT_EQ(sr.GetMax()->GetDim(0), 8); - EXPECT_EQ(sr.GetMax()->GetDim(1), -1); -} - -TEST_F(ShapeRangeUT, ModifyMinOk) { - Shape min{3, 4, 255, 6}; - Range sr; - sr.SetMin(&min); - - auto min_pr = sr.GetMin(); - (*min_pr)[0] = 1; - (*min_pr)[1] = 2; - EXPECT_EQ(sr.GetMin()->GetDim(0), 1); - EXPECT_EQ(sr.GetMin()->GetDim(1), 2); -} -} // namespace gert diff --git a/tests/ut/exe_graph/shape_unittest.cc b/tests/ut/exe_graph/shape_unittest.cc deleted file mode 100644 index 1371325ff1238d425c6430f7207ba6a6c314dc4c..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/shape_unittest.cc +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/shape.h" -#include -namespace gert { -class ShapeUT : public testing::Test {}; -TEST_F(ShapeUT, DefaultConstructOk) { - Shape s; - EXPECT_EQ(s.GetDimNum(), 0); -} - -TEST_F(ShapeUT, ConstructFromListOk) { - Shape s{8, 3, 224, 224}; - EXPECT_EQ(s.GetDimNum(), 4); - EXPECT_EQ(s.GetDim(0), 8); - EXPECT_EQ(s.GetDim(1), 3); - EXPECT_EQ(s.GetDim(2), 224); - EXPECT_EQ(s.GetDim(3), 224); -} - -TEST_F(ShapeUT, ConstructFromListOverMaxNum) { - Shape s{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}; - EXPECT_EQ(s.GetDimNum(), 0); -} - -TEST_F(ShapeUT, EqualOk) { - Shape s1{8, 3, 224, 224}; - Shape s2{8, 3, 224, 224}; - EXPECT_TRUE(s1 == s2); - EXPECT_FALSE(s1 != s2); -} - -TEST_F(ShapeUT, NotEqualForSize) { - Shape s1{3, 224, 224}; - Shape s2{8, 3, 224, 224}; - EXPECT_FALSE(s1 == s2); - EXPECT_TRUE(s1 != s2); -} - -TEST_F(ShapeUT, NotEqualForDim) { - Shape s1{7, 3, 224, 224}; - Shape s2{8, 3, 224, 224}; - EXPECT_FALSE(s1 == s2); - EXPECT_TRUE(s1 != s2); -} - -TEST_F(ShapeUT, GetTensorShapeSizeOk) { - Shape s2{8, 3, 224, 224}; - EXPECT_EQ(s2.GetShapeSize(), 8 * 3 * 224 *224); -} - -TEST_F(ShapeUT, GetScalerShapeSizeOk) { - Shape s; - EXPECT_EQ(s.GetShapeSize(), 1); -} - -TEST_F(ShapeUT, Get1ShapeSizeOk) { - Shape s{1}; - EXPECT_EQ(s.GetShapeSize(), 1); -} - -TEST_F(ShapeUT, GetShapeOverflow) { - Shape s{8, 3, 224, std::numeric_limits::max()}; - EXPECT_LT(s.GetShapeSize(), 0); -} - -TEST_F(ShapeUT, GetEmptyTensorShapeSize) { - Shape s2{8, 3, 224, 0}; - EXPECT_EQ(s2.GetShapeSize(), 0); -} - -TEST_F(ShapeUT, IsScalarOk) { - Shape s1; - Shape s2{8, 3, 224, 0}; - EXPECT_TRUE(s1.IsScalar()); - EXPECT_FALSE(s2.IsScalar()); -} - -TEST_F(ShapeUT, GetDimNumOk) { - Shape s1; - Shape s2{8, 3, 224, 0}; - EXPECT_EQ(s1.GetDimNum(), 0); - EXPECT_EQ(s2.GetDimNum(), 4); -} - -TEST_F(ShapeUT, SetGetDimNumOk) { - Shape s; - EXPECT_EQ(s.GetDimNum(), 0); - s.SetDimNum(4); - EXPECT_EQ(s.GetDimNum(), 4); -} - -TEST_F(ShapeUT, GetDimOk) { - Shape s1; - Shape s2{8, 3, 224, 224}; - - EXPECT_EQ(s1.GetDim(0), 0); - EXPECT_EQ(s2.GetDim(0), 8); - EXPECT_EQ(s2.GetDim(1), 3); - EXPECT_EQ(s2.GetDim(2), 224); - EXPECT_EQ(s2.GetDim(3), 224); - - EXPECT_EQ(s1[0], 0); - EXPECT_EQ(s2[0], 8); - EXPECT_EQ(s2[1], 3); - EXPECT_EQ(s2[2], 224); - EXPECT_EQ(s2[3], 224); -} - -TEST_F(ShapeUT, ModifyDimOk) { - Shape s1; - Shape s2{8, 3, 224, 224}; - - s1[0] = 8; - s1[1] = 8; - EXPECT_EQ(s1.GetDim(0), 8); - EXPECT_EQ(s1.GetDim(1), 8); - - s2[0] = 16; - s2[1] = 16; - EXPECT_EQ(s2.GetDim(0), 16); - EXPECT_EQ(s2.GetDim(1), 16); -} - -TEST_F(ShapeUT, SetGetDimOfOutRange) { - Shape s1; - EXPECT_EQ(s1.GetDim(25), std::numeric_limits::min()); - s1.SetDim(25, 10); -} - -TEST_F(ShapeUT, SetGetDimOk) { - Shape s{1}; - EXPECT_EQ(s.GetDim(0), 1); - s.SetDim(0, 10); - EXPECT_EQ(s.GetDim(0), 10); -} - -TEST_F(ShapeUT, AppendDimOk) { - Shape s{1}; - s.AppendDim(10).AppendDim(20); - Shape expect_s{1,10,20}; - EXPECT_EQ(s, expect_s); -} - -TEST_F(ShapeUT, AppendDimOutOfBounds) { - Shape s{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; - Shape expect_s{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25}; - s.AppendDim(10); - EXPECT_EQ(s, expect_s); -} - -TEST_F(ShapeUT, SetScalar) { - Shape s{1, 2, 3, 4}; - EXPECT_EQ(s.GetDimNum(), 4); - s.SetScalar(); - EXPECT_EQ(s.GetDimNum(), 0); -} - -TEST_F(ShapeUT, CopyConstruct) { - Shape s{1, 2, 3, 4, 5}; - Shape s_copy(s); - EXPECT_EQ(s_copy.GetDimNum(), 5); -} - - -TEST_F(ShapeUT, CopyAssign) { - Shape s{4,3,2,1}; - Shape s_copy{1, 2, 3, 4, 5}; - EXPECT_EQ(s_copy.GetDimNum(), 5); - s_copy = s; - EXPECT_EQ(s_copy.GetDimNum(), 4); - EXPECT_EQ(s_copy.GetDim(4), 5); - - Shape a{4,3,2,1}; - Shape a_copy{1, 2, 3, 4, 5}; - EXPECT_EQ(a.GetDimNum(), 4); - a = a_copy; - EXPECT_EQ(a.GetDimNum(), 5); - EXPECT_EQ(a.GetDim(4), 5); -} -} // namespace gert diff --git a/tests/ut/exe_graph/shape_utils_unittest.cc b/tests/ut/exe_graph/shape_utils_unittest.cc deleted file mode 100644 index aaba3efd16b31a93a4b5795f78918e9ff45d6e5f..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/shape_utils_unittest.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/lowering/shape_utils.h" -#include -namespace gert { -class ShapeUtilsUT : public testing::Test {}; -TEST_F(ShapeUtilsUT, EnsureNotScalar_ReturnInput_NotScalarShape) { - Shape s1{1, 2, 3}; - Shape s2{1}; - Shape s3{0, 1}; - ASSERT_EQ(&EnsureNotScalar(s1), &s1); - ASSERT_EQ(&EnsureNotScalar(s2), &s2); - ASSERT_EQ(&EnsureNotScalar(s3), &s3); -} -TEST_F(ShapeUtilsUT, EnsureNotScalar_ReturnVec1_ScalarShape) { - Shape s1{}; - auto &s = EnsureNotScalar(s1); - ASSERT_FALSE(s.IsScalar()); - ASSERT_EQ(s.GetDimNum(), 1); - ASSERT_EQ(s.GetDim(0), 1); -} -TEST_F(ShapeUtilsUT, ShapeToString_Empty_Scalar) { - Shape s; - EXPECT_TRUE(ShapeToString(s).empty()); -} -TEST_F(ShapeUtilsUT, ShapeToString_NoComman_OneDim) { - Shape s{1}; - EXPECT_EQ(ShapeToString(s), "1"); -} -TEST_F(ShapeUtilsUT, ShapeToString_DefaulJoinStr) { - Shape s{1, 2, 3}; - EXPECT_EQ(ShapeToString(s), "1,2,3"); -} -TEST_F(ShapeUtilsUT, ShapeToString_SelfDefinedStr) { - Shape s{1, 2, 3}; - EXPECT_EQ(ShapeToString(s, ", "), "1, 2, 3"); -} -TEST_F(ShapeUtilsUT, ShapeToString_UseComman_NullJoinStr) { - Shape s{1, 2, 3}; - EXPECT_EQ(ShapeToString(s, nullptr), "1,2,3"); -} - -TEST_F(ShapeUtilsUT, CalcAlignedSizeByShape_success) { - Shape s{1, 2, 3}; - size_t ret_tensor_size = 0U; - CalcAlignedSizeByShape(s, ge::DataType::DT_FLOAT, ret_tensor_size); - EXPECT_EQ(ret_tensor_size, 64); - - CalcAlignedSizeByShape(s, ge::DataType::DT_FLOAT16, ret_tensor_size); - EXPECT_EQ(ret_tensor_size, 64); - - CalcAlignedSizeByShape(s, ge::DataType::DT_STRING, ret_tensor_size); - EXPECT_EQ(ret_tensor_size, 128); - - CalcAlignedSizeByShape(s, ge::DataType::DT_INT4, ret_tensor_size); - EXPECT_EQ(ret_tensor_size, 64); -} - -TEST_F(ShapeUtilsUT, CalcTotalSizeByShape_failed) { - Shape s{-1, 2, 3}; - size_t ret_tensor_size = 0U; - - EXPECT_NE(CalcAlignedSizeByShape(s, ge::DataType::DT_STRING, ret_tensor_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(ret_tensor_size, 0); - - Shape s1{std::numeric_limits::max(), 2, 3}; - EXPECT_NE(CalcAlignedSizeByShape(s1, ge::DataType::DT_STRING, ret_tensor_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(ret_tensor_size, 0); -} -} // namespace gert diff --git a/tests/ut/exe_graph/symbol_shape_unittest.cc b/tests/ut/exe_graph/symbol_shape_unittest.cc deleted file mode 100644 index 572f3e86ef1f91afad0549efa228a83137fd499a..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/symbol_shape_unittest.cc +++ /dev/null @@ -1,368 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "exe_graph/runtime/symbolic_tensor.h" -#include "graph/symbolizer/symbolic.h" -#include "expression/const_values.h" -#include "graph/ge_tensor.h" - -namespace gert { -class SymbolShapeUT : public testing::Test {}; -TEST_F(SymbolShapeUT, DefaultConstructOk) { - SymbolShape s; - EXPECT_EQ(s.GetDimNum(), 0); -} - -TEST_F(SymbolShapeUT, ConstructFromListOk) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape s{s0, s1, s2, s3}; - EXPECT_EQ(s.GetDimNum(), 4); - EXPECT_EQ(s.GetDim(0), s0); - EXPECT_EQ(s.GetDim(1), s1); - EXPECT_EQ(s.GetDim(2), s2); - EXPECT_EQ(s.GetDim(3), s3); - - EXPECT_EQ(s.GetDims()[0], s0); - EXPECT_EQ(s.GetDims()[1], s1); - EXPECT_EQ(s.GetDims()[2], s2); - EXPECT_EQ(s.GetDims()[3], s3); -} - -TEST_F(SymbolShapeUT, ConstructMaxNum) { - auto s0 = ge::Symbol("s0"); - SymbolShape s{s0, s0, s0, s0, s0, s0, s0, s0, s0, s0, - s0, s0, s0, s0, s0, s0, s0, s0, s0, s0, - s0, s0, s0, s0, s0, s0}; - EXPECT_EQ(s.GetDimNum(), 26); -} - -TEST_F(SymbolShapeUT, EqualOk) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape1{s0, s1, s2, s3}; - SymbolShape symbol_shape2{s0, s1, s2, s3}; - - EXPECT_TRUE(symbol_shape1 == symbol_shape2); - EXPECT_FALSE(symbol_shape1 != symbol_shape2); -} - -TEST_F(SymbolShapeUT, NotEqualForSize) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape1{s0, s1, s2}; - SymbolShape symbol_shape2{s0, s1, s2, s3}; - EXPECT_FALSE(symbol_shape1 == symbol_shape2); - EXPECT_TRUE(symbol_shape1 != symbol_shape2); -} - -TEST_F(SymbolShapeUT, NotEqualForDim) { - auto s0 = ge::Symbol("s0"); - auto s0_bak = ge::Symbol(9); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape1{s0, s1, s2, s3}; - SymbolShape symbol_shape2{s0_bak, s1, s2, s3}; - EXPECT_FALSE(symbol_shape1 == symbol_shape2); - EXPECT_TRUE(symbol_shape1 != symbol_shape2); -} - -TEST_F(SymbolShapeUT, GetTensorShapeSizeOk) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape1{s0, s1, s2, s3}; - auto expr = ge::sym::Mul(ge::sym::Mul(ge::sym::Mul(s0, s1), s2), s3); - EXPECT_EQ(symbol_shape1.GetSymbolShapeSize(), expr); -} - -TEST_F(SymbolShapeUT, GetScalerShapeSizeOk) { - SymbolShape s; - EXPECT_EQ(s.GetSymbolShapeSize(), ge::sym::kSymbolOne); -} - -TEST_F(SymbolShapeUT, Get1ShapeSizeOk) { - auto s0 = ge::Symbol("s0"); - SymbolShape s{s0}; - EXPECT_EQ(s.GetSymbolShapeSize(), ge::Symbol("s0")); -} - -TEST_F(SymbolShapeUT, GetEmptyTensorShapeSize) { - auto s0 = ge::Symbol(0); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape{s0, s1, s2, s3}; - EXPECT_EQ(symbol_shape.GetSymbolShapeSize(), ge::sym::kSymbolZero); -} - -TEST_F(SymbolShapeUT, IsScalarOk) { - SymbolShape symbol_shape1; - auto s0 = ge::Symbol(0); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape2{s0, s1, s2, s3}; - EXPECT_TRUE(symbol_shape1.IsScalar()); - EXPECT_FALSE(symbol_shape2.IsScalar()); -} - -TEST_F(SymbolShapeUT, GetDimNumOk) { - SymbolShape symbol_shape1; - auto s0 = ge::Symbol(0); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape2{s0, s1, s2, s3}; - EXPECT_EQ(symbol_shape1.GetDimNum(), 0); - EXPECT_EQ(symbol_shape2.GetDimNum(), 4); -} - -TEST_F(SymbolShapeUT, GetDimOk) { - SymbolShape symbol_shape1; - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape2{s0, s1, s2, s3}; - - EXPECT_EQ(symbol_shape1.GetDims().size(), 0); - - EXPECT_EQ(symbol_shape2.GetDims()[0], s0); - EXPECT_EQ(symbol_shape2.GetDims()[1], s1); - EXPECT_EQ(symbol_shape2.GetDims()[2], s2); - EXPECT_EQ(symbol_shape2.GetDims()[3], s3); - - EXPECT_EQ(symbol_shape2.MutableDims()[0], s0); - EXPECT_EQ(symbol_shape2.MutableDims()[1], s1); - EXPECT_EQ(symbol_shape2.MutableDims()[2], s2); - EXPECT_EQ(symbol_shape2.MutableDims()[3], s3); -} - -TEST_F(SymbolShapeUT, ModifyDimOk) { - SymbolShape symbol_shape1; - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape2{s0, s1, s2, s3}; - - symbol_shape1.MutableDims().emplace_back(s0); - symbol_shape1.MutableDims().emplace_back(s0); - EXPECT_EQ(symbol_shape1.GetDims()[0], s0); - EXPECT_EQ(symbol_shape1.GetDims()[1], s0); - - symbol_shape2.MutableDims()[0] = s1; - symbol_shape2.MutableDims()[1] = s2; - EXPECT_EQ(symbol_shape2.GetDims()[0], s1); - EXPECT_EQ(symbol_shape2.GetDims()[1], s2); - EXPECT_EQ(symbol_shape2.GetDims()[2], s2); - EXPECT_EQ(symbol_shape2.GetDims()[3], s3); -} - -TEST_F(SymbolShapeUT, SetGetDimOk) { - auto one = ge::Symbol(1); - SymbolShape s{one}; - EXPECT_EQ(s.GetDims()[0], one); - auto s10 = ge::Symbol(10); - s.MutableDims()[0] = s10; - EXPECT_EQ(s.GetDims()[0], ge::Symbol(10)); -} - -TEST_F(SymbolShapeUT, AppendDimOk) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol(10); - SymbolShape symbol_shape1{s0}; - symbol_shape1.AppendDim(s1).AppendDim(s2); - SymbolShape expect_s{s0, s1, s2}; - EXPECT_EQ(symbol_shape1, expect_s); -} - -TEST_F(SymbolShapeUT, AppendDimOutOfBounds) { - auto s0 = ge::Symbol("s0"); - SymbolShape s{s0, s0, s0, s0, s0, s0, s0, s0, s0, - s0, s0, s0, s0, s0, s0, s0, s0, s0, - s0, s0, s0, s0, s0, s0, s0}; - SymbolShape expect_s{s0, s0, s0, s0, s0, s0, s0, s0, s0, - s0, s0, s0, s0, s0, s0, s0, s0, s0, - s0, s0, s0, s0, s0, s0, s0, s0}; - s.AppendDim(s0); - EXPECT_EQ(s, expect_s); -} - -TEST_F(SymbolShapeUT, SetScalar) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - SymbolShape symbol_shape2{s0, s1, s2, s3}; - EXPECT_EQ(symbol_shape2.GetDimNum(), 4); - symbol_shape2.SetScalar(); - EXPECT_EQ(symbol_shape2.GetDimNum(), 0); -} - -TEST_F(SymbolShapeUT, CopyConstruct) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s2"); - const SymbolShape symbol_shape1{s0, s1, s2, s3}; - const SymbolShape symbol_shape_copy(symbol_shape1); - EXPECT_EQ(symbol_shape_copy.GetDimNum(), 4); -} - -TEST_F(SymbolShapeUT, CopyAssign) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - auto s4 = ge::Symbol(128); - SymbolShape symbol_shape1{s0, s1, s2, s3}; - SymbolShape symbol_shape_copy(symbol_shape1); - - EXPECT_EQ(symbol_shape_copy.GetDimNum(), 4); - symbol_shape_copy = symbol_shape1; - EXPECT_EQ(symbol_shape_copy.GetDimNum(), 4); - - SymbolShape a{s4, s3, s2, s1}; - SymbolShape a_copy{s0, s1, s2, s3, s4}; - EXPECT_EQ(a.GetDimNum(), 4); - a = a_copy; - EXPECT_EQ(a.GetDimNum(), 5); - EXPECT_EQ(a.GetDims()[4], s4); - - std::vector res{s0, s1, s2, s3, s4}; - EXPECT_EQ(a.GetDims().size(), 5); - EXPECT_EQ(a.GetDims().at(0), s0); - EXPECT_EQ(a.GetDims().at(2), s2); - EXPECT_EQ(a.GetDims().at(4), s4); - a.MutableDims()[0] = s1; - EXPECT_EQ(a.GetDims().at(0), s1); - a.MutableDims()[2] = s1; - EXPECT_EQ(a.GetDims().at(2), s1); -} - -// SymbolShape的基础测试 -TEST_F(SymbolShapeUT, SymbolShapeTest) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - - SymbolShape symbol_shape({s0, s1, s2, s3}); - EXPECT_EQ(symbol_shape.GetDimNum(), 4); - EXPECT_EQ(symbol_shape.GetDim(0), s0); - EXPECT_EQ(symbol_shape.GetDim(1), s1); - symbol_shape.MutableDim(0) = s2; - EXPECT_EQ(symbol_shape.GetDim(0), s2); - symbol_shape.AppendDim(s3); - EXPECT_EQ(symbol_shape.GetDimNum(), 5); - EXPECT_EQ(symbol_shape.GetSymbolShapeSize(), s1 * s2 * s2 * s3 * s3); - EXPECT_EQ(symbol_shape.IsScalar(), false); - symbol_shape.Clear(); - EXPECT_EQ(symbol_shape.GetDimNum(), 0); - - // scalar测试 - SymbolShape symbol_shape_scalar; - EXPECT_EQ(symbol_shape_scalar.IsScalar(), true); - EXPECT_EQ(symbol_shape_scalar.GetDimNum(), 0); - EXPECT_EQ(symbol_shape_scalar.GetSymbolShapeSize(), 1); - - // Mutabe测试 - SymbolShape symbol_shape2({s0, s1, s2, s3}); - symbol_shape2.MutableDim(0) = s1; - EXPECT_EQ(symbol_shape2.GetDim(0), s1); - SymbolShape symbol_shape2_res({s1, s1, s2, s3}); - EXPECT_EQ(symbol_shape2, symbol_shape2_res); - - symbol_shape2.MutableDims() = {s2, s2, s2, s2}; - SymbolShape symbol_shape2_res2({s2, s2, s2, s2}); - EXPECT_EQ(symbol_shape2, symbol_shape2_res2); - EXPECT_NE(symbol_shape2, symbol_shape2_res); - - // 等于不等于测试 - SymbolShape symbol_shape3({s0, s1, s2}); - SymbolShape symbol_shape4({s0, s1, s2, s3}); - EXPECT_EQ(symbol_shape3 == symbol_shape3, true); - EXPECT_EQ(symbol_shape3 == symbol_shape4, false); - EXPECT_EQ(symbol_shape3 != symbol_shape3, false); - EXPECT_EQ(symbol_shape3 != symbol_shape4, true); - - symbol_shape3.AppendDim(s3); - EXPECT_EQ(symbol_shape3 == symbol_shape4, true); - EXPECT_EQ(symbol_shape3 != symbol_shape4, false); -} - -// SymbolShape的拷贝构造、移动构造、拷贝赋值、移动赋值测试 -TEST_F(SymbolShapeUT, SymbolShapeCopyTest) { - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - - SymbolShape symbol_shape({s0, s1, s2, s3}); - - SymbolShape symbol_shape_copy(symbol_shape); - EXPECT_EQ(symbol_shape_copy.GetDimNum(), 4); - EXPECT_EQ(symbol_shape_copy.GetDim(0), s0); - EXPECT_EQ(symbol_shape_copy.GetDim(1), s1); - - SymbolShape symbol_shape_move(std::move(symbol_shape_copy)); - EXPECT_EQ(symbol_shape_move.GetDimNum(), 4); - EXPECT_EQ(symbol_shape_move.GetDim(0), s0); - EXPECT_EQ(symbol_shape_move.GetDim(1), s1); - // 移动赋值后,symbol_shape_move的内容没了 - EXPECT_EQ(symbol_shape_copy.GetDimNum(), 0); - - symbol_shape_copy = symbol_shape_move; - EXPECT_EQ(symbol_shape_copy.GetDimNum(), 4); - EXPECT_EQ(symbol_shape_copy.GetDim(0), s0); - EXPECT_EQ(symbol_shape_copy.GetDim(1), s1); - // 拷贝赋值后,symbol_shape_move的内容不变 - EXPECT_EQ(symbol_shape_move.GetDimNum(), 4); - - symbol_shape_move = std::move(symbol_shape_copy); - EXPECT_EQ(symbol_shape_move.GetDimNum(), 4); - EXPECT_EQ(symbol_shape_move.GetDim(0), s0); - EXPECT_EQ(symbol_shape_move.GetDim(1), s1); -} - -TEST_F(SymbolShapeUT, SymbolShapeSizeCacheTest) { - SymbolShape symbol_shape; - EXPECT_EQ(symbol_shape.GetSymbolShapeSize(), 1); - - auto s0 = ge::Symbol("s0"); - auto s1 = ge::Symbol("s1"); - auto s2 = ge::Symbol("s2"); - auto s3 = ge::Symbol("s3"); - - SymbolShape symbol_shape1({s0, s1, s2, s3}); - auto size = symbol_shape1.GetSymbolShapeSize(); - EXPECT_EQ(size, s0 * s1 * s2 * s3); - - symbol_shape1.MutableDim(0) = s1; - EXPECT_EQ(symbol_shape1.GetSymbolShapeSize(), s1 * s1 * s2 * s3); - - symbol_shape1.MutableDims() = {s2, s2, s2, s2}; - EXPECT_EQ(symbol_shape1.GetSymbolShapeSize(), s2 * s2 * s2 * s2); - - symbol_shape1.AppendDim(s3); - EXPECT_EQ(symbol_shape1.GetSymbolShapeSize(), s2 * s2 * s2 * s2 * s3); -} -} // namespace gert diff --git a/tests/ut/exe_graph/tensor_data_unittest.cc b/tests/ut/exe_graph/tensor_data_unittest.cc deleted file mode 100644 index aef54ef8f23e259679b0f28ddef8e38d0aa8cecb..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/tensor_data_unittest.cc +++ /dev/null @@ -1,396 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/tensor_data.h" -#ifdef ONLY_COMPILE_OPEN_SRC -#include "exe_graph/runtime/tensor_data_utils.h" -#endif -#include -namespace gert { -namespace { -template -class ManagerStub { - public: - static ge::graphStatus Success(TensorAddress addr, TensorOperateType operate_type, void **out) { - operate_count[operate_type]++; - if (operate_type == kGetTensorAddress) { - *out = reinterpret_cast(N); - } - return ge::GRAPH_SUCCESS; - } - static ge::graphStatus Failed(TensorAddress addr, TensorOperateType operate_type, void **out) { - return ge::GRAPH_FAILED; - } - static ge::graphStatus FreeFailed(TensorAddress addr, TensorOperateType operate_type, void **out) { - operate_count[operate_type]++; - if (operate_type == kFreeTensor) { - return ge::GRAPH_FAILED; - } - return Success(addr, operate_type, out); - } - static void Clear() { - memset(operate_count, 0, sizeof(operate_count)); // memse函数misra告警屏蔽 - } - static size_t operate_count[kTensorOperateType]; -}; - -template -size_t ManagerStub::operate_count[kTensorOperateType] = {0}; -} // namespace - -class TensorDataUT : public testing::Test {}; - -TEST_F(TensorDataUT, TensorDataWithMangerSuccess) { - ManagerStub<8>::Clear(); - - auto addr = reinterpret_cast(0x16); - { - TensorData data(addr, ManagerStub<8>::Success); - EXPECT_EQ(reinterpret_cast(data.GetAddr()), 8); - EXPECT_EQ(data.Free(), ge::GRAPH_SUCCESS); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); - - EXPECT_EQ(data.GetAddr(), nullptr); - data.SetAddr(addr, nullptr); - EXPECT_EQ(reinterpret_cast(data.GetAddr()), 0x16); - data.SetAddr(addr, ManagerStub<8>::Failed); - EXPECT_EQ(data.GetAddr(), nullptr); - } - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, TransferOwner) { - ManagerStub<8>::Clear(); - - auto addr = reinterpret_cast(0x16); - { - TensorData td0(addr, nullptr); - TensorData td1(addr, ManagerStub<8>::Success); - td0 = std::move(td1); - } - EXPECT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 0); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, TransferOnwerFreeOld) { - ManagerStub<8>::Clear(); - ManagerStub<18>::Clear(); - - auto addr = reinterpret_cast(0x16); - { - TensorData td0(addr, ManagerStub<18>::Success); - TensorData td1(addr, ManagerStub<8>::Success); - EXPECT_EQ(ManagerStub<18>::operate_count[kFreeTensor], 0); - td0 = std::move(td1); - EXPECT_EQ(ManagerStub<18>::operate_count[kFreeTensor], 1); - - } - EXPECT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 0); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, ConstructRightValue) { - ManagerStub<8>::Clear(); - - auto addr = reinterpret_cast(0x16); - { - TensorData td0(addr, ManagerStub<8>::Success); - TensorData td1(std::move(td0)); - } - EXPECT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 0); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, FreeByHand) { - ManagerStub<8>::Clear(); - - auto addr = reinterpret_cast(0x16); - { - TensorData td0(addr, ManagerStub<8>::Success); - EXPECT_EQ(td0.Free(), ge::GRAPH_SUCCESS); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); - } - EXPECT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 0); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, TensorDataWithMangerFreeSuccess) { - ManagerStub<8>::Clear(); - { - auto addr = reinterpret_cast(0x16); - TensorData data(addr, ManagerStub<8>::Success); - EXPECT_EQ(reinterpret_cast(data.GetAddr()), 8); - EXPECT_EQ(data.Free(), ge::GRAPH_SUCCESS); - } - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, ShareTensorDataOk) { - TensorData td; - td.SetAddr(reinterpret_cast(10), nullptr); - - TensorData td1; - td1.SetAddr(reinterpret_cast(11), nullptr); - ASSERT_EQ(td1.GetAddr(), reinterpret_cast(11)); - - td1.ShareFrom(td); - EXPECT_EQ(td1.GetAddr(), reinterpret_cast(10)); -} - -TEST_F(TensorDataUT, ReleaseBeforeShareTensorData) { - ManagerStub<8>::Clear(); - - TensorData td; - td.SetAddr(reinterpret_cast(10), nullptr); - - TensorData td1; - td1.SetAddr(reinterpret_cast(11), ManagerStub<8>::Success); - ASSERT_EQ(td1.GetAddr(), reinterpret_cast(8)); - - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - td1.ShareFrom(td); - EXPECT_EQ(td1.GetAddr(), reinterpret_cast(10)); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, ShareManagedTensorData) { - ManagerStub<8>::Clear(); - - { - TensorData td; - td.SetAddr(reinterpret_cast(10), ManagerStub<8>::Success); - - ASSERT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 0); - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - { - TensorData td1; - td1.SetAddr(reinterpret_cast(11), nullptr); - - td1.ShareFrom(td); - ASSERT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 1); - EXPECT_EQ(td1.GetAddr(), reinterpret_cast(8)); - } - ASSERT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 1); - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); - } - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 2); -} - -TEST_F(TensorDataUT, ReleaseBeforeShareManagedTensorData) { - ManagerStub<8>::Clear(); - ManagerStub<18>::Clear(); - - TensorData td; - EXPECT_EQ(td.SetAddr(reinterpret_cast(10), ManagerStub<8>::Success), ge::GRAPH_SUCCESS); - TensorData td1; - EXPECT_EQ(td1.SetAddr(reinterpret_cast(11), ManagerStub<18>::Success), ge::GRAPH_SUCCESS); - - ASSERT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 0); - ASSERT_EQ(ManagerStub<18>::operate_count[kFreeTensor], 0); - td1.ShareFrom(td); - ASSERT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 1); - EXPECT_EQ(ManagerStub<18>::operate_count[kFreeTensor], 1); - EXPECT_EQ(td1.GetAddr(), reinterpret_cast(8)); - - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - EXPECT_EQ(td1.Free(), ge::GRAPH_SUCCESS); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, ReleaseBeforeSet) { - ManagerStub<8>::Clear(); - - TensorData td; - EXPECT_EQ(td.SetAddr(reinterpret_cast(10), ManagerStub<8>::Success), ge::GRAPH_SUCCESS); - - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - EXPECT_EQ(td.SetAddr(reinterpret_cast(100), nullptr), ge::GRAPH_SUCCESS); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} - -TEST_F(TensorDataUT, ReleaseFailedWhenSet) { - ManagerStub<8>::Clear(); - - TensorData td; - EXPECT_EQ(td.SetAddr(reinterpret_cast(10), ManagerStub<8>::FreeFailed), ge::GRAPH_SUCCESS); - ASSERT_EQ(td.GetAddr(), reinterpret_cast(8)); - - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - EXPECT_NE(td.SetAddr(reinterpret_cast(100), nullptr), ge::GRAPH_SUCCESS); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); - EXPECT_EQ(td.GetAddr(), reinterpret_cast(8)); -} - -TEST_F(TensorDataUT, ReleaseFailedWhenShare) { - ManagerStub<8>::Clear(); - - TensorData td; - td.SetAddr(reinterpret_cast(10), nullptr); - - TensorData td1; - td1.SetAddr(reinterpret_cast(11), ManagerStub<8>::FreeFailed); - ASSERT_EQ(td1.GetAddr(), reinterpret_cast(8)); - - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - EXPECT_NE(td1.ShareFrom(td), ge::GRAPH_SUCCESS); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); - EXPECT_EQ(td1.GetAddr(), reinterpret_cast(8)); -} - -TEST_F(TensorDataUT, ShareFromSelf) { - ManagerStub<8>::Clear(); - - auto addr1 = reinterpret_cast(0x16); - auto addr2 = reinterpret_cast(0x26); - { - TensorData td0(addr1, nullptr); - TensorData td1(addr2, ManagerStub<8>::Success); - EXPECT_FALSE(td0.IsSharedWith(td1)); - td0.ShareFrom(td1); - EXPECT_TRUE(td0.IsSharedWith(td0)); - EXPECT_EQ(td0.GetAddr(), td1.GetAddr()); - EXPECT_NE(td0.GetAddr(), nullptr); - - td0.ShareFrom(td0); - EXPECT_EQ(td0.GetAddr(), td1.GetAddr()); - EXPECT_NE(td0.GetAddr(), nullptr); - } - EXPECT_EQ(ManagerStub<8>::operate_count[kPlusShareCount], 1); - EXPECT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 2); -} - -TEST_F(TensorDataUT, InitValue) { - TensorData td(reinterpret_cast(10), nullptr); - EXPECT_EQ(td.GetPlacement(), kTensorPlacementEnd); - EXPECT_EQ(td.GetSize(), 0U); - - td.SetPlacement(kOnHost); - EXPECT_EQ(td.GetPlacement(), kOnHost); - - td.SetSize(10); - EXPECT_EQ(td.GetSize(), 10); - - // test move construct - TensorData td1(std::move(td)); - EXPECT_EQ(td1.GetPlacement(), kOnHost); - EXPECT_EQ(td1.GetSize(), 10); - - // test operator= - TensorData td2 = std::move(td1); - EXPECT_EQ(td2.GetPlacement(), kOnHost); - EXPECT_EQ(td2.GetSize(), 10); - - EXPECT_EQ(td.GetPlacement(), kTensorPlacementEnd); - EXPECT_EQ(td.GetSize(), 0); -} - -TEST_F(TensorDataUT, GetPlacementStr_Success) { - EXPECT_STREQ(GetPlacementStr(kOnHost), "HostDDR"); - EXPECT_STREQ(GetPlacementStr(kOnDeviceHbm), "DeviceHbm"); - EXPECT_STREQ(GetPlacementStr(kFollowing), "HostDDR"); - EXPECT_STREQ(GetPlacementStr(kOnDeviceP2p), "DeviceP2p"); - EXPECT_STREQ(GetPlacementStr(kTensorPlacementEnd), "Unknown"); -} - -TEST_F(TensorDataUT, IsPlacementSrcToDstNeedCopy_Success) { - EXPECT_FALSE(IsPlacementSrcToDstNeedCopy(kOnDeviceHbm, kOnDeviceHbm)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnDeviceHbm, kOnHost)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnDeviceHbm, kFollowing)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnDeviceHbm, kOnDeviceP2p)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnDeviceHbm, kTensorPlacementEnd)); - - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnHost, kOnDeviceHbm)); - EXPECT_FALSE(IsPlacementSrcToDstNeedCopy(kOnHost, kOnHost)); - EXPECT_FALSE(IsPlacementSrcToDstNeedCopy(kOnHost, kFollowing)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnHost, kOnDeviceP2p)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnHost, kTensorPlacementEnd)); - - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kFollowing, kOnDeviceHbm)); - EXPECT_FALSE(IsPlacementSrcToDstNeedCopy(kFollowing, kOnHost)); - EXPECT_FALSE(IsPlacementSrcToDstNeedCopy(kFollowing, kFollowing)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kFollowing, kOnDeviceP2p)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kFollowing, kTensorPlacementEnd)); - - EXPECT_FALSE(IsPlacementSrcToDstNeedCopy(kOnDeviceP2p, kOnDeviceHbm)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnDeviceP2p, kOnHost)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnDeviceP2p, kFollowing)); - EXPECT_FALSE(IsPlacementSrcToDstNeedCopy(kOnDeviceP2p, kOnDeviceP2p)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kOnDeviceP2p, kTensorPlacementEnd)); - - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kTensorPlacementEnd, kOnDeviceHbm)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kTensorPlacementEnd, kOnHost)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kTensorPlacementEnd, kFollowing)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kTensorPlacementEnd, kOnDeviceP2p)); - EXPECT_TRUE(IsPlacementSrcToDstNeedCopy(kTensorPlacementEnd, kTensorPlacementEnd)); -} -TEST_F(TensorDataUT, Release_OwnershipMoved_HasOwnership) { - ManagerStub<8>::Clear(); - - TensorAddress released_addr = nullptr; - TensorAddrManager manager = nullptr; - auto addr = reinterpret_cast(0x16); - { - TensorData data(addr, ManagerStub<8>::Success); - released_addr = data.Release(manager); - EXPECT_EQ(data.GetAddr(), nullptr); - } - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - ASSERT_EQ(released_addr, addr); - ASSERT_EQ(manager, ManagerStub<8>::Success); - - ASSERT_EQ(manager(released_addr, kFreeTensor, nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 1); -} -TEST_F(TensorDataUT, Release_PointerMoved_NoOwnership) { - ManagerStub<8>::Clear(); - - TensorAddress released_addr = nullptr; - TensorAddrManager manager = nullptr; - auto addr = reinterpret_cast(0x16); - { - TensorData data(addr, nullptr); - released_addr = data.Release(manager); - EXPECT_EQ(data.GetAddr(), nullptr); - } - ASSERT_EQ(ManagerStub<8>::operate_count[kFreeTensor], 0); - ASSERT_EQ(released_addr, addr); - ASSERT_EQ(manager, nullptr); -} -TEST_F(TensorDataUT, Release_ReturnNullptr_HasNoData) { - TensorAddrManager manager = nullptr; - TensorData data; - ASSERT_EQ(data.Release(manager), nullptr); - ASSERT_EQ(manager, nullptr); -} -TEST_F(TensorDataUT, Release_FunctionCorrect_AfterRelease) { - ManagerStub<1>::Clear(); - ManagerStub<2>::Clear(); - - TensorAddress released_addr = nullptr; - TensorAddrManager manager = nullptr; - auto addr1 = reinterpret_cast(0x16); - auto addr2 = reinterpret_cast(0x32); - - { - TensorData data1(addr1, ManagerStub<1>::Success); - TensorData data2(addr2, ManagerStub<2>::Success); - - released_addr = data1.Release(manager); - data1.ShareFrom(data2); - } - ASSERT_EQ(ManagerStub<1>::operate_count[kFreeTensor], 0); - ASSERT_EQ(ManagerStub<2>::operate_count[kFreeTensor], 2); - - ASSERT_EQ(released_addr, addr1); - ASSERT_EQ(manager, ManagerStub<1>::Success); - - ASSERT_EQ(manager(released_addr, kFreeTensor, nullptr), ge::GRAPH_SUCCESS); - ASSERT_EQ(ManagerStub<1>::operate_count[kFreeTensor], 1); -} -} // namespace gert diff --git a/tests/ut/exe_graph/tensor_unittest.cc b/tests/ut/exe_graph/tensor_unittest.cc deleted file mode 100644 index 84cee4c084f8dd4a9722fdc4c13d0b6f2eff080f..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/tensor_unittest.cc +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/tensor.h" -#include "graph/ge_tensor.h" -#include -namespace gert { -class TensorUT : public testing::Test {}; -TEST_F(TensorUT, ConstructOk) { - Tensor tensor{{{8, 3, 224, 224}, {16, 3, 224, 224}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, //dt - nullptr}; - const Tensor &t = tensor; - - EXPECT_EQ(t.GetOriginShape(), Shape({8, 3, 224, 224})); - EXPECT_EQ(t.GetStorageShape(), Shape({16, 3, 224, 224})); - - EXPECT_EQ(t.GetOriginFormat(), ge::FORMAT_ND); - EXPECT_EQ(t.GetStorageFormat(), ge::FORMAT_FRACTAL_NZ); - EXPECT_EQ(t.GetExpandDimsType(), ExpandDimsType{}); - - EXPECT_EQ(t.GetPlacement(), kOnDeviceHbm); - EXPECT_EQ(t.GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(t.GetAddr(), nullptr); - EXPECT_EQ(t.GetData(), nullptr); -} - -TEST_F(TensorUT, GetDataAddrFollowingOk) { - Tensor tensor{{{8, 3, 224, 224}, {16, 3, 224, 224}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kFollowing, // placement - ge::DT_FLOAT16, //dt - nullptr}; - const Tensor &t = tensor; - - EXPECT_EQ(t.GetOriginShape(), Shape({8, 3, 224, 224})); - EXPECT_EQ(t.GetStorageShape(), Shape({16, 3, 224, 224})); - - EXPECT_EQ(t.GetOriginFormat(), ge::FORMAT_ND); - EXPECT_EQ(t.GetStorageFormat(), ge::FORMAT_FRACTAL_NZ); - EXPECT_EQ(t.GetExpandDimsType(), ExpandDimsType{}); - - EXPECT_EQ(t.GetPlacement(), kFollowing); - EXPECT_EQ(t.GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(t.GetAddr(), &tensor + 1); - EXPECT_EQ(t.GetData(), reinterpret_cast(&tensor + 1)); -} - -TEST_F(TensorUT, GetCopiedDataAddrFollowingOk) { - Tensor tensor{{{8, 3, 224, 224}, {16, 3, 224, 224}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kFollowing, // placement - ge::DT_FLOAT16, //dt - nullptr}; - Tensor t = {{}, {}, {}, {}, nullptr}; - memcpy(static_cast(&t), static_cast(&tensor), sizeof(tensor)); - - EXPECT_EQ(t.GetOriginShape(), Shape({8, 3, 224, 224})); - EXPECT_EQ(t.GetStorageShape(), Shape({16, 3, 224, 224})); - - EXPECT_EQ(t.GetOriginFormat(), ge::FORMAT_ND); - EXPECT_EQ(t.GetStorageFormat(), ge::FORMAT_FRACTAL_NZ); - EXPECT_EQ(t.GetExpandDimsType(), ExpandDimsType{}); - - EXPECT_EQ(t.GetPlacement(), kFollowing); - EXPECT_EQ(t.GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(t.GetAddr(), &t + 1); - EXPECT_EQ(t.GetData(), reinterpret_cast(&t + 1)); -} - -TEST_F(TensorUT, SetGetShapeOk) { - Tensor t = {{}, {}, {}, {}, nullptr}; - const Tensor &ct = t; - t.MutableOriginShape() = Shape{8,3,224,224}; - t.MutableStorageShape() = Shape{8,1,224,224,16}; - EXPECT_EQ(t.GetOriginShape(), Shape({8,3,224,224})); - EXPECT_EQ(t.GetStorageShape(), Shape({8,1,224,224,16})); - EXPECT_EQ(ct.GetOriginShape(), Shape({8,3,224,224})); - EXPECT_EQ(ct.GetStorageShape(), Shape({8,1,224,224,16})); -} - -TEST_F(TensorUT, SetGetFormatOk) { - Tensor tensor = {{}, {}, {}, {}, nullptr}; - const Tensor &t = tensor; - tensor.SetOriginFormat(ge::FORMAT_NHWC); - tensor.SetStorageFormat(ge::FORMAT_NC1HWC0); - - EXPECT_EQ(t.GetOriginFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(t.GetStorageFormat(), ge::FORMAT_NC1HWC0); - - EXPECT_EQ(t.GetFormat().GetOriginFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(t.GetFormat().GetStorageFormat(), ge::FORMAT_NC1HWC0); -} - -TEST_F(TensorUT, SetGetPlacementOk) { - Tensor t = {{}, {}, {}, {}, nullptr}; - const Tensor &ct = t; - t.SetPlacement(kOnHost); - EXPECT_EQ(t.GetPlacement(), kOnHost); - EXPECT_EQ(ct.GetPlacement(), kOnHost); -} - -TEST_F(TensorUT, SetGetDataTypeOk) { - Tensor t = {{}, {}, {}, {}, nullptr}; - const Tensor &ct = t; - t.SetDataType(ge::DT_FLOAT16); - EXPECT_EQ(t.GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(ct.GetDataType(), ge::DT_FLOAT16); -} - -TEST_F(TensorUT, SetGetAddrOk) { - Tensor t = {{}, {}, {}, {}, nullptr}; - const Tensor &ct = t; - void *a = &t; - - TensorData td(a, nullptr); - t.SetData(std::move(td)); - EXPECT_EQ(t.GetAddr(), a); - EXPECT_EQ(ct.GetAddr(), a); - - EXPECT_EQ(t.GetData(), &t); - EXPECT_EQ(ct.GetData(), &t); -} - -TEST_F(TensorUT, GetTensorDataOk) { - Tensor t = {{}, {}, {}, {}, nullptr}; - auto a = reinterpret_cast(10); - t.MutableTensorData() = TensorData{a, nullptr}; - EXPECT_EQ(t.GetAddr(), a); -} - -TEST_F(TensorUT, GetTensorPlacementOk) { - Tensor t = {{}, {}, kOnHost, {}, nullptr}; - EXPECT_EQ(t.GetPlacement(), kOnHost); -} - -TEST_F(TensorUT, GetTensorSizeOk) { - StorageShape sh({1, 2, 3}, {1, 2, 3}); - Tensor t = {sh, {}, {}, ge::DT_FLOAT, nullptr}; - EXPECT_EQ(t.GetSize(), 24); -} - -TEST_F(TensorUT, CreateFollowingCheckTotalSize) { - size_t total_size; - auto ptr = Tensor::CreateFollowing(32, ge::DataType::DT_INT8, total_size); - EXPECT_NE(ptr, nullptr); - auto tensor = reinterpret_cast(ptr.get()); - EXPECT_EQ(tensor->GetSize(), 32); -} - -TEST_F(TensorUT, CreateFollowingWithTensorCheckTotalSize) { - size_t total_size; - auto ptr = Tensor::CreateFollowing(ge::DT_FLOAT, 8U, total_size); - EXPECT_NE(ptr, nullptr); - auto tensor = reinterpret_cast(ptr.get()); - EXPECT_EQ(tensor->GetSize(), 8); -} - -TEST_F(TensorUT, CreateFollowingWithTensorUseStringTypeCheckTotalSize) { - size_t total_size; - auto ptr = Tensor::CreateFollowing(ge::DT_STRING, 160U, total_size); - EXPECT_NE(ptr, nullptr); - auto tensor = reinterpret_cast(ptr.get()); - EXPECT_EQ(tensor->GetSize(), 160); -} -} // namespace gert diff --git a/tests/ut/exe_graph/tiling_context_builder_unittest.cc b/tests/ut/exe_graph/tiling_context_builder_unittest.cc deleted file mode 100644 index 8c9a617970b1d54bac5f543b6a0921ba723b7f09..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/tiling_context_builder_unittest.cc +++ /dev/null @@ -1,551 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#define protected public -#include "base/registry/op_impl_space_registry_v2.h" -#include "graph/utils/math_util.h" -#include "exe_graph/lowering/tiling_context_builder.h" -#include "exe_graph/lowering/device_tiling_context_builder.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils.h" -#include "exe_graph/runtime/atomic_clean_tiling_context.h" -#include "exe_graph/lowering/value_holder.h" -#include "platform/platform_infos_def.h" -#include "common/ge_common/util.h" -#include "register/op_impl_registry.h" -#include "register/op_impl_registry_base.h" -#include "faker/node_faker.h" -#include "faker/space_registry_faker.h" -#include "graph/debug/ge_attr_define.h" -#include "common/checker.h" -#undef protected - -namespace gert { -class TilingContextBuilderUT : public testing::Test {}; - -namespace { -IMPL_OP(DDIT02).InputsDataDependency({0, 2}); - -ge::Status AddDataNodeForAtomic(ge::ComputeGraphPtr &graph, ge::NodePtr &clean_node, size_t output_size) { - // add data node for workspace - auto workspace_data_op_desc = std::make_shared(clean_node->GetName() + "_Data_0", "Data"); - GE_CHECK_NOTNULL(workspace_data_op_desc); - if (workspace_data_op_desc->AddOutputDesc(ge::GeTensorDesc()) != ge::SUCCESS) { - GELOGE(ge::FAILED, "workspace_data_op_desc add output desc failed"); - return ge::FAILED; - } - auto workspace_data_node = graph->AddNode(workspace_data_op_desc); - GE_CHECK_NOTNULL(workspace_data_node); - auto ret = ge::GraphUtils::AddEdge(workspace_data_node->GetOutDataAnchor(0), clean_node->GetInDataAnchor(0)); - if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "add edge between [%s] and [%s] failed", workspace_data_node->GetName().c_str(), - clean_node->GetName().c_str()); - return ge::FAILED; - } - - // add data node for output - for (size_t i = 0U; i < output_size; ++i) { - auto data_op_desc = std::make_shared(clean_node->GetName() + "_Data_" + std::to_string(i + 1), "Data"); - GE_CHECK_NOTNULL(data_op_desc); - if (data_op_desc->AddOutputDesc(ge::GeTensorDesc()) != ge::SUCCESS) { - GELOGE(ge::FAILED, "data_op_desc add output desc failed, i = %zu", i); - return ge::FAILED; - } - auto data_node = graph->AddNode(data_op_desc); - GE_CHECK_NOTNULL(data_node); - ret = ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), clean_node->GetInDataAnchor(i + 1)); - if (ret != ge::SUCCESS) { - GELOGE(ge::FAILED, "add edge between [%s] and [%s] failed", data_node->GetName().c_str(), - clean_node->GetName().c_str()); - return ge::FAILED; - } - } - return ge::SUCCESS; -} - -ge::NodePtr BuildAtomicNode(ge::ComputeGraphPtr &graph) { - std::vector workspace_indexes = {1, 2}; - std::vector outputs_indexes = {0, 2}; - - auto atomic_op_desc = std::make_shared("AtomicClean", "DynamicAtomicAddrClean"); - - atomic_op_desc->AppendIrInput("workspace", ge::kIrInputRequired); - atomic_op_desc->AppendIrInput("output", ge::kIrInputDynamic); - - atomic_op_desc->AddInputDesc("workspace", ge::GeTensorDesc()); - for (size_t i = 0U; i < outputs_indexes.size(); ++i) { - atomic_op_desc->AddInputDesc("output" + std::to_string(i + 1), ge::GeTensorDesc()); - } - if (!ge::AttrUtils::SetListInt(atomic_op_desc, "WorkspaceIndexes", workspace_indexes)) { - return nullptr; - } - auto clean_node = graph->AddNode(atomic_op_desc); - if (clean_node == nullptr) { - GELOGE(ge::FAILED, "add node failed"); - return nullptr; - } - if (AddDataNodeForAtomic(graph, clean_node, outputs_indexes.size()) != ge::SUCCESS) { - GELOGE(ge::FAILED, "add data node for atomic clean node failed, outputs_indexes size = %zu", - outputs_indexes.size()); - return nullptr; - } - return clean_node; -} -} // namespace - -TEST_F(TilingContextBuilderUT, CompileInfoNullptr) { - fe::PlatFormInfos platform_infos; - auto builder = TilingContextBuilder(); - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - - ge::graphStatus ret; - auto tiling_context_holder = - builder.CompileInfo(nullptr).PlatformInfo(reinterpret_cast(&platform_infos)).Build(op, ret); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_NE(tiling_context_holder.context_, nullptr); -} - -TEST_F(TilingContextBuilderUT, PlatformInfoNullptr) { - fe::PlatFormInfos platform_infos; - auto builder = TilingContextBuilder(); - std::string op_compile_info_json = "{}"; - - auto node = ComputeNodeFaker().NameAndType("Test", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - ASSERT_NE(node, nullptr); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - - ge::graphStatus ret; - auto tiling_context_holder = builder.CompileInfo(&op_compile_info_json).PlatformInfo(nullptr).Build(op, ret); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_NE(tiling_context_holder.context_, nullptr); -} - -TEST_F(TilingContextBuilderUT, BuildRTInputTensorsFailed) { - auto node = ComputeNodeFaker().NameAndType("UbNode", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - auto graph = std::make_shared("ub_graph"); - auto data0 = ComputeNodeFaker(graph) - .NameAndType("Data0", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 0) - .IoNum(0, 1) - .Build(); - auto data1 = ComputeNodeFaker(graph) - .NameAndType("Data1", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 1) - .IoNum(0, 1) - .Build(); - auto data2 = ComputeNodeFaker(graph) - .NameAndType("Data2", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 2) - .IoNum(0, 1) - .Build(); - auto node2 = ComputeNodeFaker().NameAndType("UbNode2", "DDIT02").IoNum(1, 1).InputNames({"d"}).Build(); - ge::GraphUtils::AddEdge(data0->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), node->GetInDataAnchor(2)); - graph->SetParentNode(node2); - - // construct op impl registry - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - auto funcs = gert::OpImplRegistry::GetInstance().GetOpImpl("DDIT02"); - registry_holder->AddTypesToImpl("DDIT02", *funcs); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - auto tiling_data = gert::TilingData::CreateCap(1024); - auto workspace_size = gert::ContinuousVector::Create(16); - std::string op_compile_info_json = "{}"; - fe::PlatFormInfos platform_infos; - auto builder = TilingContextBuilder(); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - ge::graphStatus ret; - auto tiling_context_holder = builder.CompileInfo(const_cast(op_compile_info_json.c_str())) - .PlatformInfo(reinterpret_cast(&platform_infos)) - .TilingData(tiling_data.get()) - .Workspace(reinterpret_cast(workspace_size.get())) - .SpaceRegistry(space_registry) - .Build(op, ret); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_NE(tiling_context_holder.context_, nullptr); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(TilingContextBuilderUT, BuildRTInputTensorsFailed_UseRegistryV2) { - auto node = ComputeNodeFaker().NameAndType("UbNode", "DDIT02").IoNum(3, 1).InputNames({"x", "y", "z"}).Build(); - node->GetOpDesc()->SetOpInferDepends({"x", "z"}); - - auto graph = std::make_shared("ub_graph"); - auto data0 = ComputeNodeFaker(graph) - .NameAndType("Data0", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 0) - .IoNum(0, 1) - .Build(); - auto data1 = ComputeNodeFaker(graph) - .NameAndType("Data1", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 1) - .IoNum(0, 1) - .Build(); - auto data2 = ComputeNodeFaker(graph) - .NameAndType("Data2", "Data") - .Attr(ge::ATTR_NAME_PARENT_NODE_INDEX.c_str(), 2) - .IoNum(0, 1) - .Build(); - auto node2 = ComputeNodeFaker().NameAndType("UbNode2", "DDIT02").IoNum(1, 1).InputNames({"d"}).Build(); - ge::GraphUtils::AddEdge(data0->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(data2->GetOutDataAnchor(0), node->GetInDataAnchor(2)); - graph->SetParentNode(node2); - - // construct op impl registry - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - auto funcs = gert::OpImplRegistry::GetInstance().GetOpImpl("DDIT02"); - registry_holder->AddTypesToImpl("DDIT02", *funcs); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistryV2::GetInstance().SetSpaceRegistry(space_registry); - - auto tiling_data = gert::TilingData::CreateCap(1024); - auto workspace_size = gert::ContinuousVector::Create(16); - std::string op_compile_info_json = "{}"; - fe::PlatFormInfos platform_infos; - auto builder = TilingContextBuilder(); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - ge::graphStatus ret; - auto tiling_context_holder = builder.CompileInfo(const_cast(op_compile_info_json.c_str())) - .PlatformInfo(reinterpret_cast(&platform_infos)) - .TilingData(tiling_data.get()) - .Workspace(reinterpret_cast(workspace_size.get())) - .SetSpaceRegistryV2(space_registry, OppImplVersionTag::kOpp) - .Build(op, ret); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_NE(tiling_context_holder.context_, nullptr); - - auto builder_failed = TilingContextBuilder(); - auto tiling_context_holder_fail = builder_failed.CompileInfo(const_cast(op_compile_info_json.c_str())) - .PlatformInfo(reinterpret_cast(&platform_infos)) - .TilingData(tiling_data.get()) - .Workspace(reinterpret_cast(workspace_size.get())) - .SetSpaceRegistryV2(space_registry, OppImplVersionTag::kVersionEnd) - .Build(op, ret); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - auto ctx = tiling_context_holder_fail.context_; - EXPECT_NE(ctx, nullptr); - DefaultOpImplSpaceRegistryV2::GetInstance().SetSpaceRegistry(nullptr); -} - -// 值依赖场景,输入数据来自const -TEST_F(TilingContextBuilderUT, BuildWithInputConstSuccess) { - auto tiling_data = gert::TilingData::CreateCap(1024); - auto workspace_size = gert::ContinuousVector::Create(16); - std::string op_compile_info_json = "{}"; - fe::PlatFormInfos platform_infos; - auto space_registries = DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistries(); - auto builder = TilingContextBuilder(); - - auto foo_node = ComputeNodeFaker().NameAndType("foo", "Foo").IoNum(1, 2).InputNames({"x"}).Build(); - auto bar_node = ComputeNodeFaker().NameAndType("bar", "Bar").IoNum(2, 1).InputNames({"x", "y"}).Build(); - ge::GraphUtils::AddEdge(foo_node->GetOutDataAnchor(0), bar_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(foo_node->GetOutDataAnchor(1), bar_node->GetInDataAnchor(1)); - const size_t k_input_anchor = 3U; - ge::NodeUtils::AppendInputAnchor(bar_node, k_input_anchor); - EXPECT_EQ(bar_node->GetAllInDataAnchorsSize(), k_input_anchor); - ge::OpDescPtr op_desc = bar_node->GetOpDesc(); - ge::GeTensorDesc tensor_desc(ge::GeShape({1})); - op_desc->AddOutputDesc("z", tensor_desc); - op_desc->MutableInputDesc(1)->SetDataType(ge::DT_INT32); - op_desc->MutableInputDesc(1)->SetShape(ge::GeShape({1})); - op_desc->MutableInputDesc(1)->SetOriginShape(ge::GeShape({1})); - auto op = ge::OpDescUtils::CreateOperatorFromNode(bar_node->shared_from_this()); - ge::graphStatus ret; - auto tiling_context_holder = builder.CompileInfo(const_cast(op_compile_info_json.c_str())) - .PlatformInfo(reinterpret_cast(&platform_infos)) - .TilingData(tiling_data.get()) - .Deterministic(1) - .Workspace(reinterpret_cast(workspace_size.get())) - .SpaceRegistries(space_registries) - .Build(op, ret); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - auto tiling_context = reinterpret_cast(tiling_context_holder.context_); - // check content in context - // 1.check input shape and tensor - auto input_tensor1 = tiling_context->GetInputTensor(1); - EXPECT_NE(input_tensor1, nullptr); - EXPECT_EQ(input_tensor1->GetDataType(), ge::DT_INT32); - EXPECT_EQ(input_tensor1->GetOriginShape().GetDim(0), 1); - - // deprecated later - builder.CompileInfo(const_cast(op_compile_info_json.c_str())) - .PlatformInfo(reinterpret_cast(&platform_infos)) - .TilingData(tiling_data.get()) - .Deterministic(1) - .Workspace(reinterpret_cast(workspace_size.get())) - .SpaceRegistries(space_registries) - .Build(op); - - bg::ValueHolder::PopGraphFrame(); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(TilingContextBuilderUT, BuildAtomicCompileInfoNullptr) { - // build atomic clean node - auto tmp_graph = std::make_shared("tmp-graph"); - BuildAtomicNode(tmp_graph); - auto op = ge::OpDescUtils::CreateOperatorFromNode(tmp_graph->FindNode("AtomicClean")); - auto builder = AtomicTilingContextBuilder(); - ge::graphStatus ret; - auto tiling_context_holder = builder.CompileInfo(nullptr).Build(op, ret); - auto context = reinterpret_cast(tiling_context_holder.context_); - EXPECT_NE(context, nullptr); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - // deprecated later - builder.CompileInfo(nullptr).Build(op); -} - -TEST_F(TilingContextBuilderUT, BuildAtomicTilingContextSuccess) { - // build atomic clean node - std::vector output_clean_sizes = {44, 55}; - auto tmp_graph = std::make_shared("tmp-graph"); - BuildAtomicNode(tmp_graph); - - auto tiling_data = gert::TilingData::CreateCap(1024); - auto workspace_size = gert::ContinuousVector::Create(16); - - std::string op_compile_info_json = "{}"; - auto clean_workspace_size = gert::ContinuousVector::Create(16); - auto clean_workspace_ptr = reinterpret_cast *>(clean_workspace_size.get()); - clean_workspace_ptr->SetSize(2); - *(clean_workspace_ptr->MutableData()) = 22; - *(clean_workspace_ptr->MutableData() + 1) = 33; - - auto op = ge::OpDescUtils::CreateOperatorFromNode(tmp_graph->FindNode("AtomicClean")); - auto builder = AtomicTilingContextBuilder(); - ge::graphStatus ret; - auto tiling_context_holder = - builder.CompileInfo(const_cast(op_compile_info_json.c_str())) - .CleanWorkspaceSizes(reinterpret_cast(clean_workspace_size.get())) - .CleanOutputSizes(output_clean_sizes) - .TilingData(tiling_data.get()) - .Workspace(reinterpret_cast(workspace_size.get())) - .Build(op, ret); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - - auto context = reinterpret_cast(tiling_context_holder.context_); - // check content in context - auto clean_workspace_size_in_context = context->GetCleanWorkspaceSizes(); - EXPECT_EQ(clean_workspace_size_in_context->GetSize(), 2); - auto ws_size_data = reinterpret_cast(clean_workspace_size_in_context->GetData()); - EXPECT_EQ(ws_size_data[0], 22); - EXPECT_EQ(ws_size_data[1], 33); - - EXPECT_EQ(context->GetCleanOutputSize(0), 44); - EXPECT_EQ(context->GetCleanOutputSize(1), 55); -} - -static ge::ComputeGraphPtr ConcatV2ConstDependencyGraph() { - // root - auto root_graph = std::make_shared("root_graph"); - auto op_desc = std::make_shared("ifa", "IncreFlashAttention"); - - // set ifa ir - op_desc->AppendIrInput("query", ge::kIrInputRequired); - op_desc->AppendIrInput("key", ge::kIrInputDynamic); - op_desc->AppendIrInput("value", ge::kIrInputDynamic); - op_desc->AppendIrInput("pse_shift", ge::kIrInputOptional); - op_desc->AppendIrInput("atten_mask", ge::kIrInputOptional); - op_desc->AppendIrInput("actual_seq_lengths", ge::kIrInputOptional); - op_desc->MutableAllInputName() = {{"query", 0}, {"key0", 1}, {"value0", 2}, {"actual_seq_lengths", 5}}; - op_desc->AppendIrOutput("attention_out", ge::kIrOutputRequired); - op_desc->MutableAllOutputName() = {{"attention_out", 0}}; - - std::vector in_shape = {1, 4, 1, 1024}; - ge::GeShape shape(in_shape); - ge::GeTensorDesc tensor_desc(shape); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetShape(shape); - - ge::GeTensorDesc invalid_desc; - invalid_desc.SetDataType(ge::DT_UNDEFINED); - invalid_desc.SetFormat(ge::FORMAT_RESERVED); - - op_desc->AddInputDesc(tensor_desc); - op_desc->AddInputDesc(tensor_desc); - op_desc->AddInputDesc(tensor_desc); - op_desc->AddInputDesc(invalid_desc); - op_desc->AddInputDesc(invalid_desc); - op_desc->AddInputDesc(tensor_desc); - - op_desc->AddOutputDesc(tensor_desc); - - const auto node_id = op_desc->GetId(); - auto ifa_node = root_graph->AddNode(op_desc); - for (size_t i = 0UL; i < op_desc->GetAllInputsSize(); ++i) { - const auto input_desc = op_desc->GetInputDesc(i); - if (input_desc.IsValid() != ge::GRAPH_SUCCESS) { - GELOGD("Node: %s, input: %zu, is invalid, skip add edge.", op_desc->GetNamePtr(), i); - continue; - } - auto op_data = ge::OpDescBuilder(std::to_string(i), "Data").AddInput("x").AddOutput("y").Build(); - auto data_node = root_graph->AddNode(op_data); - ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), ifa_node->GetInDataAnchor(i)); - } - - for (size_t i = 0UL; i < op_desc->GetOutputsSize(); ++i) { - const auto input_desc = op_desc->GetOutputDesc(i); - auto out_op = ge::OpDescBuilder(std::to_string(i), "Data").AddInput("x").AddOutput("y").Build(); - auto out_node = root_graph->AddNode(out_op); - ge::GraphUtils::AddEdge(ifa_node->GetOutDataAnchor(i), out_node->GetInDataAnchor(0)); - } - - // AddNode operation may change node id to 0, which need to be recovered - op_desc->SetId(node_id); - - return root_graph; -} - -TEST_F(TilingContextBuilderUT, BuildDeviceTilingContextSuccess) { - auto graph = ConcatV2ConstDependencyGraph(); - fe::PlatFormInfos platform_infos; - auto node = graph->FindNode("ifa"); - EXPECT_NE(node, nullptr); - auto op_desc = node->GetOpDesc(); - EXPECT_NE(op_desc, nullptr); - - size_t total_plain_size{0UL}; - - size_t extend_info_size{0UL}; - bg::BufferPool buffer_pool; - auto compute_node_extend_holder = bg::CreateComputeNodeInfo(node, buffer_pool, extend_info_size); - EXPECT_NE(compute_node_extend_holder, nullptr); - - const size_t device_tiling_size = gert::DeviceTilingContextBuilder::CalcTotalTiledSize(op_desc); - size_t aligned_tiling_context_size = ge::RoundUp(static_cast(device_tiling_size), sizeof(uintptr_t)); - aligned_tiling_context_size += extend_info_size; - total_plain_size += aligned_tiling_context_size; - - // tiling - const auto aligned_tiling_size = 1024UL; - total_plain_size += aligned_tiling_size; - - const size_t workspace_addr_size = 16 * sizeof(gert::ContinuousVector); - total_plain_size += workspace_addr_size; - - auto host_pointer = std::unique_ptr(new uint8_t[total_plain_size]); - EXPECT_NE(host_pointer, nullptr); - auto device_addr = std::unique_ptr(new uint8_t[total_plain_size]); - EXPECT_NE(device_addr, nullptr); - - // copy tiling_data - uint8_t *context_host_begin = &host_pointer[aligned_tiling_size + workspace_addr_size]; - uint64_t context_dev_begin = ge::PtrToValue(device_addr.get()) + aligned_tiling_size + workspace_addr_size; - - auto space_registry = DefaultOpImplSpaceRegistry::GetInstance().GetDefaultSpaceRegistry(); - gert::Tensor host_tensor; - gert::Tensor device_tensor; - host_tensor.MutableTensorData().SetAddr(reinterpret_cast(0x120000), nullptr); - device_tensor.MutableTensorData().SetAddr(reinterpret_cast(0x120000), nullptr); - std::map index_to_tensor; - index_to_tensor[3] = {&host_tensor, ge::PtrToValue(&device_tensor)}; - - gert::TiledKernelContextHolder tiling_context_holder; - tiling_context_holder.compute_node_info_size_ = extend_info_size; - tiling_context_holder.host_compute_node_info_ = compute_node_extend_holder.get(); - - auto context_builder = gert::DeviceTilingContextBuilder(); - ge::Status ret = context_builder.PlatformInfo(reinterpret_cast(&platform_infos)) - .TilingData(device_addr.get()) - .Deterministic(0) - .CompileInfo(nullptr) - .Workspace(ge::ValueToPtr(ge::PtrToValue(device_addr.get()) + aligned_tiling_size)) - .AddrRefreshedInputTensor(index_to_tensor) - .TiledHolder(context_host_begin, context_dev_begin, - total_plain_size - aligned_tiling_size - workspace_addr_size) - .Build(node, tiling_context_holder); - // mock h2d - EXPECT_EQ(memcpy_s(device_addr.get(), total_plain_size, host_pointer.get(), total_plain_size), 0); - - auto host_context = reinterpret_cast(tiling_context_holder.host_context_); - EXPECT_NE(host_context, nullptr); - auto device_context = reinterpret_cast(tiling_context_holder.dev_context_addr_); - EXPECT_NE(device_context, nullptr); - - // checkout input chains - // input0 shape - EXPECT_NE(device_context->GetInputShape(0), nullptr); - EXPECT_EQ(device_context->GetInputShape(0)->GetStorageShape().GetDimNum(), - op_desc->GetInputDesc(0).GetShape().GetDimNum()); - for (size_t i = 0UL; i < device_context->GetInputShape(0)->GetStorageShape().GetDimNum(); ++i) { - EXPECT_EQ(device_context->GetInputShape(0)->GetStorageShape().GetDim(i), - op_desc->GetInputDesc(0).GetShape().GetDim(i)); - } - - // tiling depends tensor addr - const gert::Tensor *value_tensor = device_context->GetInputTensor(3); - EXPECT_EQ(value_tensor, &device_tensor); - EXPECT_EQ(value_tensor->GetTensorData().GetAddr(), reinterpret_cast(0x120000)); - - // checkout output chains - // tiling_data addr - const auto tiling_data_ptr = device_context->GetOutputPointer(TilingContext::kOutputTilingData); - EXPECT_EQ(ge::PtrToValue(tiling_data_ptr), ge::PtrToValue(device_addr.get())); - - // tiling_key_addr - const auto tiling_key_ptr = device_context->GetOutputPointer(TilingContext::kOutputTilingKey); - EXPECT_EQ(reinterpret_cast(tiling_key_ptr) % 128U, 0U); - EXPECT_EQ(ge::PtrToValue(tiling_key_ptr), - tiling_context_holder.output_addrs_[gert::TilingContext::TilingOutputIndex::kOutputTilingKey]); - device_context->SetTilingKey(0x123UL); - EXPECT_EQ(*tiling_key_ptr, 0x123UL); - - // block_dim addr - const auto block_dim_ptr = device_context->GetOutputPointer(TilingContext::kOutputBlockDim); - EXPECT_EQ(reinterpret_cast(block_dim_ptr) % 128U, 0U); - EXPECT_EQ(ge::PtrToValue(block_dim_ptr), - tiling_context_holder.output_addrs_[gert::TilingContext::TilingOutputIndex::kOutputBlockDim]); - device_context->SetBlockDim(40U); - EXPECT_EQ(*block_dim_ptr, 40U); - - // op type - char *op_type = reinterpret_cast(tiling_context_holder.dev_op_type_addr_); - EXPECT_EQ(op_type, op_desc->GetType()); - - EXPECT_EQ(ret, ge::SUCCESS); -} - -TEST_F(TilingContextBuilderUT, GetDependInputTensorAddr_Data_Input_Success) { - auto node = ComputeNodeFaker().NameAndType("Test", "Test").IoNum(1, 1).InputNames({"x"}).Build(); - auto builder = TilingContextBuilder(); - TensorAddress address = nullptr; - const auto ret = builder.GetDependInputTensorAddr(ge::OpDescUtils::CreateOperatorFromNode(node), 0, address); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_TRUE(address == nullptr); -} - -TEST_F(TilingContextBuilderUT, GetDependInputTensorAddr_Const_Input_Success) { - auto node = ComputeNodeFaker().NameAndType("Test", "Test").IoNum(1, 1).InputNames({"x"}).Build(); - auto const1 = ComputeNodeFaker().NameAndType("Const", "Const").IoNum(1, 1).InputNames({"x"}).Build(); - ge::GraphUtils::AddEdge(const1->GetOutDataAnchor(0), node->GetInDataAnchor(0)); - int32_t weight[1] = {1}; - ge::GeTensorDesc weight_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_INT32); - ge::GeTensorPtr tensor0 = std::make_shared(weight_desc, (uint8_t *) weight, sizeof(weight)); - ge::OpDescUtils::SetWeights(const1, {tensor0}); - auto builder = TilingContextBuilder(); - TensorAddress address = 0x0; - const auto ret = builder.GetDependInputTensorAddr(ge::OpDescUtils::CreateOperatorFromNode(node), 0, address); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_NE(address, nullptr); -} -} // namespace gert diff --git a/tests/ut/exe_graph/tiling_context_unittest.cc b/tests/ut/exe_graph/tiling_context_unittest.cc deleted file mode 100644 index 6051633773e18bae2b7d388cda418f03e1316fdf..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/tiling_context_unittest.cc +++ /dev/null @@ -1,669 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "exe_graph/runtime/tiling_context.h" -#include "faker/kernel_run_context_faker.h" -#include "platform/platform_infos_def.h" - -namespace gert { -class TilingContextUT : public testing::Test {}; -namespace { -struct TestTilingData { - int64_t a; -}; -struct TestCompileInfo { - int64_t a; - int64_t b; - std::vector c; -}; -} // namespace -TEST_F(TilingContextUT, GetCompileInfoOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - fe::PlatFormInfos platform_info; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .PlatformInfo(reinterpret_cast(&platform_info)) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - auto compile_info = reinterpret_cast(context->GetCompileInfo()); - ASSERT_NE(compile_info, nullptr); - EXPECT_EQ(compile_info->a, 10); - EXPECT_EQ(compile_info->b, 200); - EXPECT_EQ(compile_info->c, compile_info_holder.c); - EXPECT_EQ(context->GetPlatformInfo()->GetCoreNum(), 8); -} - -TEST_F(TilingContextUT, GetInputShapeOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - ASSERT_NE(context->GetInputShape(0), nullptr); - EXPECT_EQ(context->GetInputShape(0)->GetOriginShape(), in_shape.GetOriginShape()); - EXPECT_EQ(context->GetInputShape(0)->GetStorageShape(), in_shape.GetStorageShape()); - ASSERT_EQ(context->GetInputShape(1), nullptr); -} - -TEST_F(TilingContextUT, GetDynamicInputShapeOk) { - gert::StorageShape in_shape0 = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape1 = {{2, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape2 = {{3, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape3 = {{4, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape4 = {{5, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(5, 1) - .IrInstanceNum({1,2,1,0,1}) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(4, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape0, &in_shape1, &in_shape2, &in_shape3, &in_shape4}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetDynamicInputShape(0, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(0, 0), in_shape0); - EXPECT_EQ(context->GetDynamicInputShape(0, 1), nullptr); - - ASSERT_NE(context->GetDynamicInputShape(1, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(1, 0), in_shape1); - ASSERT_NE(context->GetDynamicInputShape(1, 1), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(1, 1), in_shape2); - EXPECT_EQ(context->GetDynamicInputShape(1, 2), nullptr); - - ASSERT_NE(context->GetDynamicInputShape(2, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(2, 0), in_shape3); - EXPECT_EQ(context->GetDynamicInputShape(2, 1), nullptr); - - EXPECT_EQ(context->GetOptionalInputShape(3), nullptr); - - ASSERT_NE(context->GetOptionalInputShape(4), nullptr); - EXPECT_EQ(*context->GetOptionalInputShape(4), in_shape4); -} - -TEST_F(TilingContextUT, GetDynamicInputDescOk) { - gert::StorageShape in_shape0 = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape1 = {{2, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape2 = {{3, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape3 = {{4, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape4 = {{5, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(5, 1) - .IrInstanceNum({1,2,1,0,1}) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_FRACTAL_Z, ge::FORMAT_ND) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_C1HWNC0, ge::FORMAT_ND) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(4, ge::DT_FLOAT16, ge::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape0, &in_shape1, &in_shape2, &in_shape3, &in_shape4}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - - auto tensor_desc1 = context->GetDynamicInputDesc(1, 0); - ASSERT_NE(tensor_desc1, nullptr); - EXPECT_EQ(tensor_desc1->GetOriginFormat(), ge::FORMAT_FRACTAL_Z); - - auto tensor_desc2 = context->GetDynamicInputDesc(1, 1); - ASSERT_NE(tensor_desc2, nullptr); - EXPECT_EQ(tensor_desc2->GetOriginFormat(), ge::FORMAT_C1HWNC0); - - auto tensor_desc3 = context->GetDynamicInputDesc(1, 2); - ASSERT_EQ(tensor_desc3, nullptr); - - EXPECT_EQ(context->GetOptionalInputDesc(3), nullptr); - - ASSERT_NE(context->GetRequiredInputDesc(4), nullptr); - EXPECT_EQ(context->GetRequiredInputDesc(4)->GetOriginFormat(), ge::FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS); -} - -TEST_F(TilingContextUT, GetOutputShapeOk) { - gert::StorageShape in_shape0 = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape1 = {{2, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape2 = {{3, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape3 = {{4, 16, 256}, {1, 16, 256}}; - gert::StorageShape in_shape4 = {{5, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{6, 16, 256}, {1, 16, 1, 16, 16}}; - - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(5, 1) - .IrInstanceNum({1,2,1,0,1}) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(4, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape0, &in_shape1, &in_shape2, &in_shape3, &in_shape4}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetOutputShape(0), nullptr); - EXPECT_EQ(*context->GetOutputShape(0), out_shape); - ASSERT_EQ(context->GetOutputShape(1), nullptr); -} -TEST_F(TilingContextUT, GetInputTensorOk) { - gert::Tensor in_tensor = {{{1, 16, 256}, {1, 16, 256}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0xabc}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({reinterpret_cast(&in_tensor)}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - ASSERT_NE(context->GetInputTensor(0), nullptr); - EXPECT_EQ(context->GetInputTensor(0)->GetOriginShape(), in_tensor.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(0)->GetStorageShape(), in_tensor.GetStorageShape()); - EXPECT_EQ(context->GetInputTensor(0)->GetOriginFormat(), in_tensor.GetOriginFormat()); - EXPECT_EQ(context->GetInputTensor(0)->GetStorageFormat(), in_tensor.GetStorageFormat()); - EXPECT_EQ(context->GetInputTensor(0)->GetExpandDimsType(), in_tensor.GetExpandDimsType()); - EXPECT_EQ(context->GetInputTensor(0)->GetDataType(), in_tensor.GetDataType()); - EXPECT_EQ(context->GetInputTensor(0)->GetAddr(), in_tensor.GetAddr()); -} - -TEST_F(TilingContextUT, GetDynamicInputTensorOk) { - gert::Tensor in_tensor = {{{1, 16}, {1, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0x234}; - gert::Tensor in_tensor1 = {{{1, 16}, {1, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0x123}; - gert::Tensor in_tensor2 = {{{1, 16}, {1, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0x456}; - gert::StorageShape out_shape = {{1, 16}, {1, 16}}; - - // tiling data - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(3, 1) - .IrInstanceNum({1, 0, 2}) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({reinterpret_cast(&in_tensor), - reinterpret_cast(&in_tensor1), - reinterpret_cast(&in_tensor2)}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetInputTensor(0), nullptr); - EXPECT_EQ(context->GetInputTensor(0)->GetOriginShape(), in_tensor.GetOriginShape()); - EXPECT_EQ(context->GetInputTensor(0)->GetStorageShape(), in_tensor.GetStorageShape()); - EXPECT_EQ(context->GetInputTensor(0)->GetOriginFormat(), in_tensor.GetOriginFormat()); - EXPECT_EQ(context->GetInputTensor(0)->GetStorageFormat(), in_tensor.GetStorageFormat()); - EXPECT_EQ(context->GetInputTensor(0)->GetExpandDimsType(), in_tensor.GetExpandDimsType()); - EXPECT_EQ(context->GetInputTensor(0)->GetDataType(), in_tensor.GetDataType()); - EXPECT_EQ(context->GetInputTensor(0)->GetAddr(), in_tensor.GetAddr()); - - ASSERT_EQ(context->GetOptionalInputTensor(1), nullptr); - - ASSERT_NE(context->GetDynamicInputTensor(2, 0), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetOriginShape(), in_tensor1.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetStorageShape(), in_tensor1.GetStorageShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetOriginFormat(), in_tensor1.GetOriginFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetStorageFormat(), in_tensor1.GetStorageFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetExpandDimsType(), in_tensor1.GetExpandDimsType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetDataType(), in_tensor1.GetDataType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetAddr(), in_tensor1.GetAddr()); - - ASSERT_NE(context->GetDynamicInputTensor(2, 1), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetOriginShape(), in_tensor2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetStorageShape(), in_tensor2.GetStorageShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetOriginFormat(), in_tensor2.GetOriginFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetStorageFormat(), in_tensor2.GetStorageFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetExpandDimsType(), in_tensor2.GetExpandDimsType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetDataType(), in_tensor2.GetDataType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetAddr(), in_tensor2.GetAddr()); -} - -TEST_F(TilingContextUT, SetTypedTilingDataOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - auto tiling_data = context->GetTilingData(); - ASSERT_NE(tiling_data, nullptr); - tiling_data->a = 10; - auto root_tiling_data = reinterpret_cast(param.get()); - - EXPECT_EQ(root_tiling_data->GetDataSize(), sizeof(TestTilingData)); - EXPECT_EQ(root_tiling_data->GetData(), tiling_data); -} - -TEST_F(TilingContextUT, SetTypedTilingDataOutOfBounds) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - auto param = gert::TilingData::CreateCap(4); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - auto tiling_data = context->GetTilingData(); - EXPECT_EQ(tiling_data, nullptr); -} -TEST_F(TilingContextUT, SetAppendTilingDataOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - - // 算子tiling中可以用如下操作append - auto tiling_data = context->GetRawTilingData(); - ASSERT_NE(tiling_data, nullptr); - - tiling_data->Append(static_cast(10)); - tiling_data->Append(static_cast(20)); - tiling_data->Append(static_cast(30)); - tiling_data->Append(static_cast(40)); - tiling_data->Append(static_cast(50)); - tiling_data->Append(static_cast(60)); - - EXPECT_EQ(context->GetRawTilingData()->GetDataSize(), 31); // 3 * 8 + 4 + 2 + 1 -} - -TEST_F(TilingContextUT, SetTilingKeyOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - - context->SetTilingKey(20); - EXPECT_EQ(context->GetTilingKey(), 20); - EXPECT_EQ( - *reinterpret_cast( - &(holder.holder.value_holder_[holder.kernel_input_num + TilingContext::kOutputTilingKey].any_value_.data)), - 20); -} -TEST_F(TilingContextUT, SetBlockDimOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - - context->SetBlockDim(10); - EXPECT_EQ(context->GetBlockDim(), 10); - EXPECT_EQ(*reinterpret_cast(&( - holder.holder.value_holder_[holder.kernel_input_num + TilingContext::kOutputBlockDim].any_value_.data)), - 10); - EXPECT_EQ(context->SetAicpuBlockDim(4U), ge::GRAPH_SUCCESS); - EXPECT_EQ(context->GetAicpuBlockDim(), 4U); -} - -TEST_F(TilingContextUT, TestSetLocalMemorySizeFail) { - // no mem for local mem size, set fail and get invalid value. - auto holder = gert::TilingContextFaker().Build(); - auto context = holder.GetContext(); - - // set fail - EXPECT_NE(context->SetLocalMemorySize(10), ge::GRAPH_SUCCESS); - - // get invalid value - EXPECT_EQ(context->GetLocalMemorySize(), std::numeric_limits::max()); -} - -TEST_F(TilingContextUT, TestSetLocalMemorySize) { - auto holder = gert::TilingContextFaker().NodeIoNum(1, 1).Build(); - auto context = holder.GetContext(); - - // no set, get default value (0) - EXPECT_EQ(context->GetLocalMemorySize(), 0U); - - // set and get test - uint32_t test_value = 1000; - EXPECT_EQ(context->SetLocalMemorySize(test_value), ge::GRAPH_SUCCESS); - EXPECT_EQ(context->GetLocalMemorySize(), test_value); -} - -TEST_F(TilingContextUT, SetNeedAtomicOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - - context->SetNeedAtomic(true); - EXPECT_TRUE(context->NeedAtomic()); - EXPECT_TRUE(*reinterpret_cast( - &(holder.holder.value_holder_[holder.kernel_input_num + TilingContext::kOutputAtomicCleanFlag].any_value_.data))); - - context->SetNeedAtomic(false); - EXPECT_FALSE(context->NeedAtomic()); - EXPECT_FALSE(*reinterpret_cast( - &(holder.holder.value_holder_[holder.kernel_input_num + TilingContext::kOutputAtomicCleanFlag].any_value_.data))); -} - -TEST_F(TilingContextUT, SetWorkspaceSizesOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - auto workspace_size_holder = ContinuousVector::Create(8); - auto ws_size = reinterpret_cast(workspace_size_holder.get()); - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Workspace(ws_size) - .Build(); - - auto context = holder.GetContext(); - - auto ws = context->GetWorkspaceSizes(1); - ASSERT_NE(ws, nullptr); - ws[0] = 10240; - EXPECT_EQ(ws_size->GetSize(), 1); - EXPECT_EQ(context->GetWorkspaceNum(), 1U); - EXPECT_EQ(reinterpret_cast(ws_size->GetData())[0], 10240); -} - -TEST_F(TilingContextUT, SetWorkspaceSizesOutOfBounds) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - auto workspace_size_holder = ContinuousVector::Create(0); - auto ws_size = reinterpret_cast(workspace_size_holder.get()); - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Workspace(ws_size) - .Build(); - - auto context = holder.GetContext(); - - auto ws = context->GetWorkspaceSizes(9); - EXPECT_EQ(ws, nullptr); -} - -TEST_F(TilingContextUT, SetTilingCondOk) { - gert::StorageShape in_shape = {{1, 16, 256}, {1, 16, 256}}; - gert::StorageShape out_shape = {{1, 16, 256}, {1, 16, 1, 16, 16}}; - - // tiling data - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(1, 1) - .IrInputNum(1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({&in_shape}) - .OutputShapes({&out_shape}) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - context->SetTilingCond(10); - EXPECT_EQ(context->GetTilingCond(), 10); -} - -TEST_F(TilingContextUT, GetRequiredInputTensorOk) { - gert::Tensor in_tensor = {{{1, 16}, {1, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0x0}; - gert::Tensor in_tensor1 = {{{1, 16}, {1, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0x0}; - gert::Tensor in_tensor2 = {{{1, 16}, {1, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0x0}; - gert::Tensor in_tensor3 = {{{1, 16}, {1, 16}}, // shape - {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, {}}, // format - kOnDeviceHbm, // placement - ge::DT_FLOAT16, // data type - (void *) 0x1234}; - gert::StorageShape out_shape = {{1, 16}, {1, 16}}; - - // tiling data - TestCompileInfo compile_info_holder = {10, 200, {10, 20, 30}}; - auto param = gert::TilingData::CreateCap(2048); - auto holder = gert::TilingContextFaker() - .NodeIoNum(4, 1) - .IrInstanceNum({1, 0, 2, 1}) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(2, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeInputTd(3, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_ND) - .NodeOutputTd(0, ge::DT_FLOAT16, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ) - .InputShapes({reinterpret_cast(&in_tensor), - reinterpret_cast(&in_tensor1), - reinterpret_cast(&in_tensor2), - reinterpret_cast(&in_tensor3)}) - .OutputShapes({&out_shape}) - .CompileInfo(&compile_info_holder) - .TilingData(param.get()) - .Build(); - - auto context = holder.GetContext(); - ASSERT_NE(context, nullptr); - - ASSERT_NE(context->GetRequiredInputShape(0), nullptr); - EXPECT_EQ(*context->GetRequiredInputShape(0), in_tensor.GetShape()); - - ASSERT_NE(context->GetRequiredInputTensor(0), nullptr); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetOriginShape(), in_tensor.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetStorageShape(), in_tensor.GetStorageShape()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetOriginFormat(), in_tensor.GetOriginFormat()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetStorageFormat(), in_tensor.GetStorageFormat()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetExpandDimsType(), in_tensor.GetExpandDimsType()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetDataType(), in_tensor.GetDataType()); - EXPECT_EQ(context->GetRequiredInputTensor(0)->GetAddr(), in_tensor.GetAddr()); - - ASSERT_EQ(context->GetOptionalInputTensor(1), nullptr); - - ASSERT_NE(context->GetDynamicInputShape(2, 0), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(2, 0), in_tensor1.GetShape()); - - ASSERT_NE(context->GetDynamicInputTensor(2, 0), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetOriginShape(), in_tensor1.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetStorageShape(), in_tensor1.GetStorageShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetOriginFormat(), in_tensor1.GetOriginFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetStorageFormat(), in_tensor1.GetStorageFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetExpandDimsType(), in_tensor1.GetExpandDimsType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetDataType(), in_tensor1.GetDataType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 0)->GetAddr(), in_tensor1.GetAddr()); - - ASSERT_NE(context->GetDynamicInputShape(2, 1), nullptr); - EXPECT_EQ(*context->GetDynamicInputShape(2, 1), in_tensor2.GetShape()); - - ASSERT_NE(context->GetDynamicInputTensor(2, 1), nullptr); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetOriginShape(), in_tensor2.GetOriginShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetStorageShape(), in_tensor2.GetStorageShape()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetOriginFormat(), in_tensor2.GetOriginFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetStorageFormat(), in_tensor2.GetStorageFormat()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetExpandDimsType(), in_tensor2.GetExpandDimsType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetDataType(), in_tensor2.GetDataType()); - EXPECT_EQ(context->GetDynamicInputTensor(2, 1)->GetAddr(), in_tensor2.GetAddr()); - - ASSERT_NE(context->GetRequiredInputShape(3), nullptr); - EXPECT_EQ(*context->GetRequiredInputShape(3), in_tensor3.GetShape()); - - ASSERT_NE(context->GetRequiredInputTensor(3), nullptr); - EXPECT_EQ(context->GetRequiredInputShape(3)->GetOriginShape(), in_tensor3.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputShape(3)->GetStorageShape(), in_tensor3.GetStorageShape()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetOriginShape(), in_tensor3.GetOriginShape()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetStorageShape(), in_tensor3.GetStorageShape()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetOriginFormat(), in_tensor3.GetOriginFormat()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetStorageFormat(), in_tensor3.GetStorageFormat()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetExpandDimsType(), in_tensor3.GetExpandDimsType()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetDataType(), in_tensor3.GetDataType()); - EXPECT_EQ(context->GetRequiredInputTensor(3)->GetAddr(), in_tensor3.GetAddr()); -} -} // namespace gert diff --git a/tests/ut/exe_graph/tiling_parse_context_builder_unittest.cc b/tests/ut/exe_graph/tiling_parse_context_builder_unittest.cc deleted file mode 100644 index c49fce3c338cf8986e10a3e26831eeec2728c25e..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/tiling_parse_context_builder_unittest.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "lowering/tiling_parse_context_builder.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "exe_graph/runtime/tiling_parse_context.h" -#include "platform/platform_infos_def.h" -#include "common/ge_common/util.h" -#include "register/op_impl_space_registry.h" -#include "register/op_impl_registry.h" -#include "faker/node_faker.h" - -namespace gert { -class TilingParseContextBuilderUT : public testing::Test {}; - -TEST_F(TilingParseContextBuilderUT, CompileInfoNullptr) { - fe::PlatFormInfos platform_infos; - auto builder = TilingParseContextBuilder(); - - auto node = ComputeNodeFaker().NameAndType("bar", "Bar").IoNum(2, 1).InputNames({"x", "y"}).Build(); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - auto holder = builder.CompileJson(nullptr).PlatformInfo(&platform_infos).Build(op); - EXPECT_NE(holder.context_, nullptr); -} - -TEST_F(TilingParseContextBuilderUT, PlatformInfosNullptr) { - std::string op_compile_info_json = "{}"; - fe::PlatFormInfos platform_infos; - auto builder = TilingParseContextBuilder(); - - auto node = ComputeNodeFaker().NameAndType("bar", "Bar").IoNum(2, 1).InputNames({"x", "y"}).Build(); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - auto holder = builder.CompileJson(op_compile_info_json.c_str()).PlatformInfo(nullptr).Build(op); - EXPECT_NE(holder.context_, nullptr); -} - -TEST_F(TilingParseContextBuilderUT, TilingFuncNullptr) { - std::string op_compile_info_json = "{}"; - fe::PlatFormInfos platform_infos; - auto builder = TilingParseContextBuilder(); - - // construct op - auto node = ComputeNodeFaker().NameAndType("bar", "Bar").IoNum(2, 1).InputNames({"x", "y"}).Build(); - ge::OpDescPtr op_desc = node->GetOpDesc(); - ge::GeTensorDesc tensor_desc(ge::GeShape({1})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - op_desc->MutableInputDesc(1)->SetDataType(ge::DT_INT32); - op_desc->MutableInputDesc(1)->SetShape(ge::GeShape({1})); - op_desc->MutableInputDesc(1)->SetOriginShape(ge::GeShape({1})); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - - auto holder = builder.CompileJson(op_compile_info_json.c_str()).PlatformInfo(&platform_infos).Build(op); - EXPECT_NE(holder.context_, nullptr); -} - -TEST_F(TilingParseContextBuilderUT, BuildSuccess) { - std::string op_compile_info_json = "{123}"; - fe::PlatFormInfos platform_infos; - auto builder = TilingParseContextBuilder(); - - // construct op - auto node = ComputeNodeFaker().NameAndType("bar", "Bar").IoNum(2, 1).InputNames({"x", "y"}).Build(); - ge::OpDescPtr op_desc = node->GetOpDesc(); - ge::GeTensorDesc tensor_desc(ge::GeShape({1})); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - op_desc->MutableInputDesc(1)->SetDataType(ge::DT_INT32); - op_desc->MutableInputDesc(1)->SetShape(ge::GeShape({1})); - op_desc->MutableInputDesc(1)->SetOriginShape(ge::GeShape({1})); - auto op = ge::OpDescUtils::CreateOperatorFromNode(node->shared_from_this()); - - OpImplRegisterV2::CompileInfoCreatorFunc create_func = []() -> void * { return new int32_t(0); }; - - OpImplRegisterV2::CompileInfoDeleterFunc delete_func = [](void *ptr) { delete reinterpret_cast(ptr); }; - - auto holder = builder.CompileJson(op_compile_info_json.c_str()) - .PlatformInfo(&platform_infos) - .CompileInfoCreatorFunc(create_func) - .CompileInfoDeleterFunc(delete_func) - .Build(op); - EXPECT_NE(holder.GetKernelContext(), nullptr); - auto tiling_parse_context = reinterpret_cast(holder.context_); - EXPECT_NE(tiling_parse_context->GetCompiledInfo(), nullptr); - EXPECT_NE(tiling_parse_context->GetPlatformInfo(), nullptr); - EXPECT_EQ(*tiling_parse_context->GetCompiledInfo(), 0); - EXPECT_STREQ(tiling_parse_context->GetCompiledJson(), "{123}"); -} -} // namespace gert diff --git a/tests/ut/exe_graph/tiling_parse_context_unittest.cc b/tests/ut/exe_graph/tiling_parse_context_unittest.cc deleted file mode 100644 index 565285b247edd850487e8e25909cb8b30e869878..0000000000000000000000000000000000000000 --- a/tests/ut/exe_graph/tiling_parse_context_unittest.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "exe_graph/runtime/tiling_parse_context.h" -#include -#include -#include "faker/kernel_run_context_faker.h" -#include "platform/platform_info.h" -namespace gert { -class TilingParseContextUT : public testing::Test {}; -struct CompiledInfo1 { - uint64_t a; - uint64_t b; -}; -struct CompileDInfo2 { - uint32_t a; - uint32_t b; -}; -struct CompiledInfo3 { - int32_t core_num; -}; - -TEST_F(TilingParseContextUT, GetIoOk) { - char *json_str = const_cast("{}"); - CompiledInfo1 ci = {10, 20}; - auto context_holder = KernelRunContextFaker().KernelIONum(1, 1).Inputs({json_str}).Outputs({&ci}).Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - EXPECT_STREQ(context->GetCompiledJson(), "{}"); - ASSERT_NE(context->GetCompiledInfo(), nullptr); - EXPECT_EQ(context->GetCompiledInfo()->a, 10); - EXPECT_EQ(context->GetCompiledInfo()->b, 20); -} -TEST_F(TilingParseContextUT, SetCompiledInfoOk) { - char *json_str = const_cast("{}"); - fe::PlatFormInfos platform_info; - auto context_holder = KernelRunContextFaker().KernelIONum(1, 1) - .Inputs({json_str, reinterpret_cast(&platform_info)}) - .Outputs({nullptr}).Build(); - auto context = context_holder.GetContext(); - EXPECT_EQ(context->GetPlatformInfo()->GetCoreNum(), 8); -} - -TEST_F(TilingParseContextUT, CompiledInfoLessThan8Bytes) { - char *json_str = const_cast("{}"); - CompiledInfo3 ci = {2}; - auto context_holder = KernelRunContextFaker().KernelIONum(1, 1).Inputs({json_str}).Outputs({&ci}).Build(); - auto context = context_holder.GetContext(); - ASSERT_NE(context, nullptr); - EXPECT_STREQ(context->GetCompiledJson(), "{}"); - ASSERT_NE(context->GetCompiledInfo(), nullptr); - EXPECT_EQ(context->GetCompiledInfo()->core_num, 2); -} -} // namespace gert diff --git a/tests/ut/expression/CMakeLists.txt b/tests/ut/expression/CMakeLists.txt deleted file mode 100644 index 8ada96ac12c5708c521802739fbbe4e3fe7465fd..0000000000000000000000000000000000000000 --- a/tests/ut/expression/CMakeLists.txt +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${METADEF_DIR}/inc/common/util/trace_manager) -include_directories(${CMAKE_BINARY_DIR}/proto/metadef_protos) -include_directories(${CMAKE_BINARY_DIR}/proto/metadef_protos/proto) -include_directories(${METADEF_DIR}) -include_directories(${METADEF_DIR}/exe_graph) -include_directories(${METADEF_DIR}/tests/depends) -include_directories(${METADEF_DIR}/graph) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) -include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) - -file(GLOB_RECURSE UT_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/expression/testcase/*.cc") -file(GLOB_RECURSE FAKER_SRCS CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/faker/*.cc") -add_executable(ut_expression ${UT_FILES} ${FAKER_SRCS}) - -target_compile_options(ut_expression PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Wno-deprecated-declarations - -Wall -Wfloat-equal -Werror - -D_GLIBCXX_USE_CXX11_ABI=0 - -fno-access-control -) - -target_compile_definitions(ut_expression PRIVATE - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - FUNC_VISIBILITY -) - -# intf_pub包含stdc++11 当前需要按照更高版本来编译ut -target_link_libraries(ut_expression PRIVATE - -lgcov - -Wl,--no-as-needed - platform_stub - slog_headers - metadef_headers - register - rt2_registry_static - opp_registry - slog - symengine Boost::boost aihac_symbolizer - GTest::gtest GTest::gtest_main ascend_protobuf error_manager graph_base graph slog_stub c_sec json mmpa_stub -lrt -ldl -) - -target_include_directories(ut_expression PRIVATE - ${METADEF_DIR}/tests/depends - ${METADEF_DIR}/tests/ut/ascendc_ir -) diff --git a/tests/ut/expression/testcase/attr_group_shape_env_unittest.cc b/tests/ut/expression/testcase/attr_group_shape_env_unittest.cc deleted file mode 100644 index 4665e20c7bb50ac24f809aaded9db81bdc931771..0000000000000000000000000000000000000000 --- a/tests/ut/expression/testcase/attr_group_shape_env_unittest.cc +++ /dev/null @@ -1,406 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "proto/ge_ir.pb.h" -#include "graph/attr_store.h" -#include "graph/symbolizer/guard_dfx_context.h" - -#include "attribute_group/attr_group_shape_env.h" - -#include -#include "source_stub.h" -namespace ge { -namespace { -class AttributeGroupShapeEnvUt : public testing::Test {}; - -TEST_F(AttributeGroupShapeEnvUt, ShapeEnvAttrDeserialize_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(true, DynamicMode::kDynamic)); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(1, 2)); - Symbol s1 = shape_env.CreateSymbol(5, MakeShared(2, 0)); - - auto guard_1 = sym::Eq(s1 + Symbol(2), s0); - auto guard_2 = sym::Gt(s0 + Symbol(1), s1); - shape_env.AppendSymbolCheckInfo(guard_1); - shape_env.AppendSymbolAssertInfo(guard_2); - shape_env.AppendReplacement(s1 + Symbol(2), s0); - proto::AttrGroupDef attr_group_def; - auto ret = shape_env.Serialize(attr_group_def); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto shape_env1 = ShapeEnvAttr(); - ret = shape_env1.Deserialize(attr_group_def, nullptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(2, shape_env1.symbol_to_value_.size()); - EXPECT_EQ(2, shape_env1.value_to_symbol_.size()); - EXPECT_EQ(1, shape_env1.GetAllSymbolCheckInfos().size()); - EXPECT_EQ(1, shape_env1.GetAllSymbolAssertInfos().size()); - EXPECT_EQ(Symbol("s0"), shape_env1.value_to_symbol_[2][0]); - EXPECT_EQ(Symbol("s1"), shape_env1.value_to_symbol_[5][0]); - EXPECT_EQ(true, shape_env1.HasSymbolCheckInfo(sym::Eq(Symbol("s1") + Symbol(2), Symbol("s0")))); - EXPECT_EQ(true, shape_env1.HasSymbolAssertInfo(sym::Gt(Symbol("s0") + Symbol(1), Symbol("s1")))); -} - -TEST_F(AttributeGroupShapeEnvUt, CreateSymbolDuckMode_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDuck)); - Symbol sym = shape_env.CreateSymbol(2, MakeShared(0, 2)); - EXPECT_EQ(1, shape_env.symbol_to_value_.size()); - EXPECT_EQ(1, shape_env.value_to_symbol_.size()); - EXPECT_EQ(1, shape_env.symbol_to_source_.size()); - EXPECT_EQ(sym, shape_env.value_to_symbol_[2][0]); - EXPECT_EQ(2, shape_env.symbol_to_value_[sym]); - EXPECT_EQ(0, shape_env.symbol_to_source_[sym]->GetGlobalIndex()); - EXPECT_EQ(std::string(sym.Serialize().get()), "s0"); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(1, 0)); - EXPECT_EQ(sym, sym1); -} - -TEST_F(AttributeGroupShapeEnvUt, CreateSymbolSpecializeZeroOne_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(true, DynamicMode::kDynamic)); - Symbol sym = shape_env.CreateSymbol(1, MakeShared(0, 2)); - EXPECT_EQ(0, shape_env.symbol_to_value_.size()); - EXPECT_EQ(0, shape_env.value_to_symbol_.size()); - EXPECT_EQ(0, shape_env.symbol_to_source_.size()); - EXPECT_EQ(std::string(sym.Serialize().get()), "1"); - Symbol sym1 = shape_env.CreateSymbol(0, MakeShared(0, 3)); - EXPECT_EQ(0, shape_env.symbol_to_value_.size()); - EXPECT_EQ(0, shape_env.value_to_symbol_.size()); - EXPECT_EQ(0, shape_env.symbol_to_source_.size()); - EXPECT_EQ(std::string(sym1.Serialize().get()), "0"); -} - -TEST_F(AttributeGroupShapeEnvUt, CreateSymbolDynamicMode_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - Symbol sym = shape_env.CreateSymbol(2, MakeShared(0, 2)); - auto symbol_relation = shape_env.GetAllSym2Src(); - EXPECT_EQ(symbol_relation.size(), 1); - EXPECT_EQ(symbol_relation[0].first, sym); - EXPECT_EQ(symbol_relation[0].second->GetSourceStr(), MakeShared(0, 2)->GetSourceStr()); - EXPECT_EQ(1, shape_env.symbol_to_value_.size()); - EXPECT_EQ(1, shape_env.value_to_symbol_.size()); - EXPECT_EQ(1, shape_env.symbol_to_source_.size()); - EXPECT_EQ(1, shape_env.value_to_symbol_[2].size()); - EXPECT_EQ(sym, shape_env.value_to_symbol_[2][0]); - EXPECT_EQ(2, shape_env.symbol_to_value_[sym]); - EXPECT_EQ(0, shape_env.symbol_to_source_[sym]->GetGlobalIndex()); - EXPECT_EQ(std::string(sym.Serialize().get()), "s0"); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(1, 0)); - EXPECT_EQ(2, shape_env.symbol_to_value_.size()); - EXPECT_EQ(1, shape_env.value_to_symbol_.size()); - EXPECT_EQ(2, shape_env.symbol_to_source_.size()); - EXPECT_EQ(2, shape_env.value_to_symbol_[2].size()); - EXPECT_EQ(sym, shape_env.value_to_symbol_[2][0]); - EXPECT_EQ(sym1, shape_env.value_to_symbol_[2][1]); - EXPECT_EQ(2, shape_env.symbol_to_value_[sym1]); - EXPECT_EQ(1, shape_env.symbol_to_source_[sym1]->GetGlobalIndex()); - EXPECT_EQ(std::string(sym1.Serialize().get()), "s1"); -} - -TEST_F(AttributeGroupShapeEnvUt, CreateSymbolStaticMode_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kStatic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym = shape_env.CreateSymbol(2, MakeShared(0, 2)); - EXPECT_EQ(1, shape_env.symbol_to_value_.size()); - EXPECT_EQ(1, shape_env.value_to_symbol_.size()); - EXPECT_EQ(1, shape_env.symbol_to_source_.size()); - EXPECT_EQ(1, shape_env.value_to_symbol_[2].size()); - EXPECT_EQ(sym, shape_env.value_to_symbol_[2][0]); - EXPECT_EQ(2, shape_env.symbol_to_value_[sym]); - EXPECT_EQ(0, shape_env.symbol_to_source_[sym]->GetGlobalIndex()); - EXPECT_EQ(std::string(sym.Serialize().get()), "s0"); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 2)); - EXPECT_EQ(sym, sym1); - EXPECT_EQ(true, shape_env.HasSymbolAssertInfo(sym::Eq(sym, Symbol(2)))); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(AttributeGroupShapeEnvUt, GetInputShapeSourceStr) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - Symbol sym_shape = shape_env.CreateSymbol(2, MakeShared(0, 2)); - Symbol sym_value = shape_env.CreateSymbol(3, MakeShared(1, 2)); - - auto shape_str = shape_env.GetAllSym2Src(); - EXPECT_EQ(shape_str.size(), 2); - EXPECT_EQ(shape_str[1].second->GetSourceStr(), R"([&]() -> int64_t { - const auto *tensor = context->GetGraphInputTensor(0); - if (tensor == nullptr) { - return -1; - } - return tensor->GetOriginShape().GetDim(2); - }())"); - EXPECT_EQ(shape_str[0].second->GetSourceStr(), R"([&]() -> int64_t { - const auto *tensor = context->GetGraphInputTensor(1); - if (tensor == nullptr) { - return -1; - } - return tensor->GetOriginShape().GetDim(2); - }())"); - EXPECT_EQ(shape_str[0].second->GetGlobalIndex(), 1); - EXPECT_EQ(shape_str[1].second->GetGlobalIndex(), 0); - EXPECT_EQ(shape_str[0].second->GetGlobalIndexStr(), "context->GetInputPointer(1)"); - EXPECT_EQ(shape_str[1].second->GetGlobalIndexStr(), "context->GetInputPointer(0)"); -} - -TEST_F(AttributeGroupShapeEnvUt, GetSourceStrNew) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - auto source_shape = ge::MakeShared(0, ge::DT_INT32); - Symbol sym_shape = shape_env.CreateSymbol(2, source_shape); - - auto source_value = ge::MakeShared(2, ge::DT_INT32); - Symbol sym_value = shape_env.CreateSymbol(3, source_value); - - auto source_value2 = ge::MakeShared(4, ge::DT_INT64); - Symbol sym_value2 = shape_env.CreateSymbol(4, source_value2); - - auto source_value3 = ge::MakeShared(6, ge::DT_UINT32); - Symbol sym_value3 = shape_env.CreateSymbol(5, source_value3); - - auto source_value4 = ge::MakeShared(8, ge::DT_UINT64); - Symbol sym_value4 = shape_env.CreateSymbol(6, source_value4); - - auto source_value_unsupport = ge::MakeShared(8, ge::DT_FLOAT); - Symbol sym_unsupport = shape_env.CreateSymbol(7, source_value_unsupport); - - auto shape_str = shape_env.GetAllSym2Src(); - EXPECT_EQ(shape_str.size(), 6); - EXPECT_EQ(shape_str[5].second->GetSourceStr(), R"([&]() -> int64_t { - const auto* tensor = context->GetGraphInputTensor(0); - if (tensor == nullptr) { - return -1; - } - const auto* data = tensor->GetData(); - int64_t sum = 0; - for (size_t i = 0; i < tensor->GetSize() / sizeof(int32_t); ++i) { - sum += data[i]; - } - return sum; - }())"); - EXPECT_EQ(shape_str[4].second->GetSourceStr(), R"([&]() -> int64_t { - const auto* tensor = context->GetGraphInputTensor(2); - if (tensor == nullptr) { - return -1; - } - const auto* data = tensor->GetData(); - int64_t sum = 0; - for (size_t i = 0; i < tensor->GetSize() / sizeof(int32_t); ++i) { - sum += data[i]; - } - return sum; - }())"); - EXPECT_EQ(shape_str[3].second->GetSourceStr(), R"([&]() -> int64_t { - const auto* tensor = context->GetGraphInputTensor(4); - if (tensor == nullptr) { - return -1; - } - const auto* data = tensor->GetData(); - int64_t sum = 0; - for (size_t i = 0; i < tensor->GetSize() / sizeof(int64_t); ++i) { - sum += data[i]; - } - return sum; - }())"); - EXPECT_EQ(shape_str[2].second->GetSourceStr(), R"([&]() -> int64_t { - const auto* tensor = context->GetGraphInputTensor(6); - if (tensor == nullptr) { - return -1; - } - const auto* data = tensor->GetData(); - int64_t sum = 0; - for (size_t i = 0; i < tensor->GetSize() / sizeof(uint32_t); ++i) { - sum += data[i]; - } - return sum; - }())"); - EXPECT_EQ(shape_str[1].second->GetSourceStr(), R"([&]() -> int64_t { - const auto* tensor = context->GetGraphInputTensor(8); - if (tensor == nullptr) { - return -1; - } - const auto* data = tensor->GetData(); - int64_t sum = 0; - for (size_t i = 0; i < tensor->GetSize() / sizeof(uint64_t); ++i) { - sum += data[i]; - } - return sum; - }())"); - EXPECT_EQ(shape_str[0].second->GetSourceStr(), ""); -} - -TEST_F(AttributeGroupShapeEnvUt, ShapeEnvAttrClone_Succ) { - auto s = AttrStore::Create(1); - auto shape_env = s.GetOrCreateAttrsGroup(); - EXPECT_NE(shape_env, nullptr); - SetCurShapeEnvContext(shape_env); - shape_env->shape_env_setting_.specialize_zero_one = true; - shape_env->shape_env_setting_.dynamic_mode = DynamicMode::kDynamic; - auto s0 = shape_env->CreateSymbol(10, MakeShared(0, 2)); - auto s1 = shape_env->CreateSymbol(10, MakeShared(0, 2)); - auto s2 = shape_env->CreateSymbol(20, MakeShared(0, 2)); - EXPECT_SYMBOL_EQ(s0 + s1, s2); - - // 测试Clone - auto s_bak = s; - auto shape_env_bak = s_bak.GetAttrsGroup(); - EXPECT_NE(shape_env_bak, nullptr); - EXPECT_EQ(shape_env_bak->shape_env_setting_.specialize_zero_one, true); - EXPECT_EQ(shape_env_bak->shape_env_setting_.dynamic_mode, DynamicMode::kDynamic); - EXPECT_EQ(shape_env_bak->replacements_.size(), 2); - EXPECT_EQ(shape_env_bak->symbol_to_value_.size(), 3); - EXPECT_EQ(shape_env_bak->symbol_to_value_, shape_env->symbol_to_value_); - EXPECT_EQ(shape_env_bak->symbol_to_source_.size(), 3); - EXPECT_EQ(shape_env_bak->symbol_to_source_, shape_env->symbol_to_source_); - EXPECT_EQ(shape_env_bak->value_to_symbol_.size(), 2); - EXPECT_EQ(shape_env_bak->value_to_symbol_, shape_env->value_to_symbol_); - EXPECT_EQ(shape_env_bak->symbol_check_infos_.size(), 1); - EXPECT_EQ(shape_env_bak->symbol_check_infos_, shape_env->symbol_check_infos_); - EXPECT_EQ(shape_env_bak->symbol_assert_infos_.size(),0); - EXPECT_EQ(shape_env_bak->symbol_assert_infos_, shape_env->symbol_assert_infos_); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(AttributeGroupShapeEnvUt, ShapeEnvAttrCopy_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(true, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - auto s0 = shape_env.CreateSymbol(10, MakeShared(0, 2)); - auto s1 = shape_env.CreateSymbol(10, MakeShared(0, 2)); - auto s2 = shape_env.CreateSymbol(20, MakeShared(0, 2)); - EXPECT_SYMBOL_EQ(s0 + s1, s2); - - // 测试operator= - auto shape_env_bak = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDuck)); - shape_env_bak = shape_env; - EXPECT_EQ(shape_env_bak.shape_env_setting_.specialize_zero_one, true); - EXPECT_EQ(shape_env_bak.shape_env_setting_.dynamic_mode, DynamicMode::kDynamic); - EXPECT_EQ(shape_env_bak.replacements_.size(), 2); - EXPECT_EQ(shape_env_bak.symbol_to_value_.size(), 3); - EXPECT_EQ(shape_env_bak.symbol_to_value_, shape_env.symbol_to_value_); - EXPECT_EQ(shape_env_bak.symbol_to_source_.size(), 3); - EXPECT_EQ(shape_env_bak.symbol_to_source_, shape_env.symbol_to_source_); - EXPECT_EQ(shape_env_bak.value_to_symbol_.size(), 2); - EXPECT_EQ(shape_env_bak.value_to_symbol_, shape_env.value_to_symbol_); - EXPECT_EQ(shape_env_bak.symbol_check_infos_.size(), 1); - EXPECT_EQ(shape_env_bak.symbol_check_infos_, shape_env.symbol_check_infos_); - EXPECT_EQ(shape_env_bak.symbol_assert_infos_.size(),0); - EXPECT_EQ(shape_env_bak.symbol_assert_infos_, shape_env.symbol_assert_infos_); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(AttributeGroupShapeEnvUt, CreateShapeEnvFromGraph_Succ) { - auto s = AttrStore::Create(1); - auto shape_env = s.CreateAttrsGroup(ShapeEnvSetting(false, DynamicMode::kDuck)); - EXPECT_NE(shape_env, nullptr); -} - -TEST_F(AttributeGroupShapeEnvUt, Get_Guard_Has_Dfx_Info_When_Set_Guard_Context) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(true, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(1, 2)); - Symbol s1 = shape_env.CreateSymbol(5, MakeShared(2, 0)); - - auto guard_1 = sym::Eq(s1 + Symbol(2), s0); - auto guard_2 = sym::Gt(s0 + Symbol(1), s1); - - GuardDfxContext dfx_context("node name:Add"); - - shape_env.AppendSymbolCheckInfo(guard_1); - shape_env.AppendSymbolAssertInfo(guard_2); - - auto check_infos = shape_env.GetAllSymbolCheckInfos(); - auto assert_infos = shape_env.GetAllSymbolAssertInfos(); - EXPECT_EQ(1, check_infos.size()); - EXPECT_EQ(1, assert_infos.size()); - - EXPECT_EQ("node name:Add", check_infos[0].dfx_info); - EXPECT_EQ("node name:Add", assert_infos[0].dfx_info); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(AttributeGroupShapeEnvUt, Get_Guard_Has_No_Dfx_Info_When_Clear_Guard_Context) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(true, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(1, 2)); - Symbol s1 = shape_env.CreateSymbol(5, MakeShared(2, 0)); - - auto guard_1 = sym::Eq(s1 + Symbol(2), s0); - auto guard_2 = sym::Gt(s0 + Symbol(1), s1); - - { - GuardDfxContext dfx_context("node name:Add"); - shape_env.AppendSymbolCheckInfo(guard_1); - } - - auto check_infos = shape_env.GetAllSymbolCheckInfos(); - EXPECT_EQ(1, check_infos.size()); - EXPECT_EQ("node name:Add", check_infos[0].dfx_info); - - shape_env.AppendSymbolAssertInfo(guard_2); - auto assert_infos = shape_env.GetAllSymbolAssertInfos(); - EXPECT_EQ(1, assert_infos.size()); - EXPECT_EQ("", assert_infos[0].dfx_info); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(AttributeGroupShapeEnvUt, Get_Guard_Has_New_Dfx_Info_When_Set_Guard_Context_Twice) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(true, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(1, 2)); - Symbol s1 = shape_env.CreateSymbol(5, MakeShared(2, 0)); - - auto guard_1 = sym::Eq(s1 + Symbol(2), s0); - auto guard_2 = sym::Gt(s0 + Symbol(1), s1); - - GuardDfxContext dfx_context("node name:Add"); - shape_env.AppendSymbolCheckInfo(guard_1); - - GuardDfxContext dfx_context1("node name:Sub"); - shape_env.AppendSymbolAssertInfo(guard_2); - - auto check_infos = shape_env.GetAllSymbolCheckInfos(); - auto assert_infos = shape_env.GetAllSymbolAssertInfos(); - EXPECT_EQ(1, check_infos.size()); - EXPECT_EQ(1, assert_infos.size()); - - EXPECT_EQ("node name:Add", check_infos[0].dfx_info); - EXPECT_EQ("node name:Sub", assert_infos[0].dfx_info); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(AttributeGroupShapeEnvUt, Get_Deserialize_Guard_Has_New_Dfx_Info_When_Set_Guard_Context_Serialize) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(true, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(1, 2)); - Symbol s1 = shape_env.CreateSymbol(5, MakeShared(2, 0)); - - auto guard_1 = sym::Eq(s1 + Symbol(2), s0); - auto guard_2 = sym::Gt(s0 + Symbol(1), s1); - GuardDfxContext dfx_context("node name:Add"); - shape_env.AppendSymbolCheckInfo(guard_1); - shape_env.AppendSymbolAssertInfo(guard_2); - - proto::AttrGroupDef attr_group_def; - auto ret = shape_env.Serialize(attr_group_def); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto shape_env1 = ShapeEnvAttr(); - ret = shape_env1.Deserialize(attr_group_def, nullptr); - auto check_infos = shape_env1.GetAllSymbolCheckInfos(); - auto assert_infos = shape_env1.GetAllSymbolAssertInfos(); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(2, shape_env1.symbol_to_value_.size()); - EXPECT_EQ(2, shape_env1.value_to_symbol_.size()); - EXPECT_EQ(Symbol("s0"), shape_env1.value_to_symbol_[2][0]); - EXPECT_EQ(Symbol("s1"), shape_env1.value_to_symbol_[5][0]); - EXPECT_EQ(true, shape_env1.HasSymbolCheckInfo(sym::Eq(Symbol("s1") + Symbol(2), Symbol("s0")))); - EXPECT_EQ(true, shape_env1.HasSymbolAssertInfo(sym::Gt(Symbol("s0") + Symbol(1), Symbol("s1")))); - EXPECT_EQ("node name:Add", check_infos[0].dfx_info); - EXPECT_EQ("node name:Add", assert_infos[0].dfx_info); - SetCurShapeEnvContext(nullptr); -} -} -} // namespace ge diff --git a/tests/ut/expression/testcase/attr_group_symbolic_unittest.cc b/tests/ut/expression/testcase/attr_group_symbolic_unittest.cc deleted file mode 100644 index 0bdcc2c4ed5af4eb2e62460bfb0ad7d5a36ba851..0000000000000000000000000000000000000000 --- a/tests/ut/expression/testcase/attr_group_symbolic_unittest.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "attribute_group/attr_group_base.h" -#include "attribute_group/attr_group_shape_env.h" -#include "attribute_group/attr_group_symbolic_desc.h" -#include "proto/ge_ir.pb.h" - -namespace ge { -namespace { -class AttrGroupsBaseTest : public AttrGroupsBase { -public: - virtual std::unique_ptr Clone() {return nullptr;}; -}; - -class AttributeGroupUt : public testing::Test {}; - -TEST_F(AttributeGroupUt, Clone) { - EXPECT_NO_THROW( - auto symbolicDesc = SymbolicDescAttr().Clone(); - auto shapeEnv = ShapeEnvAttr().Clone(); - proto::AttrGroupDef attr_group_def; - auto base = AttrGroupsBaseTest(); - EXPECT_EQ(GRAPH_SUCCESS, base.Serialize(attr_group_def)); - EXPECT_EQ(GRAPH_SUCCESS, base.Deserialize(attr_group_def, nullptr)); - ); -} - -TEST_F(AttributeGroupUt, SymbolicDescAttrSerialize) { - auto symbolicDesc = SymbolicDescAttr(); - symbolicDesc.symbolic_tensor.MutableOriginSymbolShape().AppendDim(ge::Symbol("s0")); - symbolicDesc.symbolic_tensor.MutableSymbolicValue()->emplace_back(ge::Symbol("s2")); - proto::AttrGroupDef attr_group_def; - auto ret = symbolicDesc.Serialize(attr_group_def); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(AttributeGroupUt, SymbolicDescAttrDeserialize) { - auto symbolicDesc = SymbolicDescAttr(); - symbolicDesc.symbolic_tensor.MutableOriginSymbolShape().AppendDim(ge::Symbol("s0")); - symbolicDesc.symbolic_tensor.MutableSymbolicValue()->emplace_back(ge::Symbol("s2")); - proto::AttrGroupDef attr_group_def; - EXPECT_EQ(GRAPH_SUCCESS, symbolicDesc.Serialize(attr_group_def)); - - auto symbolicDesc1 = SymbolicDescAttr(); - auto ret = symbolicDesc1.Deserialize(attr_group_def, nullptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(1, symbolicDesc1.symbolic_tensor.origin_symbol_shape_.GetDimNum()); - EXPECT_EQ(ge::Symbol("s0"), symbolicDesc1.symbolic_tensor.origin_symbol_shape_.GetDim(0)); - EXPECT_EQ(1, symbolicDesc1.symbolic_tensor.GetSymbolicValue()->size()); - EXPECT_EQ(ge::Symbol("s2"), *symbolicDesc1.symbolic_tensor.GetSymbolicValue()->begin()); -} -} -} // namespace ge \ No newline at end of file diff --git a/tests/ut/expression/testcase/expression_unittest.cc b/tests/ut/expression/testcase/expression_unittest.cc deleted file mode 100644 index 60dffab026131f07c63b9ad202a2a656c236d86b..0000000000000000000000000000000000000000 --- a/tests/ut/expression/testcase/expression_unittest.cc +++ /dev/null @@ -1,2343 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "common/checker.h" -#include "graph/symbolizer/symbolic.h" -#include "expression/expression_impl.h" -#include "expression/expr_parser.h" - -#include -#include "exe_graph/runtime/infer_symbol_shape_context.h" -#include "faker/kernel_run_context_faker.h" -#include "attribute_group/attr_group_shape_env.h" -#include "graph/symbolizer/symbolic_utils.h" -#include "source_stub.h" -namespace ge { -class UtestExpression : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; -using namespace ge; -using namespace ge::sym; - -TEST_F(UtestExpression, TestBasic) { - auto symbol2 = Symbol(2); - EXPECT_EQ(symbol2, 2); - auto symbol4 = Symbol(4); - EXPECT_EQ(symbol4.IsValid(), true); - EXPECT_EQ(symbol4, 4); - EXPECT_EQ(symbol2 + symbol4, 6); - EXPECT_EQ(symbol4 - symbol2, 2); - EXPECT_EQ(symbol4 * symbol2, 8); - EXPECT_EQ(symbol4 / symbol2, 2); - - auto e2 = Expression(symbol2); - EXPECT_EQ(e2, 2); - EXPECT_EQ(e2.IsValid(), true); - Expression copy_e2; - copy_e2 = e2; - EXPECT_EQ(copy_e2, 2); - auto e4 = Expression(symbol4); - EXPECT_EQ(e4, 4); - EXPECT_EQ(e2 + e4, 6); - EXPECT_EQ(e4 - e2, 2); - EXPECT_EQ(e4 * e2, 8); - EXPECT_EQ(e4 / e2, 2); -} - -TEST_F(UtestExpression, GetExprType) { - auto symbol1 = Symbol(2); - EXPECT_EQ(symbol1, 2); - EXPECT_NE(symbol1, 3); - - auto symbol2 = Symbol(2.0); - EXPECT_EQ(symbol2, 2.0); - EXPECT_NE(symbol2, 2); - - EXPECT_EQ(symbol1.GetExprType(), ExprType::kExprConstantInteger); - - auto symbol_uint32 = Symbol(2); - EXPECT_EQ(symbol_uint32.GetExprType(), ExprType::kExprConstantInteger); - - auto symbol_int64 = Symbol(2l); - EXPECT_EQ(symbol_int64.GetExprType(), ExprType::kExprConstantInteger); - - auto symbol_uint64 = Symbol(2lu); - EXPECT_EQ(symbol_uint64.GetExprType(), ExprType::kExprConstantInteger); - - auto symbol_double = Symbol(2.5); - EXPECT_EQ(symbol_double.GetExprType(), ExprType::kExprConstantRealDouble); - - auto symbol_num = Symbol(2); - auto symbol_den = Symbol(3); - auto ret = Div(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantRation); - - ret = Mul(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Add(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Sub(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Max(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Min(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Ceiling(symbol_double); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Pow(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Mod(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprConstantInteger); - - ret = Log(symbol_num, symbol_den); - EXPECT_EQ(ret.GetExprType(), ExprType::kExprOperation); - - symbol2 = Symbol("a"); - EXPECT_EQ(symbol2.GetExprType(), ExprType::kExprVariable); - - auto expr3 = Add(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - EXPECT_NE(expr3, symbol2); - - expr3 = Sub(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Mul(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Div(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Max(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Min(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Pow(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Mod(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Abs(symbol1); // symbol1是常量 - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprConstantInteger); - - expr3 = Abs(symbol2); // symbol2是符号 - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Log(symbol1, symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Log(symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Ceiling(symbol2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - - expr3 = Rational(1, 2); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprConstantRation); - - auto symbol_x = Symbol("x"); - auto symbol_2 = Symbol(2); - auto symbol_n = Symbol(1); - auto symbol = Mul(symbol_x, symbol_2); - expr3 = Coeff(symbol, symbol_x, symbol_n); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprConstantInteger); - EXPECT_EQ(expr3, 2); - - // 3*x**2 + 2*x*y + 1 - auto expr_coeff_base = - Add(Add(Mul(Symbol(3), Pow(Symbol("x"), Symbol(2))), Mul(Mul(Symbol(2), Symbol("x")), Symbol("y"))), Symbol(1)); - expr3 = Coeff(expr_coeff_base, Symbol("x"), Symbol(2)); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprConstantInteger); - EXPECT_EQ(expr3, 3); - - expr3 = Coeff(expr_coeff_base, Symbol("x"), Symbol(1)); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - EXPECT_EQ(std::string(expr3.Serialize().get()), "(2 * y)"); - - // -2 - expr3 = Neg(symbol1); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprConstantInteger);; - EXPECT_EQ(std::string(expr3.Serialize().get()), "-2"); - - // -x - expr3 = Neg(symbol_x); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - EXPECT_EQ(std::string(expr3.Serialize().get()), "(-1 * x)"); - - // -(x + 2) - expr3 = Neg(Add(symbol_x, symbol1)); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - EXPECT_EQ(std::string(expr3.Serialize().get()), "((2 + x) * -1)"); - - // -(2/x) - expr3 = Neg((symbol1 / symbol_x)); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - EXPECT_EQ(std::string(expr3.Serialize().get()), "(-2 / (x))"); - - // a%b - expr3 = Mod(symbol1, symbol_x); - EXPECT_EQ(expr3.GetExprType(), ExprType::kExprOperation); - EXPECT_EQ(std::string(expr3.Serialize().get()), "Mod(2, x)"); -} - -TEST_F(UtestExpression, GetLogicConstExprType_Succ) { - auto symbol_const_0 = Symbol(2); - auto symbol_const_1 = Symbol(2); - auto expr_eq = Eq(symbol_const_0, symbol_const_1); - EXPECT_EQ(expr_eq.GetExprType(), ExprType::kExprConstantBoolean); - auto expr_ne = Ne(symbol_const_0, symbol_const_1); - EXPECT_EQ(expr_ne.GetExprType(), ExprType::kExprConstantBoolean); - auto expr_gt = Gt(symbol_const_0, symbol_const_1); - EXPECT_EQ(expr_gt.GetExprType(), ExprType::kExprConstantBoolean); - auto expr_ge = Ge(symbol_const_0, symbol_const_1); - EXPECT_EQ(expr_ge.GetExprType(), ExprType::kExprConstantBoolean); - auto expr_lt = Lt(symbol_const_0, symbol_const_1); - EXPECT_EQ(expr_lt.GetExprType(), ExprType::kExprConstantBoolean); - auto expr_le = Le(symbol_const_0, symbol_const_1); - EXPECT_EQ(expr_le.GetExprType(), ExprType::kExprConstantBoolean); - auto expr_not = Not(expr_le); - EXPECT_EQ(expr_not.GetExprType(), ExprType::kExprConstantBoolean); - - auto expr_log = Log(symbol_const_0, symbol_const_1); - auto expr_ne_log = Ne(expr_log, symbol_const_0); - EXPECT_EQ(expr_ne_log.GetExprType(), ExprType::kExprConstantBoolean); -} - -TEST_F(UtestExpression, GetLogicExprType_Succ) { - auto symbol_logic_0 = Symbol("s0"); - auto symbol_logic_1 = Symbol("s1"); - auto symbol_const_0 = Symbol(2); - auto symbol_const_1 = Symbol(2); - auto expr_eq = Eq(symbol_logic_0, symbol_logic_1); - EXPECT_EQ(expr_eq.GetExprType(), ExprType::kExprOperationBoolean); - - auto expr_log = Log(symbol_const_0, symbol_const_1); - auto expr_gt = Gt(expr_log, symbol_logic_0); - EXPECT_EQ(expr_gt.GetExprType(), ExprType::kExprOperationBoolean); - - auto expr_ge = Ge(symbol_logic_0 + symbol_logic_1, symbol_const_0); - EXPECT_EQ(expr_ge.GetExprType(), ExprType::kExprOperationBoolean); - - auto expr_lt = Lt(symbol_const_0 + symbol_const_1, symbol_logic_1); - EXPECT_EQ(expr_lt.GetExprType(), ExprType::kExprOperationBoolean); - - auto expr_le = Le(symbol_logic_0 * symbol_const_1, symbol_const_1); - EXPECT_EQ(expr_le.GetExprType(), ExprType::kExprOperationBoolean); - - auto expr_not = Not(expr_le); - EXPECT_EQ(expr_not.GetExprType(), ExprType::kExprOperationBoolean); -} - -TEST_F(UtestExpression, EqSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Eq(Pow(s0, s1) + Max(s1, s2) * s1, (Min(s0, s2) + s1) * (Div(s2, s1) + s1)); - const std::string expr_str = std::string(expr.Serialize().get()); - EXPECT_EQ(expr_str, "ExpectEq((((s2 / (s1)) + s1) * (Min(s0, s2) + s1)), ((Max(s1, s2) * s1) + Pow(s0, s1)))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, NeSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Ne(s0 + s1 + s2, s1 + s0 * s2); - const std::string expr_str = std::string(expr.Serialize().get()); - EXPECT_EQ(expr_str, "ExpectNe(((s0 * s2) + s1), (s0 + s1 + s2))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, ParseNumer_Failed) { - // out_of_range - auto expr_parser = Expression::Parse("11111111111111111111111111111111111000000000000000000000000000000+s1"); - EXPECT_EQ(expr_parser.Str().get(), nullptr); - - // invalid_argument - ge::Scanner scanner(""); - ge::ExprParser ep(scanner); - ep.currentToken_.value = "this is not a number"; - EXPECT_EQ(ep.ParserNumber(), nullptr); -} - -TEST_F(UtestExpression, GtSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Gt(s0 + Symbol(3) * s2, (Symbol(3) + s1) * (s0 + s2)); - const std::string expr_str = std::string(expr.Serialize().get()); - // 大于可以被转化为小于 - EXPECT_EQ(expr_str, "ExpectLt(((3 + s1) * (s0 + s2)), ((3 * s2) + s0))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, LeSerializeAndDeserialize_Succ) { - auto expr = Le(Symbol(2), Symbol(3)); - const std::string expr_str = std::string(expr.Serialize().get()); - // 大于可以被转化为小于 - EXPECT_EQ(expr_str, "True"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, TestExpressionInvalid) { - Expression expr = Expression::Parse("3 & 8"); - EXPECT_EQ(expr.Simplify().IsValid(), false); - EXPECT_EQ(expr.Replace({}).IsValid(), false); - EXPECT_EQ(expr.Subs({}).IsValid(), false); -} - - -TEST_F(UtestExpression, FalseConstSerializeAndDeserialize_Succ) { - auto expr = Not(Le(Symbol(2), Symbol(3))); - const std::string expr_str = std::string(expr.Serialize().get()); - // 大于可以被转化为小于 - EXPECT_EQ(expr_str, "False"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, ConstBoolWithSymbolSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto expr = Eq(Le(Symbol(2), Symbol(3)), Ne(s0, Symbol(2))); - const std::string expr_str = std::string(expr.Serialize().get()); - // 大于可以被转化为小于 - EXPECT_EQ(expr_str, "ExpectEq(True, ExpectNe(2, s0))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, GeSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Ge(Max(s0 * s2, s1) * s2, Ceiling(Symbol(3) + s1 * s0 + s2) + s1); - const std::string expr_str = std::string(expr.Serialize().get()); - // 大于可以被转化为小于 - EXPECT_EQ(expr_str, "ExpectLe((3 + Ceiling(((s0 * s1) + s2)) + s1), (Max(s1, (s0 * s2)) * s2))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, LtSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Lt(Abs(Max(s1, s0 * s2) + Symbol(3)), Symbol(4)); - const std::string expr_str = std::string(expr.Serialize().get()); - // 大于可以被转化为小于 - EXPECT_EQ(expr_str, "ExpectLt(Abs((3 + Max(s1, (s0 * s2)))), 4)"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, NotSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Not(Eq(Pow(s1, s0) + Max(s1, s2) * s1, (Min(s0, s2) + s1) * (Div(s2, s1) + s1))); - const std::string expr_str = std::string(expr.Serialize().get()); - EXPECT_EQ(expr_str, "ExpectNe((((s2 / (s1)) + s1) * (Min(s0, s2) + s1)), ((Max(s1, s2) * s1) + Pow(s1, s0)))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, DoubleNotSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Not(Not(Eq(Pow(s0, s1) + Max(s1, s2) * s1, (Min(s0, s2) + s1) * (Div(s2, s1) + s1)))); - const std::string expr_str = std::string(expr.Serialize().get()); - EXPECT_EQ(expr_str, "ExpectEq((((s2 / (s1)) + s1) * (Min(s0, s2) + s1)), ((Max(s1, s2) * s1) + Pow(s0, s1)))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, NotEqualSerializeAndDeserialize_Succ) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto expr = Eq(Not(Ne(Symbol(2), s0)), Ge((s1 + s2), s0)); - const std::string expr_str = std::string(expr.Serialize().get()); - EXPECT_EQ(expr_str, "ExpectEq(ExpectEq(2, s0), ExpectLe(s0, (s1 + s2)))"); - auto expr_parser = Expression::Deserialize(expr_str.c_str()); - EXPECT_EQ(expr_parser, expr); -} - -TEST_F(UtestExpression, Str) { - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - auto b_div_c = Div(var_b, var_c); - EXPECT_EQ(std::string(b_div_c.Serialize().get()), "(b / (c))"); - auto one = Symbol(1); - auto b_div_one = Div(one, var_c); - EXPECT_EQ(std::string(b_div_one.Serialize().get()), "Pow(c, -1)"); - - auto b_max_c = Max(b_div_c, var_c); - EXPECT_EQ(std::string(b_max_c.Serialize().get()), "Max(c, (b / (c)))"); - - -// EXPECT_EQ(Rational(2, 3)->Serialize(), "(double)(2)/(double)(3)"); - EXPECT_EQ(std::string(Pow(var_b, Rational(1, 2)).Serialize().get()), "Sqrt(b)"); - EXPECT_EQ(std::string(Pow(var_b, Symbol(3)).Serialize().get()), "(b * b * b)"); - EXPECT_EQ(std::string(Pow(var_b, var_c).Serialize().get()), "Pow(b, c)"); - EXPECT_EQ(std::string(Mod(var_b, var_c).Serialize().get()), "Mod(b, c)"); - EXPECT_EQ(std::string(Ceiling(var_b).Serialize().get()), "Ceiling(b)"); - EXPECT_EQ(std::string(Abs(var_b).Serialize().get()), "Abs(b)"); - EXPECT_EQ(std::string(Sub(Symbol(2), Symbol(3) * var_b).Serialize().get()), "(2 - (3 * b))"); - EXPECT_EQ(std::string(Max(Max(var_b, var_c), Symbol("d")).Serialize().get()), - "Max(Max(b, c), d)"); - EXPECT_EQ(std::string(Min(Min(var_b, var_c), Symbol("d")).Serialize().get()), - "Min(Min(b, c), d)"); -} - -TEST_F(UtestExpression, Parser) { - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - auto b_div_c = Div(var_b, var_c); - EXPECT_EQ(std::string(b_div_c.Serialize().get()), "(b / (c))"); - auto b_div_parser = Expression::Parse(b_div_c.Serialize().get()); - EXPECT_EQ(b_div_c, b_div_parser); - auto b_max_c = Max(b_div_c, var_c); - EXPECT_EQ(std::string(b_max_c.Serialize().get()), "Max(c, (b / (c)))"); - - auto double_2_3 = Rational(2, 3); - EXPECT_EQ(std::string(double_2_3.Serialize().get()), "Rational(2 , 3)"); - auto double_2_3_parser = Expression::Parse(double_2_3.Serialize().get()); - // todo: 对于表达式2/3,当前的序列化为了给c++编译器编译,搞成了(double)(2)/(double)(3),是一个Rational - // 本次修改为采用Rational(a, b)表达的方式,来表达分子分母,当前只支持int类型的分子分母,给c++编译时,写一个Rational函数, - // 函数里面做double的cast转换 - EXPECT_EQ(double_2_3_parser, double_2_3); - - auto pow_3 = Pow(var_b, Symbol(3)); - EXPECT_EQ(std::string(pow_3.Serialize().get()), "(b * b * b)"); - auto power_parser_3 = Expression::Parse(pow_3.Serialize().get()); - EXPECT_EQ(power_parser_3, pow_3); - - auto pow_b_c = Pow(var_b, var_c); - EXPECT_EQ(std::string(pow_b_c.Serialize().get()), "Pow(b, c)"); - auto power_parser_b_c = Expression::Parse(pow_b_c.Serialize().get()); - EXPECT_EQ(power_parser_b_c, pow_b_c); - - auto mod_b_c = Mod(var_b, var_c); - EXPECT_EQ(std::string(mod_b_c.Serialize().get()), "Mod(b, c)"); - auto mod_parser_b_c = Expression::Parse(mod_b_c.Serialize().get()); - EXPECT_EQ(mod_parser_b_c, mod_b_c); - - auto log_b_c = Log(var_b, var_c); - EXPECT_EQ(std::string(log_b_c.Serialize().get()), "(Log(b) / (Log(c)))"); - auto log_parser_b_c = Expression::Parse(log_b_c.Serialize().get()); - EXPECT_EQ(log_parser_b_c, log_b_c); - - auto sub_mul = Sub(Symbol(2), Symbol(3) * var_b); - EXPECT_EQ(std::string(sub_mul.Serialize().get()), "(2 - (3 * b))"); - auto sub_mul_parser = Expression::Parse(sub_mul.Serialize().get()); - EXPECT_EQ(sub_mul_parser, sub_mul); - - auto add_mul = Add(Symbol(2), Symbol(3) * var_b); - EXPECT_EQ(std::string(add_mul.Serialize().get()), "((3 * b) + 2)"); // 不保证顺序 - auto add_mul_parser = Expression::Parse(add_mul.Serialize().get()); - EXPECT_EQ(add_mul_parser, add_mul); - - auto max_max = Max(Max(var_b, var_c), Symbol("d")); - EXPECT_EQ(std::string(max_max.Serialize().get()), "Max(Max(b, c), d)"); - auto max_max_parser = Expression::Parse(max_max.Serialize().get()); - EXPECT_EQ(max_max_parser, max_max); - - auto min_min = Min(Min(var_b, var_c), Symbol("d")); - EXPECT_EQ(std::string(min_min.Serialize().get()), "Min(Min(b, c), d)"); - auto min_min_parser = Expression::Parse(min_min.Serialize().get()); - EXPECT_EQ(min_min_parser, min_min); - - // 这个地方主线的序列化不保序,按照主线处理 - auto min_max = Min(Max(var_b, var_c), Symbol("d")); - EXPECT_EQ(std::string(min_max.Serialize().get()), "Min(d, Max(b, c))"); - auto min_max_parser = Expression::Parse(min_max.Serialize().get()); - EXPECT_EQ(min_max_parser, min_max); - - auto ceil = Ceiling(var_b); - EXPECT_EQ(std::string(ceil.Serialize().get()), "Ceiling(b)"); - auto ceil_parser = Expression::Parse(ceil.Serialize().get()); - EXPECT_EQ(ceil_parser, ceil); - - auto abs = Abs(var_b); - EXPECT_EQ(std::string(abs.Serialize().get()), "Abs(b)"); - auto abs_parser = Expression::Parse(abs.Serialize().get()); - EXPECT_EQ(abs_parser, abs); - - auto min_5_double = Min(Max(var_b, var_c), Symbol(5.0)); - EXPECT_EQ(std::string(min_5_double.Serialize().get()), "Min(Max(b, c), 5.0)"); - auto min_5_double_parser = Expression::Parse(min_5_double.Serialize().get()); - EXPECT_EQ(min_5_double_parser, min_5_double); - - EXPECT_EQ(std::string(var_b.GetName().get()), "b"); -} - -TEST_F(UtestExpression, Parser_Invalid) { - auto failed_parser = Expression::Parse("1 % dsfde )"); - EXPECT_EQ(failed_parser.IsValid(), false); - - failed_parser = Expression::Parse("5* (sincos(s0s1))"); - EXPECT_EQ(failed_parser.IsValid(), false); -} - -TEST_F(UtestExpression, Serialize_And_Deserialize) { - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - auto b_div_c = Div(var_b, var_c); - EXPECT_EQ(std::string(b_div_c.Serialize().get()), "(b / (c))"); - auto b_div_parser = Expression::Deserialize(b_div_c.Serialize().get()); - EXPECT_EQ(b_div_c, b_div_parser); - auto b_max_c = Max(b_div_c, var_c); - EXPECT_EQ(std::string(b_max_c.Serialize().get()), "Max(c, (b / (c)))"); - - auto double_2_3 = Rational(2, 3); - EXPECT_EQ(std::string(double_2_3.Serialize().get()), "Rational(2 , 3)"); - auto double_2_3_parser = Expression::Deserialize(double_2_3.Serialize().get()); - // todo: 对于表达式2/3,当前的序列化为了给c++编译器编译,搞成了(double)(2)/(double)(3),是一个Rational - // 本次修改为采用Rational(a, b)表达的方式,来表达分子分母,当前只支持int类型的分子分母,给c++编译时,写一个Rational函数, - // 函数里面做double的cast转换 - EXPECT_EQ(double_2_3_parser, double_2_3); - - auto pow_3 = Pow(var_b, Symbol(3)); - EXPECT_EQ(std::string(pow_3.Serialize().get()), "(b * b * b)"); - auto power_parser_3 = Expression::Deserialize(pow_3.Serialize().get()); - EXPECT_EQ(power_parser_3, pow_3); - - auto pow_b_c = Pow(var_b, var_c); - EXPECT_EQ(std::string(pow_b_c.Serialize().get()), "Pow(b, c)"); - auto power_parser_b_c = Expression::Deserialize(pow_b_c.Serialize().get()); - EXPECT_EQ(power_parser_b_c, pow_b_c); - - auto mod_b_c = Mod(var_b, var_c); - EXPECT_EQ(std::string(mod_b_c.Serialize().get()), "Mod(b, c)"); - auto mod_parser_b_c = Expression::Deserialize(mod_b_c.Serialize().get()); - EXPECT_EQ(mod_parser_b_c, mod_b_c); - - auto log_b_c = Log(var_b, var_c); - EXPECT_EQ(std::string(log_b_c.Serialize().get()), "(Log(b) / (Log(c)))"); - auto log_parser_b_c = Expression::Deserialize(log_b_c.Serialize().get()); - EXPECT_EQ(log_parser_b_c, log_b_c); - - auto sub_mul = Sub(Symbol(2), Symbol(3) * var_b); - EXPECT_EQ(std::string(sub_mul.Serialize().get()), "(2 - (3 * b))"); - auto sub_mul_parser = Expression::Deserialize(sub_mul.Serialize().get()); - EXPECT_EQ(sub_mul_parser, sub_mul); - - auto add_mul = Add(Symbol(2), Symbol(3) * var_b); - EXPECT_EQ(std::string(add_mul.Serialize().get()), "((3 * b) + 2)"); // 不保证顺序 - auto add_mul_parser = Expression::Deserialize(add_mul.Serialize().get()); - EXPECT_EQ(add_mul_parser, add_mul); - - auto max_max = Max(Max(var_b, var_c), Symbol("d")); - EXPECT_EQ(std::string(max_max.Serialize().get()), "Max(Max(b, c), d)"); - auto max_max_parser = Expression::Deserialize(max_max.Serialize().get()); - EXPECT_EQ(max_max_parser, max_max); - - auto min_min = Min(Min(var_b, var_c), Symbol("d")); - EXPECT_EQ(std::string(min_min.Serialize().get()), "Min(Min(b, c), d)"); - auto min_min_parser = Expression::Deserialize(min_min.Serialize().get()); - EXPECT_EQ(min_min_parser, min_min); - - // 这个地方主线的序列化不保序,按照主线处理 - auto min_max = Min(Max(var_b, var_c), Symbol("d")); - EXPECT_EQ(std::string(min_max.Serialize().get()), "Min(d, Max(b, c))"); - auto min_max_parser = Expression::Deserialize(min_max.Serialize().get()); - EXPECT_EQ(min_max_parser, min_max); - - auto ceil = Ceiling(var_b); - EXPECT_EQ(std::string(ceil.Serialize().get()), "Ceiling(b)"); - auto ceil_parser = Expression::Deserialize(ceil.Serialize().get()); - EXPECT_EQ(ceil_parser, ceil); - - auto abs = Abs(var_b); - EXPECT_EQ(std::string(abs.Serialize().get()), "Abs(b)"); - auto abs_parser = Expression::Deserialize(abs.Serialize().get()); - EXPECT_EQ(abs_parser, abs); - - auto min_5_double = Min(Max(var_b, var_c), Symbol(5.0)); - EXPECT_EQ(std::string(min_5_double.Serialize().get()), "Min(Max(b, c), 5.0)"); - auto min_5_double_parser = Expression::Deserialize(min_5_double.Serialize().get()); - EXPECT_EQ(min_5_double_parser, min_5_double); - - EXPECT_EQ(std::string(var_b.GetName().get()), "b"); -} - -// 如果不是按照序列化出来的字符串去进行反序列化,反序列化会失败 -TEST_F(UtestExpression, Deserialize_Invalid) { - auto failed_parser = Expression::Deserialize("s0*s1"); - EXPECT_EQ(failed_parser.IsValid(), false); - - failed_parser = Expression::Deserialize("a+2"); - EXPECT_EQ(failed_parser.IsValid(), false); -} - -TEST_F(UtestExpression, EqualAndNotEqual) { - auto var_b = Symbol("b"); - auto int_2 = Symbol(2); - auto int_3 = Symbol(3); - auto int_n_6 = Symbol(-6); - auto int_6 = Symbol(6); - - auto b_2 = Mul(var_b, int_2); - auto b_b = Add(var_b, var_b); - EXPECT_TRUE(b_2 == b_b); - - auto b_3 = Mul(var_b, int_3); - EXPECT_TRUE(b_3 != b_b); - - auto abs_1 = Abs(int_n_6); - EXPECT_TRUE(abs_1 == int_6); -} - -TEST_F(UtestExpression, SymbolCheckWithoutContext) { - auto var_a = Symbol("a"); - auto var_b = Symbol("b"); - auto ret = EXPECT_SYMBOL_EQ(var_b, var_a); - EXPECT_EQ(ret, false); - bool guard_res0 = [&var_a, &var_b] () -> bool { - ASSERT_SYMBOL_EQ(var_a, var_b); - return true; - }(); - EXPECT_EQ(guard_res0, false); - EXPECT_EQ(SymbolicUtils::StaticCheckEq(var_a, var_b), TriBool::kUnknown); -} - -TEST_F(UtestExpression, GetName) { - auto var_b = Symbol("b"); - EXPECT_EQ(std::string(var_b.GetName().get()), "b"); - auto var_c = Symbol(static_cast(1), "s0"); - EXPECT_EQ(std::string(var_c.GetName().get()), "s0"); - auto var_d = Symbol(static_cast(1), "s1"); - EXPECT_EQ(std::string(var_d.GetName().get()), "s1"); - auto var_e = Symbol(static_cast(1), "s2"); - EXPECT_EQ(std::string(var_e.GetName().get()), "s2"); - auto var_f = Symbol(static_cast(1), "s3"); - EXPECT_EQ(std::string(var_f.GetName().get()), "s3"); - auto var_g = Symbol(static_cast(1.0), "s4"); - EXPECT_EQ(std::string(var_g.GetName().get()), "s4"); - auto var_h = Symbol(static_cast(1.0)); - EXPECT_EQ(std::string(var_h.GetName().get()), "Const_0"); - auto var_i = Symbol(static_cast(5)); - EXPECT_EQ(std::string(var_i.GetName().get()), "Const_1"); -} - -TEST_F(UtestExpression, Operator) { - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - EXPECT_EQ((var_b + var_c).GetExprType(), ExprType::kExprOperation); - EXPECT_EQ((var_b - var_c).GetExprType(), ExprType::kExprOperation); - EXPECT_EQ((var_b * var_c).GetExprType(), ExprType::kExprOperation); - EXPECT_EQ((var_b / var_c).GetExprType(), ExprType::kExprOperation); -} - -TEST_F(UtestExpression, Replace) { - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - auto var_d = Symbol("d"); - auto b_div_c = Div(var_b, var_c); - std::vector> replace_vars; - replace_vars.push_back({var_b, var_d}); - auto replace_expr = b_div_c.Replace(replace_vars); - EXPECT_TRUE(replace_expr == Div(var_d, var_c)); -} - -TEST_F(UtestExpression, Subs) { - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - auto var_d = Symbol("d"); - auto b_div_c = Div(var_b, var_c); - std::vector> subs_vars; - subs_vars.push_back({var_b, var_d}); - auto subs_expr = b_div_c.Subs(subs_vars); - EXPECT_TRUE(subs_expr == Div(var_d, var_c)); -} - -TEST_F(UtestExpression, Simplify) { - auto var_b = Symbol("b"); - auto const_1 = Symbol(1); - auto const_2 = Symbol(2); - auto const_3 = Symbol(3); - EXPECT_TRUE((Add(Add(var_b, const_1), const_2).Simplify()) == Add(var_b, const_3)); -} - -TEST_F(UtestExpression, GetPrimaryArgs) { - auto const_neg_2 = Symbol(-2); - auto const_neg_3 = Symbol(-3); - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - auto var_d = Symbol("d"); - auto var_e = Symbol("e"); - std::vector args_exp = {var_b, var_c, var_d, var_e}; - auto mul_expr = Min(Max(Add(Pow(Mul(const_neg_2, var_b), Mul(var_b, const_neg_3)), var_c), var_d), var_e); - auto prim_args = mul_expr.FreeSymbols(); - EXPECT_EQ(prim_args.size(), args_exp.size()); - bool has_find = true; - for (auto &arg_get : prim_args) { - bool one_has_find = false; - for (auto &arg_exp : args_exp) { - if (arg_get == arg_exp) { - one_has_find = true; - break; - } - } - if (!one_has_find) { - has_find = false; - break; - } - } - EXPECT_EQ(has_find, true); -} - -TEST_F(UtestExpression, ContainVar) { - auto var_b = Symbol("b"); - auto const_1 = Symbol(1); - EXPECT_TRUE(Add(var_b, const_1).ContainVar(var_b)); - EXPECT_FALSE(Add(var_b, const_1).ContainVar(const_1)); -} - -TEST_F(UtestExpression, GetResult) { - auto var_b = Symbol("b"); - auto var_c = Symbol("c"); - auto b_add_c = Add(var_b, var_c); - std::vector> replace_vars; - replace_vars.emplace_back(var_b, Symbol(1)); - replace_vars.emplace_back(var_c, Symbol(2)); - double result; - auto code = b_add_c.GetResult(replace_vars, result); - EXPECT_EQ(code, ge::GRAPH_SUCCESS); - EXPECT_EQ(result, static_cast(3)); - - replace_vars.clear(); - replace_vars.emplace_back(var_b, Symbol(1.0)); - replace_vars.emplace_back(var_c, Symbol(2.0)); - code = b_add_c.GetResult(replace_vars, result); - EXPECT_EQ(code, ge::GRAPH_SUCCESS); - EXPECT_EQ(result, static_cast(3)); - - replace_vars.clear(); - replace_vars.emplace_back(var_b, Symbol(1)); - replace_vars.emplace_back(var_c, Rational(2, 3)); - code = b_add_c.GetResult(replace_vars, result); - EXPECT_EQ(code, ge::GRAPH_SUCCESS); - EXPECT_TRUE(std::abs(result - (static_cast(1) + static_cast(2) / static_cast(3))) < 0.0001); -} - -TEST_F(UtestExpression, GetBoolConstValueEq_Succ) { - auto expr = Eq(Symbol(2), Symbol(2)); - bool value; - EXPECT_EQ(expr.GetConstValue(value), true); - EXPECT_EQ(value, true); -} - -TEST_F(UtestExpression, GetBoolConstValueNot_Succ) { - auto expr = Not(Eq(Symbol(2), Symbol(2))); - bool value; - EXPECT_EQ(expr.GetConstValue(value), true); - EXPECT_EQ(value, false); -} - -TEST_F(UtestExpression, GetBoolConstValueFromVariable_Failed) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto expr = Not(Eq(s0, s1)); - bool value; - EXPECT_EQ(expr.GetConstValue(value), false); -} - -TEST_F(UtestExpression, GetBoolConstValueFromNonBoolExpr_Failed) { - auto s0 = Symbol(3); - bool value; - EXPECT_EQ(s0.GetConstValue(value), false); -} - -TEST_F(UtestExpression, GetDoubleConstValueFromBoolExpr_Failed) { - auto expr = Not(Eq(Symbol(2), Symbol(2))); - double value; - EXPECT_EQ(expr.GetConstValue(value), false); -} - -TEST_F(UtestExpression, GetDoubleConstValueFromIntExpr_Failed) { - auto s0 = Symbol(2); - double value; - EXPECT_EQ(s0.GetConstValue(value), false); -} - -TEST_F(UtestExpression, GetDoubleConstValueFromRelationExpr_Succ) { - auto s0 = Rational(2, 4); - double value; - EXPECT_EQ(s0.GetConstValue(value), true); - EXPECT_EQ(value, 0.5); -} - -TEST_F(UtestExpression, NotExprWithInParamNotBool_Failed) { - auto s0 = Symbol(2); - auto not_expr = Not(s0); - EXPECT_EQ(not_expr.IsValid(), false); -} - -TEST_F(UtestExpression, GetBoolConstValueWithWrongType_Failed) { - auto expr = Not(Eq(Symbol(2), Symbol(2))); - int64_t value; - EXPECT_EQ(expr.GetConstValue(value), false); -} - -TEST_F(UtestExpression, GetBoolConstValueWithAdd_Succ) { - auto expr = Eq(Symbol(2) + Symbol(2), Symbol(2) * Symbol(2)); - bool value = false; - EXPECT_EQ(expr.GetConstValue(value), true); - EXPECT_EQ(value, true); -} - -TEST_F(UtestExpression, CheckIsVariable_Succ) { - auto expr = Symbol("s0"); - EXPECT_EQ(expr.IsVariableExpr(), true); -} - -TEST_F(UtestExpression, CheckIsBoolean_Succ) { - auto expr = Eq(Symbol(2) + Symbol(2), Symbol(2) * Symbol(2)); - EXPECT_EQ(expr.IsBooleanExpr(), true); - auto expr1 = Eq(Symbol(2), Symbol("s0")); - EXPECT_EQ(expr1.IsBooleanExpr(), true); -} - -TEST_F(UtestExpression, GetConstValue) { - auto var_b = Symbol("b"); - int32_t value; - EXPECT_EQ(var_b.GetConstValue(value), false); - - int32_t value1 = 2; - auto const_1 = Symbol(value1); - int32_t res_value1; - EXPECT_EQ(const_1.GetConstValue(res_value1), true); - EXPECT_EQ(res_value1, value1); - - uint32_t value2 = 1; - auto const_2 = Symbol(value2); - uint32_t res_value2; - EXPECT_EQ(const_2.GetConstValue(res_value2), true); - EXPECT_EQ(res_value2, value2); - - double value3 = 1.0; - auto const_3 = Symbol(value3); - double res_value3; - EXPECT_EQ(const_3.GetConstValue(res_value3), true); - EXPECT_EQ(res_value3, value3); - - float value4 = 1.0; - auto const_4 = Symbol(value4); - float res_value4; - EXPECT_EQ(const_4.GetConstValue(res_value4), true); - EXPECT_EQ(res_value4, value4); - - int64_t value5 = 1; - auto const_5 = Symbol(value5); - int64_t res_value5; - EXPECT_EQ(const_5.GetConstValue(res_value5), true); - EXPECT_EQ(res_value5, value5); - - uint64_t value6 = 1; - auto const_6 = Symbol(value6); - uint64_t res_value6; - EXPECT_EQ(const_6.GetConstValue(res_value6), true); - EXPECT_EQ(res_value6, value6); - - // 常量 + 常量 - auto add = sym::Add(const_6, const_6); - uint64_t res_value7; - EXPECT_EQ(add.GetConstValue(res_value7), true); - EXPECT_EQ(res_value7, value6 + value6); - - // 常量 * 常量 - auto mul = sym::Mul(const_6, const_6); - uint64_t res_value8; - EXPECT_EQ(mul.GetConstValue(res_value8), true); - EXPECT_EQ(res_value8, value6 * value6); - - // 常量 - 常量 - auto sub = sym::Sub(const_6, const_6); - uint64_t res_value9; - EXPECT_EQ(sub.GetConstValue(res_value9), true); - EXPECT_EQ(res_value9, value6 - value6); - - // 常量 / 常量 - auto div = sym::Div(const_6, const_6); - uint64_t res_value10; - EXPECT_EQ(div.GetConstValue(res_value10), true); - EXPECT_EQ(res_value10, value6 / value6); - - // Max(常量, 常量) - auto max1 = sym::Max(const_1, const_6); - uint64_t res_value11; - EXPECT_EQ(max1.GetConstValue(res_value11), true); - EXPECT_EQ(res_value11, std::max(static_cast(value1), value6)); - - // 常量 + 变量 - auto add20 = sym::Add(const_6, var_b); - uint64_t res_value20; - EXPECT_EQ(add20.GetConstValue(res_value20), false); - - // 常量绝对值 -> 常量 - auto abs1 = sym::Abs(const_6); - uint64_t res_value21; - EXPECT_EQ(abs1.GetConstValue(res_value21), true); - - // 变量绝对值 -> 变量 - auto abs2 = sym::Abs(var_b); - uint64_t res_value22; - EXPECT_EQ(abs2.GetConstValue(res_value22), false); -} - -TEST_F(UtestExpression, TestAlgin) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - - auto expr1 = Align(s0, 32); - auto expr2 = Align(s0, 32); - EXPECT_EQ(expr1, expr2); - - auto str = expr1.Serialize(); - EXPECT_EQ(std::string(str.get()), "(32 * Ceiling((Rational(1 , 32) * s0)))"); - auto expr3 = Expression::Parse(str.get()); - EXPECT_EQ(expr1, expr3); - - auto const_16 = Symbol(32); - auto const_1 = Symbol(1); - auto expr4 = (const_16 * Ceiling((Rational(1 , 32) * s0))); - EXPECT_EQ(expr1, expr4); -} - -TEST_F(UtestExpression, AlignWithPositiveIntegerConst) { - int res = 0; - EXPECT_EQ(sym::AlignWithPositiveInteger(Symbol(10), 8).GetConstValue(res), true); - EXPECT_EQ(res, 16); - EXPECT_EQ(sym::AlignWithPositiveInteger(Symbol(15), 4).GetConstValue(res), true); - EXPECT_EQ(res, 16); - EXPECT_EQ(sym::AlignWithPositiveInteger(Symbol(7), 2).GetConstValue(res), true); - EXPECT_EQ(res, 8); - EXPECT_EQ(sym::AlignWithPositiveInteger(Symbol(8), 2).GetConstValue(res), true); - EXPECT_EQ(res, 8); -} - -TEST_F(UtestExpression, AlignWithPositiveInteger) { - auto s0 = Symbol("s0"); - - auto expr1 = sym::AlignWithPositiveInteger(s0, 8); - auto expr2 = sym::AlignWithPositiveInteger(s0, 8); - EXPECT_TRUE(expr1 == expr2); - - auto str0 = expr1.Serialize(); - EXPECT_EQ(std::string(str0.get()), "(8 * Floor(((7 + s0) * Rational(1 , 8))))"); - auto expr3 = Expression::Parse(str0.get()); - EXPECT_EQ(expr1, expr3); - - auto str1 = expr1.Str(StrType::kStrExpr); - EXPECT_EQ(std::string(str1.get()), "(8 * Floor(((7 + s0) * 1/8)))"); - auto expr4 = Expression::Parse(str1.get()); - EXPECT_EQ(expr1, expr4); -} - -TEST_F(UtestExpression, StrTypeTest) { - auto expr1 = sym::Div(Symbol("s0"), Symbol("s1")); - auto expr2 = sym::Div(Symbol("s0"), Symbol(8)); - auto expr3 = sym::Div(Symbol("s0"), Symbol(8)); - auto expr4 = sym::Div(Symbol("s0"), Symbol(8)); - EXPECT_EQ(std::string("(s0 / (s1))"), expr1.Str(StrType::kStrCpp).get()); - EXPECT_EQ(std::string("(Rational(1 , 8) * s0)"), expr2.Str(StrType::kStrCpp).get()); - EXPECT_EQ(std::string("(1/8 * s0)"), expr3.Str(StrType::kStrEnd).get()); - EXPECT_EQ(std::string("(1/8 * s0)"), expr4.Str(StrType::kStrExpr).get()); -} - -TEST_F(UtestExpression, TestAlignWithPositiveInteger_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(12, MakeShared(0, 0)); - auto expr = sym::AlignWithPositiveInteger(sym0, 8); - int64_t value_int = 0; - EXPECT_EQ(expr.GetHint(value_int), true); - EXPECT_EQ(value_int, 16); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SymbolCheck_Old_Api_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - EXPECT_EQ(sym0.IsVariableExpr(), true); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - EXPECT_EQ(sym1.IsVariableExpr(), true); - const std::string file_name = "test.cc"; - bool guard_res0 = ExpectSymbolEq(sym0 + Symbol(1), sym1, file_name.c_str(), 2); - EXPECT_EQ(guard_res0, true); - bool guard_res1 = AssertSymbolEq(sym0 + Symbol(1), sym1, file_name.c_str(), 2); - EXPECT_EQ(guard_res1, true); - bool guard_res2 = ExpectSymbolBool(sym::Lt(sym0, sym1), file_name.c_str(), 2); - EXPECT_EQ(guard_res2, true); - bool guard_res3 = AssertSymbolBool(sym::Lt(sym0, sym1), file_name.c_str(), 2); - EXPECT_EQ(guard_res3, true); -} - -TEST_F(UtestExpression, SymbolCheck_With_Simplify_Guard_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol s2 = shape_env.CreateSymbol(2, MakeShared(0, 2)); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0, s1), true); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0, s2), true); - EXPECT_EQ(EXPECT_SYMBOL_EQ(Symbol(2) * s1 + Symbol(64), Symbol(68)), true); - EXPECT_EQ(EXPECT_SYMBOL_LT(s1, Symbol(100)), true); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); - auto expr = Symbol(2) * s0 + Symbol(64); - EXPECT_EQ(SymbolicUtils::StaticCheckEq(expr.Simplify(), Symbol(68)), TriBool::kTrue); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); - // 第二次不会额外追加guard - EXPECT_EQ(SymbolicUtils::StaticCheckEq(expr.Simplify(), Symbol(68)), TriBool::kTrue); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); - - EXPECT_EQ(SymbolicUtils::StaticCheckLt(s2, Symbol(100)), TriBool::kTrue); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); -} - -TEST_F(UtestExpression, SymbolCheck_With_Simplify_Guard_Succ2) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol s2 = shape_env.CreateSymbol(2, MakeShared(0, 2)); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0, s1), true); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0, s2), true); - EXPECT_EQ(EXPECT_SYMBOL_EQ(Symbol(2) * s1 + Symbol(64), Symbol(68)), true); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); - // 化简后,应该是const的表达式,不会新增guard - EXPECT_EQ(SymbolicUtils::StaticCheckEq(s1, s2), TriBool::kTrue); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); -} - -TEST_F(UtestExpression, SymbolCheck_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - EXPECT_EQ(sym0.IsVariableExpr(), true); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - EXPECT_EQ(sym1.IsVariableExpr(), true); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - EXPECT_EQ(sym2.IsVariableExpr(), true); - Symbol sym3 = shape_env.CreateSymbol(4, MakeShared(0, 3)); - EXPECT_EQ(sym3.IsVariableExpr(), true); - Symbol sym4 = shape_env.CreateSymbol(5, MakeShared(0, 4)); - EXPECT_EQ(sym4.IsVariableExpr(), true); - Symbol sym5 = shape_env.CreateSymbol(6, MakeShared(0, 5)); - EXPECT_EQ(sym5.IsVariableExpr(), true); - Symbol sym6 = shape_env.CreateSymbol(7, MakeShared(0, 6)); - EXPECT_EQ(sym6.IsVariableExpr(), true); - - bool guard_res0 = EXPECT_SYMBOL_EQ(sym0 + sym1, sym2); - EXPECT_EQ(guard_res0, true); - bool guard_res1 = EXPECT_SYMBOL_NE(sym0 * sym1, sym2); - EXPECT_EQ(guard_res1, true); - bool guard_res2 = EXPECT_SYMBOL_LT(sym3 / sym1, sym0); - EXPECT_EQ(guard_res2, false); - bool guard_res3 = EXPECT_SYMBOL_LE(sym4 - sym3, sym0); - EXPECT_EQ(guard_res3, true); - bool guard_res4 = EXPECT_SYMBOL_GT(sym::Pow(sym1, sym2), sym::Max(sym5, sym4)); - EXPECT_EQ(guard_res4, true); - bool guard_res5 = EXPECT_SYMBOL_GE(sym::Min(sym1, sym2), sym::Abs(sym6)); - EXPECT_EQ(guard_res5, false); - bool guard_res6 = EXPECT_SYMBOL_EQ(Symbol(2), Symbol(3)); - EXPECT_EQ(guard_res6, false); - bool guard_res7 = EXPECT_SYMBOL_GT(sym::Mod(sym6, sym1), sym::Max(sym5, sym4)); - EXPECT_EQ(guard_res7, false); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(sym0 + sym1, sym2)), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Ne(sym2, sym0 * sym1)), false); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Lt(sym3 / sym1, sym0)), false); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Not(sym::Lt(sym3 / sym1, sym0))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Le(sym4 - sym3, sym0)), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Lt(sym::Max(sym5, sym4), sym::Pow(sym1, sym2))), false); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Ge(sym::Min(sym1, sym2), sym::Abs(sym6))), false); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Not(sym::Ge(sym::Min(sym1, sym2), sym::Abs(sym6)))), false); - - EXPECT_EQ(SymbolicUtils::StaticCheckEq(sym2, sym0 + sym1), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckNe(sym2, sym1 * sym0), TriBool::kTrue); - auto expr1 = sym3 / sym1; - EXPECT_EQ(SymbolicUtils::StaticCheckLt(expr1, sym0), TriBool::kUnknown); - EXPECT_EQ(SymbolicUtils::StaticCheckLe(sym0, sym3 / sym1), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckLe((sym4 - sym3), sym0), TriBool::kTrue); - - EXPECT_EQ(SymbolicUtils::StaticCheckGt(sym::Pow(sym1, sym2), sym::Max(sym5, sym4)), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckGe(sym::Min(sym1, sym2), sym::Abs(sym6)), TriBool::kUnknown); - EXPECT_EQ(SymbolicUtils::StaticCheckGt(sym::Abs(sym6), sym::Min(sym1, sym2)), TriBool::kTrue); - SetCurShapeEnvContext(nullptr); -} - -Status InferAddSymbolShapeStub(gert::InferSymbolShapeContext *context) { - auto input_shape0 = context->GetInputSymbolShape(0); - GE_ASSERT_NOTNULL(input_shape0); - auto input_shape1 = context->GetInputSymbolShape(1); - GE_ASSERT_NOTNULL(input_shape1); - auto output_shape = context->GetOutputSymbolShape(0); - GE_ASSERT_NOTNULL(output_shape); - for (size_t i = 0UL; i < input_shape0->GetDimNum(); i++) { - auto s0 = input_shape0->GetDim(i); - auto s1 = input_shape1->GetDim(i); - if (EXPECT_SYMBOL_EQ(s0, s1)) { - output_shape->AppendDim(s0); - } else if (EXPECT_SYMBOL_EQ(s0, Symbol(1))) { - output_shape->AppendDim(s1); - } else if (EXPECT_SYMBOL_EQ(s1, Symbol(1))) { - output_shape->AppendDim(s0); - } else { - return GRAPH_FAILED; - } - } - return GRAPH_SUCCESS; -} - -TEST_F(UtestExpression, SymbolCheckBroadCast_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(4, MakeShared(0, 3)); - gert::SymbolShape in_shape0({(sym0 + sym2), sym2, sym0, sym1}); - gert::SymbolShape in_shape1({Symbol(4), sym0, sym3, (sym3 / sym1)}); - gert::SymbolShape out_shape({}); - - auto context_holder = gert::InferSymbolShapeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Inputs({&in_shape0, &in_shape1}) - .Outputs({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - EXPECT_EQ(InferAddSymbolShapeStub(context), - GRAPH_SUCCESS); - auto output_shape = context->GetOutputSymbolShape(0); - EXPECT_NE(output_shape, nullptr); - EXPECT_EQ(output_shape->GetDimNum(), 4); - EXPECT_EQ(output_shape->GetDim(0), (sym0 + sym2)); - EXPECT_EQ(output_shape->GetDim(1), sym2); - EXPECT_EQ(output_shape->GetDim(2), sym3); - EXPECT_EQ(output_shape->GetDim(3), sym1); - - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(sym0 + sym2, Symbol(4))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Ne(sym2, sym0)), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Ne(sym2, Symbol(1))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(sym0, Symbol(1))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Ne(sym3, sym0)), false); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(sym1, sym3 / sym1)), true); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SymbolAssertCheck_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(4, MakeShared(0, 3)); - Symbol sym4 = shape_env.CreateSymbol(5, MakeShared(0, 4)); - Symbol sym5 = shape_env.CreateSymbol(6, MakeShared(0, 5)); - Symbol sym6 = shape_env.CreateSymbol(7, MakeShared(0, 6)); - - bool guard_res0 = [&sym1] () -> bool { - ASSERT_SYMBOL_EQ(sym::Rational(4, 2), sym1); - return true; - }(); - EXPECT_EQ(guard_res0, true); - - bool guard_res0_1 = [&sym1, &sym2] () -> bool { - ASSERT_SYMBOL_EQ(sym2, sym1); - return true; - }(); - EXPECT_EQ(guard_res0_1, false); - - bool guard_res1 = [&sym0, &sym2] () -> bool { - ASSERT_SYMBOL_NE(sym::Ceiling(sym0), sym2); - return true; - }(); - EXPECT_EQ(guard_res1, true); - - bool guard_res2 = [&sym1, &sym4] () -> bool { - ASSERT_SYMBOL_LT(sym::Log(sym1, sym1), sym4); - return true; - }(); - EXPECT_EQ(guard_res2, true); - - bool guard_res3 = [&sym4, &sym3, &sym0] () -> bool { - ASSERT_SYMBOL_LE(sym4 - sym3, sym0); - return true; - }(); - EXPECT_EQ(guard_res3, true); - - bool guard_res4 = [&sym4, &sym5, &sym1, &sym2] () -> bool { - ASSERT_SYMBOL_GT(sym::Pow(sym1, sym2), sym::Max(sym5, sym4)); - return true; - }(); - EXPECT_EQ(guard_res4, true); - - bool guard_res5 = [&sym6, &sym1, &sym2] () -> bool { - ASSERT_SYMBOL_GE(sym::Min(sym1, sym2), sym::Abs(sym6)); - return true; - }(); - EXPECT_EQ(guard_res5, false); - - bool guard_res6 = [] () -> bool { - ASSERT_SYMBOL_GE(Symbol(5), Symbol(2)); - return true; - }(); - EXPECT_EQ(guard_res6, true); - - EXPECT_EQ(shape_env.HasSymbolAssertInfo(sym::Eq(sym::Rational(4, 2), sym1)), true); - EXPECT_EQ(shape_env.HasSymbolAssertInfo(sym::Ne(sym::Ceiling(sym0), sym2)), true); - EXPECT_EQ(shape_env.HasSymbolAssertInfo(sym::Gt(sym4, sym::Log(sym1, sym1))), true); - EXPECT_EQ(shape_env.HasSymbolAssertInfo(sym::Le(sym4 - sym3, sym0)), true); - EXPECT_EQ(shape_env.HasSymbolAssertInfo(sym::Gt(sym::Pow(sym1, sym2), sym::Max(sym5, sym4))), false); - EXPECT_EQ(shape_env.HasSymbolAssertInfo(sym::Ge(sym::Min(sym1, sym2), sym::Abs(sym6))), false); - - EXPECT_EQ(SymbolicUtils::StaticCheckEq(sym::Rational(4, 2), sym1), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckNe(sym::Ceiling(sym0), sym2), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckGt(sym4, sym::Log(sym1, sym1)), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckLe((sym4 - sym3), sym0), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckGt(sym::Pow(sym1, sym2), sym::Max(sym5, sym4)), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckGe(sym::Min(sym1, sym2), sym::Abs(sym6)), TriBool::kUnknown); - SetCurShapeEnvContext(nullptr); -} - -Status InferMatmulSymbolShapeStub(gert::InferSymbolShapeContext *context) { - auto input_shape0 = context->GetInputSymbolShape(0); - GE_ASSERT_NOTNULL(input_shape0); - auto input_shape1 = context->GetInputSymbolShape(1); - GE_ASSERT_NOTNULL(input_shape1); - auto output_shape = context->GetOutputSymbolShape(0); - GE_ASSERT_NOTNULL(output_shape); - auto s0 = input_shape0->GetDim(1); - auto s1 = input_shape1->GetDim(0); - ASSERT_SYMBOL_EQ(s0, s1); - output_shape->AppendDim(input_shape0->GetDim(0)); - output_shape->AppendDim(input_shape1->GetDim(1)); - return GRAPH_SUCCESS; -} - -TEST_F(UtestExpression, SymbolAssertMatmul_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(4, MakeShared(0, 3)); - - gert::SymbolShape in_shape0({(sym0 + sym2), sym3}); - gert::SymbolShape in_shape1({sym::Pow(sym1, sym1), (sym2 * Symbol(2))}); - gert::SymbolShape out_shape({}); - - auto context_holder = gert::InferSymbolShapeContextFaker() - .IrInputNum(2) - .NodeIoNum(2, 1) - .NodeInputTd(0, ge::DT_FLOAT16, ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0) - .NodeInputTd(1, ge::DT_FLOAT16, ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z) - .Inputs({&in_shape0, &in_shape1}) - .Outputs({&out_shape}) - .Build(); - auto context = context_holder.GetContext(); - EXPECT_EQ(InferMatmulSymbolShapeStub(context), GRAPH_SUCCESS); - auto output_shape = context->GetOutputSymbolShape(0); - EXPECT_NE(output_shape, nullptr); - EXPECT_EQ(output_shape->GetDimNum(), 2); - EXPECT_EQ(output_shape->GetDim(0), (sym0 + sym2)); - EXPECT_EQ(output_shape->GetDim(1), sym2 * Symbol(2)); - EXPECT_EQ(shape_env.HasSymbolAssertInfo(sym::Eq(sym3, sym::Pow(sym1, sym1))), true); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyVariable1_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(4, MakeShared(0, 3)); - - EXPECT_SYMBOL_EQ(sym2, sym0 + sym1); - EXPECT_SYMBOL_EQ(sym1, Symbol(2) * sym0); - EXPECT_SYMBOL_EQ(sym0, Symbol(1)); - auto expr1 = Symbol(2) * (sym0 + sym1) + sym1 * sym2 + sym3 + sym2; - auto expect_expr = Symbol(15) + sym3; - EXPECT_EQ(expr1.Simplify(), expect_expr); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyVariable2_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(3, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(2, MakeShared(0, 3)); - Symbol sym4 = shape_env.CreateSymbol(3, MakeShared(0, 4)); - Symbol sym5 = shape_env.CreateSymbol(3, MakeShared(0, 5)); - EXPECT_SYMBOL_EQ(sym1, sym2); - EXPECT_SYMBOL_EQ(sym5, sym4); - EXPECT_SYMBOL_EQ(sym4, sym0 + sym3); - EXPECT_SYMBOL_EQ(sym0 + sym3, sym2); - - auto expr1 = Symbol(2) * (sym0 + sym3) + sym::Max(sym5 * sym4, sym1 + sym2); - EXPECT_EQ(std::string(expr1.Simplify().Str().get()), - "((2 * s0) + (2 * s3) + Max(((2 * s0) + (2 * s3)), ((s0 + s3) * (s0 + s3))))"); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyVariable3_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(2, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(5, MakeShared(0, 3)); - Symbol sym4 = shape_env.CreateSymbol(3, MakeShared(0, 4)); - EXPECT_SYMBOL_EQ(sym1, sym2); - EXPECT_SYMBOL_EQ(sym3, sym1 + sym4); - EXPECT_SYMBOL_EQ(sym0 + sym2, sym4); - - auto expr1 = sym3 * sym4 + sym1 * Symbol(2) + sym2; - EXPECT_EQ(std::string(expr1.Simplify().Str().get()), - "(((s2 * s2) * 2) + (3 * s0 * s2) + (3 * s2) + (s0 * s0))"); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyVariable4_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(3, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(2, MakeShared(0, 3)); - - EXPECT_SYMBOL_EQ(sym1, sym2); - EXPECT_SYMBOL_EQ(Symbol(2), sym3); - EXPECT_SYMBOL_EQ(sym0 + sym3, sym2); - - auto expr1 = sym3 + sym1 * Symbol(2) + sym2; - EXPECT_EQ(std::string(expr1.Simplify().Str().get()), - "((3 * s0) + 8)"); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyWithDeplicateSym_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - EXPECT_SYMBOL_EQ(sym0, sym::Max(Symbol(0), sym0)); - EXPECT_EQ(sym0.Simplify(), sym::Max(Symbol(0), sym0)); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyVariable5_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(4, MakeShared(0, 2)); - - EXPECT_SYMBOL_EQ(sym0, sym1); - EXPECT_SYMBOL_EQ(sym1, sym2 - sym0); - - auto expr1 = sym0; - EXPECT_EQ(expr1.Simplify(), sym1); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyVariable6_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(2, MakeShared(0, 3)); - - EXPECT_SYMBOL_EQ(sym0, sym1); - EXPECT_SYMBOL_EQ(sym3, sym::Min(Symbol(2), sym2)); - EXPECT_SYMBOL_EQ((sym0 + sym3) * sym::Ceiling(sym1), sym2); - EXPECT_SYMBOL_EQ(sym::Pow(sym0, Symbol(2)) / sym::Abs(sym1), sym0); - - auto expr1 = sym0 + sym1 - (sym2 * sym3); - EXPECT_EQ(std::string(expr1.Simplify().Serialize().get()), - "(((s0 * s0 * s0 * s0) / ((Abs(s1) * Abs(s1) * Abs(s1)))) + ((s0 * s0) / (Abs(((s0 * s0) / (Abs(s1)))))) - ((s0 * s0) * Ceiling(((s0 * s0 * s0 * s0) / ((Abs(s1) * Abs(s1) * Abs(s1))))) * Min(2, (((s0 * s0) * Ceiling(((s0 * s0 * s0 * s0) / ((Abs(s1) * Abs(s1) * Abs(s1))))) / (Abs(((s0 * s0) / (Abs(s1)))))) + (Ceiling(((s0 * s0 * s0 * s0) / ((Abs(s1) * Abs(s1) * Abs(s1))))) * s3))) / (Abs(((s0 * s0) / (Abs(s1)))))) - (Ceiling(((s0 * s0 * s0 * s0) / ((Abs(s1) * Abs(s1) * Abs(s1))))) * Min(2, (((s0 * s0) * Ceiling(((s0 * s0 * s0 * s0) / ((Abs(s1) * Abs(s1) * Abs(s1))))) / (Abs(((s0 * s0) / (Abs(s1)))))) + (Ceiling(((s0 * s0 * s0 * s0) / ((Abs(s1) * Abs(s1) * Abs(s1))))) * s3))) * Min(2, s2)))"); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, SimplifyVariable7_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(1, MakeShared(0, 2)); - Symbol sym3 = shape_env.CreateSymbol(2, MakeShared(0, 3)); - - EXPECT_SYMBOL_EQ(sym0, sym1); - EXPECT_SYMBOL_EQ(sym1, sym::Min(Symbol(8), sym2)); - EXPECT_SYMBOL_EQ(sym3, sym2 * Symbol(2)); - EXPECT_SYMBOL_EQ(sym0 * sym::Ceiling(sym1), sym2); - - auto expr1 = sym::Pow(sym3, sym0); - EXPECT_EQ(std::string(expr1.Simplify().Serialize().get()), - "Pow((2 * Ceiling(Min(8, s2)) * Min(8, s2)), Min(8, (Ceiling(Min(8, s2)) * s0)))"); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, TestHash) { - auto symbol2 = Symbol(2); - auto symbol2_bak = Symbol(2); - EXPECT_EQ(symbol2, symbol2_bak); - - auto s0 = Symbol("s0"); - auto s0_bak = Symbol("s0"); - auto s1 = Symbol("s1"); - EXPECT_EQ(s0.Hash(), s0_bak.Hash()); - auto expr1 = Mul(s0, s1); - auto expr2 = Mul(s0_bak, s1); - auto expr3 = Mul(s1, s0); - EXPECT_EQ(expr1, expr2); - EXPECT_EQ(expr2, expr3); - EXPECT_EQ(expr1, expr3); - EXPECT_EQ(expr1.Hash(), expr2.Hash()); - EXPECT_EQ(expr3.Hash(), expr1.Hash()); - EXPECT_EQ(expr2.Hash(), expr3.Hash()); - - expr1 = Eq(s0, s0_bak); - expr2 = Eq(s0_bak, s0); - EXPECT_EQ(expr1, expr2); - EXPECT_EQ(expr2.Hash(), expr1.Hash()); - - expr1 = Ne(s0, s1); - expr2 = Ne(s1, s0); - EXPECT_EQ(expr1, expr2); - EXPECT_EQ(expr2.Hash(), expr1.Hash()); - - expr1 = s0 + s1 + s0; - expr2 = symbol2 * s0 + s1; - EXPECT_EQ(expr1, expr2); - EXPECT_EQ(expr2.Hash(), expr1.Hash()); -} - -TEST_F(UtestExpression, TestExpressionUnorderdMap) { - using UMapExprInt = std::unordered_map; - UMapExprInt map1; - auto s0 = Symbol("s0"); - auto s0_bak = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - map1[s0] = 0; - map1[s0_bak] = 1; - map1[s1] = 2; - - EXPECT_NE(map1.find(s0), map1.end()); - EXPECT_EQ(map1.find(s0)->second, 1); // 被更新了 - EXPECT_EQ(map1.find(s1)->second, 2); - auto s3 = Symbol("s3"); - auto s4 = Symbol("s4"); - auto s5 = Symbol("s5"); - auto s6 = Symbol("s6"); - map1[s3] = 3; - map1[s4] = 4; - map1[s5] = 5; - map1[s6] = 6; -} - -TEST_F(UtestExpression, TestExpressionMap) { - using MapExprInt = std::map; - MapExprInt map1; - auto s0 = Symbol("s0"); - auto s0_bak = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - map1[s0] = 0; - map1[s0_bak] = 1; - map1[s1] = 2; - EXPECT_NE(map1.find(s0), map1.end()); - EXPECT_EQ(map1.find(s0)->second, 1); // 被更新了 - EXPECT_EQ(map1.find(s1)->second, 2); - EXPECT_EQ(map1.begin()->first, s0); -} - -TEST_F(UtestExpression, TestExpressionSet) { - using SetExpr = std::set; - SetExpr set1; - auto s0 = Symbol("s0"); - auto s0_bak = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto s3 = Symbol("s3"); - set1.insert(s0); - set1.insert(s0_bak); - set1.insert(s1); - set1.insert(s2); - set1.insert(s3); - EXPECT_EQ(set1.size(), 4); -} - -TEST_F(UtestExpression, TestCompare) { - auto s0 = Symbol("s0"); - auto s0_bak = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - EXPECT_EQ(s0.Compare(s0), 0); - EXPECT_EQ(s0.Compare(s1), -1); - EXPECT_EQ(s0.Compare(s2), -1); - - auto expr1 = s0 + s1; - auto expr2 = s2 + s0; - EXPECT_EQ(expr1.Compare(expr2), -1); - auto expr3 = Pow(Mul(Max(s0, s1), s2), Symbol(2)); - std::cout << expr2.Hash() << std::endl; - std::cout << expr3.Hash() << std::endl; - EXPECT_EQ(expr3.Compare(expr2), -1); -} - -TEST_F(UtestExpression, TestSimplifyCeiling_Floor) { - auto s6 = Symbol("s0"); - auto s0 = Symbol(0); - auto s192 = Symbol(192); - auto expr1 = Ceiling((Min(s192, s6)- Min(s0, s6))); - auto expr2 = (Symbol(-1) * Floor(((Min(s192, s6) - Min(s0, s6)) * Symbol(-1)))); - EXPECT_EQ(expr1.Simplify(), expr2); -} - -TEST_F(UtestExpression, TestLog_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - auto expr = sym::Log(sym0); - int64_t value_int = 0; - EXPECT_EQ(expr.GetHint(value_int), true); - EXPECT_TRUE(value_int == 0); - - auto arg = Symbol(100); - auto base = Symbol(10); - - auto res = sym::Log(arg, base); - ASSERT_EQ(res, sym::Log(arg) / sym::Log(base)); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, TestAlignment_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(12, MakeShared(0, 0)); - auto expr = sym::Align(sym0, 8); - int64_t value_int = 0; - EXPECT_EQ(expr.GetHint(value_int), true); - EXPECT_EQ(value_int, 16); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, TestAlignmentZero_Failed) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(12, MakeShared(0, 0)); - auto expr = sym::Align(sym0, 0); - EXPECT_EQ(expr.IsValid(), false); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, TestAlignmentWithPositiveZero_Failed) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(12, MakeShared(0, 0)); - auto expr = sym::AlignWithPositiveInteger(sym0, 0); - EXPECT_EQ(expr.IsValid(), false); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, CoeffTest_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - int64_t const_sym1 = 1L; - uint64_t const_sym2 = 2UL; - uint32_t const_sym3 = 3UL; - int32_t const_sym4 = 4; - - auto symbol_2 = Symbol(const_sym2); - auto symbol_1 = Symbol(const_sym1); - - auto symbol = sym::Mul(sym0, symbol_2); - auto expr1 = sym::Coeff(symbol, sym0, symbol_1); - EXPECT_EQ(expr1, 2); - - // 3*x**y + 2*x*y + 2**x * 4 - auto expr_coeff_base = - Symbol(const_sym3) * sym::Pow(sym0, sym1) + Symbol(const_sym2) * sym0 * sym1 + - Symbol(const_sym4) * sym::Pow(Symbol(const_sym2), sym0); - auto expr2 = sym::Coeff(expr_coeff_base, sym0, sym1); - EXPECT_EQ(expr2, 3); - - auto expr3 = sym::Coeff(expr_coeff_base, sym1, Symbol(const_sym1)); - EXPECT_EQ(expr3, Symbol(2) * sym0); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, StaticCheckConst) { - EXPECT_EQ(SymbolicUtils::StaticCheckEq(sym::Rational(4, 2), sym::Log(Symbol(1))), TriBool::kFalse); -} - -TEST_F(UtestExpression, TestNotEqual) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting(false, DynamicMode::kDynamic)); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - EXPECT_EQ(Symbol(3.5f) != sym0, true); -} - -TEST_F(UtestExpression, ComputeExprHint_Succ) { - auto shape_env = ShapeEnvAttr(ShapeEnvSetting()); - SetCurShapeEnvContext(&shape_env); - Symbol sym0 = shape_env.CreateSymbol(1, MakeShared(0, 0)); - Symbol sym1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - Symbol sym2 = shape_env.CreateSymbol(3, MakeShared(0, 1)); - // int类型 - auto expr1 = Symbol(2) * (sym0 * sym2) + sym::Max(sym0 + Symbol(3) * sym1, sym2); - int64_t value_int = 0; - EXPECT_EQ(expr1.GetHint(value_int), true); - EXPECT_EQ(value_int, 13); - - // bool 类型 - auto expr2 = sym::Eq(Symbol(3) * sym0, sym::Pow(sym1, sym2) - (sym2 + Symbol(2))); - bool value_bool = false; - EXPECT_EQ(expr2.GetHint(value_bool), true); - EXPECT_EQ(value_bool, true); - - // float - auto expr3 = sym::Rational(1, 2) + sym2; - double value_double = 0.0f; - EXPECT_EQ(expr3.GetHint(value_double), true); - EXPECT_EQ(value_double, 3.5f); - SetCurShapeEnvContext(nullptr); -} - -TEST_F(UtestExpression, ComputeConstHint_Succ) { - // int类型 - auto expr1 = Symbol(2) + sym::Max(Symbol(1) + Symbol(3), Symbol(2)); - int64_t value_int = 0; - EXPECT_EQ(expr1.GetHint(value_int), true); - EXPECT_EQ(value_int, 6); - - // bool 类型 - auto expr_bool1 = sym::Eq(Symbol(3), sym::Pow(Symbol(2), Symbol(3)) - Symbol(5)); - bool value_bool = false; - EXPECT_EQ(expr_bool1.GetHint(value_bool), true); - EXPECT_EQ(value_bool, true); - - auto expr_bool2 = sym::Eq(Symbol(3), sym::Pow(Symbol(3), Symbol(2)) - Symbol(5)); - EXPECT_EQ(expr_bool2.GetHint(value_bool), true); - EXPECT_TRUE(value_bool == false); - - // float - auto expr3 = sym::Rational(1, 2) + Symbol(2); - double value_double = 0.0f; - EXPECT_EQ(expr3.GetHint(value_double), true); - EXPECT_EQ(value_double, 2.5f); -} - -TEST_F(UtestExpression, GetConstValue_Succ) { - // int类型 - auto expr1 = Symbol(2) + sym::Max(Symbol(1) + Symbol(3), Symbol(2)); - int64_t value_int = 0; - EXPECT_EQ(expr1.GetConstValue(value_int), true); - EXPECT_EQ(value_int, 6); - - // bool 类型 - auto expr2 = sym::Eq(Symbol(3), sym::Pow(Symbol(2), Symbol(3)) - Symbol(5)); - bool value_bool = false; - EXPECT_EQ(expr2.GetConstValue(value_bool), true); - EXPECT_EQ(value_bool, true); - - // float - auto expr3 = sym::Rational(1, 2) + Symbol(2); - double value_double = 0.0f; - EXPECT_EQ(expr3.GetConstValue(value_double), true); - EXPECT_EQ(value_double, 2.5f); -} - -TEST_F(UtestExpression, Abnormal_Sym_Expr) { - auto s0 = Symbol("s0"); - auto e0 = s0 + s0; - auto e1 = Expression::Deserialize("a(s0)"); - e0 = e1; - EXPECT_NE(e0.IsConstExpr(), true); - EXPECT_EQ(e0.Serialize(), nullptr); - EXPECT_EQ(e0.FreeSymbols().size(), 0); - double a; - EXPECT_NE(e0.GetResult({}, a), GRAPH_SUCCESS); - EXPECT_EQ(e0.GetConstValue(a), false); - bool c; - EXPECT_EQ(e0.GetConstValue(c), false); - EXPECT_EQ(e0 == e1, false); - EXPECT_EQ(e0.GetExprType(), ExprType::kExprNone); -} - -TEST_F(UtestExpression, Parser_Empty) { - Expression expr = Expression::Parse(nullptr); - EXPECT_EQ(expr.IsValid(), false); -} - -TEST_F(UtestExpression, Parser_Minus) { - auto s0 = Symbol("s0"); - auto neg_2 = Symbol(-2); - auto c_2 = Symbol(2); - - // Add - auto expr = sym::Add(s0, neg_2); - auto expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 + s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Add(neg_2, s0); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 + s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - // Sub - expr = sym::Sub(s0, neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(2 + s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Sub(neg_2, s0); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 - s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - //Mul - expr = sym::Mul(s0, neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 * s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Mul(neg_2, s0); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 * s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - //Div - expr = sym::Div(neg_2, s0); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 / (s0))"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Div(s0, neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(Rational(-1 , 2) * s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - // Max - expr = sym::Max(s0, neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "Max(s0, -2)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Max(neg_2, s0); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "Max(s0, -2)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - // Min - expr = sym::Min(s0, neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "Min(s0, -2)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Min(neg_2, s0); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "Min(s0, -2)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - // Pow - expr = sym::Pow(s0, neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "Pow(s0, -2)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Pow(neg_2, s0); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "Pow(-2, s0)"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - //Abs - expr = sym::Abs(neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "2"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Abs(sym::Add(neg_2, s0)); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "Abs((2 - s0))"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - //Ceiling - expr = sym::Ceiling(neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "-2"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Ceiling(sym::Add(neg_2, s0)); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 + Ceiling(s0))"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - //Floor - expr = sym::Floor(neg_2); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "-2"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); - - expr = sym::Floor(sym::Add(neg_2, s0)); - expr_str = expr.Serialize(); - EXPECT_EQ(std::string(expr_str.get()), "(-2 + Floor(s0))"); - EXPECT_EQ(Expression::Parse(expr_str.get()), expr); -} - -TEST_F(UtestExpression, Parser_Minus1) { - std::string str = "((-((8+s2)*s3)-(-1-((8+s2)*s3)-s3)-s3)*(8+s1)*(8+s2)*s3)"; - Expression expr = Expression::Parse(str.c_str()); - EXPECT_EQ(std::string(expr.Serialize().get()), "(( - ((8 + s2) * s3) - (-1 - ((8 + s2) * s3) - s3) - s3) * (8 + s1) * (8 + s2) * s3)"); - - str = "-(s1 + s2)"; - expr = Expression::Parse(str.c_str()); - EXPECT_EQ(std::string(expr.Serialize().get()), "((s1 + s2) * -1)"); - - str = "1 - (s1-1)"; - expr = Expression::Parse(str.c_str()); - EXPECT_EQ(std::string(expr.Serialize().get()), "(1 - (-1 + s1))"); - - str = "-s1"; - expr = Expression::Parse(str.c_str()); - EXPECT_EQ(std::string(expr.Serialize().get()), "(-1 * s1)"); - - str = "1-s0"; - expr = Expression::Parse(str.c_str()); - EXPECT_EQ(std::string(expr.Serialize().get()), "(1 - s0)"); - - str = "-1-((8+s2)*s3)-s3"; - expr = Expression::Parse(str.c_str()); - EXPECT_EQ(std::string(expr.Serialize().get()), "(-1 - ((8 + s2) * s3) - s3)"); - - str = "(s1-1)"; - expr = Expression::Parse(str.c_str()); - EXPECT_EQ(std::string(expr.Serialize().get()), "(-1 + s1)"); -} - -TEST_F(UtestExpression, CanonicalizeBoolExpr_basic) { - Expression e(nullptr); - EXPECT_EQ(e.CanonicalizeBoolExpr().Str().get(), nullptr); - - auto e0 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 == 1")); - EXPECT_EQ(e0->CanonicalizeBoolExpr()->Str(), "ExpectEq(1, s0)"); - - auto e1 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 != 1")); - EXPECT_EQ(e1->CanonicalizeBoolExpr()->Str(), "ExpectNe(1, s0)"); - - auto e2 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 <= 1")); - EXPECT_EQ(e2->CanonicalizeBoolExpr()->Str(), "ExpectLe(s0, 1)"); - - auto e3 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 >= 1")); - EXPECT_EQ(e3->CanonicalizeBoolExpr()->Str(), "ExpectLe(1, s0)"); - - auto e4 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 > 1")); - EXPECT_EQ(e4->CanonicalizeBoolExpr()->Str(), "ExpectLt(1, s0)"); - - auto e5 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 + s1")); - EXPECT_EQ(e5->CanonicalizeBoolExpr()->Str(), "(s0 + s1)"); - - auto e6 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s2 < 0")); - EXPECT_EQ(e6->CanonicalizeBoolExpr()->Str(), "ExpectLt(s2, 0)"); - - auto e7 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("0 < s2")); - EXPECT_EQ(e7->CanonicalizeBoolExpr()->Str(), "ExpectLt(0, s2)"); - - auto e8 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s2 + 1 < 0")); - EXPECT_EQ(e8->CanonicalizeBoolExpr()->Str(), "ExpectLt((1 + s2), 0)"); - - auto e9 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s2 - 1 < 0")); - EXPECT_EQ(e9->CanonicalizeBoolExpr()->Str(), "ExpectLt(s2, 1)"); - - auto e10 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("Mod(s2,2) <= 0")); - EXPECT_EQ(e10->CanonicalizeBoolExpr()->Str(), "ExpectLe(Mod(s2, 2), 0)"); - - auto e11 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("Mod(s2,2) -1 <= 0")); - EXPECT_EQ(e11->CanonicalizeBoolExpr()->Str(), "ExpectLe(Mod(s2, 2), 1)"); -} - -TEST_F(UtestExpression, CanonicalizeBoolExpr_basic_neg) { - auto e0 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 -1 == 3")); - EXPECT_EQ(e0->CanonicalizeBoolExpr()->Str(), "ExpectEq(4, s0)"); - - auto e1 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 -1 != 3")); - EXPECT_EQ(e1->CanonicalizeBoolExpr()->Str(), "ExpectNe(4, s0)"); - - auto e2 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 -1 <= 3")); - EXPECT_EQ(e2->CanonicalizeBoolExpr()->Str(), "ExpectLe(s0, 4)"); - - auto e3 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 -1 >= 3")); - EXPECT_EQ(e3->CanonicalizeBoolExpr()->Str(), "ExpectLe(4, s0)"); - - auto e4 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("s0 -1 > 3")); - EXPECT_EQ(e4->CanonicalizeBoolExpr()->Str(), "ExpectLt(4, s0)"); -} - -TEST_F(UtestExpression, CanonicalizeBoolExpr) { - Symbol r0 = Symbol("r0"); - Symbol s1 = Symbol(1); - Symbol s2 = Symbol(4096); - Symbol s3 = Symbol(41); - Symbol x = Symbol("x"); - Symbol y = Symbol("y"); - - // x * y == 0 ---> x *y == 0 - EXPECT_EQ(std::string(sym::Eq(Mul(x, y), Symbol(0)).CanonicalizeBoolExpr().Serialize().get()), - "ExpectEq(0, (x * y))"); - - // 2*x*y + 4*x == 0 ---> x*y + 2*x == 0 - auto e = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("2*x*y+4*x==0")); - EXPECT_EQ(e->CanonicalizeBoolExpr()->Str(), "ExpectEq(0, ((2 * x) + (x * y)))"); - - // 2*x*y+4*x + 2**x==0 ---> 2*x*y+4*x + 2**x==0 (pow not support) - auto e1 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("2*x*y+4*x + 2**x==0")); - EXPECT_EQ(e1->CanonicalizeBoolExpr()->Str(), "ExpectEq(0, ((2 * x * y) + (4 * x) + Pow(2, x)))"); - - // 2*x + 4y == 0 ---> x + 2*y == 0 - auto e2 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("2*x + 4y == 0")); - EXPECT_EQ(e2->CanonicalizeBoolExpr()->Str(), "ExpectEq(0, ((2 * y) + x))"); - - EXPECT_EQ(std::string(sym::Eq(Add(Symbol(0), Symbol(0)), Symbol(0)).CanonicalizeBoolExpr().Serialize().get()), - "True"); - - EXPECT_EQ(std::string(Mul(r0, s2).CanonicalizeBoolExpr().Serialize().get()), "(4096 * r0)"); - - auto expr = sym::Lt(r0, s1).CanonicalizeBoolExpr(); - EXPECT_EQ(std::string(expr.Serialize().get()), "ExpectLt(r0, 1)"); - - auto expr2 = sym::Lt(r0, s2).CanonicalizeBoolExpr(); - EXPECT_EQ(std::string(expr2.Serialize().get()), "ExpectLt(r0, 4096)"); - - auto expr3 = sym::Ge(r0, s2).CanonicalizeBoolExpr(); - EXPECT_EQ(std::string(expr3.Serialize().get()), "ExpectLe(4096, r0)"); - - auto expr4 = sym::Eq(sym::Add(r0, s3), s2).CanonicalizeBoolExpr(); - EXPECT_EQ(std::string(expr4.Serialize().get()), "ExpectEq(4055, r0)"); - - auto expr41 = sym::Eq(s2, sym::Add(r0, s3)).CanonicalizeBoolExpr(); - EXPECT_EQ(std::string(expr41.Serialize().get()), "ExpectEq(4055, r0)"); - - auto expr5 = sym::Eq(sym::Add(s3, sym::Mul(r0, Symbol(2))), s2).CanonicalizeBoolExpr(); - EXPECT_EQ(std::string(expr5.Serialize().get()), "ExpectEq(4055, (2 * r0))"); - - auto expr6 = sym::Eq(sym::Add(Symbol(42), sym::Mul(r0, Symbol(2))), Symbol(4096)).CanonicalizeBoolExpr(); - EXPECT_EQ(std::string(expr6.Serialize().get()), "ExpectEq(2027, r0)"); -} - -TEST_F(UtestExpression, EvaluateAsBoolBasic) { - // s0>0 -> 2*s0 > s0 return true - - auto e = ShapeEnvAttr(); - e.CreateSymbol(1, MakeShared(0, 1)); - - SetCurShapeEnvContext(&e); - - //s0 > 0 - ExpectSymbolBool(sym::Gt(Symbol("s0"), Symbol(0)), "xxx", 100); - - //2*s0 > s0 - EXPECT_TRUE(SymbolicUtils::StaticCheckGt(Mul(Symbol(2), Symbol("s0")), Symbol("s0")) == TriBool::kTrue); -} - -TEST_F(UtestExpression, EvaluateAsBool_case_canfuse) { - // 4*s0*s1*s2 > 4*s0*s2 ---> 1 < s1 - auto e = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("4*s0*s1*s2 > 4*s0*s2")); - EXPECT_EQ(e->CanonicalizeBoolExpr()->Str(), "ExpectLt(1, s1)"); - - // 4*s0*s1*s2 < 4*s0*s2 ---> 1 > s1 - auto e1 = ExpressionImpl::CreateExpressionImpl(SymEngine::parse("4*s0*s2 > 4*s0*s1*s2")); - EXPECT_EQ(e1->CanonicalizeBoolExpr()->Str(), "ExpectLt(s1, 1)"); - - auto s = ShapeEnvAttr(); - s.CreateSymbol(1, MakeShared(0, 1)); - s.CreateSymbol(2, MakeShared(0, 2)); - s.CreateSymbol(3, MakeShared(0, 3)); - - SetCurShapeEnvContext(&s); - - ExpectSymbolBool(sym::Gt(Symbol("s0"), Symbol(0)), "xxx", 100); - ExpectSymbolBool(sym::Gt(Symbol("s1"), Symbol(1)), "xxx", 100); - ExpectSymbolBool(sym::Gt(Symbol("s2"), Symbol(0)), "xxx", 100); - - auto expr1 = sym::Mul(sym::Mul(sym::Mul(Symbol(4), Symbol("s0")), Symbol("s1")), Symbol("s2")); - auto expr2 = sym::Mul(sym::Mul(Symbol(4), Symbol("s0")), Symbol("s2")); - - EXPECT_TRUE(SymbolicUtils::StaticCheckGt(expr1, expr2) == TriBool::kTrue); -} - -TEST_F(UtestExpression, GetArgsTest) { - Expression e(nullptr); - EXPECT_EQ(e.GetArgs().size(), 0); - auto s0 = Symbol("s0"); - auto expr1 = Mul(s0, Symbol(2)); - auto s1 = Symbol("s1"); - auto expr2 = Pow(s1, Symbol(2)); - auto expr3 = Add(expr1, expr2); - - EXPECT_EQ(expr3.GetArgs().size(), 2); - EXPECT_EQ(std::string(expr3.GetArgs()[0].Serialize().get()), expr2.Serialize().get()); - EXPECT_EQ(std::string(expr3.GetArgs()[1].Serialize().get()), expr1.Serialize().get()); -} - -TEST_F(UtestExpression, TriBoolConvert) { - ge::TriBool tb = ge::TriBool::kTrue; - - EXPECT_EQ(tb, ge::TriBool::kTrue); - EXPECT_NE(tb, ge::TriBool::kFalse); - EXPECT_NE(tb, ge::TriBool::kUnknown); -} - -TEST_F(UtestExpression, LogicalTest) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto s3 = Symbol("s3"); - - //非布尔表达式测试 - EXPECT_EQ(LogicalAnd({s0}).Serialize().get(), nullptr); - EXPECT_EQ(LogicalOr({s0}).Serialize().get(), nullptr); - - //基础用例测试 - auto expr1 = LogicalAnd({Eq(s0, s1), Eq(s2, s3)}); - EXPECT_EQ(std::string(expr1.Serialize().get()), "LogicAnd(ExpectEq(s0, s1), ExpectEq(s2, s3))"); - - auto expr2 = LogicalOr({Eq(s0, s1), Eq(s2, s3)}); - EXPECT_EQ(std::string(expr2.Serialize().get()), "LogicOr(ExpectEq(s0, s1), ExpectEq(s2, s3))"); - - //解析测试 - auto expr3 = ExpressionImpl::Parse("LogicAnd(ExpectEq(s0, s1), ExpectEq(s2, s3))"); - EXPECT_EQ(expr3->Str(), expr1.Serialize().get()); - auto expr3_1 = ExpressionImpl::Parse("LogicOr(ExpectEq(s0, s1), ExpectEq(s2, s3))"); - EXPECT_EQ(expr3_1->Str(), expr2.Serialize().get()); - - auto expr3_2 = ExpressionImpl::Parse("LogicAnd(ExpectEq(s0, s1))"); - EXPECT_EQ(expr3_2.get(), nullptr); - auto expr3_3 = ExpressionImpl::Parse("LogicOr(ExpectEq(s0, s1))"); - EXPECT_EQ(expr3_3.get(), nullptr); - - auto expr4 = sym::LogicalAnd({}); - EXPECT_EQ(std::string(expr4.Serialize().get()), "True"); - - auto expr5 = sym::LogicalOr({}); - EXPECT_EQ(std::string(expr5.Serialize().get()), "False"); - - auto expr6 = LogicalAnd({Eq(s0, s1)}); - EXPECT_EQ(std::string(expr6.Serialize().get()), "ExpectEq(s0, s1)"); - - auto expr7 = LogicalOr({Eq(s0, s1)}); - EXPECT_EQ(std::string(expr7.Serialize().get()), "ExpectEq(s0, s1)"); - - auto expr8 = LogicalOr({Eq(Symbol(1), Symbol(1))}); - EXPECT_EQ(std::string(expr8.Serialize().get()), "True"); - - auto expr9 = LogicalOr({Eq(Symbol(1), Symbol(1)), Eq(s0, s1)}); - EXPECT_EQ(std::string(expr9.Serialize().get()), "True"); - - auto expr10 = LogicalAnd({Eq(Symbol(1), Symbol(1)), Eq(s0, s1)}); - EXPECT_EQ(std::string(expr10.Serialize().get()), "ExpectEq(s0, s1)"); - - auto expr11 = LogicalAnd({Eq(Symbol(1), Symbol(0)), Eq(s0, s1)}); - EXPECT_EQ(std::string(expr11.Serialize().get()), "False"); - - //添加到shape env测试 - auto s = ShapeEnvAttr(); - s.CreateSymbol(1, MakeShared(0, 1)); - s.CreateSymbol(2, MakeShared(0, 2)); - s.CreateSymbol(3, MakeShared(0, 3)); - - SetCurShapeEnvContext(&s); - - ExpectSymbolBool(expr1, "xxx", 100); - ExpectSymbolBool(expr2, "xxx", 100); - - EXPECT_TRUE(SymbolicUtils::StaticCheckGt(expr1, expr2) == TriBool::kUnknown); - EXPECT_TRUE(SymbolicUtils::StaticCheckEq(expr1, expr1) == TriBool::kTrue); -} - -TEST_F(UtestExpression, LogicalTestConst) { - EXPECT_TRUE(EXPECT_SYMBOL_AND()); - EXPECT_FALSE(EXPECT_SYMBOL_OR()); - - EXPECT_TRUE(EXPECT_SYMBOL_AND(Eq(Symbol(1), Symbol(1)))); - EXPECT_FALSE(EXPECT_SYMBOL_AND(Eq(Symbol(1), Symbol(1)), Eq(Symbol(1), Symbol(0)))); - - EXPECT_TRUE(EXPECT_SYMBOL_OR(Eq(Symbol(1), Symbol(1)))); - EXPECT_TRUE(EXPECT_SYMBOL_OR(Eq(Symbol(1), Symbol(1)), Eq(Symbol(1), Symbol(0)))); -} - -TEST_F(UtestExpression, LogicalOrTestGuard) { - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto s2 = Symbol("s2"); - auto s3 = Symbol("s3"); - auto s4 = Symbol("s4"); - auto s5 = Symbol("s5"); - - auto s = ShapeEnvAttr(); - SetCurShapeEnvContext(&s); - s.CreateSymbol(1, MakeShared(0, 1)); - s.CreateSymbol(2, MakeShared(0, 2)); - s.CreateSymbol(3, MakeShared(0, 3)); - s.CreateSymbol(4, MakeShared(0, 4)); - s.CreateSymbol(5, MakeShared(0, 5)); - s.CreateSymbol(6, MakeShared(0, 6)); - - auto ret = EXPECT_SYMBOL_OR(Ge(s0, s1), Le(s2, s3), Eq(s4, s5)); - EXPECT_EQ(ret, true); - ret = EXPECT_SYMBOL_OR(Ge(s0, s1), Le(s3, s2), Eq(s5, s4)); - EXPECT_EQ(ret, false); - ret = EXPECT_SYMBOL_AND(Ge(s0, s1), Le(s2, s3), Eq(s4, s5)); - EXPECT_EQ(ret, false); - ret = EXPECT_SYMBOL_AND(Ge(s1, s0), Gt(s3, s2), Gt(s5, s4)); - EXPECT_EQ(ret, true); - - const std::set expect_guard = {"LogicOr(ExpectEq(s4, s5), ExpectLe(s1, s0), ExpectLe(s2, s3))", - "LogicAnd(ExpectLt(s0, s1), ExpectLt(s2, s3), ExpectNe(s4, s5))", - "LogicAnd(ExpectLe(s0, s1), ExpectLt(s2, s3), ExpectLt(s4, s5))", - "LogicOr(ExpectLt(s0, s1), ExpectLt(s3, s2), ExpectNe(s4, s5))"}; - for (auto &iter : s.GetAllSymbolCheckInfos()) { - EXPECT_NE(expect_guard.find(std::string(iter.expr.Serialize().get())), expect_guard.end()); - } -} - -TEST_F(UtestExpression, SimplifyWithShapeEnv) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - - auto expr = sym::Ceiling(sym::Sub(sym::Ceiling(sym::Mul(sym::Rational(1,2), Symbol("s0"))), Symbol(20))); - auto expr1 = Expression::Deserialize(expr.Str().get()); - EXPECT_EQ(expr1.impl_, nullptr); - auto expr2 = Expression::Deserialize(expr.Simplify().Str().get()); - EXPECT_NE(expr2.impl_, nullptr); -} - -TEST_F(UtestExpression, ExpandSimplifyTest) { - SetCurShapeEnvContext(nullptr); - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto c1 = Symbol(1); - auto c2 = Symbol(2); - - // s0 + 2(s1 + 1) - auto expr = sym::Add(s0, sym::Mul(c2, sym::Add(s1, c1))); - EXPECT_EQ(std::string(expr.Str().get()), "(((1 + s1) * 2) + s0)"); - expr = expr.Simplify(); - EXPECT_EQ(std::string(expr.Str().get()), "((2 * s1) + 2 + s0)"); - - auto s2 = Symbol("s2"); - auto s3 = Symbol("s3"); - auto c5 = Symbol(5); - auto expr1 = c2 * s1 * s2 * s3; - auto expr2 = c5 * s2 * s3; - auto expr3 = c5 *s3; - - expr = expr1 + expr2 + expr3 + c5 - (expr1 + expr2 + expr3 + c5); - EXPECT_NE(expr, Symbol(0)); - EXPECT_EQ(expr.Simplify(), Symbol(0)); -} - -TEST_F(UtestExpression, StaticCheckTest) { - SetCurShapeEnvContext(nullptr); - auto s0 = Symbol("s0"); - auto s1 = Symbol("s1"); - auto c0 = Symbol(0); - auto c1 = Symbol(1); - auto c2 = Symbol(2); - - // s0 + 1 - auto expr1 = s0 + c1; - // s0 + 2 - (s0 + 1) = 1 - auto expr2 = s0 + c2 - expr1; - EXPECT_NE(expr2, c1); - // == - EXPECT_EQ(SymbolicUtils::StaticCheckEq(expr2, c1), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckEq(expr2, c2), TriBool::kFalse); - EXPECT_EQ(SymbolicUtils::StaticCheckEq(expr2, s1), TriBool::kUnknown); - // != - EXPECT_EQ(SymbolicUtils::StaticCheckNe(expr2, c2), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckNe(expr2, c1), TriBool::kFalse); - EXPECT_EQ(SymbolicUtils::StaticCheckNe(expr2, s1), TriBool::kUnknown); - // < - EXPECT_EQ(SymbolicUtils::StaticCheckLt(expr2, c2), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckLt(expr2, c1), TriBool::kFalse); - EXPECT_EQ(SymbolicUtils::StaticCheckLt(expr2, s1), TriBool::kUnknown); - // <= - EXPECT_EQ(SymbolicUtils::StaticCheckLe(expr2, c2), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckLe(expr2, c1), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckLe(expr2, c0), TriBool::kFalse); - EXPECT_EQ(SymbolicUtils::StaticCheckLe(expr2, s1), TriBool::kUnknown); - // > - EXPECT_EQ(SymbolicUtils::StaticCheckGt(expr2, c0), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckGt(expr2, c2), TriBool::kFalse); - EXPECT_EQ(SymbolicUtils::StaticCheckGt(expr2, s1), TriBool::kUnknown); - // >= - EXPECT_EQ(SymbolicUtils::StaticCheckGe(expr2, c1), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckGe(expr2, c0), TriBool::kTrue); - EXPECT_EQ(SymbolicUtils::StaticCheckGe(expr2, c2), TriBool::kFalse); - EXPECT_EQ(SymbolicUtils::StaticCheckGe(expr2, s1), TriBool::kUnknown); -} - -TEST_F(UtestExpression, Expect_Add_Replacement_And_Simplify_When_Input_Two_Var_NE_False) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(2, MakeShared(0, 1)); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0, s1), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0, s1)), true); - - EXPECT_EQ(shape_env.Simplify(s0), s1); - EXPECT_EQ(shape_env.Simplify(s1), s1); // 当前的replace如果都为符号,且rank一样的情况下,后面的是前面replace -} - -TEST_F(UtestExpression, Expect_Add_Replacement_And_Simplify_When_Input_One_Var_One_Const_NE_False) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = Symbol(2); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0, s1), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0, s1)), true); - - auto exp = s0 + Symbol(1); - EXPECT_EQ(exp.Simplify(), Symbol(3)); -} - -TEST_F(UtestExpression, Expect_Add_Replacement_And_Simplify_When_Input_One_Var_One_Exper_NE_False) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - Symbol s2 = shape_env.CreateSymbol(1, MakeShared(0, 2)); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0, s1 + s2), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0, s1 + s2)), true); - - EXPECT_EQ(s0.Simplify(), s1 + s2); -} - -TEST_F(UtestExpression, Expect_Add_Replacement_And_Simplify_When_Input_Two_Exper_NE_False) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0 + Symbol(1), s1 + Symbol(2)), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0 + Symbol(1), s1 + Symbol(2))), true); - - auto exp1 = s0 + Symbol(1); - auto exp2 = s1 + Symbol(2); - EXPECT_EQ(exp1.Simplify(), exp2); // 标准化后:s0 == s1 + 1 -} - -TEST_F(UtestExpression, Expect_Not_Add_Replacement_And_Simplify_When_Input_One_Exper_One_Const_NE_False) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0 + s1, Symbol(3)), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0 + s1, Symbol(3))), true); - - auto exp = s0 + s1; - EXPECT_EQ(exp.Simplify(), exp); // 表达式与常量间不支持replace -} - -TEST_F(UtestExpression, Expect_Simplify_All_Guard_When_Input_Replacement_By_EQ) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0 + s1, Symbol(3)), true); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0, Symbol(2)), true); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); // 新增加replacement会化简第一个guard并插入 - - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0 + s1, Symbol(3))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0, Symbol(2))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s1, Symbol(1))), true); -} - -TEST_F(UtestExpression, Expect_Simplify_All_Guard_When_Input_Replacement_By_NE) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0 + s1, Symbol(3)), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0, Symbol(2)), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); // 新增加replacement会化简第一个guard并插入 - - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0 + s1, Symbol(3))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0, Symbol(2))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s1, Symbol(1))), true); -} - -TEST_F(UtestExpression, Expect_Not_Simplify_All_Guard_When_Not_Input_Replacement) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0 + s1, Symbol(3)), false); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(EXPECT_SYMBOL_NE(s0, Symbol(1)), true); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 2); - - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0 + s1, Symbol(3))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Ne(s0, Symbol(1))), true); -} - -TEST_F(UtestExpression, Expect_Static_EQ_True_When_Input_Replacement_By_EQ) { - ShapeEnvAttr shape_env; - SetCurShapeEnvContext(&shape_env); - Symbol s0 = shape_env.CreateSymbol(2, MakeShared(0, 0)); - Symbol s1 = shape_env.CreateSymbol(1, MakeShared(0, 1)); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0 + s1, Symbol(3)), true); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 1); - EXPECT_EQ(EXPECT_SYMBOL_EQ(s0, Symbol(2)), true); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); // 新增加replacement会化简第一个guard并插入 - - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0 + s1, Symbol(3))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s0, Symbol(2))), true); - EXPECT_EQ(shape_env.HasSymbolCheckInfo(sym::Eq(s1, Symbol(1))), true); - - EXPECT_EQ(SymbolicUtils::StaticCheckEq(s1, Symbol(1)), TriBool::kTrue); - EXPECT_EQ(shape_env.GetAllSymbolCheckInfos().size(), 3); // 不会再全量化简 -} -} // namespace ge diff --git a/tests/ut/expression/testcase/source_stub.h b/tests/ut/expression/testcase/source_stub.h deleted file mode 100644 index 1047c83c0670f25db1b02d29f7e1849670d96352..0000000000000000000000000000000000000000 --- a/tests/ut/expression/testcase/source_stub.h +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#ifndef SOURCE_STUB_H -#define SOURCE_STUB_H - -#include "graph/symbolizer/symbolic.h" -#include "common/checker.h" - -#include - -namespace ge { -static std::map kGeDType2CppDtype = { - {ge::DT_INT32, "int32_t"}, - {ge::DT_INT64, "int64_t"}, - {ge::DT_UINT32, "uint32_t"}, - {ge::DT_UINT64, "uint64_t"}, -}; -class GraphInputShapeSourceStub : public Source { -public: - GraphInputShapeSourceStub(int32_t input_data_idx, size_t dim_idx) - : input_data_idx_(input_data_idx), dim_idx_(dim_idx) {} - - std::string GetSourceStr() const override { - return R"([&]() -> int64_t { - const auto *tensor = context->GetGraphInputTensor()" + std::to_string(input_data_idx_) + R"(); - if (tensor == nullptr) { - return -1; - } - return tensor->GetOriginShape().GetDim()" + std::to_string(dim_idx_) + R"(); - }())"; - } - ~GraphInputShapeSourceStub() override = default; -private: - int32_t input_data_idx_; // Data的index,描述symbol来自于graph输入中第几个输入data - size_t dim_idx_; // 描述symbol来自于data中对应shape的第几个dim -}; - -class InputValueSumSourceStub : public ge::Source { -public: - InputValueSumSourceStub(int32_t input_data_idx, ge::DataType dtype) - : input_data_idx_(input_data_idx), dtype_(dtype) {} - - [[nodiscard]] std::string GetSourceStr() const override { - if (kGeDType2CppDtype.find(dtype_) == kGeDType2CppDtype.end()) { - GELOGE(FAILED, "Unsupported data type: %s", TypeUtils::DataTypeToSerialString(dtype_).c_str()); - return ""; - } - return R"([&]() -> int64_t { - const auto* tensor = context->GetGraphInputTensor()" + std::to_string(input_data_idx_) + R"(); - if (tensor == nullptr) { - return -1; - } - const auto* data = tensor->GetData<)" + kGeDType2CppDtype[dtype_] + R"(>(); - int64_t sum = 0; - for (size_t i = 0; i < tensor->GetSize() / sizeof()" + kGeDType2CppDtype[dtype_] + R"(); ++i) { - sum += data[i]; - } - return sum; - }())"; - } - ~InputValueSumSourceStub() override = default; - -private: - int32_t input_data_idx_; // Data的index,描述symbol来自于graph输入中第几个输入data - ge::DataType dtype_; // 描述value的数据类型,用于后续执行时取值 -}; -} - -#endif // SOURCE_STUB_H \ No newline at end of file diff --git a/tests/ut/graph/CMakeLists.txt b/tests/ut/graph/CMakeLists.txt deleted file mode 100644 index dd8ec135ae34f968c60f520f91a6090cbb86fae8..0000000000000000000000000000000000000000 --- a/tests/ut/graph/CMakeLists.txt +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2024 Huawei Technologies Co., Ltd. -# This file is a part of the CANN Open Software. -# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -# Please refer to the License for details. You may not use this file except in compliance with the License. -# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -# See LICENSE in the root of the software repository for the full text of the License. -# ====================================================================================================================== - -# include directories -include_directories(${CMAKE_CURRENT_LIST_DIR}) -include_directories(${METADEF_DIR}/inc/common/util/trace_manager) -include_directories(${CMAKE_BINARY_DIR}/proto/metadef_protos) -include_directories(${CMAKE_BINARY_DIR}/proto/metadef_protos/proto) -include_directories(${METADEF_DIR}) -include_directories(${METADEF_DIR}/graph) -include_directories(${CMAKE_BINARY_DIR}) -include_directories(${CMAKE_BINARY_DIR}/proto/ge) -include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) - -file(GLOB_RECURSE UT_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/graph/testcase/*.cc" - "${METADEF_DIR}/tests/ut/base/testcase/ascend_string_unittest.cc" - "${METADEF_DIR}/tests/ut/base/testcase/type_utils_unittest.cc") -file(GLOB_RECURSE FAKE_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/cache_desc_stub/runtime_cache_desc.cc") -file(GLOB_RECURSE UTILS_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/graph/common/*.cc") -file(GLOB_RECURSE FAKER_SRCS CONFIGURE_DEPENDS "${METADEF_DIR}/tests/depends/faker/kernel_run*.cc") -add_executable(ut_graph ${UT_FILES} ${FAKE_FILES} ${UTILS_FILES} ${FAKER_SRCS}) - -target_compile_options(ut_graph PRIVATE - -g --coverage -fprofile-arcs -ftest-coverage - -Wno-deprecated-declarations - -Wall -Wfloat-equal -Werror - -fno-access-control -) - -target_compile_definitions(ut_graph PRIVATE - $<$:ONLY_COMPILE_OPEN_SRC> - google=ascend_private - FUNC_VISIBILITY -) - -target_link_libraries(ut_graph PRIVATE - intf_pub -lgcov - -Wl,--no-as-needed - platform_stub - slog_headers - metadef_headers register opp_registry graph graph_base error_manager aihac_symbolizer runtime_headers - msprof_headers GTest::gtest GTest::gtest_main ascend_protobuf slog_stub c_sec json mmpa_stub -lrt -ldl -) - -target_include_directories(ut_graph PRIVATE - ${METADEF_DIR}/tests/depends - ${METADEF_DIR}/tests/ut/graph - ${METADEF_DIR}/tests/ut/graph/common -) diff --git a/tests/ut/graph/common/graph_builder_utils.cc b/tests/ut/graph/common/graph_builder_utils.cc deleted file mode 100644 index d025b6d05d2ce8d80c37e815118e71eb97a9ea78..0000000000000000000000000000000000000000 --- a/tests/ut/graph/common/graph_builder_utils.cc +++ /dev/null @@ -1,105 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph_builder_utils.h" - -#include "graph/utils/graph_utils.h" - -namespace ge { -namespace ut { - -GeTensorDescPtr GetTensorDesc(const std::string &name, const std::string &type, Format format, DataType data_type, - std::vector shape) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape(shape)); - tensor_desc->SetFormat(format); - tensor_desc->SetDataType(data_type); - tensor_desc->SetOriginFormat(format); - tensor_desc->SetOriginShape(GeShape(shape)); - tensor_desc->SetOriginDataType(data_type); - - return tensor_desc; -} - -NodePtr GraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, Format format, - DataType data_type, std::vector shape) { - auto tensor_desc = GetTensorDesc(name, type, format, data_type, shape); - - auto op_desc = std::make_shared(name, type); - for (int i = 0; i < in_cnt; ++i) { - op_desc->AddInputDesc(tensor_desc->Clone()); - } - for (int i = 0; i < out_cnt; ++i) { - op_desc->AddOutputDesc(tensor_desc->Clone()); - } - return graph_->AddNode(op_desc); -} - -void GraphBuilder::AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx) { - GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); -} -void GraphBuilder::AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node) { - GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); -} -NodePtr GraphBuilder::AddNode(const string &name, const string &type, std::initializer_list input_names, - std::initializer_list output_names, Format format, DataType data_type, - std::vector shape) { - auto tensor_desc = GetTensorDesc(name, type, format, data_type, shape); - - auto op_desc_ptr = std::make_shared(name, type); - for (auto &input_name : input_names) { - op_desc_ptr->AddInputDesc(input_name, tensor_desc->Clone()); - } - for (auto &output_name : output_names) { - op_desc_ptr->AddOutputDesc(output_name, tensor_desc->Clone()); - } - - return graph_->AddNode(op_desc_ptr); -} - -FastNode *ExecuteGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format, DataType data_type, std::vector shape) { - auto tensor_desc = GetTensorDesc(name, type, format, data_type, shape); - - auto op_desc_ptr = std::make_shared(name, type); - for (int i = 0; i < in_cnt; ++i) { - op_desc_ptr->AddInputDesc(tensor_desc->Clone()); - } - for (int i = 0; i < out_cnt; ++i) { - op_desc_ptr->AddOutputDesc(tensor_desc->Clone()); - } - - return graph_->AddNode(op_desc_ptr); -} - -FastEdge *ExecuteGraphBuilder::AddDataEdge(FastNode *src_node, int src_idx, FastNode *dst_node, int dst_idx) { - return graph_->AddEdge(src_node, src_idx, dst_node, dst_idx); -} -FastEdge *ExecuteGraphBuilder::AddControlEdge(FastNode *src_node, FastNode *dst_node) { - return graph_->AddEdge(src_node, -1, dst_node, -1); -} -FastNode *ExecuteGraphBuilder::AddNode(const string &name, const string &type, - std::initializer_list input_names, - std::initializer_list output_names, Format format, DataType data_type, - std::vector shape) { - auto tensor_desc = GetTensorDesc(name, type, format, data_type, shape); - - auto op_desc_ptr = std::make_shared(name, type); - for (auto &input_name : input_names) { - op_desc_ptr->AddInputDesc(input_name, tensor_desc->Clone()); - } - for (auto &output_name : output_names) { - op_desc_ptr->AddOutputDesc(output_name, tensor_desc->Clone()); - } - - return graph_->AddNode(op_desc_ptr); -} - -} // namespace ut -} // namespace ge diff --git a/tests/ut/graph/common/graph_builder_utils.h b/tests/ut/graph/common/graph_builder_utils.h deleted file mode 100644 index ef5bcebc12703be6ea5aa30f3b7b8a441837721d..0000000000000000000000000000000000000000 --- a/tests/ut/graph/common/graph_builder_utils.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ -#define MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ - -#include -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/node.h" -#include "graph/fast_graph/execute_graph.h" - -namespace ge { -namespace ut { -class GraphBuilder { - public: - explicit GraphBuilder(const std::string &name) { - graph_ = std::make_shared(name); - } - NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, - std::vector shape = {1, 1, 224, 224}); - - NodePtr AddNode(const std::string &name, const std::string &type, std::initializer_list input_names, - std::initializer_list output_names, Format format = FORMAT_NCHW, - DataType data_type = DT_FLOAT, std::vector shape = {1, 1, 224, 224}); - void AddDataEdge(const NodePtr &src_node, int src_idx, const NodePtr &dst_node, int dst_idx); - void AddControlEdge(const NodePtr &src_node, const NodePtr &dst_node); - ComputeGraphPtr GetGraph() { - graph_->TopologicalSorting(); - return graph_; - } - ComputeGraphPtr GetGraphWithoutSort() { - return graph_; - } - - private: - ComputeGraphPtr graph_; -}; - -class ExecuteGraphBuilder { - public: - explicit ExecuteGraphBuilder(const std::string &name) { - graph_ = std::make_shared(name); - } - FastNode *AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, - Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, - std::vector shape = {1, 1, 224, 224}); - - FastNode *AddNode(const std::string &name, const std::string &type, std::initializer_list input_names, - std::initializer_list output_names, Format format = FORMAT_NCHW, - DataType data_type = DT_FLOAT, std::vector shape = {1, 1, 224, 224}); - FastEdge *AddDataEdge(FastNode *src_node, int src_idx, FastNode *dst_node, int dst_idx); - FastEdge *AddControlEdge(FastNode *src_node, FastNode *dst_node); - std::shared_ptr GetGraph() { - graph_->TopologicalSorting(); - return graph_; - } - std::shared_ptr GetGraphBeforeTopo() { - return graph_; - } - - private: - std::shared_ptr graph_; -}; -} // namespace ut -} // namespace ge - -#endif // MAIN_LLT_FRAMEWORK_DOMI_UT_GE_TEST_GRAPH_PASSES_GRAPH_BUILDER_UTILS_H_ diff --git a/tests/ut/graph/common/share_graph.cc b/tests/ut/graph/common/share_graph.cc deleted file mode 100644 index 6ea84016a12b84131a06e7ec0d6e517bcd3ff856..0000000000000000000000000000000000000000 --- a/tests/ut/graph/common/share_graph.cc +++ /dev/null @@ -1,101 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "share_graph.h" - -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" - -namespace ge { -template -GraphT SharedGraph::BuildGraphWithSubGraph() { - auto root_builder = CreateBuilder("root"); - const auto &data0 = root_builder.AddNode("data0", "Data", 1, 1); - const auto &case0 = root_builder.AddNode("case0", "Case", 1, 1); - const auto &relu0 = root_builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = root_builder.AddNode("relu1", "Relu", 1, 1); - const auto &netoutput = root_builder.AddNode("netoutput", "NetOutput", 1, 1); - auto root_graph = root_builder.GetGraph(); - root_builder.AddDataEdge(data0, 0, case0, 0); - root_builder.AddDataEdge(case0, 0, relu0, 0); - root_builder.AddDataEdge(relu0, 0, relu1, 0); - root_builder.AddDataEdge(relu1, 0, netoutput, 0); - - auto sub_builder1 = CreateBuilder("sub1"); - (void) sub_builder1.AddNode("data1", "Data", 0, 1); - auto sub_graph1 = sub_builder1.GetGraph(); - root_graph->AddSubGraph(sub_graph1); - sub_graph1->SetParentNode(case0); - BuilderUtils::SetParentGraph(sub_graph1, root_graph); - case0->GetOpDescBarePtr()->AddSubgraphName("branch1"); - case0->GetOpDescBarePtr()->SetSubgraphInstanceName(0, "sub1"); - - auto sub_builder2 = CreateBuilder("sub2"); - (void) sub_builder2.AddNode("data2", "Data", 0, 1); - auto sub_graph2 = sub_builder2.GetGraph(); - root_graph->AddSubGraph(sub_graph2); - sub_graph2->SetParentNode(case0); - BuilderUtils::SetParentGraph(sub_graph2, root_graph); - case0->GetOpDescBarePtr()->AddSubgraphName("branch2"); - case0->GetOpDescBarePtr()->SetSubgraphInstanceName(1, "sub2"); - root_graph->TopologicalSorting(); - return root_graph; -} - -template -GraphT SharedGraph::BuildGraphWithConst() { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - - auto builder = CreateBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDescBarePtr(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - AttrUtils::SetStr(const_node->GetOpDescBarePtr(), "fake_attr_name", "fake_attr_value"); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - AttrUtils::SetStr(add_node->GetOpDescBarePtr(), "fake_attr_name", "fake_attr_value"); - AttrUtils::SetStr(add_node->GetOpDescBarePtr(), ge::ATTR_NAME_WEIGHTS, "fake_attr_value"); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, netoutput, 0); - return builder.GetGraph(); -} - -/** n5 - * / | \\c - * n2 n3 n4 - * \ | //c - * n1 - */ -template -GraphT SharedGraph::BuildGraphWithControlEdge() { - auto builder = CreateBuilder("graph_with_ctrl_edge"); - auto n1 = builder.AddNode("n1", "Data", 1, 1); - auto n2 = builder.AddNode("n2", "Op", 1, 1); - auto n3 = builder.AddNode("n3", "Op", 1, 1); - auto n4 = builder.AddNode("n4", "Op", 1, 1); - auto n5 = builder.AddNode("n5", "Netoutput", 3, 1); - builder.AddDataEdge(n1, 0, n2, 0); - builder.AddDataEdge(n1, 0, n3, 0); - builder.AddControlEdge(n1, n4); - builder.AddDataEdge(n2, 0, n5, 0); - builder.AddDataEdge(n3, 0, n5, 1); - builder.AddControlEdge(n4, n5); - builder.AddDataEdge(n1, 0, n4, 0); - builder.AddDataEdge(n4, 0, n5, 2); - return builder.GetGraph(); -} - -template class SharedGraph; -template class SharedGraph; -} // namespace ge diff --git a/tests/ut/graph/common/share_graph.h b/tests/ut/graph/common/share_graph.h deleted file mode 100644 index e6ea17856e148ae438df4a37e3c742d61c269bb9..0000000000000000000000000000000000000000 --- a/tests/ut/graph/common/share_graph.h +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_SHARE_GRAPH_H -#define METADEF_CXX_SHARE_GRAPH_H - -#include "graph_builder_utils.h" - -namespace ge { -template -class SharedGraph { - public: - static GraphT BuildGraphWithSubGraph(); - static GraphT BuildGraphWithConst(); - static GraphT BuildGraphWithControlEdge(); - - private: - static BuilderT CreateBuilder(const std::string &name) { - return BuilderT(name); - } -}; - -template -class BuilderUtils { - public: - static void SetParentGraph(GraphT &graph, GraphT &parent_graph) { - graph->SetParentGraph(parent_graph); - } -}; - -template<> -class BuilderUtils { - public: - // ExecuteGraph requires passing a raw pointer, so explicitly specialize template - static void SetParentGraph(ExecuteGraphPtr &graph, ExecuteGraphPtr &parent_graph) { - graph->SetParentGraph(parent_graph.get()); - } -}; - -} // namespace ge - -#endif //METADEF_CXX_SHARE_GRAPH_H diff --git a/tests/ut/graph/testcase/aging_policy_lru_k_unittest.cc b/tests/ut/graph/testcase/aging_policy_lru_k_unittest.cc deleted file mode 100644 index c5531dbf3b45253c826f27afc1227609e5d7b097..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/aging_policy_lru_k_unittest.cc +++ /dev/null @@ -1,135 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/cache_policy/aging_policy_lru_k.h" -#include "cache_desc_stub/runtime_cache_desc.h" -#include "graph/cache_policy/cache_state.h" - -namespace ge { -namespace { -CacheDescPtr CreateRuntimeCacheDesc(const std::vector &shapes) { - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes); - return cache_desc; -} -void InsertCacheInfoQueue(CacheState &cache_state, uint16_t depth) { - for (uint16_t i = 0; i < depth; ++i) { - int64_t dim_0 = i; - gert::Shape s{dim_0, 256, 256}; - auto cache_desc = CreateRuntimeCacheDesc({s}); - auto hash_key = cache_desc->GetCacheDescHash(); - (void) cache_state.AddCache(hash_key, cache_desc); - } -} -} // namespace -class AgingPolicyLruKUT : public testing::Test {}; - -TEST_F(AgingPolicyLruKUT, IsReadyToAddCache_ReturnFalse_CacheDescNotAppear2Times) { - gert::Shape s1{256, 256}; - const std::vector shapes1{s1}; - auto cache_desc = CreateRuntimeCacheDesc(shapes1); - auto hash = cache_desc->GetCacheDescHash(); - - AgingPolicyLruK ag; - EXPECT_EQ(ag.IsReadyToAddCache(hash, cache_desc), false); -} - -TEST_F(AgingPolicyLruKUT, IsReadyToAddCache_ReturnTrue_CacheDescAppear2Times) { - gert::Shape s1{256, 256}; - const std::vector shapes1{s1}; - auto cache_desc = CreateRuntimeCacheDesc(shapes1); - auto hash = cache_desc->GetCacheDescHash(); - - AgingPolicyLruK ag; - EXPECT_EQ(ag.IsReadyToAddCache(hash, cache_desc), false); - EXPECT_EQ(ag.IsReadyToAddCache(hash, cache_desc), true); -} - -TEST_F(AgingPolicyLruKUT, IsReadyToAddCache_ReturnFalse_CacheDescNotMatched) { - gert::Shape s1{256, 256}; - gert::Shape s2{1, 256, 256}; - const std::vector shapes1{s1}; - const std::vector shapes2{s2}; - auto cache_desc1 = CreateRuntimeCacheDesc(shapes1); - auto cache_desc2 = CreateRuntimeCacheDesc(shapes2); - auto hash2 = cache_desc2->GetCacheDescHash(); - - AgingPolicyLruK ag; - EXPECT_EQ(ag.IsReadyToAddCache(hash2, cache_desc1), false); - EXPECT_EQ(ag.IsReadyToAddCache(hash2, cache_desc2), false); -} - -TEST_F(AgingPolicyLruKUT, IsReadyToAddCache_ReturnTrue_HashCollisionButCacheDescMatched) { - gert::Shape s1{256, 256}; - gert::Shape s2{1, 256, 256}; - const std::vector shapes1{s1}; - const std::vector shapes2{s2}; - auto cache_desc1 = CreateRuntimeCacheDesc(shapes1); - auto cache_desc2 = CreateRuntimeCacheDesc(shapes2); - auto hash2 = cache_desc2->GetCacheDescHash(); - - AgingPolicyLruK ag; - EXPECT_EQ(ag.IsReadyToAddCache(hash2, cache_desc1), false); - EXPECT_EQ(ag.IsReadyToAddCache(hash2, cache_desc2), false); - EXPECT_EQ(ag.IsReadyToAddCache(hash2, cache_desc2), true); -} - -TEST_F(AgingPolicyLruKUT, DoAging_NoAgingId_CacheQueueNotReachDepth) { - CacheState cache_state; - uint16_t depth = 20; - AgingPolicyLruK ag(depth); - - auto delete_ids = ag.DoAging(cache_state); - EXPECT_EQ(delete_ids.size(), 0); - - InsertCacheInfoQueue(cache_state, depth); - delete_ids = ag.DoAging(cache_state); - EXPECT_EQ(delete_ids.size(), 0); -} - -TEST_F(AgingPolicyLruKUT, DoAging_GetAgingIds_CacheQueueOverDepth) { - CacheState cache_state; - AgingPolicyLruK ag(20); - auto delete_ids = ag.DoAging(cache_state); - EXPECT_EQ(delete_ids.size(), 0); - - uint16_t depth = 21; - InsertCacheInfoQueue(cache_state, depth); - delete_ids = ag.DoAging(cache_state); - ASSERT_EQ(delete_ids.size(), 1); - EXPECT_EQ(delete_ids[0], 0); - - depth = 25; - InsertCacheInfoQueue(cache_state, depth); - delete_ids = ag.DoAging(cache_state); - ASSERT_EQ(delete_ids.size(), 1); - EXPECT_EQ(delete_ids[0], 0); -} -TEST_F(AgingPolicyLruKUT, DoAging_Aging5Times_CacheQueueDepthIs25) { - CacheState cache_state; - AgingPolicyLruK ag(20); - auto delete_ids = ag.DoAging(cache_state); - EXPECT_EQ(delete_ids.size(), 0); - - int16_t depth = 25; - InsertCacheInfoQueue(cache_state, depth); - - for (size_t i = 0U; i < static_cast(depth); ++i) { - delete_ids = ag.DoAging(cache_state); - if (i < 5U) { - ASSERT_EQ(delete_ids.size(), 1); - EXPECT_EQ(delete_ids[0], i); - } else { - EXPECT_EQ(delete_ids.size(), 0); - } - cache_state.DelCache(delete_ids); - } -} -} // namespace ge diff --git a/tests/ut/graph/testcase/amy_map_unittest.cc b/tests/ut/graph/testcase/amy_map_unittest.cc deleted file mode 100644 index 8daf92cc95413b6b6f20c79ca92e8c27390614a8..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/amy_map_unittest.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/detail/any_map.h" -#include - -namespace ge { -class AnyMapUt : public testing::Test {}; - -TEST_F(AnyMapUt, GetNamesSuccess) { - AnyMap am; - std::set names; - am.Names(names); - EXPECT_EQ(names.size(), 0); - am.Set("name1", 10); - am.Set("name2", 20); - - am.Names(names); - EXPECT_EQ(names, std::set({"name1", "name2"})); -} - -TEST_F(AnyMapUt, TestHasFuncOk) { - AnyMap am; - EXPECT_FALSE(am.Has("name1")); - EXPECT_FALSE(am.Has("name2")); - - am.Set("name1", 10); - am.Set("name2", 20); - - EXPECT_TRUE(am.Has("name1")); - EXPECT_TRUE(am.Has("name2")); -} - -TEST_F(AnyMapUt, TestSwapFuncOk) { - AnyMap am1, am2; - am1.Set("name1", static_cast(10)); - am2.Set("name2", std::vector({10,20,30,40})); - am1.Swap(am2); - - EXPECT_EQ(am1.Get("name1"), nullptr); - auto p = am1.Get>("name2"); - ASSERT_NE(p, nullptr); - EXPECT_EQ(*am1.Get>("name2"), std::vector({10,20,30,40})); -} - -TEST_F(AnyMapUt, GetOk) { - AnyMap am; - std::vector data = {1,2,3,4,5,6}; - std::vector ret; - - EXPECT_EQ(am.Get>("Test"), nullptr); - EXPECT_FALSE(am.Get("Test", ret)); - - am.Set("Test", data); - ASSERT_NE(am.Get>("Test"), nullptr); - EXPECT_EQ(*am.Get>("Test"), data); - - EXPECT_TRUE(am.Get("Test", ret)); - EXPECT_EQ(ret, data); - -} -} diff --git a/tests/ut/graph/testcase/anchor_unittest.cc b/tests/ut/graph/testcase/anchor_unittest.cc deleted file mode 100644 index 4bb983bb2506922948a12682bb8318edbc1890fa..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/anchor_unittest.cc +++ /dev/null @@ -1,517 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "test_structs.h" -#include "func_counter.h" -#include "graph/anchor.h" -#include "graph/node.h" -#include "graph_builder_utils.h" -#include "utils/graph_utils.h" - -namespace ge { -namespace { -class SubInDataAnchor : public InDataAnchor -{ -public: - - SubInDataAnchor(const NodePtr& owner_node, const int32_t idx); - - virtual ~SubInDataAnchor(); - - bool EncaEq(const AnchorPtr anchor); - bool EncaIsTypeOf(const TYPE type); - void SetImpNull(); - template - static Anchor::TYPE EncaTypeOf() { - return Anchor::TypeOf(); - } - -}; - -SubInDataAnchor::SubInDataAnchor(const NodePtr &owner_node, const int32_t idx):InDataAnchor(owner_node,idx){} - -SubInDataAnchor::~SubInDataAnchor() = default; - -bool SubInDataAnchor::EncaEq(const AnchorPtr anchor){ - return Equal(anchor); -} - -bool SubInDataAnchor::EncaIsTypeOf(const TYPE type){ - return IsTypeOf(type); -} - -void SubInDataAnchor::SetImpNull(){ - impl_ = nullptr; -} - -class SubOutDataAnchor : public OutDataAnchor -{ -public: - - SubOutDataAnchor(const NodePtr& owner_node, const int32_t idx); - - virtual ~SubOutDataAnchor(); - - bool EncaEq(const AnchorPtr anchor); - bool EncaIsTypeOf(const TYPE type); - void SetImpNull(); - -}; - -SubOutDataAnchor::SubOutDataAnchor(const NodePtr &owner_node, const int32_t idx):OutDataAnchor(owner_node,idx){} - -SubOutDataAnchor::~SubOutDataAnchor() = default; - -bool SubOutDataAnchor::EncaEq(const AnchorPtr anchor){ - return Equal(anchor); -} - -bool SubOutDataAnchor::EncaIsTypeOf(const TYPE type){ - return IsTypeOf(type); -} - -void SubOutDataAnchor::SetImpNull(){ - impl_ = nullptr; -} - -class SubInControlAnchor : public InControlAnchor -{ -public: - - SubInControlAnchor(const NodePtr &owner_node); - SubInControlAnchor(const NodePtr& owner_node, const int32_t idx); - - virtual ~SubInControlAnchor(); - - bool EncaEq(const AnchorPtr anchor); - bool EncaIsTypeOf(const TYPE type); - void SetImpNull(); - -}; - -SubInControlAnchor::SubInControlAnchor(const NodePtr &owner_node):InControlAnchor(owner_node){} -SubInControlAnchor::SubInControlAnchor(const NodePtr &owner_node, const int32_t idx):InControlAnchor(owner_node,idx){} - -SubInControlAnchor::~SubInControlAnchor() = default; - -bool SubInControlAnchor::EncaEq(const AnchorPtr anchor){ - return Equal(anchor); -} - -bool SubInControlAnchor::EncaIsTypeOf(const TYPE type){ - return IsTypeOf(type); -} - -void SubInControlAnchor::SetImpNull(){ - impl_ = nullptr; -} - -class SubOutControlAnchor : public OutControlAnchor -{ -public: - - SubOutControlAnchor(const NodePtr &owner_node); - - SubOutControlAnchor(const NodePtr& owner_node, const int32_t idx); - - virtual ~SubOutControlAnchor(); - - bool EncaEq(const AnchorPtr anchor); - bool EncaIsTypeOf(const TYPE type); - void SetImpNull(); - -}; - -SubOutControlAnchor::SubOutControlAnchor(const NodePtr &owner_node):OutControlAnchor(owner_node){} -SubOutControlAnchor::SubOutControlAnchor(const NodePtr &owner_node, const int32_t idx):OutControlAnchor(owner_node,idx){} - -SubOutControlAnchor::~SubOutControlAnchor() = default; - -bool SubOutControlAnchor::EncaEq(const AnchorPtr anchor){ - return Equal(anchor); -} - -bool SubOutControlAnchor::EncaIsTypeOf(const TYPE type){ - return IsTypeOf(type); -} - -void SubOutControlAnchor::SetImpNull(){ - impl_ = nullptr; -} - -} - -using SubInDataAnchorPtr = std::shared_ptr; -using SubOutDataAnchorPtr = std::shared_ptr; -using SubInControlAnchorPtr = std::shared_ptr; -using SubOutControlAnchorPtr = std::shared_ptr; - -class AnchorUt : public testing::Test {}; - -TEST_F(AnchorUt, SubInDataAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", "Data", 1, 1); - SubInDataAnchorPtr in_anch = std::make_shared(node, 111); - in_anch->SetIdx(222); - EXPECT_EQ(in_anch->GetIdx(), 222); - EXPECT_EQ(in_anch->EncaEq(Anchor::DynamicAnchorCast(in_anch)), true); - EXPECT_EQ(in_anch->GetPeerAnchorsSize(),0); - EXPECT_EQ(in_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_EQ(in_anch->GetOwnerNode(), node); - EXPECT_EQ(in_anch->GetOwnerNodeBarePtr(), node.get()); - EXPECT_EQ(in_anch->IsLinkedWith(nullptr), false); - EXPECT_EQ(in_anch->GetPeerOutAnchor(), nullptr); - EXPECT_EQ(in_anch->LinkFrom(nullptr), GRAPH_FAILED); - auto node2 = builder.AddNode("Data", "Data", 2, 2); - OutDataAnchorPtr peer = std::make_shared(node2, 22); - EXPECT_EQ(in_anch->LinkFrom(peer), GRAPH_SUCCESS); - EXPECT_EQ(in_anch->IsLinkedWith(peer), true); - EXPECT_EQ(in_anch->GetPeerAnchorsSize(),1); - EXPECT_EQ(in_anch->GetPeerAnchors().size(),1); - EXPECT_NE(in_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_NE(in_anch->GetOwnerNode(),nullptr); - auto node3 = builder.AddNode("Data", "Data", 3, 3); - auto node4 = builder.AddNode("Data", "Data", 4, 4); - OutDataAnchorPtr first = std::make_shared(node4, 44); - SubInDataAnchorPtr second = std::make_shared(node3, 33); - EXPECT_EQ(in_anch->Insert(peer,first,second),GRAPH_SUCCESS); - - auto node5 = builder.AddNode("Data", "Data", 5, 5); - OutDataAnchorPtr oa5 = std::make_shared(node5, 55); - auto node6 = builder.AddNode("Data", "Data", 6, 6); - SubInDataAnchorPtr ia6 = std::make_shared(node, 66); - EXPECT_EQ(ia6->LinkFrom(oa5), GRAPH_SUCCESS); - EXPECT_EQ(ia6->Unlink(oa5), GRAPH_SUCCESS); - - EXPECT_EQ(in_anch->Unlink(nullptr),GRAPH_FAILED); - EXPECT_EQ(in_anch->EncaEq(nullptr),false); - EXPECT_EQ(in_anch->EncaIsTypeOf("nnn"),false); - EXPECT_NE(in_anch->DynamicAnchorCast(in_anch),nullptr); - EXPECT_EQ(in_anch->DynamicAnchorCast(in_anch),nullptr); -} - - -TEST_F(AnchorUt, SubOutDataAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", "Data", 1, 1); - auto attr = builder.AddNode("Attr", "Attr", 1, 1); - SubOutDataAnchorPtr out_anch = std::make_shared(node, 111); - out_anch->SetIdx(222); - EXPECT_EQ(out_anch->EncaEq(Anchor::DynamicAnchorCast(out_anch)), true); - EXPECT_EQ(out_anch->GetIdx(), 222); - EXPECT_EQ(out_anch->GetPeerAnchorsSize(),0); - EXPECT_EQ(out_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_EQ(out_anch->GetOwnerNode(),node); - EXPECT_EQ(out_anch->IsLinkedWith(nullptr), false); - auto node2 = builder.AddNode("Data", "Data", 2, 2); - InDataAnchorPtr peer = std::make_shared(node2, 22); - EXPECT_EQ(out_anch->LinkTo(peer), GRAPH_SUCCESS); - - auto node3 = builder.AddNode("Data", "Data", 3, 3); - InControlAnchorPtr peerctr = std::make_shared(node3, 33); - EXPECT_EQ(out_anch->LinkTo(peerctr), GRAPH_SUCCESS); - EXPECT_EQ(peerctr->GetPeerOutDataAnchors().size(), 1); - EXPECT_EQ(out_anch->GetPeerAnchorsSize(),2); - EXPECT_EQ(out_anch->GetPeerAnchors().size(),2); - EXPECT_NE(out_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_NE(out_anch->GetOwnerNode(),nullptr); - auto node22 = builder.AddNode("Data", "Data", 22, 22); - SubInDataAnchorPtr peerd2 = std::make_shared(node2, 222); - peerd2->SetImpNull(); - EXPECT_EQ(out_anch->LinkTo(peerd2), GRAPH_FAILED); - auto node33 = builder.AddNode("Data", "Data", 33, 33); - SubInControlAnchorPtr peerctr2 = std::make_shared(node3, 333); - peerctr2->SetImpNull(); - EXPECT_EQ(out_anch->LinkTo(peerctr2), GRAPH_FAILED); - - EXPECT_EQ(out_anch->Unlink(nullptr),GRAPH_FAILED); - EXPECT_EQ(out_anch->EncaEq(nullptr),false); - EXPECT_EQ(out_anch->EncaIsTypeOf("nnn"),false); - out_anch->SetImpNull(); - auto nodelast = builder.AddNode("Data", "Data", 23, 23); - SubInDataAnchorPtr peerd23 = std::make_shared(nodelast, 223); - EXPECT_EQ(out_anch->LinkTo(peerd23), GRAPH_FAILED); - - auto node4 = builder.AddNode("Data4", "Data", 1, 1); - InDataAnchorPtr peerin2 = std::make_shared(node4, 44); - out_anch->impl_ = nullptr; - EXPECT_EQ(out_anch->LinkTo(peerin2), GRAPH_FAILED); -} - - -TEST_F(AnchorUt, SubInControlAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 111, 1); - SubInControlAnchorPtr in_canch0 = std::make_shared(node0); - EXPECT_NE(in_canch0, nullptr); - EXPECT_EQ(in_canch0->EncaEq(Anchor::DynamicAnchorCast(in_canch0)), true); - auto node = builder.AddNode("Data", "Data", 1, 1); - SubInControlAnchorPtr inc_anch = std::make_shared(node, 111); - inc_anch->SetIdx(222); - EXPECT_EQ(inc_anch->GetIdx(), 222); - EXPECT_EQ(inc_anch->GetPeerAnchorsSize(),0); - EXPECT_EQ(inc_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_EQ(inc_anch->GetOwnerNode(),node); - EXPECT_EQ(inc_anch->IsLinkedWith(nullptr), false); - EXPECT_EQ(inc_anch->LinkFrom(nullptr), GRAPH_FAILED); - auto node2 = builder.AddNode("Data", "Data", 2, 2); - OutControlAnchorPtr peer = std::make_shared(node2, 22); - EXPECT_EQ(inc_anch->LinkFrom(peer), GRAPH_SUCCESS); - EXPECT_EQ(inc_anch->IsPeerOutAnchorsEmpty(),false); - EXPECT_EQ(inc_anch->GetPeerAnchorsSize(),1); - EXPECT_EQ(inc_anch->GetPeerAnchors().size(),1); - EXPECT_EQ(inc_anch->GetPeerAnchorsPtr().size(),1); - EXPECT_EQ(inc_anch->GetPeerOutDataAnchors().size(), 0); - EXPECT_NE(inc_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_NE(inc_anch->GetOwnerNode(),nullptr); - auto node3 = builder.AddNode("Data", "Data", 3, 3); - SubInControlAnchorPtr second = std::make_shared(node3, 33); - auto node4 = builder.AddNode("Data", "Data", 4, 4); - OutControlAnchorPtr first = std::make_shared(node4, 44); - EXPECT_EQ(inc_anch->Insert(peer,first,second),GRAPH_SUCCESS); - EXPECT_EQ(inc_anch->Unlink(nullptr),GRAPH_FAILED); - EXPECT_EQ(inc_anch->EncaEq(nullptr),false); - EXPECT_EQ(inc_anch->EncaIsTypeOf("nnn"),false); - inc_anch->UnlinkAll(); - auto node24 = builder.AddNode("Data24", "Data", 2, 2); - OutControlAnchorPtr peer24 = std::make_shared(node24, 24); - inc_anch->impl_ = nullptr; - EXPECT_EQ(inc_anch->LinkFrom(peer24), GRAPH_FAILED); -} - - -TEST_F(AnchorUt, SubOutControlAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 111, 1); - SubOutControlAnchorPtr out_canch0 = std::make_shared(node0); - EXPECT_NE(out_canch0, nullptr); - EXPECT_EQ(out_canch0->EncaEq(Anchor::DynamicAnchorCast(out_canch0)), true); - auto node = builder.AddNode("Data", "Data", 1, 1); - SubOutControlAnchorPtr outc_anch = std::make_shared(node, 111); - outc_anch->SetIdx(222); - EXPECT_EQ(outc_anch->GetIdx(), 222); - EXPECT_EQ(outc_anch->GetPeerAnchorsSize(),0); - EXPECT_EQ(outc_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_EQ(outc_anch->GetOwnerNode(),node); - EXPECT_EQ(outc_anch->IsLinkedWith(nullptr), false); - auto node2 = builder.AddNode("Data", "Data", 2, 2); - InDataAnchorPtr peer = std::make_shared(node2, 22); - EXPECT_EQ(outc_anch->LinkTo(peer), GRAPH_SUCCESS); - auto node3 = builder.AddNode("Data", "Data", 3, 3); - InControlAnchorPtr peerctr = std::make_shared(node3, 33); - EXPECT_EQ(outc_anch->LinkTo(peerctr), GRAPH_SUCCESS); - EXPECT_EQ(outc_anch->GetPeerAnchorsSize(),2); - EXPECT_EQ(outc_anch->GetPeerAnchors().size(),2); - EXPECT_NE(outc_anch->GetFirstPeerAnchor(),nullptr); - EXPECT_NE(outc_anch->GetOwnerNode(),nullptr); - auto node22 = builder.AddNode("Data", "Data", 22, 22); - SubInDataAnchorPtr peerd2 = std::make_shared(node2, 222); - peerd2->SetImpNull(); - EXPECT_EQ(outc_anch->LinkTo(peerd2), GRAPH_FAILED); - auto node33 = builder.AddNode("Data", "Data", 33, 33); - SubInControlAnchorPtr peerctr2 = std::make_shared(node3, 333); - peerctr2->SetImpNull(); - EXPECT_EQ(outc_anch->LinkTo(peerctr2), GRAPH_FAILED); - EXPECT_NE(outc_anch->GetPeerInControlAnchors().size(), 0); - EXPECT_NE(outc_anch->GetPeerInControlAnchorsPtr().size(), 0); - EXPECT_EQ(outc_anch->GetPeerInControlAnchorsPtr().size(), outc_anch->GetPeerInControlAnchors().size()); - EXPECT_NE(outc_anch->GetPeerInDataAnchors().size(), 0); - - EXPECT_EQ(outc_anch->Unlink(nullptr),GRAPH_FAILED); - EXPECT_EQ(outc_anch->EncaEq(nullptr),false); - EXPECT_EQ(outc_anch->EncaIsTypeOf("nnn"),false); - outc_anch->SetImpNull(); - auto nodelast = builder.AddNode("Data", "Data", 23, 23); - SubInDataAnchorPtr peerd23 = std::make_shared(nodelast, 223); - EXPECT_EQ(outc_anch->LinkTo(peerd23), GRAPH_FAILED); - - auto node4 = builder.AddNode("Data4", "Data", 1, 1); - InControlAnchorPtr peerctr4 = std::make_shared(node4, 44); - outc_anch->impl_ = nullptr; - EXPECT_EQ(outc_anch->LinkTo(peerctr4), GRAPH_FAILED); -} - -// node1 is replaced by node3 -// node0 node0 -// / \ --> / | -// node1 node2 node3 node2 -TEST_F(AnchorUt, CheckReplacePeerOrder) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - auto node3 = builder.AddNode("Data", "Data", 1, 1); - - graphStatus ret = ge::GraphUtils::AddEdge(node0->GetOutDataAnchor(0U), node1->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = ge::GraphUtils::AddEdge(node0->GetOutDataAnchor(0U), node2->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = node0->GetOutDataAnchor(0U)->ReplacePeer(node1->GetInDataAnchor(0), - node3->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - bool check_same = node0->GetOutDataAnchor(0)->GetFirstPeerAnchor()->Equal(node3->GetInDataAnchor(0)); - EXPECT_TRUE(check_same); - EXPECT_TRUE(node1->GetInDataAnchor(0U)->GetPeerAnchors().empty()); -} - -// node1 is replaced by node4 -// node0 node0 -// / | \ --> / | | -// node1 node2 node3 node4 node2 node3 -TEST_F(AnchorUt, ControlAnchorReplacePeer) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - auto node3 = builder.AddNode("Data", "Data", 1, 1); - auto node4 = builder.AddNode("Data", "Data", 1, 1); - (void)ge::GraphUtils::AddEdge(node0->GetOutControlAnchor(), node1->GetInControlAnchor()); - (void)ge::GraphUtils::AddEdge(node0->GetOutControlAnchor(), node2->GetInControlAnchor()); - (void) ge::GraphUtils::AddEdge(node0->GetOutControlAnchor(), node3->GetInControlAnchor()); - EXPECT_EQ(node1->GetInControlAnchor()->GetPeerOutControlAnchors().size(), - node1->GetInControlAnchor()->GetPeerOutControlAnchorsPtr().size()); - graphStatus ret = node0->GetOutControlAnchor()->ReplacePeer(node1->GetInControlAnchor(), node4->GetInControlAnchor()); - EXPECT_EQ(ret, GRAPH_SUCCESS); - bool check_same = node0->GetOutControlAnchor()->GetFirstPeerAnchor()->Equal(node4->GetInControlAnchor()); - EXPECT_TRUE(check_same); - EXPECT_TRUE(node1->GetInControlAnchor()->GetPeerAnchors().empty()); -} - -TEST_F(AnchorUt, ReplacePeerDifferentTypeFailed) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - - graphStatus ret = ge::GraphUtils::AddEdge(node0->GetOutControlAnchor(), node1->GetInControlAnchor()); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = node0->GetOutControlAnchor()->ReplacePeer(node1->GetInControlAnchor(), - node2->GetInDataAnchor(0U)); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = node0->GetOutControlAnchor()->ReplacePeer(node1->GetInDataAnchor(0U), - node2->GetInDataAnchor(0U)); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -// node0 is replaced by node2 -// node0 node2 -// / --> / -// node1 node1 -TEST_F(AnchorUt, ReplacePeerOfOutDataAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - - graphStatus ret = ge::GraphUtils::AddEdge(node0->GetOutDataAnchor(0U), node1->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(node0->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), - node0->GetOutDataAnchor(0)->GetPeerInDataAnchorsPtr().size()); - ret = node1->GetInDataAnchor(0U)->ReplacePeer(node0->GetOutAnchor(0U), node2->GetOutAnchor(0U)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - bool check_same = node1->GetInDataAnchor(0U)->GetFirstPeerAnchor()->Equal(node2->GetOutAnchor(0U)); - EXPECT_TRUE(check_same); -} - -TEST_F(AnchorUt, CheckReplaceNewAnchorPeerNotEmpty) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - auto node3 = builder.AddNode("Data", "Data", 1, 1); - - graphStatus ret = ge::GraphUtils::AddEdge(node0->GetOutDataAnchor(0U), node1->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = ge::GraphUtils::AddEdge(node2->GetOutDataAnchor(0U), node3->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = node0->GetOutDataAnchor(0U)->ReplacePeer(node1->GetInDataAnchor(0), - node3->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(AnchorUt, InsertNotEmptyNodeFailed) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - auto node3 = builder.AddNode("Data", "Data", 1, 1); - auto node4 = builder.AddNode("Data", "Data", 1, 1); - - graphStatus ret = ge::GraphUtils::AddEdge(node0->GetOutDataAnchor(0U), node2->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = ge::GraphUtils::AddEdge(node2->GetOutDataAnchor(0U), node1->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = ge::GraphUtils::AddEdge(node1->GetOutDataAnchor(0U), node4->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = node0->GetOutDataAnchor(0U)->Insert(node2->GetInDataAnchor(0), - node1->GetInDataAnchor(0), - node1->GetOutDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(AnchorUt, InsertDifferentAnchorFailed) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - - graphStatus ret = ge::GraphUtils::AddEdge(node0->GetOutDataAnchor(0U), node2->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = node0->GetOutDataAnchor(0U)->Insert(node2->GetInDataAnchor(0), - node1->GetInControlAnchor(), - node1->GetOutDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = node0->GetOutDataAnchor(0U)->Insert(node2->GetInControlAnchor(), - node1->GetInDataAnchor(0), - node1->GetOutDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -// node0 -- node2 -// 插入新节点 -// node0--node1-- node2 -TEST_F(AnchorUt, InsertSuccess) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data", "Data", 1, 1); - auto node1 = builder.AddNode("Data", "Data", 1, 1); - auto node2 = builder.AddNode("Data", "Data", 1, 1); - - graphStatus ret = ge::GraphUtils::AddEdge(node0->GetOutDataAnchor(0U), node2->GetInDataAnchor(0U)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = node0->GetOutDataAnchor(0U)->Insert(node2->GetInDataAnchor(0U), - node1->GetInDataAnchor(0U), - node1->GetOutDataAnchor(0U)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - bool the_same = node0->GetOutDataAnchor(0U)->GetFirstPeerAnchor()->Equal(node1->GetInDataAnchor(0U)); - EXPECT_TRUE(the_same); - the_same = node2->GetInDataAnchor(0U)->GetFirstPeerAnchor()->Equal(node1->GetOutDataAnchor(0U)); - EXPECT_TRUE(the_same); - - auto node00 = builder.AddNode("Data", "Data", 1, 1); - auto node11 = builder.AddNode("Data", "Data", 1, 1); - auto node22 = builder.AddNode("Data", "Data", 1, 1); - ret = ge::GraphUtils::AddEdge(node00->GetOutDataAnchor(0U), node22->GetInDataAnchor(0U)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = node22->GetInDataAnchor(0U)->Insert(node00->GetOutDataAnchor(0U), - node11->GetOutDataAnchor(0), - node11->GetInDataAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - the_same = node00->GetOutDataAnchor(0U)->GetFirstPeerAnchor()->Equal(node11->GetInDataAnchor(0U)); - EXPECT_TRUE(the_same); - the_same = node22->GetInDataAnchor(0U)->GetFirstPeerAnchor()->Equal(node11->GetOutDataAnchor(0U)); - EXPECT_TRUE(the_same); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/anchor_utils_unittest.cc b/tests/ut/graph/testcase/anchor_utils_unittest.cc deleted file mode 100644 index 518c9e2c63ea02fa75c01b3349fedcb73ced92a3..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/anchor_utils_unittest.cc +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "test_structs.h" -#include "func_counter.h" -#include "graph/anchor.h" -#include "graph/utils/anchor_utils.h" -#include "graph/node.h" -#include "graph_builder_utils.h" - -namespace ge { -namespace { -class SubAnchor : public Anchor -{ -public: - - SubAnchor(const NodePtr& owner_node, const int32_t idx); - - virtual ~SubAnchor(); - - virtual bool Equal(const AnchorPtr anchor) const; - -}; - -SubAnchor::SubAnchor(const NodePtr &owner_node, const int32_t idx):Anchor(owner_node,idx){} - -SubAnchor::~SubAnchor() = default; - -bool SubAnchor::Equal(const AnchorPtr anchor) const{ - return true; -} -} - -using SubAnchorPtr = std::shared_ptr; - -class AnchorUtilsUt : public testing::Test {}; - -TEST_F(AnchorUtilsUt, GetStatus) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", "Data", 1, 1); - InDataAnchorPtr inanch = std::make_shared(node, 111); - EXPECT_NE(AnchorUtils::GetStatus(inanch), ANCHOR_RESERVED); - EXPECT_EQ(AnchorUtils::GetStatus(inanch), ANCHOR_SUSPEND); - EXPECT_EQ(AnchorUtils::GetStatus(nullptr), ANCHOR_RESERVED); -} - -TEST_F(AnchorUtilsUt, SetStatus) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", "Data", 1, 1); - InDataAnchorPtr inanch = std::make_shared(node, 111); - EXPECT_EQ(AnchorUtils::SetStatus(inanch, ANCHOR_DATA), GRAPH_SUCCESS); - EXPECT_EQ(AnchorUtils::SetStatus(inanch, ANCHOR_RESERVED), GRAPH_FAILED); - EXPECT_EQ(AnchorUtils::SetStatus(nullptr, ANCHOR_DATA), GRAPH_FAILED); -} - - -TEST_F(AnchorUtilsUt, GetIdx) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", "Data", 1, 1); - InDataAnchorPtr inanch = std::make_shared(node, 111); - EXPECT_EQ(AnchorUtils::GetIdx(inanch), 111); - - ut::GraphBuilder builder2 = ut::GraphBuilder("graph"); - auto node2 = builder2.AddNode("Data", "Data", 2, 2); - OutControlAnchorPtr outanch = std::make_shared(node2, 22); - EXPECT_EQ(AnchorUtils::GetIdx(outanch), 22); - - SubAnchorPtr sanch = std::make_shared(node, 444); - EXPECT_EQ(AnchorUtils::GetIdx(sanch), -1); - -} - -} // namespace ge diff --git a/tests/ut/graph/testcase/arg_desc_info_unittest.cc b/tests/ut/graph/testcase/arg_desc_info_unittest.cc deleted file mode 100644 index 0ef4c8ccdba9ff2586bfdc7e3d93cce058f77aa2..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/arg_desc_info_unittest.cc +++ /dev/null @@ -1,310 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/arg_desc_info.h" -#include "graph/utils/args_format_desc_utils.h" -#include "ge_common/ge_api_error_codes.h" - - -namespace ge { -class TestArgDescInfo : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -// 验证ArgDescInfo的反序列化能力 -TEST_F(TestArgDescInfo, ArgDescInfoDeserialize) { - std::vector args; - ArgsFormatDescUtils::InsertCustomValue(args, 0, 2); - ArgsFormatDescUtils::InsertHiddenInputs(args, 1, HiddenInputsType::HCOM, 2); - ArgsFormatDescUtils::Append(args, AddrType::INPUT, 0); - ArgsFormatDescUtils::Append(args, AddrType::INPUT_DESC, 1, true); - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT, 0); - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT_DESC, 1, true); - ArgsFormatDescUtils::Append(args, AddrType::WORKSPACE); - ArgsFormatDescUtils::Append(args, AddrType::TILING); - auto args_format_str = ArgsFormatDescUtils::Serialize(args); - auto args_infos = ArgsFormatSerializer::Deserialize(args_format_str.c_str()); - EXPECT_EQ(args_infos.size(), 9); - EXPECT_EQ(args_infos[0].GetType(), ArgDescType::kCustomValue); - EXPECT_EQ(args_infos[0].GetIrIndex(), -1); - EXPECT_EQ(args_infos[0].GetCustomValue(), 2); - EXPECT_EQ(args_infos[0].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[0].IsFolded(), false); - - EXPECT_EQ(args_infos[1].GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_infos[1].GetIrIndex(), -1); - EXPECT_EQ(args_infos[1].GetCustomValue(), 0); - EXPECT_EQ(args_infos[1].GetHiddenInputSubType(), HiddenInputSubType::kHcom); - EXPECT_EQ(args_infos[1].IsFolded(), false); - - EXPECT_EQ(args_infos[2].GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_infos[2].GetIrIndex(), -1); - EXPECT_EQ(args_infos[2].GetCustomValue(), 0); - EXPECT_EQ(args_infos[2].GetHiddenInputSubType(), HiddenInputSubType::kHcom); - EXPECT_EQ(args_infos[2].IsFolded(), false); - - EXPECT_EQ(args_infos[3].GetType(), ArgDescType::kIrInput); - EXPECT_EQ(args_infos[3].GetIrIndex(), 0); - EXPECT_EQ(args_infos[3].GetCustomValue(), 0); - EXPECT_EQ(args_infos[3].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[3].IsFolded(), false); - - EXPECT_EQ(args_infos[4].GetType(), ArgDescType::kIrInputDesc); - EXPECT_EQ(args_infos[4].GetIrIndex(), 1); - EXPECT_EQ(args_infos[4].GetCustomValue(), 0); - EXPECT_EQ(args_infos[4].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[4].IsFolded(), true); - - EXPECT_EQ(args_infos[5].GetType(), ArgDescType::kIrOutput); - EXPECT_EQ(args_infos[5].GetIrIndex(), 0); - EXPECT_EQ(args_infos[5].GetCustomValue(), 0); - EXPECT_EQ(args_infos[5].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[5].IsFolded(), false); - - EXPECT_EQ(args_infos[6].GetType(), ArgDescType::kIrOutputDesc); - EXPECT_EQ(args_infos[6].GetIrIndex(), 1); - EXPECT_EQ(args_infos[6].GetCustomValue(), 0); - EXPECT_EQ(args_infos[6].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[6].IsFolded(), true); - - EXPECT_EQ(args_infos[7].GetType(), ArgDescType::kWorkspace); - EXPECT_EQ(args_infos[7].GetIrIndex(), -1); - EXPECT_EQ(args_infos[7].GetCustomValue(), 0); - EXPECT_EQ(args_infos[7].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[7].IsFolded(), false); - - EXPECT_EQ(args_infos[8].GetType(), ArgDescType::kTiling); - EXPECT_EQ(args_infos[8].GetIrIndex(), -1); - EXPECT_EQ(args_infos[8].GetCustomValue(), 0); - EXPECT_EQ(args_infos[8].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[8].IsFolded(), false); -} - -// 验证ArgDescInfo的序列化能力 -TEST_F(TestArgDescInfo, ArgDescInfoSerialize) { - std::vector args_desc; - auto custom_value_arg = ArgDescInfo(ArgDescType::kCustomValue); - EXPECT_EQ(custom_value_arg.SetCustomValue(2), SUCCESS); - args_desc.emplace_back(custom_value_arg); - args_desc.emplace_back(ArgDescInfo::CreateCustomValue(0)); - auto hidden_input_arg = ArgDescInfo(ArgDescType::kHiddenInput, 0); - EXPECT_EQ(hidden_input_arg.SetHiddenInputSubType(HiddenInputSubType::kHcom), SUCCESS); - args_desc.emplace_back(hidden_input_arg); - args_desc.emplace_back(ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kIrInput, 0)); - auto input_desc_arg = ArgDescInfo(ArgDescType::kIrInputDesc, 1); - input_desc_arg.SetFolded(true); - args_desc.emplace_back(input_desc_arg); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kIrOutput, 0)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kIrOutputDesc, 1, true)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kWorkspace)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kTiling)); - auto args_format_str = ArgsFormatSerializer::Serialize(args_desc); - EXPECT_EQ(std::string(args_format_str.GetString()), "{#2}{#0}{hi.hcom0*}{hi.hcom1*}{i0*}{i_desc1}{o0*}{o_desc1}{ws*}{t}"); -} - -// 验证ArgDescInfo的反序列化能力,args序列中有InputInstance和OutputInstance -TEST_F(TestArgDescInfo, ArgDescInfoWithInstanceDeserialize) { - std::vector args; - ArgsFormatDescUtils::InsertCustomValue(args, 0, 2); - ArgsFormatDescUtils::InsertHiddenInputs(args, 1, HiddenInputsType::HCOM, 2); - ArgsFormatDescUtils::Append(args, AddrType::INPUT_INSTANCE, 0); - ArgsFormatDescUtils::Append(args, AddrType::INPUT_INSTANCE, 1); - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT_INSTANCE, 0); - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT_INSTANCE, 1); - ArgsFormatDescUtils::Append(args, AddrType::WORKSPACE); - ArgsFormatDescUtils::Append(args, AddrType::TILING); - auto args_format_str = ArgsFormatDescUtils::Serialize(args); - auto args_infos = ArgsFormatSerializer::Deserialize(args_format_str.c_str()); - EXPECT_EQ(args_infos.size(), 9); - EXPECT_EQ(args_infos[0].GetType(), ArgDescType::kCustomValue); - EXPECT_EQ(args_infos[0].GetIrIndex(), -1); - EXPECT_EQ(args_infos[0].GetCustomValue(), 2); - EXPECT_EQ(args_infos[0].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[0].IsFolded(), false); - - EXPECT_EQ(args_infos[1].GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_infos[1].GetIrIndex(), -1); - EXPECT_EQ(args_infos[1].GetCustomValue(), 0); - EXPECT_EQ(args_infos[1].GetHiddenInputSubType(), HiddenInputSubType::kHcom); - EXPECT_EQ(args_infos[1].IsFolded(), false); - - EXPECT_EQ(args_infos[2].GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_infos[2].GetIrIndex(), -1); - EXPECT_EQ(args_infos[2].GetCustomValue(), 0); - EXPECT_EQ(args_infos[2].GetHiddenInputSubType(), HiddenInputSubType::kHcom); - EXPECT_EQ(args_infos[2].IsFolded(), false); - - EXPECT_EQ(args_infos[3].GetType(), ArgDescType::kInputInstance); - EXPECT_EQ(args_infos[3].GetIrIndex(), -1); - EXPECT_EQ(args_infos[3].GetCustomValue(), 0); - EXPECT_EQ(args_infos[3].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[3].IsFolded(), false); - - EXPECT_EQ(args_infos[4].GetType(), ArgDescType::kInputInstance); - EXPECT_EQ(args_infos[4].GetIrIndex(), -1); - EXPECT_EQ(args_infos[4].GetCustomValue(), 0); - EXPECT_EQ(args_infos[4].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[4].IsFolded(), false); - - EXPECT_EQ(args_infos[5].GetType(), ArgDescType::kOutputInstance); - EXPECT_EQ(args_infos[5].GetIrIndex(), -1); - EXPECT_EQ(args_infos[5].GetCustomValue(), 0); - EXPECT_EQ(args_infos[5].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[5].IsFolded(), false); - - EXPECT_EQ(args_infos[6].GetType(), ArgDescType::kOutputInstance); - EXPECT_EQ(args_infos[6].GetIrIndex(), -1); - EXPECT_EQ(args_infos[6].GetCustomValue(), 0); - EXPECT_EQ(args_infos[6].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[6].IsFolded(), false); - - EXPECT_EQ(args_infos[7].GetType(), ArgDescType::kWorkspace); - EXPECT_EQ(args_infos[7].GetIrIndex(), -1); - EXPECT_EQ(args_infos[7].GetCustomValue(), 0); - EXPECT_EQ(args_infos[7].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[7].IsFolded(), false); - - EXPECT_EQ(args_infos[8].GetType(), ArgDescType::kTiling); - EXPECT_EQ(args_infos[8].GetIrIndex(), -1); - EXPECT_EQ(args_infos[8].GetCustomValue(), 0); - EXPECT_EQ(args_infos[8].GetHiddenInputSubType(), HiddenInputSubType::kEnd); - EXPECT_EQ(args_infos[8].IsFolded(), false); -} - -// 验证ArgDescInfo的序列化能力,args中存在InputInstance和OutputInstance -TEST_F(TestArgDescInfo, ArgDescInfoWithInstanceSerialize) { - std::vector args_desc; - auto custom_value_arg = ArgDescInfo(ArgDescType::kCustomValue); - EXPECT_EQ(custom_value_arg.SetCustomValue(2), SUCCESS); - args_desc.emplace_back(custom_value_arg); - args_desc.emplace_back(ArgDescInfo::CreateCustomValue(0)); - auto hidden_input_arg = ArgDescInfo(ArgDescType::kHiddenInput, 0); - EXPECT_EQ(hidden_input_arg.SetHiddenInputSubType(HiddenInputSubType::kHcom), SUCCESS); - args_desc.emplace_back(hidden_input_arg); - args_desc.emplace_back(ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kInputInstance)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kInputInstance)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kOutputInstance)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kOutputInstance)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kWorkspace)); - args_desc.emplace_back(ArgDescInfo(ArgDescType::kTiling)); - auto args_format_str = ArgsFormatSerializer::Serialize(args_desc); - EXPECT_EQ(std::string(args_format_str.GetString()), - "{#2}{#0}{hi.hcom0*}{hi.hcom1*}{i_instance0*}{i_instance1*}{o_instance0*}{o_instance1*}{ws*}{t}"); -} - -// 验证ArgDescInfo的拷贝赋值函数 -TEST_F(TestArgDescInfo, ArgDescInfoCopyAssignFunc) { - ArgDescInfo args_desc_1(ArgDescType::kIrInput, 0, true); - args_desc_1 = ArgDescInfo(ArgDescType::kIrInput, 1, true); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kIrInput); - EXPECT_EQ(args_desc_1.GetIrIndex(), 1); - EXPECT_EQ(args_desc_1.IsFolded(), true); - EXPECT_EQ(args_desc_1.GetCustomValue(), 0); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - - args_desc_1 = ArgDescInfo::CreateCustomValue(2); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kCustomValue); - EXPECT_EQ(args_desc_1.GetIrIndex(), -1); - EXPECT_EQ(args_desc_1.IsFolded(), false); - EXPECT_EQ(args_desc_1.GetCustomValue(), 2); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - - args_desc_1 = ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_desc_1.GetIrIndex(), -1); - EXPECT_EQ(args_desc_1.IsFolded(), false); - EXPECT_EQ(args_desc_1.GetCustomValue(), 0); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kHcom); -} - -// 验证ArgDescInfo的拷贝构造函数 -TEST_F(TestArgDescInfo, ArgDescInfoCopyConstructFunc) { - ArgDescInfo args_desc_input_desc(ArgDescType::kIrInput, 1, true); - ArgDescInfo args_desc_1(args_desc_input_desc); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kIrInput); - EXPECT_EQ(args_desc_1.GetIrIndex(), 1); - EXPECT_EQ(args_desc_1.IsFolded(), true); - EXPECT_EQ(args_desc_1.GetCustomValue(), 0); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - ArgDescInfo args_desc_custom_value = ArgDescInfo::CreateCustomValue(2); - ArgDescInfo args_desc_2(args_desc_custom_value); - EXPECT_EQ(args_desc_2.GetType(), ArgDescType::kCustomValue); - EXPECT_EQ(args_desc_2.GetIrIndex(), -1); - EXPECT_EQ(args_desc_2.IsFolded(), false); - EXPECT_EQ(args_desc_2.GetCustomValue(), 2); - EXPECT_EQ(args_desc_2.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - - auto args_desc_hidden_input = ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom); - ArgDescInfo args_desc_3(args_desc_hidden_input); - EXPECT_EQ(args_desc_3.GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_desc_3.GetIrIndex(), -1); - EXPECT_EQ(args_desc_3.IsFolded(), false); - EXPECT_EQ(args_desc_3.GetCustomValue(), 0); - EXPECT_EQ(args_desc_3.GetHiddenInputSubType(), HiddenInputSubType::kHcom); -} - -// 验证ArgDescInfo的移动赋值函数 -TEST_F(TestArgDescInfo, ArgDescInfoMoveAssignFunc) { - ArgDescInfo args_desc_1(ArgDescType::kIrInput, 0, true); - auto args_desc_input_desc = ArgDescInfo(ArgDescType::kIrInput, 1, true); - args_desc_1 = std::move(args_desc_input_desc); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kIrInput); - EXPECT_EQ(args_desc_1.GetIrIndex(), 1); - EXPECT_EQ(args_desc_1.IsFolded(), true); - EXPECT_EQ(args_desc_1.GetCustomValue(), 0); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - - auto args_desc_custom_value = ArgDescInfo::CreateCustomValue(2); - args_desc_1 = std::move(args_desc_custom_value); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kCustomValue); - EXPECT_EQ(args_desc_1.GetIrIndex(), -1); - EXPECT_EQ(args_desc_1.IsFolded(), false); - EXPECT_EQ(args_desc_1.GetCustomValue(), 2); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - - auto args_desc_hidden_input = ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom); - args_desc_1 = std::move(args_desc_hidden_input); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_desc_1.GetIrIndex(), -1); - EXPECT_EQ(args_desc_1.IsFolded(), false); - EXPECT_EQ(args_desc_1.GetCustomValue(), 0); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kHcom); -} - -// 验证ArgDescInfo的移动构造函数 -TEST_F(TestArgDescInfo, ArgDescInfoMoveConstructFunc) { - ArgDescInfo args_desc_input_desc(ArgDescType::kIrInput, 1, true); - ArgDescInfo args_desc_1(std::move(args_desc_input_desc)); - EXPECT_EQ(args_desc_1.GetType(), ArgDescType::kIrInput); - EXPECT_EQ(args_desc_1.GetIrIndex(), 1); - EXPECT_EQ(args_desc_1.IsFolded(), true); - EXPECT_EQ(args_desc_1.GetCustomValue(), 0); - EXPECT_EQ(args_desc_1.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - ArgDescInfo args_desc_custom_value = ArgDescInfo::CreateCustomValue(2); - ArgDescInfo args_desc_2(std::move(args_desc_custom_value)); - EXPECT_EQ(args_desc_2.GetType(), ArgDescType::kCustomValue); - EXPECT_EQ(args_desc_2.GetIrIndex(), -1); - EXPECT_EQ(args_desc_2.IsFolded(), false); - EXPECT_EQ(args_desc_2.GetCustomValue(), 2); - EXPECT_EQ(args_desc_2.GetHiddenInputSubType(), HiddenInputSubType::kEnd); - - auto args_desc_hidden_input = ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom); - ArgDescInfo args_desc_3(std::move(args_desc_hidden_input)); - EXPECT_EQ(args_desc_3.GetType(), ArgDescType::kHiddenInput); - EXPECT_EQ(args_desc_3.GetIrIndex(), -1); - EXPECT_EQ(args_desc_3.IsFolded(), false); - EXPECT_EQ(args_desc_3.GetCustomValue(), 0); - EXPECT_EQ(args_desc_3.GetHiddenInputSubType(), HiddenInputSubType::kHcom); -} -} \ No newline at end of file diff --git a/tests/ut/graph/testcase/args_format_desc_unittest.cc b/tests/ut/graph/testcase/args_format_desc_unittest.cc deleted file mode 100644 index 5283bc865119da3ad0578de482ed8bb82bc9a0ea..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/args_format_desc_unittest.cc +++ /dev/null @@ -1,725 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "ge_ir.pb.h" -#include "graph/args_format_desc.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "graph/tensor.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" -#include "register/hidden_inputs_func_registry.h" -#include "external/graph/operator_factory.h" -#include "external/graph/operator_reg.h" -#include "external/register/op_impl_registry.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "common/graph_builder_utils.h" - -using namespace std; -using namespace ge; -namespace ge { -class UtestArgsFormatDesc : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestArgsFormatDesc, serialize_simple_args) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - ArgsFormatDesc desc; - desc.Append(AddrType::FFTS_ADDR); - desc.Append(AddrType::OVERFLOW_ADDR); - desc.Append(AddrType::TILING); - desc.Append(AddrType::TILING_FFTS, 0); - desc.Append(AddrType::TILING_FFTS, 1); - std::string res = desc.ToString(); - std::string expect_res = "{ffts_addr}{overflow_addr}{t}{t_ffts.non_tail}{t_ffts.tail}"; - EXPECT_EQ(expect_res, res); - size_t args_size{0UL}; - EXPECT_EQ(desc.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, 40UL); - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expect_res, descs), SUCCESS); -} - -TEST_F(UtestArgsFormatDesc, serialize_simple_args1) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - GeTensorDesc desc; - op_desc->AddOutputDesc(desc); - op_desc->AddOutputDesc(desc); - op_desc->AddOutputDesc(desc); - op_desc->AddOutputDesc(desc); - op_desc->AddOutputDesc(desc); - op_desc->AddInputDesc(desc); - op_desc->AddInputDesc(desc); - op_desc->AddInputDesc(desc); - op_desc->AddInputDesc(desc); - op_desc->AddInputDesc(desc); - op_desc->AddInputDesc(desc); - - std::string expect_res = - "{i_instance0}{i_instance0*}{i_instance2*}{o_instance3}{o_instance1}{o_instance3*}{i_instance2}{o_instance4}"; - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expect_res, descs), SUCCESS); - ArgsFormatDesc format_desc; - format_desc.Append(AddrType::INPUT_INSTANCE, 0, true); - format_desc.Append(AddrType::INPUT_INSTANCE, 0); - format_desc.Append(AddrType::INPUT_INSTANCE, 2); - format_desc.Append(AddrType::OUTPUT_INSTANCE, 3, true); - format_desc.Append(AddrType::OUTPUT_INSTANCE, 1, true); - format_desc.Append(AddrType::OUTPUT_INSTANCE, 3); - format_desc.Append(AddrType::INPUT_INSTANCE, 2, true); - format_desc.Append(AddrType::OUTPUT_INSTANCE, 4, true); - std::string res = format_desc.ToString(); - EXPECT_EQ(res, expect_res); - std::size_t arg_size{0UL}; - EXPECT_EQ(format_desc.GetArgsSize(op_desc, arg_size), SUCCESS); - EXPECT_EQ(arg_size, 104); - - std::string expect_res1 = "{i_instance*}{o_instance*}"; - std::vector descs1; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expect_res1, descs1), SUCCESS); - EXPECT_EQ(descs1.size(), 11UL); -} - -REG_OP(SimpleOP) - .INPUT(a, TensorType(DT_FLOAT32)) - .INPUT(b, TensorType(DT_FLOAT32)) - .INPUT(c, TensorType(DT_FLOAT32)) - .OUTPUT(x, TensorType(DT_FLOAT32)) - .OUTPUT(y, TensorType(DT_FLOAT32)) - .OUTPUT(z, TensorType(DT_FLOAT32)) - .OP_END_FACTORY_REG(SimpleOP); - -TEST_F(UtestArgsFormatDesc, common_args) { - auto op = OperatorFactory::CreateOperator("test1", "SimpleOP"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 40, 1024, 256}); - GeTensorDesc desc(shape); - op_desc->AddInputDesc(desc); - op_desc->AddInputDesc(desc); - op_desc->AddInputDesc(desc); - - op_desc->AddOutputDesc(desc); - op_desc->AddOutputDesc(desc); - op_desc->AddOutputDesc(desc); - - ArgsFormatDesc args_des; - args_des.Append(AddrType::INPUT, -1); - args_des.Append(AddrType::OUTPUT, -1); - args_des.Append(AddrType::WORKSPACE, -1); - std::string res = args_des.ToString(); - std::string expect_res = "{i*}{o*}{ws*}"; - EXPECT_EQ(expect_res, res); - - size_t args_size{0UL}; - EXPECT_EQ(args_des.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, 176UL); - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expect_res, descs), SUCCESS); - EXPECT_EQ(descs.size(), 7UL); - EXPECT_EQ(descs[2].addr_type, AddrType::INPUT); - EXPECT_EQ(descs[2].ir_idx, 2); -} - -REG_OP(DynOP) - .INPUT(a, TensorType(DT_FLOAT32)) - .DYNAMIC_INPUT(b, TensorType(DT_FLOAT32)) - .INPUT(c, TensorType(DT_FLOAT32)) - .OUTPUT(x, TensorType(DT_FLOAT32)) - .DYNAMIC_OUTPUT(y, TensorType(DT_FLOAT32)) - .OUTPUT(z, TensorType(DT_FLOAT32)) - .OP_END_FACTORY_REG(DynOP); - -TEST_F(UtestArgsFormatDesc, common_args_dynamic_folded) { - auto op = OperatorFactory::CreateOperator("test1", "DynOP"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 40, 1024, 256}); - GeTensorDesc desc(shape); - op_desc->UpdateInputDesc(0, desc); - op_desc->AddDynamicInputDescByIndex("b", 1, 1); - op_desc->UpdateInputDesc(1, desc); - op_desc->UpdateInputDesc(2, desc); - op_desc->UpdateOutputDesc(0, desc); - op_desc->AddDynamicOutputDesc("y", 1, true); - op_desc->UpdateOutputDesc("y0", desc); - op_desc->UpdateOutputDesc(0, desc); - - ArgsFormatDesc args_des; - args_des.Append(AddrType::INPUT, -1); - args_des.Append(AddrType::OUTPUT, -1); - args_des.Append(AddrType::WORKSPACE, 0); - args_des.Append(AddrType::WORKSPACE, 1); - std::string res = args_des.ToString(); - std::string expect_res = "{i*}{o*}{ws0*}{ws1*}"; - EXPECT_EQ(expect_res, res); - - size_t args_size{0UL}; - EXPECT_EQ(args_des.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, 80UL); - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expect_res, descs), SUCCESS); - - std::string expanded_res = "{i0}{i1}{i2}{o0}{o1}{o2}{ws0}{ws1}"; - std::vector expanded_descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expanded_res, expanded_descs), SUCCESS); - EXPECT_EQ(descs.size(), expanded_descs.size()); - for (auto i = 0UL; i < descs.size(); ++i) { - EXPECT_EQ(descs[i].addr_type, expanded_descs[i].addr_type); - EXPECT_EQ(descs[i].ir_idx, expanded_descs[i].ir_idx); - EXPECT_EQ(descs[i].folded, expanded_descs[i].folded); - } -} - -TEST_F(UtestArgsFormatDesc, common_args_size_equal) { - auto op = OperatorFactory::CreateOperator("test1", "DynOP"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 40, 1024, 256}); - GeTensorDesc desc(shape); - op_desc->UpdateInputDesc(0, desc); - op_desc->AddDynamicInputDescByIndex("b", 1, 1); - op_desc->UpdateInputDesc(1, desc); - op_desc->UpdateInputDesc(2, desc); - op_desc->UpdateOutputDesc(0, desc); - op_desc->AddDynamicOutputDesc("y", 1, true); - op_desc->UpdateOutputDesc("y0", desc); - op_desc->UpdateOutputDesc(0, desc); - - ArgsFormatDesc args_des; - args_des.Append(AddrType::INPUT, -1); - args_des.Append(AddrType::OUTPUT, -1); - args_des.Append(AddrType::WORKSPACE, 0); - args_des.Append(AddrType::WORKSPACE, 1); - size_t args_size{0UL}; - EXPECT_EQ(args_des.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - ArgsFormatDesc args_des1; - args_des1.Append(AddrType::INPUT, 0); - args_des1.Append(AddrType::INPUT, 1, true); - args_des1.Append(AddrType::INPUT, 2); - args_des1.Append(AddrType::OUTPUT, 0); - args_des1.Append(AddrType::OUTPUT, 1, true); - args_des1.Append(AddrType::OUTPUT, 2); - args_des1.Append(AddrType::WORKSPACE, 0); - args_des1.Append(AddrType::WORKSPACE, 1); - size_t args_size1{0UL}; - EXPECT_EQ(args_des1.GetArgsSize(op_desc, args_size1), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, args_size1); -} - -REG_OP(IFA) - .INPUT(query, TensorType(DT_FLOAT32)) - .DYNAMIC_INPUT(key, TensorType(DT_FLOAT32)) - .DYNAMIC_INPUT(value, TensorType(DT_FLOAT32)) - .OPTIONAL_INPUT(padding_mask, TensorType(DT_FLOAT32)) - .OPTIONAL_INPUT(atten_mask, TensorType(DT_FLOAT32)) - .OPTIONAL_INPUT(actual_seq_lengths, TensorType(DT_FLOAT32)) - .DYNAMIC_OUTPUT(attention_out, TensorType(DT_FLOAT32)) - .OP_END_FACTORY_REG(IFA); - -TEST_F(UtestArgsFormatDesc, serialize_dynamic_args) { - auto op = OperatorFactory::CreateOperator("test1", "IFA"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 40, 1024, 256}); - GeTensorDesc desc(shape); - GeShape scalar_shape; - GeTensorDesc scalar_desc(scalar_shape); - op_desc->UpdateInputDesc(0, desc); - op_desc->AddDynamicInputDescByIndex("key", 2, 1); - op_desc->UpdateInputDesc(1, desc); - op_desc->UpdateInputDesc(2, desc); - op_desc->AddDynamicInputDescByIndex("value", 2, 3); - op_desc->UpdateInputDesc(3, desc); - op_desc->UpdateInputDesc(4, scalar_desc); - op_desc->UpdateInputDesc("atten_mask", desc); - - op_desc->AddDynamicOutputDesc("attention_out", 2, true); - op_desc->UpdateOutputDesc("attention_out0", desc); - op_desc->UpdateOutputDesc("attention_out1", scalar_desc); - - ArgsFormatDesc args_des; - args_des.Append(AddrType::FFTS_ADDR); - args_des.Append(AddrType::INPUT, 0); - args_des.Append(AddrType::INPUT_DESC, 1, true); - args_des.Append(AddrType::INPUT_DESC, 2, true); - args_des.Append(AddrType::INPUT, 4); - args_des.Append(AddrType::OUTPUT_DESC, 0, true); - args_des.Append(AddrType::WORKSPACE, 0); - args_des.Append(AddrType::TILING_FFTS, 1); - std::string res = args_des.ToString(); - std::string expect_res = "{ffts_addr}{i0*}{i_desc1}{i_desc2}{i4*}{o_desc0}{ws0*}{t_ffts.tail}"; - EXPECT_EQ(expect_res, res); - - size_t args_size{0UL}; - EXPECT_EQ(args_des.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, 328UL); - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expect_res, descs), SUCCESS); - EXPECT_EQ(descs.size(), 8UL); - EXPECT_EQ(descs[2].addr_type, AddrType::INPUT_DESC); -} - -TEST_F(UtestArgsFormatDesc, serialize_hidden_input) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - ArgsFormatDesc desc; - ArgsFormatDescUtils::InsertHiddenInputs(desc.arg_descs_, -1, HiddenInputsType::HCOM); - desc.Append(AddrType::PLACEHOLDER); - EXPECT_EQ(desc.ToString(), "{hi.hcom0*}{}"); - size_t args_size{0UL}; - EXPECT_EQ(desc.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, 16UL); -} - -TEST_F(UtestArgsFormatDesc, serialize_ascendcpp_hidden_input) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - ArgsFormatDesc desc; - ArgsFormatDescUtils::InsertHiddenInputs(desc.arg_descs_, -1, HiddenInputsType::TILEFWK); - desc.Append(AddrType::PLACEHOLDER); - EXPECT_EQ(desc.ToString(), "{hi.tilefwk0*}{}"); - size_t args_size{0UL}; - EXPECT_EQ(desc.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, 16UL); -} - -TEST_F(UtestArgsFormatDesc, deserialzie_hidden_input) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::vector descs; - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, "{hi.unsupported}", descs), SUCCESS); - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, "{hi.hcom}", descs), SUCCESS); - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, "{hi.hcom[xx]}", descs), SUCCESS); - - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, "{hi.hcom0*}", descs), SUCCESS); - EXPECT_EQ(descs.size(), 1UL); - EXPECT_EQ(descs[0UL].addr_type, AddrType::HIDDEN_INPUT); - EXPECT_EQ(descs[0UL].ir_idx, 0); - - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, "{hi.tilefwk0*}", descs), SUCCESS); - EXPECT_EQ(descs.size(), 1UL); - EXPECT_EQ(descs[0UL].addr_type, AddrType::HIDDEN_INPUT); - EXPECT_EQ(descs[0UL].ir_idx, 0); - EXPECT_EQ(*reinterpret_cast(descs[0UL].reserved), static_cast(HiddenInputsType::TILEFWK)); - - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, "{hi.hcom0*}{hi.hcom1*}", descs), SUCCESS); - EXPECT_EQ(descs.size(), 2UL); - EXPECT_EQ(descs[0UL].addr_type, AddrType::HIDDEN_INPUT); - EXPECT_EQ(descs[0UL].ir_idx, 0); - EXPECT_EQ(descs[1UL].addr_type, AddrType::HIDDEN_INPUT); - EXPECT_EQ(descs[1UL].ir_idx, 1); -} - -TEST_F(UtestArgsFormatDesc, serialize_custom_val) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - ArgsFormatDesc desc; - size_t args_size = 0; - - EXPECT_EQ(ArgsFormatDescUtils::InsertCustomValue(desc.arg_descs_, -1, 0), GRAPH_SUCCESS); - EXPECT_EQ(desc.ToString(), "{#0}"); - EXPECT_EQ(desc.GetArgsSize(op_desc, args_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(args_size, 8UL); -} - -TEST_F(UtestArgsFormatDesc, deserialzie_custom_val) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::vector descs; - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, "{xxx}", descs), SUCCESS); - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, "{#}", descs), SUCCESS); - - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, "{#18446744073709551615}", descs), SUCCESS); // 0xFFFFFFFFFFFFFFFF - EXPECT_EQ(descs[0UL].addr_type, AddrType::CUSTOM_VALUE); - EXPECT_EQ(*(uint64_t *)descs[0UL].reserved, 0xFFFFFFFFFFFFFFFF); - - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, "{#0}", descs), SUCCESS); - EXPECT_EQ(descs[0UL].addr_type, AddrType::CUSTOM_VALUE); - EXPECT_EQ(*(uint64_t *)descs[0UL].reserved, 0); -} - -TEST_F(UtestArgsFormatDesc, deserialzie_placeholder) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = "{}"; - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, format1, descs), SUCCESS); - EXPECT_EQ(descs.size(), 1UL); - EXPECT_EQ(descs[0UL].addr_type, AddrType::PLACEHOLDER); -} - -TEST_F(UtestArgsFormatDesc, placeholder_serdes) { - auto op_desc = std::make_shared("op", "Dummy"); - ArgsFormatDesc desc, desc2; - size_t size = 0; - desc.Append(AddrType::PLACEHOLDER); - desc.AppendPlaceholder(); - desc.AppendPlaceholder(ArgsFormatWidth::BIT32); - desc.AppendPlaceholder(ArgsFormatWidth::BIT64); - EXPECT_EQ(desc.GetArgsSize(op_desc, size), GRAPH_SUCCESS); - EXPECT_EQ(size, 8 + 8 + 4 + 8); - auto str = desc.ToString(); - EXPECT_EQ(str, "{}{}{.32b}{}"); - - EXPECT_EQ(ArgsFormatDesc::FromString(desc2, op_desc, str), GRAPH_SUCCESS); - size_t idx = 0; - for (const auto &iter : desc) { - const auto &iter2 = desc2.arg_descs_[idx++]; - EXPECT_EQ(iter.addr_type, AddrType::PLACEHOLDER); - EXPECT_EQ(iter2.addr_type, AddrType::PLACEHOLDER); - EXPECT_EQ(iter.ir_idx, iter2.ir_idx); - } -} - -TEST_F(UtestArgsFormatDesc, custom_value_serdes) { - auto op_desc = std::make_shared("op", "Dummy"); - ArgsFormatDesc desc, desc2; - size_t size = 0; - desc.AppendCustomValue(42); - desc.AppendCustomValue(114, ArgsFormatWidth::BIT32); - desc.AppendCustomValue(514, ArgsFormatWidth::BIT64); - EXPECT_EQ(desc.GetArgsSize(op_desc, size), GRAPH_SUCCESS); - EXPECT_EQ(size, 8 + 4 + 8); - auto str = desc.ToString(); - EXPECT_EQ(str, "{#42}{#.32b114}{#514}"); - - EXPECT_EQ(ArgsFormatDesc::FromString(desc2, op_desc, str), GRAPH_SUCCESS); - size_t idx = 0; - for (const auto &iter : desc) { - const auto &iter2 = desc2.arg_descs_[idx++]; - EXPECT_EQ(iter.addr_type, AddrType::CUSTOM_VALUE); - EXPECT_EQ(iter2.addr_type, AddrType::CUSTOM_VALUE); - EXPECT_EQ(iter.ir_idx, iter2.ir_idx); - EXPECT_EQ(*reinterpret_cast(iter.reserved), - *reinterpret_cast(iter2.reserved)); - } -} - -TEST_F(UtestArgsFormatDesc, invalid_args_format_width) { - auto op_desc = std::make_shared("op", "Dummy"); - auto invalid_width = static_cast(0); - ArgsFormatDesc desc; - size_t size = 0; - desc.AppendPlaceholder(invalid_width); - EXPECT_NE(desc.GetArgsSize(op_desc, size), GRAPH_SUCCESS); - - desc.Clear(); - desc.AppendCustomValue(42, invalid_width); - EXPECT_NE(desc.GetArgsSize(op_desc, size), GRAPH_SUCCESS); -} - -TEST_F(UtestArgsFormatDesc, deserialzie_unsupported) { -auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = "{hehe}"; - std::vector descs1; - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, format1, descs1), SUCCESS); - - std::string format2 = "{ }"; - std::vector descs2; - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, format2, descs2), SUCCESS); - - std::string format3 = "{hi.unsupported}"; - std::vector descs3; - EXPECT_NE(ArgsFormatDesc::Parse(op_desc, format3, descs3), SUCCESS); -} - -TEST_F(UtestArgsFormatDesc, deserialzie_tiling_context) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = - "{tiling_context}{*op_type}{tiling_context.tiling_key}{tiling_context.tiling_data}{tiling_context.block_dim}"; - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, format1, descs), SUCCESS); - EXPECT_EQ(descs.size(), 5UL); - EXPECT_EQ(descs[0UL].addr_type, AddrType::TILING_CONTEXT); - EXPECT_EQ(descs[1UL].addr_type, AddrType::OP_TYPE); - EXPECT_EQ(descs[2UL].ir_idx, static_cast(TilingContextSubType::TILING_KEY)); - EXPECT_EQ(descs[3UL].ir_idx, static_cast(TilingContextSubType::TILING_DATA)); - EXPECT_EQ(descs[4UL].ir_idx, static_cast(TilingContextSubType::BLOCK_DIM)); -} - -TEST_F(UtestArgsFormatDesc, serialzie_tiling_context) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = - "{tiling_context}{*op_type}{tiling_context.tiling_key}{tiling_context.tiling_data}{tiling_context.block_dim}"; - ArgsFormatDesc desc; - desc.AppendTilingContext(); - desc.Append(AddrType::OP_TYPE); - desc.AppendTilingContext(TilingContextSubType::TILING_KEY); - desc.AppendTilingContext(TilingContextSubType::TILING_DATA); - desc.AppendTilingContext(TilingContextSubType::BLOCK_DIM); - size_t target_size{0UL}; - EXPECT_EQ(desc.GetArgsSize(op_desc, target_size), SUCCESS); - EXPECT_EQ(target_size, 40UL); - std::string res = desc.ToString(); - EXPECT_EQ(format1, res); -} - -TEST_F(UtestArgsFormatDesc, single_arg_size_calc) { - auto op = OperatorFactory::CreateOperator("test1", "IFA"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 40, 1024, 256}); - GeTensorDesc desc(shape); - GeShape scalar_shape; - GeTensorDesc scalar_desc(scalar_shape); - op_desc->UpdateInputDesc(0, desc); - op_desc->AddDynamicInputDescByIndex("key", 2, 1); - op_desc->UpdateInputDesc(1, desc); - op_desc->UpdateInputDesc(2, desc); - op_desc->AddDynamicInputDescByIndex("value", 2, 3); - op_desc->UpdateInputDesc(3, desc); - op_desc->UpdateInputDesc(4, scalar_desc); - op_desc->UpdateInputDesc("atten_mask", desc); - - op_desc->AddDynamicOutputDesc("attention_out", 2, true); - op_desc->UpdateOutputDesc("attention_out0", desc); - op_desc->UpdateOutputDesc("attention_out1", scalar_desc); - - ArgsFormatDesc args_des; - args_des.Append(AddrType::FFTS_ADDR); - args_des.Append(AddrType::INPUT, 0); - args_des.Append(AddrType::INPUT_DESC, 1, true); - args_des.Append(AddrType::INPUT_DESC, 2, true); - args_des.Append(AddrType::INPUT, 4); - args_des.Append(AddrType::OUTPUT_DESC, 0, true); - args_des.Append(AddrType::WORKSPACE, 0); - args_des.Append(AddrType::TILING_FFTS, 1); - args_des.Append(AddrType::HIDDEN_INPUT); - std::string res = args_des.ToString(); - std::string expect_res = "{ffts_addr}{i0*}{i_desc1}{i_desc2}{i4*}{o_desc0}{ws0*}{t_ffts.tail}{hi.hcom0*}"; - EXPECT_EQ(expect_res, res); - - std::vector descs; - EXPECT_EQ(ArgsFormatDesc::Parse(op_desc, expect_res, descs), SUCCESS); - EXPECT_EQ(descs.size(), 9UL); - - size_t arg_size{0UL}; - EXPECT_EQ(args_des.GetArgSize(op_desc, descs[1], arg_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(arg_size, 8UL); - arg_size = 0UL; - EXPECT_EQ(args_des.GetArgSize(op_desc, descs[2], arg_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(arg_size, 112UL); - arg_size = 0UL; - EXPECT_EQ(args_des.GetArgSize(op_desc, descs[3], arg_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(arg_size, 88UL); - arg_size = 0UL; - EXPECT_EQ(args_des.GetArgSize(op_desc, descs[4], arg_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(arg_size, 8UL); - arg_size = 0UL; - EXPECT_EQ(args_des.GetArgSize(op_desc, descs[5], arg_size), ge::GRAPH_SUCCESS); - EXPECT_EQ(arg_size, 88UL); - - ArgDesc unsupport_desc; - unsupport_desc.addr_type = static_cast(0xff); - EXPECT_EQ(args_des.GetArgSize(op_desc, unsupport_desc, arg_size), ge::GRAPH_PARAM_INVALID); - - args_des.Clear(); - EXPECT_EQ(args_des.ToString(), ""); -} - -/* - * netoutput - * | \ \ - * node4 node5 node6 - * | \ - * node2 node3 - * \ / - * node1 - */ -ge::ComputeGraphPtr BuildNormalGraph(const std::string &name) { - auto builder = ge::ut::GraphBuilder(name); - auto node1 = builder.AddNode("node1", "node1", 0, 2); - auto node2 = builder.AddNode("node2", "node2", 1, 1); - auto node3 = builder.AddNode("node3", "node3", 1, 1); - auto node4 = builder.AddNode("node4", "node4", 1, 1); - auto node5 = builder.AddNode("node5", "node5", 1, 1); - auto node6 = builder.AddNode("node6", "node6", 0, 1); - auto netoutput = builder.AddNode("netoutput", "netoutput", 3, 1); - - node1->GetOpDesc()->AppendIrInput("x", kIrInputRequired); - node2->GetOpDesc()->AppendIrInput("x", kIrInputRequired); - node2->GetOpDesc()->AppendIrOutput("y", kIrOutputRequired); - - builder.AddDataEdge(node1, 0, node2, 0); - builder.AddDataEdge(node1, 1, node3, 0); - builder.AddDataEdge(node2, 0, node4, 0); - builder.AddDataEdge(node3, 0, node5, 0); - builder.AddDataEdge(node4, 0, netoutput, 0); - builder.AddDataEdge(node5, 0, netoutput, 1); - builder.AddDataEdge(node6, 0, netoutput, 2); - return builder.GetGraphWithoutSort(); -} - -TEST_F(UtestArgsFormatDesc, SknArgDescTest) { - auto sub_graph = BuildNormalGraph("test"); - auto op_desc = std::make_shared("sk", "SuperKernel"); - EXPECT_NE(sub_graph, nullptr); - EXPECT_NE(op_desc, nullptr); - op_desc->SetExtAttr("_sk_sub_graph", sub_graph); - SkArgDesc sk_desc = {AddrType::SUPER_KERNEL_SUB_NODE, 1, false, AddrType::INPUT, 0}; - ArgDesc sub_desc = *reinterpret_cast(&sk_desc); - std::vector args_desc_vec; - args_desc_vec.emplace_back(sub_desc); - auto str = ArgsFormatDesc::Serialize(args_desc_vec); - EXPECT_EQ(str, "{skn1i0*}"); - - std::vector target_sub_desc_vec; - auto ret = ArgsFormatDesc::Parse(nullptr, str, target_sub_desc_vec, false); - EXPECT_NE(ret, GRAPH_SUCCESS); - - ret = ArgsFormatDesc::Parse(op_desc, str, target_sub_desc_vec, false); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(target_sub_desc_vec.size(), 1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(target_sub_desc_vec.size(), 1); - EXPECT_EQ(target_sub_desc_vec[0].ir_idx, sub_desc.ir_idx); - EXPECT_EQ(target_sub_desc_vec[0].addr_type, sub_desc.addr_type); - EXPECT_EQ(target_sub_desc_vec[0].folded, sub_desc.folded); - - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_idx, - reinterpret_cast(&sub_desc)->sub_idx); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_addr_type, - reinterpret_cast(&sub_desc)->sub_addr_type); - - ArgDesc arg_desc_normal{}; - arg_desc_normal.addr_type = AddrType::INPUT; - ArgDesc tmp_arg_desc{}; - int32_t sub_op_id = 0; - EXPECT_EQ(ArgsFormatDesc::ConvertArgDescSkToNormal(arg_desc_normal, tmp_arg_desc, sub_op_id), GRAPH_SUCCESS); -} - -TEST_F(UtestArgsFormatDesc, SknArgDescTestHiddenInput) { - auto sub_graph = BuildNormalGraph("test"); - auto op_desc = std::make_shared("sk", "SuperKernel"); - EXPECT_NE(sub_graph, nullptr); - EXPECT_NE(op_desc, nullptr); - op_desc->SetExtAttr("_sk_sub_graph", sub_graph); - SkArgDescV2 sk_desc = {AddrType::SUPER_KERNEL_SUB_NODE, 1, - static_cast(HiddenInputsType::HCOM), AddrType::HIDDEN_INPUT, 0}; - auto sub_desc = *reinterpret_cast(&sk_desc); - std::vector args_desc_vec; - args_desc_vec.emplace_back(*reinterpret_cast(&sk_desc)); - - sk_desc.reserved = static_cast(HiddenInputsType::TILEFWK); - args_desc_vec.emplace_back(*reinterpret_cast(&sk_desc)); - sk_desc.reserved = static_cast(HiddenInputsType::HCCLSUPERKERNEL); - args_desc_vec.emplace_back(*reinterpret_cast(&sk_desc)); - - auto str = ArgsFormatDesc::Serialize(args_desc_vec); - EXPECT_EQ(str, "{skn1hi.hcom0*}{skn1hi.tilefwk0*}{skn1hi.hcclsk0*}"); - - std::vector target_sub_desc_vec; - auto ret = ArgsFormatDesc::Parse(op_desc, str, target_sub_desc_vec, false); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(target_sub_desc_vec.size(), 3); - EXPECT_EQ(target_sub_desc_vec[0].ir_idx, sub_desc.ir_idx); - EXPECT_EQ(target_sub_desc_vec[0].addr_type, sub_desc.addr_type); - EXPECT_EQ(target_sub_desc_vec[0].folded, sub_desc.folded); - - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_idx, - reinterpret_cast(&sub_desc)->sub_idx); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_idx, - 0); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_addr_type, - reinterpret_cast(&sub_desc)->sub_addr_type); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_addr_type, - AddrType::HIDDEN_INPUT); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->reserved, - static_cast(HiddenInputsType::HCOM)); - - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[1])->sub_addr_type, - AddrType::HIDDEN_INPUT); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[1])->reserved, - static_cast(HiddenInputsType::TILEFWK)); - - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[2])->sub_addr_type, - AddrType::HIDDEN_INPUT); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[2])->reserved, - static_cast(HiddenInputsType::HCCLSUPERKERNEL)); - - ArgDesc tmp_arg_desc{}; - int32_t sub_op_id = 0; - EXPECT_EQ(ArgsFormatDesc::ConvertArgDescSkToNormal(target_sub_desc_vec[0], tmp_arg_desc, sub_op_id), GRAPH_SUCCESS); - EXPECT_EQ(*reinterpret_cast(tmp_arg_desc.reserved), static_cast(HiddenInputsType::HCOM)); - - EXPECT_EQ(ArgsFormatDesc::ConvertArgDescSkToNormal(target_sub_desc_vec[1], tmp_arg_desc, sub_op_id), GRAPH_SUCCESS); - EXPECT_EQ(*reinterpret_cast(tmp_arg_desc.reserved), static_cast(HiddenInputsType::TILEFWK)); - - EXPECT_EQ(ArgsFormatDesc::ConvertArgDescSkToNormal(target_sub_desc_vec[2], tmp_arg_desc, sub_op_id), GRAPH_SUCCESS); - EXPECT_EQ(*reinterpret_cast(tmp_arg_desc.reserved), static_cast(HiddenInputsType::HCCLSUPERKERNEL)); - - size_t arg_size = 0; - ret = ArgsFormatDesc::GetArgSize(op_desc, sub_desc, arg_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(arg_size, 8); -} - -TEST_F(UtestArgsFormatDesc, SknArgDesceEventAddr) { - auto sub_graph = BuildNormalGraph("test"); - auto op_desc = std::make_shared("sk", "SuperKernel"); - EXPECT_NE(sub_graph, nullptr); - EXPECT_NE(op_desc, nullptr); - op_desc->SetExtAttr("_sk_sub_graph", sub_graph); - SkArgDesc sk_desc = {AddrType::SUPER_KERNEL_SUB_NODE, 1, false, AddrType::EVENT_ADDR, 10}; - ArgDesc sub_desc = *reinterpret_cast(&sk_desc); - std::vector args_desc_vec; - args_desc_vec.emplace_back(sub_desc); - auto str = ArgsFormatDesc::Serialize(args_desc_vec); - - EXPECT_EQ(str, "{skn1event_addr10*}"); - - std::vector target_sub_desc_vec; - auto ret = ArgsFormatDesc::Parse(op_desc, str, target_sub_desc_vec, false); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(target_sub_desc_vec.size(), 1); - EXPECT_EQ(target_sub_desc_vec[0].ir_idx, sub_desc.ir_idx); - EXPECT_EQ(target_sub_desc_vec[0].addr_type, sub_desc.addr_type); - EXPECT_EQ(target_sub_desc_vec[0].folded, sub_desc.folded); - - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_idx, - reinterpret_cast(&sub_desc)->sub_idx); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_idx, - 10); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_addr_type, - reinterpret_cast(&sub_desc)->sub_addr_type); - EXPECT_EQ(reinterpret_cast(&target_sub_desc_vec[0])->sub_addr_type, - AddrType::EVENT_ADDR); - size_t arg_size = 0; - ret = ArgsFormatDesc::GetArgSize(op_desc, sub_desc, arg_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(arg_size, 8); -} - -TEST_F(UtestArgsFormatDesc, ConvertToSuperKernelArgFormat) { - auto sub_graph = BuildNormalGraph("test"); - auto op_desc = std::make_shared("sk", "SuperKernel"); - EXPECT_NE(sub_graph, nullptr); - EXPECT_NE(op_desc, nullptr); - op_desc->SetExtAttr("_sk_sub_graph", sub_graph); - - auto sk_node = std::shared_ptr(new (std::nothrow) Node(op_desc, nullptr)); - EXPECT_NE(sk_node, nullptr); - (void)sk_node->Init(); - - NodePtr sub_node = std::shared_ptr(new (std::nothrow) Node(op_desc, nullptr));; - std::string sub_node_arg_format = "{i0*}"; - std::string sk_arg_format; - sub_node->GetOpDesc()->AppendIrInput("x", kIrInputRequired); - EXPECT_EQ(ArgsFormatDesc::ConvertToSuperKernelArgFormat( - sk_node, sub_node, sub_node_arg_format, sk_arg_format), ge::GRAPH_SUCCESS); - EXPECT_EQ(sk_arg_format, "{skn0i0*}"); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/args_format_desc_utils_unittest.cc b/tests/ut/graph/testcase/args_format_desc_utils_unittest.cc deleted file mode 100644 index e414402c65074366ef6ae9f602c025adce573dda..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/args_format_desc_utils_unittest.cc +++ /dev/null @@ -1,290 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "ge_ir.pb.h" -#include "graph/utils/args_format_desc_utils.h" -#include "graph/ge_attr_value.h" -#include "graph/ge_tensor.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "graph/tensor.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" -#include "register/hidden_inputs_func_registry.h" -#include "slog.h" -#include "external/graph/operator_factory.h" -#include "external/graph/operator_reg.h" -#include "external/register/op_impl_registry.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/op_desc_utils_ex.h" - -using namespace std; -using namespace ge; -namespace ge { -class UtestArgsFormatDescUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestArgsFormatDescUtils, serialize_simple_args) { - std::vector desc; - ArgsFormatDescUtils::Append(desc, AddrType::FFTS_ADDR); - ArgsFormatDescUtils::Append(desc, AddrType::OVERFLOW_ADDR); - ArgsFormatDescUtils::Append(desc, AddrType::TILING); - ArgsFormatDescUtils::Append(desc, AddrType::TILING_FFTS, 0); - ArgsFormatDescUtils::Append(desc, AddrType::TILING_FFTS, 1); - std::string res = ArgsFormatDescUtils::ToString(desc); - std::string expect_res = "{ffts_addr}{overflow_addr}{t}{t_ffts.non_tail}{t_ffts.tail}"; - EXPECT_EQ(expect_res, res); - - std::vector descs; - EXPECT_EQ(ArgsFormatDescUtils::Parse(expect_res, descs), SUCCESS); - EXPECT_EQ(descs.size(), 5); - EXPECT_EQ(desc[0].addr_type, AddrType::FFTS_ADDR); - EXPECT_EQ(desc[1].addr_type, AddrType::OVERFLOW_ADDR); - EXPECT_EQ(desc[2].addr_type, AddrType::TILING); - EXPECT_EQ(desc[3].addr_type, AddrType::TILING_FFTS); - EXPECT_EQ(desc[3].ir_idx, 0); - EXPECT_EQ(desc[4].addr_type, AddrType::TILING_FFTS); - EXPECT_EQ(desc[4].ir_idx, 1); -} - -TEST_F(UtestArgsFormatDescUtils, common_args) { - std::vector arg_desc; - ArgsFormatDescUtils::Append(arg_desc, AddrType::INPUT, -1); - ArgsFormatDescUtils::Append(arg_desc, AddrType::OUTPUT, -1); - ArgsFormatDescUtils::Append(arg_desc, AddrType::WORKSPACE, -1); - - std::string res = ArgsFormatDescUtils::ToString(arg_desc); - std::string expect_res = "{i*}{o*}{ws*}"; - EXPECT_EQ(expect_res, res); - - // 对外开放的接口,parser时不会处理动态输入的场景,因此: - // i* -> [AddrType::INPUT, -1, false] - // o* -> [AddrType::INPUT, -1, false] - // ws* -> [AddrType::INPUT, -1, false] - std::vector descs; - EXPECT_EQ(ArgsFormatDescUtils::Parse(expect_res, descs), SUCCESS); - EXPECT_EQ(descs.size(), 3UL); - EXPECT_EQ(descs[0].addr_type, AddrType::INPUT); - EXPECT_EQ(descs[0].ir_idx, -1); - EXPECT_EQ(descs[0].folded, false); - - EXPECT_EQ(descs[1].addr_type, AddrType::OUTPUT); - EXPECT_EQ(descs[1].ir_idx, -1); - EXPECT_EQ(descs[1].folded, false); - - EXPECT_EQ(descs[2].addr_type, AddrType::WORKSPACE); - EXPECT_EQ(descs[2].ir_idx, -1); - EXPECT_EQ(descs[2].folded, false); -} - -REG_OP(DynOP) - .INPUT(a, TensorType(DT_FLOAT32)) - .DYNAMIC_INPUT(b, TensorType(DT_FLOAT32)) - .INPUT(c, TensorType(DT_FLOAT32)) - .OUTPUT(x, TensorType(DT_FLOAT32)) - .DYNAMIC_OUTPUT(y, TensorType(DT_FLOAT32)) - .OUTPUT(z, TensorType(DT_FLOAT32)) - .OP_END_FACTORY_REG(DynOP); - -TEST_F(UtestArgsFormatDescUtils, common_args_dynamic_folded) { - std::vector args_desc; - ArgsFormatDescUtils::Append(args_desc, AddrType::INPUT, -1); - ArgsFormatDescUtils::Append(args_desc, AddrType::OUTPUT, -1); - ArgsFormatDescUtils::Append(args_desc, AddrType::WORKSPACE, 0); - ArgsFormatDescUtils::Append(args_desc, AddrType::WORKSPACE, 1); - std::string res = ArgsFormatDescUtils::ToString(args_desc); - std::string expect_res = "{i*}{o*}{ws0*}{ws1*}"; - EXPECT_EQ(expect_res, res); - - std::vector descs; - EXPECT_EQ(ArgsFormatDescUtils::Parse(expect_res, descs), SUCCESS); - EXPECT_EQ(descs[0].addr_type, AddrType::INPUT); - EXPECT_EQ(descs[0].ir_idx, -1); - EXPECT_EQ(descs[0].folded, false); - - EXPECT_EQ(descs[2].addr_type, AddrType::WORKSPACE); - EXPECT_EQ(descs[2].ir_idx, 0); - EXPECT_EQ(descs[2].folded, false); - - std::string expanded_res = "{i0}{i1}{i2}{o0}{o1}{o2}{ws0}{ws1}{i0}{i_desc0}{i_instance0}{o_instance*}"; - std::vector expanded_descs; - EXPECT_EQ(ArgsFormatDescUtils::Parse(expanded_res, expanded_descs), SUCCESS); - // 对外开放的接口,由于没有op_desc,因此不会将i* o* ws*类似的展开,因此不会相等 - EXPECT_NE(descs.size(), expanded_descs.size()); - - EXPECT_EQ(expanded_descs[0].addr_type, AddrType::INPUT); - EXPECT_EQ(expanded_descs[0].ir_idx, 0); - EXPECT_EQ(expanded_descs[0].folded, true); - -} - -TEST_F(UtestArgsFormatDescUtils, serialize_dynamic_args) { - std::vector args_des; - ArgsFormatDescUtils::Append(args_des, AddrType::FFTS_ADDR); - ArgsFormatDescUtils::Append(args_des, AddrType::INPUT, 0); - ArgsFormatDescUtils::Append(args_des, AddrType::INPUT_DESC, 1, true); - ArgsFormatDescUtils::Append(args_des, AddrType::INPUT_DESC, 2, true); - ArgsFormatDescUtils::Append(args_des, AddrType::INPUT, 4); - ArgsFormatDescUtils::Append(args_des, AddrType::OUTPUT_DESC, 0, true); - ArgsFormatDescUtils::Append(args_des, AddrType::WORKSPACE, 0); - ArgsFormatDescUtils::Append(args_des, AddrType::TILING_FFTS, 1); - std::string res = ArgsFormatDescUtils::ToString(args_des); - std::string expect_res = "{ffts_addr}{i0*}{i_desc1}{i_desc2}{i4*}{o_desc0}{ws0*}{t_ffts.tail}"; - EXPECT_EQ(expect_res, res); - - std::vector descs; - EXPECT_EQ(ArgsFormatDescUtils::Parse(expect_res, descs), SUCCESS); - EXPECT_EQ(descs.size(), 8UL); - EXPECT_EQ(descs[0].addr_type, AddrType::FFTS_ADDR); - EXPECT_EQ(descs[0].ir_idx, -1); - EXPECT_EQ(descs[0].folded, false); - - EXPECT_EQ(descs[1].addr_type, AddrType::INPUT); - EXPECT_EQ(descs[1].ir_idx, 0); - EXPECT_EQ(descs[1].folded, false); - - EXPECT_EQ(descs[2].addr_type, AddrType::INPUT_DESC); - EXPECT_EQ(descs[2].ir_idx, 1); - EXPECT_EQ(descs[2].folded, true); - - EXPECT_EQ(descs[3].addr_type, AddrType::INPUT_DESC); - EXPECT_EQ(descs[3].ir_idx, 2); - EXPECT_EQ(descs[3].folded, true); - - EXPECT_EQ(descs[4].addr_type, AddrType::INPUT); - EXPECT_EQ(descs[4].ir_idx, 4); - EXPECT_EQ(descs[4].folded, false); - - EXPECT_EQ(descs[5].addr_type, AddrType::OUTPUT_DESC); - EXPECT_EQ(descs[5].ir_idx, 0); - EXPECT_EQ(descs[5].folded, true); - - EXPECT_EQ(descs[6].addr_type, AddrType::WORKSPACE); - EXPECT_EQ(descs[6].ir_idx, 0); - EXPECT_EQ(descs[6].folded, false); - - EXPECT_EQ(descs[7].addr_type, AddrType::TILING_FFTS); - // ffts的ir_index无效,默认值-1 - EXPECT_EQ(descs[7].ir_idx, -1); - EXPECT_EQ(descs[7].folded, false); -} - -TEST_F(UtestArgsFormatDescUtils, deserialzie_placeholder) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = "{}"; - std::vector descs; - EXPECT_EQ(ArgsFormatDescUtils::Parse(format1, descs), SUCCESS); - EXPECT_EQ(descs.size(), 1UL); - EXPECT_EQ(descs[0UL].addr_type, AddrType::PLACEHOLDER); -} - -TEST_F(UtestArgsFormatDescUtils, deserialzie_unsupported) { -auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = "{hehe}"; - std::vector descs1; - EXPECT_NE(ArgsFormatDescUtils::Parse(format1, descs1), SUCCESS); - - std::string format2 = "{ }"; - std::vector descs2; - EXPECT_NE(ArgsFormatDescUtils::Parse(format2, descs2), SUCCESS); - - std::string format3 = "{hi.unsupported}"; - std::vector descs3; - EXPECT_NE(ArgsFormatDescUtils::Parse(format3, descs3), SUCCESS); -} - -TEST_F(UtestArgsFormatDescUtils, deserialzie_tiling_context) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = - "{tiling_context}{*op_type}{tiling_context.tiling_key}{tiling_context.tiling_data}{tiling_context.block_dim}"; - std::vector descs; - EXPECT_EQ(ArgsFormatDescUtils::Parse(format1, descs), SUCCESS); - EXPECT_EQ(descs.size(), 5UL); - EXPECT_EQ(descs[0UL].addr_type, AddrType::TILING_CONTEXT); - EXPECT_EQ(descs[1UL].addr_type, AddrType::OP_TYPE); - EXPECT_EQ(descs[2UL].ir_idx, static_cast(TilingContextSubType::TILING_KEY)); - EXPECT_EQ(descs[3UL].ir_idx, static_cast(TilingContextSubType::TILING_DATA)); - EXPECT_EQ(descs[4UL].ir_idx, static_cast(TilingContextSubType::BLOCK_DIM)); -} - -TEST_F(UtestArgsFormatDescUtils, serialzie_tiling_context) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::string format1 = - "{tiling_context}{*op_type}{tiling_context.tiling_key}{tiling_context.tiling_data}{tiling_context.block_dim}"; - std::vector args_desc; - ArgsFormatDescUtils::AppendTilingContext(args_desc); - ArgsFormatDescUtils::Append(args_desc, AddrType::OP_TYPE); - ArgsFormatDescUtils::AppendTilingContext(args_desc, TilingContextSubType::TILING_KEY); - ArgsFormatDescUtils::AppendTilingContext(args_desc, TilingContextSubType::TILING_DATA); - ArgsFormatDescUtils::AppendTilingContext(args_desc, TilingContextSubType::BLOCK_DIM); - - std::string res = ArgsFormatDescUtils::ToString(args_desc); - EXPECT_EQ(format1, res); -} - -TEST_F(UtestArgsFormatDescUtils, insert_hidden_input) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::vector descs; - ArgsFormatDescUtils::Append(descs, AddrType::PLACEHOLDER); - ArgsFormatDescUtils::InsertHiddenInputs(descs, -1, HiddenInputsType::HCOM); - EXPECT_EQ(ArgsFormatDescUtils::ToString(descs), "{}{hi.hcom0*}"); - - EXPECT_EQ(ArgsFormatDescUtils::InsertHiddenInputs(descs, 1, HiddenInputsType::HCOM, 3), GRAPH_SUCCESS); - EXPECT_EQ(ArgsFormatDescUtils::ToString(descs), "{}{hi.hcom0*}{hi.hcom1*}{hi.hcom2*}{hi.hcom0*}"); - - EXPECT_EQ(ArgsFormatDescUtils::InsertHiddenInputs(descs, 5, HiddenInputsType::HCOM, 3), GRAPH_SUCCESS); - EXPECT_EQ(ArgsFormatDescUtils::ToString(descs), "{}{hi.hcom0*}{hi.hcom1*}{hi.hcom2*}{hi.hcom0*}{hi.hcom0*}{hi.hcom1*}{hi.hcom2*}"); - - EXPECT_NE(ArgsFormatDescUtils::InsertHiddenInputs(descs, 9, HiddenInputsType::HCOM, 1), GRAPH_SUCCESS); -} - -TEST_F(UtestArgsFormatDescUtils, insert_custom_val) { - auto op_desc = std::make_shared("tmp_op", "Mul"); - std::vector descs; - - EXPECT_EQ(ArgsFormatDescUtils::InsertCustomValue(descs, -1, 0), GRAPH_SUCCESS); - EXPECT_EQ(ArgsFormatDescUtils::ToString(descs), "{#0}"); - - EXPECT_EQ(ArgsFormatDescUtils::InsertCustomValue(descs, 0, UINT64_MAX), GRAPH_SUCCESS); - EXPECT_EQ(ArgsFormatDescUtils::ToString(descs), "{#18446744073709551615}{#0}"); - - EXPECT_NE(ArgsFormatDescUtils::InsertCustomValue(descs, 3, UINT64_MAX), GRAPH_SUCCESS); -} - -TEST_F(UtestArgsFormatDescUtils, test_debug_log) { - dlog_setlevel(0, 0, 0); - std::vector desc; - ArgsFormatDescUtils::Append(desc, AddrType::INPUT); - ArgsFormatDescUtils::Append(desc, AddrType::INPUT, 0); - ArgsFormatDescUtils::Append(desc, AddrType::INPUT, 0, true); - ArgsFormatDescUtils::Append(desc, AddrType::FFTS_ADDR); - ArgsFormatDescUtils::Append(desc, AddrType::OVERFLOW_ADDR); - ArgsFormatDescUtils::Append(desc, AddrType::TILING); - ArgsFormatDescUtils::Append(desc, AddrType::TILING_FFTS, 0); - ArgsFormatDescUtils::Append(desc, AddrType::TILING_FFTS, 1); - std::string res = ArgsFormatDescUtils::ToString(desc); - std::string expect_res = "{i*}{i0*}{i0}{ffts_addr}{overflow_addr}{t}{t_ffts.non_tail}{t_ffts.tail}"; - EXPECT_EQ(expect_res, res); - dlog_setlevel(0, 3, 0); -} - -TEST_F(UtestArgsFormatDescUtils, AddrTypeValue) { - // 增加新的枚举值,需要 l0 exception dump 适配 - EXPECT_EQ(static_cast(AddrType::MAX), static_cast(AddrType::EVENT_ADDR) + 1); - EXPECT_EQ(static_cast(TilingContextSubType::MAX), static_cast(TilingContextSubType::BLOCK_DIM) + 1); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/attr_serializer_registry_unittest.cc b/tests/ut/graph/testcase/attr_serializer_registry_unittest.cc deleted file mode 100644 index 7fd1ad7596050c36edcdc9af581a6af3c215396a..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/attr_serializer_registry_unittest.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "graph/serialization/attr_serializer_registry.h" -#include "graph/serialization/string_serializer.h" - -#include "proto/ge_ir.pb.h" -#include -#include -namespace ge { -class AttrSerializerRegistryUt : public testing::Test {}; - -TEST_F(AttrSerializerRegistryUt, StringReg) { - REG_GEIR_SERIALIZER(serializer_for_ut, ge::StringSerializer, GetTypeId(), proto::AttrDef::kS); - GeIrAttrSerializer *serializer = AttrSerializerRegistry::GetInstance().GetSerializer(GetTypeId()); - GeIrAttrSerializer *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kS); - ASSERT_NE(serializer, nullptr); - ASSERT_NE(deserializer, nullptr); -} - - -} diff --git a/tests/ut/graph/testcase/attr_serializer_unittest.cc b/tests/ut/graph/testcase/attr_serializer_unittest.cc deleted file mode 100644 index eb9d19e6d2e671d3d1f13bf06f4168a57e167110..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/attr_serializer_unittest.cc +++ /dev/null @@ -1,611 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include "graph/serialization/attr_serializer_registry.h" -#include "graph/serialization/string_serializer.h" -#include "graph/serialization/int_serializer.h" -#include "graph/serialization/float_serializer.h" -#include "graph/serialization/bool_serializer.h" -#include "graph/serialization/buffer_serializer.h" -#include "graph/serialization/data_type_serializer.h" -#include "graph/serialization/tensor_serializer.h" -#include "graph/serialization/tensor_desc_serializer.h" -#include "graph/serialization/list_list_int_serializer.h" -#include "graph/serialization/list_value_serializer.h" -#include "graph/serialization/list_list_float_serializer.h" -#include "graph/serialization/graph_serializer.h" -#include "graph/serialization/named_attrs_serializer.h" -#include "graph/any_value.h" -#include "graph/utils/attr_utils.h" -#include "graph/op_desc.h" -#include "proto/ge_ir.pb.h" -#include "graph/ge_attr_value.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/model_serialize.h" -#include "graph_builder_utils.h" -#include "test_std_structs.h" - -namespace ge { -GeTensorPtr CreateTensor_1_1_224_224(float *tensor_data) { - auto tensor = std::make_shared(); - tensor->SetData(reinterpret_cast(tensor_data), 224 * 224 * sizeof(float)); - GeTensorDesc td; - td.SetShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetOriginShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetFormat(FORMAT_NCHW); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - tensor->SetTensorDesc(td); - return tensor; -} - -ComputeGraphPtr CreateGraph_1_1_224_224(float *tensor_data) { - ut::GraphBuilder builder("graph1"); - auto data1 = builder.AddNode("data1", "Data", {}, {"y"}); - AttrUtils::SetInt(data1->GetOpDesc(), "index", 0); - auto const1 = builder.AddNode("const1", "Const", {}, {"y"}); - GeTensorDesc const1_td; - const1_td.SetShape(GeShape({1, 1, 224, 224})); - const1_td.SetOriginShape(GeShape({1, 1, 224, 224})); - const1_td.SetFormat(FORMAT_NCHW); - const1_td.SetOriginFormat(FORMAT_NCHW); - const1_td.SetDataType(DT_FLOAT); - const1_td.SetOriginDataType(DT_FLOAT); - GeTensor tensor(const1_td); - tensor.SetData(reinterpret_cast(tensor_data), sizeof(float) * 224 * 224); - AttrUtils::SetTensor(const1->GetOpDesc(), "value", tensor); - auto add1 = builder.AddNode("add1", "Add", {"x1", "x2"}, {"y"}); - auto netoutput1 = builder.AddNode("NetOutputNode", "NetOutput", {"x"}, {}); - - builder.AddDataEdge(data1, 0, add1, 0); - builder.AddDataEdge(const1, 0, add1, 1); - builder.AddDataEdge(add1, 0, netoutput1, 0); - - return builder.GetGraph(); -} - -class AttrSerializerUt : public testing::Test {}; - -TEST_F(AttrSerializerUt, StringAttr) { - REG_GEIR_SERIALIZER(str_serializer, StringSerializer, GetTypeId(), proto::AttrDef::kS); - - auto op_desc = std::make_shared(); - AttrUtils::SetStr(op_desc, "str_name", "test_string1"); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("str_name"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kS); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - std::string res; - value2.GetValue(res); - ASSERT_EQ(res, "test_string1"); -} - -TEST_F(AttrSerializerUt, IntAttr) { - REG_GEIR_SERIALIZER(int_serializer, IntSerializer, GetTypeId(), proto::AttrDef::kI); - - auto op_desc = std::make_shared(); - - int64_t val = 12344321; - AttrUtils::SetInt(op_desc, "int_val", val); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("int_val"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kI); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - int64_t res; - value2.GetValue(res); - ASSERT_EQ(res, val); -} - -TEST_F(AttrSerializerUt, FloatAttr) { - REG_GEIR_SERIALIZER(float_serializer, FloatSerializer, GetTypeId(), proto::AttrDef::kF); - - auto op_desc = std::make_shared(); - - float val = 123.321f; - AttrUtils::SetFloat(op_desc, "float_val", val); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("float_val"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kF); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - float res; - value2.GetValue(res); - ASSERT_EQ(res, val); -} - -TEST_F(AttrSerializerUt, TensorDescAttr) { - REG_GEIR_SERIALIZER(tensor_desc_serializer, TensorDescSerializer, - GetTypeId(), proto::AttrDef::kTd); - - GeTensorDesc td; - td.SetShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetOriginShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetFormat(FORMAT_NCHW); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT); - td.SetOriginDataType(DT_FLOAT); - - auto op_desc = std::make_shared(); - - AttrUtils::SetTensorDesc(op_desc, "desc_val", td); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("desc_val"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kTd); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - GeTensorDesc res; - ASSERT_EQ(value2.GetValue(res), GRAPH_SUCCESS); - ASSERT_EQ(res, td); -} - -TEST_F(AttrSerializerUt, BoolAttr) { - REG_GEIR_SERIALIZER(bool_serializer, BoolSerializer, GetTypeId(), proto::AttrDef::kI); - - auto op_desc = std::make_shared(); - - bool val = true; - AttrUtils::SetBool(op_desc, "bool_val", val); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("bool_val"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kB); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - bool res; - value2.GetValue(res); - ASSERT_EQ(res, val); -} - -TEST_F(AttrSerializerUt, DataTypeAttr) { - REG_GEIR_SERIALIZER(data_type_serializer, DataTypeSerializer, - GetTypeId(), proto::AttrDef::kDt); - - auto op_desc = std::make_shared(); - DataType dt = DT_DOUBLE; - AttrUtils::SetDataType(op_desc, "val", dt); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("val"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kDt); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - DataType res; - value2.GetValue(res); - ASSERT_EQ(res, dt); -} - -TEST_F(AttrSerializerUt, NamedAttr) { - REG_GEIR_SERIALIZER(named_attr_serializer, NamedAttrsSerializer, - GetTypeId(), proto::AttrDef::kFunc); - REG_GEIR_SERIALIZER(tesnor_serializer, TensorSerializer, GetTypeId(), proto::AttrDef::kT); - REG_GEIR_SERIALIZER(tensor_desc_serializer, TensorDescSerializer, - GetTypeId(), proto::AttrDef::kTd); - REG_GEIR_SERIALIZER(graph_desc_serializer, GraphSerializer, - GetTypeId(), proto::AttrDef::kG); - - AnyValue value; - value.SetValue(1.2f); - - ge::NamedAttrs named_attrs; - named_attrs.SetName("named_attr"); - - float data[224 * 224] = {1.0f}; - GeTensorPtr ge_tensor = CreateTensor_1_1_224_224(data); - AttrUtils::SetTensor(named_attrs, "tensor_attr", ge_tensor); - - float tensor[224 * 224] = {1.0f}; - auto compute_graph = CreateGraph_1_1_224_224(tensor); - AttrUtils::SetGraph(named_attrs, "graph_attr", compute_graph); - - auto op_desc = std::make_shared(); - AttrUtils::SetNamedAttrs(op_desc, "named_attr", named_attrs); - - ModelSerializeImp impl; - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - google::protobuf::Map attr_map = op_def.attr(); - - EXPECT_TRUE(attr_map.count("named_attr") > 0); - - auto res_op_desc = std::make_shared(); - EXPECT_TRUE(impl.UnserializeOpDesc(res_op_desc, op_def)); - - ge::NamedAttrs res_named_attrs; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(res_op_desc, "named_attr", res_named_attrs)); -} - -TEST_F(AttrSerializerUt, OpToString) { - REG_GEIR_SERIALIZER(tensor_serializer, TensorSerializer, GetTypeId(), proto::AttrDef::kT); - REG_GEIR_SERIALIZER(tensor_desc_serializer, TensorDescSerializer, GetTypeId(), proto::AttrDef::kTd); - - auto op_desc = std::make_shared(); - float data[224 * 224] = {1.0f}; - GeTensorPtr ge_tensor = CreateTensor_1_1_224_224(data); - AttrUtils::SetTensor(op_desc, "tensor", ge_tensor); - - ModelSerializeImp impl; - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - std::string op_str = op_def.SerializeAsString(); - EXPECT_TRUE(attr_map.count("tensor") > 0); -} - -TEST_F(AttrSerializerUt, ListFloatAttr) { - REG_GEIR_SERIALIZER(float_serializer, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); - - auto op_desc = std::make_shared(); - vector val = {1.2f, 1.3f, 1.4f}; - AttrUtils::SetListFloat(op_desc, "mem_size", val); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("mem_size"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kList); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - std::vector res; - value2.GetValue(res); - ASSERT_EQ(res, val); -} - -TEST_F(AttrSerializerUt, ListIntAttr) { - REG_GEIR_SERIALIZER(list_int_serializer, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); - - auto op_desc = std::make_shared(); - vector val = {-1, 224, 224, 224}; - AttrUtils::SetListInt(op_desc, "shapes", val); - - ModelSerializeImp impl; - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("shapes"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kList); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - std::vector res; - value2.GetValue(res); - ASSERT_EQ(res, val); -} - -TEST_F(AttrSerializerUt, ListListFloatAttr) { - REG_GEIR_SERIALIZER(list_float_serializer, ListListFloatSerializer, GetTypeId>>(), - proto::AttrDef::kListListFloat); - - auto op_desc = std::make_shared(); - vector> val = {{1.2f, 1.3f, 1.4f}}; - AttrUtils::SetListListFloat(op_desc, "mem_size", val); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("mem_size"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kListListFloat); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - std::vector> res; - value2.GetValue(res); - ASSERT_EQ(res, val); -} - -TEST_F(AttrSerializerUt, ListListIntAttr) { - REG_GEIR_SERIALIZER(list_int_serializer, ListListIntSerializer, GetTypeId>>(), - proto::AttrDef::kListListInt); - - auto op_desc = std::make_shared(); - vector> val = {{0, 1}, {-1, 1}}; - AttrUtils::SetListListInt(op_desc, "value_range", val); - - ModelSerializeImp impl; - - proto::OpDef op_def; - impl.SerializeOpDesc(op_desc, &op_def); - - google::protobuf::Map attr_map = op_def.attr(); - - auto iter = attr_map.find("value_range"); - EXPECT_TRUE(iter != attr_map.end()); - - proto::AttrDef attr2 = iter->second; - AnyValue value2; - auto *deserializer = AttrSerializerRegistry::GetInstance().GetDeserializer(proto::AttrDef::kListListInt); - ASSERT_NE(deserializer, nullptr); - deserializer->Deserialize(attr2, value2); - std::vector> res; - value2.GetValue(res); - ASSERT_EQ(res, val); -} - -TEST_F(AttrSerializerUt, SetAttrToComputeGraph) { - REG_GEIR_SERIALIZER(list_data_type_serializer, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); - - float tensor[224 * 224] = {1.0f}; - auto computer_graph = CreateGraph_1_1_224_224(tensor); - - std::vector dts = {DT_DOUBLE, DT_BF16}; - std::vector strings = {"str1", "str2", "str3"}; - std::vector bools = {true, false, true}; - std::vector buffers = {Buffer(128), Buffer(512)}; - - AttrUtils::SetListDataType(computer_graph, "list_dt", dts); - AttrUtils::SetListStr(computer_graph, "list_str", strings); - AttrUtils::SetListBool(computer_graph, "list_bool", bools); - AttrUtils::SetListBytes(computer_graph, "list_buffer", buffers); - - ModelSerializeImp impl; - - proto::GraphDef graph_def; - impl.SerializeGraph(computer_graph, &graph_def); - - google::protobuf::Map attr_map = graph_def.attr(); - - EXPECT_TRUE(attr_map.count("list_dt") > 0); - EXPECT_TRUE(attr_map.count("list_str") > 0); - EXPECT_TRUE(attr_map.count("list_bool") > 0); - EXPECT_TRUE(attr_map.count("list_buffer") > 0); - - auto compute_graph_gen = std::make_shared("res_graph"); - impl.UnserializeGraph(compute_graph_gen, graph_def); - - std::map res_map = AttrUtils::GetAllAttrs(compute_graph_gen); - - EXPECT_TRUE(res_map.count("list_dt") > 0); - EXPECT_TRUE(res_map.count("list_str") > 0); - EXPECT_TRUE(res_map.count("list_bool") > 0); - EXPECT_TRUE(res_map.count("list_buffer") > 0); - - std::vector res_dts; - ASSERT_EQ(res_map["list_dt"].GetValue(res_dts), GRAPH_SUCCESS); - ASSERT_EQ(res_dts, dts); - - std::vector res_strs; - ASSERT_EQ(res_map["list_str"].GetValue(res_strs), GRAPH_SUCCESS); - ASSERT_EQ(res_strs, strings); - - std::vector res_bools; - ASSERT_EQ(res_map["list_bool"].GetValue(res_bools), GRAPH_SUCCESS); - ASSERT_EQ(res_bools, bools); - - std::vector res_bts; - ASSERT_EQ(res_map["list_buffer"].GetValue(res_bts), GRAPH_SUCCESS); - ASSERT_EQ(res_bts.size(), buffers.size()); -} - -TEST_F(AttrSerializerUt, SetListAttrToComputeGraph) { - REG_GEIR_SERIALIZER(list_named_attr_serializer, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); - REG_GEIR_SERIALIZER(list_tensor_serializer, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); - REG_GEIR_SERIALIZER(list_tensor_desc_serializer, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); - REG_GEIR_SERIALIZER(list_graph_def_serializer, ListValueSerializer, - GetTypeId>(), proto::AttrDef::kList); - REG_GEIR_SERIALIZER(list_named_attr_serializer, NamedAttrsSerializer, - GetTypeId(), proto::AttrDef::kFunc); - REG_GEIR_SERIALIZER(list_tensor_serializer, TensorSerializer, GetTypeId(), proto::AttrDef::kT); - REG_GEIR_SERIALIZER(list_tensor_desc_serializer, TensorDescSerializer, - GetTypeId(), proto::AttrDef::kTd); - REG_GEIR_SERIALIZER(list_graph_def_serializer, GraphSerializer, - GetTypeId(), proto::AttrDef::kG); - - float tensor[224 * 224] = {1.0f}; - auto computer_graph = CreateGraph_1_1_224_224(tensor); - float tensor1[224 * 224] = {1.0f}; - GeTensorPtr ge_tensor = CreateTensor_1_1_224_224(tensor1); - std::vector ge_tensors = {*ge_tensor}; - - GeTensorDesc const1_td; - const1_td.SetShape(GeShape({1, 1, 224, 224})); - const1_td.SetOriginShape(GeShape({1, 1, 224, 224})); - const1_td.SetFormat(FORMAT_NCHW); - const1_td.SetOriginFormat(FORMAT_NCHW); - const1_td.SetDataType(DT_FLOAT); - const1_td.SetOriginDataType(DT_FLOAT); - std::vector tds = {const1_td}; - - ge::NamedAttrs named_attrs; - named_attrs.SetName("named_attr"); - std::vector attrs = {named_attrs}; - - float t[224 * 224] = {2.0f}; - auto graph_t = CreateGraph_1_1_224_224(t); - graph_t->SetName("graph_t"); - - std::vector graphs = {graph_t}; - - AttrUtils::SetListTensor(computer_graph, "list_t", ge_tensors); - AttrUtils::SetListTensorDesc(computer_graph, "list_td", tds); - AttrUtils::SetListNamedAttrs(computer_graph, "list_n", attrs); - AttrUtils::SetListGraph(computer_graph, "list_g", graphs); - - ModelSerializeImp impl; - - proto::GraphDef graph_def; - impl.SerializeGraph(computer_graph, &graph_def); - - google::protobuf::Map attr_map = graph_def.attr(); - - EXPECT_TRUE(attr_map.count("list_t") > 0); - EXPECT_TRUE(attr_map.count("list_g") > 0); - EXPECT_TRUE(attr_map.count("list_td") > 0); - EXPECT_TRUE(attr_map.size() == 4); - EXPECT_TRUE(attr_map.count("list_n") > 0); - - auto compute_graph_gen = std::make_shared("res_graph"); - impl.UnserializeGraph(compute_graph_gen, graph_def); - - std::map res_map = AttrUtils::GetAllAttrs(compute_graph_gen); - - EXPECT_TRUE(res_map.count("list_t") > 0); - EXPECT_TRUE(res_map.count("list_g") > 0); - EXPECT_TRUE(res_map.count("list_td") > 0); - EXPECT_TRUE(res_map.count("list_n") > 0); - -} - - -TEST_F(AttrSerializerUt, TdAttrInOpDesc) { - GeTensorDesc td = StandardTd_5d_1_1_224_224(); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc("x1", td); - op_desc->AddInputDesc("x2", td); - op_desc->AddOutputDesc("y", td); - AttrUtils::SetStr(op_desc, "padding", "SAME"); - - auto op_def = std::make_shared(); - EXPECT_NE(op_def, nullptr); - ModelSerializeImp imp; - EXPECT_TRUE(imp.SerializeOpDesc(op_desc, op_def.get())); - - EXPECT_EQ(op_def->attr().count("padding"), 1); - EXPECT_EQ(op_def->attr().at("padding").value_case(), proto::AttrDef::ValueCase::kS); - EXPECT_EQ(op_def->input_desc_size(), 2); - EXPECT_EQ(op_def->output_desc_size(), 1); - - ExpectStandardTdProto_5d_1_1_224_224(op_def->input_desc(0)); - ExpectStandardTdProto_5d_1_1_224_224(op_def->input_desc(1)); - ExpectStandardTdProto_5d_1_1_224_224(op_def->output_desc(0)); -} - -TEST_F(AttrSerializerUt, ConstSerializer) { - ut::GraphBuilder builder("graph1"); - auto data1 = builder.AddNode("data1", "Data", {}, {"y"}); - AttrUtils::SetInt(data1->GetOpDesc(), "index", 0); - auto const1 = builder.AddNode("const1", "Const", {}, {"y"}); - GeTensorDesc const1_td; - const1_td.SetShape(GeShape({1, 1, 224, 224})); - const1_td.SetOriginShape(GeShape({1, 1, 224, 224})); - const1_td.SetFormat(FORMAT_NCHW); - const1_td.SetOriginFormat(FORMAT_NCHW); - const1_td.SetDataType(DT_FLOAT); - const1_td.SetOriginDataType(DT_FLOAT); - GeTensor tensor(const1_td); - float tensor_data[224*224] = {1.010101, 2.020202, 3.030303}; - tensor.SetData(reinterpret_cast(tensor_data), sizeof(float) * 224 * 224); - AttrUtils::SetTensor(const1->GetOpDesc(), "value", tensor); - - ComputeGraphPtr graph = builder.GetGraph(); - - ModelSerializeImp impl; - proto::GraphDef graph_def; - impl.SerializeGraph(graph, &graph_def); - auto compute_graph_gen = std::make_shared("res_graph"); - - impl.UnserializeGraph(compute_graph_gen, graph_def); - - NodePtr res_node = compute_graph_gen->FindNode("const1"); - EXPECT_TRUE(res_node != nullptr); - - ConstGeTensorPtr res_tensor; - EXPECT_TRUE(AttrUtils::GetTensor(res_node->GetOpDesc(), "value", res_tensor)); -} - -} // namespace ge diff --git a/tests/ut/graph/testcase/attr_utils_unittest.cc b/tests/ut/graph/testcase/attr_utils_unittest.cc deleted file mode 100644 index e73c49c147646e71a84eefbdba35a32cdd60d8af..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/attr_utils_unittest.cc +++ /dev/null @@ -1,1320 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/utils/attr_utils.h" -#include "graph/op_desc.h" -#include "graph/compute_graph.h" -#include "graph/debug/ge_attr_define.h" -#include "graph_builder_utils.h" -#include "test_std_structs.h" - -namespace ge { -namespace { -std::unique_ptr GetRandomFloat(std::initializer_list shape) { - int64_t size = 1; - for (auto dim : shape) { - size *= dim; - } - auto data = std::unique_ptr(new float[size]); - for (int64_t i = 0; i < size; ++i) { - data.get()[i] = static_cast(rand()) / static_cast(RAND_MAX); - } - return data; -} - -GeTensorPtr CreateTensor_8_3_224_224(float *tensor_data) { - auto tensor = std::make_shared(); - tensor->SetData(reinterpret_cast(tensor_data), 8*3*224*224*sizeof(float)); - GeTensorDesc td; - td.SetShape(GeShape(std::vector({8, 3, 224, 224}))); - td.SetOriginShape(GeShape(std::vector({8, 3, 224, 224}))); - td.SetFormat(FORMAT_NCHW); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - tensor->SetTensorDesc(td); - return tensor; -} - -void ExpectTensorEqual_8_3_224_224(ConstGeTensorPtr out_tensor, float *tensor_data) { - EXPECT_NE(const_cast(out_tensor->GetData().data()), reinterpret_cast(tensor_data)); - EXPECT_EQ(out_tensor->GetData().size(), 8*3*224*224*sizeof(float)); - for (size_t i = 0; i < 8*3*224*224; ++i) { - EXPECT_FLOAT_EQ(reinterpret_cast(out_tensor->GetData().data())[i], tensor_data[i]); - } - EXPECT_EQ(out_tensor->GetTensorDesc().GetShape().GetDims(), std::vector({8,3,224,224})); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginShape().GetDims(), std::vector({8,3,224,224})); - EXPECT_EQ(out_tensor->GetTensorDesc().GetFormat(), FORMAT_NCHW); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(out_tensor->GetTensorDesc().GetDataType(), DT_FLOAT); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&out_tensor->GetTensorDesc(), "bcd", s)); - EXPECT_EQ(s, "Hello world"); -} - -GeTensorPtr CreateTensor_5d_8_3_224_224(float *tensor_data) { - auto tensor = std::make_shared(); - tensor->SetData(reinterpret_cast(tensor_data), 8*1*224*224*16*sizeof(float)); - GeTensorDesc td; - td.SetShape(GeShape(std::vector({8, 1, 224, 224, 16}))); - td.SetOriginShape(GeShape(std::vector({8, 3, 224, 224}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - tensor->SetTensorDesc(td); - return tensor; -} - -void ExpectTensorEqual_5d_8_3_224_224(ConstGeTensorPtr out_tensor, float *tensor_data) { - EXPECT_NE(const_cast(out_tensor->GetData().data()), reinterpret_cast(tensor_data)); - EXPECT_EQ(out_tensor->GetData().size(), 8*1*224*224*16*sizeof(float)); - for (size_t i = 0; i < 8*1*224*224*16; ++i) { - EXPECT_FLOAT_EQ(reinterpret_cast(out_tensor->GetData().data())[i], tensor_data[i]); - } - EXPECT_EQ(out_tensor->GetTensorDesc().GetShape().GetDims(), std::vector({8,1,224,224,16})); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginShape().GetDims(), std::vector({8,3,224,224})); - EXPECT_EQ(out_tensor->GetTensorDesc().GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(out_tensor->GetTensorDesc().GetDataType(), DT_FLOAT); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&out_tensor->GetTensorDesc(), "bcd", s)); - EXPECT_EQ(s, "Hello world"); -} - -ComputeGraphPtr CreateGraph_1_1_224_224(float *tensor_data) { - ut::GraphBuilder builder("graph1"); - auto data1 = builder.AddNode("data1", "Data", {}, {"y"}); - AttrUtils::SetInt(data1->GetOpDesc(), "index", 0); - auto const1 = builder.AddNode("const1", "Const", {}, {"y"}); - GeTensorDesc const1_td; - const1_td.SetShape(GeShape({1,1,224,224})); - const1_td.SetOriginShape(GeShape({1,1,224,224})); - const1_td.SetFormat(FORMAT_NCHW); - const1_td.SetOriginFormat(FORMAT_NCHW); - const1_td.SetDataType(DT_FLOAT); - const1_td.SetOriginDataType(DT_FLOAT); - GeTensor tensor(const1_td); - tensor.SetData(reinterpret_cast(tensor_data), sizeof(float) * 224 * 224); - AttrUtils::SetTensor(const1->GetOpDesc(), "value", tensor); - auto add1 = builder.AddNode("add1", "Add", {"x1", "x2"}, {"y"}); - auto netoutput1 = builder.AddNode("NetOutputNode", "NetOutput", {"x"}, {}); - - builder.AddDataEdge(data1, 0, add1, 0); - builder.AddDataEdge(const1, 0, add1, 1); - builder.AddDataEdge(add1, 0, netoutput1, 0); - - return builder.GetGraph(); -} - -bool ExpectConnected(const NodePtr &src, int src_index, const NodePtr &dst, int dst_index) { - AnchorPtr src_anchor, dst_anchor; - if (src_index >= 0 && dst_index >= 0) { - src_anchor = src->GetOutDataAnchor(src_index); - dst_anchor = dst->GetInDataAnchor(dst_index); - } else if (src_index < 0 && dst_index < 0) { - src_anchor = src->GetOutControlAnchor(); - dst_anchor = dst->GetInControlAnchor(); - } else { - return false; - } - - for (auto &peer_anchor : dst_anchor->GetPeerAnchors()) { - if (src_anchor == peer_anchor) { - return true; - } - } - return false; -} - -void ExpectEqGraph_1_1_224_224(const ConstComputeGraphPtr &graph, float *tensor_data) { - EXPECT_EQ(graph->GetAllNodesSize(), 4); - auto data1 = graph->FindNode("data1"); - auto const1 = graph->FindNode("const1"); - auto add1 = graph->FindNode("add1"); - auto netoutput1 = graph->FindNode("NetOutputNode"); - EXPECT_NE(data1, nullptr); - EXPECT_NE(const1, nullptr); - EXPECT_NE(add1, nullptr); - EXPECT_NE(netoutput1, nullptr); - - int data_index = 10; - EXPECT_TRUE(AttrUtils::GetInt(data1->GetOpDesc(), "index", data_index)); - EXPECT_EQ(data_index, 0); - - EXPECT_EQ(data1->GetOpDesc()->GetName(), "data1"); - EXPECT_EQ(data1->GetType(), "Data"); - EXPECT_EQ(data1->GetOpDesc()->GetInputsSize(), 0); - EXPECT_EQ(data1->GetOpDesc()->GetOutputsSize(), 1); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetFormat(), FORMAT_NCHW); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetShape(), GeShape({1,1,224,224})); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetDataType(), DT_FLOAT); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetOriginDataType(), DT_FLOAT); - - EXPECT_EQ(const1->GetOpDesc()->GetName(), "const1"); - EXPECT_EQ(const1->GetType(), "Const"); - EXPECT_EQ(const1->GetOpDesc()->GetInputsSize(), 0); - EXPECT_EQ(const1->GetOpDesc()->GetOutputsSize(), 1); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetFormat(), FORMAT_NCHW); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetShape(), GeShape({1,1,224,224})); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetDataType(), DT_FLOAT); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetOriginDataType(), DT_FLOAT); - - ConstGeTensorPtr tensor; - EXPECT_TRUE(AttrUtils::GetTensor(const1->GetOpDesc(), "value", tensor)); - EXPECT_EQ(tensor->GetTensorDesc().GetFormat(), FORMAT_NCHW); - EXPECT_EQ(tensor->GetTensorDesc().GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(tensor->GetTensorDesc().GetShape(), GeShape({1,1,224,224})); - EXPECT_EQ(tensor->GetTensorDesc().GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(tensor->GetTensorDesc().GetDataType(), DT_FLOAT); - EXPECT_EQ(tensor->GetTensorDesc().GetOriginDataType(), DT_FLOAT); - for (size_t i = 0; i < 224*224; ++i) { - EXPECT_FLOAT_EQ(reinterpret_cast(tensor->GetData().data())[i], tensor_data[i]); - } - - - EXPECT_EQ(add1->GetOpDesc()->GetName(), "add1"); - EXPECT_EQ(add1->GetType(), "Add"); - EXPECT_EQ(add1->GetOpDesc()->GetInputsSize(), 2); - EXPECT_EQ(add1->GetOpDesc()->GetOutputsSize(), 1); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetOriginDataType(), DT_FLOAT); - - EXPECT_EQ(netoutput1->GetOpDesc()->GetName(), "NetOutputNode"); - EXPECT_EQ(netoutput1->GetType(), "NetOutput"); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputsSize(), 1); - EXPECT_EQ(netoutput1->GetOpDesc()->GetOutputsSize(), 0); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetShape(), GeShape({1,1,224,224})); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetDataType(), DT_FLOAT); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetOriginDataType(), DT_FLOAT); - - EXPECT_EQ(data1->GetOutNodes().size(), 1); - EXPECT_TRUE(ExpectConnected(data1, 0, add1, 0)); - EXPECT_EQ(const1->GetOutNodes().size(), 1); - EXPECT_TRUE(ExpectConnected(const1, 0, add1, 1)); - EXPECT_EQ(add1->GetOutNodes().size(), 1); - EXPECT_TRUE(ExpectConnected(add1, 0, netoutput1, 0)); -} - -ComputeGraphPtr CreateGraph_5d_1_1_224_224(float *tensor_data) { - ut::GraphBuilder builder("graph1"); - auto data1 = builder.AddNode("data1", "Data", {}, {"y"}); - AttrUtils::SetInt(data1->GetOpDesc(), "index", 0); - data1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); - data1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape({1,1,224,224,16})); - - auto const1 = builder.AddNode("const1", "Const", {}, {"y"}); - const1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); - const1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape({1,1,224,224,16})); - GeTensorDesc const1_td; - const1_td.SetShape(GeShape({1,1,224,224})); - const1_td.SetOriginShape(GeShape({1,1,224,224})); - const1_td.SetFormat(FORMAT_NCHW); - const1_td.SetOriginFormat(FORMAT_NCHW); - const1_td.SetDataType(DT_FLOAT); - const1_td.SetOriginDataType(DT_FLOAT); - GeTensor tensor(const1_td); - tensor.SetData(reinterpret_cast(tensor_data), sizeof(float) * 224 * 224); - AttrUtils::SetTensor(const1->GetOpDesc(), "value", tensor); - - auto add1 = builder.AddNode("add1", "Add", {"x1", "x2"}, {"y"}); - add1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NC1HWC0); - add1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape({1,1,224,224,16})); - add1->GetOpDesc()->MutableInputDesc(1)->SetFormat(FORMAT_NC1HWC0); - add1->GetOpDesc()->MutableInputDesc(1)->SetShape(GeShape({1,1,224,224,16})); - add1->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); - add1->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape({1,1,224,224,16})); - auto netoutput1 = builder.AddNode("NetOutputNode", "NetOutput", {"x"}, {}); - netoutput1->GetOpDesc()->MutableInputDesc(0)->SetFormat(FORMAT_NC1HWC0); - netoutput1->GetOpDesc()->MutableInputDesc(0)->SetShape(GeShape({1,1,224,224,16})); - - builder.AddDataEdge(data1, 0, add1, 0); - builder.AddDataEdge(const1, 0, add1, 1); - builder.AddDataEdge(add1, 0, netoutput1, 0); - - return builder.GetGraph(); -} - -void ExpectEqGraph_5d_1_1_224_224(const ConstComputeGraphPtr &graph, float *tensor_data) { - EXPECT_EQ(graph->GetAllNodesSize(), 4); - auto data1 = graph->FindNode("data1"); - auto const1 = graph->FindNode("const1"); - auto add1 = graph->FindNode("add1"); - auto netoutput1 = graph->FindNode("NetOutputNode"); - EXPECT_NE(data1, nullptr); - EXPECT_NE(const1, nullptr); - EXPECT_NE(add1, nullptr); - EXPECT_NE(netoutput1, nullptr); - /* todo 属性当前不支持序列化,支持序列化后,放开校验 - int data_index = 10; - EXPECT_TRUE(AttrUtils::GetInt(data1->GetOpDesc(), "index", data_index)); - EXPECT_EQ(data_index, 0); -*/ - EXPECT_EQ(data1->GetOpDesc()->GetName(), "data1"); - EXPECT_EQ(data1->GetType(), "Data"); - EXPECT_EQ(data1->GetOpDesc()->GetInputsSize(), 0); - EXPECT_EQ(data1->GetOpDesc()->GetOutputsSize(), 1); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetShape(), GeShape({1,1,224,224,16})); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetDataType(), DT_FLOAT); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc("y").GetOriginDataType(), DT_FLOAT); - - EXPECT_EQ(const1->GetOpDesc()->GetName(), "const1"); - EXPECT_EQ(const1->GetType(), "Const"); - EXPECT_EQ(const1->GetOpDesc()->GetInputsSize(), 0); - EXPECT_EQ(const1->GetOpDesc()->GetOutputsSize(), 1); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetShape(), GeShape({1,1,224,224,16})); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetDataType(), DT_FLOAT); - EXPECT_EQ(const1->GetOpDesc()->GetOutputDesc("y").GetOriginDataType(), DT_FLOAT); -// -// ConstGeTensorPtr tensor; -// EXPECT_TRUE(AttrUtils::GetTensor(const1->GetOpDesc(), "value", tensor)); -// EXPECT_EQ(tensor->GetTensorDesc().GetFormat(), FORMAT_NCHW); -// EXPECT_EQ(tensor->GetTensorDesc().GetOriginFormat(), FORMAT_NCHW); -// EXPECT_EQ(tensor->GetTensorDesc().GetShape(), GeShape({1,1,224,224})); -// EXPECT_EQ(tensor->GetTensorDesc().GetOriginShape(), GeShape({1,1,224,224})); -// EXPECT_EQ(tensor->GetTensorDesc().GetDataType(), DT_FLOAT); -// EXPECT_EQ(tensor->GetTensorDesc().GetOriginDataType(), DT_FLOAT); -// for (size_t i = 0; i < 224*224; ++i) { -// EXPECT_FLOAT_EQ(reinterpret_cast(tensor->GetData().data())[i], tensor_data[i]); -// } - - - EXPECT_EQ(add1->GetOpDesc()->GetName(), "add1"); - EXPECT_EQ(add1->GetType(), "Add"); - EXPECT_EQ(add1->GetOpDesc()->GetInputsSize(), 2); - EXPECT_EQ(add1->GetOpDesc()->GetOutputsSize(), 1); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetShape(), GeShape({1,1,224,224,16})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x1")->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetShape(), GeShape({1,1,224,224,16})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetInputDescPtr("x2")->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetShape(), GeShape({1,1,224,224,16})); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetDataType(), DT_FLOAT); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc("y").GetOriginDataType(), DT_FLOAT); - - EXPECT_EQ(netoutput1->GetOpDesc()->GetName(), "NetOutputNode"); - EXPECT_EQ(netoutput1->GetType(), "NetOutput"); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputsSize(), 1); - EXPECT_EQ(netoutput1->GetOpDesc()->GetOutputsSize(), 0); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetShape(), GeShape({1,1,224,224,16})); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetOriginShape(), GeShape({1,1,224,224})); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetDataType(), DT_FLOAT); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDescPtr("x")->GetOriginDataType(), DT_FLOAT); - - EXPECT_EQ(data1->GetOutNodes().size(), 1); - EXPECT_TRUE(ExpectConnected(data1, 0, add1, 0)); - EXPECT_EQ(const1->GetOutNodes().size(), 1); - EXPECT_TRUE(ExpectConnected(const1, 0, add1, 1)); - EXPECT_EQ(add1->GetOutNodes().size(), 1); - EXPECT_TRUE(ExpectConnected(add1, 0, netoutput1, 0)); -} -} -class AttrUtilsUt : public testing::Test {}; - -TEST_F(AttrUtilsUt, HasAttrOk) { - auto op_desc = std::make_shared(); - EXPECT_FALSE(AttrUtils::HasAttr(op_desc, "abc")); - EXPECT_FALSE(AttrUtils::HasAttr(op_desc, "bcd")); - - EXPECT_TRUE(AttrUtils::SetInt(op_desc, "abc", 10)); - EXPECT_TRUE(AttrUtils::SetStr(op_desc, "bcd", "hello")); - - EXPECT_TRUE(AttrUtils::HasAttr(op_desc, "abc")); - EXPECT_TRUE(AttrUtils::HasAttr(op_desc, "bcd")); -} - -TEST_F(AttrUtilsUt, SetGetIntOk) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetInt(op_desc, "abc", 10)); - EXPECT_TRUE(AttrUtils::SetInt(op_desc, "bcd", 0xffffffffffff)); - - int64_t i64; - int32_t i32; - uint32_t ui32; - - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "abc", i64)); - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "abc", i32)); - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "abc", ui32)); - EXPECT_EQ(i64, 10); - EXPECT_EQ(i32, 10); - EXPECT_EQ(ui32, 10); - - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "bcd", i64)); - EXPECT_EQ(i64, 0xffffffffffff); - EXPECT_FALSE(AttrUtils::GetInt(op_desc, "bcd", i32)); - EXPECT_FALSE(AttrUtils::GetInt(op_desc, "bcd", ui32)); -} - -TEST_F(AttrUtilsUt, SetGetInt_ExceedsLimit) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetInt(op_desc, "bcd", 0xffffffff)); - - int64_t i64; - int32_t i32; - uint32_t ui32; - - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "bcd", i64)); - EXPECT_FALSE(AttrUtils::GetInt(op_desc, "bcd", i32)); - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "bcd", ui32)); - EXPECT_EQ(i64, 0xffffffff); - EXPECT_EQ(ui32, 0xffffffff); -} - -TEST_F(AttrUtilsUt, SetGetListIntOk1) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "abc2", std::vector({1,2,3}))); - - std::vector i64; - std::vector i32; - std::vector ui32; - - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i64)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i32)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", ui32)); - EXPECT_EQ(i64, std::vector({1,2,3})); - EXPECT_EQ(i32, std::vector({1,2,3})); - EXPECT_EQ(ui32, std::vector({1,2,3})); -} - -TEST_F(AttrUtilsUt, SetGetListIntOk2) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "abc2", std::vector({1,2,3}))); - - std::vector i64; - std::vector i32; - std::vector ui32; - - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i64)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i32)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", ui32)); - EXPECT_EQ(i64, std::vector({1,2,3})); - EXPECT_EQ(i32, std::vector({1,2,3})); - EXPECT_EQ(ui32, std::vector({1,2,3})); -} - -TEST_F(AttrUtilsUt, SetGetListIntOk3) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "abc2", std::vector({1,2,3}))); - - std::vector i64; - std::vector i32; - std::vector ui32; - - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i64)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i32)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", ui32)); - EXPECT_EQ(i64, std::vector({1,2,3})); - EXPECT_EQ(i32, std::vector({1,2,3})); - EXPECT_EQ(ui32, std::vector({1,2,3})); -} - -TEST_F(AttrUtilsUt, SetGetListIntOk4) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "abc2", {1,2,3})); - - std::vector i64; - std::vector i32; - std::vector ui32; - - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i64)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i32)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", ui32)); - EXPECT_EQ(i64, std::vector({1,2,3})); - EXPECT_EQ(i32, std::vector({1,2,3})); - EXPECT_EQ(ui32, std::vector({1,2,3})); -} - -TEST_F(AttrUtilsUt, SetGetListInt_ExceedsLimit1) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "abc2", {1,2,3, 0xffffffffffff})); - - std::vector i64; - std::vector i32; - std::vector ui32; - - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i64)); - EXPECT_FALSE(AttrUtils::GetListInt(op_desc, "abc2", i32)); - EXPECT_FALSE(AttrUtils::GetListInt(op_desc, "abc2", ui32)); - EXPECT_EQ(i64, std::vector({1,2,3,0xffffffffffff})); -} - -TEST_F(AttrUtilsUt, SetGetListInt_ExceedsLimit2) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "abc2", {1,2,3, 0xffffffff})); - - std::vector i64; - std::vector i32; - std::vector ui32; - - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i64)); - EXPECT_FALSE(AttrUtils::GetListInt(op_desc, "abc2", i32)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", ui32)); - EXPECT_EQ(i64, std::vector({1,2,3,0xffffffff})); - EXPECT_EQ(ui32, std::vector({1,2,3,0xffffffff})); -} - -TEST_F(AttrUtilsUt, SetGetListInt_ExceedsLimit3) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "abc2", {1,2,3, -1})); - - std::vector i64; - std::vector i32; - //std::vector ui32; - - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i64)); - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "abc2", i32)); - //EXPECT_FALSE(AttrUtils::GetListInt(op_desc, "abc2", ui32)); - EXPECT_EQ(i64, std::vector({1,2,3,-1})); - //EXPECT_EQ(i32, std::vector({1,2,3,-1})); -} - -TEST_F(AttrUtilsUt, SetGetFloatOk) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetFloat(op_desc, "abc", 3.1415926)); - float f; - EXPECT_TRUE(AttrUtils::GetFloat(op_desc, "abc", f)); - EXPECT_FLOAT_EQ(f, 3.1415926); -} - -TEST_F(AttrUtilsUt, SetGetListFloatOk) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListFloat(op_desc, "abc", std::vector({3.1415,4.1415,5.1415926}))); - std::vector f; - EXPECT_TRUE(AttrUtils::GetListFloat(op_desc, "abc", f)); - EXPECT_EQ(f.size(), 3); - EXPECT_FLOAT_EQ(f[0], 3.1415); - EXPECT_FLOAT_EQ(f[1], 4.1415); - EXPECT_FLOAT_EQ(f[2], 5.1415926); -} - -TEST_F(AttrUtilsUt, SetGetBoolOk) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetBool(op_desc, "abc", true)); - EXPECT_TRUE(AttrUtils::SetBool(op_desc, "bcd", false)); - bool b1 = false, b2 = true; - EXPECT_TRUE(AttrUtils::GetBool(op_desc, "abc", b1)); - EXPECT_TRUE(AttrUtils::GetBool(op_desc, "bcd", b2)); - EXPECT_TRUE(b1); - EXPECT_FALSE(b2); -} - -TEST_F(AttrUtilsUt, SetGetListBoolOk) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListBool(op_desc, "abc", std::vector({true,false,false,true}))); - std::vector b; - EXPECT_TRUE(AttrUtils::GetListBool(op_desc, "abc", b)); - EXPECT_EQ(b.size(), 4); - EXPECT_TRUE(b[0]); - EXPECT_FALSE(b[1]); - EXPECT_FALSE(b[2]); - EXPECT_TRUE(b[3]); -} - -TEST_F(AttrUtilsUt, SetGetStrOk) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetStr(op_desc, "abc", "Hello")); - EXPECT_TRUE(AttrUtils::SetStr(op_desc, "bcd", "World")); - std::string s1, s2; - EXPECT_TRUE(AttrUtils::GetStr(op_desc, "abc", s1)); - EXPECT_TRUE(AttrUtils::GetStr(op_desc, "bcd", s2)); - EXPECT_EQ(s1, "Hello"); - EXPECT_EQ(s2, "World"); -} - -TEST_F(AttrUtilsUt, SetGetListStrOk) { - auto op_desc = std::make_shared(); - - EXPECT_TRUE(AttrUtils::SetListStr(op_desc, "abc", std::vector({"Hello", "world", "!"}))); - std::vector s; - EXPECT_TRUE(AttrUtils::GetListStr(op_desc, "abc", s)); - EXPECT_EQ(s, std::vector({"Hello", "world", "!"})); -} - -TEST_F(AttrUtilsUt, SetGetTensorDescOk) { - auto op_desc = std::make_shared(); - GeTensorDesc td; - td.SetShape(GeShape(std::vector({8,1,128,128,16}))); - td.SetOriginShape(GeShape(std::vector({8,3,128,128}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT16); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - - EXPECT_TRUE(AttrUtils::SetTensorDesc(op_desc, "abc", td)); - - GeTensorDesc td1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(op_desc, "abc", td1)); - EXPECT_EQ(td1.GetShape().GetDims(), std::vector({8,1,128,128,16})); - EXPECT_EQ(td1.GetOriginShape().GetDims(), std::vector({8,3,128,128})); - EXPECT_EQ(td1.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(td1.GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(td1.GetDataType(), DT_FLOAT16); - EXPECT_EQ(td1.GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&td1, "bcd", s)); - EXPECT_EQ(s, "Hello world"); -} - -TEST_F(AttrUtilsUt, SetGetTensorDescOk_CopyValidation1) { - auto op_desc = std::make_shared(); - GeTensorDesc td; - td.SetShape(GeShape(std::vector({8,1,128,128,16}))); - td.SetOriginShape(GeShape(std::vector({8,3,128,128}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT16); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - - EXPECT_TRUE(AttrUtils::SetTensorDesc(op_desc, "abc", td)); - td.SetShape(GeShape(std::vector({1}))); - td.SetOriginShape(GeShape(std::vector({8}))); - td.SetFormat(FORMAT_ND); - td.SetOriginFormat(FORMAT_ND); - td.SetDataType(DT_INT16); - td.SetOriginDataType(DT_INT16); - AttrUtils::SetStr(&td, "bcd", "adasdfasdf"); - - GeTensorDesc td1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(op_desc, "abc", td1)); - EXPECT_EQ(td1.GetShape().GetDims(), std::vector({8,1,128,128,16})); - EXPECT_EQ(td1.GetOriginShape().GetDims(), std::vector({8,3,128,128})); - EXPECT_EQ(td1.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(td1.GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(td1.GetDataType(), DT_FLOAT16); - EXPECT_EQ(td1.GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&td1, "bcd", s)); - EXPECT_EQ(s, "Hello world"); -} - -TEST_F(AttrUtilsUt, SetGetTensorDescOk_CopyValidation2) { - auto op_desc = std::make_shared(); - GeTensorDesc td; - td.SetShape(GeShape(std::vector({8,1,128,128,16}))); - td.SetOriginShape(GeShape(std::vector({8,3,128,128}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT16); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - - EXPECT_TRUE(AttrUtils::SetTensorDesc(op_desc, "abc", td)); - - GeTensorDesc td1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(op_desc, "abc", td1)); - td1.SetShape(GeShape(std::vector({1}))); - td1.SetOriginShape(GeShape(std::vector({8}))); - td1.SetFormat(FORMAT_ND); - td1.SetOriginFormat(FORMAT_ND); - td1.SetDataType(DT_INT16); - td1.SetOriginDataType(DT_INT16); - AttrUtils::SetStr(&td1, "bcd", "adasdfasdf"); - - EXPECT_EQ(td.GetShape().GetDims(), std::vector({8,1,128,128,16})); - EXPECT_EQ(td.GetOriginShape().GetDims(), std::vector({8,3,128,128})); - EXPECT_EQ(td.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(td.GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(td.GetDataType(), DT_FLOAT16); - EXPECT_EQ(td.GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&td, "bcd", s)); - EXPECT_EQ(s, "Hello world"); -} - -TEST_F(AttrUtilsUt, SetGetListTensorDescOk) { - auto op_desc = std::make_shared(); - std::vector tds(5); - for (auto &td : tds) { - td.SetShape(GeShape(std::vector({8,1,128,128,16}))); - td.SetOriginShape(GeShape(std::vector({8,3,128,128}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT16); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - } - - EXPECT_TRUE(AttrUtils::SetListTensorDesc(op_desc, "abc", tds)); - - std::vector tds1; - EXPECT_TRUE(AttrUtils::GetListTensorDesc(op_desc, "abc", tds1)); - for (auto &td1 : tds1) { - EXPECT_EQ(td1.GetShape().GetDims(), std::vector({8,1,128,128,16})); - EXPECT_EQ(td1.GetOriginShape().GetDims(), std::vector({8,3,128,128})); - EXPECT_EQ(td1.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(td1.GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(td1.GetDataType(), DT_FLOAT16); - EXPECT_EQ(td1.GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&td1, "bcd", s)); - EXPECT_EQ(s, "Hello world"); - } -} - -TEST_F(AttrUtilsUt, SetGetListTensorDescOk_CopyValidation1) { - auto op_desc = std::make_shared(); - std::vector tds(5); - for (auto &td : tds) { - td.SetShape(GeShape(std::vector({8,1,128,128,16}))); - td.SetOriginShape(GeShape(std::vector({8,3,128,128}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT16); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - } - - EXPECT_TRUE(AttrUtils::SetListTensorDesc(op_desc, "abc", tds)); - for (auto &td : tds) { - td.SetShape(GeShape(std::vector({1}))); - td.SetOriginShape(GeShape(std::vector({8}))); - td.SetFormat(FORMAT_ND); - td.SetOriginFormat(FORMAT_ND); - td.SetDataType(DT_INT16); - td.SetOriginDataType(DT_INT16); - AttrUtils::SetStr(&td, "bcd", "adasdfasdf"); - } - - std::vector tds1; - EXPECT_TRUE(AttrUtils::GetListTensorDesc(op_desc, "abc", tds1)); - for (auto &td1 : tds1) { - EXPECT_EQ(td1.GetShape().GetDims(), std::vector({8,1,128,128,16})); - EXPECT_EQ(td1.GetOriginShape().GetDims(), std::vector({8,3,128,128})); - EXPECT_EQ(td1.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(td1.GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(td1.GetDataType(), DT_FLOAT16); - EXPECT_EQ(td1.GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&td1, "bcd", s)); - EXPECT_EQ(s, "Hello world"); - } -} - -TEST_F(AttrUtilsUt, SetGetListTensorDescOk_CopyValidation2) { - auto op_desc = std::make_shared(); - std::vector tds(5); - for (auto &td : tds) { - td.SetShape(GeShape(std::vector({8,1,128,128,16}))); - td.SetOriginShape(GeShape(std::vector({8,3,128,128}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT16); - td.SetOriginDataType(DT_FLOAT); - AttrUtils::SetStr(&td, "bcd", "Hello world"); - } - - EXPECT_TRUE(AttrUtils::SetListTensorDesc(op_desc, "abc", tds)); - std::vector tds1; - EXPECT_TRUE(AttrUtils::GetListTensorDesc(op_desc, "abc", tds1)); - for (auto &td1 : tds1) { - td1.SetShape(GeShape(std::vector({1}))); - td1.SetOriginShape(GeShape(std::vector({8}))); - td1.SetFormat(FORMAT_ND); - td1.SetOriginFormat(FORMAT_ND); - td1.SetDataType(DT_INT16); - td1.SetOriginDataType(DT_INT16); - AttrUtils::SetStr(&td1, "bcd", "adasdfasdf"); - } - for (auto &td : tds) { - EXPECT_EQ(td.GetShape().GetDims(), std::vector({8,1,128,128,16})); - EXPECT_EQ(td.GetOriginShape().GetDims(), std::vector({8,3,128,128})); - EXPECT_EQ(td.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(td.GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(td.GetDataType(), DT_FLOAT16); - EXPECT_EQ(td.GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&td, "bcd", s)); - EXPECT_EQ(s, "Hello world"); - } -} - -TEST_F(AttrUtilsUt, SetGetTensorOk1) { - auto op_desc = std::make_shared(); - auto tensor_data = GetRandomFloat({8, 3, 224, 224}); - { - auto tensor = CreateTensor_8_3_224_224(tensor_data.get()); - ConstGeTensorPtr tensor1 = tensor; - - EXPECT_TRUE(AttrUtils::SetTensor(op_desc, "abc", tensor)); - EXPECT_TRUE(AttrUtils::SetTensor(op_desc, "bcd", *tensor)); - EXPECT_TRUE(AttrUtils::SetTensor(op_desc, "cde", tensor1)); - } - - ConstGeTensorPtr out_tensor; - EXPECT_TRUE(AttrUtils::GetTensor(op_desc, "abc", out_tensor)); - EXPECT_NE(out_tensor, nullptr); - ExpectTensorEqual_8_3_224_224(out_tensor, tensor_data.get()); - - EXPECT_TRUE(AttrUtils::GetTensor(op_desc, "bcd", out_tensor)); - EXPECT_NE(out_tensor, nullptr); - ExpectTensorEqual_8_3_224_224(out_tensor, tensor_data.get()); - - EXPECT_TRUE(AttrUtils::GetTensor(op_desc, "cde", out_tensor)); - EXPECT_NE(out_tensor, nullptr); - ExpectTensorEqual_8_3_224_224(out_tensor, tensor_data.get()); -} - -TEST_F(AttrUtilsUt, SetGetTensorOk1_CopyValidation1) { - auto op_desc = std::make_shared(); - auto tensor_data = GetRandomFloat({8, 3, 224, 224}); - auto tensor_data1 = GetRandomFloat({16, 3, 224, 224}); - auto tensor = CreateTensor_8_3_224_224(tensor_data.get()); - - EXPECT_TRUE(AttrUtils::SetTensor(op_desc, "abc", tensor)); - tensor->MutableData().SetData(reinterpret_cast(tensor_data1.get()), 16*3*224*224*sizeof(float)); - tensor->MutableTensorDesc().SetShape(GeShape(std::vector({16,3,224,224}))); - tensor->MutableTensorDesc().SetOriginShape(GeShape(std::vector({16,3,224,224}))); - - ConstGeTensorPtr out_tensor; - EXPECT_TRUE(AttrUtils::GetTensor(op_desc, "abc", out_tensor)); - EXPECT_NE(out_tensor, nullptr); - ExpectTensorEqual_8_3_224_224(out_tensor, tensor_data.get()); -} - -TEST_F(AttrUtilsUt, SetGetTensorOk1_MultipleGet) { - auto op_desc = std::make_shared(); - auto tensor_data = GetRandomFloat({8, 3, 224, 224}); - auto tensor = CreateTensor_8_3_224_224(tensor_data.get()); - - EXPECT_TRUE(AttrUtils::SetTensor(op_desc, "abc", tensor)); - - auto tensor_data1 = GetRandomFloat({16, 3, 224, 224}); - GeTensorPtr out_tensor = nullptr; - EXPECT_TRUE(AttrUtils::MutableTensor(op_desc, "abc", out_tensor)); - EXPECT_NE(out_tensor, nullptr); - out_tensor->MutableData().SetData(reinterpret_cast(tensor_data1.get()), 16*3*224*224*sizeof(float)); - out_tensor->MutableTensorDesc().SetShape(GeShape(std::vector({16,3,224,224}))); - out_tensor->MutableTensorDesc().SetOriginShape(GeShape(std::vector({16,3,224,224}))); - - out_tensor = nullptr; - EXPECT_TRUE(AttrUtils::MutableTensor(op_desc, "abc", out_tensor)); - EXPECT_NE(out_tensor, nullptr); - - EXPECT_NE(const_cast(out_tensor->GetData().data()), reinterpret_cast(tensor_data1.get())); - EXPECT_EQ(out_tensor->GetData().size(), 16*3*224*224*sizeof(float)); - for (size_t i = 0; i < 16*3*224*224; ++i) { - EXPECT_FLOAT_EQ(reinterpret_cast(out_tensor->GetData().data())[i], tensor_data1.get()[i]); - } - EXPECT_EQ(out_tensor->GetTensorDesc().GetShape().GetDims(), std::vector({16,3,224,224})); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginShape().GetDims(), std::vector({16,3,224,224})); - EXPECT_EQ(out_tensor->GetTensorDesc().GetFormat(), FORMAT_NCHW); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(out_tensor->GetTensorDesc().GetDataType(), DT_FLOAT); - EXPECT_EQ(out_tensor->GetTensorDesc().GetOriginDataType(), DT_FLOAT); - std::string s; - EXPECT_TRUE(AttrUtils::GetStr(&out_tensor->GetTensorDesc(), "bcd", s)); - EXPECT_EQ(s, "Hello world"); -} - -TEST_F(AttrUtilsUt, SetGetListTensor) { - auto data1 = GetRandomFloat({8,3,224,224}); - auto data2 = GetRandomFloat({8,1,224,224,16}); - auto data3 = GetRandomFloat({8,3,224,224}); - auto tensor1 = CreateTensor_8_3_224_224(data1.get()); - auto tensor2 = CreateTensor_5d_8_3_224_224(data2.get()); - auto tensor3 = CreateTensor_8_3_224_224(data3.get()); - - auto op_desc = std::make_shared(); - EXPECT_TRUE(AttrUtils::SetListTensor(op_desc, "abc", std::vector({tensor1, tensor2, tensor3}))); - EXPECT_TRUE(AttrUtils::SetListTensor(op_desc, "abc1", std::vector({tensor1, tensor2, tensor3}))); - EXPECT_TRUE(AttrUtils::SetListTensor(op_desc, "abc2", std::vector({*tensor1, *tensor2, *tensor3}))); - EXPECT_TRUE(AttrUtils::SetListTensor(op_desc, "abc3", {tensor1, tensor2, tensor3})); - - std::vector out_tensors; - EXPECT_TRUE(AttrUtils::GetListTensor(op_desc, "abc", out_tensors)); - EXPECT_EQ(out_tensors.size(), 3); - ExpectTensorEqual_8_3_224_224(out_tensors[0], data1.get()); - ExpectTensorEqual_5d_8_3_224_224(out_tensors[1], data2.get()); - ExpectTensorEqual_8_3_224_224(out_tensors[2], data3.get()); - - EXPECT_TRUE(AttrUtils::GetListTensor(op_desc, "abc1", out_tensors)); - EXPECT_EQ(out_tensors.size(), 3); - ExpectTensorEqual_8_3_224_224(out_tensors[0], data1.get()); - ExpectTensorEqual_5d_8_3_224_224(out_tensors[1], data2.get()); - ExpectTensorEqual_8_3_224_224(out_tensors[2], data3.get()); - - EXPECT_TRUE(AttrUtils::GetListTensor(op_desc, "abc2", out_tensors)); - EXPECT_EQ(out_tensors.size(), 3); - ExpectTensorEqual_8_3_224_224(out_tensors[0], data1.get()); - ExpectTensorEqual_5d_8_3_224_224(out_tensors[1], data2.get()); - ExpectTensorEqual_8_3_224_224(out_tensors[2], data3.get()); - - EXPECT_TRUE(AttrUtils::GetListTensor(op_desc, "abc3", out_tensors)); - EXPECT_EQ(out_tensors.size(), 3); - ExpectTensorEqual_8_3_224_224(out_tensors[0], data1.get()); - ExpectTensorEqual_5d_8_3_224_224(out_tensors[1], data2.get()); - ExpectTensorEqual_8_3_224_224(out_tensors[2], data3.get()); -} - -TEST_F(AttrUtilsUt, SetGetListTensor_MutableOk) { - auto data1 = GetRandomFloat({8,3,224,224}); - auto data2 = GetRandomFloat({8,1,224,224,16}); - auto data3 = GetRandomFloat({8,3,224,224}); - auto data4 = GetRandomFloat({8,1,224,224,16}); - auto tensor1 = CreateTensor_8_3_224_224(data1.get()); - auto tensor2 = CreateTensor_5d_8_3_224_224(data2.get()); - auto tensor3 = CreateTensor_8_3_224_224(data3.get()); - - auto op_desc = std::make_shared(); - EXPECT_TRUE(AttrUtils::SetListTensor(op_desc, "abc", {tensor1, tensor2, tensor3})); - - std::vector out_tensors; - EXPECT_TRUE(AttrUtils::MutableListTensor(op_desc, "abc", out_tensors)); - out_tensors[2]->MutableData().SetData(reinterpret_cast(data4.get()), 8*1*224*224*16*sizeof(float)); - out_tensors[2]->MutableTensorDesc().SetShape(GeShape(std::vector({8,1,224,224,16}))); - out_tensors[2]->MutableTensorDesc().SetFormat(FORMAT_NC1HWC0); - out_tensors.clear(); - - EXPECT_TRUE(AttrUtils::MutableListTensor(op_desc, "abc", out_tensors)); - EXPECT_EQ(out_tensors.size(), 3); - ExpectTensorEqual_8_3_224_224(out_tensors[0], data1.get()); - ExpectTensorEqual_5d_8_3_224_224(out_tensors[1], data2.get()); - ExpectTensorEqual_5d_8_3_224_224(out_tensors[2], data4.get()); -} - -TEST_F(AttrUtilsUt, SetGetListTensor_CopyValidation) { - auto data1 = GetRandomFloat({8,3,224,224}); - auto data2 = GetRandomFloat({8,1,224,224,16}); - auto data3 = GetRandomFloat({8,3,224,224}); - auto data4 = GetRandomFloat({8,1,224,224,16}); - auto tensor1 = CreateTensor_8_3_224_224(data1.get()); - auto tensor2 = CreateTensor_5d_8_3_224_224(data2.get()); - auto tensor3 = CreateTensor_8_3_224_224(data3.get()); - - auto op_desc = std::make_shared(); - EXPECT_TRUE(AttrUtils::SetListTensor(op_desc, "abc", {tensor1, tensor2, tensor3})); - tensor3->MutableData().SetData(reinterpret_cast(data4.get()), 8*1*224*224*16*sizeof(float)); - tensor3->MutableTensorDesc().SetShape(GeShape(std::vector({8,1,224,224,16}))); - tensor3->MutableTensorDesc().SetFormat(FORMAT_NC1HWC0); - - std::vector out_tensors; - EXPECT_TRUE(AttrUtils::GetListTensor(op_desc, "abc", out_tensors)); - EXPECT_EQ(out_tensors.size(), 3); - ExpectTensorEqual_8_3_224_224(out_tensors[0], data1.get()); - ExpectTensorEqual_5d_8_3_224_224(out_tensors[1], data2.get()); - ExpectTensorEqual_8_3_224_224(out_tensors[2], data3.get()); -} - -TEST_F(AttrUtilsUt, SetGetGraphGraph) { - auto const_data = GetRandomFloat({1,1,224,224}); - auto holder = std::make_shared("holder"); - - { - auto graph = CreateGraph_1_1_224_224(const_data.get()); - EXPECT_TRUE(AttrUtils::SetGraph(holder, "abc", graph)); - } - - ComputeGraphPtr out_graph = nullptr; - EXPECT_TRUE(AttrUtils::GetGraph(holder, "abc", out_graph)); - - EXPECT_NE(out_graph, nullptr); - ExpectEqGraph_1_1_224_224(out_graph, const_data.get()); -} - -TEST_F(AttrUtilsUt, SetGraphGraph_CopyValidation) { - auto const_data = GetRandomFloat({1,1,224,224}); - auto holder = std::make_shared("holder"); - - auto graph = CreateGraph_1_1_224_224(const_data.get()); - EXPECT_TRUE(AttrUtils::SetGraph(holder, "abc", graph)); - graph->FindNode("data1")->GetOpDesc()->MutableOutputDesc(0)->SetFormat(FORMAT_NC1HWC0); - graph->FindNode("data1")->GetOpDesc()->MutableOutputDesc(0)->SetShape(GeShape({1,1,224,224,16})); - - ComputeGraphPtr out_graph = nullptr; - EXPECT_TRUE(AttrUtils::GetGraph(holder, "abc", out_graph)); - - EXPECT_NE(out_graph, nullptr); - ExpectEqGraph_1_1_224_224(out_graph, const_data.get()); -} - -TEST_F(AttrUtilsUt, SetGetListGraphGraph) { - auto const_data1 = GetRandomFloat({1,1,224,224,16}); - auto const_data2 = GetRandomFloat({1,1,224,224,16}); - auto const_data3 = GetRandomFloat({1,1,224,224}); - auto holder = std::make_shared("holder"); - - { - auto graph1 = CreateGraph_5d_1_1_224_224(const_data1.get()); - auto graph2 = CreateGraph_5d_1_1_224_224(const_data2.get()); - auto graph3 = CreateGraph_1_1_224_224(const_data3.get()); - EXPECT_TRUE(AttrUtils::SetListGraph(holder, "abc", std::vector({graph1, graph2, graph3}))); - } - - std::vector out_graphs; - EXPECT_TRUE(AttrUtils::GetListGraph(holder, "abc", out_graphs)); - - EXPECT_EQ(out_graphs.size(), 3); - ExpectEqGraph_5d_1_1_224_224(out_graphs[0], const_data1.get()); - ExpectEqGraph_5d_1_1_224_224(out_graphs[1], const_data2.get()); - ExpectEqGraph_1_1_224_224(out_graphs[2], const_data3.get()); -} - -TEST_F(AttrUtilsUt, SimpleTest) { - auto op_desc = std::make_shared(); - { - op_desc->SetAttr("Foo", GeAttrValue::CreateFrom(true)); - } - EXPECT_TRUE(AttrUtils::SetBool(op_desc, "Foo", true)); - bool val = false; - EXPECT_TRUE(AttrUtils::GetBool(op_desc, "Foo", val)); - EXPECT_TRUE(val); -} - -TEST_F(AttrUtilsUt, SetGetBytes) { - GeTensorDesc td; - auto data = GetRandomFloat({1,2,3,4}); - auto b1 = Buffer::CopyFrom(reinterpret_cast(data.get()), sizeof(float) * 1 * 2 * 3 * 4); - EXPECT_TRUE(AttrUtils::SetBytes(&td, "abc", b1)); - Buffer b2; - EXPECT_TRUE(AttrUtils::GetBytes(&td, "abc", b2)); - EXPECT_EQ(b1.size(), b2.size()); - EXPECT_EQ(memcmp(b1.data(), b2.data(), b1.size()), 0); - EXPECT_NE(b1.data(), b2.data()); -} - -TEST_F(AttrUtilsUt, SetGetBytes_ZeroCopy) { - GeTensorDesc td; - auto data = GetRandomFloat({1,2,3,4}); - auto b1 = Buffer::CopyFrom(reinterpret_cast(data.get()), sizeof(float) * 1 * 2 * 3 * 4); - auto addr = b1.data(); - EXPECT_TRUE(AttrUtils::SetZeroCopyBytes(&td, "abc", std::move(b1))); - Buffer b2; - EXPECT_TRUE(AttrUtils::GetZeroCopyBytes(&td, "abc", b2)); - EXPECT_EQ(addr, b2.data()); - EXPECT_EQ(b2.size(), sizeof(float) * 2 * 3 * 4); -} - -TEST_F(AttrUtilsUt, SetGetBytes_CopyValidation) { - GeTensorDesc td; - auto data = GetRandomFloat({1,2,3,4}); - auto b1 = Buffer::CopyFrom(reinterpret_cast(data.get()), sizeof(float) * 1 * 2 * 3 * 4); - EXPECT_TRUE(AttrUtils::SetBytes(&td, "abc", b1)); - b1.ClearBuffer(); - Buffer b2; - EXPECT_TRUE(AttrUtils::GetBytes(&td, "abc", b2)); - EXPECT_EQ(sizeof(float) * 1 * 2 * 3 * 4, b2.size()); - EXPECT_EQ(memcmp(data.get(), b2.data(), b2.size()), 0); -} - -TEST_F(AttrUtilsUt, SetGetListBytes) { - GeTensorDesc td; - auto data1 = GetRandomFloat({20}); - auto data2 = GetRandomFloat({40}); - auto data3 = GetRandomFloat({90}); - std::vector bufs = { - Buffer::CopyFrom(reinterpret_cast(data1.get()), sizeof(float) * 20), - Buffer::CopyFrom(reinterpret_cast(data2.get()), sizeof(float) * 40), - Buffer::CopyFrom(reinterpret_cast(data3.get()), sizeof(float) * 90) - }; - EXPECT_TRUE(AttrUtils::SetListBytes(&td, "abc", bufs)); - std::vector out_bufs; - EXPECT_TRUE(AttrUtils::GetListBytes(&td, "abc", out_bufs)); - EXPECT_EQ(out_bufs.size(), 3); - for (size_t i = 0; i < 3; ++i) { - EXPECT_EQ(out_bufs[i].size(), bufs[i].size()); - EXPECT_EQ(memcmp(out_bufs[i].data(), bufs[i].data(), out_bufs[i].size()), 0); - EXPECT_NE(out_bufs[i].data(), bufs[i].data()); - } -} - -TEST_F(AttrUtilsUt, SetGetListBytes_CopyValidation) { - GeTensorDesc td; - auto data1 = GetRandomFloat({20}); - auto data2 = GetRandomFloat({40}); - auto data3 = GetRandomFloat({90}); - std::vector bufs = { - Buffer::CopyFrom(reinterpret_cast(data1.get()), sizeof(float) * 20), - Buffer::CopyFrom(reinterpret_cast(data2.get()), sizeof(float) * 40), - Buffer::CopyFrom(reinterpret_cast(data3.get()), sizeof(float) * 90) - }; - - EXPECT_TRUE(AttrUtils::SetListBytes(&td, "abc", bufs)); - bufs[0].ClearBuffer(); - bufs[1].ClearBuffer(); - bufs[2].ClearBuffer(); - - std::vector out_bufs; - EXPECT_TRUE(AttrUtils::GetListBytes(&td, "abc", out_bufs)); - EXPECT_EQ(out_bufs.size(), 3); - - EXPECT_EQ(out_bufs[0].size(), 20 * sizeof(float)); - EXPECT_EQ(memcmp(out_bufs[0].data(), data1.get(), out_bufs[0].size()), 0); - EXPECT_EQ(out_bufs[1].size(), 40 * sizeof(float)); - EXPECT_EQ(memcmp(out_bufs[1].data(), data2.get(), out_bufs[1].size()), 0); - EXPECT_EQ(out_bufs[2].size(), 90 * sizeof(float)); - EXPECT_EQ(memcmp(out_bufs[2].data(), data3.get(), out_bufs[2].size()), 0); -} - -TEST_F(AttrUtilsUt, SetGetListBytes_ZeroCopy) { - GeTensorDesc td; - auto data1 = GetRandomFloat({20}); - auto data2 = GetRandomFloat({40}); - auto data3 = GetRandomFloat({90}); - std::vector bufs = { - Buffer::CopyFrom(reinterpret_cast(data1.get()), sizeof(float) * 20), - Buffer::CopyFrom(reinterpret_cast(data2.get()), sizeof(float) * 40), - Buffer::CopyFrom(reinterpret_cast(data3.get()), sizeof(float) * 90) - }; - EXPECT_TRUE(AttrUtils::SetZeroCopyListBytes(&td, "abc", bufs)); - std::vector out_bufs; - EXPECT_TRUE(AttrUtils::GetZeroCopyListBytes(&td, "abc", out_bufs)); - EXPECT_EQ(out_bufs.size(), 3); - for (size_t i = 0; i < 3; ++i) { - EXPECT_EQ(out_bufs[i].data(), bufs[i].data()); - } -} - -TEST_F(AttrUtilsUt, SetGetListListInt) { - auto op_desc = std::make_shared(); - EXPECT_TRUE(AttrUtils::SetListListInt(op_desc, - "abc", - std::vector>({{1, 2, 3}, {4, 4, 5}, {2, 2}}))); - std::vector> vec; - EXPECT_TRUE(AttrUtils::GetListListInt(op_desc, "abc", vec)); - EXPECT_EQ(vec, std::vector>({{1, 2, 3}, {4, 4, 5}, {2, 2}})); - //支持同名类型覆写 - EXPECT_TRUE(AttrUtils::SetInt(op_desc, "abc", 100)); - int64_t index; - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "abc", index)); - EXPECT_EQ(index, 100); -} - -TEST_F(AttrUtilsUt, SetGetListListFloat) { - auto op_desc = std::make_shared(); - EXPECT_TRUE(AttrUtils::SetListListFloat(op_desc, "abc", std::vector>({{1.1,2.9,3.14},{4.122,43.4,5.55},{2.1,2.0}}))); - std::vector> vec; - EXPECT_TRUE(AttrUtils::GetListListFloat(op_desc, "abc", vec)); - EXPECT_EQ(vec.size(), 3); - EXPECT_EQ(vec[0].size(), 3); - EXPECT_EQ(vec[1].size(), 3); - EXPECT_EQ(vec[2].size(), 2); - EXPECT_FLOAT_EQ(vec[0][0], 1.1); - EXPECT_FLOAT_EQ(vec[1][0], 4.122); - EXPECT_FLOAT_EQ(vec[2][0], 2.1); -} - -TEST_F(AttrUtilsUt, SetGetNamedAttrs) { - auto op_desc = std::make_shared(); - NamedAttrs nas; - nas.SetName("Hello Name"); - nas.SetAttr("abc", AnyValue::CreateFrom(static_cast(10))); - nas.SetAttr("bcd", AnyValue::CreateFrom(true)); - - EXPECT_TRUE(AttrUtils::SetNamedAttrs(op_desc, "attr", nas)); - - NamedAttrs out_nas; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(op_desc, "attr", out_nas)); - EXPECT_EQ(out_nas.GetName(), nas.GetName()); - AnyValue av; - EXPECT_EQ(out_nas.GetAttr("abc", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), 10); - - EXPECT_EQ(out_nas.GetAttr("bcd", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), true); -} - -TEST_F(AttrUtilsUt, SetGetNamedAttrs_CopyValidation) { - auto op_desc = std::make_shared(); - NamedAttrs nas; - nas.SetName("Hello Name"); - nas.SetAttr("abc", AnyValue::CreateFrom(static_cast(10))); - nas.SetAttr("bcd", AnyValue::CreateFrom(true)); - - EXPECT_TRUE(AttrUtils::SetNamedAttrs(op_desc, "attr", nas)); - AnyValue tmp_av; - nas.GetAttr("abc", tmp_av); - tmp_av.SetValue(static_cast(1024)); - nas.SetAttr("bcd", AnyValue::CreateFrom(1243124)); - - NamedAttrs out_nas; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(op_desc, "attr", out_nas)); - EXPECT_EQ(out_nas.GetName(), nas.GetName()); - AnyValue av; - EXPECT_EQ(out_nas.GetAttr("abc", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), 10); - - EXPECT_EQ(out_nas.GetAttr("bcd", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), true); -} - -TEST_F(AttrUtilsUt, SetGetNamedAttrs_AttachedStreamInfo) { - auto op_desc = std::make_shared(); - // group id, 提前设置到op_desc上的 - AttrUtils::SetStr(op_desc, "group", "group0"); - // *********设置流程********** - NamedAttrs nas_stream_info; - // 这个名字不太重要 - nas_stream_info.SetName("nas0"); - // 下面的SetAttr方法的第一个参数涉及到跨组件的named_attr的进一步解析,所以需要写对 - // attach策略 - AttrUtils::SetStr(nas_stream_info, ge::ATTR_NAME_ATTACHED_STREAM_POLICY, "group"); - // attach流名称 - AttrUtils::SetStr(nas_stream_info, ge::ATTR_NAME_ATTACHED_STREAM_KEY, "kfc_stream"); - // nas_stream_info填充好之后,使用ATTR_NAME_ATTACHED_STREAM_INFO设置到opdesc上 - EXPECT_TRUE(AttrUtils::SetNamedAttrs(op_desc, ge::ATTR_NAME_ATTACHED_STREAM_INFO, nas_stream_info)); - - // *********解析流程********** - NamedAttrs attrs_for_assign_attached_stream; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(op_desc, ge::ATTR_NAME_ATTACHED_STREAM_INFO, attrs_for_assign_attached_stream)); - EXPECT_EQ(attrs_for_assign_attached_stream.GetName(), nas_stream_info.GetName()); - std::string tmp; - EXPECT_TRUE(AttrUtils::GetStr(attrs_for_assign_attached_stream, ge::ATTR_NAME_ATTACHED_STREAM_POLICY, tmp)); - EXPECT_EQ(tmp, "group"); - EXPECT_TRUE(AttrUtils::GetStr(op_desc, "group", tmp)); - EXPECT_EQ(tmp, "group0"); - EXPECT_TRUE(AttrUtils::GetStr(attrs_for_assign_attached_stream, ge::ATTR_NAME_ATTACHED_STREAM_KEY, tmp)); - EXPECT_EQ(tmp, "kfc_stream"); -} - -TEST_F(AttrUtilsUt, SetGetNamedAttrs_AttachedNotifyInfo) { - auto op_desc = std::make_shared(); - // group id, 提前设置到op_desc上的 - AttrUtils::SetStr(op_desc, "group", "group0"); - // *********设置流程********** - NamedAttrs nas; - // 这个名字不太重要 - nas.SetName("nas0"); - // 下面的map的key涉及到跨组件的named_attr的进一步解析,所以需要按照约定字符串填写 - static const std::unordered_map - nas_infos = {{ge::ATTR_NAME_ATTACHED_NOTIFY_POLICY, "group"}, {ge::ATTR_NAME_ATTACHED_NOTIFY_KEY, "kfc_notify"}, - {ge::ATTR_NAME_ATTACHED_NOTIFY_TYPE, "on_device"}}; - for (const auto &pair: nas_infos) { - AttrUtils::SetStr(nas, pair.first, pair.second); - } - // nas填充好之后,使用ATTR_NAME_ATTACHED_NOTIFY_INFO设置到opdesc上 - EXPECT_TRUE(AttrUtils::SetNamedAttrs(op_desc, ge::ATTR_NAME_ATTACHED_NOTIFY_INFO, nas)); - - // *********解析流程********** - NamedAttrs parser_nas; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(op_desc, ge::ATTR_NAME_ATTACHED_NOTIFY_INFO, parser_nas)); - EXPECT_EQ(parser_nas.GetName(), nas.GetName()); - std::string tmp; - EXPECT_TRUE(AttrUtils::GetStr(parser_nas, ge::ATTR_NAME_ATTACHED_NOTIFY_POLICY, tmp)); - EXPECT_EQ(tmp, "group"); - EXPECT_TRUE(AttrUtils::GetStr(op_desc, "group", tmp)); - EXPECT_EQ(tmp, "group0"); - EXPECT_TRUE(AttrUtils::GetStr(parser_nas, ge::ATTR_NAME_ATTACHED_NOTIFY_KEY, tmp)); - EXPECT_EQ(tmp, "kfc_notify"); - EXPECT_TRUE(AttrUtils::GetStr(parser_nas, ge::ATTR_NAME_ATTACHED_NOTIFY_TYPE, tmp)); - EXPECT_EQ(tmp, "on_device"); -} - -TEST_F(AttrUtilsUt, SetGetListNamedAttrs) { - auto op_desc = std::make_shared(); - std::vector nass(5); - for (size_t i = 0; i < nass.size(); ++i) { - auto &nas = nass[i]; - nas.SetName(std::string("name_") + std::to_string(i)); - nas.SetAttr("abc", AnyValue::CreateFrom(static_cast(rand()))); - } - - EXPECT_TRUE(AttrUtils::SetListNamedAttrs(op_desc, "attr", nass)); - - std::vector out_nass; - EXPECT_TRUE(AttrUtils::GetListNamedAttrs(op_desc, "attr", out_nass)); - EXPECT_EQ(out_nass.size(), 5); - for (size_t i = 0; i < out_nass.size(); ++i) { - auto &out_nas = out_nass[i]; - auto &nas = nass[i]; - EXPECT_EQ(out_nas.GetName(), nas.GetName()); - AnyValue out_av, av; - EXPECT_EQ(out_nas.GetAttr("abc", out_av), GRAPH_SUCCESS); - EXPECT_EQ(nas.GetAttr("abc", av), GRAPH_SUCCESS); - EXPECT_EQ(*out_av.Get(), *av.Get()); - } -} - -TEST_F(AttrUtilsUt, ClearAttrs) { - auto op_desc = std::make_shared(); - NamedAttrs nas; - nas.SetName("Hello Name"); - nas.SetAttr("abc", AnyValue::CreateFrom(static_cast(10))); - - EXPECT_TRUE(AttrUtils::SetNamedAttrs(op_desc, "attr", nas)); - EXPECT_EQ(AttrUtils::GetAllAttrs(op_desc).size(), 1); - AttrUtils::ClearAllAttrs(op_desc); - EXPECT_EQ(AttrUtils::GetAllAttrs(op_desc).size(), 0); -} - -TEST_F(AttrUtilsUt, ValueTypeToSerialString) { - EXPECT_EQ(AttrUtils::ValueTypeToSerialString(AnyValue::VT_STRING), "VT_STRING"); - EXPECT_EQ(AttrUtils::ValueTypeToSerialString(static_cast(-1)), ""); -} - -TEST_F(AttrUtilsUt, SerialStringToValueType) { - EXPECT_EQ(AttrUtils::SerialStringToValueType("VT_STRING"), AnyValue::VT_STRING); - EXPECT_EQ(AttrUtils::SerialStringToValueType("XXXXX"), AnyValue::VT_NONE); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/attribute_group_base_unittest.cc b/tests/ut/graph/testcase/attribute_group_base_unittest.cc deleted file mode 100644 index 3327ed806515ec8af4dd7d718a3f1f535af9b403..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/attribute_group_base_unittest.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "attribute_group/attr_group_base.h" -#include "attribute_group/attr_group_shape_env.h" -#include "attribute_group/attr_group_symbolic_desc.h" -#include "ge_ir.pb.h" - -namespace ge { -namespace { -class AttrGroupsBaseTest : public AttrGroupsBase { -public: - virtual std::unique_ptr Clone() {return nullptr;}; -}; - -class AttributeGroupUt : public testing::Test {}; - -TEST_F(AttributeGroupUt, TypeID) { - EXPECT_EQ(GetTypeId(), (void*)10); - EXPECT_EQ(GetTypeId(), (void*)11); - EXPECT_EQ(GetTypeId(), (void*)12); - EXPECT_EQ(GetTypeId(), (void*)13); - EXPECT_EQ(GetTypeId(), (void*)14); - EXPECT_EQ(GetTypeId(), (void*)15); - EXPECT_EQ(GetTypeId(), (void*)16); - EXPECT_EQ(GetTypeId(), (void*)17); -} - -TEST_F(AttributeGroupUt, Clone) { - EXPECT_NO_THROW( - proto::AttrGroupDef attr_group_def; - auto base = AttrGroupsBaseTest(); - EXPECT_EQ(GRAPH_SUCCESS, base.Serialize(attr_group_def)); - EXPECT_EQ(GRAPH_SUCCESS, base.Deserialize(attr_group_def, nullptr)); - ); -} -} -} // namespace ge diff --git a/tests/ut/graph/testcase/buffer_unittest.cc b/tests/ut/graph/testcase/buffer_unittest.cc deleted file mode 100644 index e0a020b4707bf6ab5e27b6e1cbde66460c838702..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/buffer_unittest.cc +++ /dev/null @@ -1,154 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/buffer.h" - -namespace ge { -class BufferUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(BufferUT, ShareFrom1) { - uint8_t first_buf[100]; - for (int i = 0; i < 100; ++i) { - first_buf[i] = i * 1024; - } - uint8_t second_buf[100]; - for (int i = 0; i < 100; ++i) { - second_buf[i] = i * 1024; - } - second_buf[50] = 10; - - Buffer buf(100); - memcpy_s(buf.GetData(), buf.GetSize(), first_buf, sizeof(first_buf)); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); - - Buffer buf1 = BufferUtils::CreateShareFrom(buf); // The buf1 and buf are ref from the same memory now - buf1.GetData()[50] = 10; - EXPECT_EQ(memcmp(buf1.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_EQ(memcmp(buf.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_NE(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); -} - -TEST_F(BufferUT, ShareFrom2) { - uint8_t first_buf[100]; - for (int i = 0; i < 100; ++i) { - first_buf[i] = i * 1024; - } - uint8_t second_buf[100]; - for (int i = 0; i < 100; ++i) { - second_buf[i] = i * 1024; - } - second_buf[50] = 10; - - Buffer buf(100); - memcpy_s(buf.GetData(), buf.GetSize(), first_buf, sizeof(first_buf)); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); - - Buffer buf1; - BufferUtils::ShareFrom(buf, buf1); // The buf1 and buf are ref from the same memory now - buf1.GetData()[50] = 10; - EXPECT_EQ(memcmp(buf1.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_EQ(memcmp(buf.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_NE(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); -} - -TEST_F(BufferUT, OperatorAssign) { - uint8_t first_buf[100]; - for (int i = 0; i < 100; ++i) { - first_buf[i] = i * 1024; - } - uint8_t second_buf[100]; - for (int i = 0; i < 100; ++i) { - second_buf[i] = i * 1024; - } - second_buf[50] = 10; - - Buffer buf(100); - memcpy_s(buf.GetData(), buf.GetSize(), first_buf, sizeof(first_buf)); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); - - Buffer buf1; - buf1 = buf; // The buf1 and buf are ref from the same memory now - buf1.GetData()[50] = 10; - EXPECT_EQ(memcmp(buf1.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_EQ(memcmp(buf.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_NE(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); -} - -TEST_F(BufferUT, CreateShareFrom) { - uint8_t first_buf[100]; - for (int i = 0; i < 100; ++i) { - first_buf[i] = i * 1024; - } - uint8_t second_buf[100]; - for (int i = 0; i < 100; ++i) { - second_buf[i] = i * 1024; - } - second_buf[50] = 10; - - Buffer buf(100); - memcpy_s(buf.GetData(), buf.GetSize(), first_buf, sizeof(first_buf)); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); - - Buffer buf1 = BufferUtils::CreateShareFrom(buf); // The buf1 and buf are ref from the same memory now - buf1.GetData()[50] = 10; - EXPECT_EQ(memcmp(buf1.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_EQ(memcmp(buf.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_NE(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); -} - -TEST_F(BufferUT, CreateCopyFrom1) { - uint8_t first_buf[100]; - for (int i = 0; i < 100; ++i) { - first_buf[i] = i * 2; - } - uint8_t second_buf[100]; - for (int i = 0; i < 100; ++i) { - second_buf[i] = i * 2; - } - second_buf[50] = 250; - - Buffer buf(100); - memcpy_s(buf.GetData(), buf.GetSize(), first_buf, sizeof(first_buf)); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); - - Buffer buf1; - BufferUtils::CopyFrom(buf, buf1); - buf1.GetData()[50] = 250; - EXPECT_EQ(memcmp(buf1.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); -} - -TEST_F(BufferUT, CreateCopyFrom2) { - uint8_t first_buf[100]; - for (int i = 0; i < 100; ++i) { - first_buf[i] = i * 2; - } - uint8_t second_buf[100]; - for (int i = 0; i < 100; ++i) { - second_buf[i] = i * 2; - } - second_buf[50] = 250; - - Buffer buf(100); - memcpy_s(buf.GetData(), buf.GetSize(), first_buf, sizeof(first_buf)); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); - - Buffer buf1 = BufferUtils::CreateCopyFrom(buf); // The buf1 and buf are ref from the same memory now - buf1.GetData()[50] = 250; - EXPECT_EQ(memcmp(buf1.GetData(), second_buf, sizeof(second_buf)), 0); - EXPECT_EQ(memcmp(buf.GetData(), first_buf, sizeof(first_buf)), 0); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/compile_cache_policy_hasher_unittest.cc b/tests/ut/graph/testcase/compile_cache_policy_hasher_unittest.cc deleted file mode 100644 index 6165ed4773b97258e60dd22ecd8f31155a7fc182..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/compile_cache_policy_hasher_unittest.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/cache_policy/compile_cache_desc.h" - -namespace ge { -class UtestCompileCachePolicyHasher : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHolderConstrutFromPtr) { - uint8_t data1[2] = {0U,1U}; - BinaryHolder holder1(data1, sizeof(data1)); - const uint8_t *dataPtr = holder1.GetDataPtr(); - ASSERT_NE(dataPtr, nullptr); - size_t size1 = holder1.GetDataLen(); - ASSERT_EQ(size1, sizeof(data1)); - ASSERT_EQ(dataPtr[0], 0); - ASSERT_EQ(dataPtr[1], 1); -} - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHolderCopyConstrut) { - uint8_t data1[2] = {0U,1U}; - BinaryHolder holder1(data1, sizeof(data1)); - const uint8_t *dataPtr = holder1.GetDataPtr(); - ASSERT_NE(dataPtr, nullptr); - size_t size1 = holder1.GetDataLen(); - ASSERT_EQ(size1, sizeof(data1)); - - BinaryHolder holder2 = holder1; - ASSERT_EQ((holder1 != holder2), false); - ASSERT_NE(holder1.GetDataPtr(), holder2.GetDataPtr()); - BinaryHolder holder3(holder1); - ASSERT_NE(holder1.GetDataPtr(), holder3.GetDataPtr()); - ASSERT_EQ((holder1 != holder3), false); -} - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHolderMoveConstrut) { - uint8_t data1[2] = {0U,1U}; - BinaryHolder holder1(data1, sizeof(data1)); - - BinaryHolder holder2 = std::move(holder1); - ASSERT_EQ(holder1.GetDataPtr(), nullptr); - ASSERT_NE(holder2.GetDataPtr(), nullptr); - ASSERT_EQ(holder1.GetDataLen(), 0); - ASSERT_EQ(holder2.GetDataLen(), sizeof(data1)); - - const uint8_t *dataPtr = holder2.GetDataPtr(); - ASSERT_NE(dataPtr, nullptr); - size_t size2 = holder2.GetDataLen(); - ASSERT_EQ(size2, sizeof(data1)); - ASSERT_EQ(dataPtr[0], 0); - ASSERT_EQ(dataPtr[1], 1); - - BinaryHolder holder3(std::move(holder2)); - ASSERT_EQ(holder2.GetDataPtr(), nullptr); - ASSERT_NE(holder3.GetDataPtr(), nullptr); - ASSERT_EQ(holder2.GetDataLen(), 0); - ASSERT_EQ(holder3.GetDataLen(), sizeof(data1)); -} - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHoldercreateFromUniquePtr) { - auto data_ptr = std::unique_ptr(new (std::nothrow) uint8_t[10]); - ASSERT_NE(data_ptr, nullptr); - const uint8_t *data_ptr_real = data_ptr.get(); - auto holder2 = BinaryHolder::createFrom(std::move(data_ptr), 10); - ASSERT_NE(holder2->GetDataPtr(), nullptr); - ASSERT_EQ(holder2->GetDataPtr(), data_ptr_real); - ASSERT_EQ(holder2->GetDataLen(), 10); - ASSERT_EQ(data_ptr, nullptr); -} - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHoldercreateFromUniquePtrFail) { - auto holder2 = BinaryHolder::createFrom(nullptr, 0); - ASSERT_EQ(holder2->GetDataPtr(), nullptr); - ASSERT_EQ(holder2->GetDataLen(), 0); -} - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHolderEqual) { - uint8_t data1[8] = {0U,1U,2U,3U,4U,5U,6U,7U}; - uint8_t data2[8] = {0U,1U,2U,3U,4U,5U,6U,7U}; - - BinaryHolder holder1(data1, sizeof(data1)); - const uint8_t *dataPtr = holder1.GetDataPtr(); - ASSERT_NE(dataPtr, nullptr); - size_t size1 = holder1.GetDataLen(); - ASSERT_EQ(size1, sizeof(data1)); - - BinaryHolder holder2(data2, sizeof(data2)); - ASSERT_EQ((holder1 != holder2), false); -} - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHolderDiffBecauseLength) { - uint8_t data1[8] = {0U,1U,2U,3U,4U,5U,6U,7U}; - uint8_t data2[9] = {0U,1U,2U,3U,4U,5U,7U,9U,11U}; - - BinaryHolder holder1(data1, sizeof(data1)); - ASSERT_EQ(holder1.GetDataLen(), sizeof(data1)); - BinaryHolder holder2(data2, sizeof(data2)); - ASSERT_EQ(holder2.GetDataLen(), sizeof(data2)); - ASSERT_EQ((holder1 != holder2), true); -} - -TEST_F(UtestCompileCachePolicyHasher, TestBinaryHolderDiffBecauseVaule) { - uint8_t data1[8] = {0U,1U,2U,3U,4U,5U,6U,7U}; - uint8_t data2[8] = {1U,1U,2U,3U,4U,5U,6U,7U}; - - BinaryHolder holder1(data1, sizeof(data1)); - ASSERT_EQ(holder1.GetDataLen(), sizeof(data1)); - BinaryHolder holder2(data2, sizeof(data2)); - ASSERT_EQ(holder2.GetDataLen(), sizeof(data2)); - ASSERT_EQ((holder1 != holder2), true); -} - -TEST_F(UtestCompileCachePolicyHasher, TestGetCacheDescHashWithoutShape) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("1111"); - TensorInfoArgs tensor_info_args(FORMAT_ND, FORMAT_ND, DT_BF16); - cache_desc->AddTensorInfo(tensor_info_args); - CacheHashKey id = cache_desc->GetCacheDescHash(); - cache_desc->SetOpType("2222"); - CacheHashKey id_another = cache_desc->GetCacheDescHash(); - ASSERT_NE(id, id_another); -} -} diff --git a/tests/ut/graph/testcase/compile_cache_policy_unittest.cc b/tests/ut/graph/testcase/compile_cache_policy_unittest.cc deleted file mode 100644 index 8b2449542905b80f40256b432a95490788e5063d..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/compile_cache_policy_unittest.cc +++ /dev/null @@ -1,729 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "common/checker.h" -#include "exe_graph/runtime/shape.h" -#include "graph/cache_policy/cache_state.h" -#include "graph/cache_policy/match_policy_for_exactly_the_same.h" -#include "graph/cache_policy/aging_policy_lru_k.h" -#include "cache_desc_stub/runtime_cache_desc.h" -#include "graph/cache_policy/cache_policy.h" -#include "graph/cache_policy/aging_policy_lru.h" - -namespace ge { -namespace { -std::vector AddCachesByDepth(std::unique_ptr &cp, uint16_t depth) { - std::vector ids; - for (uint16_t i = 0; i < depth; ++i) { - int64_t dim_0 = i; - gert::Shape s{dim_0, 256, 256}; - auto cache_desc = std::make_shared(); - cache_desc->SetShapes({s}); - CacheItemId cache_id = cp->AddCache(cache_desc); - - if (cache_id != KInvalidCacheItemId) { - GELOGE(ge::FAILED, "AddCachesByDepth falied."); - return {}; - } - cache_id = cp->AddCache(cache_desc); - if (cache_id == KInvalidCacheItemId) { - GELOGE(ge::FAILED, "AddCachesByDepth falied."); - return {}; - } - ids.emplace_back(cache_id); - } - return ids; -} - -std::vector AddCachesByDepthForLRU(std::unique_ptr &cp, uint16_t depth) { - std::vector ids; - for (uint16_t i = 0; i < depth; ++i) { - int64_t dim_0 = i; - gert::Shape s{dim_0, 256, 256}; - auto cache_desc = std::make_shared(); - cache_desc->SetShapes({s}); - auto cache_id = cp->AddCache(cache_desc); - if (cache_id == KInvalidCacheItemId) { - GELOGE(ge::FAILED, "AddCachesByDepth falied."); - return {}; - } - ids.emplace_back(cache_id); - } - return ids; -} -} // namespace - -class UtestCompileCachePolicy : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestCompileCachePolicy, CreateCCPSuccess_1) { - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, ge::AgingPolicyType::AGING_POLICY_LRU); - ASSERT_NE(ccp, nullptr); -} - -TEST_F(UtestCompileCachePolicy, CreateCCPSuccess_2) { - auto mp_ptr = PolicyRegister::GetInstance().GetMatchPolicy(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY); - auto ap_ptr = PolicyRegister::GetInstance().GetAgingPolicy(ge::AgingPolicyType::AGING_POLICY_LRU); - auto ccp = ge::CachePolicy::Create(mp_ptr, ap_ptr); - ASSERT_NE(ccp, nullptr); -} - -TEST_F(UtestCompileCachePolicy, CreateCCPFailed_1) { - auto mp_ptr = PolicyRegister::GetInstance().GetMatchPolicy(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY); - auto ap_ptr = nullptr; - auto ccp = ge::CachePolicy::Create(mp_ptr, ap_ptr); - ASSERT_EQ(ccp, nullptr); -} - -TEST_F(UtestCompileCachePolicy, CreateCCPFailed_2) { - auto mp_ptr = nullptr; - auto ap_ptr = nullptr; - auto ccp = ge::CachePolicy::Create(mp_ptr, ap_ptr); - ASSERT_EQ(ccp, nullptr); -} - -TEST_F(UtestCompileCachePolicy, AddSameCache) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - cache_desc->AddBinary(holder); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - CacheItemId cache_id_same = ccp->AddCache(cache_desc); - ASSERT_EQ(cache_id_same, cache_id); -} - -TEST_F(UtestCompileCachePolicy, AddDifferentOptypeCache) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - cache_desc->AddBinary(holder); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - cache_desc->SetOpType("another_op"); - CacheItemId cache_id_another = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id_another, -1); - ASSERT_NE(cache_id_another, cache_id); -} - -TEST_F(UtestCompileCachePolicy, AddDifferentUniqueIdCache) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - cache_desc->AddBinary(holder); - cache_desc->SetScopeId({1, 2}); - ASSERT_EQ(cache_desc->scope_id_.size(), 2); - ASSERT_EQ(cache_desc->scope_id_[0], 1); - ASSERT_EQ(cache_desc->scope_id_[1], 2); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - CompileCacheDescPtr other_cache_desc = std::make_shared(*cache_desc.get()); - other_cache_desc->SetScopeId({1, 3}); - ASSERT_EQ(other_cache_desc->scope_id_.size(), 2); - ASSERT_EQ(other_cache_desc->scope_id_[0], 1); - ASSERT_EQ(other_cache_desc->scope_id_[1], 3); - CacheItemId cache_id_another = ccp->AddCache(other_cache_desc); - ASSERT_NE(cache_id_another, -1); - ASSERT_NE(cache_id_another, cache_id); -} - -TEST_F(UtestCompileCachePolicy, AddDifferentBinarySizeCache) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - CompileCacheDescPtr other_cache_desc = std::make_shared(*cache_desc.get()); - other_cache_desc->AddBinary(holder); - CacheItemId cache_id_another = ccp->AddCache(other_cache_desc); - ASSERT_NE(cache_id_another, -1); - ASSERT_NE(cache_id_another, cache_id); -} - -TEST_F(UtestCompileCachePolicy, AddDifferentBinaryValueCache) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - cache_desc->AddBinary(holder); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - uint8_t value_another = 8; - BinaryHolder holder1(&value_another, 1); - CompileCacheDescPtr other_cache_desc = std::make_shared(*cache_desc.get()); - other_cache_desc->other_desc_.clear(); - other_cache_desc->AddBinary(holder1); - CacheItemId cache_id_another = ccp->AddCache(other_cache_desc); - ASSERT_NE(cache_id_another, -1); - ASSERT_NE(cache_id_another, cache_id); -} - -TEST_F(UtestCompileCachePolicy, AddDifferentTensorFormatCache) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - cache_desc->AddBinary(holder); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - cache_desc->tensor_info_args_vec_[0].format_ = ge::FORMAT_NCHW; - CacheItemId cache_id_another = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id_another, -1); - ASSERT_NE(cache_id_another, cache_id); -} - -TEST_F(UtestCompileCachePolicy, CacheFindFailBecauseRangeFirst) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - CompileCacheDescPtr cache_desc_match = cache_desc; - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - SmallVector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - ASSERT_EQ(tensor_info.shape_.size(), 2); - ASSERT_EQ(tensor_info.shape_[0], -1); - ASSERT_EQ(tensor_info.shape_[1], -1); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - std::vector shape_match{0,5}; - tensor_info.SetShape(shape_match); - cache_desc_match->AddTensorInfo(tensor_info); - CacheItemId cache_id_find = ccp->FindCache(cache_desc_match); - ASSERT_EQ(cache_id_find, KInvalidCacheItemId); -} - -TEST_F(UtestCompileCachePolicy, CacheFindFailBecauseRangeSecond) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - CompileCacheDescPtr cache_desc_match = cache_desc; - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - std::vector shape_match{5,11}; - tensor_info.SetShape(shape_match); - cache_desc_match->AddTensorInfo(tensor_info); - CacheItemId cache_id_find = ccp->FindCache(cache_desc_match); - ASSERT_EQ(cache_id_find, KInvalidCacheItemId); -} - -TEST_F(UtestCompileCachePolicy, CacheFindSuccessBecauseUnknownRank) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - CompileCacheDescPtr cache_desc_match = std::make_shared(*cache_desc.get()); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-2}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - cache_desc->AddTensorInfo(tensor_info); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - std::vector shape_match{5,11}; - tensor_info.SetShape(shape_match); - cache_desc_match->AddTensorInfo(tensor_info); - CacheItemId cache_id_find = ccp->FindCache(cache_desc_match); - ASSERT_EQ(cache_id_find, cache_id); -} - - -TEST_F(UtestCompileCachePolicy, CacheFindSuccessCommonTest) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - CompileCacheDescPtr cache_desc_match = std::make_shared(*cache_desc.get()); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - std::vector shape_match{5,5}; - tensor_info.SetShape(shape_match); - cache_desc_match->AddTensorInfo(tensor_info); - CacheItemId cache_id_find = ccp->FindCache(cache_desc_match); - ASSERT_EQ(cache_id, cache_id_find); - ge::CachePolicy ccp1; - ASSERT_EQ(ccp1.FindCache(cache_desc_match), KInvalidCacheItemId); -} - -TEST_F(UtestCompileCachePolicy, CacheDelTest) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - std::vector shape{-1,-1}; - tensor_info.SetShape(shape); - tensor_info.SetOriginShape(shape); - std::vector> ranges{{1,10}, {1,10}}; - tensor_info.SetShapeRange(ranges); - cache_desc->AddTensorInfo(tensor_info); - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - cache_desc->AddBinary(holder); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - CacheItemId cache_id_same = ccp->AddCache(cache_desc); - ASSERT_EQ(cache_id_same, cache_id); - - cache_desc->SetOpType("another_op"); - CacheItemId cache_id_another = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id_another, -1); - ASSERT_NE(cache_id_another, cache_id); - - std::vector delete_item{cache_id, cache_id_another}; - std::vector delete_item_ret = ccp->DeleteCache(delete_item); - ASSERT_EQ(delete_item_ret.size(), 2); -} - -TEST_F(UtestCompileCachePolicy, AgingCacheSuccess_1) { - CompileCacheDescPtr cache_desc = std::make_shared(); - cache_desc->SetOpType("test_op"); - TensorInfoArgs tensor_info(ge::FORMAT_ND, ge::FORMAT_ND, ge::DT_FLOAT16); - cache_desc->AddTensorInfo(tensor_info); - ASSERT_EQ(cache_desc->GetTensorInfoSize(), 1); - ASSERT_EQ(cache_desc->MutableTensorInfo(1), nullptr); - ASSERT_NE(cache_desc->MutableTensorInfo(0), nullptr); - uint8_t value = 9; - uint8_t *data = &value; - BinaryHolder holder(data, 1); - BinaryHolder holder_new; - holder_new = holder; - ASSERT_NE(holder_new.GetDataPtr(), nullptr); - cache_desc->AddBinary(holder_new); - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_EXACT_ONLY, - ge::AgingPolicyType::AGING_POLICY_LRU); - CacheItemId cache_id = ccp->AddCache(cache_desc); - ASSERT_NE(cache_id, -1); - - std::vector del_item = ccp->DoAging(); - ASSERT_EQ(cache_id, del_item[0]); -} - -TEST_F(UtestCompileCachePolicy, CreateCCPSuccess_RuntimeCachePolicy_1) { - auto ccp = ge::CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_FOR_EXACTLY_THE_SAME, - ge::AgingPolicyType::AGING_POLICY_LRU_K); - ASSERT_NE(ccp, nullptr); -} - -TEST_F(UtestCompileCachePolicy, CreateCCPSuccess_RuntimeCachePolicy_2) { - auto mp_ptr = PolicyRegister::GetInstance().GetMatchPolicy(ge::MatchPolicyType::MATCH_POLICY_FOR_EXACTLY_THE_SAME); - auto ap_ptr = PolicyRegister::GetInstance().GetAgingPolicy(ge::AgingPolicyType::AGING_POLICY_LRU_K); - auto ccp = ge::CachePolicy::Create(mp_ptr, ap_ptr); - ASSERT_NE(ccp, nullptr); -} - -TEST_F(UtestCompileCachePolicy, AddCache_ReturnKInvalidCacheItemId_CacheDescNotMetAddCondition) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s{1, 3, 256, 256}; - std::vector shapes{s}; - - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); -} - -TEST_F(UtestCompileCachePolicy, AddCache_ReturnValidCacheItemId_CacheDescMetAddCondition) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s{1, 3, 256, 256}; - std::vector shapes{s}; - - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); -} - -TEST_F(UtestCompileCachePolicy, AddCache_GetSameCacheId_AddSameCache) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s{1, 3, 256, 256}; - std::vector shapes{s}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - - // add cache 2 - CacheItemId cache_id_same = cp->AddCache(cache_desc); - EXPECT_TRUE(cache_id_same == cache_id); -} - -TEST_F(UtestCompileCachePolicy, AddCache_GetSameCacheId_AddAnotherSameCache) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s{1, 3, 256, 256}; - std::vector shapes{s}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - - // add cache 2 - auto another_cache_desc = std::make_shared(); - another_cache_desc->SetShapes(shapes); - CacheItemId cache_id_same = cp->AddCache(another_cache_desc); - EXPECT_EQ(cache_id_same, cache_id); -} - - -TEST_F(UtestCompileCachePolicy, AddCache_GetDiffCacheId_OneCacheSetDiffShapes) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s1{1, 3, 256, 256}; - std::vector shapes1{s1}; - gert::Shape s2{3, 256, 256}; - std::vector shapes2{s2}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes1); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - - // add cache 2 - cache_desc->SetShapes(shapes2); - CacheItemId another_cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(another_cache_id, KInvalidCacheItemId); - another_cache_id = cp->AddCache(cache_desc); - EXPECT_NE(another_cache_id, KInvalidCacheItemId); - EXPECT_NE(another_cache_id, cache_id); -} - -TEST_F(UtestCompileCachePolicy, AddCache_GetDiffCacheId_AddTwoDiffCache) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s1{1, 3, 256, 256}; - std::vector shapes1{s1}; - gert::Shape s2{3, 256, 256}; - std::vector shapes2{s2}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes1); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - - // add cache 2 - auto another_cache_desc = std::make_shared(); - another_cache_desc->SetShapes(shapes2); - CacheItemId another_cache_id = cp->AddCache(another_cache_desc); - EXPECT_EQ(another_cache_id, KInvalidCacheItemId); - another_cache_id = cp->AddCache(another_cache_desc); - EXPECT_NE(another_cache_id, KInvalidCacheItemId); - EXPECT_NE(another_cache_id, cache_id); -} - -TEST_F(UtestCompileCachePolicy, FindCache_GetSameId_FindTheCacheAdded) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s1{1, 3, 256, 256}; - std::vector shapes1{s1}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes1); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - - // find cache 1 - CacheItemId another_cache_id = cp->FindCache(cache_desc); - EXPECT_EQ(another_cache_id, cache_id); -} - -TEST_F(UtestCompileCachePolicy, FindCache_GetIdMatched_1HashWith2CacheDescs) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s1{1, 3, 256, 256}; - std::vector shapes1{s1}; - gert::Shape s2{3, 256, 256}; - std::vector shapes2{s2}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes1); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - auto cache_desc_hash = cache_desc->GetCacheDescHash(); - - // modify cache_desc which cp saved - cache_desc->SetShapes(shapes2); - - // add cache 2 with same hash key - auto another_cache_desc = std::make_shared(); - another_cache_desc->SetShapes(shapes1); - EXPECT_EQ(another_cache_desc->GetCacheDescHash(), cache_desc_hash); - CacheItemId another_cache_id = cp->AddCache(another_cache_desc); - EXPECT_EQ(another_cache_id, KInvalidCacheItemId); - another_cache_id = cp->AddCache(another_cache_desc); - EXPECT_NE(another_cache_id, KInvalidCacheItemId); - EXPECT_NE(another_cache_id, cache_id); - - // find cache 2 - auto find_cache_desc = std::make_shared(); - find_cache_desc->SetShapes(shapes1); - EXPECT_EQ(find_cache_desc->GetCacheDescHash(), cache_desc_hash); - auto find_cache_id = cp->FindCache(find_cache_desc); - EXPECT_EQ(another_cache_id, find_cache_id); -} - -TEST_F(UtestCompileCachePolicy, FindCache_ReturnKInvalidCacheItemId_HashKeyNotMatched) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s1{0, 3, 256, 256}; - std::vector shapes1{s1}; - gert::Shape s2{3, 256, 256}; - std::vector shapes2{s2}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes1); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - - // find cache 2 - auto another_cache_desc = std::make_shared(); - another_cache_desc->SetShapes(shapes2); - CacheItemId another_cache_id = cp->FindCache(another_cache_desc); - EXPECT_EQ(another_cache_id, KInvalidCacheItemId); -} - -TEST_F(UtestCompileCachePolicy, FindCache_ReturnKInvalidCacheItemId_HashMatchedButCacheDescNotMatch) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - gert::Shape s1{1, 3, 256, 256}; - std::vector shapes1{s1}; - gert::Shape s2{3, 256, 256}; - std::vector shapes2{s2}; - - // add cache 1 - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes1); - CacheItemId cache_id = cp->AddCache(cache_desc); - EXPECT_EQ(cache_id, KInvalidCacheItemId); - cache_id = cp->AddCache(cache_desc); - EXPECT_NE(cache_id, KInvalidCacheItemId); - auto cache_desc_hash = cache_desc->GetCacheDescHash(); - - // modify cache 1 which cp saved - cache_desc->SetShapes(shapes2); - - // find diff cache with the same hash of cache 1 - auto find_cache_desc = std::make_shared(); - find_cache_desc->SetShapes(shapes1); - EXPECT_EQ(find_cache_desc->GetCacheDescHash(), cache_desc_hash); - auto find_cache_id = cp->FindCache(find_cache_desc); - EXPECT_EQ(find_cache_id, KInvalidCacheItemId); -} - -TEST_F(UtestCompileCachePolicy, DoAging_NoAgingId_CacheQueueNotReachDepth) { - auto mp = std::make_shared(); - auto ap = std::make_shared(1); - auto cp = CachePolicy::Create(mp, ap); - - uint16_t depth = 1; - auto add_cache_ids = AddCachesByDepth(cp, depth); - ASSERT_EQ(add_cache_ids.size(), depth); - - auto delete_ids = cp->DoAging(); - EXPECT_EQ(delete_ids.size(), 0); -} - -TEST_F(UtestCompileCachePolicy, DoAging_GetAgingIds_CacheQueueOverDepth) { - auto mp = std::make_shared(); - auto ap = std::make_shared(1); - auto cp = CachePolicy::Create(mp, ap); - - uint16_t depth = 2; - auto add_cache_ids = AddCachesByDepth(cp, depth); - ASSERT_EQ(add_cache_ids.size(), depth); - - auto delete_ids = cp->DoAging(); - EXPECT_EQ(delete_ids.size(), 1); - EXPECT_EQ(delete_ids[0], add_cache_ids[0]); -} - -TEST_F(UtestCompileCachePolicy, DoAging_Aging2Times_CacheQueueOverDepth) { - size_t cached_aging_depth = 1U; - auto cp = CachePolicy::Create(ge::MatchPolicyType::MATCH_POLICY_FOR_EXACTLY_THE_SAME, - ge::AgingPolicyType::AGING_POLICY_LRU_K, cached_aging_depth); - - uint16_t depth = 3; - auto add_cache_ids = AddCachesByDepth(cp, depth); - ASSERT_EQ(add_cache_ids.size(), depth); - - for (size_t i = 0U; i < 2U; ++i) { - auto delete_ids = cp->DoAging(); - ASSERT_EQ(delete_ids.size(), 1); - EXPECT_EQ(delete_ids[0], add_cache_ids[i]); - } -} - -TEST_F(UtestCompileCachePolicy, DoAging_TestSetIntervalForLRU) { - auto mp = std::make_shared(); - auto ap = std::make_shared(); - auto cp = CachePolicy::Create(mp, ap); - - uint16_t depth = 3; - auto add_cache_ids = AddCachesByDepthForLRU(cp, depth); - ASSERT_EQ(add_cache_ids.size(), depth); - - ap->SetDeleteInterval(depth); - auto delete_ids = cp->DoAging(); - EXPECT_EQ(delete_ids.size(), 0U); - - ap->SetDeleteInterval(1U); - delete_ids = cp->DoAging(); - ASSERT_EQ(delete_ids.size(), 2); - EXPECT_TRUE(delete_ids[0] == 0 || delete_ids[0] == 1); - EXPECT_TRUE(delete_ids[1] == 0 || delete_ids[1] == 1); -} -} diff --git a/tests/ut/graph/testcase/compute_graph_unittest.cc b/tests/ut/graph/testcase/compute_graph_unittest.cc deleted file mode 100644 index fa6728e0b63aa005323df14153d8b633131ae435..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/compute_graph_unittest.cc +++ /dev/null @@ -1,2028 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_local_context.h" -#include "graph/ge_context.h" - -#include "graph/compute_graph.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/op_desc.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/ge_tensor.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph_builder_utils.h" -#include "common/ge_common/ge_types.h" -#include "debug/ge_op_types.h" -#include "inc/graph/debug/ge_attr_define.h" -#include "graph/utils/transformer_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/attribute_group/attr_group_shape_env.h" - -namespace { -/* - * netoutput1 - * | - * add - * / \ - * data1 data2 - */ -ge::ComputeGraphPtr BuildSubGraph(const std::string &name) { - ge::ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto add = builder.AddNode(name + "sub", "Sub", 2, 1); - auto netoutput = builder.AddNode(name + "_netoutput", "NetOutput", 1, 1); - - ge::AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - ge::AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - ge::AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - - builder.AddDataEdge(data1, 0, add, 0); - builder.AddDataEdge(data2, 0, add, 1); - builder.AddDataEdge(add, 0, netoutput, 0); - - return builder.GetGraph(); -} -/* - * netoutput - * | - * if - * / \ - * data1 data2 - */ -ge::ComputeGraphPtr BuildMainGraphWithIf(const std::string &name) { - ge::ut::GraphBuilder builder(name); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto if1 = builder.AddNode("if", "If", 2, 1); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 1, 1); - - builder.AddDataEdge(data1, 0, if1, 0); - builder.AddDataEdge(data2, 0, if1, 1); - builder.AddDataEdge(if1, 0, netoutput1, 0); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildSubGraph("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildSubGraph("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} -/* - * netoutput - * | \ \ - * node4 node5 node6 - * | \ - * node2 node3 - * \ / - * node1 - */ -ge::ComputeGraphPtr BuildNormalGraph(const std::string &name) { - auto builder = ge::ut::GraphBuilder(name); - const auto &node1 = builder.AddNode("node1", "node1", 0, 2); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - const auto &node4 = builder.AddNode("node4", "node4", 1, 1); - const auto &node5 = builder.AddNode("node5", "node5", 1, 1); - const auto &node6 = builder.AddNode("node6", "node6", 0, 1); - const auto &netoutput = builder.AddNode("netoutput", "netoutput", 3, 1); - - builder.AddDataEdge(node1, 0, node2, 0); - builder.AddDataEdge(node1, 1, node3, 0); - builder.AddDataEdge(node2, 0, node4, 0); - builder.AddDataEdge(node3, 0, node5, 0); - builder.AddDataEdge(node4, 0, netoutput, 0); - builder.AddDataEdge(node5, 0, netoutput, 1); - builder.AddDataEdge(node6, 0, netoutput, 2); - - builder.AddControlEdge(node1, node2); - builder.AddControlEdge(node1, node3); - builder.AddControlEdge(node2, node4); - builder.AddControlEdge(node3, node5); - builder.AddControlEdge(node4, netoutput); - builder.AddControlEdge(node5, netoutput); - builder.AddControlEdge(node6, netoutput); - return builder.GetGraph(); -} - -/* - * variable data - * / \ | - * node1 node2 node3 - * | | | - * | | node4 - * \ | / - * node5 - */ -ge::ComputeGraphPtr BuildDelayTopoGraph(const std::string &name) { - auto builder = ge::ut::GraphBuilder(name); - const auto &variable = builder.AddNode("variable", ge::VARIABLE, 0, 2); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - const auto &node4 = builder.AddNode("node4", "node4", 1, 1); - const auto &node5 = builder.AddNode("node5", "node5", 3, 0); - const auto &data = builder.AddNode("data", "DATA", 0, 1); - - builder.AddDataEdge(variable, 0, node1, 0); - builder.AddDataEdge(variable, 1, node2, 0); - builder.AddDataEdge(node1, 0, node5, 0); - builder.AddDataEdge(node2, 0, node5, 1); - builder.AddDataEdge(data, 0, node3, 0); - builder.AddDataEdge(node3, 0, node4, 0); - builder.AddDataEdge(node4, 0, node5, 2); - - builder.AddControlEdge(node2, node3); - return builder.GetGraph(); -} - -/* - * variable data - * / \ | - * | node1 node2 - * | | | - * | | node3 - * \ | / - * node4 - */ -ge::ComputeGraphPtr BuildDelayTopoGraphSkipInput(const std::string &name) { - auto builder = ge::ut::GraphBuilder(name); - const auto &variable = builder.AddNode("variable", ge::VARIABLE, 0, 2); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node3", "node2", 1, 1); - const auto &node3 = builder.AddNode("node4", "node3", 1, 1); - const auto &node4 = builder.AddNode("node5", "node4", 3, 0); - const auto &data = builder.AddNode("data", "DATA", 0, 1); - - builder.AddDataEdge(variable, 0, node1, 0); - builder.AddDataEdge(variable, 1, node4, 1); - builder.AddDataEdge(node1, 0, node4, 0); - builder.AddDataEdge(data, 0, node2, 0); - builder.AddDataEdge(node2, 0, node3, 0); - builder.AddDataEdge(node3, 0, node4, 2); - - builder.AddControlEdge(node1, node2); - return builder.GetGraph(); -} - -/* - * constant const variable data - * \ | / \ | - * node1 node2 node3 - * | | | - * | | node4 - * \ | / - * node5 - */ -ge::ComputeGraphPtr BuildDelayTopoGraphMultiInput(const std::string &name, bool all_is_long_life = true) { - auto builder = ge::ut::GraphBuilder(name); - const auto &constant = builder.AddNode("const", ge::CONSTANT, 0, 1); - auto type = ge::CONSTANTOP; - if (!all_is_long_life) { - type = "test"; - } - const auto &constantop = builder.AddNode("constant", type, 0, 1); - const auto &variable = builder.AddNode("variable", ge::VARIABLE, 0, 2); - const auto &node1 = builder.AddNode("node1", "node1", 3, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - const auto &node4 = builder.AddNode("node4", "node4", 1, 1); - const auto &node5 = builder.AddNode("node5", "node5", 3, 0); - const auto &data = builder.AddNode("data", "DATA", 0, 1); - - builder.AddDataEdge(constant, 0, node1, 0); - builder.AddDataEdge(constantop, 0, node1, 1); - builder.AddDataEdge(variable, 0, node1, 2); - builder.AddDataEdge(variable, 1, node2, 0); - builder.AddDataEdge(node1, 0, node5, 0); - builder.AddDataEdge(node2, 0, node5, 1); - builder.AddDataEdge(data, 0, node3, 0); - builder.AddDataEdge(node3, 0, node4, 0); - builder.AddDataEdge(node4, 0, node5, 2); - - builder.AddControlEdge(node2, node3); - return builder.GetGraph(); -} -} - -namespace ge -{ - class UtestComputeGraph : public testing::Test { - protected: - void SetUp() override {} - void TearDown() override {} - }; - -TEST_F(UtestComputeGraph, GetAllNodes_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc(tensor_desc->Clone()); - graph->AddNode(op_desc); - graph->AddNode(op_desc); - - auto node_filter = [](const Node &node){ return true;}; - auto graph_filter = [](const Node &node, const char *str, const ComputeGraphPtr &graph){ return true;}; - auto out_nodes = graph->GetAllNodes(node_filter, graph_filter); - EXPECT_EQ(out_nodes.size(), 2); -} - -TEST_F(UtestComputeGraph, GetNodes_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc(tensor_desc->Clone()); - graph->AddNode(op_desc); - graph->AddNode(op_desc); - auto node_filter = [](const Node &node){ return true;}; - auto graph_filter = [](const Node &node, const char *str, const ComputeGraphPtr &graph){ return true;}; - auto out_nodes = graph->GetNodes(true, node_filter, graph_filter); - EXPECT_EQ(out_nodes.size(), 2); -} - -TEST_F(UtestComputeGraph, AddNodeFront_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto op_desc = std::make_shared("node1", "node1"); - op_desc->AddInputDesc(tensor_desc->Clone()); - auto node = graph->AddNode(op_desc); - - auto op_desc1 = std::make_shared("add_front", "add_front"); - op_desc1->AddInputDesc(tensor_desc->Clone()); - auto nodeptr = graph->AddNodeFront(node); - EXPECT_EQ(node, nodeptr); -} - -/* - Data1 Data2 Data1 Data2 - | | | | - Relu1 Relu2 Relu1 Relu2 - / \ | / \ | - Relu3 Relu4 | Relu3 Relu4 | - | | | | \ | ------- - | | | Relu5 AppendNode | - | | | | | | - Relu5 Relu6 | ---> \ Relu6 | - \ / | \ / | - \ / | Add | - Add | | | - \ / \ / - Add2 Add2 - | | - Output Output - * 在Relu4后插入一个append算子 -*/ -TEST_F(UtestComputeGraph,InsertNode_success) { - // 开启稳定GE排序 - std::map options_map; - options_map["ge.topoSortingMode"] = "3"; - GetThreadLocalContext().SetGraphOption(options_map); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto data1 = builder.AddNode("Data1", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto relu1 = builder.AddNode("Relu1", "Relu", 1, 1); - auto relu2 = builder.AddNode("Relu2", "Relu", 1, 1); - auto relu3 = builder.AddNode("Relu3", "Relu", 1, 1); - auto relu4 = builder.AddNode("Relu4", "Relu", 1, 1); - auto relu5 = builder.AddNode("Relu5", "Relu", 1, 1); - auto relu6 = builder.AddNode("Relu6", "Relu", 1, 1); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto add2_node = builder.AddNode("Add2", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(data2, 0, relu2, 0); - builder.AddDataEdge(relu1, 0, relu3, 0); - builder.AddDataEdge(relu1, 0, relu4, 0); - builder.AddDataEdge(relu3, 0, relu5, 0); - builder.AddDataEdge(relu4, 0, relu6, 0); - builder.AddDataEdge(relu5, 0, add_node, 0); - builder.AddDataEdge(relu6, 0, add_node, 1); - builder.AddDataEdge(relu2, 0, add2_node, 0); - builder.AddDataEdge(add_node, 0, add2_node, 1); - builder.AddDataEdge(add2_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - std::vector expected_stable_rdfs_names = - {"Data1", "Data2", "Relu1", "Relu2", "Relu3", "Relu4", "Relu5", "Relu6", "Add", "Add2", "Netoutput"}; - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector stable_rdfs_names; - for (auto &node : graph->GetDirectNode()) { - stable_rdfs_names.push_back(node->GetName()); - } - EXPECT_EQ(stable_rdfs_names, expected_stable_rdfs_names); - - auto op_desc_new = std::make_shared("append_1", "Add"); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddOutputDesc(tensor_desc->Clone()); - auto append_node = graph->InsertNode(relu4, op_desc_new); - ASSERT_NE(append_node, nullptr); - GraphUtils::RemoveEdge(relu4->GetOutAnchor(0), relu6->GetInAnchor(0)); - builder.AddDataEdge(relu2, 0, append_node, 0); - builder.AddDataEdge(relu4, 0, append_node, 1); - builder.AddDataEdge(append_node, 0, relu6, 0); - - auto op_desc_new_2 = std::make_shared("append_2", "Rule"); - op_desc_new_2->AddInputDesc(tensor_desc->Clone()); - op_desc_new_2->AddOutputDesc(tensor_desc->Clone()); - auto append_node_2 = graph->InsertNodeBefore(add2_node, op_desc_new_2); - ASSERT_NE(append_node_2, nullptr); - GraphUtils::RemoveEdge(add_node->GetOutAnchor(0), add2_node->GetInAnchor(1)); - builder.AddDataEdge(add_node, 0, append_node_2, 0); - builder.AddDataEdge(append_node_2, 0, add2_node, 1); - expected_stable_rdfs_names = - {"Data1", "Data2", "Relu1", "Relu2", "Relu3", "Relu4", "append_1", "Relu5", "Relu6", "Add", - "append_2", "Add2", "Netoutput"}; - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - stable_rdfs_names.clear(); - for (auto &node : graph->GetDirectNode()) { - stable_rdfs_names.push_back(node->GetName()); - } - EXPECT_EQ(stable_rdfs_names, expected_stable_rdfs_names); -} - -/* - Data1 Data2 Data1 Data2 - | | | | - Relu1 Relu2 Relu1 Relu2 - / \ | / \ | - Relu3 Relu4 | Relu3 Relu4 | - | | | | \ | ------- - | | | Relu5 AppendNode | - | | | | | | - Relu5 Relu6 | ---> \ Relu6 | - \ / | \ / | - \ / | Add | - Add | | | - \ / \ / - Add2 Add2 - | | - Output Output - - * 在Relu2后面插入一个append - * 测试当存在大topo算子连边向小topo算子的改变时,topo序发生变化 - * Relu4 ->AppendNode -*/ -TEST_F(UtestComputeGraph, InsertNode_bigtopo_to_smalltopo) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto data1 = builder.AddNode("Data1", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto relu1 = builder.AddNode("Relu1", "Relu", 1, 1); - auto relu2 = builder.AddNode("Relu2", "Relu", 1, 1); - auto relu3 = builder.AddNode("Relu3", "Relu", 1, 1); - auto relu4 = builder.AddNode("Relu4", "Relu", 1, 1); - auto relu5 = builder.AddNode("Relu5", "Relu", 1, 1); - auto relu6 = builder.AddNode("Relu6", "Relu", 1, 1); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto add2_node = builder.AddNode("Add2", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(data2, 0, relu2, 0); - builder.AddDataEdge(relu1, 0, relu3, 0); - builder.AddDataEdge(relu1, 0, relu4, 0); - builder.AddDataEdge(relu3, 0, relu5, 0); - builder.AddDataEdge(relu4, 0, relu6, 0); - builder.AddDataEdge(relu5, 0, add_node, 0); - builder.AddDataEdge(relu6, 0, add_node, 1); - builder.AddDataEdge(relu2, 0, add2_node, 0); - builder.AddDataEdge(add_node, 0, add2_node, 1); - builder.AddDataEdge(add2_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - std::vector expected_stable_rdfs_names = - {"Data1", "Data2", "Relu1", "Relu2", "Relu3", "Relu4", "Relu5", "Relu6", "Add", "Add2", "Netoutput"}; - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector stable_rdfs_names; - for (auto &node : graph->GetDirectNode()) { - stable_rdfs_names.push_back(node->GetName()); - } - EXPECT_EQ(stable_rdfs_names, expected_stable_rdfs_names); - - auto op_desc_new = std::make_shared("append_1", "Add"); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddOutputDesc(tensor_desc->Clone()); - auto append_node = graph->InsertNode(relu2, op_desc_new); - ASSERT_NE(append_node, nullptr); - GraphUtils::RemoveEdge(relu4->GetOutAnchor(0), relu6->GetInAnchor(0)); - builder.AddDataEdge(relu2, 0, append_node, 0); - builder.AddDataEdge(relu4, 0, append_node, 1); - builder.AddDataEdge(append_node, 0, relu6, 0); - // 由于存在大topo的有连边给小topo算子,所以顺序变化 - expected_stable_rdfs_names = - {"Data1", "Data2", "Relu1", "Relu2", "Relu4", "append_1", "Relu3", "Relu5", "Relu6", "Add", "Add2", "Netoutput"}; - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - stable_rdfs_names.clear(); - for (auto &node : graph->GetDirectNode()) { - stable_rdfs_names.push_back(node->GetName()); - } - EXPECT_EQ(stable_rdfs_names, expected_stable_rdfs_names); -} - -/* - Data1 Data2 Data1 Data2 Data1 Data2 Data1 Data2 - | | | | | | | | - Relu1 Relu2 Relu1 Relu2 Relu1 Relu2 Relu1 Relu2 - / \ | / \ | / \ | / \ | - Relu3 Relu4 | fuse_node Relu4 | fuse_node Relu4 | fuse_node fuse_node2 | - | | | | \ | | \ | | \ | - | | | | | | append_node1 Relu6 | append_node1 | | - | | | | | | | | | | | | - Relu5 Relu6 | ---> \ Relu6 | --> \ append_node2 | -> \ / | - \ / | \ | | \ / | Add | - \ / | \ / | Add | | | - \ / | Add | | | \ / - Add | | | \ / Add2 - \ / \ / Add2 | - Add2 Add2 | Output - | | Output - Output Output - * 在将Relu3和Relu4融合 -*/ - -TEST_F(UtestComputeGraph, FuseNodeKeepTopo_success) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto data1 = builder.AddNode("Data1", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto relu1 = builder.AddNode("Relu1", "Relu", 1, 1); - auto relu2 = builder.AddNode("Relu2", "Relu", 1, 1); - auto relu3 = builder.AddNode("Relu3", "Relu", 1, 1); - AttrUtils::SetStr(relu3->GetOpDesc(), public_attr::USER_STREAM_LABEL, "test_stream"); - auto relu4 = builder.AddNode("Relu4", "Relu", 1, 1); - auto relu5 = builder.AddNode("Relu5", "Relu", 1, 1); - auto relu6 = builder.AddNode("Relu6", "Relu", 1, 1); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto add2_node = builder.AddNode("Add2", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(data2, 0, relu2, 0); - builder.AddDataEdge(relu1, 0, relu3, 0); - builder.AddDataEdge(relu1, 0, relu4, 0); - builder.AddDataEdge(relu3, 0, relu5, 0); - builder.AddDataEdge(relu4, 0, relu6, 0); - builder.AddDataEdge(relu5, 0, add_node, 0); - builder.AddDataEdge(relu6, 0, add_node, 1); - builder.AddDataEdge(relu2, 0, add2_node, 0); - builder.AddDataEdge(add_node, 0, add2_node, 1); - builder.AddDataEdge(add2_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - auto op_desc_new = std::make_shared("fuse_node", "Relu"); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddOutputDesc(tensor_desc->Clone()); - auto fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 1); - std::string inherited_stream_label; - AttrUtils::GetStr(op_desc_new, public_attr::USER_STREAM_LABEL, inherited_stream_label); - EXPECT_STREQ(inherited_stream_label.c_str(), "test_stream"); - GraphUtils::RemoveEdge(relu1->GetOutAnchor(0), relu3->GetInAnchor(0)); - GraphUtils::RemoveEdge(relu5->GetOutAnchor(0), add_node->GetInAnchor(0)); - builder.AddDataEdge(relu1, 0, fuse_node_vec.front(), 0); - - auto op_desc_append = std::make_shared("append_node", "Relu"); - op_desc_append->AddInputDesc(tensor_desc->Clone()); - op_desc_append->AddOutputDesc(tensor_desc->Clone()); - auto op_desc_append2 = std::make_shared("append_node2", "Relu"); - op_desc_append2->AddInputDesc(tensor_desc->Clone()); - op_desc_append2->AddOutputDesc(tensor_desc->Clone()); - auto append_node_vec = graph->InsertNodes(fuse_node_vec.front(), {op_desc_append, op_desc_append2}); - ASSERT_EQ(append_node_vec.size(), 2); - builder.AddDataEdge(fuse_node_vec.front(), 0, append_node_vec[0], 0); - builder.AddDataEdge(append_node_vec[0], 0, add_node, 0); - GraphUtils::RemoveEdge(relu6->GetOutAnchor(0), add_node->GetInAnchor(1)); - builder.AddDataEdge(relu6, 0, append_node_vec[1], 0); - - auto op_desc_new2 = std::make_shared("fuse_node2", "Relu"); - op_desc_new2->AddInputDesc(tensor_desc->Clone()); - op_desc_new2->AddOutputDesc(tensor_desc->Clone()); - std::string not_support_reason; - ASSERT_TRUE(graph->IsSupportFuse({relu4, relu6, append_node_vec[1]}, not_support_reason)); - auto fuse_node_vec2 = graph->FuseNodeKeepTopo({relu4, relu6, append_node_vec[1]}, {op_desc_new2}); - ASSERT_EQ(fuse_node_vec2.size(), 1); - GraphUtils::RemoveEdge(relu1->GetOutAnchor(0), relu4->GetInAnchor(0)); - builder.AddDataEdge(fuse_node_vec2.front(), 0, add_node, 1); - builder.AddDataEdge(relu1, 0, fuse_node_vec2.front(), 0); - graph->RemoveNode(relu3); - graph->RemoveNode(relu5); - graph->RemoveNode(relu4); - graph->RemoveNode(relu6); - graph->RemoveNode(append_node_vec[1]); - - // 由于存在大topo的有连边给小topo算子,所以顺序变化 - std::vector expected_stable_rdfs_names = - {"Data1", "Data2", "Relu1", "Relu2", "fuse_node", "append_node", "fuse_node2", "Add", "Add2", "Netoutput"}; - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector stable_rdfs_names; - for (auto &node : graph->GetDirectNode()) { - stable_rdfs_names.push_back(node->GetName()); - } - EXPECT_EQ(stable_rdfs_names, expected_stable_rdfs_names); - // 关闭稳定GE排序 - std::map options_map; - GetThreadLocalContext().SetGraphOption(options_map); -} - -TEST_F(UtestComputeGraph, FuseNodeKeepTopo_inherit_sk_option) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto data1 = builder.AddNode("Data1", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto relu1 = builder.AddNode("Relu1", "Relu", 1, 1); - auto relu2 = builder.AddNode("Relu2", "Relu", 1, 1); - auto relu3 = builder.AddNode("Relu3", "Relu", 1, 1); - auto relu4 = builder.AddNode("Relu4", "Relu", 1, 1); - auto relu5 = builder.AddNode("Relu5", "Relu", 1, 1); - auto relu6 = builder.AddNode("Relu6", "Relu", 1, 1); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto add2_node = builder.AddNode("Add2", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(data2, 0, relu2, 0); - builder.AddDataEdge(relu1, 0, relu3, 0); - builder.AddDataEdge(relu1, 0, relu4, 0); - builder.AddDataEdge(relu3, 0, relu5, 0); - builder.AddDataEdge(relu4, 0, relu6, 0); - builder.AddDataEdge(relu5, 0, add_node, 0); - builder.AddDataEdge(relu6, 0, add_node, 1); - builder.AddDataEdge(relu2, 0, add2_node, 0); - builder.AddDataEdge(add_node, 0, add2_node, 1); - builder.AddDataEdge(add2_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - auto op_desc_new = std::make_shared("fuse_node", "Relu"); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddOutputDesc(tensor_desc->Clone()); - AttrUtils::SetStr(relu3->GetOpDesc(), "_super_kernel_scope", "scope1"); - AttrUtils::SetStr(relu5->GetOpDesc(), "_super_kernel_scope", "scope2"); - auto fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 0); - - AttrUtils::SetStr(relu5->GetOpDesc(), "_super_kernel_scope", "scope1"); - fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 1); - - AttrUtils::SetStr(relu3->GetOpDesc(), "_super_kernel_options", "option1"); - AttrUtils::SetStr(relu5->GetOpDesc(), "_super_kernel_options", "option2"); - fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 0); - - AttrUtils::SetStr(relu5->GetOpDesc(), "_super_kernel_options", "option1"); - fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 1); -} - -TEST_F(UtestComputeGraph, FuseNodeKeepTopo_inherit_coreNum_option) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto data1 = builder.AddNode("Data1", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto relu1 = builder.AddNode("Relu1", "Relu", 1, 1); - auto relu2 = builder.AddNode("Relu2", "Relu", 1, 1); - auto relu3 = builder.AddNode("Relu3", "Relu", 1, 1); - auto relu4 = builder.AddNode("Relu4", "Relu", 1, 1); - auto relu5 = builder.AddNode("Relu5", "Relu", 1, 1); - auto relu6 = builder.AddNode("Relu6", "Relu", 1, 1); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto add2_node = builder.AddNode("Add2", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(data2, 0, relu2, 0); - builder.AddDataEdge(relu1, 0, relu3, 0); - builder.AddDataEdge(relu1, 0, relu4, 0); - builder.AddDataEdge(relu3, 0, relu5, 0); - builder.AddDataEdge(relu4, 0, relu6, 0); - builder.AddDataEdge(relu5, 0, add_node, 0); - builder.AddDataEdge(relu6, 0, add_node, 1); - builder.AddDataEdge(relu2, 0, add2_node, 0); - builder.AddDataEdge(add_node, 0, add2_node, 1); - builder.AddDataEdge(add2_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - auto op_desc_new = std::make_shared("fuse_node", "Relu"); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddOutputDesc(tensor_desc->Clone()); - AttrUtils::SetStr(relu3->GetOpDesc(), ge::public_attr::OP_AI_CORE_NUM, "5"); - AttrUtils::SetStr(relu5->GetOpDesc(), ge::public_attr::OP_AI_CORE_NUM, "6"); - auto fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 0); - - AttrUtils::SetStr(relu5->GetOpDesc(), ge::public_attr::OP_AI_CORE_NUM, "5"); - fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 1); - - AttrUtils::SetStr(relu3->GetOpDesc(), ge::public_attr::OP_VECTOR_CORE_NUM, "10"); - AttrUtils::SetStr(relu5->GetOpDesc(), ge::public_attr::OP_VECTOR_CORE_NUM, "11"); - fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 0); - - AttrUtils::SetStr(relu5->GetOpDesc(), ge::public_attr::OP_VECTOR_CORE_NUM, "10"); - fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - ASSERT_EQ(fuse_node_vec.size(), 1); -} - -TEST_F(UtestComputeGraph, StreamLableNotSame_FuseNodeKeepTopo_failed) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto data1 = builder.AddNode("Data1", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto relu1 = builder.AddNode("Relu1", "Relu", 1, 1); - auto relu2 = builder.AddNode("Relu2", "Relu", 1, 1); - auto relu3 = builder.AddNode("Relu3", "Relu", 1, 1); - AttrUtils::SetStr(relu3->GetOpDesc(), public_attr::USER_STREAM_LABEL, "test_stream1"); - auto relu4 = builder.AddNode("Relu4", "Relu", 1, 1); - auto relu5 = builder.AddNode("Relu5", "Relu", 1, 1); - AttrUtils::SetStr(relu5->GetOpDesc(), public_attr::USER_STREAM_LABEL, "test_stream2"); - auto relu6 = builder.AddNode("Relu6", "Relu", 1, 1); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto add2_node = builder.AddNode("Add2", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(data2, 0, relu2, 0); - builder.AddDataEdge(relu1, 0, relu3, 0); - builder.AddDataEdge(relu1, 0, relu4, 0); - builder.AddDataEdge(relu3, 0, relu5, 0); - builder.AddDataEdge(relu4, 0, relu6, 0); - builder.AddDataEdge(relu5, 0, add_node, 0); - builder.AddDataEdge(relu6, 0, add_node, 1); - builder.AddDataEdge(relu2, 0, add2_node, 0); - builder.AddDataEdge(add_node, 0, add2_node, 1); - builder.AddDataEdge(add2_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - auto op_desc_new = std::make_shared("fuse_node", "Relu"); - op_desc_new->AddInputDesc(tensor_desc->Clone()); - op_desc_new->AddOutputDesc(tensor_desc->Clone()); - std::string not_support_reason; - EXPECT_FALSE(graph->IsSupportFuse({relu3, relu5},not_support_reason)); - EXPECT_TRUE(not_support_reason.find("test_stream1") > 0); - EXPECT_TRUE(not_support_reason.find("test_stream2") > 0); - auto fuse_node_vec = graph->FuseNodeKeepTopo({relu3, relu5}, {op_desc_new}); - EXPECT_TRUE(fuse_node_vec.empty()); - std::string inherited_stream_label; - AttrUtils::GetStr(op_desc_new, public_attr::USER_STREAM_LABEL, inherited_stream_label); - EXPECT_TRUE(inherited_stream_label.empty()); -} - -TEST_F(UtestComputeGraph, RemoveNode_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - auto op_desc = std::make_shared("node1", "node1"); - op_desc->AddInputDesc(tensor_desc->Clone()); - auto node = graph->AddNode(op_desc); - - EXPECT_EQ(graph->RemoveNode(node), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, GraphMembersAreEqual_success) { - auto graph1 = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto op_desc = std::make_shared("node1", "node1"); - op_desc->AddInputDesc(tensor_desc->Clone()); - graph1->AddNode(op_desc); - graph1->AddNode(op_desc); - - auto graph2 = std::make_shared("graph"); - graph2->AddNode(op_desc); - EXPECT_EQ(graph1->GraphMembersAreEqual(*(graph2)), false); - graph2->AddNode(op_desc); - EXPECT_EQ(graph1->GraphMembersAreEqual(*(graph2)), true); -} - -TEST_F(UtestComputeGraph, GraphAttrsAreEqual_success) { - auto graph1 = std::make_shared("graph1"); - - int64_t val = 0; - AnyValue anyvalue; - anyvalue.SetValue(val); - graph1->SetAttr("test", anyvalue); - - auto graph2 = std::make_shared("graph2"); - EXPECT_EQ(graph1->GraphAttrsAreEqual(*(graph2)), false); - - graph2->SetAttr("test", anyvalue); - EXPECT_EQ(graph1->GraphAttrsAreEqual(*(graph2)), true); -} - -TEST_F(UtestComputeGraph, VectorInputNodePtrIsEqual_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto op_desc = std::make_shared("node1", "node1"); - op_desc->AddInputDesc(tensor_desc->Clone()); - auto node = graph->AddNode(op_desc); - - std::vector leftnodes{node}; - std::vector rightnodes{node}; - EXPECT_EQ(graph->VectorInputNodePtrIsEqual(leftnodes, rightnodes), true); - rightnodes.push_back(node); - EXPECT_EQ(graph->VectorInputNodePtrIsEqual(leftnodes, rightnodes), false); -} - -TEST_F(UtestComputeGraph, RemoveConstInput_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - - auto op_desc = std::make_shared("node1", CONSTANT); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddOutputDesc(tensor_desc->Clone()); - - auto node1 = graph->AddNode(op_desc); - auto node2 = graph->AddNode(op_desc); - GraphUtils::AddEdge(node1->GetOutControlAnchor(), node2->GetInControlAnchor()); - EXPECT_EQ(graph->RemoveConstInput(node1), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, RemoveSubGraph_success) { - auto rootgraph = std::make_shared("rootgraph"); - auto subgraph = std::make_shared("subgraph"); - rootgraph->AddSubGraph(subgraph); - EXPECT_EQ(rootgraph->RemoveSubGraph(subgraph), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, Set_GetShareParamLayer_success) { - auto graph = std::make_shared("graph"); - std::map, std::vector> params_share_map{{{"test"},{"test"}}}; - graph->SetShareParamLayer(params_share_map); - EXPECT_EQ(graph->GetShareParamLayer().size(), 1); -} - -TEST_F(UtestComputeGraph, Set_GetGraphOutNodes_success) { - auto graph = std::make_shared("graph"); - std::map> out_nodes_map{{"test",{1}}}; - auto opdesc = std::make_shared(); - graph->SetGraphOutNodes(out_nodes_map); - EXPECT_EQ(graph->GetGraphOutNodes().size(), 1); - std::map> append_out_nodes_map{{"test2",{2}}}; - graph->AppendGraphOutNodes(append_out_nodes_map); - EXPECT_EQ(graph->GetGraphOutNodes().size(), 2); -} - -TEST_F(UtestComputeGraph, Set_GetOrigGraph_success) { - auto graph = std::make_shared("graph"); - auto origin_graph = std::make_shared("origin_graph"); - graph->SetOrigGraph(origin_graph); - EXPECT_NE(graph->GetOrigGraph(), nullptr); -} - -TEST_F(UtestComputeGraph, GetOutputSize_success) { - auto graph = std::make_shared("graph"); - auto nodes = std::make_shared(); - graph->AddOutputNode(nodes); - EXPECT_EQ(graph->GetOutputSize(), 1); -} - -TEST_F(UtestComputeGraph, GetInputSize_success) { - auto graph = std::make_shared("graph"); - auto nodes = std::make_shared(); - graph->AddInputNode(nodes); - EXPECT_EQ(graph->GetInputSize(), 1); -} - -TEST_F(UtestComputeGraph, Set_GetNeedIteration_success) { - auto graph = std::make_shared("graph"); - graph->SetNeedIteration(true); - EXPECT_EQ(graph->GetNeedIteration(), true); -} - -TEST_F(UtestComputeGraph, UpdateInputMapping_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto opdesc = std::make_shared(ATTR_NAME_PARENT_NODE_INDEX, DATA); - opdesc->AddInputDesc("name1", tensor_desc->Clone()); - opdesc->AddOutputDesc("name2", tensor_desc->Clone()); - auto node = graph->AddNode(opdesc); - ge::AttrUtils::SetInt(opdesc, ATTR_NAME_PARENT_NODE_INDEX, 1); - - graph->AddInputNode(node); - std::map input_mapping{{0,1}}; - EXPECT_EQ(graph->UpdateInputMapping(input_mapping), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, UpdateOutputMapping_success) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto opdesc = std::make_shared(ATTR_NAME_PARENT_NODE_INDEX, NETOUTPUT); - opdesc->AddInputDesc("name1", tensor_desc->Clone()); - opdesc->AddOutputDesc("name2", tensor_desc->Clone()); - auto node = graph->AddNode(opdesc); - ge::AttrUtils::SetInt(opdesc, ATTR_NAME_PARENT_NODE_INDEX, 1); - graph->AddOutputNode(node); - std::map output_mapping{{0,1}}; - EXPECT_EQ(graph->UpdateOutputMapping(output_mapping), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, ReorderEventNodes_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode(ATTR_NAME_PARENT_NODE_INDEX, SEND, 1, 1); - const auto &node2 = builder.AddNode(ATTR_NAME_PARENT_NODE_INDEX, RECV, 1, 1); - const auto &node3 = builder.AddNode(ATTR_NAME_PARENT_NODE_INDEX, RECV, 1, 1); - builder.AddControlEdge(node1, node2); - builder.AddControlEdge(node3, node1); - builder.AddControlEdge(node2, node3); - auto graph = builder.GetGraph(); - - EXPECT_EQ(graph->ReorderEventNodes(), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, DFSTopologicalSorting_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - std::vector vec_nodes{node1, node2, node3}; - - builder.AddControlEdge(node1, node2); - builder.AddControlEdge(node3, node1); - - std::vector stack{}; - auto graph = builder.GetGraph(); - std::map map_in_edge_num{}; - EXPECT_EQ(graph->DFSTopologicalSorting(vec_nodes, map_in_edge_num, stack, false), - GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, BFSTopologicalSorting_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - std::vector vec_nodes{node1, node2, node3}; - - builder.AddControlEdge(node1, node2); - builder.AddControlEdge(node3, node1); - - std::deque stack{}; - auto graph = builder.GetGraph(); - std::map map_in_edge_num{}; - EXPECT_EQ(graph->BFSTopologicalSorting(vec_nodes, map_in_edge_num, stack), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, CollectBreadthOutNode_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 2, 2); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - builder.AddDataEdge(node1, 0, node2, 0); - builder.AddDataEdge(node2, 0, node1, 0); - builder.AddControlEdge(node2, node1); - builder.AddControlEdge(node1, node3); - GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node3->GetInControlAnchor()); - std::map map_in_edge_num{}; - map_in_edge_num.emplace(node1, 2); - map_in_edge_num.emplace(node2, 1); - map_in_edge_num.emplace(node3, 1); - std::map breadth_node_map{}; - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->CollectBreadthOutNode(node1, map_in_edge_num, breadth_node_map), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, RemoveConstInputSuccess) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", CONSTANT, 2, 2); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - builder.AddDataEdge(node1, 0, node2, 0); - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->RemoveConstInput(node2), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, TopologicalSorting_success) { - const auto func = [](const NodePtr &node1, const NodePtr &node2) -> bool { return true; }; - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 0, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 0); - auto graph = builder.GetGraph(); - graph->TopologicalSorting(func); - EXPECT_EQ(node1->GetOpDesc()->GetId(), 1); - EXPECT_EQ(node2->GetOpDesc()->GetId(), 0); -} - -/* - * netoutput - * | \ \ - * node4 node5 node6 - * | \ - * node2 node3 - * \ / - * node1 - */ -TEST_F(UtestComputeGraph, TopologicalSortingMode_success) { - std::map options_map; - - auto graph = BuildNormalGraph("test_topo_graph"); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_bfs_names = {"node1", "node2", "node3", "node4", "node5", "node6", "netoutput"}; - std::vector expected_dfs_names = {"node1", "node3", "node5", "node2", "node4", "node6", "netoutput"}; - std::vector bfs_names; - std::vector dfs_names; - std::vector bfs_names1; - std::vector dfs_names1; - options_map.emplace("ge.topoSortingMode", "0"); - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - const auto &graph_bfs_topo = graph->GetAllNodes(); - for (auto &node : graph_bfs_topo) { - bfs_names.push_back(node->GetName()); - } - const auto &graph_bfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_bfs_topo1) { - bfs_names1.push_back(node->GetName()); - } - - options_map["ge.topoSortingMode"] = "1"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - options_map["ge.topoSortingMode"] = "2"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - - EXPECT_EQ(bfs_names, expected_bfs_names); - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(bfs_names1, expected_bfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -TEST_F(UtestComputeGraph, BFSTopologicalSortingInPriorityMode_success) { - std::map options_map; - - auto graph = BuildNormalGraph("test_bfs_topo_graph"); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_bfs_names = {"node1", "node2", "node3", "node4", "node5", "node6", "netoutput"}; - std::vector bfs_names; - std::vector bfs_names1; - options_map["ge.topoSortingMode"] = "0"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - const auto &graph_bfs_topo = graph->GetAllNodes(); - for (auto &node : graph_bfs_topo) { - bfs_names.push_back(node->GetName()); - } - const auto &graph_bfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_bfs_topo1) { - bfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(bfs_names, expected_bfs_names); - EXPECT_EQ(bfs_names1, expected_bfs_names); -} - -TEST_F(UtestComputeGraph, TrainTopologicalSortingInPriorityMode_BFS_success) { - std::map options_map; - - auto graph = BuildNormalGraph("test_train_topo_graph_bfs"); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_bfs_names = {"node1", "node2", "node3", "node4", "node5", "node6", "netoutput"}; - std::vector bfs_names; - std::vector bfs_names1; - options_map["ge.graphRunMode"] = "1"; // tarin - options_map["ge.topoSortingMode"] = ""; // no topo sort mode - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - const auto &graph_bfs_topo = graph->GetAllNodes(); - for (auto &node : graph_bfs_topo) { - bfs_names.push_back(node->GetName()); - } - const auto &graph_bfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_bfs_topo1) { - bfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(bfs_names, expected_bfs_names); - EXPECT_EQ(bfs_names1, expected_bfs_names); -} - -TEST_F(UtestComputeGraph, TrainAndInvalidTopologicalSortingInPriorityMode_BFS_success) { - std::map options_map; - - auto graph = BuildNormalGraph("test_train_topo_graph_with_invalid_sort_mode_bfs"); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_bfs_names = {"node1", "node2", "node3", "node4", "node5", "node6", "netoutput"}; - std::vector bfs_names; - std::vector bfs_names1; - options_map["ge.graphRunMode"] = "1"; // tarin - options_map["ge.topoSortingMode"] = "10"; // invalid topo sort mode - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - const auto &graph_bfs_topo = graph->GetAllNodes(); - for (auto &node : graph_bfs_topo) { - bfs_names.push_back(node->GetName()); - } - const auto &graph_bfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_bfs_topo1) { - bfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(bfs_names, expected_bfs_names); - EXPECT_EQ(bfs_names1, expected_bfs_names); -} - -TEST_F(UtestComputeGraph, DFSTopologicalSortingInPriorityMode_success) { - std::map options_map; - - auto graph = BuildNormalGraph("test_dfs_topo_graph"); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = {"node6", "node1", "node2", "node3", "node4", "node5", "netoutput"}; - std::vector dfs_names; - std::vector dfs_names1; - options_map["ge.topoSortingMode"] = "1"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -TEST_F(UtestComputeGraph, NotTrainTopologicalSortingInPriorityMode_DFS_success) { - std::map options_map; - - auto graph = BuildNormalGraph("test_dfs_not_train_topo_graph_dfx"); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = {"node6", "node1", "node2", "node3", "node4", "node5", "netoutput"}; - std::vector dfs_names; - std::vector dfs_names1; - options_map["ge.graphRunMode"] = "0"; // not tarin - options_map["ge.topoSortingMode"] = ""; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -TEST_F(UtestComputeGraph, ReverseDfsTopologicalSortingInPriorityMode_success) { - std::map options_map; - - auto graph = BuildNormalGraph("test_reverse_dfs_topo_graph"); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = {"node6", "node1", "node2", "node4", "node3", "node5", "netoutput"}; - std::vector dfs_names; - std::vector dfs_names1; - options_map["ge.topoSortingMode"] = "1"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSortingGraph(true), GRAPH_SUCCESS); - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -TEST_F(UtestComputeGraph, SortNodes_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - const auto &node4 = builder.AddNode("node4", "node4", 1, 0); - - builder.AddControlEdge(node1, node2); - builder.AddControlEdge(node3, node1); - builder.AddControlEdge(node2, node4); - auto graph = builder.GetGraph(); - std::map map_in_edge_num{{node1, 2},{node2, 2},{node3, 2}}; - std::vector stack{}; - EXPECT_EQ(graph->SortNodes(stack, map_in_edge_num), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, GetInEdgeSize_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 2, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 1); - const auto &node3 = builder.AddNode("node3", "node3", 0, 1); - builder.AddDataEdge(node2, 0, node1, 0); - builder.AddDataEdge(node3, 0, node1, 1); - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->GetInEdgeSize(node1), 2); -} - -TEST_F(UtestComputeGraph, GetOutEdgeSize_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 0, 2); - const auto &node2 = builder.AddNode("node2", "node2", 1, 0); - const auto &node3 = builder.AddNode("node3", "node3", 1, 0); - builder.AddDataEdge(node1, 0, node2, 0); - builder.AddDataEdge(node1, 1, node3, 0); - auto graph = builder.GetGraph(); - graph->Dump(); - EXPECT_EQ(graph->GetOutEdgeSize(node1), 2); -} - -TEST_F(UtestComputeGraph, IsValid_success) { - auto graph = std::make_shared("graph"); - EXPECT_EQ(graph->IsValid(), false); -} - -TEST_F(UtestComputeGraph, InValid_success) { - const auto func = [](const NodePtr &node1, const NodePtr &node2) -> bool { return true; }; - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 0, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 0); - auto graph = builder.GetGraph(); - graph->TopologicalSorting(func); - EXPECT_EQ(graph->IsValid(), true); - graph->InValid(); - EXPECT_EQ(graph->IsValid(), false); -} - -TEST_F(UtestComputeGraph, Swap_success) { - auto builder1 = ut::GraphBuilder("graph1"); - const auto &node1 = builder1.AddNode("node1", "node1", 0, 0); - auto graph1 = builder1.GetGraph(); - auto builder2 = ut::GraphBuilder("graph2"); - const auto &node2 = builder2.AddNode("node2", "node2", 0, 0); - const auto &node3 = builder2.AddNode("node3", "node3", 0, 0); - auto graph2 = builder2.GetGraph(); - - graph1->Swap(*(graph2)); - EXPECT_EQ(graph1->GetNodes(false).size(), 2); - EXPECT_EQ(graph2->GetNodes(false).size(), 1); - EXPECT_EQ(graph1->GetName(), "graph2"); - EXPECT_EQ(graph2->GetName(), "graph1"); -} - -TEST_F(UtestComputeGraph, Swap_with_subgraph_success) { - auto graph1 = BuildMainGraphWithIf("root_graph_1"); - auto graph2 = BuildMainGraphWithIf("root_graph_2"); - - graph1->Swap(*(graph2)); - auto if_node_1 = graph1->FindFirstNodeMatchType("If"); - ASSERT_NE(if_node_1, nullptr); - auto if_node_2 = graph2->FindFirstNodeMatchType("If"); - ASSERT_NE(if_node_2, nullptr); - EXPECT_EQ(graph1->GetName(), "root_graph_2"); - EXPECT_EQ(graph2->GetName(), "root_graph_1"); - EXPECT_EQ(if_node_1->GetOwnerComputeGraph()->GetName(), "root_graph_2"); - EXPECT_EQ(if_node_2->GetOwnerComputeGraph()->GetName(), "root_graph_1"); - - const auto if_1_subgraph_name = if_node_1->GetOpDesc()->GetSubgraphInstanceName(0); - const auto if_1_subgraph = graph1->GetSubgraph(if_1_subgraph_name); - ASSERT_NE(if_1_subgraph, nullptr); - EXPECT_EQ(if_1_subgraph->GetParentGraph()->GetName(), graph1->GetName()); - - const auto if_2_subgraph_name = if_node_2->GetOpDesc()->GetSubgraphInstanceName(0); - const auto if_2_subgraph = graph2->GetSubgraph(if_2_subgraph_name); - ASSERT_NE(if_2_subgraph, nullptr); - EXPECT_EQ(if_2_subgraph->GetParentGraph()->GetName(), graph2->GetName()); - EXPECT_EQ(graph1->GetAllNodesPtr().size() > graph1->GetDirectNodePtr().size(), true); -} - -TEST_F(UtestComputeGraph, InsertToNodeList_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 0, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 0); - const auto &node3 = builder.AddNode("node3", "node1", 0, 0); - auto graph = builder.GetGraph(); - graph->impl_->InsertToNodeList(graph->impl_->nodes_.begin(), node3); - EXPECT_EQ(*(graph->impl_->nodes_.begin()), node3); -} - -TEST_F(UtestComputeGraph, PushBackToNodeList_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 0, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 0); - const auto &node3 = builder.AddNode("node3", "node3", 0, 0); - auto graph = builder.GetGraph(); - graph->impl_->PushBackToNodeList(node1); - auto node_list = graph->GetDirectNode(); - EXPECT_EQ(*(node_list.end() - 1), node1); - auto node_list_ptr = graph->GetDirectNodePtr(); - EXPECT_EQ(*(node_list_ptr.end() - 1), node1.get()); -} - -TEST_F(UtestComputeGraph, EmplaceBackToNodeList_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 0, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 0); - const auto &node3 = builder.AddNode("node3", "node1", 0, 0); - auto graph = builder.GetGraph(); - graph->impl_->EmplaceBackToNodeList(node1); - auto node_list = graph->GetDirectNode(); - EXPECT_EQ(*(node_list.end() - 1), node1); - auto node_list_ptr = graph->GetDirectNodePtr(); - EXPECT_EQ(*(node_list_ptr.end() - 1), node1.get()); -} - -TEST_F(UtestComputeGraph, ClearNodeList_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 0, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 0); - const auto &node3 = builder.AddNode("node3", "node1", 0, 0); - auto graph = builder.GetGraph(); - graph->ClearNodeList(); - EXPECT_EQ(graph->GetDirectNode().size(), 0); - EXPECT_EQ(graph->GetDirectNodePtr().size(), 0); -} - -TEST_F(UtestComputeGraph, IsolateNode_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 2, 2); - const auto &node2 = builder.AddNode("node2", "node2", 0, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 0); - const auto &node4 = builder.AddNode("node4", "node4", 0, 1); - const auto &node5 = builder.AddNode("node5", "node5", 1, 0); - builder.AddDataEdge(node2, 0, node1, 0); - builder.AddDataEdge(node1, 0, node3, 0); - builder.AddControlEdge(node1, node4); - builder.AddControlEdge(node5, node1); - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->IsolateNode(node1), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, RemoveExtraOutEdge_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node2", "node2", 0, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 0); - builder.AddControlEdge(node1, node2); - builder.AddControlEdge(node3, node1); - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->RemoveExtraOutEdge(node1), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, InferOriginFormat_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 1); - builder.AddDataEdge(node1, 0, node2, 0); - auto graph = builder.GetGraph(); - EXPECT_EQ(GraphUtilsEx::InferOriginFormat(graph), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, InferShapeInNeed_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 0); - const auto &node2 = builder.AddNode("node2", "node2", 0, 1); - builder.AddDataEdge(node1, 0, node2, 0); - auto graph = builder.GetGraph(); - EXPECT_EQ(GraphUtilsEx::InferShapeInNeed(graph), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, SetSessionID_success) { - auto graph = std::make_shared("graph"); - auto session_id = graph->GetSessionID() + 1; - graph->SetSessionID(session_id); - EXPECT_EQ(graph->GetSessionID(), session_id); -} - -TEST_F(UtestComputeGraph, SetGraphID_success) { - auto graph = std::make_shared("graph"); - auto graph_id = graph->GetGraphID() + 1; - graph->SetGraphID(graph_id); - EXPECT_EQ(graph->GetGraphID(), graph_id); - auto empty_graph = std::make_shared(nullptr); - EXPECT_NE(empty_graph, nullptr); -} - -TEST_F(UtestComputeGraph, SetSummaryGraph_success) { - auto graph = std::make_shared("graph"); - auto summary_flag = !graph->IsSummaryGraph(); - graph->SetSummaryFlag(summary_flag); - EXPECT_EQ(graph->IsSummaryGraph(), summary_flag); - -} -namespace { -void BuildComplexGraph(ut::GraphBuilder &builder) { - const auto &node1 = builder.AddNode("node1", "node1", 0, 3, FORMAT_NCHW, DT_FLOAT, {1, 1}); - const auto &node2 = builder.AddNode("node2", "node2", 0, 3, FORMAT_NCHW, DT_FLOAT, {1, 1}); - const auto &node3 = builder.AddNode("node3", "node3", 0, 3, FORMAT_NCHW, DT_FLOAT, {1, 1}); - const auto &node4 = builder.AddNode("node4", "node4", 0, 3, FORMAT_NCHW, DT_FLOAT, {1, 1}); - const auto &node5 = builder.AddNode("node5", "node5", 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 2}); - const auto &node6 = builder.AddNode("node6", "node6", 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 3}); - const auto &node7 = builder.AddNode("node7", "node7", 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 4}); - const auto &node8 = builder.AddNode("node8", "node8", 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 5}); - const auto &node9 = builder.AddNode("node9", "node9", 1, 1, FORMAT_NCHW, DT_FLOAT, {2, 2}); - const auto &node10 = builder.AddNode("node10", "node10", 1, 1, FORMAT_NCHW, DT_FLOAT, {2, 3}); - const auto &node11 = builder.AddNode("node11", "node11", 1, 1, FORMAT_NCHW, DT_FLOAT, {2, 3}); - const auto &node12 = builder.AddNode("node12", "node12", 1, 1, FORMAT_NCHW, DT_FLOAT, {2, 4}); - const auto &node13 = builder.AddNode("node13", "node13", 1, 1, FORMAT_NCHW, DT_FLOAT, {3, 2}); - const auto &node14 = builder.AddNode("node14", "node14", 1, 1, FORMAT_NCHW, DT_FLOAT, {3, 3}); - const auto &node15 = builder.AddNode("node15", "node15", 1, 1, FORMAT_NCHW, DT_FLOAT, {3, 4}); - const auto &node16 = builder.AddNode("node16", "node16", 1, 1, FORMAT_NCHW, DT_FLOAT, {3, 5}); - const auto &node17 = builder.AddNode("node17", "node17", 4, 1, FORMAT_NCHW, DT_FLOAT, {4, 2}); - const auto &node18 = builder.AddNode("node18", "node18", 4, 1, FORMAT_NCHW, DT_FLOAT, {4, 2}); - const auto &node19 = builder.AddNode("node19", "node19", 4, 1, FORMAT_NCHW, DT_FLOAT, {4, 2}); - const auto &node20 = builder.AddNode("node20", "node20", 3, 0); - - builder.AddDataEdge(node1, 0, node5, 0); - builder.AddDataEdge(node1, 1, node9, 0); - builder.AddDataEdge(node1, 2, node13, 0); - builder.AddDataEdge(node2, 0, node6, 0); - builder.AddDataEdge(node2, 1, node10, 0); - builder.AddDataEdge(node2, 2, node14, 0); - builder.AddDataEdge(node3, 0, node7, 0); - builder.AddDataEdge(node3, 1, node11, 0); - builder.AddDataEdge(node3, 2, node15, 0); - builder.AddDataEdge(node4, 0, node8, 0); - builder.AddDataEdge(node4, 1, node12, 0); - builder.AddDataEdge(node4, 2, node16, 0); - builder.AddDataEdge(node5, 0, node17, 0); - builder.AddDataEdge(node6, 0, node17, 1); - builder.AddDataEdge(node7, 0, node17, 2); - builder.AddDataEdge(node8, 0, node17, 3); - builder.AddDataEdge(node9, 0, node18, 0); - builder.AddDataEdge(node10, 0, node18, 1); - builder.AddDataEdge(node11, 0, node18, 2); - builder.AddDataEdge(node12, 0, node18, 3); - builder.AddDataEdge(node13, 0, node19, 0); - builder.AddDataEdge(node14, 0, node19, 1); - builder.AddDataEdge(node15, 0, node19, 2); - builder.AddDataEdge(node16, 0, node19, 3); - - builder.AddControlEdge(node17, node20); - builder.AddControlEdge(node18, node20); - builder.AddControlEdge(node19, node20); -} -void VerifyTopoSortingResult(const ComputeGraphPtr &graph) { - const std::map expected_ids = { - {"node1", 0}, {"node2", 2}, {"node3", 4}, {"node4", 6}, - {"node5", 1}, {"node6", 3}, {"node7", 5}, {"node8", 7}, - {"node9", 9}, {"node10", 10}, {"node11", 11}, {"node12", 12}, - {"node13", 14}, {"node14", 15}, {"node15", 16}, {"node16", 17}, - {"node17", 8}, {"node18", 13}, {"node19", 18}, {"node20", 19} - }; - - for (const auto &pair : expected_ids) { - const auto node = graph->FindNode(pair.first); - ASSERT_NE(node, nullptr) << "Missing node: " << pair.first; - EXPECT_EQ(node->GetOpDesc()->GetId(), pair.second) - << "ID mismatch for node: " << pair.first; - } -} - -void TestTopoSortingWithOptions(TopoSortingMode mode) { - auto builder = ut::GraphBuilder("graph_reverse_dfs"); - BuildComplexGraph(builder); - - std::map options; - options["ge.topoSortingMode"] = std::to_string(static_cast(mode)); - GetThreadLocalContext().SetGraphOption(options); - - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - VerifyTopoSortingResult(graph); -} - -void TestTopoSortingWithDirectArg(TopoSortingMode mode) { - auto builder = ut::GraphBuilder("graph_reverse_dfs"); - BuildComplexGraph(builder); - - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->TopologicalSorting(mode), GRAPH_SUCCESS); - VerifyTopoSortingResult(graph); -} -} -TEST_F(UtestComputeGraph, DFSPOSTORDERTopologicalSorting_success) { - TestTopoSortingWithOptions(TopoSortingMode::kRDFS); -} - -TEST_F(UtestComputeGraph, DFSPOSTORDERTopologicalSorting_by_arg_success) { - TestTopoSortingWithDirectArg(TopoSortingMode::kRDFS); -} - -TEST_F(UtestComputeGraph, DynamicShapeGraph_DFSPOSTORDERTopologicalSorting_success) { - auto builder = ut::GraphBuilder("graph_reverse_dfs"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1, FORMAT_NCHW, DT_FLOAT, {-1, 1}); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1, FORMAT_NCHW, DT_FLOAT, {-1, 1}); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1, FORMAT_NCHW, DT_FLOAT, {-1, 1}); - - builder.AddDataEdge(node1, 0, node2, 0); - builder.AddDataEdge(node2, 0, node3, 0); - - GetThreadLocalContext().SetGraphOption({}); - auto graph = builder.GetGraph(); - ASSERT_TRUE(AttrUtils::SetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true)); - std::map options_map; - options_map["ge.topoSortingMode"] = "2"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - - EXPECT_EQ(node1->GetOpDesc()->GetId(), 0); - EXPECT_EQ(node2->GetOpDesc()->GetId(), 1); - EXPECT_EQ(node3->GetOpDesc()->GetId(), 2); -} - -TEST_F(UtestComputeGraph, DFSPOSTORDERTopologicalSorting_ringing_fail) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - - builder.AddDataEdge(node1, 0, node2, 0); - builder.AddDataEdge(node2, 0, node3, 0); - builder.AddDataEdge(node3, 0, node1, 0); - - auto graph = builder.GetGraph(); - std::map options_map; - options_map["ge.topoSortingMode"] = "2"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_NE(graph->TopologicalSorting(), GRAPH_SUCCESS); -} - -TEST_F(UtestComputeGraph, DelayTopologicalSorting) { - auto graph = BuildDelayTopoGraph("test_delay_topo_graph"); - std::map options_map; - options_map["ge.topoSortingMode"] = "2"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = {"variable", "data", "node2", "node3", "node4", "node1", "node5"}; - std::vector dfs_names; - std::vector dfs_names1; - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -// 校验稳定的RDFS的顺序和原始的DFS顺序一样 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting1) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为默认的DFS - auto graph = BuildDelayTopoGraph("test_stable_rdfs_graph"); - auto origin_dfs_nodes = graph->GetDirectNode(); - options_map["ge.topoSortingMode"] = "3"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用Stable RDFS - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_rdfs_names = {"variable", "node2", "node1", "data", "node3", "node4", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(origin_dfs_nodes.size(), after_rdfs_topo.size()); - EXPECT_EQ(origin_dfs_nodes.size(), expected_rdfs_names.size()); - // 因为之前的topo顺序是对的,所以预期跟原来的topo一样 - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), after_rdfs_topo.at(i)->GetName()); - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -// 校验稳定的RDFS的顺序和原始的BFS顺序一样 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting2) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为BFS - options_map["ge.topoSortingMode"] = "0"; - GetThreadLocalContext().SetGraphOption(options_map); - - auto graph = BuildDelayTopoGraph("test_stable_rdfs_graph"); - auto origin_dfs_nodes = graph->GetDirectNode(); - options_map["ge.topoSortingMode"] = "3"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用Stable RDFS - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_rdfs_names = {"variable", "node1", "node2", "data", "node3", "node4", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(origin_dfs_nodes.size(), after_rdfs_topo.size()); - EXPECT_EQ(origin_dfs_nodes.size(), expected_rdfs_names.size()); - // 因为之前的topo顺序是对的,所以预期跟原来的topo一样 - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), after_rdfs_topo.at(i)->GetName()); - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -// 校验一种之前会报错的场景 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting3) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为BFS - options_map["ge.topoSortingMode"] = "0"; - GetThreadLocalContext().SetGraphOption(options_map); - - auto graph = BuildDelayTopoGraphSkipInput("test_stable_rdfs_graph"); - auto origin_dfs_nodes = graph->GetDirectNode(); - options_map["ge.topoSortingMode"] = "3"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用Stable RDFS - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_rdfs_names = {"variable", "node1", "data", "node3", "node4", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(origin_dfs_nodes.size(), after_rdfs_topo.size()); - EXPECT_EQ(origin_dfs_nodes.size(), expected_rdfs_names.size()); - // 因为之前的topo顺序是对的,所以预期跟原来的topo一样 - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), after_rdfs_topo.at(i)->GetName()); - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -// Topo序调整后最大程度跟之前保持一致, 体现在noop插入节点之前的顺序保持跟之前的一致 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting1_1) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为默认的DFS - auto graph = BuildDelayTopoGraph("test_stable_rdfs_graph"); - OpDescPtr insert_op1 = ge::ComGraphMakeShared("noop1", "noop"); - OpDescPtr insert_op2 = ge::ComGraphMakeShared("noop2", "noop"); - // 改了图,topo需要重排 - auto insert_node1 = graph->AddNode(insert_op1); - auto insert_node2 = graph->AddNode(insert_op2); - GraphUtils::AddEdge(graph->FindNode("data")->GetOutControlAnchor(), insert_node1->GetInControlAnchor()); - GraphUtils::AddEdge(insert_node1->GetOutControlAnchor(), graph->FindNode("node4")->GetInControlAnchor()); - GraphUtils::AddEdge(graph->FindNode("data")->GetOutControlAnchor(), insert_node2->GetInControlAnchor()); - GraphUtils::AddEdge(insert_node2->GetOutControlAnchor(), graph->FindNode("node4")->GetInControlAnchor()); - std::vector - wrong_dfs_names = {"variable", "node2", "node1", "data", "node3", "node4", "node5", "noop1", "noop2"}; - - auto origin_dfs_nodes = graph->GetDirectNode(); - for (size_t i = 0; i < origin_dfs_nodes.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), wrong_dfs_names.at(i)); - } - options_map["ge.topoSortingMode"] = "3"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用Stable RDFS - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - // topo调整后最大程度跟之前保持一致 - std::vector - expected_rdfs_names = {"variable", "node2", "node1", "data", "node3", "noop1", "noop2", "node4", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(after_rdfs_topo.size(), expected_rdfs_names.size()); - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - - -// StableRDFSTopologicalSorting1_1用例的对照组 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting1_2) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为默认的DFS - auto graph = BuildDelayTopoGraph("test_stable_rdfs_graph"); - OpDescPtr insert_op1 = ge::ComGraphMakeShared("noop1", "noop"); - OpDescPtr insert_op2 = ge::ComGraphMakeShared("noop2", "noop"); - // 改了图,topo需要重排 - auto insert_node1 = graph->AddNode(insert_op1); - auto insert_node2 = graph->AddNode(insert_op2); - GraphUtils::AddEdge(graph->FindNode("data")->GetOutControlAnchor(), insert_node1->GetInControlAnchor()); - GraphUtils::AddEdge(insert_node1->GetOutControlAnchor(), graph->FindNode("node4")->GetInControlAnchor()); - GraphUtils::AddEdge(graph->FindNode("data")->GetOutControlAnchor(), insert_node2->GetInControlAnchor()); - GraphUtils::AddEdge(insert_node2->GetOutControlAnchor(), graph->FindNode("node4")->GetInControlAnchor()); - std::vector - wrong_dfs_names = {"variable", "node2", "node1", "data", "node3", "node4", "node5", "noop1", "noop2"}; - - auto origin_dfs_nodes = graph->GetDirectNode(); - for (size_t i = 0; i < origin_dfs_nodes.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), wrong_dfs_names.at(i)); - } - options_map["ge.topoSortingMode"] = "1"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用DFS - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - // noop插入之后,造成了不依赖noop的node3的顺序排在了跟noop之后了 - std::vector - expected_rdfs_names = {"variable", "node2", "node1", "data", "noop2", "noop1", "node3", "node4", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(after_rdfs_topo.size(), expected_rdfs_names.size()); - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -// 融合场景 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting2_1) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为默认的DFS - auto graph = BuildDelayTopoGraphMultiInput("test_stable_rdfs_graph"); - std::vector - origin_dfs_names = {"const", "constant", "variable", "node2", "node1", "data", "node3", "node4", "node5"}; - auto origin_dfs_nodes = graph->GetDirectNode(); - for (size_t i = 0; i < origin_dfs_nodes.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), origin_dfs_names.at(i)); - } - // 改图 - // node2-node3融合 - OpDescPtr fusion_2_3 = ge::ComGraphMakeShared("fusion_2_3", "fusion"); - EXPECT_FALSE(fusion_2_3 == nullptr); - fusion_2_3->AddInputDesc(GeTensorDesc()); - fusion_2_3->AddInputDesc(GeTensorDesc()); - fusion_2_3->AddOutputDesc(GeTensorDesc()); - fusion_2_3->AddOutputDesc(GeTensorDesc()); - auto fusion_2_3_node = graph->AddNode(fusion_2_3); - GraphUtils::ReplaceNodeDataAnchors(fusion_2_3_node, graph->FindNode("node2"), {0}, {0}); - GraphUtils::ReplaceNodeDataAnchors(fusion_2_3_node, graph->FindNode("node3"), {-1, 0}, {-1, 0}); - graph->RemoveNode(graph->FindNode("node2")); - graph->RemoveNode(graph->FindNode("node3")); - - options_map["ge.topoSortingMode"] = "3"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用StableRDFS,体现在"const", "constant", "variable"顺序稳定 - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector - expected_rdfs_names = {"const", "constant", "variable", "node1", "data", "fusion_2_3", "node4", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(after_rdfs_topo.size(), expected_rdfs_names.size()); - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -// 融合场景StableRDFSTopologicalSorting2_1对照组 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting2_2) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为默认的DFS - auto graph = BuildDelayTopoGraphMultiInput("test_stable_rdfs_graph"); - std::vector - origin_dfs_names = {"const", "constant", "variable", "node2", "node1", "data", "node3", "node4", "node5"}; - auto origin_dfs_nodes = graph->GetDirectNode(); - for (size_t i = 0; i < origin_dfs_nodes.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), origin_dfs_names.at(i)); - } - // 改图 - // node2-node3融合 - OpDescPtr fusion_2_3 = ge::ComGraphMakeShared("fusion_2_3", "fusion"); - EXPECT_FALSE(fusion_2_3 == nullptr); - fusion_2_3->AddInputDesc(GeTensorDesc()); - fusion_2_3->AddInputDesc(GeTensorDesc()); - fusion_2_3->AddOutputDesc(GeTensorDesc()); - fusion_2_3->AddOutputDesc(GeTensorDesc()); - auto fusion_2_3_node = graph->AddNode(fusion_2_3); - GraphUtils::ReplaceNodeDataAnchors(fusion_2_3_node, graph->FindNode("node2"), {0}, {0}); - GraphUtils::ReplaceNodeDataAnchors(fusion_2_3_node, graph->FindNode("node3"), {-1, 0}, {-1, 0}); - graph->RemoveNode(graph->FindNode("node2")); - graph->RemoveNode(graph->FindNode("node3")); - - options_map["ge.topoSortingMode"] = "2"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用RDFS - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector - expected_rdfs_names = {"data", "variable", "fusion_2_3", "const", "constant", "node4", "node1", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(after_rdfs_topo.size(), expected_rdfs_names.size()); - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -// 删除场景 -// 删除constant节点,node1的对应输入变成node4的输出 -/* - * constant const variable data - * \ | / \ | - * node1 node2 node3 - * | | | - * | | node4 - * \ | / - * node5 - */ - - -/* - * data - * \ - * node3 - * \ - * node4 const variable - * \ | / \ - * node1 node2 - * | | - * | | - * \ | - * node5 - */ -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting2_3) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为默认的DFS - auto graph = BuildDelayTopoGraphMultiInput("test_stable_rdfs_graph"); - std::vector - origin_dfs_names = {"const", "constant", "variable", "node2", "node1", "data", "node3", "node4", "node5"}; - auto origin_dfs_nodes = graph->GetDirectNode(); - for (size_t i = 0; i < origin_dfs_nodes.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), origin_dfs_names.at(i)); - } - // 改图 - GraphUtils::RemoveEdge(graph->FindNode("node4")->GetOutDataAnchor(0), graph->FindNode("node5")->GetInDataAnchor(2)); - GraphUtils::RemoveEdge(graph->FindNode("constant")->GetOutDataAnchor(0), - graph->FindNode("node1")->GetInDataAnchor(1)); - GraphUtils::AddEdge(graph->FindNode("node4")->GetOutDataAnchor(0), graph->FindNode("node1")->GetInDataAnchor(1)); - graph->RemoveNode(graph->FindNode("constant")); - - options_map["ge.topoSortingMode"] = "3"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用StableRDFS, 体现在"const", "variable", "node2"顺序稳定 - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector - expected_rdfs_names = {"const", "variable", "node2", "data", "node3", "node4", "node1", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(after_rdfs_topo.size(), expected_rdfs_names.size()); - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} -// StableRDFSTopologicalSorting2_3对照组 -TEST_F(UtestComputeGraph, StableRDFSTopologicalSorting2_4) { - // set up - std::map options_map; - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); - // 构图的原始topo为默认的DFS - auto graph = BuildDelayTopoGraphMultiInput("test_stable_rdfs_graph"); - std::vector - origin_dfs_names = {"const", "constant", "variable", "node2", "node1", "data", "node3", "node4", "node5"}; - auto origin_dfs_nodes = graph->GetDirectNode(); - for (size_t i = 0; i < origin_dfs_nodes.size(); ++i) { - EXPECT_EQ(origin_dfs_nodes.at(i)->GetName(), origin_dfs_names.at(i)); - } - // 改图 - GraphUtils::RemoveEdge(graph->FindNode("node4")->GetOutDataAnchor(0), graph->FindNode("node5")->GetInDataAnchor(2)); - GraphUtils::RemoveEdge(graph->FindNode("constant")->GetOutDataAnchor(0), - graph->FindNode("node1")->GetInDataAnchor(1)); - GraphUtils::AddEdge(graph->FindNode("node4")->GetOutDataAnchor(0), graph->FindNode("node1")->GetInDataAnchor(1)); - graph->RemoveNode(graph->FindNode("constant")); - - options_map["ge.topoSortingMode"] = "2"; - GetThreadLocalContext().SetGraphOption(options_map); - // 调用RDFS - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector - expected_rdfs_names = {"const", "data", "variable", "node2", "node3", "node4", "node1", "node5"}; - std::vector rdfs_names; - const auto &after_rdfs_topo = graph->GetDirectNode(); - EXPECT_EQ(after_rdfs_topo.size(), expected_rdfs_names.size()); - for (size_t i = 0; i < after_rdfs_topo.size(); ++i) { - EXPECT_EQ(after_rdfs_topo.at(i)->GetName(), expected_rdfs_names.at(i)); - } - // tear down - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -TEST_F(UtestComputeGraph, NoDelayTopologicalSorting) { - auto graph = BuildDelayTopoGraph("test_delay_topo_graph"); - std::map options_map; - options_map["ge.topoSortingMode"] = "1"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = {"variable", "node2", "node1", "data", "node3", "node4", "node5"}; - std::vector dfs_names; - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - std::vector dfs_names1; - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -TEST_F(UtestComputeGraph, DelayTopologicalSortingMultiInput) { - auto graph = BuildDelayTopoGraphMultiInput("test_delay_topo_graph"); - std::map options_map; - options_map["ge.topoSortingMode"] = "2"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = - {"const", "constant", "variable", "data", "node2", "node3", "node4", "node1", "node5"}; - std::vector dfs_names; - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - std::vector dfs_names1; - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -TEST_F(UtestComputeGraph, NoDelayTopologicalSortingMultiInput) { - auto graph = BuildDelayTopoGraphMultiInput("test_delay_topo_graph", false); - std::map options_map; - options_map["ge.topoSortingMode"] = "2"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = - {"const", "constant", "variable", "node1", "data", "node2", "node3", "node4", "node5"}; - std::vector dfs_names; - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - std::vector dfs_names1; - const auto &graph_dfs_topo1 = graph->GetAllNodesPtr(); - for (auto &node : graph_dfs_topo1) { - dfs_names1.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); - EXPECT_EQ(dfs_names1, expected_dfs_names); -} - -TEST_F(UtestComputeGraph, ReorderByNodeId) { - auto graph = BuildDelayTopoGraphMultiInput("test_delay_topo_graph"); - const auto &constant = graph->FindNode("const"); - const auto &constantop = graph->FindNode("constant"); - const auto &variable = graph->FindNode("variable"); - const auto &node1 = graph->FindNode("node1"); - const auto &node2 = graph->FindNode("node2"); - const auto &node3 = graph->FindNode("node3"); - const auto &node4 = graph->FindNode("node4"); - const auto &node5 = graph->FindNode("node5"); - const auto &data = graph->FindNode("data"); - int64_t seq_id = 0L; - std::vector nodes{node5, node4, node3, node2, node1, variable, data, constantop, constant}; - for (auto &node : nodes) { - node->GetOpDesc()->SetId(seq_id++); - } - graph->ReorderByNodeId(); - auto sorted_nodes = graph->GetDirectNode(); - ASSERT_TRUE(sorted_nodes.size() == nodes.size()); - int32_t id = 0; - for (auto &node : nodes) { - EXPECT_EQ(node, sorted_nodes.at(id++)); - } - auto sorted_nodes_ptr = graph->GetDirectNodePtr(); - ASSERT_TRUE(sorted_nodes_ptr.size() == nodes.size()); - id = 0; - for (auto &node : nodes) { - EXPECT_EQ(node.get(), sorted_nodes_ptr.at(id++)); - } -} - -TEST_F(UtestComputeGraph, CreateShapeEnvAttrWithArgs) { - auto graph = BuildDelayTopoGraphMultiInput("test_delay_topo_graph"); - graph->CreateAttrsGroup(ShapeEnvSetting()); - ASSERT_NE(graph->GetAttrsGroup(), nullptr); -} -} diff --git a/tests/ut/graph/testcase/constant_utils_unittest.cc b/tests/ut/graph/testcase/constant_utils_unittest.cc deleted file mode 100644 index 299e8a50cdfd80b739d2e3c5c6ec800982aca07c..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/constant_utils_unittest.cc +++ /dev/null @@ -1,220 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/tensor.h" -#include "graph/utils/file_utils.h" -#include "graph/utils/constant_utils.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" - -namespace ge { -class UtestConstantUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestConstantUtils, TestIsConstant) { - // check node is constant - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - const auto &const1 = builder.AddNode("const1", "Constant", 0, 1); - ASSERT_EQ(ConstantUtils::IsConstant(const1), true); - - // check operator is constant - const auto &const2 = builder.AddNode("const1", "Const", 0, 1); - ASSERT_EQ(ConstantUtils::IsConstant(const2), true); - - // check normal op is not constant - const auto &cast = builder.AddNode("cast", "Cast", 1, 1); - auto cast_op = OpDescUtils::CreateOperatorFromNode(cast); - ASSERT_EQ(ConstantUtils::IsConstant(cast), false); -} - -TEST_F(UtestConstantUtils, TestNodeIsPotentialConstant) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - // check potential const - const auto &shape_node = builder.AddNode("shape", "Shape", 1, 1); - AttrUtils::SetBool(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, true); - // new a tensor - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0}); - AttrUtils::SetListTensor(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT, {tensor}); - ASSERT_EQ(ConstantUtils::IsConstant(shape_node), true); - ASSERT_EQ(ConstantUtils::IsPotentialConst(shape_node->GetOpDesc()), true); - - // check const is not potential const - const auto &const1 = builder.AddNode("const1", "Constant", 0, 1); - ASSERT_EQ(ConstantUtils::IsPotentialConst(const1->GetOpDesc()), false); - - // check normal node is not potential const - const auto &cast = builder.AddNode("cast", "Cast", 1, 1); - ASSERT_EQ(ConstantUtils::IsPotentialConst(cast->GetOpDesc()), false); -} - -TEST_F(UtestConstantUtils, TestGetWeightFromOpDesc) { - // new a tensor - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - - // build two const - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - const auto &shape_node = builder.AddNode("shape_node", "Shape", 1, 1); - AttrUtils::SetBool(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, true); - AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0}); - AttrUtils::SetListTensor(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT, {tensor}); - const auto &const_node = builder.AddNode("const_node", "Const", 0, 1); - OpDescUtils::SetWeights(const_node, {tensor}); - - // get weight from real const - ConstGeTensorPtr weight; - ASSERT_TRUE(ConstantUtils::GetWeight(const_node->GetOpDesc(), 0, weight)); - ASSERT_EQ(weight->GetTensorDesc().GetDataType(), DT_UINT8); - ASSERT_EQ(weight->GetTensorDesc().GetShape().GetDims(), shape); - - // get weight from potential const - ConstGeTensorPtr potential_weight; - ASSERT_TRUE(ConstantUtils::GetWeight(shape_node->GetOpDesc(), 0, potential_weight)); - ASSERT_EQ(potential_weight->GetTensorDesc().GetDataType(), DT_UINT8); - ASSERT_EQ(potential_weight->GetTensorDesc().GetShape().GetDims(), shape); - - // check invalid potential const get weight - // build potential op - const auto &shape_node_2 = builder.AddNode("shape_node_2", "Shape", 1, 1); - AttrUtils::SetBool(shape_node_2->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, true); - ASSERT_FALSE(ConstantUtils::GetWeight(shape_node_2->GetOpDesc(), 0, potential_weight)); - AttrUtils::SetListInt(shape_node_2->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0}); - ASSERT_FALSE(ConstantUtils::GetWeight(shape_node_2->GetOpDesc(), 0, potential_weight)); - AttrUtils::SetListInt(shape_node_2->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0, 1}); - AttrUtils::SetListTensor(shape_node_2->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT, {tensor}); - ASSERT_FALSE(ConstantUtils::GetWeight(shape_node_2->GetOpDesc(), 0, potential_weight)); -} - -TEST_F(UtestConstantUtils, TestGetWeightFromOperator) { - // new a tensor - std::vector value{1, 2, 3}; - std::vector shape{3}; - TensorDesc tensor_desc(Shape(shape), FORMAT_ND, DT_UINT8); - Tensor tensor(tensor_desc, value); - - // build potential op - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - const auto &shape_node = builder.AddNode("shape_node", "Shape", 1, 1); - auto shape_op = OpDescUtils::CreateOperatorFromNode(shape_node); - shape_op.SetAttr(ATTR_NAME_POTENTIAL_CONST, true); - shape_op.SetAttr(ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0}); - vector weights = {tensor}; - shape_op.SetAttr(ATTR_NAME_POTENTIAL_WEIGHT, weights); - - // get weight from potential const - ASSERT_TRUE(ConstantUtils::IsConstant(shape_node)); -} - -TEST_F(UtestConstantUtils, TestSetWeight) { - // new a tensor - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - - // build two const - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - const auto &shape_node = builder.AddNode("shape_node", "Shape", 1, 1); - AttrUtils::SetBool(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, true); - AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0}); - AttrUtils::SetListTensor(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT, {tensor}); - const auto &const_node = builder.AddNode("const_node", "Constant", 0, 1); - - // set weight to real const - auto ret = ConstantUtils::SetWeight(const_node->GetOpDesc(), 0, tensor); - ASSERT_EQ(ret, true); - - // set weight to potential const - ConstGeTensorPtr potential_weight; - ret = ConstantUtils::SetWeight(shape_node->GetOpDesc(), 0, tensor); - ASSERT_EQ(ret, true); - // check weight index is out of range - ret = ConstantUtils::SetWeight(shape_node->GetOpDesc(), 1, tensor); - ASSERT_EQ(ret, false); -} - -TEST_F(UtestConstantUtils, TestMarkUnmarkPotentialConst) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - const auto &shape_node = builder.AddNode("shape_node", "Shape", 1, 1); - vector indices = {0}; - // new a tensor - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - vector weights = {tensor}; - // test normal case mark potential const - ASSERT_TRUE(ConstantUtils::MarkPotentialConst(shape_node->GetOpDesc(), indices, weights)); - bool is_potential_const = false; - ASSERT_TRUE(AttrUtils::GetBool(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, is_potential_const)); - ASSERT_TRUE(is_potential_const); - // test normal case unmark potential const - ASSERT_TRUE(ConstantUtils::UnMarkPotentialConst(shape_node->GetOpDesc())); - ASSERT_FALSE(AttrUtils::GetBool(shape_node->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, is_potential_const)); - - // test mark potential const : indices not match weights - const auto &shape_node2 = builder.AddNode("shape_node", "Shape", 1, 1); - ASSERT_FALSE(ConstantUtils::MarkPotentialConst(shape_node->GetOpDesc(), {0,1}, weights)); -} - -TEST_F(UtestConstantUtils, TestGetWeightFromFile) { - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3, 4, 5}; - std::vector shape{5}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - const auto &const_node = builder.AddNode("const_node", "Const", 0, 1); - OpDescUtils::SetWeights(const_node, {tensor}); - - ConstGeTensorPtr weight; - ASSERT_FALSE(ConstantUtils::GetWeightFromFile(const_node->GetOpDesc(), weight)); - - const auto &data = reinterpret_cast(tensor->GetData().GetData()); - const auto size = tensor->GetData().GetSize(); - ASSERT_EQ(SaveBinToFile(data, size, "./weight.bin"), GRAPH_SUCCESS); - - const auto &fileconstant = builder.AddNode("fileconstant", "FileConstant", 0, 1); - AttrUtils::SetDataType(fileconstant->GetOpDesc(), "dtype", DT_UINT8); - fileconstant->GetOpDesc()->UpdateOutputDesc(0, tensor->GetTensorDesc()); - ASSERT_FALSE(ConstantUtils::GetWeightFromFile(fileconstant->GetOpDesc(), weight)); - AttrUtils::SetStr(fileconstant->GetOpDesc(), "location", "./weight.bin"); - // Invalid size - AttrUtils::SetInt(fileconstant->GetOpDesc(), "length", size + 1024); - ASSERT_FALSE(ConstantUtils::GetWeightFromFile(fileconstant->GetOpDesc(), weight)); - // Valid size - AttrUtils::SetInt(fileconstant->GetOpDesc(), "length", size); - ASSERT_TRUE(ConstantUtils::GetWeightFromFile(fileconstant->GetOpDesc(), weight)); - ASSERT_EQ(weight->GetTensorDesc().GetDataType(), DT_UINT8); - ASSERT_EQ(weight->GetTensorDesc().GetShape().GetDims(), shape); - system("rm -rf ./weight.bin"); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/cycle_detector_unittest.cc b/tests/ut/graph/testcase/cycle_detector_unittest.cc deleted file mode 100644 index 08d83bbc808f1be82e12080d051132139365b60f..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/cycle_detector_unittest.cc +++ /dev/null @@ -1,424 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "graph/utils/cycle_detector.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/connection_matrix.h" -#include "graph/utils/connection_matrix_impl.h" -#include "graph_builder_utils.h" -using namespace std; -using namespace ge; -namespace { -const int kContainCycle = 0; -const int kNoCycleCase1 = 1; -const int kNoCycleCase2 = 2; -const int kNoCycleCase3 = 3; - -/* -* if we want to fusion cast1 and cast2 -* it will cause a cycle between fusion_cast and transdata -* data1 -* / \ -* / \ -* cast1 \ -* | \ -* trandata---> cast2 -*/ -void BuildGraphMayCauseCycleWhenFusion(ComputeGraphPtr &graph) { - auto root_builder = ut::GraphBuilder("root"); - const auto &data1 = root_builder.AddNode("data1", "Data", 1, 1); - const auto &cast1 = root_builder.AddNode("cast1", "Cast", 1, 1); - const auto &cast2 = root_builder.AddNode("cast2", "Cast", 1, 1); - const auto &transdata = root_builder.AddNode("transdata", "TransData", 1, 1); - - root_builder.AddDataEdge(data1, 0, cast1, 0); - root_builder.AddDataEdge(data1, 0, cast2, 0); - root_builder.AddDataEdge(cast1, 0, transdata, 0); - root_builder.AddControlEdge(transdata, cast2); - graph = root_builder.GetGraph(); -} - -} // namespace -class UtestCycleDetector : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; -/* -* if we want to fusion cast1 and cast2 -* it will cause a cycle between fusion_cast and transdata -* data1 data1 -* / \ | -* / \ ===> fusion_cast <------ -* cast1 \ | | -* | \ transdata----------- -* trandata---> cast2 wrong graph, which has cycle between -* transdata and fusion_cast. -*/ -TEST_F(UtestCycleDetector, TestCycleDetection_00) { - ComputeGraphPtr graph; - BuildGraphMayCauseCycleWhenFusion(graph); - - auto cast1 = graph->FindNode("cast1"); - auto cast2 = graph->FindNode("cast2"); - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - - bool has_cycle = detector->HasDetectedCycle({{cast1, cast2}}); - EXPECT_TRUE(has_cycle); -} - -/* A - * / \ - * B \ - * / \ - * D------->C - * | | - * After fusion A/B/C, the graph looks like: - * <--- - * / \ - * ABC--->D - */ -static ComputeGraphPtr BuildFusionGraph01(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 2, 1); - auto d = builder.AddNode("D", "D", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 2, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, d, 0); - builder.AddDataEdge(d, 0, c, 1); - - builder.AddDataEdge(a, 0, c, 0); - builder.AddDataEdge(c, 0, netoutput, 0); - builder.AddDataEdge(d, 0, netoutput, 1); - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c}; - return graph; -} - -TEST_F(UtestCycleDetector, TestCycleDetection_01) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph01(fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_TRUE(has_cycle); -} - -/* A - * / \ - * B \ - * / \ - * D C - * \ / - * Netoutput - * After fusion A/B/C, the graph looks like: - * - * ABC--->D - * \ / - * Netoutput - * No cycle will be generated if fusing. */ -static ComputeGraphPtr BuildFusionGraph02(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 2, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, d, 0); - - builder.AddDataEdge(a, 0, c, 0); - builder.AddDataEdge(c, 0, netoutput, 0); - builder.AddDataEdge(d, 0, netoutput, 1); - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c}; - return graph; -} - -/* -ori connection_matrix(5x5): -1 0 0 0 0 -1 1 0 0 0 -1 0 1 0 0 -1 0 1 1 0 -1 1 1 1 1 -After update(6x6): -1 1 1 0 0 1 -1 1 1 0 0 1 -1 1 1 0 0 1 -1 1 1 1 0 1 -1 1 1 1 1 1 -1 1 1 0 0 1 -*/ - -TEST_F(UtestCycleDetector, cycle_detection_02) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph02(fusion_nodes); - - CycleDetectorSharedPtr detector = GraphUtils::CreateSharedCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_FALSE(has_cycle); - std::string res_ori = "1000011000101001011011111"; - std::string res_update = "111001111001111001111101111111111001"; - std::stringstream val_ori; - for (size_t i = 0; i < 5; ++i) { - auto bit_map = detector->connectivity_->impl_->bit_maps_[i]; - for (size_t j = 0; j < 5; ++j) { - val_ori << bit_map.GetBit(j); - } - } - EXPECT_EQ(val_ori.str(), res_ori); - detector->ExpandAndUpdate(fusion_nodes, "ABC"); - std::stringstream val_update; - for (size_t i = 0; i < 6; ++i) { - auto bit_map = detector->connectivity_->impl_->bit_maps_[i]; - for (size_t j = 0; j < 6; ++j) { - val_update << bit_map.GetBit(j); - } - } - EXPECT_EQ(val_update.str(), res_update); -} - -/* A--->B---->C---->D - * \-----E-------/ - * - * A, B, C, D will be fused. - * Cycle will be generated if fusing. - */ -static ComputeGraphPtr BuildFusionGraph03(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 2, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, c, 0); - - builder.AddDataEdge(c, 0, d, 0); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(e, 0, d, 1); - builder.AddDataEdge(d, 0, netoutput, 0); - - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c, d}; - return graph; -} - -TEST_F(UtestCycleDetector, cycle_detection_03) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph03(fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_TRUE(has_cycle); -} - -/* A--->B---->C------->D - * \-----E---F------/ - * - * A, B, C, D will be fused. - * Cycle will be generated if fusing. - */ -static ComputeGraphPtr BuildFusionGraph04(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 2, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto f = builder.AddNode("F", "F", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, c, 0); - builder.AddDataEdge(c, 0, d, 0); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(e, 0, f, 0); - builder.AddDataEdge(f, 0, d, 1); - - builder.AddDataEdge(d, 0, netoutput, 0); - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c, d}; - return graph; -} - -TEST_F(UtestCycleDetector, cycle_detection_04) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph04(fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_TRUE(has_cycle); -} - -/* A--->B---->C------->D - * \-----E---F------/ - * - * B/C will be fused. - * No Cycle will be generated if fusing. - */ -static ComputeGraphPtr BuildFusionGraph05(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 2, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto f = builder.AddNode("F", "F", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, c, 0); - builder.AddDataEdge(c, 0, d, 0); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(e, 0, f, 0); - builder.AddDataEdge(f, 0, d, 0); - - builder.AddDataEdge(d, 0, netoutput, 0); - auto graph = builder.GetGraph(); - fusion_nodes = {b, c}; - return graph; -} - -TEST_F(UtestCycleDetector, cycle_detection_05) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph05(fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_FALSE(has_cycle); -} - -/* - * /-----H----------------\ - * /------G---------\ \ - * / /------I------\ \ - * A--->B---->C------->D---NetOutput - * \------E---F------------/ - * - * B/C will be fused. - * No Cycle will be generated if fusing. - */ -ComputeGraphPtr CreateGraph06(int case_num, std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 4); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 3, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto f = builder.AddNode("F", "F", 1, 1); - auto g = builder.AddNode("G", "G", 1, 1); - auto h = builder.AddNode("H", "H", 1, 1); - auto i = builder.AddNode("I", "I", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 3, 0); - - builder.AddControlEdge(a, b); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(a, 1, g, 0); - builder.AddDataEdge(a, 2, h, 0); - builder.AddDataEdge(h, 0, netoutput, 0); - - builder.AddDataEdge(b, 0, c, 0); - builder.AddDataEdge(b, 0, i, 0); - builder.AddDataEdge(i, 0, d, 0); - builder.AddDataEdge(c, 0, d, 1); - builder.AddDataEdge(d, 0, netoutput, 1); - - builder.AddDataEdge(g, 0, d, 2); - - builder.AddDataEdge(e, 0, f, 0); - builder.AddDataEdge(f, 0, netoutput, 3); - - auto graph = builder.GetGraph(); - if (case_num == kNoCycleCase1) { - fusion_nodes = {a, b, e, g, h}; - } else if (case_num == kContainCycle) { - fusion_nodes = {b, c, d}; - } else if (case_num == kNoCycleCase2) { - fusion_nodes = {b, c, i}; - } else if (case_num == kNoCycleCase3) { - fusion_nodes = {b, c, d, i}; - } - return graph; -} - - -static ComputeGraphPtr BuildFusionGraph06(int case_num, - std::vector &fusion_nodes) { - auto graph = CreateGraph06(case_num, fusion_nodes); - return graph; -} - -TEST_F(UtestCycleDetector, cycle_detection_06) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase1, fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_FALSE(has_cycle); -} - -TEST_F(UtestCycleDetector, cycle_detection_07) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kContainCycle, fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_TRUE(has_cycle); -} - -TEST_F(UtestCycleDetector, cycle_detection_08) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase2, fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_FALSE(has_cycle); -} - -TEST_F(UtestCycleDetector, cycle_detection_09) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase3, fusion_nodes); - - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - bool has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_FALSE(has_cycle); -} - -TEST_F(UtestCycleDetector, ConnectionMatrixCoverage_00) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase2, fusion_nodes); - CycleDetectorPtr detector = GraphUtils::CreateCycleDetector(graph); - EXPECT_NE(detector, nullptr); - detector->Update(graph, fusion_nodes); - auto has_cycle = detector->HasDetectedCycle({fusion_nodes}); - EXPECT_FALSE(has_cycle); - detector->Update(graph, fusion_nodes); -} diff --git a/tests/ut/graph/testcase/enum_attr_utils_unittest.cc b/tests/ut/graph/testcase/enum_attr_utils_unittest.cc deleted file mode 100644 index f4f9d6a615a93ecb31f4568a4ddd78260b473258..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/enum_attr_utils_unittest.cc +++ /dev/null @@ -1,316 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/utils/enum_attr_utils.h" - -namespace ge { -class UtestEnumAttrUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestEnumAttrUtils, TestGetEnumAttrName) { - // 验证一个属性 - vector enum_attr_names = {}; - string attr_name = "name0"; - string enum_attr_name = ""; - bool is_new_attr = false; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - ASSERT_EQ(is_new_attr, true); - ASSERT_EQ(enum_attr_name.size(), 2U); - ASSERT_EQ(enum_attr_name.at(1), 1); - ASSERT_EQ(enum_attr_names.size(), 1U); - ASSERT_EQ(enum_attr_names[0], attr_name); - - // 验证两个不同的属性 - attr_name = "name1"; - enum_attr_name = ""; - is_new_attr = false; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - ASSERT_EQ(is_new_attr, true); - ASSERT_EQ(enum_attr_name.size(), 2U); - ASSERT_EQ(enum_attr_name.at(1), 2); - ASSERT_EQ(enum_attr_names.size(), 2U); - ASSERT_EQ(enum_attr_names[1], attr_name); - - // 验证kMaxValueOfEachDigit个不同属性, 边界值为kMaxValueOfEachDigit,一位数最大值 - for (uint32_t i = 2U; i < kMaxValueOfEachDigit; i++) { - attr_name = "name" + to_string(i); - enum_attr_name = ""; - is_new_attr = false; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - } - ASSERT_EQ(is_new_attr, true); - ASSERT_EQ(enum_attr_name.size(), 2U); - ASSERT_EQ(enum_attr_name.at(1), kMaxValueOfEachDigit); - ASSERT_EQ(enum_attr_names.size(), kMaxValueOfEachDigit); - ASSERT_EQ(enum_attr_names[kMaxValueOfEachDigit - 1U], attr_name); - - // 验证kMaxValueOfEachDigit + 1个不同属性, 两位数初始值 - attr_name = "name" + to_string(kMaxValueOfEachDigit); - enum_attr_name = ""; - is_new_attr = false; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - ASSERT_EQ(is_new_attr, true); - ASSERT_EQ(enum_attr_name.size(), 3U); - ASSERT_EQ(enum_attr_name.at(1), 1U); - ASSERT_EQ(enum_attr_name.at(2), 2U); - ASSERT_EQ(enum_attr_names.size(), kMaxValueOfEachDigit + 1U); - ASSERT_EQ(enum_attr_names[kMaxValueOfEachDigit], attr_name); - - // 验证kMaxValueOfEachDigit * kMaxValueOfEachDigit个不同属性, 两位数初始值的最大值 - for (uint32_t i = kMaxValueOfEachDigit + 1U; i < kMaxValueOfEachDigit * kMaxValueOfEachDigit; i++) { - attr_name = "name" + to_string(i); - enum_attr_name = ""; - is_new_attr = false; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - } - ASSERT_EQ(is_new_attr, true); - ASSERT_EQ(enum_attr_name.size(), 3U); - ASSERT_EQ(enum_attr_name.at(1), kMaxValueOfEachDigit); - ASSERT_EQ(enum_attr_name.at(2), kMaxValueOfEachDigit); - ASSERT_EQ(enum_attr_names.size(), kMaxValueOfEachDigit * kMaxValueOfEachDigit); - ASSERT_EQ(enum_attr_names[kMaxValueOfEachDigit * kMaxValueOfEachDigit - 1U], attr_name); - - // 验证kMaxValueOfEachDigit * kMaxValueOfEachDigit + 1个不同属性, 三位数初始值 - attr_name = "name" + to_string(kMaxValueOfEachDigit * kMaxValueOfEachDigit); - enum_attr_name = ""; - is_new_attr = false; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - ASSERT_EQ(is_new_attr, true); - ASSERT_EQ(enum_attr_name.size(), 4U); - ASSERT_EQ(enum_attr_name.at(1), 1U); - ASSERT_EQ(enum_attr_name.at(2), 1U); - ASSERT_EQ(enum_attr_name.at(3), 2U); - ASSERT_EQ(enum_attr_names.size(), kMaxValueOfEachDigit * kMaxValueOfEachDigit + 1U); - ASSERT_EQ(enum_attr_names[kMaxValueOfEachDigit * kMaxValueOfEachDigit], attr_name); - - // 验证属性名重复场景1 - attr_name = "name0"; - enum_attr_name = ""; - is_new_attr = true; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - ASSERT_EQ(is_new_attr, false); - ASSERT_EQ(enum_attr_name.size(), 2U); - ASSERT_EQ(enum_attr_name.at(1), 1); - ASSERT_EQ(enum_attr_names[0], attr_name); - - // 验证属性名重复场景2 - attr_name = "name" + to_string(kMaxValueOfEachDigit * kMaxValueOfEachDigit); - enum_attr_name = ""; - is_new_attr = true; - EnumAttrUtils::GetEnumAttrName(enum_attr_names, attr_name, enum_attr_name, is_new_attr); - ASSERT_EQ(is_new_attr, false); - ASSERT_EQ(enum_attr_name.size(), 4U); - ASSERT_EQ(enum_attr_name.at(1), 1U); - ASSERT_EQ(enum_attr_name.at(2), 1U); - ASSERT_EQ(enum_attr_name.at(3), 2U); - ASSERT_EQ(enum_attr_names.size(), kMaxValueOfEachDigit * kMaxValueOfEachDigit + 1U); - ASSERT_EQ(enum_attr_names[kMaxValueOfEachDigit * kMaxValueOfEachDigit], attr_name); -} - -TEST_F(UtestEnumAttrUtils, TestGetEnumAttrValue) { - // 验证一个属性值 - vector enum_attr_values = {}; - string attr_value = "value0"; - int64_t enum_attr_value = 0; - EnumAttrUtils::GetEnumAttrValue(enum_attr_values, attr_value, enum_attr_value); - ASSERT_EQ(enum_attr_value, 0); - ASSERT_EQ(enum_attr_values.size(), 1U); - ASSERT_EQ(enum_attr_values[0], attr_value); - - // 验证两个属性值 - attr_value = "value1"; - enum_attr_value = 0; - EnumAttrUtils::GetEnumAttrValue(enum_attr_values, attr_value, enum_attr_value); - ASSERT_EQ(enum_attr_value, 1); - ASSERT_EQ(enum_attr_values.size(), 2U); - ASSERT_EQ(enum_attr_values[1], attr_value); - - // 验证重复属性场景 - attr_value = "value1"; - enum_attr_value = 0; - EnumAttrUtils::GetEnumAttrValue(enum_attr_values, attr_value, enum_attr_value); - ASSERT_EQ(enum_attr_value, 1); - ASSERT_EQ(enum_attr_values.size(), 2U); - ASSERT_EQ(enum_attr_values[1], attr_value); - - // 验证重复属性场景 - attr_value = "value0"; - enum_attr_value = 0; - EnumAttrUtils::GetEnumAttrValue(enum_attr_values, attr_value, enum_attr_value); - ASSERT_EQ(enum_attr_value, 0); - ASSERT_EQ(enum_attr_values.size(), 2U); - ASSERT_EQ(enum_attr_values[0], attr_value); - - // 验证三个属性值 - attr_value = "value2"; - enum_attr_value = 0; - EnumAttrUtils::GetEnumAttrValue(enum_attr_values, attr_value, enum_attr_value); - ASSERT_EQ(enum_attr_value, 2); - ASSERT_EQ(enum_attr_values.size(), 3U); - ASSERT_EQ(enum_attr_values[2], attr_value); -} - -TEST_F(UtestEnumAttrUtils, TestGetEnumAttrValues) { - // 验证三个不同属性 - vector enum_attr_values = {}; - vector attr_values; - attr_values.emplace_back("value0"); - attr_values.emplace_back("value1"); - attr_values.emplace_back("value2"); - vector enum_values = {}; - EnumAttrUtils::GetEnumAttrValues(enum_attr_values, attr_values, enum_values); - for (size_t i = 0U; i < attr_values.size(); i++) { - ASSERT_EQ(enum_values[i], i); - ASSERT_EQ(enum_attr_values[i], "value" + to_string(i)); - } - - // 验证包含两个相同属性 - vector attr_values1; - attr_values1.emplace_back("value2"); - attr_values1.emplace_back("value0"); - vector enum_values1 = {}; - EnumAttrUtils::GetEnumAttrValues(enum_attr_values, attr_values1, enum_values1); - ASSERT_EQ(enum_values1[0], 2); - ASSERT_EQ(enum_values1[1], 0); - ASSERT_EQ(enum_attr_values.size(), 3U); - - // 验证包含一个相同属性, 一个不同属性 - vector attr_values2; - attr_values2.emplace_back("value3"); - attr_values2.emplace_back("value1"); - vector enum_values2 = {}; - EnumAttrUtils::GetEnumAttrValues(enum_attr_values, attr_values2, enum_values2); - ASSERT_EQ(enum_values2[0], 3); - ASSERT_EQ(enum_values2[1], 1); - ASSERT_EQ(enum_attr_values.size(), 4U); -} - -TEST_F(UtestEnumAttrUtils, TestGetAttrName) { - // enum_attr_name为空校验 - vector enum_attr_names = {}; - vector name_use_string_values = {}; - string enum_attr_name = ""; - string attr_name = ""; - bool is_value_string = false; - auto ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_FAILED); - - // enum_attr_names为空校验 - char_t a1 = 1; - enum_attr_name.append(kAppendNum, prefix); - enum_attr_name.append(kAppendNum, a1); - ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_FAILED); - - // name_use_string_values为空校验 - enum_attr_names.emplace_back("name1"); - ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_FAILED); - - // 一个成员的正常的流程 enum化的属性名 - name_use_string_values.emplace_back(true); - ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(attr_name, "name1"); - ASSERT_EQ(is_value_string, true); - - // 两个成员的正常的流程 enum化的属性名 - enum_attr_name = ""; - char_t a2 = 2; - enum_attr_name.append(kAppendNum, prefix); - enum_attr_name.append(kAppendNum, a2); - enum_attr_names.emplace_back("name2"); - name_use_string_values.emplace_back(false); - ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(attr_name, "name2"); - ASSERT_EQ(is_value_string, false); - - // 127个成员的正常的流程 enum化的属性名, 一位数最大值 - enum_attr_name = ""; - char_t a127 = 127; - enum_attr_name.append(kAppendNum, prefix); - enum_attr_name.append(kAppendNum, a127); - for (int i = 3; i <= 127; i++) { - enum_attr_names.emplace_back("name" + to_string(i)); - name_use_string_values.emplace_back(true); - } - ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(attr_name, "name127"); - ASSERT_EQ(is_value_string, true); - - // 128个成员的正常的流程 enum化的属性名,两位数的初始值 - enum_attr_name = ""; - char_t a128_1 = 1; - char_t a128_2 = 2; - enum_attr_name.append(kAppendNum, prefix); - enum_attr_name.append(kAppendNum, a128_1); - enum_attr_name.append(kAppendNum, a128_2); - enum_attr_names.emplace_back("name128"); - name_use_string_values.emplace_back(true); - ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(attr_name, "name128"); - ASSERT_EQ(is_value_string, true); - - // 正常的流程 非enum化的属性名 - enum_attr_name = "name1"; - ret = EnumAttrUtils::GetAttrName(enum_attr_names, name_use_string_values, - enum_attr_name, attr_name, is_value_string); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(attr_name, enum_attr_name); - ASSERT_EQ(is_value_string, false); -} - -TEST_F(UtestEnumAttrUtils, TestGetAttrValue) { - // 异常场景测试 - vector enum_attr_values = {}; - int64_t enum_attr_value = 0; - string attr_value = ""; - auto ret = EnumAttrUtils::GetAttrValue(enum_attr_values, enum_attr_value, attr_value); - ASSERT_EQ(ret, GRAPH_FAILED); - - // 正常场景测试 - enum_attr_values.emplace_back("value1"); - ret = EnumAttrUtils::GetAttrValue(enum_attr_values, enum_attr_value, attr_value); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(attr_value, "value1"); -} - -TEST_F(UtestEnumAttrUtils, TestGetAttrValues) { - // 异常场景测试 - vector enum_attr_values = {}; - vector enum_values = {}; - vector attr_values = {}; - enum_values.emplace_back(1); - auto ret = EnumAttrUtils::GetAttrValues(enum_attr_values, enum_values, attr_values); - ASSERT_EQ(ret, GRAPH_FAILED); - - // 正常场景测试 - enum_attr_values.emplace_back("value1"); - enum_attr_values.emplace_back("value2"); - ret = EnumAttrUtils::GetAttrValues(enum_attr_values, enum_values, attr_values); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(attr_values.size(), 1U); - ASSERT_EQ(attr_values[0], "value2"); -} -} diff --git a/tests/ut/graph/testcase/execute_graph_adapter_unittest.cc b/tests/ut/graph/testcase/execute_graph_adapter_unittest.cc deleted file mode 100644 index b89918cd67fd5da00716021a72c5fb29e49d887e..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/execute_graph_adapter_unittest.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/utils/execute_graph_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph_builder_utils.h" -#include "graph/utils/execute_graph_adapter.h" -#include "share_graph.h" -#include "mmpa/mmpa_api.h" - -namespace ge { - -using ExecuteSharedGraph = SharedGraph; - -class UtestExecuteGraphAdapter : public testing::Test { - public: - void SetUp() override {} - void TearDown() override {} -}; - -TEST_F(UtestExecuteGraphAdapter, ConvertExecuteGraphToComputeGraph_Ok_without_subgraph) { - auto exe_graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto compute_graph = ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(exe_graph.get()); - EXPECT_NE(compute_graph, nullptr); - EXPECT_EQ(exe_graph->GetAllNodes().size(), compute_graph->GetAllNodes().size()); -} - -TEST_F(UtestExecuteGraphAdapter, ConvertExecuteGraphToComputeGraph_Ok_with_subgraph) { - auto exe_graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - auto attr_name = "use_execute_graph"; - auto attr_value = "1"; - AttrUtils::SetStr(exe_graph, attr_name, attr_value); - auto ext_attr_name = "FakeExtAttr"; - auto ext_attr_value = "233abc"; - exe_graph->SetExtAttr(ext_attr_name, ext_attr_value); - auto case0 = ExecuteGraphUtils::FindNodeFromAllNodes(exe_graph.get(), "case0"); - case0->GetOpDescBarePtr()->AddSubgraphName("branch3"); - case0->GetOpDescBarePtr()->SetSubgraphInstanceName(2, ""); - - auto compute_graph = ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(exe_graph.get()); - EXPECT_NE(compute_graph, nullptr); - EXPECT_EQ(exe_graph->GetAllNodes().size(), compute_graph->GetAllNodes().size()); - auto value = AttrUtils::GetStr(compute_graph, attr_name); - EXPECT_EQ(*value, attr_value); - auto ext_value = compute_graph->TryGetExtAttr(ext_attr_name, ""); - EXPECT_EQ(ext_value, ext_attr_value); -} - -TEST_F(UtestExecuteGraphAdapter, ConvertExecuteGraphToComputeGraph_Ok_with_null_edge) { - auto exe_graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto n4 = ExecuteGraphUtils::FindNodeFromAllNodes(exe_graph.get(), "n4"); - auto out_edges = n4->GetOutEdgesByIndex(0); - for (auto edge : out_edges) { - exe_graph->RemoveEdge(edge); - } - auto out_ctrl_edge = n4->GetOutEdgesByIndex(-1); - for (auto edge : out_ctrl_edge) { - exe_graph->RemoveEdge(edge); - } - auto compute_graph = ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(exe_graph.get()); - EXPECT_NE(compute_graph, nullptr); - EXPECT_EQ(exe_graph->GetAllNodes().size(), compute_graph->GetAllNodes().size()); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/execute_graph_unittest.cc b/tests/ut/graph/testcase/execute_graph_unittest.cc deleted file mode 100644 index abbdebe87230f52efc6a514a860d4ac78a52bc4d..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/execute_graph_unittest.cc +++ /dev/null @@ -1,1539 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/fast_graph/execute_graph.h" -#include "common/share_graph.h" -#include "graph/ge_local_context.h" -#include "graph_builder_utils.h" -#include "graph/fast_graph/fast_node.h" -#include "graph/debug/ge_op_types.h" -#include "fast_graph/fast_graph_impl.h" - -namespace { -using ExecuteSharedGraph = ge::SharedGraph; -std::shared_ptr BuildDelayTopoGraph(const std::string &name) { - auto builder = ge::ut::ExecuteGraphBuilder(name); - const auto &variable = builder.AddNode("variable", ge::VARIABLE, 0, 2); - const auto &node1 = builder.AddNode("node1", "node1", 1, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - const auto &node4 = builder.AddNode("node4", "node4", 1, 1); - const auto &node5 = builder.AddNode("node5", "node5", 3, 0); - const auto &data = builder.AddNode("data", "DATA", 0, 1); - - builder.AddDataEdge(variable, 0, node1, 0); - builder.AddDataEdge(variable, 1, node2, 0); - builder.AddDataEdge(node1, 0, node5, 0); - builder.AddDataEdge(node2, 0, node5, 1); - builder.AddDataEdge(data, 0, node3, 0); - builder.AddDataEdge(node3, 0, node4, 0); - - int dst_idx = 2; - builder.AddDataEdge(node4, 0, node5, dst_idx); - - builder.AddControlEdge(node2, node3); - return builder.GetGraph(); -} -} // namespace - -namespace ge { -class UtestExecuteGraph : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestExecuteGraph, ExecuteGraph) { - auto compute_graph = std::make_shared("graph"); - size_t graph_id = 10; - compute_graph->SetGraphId(graph_id); - auto ret = compute_graph->GetGraphId(); - ASSERT_EQ(ret, graph_id); -} - -TEST_F(UtestExecuteGraph, AddNodeToExecuteGraph) { - auto compute_graph = std::make_shared("graph"); - { - auto op_desc = std::make_shared("op", "op"); - std::vector inputs_order; - int num = 10; - for (int i = 0; i < num; ++i) { - inputs_order.push_back("test" + std::to_string(i)); - } - compute_graph->SetInputsOrder(inputs_order); - auto td = GeTensorDesc(); - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - auto node = compute_graph->AddNode(op_desc); - ASSERT_NE(node, nullptr); - auto ret = compute_graph->RemoveJustNode(node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - node = compute_graph->AddNode(node); - ASSERT_NE(node, nullptr); - - auto node_with_idx = compute_graph->AddNode(op_desc, 0); - ASSERT_NE(node_with_idx, nullptr); - - ret = compute_graph->RemoveJustNode(node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - auto node = compute_graph->AddNodeFront(op_desc); - ASSERT_NE(node, nullptr); - auto ret = compute_graph->RemoveJustNode(node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - node = compute_graph->AddNodeFront(node); - ASSERT_NE(node, nullptr); - } -} - -TEST_F(UtestExecuteGraph, RemoveJustNode) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - - for (int i = 0; i < node_num; ++i) { - auto ret = compute_graph->RemoveJustNode(node[i]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - for (int i = 0; i < edge_num; ++i) { - auto ret = compute_graph->RemoveEdge(edge[i]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestExecuteGraph, AddEdgeToExecuteGraph) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - - ASSERT_EQ(compute_graph->GetAllEdges().size(), edge_num); - - for (int i = 0; i < edge_num; ++i) { - auto ret = compute_graph->RemoveEdge(edge[i]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestExecuteGraph, AddSubGraphToExecuteGraph) { - auto root_graph = std::make_shared("root_graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = root_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - int subgraph_num = 5; - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(node[i]); - } - - { - /* Test Error */ - std::string name = "subgraph_" + std::to_string(0); - std::shared_ptr invalid_sub_graph = nullptr; - auto fast_graph = root_graph->AddSubGraph(invalid_sub_graph, name); - ASSERT_EQ(fast_graph, nullptr); - - sub_graph[0]->SetParentGraph(nullptr); - fast_graph = root_graph->AddSubGraph(sub_graph[0], name); - ASSERT_EQ(fast_graph, nullptr); - - sub_graph[0]->SetParentGraph(root_graph.get()); - sub_graph[0]->SetParentNode(nullptr); - fast_graph = root_graph->AddSubGraph(sub_graph[0], name); - ASSERT_EQ(fast_graph, nullptr); - - sub_graph[0]->SetParentGraph(root_graph.get()); - sub_graph[0]->SetParentNode(node[0]); - - auto ok_fast_graph = root_graph->AddSubGraph(sub_graph[0], name); - ASSERT_NE(ok_fast_graph, nullptr); - - auto bad_graph = root_graph->AddSubGraph(sub_graph[0], name); - ASSERT_EQ(bad_graph, nullptr); - - auto ret = root_graph->RemoveSubGraph(name); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto root_graph2 = std::make_shared("root_graph2"); - sub_graph[1]->SetParentGraph(root_graph2.get()); - sub_graph[1]->SetParentNode(node[1]); - bad_graph = root_graph->AddSubGraph(sub_graph[1], name); - ASSERT_EQ(bad_graph, nullptr); - - sub_graph[1]->SetParentGraph(root_graph.get()); - sub_graph[1]->SetParentNode(node[1]); - } - - { - std::string name = "root_graph2"; - auto root_graph2 = std::make_shared(name); - root_graph2->SetParentGraph(root_graph.get()); - root_graph2->SetParentNode(node[0]); - - auto ok_fast_graph = root_graph->AddSubGraph(root_graph2, name); - ASSERT_NE(ok_fast_graph, nullptr); - - auto find_graph = root_graph2->GetSubGraph(name); - ASSERT_NE(find_graph, nullptr); - - auto ret = root_graph->RemoveSubGraph(name); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - for (int i = 0; i < subgraph_num; i++) { - std::string name = "subgraph_" + std::to_string(i); - auto fast_graph = root_graph->AddSubGraph(sub_graph[i], name); - ASSERT_NE(fast_graph, nullptr); - - auto find_graph = root_graph->GetSubGraph(name); - ASSERT_NE(find_graph, nullptr); - - auto ret = root_graph->RemoveSubGraph(name); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - for (int i = 0; i < subgraph_num; i++) { - std::string name = "subgraph_" + std::to_string(i); - auto fast_graph = root_graph->AddSubGraph(sub_graph[i], name); - ASSERT_NE(fast_graph, nullptr); - } - - auto find_graph = root_graph->GetSubGraph("subgraph_1"); - ASSERT_NE(find_graph, nullptr); - - auto subgraphs = root_graph->GetAllSubgraphs(); - ASSERT_EQ(subgraphs.size(), subgraph_num); - root_graph->ClearAllSubGraph(); -} - -TEST_F(UtestExecuteGraph, AddOKSubGraphToExecuteGraph) { - auto root_graph = std::make_shared("root_graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = root_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - int subgraph_num = 5; - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(node[i]); - } - - { - std::shared_ptr invalid_fast_graph = nullptr; - auto fast_graph = root_graph->AddSubGraph(invalid_fast_graph); - ASSERT_EQ(fast_graph, nullptr); - - auto ret = root_graph->RemoveSubGraph(nullptr); - ASSERT_EQ(ret, GRAPH_PARAM_INVALID); - } - - for (int i = 0; i < subgraph_num; i++) { - std::string name = "subgraph_" + std::to_string(i); - auto fast_graph = root_graph->AddSubGraph(sub_graph[i]); - ASSERT_NE(fast_graph, nullptr); - - auto ret = root_graph->RemoveSubGraph(fast_graph); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestExecuteGraph, TestExecuteGraphAssign) { - auto root_graph = std::make_shared("root_graph"); - - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = root_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = root_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - - int subgraph_num = 5; - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - ASSERT_NE(sub_graph[i], nullptr); - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(node[i]); - } - - for (int i = 0; i < subgraph_num; i++) { - std::string name = "subgraph_" + std::to_string(i); - auto fast_graph = root_graph->AddSubGraph(sub_graph[i], name); - ASSERT_NE(fast_graph, nullptr); - - auto ret = root_graph->RemoveSubGraph(name); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - auto Assign_graph = std::make_shared("root_graph"); - *Assign_graph = *root_graph; -} - -TEST_F(UtestExecuteGraph, TestRecycle) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - - FastGraphUtils::GetListElementAddr(node[0])->owner->erase(FastGraphUtils::GetListElementAddr(node[0])); - auto ret = compute_graph->RecycleQuickNode(node[0]); - ASSERT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraph, TestNodes) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 10; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - ASSERT_EQ(compute_graph->GetDirectNodesSize(), node_num); - ASSERT_EQ(compute_graph->GetDirectNode().size(), node_num); - - { - auto ret = compute_graph->RemoveJustNode(node[0]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - auto node = compute_graph->AddNode(op_desc); - ASSERT_NE(node, nullptr); - } -} - -TEST_F(UtestExecuteGraph, TestIONodes) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - FastNode *quick_node[node_num] = {}; - - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - quick_node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(quick_node[i], nullptr); - } - - { - auto add_node = compute_graph->AddNode(nullptr); - ASSERT_EQ(add_node, nullptr); - - add_node = compute_graph->AddNode(nullptr, 1); - ASSERT_EQ(add_node, nullptr); - - auto input_node = compute_graph->AddInputNode(nullptr); - ASSERT_EQ(input_node, nullptr); - - auto ret = compute_graph->RemoveInputNode(nullptr); - ASSERT_NE(ret, GRAPH_SUCCESS); - - auto node = compute_graph->AddOutputNodeByIndex(nullptr, 0); - ASSERT_EQ(node, nullptr); - - ret = compute_graph->RemoveOutputNode(nullptr); - ASSERT_NE(ret, GRAPH_SUCCESS); - } - - auto input_node = compute_graph->AddInputNode(quick_node[0]); - ASSERT_NE(input_node, nullptr); - - auto output_node = compute_graph->AddOutputNodeByIndex(quick_node[node_num - 1], 0); - ASSERT_NE(output_node, nullptr); - - auto ret = compute_graph->RemoveInputNode(quick_node[0]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ret = compute_graph->RemoveOutputNode(quick_node[node_num - 1]); - ASSERT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraph, TestSortsNodes) { - auto root_graph = std::make_shared("root_graph"); - int node_num = 10; - auto subgraph_num = 10; - auto subgraph_node_num = 5; - auto io_num = 5; - - std::shared_ptr op_desc[node_num] = {nullptr}; - for (int j = 0; j < node_num; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddOutputDesc(td); - } - - op_desc[j]->AddSubgraphName("subgraph_" + std::to_string(j)); - op_desc[j]->SetSubgraphInstanceName(j, "subgraph_" + std::to_string(j)); - } - - FastNode *quick_node[node_num] = {}; - for (int i = 0; i < node_num; i++) { - quick_node[i] = root_graph->AddNode(op_desc[i]); - ASSERT_NE(quick_node[i], nullptr); - } - - for (int j = 1; j < node_num; j++) { - root_graph->AddEdge(quick_node[j], 1, quick_node[j - 1], 0); - } - - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - for (int j = 0; j < subgraph_node_num; j++) { - sub_op_desc[i][j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t x = 0; x < io_num; ++x) { - sub_op_desc[i][j]->AddInputDesc(td); - } - for (int64_t x = 0; x < io_num; ++x) { - sub_op_desc[i][j]->AddOutputDesc(td); - } - } - } - - std::shared_ptr sub_graph[subgraph_num]{}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - FastNode *sub_graph_node[subgraph_node_num] = {}; - for (int j = 0; j < subgraph_node_num; j++) { - sub_graph_node[j] = sub_graph[i]->AddNode(sub_op_desc[i][j]); - } - for (int j = 1; j < subgraph_node_num; j++) { - sub_graph[i]->AddEdge(sub_graph_node[j], 1, sub_graph_node[j - 1], 0); - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(quick_node[i]); - } - } - - for (int64_t i = 0; i < subgraph_num; ++i) { - std::string name = "subgraph_" + std::to_string(i); - auto ret = root_graph->AddSubGraph(sub_graph[i], name); - ASSERT_NE(ret, nullptr); - } - - /* bfs reverse no memory priority */ - { - static const std::string kTopoSortingBfs = "0"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingBfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "xxx"; - graph_options[ENABLE_SINGLE_STREAM] = "true"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* dfs reverse no memory priority */ - { - static const std::string kTopoSortingDfs = "1"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "xxx"; - graph_options[ENABLE_SINGLE_STREAM] = "true"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* DfsPostOrder reverse no memory priority */ - { - static const std::string kTopoSortingDfsPostOrder = "2"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfsPostOrder; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "xxx"; - graph_options[ENABLE_SINGLE_STREAM] = "true"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* bfs no reverse no memory priority */ - { - static const std::string kTopoSortingBfs = "0"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingBfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "xxxx"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* dfs no reverse no memory priority */ - { - static const std::string kTopoSortingDfs = "1"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "xxx"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* DfsPostOrder no reverse no memory priority */ - { - static const std::string kTopoSortingDfsPostOrder = "2"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfsPostOrder; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "xx"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* bfs reverse with memory priority */ - { - static const std::string kTopoSortingBfs = "0"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingBfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "MemoryPriority"; - graph_options[ENABLE_SINGLE_STREAM] = "true"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* dfs reverse with memory priority */ - { - static const std::string kTopoSortingDfs = "1"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "MemoryPriority"; - graph_options[ENABLE_SINGLE_STREAM] = "true"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* DfsPostOrder reverse with memory priority */ - { - static const std::string kTopoSortingDfsPostOrder = "2"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfsPostOrder; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "MemoryPriority"; - graph_options[ENABLE_SINGLE_STREAM] = "true"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* bfs no reverse with memory priority */ - { - static const std::string kTopoSortingBfs = "0"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingBfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "MemoryPriority"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - /* dfs no reverse with memory priority */ - { - static const std::string kTopoSortingDfs = "1"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "MemoryPriority"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestExecuteGraph, TestSortsNodesError) { - auto root_graph = std::make_shared("root_graph"); - int node_num = 10; - auto subgraph_num = 10; - auto subgraph_node_num = 5; - auto io_num = 5; - - std::shared_ptr op_desc[node_num] = {nullptr}; - for (int j = 0; j < node_num; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddOutputDesc(td); - } - } - - FastNode *quick_node[node_num] = {}; - for (int i = 0; i < node_num; i++) { - quick_node[i] = root_graph->AddNode(op_desc[i]); - ASSERT_NE(quick_node[i], nullptr); - } - - FastEdge *edge[node_num] = {}; - for (int j = 1; j < node_num; j++) { - edge[j] = root_graph->AddEdge(quick_node[j], 1, quick_node[j - 1], 0); - } - - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - for (int j = 0; j < subgraph_node_num; j++) { - sub_op_desc[i][j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t x = 0; x < io_num; ++x) { - sub_op_desc[i][j]->AddInputDesc(td); - } - for (int64_t x = 0; x < io_num; ++x) { - sub_op_desc[i][j]->AddOutputDesc(td); - } - } - } - - std::shared_ptr sub_graph[subgraph_num]{}; - for (int i = 0; i < subgraph_num; i++) { - sub_graph[i] = std::make_shared("subgraph_" + std::to_string(i)); - FastNode *sub_graph_node[subgraph_node_num] = {}; - for (int j = 0; j < subgraph_node_num; j++) { - sub_graph_node[j] = sub_graph[i]->AddNode(sub_op_desc[i][j]); - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(quick_node[i]); - } - } - - for (int64_t i = 0; i < subgraph_num; ++i) { - std::string name = "subgraph_" + std::to_string(i); - auto ret = root_graph->AddSubGraph(sub_graph[i], name); - ASSERT_NE(ret, nullptr); - } - - /* ERROR */ - { - static const std::string kTopoSortingBfs = "1"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingBfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "MemoryPriority"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_FAILED); - } - - for (int j = 1; j < node_num; j++) { - root_graph->RemoveEdge(edge[j]); - } - - /* ERROR */ - { - static const std::string kTopoSortingBfs = "1"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingBfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "MemoryPriority"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - auto ret = root_graph->TopologicalSorting(); - ASSERT_EQ(ret, GRAPH_FAILED); - } -} - -TEST_F(UtestExecuteGraph, TestGraphName) { - auto root_graph = std::make_shared("root_graph"); - std::string str = "changetohelloworld"; - root_graph->SetName(str); - auto ret = root_graph->GetName(); - ASSERT_EQ(ret, str); -} - -TEST_F(UtestExecuteGraph, TestGraphParent) { - auto root_graph = std::make_shared("root_graph"); - auto sub_graph = std::make_shared("sub_graph"); - sub_graph->SetParentGraph(root_graph.get()); - auto ret = sub_graph->GetParentGraphBarePtr(); - ASSERT_EQ(ret, root_graph.get()); -} - -TEST_F(UtestExecuteGraph, TestRecycleNode) { - auto root_graph = std::make_shared("root_graph"); - int node_num = 10; - int io_num = 5; - std::shared_ptr op_desc[node_num] = {nullptr}; - NodePtr node[node_num] = {}; - for (int j = 0; j < 1; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddOutputDesc(td); - } - } - - for (int i = 0; i < 1; i++) { - node[i] = root_graph->AddNode(op_desc[i]); - } - - auto compute_graph = std::make_shared("graph"); - { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - auto quick_node = compute_graph->AddNode(op_desc); - ASSERT_NE(quick_node, nullptr); - quick_node->SetNodePtr(node[0]); - auto ret = compute_graph->RemoveJustNode(quick_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ret = compute_graph->RecycleQuickNode(quick_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - { - auto ret = compute_graph->RecycleQuickNode(nullptr); - ASSERT_NE(ret, GRAPH_SUCCESS); - - ret = compute_graph->RecycleQuickEdge(nullptr); - ASSERT_NE(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestExecuteGraph, ClearEdge) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - - for (int i = 0; i < edge_num; ++i) { - auto ret = compute_graph->RemoveEdge(edge[i]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestExecuteGraph, DelayTopologicalSorting) { - auto graph = BuildDelayTopoGraph("test_delay_topo_graph"); - std::map options_map; - options_map["ge.topoSortingMode"] = "2"; - options_map["ge.exec.memoryOptimizationPolicy"] = "MemoryPriority"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = {"variable", "data", "node2", "node3", "node4", "node1", "node5"}; - std::vector dfs_names; - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - - EXPECT_EQ(dfs_names, expected_dfs_names); -} - -TEST_F(UtestExecuteGraph, NoDelayTopologicalSorting) { - auto graph = BuildDelayTopoGraph("test_delay_topo_graph"); - std::map options_map; - options_map["ge.topoSortingMode"] = "1"; - GetThreadLocalContext().SetGraphOption(options_map); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector expected_dfs_names = {"variable", "node2", "node1", "data", "node3", "node4", "node5"}; - std::vector dfs_names; - const auto &graph_dfs_topo = graph->GetAllNodes(); - for (auto &node : graph_dfs_topo) { - dfs_names.push_back(node->GetName()); - } - EXPECT_EQ(dfs_names, expected_dfs_names); - - { - /* recovery environment */ - static const std::string kTopoSortingDfs = "1"; - auto ori_graph_options = GetThreadLocalContext().GetAllGraphOptions(); - auto graph_options = ori_graph_options; - graph_options[OPTION_TOPOSORTING_MODE] = kTopoSortingDfs; - graph_options[MEMORY_OPTIMIZATION_POLICY] = "xxx"; - graph_options[ENABLE_SINGLE_STREAM] = "false"; - GetThreadLocalContext().SetGraphOption(graph_options); - } -} - -TEST_F(UtestExecuteGraph, TestNodeAttr) { - auto compute_graph = std::make_shared("graph"); - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - auto node = compute_graph->AddNode(op_desc); - ASSERT_NE(node, nullptr); - ASSERT_NE(node->GetExtendInfo(), nullptr); - auto name_ptr = node->GetNamePtr(); - ASSERT_NE(name_ptr, nullptr); - auto name = node->GetName(); - ASSERT_EQ(name, "op"); - auto type = node->GetType(); - ASSERT_EQ(name, "op"); - auto type_ptr = node->GetTypePtr(); - ASSERT_NE(type_ptr, nullptr); - auto graph = node->GetExtendInfo()->GetOwnerGraphBarePtr(); - ASSERT_EQ(graph, compute_graph.get()); - - auto ret = compute_graph->RemoveJustNode(node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto node1 = compute_graph->AddNode(node); - ASSERT_EQ(node1, node); - - { - auto op_desc1 = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - op_desc1->AddInputDesc(td); - op_desc1->AddOutputDesc(td); - auto new_node_1 = compute_graph->AddNode(op_desc); - new_node_1->Init(op_desc); - ASSERT_NE(*new_node_1 == *node1, true); - } -} - -TEST_F(UtestExecuteGraph, TestNodeEqual) { - auto compute_graph = std::make_shared("graph"); - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - auto node = compute_graph->AddNode(op_desc); - - auto ret = compute_graph->RemoveJustNode(node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto node1 = compute_graph->AddNode(node); - ASSERT_EQ(*node1 == *node, true); -} - -TEST_F(UtestExecuteGraph, TestEdgeError) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - int32_t invalid_index = -2; - { - QuickEdge *edge = new QuickEdge; - FastGraphUtils::GetEdge(edge).dst_input = invalid_index; - graphStatus ret = node[0]->RecordEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionInType); - ASSERT_NE(ret, GRAPH_SUCCESS); - delete edge; - } - - { - QuickEdge *edge = new QuickEdge; - FastGraphUtils::GetEdgeSrcOutput(edge) = invalid_index; - graphStatus ret = node[0]->RecordEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionOutType); - ASSERT_NE(ret, GRAPH_SUCCESS); - delete edge; - } - - { - QuickEdge *edge = new QuickEdge; - FastGraphUtils::GetEdgeSrcOutput(edge) = invalid_index; - graphStatus ret = node[0]->EraseEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionOutType); - ASSERT_NE(ret, GRAPH_SUCCESS); - delete edge; - } - - { - QuickEdge *edge = new QuickEdge; - FastGraphUtils::GetEdgeDstInput(edge) = invalid_index; - graphStatus ret = node[0]->EraseEdge(&FastGraphUtils::GetEdge(edge), DirectionType::kDirectionInType); - ASSERT_NE(ret, GRAPH_SUCCESS); - delete edge; - } -} - -TEST_F(UtestExecuteGraph, TestNodeEdgeError) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - { - auto ret = node[0]->GetOutDataNodesByIndex(-2); - ASSERT_EQ(ret.size(), 0); - } - - { - auto ret = node[1]->GetOutControlNodes(); - ASSERT_EQ(ret.size(), 0); - } - - { - auto ret = node[2]->GetInDataEdgeByIndex(-2); - ASSERT_EQ(ret, nullptr); - } -} - -TEST_F(UtestExecuteGraph, TestMoveEdge) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - - for (int i = 0; i < edge_num; ++i) { - auto ret = compute_graph->RemoveEdge(edge[i]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - edge[0] = compute_graph->AddEdge(node[0], 0, node[1], 0); - - auto size = node[0]->GetOutEdgesSizeByIndex(0); - ASSERT_EQ(size, 1); - - auto edges = node[0]->GetOutEdgesByIndex(0); - ASSERT_EQ(edges.empty(), false); - ASSERT_EQ(edges[0], edge[0]); - - size = node[1]->GetInEdgesSizeByIndex(0); - ASSERT_EQ(size, 1); - - size = node[1]->GetInEdgesSizeByIndex(-1); - ASSERT_EQ(size, 0); - - auto data_edge = node[1]->GetInDataEdgeByIndex(0); - ASSERT_EQ(data_edge, edge[0]); - - auto ret = node[0]->MoveEdge(DirectionType::kDirectionOutType, 0, 1, 0); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - size = node[0]->GetOutEdgesSizeByIndex(0); - ASSERT_EQ(size, 1); - - size = node[0]->GetOutEdgesSizeByIndex(-1); - ASSERT_EQ(size, 0); - - edges = node[0]->GetOutEdgesByIndex(0); - ASSERT_EQ(edges[0], edge[0]); - - ret = node[1]->MoveEdge(DirectionType::kDirectionInType, 0, 0, 0); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - size = node[1]->GetInEdgesSizeByIndex(0); - ASSERT_EQ(size, 1); - - data_edge = node[1]->GetInDataEdgeByIndex(0); - ASSERT_EQ(data_edge, edge[0]); -} - -TEST_F(UtestExecuteGraph, GetDataEdgeInfo) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - - auto out_size = node[0]->GetAllOutEdgesSize(); - ASSERT_EQ(out_size, edge_num); - - auto nodes = node[0]->GetOutDataNodesByIndex(0); - ASSERT_EQ(nodes.size(), 1); - - auto size = node[1]->GetAllInEdgeSize(); - ASSERT_EQ(size, edge_num); - - auto vec_size = node[1]->GetAllInDataEdgesRef(); - ASSERT_EQ(vec_size.size(), edge_num); - - vec_size = node[1]->MutableAllInDataEdges(); - ASSERT_EQ(vec_size.size(), edge_num); - - auto node_vec = node[0]->GetOutDataNodes(); - ASSERT_EQ(node_vec.size(), edge_num); - - node_vec = node[0]->GetAllInNodes(); - ASSERT_EQ(node_vec.size(), 0); - - node_vec = node[1]->GetInDataNodes(); - ASSERT_EQ(node_vec.size(), edge_num); -} - -TEST_F(UtestExecuteGraph, GetAllEdgeInfo) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[edge_num] = {}; - FastEdge *control_edge = {}; - for (int i = 0; i < edge_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], i, node[1], i); - ASSERT_NE(edge[i], nullptr); - } - control_edge = compute_graph->AddEdge(node[0], kControlEdgeIndex, node[1], kControlEdgeIndex); - ASSERT_NE(control_edge, nullptr); - - auto out_size = node[0]->GetAllOutEdgesSize(); - ASSERT_EQ(out_size, edge_num + 1); - - auto nodes = node[0]->GetOutDataNodesByIndex(0); - ASSERT_EQ(nodes.size(), 1); - - auto size = node[1]->GetAllInEdgeSize(); - ASSERT_EQ(size, edge_num + 1); - - auto vec_size = node[1]->GetAllInDataEdgesRef(); - ASSERT_EQ(vec_size.size(), edge_num); - - vec_size = node[1]->MutableAllInDataEdges(); - ASSERT_EQ(vec_size.size(), edge_num); - - auto node_vec = node[0]->GetOutDataNodes(); - ASSERT_EQ(node_vec.size(), edge_num); - - node_vec = node[1]->GetAllInNodes(); - ASSERT_EQ(node_vec.size(), edge_num + 1); - - auto flag = node[1]->IsDirectlyControlledByNode(node[0]); - ASSERT_EQ(flag, true); - - node_vec = node[1]->GetInDataNodes(); - ASSERT_EQ(node_vec.size(), edge_num); -} - -TEST_F(UtestExecuteGraph, GetControlEdgeInfo) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[node_num] = {}; - for (int i = 1; i < node_num; ++i) { - edge[i] = compute_graph->AddEdge(node[0], kControlEdgeIndex, node[i], kControlEdgeIndex); - ASSERT_NE(edge[i], nullptr); - } - - auto out_edges = node[0]->GetOutEdgesRefByIndex(kControlEdgeIndex); - ASSERT_EQ(out_edges.size(), node_num - 1); - - auto out_control_nodes = node[0]->GetOutControlNodes(); - ASSERT_EQ(out_control_nodes.size(), node_num - 1); - - auto nodes = node[0]->GetOutControlNodes(); - ASSERT_EQ(nodes.size(), node_num - 1); - - auto in_control_nodes = node[1]->GetInControlNodes(); - ASSERT_EQ(in_control_nodes.size(), 1); -} - -TEST_F(UtestExecuteGraph, deepcopy) { - auto node_num = 10; - auto io_num = 10; - auto subgraph_num = 10; - auto subgraph_node_num = 10; - std::shared_ptr op_desc[node_num] = {nullptr}; - for (int j = 0; j < node_num; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddOutputDesc(td); - } - } - - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - FastNode *node[node_num] = {}; - FastEdge *edge[node_num] = {}; - ExecuteGraph *quick_graph[subgraph_num] = {nullptr}; - std::shared_ptr sub_op_desc[subgraph_num][subgraph_node_num] = {}; - for (int i = 0; i < subgraph_num; i++) { - for (int j = 0; j < subgraph_node_num; j++) { - sub_op_desc[i][j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int64_t x = 0; x < io_num; ++x) { - sub_op_desc[i][j]->AddInputDesc(td); - } - for (int64_t x = 0; x < io_num; ++x) { - sub_op_desc[i][j]->AddOutputDesc(td); - } - } - } - - auto root_graph = std::make_shared("root_graph"); - for (int i = 0; i < node_num; i++) { - node[i] = root_graph->AddNode(op_desc[i]); - ASSERT_NE(node[i], nullptr); - } - auto input_node = root_graph->AddInputNode(node[0]); - ASSERT_NE(input_node, nullptr); - - for (int i = 1; i < node_num; i++) { - edge[i] = root_graph->AddEdge(node[i], 1, node[i - 1], 0); - ASSERT_NE(edge[i], nullptr); - } - - for (int i = 0; i < subgraph_num; i++) { - std::string name = "subgraph_" + std::to_string(i); - sub_graph[i] = std::make_shared(name); - FastNode *sub_graph_node[subgraph_node_num] = {}; - for (int j = 0; j < subgraph_node_num; j++) { - sub_graph_node[j] = sub_graph[i]->AddNode(sub_op_desc[i][j]); - ASSERT_NE(sub_graph_node[j], nullptr); - } - for (int j = 1; j < subgraph_node_num; j++) { - auto ret = sub_graph[i]->AddEdge(sub_graph_node[j], 1, sub_graph_node[j - 1], 0); - ASSERT_NE(ret, nullptr); - } - } - - for (int i = 0; i < subgraph_num; ++i) { - sub_graph[i]->SetParentGraph(root_graph.get()); - sub_graph[i]->SetParentNode(node[i]); - std::string name = "subgraph_" + std::to_string(i); - quick_graph[i] = root_graph->AddSubGraph(sub_graph[i], name); - ASSERT_NE(quick_graph[i], nullptr); - } - - std::string name = "root_graph"; - auto test1_graph = std::make_shared(name); - test1_graph->CompleteCopy(*(root_graph.get())); -} - -TEST_F(UtestExecuteGraph, TopologicalSorting_ok_with_subgraph) { - auto graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - auto sub_graph1 = graph->GetSubGraph("sub1"); - EXPECT_NE(sub_graph1, nullptr); - EXPECT_EQ(sub_graph1->GetName(), "sub1"); - auto sub_graph2 = graph->GetSubGraph("sub2"); - EXPECT_NE(sub_graph2, nullptr); - EXPECT_EQ(sub_graph2->GetName(), "sub2"); -} - -TEST_F(UtestExecuteGraph, AddEdge_check) { - const auto create_op_func = []() -> OpDescPtr { - auto td = GeTensorDesc(); - auto op_desc = std::make_shared("op", "op"); - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - return op_desc; - }; - auto exe_graph0 = std::make_shared("exe_graph0"); - auto op_desc0 = create_op_func(); - auto node0 = exe_graph0->AddNode(op_desc0); - auto exe_graph1 = std::make_shared("exe_graph1"); - auto op_desc1 = create_op_func(); - auto node1 = exe_graph1->AddNode(op_desc1); - EXPECT_EQ(exe_graph0->AddEdge(node0, 0, node1, 1), nullptr); - - auto op_desc2 = create_op_func(); - auto node2 = exe_graph0->AddNode(op_desc2); - EXPECT_NE(exe_graph0->AddEdge(node0, 0, node2, 0), nullptr); - EdgeEndpointWithDirection eep0(node0, 0); - EdgeEndpointWithDirection eep1(node0, 0); - EdgeEndpointWithDirection eep2(node2, 0); - EXPECT_EQ(eep0 < eep0, false); - EXPECT_EQ(eep0 == eep2, false); - EXPECT_EQ(eep0 == eep1, true); -} - -TEST_F(UtestExecuteGraph, EdgeAndNodeOwner) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[node_num] = {}; - FastEdge *ctrl_edge[node_num] = {}; - for (int i = 1; i < node_num; ++i) { - edge[i] = compute_graph->AddEdge(node[i - 1], 0, node[i], 0); - ASSERT_NE(edge[i], nullptr); - auto flag = compute_graph->CheckEdgeIsInGraph(edge[i]); - ASSERT_EQ(flag, true); - - ctrl_edge[i] = compute_graph->AddEdge(node[i - 1], kControlEdgeIndex, node[i], kControlEdgeIndex); - ASSERT_NE(ctrl_edge[i], nullptr); - flag = compute_graph->CheckEdgeIsInGraph(ctrl_edge[i]); - ASSERT_EQ(flag, true); - } - - auto graph1 = std::make_shared("graph1"); - for (int i = 0; i < node_num; ++i) { - auto ret = graph1->AddNode(node[i]); - ASSERT_NE(ret, nullptr); - ASSERT_NE(ret->GetExtendInfo(), nullptr); - ASSERT_EQ(ret->GetExtendInfo()->GetOwnerGraphBarePtr(), graph1.get()); - auto flag = graph1->CheckNodeIsInGraph(ret); - ASSERT_EQ(flag, true); - flag = compute_graph->CheckNodeIsInGraph(ret); - ASSERT_EQ(flag, false); - } - - for (int i = 0; i < node_num; ++i) { - auto &edges = node[i]->GetAllInDataEdgesRef(); - for (auto edge : edges) { - if (edge != nullptr) { - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - - auto &ctl_edges = node[i]->GetAllInControlEdgesRef(); - for (auto edge : ctl_edges) { - if (edge != nullptr) { - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - - auto &out_data_edges = node[i]->GetAllOutDataEdgesRef(); - for (auto edges : out_data_edges) { - for (auto edge : edges) { - if (edge != nullptr) { - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - } - - auto &out_ctl_edges = node[i]->GetAllOutControlEdgesRef(); - for (auto edge : out_ctl_edges) { - if (edge != nullptr) { - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - } - - - for (int i = 0; i < node_num; ++i) { - auto &edges = node[i]->GetAllInDataEdgesRef(); - for (auto edge : edges) { - if (edge != nullptr) { - auto ret = graph1->MoveEdgeToGraph(edge); - ASSERT_EQ(ret, GRAPH_SUCCESS); - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, false); - flag = graph1->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - - auto &ctl_edges = node[i]->GetAllInControlEdgesRef(); - for (auto edge : ctl_edges) { - if (edge != nullptr) { - auto ret = graph1->MoveEdgeToGraph(edge); - ASSERT_EQ(ret, GRAPH_SUCCESS); - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, false); - flag = graph1->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - - auto &out_data_edges = node[i]->GetAllOutDataEdgesRef(); - for (auto edges : out_data_edges) { - for (auto edge : edges) { - if (edge != nullptr) { - auto ret = graph1->MoveEdgeToGraph(edge); - ASSERT_EQ(ret, GRAPH_SUCCESS); - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, false); - flag = graph1->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - } - - auto &out_ctl_edges = node[i]->GetAllOutControlEdgesRef(); - for (auto edge : out_ctl_edges) { - if (edge != nullptr) { - auto ret = graph1->MoveEdgeToGraph(edge); - ASSERT_EQ(ret, GRAPH_SUCCESS); - auto flag = compute_graph->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, false); - flag = graph1->CheckEdgeIsInGraph(edge); - ASSERT_EQ(flag, true); - } - } - } -} - -TEST_F(UtestExecuteGraph, SetEdgeOwnerFail) { - auto compute_graph = std::make_shared("graph"); - auto ret = compute_graph->MoveEdgeToGraph(nullptr); - ASSERT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraph, GetAllNodes_ok_with_filter) { - auto graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - const auto &nodes = graph->GetAllNodes(); - EXPECT_EQ(nodes.size(), 7U); - const auto &filter_func = [](const FastNode *node) ->bool { - return node->GetType() == "Data"; - }; - const auto &filted_nodes = graph->GetAllNodes(filter_func); - EXPECT_EQ(filted_nodes.size(), 3U); -} - -TEST_F(UtestExecuteGraph, SortNodes_Order_ok) { - auto builder = ge::ut::ExecuteGraphBuilder("test"); - auto data1 = builder.AddNode("data1", "Data", 0, 1); - auto node1 = builder.AddNode("node1", "Node1", 0, 1); - auto node2 = builder.AddNode("node2", "Node2", 0, 1); - auto data2 = builder.AddNode("data2", "Data", 0, 1); - auto data3 = builder.AddNode("data3", "Data", 0, 1); - auto node3 = builder.AddNode("node3", "Node3", 0, 1); - auto data4 = builder.AddNode("data4", "Data", 0, 1); - auto node4 = builder.AddNode("node4", "Node4", 4, 1); - builder.AddDataEdge(data1, 0, node4, 0); - builder.AddDataEdge(data2, 0, node4, 1); - builder.AddDataEdge(data3, 0, node4, 2); - builder.AddDataEdge(data4, 0, node4, 3); - builder.AddControlEdge(node1, node4); - builder.AddControlEdge(node2, node4); - builder.AddControlEdge(node3, node4); - - auto graph = builder.GetGraphBeforeTopo(); - ASSERT_NE(graph, nullptr); - - std::vector stack; - std::map map_in_edge_num; - auto ret = graph->SortNodes(stack, map_in_edge_num); - ASSERT_EQ(ret, GRAPH_SUCCESS); - vector expect_node_type = {"node3", "node2", "node1", "data4", "data3", "data2", "data1"}; - for (size_t i = 0UL; i < stack.size(); ++i) { - EXPECT_EQ(stack[i]->GetName(), expect_node_type[i]); - auto it = map_in_edge_num.find(stack[i]); - EXPECT_NE(it, map_in_edge_num.end()); - EXPECT_EQ(it->second, 0); - } - auto it1 = map_in_edge_num.find(node4); - EXPECT_NE(it1, map_in_edge_num.end()); - EXPECT_EQ(it1->second, 7); // 4 data edge + 3 control edge -} -} // namespace ge diff --git a/tests/ut/graph/testcase/execute_graph_utils_unittest.cc b/tests/ut/graph/testcase/execute_graph_utils_unittest.cc deleted file mode 100644 index 2e81ccb28dece88b6900016f9f0799c79b331f86..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/execute_graph_utils_unittest.cc +++ /dev/null @@ -1,488 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "inc/graph/utils/execute_graph_utils.h" -#include "inc/graph/utils/fast_node_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_log.h" -#include "graph_builder_utils.h" -#include "share_graph.h" - -namespace ge { - -using ExecuteSharedGraph = SharedGraph; - -class UtestExecuteGraphUtils : public testing::Test { - public: - void SetUp() override {} - void TearDown() override {} -}; - -TEST_F(UtestExecuteGraphUtils, FindNodeFromAllNodes_Ok_GraphIsNull) { - const auto node = ExecuteGraphUtils::FindNodeFromAllNodes(nullptr, "graph_name"); - EXPECT_EQ(node, nullptr); -} - -TEST_F(UtestExecuteGraphUtils, FindNodeFromAllNodes_Ok_NameIsNull) { - auto builder = ut::ExecuteGraphBuilder("Test1"); - auto graph = builder.GetGraph(); - auto node = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), nullptr); - EXPECT_EQ(node, nullptr); -} - -TEST_F(UtestExecuteGraphUtils, FindNodeFromAllNodes_Ok_TryFindNodeInSubGraph) { - const auto graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - const auto node = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "data1"); - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), "data1"); -} - -TEST_F(UtestExecuteGraphUtils, FindNodesByTypeFromAllNodes_Ok_FindAllDataNodes) { - const auto graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - const auto nodes = ExecuteGraphUtils::FindNodesByTypeFromAllNodes(graph.get(), "Data"); - EXPECT_EQ(nodes.size(), 3); -} - -TEST_F(UtestExecuteGraphUtils, FindFirstNodeMatchType_OK) { - const auto root_graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - auto node = ExecuteGraphUtils::FindFirstNodeMatchType(root_graph.get(), "Data"); - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), "data0"); - node = ExecuteGraphUtils::FindFirstNodeMatchType(root_graph.get(), "Case"); - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), "case0"); - node = ExecuteGraphUtils::FindFirstNodeMatchType(root_graph.get(), "Relu"); - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), "relu0"); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeAfter_Fail_NodesInDifferentGraph) { - auto graph_builder0 = ut::ExecuteGraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - - auto graph_builder1 = ut::ExecuteGraphBuilder("test_graph1"); - const auto &node1 = graph_builder1.AddNode("data1", DATA, 1, 1); - const auto &graph1 = graph_builder1.GetGraph(); - ASSERT_EQ(ExecuteGraphUtils::InsertNodeAfter({node0, 0}, {}, node1, 0, 0), PARAM_INVALID); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeAfter_Fail_AddEdgeFail) { - auto graph_builder0 = ut::ExecuteGraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - EdgeSrcEndpoint src = {node0, 0}; - std::vector dsts; - dsts.emplace_back(node0, 0); - // case1: input_index exceeds the size of in edges - int ret = ExecuteGraphUtils::InsertNodeAfter(src, dsts, node0, 1, 0); - EXPECT_NE(ret, GRAPH_SUCCESS); - - // case2: output_index exceeds the size of out edges - int ret2 = ExecuteGraphUtils::InsertNodeAfter(src, dsts, node0, 0, 1); - EXPECT_NE(ret2, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeAfter_Ok_NodeTypeIsSwitch) { - auto graph_builder0 = ut::ExecuteGraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", SWITCH, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - EdgeSrcEndpoint src = {node0, 0}; - std::vector dsts; - dsts.emplace_back(node0, 0); - int ret = ExecuteGraphUtils::InsertNodeAfter(src, dsts, node0, 0, 0); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeAfter_Fail_SrcOwnerGraphNotEqualDstOwnerGraph) { - auto graph_builder0 = ut::ExecuteGraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - - auto graph_builder1 = ut::ExecuteGraphBuilder("test_graph1"); - const auto &node1 = graph_builder1.AddNode("data1", DATA, 1, 1); - const auto &graph1 = graph_builder1.GetGraph(); - EdgeSrcEndpoint src = {node0, 0}; - std::vector dsts; - dsts.emplace_back(node1, 0); - int ret = ExecuteGraphUtils::InsertNodeAfter(src, dsts, node0, 0, 0); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeAfter_Ok_GraphWithControlEdge) { - auto graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto src_node = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n1"); - EdgeSrcEndpoint src = {src_node, 0}; - std::vector dsts; - for (const auto out_edge : src_node->GetOutEdgesByIndex(0)) { - dsts.emplace_back(FastNodeUtils::GetDstEndpoint(out_edge)); - } - auto builder = ut::ExecuteGraphBuilder("test_graph"); - auto node_to_insert = builder.AddNode("inserted_node", "Op", 1, 3); - (void) node_to_insert->GetExtendInfo()->SetOwnerGraph(graph.get(), node_to_insert); - graph->AddNode(node_to_insert); - auto original_edge_size = graph->GetAllEdges().size(); - EXPECT_EQ(ExecuteGraphUtils::InsertNodeAfter(src, dsts, node_to_insert, 0, 0), GRAPH_SUCCESS); - EXPECT_EQ(graph->GetAllEdges().size(), original_edge_size + 1); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeAfter_Ok_will_null_out_control_edge) { - auto builder = ut::ExecuteGraphBuilder("test"); - const auto &data = builder.AddNode("data", "Data", 0, 1); - const auto &n1 = builder.AddNode("n1", "Op", 1, 1); - const auto &n2 = builder.AddNode("n2", "Op", 1, 1); - const auto &insert_node = builder.AddNode("insert_node", "Op", 1, 1); - - builder.AddDataEdge(data, 0, n1, 0); - builder.AddDataEdge(n1, 0, n2, 0); - builder.AddControlEdge(n1, n2); - auto graph = builder.GetGraph(); - for (const auto edge : n2->GetAllInControlEdgesRef()) { - if (edge != nullptr) { - graph->RemoveEdge(edge); - } - } - EXPECT_EQ(ExecuteGraphUtils::InsertNodeAfter({n1, 0}, {{n2, 0}}, insert_node, 0, 0), GRAPH_SUCCESS); - EXPECT_EQ(n2->GetInDataNodes().at(0)->GetName(), "insert_node"); - EXPECT_EQ(n2->GetInControlNodes().size(), 0); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeBefore_Fail_GetOwnerComputeGraphFail) { - auto graph_builder0 = ut::ExecuteGraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - - auto graph_builder1 = ut::ExecuteGraphBuilder("test_graph1"); - const auto &node1 = graph_builder1.AddNode("data1", DATA, 1, 1); - const auto &graph1 = graph_builder1.GetGraph(); - - int ret = ExecuteGraphUtils::InsertNodeBefore({node0, 0}, node1, 0, 0); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeBefore_Fail_OutputIndexOutOfBounds) { - auto builder = ut::ExecuteGraphBuilder("test"); - const auto &var = builder.AddNode("var", VARIABLE, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 1, 1); - const auto &allreduce = builder.AddNode("allreduce", "HcomAllReduce", 1, 1); - const auto &atomic_clean = builder.AddNode("atomic_clean", ATOMICADDRCLEAN, 0, 0); - const auto &netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - const auto &identity = builder.AddNode("identity", "Identity", 1, 1); - - builder.AddDataEdge(var, 0, assign, 0); - builder.AddDataEdge(var, 0, allreduce, 0); - builder.AddDataEdge(allreduce, 0, netoutput1, 0); - builder.AddControlEdge(assign, allreduce); - builder.AddControlEdge(atomic_clean, allreduce); - auto graph = builder.GetGraph(); - - int ret = ExecuteGraphUtils::InsertNodeBefore({allreduce, 0}, identity, 0, 5); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -/* -* var var -* atomicclean | \ | \ -* c\ | assign | assign -* \ | c/ =======> | c/ -* allreduce identity atomicclean -* | | c/ -* netoutput allreduce -* | -* netoutput - */ -TEST_F(UtestExecuteGraphUtils, InsertNodeBefore_Ok_DoNotMoveCtrlEdgeFromAtomicClean) { - auto builder = ut::ExecuteGraphBuilder("test"); - const auto &var = builder.AddNode("var", VARIABLE, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 1, 1); - const auto &allreduce = builder.AddNode("allreduce", "HcomAllReduce", 1, 1); - const auto &atomic_clean = builder.AddNode("atomic_clean", ATOMICADDRCLEAN, 0, 0); - const auto &netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - const auto &identity = builder.AddNode("identity", "Identity", 1, 1); - - builder.AddDataEdge(var, 0, assign, 0); - builder.AddDataEdge(var, 0, allreduce, 0); - builder.AddDataEdge(allreduce, 0, netoutput1, 0); - builder.AddControlEdge(assign, allreduce); - builder.AddControlEdge(atomic_clean, allreduce); - auto graph = builder.GetGraph(); - - EXPECT_EQ(ExecuteGraphUtils::InsertNodeBefore({allreduce, 0}, identity, 0, 0), GRAPH_SUCCESS); - EXPECT_EQ(identity->GetInControlNodes().at(0)->GetName(), "assign"); -} - -TEST_F(UtestExecuteGraphUtils, InsertNodeBefore_Ok_will_null_in_control_edge) { - auto builder = ut::ExecuteGraphBuilder("test"); - const auto &data = builder.AddNode("data", "Data", 0, 1); - const auto &n1 = builder.AddNode("n1", "Op", 1, 1); - const auto &n2 = builder.AddNode("n2", "Op", 1, 1); - const auto &insert_node = builder.AddNode("insert_node", "Op", 1, 1); - - builder.AddDataEdge(data, 0, n1, 0); - builder.AddDataEdge(n1, 0, n2, 0); - builder.AddControlEdge(n1, n2); - auto graph = builder.GetGraph(); - for (const auto edge : n2->GetAllInControlEdgesRef()) { - if (edge != nullptr) { - graph->RemoveEdge(edge); - } - } - EXPECT_EQ(ExecuteGraphUtils::InsertNodeBefore({n2, 0}, insert_node, 0, 0), GRAPH_SUCCESS); - EXPECT_EQ(n2->GetInDataNodes().at(0)->GetName(), "insert_node"); - EXPECT_EQ(n2->GetInControlNodes().size(), 0); -} - -TEST_F(UtestExecuteGraphUtils, CopyInCtrlEdges_Fail_NodeIsNull) { - auto builder = ut::ExecuteGraphBuilder("test"); - const auto &src_node = builder.AddNode("src_node", "node", 1, 1); - int ret = ExecuteGraphUtils::MoveInCtrlEdges(src_node, nullptr); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(UtestExecuteGraphUtils, CopyInCtrlEdges_Ok_SrcCtrlInNodesIsEmpty) { - auto builder = ut::ExecuteGraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("data0", "data", 1, 1); - auto dst_node = builder.AddNode("data1", "data", 1, 1); - int ret = ExecuteGraphUtils::MoveInCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, CopyInCtrlEdges_Ok) { - auto builder = ut::ExecuteGraphBuilder("test"); - const auto &in_ctrl_node1 = builder.AddNode("in_ctrl_node1", "node", 1, 1); - const auto &in_ctrl_node2 = builder.AddNode("in_ctrl_node2", "node", 1, 1); - const auto &src_node = builder.AddNode("src_node", "node", 1, 1); - auto dst_node = builder.AddNode("dst_node", "node", 1, 1); - builder.AddDataEdge(src_node, 0, dst_node, 0); - builder.AddControlEdge(src_node, dst_node); - builder.AddControlEdge(in_ctrl_node1, src_node); - builder.AddControlEdge(in_ctrl_node2, dst_node); - const auto graph = builder.GetGraph(); - ASSERT_EQ(ExecuteGraphUtils::MoveInCtrlEdges(src_node, dst_node), GRAPH_SUCCESS); - EXPECT_EQ(graph->GetAllEdges().size(), 4); - EXPECT_EQ(dst_node->GetAllInControlEdgesSize(), 3); - EXPECT_EQ(dst_node->GetInControlNodes().back()->GetName(), "in_ctrl_node1"); -} - -TEST_F(UtestExecuteGraphUtils, MoveInCtrlEdges_Fail_NodeIsNull) { - int ret = ExecuteGraphUtils::MoveInCtrlEdges(nullptr, nullptr); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(UtestExecuteGraphUtils, MoveInCtrlEdges_Ok) { - auto builder = ut::ExecuteGraphBuilder("test"); - const auto &in_ctrl_node1 = builder.AddNode("in_ctrl_node1", "node", 1, 1); - const auto &in_ctrl_node2 = builder.AddNode("in_ctrl_node2", "node", 1, 1); - const auto &src_node = builder.AddNode("src_node", "node", 1, 1); - auto dst_node = builder.AddNode("dst_node", "node", 1, 1); - builder.AddDataEdge(src_node, 0, dst_node, 0); - builder.AddControlEdge(src_node, dst_node); - builder.AddControlEdge(in_ctrl_node1, src_node); - builder.AddControlEdge(in_ctrl_node2, src_node); - const auto graph = builder.GetGraph(); - const auto original_edge_size = graph->GetAllEdges().size(); - ASSERT_EQ(ExecuteGraphUtils::MoveInCtrlEdges(src_node, dst_node), GRAPH_SUCCESS); - // move control edge does not change edge size of graph - EXPECT_TRUE(graph->GetAllEdges().size() == original_edge_size); - EXPECT_EQ(dst_node->GetAllInControlEdgesSize(), 3); - EXPECT_EQ(dst_node->GetInControlNodes()[1]->GetName(), "in_ctrl_node1"); - EXPECT_EQ(dst_node->GetInControlNodes()[2]->GetName(), "in_ctrl_node2"); -} - -TEST_F(UtestExecuteGraphUtils, CopyOutCtrlEdges_Fail_NodeIsNull) { - int ret = ExecuteGraphUtils::MoveOutCtrlEdges(nullptr, nullptr); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(UtestExecuteGraphUtils, CopyOutCtrlEdges_Fail_OutCtrlNodesIsEmpty) { - auto builder = ut::ExecuteGraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("data0", "data", 1, 1); - auto dst_node = builder.AddNode("data1", "data", 1, 1); - int ret = ExecuteGraphUtils::MoveOutCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, CopyOutCtrlEdges_Ok) { - auto builder = ut::ExecuteGraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("src_node", NETOUTPUT, 1, 1); - const auto &ctrl_node = builder.AddNode("ctrl_node", CONSTANT, 0, 0); - const auto &ctrl_node2 = builder.AddNode("ctrl_node2", CONSTANT, 0, 0); - auto dst_node = builder.AddNode("dst_node", NETOUTPUT, 1, 1); - auto graph = builder.GetGraph(); - // 疑问: 是否会出现自依赖的情况? - // builder.AddControlEdge(src_node, dst_node); - builder.AddControlEdge(src_node, ctrl_node); - builder.AddControlEdge(src_node, ctrl_node2); - - EXPECT_EQ(ExecuteGraphUtils::MoveOutCtrlEdges(src_node, dst_node), GRAPH_SUCCESS); - // EXPECT_EQ(dst_node->GetOutControlNodes().size(), src_node->GetOutControlNodes().size()); - EXPECT_EQ(dst_node->GetOutControlNodes().at(1U), ctrl_node2); -} - -TEST_F(UtestExecuteGraphUtils, MoveOutCtrlEdges_Fail_NodeIsNull) { - int ret = ExecuteGraphUtils::MoveOutCtrlEdges(nullptr, nullptr); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(UtestExecuteGraphUtils, MoveOutCtrlEdges_Ok) { - auto graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto src_node = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n1"); - auto dst_node = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n4"); - int ret = ExecuteGraphUtils::MoveOutCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(src_node->GetAllOutControlEdgesSize(), 0); - EXPECT_EQ(dst_node->GetAllOutControlEdgesSize(), 2); -} - -TEST_F(UtestExecuteGraphUtils, MoveNodeToGraph_Ok_MoveNodeWithSubGraph) { - auto src_graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - auto dst_graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - // find a node with subgraph - auto node_to_move = ExecuteGraphUtils::FindNodeFromAllNodes(src_graph.get(), "case0"); - auto src_direct_node_size = src_graph->GetDirectNodesSize(); - auto dst_direct_node_size = dst_graph->GetDirectNodesSize(); - EXPECT_EQ(ExecuteGraphUtils::MoveNodeToGraph(node_to_move, dst_graph.get()), GRAPH_SUCCESS); - EXPECT_EQ(dst_graph->GetDirectNodesSize(), dst_direct_node_size + 1); - EXPECT_EQ(src_graph->GetDirectNodesSize(), src_direct_node_size - 1); - EXPECT_EQ(src_graph->GetAllSubgraphs().size(), 0); -} - -TEST_F(UtestExecuteGraphUtils, ReplaceEdgeSrc_Ok) { - auto graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto n2 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n2"); - EdgeSrcEndpoint src = {n2, 0}; - auto n3 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n3"); - auto old_edge = n3->GetInDataEdgeByIndex(0); - EXPECT_EQ(ExecuteGraphUtils::ReplaceEdgeSrc(old_edge, src), GRAPH_SUCCESS); - auto new_edges = n3->GetInDataEdgeByIndex(0); - EXPECT_NE(new_edges, nullptr); - auto src_node = new_edges->src; - EXPECT_EQ(src_node, n2); - auto dst_node = new_edges->dst; - EXPECT_EQ(dst_node, n3); -} - -TEST_F(UtestExecuteGraphUtils, ReplaceEdgeSrc_Fail_null_node) { - auto graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - EdgeSrcEndpoint src = {nullptr, 0}; - auto n3 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n3"); - auto old_edge = n3->GetInDataEdgeByIndex(0); - EXPECT_EQ(ExecuteGraphUtils::ReplaceEdgeSrc(old_edge, src), GRAPH_FAILED); - EXPECT_EQ(ExecuteGraphUtils::ReplaceEdgeSrc(nullptr, src), PARAM_INVALID); -} - -TEST_F(UtestExecuteGraphUtils, FindRootGraph_Ok_subgraph) { - auto graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - auto sub_graph = graph->GetSubGraph("sub1"); - auto root_graph = ExecuteGraphUtils::FindRootGraph(sub_graph); - EXPECT_EQ(root_graph, graph.get()); -} - -TEST_F(UtestExecuteGraphUtils, FindRootGraph_Ok_root_graph) { - auto graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - auto root_graph = ExecuteGraphUtils::FindRootGraph(graph.get()); - EXPECT_EQ(root_graph, graph.get()); -} - -TEST_F(UtestExecuteGraphUtils, FindRootGraph_Ok_null_param) { - auto root_graph = ExecuteGraphUtils::FindRootGraph(nullptr); - EXPECT_EQ(root_graph, nullptr); -} - -TEST_F(UtestExecuteGraphUtils, ReplaceNodeEdges_Ok_replace_data_control_edge) { - auto graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto n4 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n4"); - auto builder = ut::ExecuteGraphBuilder("test_graph"); - auto new_node = builder.AddNode("inserted_node", "Op", 1, 1); - (void) new_node->GetExtendInfo()->SetOwnerGraph(graph.get(), new_node); - graph->AddNode(new_node); - EXPECT_EQ(ExecuteGraphUtils::ReplaceNodeEdges(new_node, n4, {0}, {0}), GRAPH_SUCCESS); - EXPECT_EQ(n4->GetAllInDataEdgesSize(), 0U); - EXPECT_EQ(n4->GetAllOutDataEdgesSize(), 0U); - EXPECT_EQ(new_node->GetAllInDataEdgesSize(), 1U); - EXPECT_EQ(new_node->GetAllInControlEdgesSize(), 1U); - EXPECT_EQ(new_node->GetAllOutDataEdgesSize(), 1U); - EXPECT_EQ(new_node->GetAllOutControlEdgesSize(), 1U); -} - -TEST_F(UtestExecuteGraphUtils, IsolateNode_Ok_relink_data_control_edge_with_iomap) { - auto graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto n4 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n4"); - EXPECT_EQ(ExecuteGraphUtils::IsolateNode(n4, {0}), GRAPH_SUCCESS); - EXPECT_EQ(n4->GetAllInDataEdgesSize(), 0U); - EXPECT_EQ(n4->GetAllInControlEdgesSize(), 0U); - EXPECT_EQ(n4->GetAllOutDataEdgesSize(), 0U); - EXPECT_EQ(n4->GetAllOutControlEdgesSize(), 0U); - auto n1 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n1"); - auto n5 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n5"); - auto edge = n5->GetInDataEdgeByIndex(2); - EXPECT_EQ(edge->src, n1); - EXPECT_EQ(edge->src_output, 0); -} - -TEST_F(UtestExecuteGraphUtils, IsolateNode_Ok_relink_data_control_edge_without_iomap) { - auto graph = ExecuteSharedGraph::BuildGraphWithControlEdge(); - auto n4 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n4"); - EXPECT_EQ(ExecuteGraphUtils::IsolateNode(n4, {}), GRAPH_SUCCESS); - EXPECT_EQ(n4->GetAllInDataEdgesSize(), 0U); - EXPECT_EQ(n4->GetAllInControlEdgesSize(), 0U); - EXPECT_EQ(n4->GetAllOutDataEdgesSize(), 0U); - EXPECT_EQ(n4->GetAllOutControlEdgesSize(), 0U); - auto n5 = ExecuteGraphUtils::FindNodeFromAllNodes(graph.get(), "n5"); - EXPECT_EQ(n5->GetInDataEdgeByIndex(2), nullptr); -} - -TEST_F(UtestExecuteGraphUtils, IsolateNode_Ok_connected_data_out) { - auto builder = ut::ExecuteGraphBuilder("test_graph0"); - auto n1 = builder.AddNode("n1", "Data", 1, 1); - auto n2 = builder.AddNode("n2", "Op", 1, 1); - auto n3 = builder.AddNode("n3", "Op", 1, 1); - builder.AddDataEdge(n1, 0, n2, 0); - builder.AddDataEdge(n2, 0, n3, 0); - builder.AddControlEdge(n1, n3); - EXPECT_EQ(ExecuteGraphUtils::IsolateNode(n2, {}), GRAPH_SUCCESS); - //EXPECT_EQ(n3->GetAllInDataEdgesSize(), 1); - EXPECT_EQ(n3->GetAllInControlEdgesSize(), 1); -} - -TEST_F(UtestExecuteGraphUtils, RemoveSubgraphRecursively_Fail_NodeToRemoveIsInvalid) { - auto graph = ExecuteSharedGraph::BuildGraphWithConst(); - auto builder = ut::ExecuteGraphBuilder("wild_graph"); - // case1: node has null owner graph - auto invalid_node = builder.AddNode("invalid", "Data", 1, 1); - ExecuteGraph *null_graph = nullptr; - invalid_node->GetExtendInfo()->SetOwnerGraph(null_graph, invalid_node); - EXPECT_EQ(ExecuteGraphUtils::RemoveSubgraphRecursively(graph.get(), invalid_node), GRAPH_SUCCESS); - - // case2: node is not in graph - auto wild_node = builder.AddNode("wild", "Data", 1, 1); - EXPECT_EQ(ExecuteGraphUtils::RemoveSubgraphRecursively(graph.get(), wild_node), GRAPH_SUCCESS); - - // case3: node without subgraph - auto node_without_subg = graph->GetDirectNode().front(); - EXPECT_EQ(ExecuteGraphUtils::RemoveSubgraphRecursively(graph.get(), node_without_subg), GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, ReplaceNodeDataEdges_Fail_GraphIsNull) { - auto builder = ut::ExecuteGraphBuilder("demo_graph"); - auto node1 = builder.AddNode("data", "Data", 1, 1); - auto node2 = builder.AddNode("data2", "Data", 1, 1); - EXPECT_EQ(ExecuteGraphUtils::ReplaceNodeDataEdges(node1, node2, {}, {}, nullptr), GRAPH_SUCCESS); -} - -TEST_F(UtestExecuteGraphUtils, GetNodeMapFromAllNodes_ok) { - const auto root_graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - auto node_name_to_nodes = ExecuteGraphUtils::GetNodeMapFromAllNodes(root_graph.get()); - EXPECT_EQ(node_name_to_nodes.size(), 7U); -} - -} // namespace ge diff --git a/tests/ut/graph/testcase/fast_graph_impl_unittest.cc b/tests/ut/graph/testcase/fast_graph_impl_unittest.cc deleted file mode 100644 index b9fcc1134e4ad7ecdf0587940e35cec00b662b7c..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/fast_graph_impl_unittest.cc +++ /dev/null @@ -1,201 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph_builder_utils.h" -#include "fast_graph/fast_graph_impl.h" - -namespace ge { -class UtestFastGraphImpl : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestFastGraphImpl, testnodes) { - auto root_graph = std::make_shared("root_graph"); - auto compute_graph = std::make_shared>("Hello World."); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - ASSERT_EQ(compute_graph->GetAllNodeInfo().size(), node_num); - ASSERT_EQ(compute_graph->GetAllNodeInfoForModify().size(), node_num); - ASSERT_EQ(compute_graph->GetRawDirectNode().size(), node_num); - ASSERT_EQ(compute_graph->GetRawDirectNode().size(), node_num); -} - -TEST_F(UtestFastGraphImpl, SetNodes) { - auto root_graph = std::make_shared("root_graph"); - auto compute_graph = std::make_shared>("Hello World."); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - auto quick_node = FastGraphUtils::GetListElementAddr(node[i]); - compute_graph->RemoveJustNode(quick_node); - ASSERT_NE(node[i], nullptr); - } - - std::vector nodes; - for (int i = 0; i < node_num; ++i) { - auto quick_node = FastGraphUtils::GetListElementAddr(node[i]); - FastGraphUtils::GetOwner(quick_node)->erase(quick_node); - nodes.push_back(node[i]); - } - auto ret = compute_graph->SetNodes(nodes); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - nodes.push_back(nullptr); - ret = compute_graph->SetNodes(nodes); - ASSERT_EQ(ret, PARAM_INVALID); -} - -TEST_F(UtestFastGraphImpl, FastGraphImpl) { - auto compute_graph = std::make_shared>("Hello World."); - auto root_graph = std::make_shared("graph"); - auto sub_graph = std::make_shared("graph"); - ExecuteGraph *quick_graph = root_graph->AddSubGraph(sub_graph); - ASSERT_NE(quick_graph, nullptr); - - auto ret = root_graph->RemoveSubGraph(quick_graph); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - std::vector sub_graphs{quick_graph}; - compute_graph->SetSubGraph(sub_graphs); - auto size = compute_graph->GetAllSubGraphSize(); - ASSERT_EQ(size, 1); -} - -TEST_F(UtestFastGraphImpl, TestIO) { - auto root_graph = std::make_shared("graph"); - auto compute_graph = std::make_shared>("Hello World."); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - std::vector> out_nodes_info; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - auto input = compute_graph->AddInputNode(node[0]); - ASSERT_NE(input, nullptr); - auto output = compute_graph->AddOutputNodeByIndex(node[node_num - 1], 0); - ASSERT_NE(output, nullptr); - - auto inputs = compute_graph->GetAllInputNodeInfo(); - ASSERT_EQ(inputs.size(), 1); - ASSERT_EQ(inputs[0], input); - - auto outputs = compute_graph->GetAllOutNodeInfo(); - ASSERT_EQ(outputs.size(), 1); - ASSERT_EQ(outputs[0].first, output); - - out_nodes_info.push_back(std::make_pair(node[node_num - 2], 0)); - compute_graph->SetGraphOutNodesInfo(out_nodes_info); - - output = compute_graph->AddOutputNodeByIndex(node[node_num - 1], 0); - ASSERT_NE(output, nullptr); - - outputs = compute_graph->GetAllOutNodeInfo(); - ASSERT_EQ(outputs.size(), 2); - ASSERT_EQ(outputs[0].first, node[node_num - 2]); - - inputs = compute_graph->GetInputNodes(); - ASSERT_EQ(inputs.size(), 1); - - outputs = compute_graph->GetAllOutNodes(); - ASSERT_EQ(outputs.size(), 2); - - auto size = compute_graph->GetDirectNodesSize(); - ASSERT_EQ(size, node_num); - - bool flag = compute_graph->CheckNodeIsInGraph(node[0]); - ASSERT_EQ(flag, true); - - compute_graph->InValid(); - flag = compute_graph->IsValid(); - ASSERT_EQ(flag, false); - - auto ret = compute_graph->ClearNode([](QuickNode *quick_node) { return GRAPH_SUCCESS; }); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto compute_graph1 = new FastGraphImpl("Hello World 1."); - compute_graph->Swap(*compute_graph1); - ASSERT_EQ(compute_graph->GetName(), "Hello World 1."); - - delete compute_graph1; -} - -TEST_F(UtestFastGraphImpl, Invalid) { - auto root_graph = std::make_shared("root_graph"); - auto compute_graph = std::make_shared>("Hello World."); - auto ret = compute_graph->RemoveInputNode(nullptr); - ASSERT_EQ(ret, GRAPH_FAILED); - - auto pointer = compute_graph->AddOutputNodeByIndex(nullptr, 0); - ASSERT_EQ(pointer, nullptr); - - ret = compute_graph->RemoveOutputNode(nullptr); - ASSERT_EQ(ret, GRAPH_FAILED); - - pointer = compute_graph->AddNode(nullptr); - ASSERT_EQ(pointer, nullptr); - - pointer = compute_graph->AddNodeFront(nullptr); - ASSERT_EQ(pointer, nullptr); - - ret = compute_graph->RemoveJustNode(nullptr); - ASSERT_EQ(ret, GRAPH_FAILED); - - ret = compute_graph->RecycleQuickNode(nullptr); - ASSERT_EQ(ret, GRAPH_FAILED); - - ret = compute_graph->RecycleQuickEdge(nullptr); - ASSERT_EQ(ret, GRAPH_FAILED); - - auto edge = compute_graph->AddEdge(nullptr, 0, nullptr, 1); - ASSERT_EQ(edge, nullptr); - - auto node1 = new FastNode(); - auto node2 = new FastNode(); - edge = compute_graph->AddEdge(node1, -1, node2, 1); - delete node1; - delete node2; - ASSERT_EQ(edge, nullptr); - - ret = compute_graph->RemoveEdge(nullptr); - ASSERT_EQ(ret, GRAPH_FAILED); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/fast_node_unittest.cc b/tests/ut/graph/testcase/fast_node_unittest.cc deleted file mode 100644 index 0f4a8c1816ee263e72e7a5fa0f4240628ae68816..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/fast_node_unittest.cc +++ /dev/null @@ -1,614 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/fast_graph/execute_graph.h" -#include "graph/ge_local_context.h" -#include "graph_builder_utils.h" -#include "fast_graph/fast_graph_impl.h" -#include "graph/fast_graph/fast_node.h" -#include "fast_graph/fast_graph_utils.h" - -namespace { -std::shared_ptr BuildDelayTopoGraphMultiInput( - const std::string &name, std::unordered_map &name_to_nodes, - bool all_is_log_life = true) { - auto builder = ge::ut::ExecuteGraphBuilder(name); - const auto &constant = builder.AddNode("const", ge::CONSTANT, 0, 1); - auto type = ge::CONSTANTOP; - if (!all_is_log_life) { - type = "test"; - } - const auto &constantop = builder.AddNode("constant", type, 0, 1); - const auto &variable = builder.AddNode("variable", ge::VARIABLE, 0, 2); - const auto &node1 = builder.AddNode("node1", "node1", 3, 1); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - const auto &node3 = builder.AddNode("node3", "node3", 1, 1); - const auto &node4 = builder.AddNode("node4", "node4", 1, 1); - const auto &node5 = builder.AddNode("node5", "node5", 3, 0); - const auto &data = builder.AddNode("data", "DATA", 0, 1); - - name_to_nodes.insert(std::make_pair("const", constant)); - name_to_nodes.insert(std::make_pair("constant", constantop)); - name_to_nodes.insert(std::make_pair("variable", variable)); - name_to_nodes.insert(std::make_pair("node1", node1)); - name_to_nodes.insert(std::make_pair("node2", node2)); - name_to_nodes.insert(std::make_pair("node3", node3)); - name_to_nodes.insert(std::make_pair("node4", node4)); - name_to_nodes.insert(std::make_pair("node5", node5)); - name_to_nodes.insert(std::make_pair("data", data)); - - int32_t dst_idx = 2; - builder.AddDataEdge(constant, 0, node1, 0); - builder.AddDataEdge(constantop, 0, node1, 1); - builder.AddDataEdge(variable, 0, node1, dst_idx); - builder.AddDataEdge(variable, 1, node2, 0); - builder.AddDataEdge(node1, 0, node5, 0); - builder.AddDataEdge(node2, 0, node5, 1); - builder.AddDataEdge(data, 0, node3, 0); - builder.AddDataEdge(node3, 0, node4, 0); - builder.AddDataEdge(node4, 0, node5, dst_idx); - - builder.AddControlEdge(node2, node3); - return builder.GetGraph(); -} -} // namespace - -namespace ge { -class UtestFastNode : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestFastNode, NodeToken) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - auto token = node[0]->GetNodeToken(); - ASSERT_NE(token, 0); - - auto quick_node = compute_graph->FindNode(token); - ASSERT_NE(quick_node, nullptr); -} - -TEST_F(UtestFastNode, NodeIoOper) { - auto compute_graph = std::make_shared("graph"); - int node_num = 1; - int edge_num = 1; - OpDescPtr op_desc; - for (int i = 0; i < node_num; ++i) { - op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - ASSERT_NE(op_desc, nullptr); - } - - auto node = new FastNode(); - auto res = node->Init(op_desc); - ASSERT_EQ(res, GRAPH_SUCCESS); - node->GetExtendInfo()->SetOwnerGraph(compute_graph.get(), node); - - int invalid_num = 10; - auto ret = node->GetOutEdgesByIndex(invalid_num); - ASSERT_EQ(ret.size(), 0); - - ret = node->GetAllInControlEdgesRef(); - ASSERT_EQ(ret.size(), 0); - - ret = node->GetAllInControlEdges(); - ASSERT_EQ(ret.size(), 0); - - ret = node->GetOutEdgesRefByIndex(invalid_num); - ASSERT_EQ(ret.size(), 0); - - auto size = node->GetDataOutNum(); - ASSERT_EQ(size, 1); - - size = node->GetDataInNum(); - ASSERT_EQ(size, 1); - - int new_num = 2; - node->UpdateDataInNum(new_num); - size = node->GetDataInNum(); - ASSERT_EQ(size, new_num); - - node->UpdateDataOutNum(new_num); - size = node->GetDataOutNum(); - ASSERT_EQ(size, new_num); - - delete node; -} - -TEST_F(UtestFastNode, GetEdgesOfNode) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[node_num] = {}; - for (int i = 1; i < node_num; ++i) { - edge[i] = compute_graph->AddEdge(node[i - 1], 0, node[i], 0); - ASSERT_NE(edge[i], nullptr); - auto edge = compute_graph->AddEdge(node[i - 1], -1, node[i], -1); - ASSERT_NE(edge, nullptr); - } - - auto edges = node[1]->GetAllInDataEdges(); - auto size = node[1]->GetAllInDataEdgesSize(); - ASSERT_EQ(edges.size(), size); - - edges = node[1]->GetAllInControlEdges(); - size = node[1]->GetAllInControlEdgesSize(); - ASSERT_EQ(edges.size(), size); - - edges = node[1]->GetAllOutControlEdges(); - size = node[1]->GetAllOutControlEdgesSize(); - ASSERT_EQ(edges.size(), size); - - edges = node[1]->GetAllOutDataEdges(); - size = node[1]->GetAllOutDataEdgesSize(); - ASSERT_EQ(edges.size(), size); -} - -TEST_F(UtestFastNode, NodePtr) { - auto graph = std::make_shared("graph"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetFormat(FORMAT_CHWN); - auto op_desc1 = std::make_shared("add_front", "add_front"); - op_desc1->AddInputDesc(tensor_desc->Clone()); - auto nodeptr = graph->AddNodeFront(op_desc1); - ASSERT_NE(nodeptr, nullptr); - - auto compute_graph = std::make_shared("graph"); - FastNode *fast_node = nullptr; - fast_node = compute_graph->AddNode(op_desc1); - ASSERT_NE(fast_node, nullptr); - - fast_node->SetNodePtr(nodeptr); - auto node = fast_node->GetNodePtr(); - ASSERT_EQ(node, nodeptr); - - auto node_ref = fast_node->GetNodePtr(); - ASSERT_EQ(node_ref, nodeptr); - - fast_node->ClearNodeBarePtr(); - auto clear_node = fast_node->GetNodeBarePtr(); - ASSERT_EQ(clear_node, nullptr); - - fast_node->ClearNodePtr(); - auto clear_node_ptr = fast_node->GetNodePtr(); - ASSERT_EQ(clear_node_ptr, nullptr); -} - -TEST_F(UtestFastNode, RemoveEdgeFunc) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge[node_num] = {}; - for (int i = 1; i < node_num; ++i) { - edge[i] = compute_graph->AddEdge(node[i - 1], 0, node[i], 0); - ASSERT_NE(edge[i], nullptr); - } - - FastGraphUtils::GetListElementAddr(node[1])->owner->erase(FastGraphUtils::GetListElementAddr(node[1])); - node[1]->RemoveAllEdge([&compute_graph](FastEdge *e) { - if (e->src != nullptr) { - e->src->EraseEdge(e, DirectionType::kDirectionOutType); - e->src = nullptr; - } - - if (e->dst != nullptr) { - e->dst->EraseEdge(e, DirectionType::kDirectionInType); - e->dst = nullptr; - } - - if (FastGraphUtils::GetListElementAddr(e)->owner != nullptr) { - FastGraphUtils::GetListElementAddr(e)->owner->erase(FastGraphUtils::GetListElementAddr(e)); - } - auto ret = compute_graph->RecycleQuickEdge(e); - if ((ret != GRAPH_SUCCESS) && (e != nullptr)) { - delete e; - } - }); - auto ret = compute_graph->RecycleQuickNode(node[1]); - if ((ret != GRAPH_SUCCESS) && (node[1] != nullptr)) { - delete node[1]; - } - ASSERT_EQ(compute_graph->GetDirectNodesSize(), node_num - 1); -} - -TEST_F(UtestFastNode, TestNodeError) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - FastEdge *edge = compute_graph->AddEdge(node[0], 0, node[1], 0); - ASSERT_NE(edge, nullptr); - - FastEdge *fail_edge = compute_graph->AddEdge(node[0], 0, node[1], 0); - ASSERT_EQ(fail_edge, nullptr); - - { - FastEdge *edge = compute_graph->AddEdge(nullptr, 0, nullptr, 0); - ASSERT_EQ(edge, nullptr); - - int invalid_num = 10; - edge = compute_graph->AddEdge(node[0], 0, node[1], invalid_num); - ASSERT_EQ(edge, nullptr); - - edge = compute_graph->AddEdge(node[0], invalid_num, node[1], 0); - ASSERT_EQ(edge, nullptr); - } - - { - auto compute_graph2 = std::make_shared("graph2"); - auto ret = compute_graph2->RemoveEdge(edge); - ASSERT_NE(ret, GRAPH_SUCCESS); - } - - { - auto ret = compute_graph->RemoveEdge(nullptr); - ASSERT_NE(ret, GRAPH_SUCCESS); - } - - { - int32_t invalid_num = 10; - auto size = node[0]->GetInEdgesSizeByIndex(invalid_num); - ASSERT_EQ(size, 0); - size = node[0]->GetOutEdgesSizeByIndex(invalid_num); - ASSERT_EQ(size, 0); - auto ret = node[0]->GetInDataEdgeByIndex(invalid_num); - ASSERT_EQ(ret, nullptr); - } - - { - auto graph = std::make_shared("graph"); - int node_num = 2; - int edge_num = 5; - FastNode *node[node_num] = {}; - { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddOutputDesc(td); - } - node[0] = graph->AddNode(op_desc); - ASSERT_NE(node[0], nullptr); - } - - { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - } - node[1] = graph->AddNode(op_desc); - ASSERT_NE(node[1], nullptr); - } - - FastEdge *edge = graph->AddEdge(node[0], 1, node[1], 1); - ASSERT_NE(edge, nullptr); - - auto size = node[0]->GetOutEdgesSizeByIndex(1); - ASSERT_EQ(size, 1); - } - - { - auto no_init_node = new FastNode(); - auto nodes = no_init_node->GetOutDataNodesByIndex(0); - ASSERT_EQ(nodes.size(), 0); - nodes = no_init_node->GetOutControlNodes(); - ASSERT_EQ(nodes.size(), 0); - delete no_init_node; - } - - { - auto ret = compute_graph->AddNodeFront(nullptr); - ASSERT_EQ(ret, nullptr); - } -} - -TEST_F(UtestFastNode, TestAddNodeWithNode) { - int node_num = 10; - int io_num = 5; - FastNode *node[node_num] = {}; - std::shared_ptr op_desc[node_num] = {nullptr}; - for (int j = 0; j < node_num; j++) { - op_desc[j] = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddInputDesc(td); - } - for (int64_t i = 0; i < io_num; ++i) { - op_desc[j]->AddOutputDesc(td); - } - } - - { - auto root_graph2 = std::make_shared("root_graph2"); - auto root_graph = std::make_shared("root_graph"); - - node[0] = root_graph2->AddNode(op_desc[0]); - ASSERT_NE(node[0], nullptr); - - auto ret = root_graph2->RemoveJustNode(node[0]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - node[0] = root_graph->AddNode(node[0]); - ASSERT_NE(node[0], nullptr); - - ret = root_graph->RemoveJustNode(node[0]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } - - { - auto root_graph2 = std::make_shared("root_graph2"); - auto root_graph = std::make_shared("root_graph"); - - node[1] = root_graph2->AddNode(op_desc[1]); - ASSERT_NE(node[1], nullptr); - ASSERT_NE(node[1]->GetExtendInfo(), nullptr); - - auto ret = root_graph2->RemoveJustNode(node[1]); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - node[1]->GetExtendInfo()->SetOwnerGraph(root_graph.get(), node[1]); - - node[1] = root_graph->AddNode(node[1]); - ASSERT_NE(node[1], nullptr); - } - - { - auto root_graph2 = std::make_shared("root_graph2"); - auto root_graph = std::make_shared("root_graph"); - int32_t node_idx = 2; - - node[node_idx] = root_graph2->AddNode(op_desc[node_idx]); - ASSERT_NE(node[node_idx], nullptr); - ASSERT_NE(node[node_idx]->GetExtendInfo(), nullptr); - - node[node_idx] = root_graph->AddNodeFront(node[node_idx]); - ASSERT_NE(node[node_idx], nullptr); - - auto ret = root_graph2->RemoveJustNode(node[node_idx]); - ASSERT_EQ(ret, GRAPH_NOT_CHANGED); - - node[node_idx]->GetExtendInfo()->SetOwnerGraph(root_graph.get(), node[node_idx]); - } - - { - auto root_graph = std::make_shared("root_graph"); - auto ret = root_graph->RemoveJustNode(nullptr); - ASSERT_NE(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestFastNode, ReorderByNodeId) { - std::unordered_map name_to_nodes; - auto graph = BuildDelayTopoGraphMultiInput("test_delay_topo_graph", name_to_nodes); - - auto iter = name_to_nodes.find("const"); - auto constant = iter->second; - - iter = name_to_nodes.find("constant"); - auto constantop = iter->second; - - iter = name_to_nodes.find("variable"); - auto variable = iter->second; - - iter = name_to_nodes.find("node1"); - auto node1 = iter->second; - - iter = name_to_nodes.find("node2"); - auto node2 = iter->second; - - iter = name_to_nodes.find("node3"); - auto node3 = iter->second; - - iter = name_to_nodes.find("node4"); - auto node4 = iter->second; - - iter = name_to_nodes.find("node5"); - auto node5 = iter->second; - - iter = name_to_nodes.find("data"); - auto data = iter->second; - - int64_t seq_id = 0L; - std::vector nodes{node5, node4, node3, node2, node1, variable, data, constantop, constant}; - for (auto &node : nodes) { - node->GetOpDescBarePtr()->SetId(seq_id++); - } - graph->ReorderByNodeId(); - auto sorted_nodes = graph->GetDirectNode(); - ASSERT_TRUE(sorted_nodes.size() == nodes.size()); - int32_t id = 0; - for (auto &node : nodes) { - EXPECT_EQ(node, sorted_nodes.at(id++)); - } -} - -TEST_F(UtestFastNode, TestNodeCheckAllInputParamter) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - } - - { - int32_t idx = -2; - auto ret = node[0]->MoveEdge(DirectionType::kDirectionOutType, idx, 0, 0); - ASSERT_NE(ret, GRAPH_SUCCESS); - } - - { - int32_t idx = 10; - auto ret = node[0]->MoveEdge(DirectionType::kDirectionOutType, idx, 0, 0); - ASSERT_NE(ret, GRAPH_SUCCESS); - } - - { - int32_t idx = 0; - int invalid_curr_idx = 10; - auto ret = node[0]->MoveEdge(DirectionType::kDirectionOutType, idx, invalid_curr_idx, 0); - ASSERT_NE(ret, GRAPH_SUCCESS); - } - - { - int32_t idx = 0; - int invalid_replace_idx = 10; - auto ret = node[0]->MoveEdge(DirectionType::kDirectionOutType, idx, 0, invalid_replace_idx); - ASSERT_NE(ret, GRAPH_SUCCESS); - } - - { - FastEdge *edge = compute_graph->AddEdge(node[0], -1, node[1], -1); - ASSERT_NE(edge, nullptr); - - edge = compute_graph->AddEdge(node[0], -1, node[2], -1); - ASSERT_NE(edge, nullptr); - - auto ret = compute_graph->RemoveEdge(edge); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - edge = compute_graph->AddEdge(node[1], -1, node[2], -1); - ASSERT_NE(edge, nullptr); - - ret = node[0]->MoveEdge(DirectionType::kDirectionOutType, -1, 0, 1); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - ret = node[2]->MoveEdge(DirectionType::kDirectionInType, -1, 0, 1); - ASSERT_EQ(ret, GRAPH_SUCCESS); - } -} - -TEST_F(UtestFastNode, other) { - auto compute_graph = std::make_shared("graph"); - int node_num = 10; - int edge_num = 5; - FastNode *node[node_num] = {}; - for (int i = 0; i < node_num; ++i) { - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (int j = 0; j < edge_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = compute_graph->AddNode(op_desc); - ASSERT_NE(node[i], nullptr); - ASSERT_NE(node[i]->GetExtendInfo(), nullptr); - ASSERT_NE(node[i]->GetExtendInfo()->GetHostNode(), true); - } - - node[0]->UpdateOpDesc(nullptr); - ASSERT_EQ(node[0]->GetOpDescPtr(), nullptr); - - auto op_desc_test = std::make_shared("test", "test"); - node[0]->UpdateOpDesc(op_desc_test); - ASSERT_EQ(node[0]->GetOpDescPtr(), op_desc_test); -} - -TEST_F(UtestFastNode, TestSymbol) { - auto exe_graph = std::make_shared("graph"); - size_t data_num = 2; - auto op_desc = std::make_shared("op", "op"); - auto td = GeTensorDesc(); - for (size_t j = 0; j < data_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - const auto node = exe_graph->AddNode(op_desc); - ASSERT_NE(node, nullptr); - - ASSERT_EQ(node->GetExtendInfo()->SetInputSymbol(0, 0), GRAPH_SUCCESS); - ASSERT_EQ(node->GetExtendInfo()->SetInputSymbol(1, 1), GRAPH_SUCCESS); - ASSERT_EQ(node->GetExtendInfo()->SetInputSymbol(2, 2), GRAPH_FAILED); - - ASSERT_EQ(node->GetExtendInfo()->SetOutputSymbol(0, 0), GRAPH_SUCCESS); - ASSERT_EQ(node->GetExtendInfo()->SetOutputSymbol(1, 1), GRAPH_SUCCESS); - ASSERT_EQ(node->GetExtendInfo()->SetOutputSymbol(2, 2), GRAPH_FAILED); - - ASSERT_EQ(node->GetExtendInfo()->GetInputSymbol(0), 0); - ASSERT_EQ(node->GetExtendInfo()->GetInputSymbol(1), 1); - ASSERT_EQ(node->GetExtendInfo()->GetInputSymbol(2), kInvalidSymbol); - - ASSERT_EQ(node->GetExtendInfo()->GetOutputSymbol(0), 0); - ASSERT_EQ(node->GetExtendInfo()->GetOutputSymbol(1), 1); - ASSERT_EQ(node->GetExtendInfo()->GetOutputSymbol(2), kInvalidSymbol); - - node->UpdateDataInNum(3); - node->UpdateDataOutNum(3); - ASSERT_EQ(node->GetExtendInfo()->SetInputSymbol(2, 2), GRAPH_SUCCESS); - ASSERT_EQ(node->GetExtendInfo()->SetOutputSymbol(2, 2), GRAPH_SUCCESS); - ASSERT_EQ(node->GetExtendInfo()->GetInputSymbol(2), 2); - ASSERT_EQ(node->GetExtendInfo()->GetOutputSymbol(2), 2); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/fast_node_utils_unittest.cc b/tests/ut/graph/testcase/fast_node_utils_unittest.cc deleted file mode 100644 index 1b3bfb02811de1475570ad14e4a297a901491951..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/fast_node_utils_unittest.cc +++ /dev/null @@ -1,361 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/fast_graph/execute_graph.h" -#include "inc/graph/utils/fast_node_utils.h" -#include "inc/graph/utils/execute_graph_utils.h" -#include "inc/graph/fast_graph/edge.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_log.h" - -namespace ge { -namespace { -/* - * +-----------+ +-----------+ - * | Graph 0 | | Graph 1 | - * | | | | - * | NetOutput | | NetOutput | - * NetOutput | | | | | | - * | | Shape | | Rank | - * Case <---------> | | | | | | - * / \ | Data(1) | | Data(1) | - * pred(Data) input(Data) +-----------+ +-----------+ - */ -std::shared_ptr BuildSimpleCaseGraph() { - // main case graph - auto builder = ut::ExecuteGraphBuilder("main_case_graph"); - auto main_data1 = builder.AddNode("data1", DATA, 1, 1); - auto main_data2 = builder.AddNode("data2", DATA, 1, 1); - auto case_node = builder.AddNode("case", CASE, 2, 1); - auto main_output = builder.AddNode("output1", NETOUTPUT, 1, 1); - builder.AddDataEdge(main_data1, 0, case_node, 0); - builder.AddDataEdge(main_data2, 0, case_node, 1); - builder.AddDataEdge(case_node, 0, main_output, 0); - auto main_graph = builder.GetGraph(); - - // case-1 subgraph - auto sub_builder1 = ut::ExecuteGraphBuilder("case1_graph"); - auto sub_data1 = sub_builder1.AddNode("sub_data1", DATA, 1, 1); - auto shape_node = sub_builder1.AddNode("shape", "Shape", 1, 1); - auto sub_out1 = sub_builder1.AddNode("sub_output1", NETOUTPUT, 1, 1); - sub_builder1.AddDataEdge(sub_data1, 0, shape_node, 0); - sub_builder1.AddDataEdge(shape_node, 0, sub_out1, 0); - AttrUtils::SetInt(sub_data1->GetOpDescBarePtr(), ATTR_NAME_PARENT_NODE_INDEX, 1); - auto sub_graph1 = sub_builder1.GetGraph(); - - // case-2 subgraph - auto sub_builder2 = ut::ExecuteGraphBuilder("case2_graph"); - auto sub_data2 = sub_builder1.AddNode("sub_data2", DATA, 1, 1); - auto rank_node = sub_builder1.AddNode("rank", "RANK", 1, 1); - auto sub_out2 = sub_builder1.AddNode("sub_output2", NETOUTPUT, 1, 1); - sub_builder1.AddDataEdge(sub_data2, 0, rank_node, 0); - sub_builder1.AddDataEdge(rank_node, 0, sub_out2, 0); - AttrUtils::SetInt(sub_data2->GetOpDescBarePtr(), ATTR_NAME_PARENT_NODE_INDEX, 1); - auto sub_graph2 = sub_builder2.GetGraph(); - - // add subgraph to case_node - sub_graph1->SetParentGraph(main_graph.get()); - sub_graph1->SetParentNode(case_node); - sub_graph2->SetParentGraph(main_graph.get()); - sub_graph2->SetParentNode(case_node); - auto g_name1 = sub_graph1->GetName(); - case_node->GetOpDescBarePtr()->AddSubgraphName(g_name1); - case_node->GetOpDescBarePtr()->SetSubgraphInstanceName(0, g_name1); - main_graph->AddSubGraph(sub_graph1, g_name1); - auto g_name2 = sub_graph2->GetName(); - case_node->GetOpDescBarePtr()->AddSubgraphName(g_name2); - case_node->GetOpDescBarePtr()->SetSubgraphInstanceName(1, g_name2); - main_graph->AddSubGraph(sub_graph2, g_name2); - - return main_graph; -} - -std::shared_ptr BuildSimpleLineGraph(const std::string &graph_name, const int node_num = 3, - const int io_num = 1) { - auto exe_graph = std::make_shared(graph_name); - std::vector node(node_num, nullptr); - for (int i = 0; i < node_num; ++i) { - std::string op_name = "op_" + std::to_string(i); - std::string op_type = "op_type_" + std::to_string(i); - auto op_desc = std::make_shared(op_name, op_type); - auto td = GeTensorDesc(); - for (int j = 0; j < io_num; ++j) { - op_desc->AddInputDesc(td); - op_desc->AddOutputDesc(td); - } - node[i] = exe_graph->AddNode(op_desc); - } - - std::vector edge(node_num, nullptr); - for (int i = 1; i < node_num; ++i) { - edge[i] = exe_graph->AddEdge(node[i - 1], 0, node[i], 0); - } - return exe_graph; -} -} // namespace - -class UtestFastNodeUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestFastNodeUtils, GetConstOpType_CONST) { - ut::ExecuteGraphBuilder builder = ut::ExecuteGraphBuilder("graph"); - auto data = builder.AddNode("const1", CONSTANT, 0, 1); - std::cout << data->GetType() << std::endl; - auto ret = FastNodeUtils::GetConstOpType(data); - EXPECT_EQ(ret, true); - // case: null input - ret = FastNodeUtils::GetConstOpType(nullptr); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestFastNodeUtils, GetConstOpType_DATA) { - ut::ExecuteGraphBuilder builder = ut::ExecuteGraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - std::cout << data->GetType() << std::endl; - std::string op_type; - auto ret = FastNodeUtils::GetConstOpType(data); - ASSERT_EQ(ret, false); -} - -TEST_F(UtestFastNodeUtils, GetConstOpType) { - ut::ExecuteGraphBuilder builder = ut::ExecuteGraphBuilder("graph"); - auto data = builder.AddNode("nt", NETOUTPUT, 0, 1); - EXPECT_EQ(FastNodeUtils::GetConstOpType(data), false); -} - -TEST_F(UtestFastNodeUtils, GetParentInput_invalid) { - auto builder = ut::ExecuteGraphBuilder("test_graph0"); - const auto &data_node = builder.AddNode("data", DATA, 0, 0); - auto graph = builder.GetGraph(); - AttrUtils::SetInt(data_node->GetOpDescPtr(), ge::ATTR_NAME_PARENT_NODE_INDEX, 1); - EXPECT_EQ(FastNodeUtils::GetParentInput(data_node), nullptr); -} - -TEST_F(UtestFastNodeUtils, GetParentInput) { - const size_t node_num = 10; - const size_t io_num = 10; - const size_t subgraph_num = 1; - const size_t subgraph_node_num = 10; - std::shared_ptr op_desc[node_num] = {nullptr}; - for (size_t j = 0; j < node_num; j++) { - if (j == 1) { - op_desc[j] = std::make_shared("op", DATA); - } else { - op_desc[j] = std::make_shared("op", "op"); - } - - auto td = GeTensorDesc(); - - for (size_t i = 0; i < io_num; ++i) { - op_desc[j]->AddInputDesc(td); - } - for (size_t i = 0; i < io_num; ++i) { - op_desc[j]->AddOutputDesc(td); - } - } - - std::shared_ptr sub_graph[subgraph_num] = {nullptr}; - FastNode *node[node_num] = {}; - FastEdge *edge[node_num] = {}; - ExecuteGraph *quick_graph[subgraph_num] = {nullptr}; - - auto root_graph = std::make_shared("root_graph"); - for (size_t i = 0; i < node_num; i++) { - node[i] = root_graph->AddNode(op_desc[i]); - ASSERT_NE(node[i], nullptr); - } - - for (size_t i = 1; i < node_num; i++) { - edge[i] = root_graph->AddEdge(node[i], 1, node[i - 1], 0); - ASSERT_NE(edge[i], nullptr); - } - - FastNode *sub_graph_node[subgraph_node_num] = {}; - std::string name = "subgraph_" + std::to_string(0); - sub_graph[0] = std::make_shared(name); - - for (size_t j = 0; j < subgraph_node_num; j++) { - sub_graph_node[j] = sub_graph[0]->AddNode(op_desc[j]); - AttrUtils::SetInt(sub_graph_node[j]->GetOpDescPtr(), ge::ATTR_NAME_PARENT_NODE_INDEX, 0); - ASSERT_NE(sub_graph_node[j], nullptr); - } - - sub_graph[0]->SetParentGraph(root_graph.get()); - sub_graph[0]->SetParentNode(node[0]); - quick_graph[0] = root_graph->AddSubGraph(sub_graph[0], name); - ASSERT_NE(quick_graph[0], nullptr); - - EXPECT_EQ(FastNodeUtils::GetParentInput(sub_graph_node[1]), nullptr); -} - -TEST_F(UtestFastNodeUtils, GetInDataNodeByIndex_Ok_NodeWithTwoInputNodes) { - auto exe_graph = BuildSimpleCaseGraph(); - const auto nodes = exe_graph->GetAllNodes(); - auto case_node = nodes[2]; - // case1: find second input of case_node - auto expect_node = nodes[1]; - auto ret_node = FastNodeUtils::GetInDataNodeByIndex(case_node, 1); - EXPECT_EQ(ret_node, expect_node); - - // case2: empty - ret_node = FastNodeUtils::GetInDataNodeByIndex(case_node, 2); - EXPECT_EQ(ret_node, nullptr); - - // case3: null input - ret_node = FastNodeUtils::GetInDataNodeByIndex(nullptr, 2); - EXPECT_EQ(ret_node, nullptr); -} - -TEST_F(UtestFastNodeUtils, AddSubgraph_Ok_MultiScenarios) { - auto exe_graph = BuildSimpleLineGraph("main_graph", 3, 1); - auto sug_graph1 = BuildSimpleLineGraph("sub_graph1", 2, 1); - auto nodes = exe_graph->GetDirectNode(); - ASSERT_EQ(nodes.size(), 3); - ASSERT_NE(nodes[1], nullptr); - // case1: add subgraph once - auto ret = FastNodeUtils::AppendSubgraphToNode(nodes[1], "sub_g1", sug_graph1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - // case2: add graph with same name - ret = FastNodeUtils::AppendSubgraphToNode(nodes[1], "sub_g1", sug_graph1); - EXPECT_NE(ret, GRAPH_SUCCESS); - - // case3: null input - ret = FastNodeUtils::AppendSubgraphToNode(nullptr, "sub_g1", sug_graph1); - EXPECT_EQ(ret, PARAM_INVALID); - ret = FastNodeUtils::AppendSubgraphToNode(nodes[1], "sub_g1", nullptr); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(UtestFastNodeUtils, GetSubgraph_Ok_GetCaseBranchGraph) { - auto exe_graph = BuildSimpleCaseGraph(); - auto case_node = exe_graph->GetDirectNode()[2]; - auto sub_graph1 = FastNodeUtils::GetSubgraphFromNode(case_node, 0); - ASSERT_NE(sub_graph1, nullptr); - EXPECT_EQ(sub_graph1->GetName(), "case1_graph"); - EXPECT_EQ(FastNodeUtils::GetSubgraphFromNode(nullptr, 0), nullptr); -} - -TEST_F(UtestFastNodeUtils, SetSubgraph_Ok_MultiScenarios) { - auto main_graph = BuildSimpleLineGraph("main_graph", 1, 1); - auto sub_graph = BuildSimpleLineGraph("sub_graph", 2, 1); - auto par_node = main_graph->GetDirectNode()[0]; - // case1: subgraph is nullptr - auto ret = FastNodeUtils::MountSubgraphToNode(par_node, 0, nullptr); - EXPECT_EQ(ret, PARAM_INVALID); - - // case2: subgraph instance does not exist - ret = FastNodeUtils::MountSubgraphToNode(par_node, 0, sub_graph); - EXPECT_NE(ret, GRAPH_SUCCESS); - - // case3: add a subgraph name to op_desc, and then set subgraph - EXPECT_EQ(par_node->GetOpDescBarePtr()->AddSubgraphName(sub_graph->GetName()), GRAPH_SUCCESS); - ret = FastNodeUtils::MountSubgraphToNode(par_node, 0, sub_graph); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(main_graph->GetAllNodes().size(), 3); -} - -TEST_F(UtestFastNodeUtils, SetSubgraph_Fail_NullRootGraph) { - auto sub_graph = BuildSimpleLineGraph("sub_graph", 2, 1); - auto par_node = std::make_shared(); - auto op_desc = std::make_shared("op", DATA); - EXPECT_EQ(par_node->Init(op_desc), GRAPH_SUCCESS); - // case4: parent node's op has no OwnerGraph (null root graph) - auto ret = FastNodeUtils::MountSubgraphToNode(par_node.get(), 0, sub_graph); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestFastNodeUtils, AppendInputEdgeInfo_Ok) { - // case1: null input - EXPECT_EQ(FastNodeUtils::AppendInputEdgeInfo(nullptr, 0), PARAM_INVALID); - - // case2: append extra input edge info - ut::ExecuteGraphBuilder builder = ut::ExecuteGraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 2, 1); - EXPECT_EQ(data->GetInEdgeSize(), 0); - EXPECT_EQ(data->GetDataInNum(), 2); - EXPECT_EQ(FastNodeUtils::AppendInputEdgeInfo(data, 11), GRAPH_SUCCESS); - EXPECT_EQ(data->GetInEdgeSize(), 0); - EXPECT_EQ(data->GetDataInNum(), 11); -} - -TEST_F(UtestFastNodeUtils, AppendOutputEdgeInfo_Ok) { - // case1: null input - EXPECT_EQ(FastNodeUtils::AppendOutputEdgeInfo(nullptr, 0), PARAM_INVALID); - - // case2: append extra output edge info - ut::ExecuteGraphBuilder builder = ut::ExecuteGraphBuilder("graph"); - auto data = builder.AddNode("Data", DATA, 2, 1); - auto net_out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(data, 0, net_out, 0); - EXPECT_EQ(data->GetOutEdgesSizeByIndex(0), 1); - EXPECT_EQ(data->GetDataInNum(), 2); - EXPECT_EQ(FastNodeUtils::AppendOutputEdgeInfo(data, 11), GRAPH_SUCCESS); - EXPECT_EQ(data->GetOutEdgesSizeByIndex(0), 1); - EXPECT_EQ(data->GetDataOutNum(), 11); -} - -TEST_F(UtestFastNodeUtils, ClearInputDesc_Fail_InvalidInput) { - EXPECT_FALSE(FastNodeUtils::ClearInputDesc(nullptr, 0)); - auto op_desc = std::make_shared(); - EXPECT_FALSE(FastNodeUtils::ClearInputDesc(op_desc.get(), 3)); -} - -TEST_F(UtestFastNodeUtils, RemoveInputEdgeInfo_Ok) { - // case1: null input - EXPECT_EQ(FastNodeUtils::RemoveInputEdgeInfo(nullptr, 0), PARAM_INVALID); - - // case2: remove input edge info until size is num - ut::ExecuteGraphBuilder builder = ut::ExecuteGraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - EXPECT_EQ(data->GetOpDescBarePtr()->GetInputsSize(), 1); - EXPECT_EQ(data->GetOpDescBarePtr()->AddInputDesc(GeTensorDesc()), GRAPH_SUCCESS); - EXPECT_EQ(data->GetOpDescBarePtr()->GetInputsSize(), 2); - EXPECT_EQ(FastNodeUtils::RemoveInputEdgeInfo(data, 0), GRAPH_SUCCESS); - EXPECT_EQ(data->GetDataInNum(), 0); -} - -TEST_F(UtestFastNodeUtils, UnlinkAll_Ok_UnlinkAllAndCheckEdgeNum) { - // case1: null input, no return value - FastNodeUtils::UnlinkAll(nullptr); - - // case2: remove all the edges connecting to the second node, 2 data egde and 1 control edge - auto exe_graph = BuildSimpleLineGraph("simple_graph", 3, 1); - auto node1 = exe_graph->GetDirectNode()[0]; - auto node2 = exe_graph->GetDirectNode()[1]; - (void) exe_graph->AddEdge(node1, kControlEdgeIndex, node2, kControlEdgeIndex); - EXPECT_EQ(exe_graph->GetAllEdges().size(), 3); - FastNodeUtils::UnlinkAll(node2); - EXPECT_EQ(exe_graph->GetAllEdges().size(), 0); -} - -TEST_F(UtestFastNodeUtils, GetInEndpoint_Ok) { - auto exe_graph = BuildSimpleLineGraph("simple_graph", 2, 1); - auto in_node = exe_graph->GetDirectNode()[1]; - auto edge = exe_graph->GetAllEdges().front(); - EXPECT_EQ(FastNodeUtils::GetDstEndpoint(edge).node, in_node); - EXPECT_EQ(FastNodeUtils::GetDstEndpoint(edge).index, 0); -} - -TEST_F(UtestFastNodeUtils, GetOutEndpoint_Ok) { - auto exe_graph = BuildSimpleLineGraph("simple_graph", 2, 1); - auto out_node = exe_graph->GetDirectNode()[0]; - auto edge = exe_graph->GetAllEdges().front(); - EXPECT_EQ(FastNodeUtils::GetSrcEndpoint(edge).node, out_node); - EXPECT_EQ(FastNodeUtils::GetSrcEndpoint(edge).index, 0); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/ffts_graph_utils_unittest.cc b/tests/ut/graph/testcase/ffts_graph_utils_unittest.cc deleted file mode 100644 index 5b7c3dba21bc25596e65b5835393dc0c3ebdc29e..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ffts_graph_utils_unittest.cc +++ /dev/null @@ -1,552 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/utils/ffts_graph_utils.h" -#include "graph/utils/node_utils.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "common/ge_common/ge_inner_error_codes.h" - -namespace ge { -namespace { -bool IfNodeExist(const ComputeGraphPtr &graph, std::function filter, - bool direct_node_flag = true) { - for (const auto &node : graph->GetNodes(direct_node_flag)) { - if (filter(node)) { - return true; - } - } - return false; -} - -NodePtr FindNodeWithNamePattern(const ComputeGraphPtr &graph, const std::string &pattern, - bool direct_node_flag = true) { - for (const auto &node : graph->GetNodes(direct_node_flag)) { - const auto &name = node->GetName(); - if (name.find(pattern) != string::npos) { - return node; - } - } - return nullptr; -} - -void GetSubgraphsWithFilter(const ComputeGraphPtr &graph, std::function filter, - std::vector &subgraphs) { - for (const auto &subgraph : graph->GetAllSubgraphs()) { - if (filter(subgraph)) { - subgraphs.emplace_back(subgraph); - } - } -} - -bool IsAllNodeMatch(const ComputeGraphPtr &graph, std::function filter, - bool direct_node_flag = true) { - for (const auto &node : graph->GetNodes(direct_node_flag)) { - if (!filter(node)) { - return false; - } - } - return true; -} - -/* - * data - * | - * cast1 - * | - * cast2 - * | - * cast3 - * | - * cast4 - * | - * cast5 - * | - * cast6 - * | - * netoutput - */ -void BuildGraphForSplit_without_func_node(ComputeGraphPtr &graph, ComputeGraphPtr &subgraph) { - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &data1 = sub_builder.AddNode("data1", DATA, 1, 1); - const auto &cast1 = sub_builder.AddNode("cast1", "Cast", 1, 1); - const auto &cast2 = sub_builder.AddNode("cast2", "Cast", 1, 1); - const auto &cast3 = sub_builder.AddNode("cast3", "Cast", 1, 1); - const auto &cast4 = sub_builder.AddNode("cast4", "Cast", 1, 1); - const auto &cast5 = sub_builder.AddNode("cast5", "Cast", 1, 1); - const auto &cast6 = sub_builder.AddNode("cast6", "Cast", 1, 1); - const auto &netoutput = sub_builder.AddNode("netoutput", NETOUTPUT, 1, 0); - - sub_builder.AddDataEdge(data1, 0, cast1, 0); - sub_builder.AddDataEdge(cast1, 0, cast2, 0); - sub_builder.AddDataEdge(cast2, 0, cast3, 0); - sub_builder.AddDataEdge(cast3, 0, cast4, 0); - sub_builder.AddDataEdge(cast4, 0, cast5, 0); - sub_builder.AddDataEdge(cast5, 0, cast6, 0); - sub_builder.AddDataEdge(cast6, 0, netoutput, 0); - - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - - auto builder = ut::GraphBuilder("root"); - const auto &input = builder.AddNode("data", DATA, 1, 1); - const auto &func_node = builder.AddNode("func_node", PARTITIONEDCALL, 1, 1); - const auto &output = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - - builder.AddDataEdge(input, 0, func_node, 0); - builder.AddDataEdge(func_node, 0, output, 0); - - subgraph = sub_builder.GetGraph(); - AttrUtils::SetStr(subgraph, "_session_graph_id", "_session_graph_id"); - graph = builder.GetGraph(); - AttrUtils::SetStr(graph, "_session_graph_id", "_session_graph_id"); - - func_node->GetOpDesc()->AddSubgraphName("f"); - func_node->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); - AttrUtils::SetStr(func_node->GetOpDesc(), ATTR_NAME_FFTS_PLUS_SUB_GRAPH, "ffts_plus"); - graph->AddSubGraph(subgraph); - subgraph->SetParentNode(func_node); - subgraph->SetParentGraph(graph); - - return; -} - -/* - * ********** root ********** func ********** then_1 ********** else_1 ********** then_2 ********** else_2 ********** - * - * input var1 data0 data1 data2 data3 data4 - * | \ / | | | | - * func constant if1 cast1 cast2 cast3 cast4 - * | \ / | | | | - * output less square1 square2 square3 square4 - * | | | | | - * netoutput0 netoutput1 var2 if2 netoutput3 netoutput4 - * \ / - * netoutput2 - * - * ****************************************************************************************************************** - */ -void BuildGraphForSplit_with_func_node(ComputeGraphPtr &graph, ComputeGraphPtr &subgraph) { - auto builder = ut::GraphBuilder("root"); - const auto &input = builder.AddNode("input", DATA, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 1, 1); - const auto &output = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input, 0, func, 0); - builder.AddDataEdge(func, 0, output, 0); - graph = builder.GetGraph(); - AttrUtils::SetStr(graph, "_session_graph_id", "_session_graph_id"); - - auto sub_builder0 = ut::GraphBuilder("func"); - const auto &data0 = sub_builder0.AddNode("data0", DATA, 1, 1); - const auto &var1 = sub_builder0.AddNode("var1", VARIABLEV2, 1, 1); - const auto &if1 = sub_builder0.AddNode("if1", "If", 2, 1); - const auto &constant = sub_builder0.AddNode("constant", CONSTANTOP, 1, 1); - const auto &less = sub_builder0.AddNode("less", "Less", 2, 1); - const auto &netoutput0 = sub_builder0.AddNode("netoutput0", NETOUTPUT, 1, 0); - sub_builder0.AddDataEdge(var1, 0, if1, 0); - sub_builder0.AddDataEdge(data0, 0, if1, 1); - sub_builder0.AddDataEdge(constant, 0, less, 0); - sub_builder0.AddDataEdge(if1, 0, less, 1); - sub_builder0.AddDataEdge(less, 0, netoutput0, 0); - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(netoutput0->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - subgraph = sub_builder0.GetGraph(); - AttrUtils::SetStr(subgraph, "_session_graph_id", "_session_graph_id"); - func->GetOpDesc()->AddSubgraphName("f"); - func->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); - AttrUtils::SetStr(func->GetOpDesc(), ATTR_NAME_FFTS_PLUS_SUB_GRAPH, "ffts_plus"); - graph->AddSubGraph(subgraph); - subgraph->SetParentNode(func); - subgraph->SetParentGraph(graph); - - auto sub_builder1 = ut::GraphBuilder("then_1"); - const auto &data1 = sub_builder1.AddNode("data1", DATA, 1, 1); - const auto &cast1 = sub_builder1.AddNode("cast1", "Cast", 1, 1); - const auto &square1 = sub_builder1.AddNode("square1", "Square", 1, 1); - const auto &netoutput1 = sub_builder1.AddNode("netoutput1", NETOUTPUT, 1, 0); - sub_builder1.AddDataEdge(data1, 0, cast1, 0); - sub_builder1.AddDataEdge(cast1, 0, square1, 0); - sub_builder1.AddDataEdge(square1, 0, netoutput1, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(netoutput1->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - const auto &subgraph1 = sub_builder1.GetGraph(); - AttrUtils::SetStr(subgraph1, "_session_graph_id", "_session_graph_id"); - if1->GetOpDesc()->AddSubgraphName("then_branch"); - if1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph1->GetName()); - graph->AddSubGraph(subgraph1); - subgraph1->SetParentNode(if1); - subgraph1->SetParentGraph(subgraph); - - auto sub_builder2 = ut::GraphBuilder("else_1"); - const auto &data2 = sub_builder2.AddNode("data2", DATA, 1, 1); - const auto &cast2 = sub_builder2.AddNode("cast2", "Cast", 1, 1); - const auto &square2 = sub_builder2.AddNode("square2", "Square", 1, 1); - const auto &var2 = sub_builder2.AddNode("var2", VARIABLEV2, 1, 1); - const auto &if2 = sub_builder2.AddNode("if2", "If", 2, 1); - const auto &netoutput2 = sub_builder2.AddNode("netoutput2", NETOUTPUT, 1, 0); - sub_builder2.AddDataEdge(data2, 0, cast2, 0); - sub_builder2.AddDataEdge(cast2, 0, square2, 0); - sub_builder2.AddDataEdge(square2, 0, if2, 1); - sub_builder2.AddDataEdge(var2, 0, if2, 0); - sub_builder2.AddDataEdge(if2, 0, netoutput2, 0); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(netoutput2->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - const auto &subgraph2 = sub_builder2.GetGraph(); - AttrUtils::SetStr(subgraph2, "_session_graph_id", "_session_graph_id"); - if1->GetOpDesc()->AddSubgraphName("else_branch"); - if1->GetOpDesc()->SetSubgraphInstanceName(1, subgraph2->GetName()); - graph->AddSubGraph(subgraph2); - subgraph2->SetParentNode(if1); - subgraph2->SetParentGraph(subgraph); - - auto sub_builder3 = ut::GraphBuilder("then_2"); - const auto &data3 = sub_builder3.AddNode("data3", DATA, 1, 1); - const auto &cast3 = sub_builder3.AddNode("cast3", "Cast", 1, 1); - const auto &square3 = sub_builder3.AddNode("square3", "Square", 1, 1); - const auto &netoutput3 = sub_builder3.AddNode("netoutput3", NETOUTPUT, 1, 0); - sub_builder3.AddDataEdge(data3, 0, cast3, 0); - sub_builder3.AddDataEdge(cast3, 0, square3, 0); - sub_builder1.AddDataEdge(square3, 0, netoutput3, 0); - AttrUtils::SetInt(data3->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(netoutput3->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - const auto &subgraph3 = sub_builder3.GetGraph();(AttrUtils::SetStr(subgraph3, "_session_graph_id", "_session_graph_id")); - if2->GetOpDesc()->AddSubgraphName("then_branch"); - if2->GetOpDesc()->SetSubgraphInstanceName(0, subgraph3->GetName()); - graph->AddSubGraph(subgraph3); - subgraph3->SetParentNode(if2); - subgraph3->SetParentGraph(subgraph2); - - auto sub_builder4 = ut::GraphBuilder("else_2"); - const auto &data4 = sub_builder4.AddNode("data4", DATA, 1, 1); - const auto &cast4 = sub_builder4.AddNode("cast4", "Cast", 1, 1); - const auto &square4 = sub_builder4.AddNode("square4", "Square", 1, 1); - const auto &netoutput4 = sub_builder4.AddNode("netoutput4", NETOUTPUT, 1, 0); - sub_builder4.AddDataEdge(data4, 0, cast4, 0); - sub_builder4.AddDataEdge(cast4, 0, square4, 0); - sub_builder4.AddDataEdge(square4, 0, netoutput4, 0); - AttrUtils::SetInt(data4->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(netoutput4->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - const auto &subgraph4 = sub_builder4.GetGraph(); - AttrUtils::SetStr(subgraph4, "_session_graph_id", "_session_graph_id"); - if2->GetOpDesc()->AddSubgraphName("else_branch"); - if2->GetOpDesc()->SetSubgraphInstanceName(1, subgraph4->GetName()); - graph->AddSubGraph(subgraph4); - subgraph4->SetParentNode(if2); - subgraph4->SetParentGraph(subgraph2); - - return; -} -} - -class UtestFftsGraphUtils : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestFftsGraphUtils, LimitExceedPartition_invalid_input) { - ASSERT_EQ(FftsGraphUtils::GraphPartition(*ut::GraphBuilder("root").GetGraph(), nullptr, {}), GRAPH_SUCCESS); - const auto &calc_func = [](const NodePtr &node) { - return std::vector {1}; - }; - ASSERT_EQ(FftsGraphUtils::GraphPartition(*ut::GraphBuilder("root").GetGraph(), calc_func, {}), GRAPH_SUCCESS); -} - -TEST_F(UtestFftsGraphUtils, LimitExceedPartition_no_func_node) { - ComputeGraphPtr graph; - ComputeGraphPtr subgraph; - BuildGraphForSplit_without_func_node(graph, subgraph); - ASSERT_NE(graph, nullptr); - ASSERT_NE(subgraph, nullptr); - - const auto &calc_func = [](const NodePtr &node) { - return std::vector {1}; - }; - ASSERT_EQ(FftsGraphUtils::GraphPartition(*subgraph, calc_func, {8}), GRAPH_SUCCESS); - ASSERT_EQ(FftsGraphUtils::GraphPartition(*subgraph, calc_func, {3}), GRAPH_SUCCESS); - ASSERT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - - ASSERT_EQ(graph->GetAllSubgraphs().size(), 2); - std::vector subgraphs; - GetSubgraphsWithFilter(graph, - [](const ComputeGraphPtr &graph) { - const auto &parent_node = graph->GetParentNode(); - if ((parent_node == nullptr) || (parent_node->GetOpDesc() == nullptr)) { - return false; - } - return parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); }, - subgraphs); - ASSERT_EQ(subgraphs.size(), 2); - for (const auto &subgraph : subgraphs) { - ASSERT_TRUE(subgraph != nullptr); - ASSERT_TRUE(IsAllNodeMatch(subgraph, - [](const NodePtr &node) { - return node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID); - }, false)); - } - - const auto &cast1 = FindNodeWithNamePattern(graph, "cast1", false); - ASSERT_NE(cast1, nullptr); - const auto &cast2 = FindNodeWithNamePattern(graph, "cast2", false); - ASSERT_NE(cast2, nullptr); - const auto &cast3 = FindNodeWithNamePattern(graph, "cast3", false); - ASSERT_NE(cast3, nullptr); - const auto &cast4 = FindNodeWithNamePattern(graph, "cast4", false); - ASSERT_NE(cast4, nullptr); - const auto &cast5 = FindNodeWithNamePattern(graph, "cast5", false); - ASSERT_NE(cast5, nullptr); - const auto &cast6 = FindNodeWithNamePattern(graph, "cast6", false); - ASSERT_NE(cast6, nullptr); - const auto &func1 = subgraphs[0]->GetParentNode(); - ASSERT_NE(func1, nullptr); - const auto &func2 = subgraphs[1]->GetParentNode(); - ASSERT_NE(func2, nullptr); - ASSERT_EQ(cast1->GetOwnerComputeGraph(), cast2->GetOwnerComputeGraph()); - ASSERT_EQ(cast1->GetOwnerComputeGraph(), cast3->GetOwnerComputeGraph()); - ASSERT_EQ(cast4->GetOwnerComputeGraph(), cast5->GetOwnerComputeGraph()); - ASSERT_EQ(cast4->GetOwnerComputeGraph(), cast6->GetOwnerComputeGraph()); - ASSERT_NE(cast1->GetOwnerComputeGraph(), cast4->GetOwnerComputeGraph()); - ASSERT_NE(cast1->GetOwnerComputeGraph(), graph); - ASSERT_NE(cast4->GetOwnerComputeGraph(), graph); - ASSERT_EQ(func1->GetOwnerComputeGraph(), graph); - ASSERT_EQ(func2->GetOwnerComputeGraph(), graph); - - const auto &input = graph->FindFirstNodeMatchType(DATA); - ASSERT_NE(input, nullptr); - const auto &output = graph->FindFirstNodeMatchType(NETOUTPUT); - ASSERT_NE(output, nullptr); - ASSERT_TRUE(input->GetOutDataAnchor(0)->IsLinkedWith(func1->GetInDataAnchor(0))); - ASSERT_TRUE(func1->GetOutDataAnchor(0)->IsLinkedWith(func2->GetInDataAnchor(0))); - ASSERT_TRUE(func2->GetOutDataAnchor(0)->IsLinkedWith(output->GetInDataAnchor(0))); - const auto &data1 = subgraphs[0]->FindFirstNodeMatchType(DATA); - ASSERT_NE(data1, nullptr); - const auto &netoutput1 = subgraphs[0]->FindFirstNodeMatchType(NETOUTPUT); - ASSERT_NE(netoutput1, nullptr); - ASSERT_TRUE(data1->GetOutDataAnchor(0)->IsLinkedWith(cast1->GetInDataAnchor(0))); - ASSERT_TRUE(cast1->GetOutDataAnchor(0)->IsLinkedWith(cast2->GetInDataAnchor(0))); - ASSERT_TRUE(cast2->GetOutDataAnchor(0)->IsLinkedWith(cast3->GetInDataAnchor(0))); - ASSERT_TRUE(cast3->GetOutDataAnchor(0)->IsLinkedWith(netoutput1->GetInDataAnchor(0))); - const auto &data2 = subgraphs[1]->FindFirstNodeMatchType(DATA); - ASSERT_NE(data2, nullptr); - const auto &netoutput2 = subgraphs[1]->FindFirstNodeMatchType(NETOUTPUT); - ASSERT_NE(netoutput2, nullptr); - ASSERT_TRUE(data2->GetOutDataAnchor(0)->IsLinkedWith(cast4->GetInDataAnchor(0))); - ASSERT_TRUE(cast4->GetOutDataAnchor(0)->IsLinkedWith(cast5->GetInDataAnchor(0))); - ASSERT_TRUE(cast5->GetOutDataAnchor(0)->IsLinkedWith(cast6->GetInDataAnchor(0))); - ASSERT_TRUE(cast6->GetOutDataAnchor(0)->IsLinkedWith(netoutput2->GetInDataAnchor(0))); -} - -TEST_F(UtestFftsGraphUtils, LimitExceedPartition_with_func_node) { - ComputeGraphPtr graph; - ComputeGraphPtr subgraph; - BuildGraphForSplit_with_func_node(graph, subgraph); - ASSERT_NE(graph, nullptr); - ASSERT_NE(subgraph, nullptr); - - const auto &calc_func = [](const NodePtr &node) { - return std::vector {1}; - }; - ASSERT_EQ(FftsGraphUtils::GraphPartition(*subgraph, calc_func, {8}), GRAPH_SUCCESS); - ASSERT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - ASSERT_EQ(graph->GetAllSubgraphs().size(), 9); -} - -//TEST_F(UtestFftsGraphUtils, ClipNodesFromGraph_no_func_node) { -// ComputeGraphPtr graph; -// ComputeGraphPtr subgraph; -// BuildGraphForSplit_without_func_node(graph, subgraph); -// ASSERT_NE(graph, nullptr); -// ASSERT_NE(subgraph, nullptr); -// -// ASSERT_EQ(FftsGraphUtils::GraphPartition(*subgraph, {}), GRAPH_SUCCESS); -// -// auto data1 = FindNodeWithNamePattern(subgraph, "data1"); -// ASSERT_NE(data1.get(), nullptr); -// ASSERT_EQ(FftsGraphUtils::GraphPartition(*subgraph, {data1}), GRAPH_SUCCESS); -// -// std::set unsupported_nodes; -// auto cast1 = FindNodeWithNamePattern(subgraph, "cast1"); -// ASSERT_NE(cast1, nullptr); -// unsupported_nodes.insert(cast1); -// auto cast4 = FindNodeWithNamePattern(subgraph, "cast4"); -// ASSERT_NE(cast4, nullptr); -// unsupported_nodes.insert(cast4); -// auto cast5 = FindNodeWithNamePattern(subgraph, "cast5"); -// ASSERT_NE(cast5, nullptr); -// unsupported_nodes.insert(cast5); -// ASSERT_EQ(FftsGraphUtils::GraphPartition(*subgraph, unsupported_nodes), GRAPH_SUCCESS); -// ASSERT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); -// -// ASSERT_EQ(graph->GetAllSubgraphs().size(), 3); -// const auto &parent_node = subgraph->GetParentNode(); -// ASSERT_NE(parent_node, nullptr); -// ASSERT_FALSE(parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)); -// ASSERT_TRUE(IsAllNodeMatch(subgraph, -// [](const NodePtr &node) { -// return !node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID); -// })); -// -// std::vector subgraphs; -// GetSubgraphsWithFilter(graph, -// [](const ComputeGraphPtr &graph) { -// const auto &parent_node = graph->GetParentNode(); -// if ((parent_node == nullptr) || (parent_node->GetOpDesc() == nullptr)) { -// return false; -// } -// return parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); }, -// subgraphs); -// ASSERT_EQ(subgraphs.size(), 2); -// for (const auto &subgraph : subgraphs) { -// ASSERT_TRUE(subgraph != nullptr); -// ASSERT_TRUE(IsAllNodeMatch(subgraph, -// [](const NodePtr &node) { -// return node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID); -// }, false)); -// } -// -// ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "data1"; })); -// ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "cast1"; })); -// ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "cast4"; })); -// ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "cast5"; })); -// ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "netoutput"; })); -// ASSERT_FALSE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "cast2"; })); -// ASSERT_FALSE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "cast3"; })); -// ASSERT_FALSE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "cast6"; })); -// -// const auto &cast2 = FindNodeWithNamePattern(graph, "cast2", false); -// ASSERT_NE(cast2, nullptr); -// const auto &cast3 = FindNodeWithNamePattern(graph, "cast3", false); -// ASSERT_NE(cast3, nullptr); -// const auto &cast6 = FindNodeWithNamePattern(graph, "cast6", false); -// ASSERT_NE(cast6, nullptr); -// ASSERT_EQ(cast2->GetOwnerComputeGraph(), cast3->GetOwnerComputeGraph()); -// ASSERT_NE(cast2->GetOwnerComputeGraph(), cast6->GetOwnerComputeGraph()); -//} - -TEST_F(UtestFftsGraphUtils, ClipNodesFromGraph_with_func_node) { - ComputeGraphPtr graph; - ComputeGraphPtr subgraph; - BuildGraphForSplit_with_func_node(graph, subgraph); - ASSERT_NE(graph, nullptr); - ASSERT_NE(subgraph, nullptr); - - std::set unsupported_nodes; - const auto &cast1 = FindNodeWithNamePattern(graph, "cast1", false); - ASSERT_NE(cast1, nullptr); - unsupported_nodes.insert(cast1); - const auto &cast3 = FindNodeWithNamePattern(graph, "cast3", false); - ASSERT_NE(cast3, nullptr); - unsupported_nodes.insert(cast3); - ASSERT_EQ(FftsGraphUtils::GraphPartition(*subgraph, unsupported_nodes), GRAPH_SUCCESS); - ASSERT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - - ASSERT_EQ(graph->GetAllSubgraphs().size(), 10); - const auto &parent_node = subgraph->GetParentNode(); - ASSERT_NE(parent_node, nullptr); - ASSERT_FALSE(parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH)); - ASSERT_TRUE(IsAllNodeMatch(subgraph, - [](const NodePtr &node) { - return !node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID); - })); - - std::vector subgraphs; - GetSubgraphsWithFilter(graph, - [](const ComputeGraphPtr &graph) { - const auto &parent_node = graph->GetParentNode(); - if ((parent_node == nullptr) || (parent_node->GetOpDesc() == nullptr)) { - return false; - } - return parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); }, - subgraphs); - ASSERT_EQ(subgraphs.size(), 5); - for (const auto &subgraph : subgraphs) { - ASSERT_TRUE(subgraph != nullptr); - ASSERT_TRUE(IsAllNodeMatch(subgraph, - [](const NodePtr &node) { - return node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID); - }, false)); - } - - ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "var1"; })); - ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "data0"; })); - ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "constant"; })); - ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "if1"; })); - ASSERT_FALSE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "less"; })); - ASSERT_TRUE(IfNodeExist(subgraph, [](const NodePtr &node) { return node->GetName() == "netoutput0"; })); - - const auto &if1 = FindNodeWithNamePattern(subgraph, "if1"); - ASSERT_NE(if1, nullptr); - std::vector if1_subgraphs; - ASSERT_EQ(NodeUtils::GetDirectSubgraphs(if1, if1_subgraphs), GRAPH_SUCCESS); - ASSERT_EQ(if1_subgraphs.size(), 2); - const auto &then1 = if1_subgraphs[0]; - ASSERT_NE(then1, nullptr); - ASSERT_TRUE(IfNodeExist(then1, [](const NodePtr &node) { return node->GetName() == "data1"; })); - ASSERT_TRUE(IfNodeExist(then1, [](const NodePtr &node) { return node->GetName() == "cast1"; })); - ASSERT_FALSE(IfNodeExist(then1, [](const NodePtr &node) { return node->GetName() == "square1"; })); - ASSERT_TRUE(IfNodeExist(then1, [](const NodePtr &node) { return node->GetName() == "netoutput1"; })); - const auto &else1 = if1_subgraphs[1]; - ASSERT_NE(else1, nullptr); - ASSERT_TRUE(IfNodeExist(else1, [](const NodePtr &node) { return node->GetName() == "data2"; })); - ASSERT_TRUE(IfNodeExist(else1, [](const NodePtr &node) { return node->GetName() == "var2"; })); - ASSERT_FALSE(IfNodeExist(else1, [](const NodePtr &node) { return node->GetName() == "cast2"; })); - ASSERT_FALSE(IfNodeExist(else1, [](const NodePtr &node) { return node->GetName() == "square2"; })); - ASSERT_TRUE(IfNodeExist(else1, [](const NodePtr &node) { return node->GetName() == "if2"; })); - ASSERT_TRUE(IfNodeExist(else1, [](const NodePtr &node) { return node->GetName() == "netoutput2"; })); - - const auto &if2 = FindNodeWithNamePattern(else1, "if2"); - ASSERT_NE(if2, nullptr); - std::vector if2_subgraphs; - ASSERT_EQ(NodeUtils::GetDirectSubgraphs(if2, if2_subgraphs), GRAPH_SUCCESS); - ASSERT_EQ(if2_subgraphs.size(), 2); - const auto &then2 = if2_subgraphs[0]; - ASSERT_NE(then2, nullptr); - ASSERT_TRUE(IfNodeExist(then2, [](const NodePtr &node) { return node->GetName() == "data3"; })); - ASSERT_TRUE(IfNodeExist(then2, [](const NodePtr &node) { return node->GetName() == "cast3"; })); - ASSERT_FALSE(IfNodeExist(then2, [](const NodePtr &node) { return node->GetName() == "square3"; })); - ASSERT_TRUE(IfNodeExist(then2, [](const NodePtr &node) { return node->GetName() == "netoutput3"; })); - const auto &else2 = if2_subgraphs[1]; - ASSERT_NE(else2, nullptr); - ASSERT_TRUE(IfNodeExist(else2, [](const NodePtr &node) { return node->GetName() == "data4"; })); - ASSERT_FALSE(IfNodeExist(else2, [](const NodePtr &node) { return node->GetName() == "cast4"; })); - ASSERT_FALSE(IfNodeExist(else2, [](const NodePtr &node) { return node->GetName() == "square4"; })); - ASSERT_TRUE(IfNodeExist(else2, [](const NodePtr &node) { return node->GetName() == "netoutput4"; })); -} - -TEST_F(UtestFftsGraphUtils, CheckRecursionDepth) { - std::map> node_value; - std::map> graph_value; - ComputeGraphPtr graph = nullptr; - ASSERT_EQ(FftsGraphUtils::Calculate(graph, nullptr, node_value, graph_value, 10), GRAPH_FAILED); - ASSERT_EQ(FftsGraphUtils::Calculate(graph, nullptr, node_value, graph_value, 9), PARAM_INVALID); - ASSERT_EQ(FftsGraphUtils::PartitionGraphWithLimit(nullptr, node_value, graph_value, {}, 10), GRAPH_FAILED); - ASSERT_EQ(FftsGraphUtils::PartitionGraphWithLimit(nullptr, node_value, graph_value, {}, 9), PARAM_INVALID); -} - -TEST_F(UtestFftsGraphUtils, SplitSubgraph_nullptr_graph) { - std::vector>> split_nodes; - split_nodes.emplace_back(std::make_pair(true, std::set{ nullptr })); - ASSERT_EQ(FftsGraphUtils::SplitSubgraph(nullptr, split_nodes), GRAPH_FAILED); -} - -TEST_F(UtestFftsGraphUtils, SetAttrForFftsPlusSubgraph_nullptr_parent_node) { - auto builder = ut::GraphBuilder(""); - ASSERT_EQ(FftsGraphUtils::SetAttrForFftsPlusSubgraph(builder.GetGraph()), GRAPH_FAILED); -} - -TEST_F(UtestFftsGraphUtils, Calculate_nullptr_node) { - NodePtr node = nullptr; - std::map> node_value; - std::map> graph_value; - ASSERT_TRUE(FftsGraphUtils::Calculate(node, nullptr, node_value, graph_value, 1).empty()); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/format_refiner_unittes.cc b/tests/ut/graph/testcase/format_refiner_unittes.cc deleted file mode 100644 index fe2861bbe7066f42abbc8c34a2cdd29590fd7e79..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/format_refiner_unittes.cc +++ /dev/null @@ -1,1033 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "graph_builder_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/refiner/format_refiner.h" -#include "graph/ref_relation.h" - -namespace ge { -class UTEST_FormatRefiner : public testing::Test { - protected: - void SetUp() - { - char* which_op = getenv("WHICH_OP"); - if (which_op != nullptr) { - is_set_env = true; - return; - } - (void)setenv("WHICH_OP", "GEOP", 0); - } - - void TearDown() - { - if (!is_set_env) { - (void)unsetenv("WHICH_OP"); - } - } - -private: - bool is_set_env{false}; - -}; - -namespace { -const string kIsGraphInferred = "_is_graph_inferred"; - -void SetFirstInferFlag(ComputeGraphPtr graph, bool is_first) { - (void)AttrUtils::SetBool(graph, kIsGraphInferred, !is_first); -} - -/* - * netoutput1 - * | - * relu1 - * | - * conv1 - * / \ - * var1 var2 - */ -ut::GraphBuilder BuildGraph1() { - auto builder = ut::GraphBuilder("g1"); - auto var1 = builder.AddNode("var1", "Variable", 0, 1); - auto var2 = builder.AddNode("var2", "Variable", 0, 1); - auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1); - auto conv_data = conv1->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NCHW); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv1->GetOpDesc()->UpdateInputDesc(0, conv_data); - auto weight = conv1->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv1->GetOpDesc()->UpdateInputDesc(1, weight); - auto conv_out = conv1->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NCHW); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv1->GetOpDesc()->UpdateOutputDesc(0, conv_out); - auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); - auto netoutput1 = builder.AddNode("netoutput1", "NetOutput", 1, 0); - - builder.AddDataEdge(var1, 0, conv1, 0); - builder.AddDataEdge(var2, 0, conv1, 1); - builder.AddDataEdge(conv1, 0, relu1, 0); - builder.AddDataEdge(relu1, 0, netoutput1, 0); - SetFirstInferFlag(builder.GetGraph(), true); - return builder; -} - -/* - * netoutput1 - * | - * relu1 - * | - * bn1 ----------------- - * | \ \ \ \ - * conv1 var3 var4 var5 var6 - * | \ - * var1 var2 - */ -ut::GraphBuilder BuildGraph2() { - auto builder = ut::GraphBuilder("g2"); - auto var1 = builder.AddNode("var1", "Variable", 0, 1); - auto var2 = builder.AddNode("var2", "Variable", 0, 1); - auto var3 = builder.AddNode("var3", "Variable", 0, 1); - auto var4 = builder.AddNode("var4", "Variable", 0, 1); - auto var5 = builder.AddNode("var5", "Variable", 0, 1); - auto var6 = builder.AddNode("var6", "Variable", 0, 1); - auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1); - auto conv_data = conv1->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NHWC); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv1->GetOpDesc()->UpdateInputDesc(0, conv_data); - auto weight = conv1->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv1->GetOpDesc()->UpdateInputDesc(1, weight); - auto conv_out = conv1->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NHWC); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv1->GetOpDesc()->UpdateOutputDesc(0, conv_out); - auto bn1 = builder.AddNode("bn1", "BatchNorm", 5, 1); - auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); - auto netoutput1 = builder.AddNode("netoutput1", "NetOutput", 1, 0); - - builder.AddDataEdge(var1, 0, conv1, 0); - builder.AddDataEdge(var2, 0, conv1, 1); - builder.AddDataEdge(conv1, 0, bn1, 0); - builder.AddDataEdge(var3, 0, bn1, 1); - builder.AddDataEdge(var4, 0, bn1, 2); - builder.AddDataEdge(var5, 0, bn1, 3); - builder.AddDataEdge(var6, 0, bn1, 4); - builder.AddDataEdge(bn1, 0, relu1, 0); - builder.AddDataEdge(relu1, 0, netoutput1, 0); - SetFirstInferFlag(builder.GetGraph(), true); - return builder; -} - -/* - * netoutput1 - * | - * conv2 - * | \ - * relu1 var3 - * | - * conv1 - * / \ - * var1 var2 - */ -ut::GraphBuilder BuildGraph3() { - auto builder = ut::GraphBuilder("g3"); - auto var1 = builder.AddNode("var1", "Variable", 0, 1); - auto var2 = builder.AddNode("var2", "Variable", 0, 1); - auto var3 = builder.AddNode("var3", "Variable", 0, 1); - auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1); - auto conv_data = conv1->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NCHW); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv1->GetOpDesc()->UpdateInputDesc(0, conv_data); - auto weight = conv1->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv1->GetOpDesc()->UpdateInputDesc(1, weight); - auto conv_out = conv1->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NCHW); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv1->GetOpDesc()->UpdateOutputDesc(0, conv_out); - auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); - auto conv2 = builder.AddNode("conv2", "Conv2D", 2, 1); - conv_data = conv2->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NHWC); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv2->GetOpDesc()->UpdateInputDesc(0, conv_data); - weight = conv2->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv2->GetOpDesc()->UpdateInputDesc(1, weight); - conv_out = conv2->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NHWC); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv2->GetOpDesc()->UpdateOutputDesc(0, conv_out); - auto netoutput1 = builder.AddNode("netoutput1", "NetOutput", 1, 0); - - builder.AddDataEdge(var1, 0, conv1, 0); - builder.AddDataEdge(var2, 0, conv1, 1); - builder.AddDataEdge(conv1, 0, relu1, 0); - builder.AddDataEdge(relu1, 0, conv2, 0); - builder.AddDataEdge(var3, 0, conv2, 1); - builder.AddDataEdge(conv2, 0, netoutput1, 0); - SetFirstInferFlag(builder.GetGraph(), true); - return builder; -} - -/* - * netoutput1 - * | - * conv2 - * | \ - * relu1 var3 - * | - * bn1 - * | - * conv1 - * / \ - * var1 var2 - */ -ut::GraphBuilder BuildGraph4() { - auto builder = ut::GraphBuilder("g4"); - auto var1 = builder.AddNode("var1", "Variable", 0, 1); - auto var2 = builder.AddNode("var2", "Variable", 0, 1); - auto var3 = builder.AddNode("var3", "Variable", 0, 1); - auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1); - auto conv_data = conv1->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NHWC); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv1->GetOpDesc()->UpdateInputDesc(0, conv_data); - auto weight = conv1->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv1->GetOpDesc()->UpdateInputDesc(1, weight); - auto conv_out = conv1->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NHWC); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv1->GetOpDesc()->UpdateOutputDesc(0, conv_out); - auto bn1 = builder.AddNode("bn1", "BatchNorm", 1, 1); - auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); - auto conv2 = builder.AddNode("conv2", "Conv2D", 2, 1); - conv_data = conv2->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NHWC); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv2->GetOpDesc()->UpdateInputDesc(0, conv_data); - weight = conv2->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv2->GetOpDesc()->UpdateInputDesc(1, weight); - conv_out = conv2->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NHWC); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv2->GetOpDesc()->UpdateOutputDesc(0, conv_out); - auto netoutput1 = builder.AddNode("netoutput1", "NetOutput", 1, 0); - - builder.AddDataEdge(var1, 0, conv1, 0); - builder.AddDataEdge(var2, 0, conv1, 1); - builder.AddDataEdge(conv1, 0, bn1, 0); - builder.AddDataEdge(bn1, 0, relu1, 0); - builder.AddDataEdge(relu1, 0, conv2, 0); - builder.AddDataEdge(var3, 0, conv2, 1); - builder.AddDataEdge(conv2, 0, netoutput1, 0); - SetFirstInferFlag(builder.GetGraph(), true); - return builder; -} - -/* - * netoutput1 - * | - * apply1 - * / \ - * relug1 --> bng1 \ - * \ / | \ \ - * relu1 | | \ - * \| | | - * | | | - * bn1 | | - * \ | | - * conv1 | - * / \| - * / | - * data1 var1 - */ -ut::GraphBuilder BuilderGraph5() { - auto builder = ut::GraphBuilder("g5"); - auto data1 = builder.AddNode("data1", "Data", 0, 1); - auto var1 = builder.AddNode("var1", "Variable", 0, 1); - auto conv1 = builder.AddNode("conv1", "Conv2D", 2, 1); - auto conv_data = conv1->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NHWC); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv1->GetOpDesc()->UpdateInputDesc(0, conv_data); - auto weight = conv1->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv1->GetOpDesc()->UpdateInputDesc(1, weight); - auto conv_out = conv1->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NHWC); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv1->GetOpDesc()->UpdateOutputDesc(0, conv_out); - auto bn1 = builder.AddNode("bn1", "BatchNorm", 1, 1); - auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); - auto relug1 = builder.AddNode("relug1", "ReluGrad", 1, 1); - auto bng1 = builder.AddNode("bng1", "BatchNormGrad", 4, 1); - auto apply1 = builder.AddNode("apply1", "ApplyMomentum", 2, 1); - auto netoutput1 = builder.AddNode("netoutput1", "NetOutput", 1, 0); - - builder.AddDataEdge(data1, 0, conv1, 0); - builder.AddDataEdge(var1, 0, conv1, 1); - builder.AddDataEdge(var1, 0, apply1, 1); - builder.AddDataEdge(conv1, 0, bn1, 0); - builder.AddDataEdge(conv1, 0, bng1, 3); - builder.AddDataEdge(bn1, 0, relu1, 0); - builder.AddDataEdge(bn1, 0, bng1, 2); - builder.AddDataEdge(relu1, 0, relug1, 0); - builder.AddDataEdge(relu1, 0, bng1, 1); - builder.AddDataEdge(relug1, 0, bng1, 0); - builder.AddDataEdge(bng1, 0, apply1, 0); - builder.AddDataEdge(apply1, 0, netoutput1, 0); - SetFirstInferFlag(builder.GetGraph(), true); - return builder; -} -/* - * netoutput1 - * | - * AddN - * / \ \ - * L2Loss GatherV2 Constant - * / \ - * Data1 Data2 - * - * - */ -ut::GraphBuilder BuildGraph6() { - auto builder = ut::GraphBuilder("g1"); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto loss = builder.AddNode("loss", "L2Loss", 1, 1); - auto gather = builder.AddNode("gather", "GatherV2", 1, 1); - auto addN = builder.AddNode("addN", "AddN", 3, 1); - auto netoutput = builder.AddNode("netoutput", "NetOutput", 1, 0); - auto constant = builder.AddNode("constant", "Constant", 0, 1); - - auto data1_input = data1->GetOpDesc()->GetInputDesc(0); - data1_input.SetFormat(FORMAT_HWCN); - data1->GetOpDesc()->UpdateInputDesc(0, data1_input); - auto data1_output = data1->GetOpDesc()->GetOutputDesc(0); - data1_output.SetFormat(FORMAT_HWCN); - data1->GetOpDesc()->UpdateOutputDesc(0, data1_output); - - auto net_input = netoutput->GetOpDesc()->GetInputDesc(0); - net_input.SetFormat(FORMAT_NCHW); - netoutput->GetOpDesc()->UpdateInputDesc(0, net_input); - - auto data2_input = data2->GetOpDesc()->GetInputDesc(0); - data2_input.SetFormat(FORMAT_HWCN); - data2->GetOpDesc()->UpdateInputDesc(0, data2_input); - auto data2_output = data2->GetOpDesc()->GetOutputDesc(0); - data2_output.SetFormat(FORMAT_HWCN); - data2->GetOpDesc()->UpdateOutputDesc(0, data2_output); - - builder.AddDataEdge(data1, 0, loss, 0); - builder.AddDataEdge(data2, 0, gather, 0); - builder.AddDataEdge(loss, 0, addN, 0); - builder.AddDataEdge(gather, 0, addN, 1); - builder.AddDataEdge(constant, 0, addN, 2); - builder.AddDataEdge(addN, 0, netoutput, 0); - - SetFirstInferFlag(builder.GetGraph(), true); - - return builder; -} - -/* - * data2 - * | - * data1 relu - * | - * reshape - * \ / - * conv - * | - * netoutput - */ - -ut::GraphBuilder BuildGraph8() { - auto builder = ut::GraphBuilder("g8"); - - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto relu = builder.AddNode("relu", "Relu", 1, 1); - auto reshape = builder.AddNode("reshape", "Reshape", 1, 1); - auto conv = builder.AddNode("conv", "Conv2D", 2, 1); - auto netoutput = builder.AddNode("netoutput", "NetOutput", 1, 0); - - auto reshape_data = reshape->GetOpDesc()->GetInputDesc(0); - reshape_data.SetFormat(FORMAT_ND); - reshape_data.SetOriginFormat(FORMAT_ND); - reshape_data.SetShape(GeShape(std::vector({224, 224}))); - reshape_data.SetShape(GeShape(std::vector({224, 224}))); - reshape->GetOpDesc()->UpdateInputDesc(0, reshape_data); - reshape->GetOpDesc()->UpdateOutputDesc(0, reshape_data); - - auto conv_data = conv->GetOpDesc()->GetInputDesc(0); - conv_data.SetFormat(FORMAT_NHWC); - conv_data.SetShape(GeShape(std::vector({1, 3, 224, 224}))); - conv->GetOpDesc()->UpdateInputDesc(0, conv_data); - auto weight = conv->GetOpDesc()->GetInputDesc(1); - weight.SetFormat(FORMAT_HWCN); - weight.SetShape(GeShape(std::vector({1, 1, 3, 256}))); - conv->GetOpDesc()->UpdateInputDesc(1, weight); - auto conv_out = conv->GetOpDesc()->GetOutputDesc(0); - conv_out.SetFormat(FORMAT_NHWC); - conv_out.SetShape(GeShape(std::vector({1, 256, 224, 224}))); - conv->GetOpDesc()->UpdateOutputDesc(0, conv_out); - - builder.AddDataEdge(data1, 0, conv, 0); - builder.AddDataEdge(data2, 0, relu, 0); - builder.AddDataEdge(relu, 0, reshape, 0); - builder.AddDataEdge(reshape, 0, conv, 1); - builder.AddDataEdge(conv, 0, netoutput, 0); - SetFirstInferFlag(builder.GetGraph(), true); - return builder; -} - -/* - * netoutput1 - * | - * BiasAdd - * | - * square - * | - * var - */ -ut::GraphBuilder BuildGraph9() { - auto builder = ut::GraphBuilder("g9"); - auto var = builder.AddNode("var", "Variable", 0, 1); - auto square = builder.AddNode("square", "Square", 1, 1); - auto biasadd = builder.AddNode("biasadd", "BiasAdd", 1, 1); - auto netoutput1 = builder.AddNode("netoutput1", "NetOutput", 1, 0); - - auto biasadd_data = biasadd->GetOpDesc()->GetInputDesc(0); - biasadd_data.SetFormat(FORMAT_NHWC); - biasadd_data.SetOriginFormat(FORMAT_NHWC); - biasadd_data.SetShape(GeShape(std::vector({1, 3, 3,224, 224}))); - biasadd->GetOpDesc()->UpdateInputDesc(0, biasadd_data); - auto biasadd_out = biasadd->GetOpDesc()->GetOutputDesc(0); - biasadd_out.SetFormat(FORMAT_NHWC); - biasadd_out.SetOriginFormat(FORMAT_NHWC); - biasadd_out.SetShape(GeShape(std::vector({1, 3, 256, 224, 224}))); - biasadd->GetOpDesc()->UpdateOutputDesc(0, biasadd_out); - - - builder.AddDataEdge(var, 0, square, 0); - builder.AddDataEdge(square, 0, biasadd, 0); - builder.AddDataEdge(biasadd, 0, netoutput1, 0); - SetFirstInferFlag(builder.GetGraph(), true); - return builder; -} - -/* - * netoutput1 - * | \ - * sub variable - * / \ - * data1 data2 - */ -ComputeGraphPtr BuildSubGraphWithVariable(const std::string name,ge::Format to_be_set_format = FORMAT_ND) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto sub = builder.AddNode(name + "sub", "Sub", 2, 1, to_be_set_format); - auto variable = builder.AddNode("variable", "Variable", 0, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 2, 2); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", static_cast(1)); - - - builder.AddDataEdge(data1, 0, sub, 0); - builder.AddDataEdge(data2, 0, sub, 1); - builder.AddDataEdge(sub, 0, netoutput, 0); - builder.AddDataEdge(data2, 0, variable, 0); - builder.AddDataEdge(variable, 0, netoutput, 1); - - - return builder.GetGraph(); -} - -/* - * netoutput1 - * | \ - * sub relu - * / \ / - * data1 data2 - */ -ComputeGraphPtr BuildSubGraph(const std::string name,ge::Format to_be_set_format = FORMAT_ND) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto sub = builder.AddNode(name + "sub", "Sub", 2, 1, to_be_set_format); - auto relu = builder.AddNode(name + "relu", "Relu", 1, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 2, 2); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", static_cast(1)); - - - builder.AddDataEdge(data1, 0, sub, 0); - builder.AddDataEdge(data2, 0, sub, 1); - builder.AddDataEdge(sub, 0, netoutput, 0); - builder.AddDataEdge(data2, 0, relu, 0); - builder.AddDataEdge(relu, 0, netoutput, 1); - - - return builder.GetGraph(); -} -/* - * netoutput relu - * | / - * if - * / \ - * data1 data2 - */ -ComputeGraphPtr BuildMainGraphWithIf(string anchor_graph) { - ut::GraphBuilder builder("main_graph"); - auto to_be_set_format = FORMAT_ND; - auto to_be_set_format_of_sub = FORMAT_ND; - if (anchor_graph == "main") { - to_be_set_format = FORMAT_NHWC; - to_be_set_format_of_sub = FORMAT_ND; - } else { - to_be_set_format = FORMAT_ND; - to_be_set_format_of_sub = FORMAT_NHWC; - } - auto data1 = builder.AddNode("data1", "Data", 1, 1, to_be_set_format); - auto data2 = builder.AddNode("data2", "Data", 1, 1, to_be_set_format); - auto if1 = builder.AddNode("if", "If", 2, 2); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 2, 2); - auto relu = builder.AddNode("relu", "Relu", 1, 1); - - builder.AddDataEdge(data1, 0, if1, 0); - builder.AddDataEdge(data2, 0, if1, 1); - builder.AddDataEdge(if1, 0, netoutput1, 0); - builder.AddDataEdge(if1, 1, relu, 0); - builder.AddDataEdge(relu, 0, netoutput1, 1); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildSubGraph("sub1", to_be_set_format_of_sub); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildSubGraph("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} - -/* - * netoutput relu - * | / - * if - * / \ - * data1 data2 - */ -ComputeGraphPtr BuildMainGraphWithIfAndVariable(string anchor_graph, const Format if_format = FORMAT_NCHW) { - ut::GraphBuilder builder("main_graph"); - auto to_be_set_format = FORMAT_ND; - auto to_be_set_format_of_sub = FORMAT_ND; - if (anchor_graph == "main") { - to_be_set_format = FORMAT_NHWC; - to_be_set_format_of_sub = FORMAT_ND; - } else { - to_be_set_format = FORMAT_ND; - to_be_set_format_of_sub = FORMAT_NHWC; - } - auto data1 = builder.AddNode("data1", "Data", 1, 1, to_be_set_format); - auto data2 = builder.AddNode("data2", "Data", 1, 1, to_be_set_format); - auto if1 = builder.AddNode("if", "If", 2, 2, if_format); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 2, 2); - auto relu = builder.AddNode("relu", "Relu", 1, 1); - - builder.AddDataEdge(data1, 0, if1, 0); - builder.AddDataEdge(data2, 0, if1, 1); - builder.AddDataEdge(if1, 0, netoutput1, 0); - builder.AddDataEdge(if1, 1, relu, 0); - builder.AddDataEdge(relu, 0, netoutput1, 1); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildSubGraphWithVariable("sub1", to_be_set_format_of_sub); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildSubGraphWithVariable("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} - -/* - * netoutput1 - * | - * relu1 - * | - * Add1(locked)(in ND out FZ) - * / \ - * var1(NCHW) var2(NCHW) - */ -ut::GraphBuilder BuildGraphWithFormatLocked() { - auto builder = ut::GraphBuilder("glocked"); - auto var1 = builder.AddNode("var1", "Variable", 0, 1, FORMAT_NCHW, DT_INT8, {1, 1, 28, 28}); - auto var2 = builder.AddNode("var2", "Variable", 0, 1, FORMAT_NCHW, DT_INT8, {1, 1, 28, 28}); - auto add1 = builder.AddNode("add1", "Add", 2, 1, FORMAT_ND, DT_INT8, {1, 1, 28, 28}); - auto relu1 = builder.AddNode("relu1", "Relu", 1, 1, FORMAT_NCHW, DT_INT8); - auto netoutput1 = builder.AddNode("netoutput1", "NetOutput", 1, 0); - - auto add1_op_desc = add1->GetOpDesc()->GetOutputDesc(0); - add1_op_desc.SetFormat(FORMAT_FRACTAL_Z); - add1_op_desc.SetOriginFormat(FORMAT_FRACTAL_Z); - add1_op_desc.SetShape(GeShape(std::vector({1, 1, 28, 28}))); - auto add1_op = OpDescUtils::CreateOperatorFromNode(add1); - add1_op.SetAttr(ATTR_NAME_FORMAT_LOCKED, true); - add1->GetOpDesc()->UpdateOutputDesc(0, add1_op_desc); - - builder.AddDataEdge(var1, 0, add1, 0); - builder.AddDataEdge(var2, 0, add1, 1); - builder.AddDataEdge(add1, 0, relu1, 0); - builder.AddDataEdge(relu1, 0, netoutput1, 0); - SetFirstInferFlag(builder.GetGraph(), true); - - return builder; -} - -} -// Test BiasAdd special process -TEST_F(UTEST_FormatRefiner, biasadd_special_process) { - auto builder = BuildGraph9(); - auto graph = builder.GetGraph(); - SetFirstInferFlag(graph, false); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - auto square = graph->FindNode("square"); - auto biasadd = graph->FindNode("biasadd"); - EXPECT_EQ(square->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(square->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(biasadd->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NDHWC); - EXPECT_EQ(biasadd->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NDHWC); - SetFirstInferFlag(graph, true); -} -// only main graph own anchor point -TEST_F(UTEST_FormatRefiner, with_if_sub_graph_1) { - auto main_graph = BuildMainGraphWithIf("main"); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(main_graph), GRAPH_SUCCESS); - // check main graph format - auto if1 = main_graph->FindNode("if"); - auto relu = main_graph->FindNode("relu"); - auto netoutput = main_graph->FindNode("netoutput"); - EXPECT_EQ(if1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetOutputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - // check sub graph - auto sub_graph_1 = main_graph->GetSubgraph("sub1"); - auto sub_graph_2 = main_graph->GetSubgraph("sub2"); - string prefix_1 = "sub1"; - string prefix_2 = "sub2"; - auto sub1_data_1 = sub_graph_1->FindNode(prefix_1 + "data1"); - auto sub1_data_2 = sub_graph_1->FindNode(prefix_1 + "data2"); - auto sub1_relu = sub_graph_1->FindNode(prefix_1 + "relu"); - auto sub1_sub = sub_graph_1->FindNode(prefix_1 + "sub"); - auto sub1_netoutput = sub_graph_1->FindNode(prefix_1 + "netoutput"); - auto sub2_data_1 = sub_graph_2->FindNode(prefix_2 + "data1"); - auto sub2_data_2 = sub_graph_2->FindNode(prefix_2 + "data2"); - auto sub2_relu = sub_graph_2->FindNode(prefix_2 + "relu"); - auto sub2_sub = sub_graph_2->FindNode(prefix_2 + "sub"); - auto sub2_netoutput = sub_graph_2->FindNode(prefix_2 + "netoutput"); - - EXPECT_EQ(sub1_data_1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_data_1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_data_2->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_data_2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_sub->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_sub->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_netoutput->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_2->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_sub->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_sub->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_netoutput->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); -} - -// If节点FORMAT_ND, 通过RefRelation对子图格式进行推导 -TEST_F(UTEST_FormatRefiner, InferOriginFormat_IfIsNDWithSubgraph_ReflectionProcessOK) { - auto main_graph = BuildMainGraphWithIfAndVariable("main", Format::FORMAT_ND); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(main_graph), GRAPH_SUCCESS); - // check main graph format - auto if1 = main_graph->FindNode("if"); - auto relu = main_graph->FindNode("relu"); - auto netoutput = main_graph->FindNode("netoutput"); - EXPECT_EQ(if1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetOutputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(if1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - // check sub graph - auto sub_graph_1 = main_graph->GetSubgraph("sub1"); - auto sub_graph_2 = main_graph->GetSubgraph("sub2"); - string prefix_1 = "sub1"; - string prefix_2 = "sub2"; - auto sub1_data_1 = sub_graph_1->FindNode(prefix_1 + "data1"); - auto sub1_data_2 = sub_graph_1->FindNode(prefix_1 + "data2"); - auto sub1_relu = sub_graph_1->FindNode(prefix_1 + "relu"); - auto sub1_sub = sub_graph_1->FindNode(prefix_1 + "sub"); - auto sub1_netoutput = sub_graph_1->FindNode(prefix_1 + "netoutput"); - auto sub2_data_1 = sub_graph_2->FindNode(prefix_2 + "data1"); - auto sub2_data_2 = sub_graph_2->FindNode(prefix_2 + "data2"); - auto sub2_relu = sub_graph_2->FindNode(prefix_2 + "relu"); - auto sub2_sub = sub_graph_2->FindNode(prefix_2 + "sub"); - auto sub2_netoutput = sub_graph_2->FindNode(prefix_2 + "netoutput"); - - EXPECT_EQ(sub1_data_1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_data_1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_data_2->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_data_2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_sub->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_sub->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_netoutput->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub2_data_1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub2_data_2->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub2_data_2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub2_sub->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub2_sub->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub2_netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_netoutput->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); -} - -// only sub graph own anchor point -TEST_F(UTEST_FormatRefiner, with_if_sub_graph_2) { - auto main_graph = BuildMainGraphWithIf("sub"); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(main_graph), GRAPH_SUCCESS); - // check main graph format - auto if1 = main_graph->FindNode("if"); - auto relu = main_graph->FindNode("relu"); - auto netoutput = main_graph->FindNode("netoutput"); - EXPECT_EQ(if1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetOutputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(if1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - // check sub graph - auto sub_graph_1 = main_graph->GetSubgraph("sub1"); - auto sub_graph_2 = main_graph->GetSubgraph("sub2"); - string prefix_1 = "sub1"; - string prefix_2 = "sub2"; - auto sub1_data_1 = sub_graph_1->FindNode(prefix_1 + "data1"); - auto sub1_data_2 = sub_graph_1->FindNode(prefix_1 + "data2"); - auto sub1_relu = sub_graph_1->FindNode(prefix_1 + "relu"); - auto sub1_sub = sub_graph_1->FindNode(prefix_1 + "sub"); - auto sub1_netoutput = sub_graph_1->FindNode(prefix_1 + "netoutput"); - auto sub2_data_1 = sub_graph_2->FindNode(prefix_2 + "data1"); - auto sub2_data_2 = sub_graph_2->FindNode(prefix_2 + "data2"); - auto sub2_relu = sub_graph_2->FindNode(prefix_2 + "relu"); - auto sub2_sub = sub_graph_2->FindNode(prefix_2 + "sub"); - auto sub2_netoutput = sub_graph_2->FindNode(prefix_2 + "netoutput"); - - EXPECT_EQ(sub1_data_1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_data_1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_data_2->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_data_2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_sub->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_sub->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(sub1_netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub1_netoutput->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_2->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_data_2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_sub->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_sub->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_netoutput->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(sub2_netoutput->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); -} - -TEST_F(UTEST_FormatRefiner, data_format) { - auto builder = BuildGraph8(); - auto graph = builder.GetGraph(); - SetFirstInferFlag(graph, false); - graph->SaveDataFormat(FORMAT_NCHW); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - auto data2 = graph->FindNode("data2"); - auto relu = graph->FindNode("relu"); - EXPECT_EQ(data2->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(data2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - SetFirstInferFlag(graph, true); -} - -TEST_F(UTEST_FormatRefiner, constantFail) { - auto builder = BuildGraph6(); - auto graph = builder.GetGraph(); - SetFirstInferFlag(graph, true); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_FAILED); -} -TEST_F(UTEST_FormatRefiner, scalarNodesInfer) { - auto builder = BuildGraph6(); - auto graph = builder.GetGraph(); - SetFirstInferFlag(graph, true); - auto constant = graph->FindNode("constant"); - ge::GeTensorPtr value = std::make_shared(); - AttrUtils::SetTensor(constant->GetOpDesc(), "value", value); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); -} - -TEST_F(UTEST_FormatRefiner, ForwardAndDefaultInferFunc) { - auto builder = BuildGraph1(); - auto graph = builder.GetGraph(); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - auto var1 = graph->FindNode("var1"); - EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto var2 = graph->FindNode("var2"); - EXPECT_EQ(var2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto relu1 = graph->FindNode("relu1"); - EXPECT_EQ(relu1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto netoutput1 = graph->FindNode("netoutput1"); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto conv1 = graph->FindNode("conv1"); - EXPECT_EQ(conv1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(conv1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(conv1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); -} - - -TEST_F(UTEST_FormatRefiner, ForwardAndSpecifedInferFunc) { - auto builder = BuildGraph1(); - auto graph = builder.GetGraph(); - auto relu1 = graph->FindNode("relu1"); - relu1->GetOpDesc()->AddInferFormatFunc([](Operator &op) { - auto output1 = op.GetOutputDesc(0); - output1.SetOriginFormat(FORMAT_NHWC); - op.UpdateOutputDesc("0", output1); - return GRAPH_SUCCESS; - }); - - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - auto var1 = graph->FindNode("var1"); - EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto var2 = graph->FindNode("var2"); - EXPECT_EQ(var2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto netoutput1 = graph->FindNode("netoutput1"); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); -} - -TEST_F(UTEST_FormatRefiner, FailedWhenInfer) { - auto builder = BuildGraph1(); - auto graph = builder.GetGraph(); - auto relu1 = graph->FindNode("relu1"); - relu1->GetOpDesc()->AddInferFormatFunc([](Operator &op) { - return GRAPH_FAILED; - }); - - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); -} - -TEST_F(UTEST_FormatRefiner, ForwardBackward) { - auto builder = BuildGraph2(); - auto graph = builder.GetGraph(); - - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - auto bn1 = graph->FindNode("bn1"); - EXPECT_EQ(bn1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(bn1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - for (auto name : {"var3", "var4", "var5", "var6"}) { - auto node = graph->FindNode(name); - EXPECT_EQ(node->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - } -} - -TEST_F(UTEST_FormatRefiner, FormatConflict) { - auto builder = BuildGraph3(); - auto graph = builder.GetGraph(); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); -} - -TEST_F(UTEST_FormatRefiner, InferStopND) { - auto builder = BuildGraph1(); - auto graph = builder.GetGraph(); - auto relu1 = graph->FindNode("relu1"); - relu1->GetOpDesc()->AddInferFormatFunc([](Operator &op) { - return GRAPH_SUCCESS; - }); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - auto var1 = graph->FindNode("var1"); - EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto var2 = graph->FindNode("var2"); - EXPECT_EQ(var2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - relu1 = graph->FindNode("relu1"); - EXPECT_EQ(relu1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto netoutput1 = graph->FindNode("netoutput1"); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto conv1 = graph->FindNode("conv1"); - EXPECT_EQ(conv1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(conv1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(conv1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); -} - -TEST_F(UTEST_FormatRefiner, InferStopSameFormat) { - auto builder = BuildGraph4(); - auto graph = builder.GetGraph(); - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - -} - -TEST_F(UTEST_FormatRefiner, ForwardMultiOutput) { - auto builder = BuilderGraph5(); - auto graph = builder.GetGraph(); - auto apply1 = graph->FindNode("apply1"); - apply1->GetOpDesc()->AddInferFormatFunc([](Operator &op) { - auto out = op.GetOutputDesc(0); - out.SetOriginFormat(FORMAT_NHWC); - op.UpdateOutputDesc("0", out); - auto in0 = op.GetInputDesc(0); - in0.SetOriginFormat(FORMAT_NHWC); - op.UpdateInputDesc("0", in0); - auto in1 = op.GetInputDesc(1); - in1.SetOriginFormat(FORMAT_HWCN); - op.UpdateInputDesc("1", in1); - return GRAPH_SUCCESS; - }); - - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - - auto data1 = graph->FindNode("data1"); - EXPECT_EQ(data1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto var1 = graph->FindNode("var1"); - EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto bn1 = graph->FindNode("bn1"); - EXPECT_EQ(bn1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(bn1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto relu1 = graph->FindNode("relu1"); - EXPECT_EQ(relu1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto relug1 = graph->FindNode("relug1"); - EXPECT_EQ(relug1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relug1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto bng1 = graph->FindNode("bng1"); - EXPECT_EQ(bng1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(bng1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(bng1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(bng1->GetOpDesc()->GetInputDesc(2).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(bng1->GetOpDesc()->GetInputDesc(3).GetOriginFormat(), FORMAT_NCHW); - - EXPECT_EQ(apply1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(apply1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(apply1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_NCHW); - -} - -TEST_F(UTEST_FormatRefiner, FormatInferWithLockedNode) { - auto builder = BuildGraphWithFormatLocked(); - auto graph = builder.GetGraph(); - // before infer format - auto add1 = graph->FindNode("add1"); - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(0).GetFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(1).GetFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc(0).GetFormat(), FORMAT_FRACTAL_Z); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_FRACTAL_Z); - - EXPECT_EQ(FormatRefiner::InferOrigineFormat(graph), GRAPH_SUCCESS); - auto var1 = graph->FindNode("var1"); - EXPECT_EQ(var1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto var2 = graph->FindNode("var2"); - EXPECT_EQ(var2->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto relu1 = graph->FindNode("relu1"); - EXPECT_EQ(relu1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(relu1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_NCHW); - auto netoutput1 = graph->FindNode("netoutput1"); - EXPECT_EQ(netoutput1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_NCHW); - // after format infer - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(0).GetFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(0).GetOriginFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(1).GetFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetInputDesc(1).GetOriginFormat(), FORMAT_ND); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc(0).GetFormat(), FORMAT_FRACTAL_Z); - EXPECT_EQ(add1->GetOpDesc()->GetOutputDesc(0).GetOriginFormat(), FORMAT_FRACTAL_Z); -} - -TEST_F(UTEST_FormatRefiner, InferOrigineFormatFailed) { - ge::ComputeGraphPtr graph = nullptr; - auto status = FormatRefiner::InferOrigineFormat(graph); - EXPECT_EQ(status, GRAPH_FAILED); -} -TEST_F(UTEST_FormatRefiner, SaveFormat) { - auto builder = BuildGraph6(); - auto graph = builder.GetGraph(); - SetFirstInferFlag(graph, true); - graph->SaveDataFormat(FORMAT_NHWC); - auto save_format = graph->GetDataFormat(); - EXPECT_EQ(save_format, FORMAT_NHWC); - graph->SaveDataFormat(FORMAT_ND); -} -} diff --git a/tests/ut/graph/testcase/ge_attr_define_unittest.cc b/tests/ut/graph/testcase/ge_attr_define_unittest.cc deleted file mode 100644 index cf4311d2a6ccebaa967657a75d5b19bc33e6a40d..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ge_attr_define_unittest.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -* ===================================================================================================================*/ - -#include -#include -#include "graph/op_desc.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/attr_utils.h" - -namespace ge { -class UtestGeAttrDefine : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestGeAttrDefine, GetAttachedAttrDefine) { - OpDescPtr op_desc = std::make_shared(); - NamedAttrs attr; - (void) ge::AttrUtils::SetStr(attr, ATTR_NAME_ATTACHED_RESOURCE_NAME, "tiling"); - (void) ge::AttrUtils::SetStr(attr, ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY, "tiling_key"); - (void) ge::AttrUtils::SetListInt(attr, ATTR_NAME_ATTACHED_RESOURCE_DEPEND_VALUE_LIST_INT, {0, 1, 2}); - (void) ge::AttrUtils::SetBool(attr, ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG, true); - (void) ge::AttrUtils::SetInt(attr, ATTR_NAME_ATTACHED_RESOURCE_ID, 1); - (void) ge::AttrUtils::SetBool(attr, ATTR_NAME_ATTACHED_RESOURCE_IS_VALID, false); - std::vector list_name_attr_set; - list_name_attr_set.emplace_back(attr); - ge::AttrUtils::SetListNamedAttrs(op_desc, ATTR_NAME_ATTACHED_STREAM_INFO_LIST, list_name_attr_set); - - std::vector list_name_attr_get; - ge::AttrUtils::GetListNamedAttrs(op_desc, ATTR_NAME_ATTACHED_STREAM_INFO_LIST, list_name_attr_get); - ASSERT_EQ(list_name_attr_get.size(), 1U); - - std::string ret_str; - EXPECT_EQ(ge::AttrUtils::GetStr(list_name_attr_get[0], ATTR_NAME_ATTACHED_RESOURCE_NAME, ret_str), true); - EXPECT_EQ(ret_str, "tiling"); - - EXPECT_EQ(ge::AttrUtils::GetStr(list_name_attr_get[0], ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY, ret_str), true); - EXPECT_EQ(ret_str, "tiling_key"); - - std::vector ret_list; - EXPECT_EQ( - ge::AttrUtils::GetListInt(list_name_attr_get[0], ATTR_NAME_ATTACHED_RESOURCE_DEPEND_VALUE_LIST_INT, ret_list), - true); - EXPECT_EQ(ret_list.size(), 3); - EXPECT_EQ(ret_list[2], 2); - - bool ret_bool; - EXPECT_EQ(ge::AttrUtils::GetBool(list_name_attr_get[0], ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG, ret_bool), true); - EXPECT_EQ(ret_bool, true); - - int64_t ret_int; - EXPECT_EQ(ge::AttrUtils::GetInt(list_name_attr_get[0], ATTR_NAME_ATTACHED_RESOURCE_ID, ret_int), true); - EXPECT_EQ(ret_int, 1); - - EXPECT_EQ(ge::AttrUtils::GetBool(list_name_attr_get[0], ATTR_NAME_ATTACHED_RESOURCE_IS_VALID, ret_bool), true); - EXPECT_EQ(ret_bool, false); - - EXPECT_EQ(ge::AttrUtils::HasAttr(list_name_attr_get[0], ATTR_NAME_ATTACHED_RESOURCE_TYPE), false); -} -} diff --git a/tests/ut/graph/testcase/ge_attr_value_unittest.cc b/tests/ut/graph/testcase/ge_attr_value_unittest.cc deleted file mode 100644 index 81bbdbe66723040e82511f307e62650020bab495..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ge_attr_value_unittest.cc +++ /dev/null @@ -1,535 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/op_desc.h" -#include "graph/ge_attr_value.h" -#include "graph/utils/attr_utils.h" -#include "external/graph/attr_value.h" -#include "external/graph/tensor.h" -#include "external/graph/ascend_string.h" -#include "external/graph/types.h" -#include -#include -#include - -namespace ge { -class UtestGeAttrValue : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - - -TEST_F(UtestGeAttrValue, GetAttrsStrAfterRid) { - string name = "const"; - string type = "Constant"; - OpDescPtr op_desc = std::make_shared(); - EXPECT_EQ(AttrUtils::GetAttrsStrAfterRid(op_desc, {}), ""); - - std::set names ={"qazwsx", "d"}; - op_desc->SetAttr("qazwsx", GeAttrValue::CreateFrom(132)); - op_desc->SetAttr("xswzaq", GeAttrValue::CreateFrom(123)); - auto tensor = GeTensor(); - op_desc->SetAttr("value", GeAttrValue::CreateFrom(tensor)); - std::string res = AttrUtils::GetAttrsStrAfterRid(op_desc, names); - EXPECT_TRUE(res.find("qazwsx") == string::npos); - EXPECT_TRUE(res.find("xswzaq") != string::npos); -} - -TEST_F(UtestGeAttrValue, GetAllAttrsStr) { - // 属性序列化 - string name = "const"; - string type = "Constant"; - OpDescPtr op_desc = std::make_shared(name, type); - EXPECT_TRUE(op_desc); - EXPECT_EQ(AttrUtils::GetAllAttrsStr(op_desc), ""); - op_desc->SetAttr("seri_i", GeAttrValue::CreateFrom(1)); - auto tensor = GeTensor(); - op_desc->SetAttr("seri_value", GeAttrValue::CreateFrom(tensor)); - op_desc->SetAttr("seri_input_desc", GeAttrValue::CreateFrom(GeTensorDesc())); - string attr = AttrUtils::GetAllAttrsStr(op_desc); - - EXPECT_TRUE(attr.find("seri_i") != string::npos); - EXPECT_TRUE(attr.find("seri_value") != string::npos); - EXPECT_TRUE(attr.find("seri_input_desc") != string::npos); - -} -TEST_F(UtestGeAttrValue, GetAllAttrs) { - string name = "const"; - string type = "Constant"; - OpDescPtr op_desc = std::make_shared(name, type); - EXPECT_TRUE(op_desc); - op_desc->SetAttr("i", GeAttrValue::CreateFrom(100)); - op_desc->SetAttr("input_desc", GeAttrValue::CreateFrom(GeTensorDesc())); - auto attrs = AttrUtils::GetAllAttrs(op_desc); - EXPECT_EQ(attrs.size(), 2); - int64_t attr_value = 0; - EXPECT_EQ(attrs["i"].GetValue(attr_value), GRAPH_SUCCESS); - EXPECT_EQ(attr_value, 100); - -} - -TEST_F(UtestGeAttrValue, TrySetExists) { - string name = "const"; - string type = "Constant"; - OpDescPtr op_desc = std::make_shared(name, type); - EXPECT_TRUE(op_desc); - - int64_t attr_value = 0; - - EXPECT_FALSE(AttrUtils::GetInt(op_desc, "i", attr_value)); - op_desc->TrySetAttr("i", GeAttrValue::CreateFrom(100)); - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "i", attr_value)); - EXPECT_EQ(attr_value, 100); - - op_desc->TrySetAttr("i", GeAttrValue::CreateFrom(102)); - attr_value = 0; - AttrUtils::GetInt(op_desc, "i", attr_value); - EXPECT_EQ(attr_value, 100); - - uint64_t uint64_val = 0U; - EXPECT_TRUE(AttrUtils::GetInt(op_desc, "i", uint64_val)); - EXPECT_EQ(uint64_val, 100U); -} - -TEST_F(UtestGeAttrValue, CloneOpDesc_check_null) { - OpDescPtr op_desc = nullptr; - auto ret = AttrUtils::CloneOpDesc(op_desc); - EXPECT_EQ(ret == nullptr, true); -} - -TEST_F(UtestGeAttrValue, CopyOpDesc_check_null) { - OpDescPtr op_desc = nullptr; - auto ret = AttrUtils::CopyOpDesc(op_desc); - EXPECT_EQ(ret == nullptr, true); -} - -TEST_F(UtestGeAttrValue, SetGetListInt) { - OpDescPtr op_desc = std::make_shared("const1", "Identity"); - EXPECT_TRUE(op_desc); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "li1", std::vector({1,2,3,4,5}))); - std::vector li1_out0; - EXPECT_TRUE(AttrUtils::GetListInt(op_desc, "li1", li1_out0)); - EXPECT_EQ(li1_out0, std::vector({1,2,3,4,5})); -} - -TEST_F(UtestGeAttrValue, SetListIntGetByGeAttrValue) { - OpDescPtr op_desc = std::make_shared("const1", "Identity"); - EXPECT_TRUE(op_desc); - - EXPECT_TRUE(AttrUtils::SetListInt(op_desc, "li1", std::vector({1,2,3,4,5}))); - auto names_to_value = AttrUtils::GetAllAttrs(op_desc); - auto iter = names_to_value.find("li1"); - EXPECT_NE(iter, names_to_value.end()); - - std::vector li1_out; - auto &ge_value = iter->second; - EXPECT_EQ(ge_value.GetValue(li1_out), GRAPH_SUCCESS); - EXPECT_EQ(li1_out, std::vector({1,2,3,4,5})); - - li1_out.clear(); - EXPECT_EQ(ge_value.GetValue>(li1_out), GRAPH_SUCCESS); - EXPECT_EQ(li1_out, std::vector({1,2,3,4,5})); -} - -TEST_F(UtestGeAttrValue, SetGetAttr_GeTensor) { - OpDescPtr op_desc = std::make_shared("const1", "Identity"); - GeTensorDesc td; - td.SetShape(GeShape(std::vector({1,100}))); - td.SetOriginShape(GeShape(std::vector({1,100}))); - td.SetDataType(DT_FLOAT); - td.SetFormat(FORMAT_ND); - float data[100]; - for (size_t i = 0; i < 100; ++i) { - data[i] = 1.0 * i; - } - auto tensor = std::make_shared(td, reinterpret_cast(data), sizeof(data)); - EXPECT_NE(tensor, nullptr); - - EXPECT_TRUE(AttrUtils::SetTensor(op_desc, "t", tensor)); - tensor = nullptr; - - EXPECT_TRUE(AttrUtils::MutableTensor(op_desc, "t", tensor)); - EXPECT_NE(tensor, nullptr); - - EXPECT_EQ(tensor->GetData().GetSize(), sizeof(data)); - auto attr_data = reinterpret_cast(tensor->GetData().GetData()); - for (size_t i = 0; i < 100; ++i) { - EXPECT_FLOAT_EQ(attr_data[i], data[i]); - } - tensor = nullptr; - - EXPECT_TRUE(AttrUtils::MutableTensor(op_desc, "t", tensor)); - EXPECT_NE(tensor, nullptr); - - EXPECT_EQ(tensor->GetData().GetSize(), sizeof(data)); - attr_data = reinterpret_cast(tensor->GetData().GetData()); - for (size_t i = 0; i < 100; ++i) { - EXPECT_FLOAT_EQ(attr_data[i], data[i]); - } - tensor = nullptr; -} - -TEST_F(UtestGeAttrValue, GetStr) { - OpDescPtr op_desc = std::make_shared("Add", "Add"); - EXPECT_TRUE(op_desc); - - std::string add_info = "add_info"; - AttrUtils::SetStr(op_desc, "compile_info_key", add_info); - const std::string *s2 = AttrUtils::GetStr(op_desc, "compile_info_key"); - EXPECT_NE(s2, nullptr); - EXPECT_EQ(*s2, add_info); -} - -TEST_F(UtestGeAttrValue, GetStr_for_2_name) { - OpDescPtr op_desc = std::make_shared("Add", "Add"); - EXPECT_TRUE(op_desc); - - std::string name1 = "compile_info_key1"; - std::string name2 = "compile_info_key2"; - std::string value1 = "add_info1"; - std::string value2 = "add_info2"; - std::string value; - - // name1未设置属性,name2设置属性,获取的值是name2的属性值 - AttrUtils::SetStr(op_desc, name2, value2); - EXPECT_TRUE(AttrUtils::GetStr(op_desc, name1, name2, value)); - EXPECT_EQ(value, value2); - - // name1和name2均设置属性,获取的是name1的属性值 - AttrUtils::SetStr(op_desc, name1, value1); - EXPECT_TRUE(AttrUtils::GetStr(op_desc, name1, name2, value)); - EXPECT_EQ(value, value1); - - // 异常场景 - EXPECT_FALSE(AttrUtils::GetStr(nullptr, name1, name2, value)); -} - -TEST_F(UtestGeAttrValue, SetNullObjectAttr) { - OpDescPtr op_desc(nullptr); - EXPECT_EQ(AttrUtils::SetStr(op_desc, "key", "value"), false); - EXPECT_EQ(AttrUtils::SetInt(op_desc, "key", 0), false); - EXPECT_EQ(AttrUtils::SetTensorDesc(op_desc, "key", GeTensorDesc()), false); - GeTensorPtr ge_tensor; - EXPECT_EQ(AttrUtils::SetTensor(op_desc, "key", ge_tensor), false); - ConstGeTensorPtr const_ge_tensor; - EXPECT_EQ(AttrUtils::SetTensor(op_desc, "key", const_ge_tensor), false); - EXPECT_EQ(AttrUtils::SetBool(op_desc, "key", true), false); - EXPECT_EQ(AttrUtils::SetBytes(op_desc, "key", Buffer()), false); - EXPECT_EQ(AttrUtils::SetFloat(op_desc, "key", 1.0), false); - EXPECT_EQ(AttrUtils::SetGraph(op_desc, "key", nullptr), false); - EXPECT_EQ(AttrUtils::SetDataType(op_desc, "key", DT_UINT8), false); - EXPECT_EQ(AttrUtils::SetListDataType(op_desc, "key", {DT_UINT8}), false); - EXPECT_EQ(AttrUtils::SetListListInt(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListInt(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListTensor(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListBool(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListFloat(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListBytes(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListGraph(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListListFloat(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListTensorDesc(op_desc, "key", {}), false); - EXPECT_EQ(AttrUtils::SetListStr(op_desc, "key", {}), false); - std::vector buffer; - EXPECT_EQ(AttrUtils::SetZeroCopyListBytes(op_desc, "key", buffer), false); -} -TEST_F(UtestGeAttrValue, GetNullObjectAttr) { - OpDescPtr op_desc(nullptr); - std::string value; - EXPECT_EQ(AttrUtils::GetStr(op_desc, "key", value), false); - int64_t i; - EXPECT_EQ(AttrUtils::GetInt(op_desc, "key", i), false); - GeTensorDesc ge_tensor_desc; - EXPECT_EQ(AttrUtils::GetTensorDesc(op_desc, "key", ge_tensor_desc), false); - ConstGeTensorPtr const_ge_tensor; - EXPECT_EQ(AttrUtils::GetTensor(op_desc, "key", const_ge_tensor), false); - bool flag; - EXPECT_EQ(AttrUtils::GetBool(op_desc, "key", flag), false); - Buffer buffer; - EXPECT_EQ(AttrUtils::GetBytes(op_desc, "key", buffer), false); - float j; - EXPECT_EQ(AttrUtils::GetFloat(op_desc, "key", j), false); - ComputeGraphPtr compute_graph; - EXPECT_EQ(AttrUtils::GetGraph(op_desc, "key", compute_graph), false); - DataType data_type; - EXPECT_EQ(AttrUtils::GetDataType(op_desc, "key", data_type), false); - std::vector data_types; - EXPECT_EQ(AttrUtils::GetListDataType(op_desc, "key", data_types), false); - std::vector> ints_list; - EXPECT_EQ(AttrUtils::GetListListInt(op_desc, "key", ints_list), false); - std::vector ints; - EXPECT_EQ(AttrUtils::GetListInt(op_desc, "key", ints), false); - std::vector tensors; - EXPECT_EQ(AttrUtils::GetListTensor(op_desc, "key", tensors), false); - std::vector flags; - EXPECT_EQ(AttrUtils::GetListBool(op_desc, "key", flags), false); - std::vector floats; - EXPECT_EQ(AttrUtils::GetListFloat(op_desc, "key", floats), false); - std::vector buffers; - EXPECT_EQ(AttrUtils::GetListBytes(op_desc, "key", buffers), false); - std::vector graphs; - EXPECT_EQ(AttrUtils::GetListGraph(op_desc, "key", graphs), false); - std::vector> floats_list; - EXPECT_EQ(AttrUtils::GetListListFloat(op_desc, "key", floats_list), false); - std::vector tensor_descs; - EXPECT_EQ(AttrUtils::GetListTensorDesc(op_desc, "key", tensor_descs), false); - std::vector strings; - EXPECT_EQ(AttrUtils::GetListStr(op_desc, "key", strings), false); - EXPECT_EQ(AttrUtils::GetZeroCopyListBytes(op_desc, "key", buffers), false); -} - -TEST_F(UtestGeAttrValue, SetGetAttrValue_Comprehensive) { - // 综合测试所有功能 - AttrValue attr_value; - - // 测试所有支持的类型 - std::vector>> test_cases = { - {"int64_t", [&]() { - int64_t val = 12345; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - int64_t get_val = 0; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, val); - }}, - {"float32_t", [&]() { - float32_t val = 3.14159f; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - float32_t get_val = 0.0f; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_FLOAT_EQ(get_val, val); - }}, - {"bool", [&]() { - bool val = true; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - bool get_val = false; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, val); - }}, - {"DataType", [&]() { - ge::DataType val = DT_FLOAT; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - ge::DataType get_val = DT_UNDEFINED; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, val); - }}, - {"vector", [&]() { - std::vector val = {1, 2, 3, 4, 5}; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - std::vector get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, val); - }}, - {"vector", [&]() { - std::vector val = {1.1f, 2.2f, 3.3f}; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - std::vector get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val.size(), val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_FLOAT_EQ(get_val[i], val[i]); - } - }}, - {"vector", [&]() { - std::vector val = {true, false, true}; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - std::vector get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, val); - }}, - {"vector>", [&]() { - std::vector> val = {{1, 2}, {3, 4, 5}}; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - std::vector> get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, val); - }}, - {"vector", [&]() { - std::vector val = {DT_FLOAT, DT_INT32, DT_BOOL}; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - std::vector get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, val); - }}, - {"AscendString", [&]() { - AscendString val("test_ascend_string"); - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - AscendString get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_STREQ(get_val.GetString(), val.GetString()); - }}, - {"vector", [&]() { - std::vector val = { - AscendString("str1"), - AscendString("str2"), - AscendString("str3") - }; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - std::vector get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val.size(), val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_STREQ(get_val[i].GetString(), val[i].GetString()); - } - }}, - {"Tensor", [&]() { - TensorDesc tensor_desc(Shape({4, 4}), FORMAT_ND, DT_FLOAT); - std::vector tensor_data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; - Tensor val(tensor_desc, tensor_data); - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - Tensor get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val.GetSize(), val.GetSize()); - }}, - {"vector", [&]() { - TensorDesc tensor_desc(Shape({2, 2}), FORMAT_ND, DT_FLOAT); - std::vector tensor_data = {1, 2, 3, 4}; - Tensor tensor1(tensor_desc, tensor_data); - Tensor tensor2(tensor_desc, tensor_data); - std::vector val = {tensor1, tensor2}; - EXPECT_EQ(attr_value.SetAttrValue(val), GRAPH_SUCCESS); - std::vector get_val; - EXPECT_EQ(attr_value.GetAttrValue(get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val.size(), val.size()); - for (size_t i = 0; i < val.size(); ++i) { - EXPECT_EQ(get_val[i].GetSize(), val[i].GetSize()); - } - }} - }; - - // 执行所有测试用例 - for (const auto &test_case : test_cases) { - SCOPED_TRACE("Testing: " + test_case.first); - test_case.second(); - } -} -// extern "C" wrapper for AttrValue SetAttrValue methods to avoid C++ name mangling -extern "C" { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_AttrValue_SetAttrValue_Tensor(void *attr_value_ptr, - const void *value) { - if (attr_value_ptr == nullptr || value == nullptr) { - return GRAPH_FAILED; - } - auto *attr_value = static_cast(attr_value_ptr); - auto *tensor = static_cast(value); - return attr_value->SetAttrValue(*tensor); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_AttrValue_SetAttrValue_Int64(void *attr_value_ptr, - int64_t value) { - if (attr_value_ptr == nullptr) { - return GRAPH_FAILED; - } - auto *attr_value = static_cast(attr_value_ptr); - return attr_value->SetAttrValue(value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_AttrValue_SetAttrValue_String(void *attr_value_ptr, - const char_t *value) { - if (attr_value_ptr == nullptr || value == nullptr) { - return GRAPH_FAILED; - } - auto *attr_value = static_cast(attr_value_ptr); - return attr_value->SetAttrValue(ge::AscendString(value)); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_AttrValue_SetAttrValue_Bool(void *attr_value_ptr, - bool value) { - if (attr_value_ptr == nullptr) { - return GRAPH_FAILED; - } - auto *attr_value = static_cast(attr_value_ptr); - return attr_value->SetAttrValue(value); -} -} -TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Tensor_Success) { - AttrValue attr_value; - Tensor tensor_value; - - // 测试成功情况 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Tensor(&attr_value, &tensor_value), GRAPH_SUCCESS); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Tensor(nullptr, &tensor_value), GRAPH_FAILED); -} - -TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Int64_Success) { - AttrValue attr_value; - int64_t int64_value = 12345; - - // 测试成功情况 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, int64_value), GRAPH_SUCCESS); - - // 验证设置的值 - int64_t get_value = 0; - EXPECT_EQ(attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, int64_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(nullptr, int64_value), GRAPH_FAILED); -} - -TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_String_Success) { - AttrValue attr_value; - const char_t *char_value = "12345"; - - // 测试成功情况 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_String(&attr_value, char_value), GRAPH_SUCCESS); - - // 验证设置的值 - ge::AscendString get_value = "0"; - EXPECT_EQ(attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, char_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_String(nullptr, char_value), GRAPH_FAILED); -} - -TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Int64_EdgeCases) { - AttrValue attr_value; - - // 测试边界值 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, std::numeric_limits::max()), GRAPH_SUCCESS); - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, std::numeric_limits::min()), GRAPH_SUCCESS); - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, 0), GRAPH_SUCCESS); - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Int64(&attr_value, -1), GRAPH_SUCCESS); -} - -TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Bool_Success) { - AttrValue attr_value; - bool bool_value = true; - - // 测试成功情况 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Bool(&attr_value, bool_value), GRAPH_SUCCESS); - - // 验证设置的值 - bool get_value = false; - EXPECT_EQ(attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, bool_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Bool(nullptr, bool_value), GRAPH_FAILED); -} - -TEST_F(UtestGeAttrValue, ExternC_AttrValue_SetAttrValue_Bool_EdgeCases) { - AttrValue attr_value; - - // 测试true值 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Bool(&attr_value, true), GRAPH_SUCCESS); - bool get_value = false; - EXPECT_EQ(attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, true); - - // 测试false值 - EXPECT_EQ(aclCom_AttrValue_SetAttrValue_Bool(&attr_value, false), GRAPH_SUCCESS); - EXPECT_EQ(attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, false); -} -} diff --git a/tests/ut/graph/testcase/ge_context_unittest.cc b/tests/ut/graph/testcase/ge_context_unittest.cc deleted file mode 100644 index 94193af28b323f724bcc54237c88583d03c05ac6..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ge_context_unittest.cc +++ /dev/null @@ -1,207 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "test_structs.h" -#include "func_counter.h" -#include "graph/ge_context.h" -#include "graph/ge_global_options.h" -#include "graph/ge_local_context.h" -#include "graph/node.h" -#include "graph_builder_utils.h" -#include "external/ge_common/ge_api_types.h" -#include "register/optimization_option_registry.h" -#include "nlohmann/json.hpp" - -namespace ge { -using Json = nlohmann::json; -class GeContextUt : public testing::Test {}; - -void SetIROptionToShowName(std::string &option_name_map, const std::string &ir_option, const std::string &show_name) { - std::string json = "\"" + ir_option + "\": \"" + show_name +"\",\n"; - option_name_map += json; - return; -} - -TEST_F(GeContextUt, All) { - ge::GEContext cont = GetContext(); - cont.Init(); - EXPECT_EQ(cont.GetHostExecFlag(), false); - EXPECT_EQ(cont.IsOverflowDetectionOpen(), false); - EXPECT_EQ(GetMutableGlobalOptions().size(), 0); - EXPECT_EQ(cont.SessionId(), 0); - EXPECT_EQ(cont.DeviceId(), 0); - EXPECT_EQ(cont.GetInputFusionSize(), 128 * 1024U); - - cont.SetSessionId(1); - cont.SetContextId(2); - cont.SetCtxDeviceId(4); - EXPECT_EQ(cont.SessionId(), 1); - EXPECT_EQ(cont.DeviceId(), 4); - - EXPECT_EQ(cont.StreamSyncTimeout(), -1); - cont.SetStreamSyncTimeout(10000); - EXPECT_EQ(cont.StreamSyncTimeout(), 10000); - - EXPECT_EQ(cont.EventSyncTimeout(), -1); - cont.SetEventSyncTimeout(20000); - EXPECT_EQ(cont.EventSyncTimeout(), 20000); -} - -TEST_F(GeContextUt, Plus) { - std::map session_option{{"ge.exec.placement", "ge.exec.placement"}}; - GetThreadLocalContext().SetSessionOption(session_option); - - std::string exec_placement; - GetThreadLocalContext().GetOption("ge.exec.placement", exec_placement); - EXPECT_EQ(exec_placement, "ge.exec.placement"); - ge::GEContext cont = GetContext(); - EXPECT_EQ(cont.GetHostExecFlag(), false); - - std::map session_option0{{"ge.graphLevelSat", "1"}}; - GetThreadLocalContext().SetSessionOption(session_option0); - EXPECT_EQ(cont.IsGraphLevelSat(), true); - - std::map session_option1{{"ge.exec.overflow", "1"}}; - GetThreadLocalContext().SetSessionOption(session_option1); - EXPECT_EQ(cont.IsOverflowDetectionOpen(), true); - - std::map session_option2{{"ge.exec.sessionId", "12345678987654321"}}; - GetThreadLocalContext().SetSessionOption(session_option2); - cont.Init(); - std::map session_option3{{"ge.exec.deviceId", "12345678987654321"}}; - GetThreadLocalContext().SetSessionOption(session_option3); - cont.Init(); - std::map session_option4{{"ge.exec.jobId", "12345"}}; - GetThreadLocalContext().SetSessionOption(session_option4); - cont.Init(); - std::map session_option5{{"ge.exec.jobId", "65536"}}; - GetThreadLocalContext().SetSessionOption(session_option5); - cont.Init(); - - // 32 * 1024 * 1024 = 33554432, max - std::map session_option6{{OPTION_EXEC_INPUT_FUSION_SIZE, "33554432"}}; - GetThreadLocalContext().SetSessionOption(session_option6); - EXPECT_EQ(cont.GetInputFusionSize(), 32 * 1024 * 1024U); - - // 32 * 1024 * 1024 + 1 = 33554433, bigger than max - std::map session_option7{{OPTION_EXEC_INPUT_FUSION_SIZE, "33554433"}}; - GetThreadLocalContext().SetSessionOption(session_option7); - EXPECT_EQ(cont.GetInputFusionSize(), 32 * 1024 * 1024U); - - // invalid value, -1 - std::map session_option8{{OPTION_EXEC_INPUT_FUSION_SIZE, "-1"}}; - GetThreadLocalContext().SetSessionOption(session_option8); - EXPECT_EQ(cont.GetInputFusionSize(), 0U); - - // value : 25600 - std::map session_option9{{OPTION_EXEC_INPUT_FUSION_SIZE, "25600"}}; - GetThreadLocalContext().SetSessionOption(session_option9); - EXPECT_EQ(cont.GetInputFusionSize(), 25600U); - - // value: 0 - std::map session_option10{{OPTION_EXEC_INPUT_FUSION_SIZE, "0"}}; - GetThreadLocalContext().SetGraphOption(session_option10); - EXPECT_EQ(cont.GetInputFusionSize(), 0U); -} - -TEST_F(GeContextUt, set_valid_SyncTimeout_from_option) { - std::map session_option{{"ge.exec.sessionId", "0"}, - {"ge.exec.deviceId", "1"}, - {"ge.exec.jobId", "2"}, - {"stream_sync_timeout", "10000"}, - {"event_sync_timeout", "20000"}}; - GetThreadLocalContext().SetSessionOption(session_option); - ge::GEContext ctx = GetContext(); - ctx.Init(); - EXPECT_EQ(ctx.StreamSyncTimeout(), 10000); - EXPECT_EQ(ctx.EventSyncTimeout(), 20000); -} - -TEST_F(GeContextUt, set_invalid_option) { - std::map session_option{{"ge.exec.sessionId", "-1"}, - {"ge.exec.deviceId", "-1"}, - {"ge.exec.jobId", "-1"}, - {"stream_sync_timeout", ""}, - {"event_sync_timeout", ""}}; - GetThreadLocalContext().SetSessionOption(session_option); - ge::GEContext ctx = GetContext(); - ctx.Init(); - EXPECT_EQ(ctx.SessionId(), 0U); - EXPECT_EQ(ctx.DeviceId(), 0U); - EXPECT_EQ(ctx.StreamSyncTimeout(), -1); - EXPECT_EQ(ctx.EventSyncTimeout(), -1); -} - -TEST_F(GeContextUt, set_OutOfRange_SyncTimeout_from_option) { - std::map session_option{{"ge.exec.sessionId", "1234567898765432112345"}, - {"ge.exec.deviceId", "1234567898765432112345"}, - {"ge.exec.jobId", "1234567898765432112345"}, - {"stream_sync_timeout", "1234567898765432112345"}, - {"event_sync_timeout", "1234567898765432112345"}}; - GetThreadLocalContext().SetSessionOption(session_option); - ge::GEContext ctx = GetContext(); - ctx.Init(); - EXPECT_EQ(ctx.SessionId(), 0U); - EXPECT_EQ(ctx.DeviceId(), 0U); - EXPECT_EQ(ctx.StreamSyncTimeout(), -1); - EXPECT_EQ(ctx.EventSyncTimeout(), -1); -} -TEST_F(GeContextUt, reset_option_name_map) { - ge::GEContext ctx = GetContext(); - Json option_name_map; - option_name_map.emplace("ge.enableSmallChannel", "enable_small_channel"); - EXPECT_EQ(ctx.SetOptionNameMap(option_name_map.dump()), ge::GRAPH_SUCCESS); - EXPECT_EQ(ctx.SetOptionNameMap(option_name_map.dump()), ge::GRAPH_SUCCESS); -} - -TEST_F(GeContextUt, set_invalid_option_name_map) { - std::string option_name_map_1 = ""; - GetThreadLocalContext() = GEThreadLocalContext(); - ge::GEContext ctx = GetContext(); - EXPECT_EQ(ctx.SetOptionNameMap(option_name_map_1), ge::GRAPH_FAILED); - std::string show_name; - show_name = ctx.GetReadableName("ge.enableSmallChannel"); - EXPECT_EQ(show_name, "ge.enableSmallChannel"); - show_name = ctx.GetReadableName("ge.exec.enable_exception_dump"); - EXPECT_EQ(show_name, "ge.exec.enable_exception_dump"); - show_name = ctx.GetReadableName("ge.exec.opWaitTimeout"); - EXPECT_EQ(show_name, "ge.exec.opWaitTimeout"); - - std::string option_name_map_2; - SetIROptionToShowName(option_name_map_2, "ge.enableSmallChannel", ""); - option_name_map_2 = "{\n" + option_name_map_2.substr(0, option_name_map_2.size() - 2) + "\n}"; - EXPECT_EQ(ctx.SetOptionNameMap(option_name_map_2), ge::GRAPH_FAILED); - show_name = ctx.GetReadableName("ge.enableSmallChannel"); - EXPECT_EQ(show_name, "ge.enableSmallChannel"); - - std::string option_name_map_3; - SetIROptionToShowName(option_name_map_3, "", "enable_small_channel"); - option_name_map_2 = "{\n" + option_name_map_3.substr(0, option_name_map_3.size() - 2) + "\n}"; - EXPECT_EQ(ctx.SetOptionNameMap(option_name_map_3), ge::GRAPH_FAILED); - show_name = ctx.GetReadableName("ge.enableSmallChannel"); - EXPECT_EQ(show_name, "ge.enableSmallChannel"); -} - -TEST_F(GeContextUt, GetOo_Ok) { - REG_OPTION("ge.oo.ctx_test_option").LEVELS(OoLevel::kO1).VISIBILITY(OoEntryPoint::kSession, OoEntryPoint::kIrBuild); - std::map session_option{ - {"ge.exec.sessionId", "-1"}, - {"ge.oo.level", "O1"}, - {"ge.oo.ctx_test_option", "false"}, - }; - auto &oopt = GetContext().GetOo(); - EXPECT_EQ(oopt.Initialize(session_option, OptionRegistry::GetInstance().GetRegisteredOptTable()), - GRAPH_SUCCESS); - std::string value; - EXPECT_EQ(oopt.GetValue("ge.oo.ctx_test_option", value), GRAPH_SUCCESS); - EXPECT_EQ(value, "false"); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/ge_graph_dumper_unittest.cc b/tests/ut/graph/testcase/ge_graph_dumper_unittest.cc deleted file mode 100644 index b95ae1525e9b53f06fc89fadf3b9a293c00df760..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ge_graph_dumper_unittest.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/dumper/ge_graph_dumper.h" -#include - -namespace ge { -class GeGraphDumperUt : public testing::Test {}; - -namespace { -int64_t dump_times = 0; -class TestDumper : public GeGraphDumper { - public: - void Dump(const ComputeGraphPtr &graph, const string &suffix) override { - ++dump_times; - } -}; -} - -TEST_F(GeGraphDumperUt, DefaultImpl) { - dump_times = 0; - GraphDumperRegistry::Unregister(); - GraphDumperRegistry::GetDumper().Dump(nullptr, "test"); - EXPECT_EQ(dump_times, 0); -} - -TEST_F(GeGraphDumperUt, RegisterOk) { - dump_times = 0; - TestDumper dumper; - GraphDumperRegistry::Unregister(); - GraphDumperRegistry::Register(dumper); - GraphDumperRegistry::GetDumper().Dump(nullptr, "test"); - EXPECT_EQ(dump_times, 1); - GraphDumperRegistry::GetDumper().Dump(nullptr, "test"); - EXPECT_EQ(dump_times, 2); -} - -TEST_F(GeGraphDumperUt, UnregisterOk) { - dump_times = 0; - TestDumper dumper; - GraphDumperRegistry::Register(dumper); - GraphDumperRegistry::GetDumper().Dump(nullptr, "test"); - EXPECT_EQ(dump_times, 1); - GraphDumperRegistry::GetDumper().Dump(nullptr, "test"); - EXPECT_EQ(dump_times, 2); - - GraphDumperRegistry::Unregister(); - GraphDumperRegistry::GetDumper().Dump(nullptr, "test"); - EXPECT_EQ(dump_times, 2); - GraphDumperRegistry::GetDumper().Dump(nullptr, "test"); - EXPECT_EQ(dump_times, 2); -} -} diff --git a/tests/ut/graph/testcase/ge_ir_utils_unittest.cc b/tests/ut/graph/testcase/ge_ir_utils_unittest.cc deleted file mode 100644 index b5d7d407d40bd4496f37ecc036d897829e103027..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ge_ir_utils_unittest.cc +++ /dev/null @@ -1,338 +0,0 @@ -// /* Copyright (c) 2024 Huawei Technologies Co., Ltd. -// * This file is a part of the CANN Open Software. -// * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -// * Please refer to the License for details. You may not use this file except in compliance with the License. -// * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -// * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -// * See LICENSE in the root of the software repository for the full text of the License. -// * ===================================================================================================================*/ -// -// #include -// -// #include "graph/utils/ge_ir_utils.h" -// #include "graph/utils/attr_utils.h" -// #include "graph/utils/graph_utils_ex.h" -// #include "graph/op_desc.h" -// #include "graph/compute_graph.h" -// #include "graph_builder_utils.h" -// #include "graph/node.h" -// #include "graph/normal_graph/node_impl.h" -// #include "test_std_structs.h" -// #include "graph/attribute_group/attr_group_shape_env.h" -// #include "graph/attribute_group/attr_group_symbolic_desc.h" -// -// namespace ge { -// namespace { -// ComputeGraphPtr BuildGraph0() { -// auto builder = ut::GraphBuilder("root"); -// const auto &data1 = builder.AddNode("data1", "Data", 1, 1); -// const auto &data2 = builder.AddNode("data2", "Data", 1, 1); -// const auto &add1 = builder.AddNode("add1", "AddN", 20, 1); -// builder.AddDataEdge(data1, 0, add1, 0); -// builder.AddDataEdge(data2, 0, add1, 1); -// builder.AddDataEdge(data2, 0, add1, 2); -// builder.AddDataEdge(data2, 0, add1, 3); -// builder.AddDataEdge(data2, 0, add1, 4); -// return builder.GetGraph(); -// } -// } -// static ComputeGraphPtr CreateGraph_1_1_224_224(float *tensor_data) { -// ut::GraphBuilder builder("graph1"); -// auto data1 = builder.AddNode("data1", "Data", {}, {"y"}); -// AttrUtils::SetInt(data1->GetOpDesc(), "index", 0); -// AttrUtils::SetFloat(data1->GetOpDesc(), "index2", 1.0f); -// auto const1 = builder.AddNode("const1", "Const", {}, {"y"}); -// GeTensorDesc const1_td; -// const1_td.SetShape(GeShape({1, 1, 224, 224})); -// const1_td.SetOriginShape(GeShape({1, 1, 224, 224})); -// const1_td.SetFormat(FORMAT_NCHW); -// const1_td.SetOriginFormat(FORMAT_NCHW); -// const1_td.SetDataType(DT_FLOAT); -// const1_td.SetOriginDataType(DT_FLOAT); -// GeTensor tensor(const1_td); -// tensor.SetData(reinterpret_cast(tensor_data), sizeof(float) * 224 * 224); -// AttrUtils::SetTensor(const1->GetOpDesc(), "value", tensor); -// auto add1 = builder.AddNode("add1", "Add", {"x1", "x2"}, {"y"}); -// auto netoutput1 = builder.AddNode("NetOutputNode", "NetOutput", {"x"}, {}); -// ge::AttrUtils::SetListListInt(add1->GetOpDesc()->MutableOutputDesc(0), "list_list_i", {{1, 0, 0, 0}}); -// ge::AttrUtils::SetListInt(add1->GetOpDesc(), "list_i", {1}); -// ge::AttrUtils::SetListStr(add1->GetOpDesc(), "list_s", {"1"}); -// ge::AttrUtils::SetListFloat(add1->GetOpDesc(), "list_f", {1.0}); -// ge::AttrUtils::SetListBool(add1->GetOpDesc(), "list_b", {false}); -// builder.AddDataEdge(data1, 0, add1, 0); -// builder.AddDataEdge(const1, 0, add1, 1); -// builder.AddDataEdge(add1, 0, netoutput1, 0); -// -// return builder.GetGraph(); -// } -// -// class GeIrUtilsUt : public testing::Test {}; -// -// TEST_F(GeIrUtilsUt, ModelSerialize) { -// ge::Model model1("model", ""); -// ut::GraphBuilder builder("void"); -// auto data_node = builder.AddNode("data", "Data", {}, {"y"}); -// auto add_node = builder.AddNode("add", "Add", {}, {"y"}); -// float tensor_data[224 * 224] = {1.0f}; -// ComputeGraphPtr compute_graph = CreateGraph_1_1_224_224(tensor_data); -// compute_graph->AddInputNode(data_node); -// compute_graph->AddOutputNode(add_node); -// model1.SetGraph(compute_graph); -// onnx::ModelProto model_proto; -// EXPECT_TRUE(OnnxUtils::ConvertGeModelToModelProto(model1, model_proto)); -// ge::Model model2; -// EXPECT_TRUE(ge::IsEqual(std::string("test"), std::string("test"), "tag")); -// EXPECT_FALSE(ge::IsEqual(300, 20, "tag")); -// } -// -// TEST_F(GeIrUtilsUt, ModelSerializeSetSubgraphs) { -// ge::Model model1("model", ""); -// ut::GraphBuilder builder("test0"); -// auto data_node = builder.AddNode("data", "Data", {}, {"y"}); -// auto add_node = builder.AddNode("add", "Add", {}, {"y"}); -// auto graph = builder.GetGraph(); -// -// ut::GraphBuilder sub_builder("sub1"); -// auto sub_graph_1 = sub_builder.GetGraph(); -// std::vector> subgraphs; -// subgraphs.push_back(sub_graph_1); -// -// graph->SetAllSubgraphs(subgraphs); -// model1.SetGraph(graph); -// onnx::ModelProto model_proto; -// bool ret = OnnxUtils::ConvertGeModelToModelProto(model1, model_proto); -// EXPECT_EQ(ret, true); -// } -// -// TEST_F(GeIrUtilsUt, EncodeDataTypeUndefined) { -// DataType data_type = DT_DUAL; -// int ret = OnnxUtils::EncodeDataType(data_type); -// EXPECT_EQ(ret, onnx::TensorProto_DataType_UNDEFINED); -// } -// -// TEST_F(GeIrUtilsUt, EncodeNodeDescFail) { -// NodePtr node; -// onnx::NodeProto *node_proto = nullptr; -// bool ret = OnnxUtils::EncodeNodeDesc(node, node_proto); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, EncodeGraphFail) { -// ConstComputeGraphPtr graph; -// onnx::GraphProto *graph_proto = nullptr; -// bool ret = OnnxUtils::EncodeGraph(graph, graph_proto); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, EncodeNodeFail) { -// NodePtr node; -// onnx::NodeProto *node_proto = nullptr; -// bool ret = OnnxUtils::EncodeNode(node, node_proto); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, EncodeNodeLinkFail) { -// NodePtr node; -// onnx::NodeProto *node_proto = nullptr; -// bool ret = OnnxUtils::EncodeNodeLink(node, node_proto); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, ConvertGeModelToModelProtoFail) { -// ge::Model model; -// onnx::ModelProto model_proto; -// bool ret = OnnxUtils::ConvertGeModelToModelProto(model, model_proto); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, ConvertGeModelToModelProtoGraphProtoIsNull) { -// ge::Model model("model", ""); -// ComputeGraphPtr compute_graph; -// model.SetGraph(compute_graph); -// onnx::ModelProto model_proto; -// bool ret = OnnxUtils::ConvertGeModelToModelProto(model, model_proto); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeNodeLinkImpFail) { -// OnnxUtils::NodeLinkInfo item; -// NodePtr node_ptr; -// bool ret = OnnxUtils::DecodeNodeLinkImp(item, node_ptr); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeNodeDescFail) { -// onnx::NodeProto *node_proto = nullptr; -// OpDescPtr op_desc; -// bool ret = OnnxUtils::DecodeNodeDesc(node_proto, op_desc); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeGraphFail) { -// int32_t recursion_depth = 20; -// onnx::GraphProto graph_proto; -// ComputeGraphPtr graph; -// bool ret = OnnxUtils::DecodeGraph(recursion_depth, graph_proto, graph); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeNodeLinkImpGetDataAnchorFail) { -// auto builder = ut::GraphBuilder("test1"); -// const auto &node1 = builder.AddNode("node1", "node", 1, 1); -// OnnxUtils::NodeLinkInfo item("node0", 1, node1, 1, "node1"); -// bool ret = OnnxUtils::DecodeNodeLinkImp(item, node1); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeNodeLinkImpGetDataAnchorTrue) { -// auto builder = ut::GraphBuilder("test1"); -// const auto &node1 = builder.AddNode("node1", "Data", 1, 1); -// const auto &node2 = builder.AddNode("node2", "NetOutput", 1, 0); -// OnnxUtils::NodeLinkInfo item("node1", 0, node2, 0, "node2"); -// bool ret = OnnxUtils::DecodeNodeLinkImp(item, node1); -// EXPECT_EQ(ret, true); -// } -// -// TEST_F(GeIrUtilsUt, DecodeNodeLinkImpGetDataAnchorImplIsNull) { -// auto builder = ut::GraphBuilder("test1"); -// const auto &node1 = builder.AddNode("node1", "Data", 1, 1); -// const auto &node2 = builder.AddNode("node2", "NetOutput", 1, 0); -// OnnxUtils::NodeLinkInfo item("node1", 0, node2, 5, "node2"); -// bool ret = OnnxUtils::DecodeNodeLinkImp(item, node1); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeNodeLinkImpGetControlAnchorTrue) { -// auto builder = ut::GraphBuilder("test1"); -// const auto &node1 = builder.AddNode("node1", "Data", 1, 1); -// const auto &node2 = builder.AddNode("node2", "NetOutput", 1, 0); -// OnnxUtils::NodeLinkInfo item("node1", -1, node2, 0, "node2"); -// bool ret = OnnxUtils::DecodeNodeLinkImp(item, node1); -// EXPECT_EQ(ret, true); -// } -// -// TEST_F(GeIrUtilsUt, DecodeGraphTrue) { -// int32_t recursion_depth = 10; -// auto builder = ut::GraphBuilder("test0"); -// const auto &node1 = builder.AddNode("node1", "Data", 1, 1); -// const auto &node2 = builder.AddNode("node2", "NetOutput", 1, 0); -// auto graph = builder.GetGraph(); -// onnx::GraphProto graph_proto; -// bool ret = OnnxUtils::DecodeGraph(recursion_depth, graph_proto, graph); -// EXPECT_EQ(ret, true); -// } -// -// TEST_F(GeIrUtilsUt, AddInputAndOutputNodesForGraphAddInputNodeFail) { -// auto builder = ut::GraphBuilder("test0"); -// const auto &node1 = builder.AddNode("node1", "Data", 1, 1); -// const auto &node2 = builder.AddNode("node2", "NetOutput", 1, 0); -// auto graph = builder.GetGraph(); -// onnx::GraphProto graph_proto; -// graph_proto.add_input(); -// std::map node_map; -// node_map.insert(pair("node2", node2)); -// bool ret = OnnxUtils::AddInputAndOutputNodesForGraph(graph_proto, graph, node_map); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, AddInputAndOutputNodesForGraphAddOutputNodeFail) { -// auto builder = ut::GraphBuilder("test0"); -// const auto &node1 = builder.AddNode("node1", "Data", 1, 1); -// const auto &node2 = builder.AddNode("node2", "NetOutput", 1, 0); -// auto graph = builder.GetGraph(); -// onnx::GraphProto graph_proto; -// graph_proto.add_output(); -// std::map node_map; -// node_map.insert(pair("node1", node1)); -// bool ret = OnnxUtils::AddInputAndOutputNodesForGraph(graph_proto, graph, node_map); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeNodeLinkDstNodeIsNull) { -// auto builder = ut::GraphBuilder("test0"); -// const auto &node1 = builder.AddNode("node1", "Data", 1, 1); -// const auto &node2 = builder.AddNode("node2", "NetOutput", 1, 0); -// auto graph = builder.GetGraph(); -// onnx::NodeProto node_proto; -// node_proto.add_input(); -// std::vector node_proto_vector; -// node_proto_vector.push_back(node_proto); -// std::map node_map; -// node_map.insert(pair("node2", node2)); -// bool ret = OnnxUtils::DecodeNodeLink(node_proto_vector, node_map); -// EXPECT_EQ(ret, false); -// } -// -// TEST_F(GeIrUtilsUt, DecodeAttributeAttrProtoTypeIsNotStrings) { -// ge::onnx::AttributeProto attr_proto; -// std::vector strings; -// strings.push_back("node1"); -// OnnxUtils::DecodeAttribute(attr_proto, strings); -// EXPECT_EQ(strings.size(), 1); -// } -// -// TEST_F(GeIrUtilsUt, DecodeAttributeAttrProtoTypeIsNotInts) { -// ge::onnx::AttributeProto attr_proto; -// std::vector ints; -// ints.push_back(1); -// ints.push_back(2); -// OnnxUtils::DecodeAttribute(attr_proto, ints); -// EXPECT_EQ(ints.size(), 2); -// } -// -// TEST_F(GeIrUtilsUt, DecodeAttributeAttrProtoTypeIsNotInt) { -// ge::onnx::AttributeProto attr_proto; -// int64_t value = 1; -// OnnxUtils::DecodeAttribute(attr_proto, value); -// EXPECT_TRUE(value == 1); -// } -// -// TEST_F(GeIrUtilsUt, DecodeAttributeAttrProtoTypeIsNotString) { -// ge::onnx::AttributeProto attr_proto; -// std::string value = "1"; -// OnnxUtils::DecodeAttribute(attr_proto, value); -// EXPECT_EQ(value, "1"); -// } -// -// TEST_F(GeIrUtilsUt, OnnxDumpCheck) { -// auto graph = BuildGraph0(); -// auto add1 = graph->FindNode("add1"); -// auto shape_env_adder = [](auto attr_group_holder) { -// attr_group_holder->template GetOrCreateAttrsGroup(); -// }; -// auto symbolic_adder = [](auto attr_group_holder) { -// attr_group_holder->template GetOrCreateAttrsGroup(); -// }; -// auto normal_attr_adder = [](auto attr_holder) { -// AttrUtils::SetStr(attr_holder, "normal_attr", "for_test"); -// }; -// shape_env_adder(add1->GetOpDesc()); -// symbolic_adder(add1->GetOpDesc()); -// shape_env_adder(add1->GetOpDesc()->MutableInputDesc(0)); -// symbolic_adder(add1->GetOpDesc()->MutableInputDesc(0)); -// shape_env_adder(add1->GetOpDesc()->MutableInputDesc(1)); -// normal_attr_adder(add1->GetOpDesc()->MutableInputDesc(1)); -// ge::onnx::NodeProto node_proto; -// EXPECT_TRUE(OnnxUtils::EncodeNode(add1, &node_proto)); -// bool tensor_structured_show = false; -// for (const auto &attr : node_proto.attribute()) { -// if (attr.name().find("input_desc_0") == 0U) { -// tensor_structured_show = true; -// const std::string kExpected = R"PROTO(name: "input_desc_0" -// s: "{\"attr_groups\":\"attr_group_def {\\n shape_env_attr_group {\\n shape_setting {\\n }\\n }\\n}\\nattr_group_def {\\n tensor_attr_group {\\n }\\n}\\n\",\"cmps_size\":0,\"cmps_tab\":\"\",\"cmps_tab_offset\":0,\"data_offset\":0,\"device_type\":\"NPU\",\"dtype\":\"DT_FLOAT\",\"input_tensor\":0,\"layout\":\"NCHW\",\"origin_dtype\":\"DT_FLOAT\",\"origin_layout\":\"NCHW\",\"origin_shape\":[1,1,224,224],\"output_tensor\":0,\"real_dim_cnt\":0,\"reuse_input\":0,\"shape\":[1,1,224,224],\"size\":0,\"weight_size\":0}" -// type: STRING -// )PROTO"; -// EXPECT_EQ(attr.DebugString(), kExpected); -// } -// if (attr.name() == "input_desc_1") { -// tensor_structured_show = true; -// const std::string kExpected = R"PROTO(name: "input_desc_1" -// s: "{\"attr_groups\":\"attr_group_def {\\n shape_env_attr_group {\\n shape_setting {\\n }\\n }\\n}\\n\",\"cmps_size\":0,\"cmps_tab\":\"\",\"cmps_tab_offset\":0,\"data_offset\":0,\"device_type\":\"NPU\",\"dtype\":\"DT_FLOAT\",\"input_tensor\":0,\"layout\":\"NCHW\",\"normal_attr\":\"s: \\\"for_test\\\"\\n\",\"origin_dtype\":\"DT_FLOAT\",\"origin_layout\":\"NCHW\",\"origin_shape\":[1,1,224,224],\"output_tensor\":0,\"real_dim_cnt\":0,\"reuse_input\":0,\"shape\":[1,1,224,224],\"size\":0,\"weight_size\":0}" -// type: STRING -// )PROTO"; -// EXPECT_EQ(attr.DebugString(), kExpected); -// } -// } -// EXPECT_TRUE(tensor_structured_show); -// } -// } diff --git a/tests/ut/graph/testcase/ge_local_context_unittest.cc b/tests/ut/graph/testcase/ge_local_context_unittest.cc deleted file mode 100644 index faffb63af4215c5c36e43598c4b34173e324fd2b..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ge_local_context_unittest.cc +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_local_context.h" - -namespace ge { -class UtestGeLocalContext : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestGeLocalContext, GetAllGraphOptionsTest) { - GEThreadLocalContext ge_local_context; - std::map graph_maps; - std::string key1 = "333"; - std::string value1 = "cccc"; - std::string key2 = "444"; - std::string value2 = "ddd"; - graph_maps.insert(std::make_pair(key1, value1)); - graph_maps.insert(std::make_pair(key2, value2)); - ge_local_context.SetGraphOption(graph_maps); - - std::map graph_options_; - graph_options_ = ge_local_context.GetAllGraphOptions(); - std::string ret_value1 = graph_options_[key1]; - EXPECT_EQ(ret_value1, "cccc"); - std::string ret_value2 = graph_options_[key2]; - EXPECT_EQ(ret_value2, "ddd"); -} - -TEST_F(UtestGeLocalContext, GetAllOptionsTest) { - GEThreadLocalContext ge_local_context; - std::map global_maps; - std::string global_key1 = "111"; - std::string global_value1 = "aaa"; - std::string global_key2 = "222"; - std::string global_value2 = "bbb"; - global_maps.insert(std::make_pair(global_key1, global_value1)); - global_maps.insert(std::make_pair(global_key2, global_value2)); - ge_local_context.SetGlobalOption(global_maps); - - std::map session_maps; - std::string session_key1 = "333"; - std::string session_value1 = "ccc"; - std::string session_key2 = "444"; - std::string session_value2 = "ddd"; - session_maps.insert(std::make_pair(session_key1, session_value1)); - session_maps.insert(std::make_pair(session_key2, session_value2)); - ge_local_context.SetSessionOption(session_maps); - - std::map graph_maps; - std::string graph_key1 = "555"; - std::string graph_value1 = "eee"; - std::string graph_key2 = "666"; - std::string graph_value2 = "fff"; - graph_maps.insert(std::make_pair(graph_key1, graph_value1)); - graph_maps.insert(std::make_pair(graph_key2, graph_value2)); - ge_local_context.SetGraphOption(graph_maps); - - std::map options_all; - options_all = ge_local_context.GetAllOptions(); - std::string ret_value1 = options_all["222"]; - EXPECT_EQ(ret_value1, "bbb"); - std::string ret_value2 = options_all["444"]; - EXPECT_EQ(ret_value2, "ddd"); - std::string ret_value3 = options_all["555"]; - EXPECT_EQ(ret_value3, "eee"); -} - -TEST_F(UtestGeLocalContext, GetGraphOptionSuccess) { - GEThreadLocalContext ge_local_context; - std::map graph_maps; - std::string graph_key1 = "test1"; - std::string graph_value1 = "node1"; - std::string graph_key2 = "test2"; - std::string graph_value2 = "node2"; - graph_maps.insert(std::make_pair(graph_key1, graph_value1)); - graph_maps.insert(std::make_pair(graph_key2, graph_value2)); - ge_local_context.SetGraphOption(graph_maps); - - std::string find_key = "test1"; - std::string option; - int ret = ge_local_context.GetOption(find_key, option); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGeLocalContext, GetGlobalOptionSuccess) { - GEThreadLocalContext ge_local_context; - std::map graph_maps; - std::string global_key1 = "global1"; - std::string global_value1 = "node1"; - std::string global_key2 = "global2"; - std::string global_value2 = "node2"; - graph_maps.insert(std::make_pair(global_key1, global_value1)); - graph_maps.insert(std::make_pair(global_key2, global_value2)); - ge_local_context.SetGlobalOption(graph_maps); - - std::string find_key = "global1"; - std::string option; - int ret = ge_local_context.GetOption(find_key, option); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGeLocalContext, StreamSyncTimeoutIsInvalid) { - std::map global_options; - global_options.insert(std::make_pair("stream_sync_timeout", "aaaaaaaaaa")); - GetThreadLocalContext().SetGlobalOption(global_options); - EXPECT_EQ(GetThreadLocalContext().StreamSyncTimeout(), -1); -} - -TEST_F(UtestGeLocalContext, GetReableNameSuccess) { - GetThreadLocalContext() = GEThreadLocalContext(); - std::map global_options; - std::string ir_option = "ge.Test"; - std::string show_name = "--test"; - std::string json = "{\"" + ir_option + "\": \"" + show_name +"\"\n}"; - global_options["ge.optionNameMap"] = json; - GetThreadLocalContext().SetGlobalOption(global_options); - auto readable_name = GetThreadLocalContext().GetReadableName(ir_option); - EXPECT_EQ(readable_name, show_name); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/ge_tensor_unittest.cc b/tests/ut/graph/testcase/ge_tensor_unittest.cc deleted file mode 100644 index 65eaefeee0e82aee5d87aa7b23fb26f854106ec9..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ge_tensor_unittest.cc +++ /dev/null @@ -1,500 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "graph/ge_tensor.h" -#include "ge_ir.pb.h" -#include "graph/ge_attr_value.h" -#include "graph/tensor.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/normal_graph/ge_tensor_impl.h" - -using namespace std; -using namespace ge; - -class UtestGeTensor : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestGeTensor, origin_shape_format) { - GeTensorDesc a; - GeShape shape({1, 2, 3, 4}); - a.SetOriginShape(shape); - a.SetOriginFormat(FORMAT_NCHW); - EXPECT_EQ(a.GetOriginShape().GetShapeSize(), 24); - EXPECT_EQ(a.GetOriginFormat(), FORMAT_NCHW); -} - -TEST_F(UtestGeTensor, get_shape_size) { - vector vec2{-1, 1, 2, 4}; - Shape shape2(vec2); - shape2.GetShapeSize(); - - vector vec3{-1, 2, 4, INT64_MAX}; - Shape shape3(vec3); - shape3.GetShapeSize(); - - vector vec4{-1, 2, 4, INT64_MAX}; - Shape shape4(vec4); - shape4.GetShapeSize(); - - vector vec1{1, 2, 3, 4}; - Shape shape1(vec1); - EXPECT_EQ(shape1.GetShapeSize(), 24); -} - -TEST_F(UtestGeTensor, TestEmptyTensor) { - vector vec1{0}; - GeShape shape1(vec1); - EXPECT_EQ(shape1.IsEmptyTensor(), true); - - vector vec2{1, 2, 3, 4}; - GeShape shape2(vec2); - EXPECT_EQ(shape2.IsEmptyTensor(), false); - - vector vec3{1, 2, 3, 0}; - GeShape shape3(vec3); - EXPECT_EQ(shape3.IsEmptyTensor(), true); -} - -TEST_F(UtestGeTensor, shape) { - GeShape a; - EXPECT_EQ(a.GetDim(0), 0); - EXPECT_EQ(a.GetShapeSize(), 0); - EXPECT_EQ(a.SetDim(0, 0), GRAPH_FAILED); - - vector vec({1, 2, 3, 4}); - GeShape b(vec); - GeShape c({1, 2, 3, 4}); - EXPECT_EQ(c.GetDimNum(), 4); - EXPECT_EQ(c.GetDim(2), 3); - EXPECT_EQ(c.GetDim(5), 0); - EXPECT_EQ(c.SetDim(10, 0), GRAPH_FAILED); - - EXPECT_EQ(c.SetDim(2, 2), GRAPH_SUCCESS); - EXPECT_EQ(c.GetDim(2), 2); - vector vec1 = c.GetDims(); - EXPECT_EQ(c.GetDim(0), vec1[0]); - EXPECT_EQ(c.GetDim(1), vec1[1]); - EXPECT_EQ(c.GetDim(2), vec1[2]); - EXPECT_EQ(c.GetDim(3), vec1[3]); - - SmallVector vec2 = c.GetMutableDims(); - EXPECT_EQ(c.GetDim(0), vec2[0]); - EXPECT_EQ(c.GetDim(1), vec2[1]); - EXPECT_EQ(c.GetDim(2), vec2[2]); - EXPECT_EQ(c.GetDim(3), vec2[3]); - - EXPECT_EQ(c.GetShapeSize(), 16); -} - -TEST_F(UtestGeTensor, ge_shape_to_string1) { - GeShape shape1({1, 2, 3, 4}); - EXPECT_EQ(shape1.ToString(), "1,2,3,4"); - GeShape shape2; - EXPECT_EQ(shape2.ToString(), ""); -} - -TEST_F(UtestGeTensor, tensor_desc) { - GeTensorDesc a; - GeShape s({1, 2, 3, 4}); - GeTensorDesc b(s, FORMAT_NCHW); - GeShape s1 = b.GetShape(); - EXPECT_EQ(s1.GetDim(0), s.GetDim(0)); - b.MutableShape().SetDim(0, 2); - EXPECT_EQ(b.GetShape().GetDim(0), 2); - GeShape s2({3, 2, 3, 4}); - b.SetShape(s2); - EXPECT_EQ(b.GetShape().GetDim(0), 3); - - EXPECT_EQ(b.GetFormat(), FORMAT_NCHW); - b.SetFormat(FORMAT_RESERVED); - EXPECT_EQ(b.GetFormat(), FORMAT_RESERVED); - - EXPECT_EQ(b.GetDataType(), DT_FLOAT); - b.SetDataType(DT_INT8); - EXPECT_EQ(b.GetDataType(), DT_INT8); - - GeTensorDesc c; - c.Update(GeShape({1}), FORMAT_NCHW); - c.Update(s, FORMAT_NCHW); - uint32_t size1 = 1; - TensorUtils::SetSize(c, size1); - GeTensorDesc d; - d = c.Clone(); - GeTensorDesc e = c; - int64_t size2 = 0; - EXPECT_EQ(TensorUtils::GetSize(e, size2), GRAPH_SUCCESS); - EXPECT_EQ(size2, 1); - - GeTensorDesc f = c; - size2 = 0; - EXPECT_EQ(TensorUtils::GetSize(f, size2), GRAPH_SUCCESS); - EXPECT_EQ(size2, 1); - EXPECT_EQ(c.IsValid(), GRAPH_SUCCESS); - c.Update(GeShape(), FORMAT_RESERVED, DT_UNDEFINED); - EXPECT_EQ(c.IsValid(), GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGeTensor, tensor) { - GeShape s({1, 2, 3, 4}); - GeTensorDesc tensor_desc(s); - std::vector data({1, 2, 3, 4}); - GeTensor a; - GeTensor b(tensor_desc); - GeTensor c(tensor_desc, data); - GeTensor d(tensor_desc, data.data(), data.size()); - - GeShape s1 = b.GetTensorDesc().GetShape(); - EXPECT_EQ(s1.GetDim(0), 1); - EXPECT_EQ(b.GetTensorDesc().GetDataType(), DT_FLOAT); - b.MutableTensorDesc().SetDataType(DT_INT8); - EXPECT_EQ(b.GetTensorDesc().GetDataType(), DT_INT8); - b.SetTensorDesc(tensor_desc); - - auto data1 = c.GetData(); - c.SetData(data); - c.SetData(data.data(), data.size()); - EXPECT_EQ(c.GetData()[0], uint8_t(1)); - EXPECT_EQ(c.GetData()[1], uint8_t(2)); - EXPECT_EQ(c.MutableData().GetData()[2], uint8_t(3)); - EXPECT_EQ(c.MutableData().GetData()[3], uint8_t(4)); - - GeTensor e(std::move(tensor_desc), std::move(data)); - EXPECT_EQ(e.GetData().GetSize(), data.size()); - EXPECT_EQ(e.GetData()[2], uint8_t(3)); - - GeTensor f = e.Clone(); - e.MutableData().data()[2] = 5; - EXPECT_EQ(e.GetData().data()[2], uint8_t(5)); - EXPECT_EQ(f.GetData().GetSize(), data.size()); - EXPECT_EQ(f.GetData()[2], uint8_t(3)); -} - -TEST_F(UtestGeTensor, test_shape_copy_move) { - GeShape shape(nullptr, nullptr); - EXPECT_EQ(shape.GetDimNum(), 0); - - GeShape shape2 = shape; - EXPECT_EQ(shape2.GetDimNum(), 0); - - GeShape shape3({1, 2, 3}); - shape2 = shape3; - EXPECT_EQ(shape2.GetDimNum(), 3); - EXPECT_EQ(shape3.GetDimNum(), 3); - - GeShape shape4 = std::move(shape3); - EXPECT_EQ(shape4.GetDimNum(), 3); - EXPECT_EQ(shape3.GetDimNum(), 3); - - GeShape shape5; - EXPECT_EQ(shape5.GetDimNum(), 0); - shape5 = std::move(shape4); - EXPECT_EQ(shape5.GetDimNum(), 3); - EXPECT_EQ(shape4.GetDimNum(), 3); -} - -TEST_F(UtestGeTensor, test_tensor_null_data) { - TensorData tensor_data; - EXPECT_EQ(tensor_data.SetData(nullptr, 1), GRAPH_SUCCESS); -} - -TEST_F(UtestGeTensor, test_tensor_null_proto) { - ProtoMsgOwner msg_owner; - GeTensor tensor(msg_owner, nullptr); - EXPECT_EQ(tensor.GetData().size(), 0); - EXPECT_EQ(tensor.MutableData().size(), 0); - EXPECT_EQ(tensor.SetData(Buffer(100)), GRAPH_SUCCESS); - - TensorUtils::SetWeightSize(tensor.MutableTensorDesc(), 100); - EXPECT_EQ(TensorUtils::GetWeightSize(tensor), 100); - - auto tensor_ptr = std::make_shared(msg_owner, nullptr); - TensorUtils::SetWeightSize(tensor_ptr->MutableTensorDesc(), 100); - EXPECT_EQ(TensorUtils::GetWeightSize(tensor_ptr), 100); - - GeTensor tensor1 = tensor; - EXPECT_EQ(TensorUtils::GetWeightSize(tensor1), 100); -} - -TEST_F(UtestGeTensor, test_tensor_utils_weight_size) { - GeTensor tensor; - EXPECT_EQ(tensor.GetData().size(), 0); - EXPECT_EQ(tensor.MutableData().size(), 0); - EXPECT_EQ(tensor.SetData(Buffer(100)), GRAPH_SUCCESS); - - TensorUtils::SetWeightSize(tensor.MutableTensorDesc(), 100); - EXPECT_EQ(TensorUtils::GetWeightSize(tensor), 100); - - uint8_t buffer[100]; - EXPECT_TRUE(TensorUtils::GetWeightAddr(tensor, buffer) != nullptr); - - auto tensor_ptr = std::make_shared(); - TensorUtils::SetWeightSize(tensor_ptr->MutableTensorDesc(), 100); - EXPECT_EQ(TensorUtils::GetWeightSize(tensor_ptr), 100); - // test weight size is larger than 2g - TensorUtils::SetWeightSize(tensor_ptr->MutableTensorDesc(), INT64_MAX - 100); - EXPECT_EQ(TensorUtils::GetWeightSize(tensor_ptr), INT64_MAX - 100); - EXPECT_TRUE(TensorUtils::GetWeightAddr(tensor_ptr, buffer) != nullptr); - - GeTensor tensor1 = tensor; - EXPECT_EQ(TensorUtils::GetWeightSize(tensor1), 100); - - GeTensor tensor2(GeTensorDesc(), Buffer(100)); - EXPECT_EQ(tensor2.GetData().size(), 100); - EXPECT_EQ(tensor2.MutableData().size(), 100); - - GeTensor tensor3; - tensor3 = tensor2; - EXPECT_EQ(tensor3.GetData().size(), 100); - EXPECT_EQ(tensor3.MutableData().size(), 100); - - TensorUtils::SetDataOffset(tensor3.MutableTensorDesc(), 20); - EXPECT_EQ(TensorUtils::GetWeightAddr(tensor3, buffer), buffer + 20); -} - -TEST_F(UtestGeTensor, test_tensor_valid) { - // Tensor(const TensorDesc &tensor_desc, const std::vector &data) - Shape shape({1, 1, 1}); - TensorDesc tensor_desc(shape); - std::vector data({1, 2, 3, 4}); - Tensor tensor1(tensor_desc, data); - EXPECT_EQ(tensor1.IsValid(), GRAPH_SUCCESS); - - // Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) - TensorDesc tensor_desc2(Shape({3, 3, 3}), FORMAT_NCHW, DT_FLOAT); - uint32_t size2 = 3 * 3 * 3 * 4; - uint8_t data2[3 * 3 * 3 * 4] = {0}; - Tensor tensor2(tensor_desc2, data2, size2); - EXPECT_EQ(tensor2.IsValid(), GRAPH_SUCCESS); - - // Tensor(TensorDesc &&tensor_desc, std::vector &&data) - Tensor tensor3(std::move(tensor_desc), std::move(data)); - EXPECT_EQ(tensor3.IsValid(), GRAPH_SUCCESS); - - // DT_UNDEFINED - TensorDesc tensor_desc3(Shape({3, 3, 3}), FORMAT_NCHW, DT_UNDEFINED); - Tensor tensor4(tensor_desc3, data2, size2); - EXPECT_EQ(tensor4.IsValid(), GRAPH_SUCCESS); - - // Tensor() - Tensor tensor5; - EXPECT_EQ(tensor5.IsValid(), GRAPH_SUCCESS); - tensor5.SetTensorDesc(tensor_desc); - tensor5.SetData(data); - EXPECT_EQ(tensor5.IsValid(), GRAPH_SUCCESS); - - // scalar 1 - uint8_t data6[4] = {1, 2, 3, 4}; - Tensor tensor6; - tensor6.SetData(data6, 4); - EXPECT_EQ(tensor6.IsValid(), GRAPH_SUCCESS); - - // scalar 2 - TensorDesc tensor_desc7(Shape(), FORMAT_NCHW, DT_FLOAT); - float data7 = 2; - Tensor tensor7(tensor_desc7, (uint8_t *)&data7, sizeof(float)); - EXPECT_EQ(tensor7.IsValid(), GRAPH_SUCCESS); - - // string scalar - TensorDesc tensor_desc8(Shape(), FORMAT_NCHW, DT_STRING); - Tensor tensor8; - tensor8.SetTensorDesc(tensor_desc8); - string data8 = "A handsome boy write this code"; - EXPECT_EQ(tensor8.SetData(data8), GRAPH_SUCCESS); - EXPECT_EQ(tensor8.IsValid(), GRAPH_SUCCESS); - - // string vector - TensorDesc tensor_desc9(Shape({2}), FORMAT_NCHW, DT_STRING); - vector data9 = {"A handsome boy write this code", "very handsome"}; - Tensor tensor9(tensor_desc9); - EXPECT_EQ(tensor9.SetData(data9), GRAPH_SUCCESS); - EXPECT_EQ(tensor9.IsValid(), GRAPH_SUCCESS); - - vector empty_data9; - EXPECT_EQ(tensor9.SetData(empty_data9), GRAPH_FAILED); -} - -TEST_F(UtestGeTensor, test_tensor_invalid) { - // Tensor(const TensorDesc &tensor_desc, const std::vector &data) - Shape shape({1, 1, 1}); - TensorDesc tensor_desc(shape); - std::vector data({1, 2, 3, 4, 5}); - Tensor tensor1(tensor_desc, data); - EXPECT_EQ(tensor1.IsValid(), GRAPH_FAILED); - - // Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) - TensorDesc tensor_desc2(Shape({3, 3, 3}), FORMAT_NCHW, DT_FLOAT); - uint32_t size2 = 3 * 3 * 3; - uint8_t data2[3 * 3 * 3] = {0}; - Tensor tensor2(tensor_desc2, data2, size2); - EXPECT_EQ(tensor2.IsValid(), GRAPH_FAILED); - - // Tensor(TensorDesc &&tensor_desc, std::vector &&data) - Tensor tensor3(std::move(tensor_desc), std::move(data)); - EXPECT_EQ(tensor3.IsValid(), GRAPH_FAILED); - - // Tensor() - Tensor tensor4; - tensor4.SetTensorDesc(tensor_desc); - EXPECT_EQ(tensor4.IsValid(), GRAPH_FAILED); - tensor4.SetData(data); - EXPECT_EQ(tensor4.IsValid(), GRAPH_FAILED); - - Tensor tensor5; - tensor5.SetData(data); - EXPECT_EQ(tensor5.IsValid(), GRAPH_FAILED); - tensor5.SetTensorDesc(tensor_desc); - EXPECT_EQ(tensor5.IsValid(), GRAPH_FAILED); - - // scalar - TensorDesc tensor_desc6(Shape(), FORMAT_NCHW, DT_FLOAT); - uint8_t data6 = 2; - Tensor tensor6(tensor_desc6, &data6, 1); - EXPECT_EQ(tensor6.IsValid(), GRAPH_FAILED); -} - -TEST_F(UtestGeTensor, NullObject) { - std::vector ints{1, 2, 3, 4}; - GeShape shape1(ints); - GeTensorSerializeUtils::GetShapeFromDescProto(nullptr, shape1); - EXPECT_EQ(shape1.GetDims(), ints); - GeTensorSerializeUtils::GetOriginShapeFromDescProto(nullptr, shape1); - EXPECT_EQ(shape1.GetDims(), ints); -} - -TEST_F(UtestGeTensor, GetFormatFromDescProto_OnlyGetPrimaryFormat_SerializeOp) { - GeShape shape({1, 2, 3, 4}); - GeTensorDesc desc(shape, FORMAT_NC1HWC0, DT_FLOAT16); - desc.SetOriginDataType(DT_INT32); - desc.SetOriginFormat(FORMAT_FRACTAL_Z); - desc.SetOriginShape(GeShape({4, 3, 2, 1})); - GeTensor tensor(desc); - proto::TensorDescriptor desc_proto; - desc_proto.set_layout(TypeUtils::FormatToSerialString(desc.GetFormat())); - // get format through opdesc - Format format_result; - GeTensorSerializeUtils::GetFormatFromDescProto(&desc_proto, format_result); - EXPECT_EQ(format_result, FORMAT_NC1HWC0); -} - -TEST_F(UtestGeTensor, GetFormatFromDescProto_GetFullFormat_SerializeOp) { - GeShape shape({1, 2, 3, 4}); - // {c0_value, bit_value}: c0_value = 2 ^ (bit_value - 1) - // {1, 1}, {2, 2}, {4, 3}, {8, 4}, {16, 5}, {32, 6}, {64, 7}, {128, 8}, {256, 9} - // 5 indicates that cube size is 16 - const Format format = static_cast(GetFormatFromSubAndC0(FORMAT_NC1HWC0, FORMAT_RESERVED, 5)); - GeTensorDesc desc(shape, FORMAT_NC1HWC0, DT_FLOAT16); - desc.SetOriginDataType(DT_INT32); - desc.SetOriginFormat(FORMAT_FRACTAL_Z); - desc.SetOriginShape(GeShape({4, 3, 2, 1})); - GeTensor tensor(desc); - proto::TensorDescriptor desc_proto; - desc_proto.set_layout(TypeUtils::FormatToSerialString(desc.GetFormat())); - - // get format through attr - ge::proto::AttrDef format_attr; - format_attr.set_i(format); - (void)desc_proto.mutable_attr()->insert({"format_for_int", format_attr}); - Format format_result; - GeTensorSerializeUtils::GetFormatFromDescProto(&desc_proto, format_result); - EXPECT_EQ(format_result, format); -} - -TEST_F(UtestGeTensor, GetOriginFormatFromDescProto_GetFullOriginFormat_SerializeOp) { - GeShape shape({1, 2, 3, 4}); - // {c0_value, bit_value}: c0_value = 2 ^ (bit_value - 1) - // {1, 1}, {2, 2}, {4, 3}, {8, 4}, {16, 5}, {32, 6}, {64, 7}, {128, 8}, {256, 9} - // 5 indicates that cube size is 16 - const Format origin_format = static_cast(GetFormatFromSubAndC0(FORMAT_FRACTAL_Z, FORMAT_RESERVED, 4)); - GeTensorDesc desc(shape, FORMAT_NC1HWC0, DT_FLOAT16); - desc.SetOriginDataType(DT_INT32); - desc.SetOriginFormat(FORMAT_FRACTAL_Z); - desc.SetOriginShape(GeShape({4, 3, 2, 1})); - GeTensor tensor(desc); - proto::TensorDescriptor desc_proto; - desc_proto.set_layout(TypeUtils::FormatToSerialString(desc.GetFormat())); - - // get format through attr - ge::proto::AttrDef ori_format_attr; - ori_format_attr.set_i(origin_format); - (void)desc_proto.mutable_attr()->insert({"origin_format_for_int", ori_format_attr}); - Format origin_format_result; - GeTensorSerializeUtils::GetOriginFormatFromDescProto(&desc_proto, origin_format_result); - EXPECT_EQ(origin_format_result, origin_format); -} -TEST_F(UtestGeTensor, tensor_desc_set_get_expand_dims_rule) { - GeTensorDesc a; - // init status - EXPECT_TRUE(a.GetExpandDimsRule().empty()); - - // test set and get - a.SetExpandDimsRule("0011"); - EXPECT_STREQ(a.GetExpandDimsRule().c_str(), "0011"); -} -TEST_F(UtestGeTensor, test_tensor_data_invalid) { - std::vector ge_tensor(2U); - for (size_t i = 0U; i < ge_tensor.size(); ++i) { - const static ge::Tensor::DeleteFunc kDoNothing = [](uint8_t *data) {}; - ge_tensor[i].SetData(nullptr, 0U, kDoNothing); - EXPECT_EQ(ge_tensor[i].IsTensorDataValid(), false); - } - - for (size_t i = 0U; i < ge_tensor.size(); ++i) { - static const uint8_t tmp[] = {0, 0, 0, 0}; - ge_tensor[i].SetData(tmp, sizeof(tmp)); - EXPECT_EQ(ge_tensor[i].IsTensorDataValid(), true); - } -} - -TEST_F(UtestGeTensor, test_ge_tensor_desc) { - GeTensorDesc a; - GeShape shape({1, 2, 3, 4, 16}); - GeShape ori_shape({1, 32, 3, 4}); - GeTensorDesc b(shape, FORMAT_NC1HWC0); - b.SetOriginShape(ori_shape); - - GeShape ret_ori = b.GetOriginShape(); - EXPECT_EQ(ret_ori.GetDimNum(), ori_shape.GetDimNum()); - for (size_t i = 0U; i < ret_ori.GetDimNum(); ++i) { - EXPECT_EQ(ret_ori.GetDim(i), ori_shape.GetDim(i)); - } - GeShape ori_shape2({3, 4}); - b.MutableOriginShape() = ori_shape2; - GeShape ret_ori2 = b.GetOriginShape(); - EXPECT_EQ(ret_ori2.GetDimNum(), ori_shape2.GetDimNum()); - for (size_t i = 0U; i < ret_ori2.GetDimNum(); ++i) { - EXPECT_EQ(ret_ori2.GetDim(i), ori_shape2.GetDim(i)); - } -} - -TEST_F(UtestGeTensor, test_is_shape_equal_unknown_rank) { - GeTensorDesc a; - GeShape src_shape({-2}); - GeShape dst_shape({1, 2, -1}); - EXPECT_EQ(TensorUtils::IsShapeEqual(src_shape, dst_shape), false); - GeShape equal_dst_shape({-2}); - EXPECT_EQ(TensorUtils::IsShapeEqual(src_shape, equal_dst_shape), true); -} - -TEST_F(UtestGeTensor, test_is_shape_equal_unknown_shape) { - GeTensorDesc a; - GeShape src_shape({1, 2, -1}); - GeShape unknown_dst_shape({1, -1, 2}); - GeShape unkown_dst_shape1({1, 2, 1024}); - EXPECT_EQ(TensorUtils::IsShapeEqual(src_shape, unknown_dst_shape), true); - EXPECT_EQ(TensorUtils::IsShapeEqual(src_shape, unkown_dst_shape1), true); -} diff --git a/tests/ut/graph/testcase/gen_task_callback_unittest.cc b/tests/ut/graph/testcase/gen_task_callback_unittest.cc deleted file mode 100644 index 8d3f37d8da7a2b6f4a2b0c011ea9d38ae8ed7158..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/gen_task_callback_unittest.cc +++ /dev/null @@ -1,630 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/kernel_launch_info.h" -#include "graph/arg_desc_info.h" -#include "runtime/rt_model.h" -#include "proto/task.pb.h" -#include "inc/external/exe_graph/runtime/exe_res_generation_context.h" -#include "inc/exe_graph/lowering/exe_res_generation_ctx_builder.h" -#include "ge/framework/common/taskdown_common.h" -#include "graph/utils/args_format_desc_utils.h" -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "common/checker.h" -#include "graph/debug/ge_attr_define.h" - -namespace ge { -namespace { -ComputeGraphPtr CreateMc2NodeGraph() { - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr x1_desc = std::make_shared("x1", "Data"); - OpDescPtr x2_desc = std::make_shared("x2", "Data"); - OpDescPtr bias_desc = std::make_shared("bias", "Data"); - OpDescPtr all_gather_matmul_desc = std::make_shared("mc2", "AllGatherMatmul"); - OpDescPtr net_output_desc = std::make_shared("output", "NetOutput"); - - // add descriptor - ge::GeShape shape1({2,4}); - GeTensorDesc tensor_desc1(shape1, ge::FORMAT_ND, ge::DT_FLOAT16); - tensor_desc1.SetOriginFormat(ge::FORMAT_ND); - tensor_desc1.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - - ge::GeShape shape2({4,3}); - GeTensorDesc tensor_desc2(shape2, ge::FORMAT_ND, ge::DT_FLOAT16); - tensor_desc2.SetOriginFormat(ge::FORMAT_ND); - tensor_desc2.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - - ge::GeShape shape3({3}); - GeTensorDesc tensor_desc3(shape3, ge::FORMAT_ND, ge::DT_FLOAT16); - tensor_desc3.SetOriginFormat(ge::FORMAT_ND); - tensor_desc3.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - - ge::GeShape shape4({2, 3}); - GeTensorDesc tensor_desc4(shape4, ge::FORMAT_ND, ge::DT_FLOAT16); - tensor_desc4.SetOriginFormat(ge::FORMAT_ND); - tensor_desc4.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc4.SetOriginShape(shape4); - - x1_desc->AddOutputDesc(tensor_desc1); - x2_desc->AddOutputDesc(tensor_desc2); - bias_desc->AddOutputDesc(tensor_desc3); - - all_gather_matmul_desc->AddInputDesc(tensor_desc1); - all_gather_matmul_desc->AddInputDesc(tensor_desc2); - all_gather_matmul_desc->AddInputDesc(tensor_desc3); - all_gather_matmul_desc->AddOutputDesc(tensor_desc4); - all_gather_matmul_desc->AddOutputDesc(tensor_desc4); - all_gather_matmul_desc->AppendIrInput("x1", ge::kIrInputRequired); - all_gather_matmul_desc->AppendIrInput("x2", ge::kIrInputRequired); - all_gather_matmul_desc->AppendIrInput("bias", ge::kIrInputOptional); - all_gather_matmul_desc->AppendIrOutput("y", ge::kIrOutputRequired); - all_gather_matmul_desc->AppendIrOutput("gather_out", ge::kIrOutputRequired); - - net_output_desc->AddInputDesc(tensor_desc4); - net_output_desc->AddInputDesc(tensor_desc4); - // create nodes - NodePtr x1_node = graph->AddNode(x1_desc); - NodePtr x2_node = graph->AddNode(x2_desc); - NodePtr bias_node = graph->AddNode(bias_desc); - NodePtr mc2_node = graph->AddNode(all_gather_matmul_desc); - NodePtr output_node = graph->AddNode(net_output_desc); - - ge::GraphUtils::AddEdge(x1_node->GetOutDataAnchor(0), mc2_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(x2_node->GetOutDataAnchor(0), mc2_node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(bias_node->GetOutDataAnchor(0), mc2_node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(mc2_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(mc2_node->GetOutDataAnchor(1), output_node->GetInDataAnchor(1)); - - all_gather_matmul_desc->SetStreamId(2); - all_gather_matmul_desc->SetId(4); - std::vector ori_work_sizes{22,33,44}; - all_gather_matmul_desc->SetWorkspaceBytes(ori_work_sizes); - return graph; -} - -gert::ExeResGenerationCtxHolderPtr CreateNodeExeResContext(const NodePtr &node) { - gert::ExeResGenerationCtxBuilder exe_ctx_builder; - auto res_ptr_holder = exe_ctx_builder.CreateOpExeContext(*node); - auto op_exe_res_ctx = reinterpret_cast(res_ptr_holder->GetKernelContext()); - std::vector stream_info_vec; - gert::StreamInfo si_1; - si_1.name = "aicpu kfc server"; - si_1.reuse_key = "kfc_stream"; - si_1.depend_value_input_indices = {}; - si_1.required = true; - stream_info_vec.emplace_back(si_1); - op_exe_res_ctx->SetAttachedStreamInfos(stream_info_vec); - std::vector stream_info_attrs; - (void)ge::AttrUtils::GetListNamedAttrs(node->GetOpDesc(), ge::ATTR_NAME_ATTACHED_STREAM_INFO_LIST, - stream_info_attrs); - (void)ge::AttrUtils::SetInt(stream_info_attrs.front(), ge::ATTR_NAME_ATTACHED_RESOURCE_ID, 4); - (void)ge::AttrUtils::SetListNamedAttrs(node->GetOpDesc(), ge::ATTR_NAME_ATTACHED_STREAM_INFO_LIST, - stream_info_attrs); - return res_ptr_holder; -} - -struct HcclCommParamDesc { - uint64_t version : 4; - uint64_t group_num : 4; - uint64_t has_ffts : 1; - uint64_t tiling_off : 7; - uint64_t is_dyn : 48; -}; - -graphStatus Mc2GenTaskCallback(const gert::ExeResGenerationContext *context, - std::vector> &tasks) { - GE_ASSERT_NOTNULL(context); - GE_ASSERT_TRUE(tasks.size() == 1UL); - auto aicore_index = 0; - // 获取attach流id - auto stream_infos = context->GetAttachedStreamInfos(); - GE_ASSERT_TRUE(!stream_infos.empty()); - const int64_t attach_stream_id = stream_infos[0].stream_id; - const int64_t stream_id = context->GetStreamId(); - // 创建WaitTask - auto wait_task = KernelLaunchInfo::CreateHcomWaitTask(context); - wait_task.SetStreamId(attach_stream_id); - tasks.insert(tasks.begin() + aicore_index, wait_task.Serialize()); - aicore_index++; - // 设置aicpu任务 - auto aicpu_task = KernelLaunchInfo::CreateAicpuKfcTask(context, - "libccl_kernel.so", "RunAicpuKfcSrvLaunch"); - size_t input_size = context->GetComputeNodeInfo()->GetIrInputsNum(); - size_t output_size = context->GetComputeNodeInfo()->GetIrOutputsNum(); - const size_t offset = 3UL; - union { - HcclCommParamDesc hccl_desc; - uint64_t custom_value; - } desc; - desc.hccl_desc.version = 1; - desc.hccl_desc.group_num = 1; - desc.hccl_desc.has_ffts = 0; - desc.hccl_desc.tiling_off = offset + input_size + output_size; - desc.hccl_desc.is_dyn = 0; - std::vector aicpu_args_format; - aicpu_args_format.emplace_back(ArgDescInfo::CreateCustomValue(desc.custom_value)); - aicpu_args_format.emplace_back(ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom)); - aicpu_args_format.emplace_back(ArgDescInfo(ArgDescType::kIrInput, 0)); - for (size_t i = 1; i < input_size; i++) { - aicpu_args_format.emplace_back(ArgDescInfo::CreateCustomValue(0)); - } - for (size_t i = 0; i < output_size; i++) { - aicpu_args_format.emplace_back(ArgDescInfo(ArgDescType::kIrOutput, i)); - } - aicpu_args_format.emplace_back(ArgDescInfo(ArgDescType::kWorkspace)); - aicpu_args_format.emplace_back(ArgDescInfo(ArgDescType::kTiling)); - aicpu_task.SetArgsFormat(ArgsFormatSerializer::Serialize(aicpu_args_format).GetString()); - aicpu_task.SetStreamId(attach_stream_id); - tasks.insert(tasks.begin() + aicore_index, aicpu_task.Serialize()); - aicore_index++; - // 创建RecordTask - auto record_task = KernelLaunchInfo::CreateHcomRecordTask(context); - record_task.SetStreamId(stream_id); - tasks.insert(tasks.begin() + aicore_index, record_task.Serialize()); - aicore_index++; - // 更改原AICORE任务的argsformat - auto aicore_task = KernelLaunchInfo::LoadFromData(context, tasks.back()); - auto aicore_args_format_str = aicore_task.GetArgsFormat(); - auto aicore_args_format = ArgsFormatSerializer::Deserialize(aicore_args_format_str); - size_t i = 0UL; - for (; i < aicore_args_format.size(); i++) { - if (aicore_args_format[i].GetType() == ArgDescType::kIrInput || - aicore_args_format[i].GetType() == ArgDescType::kInputInstance) { - break; - } - } - aicore_args_format.insert(aicore_args_format.begin() + i, ArgDescInfo::CreateHiddenInput(HiddenInputSubType::kHcom)); - aicore_task.SetArgsFormat(ArgsFormatSerializer::Serialize(aicore_args_format).GetString()); - tasks.back() = aicore_task.Serialize(); - return SUCCESS; -} -} -class TestGenTaskCallback : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; -// 验证使用kernel_def的mc2算子在GenTaskCallback函数中构造taskDef的功能 -TEST_F(TestGenTaskCallback, TestNormalMc2NodeGenTaskCallback) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - domi::TaskDef aicore_task_def; - aicore_task_def.set_type(RT_MODEL_TASK_KERNEL); - aicore_task_def.set_id(op_exe_res_ctx->GetOpId()); - aicore_task_def.set_stream_id(op_exe_res_ctx->GetStreamId()); - auto kernel_def = aicore_task_def.mutable_kernel(); - kernel_def->set_block_dim(32); - kernel_def->set_schedule_mode(0); - auto kernel_context = kernel_def->mutable_context(); - kernel_context->set_kernel_type(static_cast(ccKernelType::TE_AI_CORE)); - kernel_context->set_op_index(op_exe_res_ctx->GetOpId()); - std::vector args; - size_t input_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrInputsNum(); - size_t output_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrOutputsNum(); - for (size_t i = 0UL; i < input_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::INPUT, i); - } - for (size_t i = 0UL; i < output_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT, i); - } - ArgsFormatDescUtils::Append(args, AddrType::WORKSPACE); - ArgsFormatDescUtils::Append(args, AddrType::TILING); - kernel_context->set_args_format(ArgsFormatDescUtils::Serialize(args)); - // 序列化 - std::vector> tasks; - auto buffer_size = aicore_task_def.ByteSizeLong(); - std::vector buffer(buffer_size, 0); - aicore_task_def.SerializeToArray(buffer.data(), buffer_size); - tasks.emplace_back(buffer); - // 执行mc2的gentaskcallback - EXPECT_EQ(Mc2GenTaskCallback(op_exe_res_ctx, tasks), SUCCESS); - EXPECT_EQ(tasks.size(), 4UL); - // 校验wait算子的结果 - domi::TaskDef wait_task; - wait_task.ParseFromArray(tasks[0].data(), tasks[0].size()); - EXPECT_EQ(wait_task.id(), 4); - EXPECT_EQ(wait_task.notify_id(), UINT32_MAX); - EXPECT_EQ(wait_task.type(), RT_MODEL_TASK_NOTIFY_WAIT); - EXPECT_EQ(wait_task.private_def(), "group"); - EXPECT_EQ(wait_task.stream_id(), 4); - // 校验aicpu算子结果 - domi::TaskDef aicpu_task; - aicpu_task.ParseFromArray(tasks[1].data(), tasks[1].size()); - EXPECT_EQ(aicpu_task.type(), RT_MODEL_TASK_KERNEL); - EXPECT_EQ(aicpu_task.stream_id(), 4); - auto aicpu_kernel_def = aicpu_task.mutable_kernel(); - EXPECT_EQ(aicpu_kernel_def->so_name(), "libccl_kernel.so"); - EXPECT_EQ(aicpu_kernel_def->kernel_name(), "RunAicpuKfcSrvLaunch"); - auto aicpu_kernel_context = aicpu_kernel_def->mutable_context(); - EXPECT_EQ(aicpu_kernel_context->kernel_type(), static_cast(ccKernelType::AI_CPU_KFC)); - EXPECT_EQ(aicpu_kernel_context->op_index(), 4); - auto aicpu_args_format = aicpu_kernel_context->args_format(); - EXPECT_EQ(aicpu_args_format, "{#4113}{hi.hcom0*}{i0*}{#0}{#0}{o0*}{o1*}{ws*}{t}"); - - // 校验record算子结果 - domi::TaskDef record_task; - record_task.ParseFromArray(tasks[2].data(), tasks[2].size()); - EXPECT_EQ(record_task.id(), 4); - EXPECT_EQ(record_task.notify_id(), UINT32_MAX); - EXPECT_EQ(record_task.type(), RT_MODEL_TASK_NOTIFY_RECORD); - EXPECT_EQ(record_task.private_def(), "group"); - EXPECT_EQ(record_task.stream_id(), 2); - // 校验aicore算子结果 - domi::TaskDef aicore_task; - aicore_task.ParseFromArray(tasks[3].data(), tasks[3].size()); - EXPECT_EQ(aicore_task.type(), RT_MODEL_TASK_KERNEL); - EXPECT_EQ(aicore_task.stream_id(), 2); - auto aicore_kernel_def = aicore_task.mutable_kernel(); - EXPECT_EQ(aicore_kernel_def->block_dim(), 32); - EXPECT_EQ(aicore_kernel_def->schedule_mode(), 0); - auto aicore_kernel_context = aicore_kernel_def->mutable_context(); - EXPECT_EQ(aicore_kernel_context->kernel_type(), static_cast(ccKernelType::TE_AI_CORE)); - EXPECT_EQ(aicore_kernel_context->op_index(), 4); - auto aicore_args_format = aicore_kernel_context->args_format(); - EXPECT_EQ(aicore_args_format, "{hi.hcom0*}{i0*}{i1*}{i2*}{o0*}{o1*}{ws*}{t}"); -} - -// 验证使用kernel_def_with_handle的mc2算子在GenTaskCallback函数中构造taskDef的功能 -TEST_F(TestGenTaskCallback, TestMc2NodeWithHandleGenTaskCallback) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - domi::TaskDef aicore_task_def; - aicore_task_def.set_type(RT_MODEL_TASK_ALL_KERNEL); - aicore_task_def.set_id(op_exe_res_ctx->GetOpId()); - aicore_task_def.set_stream_id(op_exe_res_ctx->GetStreamId()); - auto kernel_def_with_handle = aicore_task_def.mutable_kernel_with_handle(); - kernel_def_with_handle->set_block_dim(32); - kernel_def_with_handle->set_schedule_mode(0); - auto kernel_context = kernel_def_with_handle->mutable_context(); - kernel_context->set_kernel_type(static_cast(ccKernelType::TE_AI_CORE)); - kernel_context->set_op_index(op_exe_res_ctx->GetOpId()); - std::vector args; - size_t input_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrInputsNum(); - size_t output_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrOutputsNum(); - for (size_t i = 0UL; i < input_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::INPUT, i); - } - for (size_t i = 0UL; i < output_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT, i); - } - ArgsFormatDescUtils::Append(args, AddrType::WORKSPACE); - ArgsFormatDescUtils::Append(args, AddrType::TILING); - kernel_context->set_args_format(ArgsFormatDescUtils::Serialize(args)); - // 序列化 - std::vector> tasks; - auto buffer_size = aicore_task_def.ByteSizeLong(); - std::vector buffer(buffer_size, 0); - aicore_task_def.SerializeToArray(buffer.data(), buffer_size); - tasks.emplace_back(buffer); - // 执行mc2的gentaskcallback - EXPECT_EQ(Mc2GenTaskCallback(op_exe_res_ctx, tasks), SUCCESS); - EXPECT_EQ(tasks.size(), 4UL); - // 校验wait算子的结果 - domi::TaskDef wait_task; - wait_task.ParseFromArray(tasks[0].data(), tasks[0].size()); - EXPECT_EQ(wait_task.id(), 4); - EXPECT_EQ(wait_task.notify_id(), UINT32_MAX); - EXPECT_EQ(wait_task.type(), RT_MODEL_TASK_NOTIFY_WAIT); - EXPECT_EQ(wait_task.private_def(), "group"); - EXPECT_EQ(wait_task.stream_id(), 4); - // 校验aicpu算子结果 - domi::TaskDef aicpu_task; - aicpu_task.ParseFromArray(tasks[1].data(), tasks[1].size()); - EXPECT_EQ(aicpu_task.type(), RT_MODEL_TASK_KERNEL); - EXPECT_EQ(aicpu_task.stream_id(), 4); - auto aicpu_kernel_def = aicpu_task.mutable_kernel(); - EXPECT_EQ(aicpu_kernel_def->so_name(), "libccl_kernel.so"); - EXPECT_EQ(aicpu_kernel_def->kernel_name(), "RunAicpuKfcSrvLaunch"); - auto aicpu_kernel_context = aicpu_kernel_def->mutable_context(); - EXPECT_EQ(aicpu_kernel_context->kernel_type(), static_cast(ccKernelType::AI_CPU_KFC)); - EXPECT_EQ(aicpu_kernel_context->op_index(), 4); - auto aicpu_args_format = aicpu_kernel_context->args_format(); - EXPECT_EQ(aicpu_args_format, "{#4113}{hi.hcom0*}{i0*}{#0}{#0}{o0*}{o1*}{ws*}{t}"); - - // 校验record算子结果 - domi::TaskDef record_task; - record_task.ParseFromArray(tasks[2].data(), tasks[2].size()); - EXPECT_EQ(record_task.id(), 4); - EXPECT_EQ(record_task.notify_id(), UINT32_MAX); - EXPECT_EQ(record_task.type(), RT_MODEL_TASK_NOTIFY_RECORD); - EXPECT_EQ(record_task.private_def(), "group"); - EXPECT_EQ(record_task.stream_id(), 2); - // 校验aicore算子结果 - domi::TaskDef aicore_task; - aicore_task.ParseFromArray(tasks[3].data(), tasks[3].size()); - EXPECT_EQ(aicore_task.type(), RT_MODEL_TASK_ALL_KERNEL); - EXPECT_EQ(aicore_task.stream_id(), 2); - auto aicore_kernel_def = aicore_task.mutable_kernel_with_handle(); - EXPECT_EQ(aicore_kernel_def->block_dim(), 32); - EXPECT_EQ(aicore_kernel_def->schedule_mode(), 0); - auto aicore_kernel_context = aicore_kernel_def->mutable_context(); - EXPECT_EQ(aicore_kernel_context->kernel_type(), static_cast(ccKernelType::TE_AI_CORE)); - EXPECT_EQ(aicore_kernel_context->op_index(), 4); - auto aicore_args_format = aicore_kernel_context->args_format(); - EXPECT_EQ(aicore_args_format, "{hi.hcom0*}{i0*}{i1*}{i2*}{o0*}{o1*}{ws*}{t}"); -} - -// 验证使用mixL2的mc2算子在GenTaskCallback函数中构造taskDef的功能 -TEST_F(TestGenTaskCallback, TestMixL2Mc2NodeGenTaskCallback) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - domi::TaskDef aicore_task_def; - aicore_task_def.set_type(RT_MODEL_TASK_ALL_KERNEL); - aicore_task_def.set_id(op_exe_res_ctx->GetOpId()); - aicore_task_def.set_stream_id(op_exe_res_ctx->GetStreamId()); - auto kernel_def_with_handle = aicore_task_def.mutable_kernel_with_handle(); - kernel_def_with_handle->set_block_dim(32); - kernel_def_with_handle->set_schedule_mode(0); - auto kernel_context = kernel_def_with_handle->mutable_context(); - kernel_context->set_kernel_type(static_cast(ccKernelType::TE)); - kernel_context->set_op_index(op_exe_res_ctx->GetOpId()); - std::vector args; - ArgsFormatDescUtils::Append(args, AddrType::FFTS_ADDR); - size_t input_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrInputsNum(); - size_t output_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrOutputsNum(); - for (size_t i = 0UL; i < input_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::INPUT, i); - } - for (size_t i = 0UL; i < output_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT, i); - } - ArgsFormatDescUtils::Append(args, AddrType::WORKSPACE); - ArgsFormatDescUtils::Append(args, AddrType::TILING); - kernel_context->set_args_format(ArgsFormatDescUtils::Serialize(args)); - // 序列化 - std::vector> tasks; - auto buffer_size = aicore_task_def.ByteSizeLong(); - std::vector buffer(buffer_size, 0); - aicore_task_def.SerializeToArray(buffer.data(), buffer_size); - tasks.emplace_back(buffer); - // 执行mc2的gentaskcallback - EXPECT_EQ(Mc2GenTaskCallback(op_exe_res_ctx, tasks), SUCCESS); - EXPECT_EQ(tasks.size(), 4UL); - // 校验wait算子的结果 - domi::TaskDef wait_task; - wait_task.ParseFromArray(tasks[0].data(), tasks[0].size()); - EXPECT_EQ(wait_task.id(), 4); - EXPECT_EQ(wait_task.notify_id(), UINT32_MAX); - EXPECT_EQ(wait_task.type(), RT_MODEL_TASK_NOTIFY_WAIT); - EXPECT_EQ(wait_task.private_def(), "group"); - EXPECT_EQ(wait_task.stream_id(), 4); - // 校验aicpu算子结果 - domi::TaskDef aicpu_task; - aicpu_task.ParseFromArray(tasks[1].data(), tasks[1].size()); - EXPECT_EQ(aicpu_task.type(), RT_MODEL_TASK_KERNEL); - EXPECT_EQ(aicpu_task.stream_id(), 4); - auto aicpu_kernel_def = aicpu_task.mutable_kernel(); - EXPECT_EQ(aicpu_kernel_def->so_name(), "libccl_kernel.so"); - EXPECT_EQ(aicpu_kernel_def->kernel_name(), "RunAicpuKfcSrvLaunch"); - auto aicpu_kernel_context = aicpu_kernel_def->mutable_context(); - EXPECT_EQ(aicpu_kernel_context->kernel_type(), static_cast(ccKernelType::AI_CPU_KFC)); - EXPECT_EQ(aicpu_kernel_context->op_index(), 4); - auto aicpu_args_format = aicpu_kernel_context->args_format(); - EXPECT_EQ(aicpu_args_format, "{#4113}{hi.hcom0*}{i0*}{#0}{#0}{o0*}{o1*}{ws*}{t}"); - - // 校验record算子结果 - domi::TaskDef record_task; - record_task.ParseFromArray(tasks[2].data(), tasks[2].size()); - EXPECT_EQ(record_task.id(), 4); - EXPECT_EQ(record_task.notify_id(), UINT32_MAX); - EXPECT_EQ(record_task.type(), RT_MODEL_TASK_NOTIFY_RECORD); - EXPECT_EQ(record_task.private_def(), "group"); - EXPECT_EQ(record_task.stream_id(), 2); - // 校验aicore算子结果 - domi::TaskDef aicore_task; - aicore_task.ParseFromArray(tasks[3].data(), tasks[3].size()); - EXPECT_EQ(aicore_task.type(), RT_MODEL_TASK_ALL_KERNEL); - EXPECT_EQ(aicore_task.stream_id(), 2); - auto aicore_kernel_def = aicore_task.mutable_kernel_with_handle(); - EXPECT_EQ(aicore_kernel_def->block_dim(), 32); - EXPECT_EQ(aicore_kernel_def->schedule_mode(), 0); - auto aicore_kernel_context = aicore_kernel_def->mutable_context(); - EXPECT_EQ(aicore_kernel_context->kernel_type(), static_cast(ccKernelType::TE)); - EXPECT_EQ(aicore_kernel_context->op_index(), 4); - auto aicore_args_format = aicore_kernel_context->args_format(); - EXPECT_EQ(aicore_args_format, "{ffts_addr}{hi.hcom0*}{i0*}{i1*}{i2*}{o0*}{o1*}{ws*}{t}"); -} - - -// 验证使用带有input_instance的mc2算子在GenTaskCallback函数中构造taskDef的功能 -TEST_F(TestGenTaskCallback, TestMc2WithInputInstanceNodeGenTaskCallback) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - domi::TaskDef aicore_task_def; - aicore_task_def.set_type(RT_MODEL_TASK_ALL_KERNEL); - aicore_task_def.set_id(op_exe_res_ctx->GetOpId()); - aicore_task_def.set_stream_id(op_exe_res_ctx->GetStreamId()); - auto kernel_def_with_handle = aicore_task_def.mutable_kernel_with_handle(); - kernel_def_with_handle->set_block_dim(32); - kernel_def_with_handle->set_schedule_mode(0); - auto kernel_context = kernel_def_with_handle->mutable_context(); - kernel_context->set_kernel_type(static_cast(ccKernelType::TE)); - kernel_context->set_op_index(op_exe_res_ctx->GetOpId()); - std::vector args; - ArgsFormatDescUtils::Append(args, AddrType::FFTS_ADDR); - size_t input_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrInputsNum(); - size_t output_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrOutputsNum(); - for (size_t i = 0UL; i < input_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::INPUT_INSTANCE, i); - } - for (size_t i = 0UL; i < output_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT_INSTANCE, i); - } - ArgsFormatDescUtils::Append(args, AddrType::WORKSPACE); - ArgsFormatDescUtils::Append(args, AddrType::TILING); - kernel_context->set_args_format(ArgsFormatDescUtils::Serialize(args)); - // 序列化 - std::vector> tasks; - auto buffer_size = aicore_task_def.ByteSizeLong(); - std::vector buffer(buffer_size, 0); - aicore_task_def.SerializeToArray(buffer.data(), buffer_size); - tasks.emplace_back(buffer); - // 执行mc2的gentaskcallback - EXPECT_EQ(Mc2GenTaskCallback(op_exe_res_ctx, tasks), SUCCESS); - EXPECT_EQ(tasks.size(), 4UL); - // 校验wait算子的结果 - domi::TaskDef wait_task; - wait_task.ParseFromArray(tasks[0].data(), tasks[0].size()); - EXPECT_EQ(wait_task.id(), 4); - EXPECT_EQ(wait_task.notify_id(), UINT32_MAX); - EXPECT_EQ(wait_task.type(), RT_MODEL_TASK_NOTIFY_WAIT); - EXPECT_EQ(wait_task.private_def(), "group"); - EXPECT_EQ(wait_task.stream_id(), 4); - // 校验aicpu算子结果 - domi::TaskDef aicpu_task; - aicpu_task.ParseFromArray(tasks[1].data(), tasks[1].size()); - EXPECT_EQ(aicpu_task.type(), RT_MODEL_TASK_KERNEL); - EXPECT_EQ(aicpu_task.stream_id(), 4); - auto aicpu_kernel_def = aicpu_task.mutable_kernel(); - EXPECT_EQ(aicpu_kernel_def->so_name(), "libccl_kernel.so"); - EXPECT_EQ(aicpu_kernel_def->kernel_name(), "RunAicpuKfcSrvLaunch"); - auto aicpu_kernel_context = aicpu_kernel_def->mutable_context(); - EXPECT_EQ(aicpu_kernel_context->kernel_type(), static_cast(ccKernelType::AI_CPU_KFC)); - EXPECT_EQ(aicpu_kernel_context->op_index(), 4); - auto aicpu_args_format = aicpu_kernel_context->args_format(); - EXPECT_EQ(aicpu_args_format, "{#4113}{hi.hcom0*}{i0*}{#0}{#0}{o0*}{o1*}{ws*}{t}"); - - // 校验record算子结果 - domi::TaskDef record_task; - record_task.ParseFromArray(tasks[2].data(), tasks[2].size()); - EXPECT_EQ(record_task.id(), 4); - EXPECT_EQ(record_task.notify_id(), UINT32_MAX); - EXPECT_EQ(record_task.type(), RT_MODEL_TASK_NOTIFY_RECORD); - EXPECT_EQ(record_task.private_def(), "group"); - EXPECT_EQ(record_task.stream_id(), 2); - // 校验aicore算子结果 - domi::TaskDef aicore_task; - aicore_task.ParseFromArray(tasks[3].data(), tasks[3].size()); - EXPECT_EQ(aicore_task.type(), RT_MODEL_TASK_ALL_KERNEL); - EXPECT_EQ(aicore_task.stream_id(), 2); - auto aicore_kernel_def = aicore_task.mutable_kernel_with_handle(); - EXPECT_EQ(aicore_kernel_def->block_dim(), 32); - EXPECT_EQ(aicore_kernel_def->schedule_mode(), 0); - auto aicore_kernel_context = aicore_kernel_def->mutable_context(); - EXPECT_EQ(aicore_kernel_context->kernel_type(), static_cast(ccKernelType::TE)); - EXPECT_EQ(aicore_kernel_context->op_index(), 4); - auto aicore_args_format = aicore_kernel_context->args_format(); - EXPECT_EQ(aicore_args_format, - "{ffts_addr}{hi.hcom0*}{i_instance0*}{i_instance1*}{i_instance2*}{o_instance0*}{o_instance1*}{ws*}{t}"); -} - -// 验证KernelLaunchInfo的移动构造函数和移动赋值函数 -TEST_F(TestGenTaskCallback, TestKernelLaunchInfoMoveConstruct) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - auto aicpu_task = KernelLaunchInfo::CreateAicpuKfcTask(op_exe_res_ctx, - "libccl_kernel.so", "RunAicpuKfcSrvLaunch"); - aicpu_task.SetStreamId(2); - aicpu_task.SetBlockDim(32); - std::string args_format = "{#4113}{hi.hcom0*}{i0*}{#0}{#0}{o0*}{o1*}{ws*}{t}"; - aicpu_task.SetArgsFormat(args_format.c_str()); - // 验证移动赋值函数 - KernelLaunchInfo aicpu_task_1 = KernelLaunchInfo::CreateHcomRecordTask(op_exe_res_ctx); - aicpu_task_1 = std::move(aicpu_task); - EXPECT_EQ(std::string(aicpu_task_1.GetSoName()), "libccl_kernel.so"); - EXPECT_EQ(std::string(aicpu_task_1.GetKernelName()), "RunAicpuKfcSrvLaunch"); - EXPECT_EQ(aicpu_task_1.GetStreamId(), 2); - EXPECT_EQ(aicpu_task_1.GetBlockDim(), 32); - EXPECT_EQ(std::string(aicpu_task_1.GetArgsFormat()), "{#4113}{hi.hcom0*}{i0*}{#0}{#0}{o0*}{o1*}{ws*}{t}"); - // 验证移动构造函数 - KernelLaunchInfo aicpu_task_2(std::move(aicpu_task_1)); - EXPECT_EQ(std::string(aicpu_task_2.GetSoName()), "libccl_kernel.so"); - EXPECT_EQ(std::string(aicpu_task_2.GetKernelName()), "RunAicpuKfcSrvLaunch"); - EXPECT_EQ(aicpu_task_2.GetStreamId(), 2); - EXPECT_EQ(std::string(aicpu_task_2.GetArgsFormat()), "{#4113}{hi.hcom0*}{i0*}{#0}{#0}{o0*}{o1*}{ws*}{t}"); -} - -// 验证KernelLaunchInfo的拷贝构造函数和拷贝赋值函数 -TEST_F(TestGenTaskCallback, TestKernelLaunchInfoCopyConstruct) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - domi::TaskDef aicore_task_def; - aicore_task_def.set_type(RT_MODEL_TASK_ALL_KERNEL); - aicore_task_def.set_id(op_exe_res_ctx->GetOpId()); - aicore_task_def.set_stream_id(op_exe_res_ctx->GetStreamId()); - auto kernel_def_with_handle = aicore_task_def.mutable_kernel_with_handle(); - kernel_def_with_handle->set_block_dim(32); - kernel_def_with_handle->set_schedule_mode(0); - auto kernel_context = kernel_def_with_handle->mutable_context(); - kernel_context->set_kernel_type(static_cast(ccKernelType::TE)); - kernel_context->set_op_index(op_exe_res_ctx->GetOpId()); - std::vector args; - ArgsFormatDescUtils::Append(args, AddrType::FFTS_ADDR); - size_t input_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrInputsNum(); - size_t output_size = op_exe_res_ctx->GetComputeNodeInfo()->GetIrOutputsNum(); - for (size_t i = 0UL; i < input_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::INPUT_INSTANCE, i); - } - for (size_t i = 0UL; i < output_size; i++) { - ArgsFormatDescUtils::Append(args, AddrType::OUTPUT_INSTANCE, i); - } - ArgsFormatDescUtils::Append(args, AddrType::WORKSPACE); - ArgsFormatDescUtils::Append(args, AddrType::TILING); - kernel_context->set_args_format(ArgsFormatDescUtils::Serialize(args)); - // 序列化 - std::vector> tasks; - auto buffer_size = aicore_task_def.ByteSizeLong(); - std::vector buffer(buffer_size, 0); - aicore_task_def.SerializeToArray(buffer.data(), buffer_size); - auto aicore_task = KernelLaunchInfo::LoadFromData(op_exe_res_ctx, buffer); - EXPECT_EQ(aicore_task.SetBlockDim(48), SUCCESS); - // 验证拷贝赋值函数 - KernelLaunchInfo copy_task = KernelLaunchInfo::CreateHcomRecordTask(op_exe_res_ctx); - copy_task = aicore_task; - EXPECT_EQ(copy_task.GetStreamId(), 2); - EXPECT_EQ(copy_task.GetBlockDim(), 48); - EXPECT_EQ(std::string(copy_task.GetArgsFormat()), "{ffts_addr}{i_instance0*}{i_instance1*}{i_instance2*}{o_instance0*}{o_instance1*}{ws*}{t}"); - - // 验证拷贝构造函数 - KernelLaunchInfo copy_task_2(copy_task); - EXPECT_EQ(copy_task_2.GetStreamId(), 2); - EXPECT_EQ(copy_task_2.GetBlockDim(), 48); - EXPECT_EQ(std::string(copy_task_2.GetArgsFormat()), "{ffts_addr}{i_instance0*}{i_instance1*}{i_instance2*}{o_instance0*}{o_instance1*}{ws*}{t}"); -} - -// 验证非aicore和aicpu算子设置blockdim场景 -TEST_F(TestGenTaskCallback, TestNonAicoreNodeSetBlockDimFailed) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - - auto notify_task = KernelLaunchInfo::CreateHcomRecordTask(op_exe_res_ctx); - EXPECT_EQ(notify_task.SetBlockDim(48), PARAM_INVALID); - EXPECT_EQ(notify_task.GetBlockDim(), 0); -} -// 验证非aicore和aicpu算子设置argsformat场景 -TEST_F(TestGenTaskCallback, TestNonAicoreNodeSetArgsFormatFailed) { - auto graph = CreateMc2NodeGraph(); - auto mc2_node = graph->FindNode("mc2"); - auto res_context_holder = CreateNodeExeResContext(mc2_node); - auto op_exe_res_ctx = reinterpret_cast(res_context_holder->GetKernelContext()); - - auto notify_task = KernelLaunchInfo::CreateHcomRecordTask(op_exe_res_ctx); - EXPECT_EQ(notify_task.SetArgsFormat("aaaaa"), PARAM_INVALID); - EXPECT_EQ(notify_task.GetArgsFormat(), nullptr); -} -} \ No newline at end of file diff --git a/tests/ut/graph/testcase/gnode_unittest.cc b/tests/ut/graph/testcase/gnode_unittest.cc deleted file mode 100644 index c31431873e0bd64c33b4c1ae2169de6009812595..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/gnode_unittest.cc +++ /dev/null @@ -1,928 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#define protected public -#define private public - -#include "graph/gnode.h" -#include "inc/external/graph/operator_reg.h" -#include "common/ge_common/ge_inner_error_codes.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/utils/node_adapter.h" -#include "graph_builder_utils.h" -#include "node_utils.h" -#include "graph/attr_value.h" - -#undef private -#undef protected - -namespace ge { -class GNodeTest : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(GNodeTest, GetALLSubgraphs) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("node", "node", 0, 0); - const auto &root_graph = root_builder.GetGraph(); - - auto sub_builder = ut::GraphBuilder("sub"); - const auto &sub_graph = sub_builder.GetGraph(); - root_graph->AddSubGraph(sub_graph); - sub_graph->SetParentNode(node); - sub_graph->SetParentGraph(root_graph); - node->GetOpDesc()->AddSubgraphName("branch1"); - node->GetOpDesc()->SetSubgraphInstanceName(0, "sub"); - - std::vector subgraphs; - ASSERT_EQ(NodeAdapter::Node2GNode(node).GetALLSubgraphs(subgraphs), GRAPH_SUCCESS); - ASSERT_EQ(subgraphs.size(), 1); -} - -TEST_F(GNodeTest, GetALLSubgraphs_nullptr_root_graph) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 0); - node->impl_->owner_graph_.reset(); - - std::vector subgraphs; - ASSERT_NE(NodeAdapter::Node2GNode(node).GetALLSubgraphs(subgraphs), GRAPH_SUCCESS); - ASSERT_TRUE(subgraphs.empty()); -} - -TEST_F(GNodeTest, GetInDataNodesAndPortIndexs_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node1 = builder.AddNode("node1", "node1", 0, 1); - const auto node2 = builder.AddNode("node2", "node2", 1, 0); - builder.AddDataEdge(node1, 0, node2, 0); - GNode gnode; - ASSERT_EQ(gnode.GetInDataNodesAndPortIndexs(0).first, nullptr); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetInDataNodesAndPortIndexs(0).first, nullptr); - gnode = NodeAdapter::Node2GNode(node2); - ASSERT_NE(gnode.GetInDataNodesAndPortIndexs(0).first, nullptr); -} - -TEST_F(GNodeTest, GetoutDataNodesAndPortIndexs_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node1 = builder.AddNode("node1", "node1", 0, 1); - const auto node2 = builder.AddNode("node2", "node2", 1, 0); - builder.AddDataEdge(node1, 0, node2, 0); - GNode gnode; - ASSERT_EQ(gnode.GetOutDataNodesAndPortIndexs(0).size(), 0); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetOutDataNodesAndPortIndexs(0).size(), 0); - gnode = NodeAdapter::Node2GNode(node1); - ASSERT_EQ(gnode.GetOutDataNodesAndPortIndexs(0).size(), 1); -} - -TEST_F(GNodeTest, GetInControlNodes_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node1 = builder.AddNode("node1", "node1", 1, 0); - const auto node2 = builder.AddNode("node2", "node2", 0, 1); - builder.AddControlEdge(node1, node2); - GNode gnode; - vector in_contorl_nodes = {}; - ASSERT_EQ(gnode.GetInControlNodes().size(), 0); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetInControlNodes().size(), 0); - gnode = NodeAdapter::Node2GNode(node2); - ASSERT_EQ(gnode.GetInControlNodes().size(), 1); -} - -TEST_F(GNodeTest, GetOutControlNodes_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node1 = builder.AddNode("node1", "node1", 1, 0); - const auto node2 = builder.AddNode("node2", "node2", 0, 1); - builder.AddControlEdge(node1, node2); - GNode gnode; - vector in_contorl_nodes = {}; - ASSERT_EQ(gnode.GetOutControlNodes().size(), 0); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetOutControlNodes().size(), 0); - gnode = NodeAdapter::Node2GNode(node1); - ASSERT_EQ(gnode.GetOutControlNodes().size(), 1); -} - -TEST_F(GNodeTest, Node2GNodePtr_success) { - auto builder = ut::GraphBuilder("graph"); - NodePtr node = nullptr; - ASSERT_EQ(NodeAdapter::Node2GNodePtr(node), nullptr); - node = builder.AddNode("node", "node", 0, 0); - ASSERT_NE(NodeAdapter::Node2GNodePtr(node), nullptr); -} - -TEST_F(GNodeTest, Node2GNode2Node_success) { - auto builder = ut::GraphBuilder("graph"); - NodePtr node = nullptr; - ASSERT_EQ(NodeAdapter::GNode2Node(NodeAdapter::Node2GNode(node)), nullptr); - node = builder.AddNode("node", "node", 0, 0); - ASSERT_EQ(NodeAdapter::GNode2Node(NodeAdapter::Node2GNode(node)), node); -} - -TEST_F(GNodeTest, GetName_Type_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node = builder.AddNode("name", "type", 0, 0); - GNode gnode; - AscendString name; - AscendString type; - ASSERT_EQ(gnode.GetName(name), GRAPH_FAILED); - ASSERT_EQ(gnode.GetType(type), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetName(name), GRAPH_FAILED); - ASSERT_EQ(gnode.GetType(type), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - gnode.GetName(name); - gnode.GetType(type); - ASSERT_EQ(name, "name"); - ASSERT_EQ(type, "type"); -} - -TEST_F(GNodeTest, GetInputConstData_success) { - auto sub_builder = ut::GraphBuilder("graph"); - const auto &sub_node = sub_builder.AddNode("sub_node", "Data", 3, 1); - const auto &sub_in_data_node = sub_builder.AddNode("sub_in_data_node", "Data", 1, 1); - const auto &sub_in_const_node = sub_builder.AddNode("sub_in_const_node", "Const", 1, 1); - const auto &sub_in_other_node = sub_builder.AddNode("sub_in_other_node", "node_1", 0, 1); - sub_builder.AddDataEdge(sub_in_const_node, 0, sub_node, 1); - sub_builder.AddDataEdge(sub_in_other_node, 0, sub_node, 2); - sub_builder.AddDataEdge(sub_in_data_node, 0, sub_node, 0); - EXPECT_TRUE(AttrUtils::SetInt(sub_in_data_node->GetOpDesc(), "_parent_node_index", 0)); - auto root_builder = ut::GraphBuilder("graph1"); - auto root_graph = root_builder.GetGraph(); - auto sub_graph = sub_builder.GetGraph(); - const auto &root_in_node = sub_builder.AddNode("root_in_node", "Const", 0, 1); - const auto &root_node = sub_builder.AddNode("root_node", "Data1", 1, 1); - root_builder.AddDataEdge(root_in_node, 0, root_node, 0); - sub_graph->SetParentGraph(root_graph); - sub_graph->SetParentNode(root_node); - GeTensor getensor; - AttrUtils::SetTensor(root_in_node->GetOpDesc(), "value", getensor); - AttrUtils::SetTensor(sub_in_const_node->GetOpDesc(), "value", getensor); - root_node->GetOpDesc()->AddSubgraphName("sub_graph"); - root_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); - root_graph->AddSubgraph("sub_graph", sub_graph); - Tensor data; - GNode gnode; - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetInputConstData(0, data), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(sub_node); - ASSERT_EQ(gnode.GetInputConstData(0, data), GRAPH_SUCCESS); - gnode = NodeAdapter::Node2GNode(sub_node); - ASSERT_EQ(gnode.GetInputConstData(1, data), GRAPH_SUCCESS); - gnode = NodeAdapter::Node2GNode(sub_node); - ASSERT_EQ(gnode.GetInputConstData(2, data), GRAPH_NODE_WITHOUT_CONST_INPUT); -} - -TEST_F(GNodeTest, GetInputIndexByName_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - const auto &in_other_node = builder.AddNode("in_other_node", "node_1", 0, 1); - builder.AddDataEdge(in_other_node, 0, node, 0); - AscendString name = nullptr; - GNode gnode; - int input_index; - ASSERT_EQ(gnode.GetInputIndexByName(name, input_index), GRAPH_PARAM_INVALID); - ASSERT_EQ(gnode.GetInputIndexByName("in_other_node", input_index), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetInputIndexByName("in_other_node", input_index), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetInputIndexByName("in_other_node", input_index), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, GetDynamicInputIndexesByName_Failed) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 3, 1); - const auto &in_node = builder.AddNode("in_node", "node_1", 0, 1); - builder.AddDataEdge(in_node, 0, node, 0); - AscendString name = nullptr; - GNode gnode; - std::vector input_indexes; - ASSERT_EQ(gnode.GetDynamicInputIndexesByName(name, input_indexes), GRAPH_PARAM_INVALID); - ASSERT_EQ(gnode.GetDynamicInputIndexesByName("in_node", input_indexes), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetDynamicInputIndexesByName("in_node", input_indexes), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetDynamicInputIndexesByName("in_node", input_indexes), GRAPH_FAILED); -} - -TEST_F(GNodeTest, GetDynamicInputIndexesByName_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 3, 1); - const auto &in_node = builder.AddNode("in_node", "node_1", 0, 1); - builder.AddDataEdge(in_node, 0, node, 0); - GNode gnode; - std::vector input_indexes; - auto op_desc = node->GetOpDesc(); - ge::GeTensorDesc tensor_desc; - op_desc->AddInputDesc("in_node0", tensor_desc); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetDynamicInputIndexesByName("in_node", input_indexes), GRAPH_SUCCESS); - ASSERT_EQ(input_indexes.size(), 1U); - input_indexes.clear(); - // 默认的名字是__input + index - ASSERT_EQ(gnode.GetDynamicInputIndexesByName("__input", input_indexes), GRAPH_SUCCESS); - ASSERT_EQ(input_indexes.size(), 3U); -} - -TEST_F(GNodeTest, GetOutputIndexByName_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 1); - const auto &in_other_node = builder.AddNode("in_other_node", "node_1", 1, 0); - builder.AddDataEdge(node, 0, in_other_node, 0); - AscendString name = nullptr; - GNode gnode; - int input_index; - ASSERT_EQ(gnode.GetOutputIndexByName(name, input_index), GRAPH_PARAM_INVALID); - ASSERT_EQ(gnode.GetOutputIndexByName("in_other_node", input_index), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetOutputIndexByName("in_other_node", input_index), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetOutputIndexByName("in_other_node", input_index), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, GetDynamicOutputIndexesByName_Failed) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 1); - const auto &out_node = builder.AddNode("out_node", "node_1", 1, 0); - builder.AddDataEdge(node, 0, out_node, 0); - AscendString name = nullptr; - GNode gnode; - std::vector output_indexes; - ASSERT_EQ(gnode.GetDynamicOutputIndexesByName(name, output_indexes), GRAPH_PARAM_INVALID); - ASSERT_EQ(gnode.GetDynamicOutputIndexesByName("out_node", output_indexes), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetDynamicOutputIndexesByName("out_node", output_indexes), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(out_node); - ASSERT_EQ(gnode.GetDynamicOutputIndexesByName("out_node", output_indexes), GRAPH_FAILED); -} - -TEST_F(GNodeTest, GetDynamicOutputIndexesByName_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 1); - const auto &out_node = builder.AddNode("out_node", "node_1", 1, 0); - builder.AddDataEdge(node, 0, out_node, 0); - AscendString name = nullptr; - GNode gnode; - std::vector output_indexes; - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetDynamicOutputIndexesByName("__output", output_indexes), GRAPH_SUCCESS); - ASSERT_EQ(output_indexes.size(), 1U); -} - -TEST_F(GNodeTest, GetInputs_outputs_Size_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - const auto &in_node = builder.AddNode("node_in", "node", 0, 1); - const auto &out_node = builder.AddNode("node_out", "node", 1, 0); - GNode gnode; - ASSERT_EQ(gnode.GetInputsSize(), GRAPH_FAILED); - ASSERT_EQ(gnode.GetOutputsSize(), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetInputsSize(), GRAPH_FAILED); - ASSERT_EQ(gnode.GetOutputsSize(), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetInputsSize(), 1); - ASSERT_EQ(gnode.GetOutputsSize(), 1); -} - -TEST_F(GNodeTest, GetInputDesc_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node = builder.AddNode("node", "node", 1, 1); - auto opdesc = node->GetOpDesc(); - GNode gnode; - TensorDesc tensordesc; - ASSERT_EQ(gnode.GetInputDesc(1, tensordesc), GRAPH_FAILED); - ASSERT_EQ(gnode.GetInputDesc(-1, tensordesc), GRAPH_PARAM_INVALID); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetInputDesc(0, tensordesc), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetInputDesc(0, tensordesc), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, UpdateInputDesc_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node = builder.AddNode("node", "node", 1, 1); - GNode gnode; - const TensorDesc tensordesc; - ASSERT_EQ(gnode.UpdateInputDesc(-1, tensordesc), GRAPH_PARAM_INVALID); - ASSERT_EQ(gnode.UpdateInputDesc(1, tensordesc), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.UpdateInputDesc(0, tensordesc), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.UpdateInputDesc(0, tensordesc), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, GetOutputDesc_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node = builder.AddNode("node", "node", 1, 1); - auto opdesc = node->GetOpDesc(); - GNode gnode; - TensorDesc tensordesc; - ASSERT_EQ(gnode.GetOutputDesc(1, tensordesc), GRAPH_FAILED); - ASSERT_EQ(gnode.GetOutputDesc(-1, tensordesc), GRAPH_PARAM_INVALID); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetOutputDesc(0, tensordesc), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetOutputDesc(0, tensordesc), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, UpdateOutputDesc_success) { - auto builder = ut::GraphBuilder("graph"); - const auto node = builder.AddNode("node", "node", 1, 1); - GNode gnode; - const TensorDesc tensordesc; - ASSERT_EQ(gnode.UpdateOutputDesc(-1, tensordesc), GRAPH_PARAM_INVALID); - ASSERT_EQ(gnode.UpdateOutputDesc(1, tensordesc), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.UpdateOutputDesc(0, tensordesc), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.UpdateOutputDesc(0, tensordesc), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, SetAttr1_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 0); - GNode gnode; - AscendString name = nullptr; - vector attr_value; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_PARAM_INVALID); - attr_value.emplace_back(name); - name = "node"; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_PARAM_INVALID); - attr_value.clear(); - attr_value.emplace_back(name); - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, SetAttr2_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 0); - GNode gnode; - AscendString name = nullptr; - AscendString attr_value = nullptr; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_PARAM_INVALID); - name = "node"; - ASSERT_NE(gnode.SetAttr(name, attr_value), GRAPH_SUCCESS); - attr_value = "value"; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, SetAttr3_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 0); - GNode gnode; - AscendString name = nullptr; - AttrValue attr_value; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_PARAM_INVALID); - name = "node"; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.SetAttr(name, attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, GetAttr1_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 0); - GNode gnode; - AscendString name = nullptr; - AscendString attr_value = "value"; - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_PARAM_INVALID); - name = "node"; - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetAttr("max_size", attr_value), GRAPH_FAILED); - gnode.SetAttr(name, attr_value); - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, GetAttr2_success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 0); - GNode gnode; - AscendString name = nullptr; - vector attr_value = {"value"}; - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_PARAM_INVALID); - name = "node"; - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetAttr("max_size", attr_value), GRAPH_FAILED); - int32_t attr_info = 0; - ASSERT_EQ(gnode.GetAttr("max_size", attr_info), GRAPH_FAILED); - gnode.SetAttr(name, attr_value); - attr_value.clear(); - ASSERT_EQ(gnode.GetAttr(name, attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, HasAttr_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto node = builder.AddNode("node", "node", 0, 0); - GNode gnode; - AscendString name = nullptr; - AscendString attr_value = "value"; - ASSERT_EQ(gnode.HasAttr(name), false); - name = "node"; - ASSERT_EQ(gnode.HasAttr(name), false); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.HasAttr(name), false); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.HasAttr("max_size"), false); - gnode.SetAttr(name, attr_value); - ASSERT_EQ(gnode.HasAttr(name), true); -} - -TEST_F(GNodeTest, GetSubGraph_success) { - auto sub_builder = ut::GraphBuilder("sub_graph"); - auto sub_graph = sub_builder.GetGraph(); - auto root_builder = ut::GraphBuilder("root_graph"); - const auto node = root_builder.AddNode("node", "node", 1, 1); - auto root_graph = root_builder.GetGraph(); - sub_graph->SetParentGraph(root_graph); - sub_graph->SetParentNode(node); - node->GetOpDesc()->AddSubgraphName("sub_graph"); - node->GetOpDesc()->SetSubgraphInstanceName(0, "sub_graph"); - root_graph->AddSubgraph("sub_graph", sub_graph); - GNode gnode; - GraphPtr graph; - ASSERT_EQ(gnode.GetSubgraph(0U, graph), GRAPH_FAILED); - gnode.impl_ = nullptr; - ASSERT_EQ(gnode.GetSubgraph(0U, graph), GRAPH_FAILED); - gnode = NodeAdapter::Node2GNode(node); - ASSERT_EQ(gnode.GetSubgraph(0U, graph), GRAPH_SUCCESS); -} -extern "C" { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_GNode_GetOutputAttr(void *node_ptr, const char *name, - int32_t output_index, - void *attr_value) { - if (node_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *node = static_cast(node_ptr); - auto *value = static_cast(attr_value); - return node->GetOutputAttr(ge::AscendString(name), output_index, *value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_GNode_SetOutputAttr(void *node_ptr, const char *name, - int32_t output_index, - const void *attr_value) { - if (node_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *node = static_cast(node_ptr); - auto *value = static_cast(attr_value); - return node->SetOutputAttr(ge::AscendString(name), output_index, *value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_GNode_GetInputAttr(void *node_ptr, const char *name, - int32_t input_index, - void *attr_value) { - if (node_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *node = static_cast(node_ptr); - auto *value = static_cast(attr_value); - return node->GetInputAttr(ge::AscendString(name), input_index, *value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_GNode_SetInputAttr(void *node_ptr, const char *name, - int32_t input_index, - const void *attr_value) { - if (node_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *node = static_cast(node_ptr); - auto *value = static_cast(attr_value); - return node->SetInputAttr(ge::AscendString(name), input_index, *value); -} -} -TEST_F(GNodeTest, ExternC_GNode_SetOutputAttr_Int64_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - AttrValue attr_value; - int64_t int64_value = 12345; - attr_value.SetAttrValue(int64_value); - - // 测试成功情况 - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, "test_attr", 0, &attr_value), GRAPH_SUCCESS); - - // 验证设置的值 - AttrValue get_attr_value; - EXPECT_EQ(gnode.GetOutputAttr("test_attr", 0, get_attr_value), GRAPH_SUCCESS); - int64_t get_value = 0; - EXPECT_EQ(get_attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, int64_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_GNode_SetOutputAttr(nullptr, "test_attr", 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, nullptr, 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, "test_attr", 0, nullptr), GRAPH_FAILED); -} - -TEST_F(GNodeTest, ExternC_GNode_SetOutputAttr_String_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - AttrValue attr_value; - const char *str_value = "test_string"; - attr_value.SetAttrValue(AscendString(str_value)); - - // 测试成功情况 - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, "test_attr", 0, &attr_value), GRAPH_SUCCESS); - - // 验证设置的值 - AttrValue get_attr_value; - EXPECT_EQ(aclCom_GNode_GetOutputAttr(&gnode, "test_attr", 0, &get_attr_value), GRAPH_SUCCESS); - AscendString get_value; - EXPECT_EQ(get_attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_STREQ(get_value.GetString(), str_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_GNode_SetOutputAttr(nullptr, "test_attr", 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, nullptr, 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, "test_attr", 0, nullptr), GRAPH_FAILED); -} - -TEST_F(GNodeTest, ExternC_GNode_SetOutputAttr_Bool_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - AttrValue attr_value; - bool bool_value = true; - attr_value.SetAttrValue(bool_value); - - // 测试成功情况 - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, "test_attr", 0, &attr_value), GRAPH_SUCCESS); - - // 验证设置的值 - AttrValue get_attr_value; - EXPECT_EQ(gnode.GetOutputAttr("test_attr", 0, get_attr_value), GRAPH_SUCCESS); - bool get_value = false; - EXPECT_EQ(get_attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, bool_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_GNode_SetOutputAttr(nullptr, "test_attr", 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, nullptr, 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetOutputAttr(&gnode, "test_attr", 0, nullptr), GRAPH_FAILED); -} - -TEST_F(GNodeTest, ExternC_GNode_SetOutputAttr_InvalidIndex) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - // 测试无效索引 - AttrValue attr_value; - int64_t int64_value = 12345; - attr_value.SetAttrValue(int64_value); - EXPECT_NE(aclCom_GNode_SetOutputAttr(&gnode, "test_attr", 999, &attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, ExternC_GNode_SetOutputAttr_EmptyGNode) { - GNode empty_gnode; - - // 测试空GNode - AttrValue attr_value; - int64_t int64_value = 12345; - attr_value.SetAttrValue(int64_value); - EXPECT_NE(aclCom_GNode_SetOutputAttr(&empty_gnode, "test_attr", 0, &attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, ExternC_GNode_SetInputAttr_Int64_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - AttrValue attr_value; - int64_t int64_value = 12345; - attr_value.SetAttrValue(int64_value); - - // 测试成功情况 - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, &attr_value), GRAPH_SUCCESS); - - // 验证设置的值 - AttrValue get_attr_value; - EXPECT_EQ(gnode.GetInputAttr("test_input_attr", 0, get_attr_value), GRAPH_SUCCESS); - int64_t get_value = 0; - EXPECT_EQ(get_attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, int64_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_GNode_SetInputAttr(nullptr, "test_input_attr", 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, nullptr, 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, nullptr), GRAPH_FAILED); -} - -TEST_F(GNodeTest, ExternC_GNode_SetInputAttr_String_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - AttrValue attr_value; - const char *str_value = "test_input_string"; - attr_value.SetAttrValue(AscendString(str_value)); - - // 测试成功情况 - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, &attr_value), GRAPH_SUCCESS); - - // 验证设置的值 - AttrValue get_attr_value; - EXPECT_EQ(aclCom_GNode_GetInputAttr(&gnode, "test_input_attr", 0, &get_attr_value), GRAPH_SUCCESS); - AscendString get_value; - EXPECT_EQ(get_attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_STREQ(get_value.GetString(), str_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_GNode_SetInputAttr(nullptr, "test_input_attr", 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, nullptr, 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, nullptr), GRAPH_FAILED); -} - -TEST_F(GNodeTest, ExternC_GNode_SetInputAttr_Bool_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - AttrValue attr_value; - bool bool_value = true; - attr_value.SetAttrValue(bool_value); - - // 测试成功情况 - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, &attr_value), GRAPH_SUCCESS); - - // 验证设置的值 - AttrValue get_attr_value; - EXPECT_EQ(gnode.GetInputAttr("test_input_attr", 0, get_attr_value), GRAPH_SUCCESS); - bool get_value = false; - EXPECT_EQ(get_attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, bool_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_GNode_SetInputAttr(nullptr, "test_input_attr", 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, nullptr, 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, nullptr), GRAPH_FAILED); -} - -TEST_F(GNodeTest, ExternC_GNode_SetInputAttr_VectorInt64_Success) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - AttrValue attr_value; - std::vector vector_value = {1, 2, 3, 4, 5}; - attr_value.SetAttrValue(vector_value); - - // 测试成功情况 - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, &attr_value), GRAPH_SUCCESS); - - // 验证设置的值 - AttrValue get_attr_value; - EXPECT_EQ(gnode.GetInputAttr("test_input_attr", 0, get_attr_value), GRAPH_SUCCESS); - std::vector get_value; - EXPECT_EQ(get_attr_value.GetAttrValue(get_value), GRAPH_SUCCESS); - EXPECT_EQ(get_value, vector_value); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_GNode_SetInputAttr(nullptr, "test_input_attr", 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, nullptr, 0, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 0, nullptr), GRAPH_FAILED); -} - -TEST_F(GNodeTest, ExternC_GNode_SetInputAttr_InvalidIndex) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 1, 1); - GNode gnode = NodeAdapter::Node2GNode(node); - - // 测试无效索引 - AttrValue attr_value; - int64_t int64_value = 12345; - attr_value.SetAttrValue(int64_value); - EXPECT_NE(aclCom_GNode_SetInputAttr(&gnode, "test_input_attr", 999, &attr_value), GRAPH_SUCCESS); -} - -TEST_F(GNodeTest, ExternC_GNode_SetInputAttr_EmptyGNode) { - GNode empty_gnode; - - // 测试空GNode - AttrValue attr_value; - int64_t int64_value = 12345; - attr_value.SetAttrValue(int64_value); - EXPECT_NE(aclCom_GNode_SetInputAttr(&empty_gnode, "test_input_attr", 0, &attr_value), GRAPH_SUCCESS); -} - -REG_OP(subTest) - .INPUT(inx, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .GRAPH(subgraph) - .OP_END_FACTORY_REG(subTest); - -REG_OP(dynamicSubTest) - .INPUT(inx, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .DYNAMIC_GRAPH(subgraphs) - .OP_END_FACTORY_REG(dynamicSubTest); - -REG_OP(dataOp) - .OUTPUT(y, TensorType::ALL()) - .ATTR(value, Int, 0) - .OP_END_FACTORY_REG(dataOp); - -TEST_F(GNodeTest, TestAddSubGraph_success) { - auto op = op::subTest("subTest"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs = {op_input1}; - AscendString name = "graph"; - GraphPtr graph = Graph::ConstructFromInputs(inputs, name); - auto gnode = graph->AddNodeByOp(op); - - name = "subgraph1"; - Operator op_input2 = op::dataOp("op_input2"); - inputs = {op_input2}; - auto subgraph = Graph::ConstructFromInputs(inputs, name); - ASSERT_EQ(GRAPH_SUCCESS, gnode.SetSubgraph("subgraph", *subgraph.get())); - - auto parent_graph = subgraph->GetParentGraph(); - AscendString parent_graph_name; - parent_graph->GetName(parent_graph_name); - AscendString graph_name; - graph->GetName(graph_name); - ASSERT_TRUE(parent_graph_name == graph_name); - - AscendString find_name; - op.GetName(find_name); - auto get_node = parent_graph->FindNodeByName(find_name); - AscendString exp_name("subTest"); - ASSERT_TRUE(find_name == exp_name); - - std::vector subgraph_names{}; - ASSERT_EQ(GRAPH_SUCCESS, op.GetSubgraphNames(subgraph_names)); - ASSERT_EQ(1, subgraph_names.size()); - AscendString exp_subgraph_name("subgraph"); - ASSERT_TRUE(exp_subgraph_name == subgraph_names.at(0)); -} - -TEST_F(GNodeTest, TestAddSubGraph_failure_invalid_graph_and_node) { - auto op = op::subTest("subTest"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs1 = {op_input1}; - AscendString name = "graph"; - GraphPtr graph_valid = Graph::ConstructFromInputs(inputs1, name); - auto node_valid = graph_valid->AddNodeByOp(op); - - Graph graph_invalid("graph_invalid"); - ASSERT_EQ(nullptr, graph_invalid.GetParentGraph()); - Graph subgraph_invalid("subgraph_invalid"); - ASSERT_EQ(ge::PARAM_INVALID, node_valid.SetSubgraph("subgraph", subgraph_invalid)); - - Operator op_input2 = op::dataOp("op_input"); - std::vector inputs2 = {op_input2}; - GraphPtr subgraph_valid = Graph::ConstructFromInputs(inputs2, name); - ASSERT_EQ(GRAPH_SUCCESS, node_valid.SetSubgraph("subgraph", *subgraph_valid.get())); -} - -TEST_F(GNodeTest, TestAddStaticSubGraph_success_different_name) { - auto op = op::subTest("subTest"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs1 = {op_input1}; - AscendString name = "graph"; - GraphPtr graph_valid = Graph::ConstructFromInputs(inputs1, name); - auto node_valid = graph_valid->AddNodeByOp(op); - - Operator op_input2 = op::dataOp("op_input"); - std::vector inputs2 = {op_input2}; - GraphPtr subgraph_valid = Graph::ConstructFromInputs(inputs2, name); - ASSERT_EQ(GRAPH_SUCCESS, node_valid.SetSubgraph("subgraph_diff", *subgraph_valid.get())); -} - -TEST_F(GNodeTest, TestAddDynamicSubGraph_success) { - auto op = op::dynamicSubTest("dynamicSubTest"); - Operator op_input_parent = op::dataOp("op_input1"); - std::vector inputs = {op_input_parent}; - AscendString name = "graph02"; - GraphPtr graph = Graph::ConstructFromInputs(inputs, name); - auto gnode = graph->AddNodeByOp(op); - - name = "subgraph1"; - Operator op_input1 = op::dataOp("op_input1"); - inputs = {op_input1}; - GraphPtr subgraph1 = Graph::ConstructFromInputs(inputs, name); - name = "subgraph2"; - Operator op_input2 = op::dataOp("op_input2"); - inputs = {op_input2}; - GraphPtr subgraph2 = Graph::ConstructFromInputs(inputs, name); - name = "subgraph3"; - Operator op_input3 = op::dataOp("op_input3"); - inputs = {op_input3}; - GraphPtr subgraph3 = Graph::ConstructFromInputs(inputs, name); - std::vector dynamic_subgraphs{*subgraph1.get(), *subgraph2.get(), *subgraph3.get()}; - ASSERT_EQ(GRAPH_SUCCESS, gnode.SetSubgraphs("subgraphs", dynamic_subgraphs)); - - auto parent_graph = subgraph1->GetParentGraph(); - AscendString parent_graph_name; - parent_graph->GetName(parent_graph_name); - AscendString graph_name; - graph->GetName(graph_name); - ASSERT_TRUE(parent_graph_name == graph_name); - - AscendString find_name; - op.GetName(find_name); - auto get_node = parent_graph->FindNodeByName(find_name); - AscendString exp_name("dynamicSubTest"); - ASSERT_TRUE(find_name == exp_name); - - std::vector ir_subgraph_names(0); - ASSERT_EQ(GRAPH_SUCCESS, op.GetSubgraphNames(ir_subgraph_names)); - ASSERT_EQ(1, ir_subgraph_names.size()); - AscendString exp_subgraph_name("subgraphs"); - ASSERT_TRUE(exp_subgraph_name == ir_subgraph_names.at(0)); - - std::vector dynamic_subgraph_vec{}; - ASSERT_EQ(GRAPH_SUCCESS, gnode.GetALLSubgraphs(dynamic_subgraph_vec)); - ASSERT_EQ(dynamic_subgraphs.size(), dynamic_subgraph_vec.size()); - AscendString subgraph_name{}; - ASSERT_EQ(GRAPH_SUCCESS, dynamic_subgraph_vec.at(0)->GetName(subgraph_name)); - AscendString exp_instance_subgraphs_name("subgraph1"); - ASSERT_TRUE(exp_instance_subgraphs_name == subgraph_name); - ASSERT_EQ(GRAPH_SUCCESS, dynamic_subgraph_vec.at(1)->GetName(subgraph_name)); - exp_instance_subgraphs_name = "subgraph2"; - ASSERT_TRUE(exp_instance_subgraphs_name == subgraph_name); - ASSERT_EQ(GRAPH_SUCCESS, dynamic_subgraph_vec.at(2)->GetName(subgraph_name)); - exp_instance_subgraphs_name = "subgraph3"; - ASSERT_TRUE(exp_instance_subgraphs_name == subgraph_name); - - GNodePtr op_gnode = parent_graph.get()->FindNodeByName("dynamicSubTest"); - auto op_node = NodeAdapter::GNode2Node(*op_gnode.get()).get(); - auto op_desc = op_node->GetOpDesc(); - auto &subgraph = op_desc->GetSubgraphNameIndexes(); - ASSERT_EQ(dynamic_subgraphs.size(), subgraph.size()); - ASSERT_EQ(2, subgraph.find("subgraphs2")->second); -} - -TEST_F(GNodeTest, TestAddDynamicSubGraph_failure_invalid_graph_and_node) { - auto op = op::dynamicSubTest("dynamicSubTest"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs1 = {op_input1}; - AscendString name = "graph"; - GraphPtr graph_valid = Graph::ConstructFromInputs(inputs1, name); - auto node_valid = graph_valid->AddNodeByOp(op); - - Graph graph_invalid("graph_invalid"); - ASSERT_EQ(nullptr, graph_invalid.GetParentGraph()); - Graph subgraph_invalid("subgraph_invalid"); - ASSERT_EQ(ge::PARAM_INVALID, node_valid.SetSubgraph("subgraphs", subgraph_invalid)); - - Operator op_input2 = op::dataOp("op_input"); - std::vector inputs2 = {op_input2}; - GraphPtr subgraph_valid = Graph::ConstructFromInputs(inputs2, name); - ASSERT_EQ(GRAPH_SUCCESS, node_valid.SetSubgraph("branch1", *subgraph_valid.get())); - - auto empty_node = GNode(); - ASSERT_EQ(PARAM_INVALID, empty_node.SetSubgraph("subgraphs", *subgraph_valid.get())); - ASSERT_EQ(PARAM_INVALID, empty_node.SetSubgraphs("subgraphs", std::vector{*subgraph_valid.get()})); - - empty_node.impl_ = nullptr; - ASSERT_EQ(PARAM_INVALID, empty_node.SetSubgraph("subgraphs", *subgraph_valid.get())); - ASSERT_EQ(PARAM_INVALID, empty_node.SetSubgraphs("subgraphs", std::vector{*subgraph_valid.get()})); -} - -TEST_F(GNodeTest, TestAddDynamicSubGraph_success_different_name) { - auto op = op::dynamicSubTest("dynamicSubTest"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs1 = {op_input1}; - AscendString name = "graph"; - GraphPtr graph_valid = Graph::ConstructFromInputs(inputs1, name); - auto node_valid = graph_valid->AddNodeByOp(op); - - Operator op_input2 = op::dataOp("op_input"); - std::vector inputs2 = {op_input2}; - GraphPtr subgraph_valid = Graph::ConstructFromInputs(inputs2, name); - ASSERT_EQ(GRAPH_SUCCESS, node_valid.SetSubgraphs("branch1", std::vector{*subgraph_valid.get()})); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/graph_dump_utils_unittest.cc b/tests/ut/graph/testcase/graph_dump_utils_unittest.cc deleted file mode 100644 index 4c4857678192bedf478bbfc34a23752e79c4da9c..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/graph_dump_utils_unittest.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/utils/graph_dump_utils.h" -#include "graph/ge_context.h" -#include "graph/utils/file_utils.h" -#include "graph_builder_utils.h" -#include "common/share_graph.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -std::stringstream GetFilePathWhenDumpPathSet(const string &ascend_work_path) { - std::stringstream dump_file_path; - dump_file_path << ascend_work_path << "/pid_" << mmGetPid() << "_deviceid_" << GetContext().DeviceId() << "/"; - return dump_file_path; -} -namespace { -std::string GetSpecificFilePath(const std::string &file_path, const string &suffix) { - DIR *dir; - struct dirent *ent; - dir = opendir(file_path.c_str()); - if (dir == nullptr) { - return ""; - } - while ((ent = readdir(dir)) != nullptr) { - if (strstr(ent->d_name, suffix.c_str()) != nullptr) { - std::string d_name(ent->d_name); - closedir(dir); - return file_path + "/" + d_name; - } - } - closedir(dir); - return ""; -} -} // namespace -using ExecuteSharedGraph = SharedGraph; -class UtestGraphDumpUtils : public testing::Test { - public: - static void SetUpTestCase() { - dump_graph_path_ = "./test_ge_graph_path"; - mmSetEnv("DUMP_GE_GRAPH", "1", 1); - mmSetEnv("DUMP_GRAPH_LEVEL", "1", 1); - mmSetEnv("DUMP_GRAPH_PATH", dump_graph_path_.c_str(), 1); - } - static void TearDownTestCase() { - unsetenv("DUMP_GE_GRAPH"); - unsetenv("DUMP_GRAPH_LEVEL"); - unsetenv("DUMP_GRAPH_PATH"); - system(("rm -rf " + dump_graph_path_).c_str()); - } - - static std::string dump_graph_path_; -}; -std::string UtestGraphDumpUtils::dump_graph_path_; - -TEST_F(UtestGraphDumpUtils, DumpGraph_Ok_with_execute_graph) { - auto exe_graph = ExecuteSharedGraph::BuildGraphWithSubGraph(); - DumpGraph(exe_graph.get(), "exe_graph_test"); - - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(dump_graph_path_); - std::string file_path = ge::RealPath(dump_file_path.str().c_str()); - // root graph - ComputeGraphPtr compute_graph = std::make_shared("GeTestGraph"); - ASSERT_EQ(GraphUtils::LoadGEGraph(GetSpecificFilePath(file_path, "_exe_graph_test.txt").c_str(), *compute_graph), true); - ASSERT_EQ(compute_graph->GetDirectNodesSize(), 5); - ASSERT_EQ(compute_graph->GetAllSubgraphs().size(), 2); - // sub graph 0 should not dump - ComputeGraphPtr subgraph_0 = std::make_shared("subgraph_0"); - ASSERT_EQ(GraphUtils::LoadGEGraph(GetSpecificFilePath(file_path, "_exe_graph_test_sub_graph_0.txt").c_str(), *subgraph_0), false); - // sub graph 1 should not dump - ComputeGraphPtr subgraph_1 = std::make_shared("subgraph_1"); - ASSERT_EQ(GraphUtils::LoadGEGraph(GetSpecificFilePath(file_path, "_exe_graph_test_sub_graph_1.txt").c_str(), *subgraph_1), false); - - system(("rm -rf " + dump_graph_path_).c_str()); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/graph_unittest.cc b/tests/ut/graph/testcase/graph_unittest.cc deleted file mode 100644 index 9371091c815da807ab68c3e1266e1f2c6413f014..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/graph_unittest.cc +++ /dev/null @@ -1,2027 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/graph.h" -#include "graph/operator.h" -#include "graph/compute_graph.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/op_desc.h" -#include "graph/node.h" -#include "graph/utils/graph_utils.h" -#include "external/graph/graph.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "inc/external/graph/operator_reg.h" -#include "inc/external/graph/operator.h" -#include "inc/external/graph/operator_factory.h" -#include "inc/external/graph/graph.h" -#include "inc/external/graph/graph_buffer.h" -#include "inc/graph/operator_factory_impl.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "graph_builder_utils.h" -#include "graph/ge_attr_value.h" -#include "ge_ir.pb.h" -#include "inc/common/ge_common/ge_inner_error_codes.h" -#include "inc/external/graph/tensor.h" -#include "inc/external/graph/ascend_string.h" -#include "inc/external/graph/types.h" -#include "graph/ge_context.h" - -#include -#include -#include -#include "graph/utils/file_utils.h" -#include "proto/onnx/ge_onnx.pb.h" -using namespace ge; -namespace { -std::stringstream GetFilePathWhenDumpPathSet(const string &ascend_work_path) { - std::stringstream dump_file_path; - dump_file_path << ascend_work_path << "/pid_" << mmGetPid() << "_deviceid_" << GetContext().DeviceId() << "/"; - return dump_file_path; -} -std::string GetSpecificFilePath(const std::string &file_path, const string &suffix) { - DIR *dir; - struct dirent *ent; - dir = opendir(file_path.c_str()); - if (dir == nullptr) { - return ""; - } - while ((ent = readdir(dir)) != nullptr) { - if (strstr(ent->d_name, suffix.c_str()) != nullptr) { - std::string d_name(ent->d_name); - closedir(dir); - return file_path + "/" + d_name; - } - } - closedir(dir); - return ""; -} -} // namespace -class UtestGraph : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -struct ExpectNodeInfo { - std::string name; - std::string type; - std::map> input_node_name; - std::map>> output_node_name; - std::vector control_input_node_name; - std::vector control_output_node_name; - int32_t input_desc_size; - int32_t output_desc_size; - ExpectNodeInfo(const std::string &in_name, const std::string &in_type, - const std::map> &in_input_node_name, - const std::map>> &in_output_node_name, - const std::vector &in_control_input_node_name, - const std::vector &in_control_output_node_name, - const int32_t in_input_desc_size, - const int32_t in_output_desc_size) - : name(in_name), type(in_type), input_node_name(in_input_node_name), - output_node_name(in_output_node_name), - control_input_node_name(in_control_input_node_name), - control_output_node_name(in_control_output_node_name), - input_desc_size(in_input_desc_size), output_desc_size(in_output_desc_size) {} -}; - -static ComputeGraphPtr BuildSubComputeGraph() { - ut::GraphBuilder builder = ut::GraphBuilder("subgraph"); - auto data = builder.AddNode("sub_Data", "sub_Data", 0, 1); - auto netoutput = builder.AddNode("sub_Netoutput", "sub_NetOutput", 1, 0); - builder.AddDataEdge(data, 0, netoutput, 0); - auto graph = builder.GetGraph(); - return graph; -} - -static void CheckNodeResult(const ComputeGraphPtr &compute_graph, - std::vector &expect_result) { - EXPECT_EQ(compute_graph->GetDirectNodesSize(), expect_result.size()); - size_t i = 0UL; - for (const auto &node : compute_graph->GetDirectNode()) { - std::cout << "node name: " << node->GetName() << ", expect name: " << expect_result[i].name << std::endl; - EXPECT_EQ(node->GetName(), expect_result[i].name); - EXPECT_EQ(node->GetType(), expect_result[i].type); - for (uint32_t in_index = 0UL; in_index < node->GetAllInDataAnchorsSize(); in_index++) { - const auto in_anchor = node->GetInDataAnchor(in_index); - ASSERT_NE(in_anchor, nullptr); - const auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - const auto iter = expect_result[i].input_node_name.find(in_index); - ASSERT_EQ(peer_out_anchor == nullptr, iter == expect_result[i].input_node_name.end()); - if (iter != expect_result[i].input_node_name.end()) { - EXPECT_EQ(iter->second.first, peer_out_anchor->GetOwnerNode()->GetName()); - EXPECT_EQ(iter->second.second, peer_out_anchor->GetIdx()); - } - } - for (uint32_t out_index = 0UL; out_index < node->GetAllOutDataAnchorsSize(); out_index++) { - const auto out_anchor = node->GetOutDataAnchor(out_index); - ASSERT_NE(out_anchor, nullptr); - const auto peer_in_anchors = out_anchor->GetPeerInDataAnchors(); - const auto iter = expect_result[i].output_node_name.find(out_index); - ASSERT_EQ(peer_in_anchors.size(), iter->second.size()); - for (size_t peer_in_index = 0UL; peer_in_index < peer_in_anchors.size(); peer_in_index++) { - EXPECT_EQ(iter->second[peer_in_index].first, peer_in_anchors.at(peer_in_index)->GetOwnerNode()->GetName()); - EXPECT_EQ(iter->second[peer_in_index].second, peer_in_anchors.at(peer_in_index)->GetIdx()); - } - } - const auto in_control_anchor = node->GetInControlAnchor(); - ASSERT_NE(in_control_anchor, nullptr); - const auto peer_out_control_anchors = in_control_anchor->GetPeerOutControlAnchors(); - ASSERT_EQ(peer_out_control_anchors.size(), expect_result[i].control_input_node_name.size()); - for (size_t control_out_index = 0UL; control_out_index < peer_out_control_anchors.size(); control_out_index++) { - EXPECT_EQ(expect_result[i].control_input_node_name.at(control_out_index), - peer_out_control_anchors.at(control_out_index)->GetOwnerNode()->GetName()); - } - const auto out_control_anchor = node->GetOutControlAnchor(); - ASSERT_NE(out_control_anchor, nullptr); - const auto peer_in_control_anchors = out_control_anchor->GetPeerInControlAnchors(); - ASSERT_EQ(peer_in_control_anchors.size(), expect_result[i].control_output_node_name.size()); - for (size_t control_in_index = 0UL; control_in_index < peer_in_control_anchors.size(); control_in_index++) { - EXPECT_EQ(expect_result[i].control_output_node_name[control_in_index], - peer_in_control_anchors.at(control_in_index)->GetOwnerNode()->GetName()); - } - const auto op_desc = node->GetOpDesc(); - ASSERT_NE(op_desc, nullptr); - EXPECT_EQ(op_desc->GetAllInputsSize(), expect_result[i].input_desc_size); - EXPECT_EQ(op_desc->GetOutputsSize(), expect_result[i].output_desc_size); - i++; - } -} - -// construct graph which contains subgraph -static ComputeGraphPtr BuildComputeGraph() { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - transdata->GetOpDesc()->AddSubgraphName("subgraph"); - transdata->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph"); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - // add subgraph - transdata->SetOwnerComputeGraph(graph); - ComputeGraphPtr subgraph = BuildSubComputeGraph(); - subgraph->SetParentGraph(graph); - subgraph->SetParentNode(transdata); - graph->AddSubgraph("subgraph", subgraph); - return graph; -} - -TEST_F(UtestGraph, copy_graph_01) { - ge::OpDescPtr add_op(new ge::OpDesc("add1", "Add")); - add_op->AddDynamicInputDesc("input", 2); - add_op->AddDynamicOutputDesc("output", 1); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - auto graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - ge::Graph copy_graph("copy_graph"); - ASSERT_EQ(copy_graph.CopyFrom(graph), ge::GRAPH_SUCCESS); - Graph graph2("graph2"); - ASSERT_EQ(copy_graph.CopyFrom(graph2), GRAPH_FAILED); - - auto cp_compute_graph = ge::GraphUtilsEx::GetComputeGraph(copy_graph); - ASSERT_NE(cp_compute_graph, nullptr); - ASSERT_NE(cp_compute_graph, compute_graph); - ASSERT_EQ(cp_compute_graph->GetDirectNodesSize(), 1); - auto cp_add_node = cp_compute_graph->FindNode("add1"); - ASSERT_NE(cp_add_node, nullptr); - ASSERT_NE(cp_add_node, add_node); -} - -TEST_F(UtestGraph, copy_graph_02) { - ge::OpDescPtr if_op(new ge::OpDesc("if", "If")); - if_op->AddDynamicInputDesc("input", 1); - if_op->AddDynamicOutputDesc("output", 1); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto if_node = compute_graph->AddNode(if_op); - auto graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - ge::Graph copy_graph("copy_graph"); - - if_op->AddSubgraphName("then_branch"); - if_op->AddSubgraphName("else_branch"); - if_op->SetSubgraphInstanceName(0, "then"); - if_op->SetSubgraphInstanceName(1, "else"); - - ge::OpDescPtr add_op1(new ge::OpDesc("add1", "Add")); - add_op1->AddDynamicInputDesc("input", 2); - add_op1->AddDynamicOutputDesc("output", 1); - std::shared_ptr then_compute_graph(new ge::ComputeGraph("then")); - auto add_node1 = then_compute_graph->AddNode(add_op1); - then_compute_graph->SetParentNode(if_node); - then_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(then_compute_graph); - - ge::OpDescPtr add_op2(new ge::OpDesc("add2", "Add")); - add_op2->AddDynamicInputDesc("input", 2); - add_op2->AddDynamicOutputDesc("output", 1); - std::shared_ptr else_compute_graph(new ge::ComputeGraph("else")); - auto add_node2 = else_compute_graph->AddNode(add_op2); - else_compute_graph->SetParentNode(if_node); - else_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(else_compute_graph); - - ASSERT_EQ(copy_graph.CopyFrom(graph), ge::GRAPH_SUCCESS); - - auto cp_compute_graph = ge::GraphUtilsEx::GetComputeGraph(copy_graph); - ASSERT_NE(cp_compute_graph, nullptr); - ASSERT_NE(cp_compute_graph, compute_graph); - ASSERT_EQ(cp_compute_graph->GetDirectNodesSize(), 1); - auto cp_if_node = cp_compute_graph->FindNode("if"); - ASSERT_NE(cp_if_node, nullptr); - ASSERT_NE(cp_if_node, if_node); - - auto cp_then_compute_graph = cp_compute_graph->GetSubgraph("then"); - ASSERT_NE(cp_then_compute_graph, nullptr); - ASSERT_NE(cp_then_compute_graph, then_compute_graph); - ASSERT_EQ(cp_then_compute_graph->GetDirectNodesSize(), 1); - auto cp_add_node1 = cp_then_compute_graph->FindNode("add1"); - ASSERT_NE(cp_add_node1, nullptr); - ASSERT_NE(cp_add_node1, add_node1); - - auto cp_else_compute_graph = cp_compute_graph->GetSubgraph("else"); - ASSERT_NE(cp_else_compute_graph, nullptr); - ASSERT_NE(cp_else_compute_graph, else_compute_graph); - ASSERT_EQ(cp_else_compute_graph->GetDirectNodesSize(), 1); - auto cp_add_node2 = cp_else_compute_graph->FindNode("add2"); - ASSERT_NE(cp_add_node2, nullptr); - ASSERT_NE(cp_add_node2, add_node2); -} - -REG_OP(Mul) - .OP_END_FACTORY_REG(Mul) -IMPL_INFER_VALUE_RANGE_FUNC(Mul, func){ - std::cout << "test" << std::endl; - return GRAPH_SUCCESS; -} - -REG_OP(Test2) - .OP_END_FACTORY_REG(Test2) -IMPL_INFER_VALUE_RANGE_FUNC(Test2, func2){ - std::cout << "test" << std::endl; - return GRAPH_SUCCESS; -} - -TEST_F(UtestGraph, test_infer_value_range_register_succ) { - string op_type = "Add"; - INFER_VALUE_RANGE_DEFAULT_REG(Add); - INFER_VALUE_RANGE_DEFAULT_REG(Test1); - auto para = OperatorFactoryImpl::GetInferValueRangePara(op_type); - ASSERT_EQ(para.is_initialized, true); - ASSERT_EQ(para.infer_value_func, nullptr); - - op_type = "Mul"; - INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Mul, INPUT_HAS_VALUE_RANGE, func); - INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Test2, INPUT_IS_DYNAMIC, func2); - para = OperatorFactoryImpl::GetInferValueRangePara(op_type); - ASSERT_EQ(para.is_initialized, true); - ASSERT_NE(para.infer_value_func, nullptr); - - op_type = "Sub"; - para = OperatorFactoryImpl::GetInferValueRangePara(op_type); - ASSERT_EQ(para.is_initialized, false); -} - -TEST_F(UtestGraph, IsRefFromRefData_HasNoAttr_ReturnFalse) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(ref_data, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - auto out_data_anchor = transdata->GetOutDataAnchor(0); - ASSERT_NE(out_data_anchor, nullptr); - - NodePtr node = nullptr; - bool is_ref_from_other = true; - EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); - EXPECT_FALSE(is_ref_from_other); -} - -TEST_F(UtestGraph, IsRefFromRefData_VarNameNotExist_ReturnFalse) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(ref_data, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - ge::AttrUtils::SetStr(transdata->GetOpDesc()->MutableOutputDesc(0), "ref_var_src_var_name", "not_exist"); - auto out_data_anchor = transdata->GetOutDataAnchor(0); - ASSERT_NE(out_data_anchor, nullptr); - - NodePtr node = nullptr; - bool is_ref_from_other = true; - EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); - EXPECT_FALSE(is_ref_from_other); -} - -TEST_F(UtestGraph, IsRefFromRefData_VarNameNodeIsNotRefData_ReturnFalse) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(ref_data, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - ge::AttrUtils::SetStr(transdata->GetOpDesc()->MutableOutputDesc(0), "ref_var_src_var_name", "NetOutput"); - auto out_data_anchor = transdata->GetOutDataAnchor(0); - ASSERT_NE(out_data_anchor, nullptr); - - NodePtr node = nullptr; - bool is_ref_from_other = true; - EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); - EXPECT_FALSE(is_ref_from_other); -} - -TEST_F(UtestGraph, IsRefFromRefData_ReturnTrue) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(ref_data, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - ge::AttrUtils::SetStr(transdata->GetOpDesc()->MutableOutputDesc(0), "ref_var_src_var_name", "ref_data"); - auto out_data_anchor = transdata->GetOutDataAnchor(0); - ASSERT_NE(out_data_anchor, nullptr); - - NodePtr node = nullptr; - bool is_ref_from_other = false; - EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); - EXPECT_TRUE(is_ref_from_other); -} - -TEST_F(UtestGraph, RefDataInSubgraph_IsRefFromInnerData_ReturnTrue) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); - auto partitioned_call = builder.AddNode("partitionedcall", "PartitionedCall", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(ref_data, 0, partitioned_call, 0); - builder.AddDataEdge(partitioned_call, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph"); - auto sub_data = sub_builder.AddNode("sub_Data", "Data", 0, 1); - AttrUtils::SetInt(sub_data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - auto sub_refdata = sub_builder.AddNode("sub_RefData", "RefData", 0, 1); - auto sub_netoutput = sub_builder.AddNode("sub_Netoutput", "NetOutput", 1, 0); - builder.AddControlEdge(sub_data, sub_refdata); - builder.AddDataEdge(sub_refdata, 0, sub_netoutput, 0); - auto sub_graph = sub_builder.GetGraph(); - - sub_graph->SetParentGraph(graph); - sub_graph->SetParentNode(partitioned_call); - graph->AddSubgraph("subgraph", sub_graph); - - auto out_data_anchor = sub_refdata->GetOutDataAnchor(0); - ASSERT_NE(out_data_anchor, nullptr); - - NodePtr node = nullptr; - bool is_ref_from_other = false; - EXPECT_EQ(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); - EXPECT_TRUE(is_ref_from_other); -} - -TEST_F(UtestGraph, RefDataInSubgraph_IsRefFromInnerData_PeerInCtrolNotData_InvalidGraph_ReturnFalse) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); - auto partitioned_call = builder.AddNode("partitionedcall", "PartitionedCall", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(ref_data, 0, partitioned_call, 0); - builder.AddDataEdge(partitioned_call, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph"); - auto sub_cast = sub_builder.AddNode("sub_Data", "Cast", 0, 1); - auto sub_refdata = sub_builder.AddNode("sub_RefData", "RefData", 0, 1); - auto sub_netoutput = sub_builder.AddNode("sub_Netoutput", "NetOutput", 1, 0); - builder.AddControlEdge(sub_cast, sub_refdata); - builder.AddDataEdge(sub_refdata, 0, sub_netoutput, 0); - auto sub_graph = sub_builder.GetGraph(); - - sub_graph->SetParentGraph(graph); - sub_graph->SetParentNode(partitioned_call); - graph->AddSubgraph("subgraph", sub_graph); - - auto out_data_anchor = sub_refdata->GetOutDataAnchor(0); - ASSERT_NE(out_data_anchor, nullptr); - - NodePtr node = nullptr; - bool is_ref_from_other = false; - EXPECT_NE(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); -} - -TEST_F(UtestGraph, RefDataInSubgraph_IsRefFromInnerData_MultiPeerInCtrl_InvalidGraph_ReturnFalse) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ref_data = builder.AddNode("ref_data", "RefData", 0, 1); - auto partitioned_call = builder.AddNode("partitionedcall", "PartitionedCall", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(ref_data, 0, partitioned_call, 0); - builder.AddDataEdge(partitioned_call, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph"); - auto sub_data = sub_builder.AddNode("sub_Data", "Data", 0, 1); - AttrUtils::SetInt(sub_data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - auto sub_cast = sub_builder.AddNode("sub_cast", "Cast", 0, 1); - auto sub_refdata = sub_builder.AddNode("sub_RefData", "RefData", 0, 1); - auto sub_netoutput = sub_builder.AddNode("sub_Netoutput", "NetOutput", 1, 0); - builder.AddControlEdge(sub_cast, sub_refdata); - builder.AddDataEdge(sub_refdata, 0, sub_netoutput, 0); - auto sub_graph = sub_builder.GetGraph(); - - sub_graph->SetParentGraph(graph); - sub_graph->SetParentNode(partitioned_call); - graph->AddSubgraph("subgraph", sub_graph); - - auto out_data_anchor = sub_refdata->GetOutDataAnchor(0); - ASSERT_NE(out_data_anchor, nullptr); - - NodePtr node = nullptr; - bool is_ref_from_other = false; - EXPECT_NE(GraphUtils::CheckIsRefFromOther(out_data_anchor, node, is_ref_from_other), GRAPH_SUCCESS); -} - -REG_OP(Shape) - .OP_END_FACTORY_REG(Shape) -IMPL_INFER_VALUE_RANGE_FUNC(Shape, ShapeValueInfer){ - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - auto output_tensor_desc = op_desc->MutableOutputDesc(0); - std::vector> in_shape_range; - op_desc->MutableInputDesc(0)->GetShapeRange(in_shape_range); - if (!in_shape_range.empty()) { - output_tensor_desc->SetValueRange(in_shape_range); - } - return GRAPH_SUCCESS; -} - -TEST_F(UtestGraph, test_value_range_infer_and_set_get) { - using std::make_pair; - string op_type = "Shape"; - INFER_VALUE_RANGE_CUSTOM_FUNC_REG(Shape, INPUT_IS_DYNAMIC, ShapeValueInfer); - auto graph = std::make_shared("test_graph"); - auto shape_op_desc = std::make_shared("node_name", op_type); - GeTensorDesc tensor_desc(GeShape({-1, -1, 4, 192}), ge::FORMAT_NCHW, DT_INT32); - std::vector> shape_range = {make_pair(1, 100), make_pair(1, 240), - make_pair(4, 4), make_pair(192, 192)}; - tensor_desc.SetShapeRange(shape_range); - shape_op_desc->AddInputDesc(tensor_desc); - GeTensorDesc out_tensor_desc(GeShape({4}), ge::FORMAT_NCHW, DT_INT32); - shape_op_desc->AddOutputDesc(out_tensor_desc); - auto shape_node = graph->AddNode(shape_op_desc); - Operator op = OpDescUtils::CreateOperatorFromNode(shape_node); - auto ret = OpDescUtilsEx::CallInferValueRangeFunc(shape_node->GetOpDesc(), op); - ASSERT_EQ(ret, GRAPH_SUCCESS); - - auto output_0_desc = shape_node->GetOpDesc()->GetOutputDesc(0); - std::vector> value_range; - output_0_desc.GetValueRange(value_range); - EXPECT_EQ(value_range.size(), 4); - - std::vector target_value_range = {1, 100, 1, 240, 4, 4, 192, 192}; - std::vector output_value_range; - for (auto pair : value_range) { - output_value_range.push_back(pair.first); - output_value_range.push_back(pair.second); - } - EXPECT_EQ(target_value_range, output_value_range); -} - -TEST_F(UtestGraph, get_all_graph_nodes) { - ComputeGraphPtr graph = BuildComputeGraph(); - auto nodes = graph->GetAllNodes(); - EXPECT_EQ(nodes.size(), 5); - - Graph graph2("Test"); - auto nodes_empty = graph2.GetAllNodes(); - EXPECT_EQ(nodes_empty.size(), 0); -} - -TEST_F(UtestGraph, SetOutputs_ops) { - Operator op1 = Operator("add"); - Operator op2 = Operator("op2"); - Operator op3 = Operator("op3"); - std::vector outputs = {op1, op2, op3}; - - Graph graph; - graph.SetOutputs(outputs); - EXPECT_EQ(graph.GetAllNodes().size(), 0); - // EXPECT_TRUE(graph.impl_->output_name_.empty()); // impl缺少头文件,找不到声明 -} - -TEST_F(UtestGraph, SetOutputs_string) { - using std::make_pair; - ge::OpDescPtr add_op(new ge::OpDesc("add_0", "add")); - add_op->AddDynamicInputDesc("input", 2); - add_op->AddDynamicOutputDesc("output", 1); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - - Operator op1 = Operator("add"); - Operator op2 = Operator("op2"); - Operator op3 = Operator("op3"); - std::string op_n1 = std::string("add"); - std::string op_n2 = std::string("op2"); - std::string op_n3 = std::string("op3"); - - std::vector> outputs = {make_pair(op1, op_n1), make_pair(op2, op_n2), - make_pair(op3, op_n3)}; - graph.SetOutputs(outputs); - EXPECT_EQ(graph.GetAllNodes().size(), 1); -} - -TEST_F(UtestGraph, SetOutputs_AscendString) { - using std::make_pair; - ge::OpDescPtr add_op(new ge::OpDesc("add_0", "add")); - add_op->AddDynamicInputDesc("input", 2); - add_op->AddDynamicOutputDesc("output", 1); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - - Operator op1 = Operator("add"); - Operator op2 = Operator("op2"); - Operator op3 = Operator("op3"); - AscendString op_n1 = AscendString("add"); - AscendString op_n2 = AscendString("op2"); - AscendString op_n3 = AscendString("op3"); - - std::vector> outputs = {make_pair(op1, op_n1), make_pair(op2, op_n2), - make_pair(op3, op_n3)}; - graph.SetOutputs(outputs); - EXPECT_EQ(graph.GetAllNodes().size(), 1); -} - -TEST_F(UtestGraph, SetOutputs_Index) { - using std::make_pair; - ge::OpDescPtr add_op(new ge::OpDesc("add_0", "add")); - add_op->AddDynamicInputDesc("input", 2); - add_op->AddDynamicOutputDesc("output", 1); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - Graph graph2; - - Operator op1 = Operator("add"); - Operator op2 = Operator("op2"); - Operator op3 = Operator("op3"); - std::vector vec_index1 = {0,1,2}; - std::vector vec_index2 = {0}; - std::vector vec_index3 = {0}; - - std::vector>> outputs = {make_pair(op1, vec_index1), - make_pair(op2, vec_index2), make_pair(op3, vec_index3)}; - graph2.SetOutputs(outputs); - graph.SetOutputs(outputs); - EXPECT_EQ(graph.GetAllNodes().size(), 1); -} - -TEST_F(UtestGraph, SetTargets) { - ge::OpDescPtr add_op(new ge::OpDesc("add_0", "add")); - add_op->AddDynamicInputDesc("input", 2); - add_op->AddDynamicOutputDesc("output", 1); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - Graph graph2; - - Operator op1 = Operator("add"); - Operator op2 = Operator("op2"); - Operator op3 = Operator("op3"); - std::vector vec_index1 = {0,1,2}; - std::vector vec_index2 = {0}; - std::vector vec_index3 = {0}; - - std::vector targets = {op1, op2, op3}; - - graph2.SetTargets(targets); - graph.SetTargets(targets); - EXPECT_EQ(graph.GetAllNodes().size(), 1); -} - -TEST_F(UtestGraph, SetNeedIteration) { - ge::OpDescPtr add_op(new ge::OpDesc("add_0", "add")); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - Graph graph2; - - graph2.SetNeedIteration(true); - graph.SetNeedIteration(false); - EXPECT_EQ(graph.GetAllNodes().size(), 1); -} - -TEST_F(UtestGraph, GetDirectNode) { - ge::OpDescPtr add_op(new ge::OpDesc("add_0", "add")); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - - ge::OpDescPtr add_op2(new ge::OpDesc("add_1", "add")); - std::shared_ptr compute_graph2 = nullptr; - Graph graph2 = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph2); - - Graph graph3; - - std::vector gnodes, gnodes2, gnodes3; - - gnodes = graph.GetDirectNode(); - gnodes2 = graph2.GetDirectNode(); - gnodes3 = graph3.GetDirectNode(); - EXPECT_EQ(gnodes.size(), 1); -} - -TEST_F(UtestGraph, RemoveNode) { - ComputeGraphPtr cgp = BuildComputeGraph(); - auto v_nodes = cgp->GetAllNodes(); - EXPECT_EQ(v_nodes.size(), 5); - - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - - auto nodes = graph.GetAllNodes(); - graph.RemoveNode(nodes[4]); - EXPECT_EQ(graph.GetAllNodes().size(), 4); - - graph.RemoveNode(nodes[0], true); - EXPECT_EQ(graph.GetAllNodes().size(), 3); -} - -TEST_F(UtestGraph, AddRemoveEdge1) { - Operator op1 = Operator("add"); - Operator op2 = Operator("op2"); - Operator op3 = Operator("op3"); - - Graph graph("a_graph"); - Graph graph2; - - GNode node1 = graph.AddNodeByOp(op1); - GNode node2 = graph.AddNodeByOp(op2); - GNode node3 = graph.AddNodeByOp(op3); - - auto ret =graph.AddDataEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = graph.AddControlEdge(node2, node3); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = graph.RemoveEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_FAILED); - - graph2.AddNodeByOp(op1); - ret =graph2.AddDataEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = graph2.AddControlEdge(node2, node3); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = graph2.RemoveEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraph, AddRemoveEdge2) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - ComputeGraphPtr cgp = builder.GetGraph(); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - - auto nodes = graph.GetAllNodes(); - EXPECT_EQ(nodes.size(), 1); - - GNode node1 = nodes[0]; - GNode node2; - - auto ret =graph.AddDataEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = graph.RemoveEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = graph.AddControlEdge(node1, node2); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraph, AddRemoveEdge3) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - ComputeGraphPtr cgp = builder.GetGraph(); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - - auto nodes = graph.GetAllNodes(); - EXPECT_EQ(nodes.size(), 3); - - GNode node1 = nodes[0]; - GNode node2 = nodes[1]; - GNode node3 = nodes[2]; - - auto ret = graph.AddDataEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = graph.AddControlEdge(node2, node3); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = graph.RemoveEdge(node1, 0, node2, 0); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraph, ConstructFromInputs1) { - Graph graph; - Operator op1 = Operator("op1"); - Operator op2 = Operator("op2"); - Operator op3 = Operator("op3"); - std::vector inputs = {op1, op2, op3}; - AscendString name = "graph_name"; - - auto ret = graph.ConstructFromInputs({}, name); - EXPECT_EQ(ret, nullptr); - - ret = graph.ConstructFromInputs(inputs, AscendString(nullptr)); - EXPECT_EQ(ret, nullptr); - - ret = graph.ConstructFromInputs(inputs, name); - EXPECT_EQ(ret, nullptr); -} - -REG_OP(Phony0) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(value, Tensor, Tensor()) - .OP_END_FACTORY_REG(Phony0); - -REG_OP(Phony1) - .DYNAMIC_INPUT(x, TensorType::NumberType()) - .OUTPUT(y, TensorType::NumberType()) - .REQUIRED_ATTR(N, Int) - .OP_END_FACTORY_REG(Phony1); - -REG_OP(Phony2) - .INPUT(x, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .INPUT(shape, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(axis, Int, 0) - .ATTR(num_axes, Int, -1) - .OP_END_FACTORY_REG(Phony2); - -TEST_F(UtestGraph, ConstructFromInputs2) { - Graph graph; - Operator op1 = op::Phony0("op1"); - Operator op2 = op::Phony1("op2"); - Operator op3 = op::Phony2("op3"); - std::vector inputs = {op1, op2, op3}; - AscendString name = "graph_name"; - - auto ret = graph.ConstructFromInputs(inputs, name); - EXPECT_NE(ret, nullptr); -} - -TEST_F(UtestGraph, SaveLoadFile) { - system("rm -rf ./ut_graph1.txt"); - system("rm -rf ./ut_graph2.txt"); - - ComputeGraphPtr cgp = BuildComputeGraph(); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - - auto ret = graph.SaveToFile(nullptr); - EXPECT_EQ(ret, GRAPH_FAILED); - - ret = graph.SaveToFile("./ut_graph1.txt"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = graph.SaveToFile(std::string("./ut_graph2.txt")); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - Graph graph2; - ret = graph2.LoadFromFile(nullptr); - EXPECT_EQ(ret, GRAPH_FAILED); - - Graph graph3; - ret = graph3.LoadFromFile("./ut_graph1.txt"); - EXPECT_NE(ret, GRAPH_FAILED); - - Graph graph4; - ret = graph4.LoadFromFile(std::string("./ut_graph2.txt")); - EXPECT_NE(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraph, LoadFromSerializedModelArray_InvalidParams) { - ge::proto::ModelDef model_def; - auto *graph_def = model_def.add_graph(); - graph_def->set_name("serialized_model_array_graph"); - - Graph graph; - EXPECT_NE(graph.LoadFromSerializedModelArray(nullptr, 0), GRAPH_SUCCESS); - - std::string serialized; - EXPECT_NE(graph.LoadFromSerializedModelArray(reinterpret_cast(serialized.c_str()), 0), GRAPH_SUCCESS); - - serialized = "abc"; - EXPECT_NE(graph.LoadFromSerializedModelArray(reinterpret_cast(serialized.c_str()), serialized.size()), GRAPH_SUCCESS); -} - - -std::vector CreateOpDef(ge::proto::GraphDef *def, const std::string &type, const std::vector &inputs, - size_t num_outputs, std::vector subgraphs = {}) { - auto name = type + std::to_string(def->op_size()); - auto *op_def = def->add_op(); - op_def->set_name(name); - op_def->set_type(type); - - - auto op_desc_attr = op_def->mutable_attr(); - proto::AttrDef input_desc_name; - proto::AttrDef input_desc_index; - proto::AttrDef output_desc_name; - proto::AttrDef output_desc_index; - - for (size_t i = 0U; i < inputs.size(); ++i) { - op_def->add_input_desc(); - *op_def->add_input() = inputs[i]; - - input_desc_name.mutable_list()->add_s(std::string("x") + std::to_string(i)); - input_desc_index.mutable_list()->add_i(i); - } - std::vector outputs; - for (size_t i = 0U; i < num_outputs; ++i) { - op_def->add_output_desc(); - outputs.push_back(op_def->name() + ":" + std::to_string(i)); - - output_desc_name.mutable_list()->add_s(std::string("y") + std::to_string(i)); - output_desc_index.mutable_list()->add_i(i); - } - - (void) op_desc_attr->insert({"_input_name_key", input_desc_name}); - (void) op_desc_attr->insert({"_input_name_value", input_desc_index}); - - (void) op_desc_attr->insert({"_output_name_key", output_desc_name}); - (void) op_desc_attr->insert({"_output_name_value", output_desc_index}); - - for (auto &subgraph : subgraphs) { - op_def->add_subgraph_name(subgraph); - } - - if (num_outputs == 0) { - outputs.push_back(op_def->name()); - } - - return outputs; -} - - -std::string GetStringBeforeColon(const std::string& str) { - size_t pos = str.find(':'); - if (pos != std::string::npos) { - return str.substr(0, pos); - } else { - return str; - } -} - - -void AssertOpMatch(ge::ComputeGraphPtr &compute_graph, const std::vector &op, - const std::vector &inputs, size_t num_outputs) { - auto op_name = GetStringBeforeColon(op[0]); - auto data = compute_graph->FindNode(op_name); - ASSERT_NE(data, nullptr); - ASSERT_EQ(data->GetInDataNodesAndAnchors().size(), inputs.size()); - size_t index = 0U; - for (auto &node_and_anchor : data->GetInDataNodesAndAnchors()) { - auto input = node_and_anchor.first->GetName() + ":" + std::to_string(node_and_anchor.second->GetIdx()); - ASSERT_EQ(input, inputs[index]); - index++; - } - auto in_name_idx = data->GetOpDesc()->GetAllInputName(); - ASSERT_EQ(in_name_idx.size(), inputs.size()); - index = 0U; - for (auto &name_idx : in_name_idx) { - ASSERT_EQ(name_idx.first, "x" + std::to_string(index)); - ASSERT_EQ(name_idx.second, index); - index++; - } - auto out_name_idx = data->GetOpDesc()->GetAllOutputName(); - ASSERT_EQ(out_name_idx.size(), num_outputs); - index = 0U; - for (auto &name_idx : out_name_idx) { - ASSERT_EQ(name_idx.first, "y" + std::to_string(index)); - ASSERT_EQ(name_idx.second, index); - index++; - } -} - - -TEST_F(UtestGraph, LoadFromSerializedModelArray_NoSubGraph) { - ge::proto::ModelDef model_def; - auto *graph_def = model_def.add_graph(); - graph_def->set_name("root_graph"); - - auto data = CreateOpDef(graph_def, "Data", {}, 1); - auto abs = CreateOpDef(graph_def, "Abs", data, 1); - auto sqrt = CreateOpDef(graph_def, "Add", {data[0], abs[0]}, 1); - auto netoutput = CreateOpDef(graph_def, "NetOutput", {abs[0], sqrt[0]}, 0); - - Graph graph; - auto serialized = model_def.SerializeAsString(); - ASSERT_EQ(graph.LoadFromSerializedModelArray(serialized.c_str(), serialized.size()), GRAPH_SUCCESS); - - auto compute_graph = ge::GraphUtilsEx::GetComputeGraph(graph); - ASSERT_EQ(compute_graph->GetName(), graph_def->name()); - - AssertOpMatch(compute_graph, data, {}, 1); - AssertOpMatch(compute_graph, abs, data, 1); - AssertOpMatch(compute_graph, sqrt, {data[0], abs[0]}, 1); - AssertOpMatch(compute_graph, netoutput, {abs[0], sqrt[0]}, 0); -} - -TEST_F(UtestGraph, LoadFromSerializedModelArray_WithSubGraph) { - ge::proto::ModelDef model_def; - auto *graph_def = model_def.add_graph(); - graph_def->set_name("root_graph"); - auto func = CreateOpDef(graph_def, "FuncOp", {}, 0, {"sub_graph"}); - - auto *sub_graph = model_def.add_graph(); - sub_graph->set_name("sub_graph"); - auto data = CreateOpDef(sub_graph, "Data", {}, 1); - auto abs = CreateOpDef(sub_graph, "Abs", data, 1); - auto sqrt = CreateOpDef(sub_graph, "Add", {data[0], abs[0]}, 1); - auto netoutput = CreateOpDef(sub_graph, "NetOutput", {abs[0], sqrt[0]}, 0); - - Graph graph; - auto serialized = model_def.SerializeAsString(); - ASSERT_EQ(graph.LoadFromSerializedModelArray(serialized.c_str(), serialized.size()), GRAPH_SUCCESS); - - auto compute_graph = ge::GraphUtilsEx::GetComputeGraph(graph); - ASSERT_EQ(compute_graph->GetName(), graph_def->name()); - - ASSERT_EQ(compute_graph->GetAllSubgraphs().size(), 1U); - auto sub_compute_graph = compute_graph->GetSubgraph("sub_graph"); - ASSERT_NE(sub_compute_graph, nullptr); - ASSERT_EQ(sub_compute_graph->GetName(), "sub_graph"); - - auto func_op = compute_graph->FindNode(GetStringBeforeColon(func[0])); - ASSERT_NE(func_op, nullptr); - ASSERT_EQ(sub_compute_graph->GetParentNode(), func_op); - ASSERT_EQ(sub_compute_graph->GetParentGraph(), compute_graph); - - AssertOpMatch(sub_compute_graph, data, {}, 1); - AssertOpMatch(sub_compute_graph, abs, data, 1); - AssertOpMatch(sub_compute_graph, sqrt, {data[0], abs[0]}, 1); - AssertOpMatch(sub_compute_graph, netoutput, {abs[0], sqrt[0]}, 0); -} - -TEST_F(UtestGraph, LoadFromSerializedModelArray_WithNestedSubGraph) { - ge::proto::ModelDef model_def; - auto *graph_def = model_def.add_graph(); - graph_def->set_name("root_graph"); - auto func = CreateOpDef(graph_def, "FuncOp", {}, 0, {"sub_graph"}); - - auto *sub_graph0 = model_def.add_graph(); - sub_graph0->set_name("sub_graph"); - auto func1 = CreateOpDef(sub_graph0, "FuncOp1", {}, 0, {"sub_graph1"}); - - auto *sub_graph1 = model_def.add_graph(); - sub_graph1->set_name("sub_graph1"); - auto data = CreateOpDef(sub_graph1, "Data", {}, 1); - auto abs = CreateOpDef(sub_graph1, "Abs", data, 1); - auto sqrt = CreateOpDef(sub_graph1, "Add", {data[0], abs[0]}, 1); - auto netoutput = CreateOpDef(sub_graph1, "NetOutput", {abs[0], sqrt[0]}, 0); - - Graph graph; - auto serialized = model_def.SerializeAsString(); - ASSERT_EQ(graph.LoadFromSerializedModelArray(serialized.c_str(), serialized.size()), GRAPH_SUCCESS); - - auto compute_graph = ge::GraphUtilsEx::GetComputeGraph(graph); - ASSERT_EQ(compute_graph->GetName(), graph_def->name()); - - ASSERT_EQ(compute_graph->GetAllSubgraphs().size(), 2U); - auto sub_compute_graph = compute_graph->GetSubgraph("sub_graph"); - ASSERT_NE(sub_compute_graph, nullptr); - ASSERT_EQ(sub_compute_graph->GetName(), "sub_graph"); - - auto sub_compute_graph1 = compute_graph->GetSubgraph("sub_graph1"); - ASSERT_NE(sub_compute_graph1, nullptr); - ASSERT_EQ(sub_compute_graph1->GetName(), "sub_graph1"); - - auto func_op = compute_graph->FindNode(GetStringBeforeColon(func[0])); - ASSERT_NE(func_op, nullptr); - ASSERT_EQ(sub_compute_graph->GetParentNode(), func_op); - ASSERT_EQ(sub_compute_graph->GetParentGraph(), compute_graph); - - auto func_op1 = sub_compute_graph->FindNode(GetStringBeforeColon(func1[0])); - ASSERT_NE(func_op1, nullptr); - ASSERT_EQ(sub_compute_graph1->GetParentNode(), func_op1); - ASSERT_EQ(sub_compute_graph1->GetParentGraph(), sub_compute_graph); - - AssertOpMatch(sub_compute_graph1, data, {}, 1); - AssertOpMatch(sub_compute_graph1, abs, data, 1); - AssertOpMatch(sub_compute_graph1, sqrt, {data[0], abs[0]}, 1); - AssertOpMatch(sub_compute_graph1, netoutput, {abs[0], sqrt[0]}, 0); -} - -TEST_F(UtestGraph, SaveAndLoadMemWithBuffer) { - ComputeGraphPtr cgp = BuildComputeGraph(); - Graph graph1 = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - - GraphBuffer buf1; - auto ret = graph1.SaveToMem(buf1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - Graph graph2; - ret = graph2.LoadFromMem(buf1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - GraphBuffer buf2; - ret = graph2.SaveToMem(buf2); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - EXPECT_EQ(buf1.GetSize(), buf2.GetSize()); - EXPECT_EQ(memcmp(buf1.GetData(), buf2.GetData(), buf1.GetSize()), 0); -} - -TEST_F(UtestGraph, SaveAndLoadMemWithData) { - ComputeGraphPtr cgp = BuildComputeGraph(); - Graph graph1 = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - - GraphBuffer buf1; - auto ret = graph1.SaveToMem(buf1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - Graph graph2; - ret = graph2.LoadFromMem(buf1.GetData(), buf1.GetSize()); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - GraphBuffer buf2; - ret = graph2.SaveToMem(buf2); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - EXPECT_EQ(buf1.GetSize(), buf2.GetSize()); - EXPECT_EQ(memcmp(buf1.GetData(), buf2.GetData(), buf1.GetSize()), 0); -} - -TEST_F(UtestGraph, LoadFromMemFailed) { - GraphBuffer buf; - Graph graph; - auto ret = graph.LoadFromMem(buf.GetData(), buf.GetSize()); - EXPECT_NE(ret, GRAPH_SUCCESS); - - ret = graph.LoadFromMem(nullptr, 0); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraph, ErrorCodeCheck) { - EXPECT_EQ(ge::FAILED, 4294967295); - EXPECT_EQ(ge::END_OF_SEQUENCE, 1343225863); - EXPECT_EQ(ge::GE_GRAPH_SAVE_WEIGHTS_FAILED, 1343242286); - - EXPECT_EQ(strcmp(GE_GET_ERRORNO_STR(ge::END_OF_SEQUENCE).c_str(), "End of sequence!"), 0); - EXPECT_EQ(strcmp(GE_GET_ERRORNO_STR(ge::FAILED).c_str(), "failed"), 0); - EXPECT_EQ(strcmp(GE_GET_ERRORNO_STR(ge::GE_GRAPH_SAVE_WEIGHTS_FAILED).c_str(), - "OMG Save Weights to Model failed."), 0); -} - -TEST_F(UtestGraph, GetName) { - Graph graph; - AscendString name; - auto ret = graph.GetName(name); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraph, RecoverGraphOperators) { - ComputeGraphPtr cgp = BuildComputeGraph(); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - auto ret = GraphUtilsEx::RecoverGraphOperators(graph); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraph, GetOpName) { - ComputeGraphPtr cgp = BuildComputeGraph(); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(cgp); - - Operator op1("add"); - auto ret = graph.AddOp(op1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - std::vector op_names1; - ret = graph.GetAllOpName(op_names1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - std::vector op_names2; - ret = graph.GetAllOpName(op_names2); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraph, FindOpByName) { - Graph graph; - Operator op1 = op::Phony0("op1"); - Operator op2 = op::Phony1("op2"); - Operator op3 = op::Phony2("op3"); - std::vector inputs = {op1, op2, op3}; - AscendString name = "graph_name"; - - GraphPtr gptr = Graph::ConstructFromInputs(inputs, name); - - EXPECT_EQ(gptr->GetAllNodes().size(), 2); - - Operator op1_2; - auto ret = gptr->FindOpByName(nullptr, op1_2); - ret = gptr->FindOpByName("op1", op1_2); - EXPECT_EQ(ret, GRAPH_FAILED); - - Operator op2_2; - ret = gptr->FindOpByName(std::string("op2"), op2_2); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraph, FindOpByType) { - Graph graph; - Operator op1 = op::Phony0("op1"); - Operator op2 = op::Phony1("op2"); - Operator op3 = op::Phony2("op3"); - std::vector inputs = {op1, op2, op3}; - AscendString name = "graph_name"; - - GraphPtr gptr = Graph::ConstructFromInputs(inputs, name); - - std::vector op1_2; - auto ret = gptr->FindOpByType(nullptr, op1_2); - ret = gptr->FindOpByType("const", op1_2); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - std::vector op2_2; - ret = gptr->FindOpByType(std::string("data"), op2_2); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraph, SaveInvalidPath) { - std::vector inputs{}; - std::vector outputs{}; - Graph graph("empty_graph"); - graph.SetInputs(inputs).SetOutputs(outputs); - std::string file_name = "....1263713612~"; - EXPECT_EQ(graph.SaveToFile(file_name), GRAPH_FAILED); -} - -/* - Data Data Const Variable - | | | / | - \ | / / | - ConcatV2 -- | DATA - | \ / | - | IdentityN ------- - \ | - ---- MatmulV2 -*/ -TEST_F(UtestGraph, TestGenerateGraphWithControlEdge) { - ge::Operator data1 = ge::Operator("Data_0", "Data"); - ge::Operator data2 = ge::Operator("Data_1", "Data"); - ge::Operator const_op = ge::Operator("Constant_0", "Constant"); - ge::Operator data3 = ge::Operator("Data_2", "Data"); - ge::Operator variable = ge::Operator("Variable_0", "Variable"); - ge::Operator concat_v2 = ge::Operator("ConcatV2_0", "ConcatV2"); - ge::Operator identity_n = ge::Operator("IdentityN_0", "IdentityN"); - ge::Operator matmul_v2 = ge::Operator("MatmulV2_0", "MatmulV2"); - - data1.InputRegister("x"); - data1.OutputRegister("y"); - data2.InputRegister("x"); - data2.OutputRegister("y"); - data3.InputRegister("x"); - data3.OutputRegister("y"); - const_op.OutputRegister("y"); - variable.InputRegister("x"); - variable.OutputRegister("y"); - concat_v2.DynamicInputRegister("x", 2); - concat_v2.InputRegister("concat_dim"); - concat_v2.OutputRegister("y"); - identity_n.DynamicInputRegister("x", 3); - identity_n.DynamicOutputRegister("y", 3); - matmul_v2.InputRegister("x1"); - matmul_v2.InputRegister("x2"); - matmul_v2.OptionalInputRegister("bias"); - matmul_v2.OptionalInputRegister("offset_w"); - matmul_v2.OutputRegister("y"); - concat_v2.SetInput(0U, data1, 0U); - concat_v2.SetInput(1U, data2, 0U); - concat_v2.SetInput(2U, const_op, 0U); - identity_n.SetInput(0U, concat_v2, 0U); - identity_n.SetInput(1U, variable, 0U); - identity_n.SetInput(2U, data3, 0U); - matmul_v2.SetInput(0U, identity_n, 0U); - matmul_v2.SetInput(1U, identity_n, 1U); - matmul_v2.SetInput(2U, identity_n, 2U); - matmul_v2.AddControlInput(concat_v2); - concat_v2.AddControlInput(variable); - std::vector ops{data1, const_op, data2, variable, concat_v2, data3, identity_n, matmul_v2}; - Graph graph("stable_sort_graph"); - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperatorWithStableTopo(graph, ops), SUCCESS); - const auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - EXPECT_EQ(compute_graph->GetName(), "stable_sort_graph"); - EXPECT_EQ(compute_graph->GetDirectNodesSize(), 8); - std::vector expect_node_info; - std::map> input_node_name; - std::map>> output_node_name; - std::vector control_input_node_name; - std::vector control_output_node_name; - std::vector> temp_vector = {{"ConcatV2_0", 0}}; - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Data_0", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("ConcatV2_0", 2)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Constant_0", "Constant", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 0, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("ConcatV2_0", 1)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Data_1", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("IdentityN_0", 1)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - control_output_node_name.emplace_back("ConcatV2_0"); - expect_node_info.emplace_back(ExpectNodeInfo("Variable_0", "Variable", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.emplace(std::make_pair(0, std::make_pair("Data_0", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("Data_1", 0))); - input_node_name.emplace(std::make_pair(2, std::make_pair("Constant_0", 0))); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("IdentityN_0", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - control_input_node_name.emplace_back("Variable_0"); - control_output_node_name.clear(); - control_output_node_name.emplace_back("MatmulV2_0"); - expect_node_info.emplace_back(ExpectNodeInfo("ConcatV2_0", "ConcatV2", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 3, 1)); - - input_node_name.clear(); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("IdentityN_0", 2)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - control_input_node_name.clear(); - control_output_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("Data_2", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.emplace(std::make_pair(0, std::make_pair("ConcatV2_0", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("Variable_0", 0))); - input_node_name.emplace(std::make_pair(2, std::make_pair("Data_2", 0))); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("MatmulV2_0", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - temp_vector.clear(); - temp_vector.emplace_back(std::make_pair("MatmulV2_0", 1)); - output_node_name.emplace(std::make_pair(1, temp_vector)); - temp_vector.clear(); - temp_vector.emplace_back(std::make_pair("MatmulV2_0", 2)); - output_node_name.emplace(std::make_pair(2, temp_vector)); - control_input_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("IdentityN_0", "IdentityN", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 3, 3)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("IdentityN_0", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("IdentityN_0", 1))); - input_node_name.emplace(std::make_pair(2, std::make_pair("IdentityN_0", 2))); - output_node_name.clear(); - control_input_node_name.emplace_back("ConcatV2_0"); - expect_node_info.emplace_back(ExpectNodeInfo("MatmulV2_0", "MatmulV2", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 4, 1)); - CheckNodeResult(compute_graph, expect_node_info); - EXPECT_EQ(compute_graph->GetInputSize(), 3); -} - -/* - Data Data Data - | | | - \ | / branch0: branch1: branch1_0: branch1_1: - If DATA DATA DATA DATA DATA DATA - | \ / | / | | - Relu Add if Relu Relu -*/ -TEST_F(UtestGraph, TestGenerateGraphWithSubGraph) { - ge::Operator data_0 = ge::Operator("Data_0", "Data"); - ge::Operator data_1 = ge::Operator("Data_1", "Data"); - ge::Operator data_2 = ge::Operator("Data_2", "Data"); - ge::Operator if_op = ge::Operator("If_0", "If"); - ge::Operator relu_0 = ge::Operator("Relu_0", "Relu"); - data_0.InputRegister("x"); - data_0.OutputRegister("y"); - data_1.InputRegister("x"); - data_1.OutputRegister("y"); - data_2.InputRegister("x"); - data_2.OutputRegister("y"); - - if_op.InputRegister("cond"); - if_op.DynamicInputRegister("input", 2); - if_op.DynamicOutputRegister("output", 1); - if_op.SubgraphRegister("then_branch", false); - if_op.SubgraphRegister("else_branch", false); - if_op.SubgraphCountRegister("then_branch", 1); - if_op.SubgraphCountRegister("else_branch", 1); - if_op.SetSubgraphBuilder("then_branch", 0, [] ()->Graph { - ge::Operator then_branch_data_0 = ge::Operator("then_branch_data_0", "Data"); - ge::Operator then_branch_data_1 = ge::Operator("then_branch_data_1", "Data"); - ge::Operator add_0 = ge::Operator("Add_0", "Add"); - then_branch_data_0.InputRegister("x"); - then_branch_data_0.OutputRegister("y"); - then_branch_data_1.InputRegister("x"); - then_branch_data_1.OutputRegister("y"); - add_0.InputRegister("x1"); - add_0.InputRegister("x2"); - add_0.OutputRegister("y"); - add_0.SetInput(0U, then_branch_data_0, 0U); - add_0.SetInput(1U, then_branch_data_1, 0U); - std::vector then_branch_ops{then_branch_data_0, then_branch_data_1, add_0}; - Graph graph("if_op_then_branch"); - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperatorWithStableTopo(graph, then_branch_ops), SUCCESS); - return graph; - }); - if_op.SetSubgraphBuilder("else_branch", 0, [] ()->Graph { - ge::Operator else_branch_data_0 = ge::Operator("else_branch_data_0", "Data"); - ge::Operator else_branch_data_1 = ge::Operator("else_branch_data_1", "Data"); - ge::Operator if_op_1 = ge::Operator("else_branch_if", "If"); - else_branch_data_0.InputRegister("x"); - else_branch_data_0.OutputRegister("y"); - else_branch_data_1.InputRegister("x"); - else_branch_data_1.OutputRegister("y"); - if_op_1.InputRegister("cond"); - if_op_1.DynamicInputRegister("input", 1); - if_op_1.DynamicOutputRegister("output", 1); - if_op_1.SubgraphRegister("then_branch", false); - if_op_1.SubgraphRegister("else_branch", false); - if_op_1.SubgraphCountRegister("then_branch", 1); - if_op_1.SubgraphCountRegister("else_branch", 1); - if_op_1.SetSubgraphBuilder("then_branch", 0, [] ()->Graph { - ge::Operator if_1_then_branch_data_0 = ge::Operator("if_1_then_branch_data_0", "Data"); - ge::Operator if_1_then_branch_relu = ge::Operator("if_1_then_branch_relu", "Relu"); - if_1_then_branch_data_0.InputRegister("x"); - if_1_then_branch_data_0.OutputRegister("y"); - if_1_then_branch_relu.InputRegister("x"); - if_1_then_branch_relu.OutputRegister("y"); - if_1_then_branch_relu.SetInput(0U, if_1_then_branch_data_0, 0U); - std::vector if_1_then_branch_ops{if_1_then_branch_data_0, if_1_then_branch_relu}; - Graph graph("if_1_then_branch"); - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperatorWithStableTopo(graph, if_1_then_branch_ops), SUCCESS); - return graph; - }); - if_op_1.SetSubgraphBuilder("else_branch", 0, [] ()->Graph { - ge::Operator if_1_else_branch_data_0 = ge::Operator("if_1_else_branch_data_0", "Data"); - ge::Operator if_1_else_branch_relu = ge::Operator("if_1_else_branch_relu", "Relu"); - if_1_else_branch_data_0.InputRegister("x"); - if_1_else_branch_data_0.OutputRegister("y"); - if_1_else_branch_relu.InputRegister("x"); - if_1_else_branch_relu.OutputRegister("y"); - if_1_else_branch_relu.SetInput(0U, if_1_else_branch_data_0, 0U); - std::vector if_1_else_branch_ops{if_1_else_branch_data_0, if_1_else_branch_relu}; - Graph graph("if_1_else_branch"); - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperatorWithStableTopo(graph, if_1_else_branch_ops), SUCCESS); - return graph; - }); - if_op_1.SetInput(0U, else_branch_data_0, 0U); - if_op_1.SetInput(1U, else_branch_data_1, 0U); - std::vector else_branch_ops{else_branch_data_0, else_branch_data_1, if_op_1}; - Graph graph("if_op_else_branch"); - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperatorWithStableTopo(graph, else_branch_ops), SUCCESS); - return graph; - }); - relu_0.InputRegister("x"); - relu_0.OutputRegister("y"); - if_op.SetInput(0U, data_0, 0U); - if_op.SetInput(1U, data_1, 0U); - if_op.SetInput(2U, data_2, 0U); - relu_0.SetInput(0U, if_op, 0U); - - std::vector ops{data_0, data_1, data_2, if_op, relu_0}; - Graph graph("stable_sort_graph_with_subgraph"); - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperatorWithStableTopo(graph, ops), SUCCESS); - const auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - std::map> graph_expect_info; - // root_graph - std::vector expect_node_info; - std::map> input_node_name; - std::map>> output_node_name; - std::vector control_input_node_name; - std::vector control_output_node_name; - std::vector> temp_vector = {{"If_0", 0}}; - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Data_0", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("If_0", 1)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Data_1", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("If_0", 2)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Data_2", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.emplace(std::make_pair(0, std::make_pair("Data_0", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("Data_1", 0))); - input_node_name.emplace(std::make_pair(2, std::make_pair("Data_2", 0))); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Relu_0", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("If_0", "If", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 3, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("If_0", 0))); - output_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("Relu_0", "Relu", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - graph_expect_info.emplace("stable_sort_graph_with_subgraph", expect_node_info); - - // if_0_then_branch - expect_node_info.clear(); - input_node_name.clear(); - temp_vector.clear(); - temp_vector.emplace_back(std::make_pair("Add_0", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("then_branch_data_0", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Add_0", 1)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("then_branch_data_1", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.emplace(std::make_pair(0, std::make_pair("then_branch_data_0", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("then_branch_data_1", 0))); - output_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("Add_0", "Add", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 2, 1)); - graph_expect_info.emplace("if_op_then_branch", expect_node_info); - - // if_0_else_branch - expect_node_info.clear(); - input_node_name.clear(); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("else_branch_if", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("else_branch_data_0", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("else_branch_if", 1)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("else_branch_data_1", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.emplace(std::make_pair(0, std::make_pair("else_branch_data_0", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("else_branch_data_1", 0))); - output_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("else_branch_if", "If", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 2, 1)); - graph_expect_info.emplace("if_op_else_branch", expect_node_info); - - // if_1_then_branch - expect_node_info.clear(); - input_node_name.clear(); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("if_1_then_branch_relu", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("if_1_then_branch_data_0", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.emplace(std::make_pair(0, std::make_pair("if_1_then_branch_data_0", 0))); - output_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("if_1_then_branch_relu", "Relu", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - graph_expect_info.emplace("if_1_then_branch", expect_node_info); - // if_1_else_branch - expect_node_info.clear(); - input_node_name.clear(); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("if_1_else_branch_relu", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("if_1_else_branch_data_0", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.emplace(std::make_pair(0, std::make_pair("if_1_else_branch_data_0", 0))); - output_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("if_1_else_branch_relu", "Relu", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - graph_expect_info.emplace("if_1_else_branch", expect_node_info); - - EXPECT_EQ(compute_graph->GetName(), "stable_sort_graph_with_subgraph"); - EXPECT_EQ(compute_graph->GetDirectNodesSize(), 5); - EXPECT_EQ(compute_graph->GetInputSize(), 3); - CheckNodeResult(compute_graph, graph_expect_info["stable_sort_graph_with_subgraph"]); - const auto subgraphs = compute_graph->GetAllSubgraphs(); - EXPECT_EQ(subgraphs.size(), 4); - for (const auto &subgraph : subgraphs) { - const auto iter = graph_expect_info.find(subgraph->GetName()); - ASSERT_NE(iter, graph_expect_info.end()); - CheckNodeResult(subgraph, iter->second); - } -} - -/* - Data Data Const - | | | - Relu Relu Relu - | | | \ - Cast Cast -- Add Cast - \ / - Add -*/ -TEST_F(UtestGraph, TestGenerateGraphWithOutputMultiRef) { - ge::Operator data_0 = ge::Operator("Data_0", "Data"); - ge::Operator data_1 = ge::Operator("Data_1","Data"); - ge::Operator const_op = ge::Operator("Constant_0", "Constant"); - ge::Operator relu_0 = ge::Operator("Relu_0", "Relu"); - ge::Operator relu_1 = ge::Operator("Relu_1", "Relu"); - ge::Operator relu_2 = ge::Operator("Relu_2", "Relu"); - ge::Operator cast_0 = ge::Operator("Cast_0", "Cast"); - ge::Operator cast_1 = ge::Operator("Cast_1", "Cast"); - ge::Operator cast_2 = ge::Operator("Cast_2", "Cast"); - ge::Operator add_0 = ge::Operator("Add_0", "Add"); - ge::Operator add_1 = ge::Operator("Add_1", "Add"); - - data_0.InputRegister("x"); - data_0.OutputRegister("y"); - data_1.InputRegister("x"); - data_1.OutputRegister("y"); - const_op.OutputRegister("y"); - relu_0.InputRegister("x"); - relu_0.OutputRegister("y"); - relu_1.InputRegister("x"); - relu_1.OutputRegister("y"); - relu_2.InputRegister("x"); - relu_2.OutputRegister("y"); - cast_0.InputRegister("x"); - cast_0.OutputRegister("y"); - cast_1.InputRegister("x"); - cast_1.OutputRegister("y"); - cast_2.InputRegister("x"); - cast_2.OutputRegister("y"); - add_0.InputRegister("x1"); - add_0.InputRegister("x2"); - add_0.OutputRegister("y"); - add_1.InputRegister("x1"); - add_1.InputRegister("x2"); - add_1.OutputRegister("y"); - - relu_0.SetInput(0U, data_0, 0U); - relu_1.SetInput(0U, data_1, 0U); - relu_2.SetInput(0U, const_op, 0U); - cast_0.SetInput(0U, relu_0, 0U); - cast_1.SetInput(0U, relu_1, 0U); - cast_2.SetInput(0U, relu_2, 0U); - add_0.SetInput(0U, cast_1, 0U); - add_0.SetInput(1U, relu_2, 0U); - add_1.SetInput(0U, cast_0, 0U); - add_1.SetInput(1U, cast_1, 0U); - - std::vector ops{data_0, const_op, data_1, relu_0, relu_1, relu_2, cast_0, cast_1, cast_2, add_0, add_1}; - Graph graph("stable_sort_graph_multi_output_ref"); - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperatorWithStableTopo(graph, ops), SUCCESS); - const auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - EXPECT_EQ(compute_graph->GetName(), "stable_sort_graph_multi_output_ref"); - EXPECT_EQ(compute_graph->GetDirectNodesSize(), 11); - std::vector expect_node_info; - std::map> input_node_name; - std::map>> output_node_name; - std::vector control_input_node_name; - std::vector control_output_node_name; - std::vector> temp_vector = {{"Relu_0", 0}}; - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Data_0", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Relu_2", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Constant_0", "Constant", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 0, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Relu_1", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Data_1", "Data", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Cast_0", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - input_node_name.emplace(std::make_pair(0, std::make_pair("Data_0", 0))); - expect_node_info.emplace_back(ExpectNodeInfo("Relu_0", "Relu", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("Data_1", 0))); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Cast_1", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Relu_1", "Relu", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("Constant_0", 0))); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Cast_2", 0)); - temp_vector.emplace_back(std::make_pair("Add_0", 1)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Relu_2", "Relu", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("Relu_0", 0))); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Add_1", 0)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Cast_0", "Cast", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("Relu_1", 0))); - temp_vector.clear(); - output_node_name.clear(); - temp_vector.emplace_back(std::make_pair("Add_0", 0)); - temp_vector.emplace_back(std::make_pair("Add_1", 1)); - output_node_name.emplace(std::make_pair(0, temp_vector)); - expect_node_info.emplace_back(ExpectNodeInfo("Cast_1", "Cast", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("Relu_2", 0))); - output_node_name.clear(); - expect_node_info.emplace_back(ExpectNodeInfo("Cast_2", "Cast", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 1, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("Cast_1", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("Relu_2", 0))); - expect_node_info.emplace_back(ExpectNodeInfo("Add_0", "Add", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 2, 1)); - - input_node_name.clear(); - input_node_name.emplace(std::make_pair(0, std::make_pair("Cast_0", 0))); - input_node_name.emplace(std::make_pair(1, std::make_pair("Cast_1", 0))); - expect_node_info.emplace_back(ExpectNodeInfo("Add_1", "Add", - input_node_name, output_node_name, control_input_node_name, control_output_node_name, 2, 1)); - CheckNodeResult(compute_graph, expect_node_info); - EXPECT_EQ(compute_graph->GetInputSize(), 2); -} - -TEST_F(UtestGraph, TestSameNameNode_fail) { - std::string op_type(__FUNCTION__); - std::string op_name("the_dummy"); - OperatorFactoryImpl::RegisterOperatorCreator(op_type, [op_type](const std::string &name) -> Operator { - auto op_desc = std::make_shared(name, op_type); - op_desc->AddOutputDesc("output", {}); - return OpDescUtils::CreateOperatorFromOpDesc(op_desc); - }); - - auto node_0 = Operator(op_name, op_type); - auto node_1 = Operator(op_name, op_type); - std::vector ops_0 = { node_0, node_1 }; - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperator("graph_with_same_name_node", ops_0), nullptr); - - auto node_2 = Operator(op_name, op_type); - node_2.SubgraphRegister("sub_graph", false); - node_2.SubgraphCountRegister("sub_graph", 1); - node_2.SetSubgraphBuilder("sub_graph", 0, [op_name, op_type]() { - ut::GraphBuilder builder = ut::GraphBuilder("sub_graph_with_same_name_node"); - builder.AddNode(op_name, op_type, 0, 1); - builder.AddNode(op_name, op_type, 0, 1); - return GraphUtilsEx::CreateGraphFromComputeGraph(builder.GetGraph()); - }); - std::vector ops_1 = { node_2 }; - EXPECT_EQ(GraphUtilsEx::CreateGraphFromOperator("graph_with_same_name_node_in_subgraph", ops_1), nullptr); -} -// extern "C" wrapper functions to avoid C++ name mangling -extern "C" { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_Graph_SetValid(void *graph_ptr) { - if (graph_ptr == nullptr) { - return GRAPH_FAILED; - } - auto *graph = static_cast(graph_ptr); - return graph->SetValid(); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_Graph_SetAttr_AttrValue(void *graph_ptr, - const char *name, - const void *attr_value) { - if (graph_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *graph = static_cast(graph_ptr); - auto *value = static_cast(attr_value); - return graph->SetAttr(ge::AscendString(name), *value); -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_Graph_GetAttr_AttrValue(void *graph_ptr, - const char *name, - void *attr_value) { - if (graph_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *graph = static_cast(graph_ptr); - auto *value = static_cast(attr_value); - return graph->GetAttr(ge::AscendString(name), *value); -} -} -TEST_F(UtestGraph, TestGraphSetAttrAndGetAttr_AttrValue) { - // 创建测试图 - Graph graph("test_graph"); - auto data_op = Operator("Data", "Data"); - auto const_op = Operator("Const", "Const"); - std::vector inputs = {data_op, const_op}; - graph.SetInputs(inputs); - - // 测试AttrValue类型的SetAttr和GetAttr - AttrValue attr_value; - attr_value.SetAttrValue(static_cast(12345)); - - // 测试成功情况 - EXPECT_EQ(graph.SetAttr("test_attr", attr_value), GRAPH_SUCCESS); - AttrValue get_attr_value; - EXPECT_EQ(graph.GetAttr("test_attr", get_attr_value), GRAPH_SUCCESS); - int64_t int_value = 0; - EXPECT_EQ(get_attr_value.GetAttrValue(int_value), GRAPH_SUCCESS); - EXPECT_EQ(int_value, 12345); - - // 测试extern "C"接口 - EXPECT_EQ(aclCom_Graph_SetAttr_AttrValue(&graph, "test_attr_c", &attr_value), GRAPH_SUCCESS); - EXPECT_EQ(aclCom_Graph_GetAttr_AttrValue(&graph, "test_attr_c", &get_attr_value), GRAPH_SUCCESS); - int_value = 0; - EXPECT_EQ(get_attr_value.GetAttrValue(int_value), GRAPH_SUCCESS); - EXPECT_EQ(int_value, 12345); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_Graph_SetAttr_AttrValue(nullptr, "test_attr", &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_Graph_SetAttr_AttrValue(&graph, nullptr, &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_Graph_SetAttr_AttrValue(&graph, "test_attr", nullptr), GRAPH_FAILED); - EXPECT_EQ(aclCom_Graph_GetAttr_AttrValue(nullptr, "test_attr", &get_attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_Graph_GetAttr_AttrValue(&graph, nullptr, &get_attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_Graph_GetAttr_AttrValue(&graph, "test_attr", nullptr), GRAPH_FAILED); -} - -TEST_F(UtestGraph, TestGraphSetAttrAndGetAttr_AttrValue_ComplexTypes) { - // 创建测试图 - Graph graph("test_graph"); - graph.SetValid(); - - // 测试复杂类型的AttrValue - AttrValue complex_attr; - std::vector vec_value = {1, 2, 3, 4, 5}; - complex_attr.SetAttrValue(vec_value); - - EXPECT_EQ(graph.SetAttr("complex_attr", complex_attr), GRAPH_SUCCESS); - AttrValue get_complex_attr; - EXPECT_EQ(graph.GetAttr("complex_attr", get_complex_attr), GRAPH_SUCCESS); - std::vector get_vec_value; - EXPECT_EQ(get_complex_attr.GetAttrValue(get_vec_value), GRAPH_SUCCESS); - EXPECT_EQ(get_vec_value, vec_value); - - // 测试extern "C"接口 - EXPECT_EQ(aclCom_Graph_SetAttr_AttrValue(&graph, "complex_attr_c", &complex_attr), GRAPH_SUCCESS); - EXPECT_EQ(aclCom_Graph_GetAttr_AttrValue(&graph, "complex_attr_c", &get_complex_attr), GRAPH_SUCCESS); - get_vec_value.clear(); - EXPECT_EQ(get_complex_attr.GetAttrValue(get_vec_value), GRAPH_SUCCESS); - EXPECT_EQ(get_vec_value, vec_value); -} - -TEST_F(UtestGraph, TestGraphSetAttrAndGetAttr_AttrValue_InvalidCases) { - // 创建测试图 - Graph graph("test_graph"); - EXPECT_EQ(graph.SetValid(), GRAPH_SUCCESS); - EXPECT_EQ(aclCom_Graph_SetValid(&graph), GRAPH_SUCCESS); - - // 测试获取不存在的属性 - AttrValue attr_value; - EXPECT_NE(graph.GetAttr("non_existent_attr", attr_value), GRAPH_SUCCESS); - - // 测试空图 - Graph empty_graph; - AttrValue test_attr; - test_attr.SetAttrValue(static_cast(12345)); - EXPECT_NE(empty_graph.SetAttr("test_attr", test_attr), GRAPH_SUCCESS); - EXPECT_NE(empty_graph.GetAttr("test_attr", attr_value), GRAPH_SUCCESS); -} - -TEST_F(UtestGraph, TestDumpOnnxGraphToFile) { - std::string ascend_work_path = "./test_ge_graph_path"; - mmSetEnv("DUMP_GRAPH_PATH", ascend_work_path.c_str(), 1); - - // 创建测试图 - auto compute_graph = BuildComputeGraph(); - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - EXPECT_EQ(graph.IsValid(), true); - - std::string suffix = "test_onnx"; - // 测试dump onnx - EXPECT_EQ(graph.DumpToFile(Graph::DumpFormat::kOnnx, suffix.c_str()), GRAPH_SUCCESS); - - // test existed dir - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - onnx::ModelProto model_proto; - ASSERT_EQ(model_proto.ByteSize(), 0); - // static thing, so follow DumpGEGraphUserGraphNameNull_AscendWorkPathNotNull this case path - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(ascend_work_path); - std::string dump_graph_path = GetSpecificFilePath(ge::RealPath(dump_file_path.str().c_str()), suffix.c_str()); - bool state = GraphUtils::ReadProtoFromTextFile(dump_graph_path.c_str(), &model_proto); - ASSERT_EQ(state, true); - ASSERT_NE(model_proto.ByteSize(), 0); - EXPECT_STREQ(model_proto.graph().name().c_str(), "graph"); - - system(("rm -rf " + ascend_work_path).c_str()); - unsetenv("DUMP_GRAPH_PATH"); -} - -TEST_F(UtestGraph, TestDumpProtoGraphToFile) { - std::string ascend_work_path = "./test_ge_graph_path"; - mmSetEnv("DUMP_GRAPH_PATH", ascend_work_path.c_str(), 1); - - // 创建测试图 - auto compute_graph = BuildComputeGraph(); - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - EXPECT_EQ(graph.IsValid(), true); - - std::string suffix = "test_txt"; - // 测试dump onnx - EXPECT_EQ(graph.DumpToFile(Graph::DumpFormat::kTxt, suffix.c_str()), GRAPH_SUCCESS); - - // test existed dir - onnx::ModelProto model_proto; - ASSERT_EQ(model_proto.ByteSize(), 0); - // static thing, so follow DumpGEGraphUserGraphNameNull_AscendWorkPathNotNull this case path - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(ascend_work_path); - std::string dump_graph_path = GetSpecificFilePath(ge::RealPath(dump_file_path.str().c_str()), suffix.c_str()); - - ComputeGraphPtr com_graph2 = std::make_shared("GeTestGraph2"); - bool state = GraphUtils::LoadGEGraph(dump_graph_path.c_str(), *com_graph2); - EXPECT_EQ(state, true); - EXPECT_EQ(com_graph2->GetDirectNodesSize(), 3); - - system(("rm -rf " + ascend_work_path).c_str()); - unsetenv("DUMP_GRAPH_PATH"); -} - -TEST_F(UtestGraph, TestDumpProtoGraphToOstream) { - // 创建测试图 - auto compute_graph = BuildComputeGraph(); - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - EXPECT_EQ(graph.IsValid(), true); - - // 测试dump txt - std::ostringstream stream; - EXPECT_EQ(graph.Dump(Graph::DumpFormat::kTxt, stream), GRAPH_SUCCESS); - - ge::proto::ModelDef txt_model_proto; - google::protobuf::TextFormat::ParseFromString(stream.str(), &txt_model_proto); - - Model model; - EXPECT_EQ(model.Load(txt_model_proto), SUCCESS); - EXPECT_EQ(model.GetGraph()->GetDirectNodesSize(), 3); -} - -TEST_F(UtestGraph, TestDumpOnnxGraphToOstream) { - // 创建测试图 - auto compute_graph = BuildComputeGraph(); - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - EXPECT_EQ(graph.IsValid(), true); - - // 测试dump txt - std::ostringstream stream; - EXPECT_EQ(graph.Dump(Graph::DumpFormat::kOnnx, stream), GRAPH_SUCCESS); - - onnx::ModelProto onnx_model_proto; - google::protobuf::TextFormat::ParseFromString(stream.str(), &onnx_model_proto); - EXPECT_STREQ(onnx_model_proto.graph().name().c_str(), "graph"); -} -REG_OP(subTest) - .INPUT(inx, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .GRAPH(subgraph) - .OP_END_FACTORY_REG(subTest); - -REG_OP(dataOp) - .OUTPUT(y, TensorType::ALL()) - .ATTR(value, Int, 0) - .OP_END_FACTORY_REG(dataOp); - -TEST_F(UtestGraph, TestGraphFindNodeByName_failure) { - Graph graph_invalid("graph_invalid"); - AscendString empty_graph_node_name = "blablabla"; - ASSERT_EQ(nullptr, graph_invalid.FindNodeByName(empty_graph_node_name)); - - auto op = op::dataOp("dataOp"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs = {op_input1}; - AscendString name = "graph"; - GraphPtr graph = Graph::ConstructFromInputs(inputs, name); - auto node = graph->AddNodeByOp(op); - - AscendString wrong_name("wrong_name"); - ASSERT_EQ(nullptr, graph->FindNodeByName(wrong_name)); -} - -TEST_F(UtestGraph, TestGraphFindNodeByName_success) { - auto op = op::dataOp("dataOp"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs = {op_input1}; - AscendString name = "graph"; - GraphPtr graph = Graph::ConstructFromInputs(inputs, name); - auto node = graph->AddNodeByOp(op); - - AscendString find_name; - op.GetName(find_name); - auto get_node = graph->FindNodeByName(find_name); - AscendString exp_name("dataOp"); - ASSERT_TRUE(find_name == exp_name); -} - -TEST_F(UtestGraph, TestGraphGetParentGraph_failure) { - Graph graph_invalid("graph_invalid"); - AscendString empty_graph_node_name = "blablabla"; - ASSERT_EQ(nullptr, graph_invalid.GetParentGraph()); -} - -TEST_F(UtestGraph, TestGraphGetParentGraph_success) { - auto op = op::subTest("subTest"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs = {op_input1}; - AscendString name = "graph"; - GraphPtr graph = Graph::ConstructFromInputs(inputs, name); - auto gnode = graph->AddNodeByOp(op); - - name = "subgraph1"; - Operator op_input2 = op::dataOp("op_input2"); - inputs = {op_input2}; - auto subgraph = Graph::ConstructFromInputs(inputs, name); - ASSERT_EQ(GRAPH_SUCCESS, gnode.SetSubgraph("subgraph", *subgraph.get())); - - auto subgraph_parent = subgraph->GetParentGraph(); - AscendString ret_parent_graph_name; - AscendString exp_parent_graph_name("graph"); - subgraph_parent->GetName(ret_parent_graph_name); - ASSERT_TRUE(exp_parent_graph_name == ret_parent_graph_name); -} - -TEST_F(UtestGraph, TestGraphGetParentNode_failure) { - Graph graph_invalid("graph_invalid"); - AscendString empty_graph_node_name = "blablabla"; - ASSERT_EQ(nullptr, graph_invalid.GetParentNode()); -} - -TEST_F(UtestGraph, TestGraphGetParentNode_success) { - auto op = op::subTest("subTest"); - Operator op_input1 = op::dataOp("op_input1"); - std::vector inputs = {op_input1}; - AscendString name = "graph"; - GraphPtr graph = Graph::ConstructFromInputs(inputs, name); - auto gnode = graph->AddNodeByOp(op); - - name = "subgraph1"; - Operator op_input2 = op::dataOp("op_input2"); - inputs = {op_input2}; - auto subgraph = Graph::ConstructFromInputs(inputs, name); - ASSERT_EQ(GRAPH_SUCCESS, gnode.SetSubgraph("subgraph", *subgraph.get())); - - auto subgraph_parent_node = subgraph->GetParentNode(); - AscendString ret_parent_node_name; - AscendString exp_parent_node_name("subTest"); - subgraph_parent_node->GetName(ret_parent_node_name); - ASSERT_TRUE(ret_parent_node_name == exp_parent_node_name); -} \ No newline at end of file diff --git a/tests/ut/graph/testcase/graph_utils_unittest.cc b/tests/ut/graph/testcase/graph_utils_unittest.cc deleted file mode 100644 index 03f63b19670e700cb4a1d3c52aaf233add9e87ac..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/graph_utils_unittest.cc +++ /dev/null @@ -1,5469 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/node.h" -#include "graph/ge_local_context.h" -#include "graph/ge_context.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_util.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "mmpa/mmpa_api.h" -#include "common/util/mem_utils.h" -#include "proto/onnx/ge_onnx.pb.h" -#include "graph/utils/ge_dump_graph_whitelist.h" - -namespace ge { -extern std::stringstream GetFilePathWhenDumpPathSet(const string &ascend_work_path); -namespace { -bool IfNodeExist(const ComputeGraphPtr &graph, std::function filter, - bool direct_node_flag = true) { - for (const auto &node : graph->GetNodes(direct_node_flag)) { - if (filter(node)) { - return true; - } - } - return false; -} -/* - * data1 const1 data2 const2 - * \ / \ / - * add1 add2 - * | | - * cast1 cast2 - * | | - * square1 var1 var2 square2 - * \ / | | \ / - * less1 | | less2 - * \ | | / - * mul - * | - * | - * | - * netoutput - */ -void BuildGraphForUnfold(ComputeGraphPtr &graph, ComputeGraphPtr &subgraph) { - auto builder = ut::GraphBuilder("root"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &input2 = builder.AddNode("data2", DATA, 1, 1); - const auto &var2 = builder.AddNode("var2", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(var1, 0, func, 1); - builder.AddDataEdge(input2, 0, func, 2); - builder.AddDataEdge(var2, 0, func, 3); - builder.AddDataEdge(func, 0, netoutput, 0); - - graph = builder.GetGraph(); - - auto sub_builder = ut::GraphBuilder("sub"); - const auto &data1 = sub_builder.AddNode("data1", DATA, 1, 1); - const auto &const1 = sub_builder.AddNode("const1", CONSTANTOP, 0, 1); - const auto &add1 = sub_builder.AddNode("add1", "Add", 2, 1); - const auto &cast1 = sub_builder.AddNode("cast1", "Cast", 1, 1); - const auto &func1 = sub_builder.AddNode("func1", PARTITIONEDCALL, 2, 1); - const auto &data2 = sub_builder.AddNode("data2", DATA, 1, 1); - const auto &data3 = sub_builder.AddNode("data3", DATA, 1, 1); - const auto &const2 = sub_builder.AddNode("const2", CONSTANTOP, 0, 1); - const auto &add2 = sub_builder.AddNode("add2", "Add", 2, 1); - const auto &cast2 = sub_builder.AddNode("cast2", "Cast", 1, 1); - const auto &func2 = sub_builder.AddNode("func2", PARTITIONEDCALL, 2, 1); - const auto &data4 = sub_builder.AddNode("data4", DATA, 1, 1); - const auto &mul = sub_builder.AddNode("mul", "Mul", 2, 1); - const auto &netoutput0 = sub_builder.AddNode("netoutput0", NETOUTPUT, 1, 0); - sub_builder.AddDataEdge(data1, 0, add1, 0); - sub_builder.AddDataEdge(const1, 0, add1, 1); - sub_builder.AddDataEdge(add1, 0, cast1, 0); - sub_builder.AddDataEdge(cast1, 0, func1, 0); - sub_builder.AddDataEdge(data2, 0, func1, 1); - sub_builder.AddDataEdge(data3, 0, add2, 0); - sub_builder.AddDataEdge(const2, 0, add2, 1); - sub_builder.AddDataEdge(add2, 0, cast2, 0); - sub_builder.AddDataEdge(cast2, 0, func2, 0); - sub_builder.AddDataEdge(data4, 0, func2, 1); - sub_builder.AddDataEdge(func1, 0, mul, 0); - sub_builder.AddDataEdge(func2, 0, mul, 1); - sub_builder.AddDataEdge(mul, 0, netoutput0, 0); - - subgraph = sub_builder.GetGraph(); - subgraph->SetGraphUnknownFlag(true); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - AttrUtils::SetInt(data3->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); - AttrUtils::SetInt(data4->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); - AttrUtils::SetInt(netoutput0->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - func->GetOpDesc()->AddSubgraphName("f"); - func->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); - graph->AddSubGraph(subgraph); - subgraph->SetParentNode(func); - subgraph->SetParentGraph(graph); - - auto sub_sub_builder1 = ut::GraphBuilder("sub_sub1"); - const auto &data5 = sub_sub_builder1.AddNode("data5", DATA, 1, 1); - const auto &data6 = sub_sub_builder1.AddNode("data6", DATA, 1, 1); - const auto &square1 = sub_sub_builder1.AddNode("square1", "Square", 1, 1); - const auto &less1 = sub_sub_builder1.AddNode("less1", "Less", 2, 1); - const auto &netoutput1 = sub_sub_builder1.AddNode("netoutput1", NETOUTPUT, 1, 0); - sub_sub_builder1.AddDataEdge(data5, 0, square1, 0); - sub_sub_builder1.AddDataEdge(square1, 0, less1, 0); - sub_sub_builder1.AddDataEdge(data6, 0, less1, 1); - sub_sub_builder1.AddDataEdge(less1, 0, netoutput1, 0); - - const auto &sub_subgraph1 = sub_sub_builder1.GetGraph(); - sub_subgraph1->SetGraphUnknownFlag(true); - AttrUtils::SetInt(data5->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(data6->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - AttrUtils::SetInt(netoutput1->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - func1->GetOpDesc()->AddSubgraphName("f"); - func1->GetOpDesc()->SetSubgraphInstanceName(0, sub_subgraph1->GetName()); - graph->AddSubGraph(sub_subgraph1); - sub_subgraph1->SetParentNode(func1); - sub_subgraph1->SetParentGraph(subgraph); - - auto sub_sub_builder2 = ut::GraphBuilder("sub_sub2"); - const auto &data7 = sub_sub_builder2.AddNode("data7", DATA, 1, 1); - const auto &data8 = sub_sub_builder2.AddNode("data8", DATA, 1, 1); - const auto &square2 = sub_sub_builder2.AddNode("square2", "Square", 1, 1); - const auto &less2 = sub_sub_builder2.AddNode("less2", "Less", 2, 1); - const auto &netoutput2 = sub_sub_builder2.AddNode("netoutput2", NETOUTPUT, 1, 0); - sub_sub_builder2.AddDataEdge(data7, 0, square2, 0); - sub_sub_builder2.AddDataEdge(square2, 0, less2, 0); - sub_sub_builder2.AddDataEdge(data8, 0, less2, 1); - sub_sub_builder2.AddDataEdge(less2, 0, netoutput2, 0); - - const auto &sub_subgraph2 = sub_sub_builder2.GetGraph(); - sub_subgraph2->SetGraphUnknownFlag(false); - AttrUtils::SetInt(data7->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(data8->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - AttrUtils::SetInt(netoutput2->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - func2->GetOpDesc()->AddSubgraphName("f"); - func2->GetOpDesc()->SetSubgraphInstanceName(0, sub_subgraph2->GetName()); - graph->AddSubGraph(sub_subgraph2); - sub_subgraph2->SetParentNode(func2); - sub_subgraph2->SetParentGraph(subgraph); - - return; -} -/* -------------- - * | | - * data1 const1 data2 const2 | - * | \ / \ / | - * | add1 add2 | - * | | | | - * | cast1 cast2 | - * | | | | - * | | | | - * | \ / | - * \ ------ mul ------------------ - * \ | - * \ | - * \ | - * ------- netoutput - */ -void BuildGraphForUnfoldWithControlEdge(ComputeGraphPtr &graph, ComputeGraphPtr &subgraph) { - auto builder = ut::GraphBuilder("root"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &input2 = builder.AddNode("data2", DATA, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(input2, 0, func, 1); - builder.AddDataEdge(func, 0, netoutput, 0); - - graph = builder.GetGraph(); - - auto sub_builder = ut::GraphBuilder("sub"); - const auto &data1 = sub_builder.AddNode("data1", DATA, 1, 1); - const auto &const1 = sub_builder.AddNode("const1", CONSTANTOP, 0, 1); - const auto &add1 = sub_builder.AddNode("add1", "Add", 2, 1); - const auto &cast1 = sub_builder.AddNode("cast1", "Cast", 1, 1); - const auto &data2 = sub_builder.AddNode("data2", DATA, 1, 1); - const auto &const2 = sub_builder.AddNode("const2", CONSTANTOP, 0, 1); - const auto &add2 = sub_builder.AddNode("add2", "Add", 2, 1); - const auto &cast2 = sub_builder.AddNode("cast2", "Cast", 1, 1); - const auto &mul = sub_builder.AddNode("mul", "Mul", 2, 1); - const auto &netoutput0 = sub_builder.AddNode("netoutput0", NETOUTPUT, 1, 0); - sub_builder.AddDataEdge(data1, 0, add1, 0); - sub_builder.AddControlEdge(data1, netoutput0); - sub_builder.AddDataEdge(const1, 0, add1, 1); - sub_builder.AddDataEdge(add1, 0, cast1, 0); - sub_builder.AddDataEdge(cast1, 0, mul, 0); - sub_builder.AddControlEdge(data2, mul); - sub_builder.AddDataEdge(data2, 0, add2, 0); - sub_builder.AddDataEdge(const2, 0, add2, 1); - sub_builder.AddDataEdge(add2, 0, cast2, 0); - sub_builder.AddDataEdge(cast2, 0, mul, 1); - sub_builder.AddDataEdge(mul, 0, netoutput0, 0); - - subgraph = sub_builder.GetGraph(); - subgraph->SetGraphUnknownFlag(true); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - AttrUtils::SetInt(netoutput0->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - func->GetOpDesc()->AddSubgraphName("f"); - func->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName()); - graph->AddSubGraph(subgraph); - subgraph->SetParentNode(func); - subgraph->SetParentGraph(graph); - return; -} - -void BuildGraphWithPlaceholderAndEnd(ComputeGraphPtr &graph) { - auto builder = ut::GraphBuilder("root"); - const auto &input1 = builder.AddNode("pld1", PLACEHOLDER, 1, 1); - const auto &input2 = builder.AddNode("pld2", PLACEHOLDER, 1, 1); - const auto &data1 = builder.AddNode("data1", DATA, 1, 1); - const auto &data2 = builder.AddNode("data2", DATA, 1, 1); - const auto &end = builder.AddNode("end", END, 1, 1); - const auto &add1 = builder.AddNode("add1", "Add", 2, 1); - const auto &add2 = builder.AddNode("add2", "Add", 2, 1); - const auto &add3 = builder.AddNode("add3", "Add", 2, 1); - builder.AddDataEdge(input1, 0, add1, 0); - builder.AddDataEdge(input2, 0, add1, 1); - builder.AddDataEdge(data1, 0, add2, 0); - builder.AddDataEdge(data2, 0, add2, 1); - builder.AddDataEdge(add1, 0, add3, 0); - builder.AddDataEdge(add2, 0, add3, 1); - builder.AddDataEdge(add3, 0, end, 0); - graph = builder.GetGraph(); - graph->AddOutputNode(end); -} - -ComputeGraphPtr BuildGraphWithSubGraph() { - auto root_builder = ut::GraphBuilder("root"); - const auto &data0 = root_builder.AddNode("data0", "Data", 1, 1); - const auto &case0 = root_builder.AddNode("case0", "Case", 1, 1); - const auto &relu0 = root_builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = root_builder.AddNode("relu1", "Relu", 1, 1); - const auto &netoutput = root_builder.AddNode("netoutput", "NetOutput", 1, 1); - const auto &root_graph = root_builder.GetGraph(); - root_builder.AddDataEdge(data0, 0, case0, 0); - root_builder.AddDataEdge(case0, 0, relu0, 0); - root_builder.AddDataEdge(relu0, 0, relu1, 0); - root_builder.AddDataEdge(relu1, 0, netoutput, 0); - - auto sub_builder1 = ut::GraphBuilder("sub1"); - const auto &data1 = sub_builder1.AddNode("data1", "Data", 0, 1); - const auto &sub_graph1 = sub_builder1.GetGraph(); - root_graph->AddSubGraph(sub_graph1); - sub_graph1->SetParentNode(case0); - sub_graph1->SetParentGraph(root_graph); - case0->GetOpDesc()->AddSubgraphName("branch1"); - case0->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - - auto sub_builder2 = ut::GraphBuilder("sub2"); - const auto &data2 = sub_builder2.AddNode("data2", "Data", 0, 1); - const auto &sub_graph2 = sub_builder2.GetGraph(); - root_graph->AddSubGraph(sub_graph2); - sub_graph2->SetParentNode(case0); - sub_graph2->SetParentGraph(root_graph); - case0->GetOpDesc()->AddSubgraphName("branch2"); - case0->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - root_graph->TopologicalSorting(); - return root_graph; -} - -ComputeGraphPtr BuildGraphWithConst() { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - AttrUtils::SetStr(const_node->GetOpDesc(), "fake_attr_name", "fake_attr_value"); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - AttrUtils::SetStr(add_node->GetOpDesc(), "fake_attr_name", "fake_attr_value"); - AttrUtils::SetStr(add_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, "fake_attr_value"); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, netoutput, 0); - return builder.GetGraph(); -} - -std::string GetSpecificFilePath(const std::string &file_path, const string &suffix) { - DIR *dir; - struct dirent *ent; - dir = opendir(file_path.c_str()); - if (dir == nullptr) { - return ""; - } - while ((ent = readdir(dir)) != nullptr) { - if (strstr(ent->d_name, suffix.c_str()) != nullptr) { - std::string d_name(ent->d_name); - closedir(dir); - return file_path + "/" + d_name; - } - } - closedir(dir); - return ""; -} - -ComputeGraphPtr BuildGraphForIsolateNode(const int in_data_num, const int out_data_num, const int in_ctrl_num, - const int out_ctrl_num) { - auto graph_builder = ut::GraphBuilder("graph"); - - const auto &del_node = graph_builder.AddNode("del_node", "DelNode", in_data_num, out_data_num); - - for (int i = 0; i < in_data_num; ++i) { - const auto &n = graph_builder.AddNode("in_node_" + std::to_string(i), "InNode", 1, 1); - graph_builder.AddDataEdge(n, 0, del_node, i); - } - for (int i = 0; i < out_data_num; ++i) { - const auto &n = graph_builder.AddNode("out_node_" + std::to_string(i), "OutNode", 1, 1); - graph_builder.AddDataEdge(del_node, i, n, 0); - } - - for (int i = 0; i < in_ctrl_num; ++i) { - const auto &n = graph_builder.AddNode("in_ctrl_node_" + std::to_string(i), "InCtrlNode", 1, 1); - graph_builder.AddControlEdge(n, del_node); - } - for (int i = 0; i < out_ctrl_num; ++i) { - const auto &n = graph_builder.AddNode("out_ctrl_node_" + std::to_string(i), "OutCtrlNode", 1, 1); - graph_builder.AddControlEdge(del_node, n); - } - return graph_builder.GetGraph(); -} -} // namespace - -namespace { -class UtestComputeGraphBuilder : public ComputeGraphBuilder { - public: - virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) { - auto graph = std::make_shared("test"); - auto op_desc = std::make_shared("node", "node"); - NodePtr node = graph->AddNode(op_desc); - std::map node_names_; - node_names_.insert(pair("node", node)); - return graph; - } - - NodePtr GetNode(const std::string &name); - std::vector GetAllNodes(); - void BuildNodes(graphStatus &error_code, std::string &error_msg); -}; - -NodePtr UtestComputeGraphBuilder::GetNode(const std::string &name) { - return ComputeGraphBuilder::GetNode(name); -} - -std::vector UtestComputeGraphBuilder::GetAllNodes() { - return ComputeGraphBuilder::GetAllNodes(); -} - -void UtestComputeGraphBuilder::BuildNodes(graphStatus &error_code, std::string &error_msg) { - return ComputeGraphBuilder::BuildNodes(error_code, error_msg); -} - -} // namespace - -class UtestGraphUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() { - } -}; - -TEST_F(UtestGraphUtils, DumpGEGraphUserGraphNameNull_AscendWorkPathNotNull) { - auto graph = BuildGraphWithConst(); - std::string ascend_work_path = "./test_ge_graph_path"; - mmSetEnv("DUMP_GE_GRAPH", "1", 1); - mmSetEnv("DUMP_GRAPH_PATH", ascend_work_path.c_str(), 1); - GraphUtils::DumpGEGraph(graph, "", true, ""); - ComputeGraphPtr com_graph = std::make_shared("GeTestGraph"); - - // test load - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(ascend_work_path); - std::string dump_graph_path = GetSpecificFilePath(ge::RealPath(dump_file_path.str().c_str()), "_.txt"); - auto state = GraphUtils::LoadGEGraph(dump_graph_path.c_str(), *com_graph); - ASSERT_EQ(state, true); - unsetenv("DUMP_GRAPH_PATH"); - unsetenv("DUMP_GE_GRAPH"); - system(("rm -rf " + ascend_work_path).c_str()); -} - -TEST_F(UtestGraphUtils, DumpGEGraphToOnnxNotAlways) { - unsetenv("DUMP_GRAPH_PATH"); - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void) setenv(kDumpGraphLevel, "4", 1); - ComputeGraph compute_graph("test_graph0"); - compute_graph.SetGraphID(0); - const std::string suffit = "always_dump"; - ge::GraphUtils::DumpGEGraphToOnnx(compute_graph, suffit); - unsetenv(kDumpGraphLevel); - - // test NOT existed dir - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - onnx::ModelProto model_proto; - ASSERT_EQ(model_proto.ByteSize(), 0); - // static thing, so follow DumpGEGraphUserGraphNameNull_AscendWorkPathNotNull this case path - std::string ascend_work_path = "./test_ge_graph_path"; - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(ascend_work_path); - std::string dump_graph_path = GetSpecificFilePath(ge::RealPath(dump_file_path.str().c_str()), suffit); - bool state = GraphUtils::ReadProtoFromTextFile(dump_graph_path.c_str(), &model_proto); - ASSERT_EQ(state, false); - ASSERT_EQ(model_proto.ByteSize(), 0); -} - -TEST_F(UtestGraphUtils, DumpGEGraphToOnnxAlways) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "4", 1); - ComputeGraph compute_graph("test_graph0"); - compute_graph.SetGraphID(0); - const std::string suffit = "always_dump"; - ge::GraphUtils::DumpGEGraphToOnnx(compute_graph, suffit, true); - unsetenv(kDumpGraphLevel); - - // test existed dir - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - onnx::ModelProto model_proto; - ASSERT_EQ(model_proto.ByteSize(), 0); - // static thing, so follow DumpGEGraphUserGraphNameNull_AscendWorkPathNotNull this case path - std::string ascend_work_path = "./test_ge_graph_path"; - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(ascend_work_path); - std::string dump_graph_path = GetSpecificFilePath(ge::RealPath(dump_file_path.str().c_str()), suffit); - bool state = GraphUtils::ReadProtoFromTextFile(dump_graph_path.c_str(), &model_proto); - ASSERT_EQ(state, true); - ASSERT_NE(model_proto.ByteSize(), 0); - system(("rm -rf " + ascend_work_path).c_str()); -} - -TEST_F(UtestGraphUtils, DumpGEGraphToOnnxByPath) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void) setenv(kDumpGraphLevel, "4", 1); - ComputeGraph compute_graph("test_graph0"); - compute_graph.SetGraphID(0); - const std::string suffix = "aaa/bbb.ccc\\ddd"; - std::string dump_file_path = "/root/"; - ge::GraphUtils::DumpGrphToOnnx(compute_graph, dump_file_path, suffix); - dump_file_path = "/tmp/"; - ge::GraphUtils::DumpGrphToOnnx(compute_graph, dump_file_path, suffix); - unsetenv(kDumpGraphLevel); - - // test existed dir - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - onnx::ModelProto model_proto; - ASSERT_EQ(model_proto.ByteSize(), 0); - const std::string safe_suffix = "aaa_bbb.ccc_ddd"; - std::string dump_graph_path = GetSpecificFilePath(dump_file_path, safe_suffix); - bool state = GraphUtils::ReadProtoFromTextFile(dump_graph_path.c_str(), &model_proto); - ASSERT_EQ(state, true); - ASSERT_NE(model_proto.ByteSize(), 0); - system(("rm -f " + dump_graph_path).c_str()); -} - -TEST_F(UtestGraphUtils, DumpGEGraphWithDumpGEGraphInvalid) { - auto graph = BuildGraphWithConst(); - std::string ascend_work_path = "./test_ge_graph_path"; - mmSetEnv("DUMP_GE_GRAPH", "0", 1); - mmSetEnv("DUMP_GRAPH_PATH", ascend_work_path.c_str(), 1); - GraphUtils::DumpGEGraph(graph, "", false, ""); - ComputeGraphPtr com_graph = std::make_shared("GeTestGraph"); - - // test load - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(ascend_work_path); - std::string dump_graph_path = ge::RealPath(dump_file_path.str().c_str()); - auto state = GraphUtils::LoadGEGraph(GetSpecificFilePath(dump_graph_path, "_.txt").c_str(), *com_graph); - ASSERT_EQ(state, false); - unsetenv("DUMP_GRAPH_PATH"); - unsetenv("DUMP_GE_GRAPH"); - system(("rm -rf " + ascend_work_path).c_str()); -} - -/* -* var var -* atomicclean | \ | \ -* \\ | assign | assign -* \\ | // =======> | // -* allreduce identity atomicclean -* | | // -* netoutput allreduce -* | -* netoutput - */ -TEST_F(UtestGraphUtils, InsertNodeBefore_DoNotMoveCtrlEdgeFromAtomicClean) { - // build test graph - auto builder = ut::GraphBuilder("test"); - const auto &var = builder.AddNode("var", VARIABLE, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 1, 1); - const auto &allreduce = builder.AddNode("allreduce", "HcomAllReduce", 1, 1); - const auto &atomic_clean = builder.AddNode("atomic_clean", ATOMICADDRCLEAN, 0, 0); - const auto &netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - const auto &identity = builder.AddNode("identity", "Identity", 1, 1); - - builder.AddDataEdge(var, 0, assign, 0); - builder.AddDataEdge(var,0,allreduce,0); - builder.AddControlEdge(assign, allreduce); - builder.AddControlEdge(atomic_clean, allreduce); - auto graph = builder.GetGraph(); - - // insert identity before allreduce - GraphUtils::InsertNodeBefore(allreduce->GetInDataAnchor(0), identity, 0, 0); - - // check assign control-in on identity - ASSERT_EQ(identity->GetInControlNodes().at(0)->GetName(), "assign"); - // check atomicclean control-in still on allreuce - ASSERT_EQ(allreduce->GetInControlNodes().at(0)->GetName(), "atomic_clean"); -} - -TEST_F(UtestGraphUtils, GetSubgraphs) { - auto root_builder = ut::GraphBuilder("root"); - const auto &case0 = root_builder.AddNode("case0", "Case", 0, 0); - const auto &root_graph = root_builder.GetGraph(); - - auto sub_builder1 = ut::GraphBuilder("sub1"); - const auto &case1 = sub_builder1.AddNode("case1", "Case", 0, 0); - const auto &sub_graph1 = sub_builder1.GetGraph(); - root_graph->AddSubGraph(sub_graph1); - sub_graph1->SetParentNode(case0); - sub_graph1->SetParentGraph(root_graph); - case0->GetOpDesc()->AddSubgraphName("branch1"); - case0->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - - auto sub_builder2 = ut::GraphBuilder("sub2"); - const auto &sub_graph2 = sub_builder2.GetGraph(); - root_graph->AddSubGraph(sub_graph2); - sub_graph2->SetParentNode(case1); - sub_graph2->SetParentGraph(sub_graph1); - case1->GetOpDesc()->AddSubgraphName("branch1"); - case1->GetOpDesc()->SetSubgraphInstanceName(0, "sub2"); - case1->GetOpDesc()->AddSubgraphName("branch2"); - case1->GetOpDesc()->SetSubgraphInstanceName(1, "not_exist"); - - std::vector subgraphs1; - ASSERT_EQ(GraphUtils::GetSubgraphsRecursively(root_graph, subgraphs1), GRAPH_SUCCESS); - ASSERT_EQ(subgraphs1.size(), 2); - - std::vector subgraphs2; - ASSERT_EQ(GraphUtils::GetSubgraphsRecursively(sub_graph1, subgraphs2), GRAPH_SUCCESS); - ASSERT_EQ(subgraphs2.size(), 1); - - std::vector subgraphs3; - ASSERT_EQ(GraphUtils::GetSubgraphsRecursively(sub_graph2, subgraphs3), GRAPH_SUCCESS); - ASSERT_TRUE(subgraphs3.empty()); -} - -TEST_F(UtestGraphUtils, GetSubgraphs_nullptr_graph) { - std::vector subgraphs; - ASSERT_NE(GraphUtils::GetSubgraphsRecursively(nullptr, subgraphs), GRAPH_SUCCESS); - ASSERT_TRUE(subgraphs.empty()); -} - -TEST_F(UtestGraphUtils, ReplaceEdgeSrc) { - auto builder = ut::GraphBuilder("root"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - const auto &node1 = builder.AddNode("node1", "node", 1, 1); - const auto &node2 = builder.AddNode("node2", "node", 1, 1); - const auto &node3 = builder.AddNode("node3", "node", 1, 1); - builder.AddDataEdge(node0, 0, node2, 0); - ASSERT_EQ(GraphUtils::ReplaceEdgeSrc(node0->GetOutDataAnchor(0), node2->GetInDataAnchor(0), - node1->GetOutDataAnchor(0)), GRAPH_SUCCESS); - ASSERT_NE(GraphUtils::ReplaceEdgeSrc(node0->GetOutDataAnchor(0), node2->GetInDataAnchor(0), - node3->GetOutDataAnchor(0)), GRAPH_SUCCESS); - - builder.AddControlEdge(node0, node2); - ASSERT_EQ(GraphUtils::ReplaceEdgeSrc(node0->GetOutControlAnchor(), node2->GetInControlAnchor(), - node1->GetOutControlAnchor()), GRAPH_SUCCESS); - ASSERT_NE(GraphUtils::ReplaceEdgeSrc(node0->GetOutControlAnchor(), node2->GetInControlAnchor(), - node3->GetOutControlAnchor()), GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, ReplaceEdgeDst) { - auto builder = ut::GraphBuilder("root"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - const auto &node1 = builder.AddNode("node1", "node", 1, 1); - const auto &node2 = builder.AddNode("node2", "node", 1, 1); - const auto &node3 = builder.AddNode("node3", "node", 1, 1); - builder.AddDataEdge(node0, 0, node2, 0); - ASSERT_EQ(GraphUtils::ReplaceEdgeDst(node0->GetOutDataAnchor(0), node2->GetInDataAnchor(0), - node1->GetInDataAnchor(0)), GRAPH_SUCCESS); - ASSERT_NE(GraphUtils::ReplaceEdgeDst(node0->GetOutDataAnchor(0), node2->GetInDataAnchor(0), - node3->GetInDataAnchor(0)), GRAPH_SUCCESS); - - builder.AddControlEdge(node0, node2); - ASSERT_EQ(GraphUtils::ReplaceEdgeDst(node0->GetOutControlAnchor(), node2->GetInControlAnchor(), - node1->GetInControlAnchor()), GRAPH_SUCCESS); - ASSERT_NE(GraphUtils::ReplaceEdgeDst(node0->GetOutControlAnchor(), node2->GetInControlAnchor(), - node3->GetInControlAnchor()), GRAPH_SUCCESS); -} - -/* - * data0 data1 - * \ /| - * add1 | data2 - * \ | /| - * add2 | data3 - * \ | /| - * add3 | data4 - * \ | / | \ - * add4 | cast1 - * \ | / | - * add5 | - * | \ | - * | cast2 - * | / - * netoutput - */ -TEST_F(UtestGraphUtils, BuildSubgraphWithNodes) { - auto builder = ut::GraphBuilder("root"); - const auto &data0 = builder.AddNode("data0", DATA, 1, 1); - const auto &data1 = builder.AddNode("data1", DATA, 1, 1); - const auto &data2 = builder.AddNode("data2", DATA, 1, 1); - const auto &data3 = builder.AddNode("data3", DATA, 1, 1); - const auto &data4 = builder.AddNode("data4", DATA, 1, 1); - - const auto &add1 = builder.AddNode("add1", "Add", 2, 1); - const auto &add2 = builder.AddNode("add2", "Add", 2, 1); - const auto &add3 = builder.AddNode("add3", "Add", 2, 1); - const auto &add4 = builder.AddNode("add4", "Add", 2, 1); - const auto &add5 = builder.AddNode("add5", "Add", 2, 1); - - const auto &cast1 = builder.AddNode("cast1", "Cast", 1, 1); - const auto &cast2 = builder.AddNode("cast2", "Cast", 1, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - - builder.AddDataEdge(data0, 0, add1, 0); - builder.AddDataEdge(data1, 0, add1, 1); - builder.AddDataEdge(add1, 0, add2, 0); - builder.AddDataEdge(data2, 0, add2, 1); - builder.AddDataEdge(add2, 0, add3, 0); - builder.AddDataEdge(data3, 0, add3, 1); - builder.AddDataEdge(add3, 0, add4, 0); - builder.AddDataEdge(data4, 0, add4, 1); - builder.AddDataEdge(data4, 0, cast1, 0); - builder.AddDataEdge(add4, 0, add5, 0); - builder.AddDataEdge(cast1, 0, add5, 1); - builder.AddDataEdge(add5, 0, cast2, 0); - builder.AddDataEdge(cast2, 0, netoutput, 0); - - builder.AddControlEdge(data1, add2); - builder.AddControlEdge(data2, add3); - builder.AddControlEdge(data3, add4); - builder.AddControlEdge(data4, add5); - builder.AddControlEdge(add5, netoutput); - builder.AddControlEdge(cast1, cast2); - - ASSERT_EQ(GraphUtils::BuildSubgraphWithNodes(nullptr, {}, "subgraph1"), nullptr); - - const auto &graph = builder.GetGraph(); - ASSERT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - ASSERT_EQ(GraphUtils::BuildSubgraphWithNodes(graph, {}, "subgraph1"), nullptr); - - std::set nodes = { data1, add2, add3, add4, add5, cast2 }; - ASSERT_EQ(GraphUtils::BuildSubgraphWithNodes(graph, nodes, "subgraph1"), nullptr); - - ASSERT_TRUE(AttrUtils::SetStr(graph, "_session_graph_id", "_session_graph_id")); - const auto &subgraph1 = GraphUtils::BuildSubgraphWithNodes(graph, nodes, "subgraph1"); - ASSERT_NE(subgraph1, nullptr); - ASSERT_EQ(subgraph1->GetParentGraph(), graph); - ASSERT_TRUE(subgraph1->HasAttr("_session_graph_id")); - ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "data1"; })); - ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "add2"; })); - ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "add3"; })); - ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "add4"; })); - ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "add5"; })); - ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "cast2"; })); - ASSERT_TRUE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetType() == PARTITIONEDCALL; })); - ASSERT_EQ(graph->GetAllSubgraphs().size(), 1); - - ASSERT_EQ(GraphUtils::BuildSubgraphWithNodes(graph, {cast1}, "subgraph1"), nullptr); -} - -TEST_F(UtestGraphUtils, BuildSubgraphWithOnlyControlNodes) { - auto builder = ut::GraphBuilder("root"); - auto data1 = builder.AddNode("data1", DATA, 0, 1); - auto data2 = builder.AddNode("data2", DATA, 0, 1); - auto op1 = builder.AddNode("square1", "Square", 1, 1); - auto op2 = builder.AddNode("square2", "Square", 1, 1); - auto output = builder.AddNode("output", NETOUTPUT, 1, 0); - - builder.AddDataEdge(data1, 0, op1, 0); - builder.AddDataEdge(data2, 0, op2, 0); - builder.AddDataEdge(op2, 0, output, 0); - builder.AddControlEdge(op1, op2); - auto origin_graph = builder.GetGraph(); - ASSERT_TRUE(AttrUtils::SetStr(origin_graph, "_session_graph_id", "graph_id")); - - auto subgraph = GraphUtils::BuildSubgraphWithNodes(origin_graph, {op1}, "subgraph"); - ASSERT_NE(subgraph, nullptr); - auto subgraph_output = subgraph->FindFirstNodeMatchType(NETOUTPUT); - ASSERT_NE(subgraph_output, nullptr); - ASSERT_FALSE(subgraph_output->GetInControlNodes().empty()); - ASSERT_EQ((*subgraph_output->GetInControlNodes().begin())->GetType(), "Square"); -} - -TEST_F(UtestGraphUtils, BuildGraphFromNodes_case0) { - auto builder = ut::GraphBuilder("graph0"); - auto data1 = builder.AddNode("data1", DATA, 0, 1); - auto data2 = builder.AddNode("data2", DATA, 0, 1); - auto op1 = builder.AddNode("square1", "Square", 1, 1); - auto op2 = builder.AddNode("square2", "Square", 1, 1); - auto output = builder.AddNode("output", NETOUTPUT, 1, 0); - - builder.AddDataEdge(data1, 0, op1, 0); - builder.AddDataEdge(data2, 0, op2, 0); - builder.AddDataEdge(op2, 0, output, 0); - builder.AddControlEdge(op1, op2); - auto origin_graph = builder.GetGraph(); - - auto graph = GraphUtils::BuildGraphFromNodes({op1}, "graph1"); - ASSERT_NE(graph, nullptr); - auto graph_output = graph->FindFirstNodeMatchType(NETOUTPUT); - ASSERT_NE(graph_output, nullptr); - ASSERT_FALSE(graph_output->GetInControlNodes().empty()); - ASSERT_EQ((*graph_output->GetInControlNodes().begin())->GetType(), "Square"); -} - - -TEST_F(UtestGraphUtils, SingleOpScene) { - auto builder1 = ut::GraphBuilder("root"); - auto data1 = builder1.AddNode("data1", DATA, 0, 1); - auto graph1 = builder1.GetGraph(); - ASSERT_TRUE(AttrUtils::SetBool(graph1, ATTR_SINGLE_OP_SCENE, true)); - bool is_single_op = GraphUtils::IsSingleOpScene(graph1); - ASSERT_EQ(is_single_op, true); - - auto builder2 = ut::GraphBuilder("root"); - auto data2 = builder2.AddNode("data2", DATA, 0, 1); - AttrUtils::SetBool(data2->GetOpDesc(), ATTR_SINGLE_OP_SCENE, true); - auto graph2 = builder2.GetGraph(); - is_single_op = GraphUtils::IsSingleOpScene(graph2); - ASSERT_EQ(is_single_op, true); - - auto builder3 = ut::GraphBuilder("root"); - auto data3 = builder3.AddNode("data3", DATA, 0, 1); - auto graph3 = builder3.GetGraph(); - is_single_op = GraphUtils::IsSingleOpScene(graph3); - ASSERT_EQ(is_single_op, false); -} - -TEST_F(UtestGraphUtils, UnfoldSubgraph) { - ComputeGraphPtr graph; - ComputeGraphPtr subgraph; - BuildGraphForUnfold(graph, subgraph); - ASSERT_NE(graph, nullptr); - ASSERT_NE(subgraph, nullptr); - - const auto &filter = [](const ComputeGraphPtr &graph) { - const auto &parent_node = graph->GetParentNode(); - if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { - return false; - } - if ((parent_node->GetType() != PARTITIONEDCALL) || - (parent_node->GetOpDesc()->GetSubgraphInstanceNames().size() != 1)) { - return false; - } - return graph->GetGraphUnknownFlag(); - }; - ASSERT_EQ(GraphUtils::UnfoldSubgraph(subgraph, filter), GRAPH_SUCCESS); - - ASSERT_EQ(graph->GetAllSubgraphs().size(), 1); - ASSERT_FALSE(graph->GetAllSubgraphs()[0]->GetGraphUnknownFlag()); -} - -TEST_F(UtestGraphUtils, UnfoldSubgraph_InnerDataHasOutControl) { - ComputeGraphPtr graph; - ComputeGraphPtr subgraph; - BuildGraphForUnfoldWithControlEdge(graph, subgraph); - ASSERT_NE(graph, nullptr); - ASSERT_NE(subgraph, nullptr); - - const auto &filter = [](const ComputeGraphPtr &graph) { - const auto &parent_node = graph->GetParentNode(); - if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { - return false; - } - if (parent_node->GetType() == PARTITIONEDCALL) { - return true; - } - return false; - }; - ASSERT_EQ(GraphUtils::UnfoldSubgraph(subgraph, filter), GRAPH_SUCCESS); - ASSERT_EQ(graph->GetAllSubgraphs().size(), 0); - ASSERT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - -} - -TEST_F(UtestGraphUtils, UnfoldSubgraph_ForPartition) { - ComputeGraphPtr graph; - ComputeGraphPtr subgraph; - BuildGraphForUnfold(graph, subgraph); - ASSERT_NE(graph, nullptr); - ASSERT_NE(subgraph, nullptr); - std::vector inputs; - std::vector outputs; - const auto &new_graph = GraphUtils::CloneGraph(graph, "", inputs, outputs); - const auto &node_size_before_unfold = new_graph->GetDirectNode().size(); - const auto &filter = [](const ComputeGraphPtr &graph) { - const auto &parent_node = graph->GetParentNode(); - if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { - return false; - } - if ((parent_node->GetType() != PARTITIONEDCALL) || - (parent_node->GetOpDesc()->GetSubgraphInstanceNames().size() != 1)) { - return false; - } - return graph->GetGraphUnknownFlag(); - }; - ASSERT_EQ(GraphUtils::UnfoldGraph(subgraph, new_graph, new_graph->FindNode(subgraph->GetParentNode()->GetName()), - filter), GRAPH_SUCCESS); - ASSERT_NE(node_size_before_unfold, new_graph->GetDirectNode().size()); -} - -TEST_F(UtestGraphUtils, GetIndependentCompileGraphs) { - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioned_call0 = root_builder.AddNode("PartitionedCall", "PartitionedCall", 0, 0); - const auto &root_graph = root_builder.GetGraph(); - (void)AttrUtils::SetBool(*root_graph, ATTR_NAME_PIPELINE_PARTITIONED, true); - - auto sub_builder1 = ut::GraphBuilder("sub1"); - const auto &data1 = sub_builder1.AddNode("Data", "Data", 0, 0); - const auto &sub_graph1 = sub_builder1.GetGraph(); - root_graph->AddSubGraph(sub_graph1); - sub_graph1->SetParentNode(partitioned_call0); - sub_graph1->SetParentGraph(root_graph); - partitioned_call0->GetOpDesc()->AddSubgraphName("sub1"); - partitioned_call0->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - - std::vector independent_compile_subgraphs; - ASSERT_EQ(GraphUtils::GetIndependentCompileGraphs(root_graph, independent_compile_subgraphs), GRAPH_SUCCESS); - ASSERT_EQ(independent_compile_subgraphs.size(), 1); - ASSERT_EQ(independent_compile_subgraphs[0]->GetName(), "sub1"); - - (void)AttrUtils::SetBool(*root_graph, ATTR_NAME_PIPELINE_PARTITIONED, false); - independent_compile_subgraphs.clear(); - ASSERT_EQ(GraphUtils::GetIndependentCompileGraphs(root_graph, independent_compile_subgraphs), GRAPH_SUCCESS); - ASSERT_EQ(independent_compile_subgraphs.size(), 1); - ASSERT_EQ(independent_compile_subgraphs[0]->GetName(), "root"); -} - -TEST_F(UtestGraphUtils, InsertNodeAfter) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - - auto graph_builder1 = ut::GraphBuilder("test_graph1"); - const auto &node1 = graph_builder1.AddNode("data1", DATA, 1, 1); - const auto &graph1 = graph_builder1.GetGraph(); - - std::vector independent_compile_subgraphs; - ASSERT_EQ(GraphUtils::InsertNodeAfter(node0->GetOutDataAnchor(0), {}, node1, 0, 0), GRAPH_FAILED); -} - TEST_F(UtestGraphUtils, NoNeedDumpGraphBySuffixIsFalse) { - std::string suffix; - bool ret = GraphUtils::NoNeedDumpGraphBySuffix(suffix); - EXPECT_EQ(ret, true); - } - -TEST_F(UtestGraphUtils, NoNeedDumpGraphBySuffixLevel0) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "0", 1); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("Build"), true); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("test"), true); - unsetenv(kDumpGraphLevel); -} - - TEST_F(UtestGraphUtils, NoNeedDumpGraphBySuffixLevel1) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "1", 1); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix(""), true); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("Build"), false); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("test"), false); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("OptimizeSubGraph"), true); - unsetenv(kDumpGraphLevel); - } - - TEST_F(UtestGraphUtils, NoNeedDumpGraphBySuffixLevel2) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void) setenv(kDumpGraphLevel, "2", 1); - for (const auto &graph_name : kGeDumpWhitelistFullName) { - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix(graph_name), false); - } - for (const auto &graph_name : kGeDumpWhitelistKeyName) { - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("_" + graph_name), false); - } - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("PreRunAfterOptimizeGraphPrepare"), true); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("RunCustomAfterInferShape_sub_graph"), true); - unsetenv(kDumpGraphLevel); - } - - TEST_F(UtestGraphUtils, NoNeedDumpGraphBySuffixLevel3) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "3", 1); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("Build"), false); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("test"), true); - unsetenv(kDumpGraphLevel); - } - - TEST_F(UtestGraphUtils, NoNeedDumpGraphBySuffixLevel4) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "4", 1); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("PreRunBegin"), false); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("test"), true); - unsetenv(kDumpGraphLevel); - } - -TEST_F(UtestGraphUtils, NoNeedDumpGraphBySuffixWhitelist) { - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "RunBegin", 1); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("PreRunBegin"), false); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("test"), true); - - (void)setenv(kDumpGraphLevel, "RunBegin|test", 1); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("PreRunBegin"), false); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("test"), false); - EXPECT_EQ(GraphUtils::NoNeedDumpGraphBySuffix("other"), true); - unsetenv(kDumpGraphLevel); -} - -TEST_F(UtestGraphUtils, DumpGEGraph_aclop_success) { - std::string dump_graph_path = "./test_ge_graph_path"; - mmSetEnv("DUMP_GE_GRAPH", "1", 1); - mmSetEnv("DUMP_GRAPH_LEVEL", "1", 1); - mmSetEnv("DUMP_GRAPH_PATH", dump_graph_path.c_str(), 1); - auto graph = BuildGraphWithSubGraph(); - AttrUtils::SetBool(graph, ATTR_SINGLE_OP_SCENE, true); - GraphUtils::DumpGEGraph(graph, "test"); - std::stringstream dump_file_path = GetFilePathWhenDumpPathSet(dump_graph_path); - std::string file_path = ge::RealPath(dump_file_path.str().c_str()); - // root graph - ComputeGraphPtr compute_graph = std::make_shared("GeTestGraph"); - ASSERT_EQ(GraphUtils::LoadGEGraph(GetSpecificFilePath(file_path, "_aclop_").c_str(), *compute_graph), true); - - unsetenv("DUMP_GE_GRAPH"); - unsetenv("DUMP_GRAPH_LEVEL"); - unsetenv("DUMP_GRAPH_PATH"); - system(("rm -rf " + dump_graph_path).c_str()); -} - -TEST_F(UtestGraphUtils, DumpGEGraph) { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - - mmSetEnv("DUMP_GE_GRAPH", "1", 1); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - // test existed dir - GraphUtils::DumpGEGraph(graph, "", true, "./ge_test_graph_0001.txt"); - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - bool state = GraphUtils::LoadGEGraph("./ge_test_graph_0001.txt", *com_graph1); - ASSERT_EQ(state, true); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 4); - - // test not existed dir - GraphUtils::DumpGEGraph(graph, "", true, "./test/ge_test_graph_0002.txt"); - ComputeGraphPtr com_graph2 = std::make_shared("GeTestGraph2"); - state = GraphUtils::LoadGEGraph("./test/ge_test_graph_0002.txt", *com_graph2); - ASSERT_EQ(state, true); - - // test input para user_graph_name, without path - GraphUtils::DumpGEGraph(graph, "", true, "ge_test_graph_0003.txt"); - ComputeGraphPtr com_graph3 = std::make_shared("GeTestGraph3"); - state = GraphUtils::LoadGEGraph("./ge_test_graph_0003.txt", *com_graph3); - ASSERT_EQ(state, true); - unsetenv("DUMP_GE_GRAPH"); -} -TEST_F(UtestGraphUtils, DumpGEGraphNoOptionsSucc) { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - GEThreadLocalContext &context = GetThreadLocalContext(); - std::map graph_maps; - std::string key1 = "pk1"; - std::string value1 = "pv1"; - std::string key2 = "pk2"; - std::string value2 = "pv2"; - graph_maps.insert(std::make_pair(key1, value1)); - graph_maps.insert(std::make_pair(key2, value2)); - context.SetGraphOption(graph_maps); - - std::map session_maps; - key1 = "sk1"; - value1 = "sv1"; - key2 = "sk2"; - value2 = "sv2"; - session_maps.insert(std::make_pair(key1, value1)); - session_maps.insert(std::make_pair(key2, value2)); - context.SetSessionOption(session_maps); - - std::map global_maps; - key1 = "gk1"; - value1 = "gv1"; - key2 = "gk2"; - value2 = "gv2"; - global_maps.insert(std::make_pair(key1, value1)); - global_maps.insert(std::make_pair(key2, value2)); - context.SetGlobalOption(global_maps); - - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "4", 1); - const char_t *const kDumpGeGraph = "DUMP_GE_GRAPH"; - (void)setenv(kDumpGeGraph, "3", 1); - // test existed dir - system("rm -f ./ge_test_graph_options_wt_0001.txt"); - GraphUtils::DumpGEGraph(graph, "PreRunBegin", false, "./ge_test_graph_options_wt_0001.txt"); - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - bool state = GraphUtils::LoadGEGraph("./ge_test_graph_options_wt_0001.txt", *com_graph1); - ASSERT_EQ(state, true); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 4); - //check graph option - ge::NamedAttrs graphOptions; - EXPECT_EQ(AttrUtils::GetNamedAttrs(com_graph1, "GraphOptions", graphOptions),false); - ge::NamedAttrs sessionOptions; - EXPECT_EQ(AttrUtils::GetNamedAttrs(com_graph1, "SessionOptions", sessionOptions),false); - ge::NamedAttrs globalOptions; - EXPECT_EQ(AttrUtils::GetNamedAttrs(com_graph1, "GlobalOptions", globalOptions),false); - -} - -TEST_F(UtestGraphUtils, DumpGEGraphOptionsSucc) { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - GEThreadLocalContext &context = GetThreadLocalContext(); - std::map graph_maps; - std::string key1 = "pk1"; - std::string value1 = "pv1"; - std::string key2 = "pk2"; - std::string value2 = "pv2"; - graph_maps.insert(std::make_pair(key1, value1)); - graph_maps.insert(std::make_pair(key2, value2)); - context.SetGraphOption(graph_maps); - - std::map session_maps; - key1 = "sk1"; - value1 = "sv1"; - key2 = "sk2"; - value2 = "sv2"; - session_maps.insert(std::make_pair(key1, value1)); - session_maps.insert(std::make_pair(key2, value2)); - context.SetSessionOption(session_maps); - - std::map global_maps; - key1 = "gk1"; - value1 = "gv1"; - key2 = "gk2"; - value2 = "gv2"; - global_maps.insert(std::make_pair(key1, value1)); - global_maps.insert(std::make_pair(key2, value2)); - context.SetGlobalOption(global_maps); - - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "4", 1); - const char_t *const kDumpGeGraph = "DUMP_GE_GRAPH"; - (void)setenv(kDumpGeGraph, "1", 1); - // test existed dir - system("rm -f ./ge_test_graph_options_wt_0002.txt"); - GraphUtils::DumpGEGraph(graph, "PreRunBegin", false, "./ge_test_graph_options_wt_0002.txt"); - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - bool state = GraphUtils::LoadGEGraph("./ge_test_graph_options_wt_0002.txt", *com_graph1); - ASSERT_EQ(state, true); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 4); - //check graph option - ge::NamedAttrs graphOptions; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(com_graph1, "GraphOptions", graphOptions)); - EXPECT_EQ(graphOptions.GetName(), "GraphOptions"); - AnyValue av; - EXPECT_EQ(graphOptions.GetAttr("pk1", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "pv1"); - EXPECT_EQ(graphOptions.GetAttr("pk2", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "pv2"); - //check session option - ge::NamedAttrs sessionOptions; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(com_graph1, "SessionOptions", sessionOptions)); - EXPECT_EQ(sessionOptions.GetName(), "SessionOptions"); - EXPECT_EQ(sessionOptions.GetAttr("sk1", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "sv1"); - - EXPECT_EQ(sessionOptions.GetAttr("sk2", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "sv2"); - //check global option - ge::NamedAttrs globalOptions; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(com_graph1, "GlobalOptions", globalOptions)); - EXPECT_EQ(globalOptions.GetName(), "GlobalOptions"); - EXPECT_EQ(globalOptions.GetAttr("gk1", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "gv1"); - - EXPECT_EQ(globalOptions.GetAttr("gk2", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "gv2"); -} -TEST_F(UtestGraphUtils, DumpGEGraphOptionsLevelNot4) { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - GEThreadLocalContext &context = GetThreadLocalContext(); - std::map graph_maps; - std::string key1 = "pk1"; - std::string value1 = "pv1"; - std::string key2 = "pk2"; - std::string value2 = "pv2"; - graph_maps.insert(std::make_pair(key1, value1)); - graph_maps.insert(std::make_pair(key2, value2)); - context.SetGraphOption(graph_maps); - - std::map session_maps; - key1 = "sk1"; - value1 = "sv1"; - key2 = "sk2"; - value2 = "sv2"; - session_maps.insert(std::make_pair(key1, value1)); - session_maps.insert(std::make_pair(key2, value2)); - context.SetSessionOption(session_maps); - - std::map global_maps; - key1 = "gk1"; - value1 = "gv1"; - key2 = "gk2"; - value2 = "gv2"; - global_maps.insert(std::make_pair(key1, value1)); - global_maps.insert(std::make_pair(key2, value2)); - context.SetGlobalOption(global_maps); - - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "1", 1); - const char_t *const kDumpGeGraph = "DUMP_GE_GRAPH"; - (void)setenv(kDumpGeGraph, "1", 1); - // test existed dir - system("rm -f ./ge_test_graph_options_wt_0004.txt"); - GraphUtils::DumpGEGraph(graph, "test", false, "./ge_test_graph_options_wt_0004.txt"); - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - bool state = GraphUtils::LoadGEGraph("./ge_test_graph_options_wt_0004.txt", *com_graph1); - ASSERT_EQ(state, true); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 4); - //check graph option - ge::NamedAttrs graphOptions; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(com_graph1, "GraphOptions", graphOptions)); - EXPECT_EQ(graphOptions.GetName(), "GraphOptions"); - AnyValue av; - EXPECT_EQ(graphOptions.GetAttr("pk1", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "pv1"); - EXPECT_EQ(graphOptions.GetAttr("pk2", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "pv2"); - //check session option - ge::NamedAttrs sessionOptions; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(com_graph1, "SessionOptions", sessionOptions)); - EXPECT_EQ(sessionOptions.GetName(), "SessionOptions"); - EXPECT_EQ(sessionOptions.GetAttr("sk1", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "sv1"); - - EXPECT_EQ(sessionOptions.GetAttr("sk2", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "sv2"); - //check global option - ge::NamedAttrs globalOptions; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(com_graph1, "GlobalOptions", globalOptions)); - EXPECT_EQ(globalOptions.GetName(), "GlobalOptions"); - EXPECT_EQ(globalOptions.GetAttr("gk1", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "gv1"); - - EXPECT_EQ(globalOptions.GetAttr("gk2", av), GRAPH_SUCCESS); - EXPECT_NE(av.Get(), nullptr); - EXPECT_EQ(*av.Get(), "gv2"); -} - -TEST_F(UtestGraphUtils, DumpGEGraphOptionsNotPreRunBeginNoDump) { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - GEThreadLocalContext &context = GetThreadLocalContext(); - std::map graph_maps; - std::string key1 = "pk1"; - std::string value1 = "pv1"; - std::string key2 = "pk2"; - std::string value2 = "pv2"; - graph_maps.insert(std::make_pair(key1, value1)); - graph_maps.insert(std::make_pair(key2, value2)); - context.SetGraphOption(graph_maps); - - std::map session_maps; - key1 = "sk1"; - value1 = "sv1"; - key2 = "sk2"; - value2 = "sv2"; - session_maps.insert(std::make_pair(key1, value1)); - session_maps.insert(std::make_pair(key2, value2)); - context.SetSessionOption(session_maps); - - std::map global_maps; - key1 = "gk1"; - value1 = "gv1"; - key2 = "gk2"; - value2 = "gv2"; - global_maps.insert(std::make_pair(key1, value1)); - global_maps.insert(std::make_pair(key2, value2)); - context.SetGlobalOption(global_maps); - - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "4", 1); - const char_t *const kDumpGeGraph = "DUMP_GE_GRAPH"; - (void)setenv(kDumpGeGraph, "1", 1); - // test existed dir - system("rm -f ./ge_test_graph_options_wt_0003.txt"); - GraphUtils::DumpGEGraph(graph, "test", false, "./ge_test_graph_options_wt_0003.txt"); - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - bool state = GraphUtils::LoadGEGraph("./ge_test_graph_options_wt_0003.txt", *com_graph1); - ASSERT_EQ(state, false); -} - -TEST_F(UtestGraphUtils, CheckDumpGraphNum) { - std::map session_option{{"ge.maxDumpFileNum", "100"}}; - GetThreadLocalContext().SetSessionOption(session_option); - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - EXPECT_NO_THROW( - GraphUtils::DumpGEGrph(graph0, "./", "1"); - GraphUtils::DumpGEGrph(graph0, "./", "1"); - GraphUtils::DumpGEGrph(graph0, "./", "1"); - GraphUtils::DumpGEGrph(graph0, "./", "1"); - GraphUtils::DumpGEGrph(graph0, "./", "1"); - ); -} - -TEST_F(UtestGraphUtils, CopyRootComputeGraph) { - auto graph = BuildGraphWithSubGraph(); - // check origin graph size - ASSERT_EQ(graph->GetAllNodesSize(), 7); - ComputeGraphPtr dst_compute_graph = std::make_shared(ComputeGraph("dst")); - // test copy root graph success - auto ret = GraphUtils::CopyComputeGraph(graph, dst_compute_graph); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(dst_compute_graph->GetAllNodesSize(), 7); - // test copy subgraph failed - auto sub1_graph = graph->GetSubgraph("sub1"); - ret = GraphUtils::CopyComputeGraph(sub1_graph, dst_compute_graph); - ASSERT_EQ(ret, GRAPH_FAILED); - - // test copy dst compute_graph null - ComputeGraphPtr empty_dst_compute_graph; - ret = GraphUtils::CopyComputeGraph(graph, empty_dst_compute_graph); - ASSERT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, CopyComputeGraphWithNodeAndGraphFilter) { - auto graph = BuildGraphWithSubGraph(); - // check origin graph size - ASSERT_EQ(graph->GetAllNodesSize(), 5 + 1 + 1); - ComputeGraphPtr dst_compute_graph = std::make_shared(ComputeGraph("dst")); - auto node_filter = [&graph](const Node &node) { - // no filter node which not in root graph - if (node.GetOwnerComputeGraph()->GetName() != graph->GetName()) { - return true; - } - // filter root graph node when node name == "relu1" - if (node.GetName() == "relu1") { - return false; - } - // copy other nodes in root graph - return true; - }; - - auto graph_filter = [&graph](const Node &node, const char *, const ComputeGraphPtr &sub_graph) { - // sub2 graph not copy - return sub_graph->GetName() != "sub2"; - }; - // test copy root graph success - auto ret = GraphUtils::CopyComputeGraph(graph, node_filter, graph_filter, nullptr, dst_compute_graph); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(dst_compute_graph->GetAllNodesSize(), 4 + 1 + 0); - ASSERT_EQ(dst_compute_graph->GetDirectNodesSize(), 4); - ASSERT_EQ(dst_compute_graph->GetDirectNode().size(), 4); - ASSERT_EQ(dst_compute_graph->FindNode("relu1"), nullptr); - ASSERT_NE(dst_compute_graph->FindNode("relu0"), nullptr); - auto sub1_graph = dst_compute_graph->GetSubgraph("sub1"); - ASSERT_EQ(sub1_graph->GetDirectNodesSize(), 1); - ASSERT_NE(sub1_graph->GetDirectNode().at(0U), nullptr); - ASSERT_NE(sub1_graph->GetDirectNode().at(0U)->GetOpDesc(), nullptr); - ASSERT_EQ(sub1_graph->GetDirectNode().at(0U)->GetOpDesc()->GetId(), - graph->GetSubgraph("sub1")->GetDirectNode().at(0U)->GetOpDesc()->GetId()); - ASSERT_NE(sub1_graph, nullptr); - ASSERT_EQ(dst_compute_graph->GetSubgraph("sub2"), nullptr); -} - -TEST_F(UtestGraphUtils, CopyComputeGraphWithoutSubGraphRepeat) { - auto graph = BuildGraphWithSubGraph(); - ComputeGraphPtr dst_compute_graph = std::make_shared(ComputeGraph("dst")); - auto graph_filter = [](const Node &node, const char *, const ComputeGraphPtr &sub_graph) { - // all graph not copy - return false; - }; - // test copy root graph success - auto ret = GraphUtils::CopyComputeGraph(graph, nullptr, graph_filter, nullptr, dst_compute_graph); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(dst_compute_graph->GetDirectNodesSize(), graph->GetDirectNodesSize()); - NodePtr case0_node = dst_compute_graph->FindNode("case0"); - ASSERT_NE(case0_node, nullptr); - const auto &names = case0_node->GetOpDesc()->GetSubgraphInstanceNames(); - for (const auto &name:names) { - EXPECT_EQ(name, ""); - } - ASSERT_EQ(dst_compute_graph->GetSubgraph("sub1"), nullptr); - ASSERT_EQ(dst_compute_graph->GetSubgraph("sub2"), nullptr); - ComputeGraphPtr dst_compute_graph2 = std::make_shared(ComputeGraph("dst2")); - ret = GraphUtils::CopyComputeGraph(dst_compute_graph, nullptr, graph_filter, nullptr, dst_compute_graph2); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(dst_compute_graph2->GetDirectNodesSize(), graph->GetDirectNodesSize()); -} - -TEST_F(UtestGraphUtils, CopyComputeGraphWithAttrFilter) { - auto graph = BuildGraphWithConst(); - ComputeGraphPtr dst_compute_graph = std::make_shared(ComputeGraph("dst")); - const std::string const_node_name = "Const"; - auto const_node_with_weight_src = graph->FindNode(const_node_name); - auto attr_filter = [&const_node_name](const OpDesc &op_desc, const std::string &attr_name) { - // keep all attr for nodes which is not `const` - if (op_desc.GetName() != const_node_name) { - return true; - } - // const node not copy weights - if (attr_name == ge::ATTR_NAME_WEIGHTS) { - return false; - } - return true; - }; - - // test copy graph success with attr filter - ASSERT_EQ(GraphUtils::CopyComputeGraph(graph, nullptr, nullptr, attr_filter, dst_compute_graph), GRAPH_SUCCESS); - auto const_node_with_weight_dst = dst_compute_graph->FindNode(const_node_name); - ASSERT_NE(const_node_with_weight_dst, nullptr); - ASSERT_NE(const_node_with_weight_src, const_node_with_weight_dst); - // src node keep origin weight - ConstGeTensorPtr weight = nullptr; - ASSERT_TRUE(AttrUtils::GetTensor(const_node_with_weight_src->GetOpDesc(), ATTR_NAME_WEIGHTS, weight)); - ASSERT_NE(weight, nullptr); - ASSERT_EQ(weight->GetData().GetSize(), 4096U); - const uint8_t *buff = weight->GetData().GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - // dst node has not weight - ASSERT_FALSE(AttrUtils::GetTensor(const_node_with_weight_dst->GetOpDesc(), ATTR_NAME_WEIGHTS, weight)); - // dst node has other attr - std::string str_value; - ASSERT_TRUE(AttrUtils::GetStr(const_node_with_weight_dst->GetOpDesc(), "fake_attr_name", str_value)); - ASSERT_EQ("fake_attr_value", str_value); - auto add_node_dst = dst_compute_graph->FindNode("Add"); - ASSERT_NE(add_node_dst, nullptr); - // other node has all attr copyed - ASSERT_TRUE(AttrUtils::GetStr(add_node_dst->GetOpDesc(), "fake_attr_name", str_value)); - ASSERT_TRUE(AttrUtils::GetStr(add_node_dst->GetOpDesc(), ATTR_NAME_WEIGHTS, str_value)); - - // test copy graph success without attr filter - ComputeGraphPtr dst_compute_graph2 = std::make_shared(ComputeGraph("dst2")); - ASSERT_EQ(GraphUtils::CopyComputeGraph(graph, nullptr, nullptr, nullptr, dst_compute_graph2), GRAPH_SUCCESS); - auto const_node_with_weight_dst2 = dst_compute_graph2->FindNode(const_node_name); - ASSERT_NE(const_node_with_weight_dst2, nullptr); - ASSERT_NE(const_node_with_weight_src, const_node_with_weight_dst2); - ConstGeTensorPtr weight2 = nullptr; - ASSERT_TRUE(AttrUtils::GetTensor(const_node_with_weight_dst2->GetOpDesc(), ATTR_NAME_WEIGHTS, weight2)); - ASSERT_NE(weight2, nullptr); - ASSERT_EQ(weight2->GetData().GetSize(), 4096U); - const uint8_t *buff2 = weight2->GetData().GetData(); - // deep copy - ASSERT_NE(buff2, buff); - ASSERT_EQ((buff2 == nullptr), false); - ASSERT_EQ(buff2[0], 7); - ASSERT_EQ(buff2[10], 8); -} - -TEST_F(UtestGraphUtils, DumpGraphByPath) { - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 7; - data_buf[10] = 8; - ge_tensor->SetData(data_buf, 4096); - - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 0, 1); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), ge::ATTR_NAME_WEIGHTS, ge_tensor); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data_node, 0, add_node, 0); - builder.AddDataEdge(const_node, 0, add_node, 0); - builder.AddDataEdge(add_node, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - // test dump_level 0 - auto ret = GraphUtils::DumpGEGraphByPath(graph, "./not-exists-path/test_graph_0.txt", ge::DumpLevel::NO_DUMP); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - ret = GraphUtils::DumpGEGraphByPath(graph, "/", ge::DumpLevel::NO_DUMP); - ASSERT_EQ((ret != 0), true); - ret = GraphUtils::DumpGEGraphByPath(graph, "test_graph_0.txt", ge::DumpLevel::NO_DUMP); - ASSERT_EQ((ret != 0), true); - ret = GraphUtils::DumpGEGraphByPath(graph, "./test_graph_0.txt", ge::DumpLevel::NO_DUMP); - ASSERT_EQ(ret, 0); - ComputeGraphPtr com_graph0 = std::make_shared("TestGraph0"); - bool state = GraphUtils::LoadGEGraph("./test_graph_0.txt", *com_graph0); - ASSERT_EQ(state, true); - ASSERT_EQ(com_graph0->GetAllNodesSize(), 4); - for (auto &node_ptr : com_graph0->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == CONSTANT) { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), false); - } - } - - // test dump_level 1 - ret = GraphUtils::DumpGEGraphByPath(graph, "./test_graph_1.txt", ge::DumpLevel::DUMP_ALL); - ASSERT_EQ(ret, 0); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - state = GraphUtils::LoadGEGraph("./test_graph_1.txt", *com_graph1); - ASSERT_EQ(state, true); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 4); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == CONSTANT) { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - } - } -} - -TEST_F(UtestGraphUtils, AddEdgeAnchorPtrIsNull) { - AnchorPtr src; - AnchorPtr dst; - int ret = GraphUtils::AddEdge(src, dst); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, AddEdgeAnchorPtrSuccess) { - auto builder = ut::GraphBuilder("root"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - const auto &node1 = builder.AddNode("node1", "node", 1, 1); - int ret = GraphUtils::AddEdge(node0->GetOutAnchor(0), node1->GetInAnchor(0)); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - int ret2 = GraphUtils::AddEdge(node0->GetOutAnchor(0), node1->GetInControlAnchor()); - EXPECT_EQ(ret2, GRAPH_SUCCESS); - - int ret3 = GraphUtils::AddEdge(node0->GetOutControlAnchor(), node1->GetInControlAnchor()); - EXPECT_EQ(ret3, GRAPH_SUCCESS); - - int ret4 = GraphUtils::AddEdge(node0->GetOutControlAnchor(), node1->GetInDataAnchor(0)); - EXPECT_EQ(ret4, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, AddEdgeControlAnchorPtrIsNull) { - OutControlAnchorPtr src; - InControlAnchorPtr dst; - int ret = GraphUtils::AddEdge(src, dst); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, AddEdgeDataAnchorPtrIsNull) { - OutDataAnchorPtr src; - InControlAnchorPtr dst; - int ret = GraphUtils::AddEdge(src, dst); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveEdgeAnchorPtrIsNull) { - AnchorPtr src; - AnchorPtr dst; - int ret = GraphUtils::RemoveEdge(src, dst); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveEdgeOutDataAnchorPtrIsNull) { - OutDataAnchorPtr src; - InControlAnchorPtr dst; - int ret = GraphUtils::RemoveEdge(src, dst); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveEdgeFail) { - auto builder = ut::GraphBuilder("root"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - const auto &node1 = builder.AddNode("node1", "node", 1, 1); - builder.AddDataEdge(node0, 0, node1, 0); - builder.AddControlEdge(node0, node1); - int ret = GraphUtils::RemoveEdge(node0->GetOutDataAnchor(0), node1->GetInControlAnchor()); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, InsertNodeBetweenDataAnchorsSuccess) { - auto builder = ut::GraphBuilder("root"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - const auto &node1 = builder.AddNode("node1", "node", 1, 1); - const auto &node2 = builder.AddNode("node2", "node", 1, 1); - NodePtr new_node(node1); - builder.AddDataEdge(node0, 0, node2, 0); - builder.AddControlEdge(node0, node2); - int ret = GraphUtils::InsertNodeBetweenDataAnchors(node0->GetOutDataAnchor(0), - node2->GetInDataAnchor(0), new_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, RemoveSubgraphRecursivelyRemoveNodeIsNull) { - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - NodePtr remove_node; - int ret = GraphUtils::RemoveSubgraphRecursively(compute_graph, remove_node); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, RemoveSubgraphRecursivelyNodeNotInGraph) { - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - auto builder = ut::GraphBuilder("root"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - NodePtr remove_node(node0); - int ret = GraphUtils::RemoveSubgraphRecursively(compute_graph, remove_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveSubgraphRecursivelyNodeHasNoSubgrah) { - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - auto builder = ut::GraphBuilder("root"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - compute_graph->AddNode(node0); - node0->SetOwnerComputeGraph(compute_graph); - int ret = GraphUtils::RemoveSubgraphRecursively(compute_graph, node0); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UtestGraphUtils, RemoveNodeWithoutRelinkNodePtrIsNull) { - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - NodePtr remove_node; - int ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, remove_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveNodeWithoutRelinkFail) { - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - NodePtr remove_node = ComGraphMakeShared(); - OpDescPtr op_desc = ComGraphMakeShared(); - remove_node->impl_->op_ = op_desc; - compute_graph->AddNode(remove_node); - // owner graph is null - int ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, remove_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(compute_graph->GetDirectNodesSize(), 0U); - // owner graph is another - ComputeGraphPtr compute_graph_another = std::make_shared("Test1"); - remove_node->SetOwnerComputeGraph(compute_graph_another); - ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, remove_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveNodesWithoutRelinkl) { - auto builder = ut::GraphBuilder("root"); - std::unordered_set remove_nodes; - size_t node_size = 6U; - for (auto i = 0U; i < node_size; i++) { - auto node = builder.AddNode("node" + std::to_string(i), "Relu", 1, 1); - if (i == 0U) { - builder.GetGraph()->AddInputNode(node); - } - if (i == node_size - 1U) { - builder.GetGraph()->AddOutputNode(node); - } - remove_nodes.emplace(node); - } - EXPECT_TRUE(builder.GetGraph()->GetAllNodesSize() == node_size); - EXPECT_TRUE(builder.GetGraph()->GetInputNodes().size() == 1U); - EXPECT_TRUE(builder.GetGraph()->GetOutputNodes().size() == 1U); - int ret = GraphUtils::RemoveNodesWithoutRelink(builder.GetGraph(), remove_nodes); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_TRUE(builder.GetGraph()->GetAllNodesSize() == 0U); - EXPECT_TRUE(builder.GetGraph()->GetAllNodes().empty()); - EXPECT_TRUE(builder.GetGraph()->GetInputNodes().empty()); - EXPECT_TRUE(builder.GetGraph()->GetOutputNodes().empty()); -} - -TEST_F(UtestGraphUtils, InsertNodeAfterAddEdgefail) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &node1 = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &node2 = graph_builder0.AddNode("data2", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - std::vector dsts; - dsts.push_back(node1->GetInDataAnchor(0)); - int ret = GraphUtils::InsertNodeAfter(node0->GetOutDataAnchor(0), dsts, node2, 1, 0); - EXPECT_EQ(ret, GRAPH_FAILED); - int ret2 = GraphUtils::InsertNodeAfter(node0->GetOutDataAnchor(0), dsts, node2, 0, 1); - EXPECT_EQ(ret2, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, InsertNodeAfterTypeIsSwitch) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", SWITCH, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - std::vector dsts; - dsts.push_back(node0->GetInDataAnchor(0)); - int ret = GraphUtils::InsertNodeAfter(node0->GetOutDataAnchor(0), dsts, node0, 0, 0); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, InsertNodeAfterSrcOwnerComputeGraphNotEqualDstOwnerComputeGraph) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - - auto graph_builder1 = ut::GraphBuilder("test_graph1"); - const auto &node1 = graph_builder1.AddNode("data1", DATA, 1, 1); - const auto &graph1 = graph_builder1.GetGraph(); - - std::vector dsts; - dsts.push_back(node1->GetInDataAnchor(0)); - int ret = GraphUtils::InsertNodeAfter(node0->GetOutDataAnchor(0), dsts, node1, 0, 0); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, InsertNodeBeforeGetOwnerComputeGraphFail) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - - auto graph_builder1 = ut::GraphBuilder("test_graph1"); - const auto &node1 = graph_builder1.AddNode("data1", DATA, 1, 1); - const auto &graph1 = graph_builder1.GetGraph(); - - int ret = GraphUtils::InsertNodeBefore(node0->GetInDataAnchor(0), node1, 0, 0); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, InsertNodeBeforeInsertCodeGetInDataAnchorFail) { - auto builder = ut::GraphBuilder("test"); - const auto &var = builder.AddNode("var", VARIABLE, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 1, 1); - const auto &allreduce = builder.AddNode("allreduce", "HcomAllReduce", 1, 1); - const auto &atomic_clean = builder.AddNode("atomic_clean", ATOMICADDRCLEAN, 0, 0); - const auto &netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - const auto &identity = builder.AddNode("identity", "Identity", 1, 1); - - builder.AddDataEdge(var, 0, assign, 0); - builder.AddDataEdge(var,0,allreduce,0); - builder.AddControlEdge(assign, allreduce); - builder.AddControlEdge(atomic_clean, allreduce); - auto graph = builder.GetGraph(); - - int ret = GraphUtils::InsertNodeBefore(allreduce->GetInDataAnchor(0), identity, 0, 5); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveJustNodeNodeIsNull) { - ComputeGraph compute_graph("test_graph0"); - int ret = GraphUtils::RemoveJustNode(compute_graph, nullptr); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveJustNodeFail) { - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - auto graph_builder0 = ut::GraphBuilder("Test0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - int ret = GraphUtils::RemoveJustNode(compute_graph, node0); - EXPECT_EQ(ret, GRAPH_FAILED); -} - - -TEST_F(UtestGraphUtils, LoadGEGraphComputeGraphIsNull) { - char_t *file = nullptr; - ge::ComputeGraph compute_graph(""); - bool ret = GraphUtils::LoadGEGraph(file, compute_graph); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, LoadGEGraphFileIsNull) { - char_t *file = nullptr; - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - bool ret = GraphUtils::LoadGEGraph(file, compute_graph); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, LoadGEGraphComputeGraphPtrSuccess) { - const char_t *file = "./test_graph_0.txt"; - ComputeGraphPtr compute_graph = std::make_shared(""); - bool ret = GraphUtils::LoadGEGraph(file, compute_graph); - EXPECT_EQ(ret, true); -} - -TEST_F(UtestGraphUtils, ReadProtoFromTextFileFileIsNull) { - google::protobuf::Message *proto = nullptr; - bool ret = GraphUtils::ReadProtoFromTextFile(nullptr, proto); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, DumpGEGraphToOnnxForLongName) { - EXPECT_NO_THROW( - setenv("DUMP_GE_GRAPH", "1", 1); - ComputeGraph compute_graph("test_graph0"); - const std::string suffit = "ge_proto_00000001_AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" - "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" - "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA.pbtxt"; - ge::GraphUtils::DumpGEGraphToOnnx(compute_graph, suffit); - setenv("DUMP_GE_GRAPH", "1", 1); - ); -} - -TEST_F(UtestGraphUtils, IsolateNodeNodeIsNull) { - NodePtr node; - std::vector io_map = {1, 2, 3}; - int ret = GraphUtils::IsolateNode(node, io_map); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, IsolateNodeIoMapIsNull) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - std::vector io_map; - int ret = GraphUtils::IsolateNode(node0, io_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, IsolateNodeIoMapSizeIsGreaterThanOutDataAnchorsSize) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - std::vector io_map = {1, 2, 3, 4}; - int ret = GraphUtils::IsolateNode(node0, io_map); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, IsolateNodeOutDataAnchorsIsNull) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 0); - std::vector io_map = {1}; - int ret = GraphUtils::IsolateNode(node0, io_map); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, IsolateNodeInDataAnchorsIsNull) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 0, 1); - std::vector io_map = {1}; - int ret = GraphUtils::IsolateNode(node0, io_map); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, IsolateNodeInitializerListTest) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - std::initializer_list io_map; - int ret = GraphUtils::IsolateNode(node0, io_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, ReplaceNodeDataAnchorsNodeIsNull) { - NodePtr new_node; - NodePtr old_node; - std::vector inputs_map = {1, 2}; - std::vector outputs_map = {1, 2}; - int ret = GraphUtils::ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, ReplaceNodeDataAnchorsReplaceOutDataAnchorsFail) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &new_node = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &old_node = graph_builder0.AddNode("data0", DATA, 0, 0); - std::vector inputs_map; - std::vector outputs_map = {1, 2}; - int ret = GraphUtils::ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, ReplaceNodeDataAnchorsReplaceInDataAnchorsFail) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &new_node = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &old_node = graph_builder0.AddNode("data0", DATA, 0, 0); - std::vector inputs_map = {1, 2}; - std::vector outputs_map; - int ret = GraphUtils::ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, ReplaceNodeDataAnchorsSuccess) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &new_node = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &old_node = graph_builder0.AddNode("data0", DATA, 0, 0); - std::vector inputs_map; - std::vector outputs_map; - int ret = GraphUtils::ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, ReplaceNodesSuccess_all_data_anchors) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &data0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &relu0 = graph_builder0.AddNode("relu0", "Relu", 1, 1); - const auto &abs0 = graph_builder0.AddNode("abs0", "Abs", 1, 1); - graph_builder0.AddDataEdge(data0, 0, relu0, 0); - graph_builder0.AddDataEdge(relu0, 0, abs0, 0); - const auto &data1 = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &relu1 = graph_builder0.AddNode("relu1", "Relu", 1, 1); - const auto &abs1 = graph_builder0.AddNode("abs1", "Abs", 1, 1); - graph_builder0.AddDataEdge(data1, 0, relu1, 0); - graph_builder0.AddDataEdge(relu1, 0, abs1, 0); - const auto &add = graph_builder0.AddNode("add", "Add", 2, 1); - graph_builder0.AddDataEdge(relu0, 0, add, 0); - graph_builder0.AddDataEdge(relu1, 0, add, 1); - const auto &out = graph_builder0.AddNode("out", "NetOutput", 1, 1); - graph_builder0.AddDataEdge(add, 0, out, 0); - - const auto &relu_abs_add = graph_builder0.AddNode("relu_abs_add", "ReluAbsAdd", 2, 1); - EXPECT_EQ(graph_builder0.GetGraph()->GetDirectNodesSize(), 9); - std::vector inputs_map{1, 0}; - std::vector outputs_map{4}; - int ret = - GraphUtils::ReplaceNodesDataAnchors({relu_abs_add}, {relu0, relu1, abs0, abs1, add}, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 数据关系移动 - EXPECT_EQ(relu_abs_add->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*relu_abs_add->GetOutDataNodes().begin(), out); - EXPECT_EQ(relu_abs_add->GetInDataNodesSize(), 2U); - EXPECT_EQ(*relu_abs_add->GetInDataNodes().begin(), data1); - EXPECT_EQ(*(relu_abs_add->GetInDataNodes().begin() + 1), data0); - EXPECT_EQ(data0->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*(data0->GetOutDataNodes().begin()), relu_abs_add); - EXPECT_EQ(data1->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*(data1->GetOutDataNodes().begin()), relu_abs_add); - EXPECT_EQ(out->GetInDataNodesSize(), 1U); - EXPECT_EQ(*(out->GetInDataNodes().begin()), relu_abs_add); - EXPECT_EQ(graph_builder0.GetGraph()->GetDirectNodesSize(), 9); - const auto &noop_node = graph_builder0.GetGraph()->FindFirstNodeMatchType(NOOP); - EXPECT_TRUE(noop_node == nullptr); -} - -TEST_F(UtestGraphUtils, ReplaceNodesSuccess_all_data_anchors_and_keep_old_in_data_anchors) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &data0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &relu0 = graph_builder0.AddNode("relu0", "Relu", 1, 1); - const auto &abs0 = graph_builder0.AddNode("abs0", "Abs", 1, 1); - graph_builder0.AddDataEdge(data0, 0, relu0, 0); - graph_builder0.AddDataEdge(relu0, 0, abs0, 0); - const auto &data1 = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &relu1 = graph_builder0.AddNode("relu1", "Relu", 1, 1); - const auto &abs1 = graph_builder0.AddNode("abs1", "Abs", 1, 1); - graph_builder0.AddDataEdge(data1, 0, relu1, 0); - graph_builder0.AddDataEdge(relu1, 0, abs1, 0); - const auto &add = graph_builder0.AddNode("add", "Add", 2, 1); - graph_builder0.AddDataEdge(relu0, 0, add, 0); - graph_builder0.AddDataEdge(relu1, 0, add, 1); - const auto &out = graph_builder0.AddNode("out", "NetOutput", 1, 1); - graph_builder0.AddDataEdge(add, 0, out, 0); - - const auto &relu_abs_add = graph_builder0.AddNode("relu_abs_add", "ReluAbsAdd", 2, 1); - EXPECT_EQ(graph_builder0.GetGraph()->GetDirectNodesSize(), 9); - std::vector inputs_map{1, 0}; - std::vector outputs_map{4}; - int ret = - GraphUtils::CopyNodesInDataAnchors({relu_abs_add}, {relu0, relu1, abs0, abs1, add}, inputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = - GraphUtils::ReplaceNodesOutDataAnchors({relu_abs_add}, {relu0, relu1, abs0, abs1, add}, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 输出数据关系移动,输入数据关系拷贝 - EXPECT_EQ(relu_abs_add->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*relu_abs_add->GetOutDataNodes().begin(), out); - EXPECT_EQ(relu_abs_add->GetInDataNodesSize(), 2U); - EXPECT_EQ(*relu_abs_add->GetInDataNodes().begin(), data1); - EXPECT_EQ(*(relu_abs_add->GetInDataNodes().begin() + 1), data0); - EXPECT_EQ(data0->GetOutDataNodesSize(), 2U); - EXPECT_EQ(*(data0->GetOutDataNodes().begin()), relu0); - EXPECT_EQ(*(data0->GetOutDataNodes().begin() + 1), relu_abs_add); - EXPECT_EQ(data1->GetOutDataNodesSize(), 2U); - EXPECT_EQ(*(data1->GetOutDataNodes().begin()), relu1); - EXPECT_EQ(*(data1->GetOutDataNodes().begin() + 1), relu_abs_add); - EXPECT_EQ(out->GetInDataNodesSize(), 1U); - EXPECT_EQ(*(out->GetInDataNodes().begin()), relu_abs_add); - EXPECT_EQ(graph_builder0.GetGraph()->GetDirectNodesSize(), 9); - const auto &noop_node = graph_builder0.GetGraph()->FindFirstNodeMatchType(NOOP); - EXPECT_TRUE(noop_node == nullptr); -} - -TEST_F(UtestGraphUtils, ReplaceNodesSuccess_with_ctrl) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &data0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &relu0 = graph_builder0.AddNode("relu0", "Relu", 1, 1); - const auto &abs0 = graph_builder0.AddNode("abs0", "Abs", 1, 1); - graph_builder0.AddDataEdge(data0, 0, relu0, 0); - graph_builder0.AddDataEdge(relu0, 0, abs0, 0); - const auto &out = graph_builder0.AddNode("out", "NetOutput", 1, 1); - graph_builder0.AddDataEdge(abs0, 0, out, 0); - const auto &relu_abs = graph_builder0.AddNode("relu_abs", "ReluAbs", 1, 1); - - // 创建控制关系 - const auto &const_node0 = graph_builder0.AddNode("const0", CONSTANT, 1, 1); - const auto &const_node1 = graph_builder0.AddNode("const1", CONSTANT, 1, 1); - const auto &const_node2 = graph_builder0.AddNode("const2", CONSTANT, 1, 1); - - graph_builder0.AddControlEdge(const_node0, relu0); - graph_builder0.AddControlEdge(const_node0, abs0); - graph_builder0.AddControlEdge(const_node1, abs0); - graph_builder0.AddControlEdge(relu0, abs0); - graph_builder0.AddControlEdge(relu0, const_node2); - - std::vector inputs_map{0}; - std::vector outputs_map{1}; - int ret = GraphUtils::ReplaceNodesDataAnchors({relu_abs}, {relu0, abs0}, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = GraphUtils::InheritExecutionOrder({relu_abs}, {relu0, abs0}, graph_builder0.GetGraph()); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 数据关系移动 - EXPECT_EQ(relu_abs->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*(relu_abs->GetOutDataNodes().begin()), out); - EXPECT_EQ(relu_abs->GetInDataNodesSize(), 1U); - EXPECT_EQ(*relu_abs->GetInDataNodes().begin(), data0); - EXPECT_EQ(data0->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*(data0->GetOutDataNodes().begin()), relu_abs); - EXPECT_EQ(out->GetInDataNodesSize(), 1U); - EXPECT_EQ(*(out->GetInDataNodes().begin()), relu_abs); - - // 控制关系拷贝 - const auto &noop_in = graph_builder0.GetGraph()->FindNode("noop_in_relu_abs"); - EXPECT_TRUE(noop_in != nullptr); - EXPECT_EQ(noop_in->GetInControlNodesSize(), 2U); - EXPECT_EQ(*noop_in->GetInControlNodes().begin(), const_node0); - EXPECT_EQ(relu0->GetInControlNodesSize(), 1U); - EXPECT_EQ(abs0->GetInControlNodesSize(), 3U); - EXPECT_EQ(*(noop_in->GetInControlNodes().begin() + 1), const_node1); - EXPECT_EQ(relu_abs->GetInControlNodesSize(), 1U); - EXPECT_EQ(*(relu_abs->GetInControlNodes().begin()), noop_in); - - const auto &noop_out = graph_builder0.GetGraph()->FindNode("noop_out_relu_abs"); - EXPECT_TRUE(noop_out != nullptr); - EXPECT_EQ(noop_out->GetOutControlNodesSize(), 1U); - EXPECT_EQ(*noop_out->GetOutControlNodes().begin(), const_node2); - EXPECT_EQ(relu_abs->GetOutControlNodesSize(), 1U); - EXPECT_EQ(*relu_abs->GetOutControlNodes().begin(), noop_out); -} - -TEST_F(UtestGraphUtils, ReplaceNodesSuccess_with_ctrl_and_keep_old_in_data_anchors) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &data0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &relu0 = graph_builder0.AddNode("relu0", "Relu", 1, 1); - const auto &abs0 = graph_builder0.AddNode("abs0", "Abs", 1, 1); - graph_builder0.AddDataEdge(data0, 0, relu0, 0); - graph_builder0.AddDataEdge(relu0, 0, abs0, 0); - const auto &out = graph_builder0.AddNode("out", "NetOutput", 1, 1); - graph_builder0.AddDataEdge(abs0, 0, out, 0); - const auto &relu_abs = graph_builder0.AddNode("relu_abs", "ReluAbs", 1, 1); - - // 创建控制关系 - const auto &const_node0 = graph_builder0.AddNode("const0", CONSTANT, 1, 1); - const auto &const_node1 = graph_builder0.AddNode("const1", CONSTANT, 1, 1); - const auto &const_node2 = graph_builder0.AddNode("const2", CONSTANT, 1, 1); - - graph_builder0.AddControlEdge(const_node0, relu0); - graph_builder0.AddControlEdge(const_node0, abs0); - graph_builder0.AddControlEdge(const_node1, abs0); - graph_builder0.AddControlEdge(relu0, abs0); - graph_builder0.AddControlEdge(relu0, const_node2); - - std::vector inputs_map{0}; - std::vector outputs_map{1}; - int ret = GraphUtils::CopyNodesInDataAnchors({relu_abs}, {relu0, abs0}, inputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = GraphUtils::ReplaceNodesOutDataAnchors({relu_abs}, {relu0, abs0}, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = GraphUtils::InheritExecutionOrder({relu_abs}, {relu0, abs0}, graph_builder0.GetGraph()); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 输出数据关系移动,输入数据关系拷贝 - EXPECT_EQ(relu_abs->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*(relu_abs->GetOutDataNodes().begin()), out); - EXPECT_EQ(relu_abs->GetInDataNodesSize(), 1U); - EXPECT_EQ(*relu_abs->GetInDataNodes().begin(), data0); - EXPECT_EQ(data0->GetOutDataNodesSize(), 2U); - EXPECT_EQ(*(data0->GetOutDataNodes().begin()), relu0); - EXPECT_EQ(*(data0->GetOutDataNodes().begin() + 1), relu_abs); - EXPECT_EQ(out->GetInDataNodesSize(), 1U); - EXPECT_EQ(*(out->GetInDataNodes().begin()), relu_abs); - - // 控制关系拷贝 - const auto &noop_in = graph_builder0.GetGraph()->FindNode("noop_in_relu_abs"); - EXPECT_TRUE(noop_in != nullptr); - EXPECT_EQ(noop_in->GetInControlNodesSize(), 2U); - EXPECT_EQ(*noop_in->GetInControlNodes().begin(), const_node0); - EXPECT_EQ(relu0->GetInControlNodesSize(), 1U); - EXPECT_EQ(abs0->GetInControlNodesSize(), 3U); - EXPECT_EQ(*(noop_in->GetInControlNodes().begin() + 1), const_node1); - EXPECT_EQ(relu_abs->GetInControlNodesSize(), 1U); - EXPECT_EQ(*(relu_abs->GetInControlNodes().begin()), noop_in); - - const auto &noop_out = graph_builder0.GetGraph()->FindNode("noop_out_relu_abs"); - EXPECT_TRUE(noop_out != nullptr); - EXPECT_EQ(noop_out->GetOutControlNodesSize(), 1U); - EXPECT_EQ(*noop_out->GetOutControlNodes().begin(), const_node2); - EXPECT_EQ(relu_abs->GetOutControlNodesSize(), 1U); - EXPECT_EQ(*relu_abs->GetOutControlNodes().begin(), noop_out); -} - -TEST_F(UtestGraphUtils, ReplaceNodesSuccess_with_data_convert_ctrl) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &data0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &data1 = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &relu0 = graph_builder0.AddNode("relu0", "2In2OutReluFake", 2, 2); - const auto &relu1 = graph_builder0.AddNode("relu1", "RELU", 1, 1); - const auto &abs0 = graph_builder0.AddNode("abs0", "Abs", 1, 1); - graph_builder0.AddDataEdge(data0, 0, relu0, 0); - graph_builder0.AddDataEdge(data1, 0, relu0, 1); - graph_builder0.AddDataEdge(relu0, 0, abs0, 0); - graph_builder0.AddDataEdge(relu0, 1, relu1, 0); - const auto &out = graph_builder0.AddNode("out", "NetOutput", 1, 1); - graph_builder0.AddDataEdge(abs0, 0, out, 0); - const auto &relu_abs = graph_builder0.AddNode("relu_abs", "ReluAbs", 1, 1); - - // 创建控制关系 - const auto &const_node0 = graph_builder0.AddNode("const0", CONSTANT, 1, 1); - const auto &const_node1 = graph_builder0.AddNode("const1", CONSTANT, 1, 1); - const auto &const_node2 = graph_builder0.AddNode("const2", CONSTANT, 1, 1); - - graph_builder0.AddControlEdge(const_node0, relu0); - graph_builder0.AddControlEdge(const_node0, abs0); - graph_builder0.AddControlEdge(const_node1, abs0); - graph_builder0.AddControlEdge(relu0, abs0); - graph_builder0.AddControlEdge(relu0, const_node2); - - std::vector inputs_map{0}; - std::vector outputs_map{2}; - int ret = GraphUtils::ReplaceNodesDataAnchors({relu_abs}, {relu0, abs0}, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = GraphUtils::InheritExecutionOrder({relu_abs}, {relu0, abs0}, graph_builder0.GetGraph(), true); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 数据关系移动 - EXPECT_EQ(relu_abs->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*(relu_abs->GetOutDataNodes().begin()), out); - EXPECT_EQ(relu_abs->GetInDataNodesSize(), 1U); - EXPECT_EQ(*relu_abs->GetInDataNodes().begin(), data0); - EXPECT_EQ(data0->GetOutDataNodesSize(), 1U); - EXPECT_EQ(*(data0->GetOutDataNodes().begin()), relu_abs); - EXPECT_EQ(out->GetInDataNodesSize(), 1U); - EXPECT_EQ(*(out->GetInDataNodes().begin()), relu_abs); - - // 控制关系拷贝 - const auto &noop_in = graph_builder0.GetGraph()->FindNode("noop_in_relu_abs"); - EXPECT_TRUE(noop_in != nullptr); - EXPECT_EQ(noop_in->GetInControlNodesSize(), 3U); - EXPECT_EQ(*noop_in->GetInControlNodes().begin(), const_node0); - // 非io_map的data1数据输入也被转换为了noop_in的控制边 - EXPECT_EQ(*(noop_in->GetInControlNodes().begin() + 1), data1); - EXPECT_EQ(*(noop_in->GetInControlNodes().begin() + 2), const_node1); - EXPECT_EQ(relu0->GetInControlNodesSize(), 1U); - EXPECT_EQ(abs0->GetInControlNodesSize(), 3U); - EXPECT_EQ(relu_abs->GetInControlNodesSize(), 1U); - EXPECT_EQ(*(relu_abs->GetInControlNodes().begin()), noop_in); - - const auto &noop_out = graph_builder0.GetGraph()->FindNode("noop_out_relu_abs"); - EXPECT_TRUE(noop_out != nullptr); - EXPECT_EQ(noop_out->GetOutControlNodesSize(), 2U); - EXPECT_EQ(*noop_out->GetOutControlNodes().begin(), const_node2); - // 非io_map的relu1数据输出也被转换为了noop_out的控制边 - EXPECT_EQ(*(noop_out->GetOutControlNodes().begin() + 1), relu1); - EXPECT_EQ(relu_abs->GetOutControlNodesSize(), 1U); - EXPECT_EQ(*relu_abs->GetOutControlNodes().begin(), noop_out); -} - -TEST_F(UtestGraphUtils, ReplaceNodesFailed_as_diff_ownergraph) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &data0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &relu0 = graph_builder0.AddNode("relu0", "Relu", 1, 1); - const auto &abs0 = graph_builder0.AddNode("abs0", "Abs", 1, 1); - graph_builder0.AddDataEdge(data0, 0, relu0, 0); - graph_builder0.AddDataEdge(relu0, 0, abs0, 0); - const auto &out = graph_builder0.AddNode("out", "NetOutput", 1, 1); - graph_builder0.AddDataEdge(abs0, 0, out, 0); - auto graph_builder1 = ut::GraphBuilder("test_graph1"); - const auto &relu_abs = graph_builder1.AddNode("relu_abs", "ReluAbs", 1, 1); - std::vector inputs_map{0}; - std::vector outputs_map{0}; - int ret = GraphUtils::ReplaceNodesDataAnchors({relu_abs}, {relu0, abs0}, inputs_map, outputs_map); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, IsolateNodeOneIONodeIsNull) { - NodePtr node; - int ret = GraphUtils::IsolateNodeOneIO(node); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, IsolateNodeOneIOInDataIs0) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node = graph_builder0.AddNode("data1", DATA, 0, 1); - int ret = GraphUtils::IsolateNodeOneIO(node); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, IsolateNodeOneIOOutDataIs0) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node = graph_builder0.AddNode("data1", DATA, 1, 0); - int ret = GraphUtils::IsolateNodeOneIO(node); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, IsolateNodeOneIOSuccess) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node = graph_builder0.AddNode("data1", DATA, 1, 1); - int ret = GraphUtils::IsolateNodeOneIO(node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, ReplaceNodeAnchorsNodeIsNull) { - NodePtr new_node; - NodePtr old_node; - std::vector inputs_map = {1, 2}; - std::vector outputs_map = {1, 2}; - int ret = GraphUtils::ReplaceNodeAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, ReplaceNodeAnchorsReplaceNodeDataAnchorsFail) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &new_node = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &old_node = graph_builder0.AddNode("data0", DATA, 0, 0); - std::vector inputs_map = {1, 2}; - std::vector outputs_map = {1, 2}; - int ret = GraphUtils::ReplaceNodeAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, ReplaceNodeAnchorsSuccess) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &new_node = builder.AddNode("data1", "node", 1, 1); - const auto &old_node = builder.AddNode("data0", "node", 1, 1); - builder.AddDataEdge(new_node, 0, old_node, 0); - builder.AddControlEdge(new_node, old_node); - std::vector inputs_map = {0}; - std::vector outputs_map = {0}; - int ret = GraphUtils::ReplaceNodeAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, ReplaceNodeAnchorsInitializerListTest) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &new_node = builder.AddNode("data1", "node", 1, 1); - const auto &old_node = builder.AddNode("data0", "node", 1, 1); - builder.AddDataEdge(new_node, 0, old_node, 0); - builder.AddControlEdge(new_node, old_node); - std::initializer_list inputs_map; - std::initializer_list outputs_map; - int ret = GraphUtils::ReplaceNodeAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, ReplaceNodeDataAnchorsInitializerListTest) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &new_node = builder.AddNode("data1", DATA, 1, 1); - const auto &old_node = builder.AddNode("data0", DATA, 1, 1); - std::initializer_list inputs_map; - std::initializer_list outputs_map; - int ret = GraphUtils::ReplaceNodeDataAnchors(new_node, old_node, inputs_map, outputs_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, CopyInCtrlEdgesNodeIsNull) { - NodePtr src_node; - NodePtr dst_node; - int ret = GraphUtils::CopyInCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, CopyInCtrlEdgesSrcCtrlInNodesIsEmpty) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("data0", "data", 1, 1); - NodePtr dst_node = builder.AddNode("data1", "data", 1, 1); - int ret = GraphUtils::CopyInCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, CopyInCtrlEdgesSuccess) { - auto builder = ut::GraphBuilder("test"); - const auto &src_node = builder.AddNode("src_node", "node", 1, 1); - NodePtr dst_node = builder.AddNode("dst_node", "node", 1, 1); - builder.AddDataEdge(src_node, 0, dst_node, 0); - builder.AddControlEdge(src_node, dst_node); - int ret = GraphUtils::CopyInCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, MoveInCtrlEdgesNodeIsNull) { - NodePtr src_node; - NodePtr dst_node; - int ret = GraphUtils::MoveInCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, MoveInCtrlEdgesSuccess) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("data0", "data", 1, 1); - NodePtr dst_node = builder.AddNode("data1", "data", 1, 1); - int ret = GraphUtils::MoveInCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, CopyOutCtrlEdgesNodeIsNull) { - NodePtr src_node; - NodePtr dst_node; - int ret = GraphUtils::CopyOutCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, CopyOutCtrlEdgesOutCtrlNodesIsEmpty) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("data0", "data", 1, 1); - NodePtr dst_node = builder.AddNode("data1", "data", 1, 1); - int ret = GraphUtils::CopyOutCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, CopyOutCtrlEdgesSuccess) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("src_node", NETOUTPUT, 1, 1); - NodePtr dst_node = builder.AddNode("dst_node", NETOUTPUT, 1, 1); - auto graph = builder.GetGraph(); - builder.AddControlEdge(src_node, dst_node); - - int ret = GraphUtils::CopyOutCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, CopyOutCtrlEdgesSuccess_with_filter) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("src_node", NETOUTPUT, 1, 1); - const auto &ctrl_node = builder.AddNode("ctrl_node", CONSTANT, 0, 0); - const auto &ctrl_node2 = builder.AddNode("ctrl_node2", CONSTANT, 0, 0); - NodePtr dst_node = builder.AddNode("dst_node", NETOUTPUT, 1, 1); - auto graph = builder.GetGraph(); - builder.AddControlEdge(src_node, ctrl_node); - builder.AddControlEdge(src_node, ctrl_node2); - NodeFilter node_filter = [&](const Node &node) { return node.GetName() == ctrl_node2->GetName(); }; - int ret = GraphUtils::CopyOutCtrlEdges(src_node, dst_node, node_filter); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - EXPECT_EQ(dst_node->GetOutControlNodesSize(), src_node->GetOutControlNodesSize() - 1U); - EXPECT_EQ(dst_node->GetOutControlNodesSize(), 1U); - EXPECT_EQ(dst_node->GetOutControlNodes().at(0U), ctrl_node2); -} - -TEST_F(UtestGraphUtils, MoveOutCtrlEdgesNodeIsNull) { - auto builder = ut::GraphBuilder("test_graph0"); - NodePtr src_node; - NodePtr dst_node; - int ret = GraphUtils::MoveOutCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, MoveOutCtrlEdgesSuccess) { - auto builder = ut::GraphBuilder("test_graph0"); - NodePtr src_node = builder.AddNode("src_node", NETOUTPUT, 1, 1); - NodePtr dst_node = builder.AddNode("dst_node", NETOUTPUT, 1, 1); - int ret = GraphUtils::MoveOutCtrlEdges(src_node, dst_node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, AppendInputNodeSuccess) { - ComputeGraphPtr compute_graph = std::make_shared("Test0"); - auto builder = ut::GraphBuilder("Test1"); - const auto &node = builder.AddNode("node", "node", 1, 1); - int ret = GraphUtils::AppendInputNode(compute_graph, node); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, CopyGraphDstrGraphIsNull) { - Graph src_graph("test0"); - Graph dst_graph(""); - int ret = GraphUtilsEx::CopyGraph(src_graph, dst_graph); - EXPECT_EQ(ret, ge::PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, GetUserInputDatas) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &usr_data_node0 = builder.AddNode("node0", DATA, 1, 1); - const auto &usr_data_node1 = builder.AddNode("node1", DATA, 1, 1); - const auto &usr_data_node2 = builder.AddNode("node2", REFDATA, 1, 1); - const auto &usr_data_node3 = builder.AddNode("node3", AIPPDATA, 1, 1); - const auto &multi_batch_data = builder.AddNode("ascend_mbatch_shape_data", DATA, 1, 1); - ge::AttrUtils::SetBool(multi_batch_data->GetOpDesc(), "_is_multi_batch_shape_data", true); - const auto &const_node = builder.AddNode("const_node0", CONSTANT, 0, 0); - NodePtr output_node = builder.AddNode("output", NETOUTPUT, 1, 1); - auto graph = builder.GetGraph(); - const auto input_nodes = ge::GraphUtilsEx::GetUserInputDataNodes(graph); - EXPECT_EQ(input_nodes.size(), 4); - for (const auto &node : input_nodes) { - EXPECT_NE(node->GetName(), "ascend_mbatch_shape_data"); - } -} - -TEST_F(UtestGraphUtils, CopyComputeGraphDepthGreaterThanKCopyGraphMaxRecursionDepth) { - ComputeGraphPtr src_compute_graph = std::make_shared("Test0"); - ComputeGraphPtr dst_compute_graph = std::make_shared("Test1"); - std::map node_old_2_new; - std::map op_desc_old_2_new; - int32_t depth = 20; - int ret = - GraphUtils::CopyComputeGraph(src_compute_graph, dst_compute_graph, node_old_2_new, op_desc_old_2_new, depth); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, CopyMembersSrcComputerGraphIsNull) { - ComputeGraphPtr dst_compute_graph = std::make_shared("Test1"); - std::unordered_map all_new_nodes; - int ret = - GraphUtils::CopyMembers(nullptr, dst_compute_graph, all_new_nodes); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, CopyMembersDstComputerGraphIsNull) { - ComputeGraphPtr src_compute_graph = std::make_shared("Test0"); - ComputeGraphPtr dst_compute_graph; - std::unordered_map all_new_nodes; - int ret = GraphUtils::CopyMembers(src_compute_graph, dst_compute_graph, all_new_nodes); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, CloneGraph) { - auto builder = ut::GraphBuilder("Test1"); - const auto &node0 = builder.AddNode("node0", DATA, 1, 1); - const auto &node1 = builder.AddNode("node1", NETOUTPUT, 1, 1); - auto graph = builder.GetGraph(); - (void) AttrUtils::SetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, "0"); - std::string prefix; - std::vector input_nodes; - std::vector output_nodes; - std::unordered_map all_new_nodes; - ComputeGraphPtr new_compute_graph = GraphUtils::CloneGraph(graph, prefix, input_nodes, output_nodes); - EXPECT_NE(new_compute_graph, nullptr); -} - -TEST_F(UtestGraphUtils, CopyTensorAttrsDstDescIsNull) { - OpDescPtr dst_desc; - auto builder = ut::GraphBuilder("Test1"); - const auto &src_node = builder.AddNode("src_node", DATA, 1, 1); - int ret = GraphUtils::CopyTensorAttrs(dst_desc, src_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, CopyTensorAttrsSrcNodeIsNull) { - OpDescPtr dst_desc = std::make_shared("test", "test"); - NodePtr src_node; - int ret = GraphUtils::CopyTensorAttrs(dst_desc, src_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, CopyTensorAttrsFail) { - OpDescPtr dst_desc = std::make_shared(); - auto builder = ut::GraphBuilder("Test1"); - const auto &src_node = builder.AddNode("src_node", DATA, 1, 1); - int ret = GraphUtils::CopyTensorAttrs(dst_desc, src_node); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RelinkGraphEdgesNodeIsNull) { - NodePtr node; - std::string prefix; - std::unordered_map all_nodes; - int ret = GraphUtils::RelinkGraphEdges(node, prefix, all_nodes); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RelinkGraphEdgesAllNodesIsNull) { - auto builder = ut::GraphBuilder("Test1"); - const auto &node = builder.AddNode("node", DATA, 1, 1); - std::string prefix; - std::unordered_map all_nodes; - int ret = GraphUtils::RelinkGraphEdges(node, prefix, all_nodes); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RelinkGraphEdgesOutCtlNotEmpty) { - auto builder = ut::GraphBuilder("Test1"); - const auto &node1 = builder.AddNode("node1", "node1", 2, 2); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - builder.AddControlEdge(node1, node2); - std::string prefix; - std::unordered_map all_nodes; - all_nodes.emplace(node1->GetName(), node1); - int ret = GraphUtils::RelinkGraphEdges(node1, prefix, all_nodes); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, RelinkGraphEdgesFail) { - auto builder = ut::GraphBuilder("Test1"); - const auto &node1 = builder.AddNode("node1", DATA, 1, 1); - const auto &node2 = builder.AddNode("node2", DATA, 1, 1); - std::string prefix; - std::unordered_map all_nodes; - all_nodes.insert(make_pair("node2", node2)); - int ret = GraphUtils::RelinkGraphEdges(node1, prefix, all_nodes); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, GetRefMappingSuccess) { - auto builder = ut::GraphBuilder("Test1"); - auto graph = builder.GetGraph(); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, FindNodeFromAllNodesGraphIsNull) { - ComputeGraphPtr graph; - std::string name; - NodePtr node = GraphUtils::FindNodeFromAllNodes(graph, name); - EXPECT_EQ(node, nullptr); -} - -TEST_F(UtestGraphUtils, FindNodeFromAllNodesSuccess) { - auto builder = ut::GraphBuilder("Test1"); - const auto &node1 = builder.AddNode("node1", DATA, 1, 1); - auto graph = builder.GetGraph(); - std::string name = "node1"; - NodePtr node = GraphUtils::FindNodeFromAllNodes(graph, name); - EXPECT_EQ(node->GetName(), "node1"); -} - -TEST_F(UtestGraphUtils, FindNodeFromAllNodesNameIsNull) { - auto builder = ut::GraphBuilder("Test1"); - auto graph = builder.GetGraph(); - std::string name; - NodePtr node = GraphUtils::FindNodeFromAllNodes(graph, name); - EXPECT_EQ(node, nullptr); -} - -TEST_F(UtestGraphUtils, HandleInAnchorMappingSuccess) { - ComputeGraphPtr graph = std::make_shared("Test0"); - auto builder = ut::GraphBuilder("Test1"); - const auto &node1 = builder.AddNode("node1", NETOUTPUT, 1, 1); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::HandleInAnchorMapping(graph, node1, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, HandleInAnchorMappingNodeTypeIsMERGE) { - ComputeGraphPtr graph = std::make_shared("Test0"); - auto builder = ut::GraphBuilder("Test1"); - const auto &node1 = builder.AddNode("node1", MERGE, 1, 1); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::HandleInAnchorMapping(graph, node1, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, HandleSubgraphInputFail) { - auto builder = ut::GraphBuilder("Test1"); - const auto &node1 = builder.AddNode("node1", DATA, 1, 1); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::HandleSubgraphInput(node1, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, HandleSubgraphInputUpdateRefMappingFail) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(var1, 0, func, 1); - builder.AddDataEdge(func, 0, netoutput, 0); - auto graph = builder.GetGraph(); - graph->SetParentNode(func); - - AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::HandleSubgraphInput(input1, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, HandleSubgraphInputSuccess) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - auto graph = builder.GetGraph(); - graph->SetParentNode(func); - - AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::HandleSubgraphInput(input1, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, HandleMergeInputPeerOutAnchorIsNull) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - auto graph = builder.GetGraph(); - graph->SetParentNode(func); - - AttrUtils::SetStr(input1->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, "data1"); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::HandleMergeInput(input1, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, HandleMergeInputPeerOutAnchorIsNotNull) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(var1, 0, func, 1); - builder.AddDataEdge(func, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - SymbolToAnchors symbol_to_anchors; - NodeIndexIO node_index_io(func, 0, kOut); - std::list symbol_list; - symbol_list.push_back(node_index_io); - symbol_to_anchors.insert(pair>("var1_out_0", symbol_list)); - - AnchorToSymbol anchor_to_symbol; - anchor_to_symbol.insert(pair("data1_out_0", "var1_out_0")); - int ret = GraphUtils::HandleMergeInput(func, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, HandleSubgraphOutput) { - auto builder = ut::GraphBuilder("test2"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(var1, 0, func, 1); - builder.AddDataEdge(func, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - graph->SetParentNode(func); - AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - - SymbolToAnchors symbol_to_anchors; - NodeIndexIO node_index_io(func, 0, kOut); - std::list symbol_list; - symbol_list.push_back(node_index_io); - symbol_to_anchors.insert(pair>("var1_out_0", symbol_list)); - - AnchorToSymbol anchor_to_symbol; - anchor_to_symbol.insert(pair("data1_out_0", "var1_out_0")); - int ret = GraphUtils::HandleSubgraphOutput(func, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, ge::PARAM_INVALID); -} - -TEST_F(UtestGraphUtils, UnionSymbolMappingSuccess) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &input2 = builder.AddNode("data2", DATA, 1, 1); - const auto &var2 = builder.AddNode("var2", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(var1, 0, func, 1); - builder.AddDataEdge(input2, 0, func, 2); - builder.AddDataEdge(var2, 0, func, 3); - builder.AddDataEdge(func, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - graph->SetParentNode(func); - AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - - SymbolToAnchors symbol_to_anchors; - NodeIndexIO node_index1(input1, 0, kOut); - NodeIndexIO node_index2(input2, 0, kOut); - std::list symbol_list; - symbol_list.push_back(node_index1); - symbol_list.push_back(node_index2); - symbol_to_anchors.insert(pair>("var1_out_0", symbol_list)); - symbol_to_anchors.insert(pair>("var2_out_0", symbol_list)); - - AnchorToSymbol anchor_to_symbol; - anchor_to_symbol.insert(pair("data1_out_0", "var1_out_0")); - anchor_to_symbol.insert(pair("data2_out_0", "var2_out_0")); - - std::string symbol; - int ret = GraphUtils::UnionSymbolMapping(node_index1, node_index2, symbol_to_anchors, anchor_to_symbol, symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -/* - * data1 var1 data2 var2 - * \ | | / - * \ | / / - * func - * | - * netoutput - */ -TEST_F(UtestGraphUtils, UpdateRefMappingFailed) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &input2 = builder.AddNode("data2", DATA, 1, 1); - const auto &var2 = builder.AddNode("var2", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(var1, 0, func, 1); - builder.AddDataEdge(input2, 0, func, 2); - builder.AddDataEdge(var2, 0, func, 3); - builder.AddDataEdge(func, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - graph->SetParentNode(func); - AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - - SymbolToAnchors symbol_to_anchors; - NodeIndexIO cur_node_info(input1, 0, kOut); - NodeIndexIO exist_node_info(input2, 0, kOut); - std::list symbol_list; - symbol_list.push_back(cur_node_info); - symbol_list.push_back(exist_node_info); - symbol_to_anchors.insert(pair>("var1_out_0", symbol_list)); - symbol_to_anchors.insert(pair>("var2_out_0", symbol_list)); - - AnchorToSymbol anchor_to_symbol; - anchor_to_symbol.insert(pair("data1_out_0", "var1_out_0")); - anchor_to_symbol.insert(pair("data2_out_0", "var2_out_0")); - - std::string symbol; - int ret = GraphUtils::UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, UpdateRefMappingSymbolToAnchorsIsNull) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 1, 1); - const auto &var1 = builder.AddNode("var1", VARIABLEV2, 1, 1); - const auto &input2 = builder.AddNode("data2", DATA, 1, 1); - const auto &var2 = builder.AddNode("var2", VARIABLEV2, 1, 1); - const auto &func = builder.AddNode("func", PARTITIONEDCALL, 4, 1); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(input1, 0, func, 0); - builder.AddDataEdge(var1, 0, func, 1); - builder.AddDataEdge(input2, 0, func, 2); - builder.AddDataEdge(var2, 0, func, 3); - builder.AddDataEdge(func, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - graph->SetParentNode(func); - AttrUtils::SetInt(input1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(input2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - - NodeIndexIO cur_node_info(input1, 0, kOut); - NodeIndexIO exist_node_info(input2, 0, kOut); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - anchor_to_symbol.insert(pair("data1_out_0", "var1_out_0")); - anchor_to_symbol.insert(pair("data2_out_0", "var2_out_0")); - - std::string symbol; - int ret = GraphUtils::UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, IsRefFromInputOutDataAnchorPtrIsNull) { - OutDataAnchorPtr out_data_anchor; - int32_t reuse_in_index; - bool ret = GraphUtils::IsRefFromInput(out_data_anchor, reuse_in_index); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, IsRefFromInputFail) { - auto builder = ut::GraphBuilder("test0"); - const auto &node0 = builder.AddNode("node0", "node", 1, 1); - int32_t reuse_in_index; - bool ret = GraphUtils::IsRefFromInput(node0->GetOutDataAnchor(0), reuse_in_index); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, IsRefFromInputPassThroughOK) { - auto builder = ut::GraphBuilder("test0"); - const auto &node0 = builder.AddNode("node0", NETOUTPUT, 1, 1); - int32_t reuse_in_index; - bool ret = GraphUtils::IsRefFromInput(node0->GetOutDataAnchor(0), reuse_in_index); - EXPECT_EQ(ret, true); -} - -TEST_F(UtestGraphUtils, IsRefFromInputTypeIsMergeSuccess) { - auto builder = ut::GraphBuilder("test0"); - const auto &node0 = builder.AddNode("node0", MERGE, 1, 1); - int32_t reuse_in_index; - bool ret = GraphUtils::IsRefFromInput(node0->GetOutDataAnchor(0), reuse_in_index); - EXPECT_EQ(ret, true); -} - -TEST_F(UtestGraphUtils, IsRefFromInputTypeIsReshapeSuccess) { - auto builder = ut::GraphBuilder("test0"); - const auto &node0 = builder.AddNode("node0", RESHAPE, 1, 1); - int32_t reuse_in_index; - bool ret = GraphUtils::IsRefFromInput(node0->GetOutDataAnchor(0), reuse_in_index); - EXPECT_EQ(ret, true); - EXPECT_EQ(reuse_in_index, 0); -} - -TEST_F(UtestGraphUtils, IsRefFromInputRefOpFail) { - auto builder = ut::GraphBuilder("test0"); - const auto &node1 = builder.AddNode("node", "node", 1, 1); - AttrUtils::SetBool(node1->GetOpDesc(), ATTR_NAME_REFERENCE, true); - - int32_t reuse_in_index; - bool ret = GraphUtils::IsRefFromInput(node1->GetOutDataAnchor(0), reuse_in_index); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, IsNoPaddingRefFromInputSuccess) { - auto builder = ut::GraphBuilder("test0"); - const auto &node1 = builder.AddNode("node", "node", 1, 1); - AttrUtils::SetBool(node1->GetOpDesc(), ATTR_NAME_NOPADDING_CONTINUOUS_INPUT, true); - AttrUtils::SetBool(node1->GetOpDesc(), ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT, true); - AttrUtils::SetBool(node1->GetOpDesc(), ATTR_NAME_OUTPUT_REUSE_INPUT, true); - - int32_t reuse_in_index; - bool ret = GraphUtils::IsNoPaddingRefFromInput(node1->GetOutDataAnchor(0), reuse_in_index); - EXPECT_EQ(ret, true); -} - -TEST_F(UtestGraphUtils, IsNodeInGraphRecursivelySuccess) { - ComputeGraphPtr graph = std::make_shared("test0"); - Node node; - node.SetOwnerComputeGraph(graph); - - bool ret = GraphUtils::IsNodeInGraphRecursively(graph, node); - EXPECT_EQ(ret, true); -} - -TEST_F(UtestGraphUtils, IsNodeInGraphRecursivelyFail) { - auto builder = ut::GraphBuilder("test0"); - Node node; - node.SetOwnerComputeGraph(builder.GetGraph()); - ComputeGraphPtr graph = std::make_shared("test1"); - bool ret = GraphUtils::IsNodeInGraphRecursively(graph, node); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, IsUnknownShapeGraphFail) { - ComputeGraphPtr graph = std::make_shared("test1"); - bool ret = GraphUtils::IsUnknownShapeGraph(graph); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, IsUnknownShapeGraphGraphIsNull) { - ComputeGraphPtr graph; - bool ret = GraphUtils::IsUnknownShapeGraph(graph); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestGraphUtils, IsUnknownShapeGraphSuccess) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("add", "Add", 2, 1, FORMAT_NHWC, DT_FLOAT, {16, 228, 228, 3}); - auto graph = builder.GetGraph(); - - auto add_node = graph->FindNode("add"); - auto out_desc = add_node->GetOpDesc()->MutableOutputDesc(0); - out_desc->SetShape(GeShape({-1, 228, 228, 3})); - - bool ret = GraphUtils::IsUnknownShapeGraph(graph); - EXPECT_EQ(ret, true); -} - -TEST_F(UtestGraphUtils, UnfoldSubgraphSuccess) { - ut::GraphBuilder builder = ut::GraphBuilder("test0"); - auto graph = builder.GetGraph(); - std::function filter; - int ret = GraphUtils::UnfoldSubgraph(graph, filter); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, MergeInputNodesFail) { - auto builder = ut::GraphBuilder("test0"); - const auto &node1 = builder.AddNode("node", DATA, 1, 1); - auto graph = builder.GetGraph(); - graph->SetParentNode(node1); - - int ret = GraphUtils::MergeInputNodes(graph, node1); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, MergeInputNodesWithIndex) { - auto builder = ut::GraphBuilder("test0"); - const auto &node1 = builder.AddNode("node", DATA, 1, 1); - AttrUtils::SetInt(node1->GetOpDesc(), "index", 0); - const auto &abs = builder.AddNode("abs_in_subgraph", "abs", 1, 0); - builder.AddDataEdge(node1, 0, abs, 0); - auto graph = builder.GetGraph(); - auto builder1 = ut::GraphBuilder("parent_of_test0"); - const auto &data0 = builder1.AddNode("data", DATA, 1, 1); - const auto &relu = builder1.AddNode("target_node", "relu", 1, 1); - builder1.AddDataEdge(data0, 0, relu, 0); - graph->SetParentNode(relu); - - int ret = GraphUtils::MergeInputNodes(graph, relu); - EXPECT_EQ(ret, SUCCESS); - EXPECT_TRUE(data0->GetOutDataNodes().size() == 1U); - EXPECT_TRUE((*data0->GetOutDataNodes().begin())->GetName() == "abs_in_subgraph"); -} - -TEST_F(UtestGraphUtils, MergeNetOutputNodeSuccess) { - auto builder = ut::GraphBuilder("test2"); - const auto &node1 = builder.AddNode("node", DATA, 1, 1); - auto graph = builder.GetGraph(); - graph->SetParentNode(node1); - - int ret = GraphUtils::MergeNetOutputNode(graph, node1); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UtestGraphUtils, RemoveJustNodeGraphImplIsNull) { - ComputeGraph compute_graph(""); - compute_graph.impl_ = nullptr; - auto graph_builder0 = ut::GraphBuilder("Test0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - int ret = GraphUtils::RemoveJustNode(compute_graph, node0); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestGraphUtils, RemoveJustNodes) { - auto graph_builder0 = ut::GraphBuilder("Test0"); - const auto &node0 = graph_builder0.AddNode("data0", DATA, 1, 1); - const auto &node1 = graph_builder0.AddNode("data1", DATA, 1, 1); - const auto &node2 = graph_builder0.AddNode("data2", DATA, 1, 1); - EXPECT_EQ(graph_builder0.GetGraph()->GetDirectNodesSize(), 3U); - std::unordered_set remove_nodes; - remove_nodes.insert(node0); - remove_nodes.insert(node1); - EXPECT_EQ(GraphUtils::RemoveJustNodes(graph_builder0.GetGraph(), remove_nodes), GRAPH_SUCCESS); - EXPECT_EQ(graph_builder0.GetGraph()->GetDirectNodesSize(), 1U); - // remove nodes not in graph, also return success - EXPECT_EQ(GraphUtils::RemoveJustNodes(graph_builder0.GetGraph(), remove_nodes), GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, GetNodeFail) { - UtestComputeGraphBuilder graph; - NodePtr node_ptr = graph.GetNode("node1"); - EXPECT_EQ(node_ptr, nullptr); -} - -TEST_F(UtestGraphUtils, GetAllNodeNodeSizeIs0) { - UtestComputeGraphBuilder graph; - std::vector node_ptr = graph.GetAllNodes(); - EXPECT_EQ(node_ptr.size(), 0); -} - -TEST_F(UtestGraphUtils, BuildExistNodesTest) { - PartialGraphBuilder builder; - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - builder.BuildExistNodes(err, msg); - EXPECT_TRUE(err == GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); - - builder.exist_nodes_.push_back(nullptr); - builder.BuildExistNodes(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - - builder.exist_nodes_.clear(); - auto gbuilder = ut::GraphBuilder("test2"); - auto node = gbuilder.AddNode("node", DATA, 1, 1); - auto opdsc = std::make_shared("node1", "node"); - builder.AddExistNode(node); - builder.AddNode(opdsc); - EXPECT_EQ(builder.exist_nodes_.size(), 1); - builder.BuildExistNodes(err, msg); - EXPECT_TRUE(err == GRAPH_FAILED); - EXPECT_NE(msg, ""); - - err = GRAPH_SUCCESS; - msg = ""; - builder.owner_graph_ = node->GetOwnerComputeGraph(); - builder.BuildExistNodes(err, msg); - EXPECT_TRUE(err == GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); -} - -TEST_F(UtestGraphUtils, PartialGraphBuilderBuildTest) { - PartialGraphBuilder par_graph_builder; - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - ComputeGraphPtr computer_graph; - computer_graph = par_graph_builder.Build(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "graph is NULL."); - EXPECT_EQ(computer_graph, nullptr); - - auto builder = ut::GraphBuilder("test1"); - auto node = builder.AddNode("node", DATA, 1, 1); - par_graph_builder.SetOwnerGraph(node->GetOwnerComputeGraph()); - computer_graph = par_graph_builder.Build(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "graph is NULL."); - EXPECT_EQ(computer_graph, nullptr); -} - -TEST_F(UtestGraphUtils, CompleteGraphBuilderBuilder) { - CompleteGraphBuilder complete_builder(""); - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - complete_builder.Build(err, msg); - EXPECT_TRUE(err == GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); -} - -TEST_F(UtestGraphUtils, CompleteGraphBuilderBuildGraphTargets) { - CompleteGraphBuilder complete_builder("test1"); - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - //node_names_ is null - complete_builder.AddTarget("Data_1"); - complete_builder.BuildGraphTargets(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); -} - -TEST_F(UtestGraphUtils, BuildNetOutputNodeWithLinkTest) { - CompleteGraphBuilder complete_builder("test1"); - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - auto builder = ut::GraphBuilder("test2"); - auto node = builder.AddNode("node", DATA, 1, 1); - auto node2 = builder.AddNode("node2", NETOUTPUT, 1, 0); - complete_builder.owner_graph_ = node->GetOwnerComputeGraph(); - - OpDescPtr net_output_desc; - std::vector peer_out_anchors; - complete_builder.BuildNetOutputNodeWithLink(net_output_desc, peer_out_anchors, err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - - err = GRAPH_SUCCESS; - msg = ""; - net_output_desc = std::make_shared("test", "test"); - complete_builder.AddTarget("Data_1"); - complete_builder.BuildNetOutputNodeWithLink(net_output_desc, peer_out_anchors, err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - - err = GRAPH_SUCCESS; - msg = ""; - uint32_t index = 1; - complete_builder.input_mapping_.insert(pair(1, 0)); - auto ret_node = complete_builder.AddDataNode(index, err, msg); - EXPECT_EQ(ret_node, complete_builder.node_names_["Data_1"]); - complete_builder.BuildNetOutputNodeWithLink(net_output_desc, peer_out_anchors, err, msg); - EXPECT_TRUE(err == GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); -} - -TEST_F(UtestGraphUtils, AddDataNodeTest) { - CompleteGraphBuilder complete_builder("test1"); - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - auto builder = ut::GraphBuilder("test2"); - auto node = builder.AddNode("node", DATA, 1, 1); - - uint32_t index = 1; - complete_builder.input_mapping_.insert(pair(1, 1)); - complete_builder.owner_graph_ = node->GetOwnerComputeGraph(); - - auto ret_node = complete_builder.AddDataNode(index, err, msg); - EXPECT_TRUE(err == GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); - EXPECT_EQ(ret_node, complete_builder.node_names_["Data_1"]); -} - -TEST_F(UtestGraphUtils, AddNetOutputNodeTest) { - CompleteGraphBuilder complete_builder("test1"); - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - // graph_outputs_ and graph_targets_ is null - complete_builder.AddNetOutputNode(err, msg); - EXPECT_TRUE(err == GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); - - // node_names_ is null - complete_builder.AddTarget("out"); - complete_builder.graph_outputs_.push_back(pair("out", 0)); - complete_builder.AddNetOutputNode(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - - // node is nullptr - err = GRAPH_SUCCESS; - msg = ""; - complete_builder.node_names_.insert(pair("out", nullptr)); - complete_builder.AddNetOutputNode(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "AddNetOutputNode failed: node is NULL."); - err = GRAPH_SUCCESS; - msg = ""; - auto compute_graph = ge::ComGraphMakeSharedAndThrow("test"); - complete_builder.owner_graph_ = compute_graph; - auto data_node = compute_graph->AddNode(OpDescBuilder("out", "Relu").AddInput("x").AddOutput("y").Build()); - complete_builder.node_names_["out"] = data_node; - complete_builder.output_mapping_.emplace(0, 0); - complete_builder.AddNetOutputNode(err, msg); - EXPECT_EQ(err, GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); -} - -TEST_F(UtestGraphUtils, AddRetValNodesTest) { - CompleteGraphBuilder complete_builder("test1"); - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - //node_names_ is null - complete_builder.graph_outputs_.push_back(pair("Data_1", 0)); - complete_builder.AddRetValNodes(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "AddRetValNode failed: node Data_1 does not exist in graph."); - - //node_names_ node is nullptr - err = GRAPH_SUCCESS; - msg = ""; - complete_builder.node_names_.insert(pair("Data_1", nullptr)); - complete_builder.AddRetValNodes(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "AddRetValNode failed: node is NULL."); - - //node_names_ node is not nullptr - auto builder = ut::GraphBuilder("test2"); - auto node = builder.AddNode("node", DATA, 1, 0); - complete_builder.owner_graph_ = node->GetOwnerComputeGraph(); - - complete_builder.node_names_.clear(); - complete_builder.node_names_.insert(pair("Data_1", node)); - complete_builder.output_mapping_.insert(pair(0, 0)); - err = GRAPH_SUCCESS; - msg = ""; - complete_builder.AddRetValNodes(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); -} - -TEST_F(UtestGraphUtils, BuildCtrlLinksTest) { - PartialGraphBuilder par_builder; - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - auto builder = ut::GraphBuilder("test1"); - auto node = builder.AddNode("node_input", DATA, 1, 1); - auto node2 = builder.AddNode("node_output", NETOUTPUT, 1, 1); - par_builder.SetOwnerGraph(node->GetOwnerComputeGraph()); - - par_builder.AddControlLink("node_input", "node_output"); - ComputeGraphPtr graph; - graph = par_builder.Build(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - EXPECT_EQ(graph, nullptr); - - par_builder.node_names_.insert(pair("node_input", nullptr)); - par_builder.node_names_.insert(pair("node_output", nullptr)); - err = GRAPH_SUCCESS; - msg = ""; - graph = par_builder.Build(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - EXPECT_EQ(graph, nullptr); - - par_builder.node_names_.clear(); - par_builder.node_names_.insert(pair("node_input", node)); - par_builder.node_names_.insert(pair("node_output", node2)); - err = GRAPH_SUCCESS; - msg = ""; - graph = par_builder.Build(err, msg); - EXPECT_TRUE(err == GRAPH_SUCCESS); - EXPECT_EQ(msg, ""); - EXPECT_EQ(graph, node->GetOwnerComputeGraph()); -} - -TEST_F(UtestGraphUtils, BuildDataLinksTest) { - PartialGraphBuilder par_builder; - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - auto builder = ut::GraphBuilder("test1"); - auto node = builder.AddNode("node_input", DATA, 1, 1); - auto node2 = builder.AddNode("node_output", NETOUTPUT, 1, 1); - par_builder.SetOwnerGraph(node->GetOwnerComputeGraph()); - - par_builder.AddDataLink("node_input", 1, "node_output", 1); - ComputeGraphPtr graph; - graph = par_builder.Build(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - EXPECT_EQ(graph, nullptr); - - par_builder.node_names_.insert(pair("node_input", nullptr)); - par_builder.node_names_.insert(pair("node_output", nullptr)); - err = GRAPH_SUCCESS; - msg = ""; - graph = par_builder.Build(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_NE(msg, ""); - EXPECT_EQ(graph, nullptr); -} - -TEST_F(UtestGraphUtils, PostProcessTest) { - CompleteGraphBuilder complete_builder("test1"); - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - auto builder = ut::GraphBuilder("test2"); - auto node1 = builder.AddNode("node1", DATA, 1, 1); - auto owner_graph = node1->GetOwnerComputeGraph(); - complete_builder.owner_graph_ = owner_graph; - - auto builder2 = ut::GraphBuilder("test3"); - auto node2 = builder2.AddNode("node", "node", 1, 1); - complete_builder.parent_node_ = node2; - auto parent_graph = complete_builder.parent_node_->GetOwnerComputeGraph(); - - std::string graph_id; - AttrUtils::SetStr(parent_graph, ATTR_NAME_SESSION_GRAPH_ID, graph_id); - - AnyValue any_value; - any_value.SetValue(1); - complete_builder.parent_node_->GetOwnerComputeGraph()->SetAttr(ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, any_value); - AttrUtils::SetBool(node1->GetOpDesc(), ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true); - - complete_builder.PostProcess(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "Copy attr _dynamic_shape_partitioned failed."); -} - - -TEST_F(UtestGraphUtils, GetRefMappingTest) { - ComputeGraphPtr graph = std::make_shared("test0"); - auto op_desc = std::make_shared("node1", "node1"); - graph->AddNode(op_desc); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestGraphUtils, ComputeGraphBuilderBuildNodesTest) { - UtestComputeGraphBuilder utest_graph_builder; - graphStatus err = GRAPH_SUCCESS; - std::string msg = ""; - - //owner_graph_ is null - utest_graph_builder.BuildNodes(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "graph is NULL."); - - //nodes_ is null - auto builder = ut::GraphBuilder("test1"); - auto node1 = builder.AddNode("node1", DATA, 1, 1); - auto owner_graph = node1->GetOwnerComputeGraph(); - utest_graph_builder.owner_graph_ = owner_graph; - err = GRAPH_SUCCESS; - msg = ""; - utest_graph_builder.nodes_.push_back(nullptr); - utest_graph_builder.BuildNodes(err, msg); - EXPECT_EQ(err, GRAPH_FAILED); - EXPECT_EQ(msg, "op_desc is NULL."); -} - - -TEST_F(UtestGraphUtils, FindNodeByTypeFromAllGraphs) { - auto graph = BuildGraphWithSubGraph(); - ASSERT_NE(graph, nullptr); - auto nodes = GraphUtils::FindNodesByTypeFromAllNodes(graph, "Data"); - EXPECT_EQ(nodes.size(), 3); - const auto &bare_nodes = GraphUtils::FindBareNodesByTypeFromAllNodes(graph, "Data"); - EXPECT_EQ(bare_nodes.size(), 3); - EXPECT_EQ(nodes.at(0).get(), bare_nodes.at(0)); - EXPECT_EQ(nodes.at(1).get(), bare_nodes.at(1)); - EXPECT_EQ(nodes.at(2).get(), bare_nodes.at(2)); -} - -TEST_F(UtestGraphUtils, RemoveNodesByTypeWithoutRelinkPlaceholder) { - ComputeGraphPtr graph = std::make_shared("test_placeholder"); - BuildGraphWithPlaceholderAndEnd(graph); - ASSERT_NE(graph, nullptr); - auto ret = GraphUtils::RemoveNodesByTypeWithoutRelink(graph, "PlaceHolder"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto nodes = GraphUtils::FindNodesByTypeFromAllNodes(graph, "PlaceHolder"); - EXPECT_EQ(nodes.size(), 0); - const auto &bare_nodes = GraphUtils::FindBareNodesByTypeFromAllNodes(graph, "PlaceHolder"); - EXPECT_EQ(bare_nodes.size(), 0); -} - -TEST_F(UtestGraphUtils, RemoveNodesByTypeWithoutRelinkEnd) { - ComputeGraphPtr graph = std::make_shared("test_end"); - BuildGraphWithPlaceholderAndEnd(graph); - ASSERT_NE(graph, nullptr); - auto ret = GraphUtils::RemoveNodesByTypeWithoutRelink(graph, "End"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto nodes = GraphUtils::FindNodesByTypeFromAllNodes(graph, "End"); - EXPECT_EQ(nodes.size(), 0); - const auto &bare_nodes = GraphUtils::FindBareNodesByTypeFromAllNodes(graph, "End"); - EXPECT_EQ(bare_nodes.size(), 0); -} - -TEST_F(UtestGraphUtils, RemoveNodesByTypeWithoutRelinkAdd) { - ComputeGraphPtr graph = std::make_shared("test_end"); - BuildGraphWithPlaceholderAndEnd(graph); - ASSERT_NE(graph, nullptr); - auto ret = GraphUtils::RemoveNodesByTypeWithoutRelink(graph, "Add"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto nodes = GraphUtils::FindNodesByTypeFromAllNodes(graph, "Add"); - EXPECT_EQ(nodes.size(), 0); - const auto &bare_nodes = GraphUtils::FindBareNodesByTypeFromAllNodes(graph, "Add"); - EXPECT_EQ(bare_nodes.size(), 0); -} - -TEST_F(UtestGraphUtils, RemoveNodesByTypeWithoutRelinkData) { - ComputeGraphPtr graph = std::make_shared("test_end"); - BuildGraphWithPlaceholderAndEnd(graph); - ASSERT_NE(graph, nullptr); - auto ret = GraphUtils::RemoveNodesByTypeWithoutRelink(graph, DATA); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto nodes = GraphUtils::FindNodesByTypeFromAllNodes(graph, DATA); - EXPECT_EQ(nodes.size(), 0); - const auto &bare_nodes = GraphUtils::FindBareNodesByTypeFromAllNodes(graph, "Data"); - EXPECT_EQ(bare_nodes.size(), 0); -} - -TEST_F(UtestGraphUtils, FindNodeByTypeFromAllGraphsNullInput) { - ComputeGraphPtr graph = nullptr; - auto nodes = GraphUtils::FindNodesByTypeFromAllNodes(graph, "Data"); - EXPECT_EQ(nodes.size(), 0); - const auto &bare_nodes = GraphUtils::FindBareNodesByTypeFromAllNodes(graph, "Data"); - EXPECT_EQ(bare_nodes.size(), 0); -} -namespace { -void CheckAnchor(const std::list &all_anchors_of_symbol, - const std::unordered_set &expect_anchors) { - for (auto iter_e = all_anchors_of_symbol.begin(); iter_e != all_anchors_of_symbol.end(); ++iter_e) { - EXPECT_EQ(expect_anchors.count((*iter_e).ToString()), 1); - } -} - -void PrintAnchors(const SymbolToAnchors &symbol_to_anchors) { - std::stringstream ss; - for (const auto &pair : symbol_to_anchors) { - ss << pair.first << " : "; - ss << "[ "; - for (const auto &anchor : pair.second) { - ss << anchor.ToString() << "|"; - } - ss << " ]"; - } - std::cout << ss.str() << std::endl; -} -} // namespace -/* - refdata(a) const(b) - \ / - assign - |(a) - | - transdata - |(a) - | - netoutput -*/ -TEST_F(UtestGraphUtils, GetRefMappingWithRefData) { - auto builder = ut::GraphBuilder("test1"); - const auto &refdata = builder.AddNode("refdata", REFDATA, 1, 1); - const auto &const1 = builder.AddNode("const1", CONSTANT, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 2, 1); - const auto &transdata = builder.AddNode("transdata", "TransData", 1, 1); - AttrUtils::SetStr(transdata->GetOpDesc()->MutableOutputDesc(0), REF_VAR_SRC_VAR_NAME, "refdata"); - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(refdata, 0, assign, 0); - builder.AddDataEdge(const1, 0, assign, 1); - builder.AddDataEdge(assign, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 当前图共5个symbol - EXPECT_EQ(symbol_to_anchors.size(), 5); - PrintAnchors(symbol_to_anchors); - // 校验transdata输出和refdata共享一个symbol - NodeIndexIO transdata_out_info(transdata, 0, kOut); - auto iter = anchor_to_symbol.find(transdata_out_info.ToString()); - EXPECT_NE(iter, anchor_to_symbol.end()); - std::string symbol_transdata = iter->second; - - NodeIndexIO refdata_info(refdata, 0, kOut); - iter = anchor_to_symbol.find(refdata_info.ToString()); - EXPECT_NE(iter, anchor_to_symbol.end()); - std::string symbol_ref_data = iter->second; - - EXPECT_STREQ(symbol_transdata.c_str(), symbol_ref_data.c_str()); - - // 校验图中refdata的symbol, 有4个tensor共享 - auto iter_a = symbol_to_anchors.find(symbol_transdata); - EXPECT_NE(iter_a, symbol_to_anchors.end()); - EXPECT_EQ(iter_a->second.size(), 4); - - NodeIndexIO assing_in_0_info(assign, 0, kIn); - NodeIndexIO netoutput_in_0_info(netoutput, 0, kIn); - std::unordered_set expect_anchors_set{refdata_info.ToString(), transdata_out_info.ToString(), - assing_in_0_info.ToString(), netoutput_in_0_info.ToString()}; - CheckAnchor(iter_a->second, expect_anchors_set); -} -/* - refdata(a) - | - identity const(b) - \ / - assign - |(a) - | - transdata - | - | - netoutput - 如果refdata和assign中间插入了identity, HandleOutAnchorMapping中,assign上带有ref_var_src_var_name值是refdata,但是同时 - assign是输出引用输入,导致符号建立错误。 -*/ -TEST_F(UtestGraphUtils, GetRefMappingWithRefData_Failed_BecauseInsertIdentity) { - auto builder = ut::GraphBuilder("test1"); - const auto &refdata = builder.AddNode("refdata", REFDATA, 0, 1); - const auto &identity = builder.AddNode("identity", IDENTITY, 1, 1); - const auto &const1 = builder.AddNode("const1", CONSTANT, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 2, 1); - const auto &transdata = builder.AddNode("transdata", "TransData", 1, 1); - AttrUtils::SetStr(assign->GetOpDesc()->MutableOutputDesc(0), REF_VAR_SRC_VAR_NAME, "refdata"); - AttrUtils::SetBool(assign->GetOpDesc(), ATTR_NAME_REFERENCE, true); - assign->GetOpDesc()->MutableAllInputName() = {{"x", 0}}; - assign->GetOpDesc()->MutableAllOutputName() = {{"x", 0}}; - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(refdata, 0, identity, 0); - builder.AddDataEdge(identity, 0, assign, 0); - builder.AddDataEdge(const1, 0, assign, 1); - builder.AddDataEdge(assign, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol); - EXPECT_NE(ret, GRAPH_SUCCESS); -} -/* - refdata(a) const(b) - \ / - assign - |(a) - | - transdata - | - | - netoutput -*/ -TEST_F(UtestGraphUtils, GetRefMappingWithRefData_Success) { - auto builder = ut::GraphBuilder("test1"); - const auto &refdata = builder.AddNode("refdata", REFDATA, 0, 1); - const auto &const1 = builder.AddNode("const1", CONSTANT, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 2, 1); - const auto &transdata = builder.AddNode("transdata", "TransData", 1, 1); - AttrUtils::SetStr(assign->GetOpDesc()->MutableOutputDesc(0), REF_VAR_SRC_VAR_NAME, "refdata"); - AttrUtils::SetBool(assign->GetOpDesc(), ATTR_NAME_REFERENCE, true); - assign->GetOpDesc()->MutableAllInputName() = {{"x", 0}}; - assign->GetOpDesc()->MutableAllOutputName() = {{"x", 0}}; - const auto &netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - builder.AddDataEdge(refdata, 0, assign, 0); - builder.AddDataEdge(const1, 0, assign, 1); - builder.AddDataEdge(assign, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} -/* - data data - \ / - merge - | - | - cast - | - | - netoutput -*/ -TEST_F(UtestGraphUtils, GetRefMappingWithMergeOp) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::GetRefMapping(graph, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 当前图共2个symbol - EXPECT_EQ(symbol_to_anchors.size(), 2); - PrintAnchors(symbol_to_anchors); - // 校验merge输出和input1,input2共享一个symbol - NodeIndexIO merge_out(merge, 0, kOut); - auto iter = anchor_to_symbol.find(merge_out.ToString()); - EXPECT_NE(iter, anchor_to_symbol.end()); - std::string symbol_merge = iter->second; - - NodeIndexIO input1_info(input1, 0, kOut); - iter = anchor_to_symbol.find(input1_info.ToString()); - EXPECT_NE(iter, anchor_to_symbol.end()); - std::string symbol_input1 = iter->second; - EXPECT_STREQ(symbol_merge.c_str(), symbol_input1.c_str()); - - NodeIndexIO input2_info(input2, 0, kOut); - iter = anchor_to_symbol.find(input2_info.ToString()); - EXPECT_NE(iter, anchor_to_symbol.end()); - std::string symbol_input2 = iter->second; - EXPECT_STREQ(symbol_merge.c_str(), symbol_input2.c_str()); -} - -TEST_F(UtestGraphUtils, GetRefMappingWithSubgraphOp) { - auto root_builder = ut::GraphBuilder("root"); - const auto &data = root_builder.AddNode("data", DATA, 0, 1); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 1, 1); - const auto &out = root_builder.AddNode("out", NETOUTPUT, 1, 1); - root_builder.AddDataEdge(data, 0, partitioncall_0, 0); - root_builder.AddDataEdge(partitioncall_0, 0, out, 0); - const auto &root_graph = root_builder.GetGraph(); - - int64_t index = 0; - auto sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_data = sub_builder.AddNode("partitioncall_0_data", DATA, 1, 1); - AttrUtils::SetInt(partitioncall_0_data->GetOpDesc(), "_parent_node_index", index); - const auto &partitioncall_0_cast = sub_builder.AddNode("partitioncall_0_cast", "Cast", 1, 1); - const auto &partitioncall_0_netoutput = sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", index); - sub_builder.AddDataEdge(partitioncall_0_data, 0, partitioncall_0_cast, 0); - sub_builder.AddDataEdge(partitioncall_0_cast, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - root_graph->AddSubgraph("partitioncall_0_sub", sub_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("partitioncall_0_sub"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - NodePtr node = GraphUtils::FindNodeFromAllNodes(const_cast(root_graph), "partitioncall_0_cast"); - EXPECT_NE(node, nullptr); - node = GraphUtils::FindNodeFromAllNodes(const_cast(sub_graph), "partitioncall_0_cast"); - EXPECT_NE(node, nullptr); - SymbolToAnchors symbol_to_anchors; - AnchorToSymbol anchor_to_symbol; - int ret = GraphUtils::GetRefMapping(root_graph, symbol_to_anchors, anchor_to_symbol); - EXPECT_EQ(ret, GRAPH_SUCCESS); - // 当前图共2个symbol - EXPECT_EQ(symbol_to_anchors.size(), 2); - PrintAnchors(symbol_to_anchors); - // 校验partitioncall_0输出和partitioncall_0_cast输出,partitioncall_0_netoutput的输入输出,out的输入输出共享一个symbol - NodeIndexIO partitioncall_0_out(partitioncall_0, 0, kOut); - auto iter = anchor_to_symbol.find(partitioncall_0_out.ToString()); - EXPECT_NE(iter, anchor_to_symbol.end()); - std::string symbol_partitioncall_0_out = iter->second; - auto iter_a = symbol_to_anchors.find(symbol_partitioncall_0_out); - EXPECT_NE(iter_a, symbol_to_anchors.end()); - EXPECT_EQ(iter_a->second.size(), 6U); - std::unordered_set expect_anchors{partitioncall_0_out.ToString()}; - NodeIndexIO partitioncall_0_cast_out(partitioncall_0_cast, 0, kOut); - expect_anchors.emplace(partitioncall_0_cast_out.ToString()); - NodeIndexIO partitioncall_0_netoutput_out(partitioncall_0_netoutput, 0, kOut); - expect_anchors.emplace(partitioncall_0_netoutput_out.ToString()); - NodeIndexIO partitioncall_0_netoutput_in(partitioncall_0_netoutput, 0, kIn); - expect_anchors.emplace(partitioncall_0_netoutput_in.ToString()); - NodeIndexIO out_in(out, 0, kIn); - expect_anchors.emplace(out_in.ToString()); - NodeIndexIO out_out(out, 0, kOut); - expect_anchors.emplace(out_out.ToString()); - CheckAnchor(iter_a->second, expect_anchors); -} - -TEST_F(UtestGraphUtils, InfershapeIfNeedOk) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("data", "Data", 1, 1, FORMAT_NHWC, DT_FLOAT, {16, 228, 228, 3}); - auto cast = builder.AddNode("cast", "Cast", 1, 1, FORMAT_NHWC, DT_FLOAT, {16, 228, 228, 3}); - auto netoutput = builder.AddNode("netoutput", "NetOutput", 1, 1, FORMAT_NHWC, DT_FLOAT, {5, 228, 228, 3}); - AttrUtils::SetBool(cast->GetOpDesc(), "isNeedInfer", true); - const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; - cast->GetOpDesc()->AddInferFunc(stub_func); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - builder.AddDataEdge(data, 0, cast, 0); - builder.AddDataEdge(cast, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - EXPECT_EQ(GraphUtilsEx::InferShapeInNeed(graph), GRAPH_SUCCESS); - std::vector expect_shape = {16, 228, 228, 3}; - EXPECT_EQ(netoutput->GetOpDesc()->GetInputDesc(0).GetShape().GetDims(), expect_shape); - int64_t parent_node_index = -1; - AttrUtils::GetInt(netoutput->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, parent_node_index); - EXPECT_EQ(parent_node_index, 0); -} - -TEST_F(UtestGraphUtils, LoadGraph_parse_fail) { - const std::string file_name = "./test.txt"; - system(("touch " + file_name).c_str()); - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - bool state = GraphUtils::LoadGEGraph(file_name.c_str(), *com_graph1); - ASSERT_EQ(state, false); - state = GraphUtils::LoadGEGraph(file_name.c_str(), com_graph1); - ASSERT_EQ(state, false); - state = GraphUtils::LoadGEGraph(nullptr, *com_graph1); - ASSERT_EQ(state, false); - state = GraphUtils::LoadGEGraph(nullptr, com_graph1); - ASSERT_EQ(state, false); - system(("rm -f " + file_name).c_str()); -} - -TEST_F(UtestGraphUtils, InsertNodeBeforeOpdesc) { - // build test graph - auto builder = ut::GraphBuilder("test"); - const auto &var = builder.AddNode("var", VARIABLE, 0, 1); - const auto &assign = builder.AddNode("assign", "Assign", 1, 1); - const auto &allreduce = builder.AddNode("allreduce", "HcomAllReduce", 1, 1); - const auto &atomic_clean = builder.AddNode("atomic_clean", ATOMICADDRCLEAN, 0, 0); - const auto &netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - // const auto &identity = builder.AddNode("identity", "Identity", 1, 1); - builder.AddDataEdge(var, 0, assign, 0); - builder.AddDataEdge(var, 0, allreduce,0); - builder.AddControlEdge(assign, allreduce); - builder.AddControlEdge(atomic_clean, allreduce); - auto graph = builder.GetGraph(); - auto identity_op_desc = ge::MakeShared("temp", "Identity"); - ASSERT_NE(identity_op_desc, nullptr); - identity_op_desc->AddInputDesc(ge::GeTensorDesc()); - identity_op_desc->AddOutputDesc(ge::GeTensorDesc()); - - std::string expect_super_kernel_scope = "_test"; - (void)AttrUtils::SetStr(allreduce->GetOpDesc(), ATTR_NAME_SUPER_KERNEL_SCOPE, expect_super_kernel_scope); - // insert identity before allreduce - auto identity_node = GraphUtils::InsertNodeBefore(allreduce->GetInDataAnchor(0), identity_op_desc, 0, 0); - ASSERT_NE(identity_node, nullptr); - std::string super_kernel_scope; - (void)AttrUtils::GetStr(identity_node->GetOpDesc(), ATTR_NAME_SUPER_KERNEL_SCOPE, super_kernel_scope); - EXPECT_EQ(super_kernel_scope, expect_super_kernel_scope); - // check assign control-in on identity - ASSERT_EQ(identity_node->GetInControlNodes().at(0)->GetName(), "assign"); - ASSERT_EQ(identity_node->GetInDataNodes().at(0)->GetName(), "var"); - // check atomicclean control-in still on allreuce - ASSERT_EQ(allreduce->GetInControlNodes().at(0)->GetName(), "atomic_clean"); - ASSERT_EQ(allreduce->GetInDataNodes().at(0)->GetName(), "temp"); -} - -TEST_F(UtestGraphUtils, InsertNodeAfterTypeIsSwitchOpDesc) { - auto graph_builder0 = ut::GraphBuilder("test_graph0"); - const auto &node0 = graph_builder0.AddNode("data0", SWITCH, 1, 1); - const auto &node1 = graph_builder0.AddNode("all_reduce", "HcomAllReduce", 1, 1); - const auto &graph0 = graph_builder0.GetGraph(); - graph_builder0.AddDataEdge(node0, 0, node1, 0); - std::vector dsts; - dsts.push_back(node1->GetInDataAnchor(0)); - auto identity_op_desc = ge::MakeShared("temp", "Identity"); - ASSERT_NE(identity_op_desc, nullptr); - identity_op_desc->AddInputDesc(ge::GeTensorDesc()); - identity_op_desc->AddOutputDesc(ge::GeTensorDesc()); - std::string expect_usr_stream_label = "_test"; - (void)AttrUtils::SetStr(node0->GetOpDesc(), public_attr::USER_STREAM_LABEL, expect_usr_stream_label); - auto identity_node = GraphUtils::InsertNodeAfter(node0->GetOutDataAnchor(0), dsts, identity_op_desc, 0, 0); - ASSERT_NE(identity_node, nullptr); - std::string usr_stream_label; - (void)AttrUtils::GetStr(identity_node->GetOpDesc(), public_attr::USER_STREAM_LABEL, usr_stream_label); - EXPECT_EQ(usr_stream_label, expect_usr_stream_label); - ASSERT_EQ(node1->GetInDataNodes().at(0)->GetName(), "temp"); - ASSERT_EQ(identity_node->GetInDataNodes().at(0)->GetName(), "data0"); -} - -TEST_F(UtestGraphUtils, Single_output_2_multi_inputs) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data1 = builder.AddNode("Data1", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto add_node = builder.AddNode("Add", "Add", 2, 1); - auto relu1 = builder.AddNode("Relu1", "Relu", 1, 1); - auto relu2 = builder.AddNode("Relu2", "Relu", 1, 1); - auto relu3 = builder.AddNode("Relu3", "Relu", 1, 1); - auto relu4 = builder.AddNode("Relu4", "Relu", 1, 1); - auto relu5 = builder.AddNode("Relu5", "Relu", 1, 1); - auto relu6 = builder.AddNode("Relu6", "Relu", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 3, 0); - builder.AddDataEdge(data1, 0, add_node, 0); - builder.AddDataEdge(data2, 0, add_node, 1); - builder.AddDataEdge(add_node, 0, relu1, 0); - builder.AddDataEdge(add_node, 0, relu2, 0); - builder.AddDataEdge(add_node, 0, relu3, 0); - builder.AddDataEdge(relu1, 0, netoutput, 0); - builder.AddDataEdge(relu2, 0, netoutput, 0); - builder.AddDataEdge(relu3, 0, netoutput, 0); - builder.AddControlEdge(add_node, relu4); - builder.AddControlEdge(add_node, relu5); - builder.AddControlEdge(add_node, relu6); - builder.AddControlEdge(relu4, netoutput); - builder.AddControlEdge(relu5, netoutput); - builder.AddControlEdge(relu6, netoutput); - auto graph = builder.GetGraph(); - - std::vector expected_dfs_names = - {"Data1", "Data2", "Add", "Relu6", "Relu5", "Relu4", "Relu3", "Relu2", "Relu1", "Netoutput"}; - EXPECT_EQ(graph->TopologicalSorting(), GRAPH_SUCCESS); - std::vector dfs_names; - for (auto &node : graph->GetAllNodes()) { - dfs_names.push_back(node->GetName()); - } - EXPECT_EQ(dfs_names, expected_dfs_names); - - const char_t *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; - (void)setenv(kDumpGraphLevel, "1", 1); - const char_t *const kDumpGeGraph = "DUMP_GE_GRAPH"; - (void)setenv(kDumpGeGraph, "2", 1); - - GraphUtils::DumpGEGraph(graph, "", true, "./ge_test_graph_single_output_2_multi_inputs.txt"); - ComputeGraphPtr com_graph1 = std::make_shared("GeTestGraph1"); - bool state = GraphUtils::LoadGEGraph("./ge_test_graph_single_output_2_multi_inputs.txt", com_graph1); - EXPECT_EQ(state, true); - EXPECT_EQ(com_graph1->TopologicalSorting(), GRAPH_SUCCESS); - dfs_names.clear(); - for (auto &node : graph->GetAllNodes()) { - dfs_names.push_back(node->GetName()); - } - EXPECT_EQ(dfs_names, expected_dfs_names); - system("rm -f ./ge_test*.txt"); -} - -TEST_F(UtestGraphUtils, CanReplace_no_attr) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - - auto op_desc = merge->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrInput("y", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 0U); -} - -TEST_F(UtestGraphUtils, CanReplace_invalid_attr) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0, 0}}); - auto op_desc = merge->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrInput("y", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, FAILED); - EXPECT_EQ(inplace_index_list.size(), 0U); -} - -TEST_F(UtestGraphUtils, CanReplace_Invalid_input_index) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}, {0, 2}}); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_NE(ret, SUCCESS); -} - -TEST_F(UtestGraphUtils, CanReplace_Invalid_output_index) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}, {1, 0}}); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_NE(ret, SUCCESS); -} - -TEST_F(UtestGraphUtils, CanReplace_Diff_streamid) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 2); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - const auto &atomicclean = builder.AddNode("atomicclean", ATOMICADDRCLEAN, 1, 0); - atomicclean->GetOpDesc()->SetStreamId(10); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}, {0, 1}}); - auto op_desc = merge->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrInput("y", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - op_desc->SetStreamId(1); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input1, 1, atomicclean, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 0U); -} - -TEST_F(UtestGraphUtils, CanReplace_Invalid_streamid) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 2); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - const auto &atomicclean = builder.AddNode("atomicclean", ATOMICADDRCLEAN, 1, 0); - atomicclean->GetOpDesc()->SetStreamId(-1); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}, {0, 1}}); - auto op_desc = merge->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrInput("y", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - op_desc->SetStreamId(1); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input1, 1, atomicclean, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 0U); -} - -TEST_F(UtestGraphUtils, CanReplace_special_streamid) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 2); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - const auto &atomicclean = builder.AddNode("atomicclean", ATOMICADDRCLEAN, 1, 0); - input1->GetOpDesc()->SetStreamId(-1); - input2->GetOpDesc()->SetStreamId(-1); - atomicclean->GetOpDesc()->SetStreamId(-1); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}, {0, 1}}); - auto op_desc = merge->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrInput("y", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - op_desc->SetStreamId(-1); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input1, 1, atomicclean, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 1U); - EXPECT_EQ(inplace_index_list[0].size(), 2U); -} - -TEST_F(UtestGraphUtils, CanReplace_Not_max_topid) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 2); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - const auto &atomicclean = builder.AddNode("atomicclean", ATOMICADDRCLEAN, 1, 1); - - auto op_desc = atomicclean->GetOpDesc(); - AttrUtils::SetListListInt(op_desc, ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}}); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input1, 1, atomicclean, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(atomicclean, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 0U); -} - -TEST_F(UtestGraphUtils, CanReplace_Out_node_has_ref_attr) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 2); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - const auto &atomicclean = builder.AddNode("atomicclean", ATOMICADDRCLEAN, 1, 1); - - AttrUtils::SetBool(atomicclean->GetOpDesc(), ATTR_NAME_REFERENCE, true); - - // 设置属性 - auto op_desc = merge->GetOpDesc(); - AttrUtils::SetListListInt(op_desc, ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}}); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input1, 1, atomicclean, 0); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 0U); -} - -TEST_F(UtestGraphUtils, CanReplace_fusion_node) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}}); - AttrUtils::SetGraph(merge->GetOpDesc(), "_original_fusion_graph", builder.GetGraph()); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 1U); - EXPECT_EQ(inplace_index_list[0].size(), 1U); - EXPECT_EQ(inplace_index_list[0][0], 0U); -} - -/* - data data - \ / - merge - | - | - cast - | - | - netoutput -*/ -TEST_F(UtestGraphUtils, CanReplace_Success) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}}); - auto op_desc = merge->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrInput("y", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 1U); - EXPECT_EQ(inplace_index_list[0].size(), 1U); - EXPECT_EQ(inplace_index_list[0][0], 0U); -} - -TEST_F(UtestGraphUtils, CanReplace_2inputs_Success) { - auto builder = ut::GraphBuilder("test1"); - const auto &input1 = builder.AddNode("data1", DATA, 0, 1); - const auto &input2 = builder.AddNode("data2", DATA, 0, 1); - const auto &merge = builder.AddNode("merge", MERGE, 2, 1); - // 设置属性 - AttrUtils::SetListListInt(merge->GetOpDesc(), ATTR_NAME_OUTPUT_INPLACE_ABILITY, {{0, 0}, {0, 1}}); - auto op_desc = merge->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputRequired); - op_desc->AppendIrInput("y", kIrInputRequired); - op_desc->AppendIrOutput("z", kIrOutputRequired); - - const auto &cast = builder.AddNode("cast", "CAST", 1, 1); - const auto &out = builder.AddNode("out", NETOUTPUT, 1, 1); - builder.AddDataEdge(input1, 0, merge, 0); - builder.AddDataEdge(input2, 0, merge, 1); - builder.AddDataEdge(merge, 0, cast, 0); - builder.AddDataEdge(cast, 0, out, 0); - auto graph = builder.GetGraph(); - - // 调用GetSupportInplaceOutput接口 - std::map> inplace_index_list; - auto ret = GraphUtils::GetSupportInplaceOutput(merge, inplace_index_list); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(inplace_index_list.size(), 1U); - EXPECT_EQ(inplace_index_list[0].size(), 2U); - EXPECT_EQ(inplace_index_list[0][0], 0U); - EXPECT_EQ(inplace_index_list[0][1], 1U); -} - -/* - Data Data - | | - - Relu Relu Data Data - | | | | / | - | Cast0 Cast Add0 --> Add Relu - | \ / \ / - |----> Add ---- Relu Add - / \ | - Cast <- Add1 ---- - -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithGraphControlEdge) { - auto builder = ut::GraphBuilder("test_expand_node_with_graph"); - const auto &data0 = builder.AddNode("data0", DATA, 0, 1); - const auto &data1 = builder.AddNode("data1", DATA, 0, 1); - const auto &relu0 = builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = builder.AddNode("relu1", "Relu", 1, 1); - const auto &cast0 = builder.AddNode("cast0", "Cast", 1, 1); - const auto &cast1 = builder.AddNode("cast1", "Cast", 1, 1); - const auto &add0 = builder.AddNode("add0", "Add", 2, 1); - const auto &relu2 = builder.AddNode("relu2", "Relu", 1, 1); - const auto &cast2 = builder.AddNode("cast2", "Cast", 1, 1); - const auto &add1 = builder.AddNode("add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - - builder.AddDataEdge(data0, 0, relu0, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(relu0, 0, cast0, 0); - builder.AddDataEdge(relu1, 0, cast1, 0); - builder.AddDataEdge(cast0, 0, add0, 0); - builder.AddDataEdge(cast1, 0, add0, 1); - builder.AddDataEdge(add0, 0, relu2, 0); - builder.AddDataEdge(add0, 0, add1, 0); - builder.AddDataEdge(relu2, 0, add1, 1); - builder.AddControlEdge(relu0, add0); - builder.AddControlEdge(add0, cast2); - auto graph = builder.GetGraph(); - std::vector> output_nodes{{cast2, 0}, {add1, 0}}; - graph->SetOutputSize(2U); - graph->SetGraphOutNodesInfo(output_nodes); - - std::vector origin_node_sort; - for (const auto &node : graph->GetDirectNode()) { - std::cout << "origin node: " << node->GetName() << std::endl; - origin_node_sort.emplace_back(node->GetName()); - } - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 0, 1); - const auto &sub_data1 = sub_builder.AddNode("sub_data1", DATA, 0, 1); - const auto &sub_add0 = sub_builder.AddNode("sub_add0", "Add", 2, 1); - const auto &sub_relu0 = sub_builder.AddNode("sub_relu0", "Relu", 1, 1); - const auto &sub_add1 = sub_builder.AddNode("sub_add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(sub_data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - sub_builder.AddDataEdge(sub_data0, 0, sub_add0, 0); - sub_builder.AddDataEdge(sub_data1, 0, sub_add0, 1); - sub_builder.AddDataEdge(sub_data1, 0, sub_relu0, 0); - sub_builder.AddDataEdge(sub_add0, 0, sub_add1, 0); - sub_builder.AddDataEdge(sub_relu0, 0, sub_add1, 1); - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_add1, 0}}; - sub_graph->SetOutputSize(1U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(add0, sub_graph), SUCCESS); - const auto add0_node = graph->FindNode("add0"); - EXPECT_EQ(add0_node, nullptr); - const auto subgraph_data0_node = graph->FindNode("sub_data0"); - EXPECT_EQ(subgraph_data0_node, nullptr); - const auto subgraph_data1_node = graph->FindNode("sub_data1"); - EXPECT_EQ(subgraph_data1_node, nullptr); - - const auto sub_add0_node = graph->FindNode("sub_add0"); - EXPECT_EQ(sub_add0_node, sub_add0); - const auto add_in_data_anchors = sub_add0_node->GetAllInDataAnchors(); - const auto add_in_data_anchor_0 = add_in_data_anchors.at(0); - const auto peer_out_add_in_data_anchor_0 = add_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor_0->GetOwnerNode()->GetName(), "cast0"); - const auto add_in_data_anchor_1 = add_in_data_anchors.at(1); - const auto peer_out_add_in_data_anchor_1 = add_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor_1->GetOwnerNode()->GetName(), "cast1"); - - const auto add_in_control_anchor = sub_add0_node->GetInControlAnchor(); - ASSERT_NE(add_in_control_anchor, nullptr); - const auto peer_out_add_in_control_anchors = add_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_add_in_control_anchors.size(), 2U); - EXPECT_EQ(peer_out_add_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - EXPECT_EQ(peer_out_add_in_control_anchors.at(1)->GetOwnerNode()->GetName(), "relu0"); - const auto add_out_data_anchor = sub_add0_node->GetOutDataAnchor(0); - ASSERT_NE(add_out_data_anchor, nullptr); - const auto peer_in_add_out_data_anchors = add_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_add_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add1"); - const auto add_out_control_anchor = sub_add0_node->GetOutControlAnchor(); - ASSERT_NE(add_out_control_anchor, nullptr); - EXPECT_EQ(add_out_control_anchor->GetPeerInControlAnchors().size(), 0U); - - const auto sub_relu0_node = graph->FindNode("sub_relu0"); - EXPECT_EQ(sub_relu0_node, sub_relu0); - const auto relu0_in_data_anchor = sub_relu0_node->GetInDataAnchor(0); - const auto peer_out_relu0_in_data_anchor = relu0_in_data_anchor->GetPeerOutAnchor(); - ASSERT_NE(peer_out_relu0_in_data_anchor, nullptr); - EXPECT_EQ(peer_out_relu0_in_data_anchor->GetOwnerNode()->GetName(), "cast1"); - - const auto relu0_in_control_anchor = sub_relu0_node->GetInControlAnchor(); - ASSERT_NE(relu0_in_control_anchor, nullptr); - const auto peer_out_relu0_in_control_anchors = relu0_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_relu0_in_control_anchors.size(), 1U); - EXPECT_EQ(peer_out_relu0_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - const auto relu0_out_data_anchor = sub_relu0_node->GetOutDataAnchor(0); - ASSERT_NE(relu0_out_data_anchor, nullptr); - const auto peer_in_relu0_out_data_anchors = relu0_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_relu0_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_relu0_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add1"); - const auto relu0_out_control_anchor = sub_relu0_node->GetOutControlAnchor(); - ASSERT_NE(relu0_out_control_anchor, nullptr); - EXPECT_EQ(relu0_out_control_anchor->GetPeerInControlAnchors().size(), 0U); - - const auto sub_add1_node = graph->FindNode("sub_add1"); - EXPECT_EQ(sub_add1_node, sub_add1); - const auto add1_in_data_anchors = sub_add1_node->GetAllInDataAnchors(); - const auto add1_in_data_anchor_0 = add1_in_data_anchors.at(0); - const auto peer_out_add1_in_data_anchor_0 = add1_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_0->GetOwnerNode()->GetName(), "sub_add0"); - const auto add1_in_data_anchor_1 = add1_in_data_anchors.at(1); - const auto peer_out_add1_in_data_anchor_1 = add1_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_1->GetOwnerNode()->GetName(), "sub_relu0"); - - const auto add1_in_control_anchor = sub_add1_node->GetInControlAnchor(); - ASSERT_NE(add1_in_control_anchor, nullptr); - const auto peer_out_add1_in_control_anchors = add1_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_add1_in_control_anchors.size(), 0U); - - const auto add1_out_data_anchor = sub_add1_node->GetOutDataAnchor(0); - ASSERT_NE(add1_out_data_anchor, nullptr); - const auto peer_in_add1_out_data_anchors = add1_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add1_out_data_anchors.size(), 2); - EXPECT_EQ(peer_in_add1_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "relu2"); - EXPECT_EQ(peer_in_add1_out_data_anchors.at(1)->GetOwnerNode()->GetName(), "add1"); - const auto add1_out_control_anchor = sub_add1_node->GetOutControlAnchor(); - ASSERT_NE(add1_out_control_anchor, nullptr); - EXPECT_EQ(add1_out_control_anchor->GetPeerInControlAnchors().size(), 1U); - EXPECT_EQ(add1_out_control_anchor->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); - - // 验证topo序 - std::vector expect_sort; - for (const auto &origin_node_name : origin_node_sort) { - if (origin_node_name == "add0") { - for (const auto &subgraph_node : sub_graph->GetDirectNode()) { - if (subgraph_node->GetType() != "Data") { - expect_sort.emplace_back(subgraph_node->GetName()); - } - } - continue; - } - expect_sort.emplace_back(origin_node_name); - } - size_t index = 0UL; - for (const auto &node : graph->GetDirectNode()) { - EXPECT_EQ(node->GetName(), expect_sort[index]); - index++; - } -} - -/* - Data Data - | | - - Relu Relu Data Data then: Data else: Data - | | | | | | | - | Cast0 Cast Add1 --> | | Relu Cast - | \ / \ / - |----> Add ---- Relu If -> - / \ | - Cast <- Add1 ---- - -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithGraphWithSubGraph) { - auto builder = ut::GraphBuilder("test_expand_node_with_graph"); - const auto &data0 = builder.AddNode("data0", DATA, 0, 1); - const auto &data1 = builder.AddNode("data1", DATA, 0, 1); - const auto &relu0 = builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = builder.AddNode("relu1", "Relu", 1, 1); - const auto &cast0 = builder.AddNode("cast0", "Cast", 1, 1); - const auto &cast1 = builder.AddNode("cast1", "Cast", 1, 1); - const auto &add0 = builder.AddNode("add0", "Add", 2, 1); - const auto &relu2 = builder.AddNode("relu2", "Relu", 1, 1); - const auto &cast2 = builder.AddNode("cast2", "Cast", 1, 1); - const auto &add1 = builder.AddNode("add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - - builder.AddDataEdge(data0, 0, relu0, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(relu0, 0, cast0, 0); - builder.AddDataEdge(relu1, 0, cast1, 0); - builder.AddDataEdge(cast0, 0, add0, 0); - builder.AddDataEdge(cast1, 0, add0, 1); - builder.AddDataEdge(add0, 0, relu2, 0); - builder.AddDataEdge(add0, 0, add1, 0); - builder.AddDataEdge(relu2, 0, add1, 1); - builder.AddControlEdge(relu0, add0); - builder.AddControlEdge(add0, cast2); - auto graph = builder.GetGraph(); - std::vector> output_nodes{{cast2, 0}, {add1, 0}}; - graph->SetOutputSize(2U); - graph->SetGraphOutNodesInfo(output_nodes); - - std::vector origin_node_sort; - for (const auto &node : graph->GetDirectNode()) { - std::cout << "origin node: " << node->GetName() << std::endl; - origin_node_sort.emplace_back(node->GetName()); - } - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 0, 1); - const auto &sub_data1 = sub_builder.AddNode("sub_data1", DATA, 0, 1); - const auto &sub_if = sub_builder.AddNode("sub_if", "If", 2, 1); - // 设置属性 - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(sub_data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - sub_builder.AddDataEdge(sub_data0, 0, sub_if, 0); - sub_builder.AddDataEdge(sub_data1, 0, sub_if, 1); - - auto then_graph_builder = ut::GraphBuilder("then_graph"); - const auto &then_graph_data0 = then_graph_builder.AddNode("then_graph_data0", DATA, 0, 1); - const auto &then_graph_relu0 = then_graph_builder.AddNode("then_graph_relu0", "Relu", 1, 1); - then_graph_builder.AddDataEdge(then_graph_data0, 0, then_graph_relu0, 0); - AttrUtils::SetInt(then_graph_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(then_graph_data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - auto then_graph = then_graph_builder.GetGraph(); - std::vector> then_graph_output_nodes{{then_graph_relu0, 0}}; - then_graph->SetOutputSize(1U); - then_graph->SetGraphOutNodesInfo(then_graph_output_nodes); - - auto else_graph_builder = ut::GraphBuilder("else_graph"); - const auto &else_graph_data0 = else_graph_builder.AddNode("else_graph_data0", DATA, 0, 1); - const auto &else_graph_cast0 = else_graph_builder.AddNode("else_graph_cast0", "Cast", 1, 1); - else_graph_builder.AddDataEdge(else_graph_data0, 0, else_graph_cast0, 0); - AttrUtils::SetInt(else_graph_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(else_graph_data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - auto else_graph = else_graph_builder.GetGraph(); - std::vector> else_graph_output_nodes{{else_graph_cast0, 0}}; - else_graph->SetOutputSize(1U); - else_graph->SetGraphOutNodesInfo(else_graph_output_nodes); - - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_if, 0}}; - sub_graph->SetOutputSize(1U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - then_graph->SetGraphUnknownFlag(false); - sub_if->GetOpDesc()->AddSubgraphName(then_graph->GetName()); - sub_if->GetOpDesc()->SetSubgraphInstanceName(0, then_graph->GetName()); - sub_graph->AddSubGraph(then_graph); - then_graph->SetParentNode(sub_if); - then_graph->SetParentGraph(sub_graph); - - else_graph->SetGraphUnknownFlag(false); - sub_if->GetOpDesc()->AddSubgraphName(else_graph->GetName()); - sub_if->GetOpDesc()->SetSubgraphInstanceName(1, else_graph->GetName()); - sub_graph->AddSubGraph(else_graph); - else_graph->SetParentNode(sub_if); - else_graph->SetParentGraph(sub_graph); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(add0, sub_graph), SUCCESS); - const auto subgraphs = graph->GetAllSubgraphs(); - EXPECT_EQ(subgraphs.size(), 2); - const auto add0_node = graph->FindNode("add0"); - EXPECT_EQ(add0_node, nullptr); - const auto subgraph_data0_node = graph->FindNode("sub_data0"); - EXPECT_EQ(subgraph_data0_node, nullptr); - const auto subgraph_data1_node = graph->FindNode("sub_data1"); - EXPECT_EQ(subgraph_data1_node, nullptr); - const auto sub_if_node = graph->FindNode("sub_if"); - EXPECT_EQ(sub_if_node, sub_if); - const auto if_in_data_anchors = sub_if_node->GetAllInDataAnchors(); - const auto if_in_data_anchor_0 = if_in_data_anchors.at(0); - const auto peer_out_if_in_data_anchor_0 = if_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_if_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_if_in_data_anchor_0->GetOwnerNode()->GetName(), "cast0"); - const auto if_in_data_anchor_1 = if_in_data_anchors.at(1); - const auto peer_out_if_in_data_anchor_1 = if_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_if_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_if_in_data_anchor_1->GetOwnerNode()->GetName(), "cast1"); - - const auto if_in_control_anchor = sub_if_node->GetInControlAnchor(); - ASSERT_NE(if_in_control_anchor, nullptr); - const auto peer_out_if_in_control_anchors = if_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_if_in_control_anchors.size(), 2U); - EXPECT_EQ(peer_out_if_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - EXPECT_EQ(peer_out_if_in_control_anchors.at(1)->GetOwnerNode()->GetName(), "relu0"); - const auto if_out_data_anchor = sub_if_node->GetOutDataAnchor(0); - ASSERT_NE(if_out_data_anchor, nullptr); - const auto peer_in_if_out_data_anchors = if_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_if_out_data_anchors.size(), 2); - EXPECT_EQ(peer_in_if_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "relu2"); - EXPECT_EQ(peer_in_if_out_data_anchors.at(1)->GetOwnerNode()->GetName(), "add1"); - const auto if_out_control_anchor = sub_if_node->GetOutControlAnchor(); - ASSERT_NE(if_out_control_anchor, nullptr); - EXPECT_EQ(if_out_control_anchor->GetPeerInControlAnchors().size(), 1U); - EXPECT_EQ(if_out_control_anchor->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); - auto subgraph_names_index = sub_if_node->GetOpDesc()->GetSubgraphNameIndexes(); - EXPECT_EQ(subgraph_names_index.size(), 2); - const auto subgraph_then_graph = subgraphs[subgraph_names_index["then_graph"]]; - ASSERT_NE(subgraph_then_graph, nullptr); - EXPECT_EQ(subgraph_then_graph->GetParentGraph(), graph); - EXPECT_EQ(subgraph_then_graph->GetParentNode(), sub_if_node); - const auto subgraph_else_graph = subgraphs[subgraph_names_index["else_graph"]]; - ASSERT_NE(subgraph_else_graph, nullptr); - EXPECT_EQ(subgraph_else_graph->GetParentGraph(), graph); - EXPECT_EQ(subgraph_else_graph->GetParentNode(), sub_if_node); - // 验证topo序 - std::vector expect_sort; - for (const auto &origin_node_name : origin_node_sort) { - if (origin_node_name == "add0") { - for (const auto &subgraph_node : sub_graph->GetDirectNode()) { - if (subgraph_node->GetType() != "Data") { - expect_sort.emplace_back(subgraph_node->GetName()); - } - } - continue; - } - expect_sort.emplace_back(origin_node_name); - } - size_t index = 0UL; - for (const auto &node : graph->GetDirectNode()) { - EXPECT_EQ(node->GetName(), expect_sort[index]); - index++; - } -} - -/* - Data Data Data Data Data - \ / | | --> / \ - If Cast Relu Relu Relu - \ / - Add -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithGraphNodeInSubGraph) { - auto builder = ut::GraphBuilder("root_graph"); - const auto &data0 = builder.AddNode("data0", DATA, 0, 1); - const auto &data1 = builder.AddNode("data1", DATA, 0, 1); - const auto &if_op = builder.AddNode("if_op", "If", 2, 1); - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - builder.AddDataEdge(data0, 0, if_op, 0); - builder.AddDataEdge(data1, 0, if_op, 1); - - auto then_graph_builder = ut::GraphBuilder("then_graph"); - const auto &then_graph_data0 = then_graph_builder.AddNode("then_graph_data0", DATA, 0, 1); - const auto &then_graph_relu0 = then_graph_builder.AddNode("then_graph_relu0", "Relu", 1, 1); - then_graph_builder.AddDataEdge(then_graph_data0, 0, then_graph_relu0, 0); - AttrUtils::SetInt(then_graph_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(then_graph_data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - auto then_graph = then_graph_builder.GetGraph(); - std::vector> then_graph_output_nodes{{then_graph_relu0, 0}}; - then_graph->SetOutputSize(1U); - then_graph->SetGraphOutNodesInfo(then_graph_output_nodes); - - auto else_graph_builder = ut::GraphBuilder("else_graph"); - const auto &else_graph_data0 = else_graph_builder.AddNode("else_graph_data0", DATA, 0, 1); - const auto &else_graph_cast0 = else_graph_builder.AddNode("else_graph_cast0", "Cast", 1, 1); - else_graph_builder.AddDataEdge(else_graph_data0, 0, else_graph_cast0, 0); - AttrUtils::SetInt(else_graph_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(else_graph_data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); - auto else_graph = else_graph_builder.GetGraph(); - std::vector> else_graph_output_nodes{{else_graph_cast0, 0}}; - else_graph->SetOutputSize(1U); - else_graph->SetGraphOutNodesInfo(else_graph_output_nodes); - - auto graph = builder.GetGraph(); - std::vector> output_nodes{{if_op, 0}}; - graph->SetOutputSize(1U); - graph->SetGraphOutNodesInfo(output_nodes); - - then_graph->SetGraphUnknownFlag(false); - if_op->GetOpDesc()->AddSubgraphName(then_graph->GetName()); - if_op->GetOpDesc()->SetSubgraphInstanceName(0, then_graph->GetName()); - graph->AddSubGraph(then_graph); - then_graph->SetParentNode(if_op); - then_graph->SetParentGraph(graph); - - else_graph->SetGraphUnknownFlag(false); - if_op->GetOpDesc()->AddSubgraphName(else_graph->GetName()); - if_op->GetOpDesc()->SetSubgraphInstanceName(1, else_graph->GetName()); - graph->AddSubGraph(else_graph); - else_graph->SetParentNode(if_op); - else_graph->SetParentGraph(graph); - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 0, 1); - const auto &sub_relu0 = sub_builder.AddNode("sub_relu0", "Relu", 1, 1); - const auto &sub_relu1 = sub_builder.AddNode("sub_relu1", "Relu", 1, 1); - const auto &sub_add = sub_builder.AddNode("sub_add", "Add", 2, 1); - sub_builder.AddDataEdge(sub_data0, 0, sub_relu0, 0); - sub_builder.AddDataEdge(sub_data0, 0, sub_relu1, 0); - sub_builder.AddDataEdge(sub_relu0, 0, sub_add, 0); - sub_builder.AddDataEdge(sub_relu1, 0, sub_add, 1); - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_add, 0}}; - sub_graph->SetOutputSize(1U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(then_graph_relu0, sub_graph), SUCCESS); - const auto subgraphs = graph->GetAllSubgraphs(); - EXPECT_EQ(subgraphs.size(), 2); - EXPECT_EQ(then_graph->FindNode("then_graph_relu0"), nullptr); - EXPECT_EQ(then_graph->FindNode("sub_data0"), nullptr); - - const auto sub_relu0_node = then_graph->FindNode("sub_relu0"); - EXPECT_EQ(sub_relu0_node, sub_relu0); - const auto relu0_in_data_anchor = sub_relu0_node->GetInDataAnchor(0); - const auto peer_out_relu0_in_data_anchor = relu0_in_data_anchor->GetPeerOutAnchor(); - ASSERT_NE(peer_out_relu0_in_data_anchor, nullptr); - EXPECT_EQ(peer_out_relu0_in_data_anchor->GetOwnerNode()->GetName(), "then_graph_data0"); - const auto relu0_out_data_anchor = sub_relu0_node->GetOutDataAnchor(0); - ASSERT_NE(relu0_out_data_anchor, nullptr); - const auto peer_in_relu0_out_data_anchors = relu0_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_relu0_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_relu0_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add"); - - const auto sub_relu1_node = then_graph->FindNode("sub_relu1"); - EXPECT_EQ(sub_relu1_node, sub_relu1); - const auto relu1_in_data_anchor = sub_relu1_node->GetInDataAnchor(0); - const auto peer_out_relu1_in_data_anchor = relu1_in_data_anchor->GetPeerOutAnchor(); - ASSERT_NE(peer_out_relu1_in_data_anchor, nullptr); - EXPECT_EQ(peer_out_relu1_in_data_anchor->GetOwnerNode()->GetName(), "then_graph_data0"); - const auto relu1_out_data_anchor = sub_relu1_node->GetOutDataAnchor(0); - ASSERT_NE(relu1_out_data_anchor, nullptr); - const auto peer_in_relu1_out_data_anchors = relu1_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_relu1_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_relu1_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add"); - - const auto sub_add_node = then_graph->FindNode("sub_add"); - EXPECT_EQ(sub_add_node, sub_add); - const auto add_in_data_anchor0 = sub_add_node->GetInDataAnchor(0); - const auto peer_out_add_in_data_anchor0 = add_in_data_anchor0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor0, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor0->GetOwnerNode()->GetName(), "sub_relu0"); - const auto add_in_data_anchor1 = sub_add_node->GetInDataAnchor(1); - const auto peer_out_add_in_data_anchor1 = add_in_data_anchor1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor1, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor1->GetOwnerNode()->GetName(), "sub_relu1"); - const auto add_out_data_anchor = sub_add_node->GetOutDataAnchor(0); - ASSERT_NE(add_out_data_anchor, nullptr); - const auto peer_in_add_out_data_anchors = add_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add_out_data_anchors.size(), 0); - const auto output_node_info = then_graph->GetGraphOutNodesInfo(); - EXPECT_EQ(output_node_info.size(), 1); - EXPECT_EQ(output_node_info[0].first, sub_add_node); - EXPECT_EQ(output_node_info[0].second, 0); -} - -/* - Data Data - | | - Relu Relu Data Data - | | | | - Cast0 Cast Data Clip --> | | - \ / | \ / - Clip ------ Clip - | \ - | ---- Relu - / \ | - Cast <- Add1 ---- - -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithGraphInputNotMatch) { - auto builder = ut::GraphBuilder("test_expand_node_with_graph"); - const auto &data0 = builder.AddNode("data0", DATA, 1, 1); - const auto &data1 = builder.AddNode("data1", DATA, 1, 1); - const auto &data2 = builder.AddNode("data2", DATA, 1, 1); - const auto &relu0 = builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = builder.AddNode("relu1", "Relu", 1, 1); - const auto &cast0 = builder.AddNode("cast0", "Cast", 1, 1); - const auto &cast1 = builder.AddNode("cast1", "Cast", 1, 1); - const auto &clip0 = builder.AddNode("clip0", "ClipByValue", 3, 1); - const auto &relu2 = builder.AddNode("relu2", "Relu", 1, 1); - const auto &cast2 = builder.AddNode("cast2", "Cast", 1, 1); - const auto &add1 = builder.AddNode("add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - - builder.AddDataEdge(data0, 0, relu0, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(relu0, 0, cast0, 0); - builder.AddDataEdge(relu1, 0, cast1, 0); - builder.AddDataEdge(cast0, 0, clip0, 0); - builder.AddDataEdge(cast1, 0, clip0, 1); - builder.AddDataEdge(data2, 0, clip0, 2); - builder.AddDataEdge(clip0, 0, relu2, 0); - builder.AddDataEdge(clip0, 0, add1, 0); - builder.AddDataEdge(relu2, 0, add1, 1); - builder.AddControlEdge(relu1, clip0); - builder.AddControlEdge(clip0, cast2); - auto graph = builder.GetGraph(); - std::vector> output_nodes{{cast2, 0}, {add1, 0}}; - graph->SetOutputSize(2U); - graph->SetGraphOutNodesInfo(output_nodes); - - std::vector origin_node_sort; - for (const auto &node : graph->GetDirectNode()) { - std::cout << "origin node: " << node->GetName() << std::endl; - origin_node_sort.emplace_back(node->GetName()); - } - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 1, 1); - const auto &sub_data2 = sub_builder.AddNode("sub_data2", DATA, 1, 1); - const auto &sub_clip0 = sub_builder.AddNode("sub_clip0", "ClipByValue", 3, 1); - // 设置属性 - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(sub_data2->GetOpDesc(), ATTR_NAME_INDEX, 2); - sub_builder.AddDataEdge(sub_data0, 0, sub_clip0, 0); - sub_builder.AddDataEdge(sub_data2, 0, sub_clip0, 2); - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_clip0, 0}}; - sub_graph->SetOutputSize(1U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(clip0, sub_graph), SUCCESS); - const auto clip0_node = graph->FindNode("clip0"); - EXPECT_EQ(clip0_node, nullptr); - const auto subgraph_data0_node = graph->FindNode("sub_data0"); - EXPECT_EQ(subgraph_data0_node, nullptr); - const auto subgraph_data2_node = graph->FindNode("sub_data2"); - EXPECT_EQ(subgraph_data2_node, nullptr); - - const auto sub_clip0_node = graph->FindNode("sub_clip0"); - EXPECT_EQ(sub_clip0_node, sub_clip0); - const auto clip0_in_data_anchors = sub_clip0_node->GetAllInDataAnchors(); - const auto clip0_in_data_anchor_0 = clip0_in_data_anchors.at(0); - const auto peer_out_clip0_in_data_anchor_0 = clip0_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_clip0_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_clip0_in_data_anchor_0->GetOwnerNode()->GetName(), "cast0"); - - const auto clip0_in_data_anchor_1 = clip0_in_data_anchors.at(1); - const auto peer_out_clip0_in_data_anchor_1 = clip0_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_EQ(peer_out_clip0_in_data_anchor_1, nullptr); - - const auto clip0_in_data_anchor_2 = clip0_in_data_anchors.at(2); - const auto peer_out_clip0_in_data_anchor_2 = clip0_in_data_anchor_2->GetPeerOutAnchor(); - ASSERT_NE(peer_out_clip0_in_data_anchor_2, nullptr); - EXPECT_EQ(peer_out_clip0_in_data_anchor_2->GetOwnerNode()->GetName(), "data2"); - - const auto clip0_in_control_anchor = sub_clip0_node->GetInControlAnchor(); - ASSERT_NE(clip0_in_control_anchor, nullptr); - const auto peer_out_clip0_in_control_anchors = clip0_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_clip0_in_control_anchors.size(), 2U); - EXPECT_EQ(peer_out_clip0_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu1"); - EXPECT_EQ(peer_out_clip0_in_control_anchors.at(1)->GetOwnerNode()->GetName(), "relu1"); - const auto clip0_out_data_anchor = sub_clip0_node->GetOutDataAnchor(0); - ASSERT_NE(clip0_out_data_anchor, nullptr); - const auto peer_in_clip0_out_data_anchors = clip0_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_clip0_out_data_anchors.size(), 2); - EXPECT_EQ(peer_in_clip0_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "relu2"); - EXPECT_EQ(peer_in_clip0_out_data_anchors.at(1)->GetOwnerNode()->GetName(), "add1"); - const auto clip0_out_control_anchor = sub_clip0_node->GetOutControlAnchor(); - ASSERT_NE(clip0_out_control_anchor, nullptr); - EXPECT_EQ(clip0_out_control_anchor->GetPeerInControlAnchors().size(), 1U); - EXPECT_EQ(clip0_out_control_anchor->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); - - // 验证topo序 - std::vector expect_sort; - for (const auto &origin_node_name : origin_node_sort) { - if (origin_node_name == "clip0") { - for (const auto &subgraph_node : sub_graph->GetDirectNode()) { - if (subgraph_node->GetType() != "Data") { - expect_sort.emplace_back(subgraph_node->GetName()); - } - } - continue; - } - expect_sort.emplace_back(origin_node_name); - } - size_t index = 0UL; - for (const auto &node : graph->GetDirectNode()) { - EXPECT_EQ(node->GetName(), expect_sort[index]); - index++; - } -} - -/* - Data Data - | | - - Relu Relu Data Data - | | | | / | - | Cast0 Cast Add0 --> Add Relu - | \ / \ / | - |----> Add ---- Relu Add | - / \ | NetOutput - Cast <- Add1 ---- - -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithNetOutput) { - auto builder = ut::GraphBuilder("test_expand_node_with_graph"); - const auto &data0 = builder.AddNode("data0", DATA, 0, 1); - const auto &data1 = builder.AddNode("data1", DATA, 0, 1); - const auto &relu0 = builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = builder.AddNode("relu1", "Relu", 1, 1); - const auto &cast0 = builder.AddNode("cast0", "Cast", 1, 1); - const auto &cast1 = builder.AddNode("cast1", "Cast", 1, 1); - const auto &add0 = builder.AddNode("add0", "Add", 2, 1); - const auto &relu2 = builder.AddNode("relu2", "Relu", 1, 1); - const auto &cast2 = builder.AddNode("cast2", "Cast", 1, 1); - const auto &add1 = builder.AddNode("add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - - builder.AddDataEdge(data0, 0, relu0, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(relu0, 0, cast0, 0); - builder.AddDataEdge(relu1, 0, cast1, 0); - builder.AddDataEdge(cast0, 0, add0, 0); - builder.AddDataEdge(cast1, 0, add0, 1); - builder.AddDataEdge(add0, 0, relu2, 0); - builder.AddDataEdge(add0, 0, add1, 0); - builder.AddDataEdge(relu2, 0, add1, 1); - builder.AddControlEdge(relu0, add0); - builder.AddControlEdge(add0, cast2); - auto graph = builder.GetGraph(); - std::vector> output_nodes{{cast2, 0}, {add1, 0}}; - graph->SetOutputSize(2U); - graph->SetGraphOutNodesInfo(output_nodes); - - std::vector origin_node_sort; - for (const auto &node : graph->GetDirectNode()) { - std::cout << "origin node: " << node->GetName() << std::endl; - origin_node_sort.emplace_back(node->GetName()); - } - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 0, 1); - const auto &sub_data1 = sub_builder.AddNode("sub_data1", DATA, 0, 1); - const auto &sub_add0 = sub_builder.AddNode("sub_add0", "Add", 2, 1); - const auto &sub_relu0 = sub_builder.AddNode("sub_relu0", "Relu", 1, 1); - const auto &sub_add1 = sub_builder.AddNode("sub_add1", "Add", 2, 1); - const auto &sub_output = sub_builder.AddNode("sub_output", "NetOutput", 1, 1); - // 设置属性 - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(sub_data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - sub_builder.AddDataEdge(sub_data0, 0, sub_add0, 0); - sub_builder.AddDataEdge(sub_data1, 0, sub_add0, 1); - sub_builder.AddDataEdge(sub_data1, 0, sub_relu0, 0); - sub_builder.AddDataEdge(sub_add0, 0, sub_add1, 0); - sub_builder.AddDataEdge(sub_relu0, 0, sub_add1, 1); - sub_builder.AddDataEdge(sub_relu0, 0, sub_output, 0); - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_relu0, 0}}; - sub_graph->SetOutputSize(1U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(add0, sub_graph), SUCCESS); - const auto add0_node = graph->FindNode("add0"); - EXPECT_EQ(add0_node, nullptr); - const auto subgraph_data0_node = graph->FindNode("sub_data0"); - EXPECT_EQ(subgraph_data0_node, nullptr); - const auto subgraph_data1_node = graph->FindNode("sub_data1"); - EXPECT_EQ(subgraph_data1_node, nullptr); - const auto output_node = graph->FindNode("sub_output"); - EXPECT_EQ(output_node, nullptr); - const auto sub_add0_node = graph->FindNode("sub_add0"); - EXPECT_EQ(sub_add0_node, sub_add0); - const auto add_in_data_anchors = sub_add0_node->GetAllInDataAnchors(); - const auto add_in_data_anchor_0 = add_in_data_anchors.at(0); - const auto peer_out_add_in_data_anchor_0 = add_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor_0->GetOwnerNode()->GetName(), "cast0"); - const auto add_in_data_anchor_1 = add_in_data_anchors.at(1); - const auto peer_out_add_in_data_anchor_1 = add_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor_1->GetOwnerNode()->GetName(), "cast1"); - - const auto add_in_control_anchor = sub_add0_node->GetInControlAnchor(); - ASSERT_NE(add_in_control_anchor, nullptr); - const auto peer_out_add_in_control_anchors = add_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_add_in_control_anchors.size(), 2U); - EXPECT_EQ(peer_out_add_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - EXPECT_EQ(peer_out_add_in_control_anchors.at(1)->GetOwnerNode()->GetName(), "relu0"); - const auto add_out_data_anchor = sub_add0_node->GetOutDataAnchor(0); - ASSERT_NE(add_out_data_anchor, nullptr); - const auto peer_in_add_out_data_anchors = add_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_add_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add1"); - const auto add_out_control_anchor = sub_add0_node->GetOutControlAnchor(); - ASSERT_NE(add_out_control_anchor, nullptr); - EXPECT_EQ(add_out_control_anchor->GetPeerInControlAnchors().size(), 0U); - - const auto sub_relu0_node = graph->FindNode("sub_relu0"); - EXPECT_EQ(sub_relu0_node, sub_relu0); - const auto relu0_in_data_anchor = sub_relu0_node->GetInDataAnchor(0); - const auto peer_out_relu0_in_data_anchor = relu0_in_data_anchor->GetPeerOutAnchor(); - ASSERT_NE(peer_out_relu0_in_data_anchor, nullptr); - EXPECT_EQ(peer_out_relu0_in_data_anchor->GetOwnerNode()->GetName(), "cast1"); - - const auto relu0_in_control_anchor = sub_relu0_node->GetInControlAnchor(); - ASSERT_NE(relu0_in_control_anchor, nullptr); - const auto peer_out_relu0_in_control_anchors = relu0_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_relu0_in_control_anchors.size(), 1U); - EXPECT_EQ(peer_out_relu0_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - const auto relu0_out_data_anchor = sub_relu0_node->GetOutDataAnchor(0); - ASSERT_NE(relu0_out_data_anchor, nullptr); - const auto peer_in_relu0_out_data_anchors = relu0_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_relu0_out_data_anchors.size(), 3); - EXPECT_EQ(peer_in_relu0_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add1"); - EXPECT_EQ(peer_in_relu0_out_data_anchors.at(1)->GetOwnerNode()->GetName(), "relu2"); - EXPECT_EQ(peer_in_relu0_out_data_anchors.at(2)->GetOwnerNode()->GetName(), "add1"); - const auto relu0_out_control_anchor = sub_relu0_node->GetOutControlAnchor(); - ASSERT_NE(relu0_out_control_anchor, nullptr); - EXPECT_EQ(relu0_out_control_anchor->GetPeerInControlAnchors().size(), 1U); - EXPECT_EQ(relu0_out_control_anchor->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); - - const auto sub_add1_node = graph->FindNode("sub_add1"); - EXPECT_EQ(sub_add1_node, sub_add1); - const auto add1_in_data_anchors = sub_add1_node->GetAllInDataAnchors(); - const auto add1_in_data_anchor_0 = add1_in_data_anchors.at(0); - const auto peer_out_add1_in_data_anchor_0 = add1_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_0->GetOwnerNode()->GetName(), "sub_add0"); - const auto add1_in_data_anchor_1 = add1_in_data_anchors.at(1); - const auto peer_out_add1_in_data_anchor_1 = add1_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_1->GetOwnerNode()->GetName(), "sub_relu0"); - - const auto add1_in_control_anchor = sub_add1_node->GetInControlAnchor(); - ASSERT_NE(add1_in_control_anchor, nullptr); - const auto peer_out_add1_in_control_anchors = add1_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_add1_in_control_anchors.size(), 0U); - - const auto add1_out_data_anchor = sub_add1_node->GetOutDataAnchor(0); - ASSERT_NE(add1_out_data_anchor, nullptr); - const auto peer_in_add1_out_data_anchors = add1_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add1_out_data_anchors.size(), 0); - const auto add1_out_control_anchor = sub_add1_node->GetOutControlAnchor(); - ASSERT_NE(add1_out_control_anchor, nullptr); - EXPECT_EQ(add1_out_control_anchor->GetPeerInControlAnchors().size(), 0U); - - // 验证topo序 - std::vector expect_sort; - for (const auto &origin_node_name : origin_node_sort) { - if (origin_node_name == "add0") { - for (const auto &subgraph_node : sub_graph->GetDirectNode()) { - if ((subgraph_node->GetType() != "Data") && (subgraph_node->GetType() != "NetOutput")) { - expect_sort.emplace_back(subgraph_node->GetName()); - } - } - continue; - } - expect_sort.emplace_back(origin_node_name); - } - size_t index = 0UL; - for (const auto &node : graph->GetDirectNode()) { - EXPECT_EQ(node->GetName(), expect_sort[index]); - index++; - } -} - -/* - Data Data - | | - - Relu Relu Data Data - | | | / | / | - | Cast0 Cast Add0 --> | Add Relu - | \ / \ \ / - |----> Add ---- Relu -- Add - / \ | - Cast <- Add1 ---- - -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithDataControl) { - auto builder = ut::GraphBuilder("test_expand_node_with_graph"); - const auto &data0 = builder.AddNode("data0", DATA, 0, 1); - const auto &data1 = builder.AddNode("data1", DATA, 0, 1); - const auto &relu0 = builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = builder.AddNode("relu1", "Relu", 1, 1); - const auto &cast0 = builder.AddNode("cast0", "Cast", 1, 1); - const auto &cast1 = builder.AddNode("cast1", "Cast", 1, 1); - const auto &add0 = builder.AddNode("add0", "Add", 2, 1); - const auto &relu2 = builder.AddNode("relu2", "Relu", 1, 1); - const auto &cast2 = builder.AddNode("cast2", "Cast", 1, 1); - const auto &add1 = builder.AddNode("add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - - builder.AddDataEdge(data0, 0, relu0, 0); - builder.AddDataEdge(data1, 0, relu1, 0); - builder.AddDataEdge(relu0, 0, cast0, 0); - builder.AddDataEdge(relu1, 0, cast1, 0); - builder.AddDataEdge(cast0, 0, add0, 0); - builder.AddDataEdge(cast1, 0, add0, 1); - builder.AddDataEdge(add0, 0, relu2, 0); - builder.AddDataEdge(add0, 0, add1, 0); - builder.AddDataEdge(relu2, 0, add1, 1); - builder.AddControlEdge(relu0, add0); - builder.AddControlEdge(add0, cast2); - auto graph = builder.GetGraph(); - std::vector> output_nodes{{cast2, 0}, {add1, 0}}; - graph->SetOutputSize(2U); - graph->SetGraphOutNodesInfo(output_nodes); - - std::vector origin_node_sort; - for (const auto &node : graph->GetDirectNode()) { - std::cout << "origin node: " << node->GetName() << std::endl; - origin_node_sort.emplace_back(node->GetName()); - } - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 0, 1); - const auto &sub_data1 = sub_builder.AddNode("sub_data1", DATA, 0, 1); - const auto &sub_add0 = sub_builder.AddNode("sub_add0", "Add", 2, 1); - const auto &sub_relu0 = sub_builder.AddNode("sub_relu0", "Relu", 1, 1); - const auto &sub_add1 = sub_builder.AddNode("sub_add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(sub_data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - sub_builder.AddDataEdge(sub_data0, 0, sub_add0, 0); - sub_builder.AddDataEdge(sub_data1, 0, sub_add0, 1); - sub_builder.AddDataEdge(sub_data1, 0, sub_relu0, 0); - sub_builder.AddDataEdge(sub_add0, 0, sub_add1, 0); - sub_builder.AddDataEdge(sub_relu0, 0, sub_add1, 1); - sub_builder.AddControlEdge(sub_data0, sub_add1); - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_add1, 0}}; - sub_graph->SetOutputSize(1U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(add0, sub_graph), SUCCESS); - const auto add0_node = graph->FindNode("add0"); - EXPECT_EQ(add0_node, nullptr); - const auto subgraph_data0_node = graph->FindNode("sub_data0"); - EXPECT_EQ(subgraph_data0_node, nullptr); - const auto subgraph_data1_node = graph->FindNode("sub_data1"); - EXPECT_EQ(subgraph_data1_node, nullptr); - - const auto sub_add0_node = graph->FindNode("sub_add0"); - EXPECT_EQ(sub_add0_node, sub_add0); - const auto add_in_data_anchors = sub_add0_node->GetAllInDataAnchors(); - const auto add_in_data_anchor_0 = add_in_data_anchors.at(0); - const auto peer_out_add_in_data_anchor_0 = add_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor_0->GetOwnerNode()->GetName(), "cast0"); - const auto add_in_data_anchor_1 = add_in_data_anchors.at(1); - const auto peer_out_add_in_data_anchor_1 = add_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add_in_data_anchor_1->GetOwnerNode()->GetName(), "cast1"); - - const auto add_in_control_anchor = sub_add0_node->GetInControlAnchor(); - ASSERT_NE(add_in_control_anchor, nullptr); - const auto peer_out_add_in_control_anchors = add_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_add_in_control_anchors.size(), 2U); - EXPECT_EQ(peer_out_add_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - EXPECT_EQ(peer_out_add_in_control_anchors.at(1)->GetOwnerNode()->GetName(), "relu0"); - const auto add_out_data_anchor = sub_add0_node->GetOutDataAnchor(0); - ASSERT_NE(add_out_data_anchor, nullptr); - const auto peer_in_add_out_data_anchors = add_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_add_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add1"); - const auto add_out_control_anchor = sub_add0_node->GetOutControlAnchor(); - ASSERT_NE(add_out_control_anchor, nullptr); - EXPECT_EQ(add_out_control_anchor->GetPeerInControlAnchors().size(), 0U); - - const auto sub_relu0_node = graph->FindNode("sub_relu0"); - EXPECT_EQ(sub_relu0_node, sub_relu0); - const auto relu0_in_data_anchor = sub_relu0_node->GetInDataAnchor(0); - const auto peer_out_relu0_in_data_anchor = relu0_in_data_anchor->GetPeerOutAnchor(); - ASSERT_NE(peer_out_relu0_in_data_anchor, nullptr); - EXPECT_EQ(peer_out_relu0_in_data_anchor->GetOwnerNode()->GetName(), "cast1"); - - const auto relu0_in_control_anchor = sub_relu0_node->GetInControlAnchor(); - ASSERT_NE(relu0_in_control_anchor, nullptr); - const auto peer_out_relu0_in_control_anchors = relu0_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_relu0_in_control_anchors.size(), 1U); - EXPECT_EQ(peer_out_relu0_in_control_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - const auto relu0_out_data_anchor = sub_relu0_node->GetOutDataAnchor(0); - ASSERT_NE(relu0_out_data_anchor, nullptr); - const auto peer_in_relu0_out_data_anchors = relu0_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_relu0_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_relu0_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "sub_add1"); - const auto relu0_out_control_anchor = sub_relu0_node->GetOutControlAnchor(); - ASSERT_NE(relu0_out_control_anchor, nullptr); - EXPECT_EQ(relu0_out_control_anchor->GetPeerInControlAnchors().size(), 0U); - - const auto sub_add1_node = graph->FindNode("sub_add1"); - EXPECT_EQ(sub_add1_node, sub_add1); - const auto add1_in_data_anchors = sub_add1_node->GetAllInDataAnchors(); - const auto add1_in_data_anchor_0 = add1_in_data_anchors.at(0); - const auto peer_out_add1_in_data_anchor_0 = add1_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_0->GetOwnerNode()->GetName(), "sub_add0"); - const auto add1_in_data_anchor_1 = add1_in_data_anchors.at(1); - const auto peer_out_add1_in_data_anchor_1 = add1_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_1->GetOwnerNode()->GetName(), "sub_relu0"); - - const auto add1_in_control_anchor = sub_add1_node->GetInControlAnchor(); - ASSERT_NE(add1_in_control_anchor, nullptr); - const auto peer_out_add1_in_control_anchors = add1_in_control_anchor->GetPeerOutControlAnchors(); - EXPECT_EQ(peer_out_add1_in_control_anchors.size(), 0U); - - const auto add1_out_data_anchor = sub_add1_node->GetOutDataAnchor(0); - ASSERT_NE(add1_out_data_anchor, nullptr); - const auto peer_in_add1_out_data_anchors = add1_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add1_out_data_anchors.size(), 2); - EXPECT_EQ(peer_in_add1_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "relu2"); - EXPECT_EQ(peer_in_add1_out_data_anchors.at(1)->GetOwnerNode()->GetName(), "add1"); - const auto add1_out_control_anchor = sub_add1_node->GetOutControlAnchor(); - ASSERT_NE(add1_out_control_anchor, nullptr); - EXPECT_EQ(add1_out_control_anchor->GetPeerInControlAnchors().size(), 1U); - EXPECT_EQ(add1_out_control_anchor->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); - - // 验证topo序 - std::vector expect_sort; - for (const auto &origin_node_name : origin_node_sort) { - if (origin_node_name == "add0") { - for (const auto &subgraph_node : sub_graph->GetDirectNode()) { - if (subgraph_node->GetType() != "Data") { - expect_sort.emplace_back(subgraph_node->GetName()); - } - } - continue; - } - expect_sort.emplace_back(origin_node_name); - } - size_t index = 0UL; - for (const auto &node : graph->GetDirectNode()) { - EXPECT_EQ(node->GetName(), expect_sort[index]); - index++; - } -} - -/* - Data Data Data - | | | - \ | / Data Data Data - identity ---> \ / \ / - / \ Add0 Add1 - Relu0 Relu1 -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithOutputNotMatch) { - auto builder = ut::GraphBuilder("test_expand_node_output_not_match"); - const auto &data0 = builder.AddNode("data0", DATA, 0, 1); - const auto &data1 = builder.AddNode("data1", DATA, 0, 1); - const auto &data2 = builder.AddNode("data2", DATA, 0, 1); - const auto &identityN = builder.AddNode("identityn0", "IdentityN", 3, 3); - const auto &relu0 = builder.AddNode("relu0", "Relu", 1, 1); - const auto &relu1 = builder.AddNode("relu1", "Relu", 1, 1); - - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); - - builder.AddDataEdge(data0, 0, identityN, 0); - builder.AddDataEdge(data1, 0, identityN, 1); - builder.AddDataEdge(data2, 0, identityN, 2); - builder.AddDataEdge(identityN, 0, relu0, 0); - builder.AddDataEdge(identityN, 1, relu1, 0); - auto graph = builder.GetGraph(); - std::vector> output_nodes{{relu0, 0}, {relu1, 0}}; - graph->SetOutputSize(2U); - graph->SetGraphOutNodesInfo(output_nodes); - - std::vector origin_node_sort; - for (const auto &node : graph->GetDirectNode()) { - origin_node_sort.emplace_back(node->GetName()); - } - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 0, 1); - const auto &sub_data1 = sub_builder.AddNode("sub_data1", DATA, 0, 1); - const auto &sub_data2 = sub_builder.AddNode("sub_data2", DATA, 0, 1); - const auto &sub_add0 = sub_builder.AddNode("sub_add0", "Add", 2, 1); - const auto &sub_add1 = sub_builder.AddNode("sub_add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(sub_data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - AttrUtils::SetInt(sub_data2->GetOpDesc(), ATTR_NAME_INDEX, 2); - sub_builder.AddDataEdge(sub_data0, 0, sub_add0, 0); - sub_builder.AddDataEdge(sub_data1, 0, sub_add0, 1); - sub_builder.AddDataEdge(sub_data1, 0, sub_add1, 0); - sub_builder.AddDataEdge(sub_data2, 0, sub_add1, 1); - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_add0, 0}, {sub_add1, 0}}; - sub_graph->SetOutputSize(2U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(identityN, sub_graph), SUCCESS); - const auto identityn_node = graph->FindNode("identityn0"); - EXPECT_EQ(identityn_node, nullptr); - const auto subgraph_data0_node = graph->FindNode("sub_data0"); - EXPECT_EQ(subgraph_data0_node, nullptr); - const auto subgraph_data1_node = graph->FindNode("sub_data1"); - EXPECT_EQ(subgraph_data1_node, nullptr); - const auto subgraph_data2_node = graph->FindNode("sub_data2"); - EXPECT_EQ(subgraph_data2_node, nullptr); - - const auto sub_add0_node = graph->FindNode("sub_add0"); - EXPECT_EQ(sub_add0_node, sub_add0); - const auto add0_in_data_anchors = sub_add0_node->GetAllInDataAnchors(); - const auto add0_in_data_anchor_0 = add0_in_data_anchors.at(0); - const auto peer_out_add0_in_data_anchor_0 = add0_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add0_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add0_in_data_anchor_0->GetOwnerNode()->GetName(), "data0"); - const auto add0_in_data_anchor_1 = add0_in_data_anchors.at(1); - const auto peer_out_add0_in_data_anchor_1 = add0_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add0_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add0_in_data_anchor_1->GetOwnerNode()->GetName(), "data1"); - - const auto add0_out_data_anchor = sub_add0_node->GetOutDataAnchor(0); - ASSERT_NE(add0_out_data_anchor, nullptr); - const auto peer_in_add0_out_data_anchors = add0_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add0_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_add0_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "relu0"); - - const auto sub_add1_node = graph->FindNode("sub_add1"); - EXPECT_EQ(sub_add1_node, sub_add1); - const auto add1_in_data_anchors = sub_add1_node->GetAllInDataAnchors(); - const auto add1_in_data_anchor_0 = add1_in_data_anchors.at(0); - const auto peer_out_add1_in_data_anchor_0 = add1_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_0->GetOwnerNode()->GetName(), "data1"); - const auto add1_in_data_anchor_1 = add1_in_data_anchors.at(1); - const auto peer_out_add1_in_data_anchor_1 = add1_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_1->GetOwnerNode()->GetName(), "data2"); - - const auto add1_out_data_anchor = sub_add1_node->GetOutDataAnchor(0); - ASSERT_NE(add1_out_data_anchor, nullptr); - const auto peer_in_add1_out_data_anchors = add1_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add1_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_add1_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "relu1"); - - // 验证topo序 - std::vector expect_sort; - for (const auto &origin_node_name : origin_node_sort) { - if (origin_node_name == "identityn0") { - for (const auto &subgraph_node : sub_graph->GetDirectNode()) { - if (subgraph_node->GetType() != "Data") { - expect_sort.emplace_back(subgraph_node->GetName()); - } - } - continue; - } - expect_sort.emplace_back(origin_node_name); - } - size_t index = 0UL; - for (const auto &node : graph->GetDirectNode()) { - EXPECT_EQ(node->GetName(), expect_sort[index]); - index++; - } -} - -/* - Data Data Data - | | | - \ | / Relu Data Data Data - identity | ---> \ / \ / - \ / Add0 Add1 - Add -*/ -TEST_F(UtestGraphUtils, TestExpandNodeWithOutputWithOutNodeInfos) { - auto builder = ut::GraphBuilder("test_expand_node_output_not_match"); - const auto &data0 = builder.AddNode("data0", DATA, 0, 1); - const auto &data1 = builder.AddNode("data1", DATA, 0, 1); - const auto &data2 = builder.AddNode("data2", DATA, 0, 1); - const auto &identityN = builder.AddNode("identityn0", "IdentityN", 3, 3); - const auto &relu0 = builder.AddNode("relu0", "Relu", 1, 1); - const auto &add0 = builder.AddNode("add0", "Add", 2, 1); - - // 设置属性 - AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); - - builder.AddDataEdge(data0, 0, identityN, 0); - builder.AddDataEdge(data1, 0, identityN, 1); - builder.AddDataEdge(data2, 0, identityN, 2); - builder.AddDataEdge(data2, 0, relu0, 0); - builder.AddDataEdge(identityN, 1, add0, 0); - builder.AddDataEdge(relu0, 0, add0, 1); - auto graph = builder.GetGraph(); - std::vector> output_nodes{{identityN, 0}, {identityN, 1}, {add0, 0}}; - graph->SetOutputSize(3U); - graph->SetGraphOutNodesInfo(output_nodes); - - std::vector origin_node_sort; - for (const auto &node : graph->GetDirectNode()) { - origin_node_sort.emplace_back(node->GetName()); - } - - auto sub_builder = ut::GraphBuilder("subgraph"); - const auto &sub_data0 = sub_builder.AddNode("sub_data0", DATA, 0, 1); - const auto &sub_data1 = sub_builder.AddNode("sub_data1", DATA, 0, 1); - const auto &sub_data2 = sub_builder.AddNode("sub_data2", DATA, 0, 1); - const auto &sub_add0 = sub_builder.AddNode("sub_add0", "Add", 2, 1); - const auto &sub_add1 = sub_builder.AddNode("sub_add1", "Add", 2, 1); - // 设置属性 - AttrUtils::SetInt(sub_data0->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(sub_data1->GetOpDesc(), ATTR_NAME_INDEX, 1); - AttrUtils::SetInt(sub_data2->GetOpDesc(), ATTR_NAME_INDEX, 2); - sub_builder.AddDataEdge(sub_data0, 0, sub_add0, 0); - sub_builder.AddDataEdge(sub_data1, 0, sub_add0, 1); - sub_builder.AddDataEdge(sub_data1, 0, sub_add1, 0); - sub_builder.AddDataEdge(sub_data2, 0, sub_add1, 1); - auto sub_graph = sub_builder.GetGraph(); - std::vector> sub_output_nodes{{sub_add0, 0}, {sub_add1, 0}}; - sub_graph->SetOutputSize(2U); - sub_graph->SetGraphOutNodesInfo(sub_output_nodes); - - EXPECT_EQ(GraphUtils::ExpandNodeWithGraph(identityN, sub_graph), SUCCESS); - const auto identityn_node = graph->FindNode("identityn0"); - EXPECT_EQ(identityn_node, nullptr); - const auto subgraph_data0_node = graph->FindNode("sub_data0"); - EXPECT_EQ(subgraph_data0_node, nullptr); - const auto subgraph_data1_node = graph->FindNode("sub_data1"); - EXPECT_EQ(subgraph_data1_node, nullptr); - const auto subgraph_data2_node = graph->FindNode("sub_data2"); - EXPECT_EQ(subgraph_data2_node, nullptr); - - const auto sub_add0_node = graph->FindNode("sub_add0"); - EXPECT_EQ(sub_add0_node, sub_add0); - const auto add0_in_data_anchors = sub_add0_node->GetAllInDataAnchors(); - const auto add0_in_data_anchor_0 = add0_in_data_anchors.at(0); - const auto peer_out_add0_in_data_anchor_0 = add0_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add0_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add0_in_data_anchor_0->GetOwnerNode()->GetName(), "data0"); - const auto add0_in_data_anchor_1 = add0_in_data_anchors.at(1); - const auto peer_out_add0_in_data_anchor_1 = add0_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add0_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add0_in_data_anchor_1->GetOwnerNode()->GetName(), "data1"); - - const auto add0_out_data_anchor = sub_add0_node->GetOutDataAnchor(0); - ASSERT_NE(add0_out_data_anchor, nullptr); - const auto peer_in_add0_out_data_anchors = add0_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add0_out_data_anchors.size(), 0UL); - - const auto sub_add1_node = graph->FindNode("sub_add1"); - EXPECT_EQ(sub_add1_node, sub_add1); - const auto add1_in_data_anchors = sub_add1_node->GetAllInDataAnchors(); - const auto add1_in_data_anchor_0 = add1_in_data_anchors.at(0); - const auto peer_out_add1_in_data_anchor_0 = add1_in_data_anchor_0->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_0, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_0->GetOwnerNode()->GetName(), "data1"); - const auto add1_in_data_anchor_1 = add1_in_data_anchors.at(1); - const auto peer_out_add1_in_data_anchor_1 = add1_in_data_anchor_1->GetPeerOutAnchor(); - ASSERT_NE(peer_out_add1_in_data_anchor_1, nullptr); - EXPECT_EQ(peer_out_add1_in_data_anchor_1->GetOwnerNode()->GetName(), "data2"); - - const auto add1_out_data_anchor = sub_add1_node->GetOutDataAnchor(0); - ASSERT_NE(add1_out_data_anchor, nullptr); - const auto peer_in_add1_out_data_anchors = add1_out_data_anchor->GetPeerInDataAnchors(); - EXPECT_EQ(peer_in_add1_out_data_anchors.size(), 1); - EXPECT_EQ(peer_in_add1_out_data_anchors.at(0)->GetOwnerNode()->GetName(), "add0"); - - // 验证输出NodeInfo - const auto output_node_infos = graph->GetGraphOutNodesInfo(); - std::vector> expect_node = {{"sub_add0", 0}, {"sub_add1", 0}, {"add0", 0}}; - ASSERT_EQ(output_node_infos.size(), expect_node.size()); - for (size_t i = 0UL; i < output_node_infos.size(); i++) { - EXPECT_EQ(output_node_infos[i].first->GetName(), expect_node[i].first); - EXPECT_EQ(output_node_infos[i].second, expect_node[i].second); - } - // 验证topo序 - std::vector expect_sort; - for (const auto &origin_node_name : origin_node_sort) { - if (origin_node_name == "identityn0") { - for (const auto &subgraph_node : sub_graph->GetDirectNode()) { - if (subgraph_node->GetType() != "Data") { - expect_sort.emplace_back(subgraph_node->GetName()); - } - } - continue; - } - expect_sort.emplace_back(origin_node_name); - } - size_t index = 0UL; - for (const auto &node : graph->GetDirectNode()) { - EXPECT_EQ(node->GetName(), expect_sort[index]); - index++; - } -} -TEST_F(UtestGraphUtils, CreateGraphPtrFromComputeGraphOk) { - auto compute_graph = std::make_shared("test_graph"); - auto graph = GraphUtilsEx::CreateGraphPtrFromComputeGraph(compute_graph); - ASSERT_NE(graph, nullptr); - EXPECT_EQ(graph->GetName(), "test_graph"); - auto cg2 = GraphUtilsEx::GetComputeGraph(*graph); - ASSERT_EQ(compute_graph.get(), cg2.get()); -} -TEST_F(UtestGraphUtils, CreateGraphPtrFromComputeGraph_nullptr) { - auto graph = GraphUtilsEx::CreateGraphPtrFromComputeGraph(nullptr); - ASSERT_EQ(graph, nullptr); -} - -TEST_F(UtestGraphUtils, ConvertInDataEdgesToInCtrlEdges_nullptr) { - EXPECT_EQ(GRAPH_PARAM_INVALID, GraphUtils::ConvertInDataEdgesToInCtrlEdges(nullptr, nullptr, nullptr)); -} - -TEST_F(UtestGraphUtils, ConvertOutDataEdgesToOutCtrlEdges_nullptr) { - EXPECT_EQ(GRAPH_PARAM_INVALID, GraphUtils::ConvertOutDataEdgesToOutCtrlEdges(nullptr, nullptr, nullptr)); -} - -TEST_F(UtestGraphUtils, IsolateNodeNodeWithNoOpOptimize) { - auto graph = BuildGraphForIsolateNode(30, 30, 10, 10); - auto node = graph->FindNode("del_node"); - // 断掉所有连边 - std::vector io_map = {}; - auto ret = GraphUtils::IsolateNode(node, io_map); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto noop = graph->FindFirstNodeMatchType(NOOP); - EXPECT_NE(noop, nullptr); - EXPECT_EQ(noop->GetInControlNodesSize(), 40UL); - EXPECT_EQ(noop->GetOutControlNodesSize(), 40UL); -} - -TEST_F(UtestGraphUtils, IsolateNodeNodeWithNoOpOptimize_AllCtrlEdge) { - // 达到阈值触发NoOp优化 - int io_num = 32; - auto graph = BuildGraphForIsolateNode(0, 0, io_num, io_num); - auto node = graph->FindNode("del_node"); - std::vector io_map{}; - auto ret = GraphUtils::IsolateNode(node, io_map); - GraphUtils::DumpGEGraphToOnnx(*graph, "TestIsolateNode_After"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto noop = graph->FindFirstNodeMatchType(NOOP); - EXPECT_NE(noop, nullptr); - EXPECT_EQ(noop->GetInControlNodesSize(), io_num); - EXPECT_EQ(noop->GetOutControlNodesSize(), io_num); - - auto in_node = graph->FindNode("in_ctrl_node_0"); - auto out_node = graph->FindNode("out_ctrl_node_0"); - EXPECT_NE(in_node, nullptr); - EXPECT_NE(out_node, nullptr); - EXPECT_EQ(in_node->GetOutControlNodesSize(), 1); - EXPECT_EQ(out_node->GetInControlNodesSize(), 1); -} - -TEST_F(UtestGraphUtils, IsolateNodeNodeWithOutNoOpOptimize_AllCtrlEdge) { - // 未达到阈值,未触发NoOp优化 - int io_num = 30; - auto graph = BuildGraphForIsolateNode(0, 0, io_num, io_num); - auto node = graph->FindNode("del_node"); - std::vector io_map{}; - auto ret = GraphUtils::IsolateNode(node, io_map); - GraphUtils::DumpGEGraphToOnnx(*graph, "TestIsolateNode_After"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto noop = graph->FindFirstNodeMatchType(NOOP); - EXPECT_EQ(noop, nullptr); - - auto in_node = graph->FindNode("in_ctrl_node_0"); - auto out_node = graph->FindNode("out_ctrl_node_0"); - EXPECT_NE(in_node, nullptr); - EXPECT_NE(out_node, nullptr); - EXPECT_EQ(in_node->GetOutControlNodesSize(), io_num); - EXPECT_EQ(out_node->GetInControlNodesSize(), io_num); -} - -TEST_F(UtestGraphUtils, IsolateNodeNodeWithNoOpOptimize_SetIoMap_NoInAnchr) { - int io_num = 32; - auto graph = BuildGraphForIsolateNode(io_num, io_num, 0, 0); - auto node = graph->FindNode("del_node"); - // 指定输入0和所有输出建立数据边,因此该输入节点不需要连控制边到NoOp - std::vector io_map(io_num, 0); - auto ret = GraphUtils::IsolateNode(node, io_map); - GraphUtils::DumpGEGraphToOnnx(*graph, "TestIsolateNode_After"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto noop = graph->FindFirstNodeMatchType(NOOP); - EXPECT_NE(noop, nullptr); - EXPECT_EQ(noop->GetInControlNodesSize(), io_num - 1); - EXPECT_EQ(noop->GetOutControlNodesSize(), io_num); - - auto in_node = graph->FindNode("in_node_0"); - auto out_node = graph->FindNode("out_node_0"); - EXPECT_NE(in_node, nullptr); - EXPECT_NE(out_node, nullptr); - EXPECT_NE(out_node->GetInDataAnchor(0), nullptr); - EXPECT_NE(out_node->GetInDataAnchor(0)->GetPeerOutAnchor(), nullptr); - EXPECT_EQ(out_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(), in_node); -} - -TEST_F(UtestGraphUtils, IsolateNodeNodeWithNoOpOptimize_SetIoMap_NoOutAnchr) { - auto graph_builder = ut::GraphBuilder("graph"); - const auto &del_node = graph_builder.AddNode("del_node", "DelNode", 2, 510); - - const auto &n1 = graph_builder.AddNode("in_node_1", "InNode", 1, 1); - graph_builder.AddDataEdge(n1, 0, del_node, 0); - - const auto &n2 = graph_builder.AddNode("in_node_2", "InNode", 1, 1); - graph_builder.AddDataEdge(n2, 0, del_node, 1); - - const auto &n3 = graph_builder.AddNode("out_node_0", "OutNode", 2, 1); - graph_builder.AddDataEdge(del_node, 0, n3, 0); - graph_builder.AddDataEdge(del_node, 1, n3, 1); - - for (int i = 2; i < 510; ++i) { - const auto &n = graph_builder.AddNode("out_node" + std::to_string(i), "OutNode", 1, 1); - graph_builder.AddDataEdge(del_node, i, n, 0); - } - auto graph = graph_builder.GetGraph(); - - - auto node = graph->FindNode("del_node"); - setenv("DUMP_GE_GRAPH", "1", 1); - setenv("DUMP_GRAPH_LEVEL", "1", 1); - setenv("DUMP_GRAPH_PATH", "/home/yangyongqiang/code/dump_graph", 1); - dlog_setlevel(0, 0, 0); - GraphUtils::DumpGEGraphToOnnx(*graph, "TestIsolateNode_Before"); - - // 指定输出0和所有输入建立数据边,因此该输出节点不需要连控制边到NoOp - std::vector io_map = {0, 1}; - auto ret = GraphUtils::IsolateNode(node, io_map); - GraphUtils::DumpGEGraphToOnnx(*graph, "TestIsolateNode_After"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto noop = graph->FindFirstNodeMatchType(NOOP); - EXPECT_NE(noop, nullptr); - EXPECT_EQ(noop->GetInControlNodesSize(), 2); - EXPECT_EQ(noop->GetOutControlNodesSize(), 508); - - auto in_node_1 = graph->FindNode("in_node_1"); - auto in_node_2 = graph->FindNode("in_node_2"); - auto out_node_0 = graph->FindNode("out_node_0"); - EXPECT_NE(in_node_1, nullptr); - EXPECT_NE(in_node_2, nullptr); - EXPECT_NE(out_node_0, nullptr); - EXPECT_NE(out_node_0->GetInDataAnchor(0), nullptr); - EXPECT_NE(out_node_0->GetInDataAnchor(0)->GetPeerOutAnchor(), nullptr); - EXPECT_EQ(out_node_0->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(), in_node_1); - EXPECT_NE(out_node_0->GetInDataAnchor(1), nullptr); - EXPECT_NE(out_node_0->GetInDataAnchor(1)->GetPeerOutAnchor(), nullptr); - EXPECT_EQ(out_node_0->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode(), in_node_2); -} - -TEST_F(UtestGraphUtils, DumpGEGraphToOnnxForLongName_Then_CutoffTheFileName) { - setenv("DUMP_GE_GRAPH", "1", 1); - ComputeGraph compute_graph("test_graph0"); - const std::string suffix = std::string(255, 'a'); - ge::GraphUtils::DumpGEGraphToOnnx(compute_graph, suffix); - unsetenv("DUMP_GE_GRAPH"); - system("rm -rf ./ge_onnx*"); -} - -TEST_F(UtestGraphUtils, DumpGrphToOnnxForLongName_Then_CutoffTheFileName) { - setenv("DUMP_GE_GRAPH", "1", 1); - ComputeGraph compute_graph("test_graph0"); - const std::string suffix = std::string(255, 'a'); - const std::string path = "./"; - ge::GraphUtils::DumpGrphToOnnx(compute_graph, path, suffix); - unsetenv("DUMP_GE_GRAPH"); - system("rm -rf ./ge_onnx*"); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/hcom_topo_info_unittest.cc b/tests/ut/graph/testcase/hcom_topo_info_unittest.cc deleted file mode 100644 index 1e21c3679e513f53e8bd5247586db4828d29fc11..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/hcom_topo_info_unittest.cc +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "external/hcom/hcom_topo_info.h" -namespace ge { -class UtestHcomTopoInfo : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; -TEST_F(UtestHcomTopoInfo, SetGroupTopoInfo) { - HcomTopoInfo::TopoInfo topo_info; - topo_info.rank_size = 8; - const std::string group = "group0"; - - // add invalid - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupTopoInfo(nullptr, topo_info), GRAPH_FAILED); - - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupTopoInfo(group.c_str(), topo_info), GRAPH_SUCCESS); - // add repeated, over write - topo_info.notify_handle = reinterpret_cast(0x8000); - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupTopoInfo(group.c_str(), topo_info), GRAPH_SUCCESS); - void *handle = nullptr; - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupNotifyHandle(group.c_str(), handle), GRAPH_SUCCESS); - EXPECT_EQ(handle, reinterpret_cast(0x8000)); - HcomTopoInfo::TopoInfo topo_info_existed; - EXPECT_TRUE(HcomTopoInfo::Instance().TryGetGroupTopoInfo(group.c_str(), topo_info_existed)); - EXPECT_TRUE(HcomTopoInfo::Instance().TopoInfoHasBeenSet(group.c_str())); - EXPECT_EQ(topo_info_existed.notify_handle, reinterpret_cast(0x8000)); - EXPECT_EQ(topo_info_existed.rank_size, 8); -} - -TEST_F(UtestHcomTopoInfo, GetAndUnsetGroupTopoInfo) { - int64_t rank_size = -1; - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupRankSize("group0", rank_size), GRAPH_SUCCESS); - EXPECT_EQ(rank_size, 8); - // not added, get failed - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupRankSize("group1", rank_size), GRAPH_FAILED); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupTopoDesc("group1"), nullptr); - void *handle = nullptr; - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupNotifyHandle("group1", handle), GRAPH_FAILED); - HcomTopoInfo::Instance().UnsetGroupTopoInfo("group1"); - HcomTopoInfo::Instance().UnsetGroupTopoInfo("group0"); - // removed - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupRankSize("group0", rank_size), GRAPH_FAILED); - - // construct topo info - HcomTopoInfo::TopoInfo topo_info; - HcomTopoInfo::TopoLevelDesc t0 = {2, 2}; - HcomTopoInfo::TopoLevelDesc t1 = {6, 6}; - EXPECT_EQ(sizeof(topo_info.topo_level_descs) / sizeof(HcomTopoInfo::TopoLevelDesc), - static_cast(HcomTopoInfo::TopoLevel::MAX)); - topo_info.rank_size = 16; - topo_info.topo_level_descs[0] = t0; - topo_info.topo_level_descs[1] = t1; - // add again - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupTopoInfo("group0", topo_info), GRAPH_SUCCESS); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupRankSize("group0", rank_size), GRAPH_SUCCESS); - EXPECT_EQ(rank_size, 16); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupRankSize("", rank_size), GRAPH_FAILED); - // check - auto topo_desc = HcomTopoInfo::Instance().GetGroupTopoDesc("group0"); - EXPECT_NE(topo_desc, nullptr); - EXPECT_EQ(sizeof(*topo_desc) / sizeof(HcomTopoInfo::TopoLevelDesc), - static_cast(HcomTopoInfo::TopoLevel::MAX)); - - HcomTopoInfo::TopoDescs topo_desc_to_check = {t0, t1}; - EXPECT_EQ((*topo_desc)[0].comm_sets, topo_desc_to_check[0].comm_sets); - EXPECT_EQ((*topo_desc)[0].rank_size, topo_desc_to_check[0].rank_size); - EXPECT_EQ((*topo_desc)[1].comm_sets, topo_desc_to_check[1].comm_sets); - EXPECT_EQ((*topo_desc)[1].rank_size, topo_desc_to_check[1].rank_size); -} - -TEST_F(UtestHcomTopoInfo, SetAndGetAndUnsetGroupOrderedStreamWithDeviceId) { - const std::string group0 = "group0"; - const std::string group1 = "group1"; - const std::string group = "group"; - void *stream0= (void *)1; - void *stream1= (void *)2; - void *stream = nullptr; - int32_t device0 = 0; - int32_t device1 = 1; - - // set group nullptr - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupOrderedStream(device0, nullptr, stream0), GRAPH_FAILED); - - // set and get - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupOrderedStream(device0, group0.c_str(), stream0), GRAPH_SUCCESS); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device0, group0.c_str(), stream), GRAPH_SUCCESS); - EXPECT_EQ(stream, stream0); - - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupOrderedStream(device0, group1.c_str(), stream1), GRAPH_SUCCESS); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device0, group1.c_str(), stream), GRAPH_SUCCESS); - EXPECT_EQ(stream, stream1); - - // override - EXPECT_EQ(HcomTopoInfo::Instance().SetGroupOrderedStream(device0, group0.c_str(), stream1), GRAPH_SUCCESS); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device0, group0.c_str(), stream), GRAPH_SUCCESS); - EXPECT_EQ(stream, stream1); - - // get no device - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device1, group.c_str(), stream), GRAPH_FAILED); - - // get no group - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device0, group.c_str(), stream), GRAPH_FAILED); - - // unset group - HcomTopoInfo::Instance().UnsetGroupOrderedStream(device0, group0.c_str()); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device0, group0.c_str(), stream), GRAPH_FAILED); - HcomTopoInfo::Instance().UnsetGroupOrderedStream(device0, group1.c_str()); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device0, group1.c_str(), stream), GRAPH_FAILED); - - // repeat unset - HcomTopoInfo::Instance().UnsetGroupOrderedStream(device0, group1.c_str()); - EXPECT_EQ(HcomTopoInfo::Instance().GetGroupOrderedStream(device0, group1.c_str(), stream), GRAPH_FAILED); -} - -} diff --git a/tests/ut/graph/testcase/infer_datatype_unittest.cc b/tests/ut/graph/testcase/infer_datatype_unittest.cc deleted file mode 100644 index 997364a1f6790fdf68a2a1bffd1dd3da3d7af7c2..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/infer_datatype_unittest.cc +++ /dev/null @@ -1,599 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "graph/op_desc.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph_builder_utils.h" -#include "graph/operator_reg.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/operator_factory_impl.h" -#include "graph/compute_graph.h" -#include "graph/utils/recover_ir_utils.h" -#include "toolchain/slog.h" - -namespace ge { -class UTInferDataType : public testing::Test { - protected: - void SetUp() { - dlog_setlevel(0, 0, 0); - } - - void TearDown() { - dlog_setlevel(0, 3, 0); - } -}; - -class OpDtypeInfer { - public: - struct TypeOrTypes { - explicit TypeOrTypes(const DataType &expect_type) : dynamic(false), types({expect_type}) {} - explicit TypeOrTypes(const std::vector &expect_types) : dynamic(true), types(expect_types) {} - - bool dynamic = false; - std::vector types; - }; - - explicit OpDtypeInfer(const std::string &type) { - auto op = OperatorFactory::CreateOperator(type, type); - desc_ = OpDescUtils::GetOpDescFromOperator(op); - } - - explicit OpDtypeInfer(const OpDescPtr &desc) : desc_(desc) {} - - OpDtypeInfer &Input(const DataType &type) { - int32_t ir_index = ++index_; - auto format = FORMAT_ND; - if (type == DT_UNDEFINED) { - format = FORMAT_RESERVED; - } - desc_->UpdateInputDesc("input" + std::to_string(ir_index), GeTensorDesc(GeShape(), format, type)); - return *this; - } - - OpDtypeInfer &Input(const std::initializer_list &raw_types) { - std::vector types(raw_types); - int32_t ir_index = ++index_; - desc_->AddDynamicInputDesc("input" + std::to_string(ir_index), types.size()); - for (size_t i = 0U; i < types.size(); ++i) { - desc_->UpdateInputDesc("input" + std::to_string(ir_index) + std::to_string(i), - GeTensorDesc(GeShape(), FORMAT_ND, types[i])); - } - return *this; - } - - OpDtypeInfer &Attr(const std::string &attr, const std::vector &types) { - AttrUtils::SetListDataType(desc_, attr, types); - return *this; - } - - OpDtypeInfer &Attr(const std::string &attr, const std::vector &types) { - AttrUtils::SetListInt(desc_, attr, types); - return *this; - } - - OpDtypeInfer &Attr(const std::string &attr, int32_t type) { - AttrUtils::SetInt(desc_, attr, type); - return *this; - } - - OpDtypeInfer &Attr(const std::string &attr, DataType type) { - AttrUtils::SetDataType(desc_, attr, type); - return *this; - } - - OpDtypeInfer &Expect(const DataType &type) { - expect_dtypes_.emplace_back(type); - return *this; - } - - OpDtypeInfer &Expect(const std::vector &types) { - desc_->AddDynamicOutputDesc("output" + std::to_string(expect_dtypes_.size() + 1U), types.size()); - expect_dtypes_.emplace_back(types); - return *this; - } - - void AssertSucceed() { - ASSERT_EQ(desc_->SymbolicInferDataType(), GRAPH_SUCCESS); - for (size_t i = 0U; i < expect_dtypes_.size(); ++i) { - std::string ir_output = "output" + std::to_string(i + 1); - if (!expect_dtypes_[i].dynamic) { - ASSERT_EQ(TypeUtils::DataTypeToSerialString(desc_->GetOutputDesc(ir_output).GetDataType()), - TypeUtils::DataTypeToSerialString(expect_dtypes_[i].types[0])); - } else { - for (size_t j = 0U; j < expect_dtypes_[i].types.size(); ++j) { - std::string ir_output_index = ir_output + std::to_string(j); - ASSERT_EQ(TypeUtils::DataTypeToSerialString(desc_->GetOutputDesc(ir_output_index).GetDataType()), - TypeUtils::DataTypeToSerialString(expect_dtypes_[i].types[j])); - } - } - } - } - - void AssertFailed() { - ASSERT_NE(desc_->SymbolicInferDataType(), GRAPH_SUCCESS); - } - - private: - std::vector expect_dtypes_; - int32_t index_ = 0; - OpDescPtr desc_; -}; - -/* ---------- 基于符号进行推导的基础用例 ---------- */ -REG_OP(Op1) - .OPTIONAL_INPUT(input1, "T") - .INPUT(input2, "T") - .DYNAMIC_INPUT(input3, "T") - .OUTPUT(output1, "T") - .DYNAMIC_OUTPUT(output2, "T") - .DATATYPE(T, TensorType({DT_FLOAT16, DT_FLOAT})) - .OP_END_FACTORY_REG(Op1); -TEST_F(UTInferDataType, sym_infer_from_regular_input_succeed) { - OpDtypeInfer("Op1") // T全部全部传入 - .Input(DT_FLOAT16) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_FLOAT16}) - .Expect(DT_FLOAT16) - .Expect({DT_FLOAT16, DT_FLOAT16}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_from_regular_input_unfed_opt) { - OpDtypeInfer("Op1") // 可选输入不传入,根据其他输入推导 - .Input(DT_UNDEFINED) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_FLOAT16}) - .Expect(DT_FLOAT16) - .Expect({DT_FLOAT16, DT_FLOAT16}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_from_regular_input_only_require) { - OpDtypeInfer("Op1") // 可选和动态都不传入,根据其他输入推导 - .Input(DT_UNDEFINED) - .Input(DT_FLOAT16) - .Input({}) - .Expect(DT_FLOAT16) - .Expect({DT_FLOAT16, DT_FLOAT16}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_from_regular_input_dtype_out_of_range) { - OpDtypeInfer("Op1") // 类型不在可选范围内 - .Input(DT_INT32) - .Input(DT_INT32) - .Input({DT_INT32, DT_INT32}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_from_regular_input_dtype_mismatch_between_ir_inputs) { - OpDtypeInfer("Op1") // 两个IR输入类型不一致 - .Input(DT_FLOAT16) - .Input(DT_FLOAT) - .Input({DT_FLOAT16, DT_FLOAT16}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_from_regular_input_dtype_mismatch_in_dyn) { - OpDtypeInfer("Op1") // 动态输入中的多个类型不一致 - .Input(DT_FLOAT16) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} - -/* ---------- 基于可选输入进行推导 ---------- */ -REG_OP(Op2) - .OPTIONAL_INPUT(input1, "T") - .OUTPUT(output1, "T") - .DYNAMIC_OUTPUT(output2, "T") - .DATATYPE(T, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(Op2); -TEST_F(UTInferDataType, sym_infer_from_optional_input_succeed) { - OpDtypeInfer("Op2") // 可选输入传入 - .Input(DT_FLOAT16) - .Expect(DT_FLOAT16) - .Expect({DT_FLOAT16, DT_FLOAT16}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_from_optional_input_unfed_opt) { - EXPECT_NO_THROW( - OpDtypeInfer("Op2") // 可选不传入,无法推导 - .Input(DT_UNDEFINED) - .AssertFailed(); - ); -} - -/* ---------- 基于动态输入进行推导 ---------- */ -REG_OP(Op3) - .DYNAMIC_INPUT(input1, "T") - .OUTPUT(output1, "T") - .DYNAMIC_OUTPUT(output2, "T") - .DATATYPE(T, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(Op3); -TEST_F(UTInferDataType, sym_infer_from_dynamic_input_succeed) { - OpDtypeInfer("Op3") // 动态输入不为空 - .Input({DT_FLOAT16}) - .Expect(DT_FLOAT16) - .Expect({DT_FLOAT16, DT_FLOAT16}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_from_dynamic_input_unfed_dyn) { - OpDtypeInfer("Op3") // 可选不传入,无法推导 - .Input({}) - .AssertFailed(); -} - -/* ---------- 基于属性进行推导 ---------- */ -REG_OP(Op4) - .REQUIRED_ATTR(dtype1, Int) - .REQUIRED_ATTR(dtype2, Type) - .REQUIRED_ATTR(dtype3, ListInt) - .REQUIRED_ATTR(dtype4, ListType) - .OUTPUT(output1, "dtype1") - .OUTPUT(output2, "dtype2") - .DYNAMIC_OUTPUT(output3, "dtype3") // 动态输出被List属性指定 - .DYNAMIC_OUTPUT(output4, "dtype4") - .DYNAMIC_OUTPUT(output5, "dtype1") // 动态输出被单个属性指定 - .DYNAMIC_OUTPUT(output6, "dtype2") - .DATATYPE(dtype1, TensorType({DT_FLOAT16, DT_FLOAT})) - .DATATYPE(dtype2, TensorType({DT_FLOAT16, DT_FLOAT})) - .DATATYPE(dtype3, ListTensorType({DT_FLOAT16, DT_FLOAT})) - .DATATYPE(dtype4, ListTensorType({DT_FLOAT16, DT_FLOAT})) - .OP_END_FACTORY_REG(Op4); -TEST_F(UTInferDataType, sym_infer_from_attr_succeed) { - OpDtypeInfer("Op4") // 根据属性进行推导 - .Attr("dtype1", int32_t(DT_FLOAT16)) - .Attr("dtype2", DT_FLOAT16) - .Attr("dtype3", std::vector{DT_FLOAT16, DT_FLOAT}) - .Attr("dtype4", std::vector{DT_FLOAT16, DT_FLOAT}) - .Expect(DT_FLOAT16) - .Expect(DT_FLOAT16) - .Expect({DT_FLOAT16, DT_FLOAT}) // 动态输出被List属性指定 - .Expect({DT_FLOAT16, DT_FLOAT}) - .Expect({DT_FLOAT16, DT_FLOAT16}) // 动态输出被单个属性指定 - .Expect({DT_FLOAT16, DT_FLOAT16}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_from_attr_dtype_out_of_range) { - OpDtypeInfer("Op4") // 属性不在允许范围内 - .Attr("dtype1", int32_t(DT_FLOAT16)) - .Attr("dtype2", DT_INT32) // 非法输入类型 - .Attr("dtype3", std::vector{DT_FLOAT16, DT_FLOAT}) - .Attr("dtype4", std::vector{DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_from_attr_list_dtype_out_of_range) { - OpDtypeInfer("Op4") // 属性不在允许范围内 - .Attr("dtype1", int32_t(DT_FLOAT16)) - .Attr("dtype2", DT_FLOAT16) - .Attr("dtype3", std::vector{DT_FLOAT16, DT_FLOAT}) - .Attr("dtype4", std::vector{DT_FLOAT16, DT_INT32}) // 非法输入类型 - .AssertFailed(); -} - -TEST_F(UTInferDataType, sym_infer_from_attr_but_type_mismatch_1) { - OpDtypeInfer("Op4") - .Attr("dtype1", int32_t(DT_FLOAT16)) - .Attr("dtype2", std::vector{DT_FLOAT16, DT_FLOAT}) // 需要单个类型,但是传入List - .Attr("dtype3", std::vector{DT_FLOAT16, DT_FLOAT}) - .Attr("dtype4", std::vector{DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} - -TEST_F(UTInferDataType, sym_infer_from_attr_but_type_mismatch_2) { - OpDtypeInfer("Op4") - .Attr("dtype1", int32_t(DT_FLOAT16)) - .Attr("dtype2", DT_FLOAT16) - .Attr("dtype3", DT_FLOAT16) // 需要List类型,但是传入单个类型 - .Attr("dtype4", std::vector{DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} - -/* ---------- 输出类型唯一场景,支持推导(类似Equal算子固定输出bool) ---------- */ -REG_OP(Op6) - .OUTPUT(output1, "T") - .DYNAMIC_OUTPUT(output2, "T") - .DATATYPE(T, TensorType({DT_FLOAT16})) - .OP_END_FACTORY_REG(Op6); -TEST_F(UTInferDataType, sym_infer_for_const_output_dtype) { // 老旧方式注册,但是输出类型唯一 - OpDtypeInfer("Op6").Expect(DT_FLOAT16).Expect({DT_FLOAT16, DT_FLOAT16}).AssertSucceed(); -} - -/* ---------- ListTensorType的类型推导 ---------- */ -REG_OP(Op7) - .DYNAMIC_INPUT(input1, "T") - .DYNAMIC_INPUT(input2, "T") - .DYNAMIC_OUTPUT(output1, "T") - .DATATYPE(T, ListTensorType({DT_FLOAT16, DT_FLOAT})) - .OP_END_FACTORY_REG(Op7); -TEST_F(UTInferDataType, sym_infer_for_list_dtype_succeed) { - OpDtypeInfer("Op7") // 正常推导 - .Input({DT_FLOAT16, DT_FLOAT, DT_FLOAT, DT_FLOAT16}) - .Input({DT_FLOAT16, DT_FLOAT, DT_FLOAT, DT_FLOAT16}) - .Expect({DT_FLOAT16, DT_FLOAT, DT_FLOAT, DT_FLOAT16}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_for_list_dtype_dtype_mismatch_between_dyn) { - OpDtypeInfer("Op7") // 对应同一个ListType sym的两个输入,类型合法但是不一致 - .Input({DT_FLOAT16, DT_FLOAT, DT_FLOAT, DT_FLOAT16}) - .Input({DT_FLOAT16, DT_FLOAT, DT_FLOAT16, DT_FLOAT16}) // 第三个输入类型不一致 - .Expect({DT_FLOAT16, DT_FLOAT, DT_FLOAT, DT_FLOAT16}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_for_list_dtype_dtype_out_of_range) { - OpDtypeInfer("Op7") // 数据类型不在范围内 - .Input({DT_FLOAT16, DT_INT32}) - .Input({DT_FLOAT16, DT_INT32}) - .AssertFailed(); -} - -/* ---------- 类型提升方式推导 ---------- */ -// 基础类型间提升 -REG_OP(Op8) - .INPUT(input1, "T1") - .DYNAMIC_INPUT(input2, "T2") - .INPUT(input3, "T3") - .OUTPUT(output1, "T4") - .DYNAMIC_OUTPUT(output2, "T5") - .DATATYPE(T1, TensorType({DT_INT32, DT_FLOAT})) - .DATATYPE(T2, TensorType({DT_INT64, DT_FLOAT})) - .DATATYPE(T3, TensorType({DT_FLOAT, DT_FLOAT16})) - .DATATYPE(T4, Promote({"T1", "T2"})) - .DATATYPE(T5, Promote({"T1", "T2", "T3"})) - .OP_END_FACTORY_REG(Op8); -TEST_F(UTInferDataType, sym_infer_for_dtype_promotion_succeed) { - OpDtypeInfer("Op8") - .Input(DT_INT32) - .Input({DT_INT64, DT_INT64}) - .Input(DT_FLOAT) - .Expect(DT_INT64) // T1和T2间提升为DT_INT64 - .Expect({DT_FLOAT, DT_FLOAT}) // T1,T2和T3间提升为DT_FLOAT - .AssertSucceed(); -} - -// ListTensorType间提升 -REG_OP(Op9) - .DYNAMIC_INPUT(input1, "T1") - .DYNAMIC_INPUT(input2, "T2") - .DYNAMIC_INPUT(input3, "T3") - .DYNAMIC_OUTPUT(output1, "T4") - .DYNAMIC_OUTPUT(output2, "T5") - .DATATYPE(T1, ListTensorType({DT_INT32, DT_FLOAT})) - .DATATYPE(T2, ListTensorType({DT_INT64, DT_FLOAT})) - .DATATYPE(T3, ListTensorType({DT_FLOAT, DT_FLOAT16})) - .DATATYPE(T4, Promote({"T1", "T2"})) - .DATATYPE(T5, Promote({"T1", "T2", "T3"})) - .OP_END_FACTORY_REG(Op9); -TEST_F(UTInferDataType, sym_infer_for_dtype_promotion_list_type_succeed) { - OpDtypeInfer("Op9") - .Input({DT_INT32, DT_FLOAT}) // T1 - .Input({DT_INT64, DT_INT64}) // T2 - .Input({DT_FLOAT, DT_FLOAT16}) // T3 - .Expect({DT_INT64, DT_FLOAT}) // T1和T2间逐个提升 - .Expect({DT_FLOAT, DT_FLOAT}) // T1,T2和T3间逐个提升 - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_for_dtype_promotion_list_type_dtype_size_mismatch) { - OpDtypeInfer("Op9") // 提升失败,数量不一致 - .Input({DT_INT32, DT_FLOAT, DT_FLOAT}) // T1 - .Input({DT_INT64, DT_INT64}) // T2 - .Input({DT_FLOAT, DT_FLOAT16}) // T3 - .AssertFailed(); -} - -// 试图在无提升规则的单类型间提升 -REG_OP(Op14) - .INPUT(input1, "T1") - .INPUT(input2, "T2") - .OUTPUT(output1, "T3") - .DATATYPE(T1, ListTensorType({DT_INT32, DT_FLOAT})) - .DATATYPE(T2, ListTensorType({DT_INT64, DT_FLOAT})) - .DATATYPE(T3, Promote({"T1", "T2"})) - .OP_END_FACTORY_REG(Op14); -TEST_F(UTInferDataType, sym_infer_for_dtype_promotion_unpromotable_types) { - EXPECT_NO_THROW( - OpDtypeInfer("Op14") // 提升失败,无提升规则 - .Input(DT_VARIANT) // T1 - .Input(DT_RESOURCE) // T2 - .AssertFailed(); - ); -} -// 试图在无提升规则的List类型间提升 -REG_OP(Op15) - .DYNAMIC_INPUT(input1, "T1") - .DYNAMIC_INPUT(input2, "T2") - .OUTPUT(output1, "T3") - .DATATYPE(T1, ListTensorType({DT_FLOAT16, DT_VARIANT})) - .DATATYPE(T2, ListTensorType({DT_FLOAT, DT_RESOURCE})) - .DATATYPE(T3, Promote({"T1", "T2"})) - .OP_END_FACTORY_REG(Op15); -TEST_F(UTInferDataType, sym_infer_for_dtype_promotion_unpromotable_list_types) { - OpDtypeInfer("Op15") // 提升失败,ListType中的某个无提升规则 - .Input({DT_FLOAT16, DT_VARIANT}) // T1 - .Input({DT_FLOAT, DT_RESOURCE}) // T2 - .AssertFailed(); -} - -/* ---------- 异常IR注册校验能力 ---------- */ -REG_OP(Op10) - .INPUT(input1, "T1") - .OUTPUT(output1, "T2") // 异常IR注册,PromoteDtype中只有一个类型 - .DATATYPE(T1, TensorType({DT_FLOAT16, DT_FLOAT})) - .DATATYPE(T2, Promote({"T1"})) - .OP_END_FACTORY_REG(Op10); -TEST_F(UTInferDataType, sym_infer_for_dtype_promotion_3) { - EXPECT_NO_THROW(OpDtypeInfer("Op10").Input(DT_FLOAT16).AssertFailed()); -} - -REG_OP(Op11) - .INPUT(input1, "T1") - .DYNAMIC_INPUT(input2, "T2") - .OUTPUT(output1, "T3") // 异常IR注册,TensorType和ListTensorType间试图提升 - .DATATYPE(T1, TensorType({DT_FLOAT16, DT_FLOAT})) - .DATATYPE(T2, ListTensorType({DT_INT64, DT_FLOAT})) - .DATATYPE(T3, Promote({"T1", "T2"})) - .OP_END_FACTORY_REG(Op11); -TEST_F(UTInferDataType, sym_infer_for_dtype_promotion_4) { - OpDtypeInfer("Op11").Input(DT_FLOAT16).Input({DT_INT64, DT_FLOAT}).AssertFailed(); -} - -REG_OP(Op12) - .OUTPUT(output1, ListTensorType({DT_FLOAT16, DT_FLOAT})) // 输出为老旧方式注册的ListTensorType - .OP_END_FACTORY_REG(Op12); -TEST_F(UTInferDataType, sym_infer_for_legacy_list_type_output) { - EXPECT_NO_THROW(OpDtypeInfer("Op12").AssertFailed()); -} - -/* ---------- 符号方式注册时的类型校验能力 ---------- */ -// 符号不用于任何类型推导 -REG_OP(Op13) - .INPUT(input1, "T1") - .OPTIONAL_INPUT(input2, "T2") - .DYNAMIC_INPUT(input3, "T3") - .DYNAMIC_INPUT(input4, "T4") - .DATATYPE(T1, TensorType({DT_FLOAT16})) - .DATATYPE(T2, TensorType({DT_FLOAT16})) - .DATATYPE(T3, TensorType({DT_FLOAT16, DT_FLOAT})) - .DATATYPE(T4, ListTensorType({DT_FLOAT16, DT_FLOAT})) - .OP_END_FACTORY_REG(Op13); -TEST_F(UTInferDataType, sym_infer_for_type_check_for_unused_sym_succeed) { - OpDtypeInfer("Op13") - .Input(DT_FLOAT16) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_FLOAT16}) - .Input({DT_FLOAT16, DT_FLOAT}) - .AssertSucceed(); -} -TEST_F(UTInferDataType, sym_infer_for_type_check_for_unused_sym_required_input_dtype_out_of_range) { - OpDtypeInfer("Op13") - .Input(DT_FLOAT) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_FLOAT16}) - .Input({DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_for_type_check_for_unused_sym_opt_input_dtype_out_of_range) { - OpDtypeInfer("Op13") - .Input(DT_FLOAT16) - .Input(DT_FLOAT) - .Input({DT_FLOAT16, DT_FLOAT16}) - .Input({DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_for_type_check_for_unused_sym_dyn_input_dtype_out_of_range) { - OpDtypeInfer("Op13") - .Input(DT_FLOAT16) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_INT32}) - .Input({DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_for_type_check_for_unused_sym_dyn_input_dtype_mismatch) { - OpDtypeInfer("Op13") - .Input(DT_FLOAT16) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_FLOAT}) - .Input({DT_FLOAT16, DT_FLOAT}) - .AssertFailed(); -} -TEST_F(UTInferDataType, sym_infer_for_type_check_for_unused_sym_dyn_list_input_dtype_out_of_range) { - OpDtypeInfer("Op13") - .Input(DT_FLOAT16) - .Input(DT_FLOAT16) - .Input({DT_FLOAT16, DT_FLOAT16}) - .Input({DT_FLOAT16, DT_INT32}) - .AssertFailed(); -} - -/* 测试IR改造后,基于老旧头文件编译的IR能基于CANN中的新IR使用符号化推导的能力 - OpLegacy的IR为老旧IR,不支持类型推导, - CompatOpCurrent模拟当前版本的cann包,其中OpLegacy的IR改造为支持类型推导 - CompatOpFeature模拟未来版本的cann包,其中OpLegacy的IR改造为支持类型推导及新增类型DT_INT32支持 -*/ -namespace { -REG_OP(CompatOpCurrent) - .INPUT(input1, "T") - .OUTPUT(output1, "T") - .DATATYPE(T, TensorType({DT_FLOAT16, DT_FLOAT})) - .OP_END_FACTORY_REG(CompatOpCurrent); - -REG_OP(CompatOpFeature) - .INPUT(input1, "T") - .OUTPUT(output1, "T") - .DATATYPE(T, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) - .OP_END_FACTORY_REG(CompatOpFeature); - -void MockLoadOpsProtoCurrent() { - OperatorFactoryImpl::SetRegisterOverridable(true); - static const OperatorCreatorRegister g_register_compat_feature( - "OpLegacy", [](const AscendString &name) { return op::CompatOpCurrent(name); }); - OperatorFactoryImpl::SetRegisterOverridable(false); -} - -void MockLoadOpsProtoFeature() { - OperatorFactoryImpl::SetRegisterOverridable(true); - static const OperatorCreatorRegister g_register_compat_feature( - "OpLegacy", [](const AscendString &name) { return op::CompatOpFeature(name); }); - OperatorFactoryImpl::SetRegisterOverridable(false); -} -} // namespace - -REG_OP(OpLegacy) - .INPUT(input1, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(output1, TensorType({DT_FLOAT16, DT_FLOAT})) - .OP_END_FACTORY_REG(OpLegacy); - -TEST_F(UTInferDataType, sym_infer_for_compat_with_legacy_ir) { - EXPECT_NO_THROW( - // 原始OpLegacy不支持类型推导 - OpDtypeInfer("OpLegacy").Input(DT_FLOAT16).AssertFailed(); - OpDtypeInfer("OpLegacy").Input(DT_INT32).AssertFailed(); - - // 模拟在老的app中加载新的ops proto,其中的IR相较于Legacy支持了类型推导 - MockLoadOpsProtoCurrent(); - - // 验证新创建的OpLegacy能正常类型推导,但是不支持新增类型DT_INT32 - OpDtypeInfer("OpLegacy").Input(DT_FLOAT16).Expect(DT_FLOAT16).AssertSucceed(); - OpDtypeInfer("OpLegacy").Input(DT_INT32).AssertFailed(); // 此时仍不支持DT_INT32 - - // 模拟将来已经支持符号推导编译后,加载未来版本ops proto, IR新增支持类型场景,其中的IR相较于Legacy支持了类型推导及新增类型DT_INT32支持 - MockLoadOpsProtoFeature(); - - // 验证新创建的OpLegacy能正常类型推导,同时支持新增类型DT_INT32 - OpDtypeInfer("OpLegacy").Input(DT_FLOAT16).Expect(DT_FLOAT16).AssertSucceed(); - OpDtypeInfer("OpLegacy").Input(DT_INT32).Expect(DT_INT32).AssertSucceed(); - ); -} - -REG_OP(OpRecover) - .INPUT(input1, "T") - .OUTPUT(output1, "T") - .DATATYPE(T, TensorType({DT_FLOAT16, DT_FLOAT})) - .OP_END_FACTORY_REG(OpRecover); - -TEST_F(UTInferDataType, sym_infer_after_recover) { - EXPECT_NO_THROW( - auto desc = std::make_shared(); - desc->SetType("OpRecover"); - desc->AddInputDesc("input1", GeTensorDesc(GeShape(), FORMAT_ND, DT_FLOAT16)); - desc->AddOutputDesc("output1", GeTensorDesc(GeShape(), FORMAT_ND, DT_FLOAT16)); - desc->AppendIrInput("input1", IrInputType::kIrInputRequired); - desc->AppendIrOutput("output1", IrOutputType::kIrOutputRequired); - - OpDtypeInfer(desc).Input(DT_FLOAT16).AssertFailed(); - - auto graph = std::make_shared("test"); - graph->AddNode(desc); - RecoverIrUtils::RecoverIrDefinitions(graph); - - OpDtypeInfer(desc).Input(DT_FLOAT16).Expect(DT_FLOAT16).AssertSucceed(); - ); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/inference_context_unittest.cc b/tests/ut/graph/testcase/inference_context_unittest.cc deleted file mode 100644 index d7376a8721128604e0f76ad3e29059a70fd8a809..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/inference_context_unittest.cc +++ /dev/null @@ -1,176 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_error_codes.h" -#include "graph/inference_context.h" -#include "graph/resource_context_mgr.h" -#include "graph/node.h" -#include "graph_builder_utils.h" -#include "graph/utils/transformer_utils.h" -#include "external/graph/types.h" - -namespace ge { -namespace { -struct TestResourceContext : ResourceContext { - std::vector shapes; - std::string resource_type; -}; -} -class TestInferenceConext : public testing::Test { - protected: - ComputeGraphPtr graph_; - void SetUp() { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - builder.AddNode("TensorArrayWrite", "TensorArrayWrite", 1, 1); - builder.AddNode("TensorArrayRead", "TensorArrayRead", 1, 1); - graph_ = builder.GetGraph(); - } - - void TearDown() {} -}; - -TEST_F(TestInferenceConext, TestSetAndGetResourceContext) { - ResourceContextMgr resource_context_mgr; - InferenceContextPtr write_inference_context = std::shared_ptr(InferenceContext::Create(&resource_context_mgr)); - InferenceContextPtr read_inference_context = std::shared_ptr(InferenceContext::Create(&resource_context_mgr)); - - // simulate write op - const char* resource_key = "123"; - std::vector resource_shapes = {GeShape({1,1,2,3})}; - TestResourceContext *resource_context = new TestResourceContext(); - resource_context->shapes = resource_shapes; - resource_context->resource_type = "normal"; - // test resource key empty, return fail - auto ret = write_inference_context->SetResourceContext(AscendString(nullptr), resource_context); - ASSERT_EQ(ret, GRAPH_PARAM_INVALID); - - write_inference_context->SetResourceContext(AscendString(resource_key), resource_context); - - // simulate read op - TestResourceContext *test_reousce_context = - dynamic_cast(read_inference_context->GetResourceContext(resource_key)); - - // check result - auto ret_shape = test_reousce_context->shapes.at(0); - auto ret_type = test_reousce_context->resource_type; - ASSERT_EQ(ret_shape.GetDims(), resource_context->shapes.at(0).GetDims()); - ASSERT_EQ(ret_type, resource_context->resource_type); -} - -TEST_F(TestInferenceConext, TestRegisterAndGetReiledOnResource) { - InferenceContextPtr read_inference_context = std::shared_ptr(InferenceContext::Create()); - - // simulate read_op register relied resource - const char* resource_key = "456"; - read_inference_context->RegisterReliedOnResourceKey(AscendString(resource_key)); - - // simulate read_op register empty relied resource - auto ret = read_inference_context->RegisterReliedOnResourceKey(AscendString(nullptr)); - ASSERT_EQ(ret, GRAPH_PARAM_INVALID); - - auto reiled_keys = read_inference_context->GetReliedOnResourceKeys(); - // check result - ASSERT_EQ(reiled_keys.empty(), false); - ASSERT_EQ(*reiled_keys.begin(), resource_key); -} - -TEST_F(TestInferenceConext, TestAddChangeResourceAndGet) { - InferenceContextPtr write_inference_context = std::shared_ptr(InferenceContext::Create()); - - // simulate write node add changed resource - const char* resource_key = "789"; - write_inference_context->AddChangedResourceKey(AscendString(resource_key)); - - // simulate write node add empty changed resource - auto ret = write_inference_context->AddChangedResourceKey(AscendString(nullptr)); - ASSERT_EQ(ret, GRAPH_PARAM_INVALID); - - auto changed_keys = write_inference_context->GetChangedResourceKeys(); - // check result - ASSERT_EQ(changed_keys.empty(), false); - ASSERT_EQ(*(changed_keys.begin()), resource_key); - - // clear changed_key - write_inference_context->ClearChangedResourceKeys(); - changed_keys = write_inference_context->GetChangedResourceKeys(); - // check result - ASSERT_EQ(changed_keys.empty(), true); -} - -TEST_F(TestInferenceConext, transformer_util) { - OpDescPtr op_desc = std::make_shared("tmp", "tmp"); - GeTensorDesc tensor_desc(GeShape(), ge::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetShape(GeShape(std::vector{1, 1})); - tensor_desc.SetOriginShape(GeShape(std::vector{1, 1, 1, 1})); - tensor_desc.SetFormat(ge::FORMAT_NCHW); - tensor_desc.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc.SetDataType(DT_FLOAT16); - op_desc->AddInputDesc(tensor_desc); - op_desc->AddOutputDesc(tensor_desc); - - std::unique_ptr transformer(new (std::nothrow) NodeShapeTransUtils(op_desc)); - transformer->Init(); - ASSERT_EQ(transformer->CatchFormatAndShape(), true); - ASSERT_EQ(transformer->UpdateFormatAndShape(), true); -} - -TEST_F(TestInferenceConext, ShapeAndType) { - EXPECT_NO_THROW( - ShapeAndType SAndT; - - Shape shape; - DataType data_type; - - shape = SAndT.GetShape(); - //ASSERT_NE(shape, NULL); - data_type = SAndT.GetDataType(); - //ASSERT_NE(data_type, NULL); - - ShapeAndType SAndT2(shape, data_type); - - SAndT2.SetShape(shape); - SAndT2.SetType(data_type); - ); -} - -TEST_F(TestInferenceConext, SetGetInputHandleShapesAndTypes) { - InferenceContextPtr write_inference_context = std::shared_ptr(InferenceContext::Create()); - - std::vector> input_handle_shapes_and_types; - std::vector> input_handle_shapes_and_types_2; - - write_inference_context->SetInputHandleShapesAndTypes(std::move(input_handle_shapes_and_types)); - input_handle_shapes_and_types_2 = write_inference_context->GetInputHandleShapesAndTypes(); - ASSERT_EQ(input_handle_shapes_and_types_2.empty(), true); -} - -TEST_F(TestInferenceConext, SetGetOutputHandleShapesAndTypes) { - InferenceContextPtr write_inference_context = std::shared_ptr(InferenceContext::Create()); - - std::vector> output_handle_shapes_and_types; - std::vector> output_handle_shapes_and_types_2; - - write_inference_context->SetOutputHandleShapesAndTypes(output_handle_shapes_and_types); - write_inference_context->SetOutputHandleShapesAndTypes(std::move(output_handle_shapes_and_types)); - output_handle_shapes_and_types_2 = write_inference_context->GetOutputHandleShapesAndTypes(); - ASSERT_EQ(output_handle_shapes_and_types_2.empty(), true); -} - -TEST_F(TestInferenceConext, SetGetMarks) { - InferenceContextPtr write_inference_context = std::shared_ptr(InferenceContext::Create()); - - const std::vector marks; - std::vector marks_2; - write_inference_context->SetMarks(marks); - write_inference_context->GetMarks(marks_2); - ASSERT_EQ(marks, marks_2); -} - -} // namespace ge diff --git a/tests/ut/graph/testcase/inference_rule_unittest.cc b/tests/ut/graph/testcase/inference_rule_unittest.cc deleted file mode 100644 index ef49b643525aa65a0c98092cda1ec781ac02657f..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/inference_rule_unittest.cc +++ /dev/null @@ -1,953 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/operator_reg.h" -#include "graph/debug/ge_log.h" -#include "utils/inference_rule.h" -#include "register/op_impl_space_registry.h" -#include "register/op_impl_registry_base.h" -#include "tests/depends/faker/kernel_run_context_faker.h" - -using Json = nlohmann::json; -using namespace gert; - -namespace ge { -REG_OP(RuleInferOp) - .DYNAMIC_INPUT(x, TensorType::ALL()) - .DYNAMIC_OUTPUT(y, TensorType::ALL()) - .OP_END_FACTORY_REG(RuleInferOp); -} // namespace ge - -namespace { -class CtxMaker { - public: - CtxMaker() : compile_holder(), runtime_holder(), dtypes_holder() { - json["shape"]["inputs"] = Json::array(); - json["shape"]["outputs"] = Json::array(); - json["dtype"] = Json::array(); - } - - CtxMaker &Input(const Json::array_t &input, const std::initializer_list runtime_input) { - json["shape"]["inputs"].push_back(input); - compile_inputs.emplace_back(NewShape()); - runtime_inputs.emplace_back(NewShape(runtime_input)); - auto &compile_input = compile_inputs.back()->MutableOriginShape(); - compile_input.SetDimNum(runtime_input.size()); - for (size_t i = 0; i < runtime_input.size(); ++i) { - const auto &dim = input[i]; - if (dim.is_string()) { - compile_input.SetDim(i, -1); - } else if (dim.is_number_integer()) { - const int64_t dim_value = dim.get(); - compile_input.SetDim(i, dim_value); - } else { - compile_input.SetDim(i, -3); - } - } - return *this; - } - - CtxMaker &ValueInput(const Json::array_t &input, const std::initializer_list runtime_input, - ge::DataType dtype) { - json["shape"]["inputs"].push_back(input); - compile_inputs.emplace_back(NewTensor(runtime_input, dtype)); - runtime_inputs.emplace_back(NewTensor(runtime_input, dtype)); - return *this; - } - - CtxMaker &NullInput() { - json["shape"]["inputs"].push_back(nullptr); - compile_inputs.emplace_back(nullptr); - runtime_inputs.emplace_back(nullptr); - return *this; - } - - CtxMaker &Output(const Json::array_t &output) { - json["shape"]["outputs"].push_back(output); - compile_outputs.emplace_back(NewShape()); - runtime_outputs.emplace_back(NewShape()); - return *this; - } - - CtxMaker &Dtypes(const Json::array_t &dtypes) { - json["dtype"] = dtypes; - output_dtypes.resize(dtypes.size(), ge::DataType::DT_UNDEFINED); - for (auto &output_dtype : output_dtypes) { - ctx_dtypes.emplace_back(&output_dtype); - } - return *this; - } - - std::string Str() const { - return json.dump(); - } - - void Build(bool with_rule = true) { - const auto rule_op = std::make_shared("op"); - rule_op->create_dynamic_input_x(compile_inputs.size()); - rule_op->create_dynamic_output_y(compile_outputs.size()); - for (size_t i = 0; i < compile_inputs.size(); ++i) { - if (compile_inputs[i] == nullptr) { - rule_op->UpdateDynamicInputDesc("x", i, ge::TensorDesc()); - continue; - } - auto &storage_shape = compile_inputs[i]->MutableOriginShape(); - std::vector dims; - dims.reserve(storage_shape.GetDimNum()); - for (size_t j = 0; j < storage_shape.GetDimNum(); ++j) { - dims.push_back(storage_shape.GetDim(j)); - } - rule_op->UpdateDynamicInputDesc("x", i, ge::TensorDesc(ge::Shape(dims), ge::FORMAT_ND, ge::DT_FLOAT16)); - } - desc = ge::OpDescUtils::GetOpDescFromOperator(*rule_op); - if (with_rule) { - ge::AttrUtils::SetStr(desc, "_inference_rule", Str()); - } - op = rule_op; - - std::vector inputs; - std::vector outputs; - inputs.reserve(compile_inputs.size()); - for (auto &input : compile_inputs) { - inputs.emplace_back(input); - } - outputs.reserve(compile_outputs.size()); - for (auto &output : compile_outputs) { - outputs.emplace_back(output); - } - - compile_holder = InferShapeContextFaker() - .IrInputNum(inputs.size()) - .NodeIoNum(inputs.size(), outputs.size()) - .InputShapes(inputs) - .OutputShapes(outputs) - .Build(); - - std::vector rt_inputs; - std::vector rt_outputs; - rt_inputs.reserve(runtime_inputs.size()); - for (auto &input : runtime_inputs) { - rt_inputs.emplace_back(input); - } - rt_outputs.reserve(runtime_outputs.size()); - for (auto &output : runtime_outputs) { - rt_outputs.emplace_back(output); - } - - runtime_holder = InferShapeContextFaker() - .IrInputNum(rt_inputs.size()) - .NodeIoNum(rt_inputs.size(), rt_outputs.size()) - .InputShapes(rt_inputs) - .OutputShapes(rt_outputs) - .Build(); - - dtypes_holder = InferDataTypeContextFaker() - .IrInputNum(rt_inputs.size()) - .NodeIoNum(rt_inputs.size(), rt_outputs.size()) - .OutputDataTypes(ctx_dtypes) - .Build(); - } - - InferShapeContext *CompileCtx() { - return compile_holder.GetContext(); - } - - InferShapeContext *RuntimeCtx() { - return runtime_holder.GetContext(); - } - - InferDataTypeContext *DtypeCtx() { - return dtypes_holder.GetContext(); - } - - ge::OpDescPtr OpDesc() const { - return desc; - } - - ge::Operator &Operator() const { - return *op; - } - - StorageShape *NewShape() { - holders.emplace_back(std::make_shared()); - return holders.back().get(); - } - - StorageShape *NewTensor(const std::initializer_list &runtime_input, ge::DataType dtype) { - values.emplace_back(std::shared_ptr(malloc(sizeof(int64_t) * runtime_input.size()), std::free)); - auto shape = StorageShape({static_cast(runtime_input.size())}, {static_cast(runtime_input.size())}); - tensor_holders.emplace_back(std::make_shared(shape, StorageFormat(), kOnHost, dtype, values.back().get())); - if (dtype == ge::DT_INT32) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = static_cast(dim); - } - } else if (dtype == ge::DT_INT64) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = dim; - } - } else if (dtype == ge::DT_UINT32) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = static_cast(dim); - } - } - return reinterpret_cast(tensor_holders.back().get()); - } - - StorageShape *NewShape(const std::initializer_list &runtime_input) { - holders.emplace_back(std::make_shared(runtime_input, runtime_input)); - return holders.back().get(); - } - - Json json; - std::vector compile_inputs; - std::vector runtime_inputs; - std::vector compile_outputs; - std::vector runtime_outputs; - - std::vector> holders; - FakeKernelContextHolder compile_holder; - FakeKernelContextHolder runtime_holder; - FakeKernelContextHolder dtypes_holder; - - std::vector> values; - std::vector> tensor_holders; - - std::vector ctx_dtypes; - std::vector output_dtypes; - - std::shared_ptr op = nullptr; - ge::OpDescPtr desc = nullptr; -}; -} // namespace - -class InferenceRuleUtest : public testing::Test { - protected: - void SetUp() override { - // construct op impl registry - const auto space_registry = std::make_shared(); - const auto registry_holder = std::make_shared(); - const auto funcs = gert::OpImplRegistry::GetInstance().CreateOrGetOpImpl("RuleInferOp"); - registry_holder->AddTypesToImpl("RuleInferOp", funcs); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - } - - void TearDown() override {} - - static std::string ShapeEqual(Shape *shape, std::initializer_list dims) { - std::stringstream ss; - if (shape == nullptr) { - return "shape == nullptr"; - } - if (shape->GetDimNum() != dims.size()) { - ss << "dim num not equal, expect " << dims.size() << ", got " << shape->GetDimNum(); - return ss.str(); - } - for (size_t i = 0; i < dims.size(); ++i) { - if (shape->GetDim(i) != *(dims.begin() + i)) { - ss << "dim[" << i << "] not equal, expect " << *(dims.begin() + i) << ", got " << shape->GetDim(i); - return ss.str(); - } - } - return ""; - } - - static std::string ShapeEqual(const ge::GeShape &shape, std::initializer_list dims) { - std::stringstream ss; - if (shape.GetDimNum() != dims.size()) { - ss << "dim num not equal, expect " << dims.size() << ", got " << shape.GetDimNum(); - return ss.str(); - } - for (size_t i = 0; i < dims.size(); ++i) { - if (shape.GetDim(i) != *(dims.begin() + i)) { - ss << "dim[" << i << "] not equal, expect " << *(dims.begin() + i) << ", got " << shape.GetDim(i); - return ss.str(); - } - } - return ""; - } -}; - -TEST_F(InferenceRuleUtest, BasicDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32}), ""); -} - -TEST_F(InferenceRuleUtest, MultiDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", "s1"}, {32, 64}).Output({"s1", "s0"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, -1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {64, 32}), ""); -} - -TEST_F(InferenceRuleUtest, DimSymbolWithFunctionVertical) { - CtxMaker ctx_maker; - int64_t s0 = 32; - int64_t s1 = 64; - // "+", "-", "*", "Div", "Floor", "Ceil", "Pow", "Mod" - ctx_maker.Input({"s0", "s1"}, {s0, s1}) - .Output({"s1+s0"}) - .Output({"s1-s0"}) - .Output({"s1*s0"}) - .Output({"Div(s1,s0)"}) - .Output({"Floor(Div(s1,3))"}) - .Output({"Ceil(Div(s1,3))"}) - .Output({"Pow(s0,2)"}) - .Output({"Mod(s1,7)"}) - .Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(1), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(2), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(3), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(4), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(5), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(6), {-1}), ""); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(7), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {s1 + s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(1), {s1 - s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(2), {s1 * s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(3), {s1 / s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(4), {s1 / 3}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(5), {(s1 + 2) / 3}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(6), {s0 * s0}), ""); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(7), {s1 % 7}), ""); -} - -TEST_F(InferenceRuleUtest, DimSymbolWithFunctionHorizontal) { - CtxMaker ctx_maker; - int64_t s0 = 32; - int64_t s1 = 64; - // "+", "-", "*", "Div", "Floor", "Ceil", "Pow", "Mod" - ctx_maker.Input({"s0", "s1"}, {s0, s1}) - .Output( - {"s1+s0", "s1-s0", "s1*s0", "Div(s1,s0)", "Floor(Div(s1,3))", "Ceil(Div(s1,3))", "Pow(s0,2)", "Mod(s1,7)"}) - .Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, -1, -1, -1, -1, -1, -1, -1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), - {s1 + s0, s1 - s0, s1 * s0, s1 / s0, s1 / 3, (s1 + 2) / 3, s0 * s0, s1 % 7}), - ""); -} - -TEST_F(InferenceRuleUtest, StaticDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Output({"128", "32+24"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {128, 56}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {128, 56}), ""); -} - -TEST_F(InferenceRuleUtest, NullDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", nullptr, "s1"}, {32, 20, 24}).Output({"s0+s1"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {56}), ""); -} - -TEST_F(InferenceRuleUtest, RepeatDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", "s0"}, {32, 32}).Input({"s1"}, {24}).Input({"s1"}, {24}).Output({"s0+s1"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {56}), ""); -} - -TEST_F(InferenceRuleUtest, SymbolMixStrAndIntAndNull) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", 128, "s1", nullptr, "s3", "24"}, {4, 128, 8, 0, 16, 24}) - .Output({"s1", "128", 32, "128+32"}) - .Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, 128, 32, 160}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {8, 128, 32, 160}), ""); -} - -TEST_F(InferenceRuleUtest, SymbolWithNullInput) { - CtxMaker ctx_maker; - ctx_maker.NullInput().Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolBasic) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0"}, {32}, ge::DT_INT32).Output({"v0+3"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {35}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolMultiDtype) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0"}, {32}, ge::DT_INT32) - .ValueInput({"v1"}, {24}, ge::DT_UINT32) - .ValueInput({"v2"}, {8}, ge::DT_INT64) - .Output({"v0+v1+v2"}) - .Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 24 + 8}), ""); -} - -TEST_F(InferenceRuleUtest, MultiValueSymbol) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0", "v2", "v1"}, {32, 2, 6}, ge::DT_INT32).Output({"v0+v1+v2"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 2 + 6}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolMixNull) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0", nullptr, "v1"}, {32, 2, 6}, ge::DT_INT32).Output({"v0+v1"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 6}), ""); -} - -TEST_F(InferenceRuleUtest, ValueSymbolMixDimSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0", "s1"}, {3, 4}) - .ValueInput({"v0", nullptr, "v1"}, {32, 2, 6}, ge::DT_INT32) - .Output({"v0+s0", "v1+s1"}) - .Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1, -1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32 + 3, 6 + 4}), ""); -} - -TEST_F(InferenceRuleUtest, CompileAndLoadSucceed) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - std::vector binary; - ASSERT_EQ(ge::ShapeInferenceRule::CompileJsonString(ctx_maker.Str(), binary), ge::GRAPH_SUCCESS); - const auto handle = ge::ShapeInferenceRule::FromCompiledBinary(binary); - ASSERT_EQ(handle.Error(), ""); - - const auto compile_ctx = ctx_maker.CompileCtx(); - ASSERT_EQ(handle.InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(compile_ctx->GetOutputShape(0), {-1}), ""); - - const auto runtime_ctx = ctx_maker.RuntimeCtx(); - ASSERT_EQ(handle.InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(runtime_ctx->GetOutputShape(0), {32}), ""); -} - -TEST_F(InferenceRuleUtest, OutputWithUndefinedSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s1"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Symbol 's1' used in output but not defined in inputs"); -} - -TEST_F(InferenceRuleUtest, InputIsNotRawSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"t0"}, {32}).Output({"t1"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing input symbols: Invalid input[0].size(0): t0, symbol dimension must start with 's' or 'v' " - "and follow with a number"); -} - -TEST_F(InferenceRuleUtest, InputIsNotSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0+2"}, {32}).Output({"s0+2"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing input symbols: Invalid input[0].size(0): s0+2, symbol dimension must start with 's' or 'v' " - "and follow with a number"); -} - -TEST_F(InferenceRuleUtest, NoShapeFiled) { - const auto handle = ge::ShapeInferenceRule::FromJsonString("{}"); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Missing 'shape' field in rule json."); -} - -TEST_F(InferenceRuleUtest, InputsFormatError) { - { - Json json; - json["shape"]["inputs"] = 3; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.inputs' field: 3 field must be an array or null."); - } - - { - Json json; - json["shape"]["inputs"] = {3}; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.inputs' field: [3] element must be an array of dimension expressions."); - } - - { - Json json; - json["shape"]["inputs"] = {{2.5}}; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Invalid 'shape.inputs' field: [[2.5]] dimension expression must be a string or integer."); - } -} - -TEST_F(InferenceRuleUtest, OutputsFormatError) { - { - Json json; - json["shape"]["outputs"] = 3; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.outputs' field: 3 field must be an array or null."); - } - - { - Json json; - json["shape"]["outputs"] = {3}; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Invalid 'shape.outputs' field: [3] element must be an array of dimension expressions."); - } - - { - Json json; - json["shape"]["outputs"] = {{2.5}}; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Invalid 'shape.outputs' field: [[2.5]] dimension expression must be a string or integer."); - } - - { - Json json; - json["shape"]["outputs"] = {{nullptr}}; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid output[0].size(0): empty dimension"); - } - - { - Json json; - json["shape"]["outputs"] = {{""}}; - const auto handle = ge::ShapeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid output[0].size(0): empty dimension"); - } -} - -TEST_F(InferenceRuleUtest, UnsupportedFunction) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"Abc(s0)"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing output tensors: Invalid dim expr 'Abc(s0)': Invalid function: Abc, supported [Div, Floor, " - "Ceil, Pow, Mod]"); -} - -TEST_F(InferenceRuleUtest, UnsupportedOperator) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0 / 3"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing output tensors: Invalid dim expr 's0 / 3': Expression contains invalid characters"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_UnmatchedRightParenthesis) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0)"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid dim expr 's0)': Unmatched ')'"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_UnmatchedLeftParenthesis) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"(s0"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Error parsing output tensors: Invalid dim expr '(s0': Unmatched '('"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_InvalidSymbol) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"2s0)"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing output tensors: Invalid dim expr '2s0)': Invalid identifier: '2s0', expected start with 's' " - "or 'v' and follow with a number"); -} - -TEST_F(InferenceRuleUtest, IllegalExpression_SyntaxError) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0 ++ 2"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Failed to compile C++ code to shared object:\nextern \"C\" {bool infer_shape(Ctx *ctx) {\n " - "GET_SYMBOL_DIM(s0, 0, 0);\n\n SET_OUTPUT_RANK(0, 1);\n SET_OUTPUT_DIM(0, 0, static_cast(s0 " - "++ 2));\n\n return true;\n}\nbool infer_shape_on_compile(Ctx *ctx) {\n SET_OUTPUT_RANK(0, 1);\n " - "SET_OUTPUT_DIM(0, 0, -1);\n\n return true;\n}}\nError: syntax error"); -} - -TEST_F(InferenceRuleUtest, BasicDtypeInfer) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({ge::DataType::DT_BF16}).Build(); - - const auto handle = ge::DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - const auto dtype_ctx = ctx_maker.DtypeCtx(); - - ASSERT_EQ(handle->InferDtype(dtype_ctx), ge::GRAPH_SUCCESS); - ASSERT_EQ(dtype_ctx->GetOutputDataType(0), ge::DataType::DT_BF16); -} - -TEST_F(InferenceRuleUtest, InvalidDtype1) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({ge::DataType::DT_UNDEFINED}).Build(); - - const auto handle = ge::DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Element 28 in 'dtype' field is out of range [0,42(DT_MAX)) and cannot be 28(DT_UNDEFINED)."); -} - -TEST_F(InferenceRuleUtest, InvalidDtype2) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({ge::DataType::DT_MAX}).Build(); - - const auto handle = ge::DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Element 42 in 'dtype' field is out of range [0,42(DT_MAX)) and cannot be 28(DT_UNDEFINED)."); -} - -TEST_F(InferenceRuleUtest, InvalidDtype3) { - CtxMaker ctx_maker; - ctx_maker.Output({128}).Dtypes({-1}).Build(); - - const auto handle = ge::DtypeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Element -1 in 'dtype' field is out of range [0,42(DT_MAX)) and cannot be 28(DT_UNDEFINED)."); -} - -TEST_F(InferenceRuleUtest, DtypesFormatError) { - { - Json json; - const auto handle = ge::DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Missing 'dtype' field in rule json."); - } - - { - Json json; - json["dtype"] = 3; - const auto handle = ge::DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Field 'dtype' must be an array."); - } - - { - Json json; - json["dtype"] = {nullptr}; - const auto handle = ge::DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Element in 'dtype' field must not be null."); - } - - { - Json json; - json["dtype"] = {2.5}; - const auto handle = ge::DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Element in 'dtype' field must be an integer."); - } - - { - Json json; - json["dtype"] = nullptr; - const auto handle = ge::DtypeInferenceRule::FromJsonString(json.dump()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), "Filed 'dtype' must not be null."); - } -} - -TEST_F(InferenceRuleUtest, JsonFormatError) { - Json json; - const auto handle = ge::DtypeInferenceRule::FromJsonString("{"); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), - "Error parsing json: [json.exception.parse_error.101] parse error at line 1, column 2: syntax error while " - "parsing object key - unexpected end of input; expected string literal"); -} - -TEST_F(InferenceRuleUtest, CalledByInvalidDimCtx) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - { - CtxMaker ctx_bug; - ctx_bug.Build(); - - const auto compile_ctx = ctx_bug.CompileCtx(); - ASSERT_NE(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } - - { - CtxMaker ctx_bug; - ctx_bug.Input({"s0"}, {32}).Build(); - - const auto compile_ctx = ctx_bug.CompileCtx(); - ASSERT_NE(handle->InferOnCompile(compile_ctx), ge::GRAPH_SUCCESS); - - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } -} - -TEST_F(InferenceRuleUtest, CalledByInvalidValueCtx) { - CtxMaker ctx_maker; - ctx_maker.ValueInput({"v0", "v1"}, {32, 24}, ge::DT_INT32).Output({"v1"}).Build(); - - const auto handle = ge::ShapeInferenceRule::FromJsonString(ctx_maker.Str()); - ASSERT_NE(handle, nullptr); - ASSERT_EQ(handle->Error(), ""); - - { - CtxMaker ctx_bug; - ctx_bug.Build(); - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } - - { - CtxMaker ctx_bug; - ctx_bug.ValueInput({"v0"}, {32}, ge::DT_INT32).Output({"v0"}).Build(); - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } - - { - CtxMaker ctx_bug; - ctx_bug.ValueInput({"v0, v1"}, {32, 24}, ge::DT_INT16).Output({"v1"}).Build(); - const auto runtime_ctx = ctx_bug.RuntimeCtx(); - ASSERT_NE(handle->InferOnRuntime(runtime_ctx), ge::GRAPH_SUCCESS); - } -} - -TEST_F(InferenceRuleUtest, CompileInvalidJsonStrOrCode) { - std::vector binary; - ASSERT_NE(ge::ShapeInferenceRule::CompileJsonString("{", binary), ge::GRAPH_SUCCESS); - - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0 ++ 2"}).Build(); - ASSERT_NE(ge::ShapeInferenceRule::CompileJsonString(ctx_maker.Str(), binary), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CallInvalidRule) { - { - const auto rule = ge::ShapeInferenceRule::FromJsonString("{"); - - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - ASSERT_NE(rule->InferOnCompile(ctx_maker.CompileCtx()), ge::GRAPH_SUCCESS); - ASSERT_NE(rule->InferOnRuntime(ctx_maker.RuntimeCtx()), ge::GRAPH_SUCCESS); - } - - { - const auto rule = ge::DtypeInferenceRule::FromJsonString("{"); - - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - ASSERT_NE(rule->InferDtype(ctx_maker.DtypeCtx()), ge::GRAPH_SUCCESS); - } -} - -TEST_F(InferenceRuleUtest, JustForCoverage) { - auto handle = ge::ShapeInferenceRule::FromCompiledBinary({}); - ASSERT_NE(handle.Error(), ""); - - ASSERT_TRUE(ge::ShapeInferenceRule::GetInferenceRule(nullptr).empty()); -} \ No newline at end of file diff --git a/tests/ut/graph/testcase/ir_definitions_recover_test.cc b/tests/ut/graph/testcase/ir_definitions_recover_test.cc deleted file mode 100644 index 2ffbcbaa3ee1ea1e6f1ffaad986aeeb050d411e0..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ir_definitions_recover_test.cc +++ /dev/null @@ -1,344 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/op_desc.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/ir_definitions_recover.h" -#include "graph/utils/recover_ir_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "external/graph/operator_reg.h" -#include "slog.h" - -using namespace ge; - -namespace gert { -class IrDefinitionsRecoverUT : public testing::Test {}; - -REG_OP(MatMulUt) - .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .ATTR(transpose_x1, Bool, false) - .ATTR(transpose_x2, Bool, false) - .REQUIRED_ATTR(loss_attr, Bool) - .OP_END_FACTORY_REG(MatMulUt) - -REG_OP(ConcatV2DUt) - .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .DYNAMIC_OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .REQUIRED_ATTR(concat_dim, Int) - .ATTR(N, Int, 1) - .OP_END_FACTORY_REG(ConcatV2DUt) - -REG_OP(BNInferenceDUt) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF_16})) - .INPUT(mean, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF_16})) - .INPUT(variance, TensorType({DT_FLOAT, DT_FLOAT16, DT_BF_16})) - .OPTIONAL_INPUT(scale, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OPTIONAL_INPUT(b, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .ATTR(momentum, Float, 0.9f) - .ATTR(epsilon, Float, 1e-5f) - .ATTR(use_global_stats, Bool, true) - .ATTR(mode, Int, 1) - .OP_END_FACTORY_REG(BNInferenceDUt) - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_ir_inputs_not_match_failed) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_ = op_desc_origin->GetIrAttrNames(); - op_desc->impl_->meta_data_.ir_meta_.ir_inputs_.ir_inputs = op_desc_origin->GetIrInputs(); - ASSERT_FALSE(op_desc->impl_->meta_data_.ir_meta_.ir_inputs_.ir_inputs.empty()); - ASSERT_FALSE(op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_.empty()); - op_desc->impl_->meta_data_.ir_meta_.ir_inputs_.ir_inputs[0].first = "fake"; - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_TRUE(op_desc->impl_->meta_data_.ir_meta_.ir_inputs_.ir_inputs[0].first == "fake"); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_ir_inputs_num_check_failed) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - op_desc->impl_->meta_data_.ir_meta_.ir_inputs_.ir_inputs.emplace_back(std::pair("fake", kIrInputRequired)); - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->impl_->meta_data_.ir_meta_.ir_inputs_.ir_inputs[0].first, "fake"); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_ir_attr_name_not_match_failed) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_.emplace_back("fake"); - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_[0], "fake"); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_ir_attr_name_num_check_failed) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_.emplace_back("fake"); - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_.back(), "fake"); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_empty_success) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - // recover success - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), op_desc_origin->GetIrAttrNames().size()); - EXPECT_EQ(op_desc->GetIrInputs().size(), op_desc_origin->GetIrInputs().size()); - EXPECT_EQ(op_desc->GetIrOutputs().size(), op_desc_origin->GetIrOutputs().size()); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_partial_success) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - op_desc->AppendIrAttrName(op_desc_origin->GetIrAttrNames().at(0)); - auto &pair = op_desc_origin->GetIrInputs().at(0); - op_desc->AppendIrInput(pair.first, pair.second); - - // recover success - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), op_desc_origin->GetIrAttrNames().size()); - EXPECT_EQ(op_desc->GetIrInputs().size(), op_desc_origin->GetIrInputs().size()); - EXPECT_EQ(op_desc->GetIrOutputs().size(), op_desc_origin->GetIrOutputs().size()); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_same_success) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - for (const auto &attr : op_desc_origin->GetIrAttrNames()) { - op_desc->AppendIrAttrName(attr); - } - for (const auto &pair : op_desc_origin->GetIrInputs()) { - op_desc->AppendIrInput(pair.first, pair.second); - } - // recover success - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), op_desc_origin->GetIrAttrNames().size()); - EXPECT_EQ(op_desc->GetIrInputs().size(), op_desc_origin->GetIrInputs().size()); - EXPECT_EQ(op_desc->GetIrOutputs().size(), op_desc_origin->GetIrOutputs().size()); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_frameworkop_success) { - auto op_desc = std::make_shared("matmul", "FrameworkOp"); - AttrUtils::SetStr(op_desc, "original_type", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMul", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - // recover success - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), op_desc_origin->GetIrAttrNames().size()); - EXPECT_EQ(op_desc->GetIrInputs().size(), op_desc_origin->GetIrInputs().size()); - EXPECT_EQ(op_desc->GetIrOutputs().size(), op_desc_origin->GetIrOutputs().size()); - -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_op_loss_not_has_default_value) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - // recover success - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_FALSE(ge::AttrUtils::HasAttr(op_desc, "loss_attr")); - EXPECT_TRUE(ge::AttrUtils::HasAttr(op_desc, "transpose_x1")); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_ir_outputs_not_match_failed) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_ = op_desc_origin->GetIrAttrNames(); - op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs = op_desc_origin->GetIrOutputs(); - ASSERT_FALSE(op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs.empty()); - ASSERT_FALSE(op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_.empty()); - op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs[0].first = "fake"; - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_TRUE(op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs[0].first == "fake"); -} - -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_ir_outputs_num_check_failed) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs.emplace_back(std::pair("fake", kIrOutputRequired)); - auto ret = RecoverIrUtils::RecoverIrDefinitions(computeGraph); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs[0].first, "fake"); -} - -// TODO if all depended is replace, this 2 function will be deleted -TEST_F(IrDefinitionsRecoverUT, RecoverIrDefinitions_wrapper_empty_success) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - // recover success - auto ret = RecoverIrDefinitions(computeGraph); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), op_desc_origin->GetIrAttrNames().size()); - EXPECT_EQ(op_desc->GetIrInputs().size(), op_desc_origin->GetIrInputs().size()); - EXPECT_EQ(op_desc->GetIrOutputs().size(), op_desc_origin->GetIrOutputs().size()); -} - -// TODO if all depended is replace, this 2 function will be deleted -TEST_F(IrDefinitionsRecoverUT, RecoverOpDescIrDefinition_wrapper_empty_success) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto computeGraph = std::make_shared("graph_name"); - ASSERT_NE(computeGraph, nullptr); - ASSERT_NE(computeGraph->AddNode(op_desc), nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - // recover success - auto ret = RecoverOpDescIrDefinition(op_desc); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), op_desc_origin->GetIrAttrNames().size()); - EXPECT_EQ(op_desc->GetIrInputs().size(), op_desc_origin->GetIrInputs().size()); - EXPECT_EQ(op_desc->GetIrOutputs().size(), op_desc_origin->GetIrOutputs().size()); -} - -TEST_F(IrDefinitionsRecoverUT, CheckIrSpe_ir_input_num_check_failed) { - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - auto ret = CheckIrSpec(op_desc); - EXPECT_EQ(ret, false); -} - -TEST_F(IrDefinitionsRecoverUT, CheckIrSpe_ir_input_dynamic_skip_check) { - auto op_desc = std::make_shared("concatv2d", "ConcatV2DUt"); - ASSERT_NE(op_desc, nullptr); - auto ret = CheckIrSpec(op_desc); - EXPECT_EQ(ret, false); -} - -TEST_F(IrDefinitionsRecoverUT, CheckIrSpe_ir_input_optional_skip_check) { - auto op_desc = std::make_shared("BNInferenceDUt", "BNInferenceDUt"); - ASSERT_NE(op_desc, nullptr); - auto ret = CheckIrSpec(op_desc); - EXPECT_EQ(ret, false); -} - -TEST_F(IrDefinitionsRecoverUT, CheckIrSpec_ir_attr_not_match_failed) { - dlog_setlevel(0, 0, 0); - auto op_desc = std::make_shared("matmul", "MatMulUt"); - ASSERT_NE(op_desc, nullptr); - - auto op = ge::OperatorFactory::CreateOperator("MatMulUt", "MatMulUt"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - - op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs = op_desc_origin->GetIrOutputs(); - op_desc->impl_->meta_data_.ir_meta_.ir_inputs_.ir_inputs = op_desc_origin->GetIrInputs(); - op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_.emplace_back("fake"); - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - op_desc->AddOutputDesc(out_desc); - op_desc->AddInputDesc(0, out_desc); - op_desc->AddInputDesc(1, out_desc); - op_desc->AddInputDesc(2, out_desc); - (void)AttrUtils::SetBool(op_desc, "transpose_x1", true); - ASSERT_FALSE(op_desc->impl_->meta_data_.ir_meta_.ir_outputs_.ir_outputs.empty()); - ASSERT_FALSE(op_desc->impl_->meta_data_.ir_meta_.ir_attr_names_.empty()); - auto ret = CheckIrSpec(op_desc); - EXPECT_EQ(ret, false); - dlog_setlevel(0, 3, 0); -} -} // namespace gert diff --git a/tests/ut/graph/testcase/match_policy_for_exactly_the_same_unittest.cc b/tests/ut/graph/testcase/match_policy_for_exactly_the_same_unittest.cc deleted file mode 100644 index e32b0077dcc1b534378c62736875208e097e1335..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/match_policy_for_exactly_the_same_unittest.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/cache_policy/match_policy_for_exactly_the_same.h" -#include "cache_desc_stub/runtime_cache_desc.h" -#include "graph/cache_policy/cache_state.h" - -namespace ge { -namespace { -CacheDescPtr CreateRuntimeCacheDesc(const std::vector &shapes) { - auto cache_desc = std::make_shared(); - cache_desc->SetShapes(shapes); - return cache_desc; -} -CacheInfo CreateCacheInfo(const uint64_t time_count, const CacheItemId item_id, - const std::vector &shapes) { - auto cache_desc = CreateRuntimeCacheDesc(shapes); - CacheInfo cache_info{time_count, item_id, cache_desc}; - return cache_info; -} -} -class MatchPolicyForExactlyTheSameUT : public testing::Test {}; - -TEST_F(MatchPolicyForExactlyTheSameUT, GetCacheItemId_KInvalidCacheItemId_CannotMatchHashKey) { - gert::Shape s1{256, 256}; - gert::Shape s2{1, 256, 256}; - const std::vector shapes1{s1}; - const std::vector shapes2{s2}; - - CCStatType hash_2_cache_infos; - auto cache_info = CreateCacheInfo(1, 1, shapes1); - auto hash = CreateRuntimeCacheDesc(shapes1)->GetCacheDescHash(); - hash_2_cache_infos[hash] = {cache_info}; - - auto find_cache_desc = CreateRuntimeCacheDesc(shapes2); - MatchPolicyForExactlyTheSame mp; - auto find_id = mp.GetCacheItemId(hash_2_cache_infos, find_cache_desc); - EXPECT_EQ(find_id, KInvalidCacheItemId); -} - -TEST_F(MatchPolicyForExactlyTheSameUT, GetCacheItemId_KInvalidCacheItemId_CannotMatchShapes) { - gert::Shape s1{256, 256}; - gert::Shape s2{1, 256, 256}; - const std::vector shapes1{s1}; - const std::vector shapes2{s2}; - - CCStatType hash_2_cache_infos; - auto cache_info = CreateCacheInfo(1, 1, shapes1); - auto hash = CreateRuntimeCacheDesc(shapes2)->GetCacheDescHash(); - hash_2_cache_infos[hash] = {cache_info}; - - auto find_cache_desc = CreateRuntimeCacheDesc(shapes2); - MatchPolicyForExactlyTheSame mp; - auto find_id = mp.GetCacheItemId(hash_2_cache_infos, find_cache_desc); - EXPECT_EQ(find_id, KInvalidCacheItemId); -} - -TEST_F(MatchPolicyForExactlyTheSameUT, GetCacheItemId_ShapesAndHashMatched) { - gert::Shape s1{256, 256}; - const std::vector shapes1{s1}; - const std::vector shapes2{s1}; - - CCStatType hash_2_cache_infos; - auto cache_info = CreateCacheInfo(1, 1, shapes1); - auto hash = CreateRuntimeCacheDesc(shapes1)->GetCacheDescHash(); - hash_2_cache_infos[hash] = {cache_info}; - - auto find_cache_desc = CreateRuntimeCacheDesc(shapes1); - MatchPolicyForExactlyTheSame mp; - auto find_id = mp.GetCacheItemId(hash_2_cache_infos, find_cache_desc); - EXPECT_EQ(find_id, cache_info.GetItemId()); -} - -TEST_F(MatchPolicyForExactlyTheSameUT, GetCacheItemId_LogHashNotExist_KeyExistButVectorEmpty) { - dlog_setlevel(0, 0, 0); - gert::Shape s1{256, 256}; - const std::vector shapes1{s1}; - - CCStatType hash_2_cache_infos; - auto cache_info = CreateCacheInfo(1, 1, shapes1); - auto find_cache_desc = CreateRuntimeCacheDesc(shapes1); - auto hash = find_cache_desc->GetCacheDescHash(); - hash_2_cache_infos[hash] = {cache_info}; - hash_2_cache_infos[hash].erase(hash_2_cache_infos[hash].begin()); // key exist but vector value empty - - MatchPolicyForExactlyTheSame mp; - auto find_id = mp.GetCacheItemId(hash_2_cache_infos, find_cache_desc); - EXPECT_EQ(find_id, KInvalidCacheItemId); - dlog_setlevel(0, 3, 0); -} - -} // namespace ge diff --git a/tests/ut/graph/testcase/math_util_unittest.cc b/tests/ut/graph/testcase/math_util_unittest.cc deleted file mode 100644 index b092232d6bdeff7dfd1fb9d584cc962c4aedc30d..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/math_util_unittest.cc +++ /dev/null @@ -1,250 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/math_util.h" -#include -namespace ge { -class MathUtilUT : public testing::Test {}; -TEST_F(MathUtilUT, AddOverflow_NotOverflow) { - size_t i = 0; - size_t j = 0; - size_t ret; - EXPECT_FALSE(AddOverflow(i, j, ret)); - EXPECT_EQ(ret, 0); - - i = 100; - j = 200; - EXPECT_FALSE(AddOverflow(i, j, ret)); - EXPECT_EQ(ret, 300); - - i = 0xFFFFFFFFFFFFFFFF; - j = 0; - EXPECT_FALSE(AddOverflow(i, j, ret)); - EXPECT_EQ(ret, 0xFFFFFFFFFFFFFFFF); - - i = 0x7FFFFFFFFFFFFFFF; - j = 0x8000000000000000; - EXPECT_FALSE(AddOverflow(i, j, ret)); - EXPECT_EQ(ret, 0xFFFFFFFFFFFFFFFF); -} -TEST_F(MathUtilUT, AddOverflow_Overflow) { - - size_t i = 0xFFFFFFFFFFFFFFFF; - size_t j = 1; - size_t ret; - EXPECT_TRUE(AddOverflow(i, j, ret)); - - i = 0x7FFFFFFFFFFFFFFF; - j = 0x8000000000000001; - EXPECT_TRUE(AddOverflow(i, j, ret)); -} -TEST_F(MathUtilUT, AddOverflow_OverflowUint8) { - uint8_t i = 255; - uint8_t j = 0; - uint8_t ret; - EXPECT_FALSE(AddOverflow(i, j, ret)); - - i = 255; - j = 1; - EXPECT_TRUE(AddOverflow(i, j, ret)); - - i = 2; - j = 254; - EXPECT_TRUE(AddOverflow(i, j, ret)); -} - -TEST_F(MathUtilUT, AddOverflow_OverflowDiffType) { - uint16_t i = 255; - uint8_t j = 0; - uint8_t ret; - EXPECT_FALSE(AddOverflow(i, j, ret)); - EXPECT_FALSE(AddOverflow(j, i, ret)); - - i = 256; - j = 0; - EXPECT_TRUE(AddOverflow(i, j, ret)); - EXPECT_TRUE(AddOverflow(j, i, ret)); - - i = 100; - j = 156; - EXPECT_TRUE(AddOverflow(i, j, ret)); - EXPECT_TRUE(AddOverflow(j, i, ret)); -} - -TEST_F(MathUtilUT, AddOverflow_IntUnderflow) { - int8_t i = -128; - int8_t j = 0; - int8_t ret; - EXPECT_FALSE(AddOverflow(i, j, ret)); - EXPECT_FALSE(AddOverflow(j, i, ret)); - - i = -128; - j = -1; - EXPECT_TRUE(AddOverflow(i, j, ret)); - EXPECT_TRUE(AddOverflow(j, i, ret)); -} - -TEST_F(MathUtilUT, AddOverflow_IntDiffTypeUnderflow) { - int16_t i = -128; - int8_t j = 0; - int8_t ret; - EXPECT_FALSE(AddOverflow(i, j, ret)); - EXPECT_FALSE(AddOverflow(j, i, ret)); - - i = -129; - j = 0; - EXPECT_TRUE(AddOverflow(i, j, ret)); - EXPECT_TRUE(AddOverflow(j, i, ret)); - - i = -128; - j = -1; - EXPECT_TRUE(AddOverflow(i, j, ret)); - EXPECT_TRUE(AddOverflow(j, i, ret)); -} - -TEST_F(MathUtilUT, RoundUp) { - EXPECT_EQ(RoundUp(10, 8), 16); - EXPECT_EQ(RoundUp(10, 3), 12); - EXPECT_EQ(RoundUp(10, 2), 10); - EXPECT_EQ(RoundUp(10, 1), 10); - // fail - EXPECT_EQ(RoundUp(std::numeric_limits::max(), 10), 0); -} - -TEST_F(MathUtilUT, CeilDiv16) { - EXPECT_EQ(CeilDiv16(0), 0); - EXPECT_EQ(CeilDiv16(1), 1); - EXPECT_EQ(CeilDiv16(15), 1); - EXPECT_EQ(CeilDiv16(16), 1); - EXPECT_EQ(CeilDiv16(17), 2); - EXPECT_EQ(CeilDiv16(32), 2); - EXPECT_EQ(CeilDiv16(33), 3); -} - -TEST_F(MathUtilUT, CeilDiv32) { - EXPECT_EQ(CeilDiv32(0), 0); - EXPECT_EQ(CeilDiv32(1), 1); - EXPECT_EQ(CeilDiv32(31), 1); - EXPECT_EQ(CeilDiv32(32), 1); - EXPECT_EQ(CeilDiv32(33), 2); - EXPECT_EQ(CeilDiv32(63), 2); - EXPECT_EQ(CeilDiv32(64), 2); - EXPECT_EQ(CeilDiv32(65), 3); -} - -TEST_F(MathUtilUT, MulOverflow_NotOverflow) { - int32_t i; - EXPECT_FALSE(MulOverflow(10, 20, i)); - EXPECT_EQ(i, 200); - - EXPECT_FALSE(MulOverflow(-10, -20, i)); - EXPECT_EQ(i, 200); - - EXPECT_FALSE(MulOverflow(-10, 20, i)); - EXPECT_EQ(i, -200); - - EXPECT_FALSE(MulOverflow(0, 0, i)); - EXPECT_EQ(i, 0); -} - -TEST_F(MathUtilUT, MulOverflow_Overflow) { - int32_t i; - EXPECT_TRUE(MulOverflow(std::numeric_limits::max(), 2, i)); - EXPECT_TRUE(MulOverflow(std::numeric_limits::min(), 2, i)); - EXPECT_TRUE(MulOverflow(std::numeric_limits::min(), -1, i)); - EXPECT_TRUE(MulOverflow(2, std::numeric_limits::max(), i)); - EXPECT_TRUE(MulOverflow(2, std::numeric_limits::min(), i)); - EXPECT_TRUE(MulOverflow(-1, std::numeric_limits::min(), i)); - EXPECT_TRUE(MulOverflow(std::numeric_limits::max() / 2 + 1, std::numeric_limits::max() / 2 + 1, i)); - EXPECT_TRUE(MulOverflow(std::numeric_limits::min() / 2 - 1, std::numeric_limits::min() / 2 - 1, i)); -} - -TEST_F(MathUtilUT, MulOverflow_OverflowUint8) { - uint8_t i; - EXPECT_TRUE(MulOverflow(static_cast(255), static_cast(2), i)); - EXPECT_TRUE(MulOverflow(static_cast(2), static_cast(255), i)); -} - -TEST_F(MathUtilUT, MulOverflow_OverflowDiffType) { - uint8_t i; - EXPECT_TRUE(MulOverflow(300, 1, i)); - EXPECT_TRUE(MulOverflow(1, 300, i)); -} - -TEST_F(MathUtilUT, RoundUpOverflow_Overflow_Int8) { - int8_t value = 127; - int8_t v1; - EXPECT_TRUE(RoundUpOverflow(value, static_cast(4), v1)); -} -TEST_F(MathUtilUT, RoundUpOverflow_Overflow_Int64) { - int64_t value = std::numeric_limits::max() - 2; - int64_t v1; - EXPECT_TRUE(RoundUpOverflow(value, static_cast(4), v1)); -} -TEST_F(MathUtilUT, RoundUpOverflow_Overflow_RetValueSmall) { - int32_t value = 1024; - int8_t v1; - EXPECT_TRUE(RoundUpOverflow(value, static_cast(4), v1)); -} -TEST_F(MathUtilUT, RoundUpOverflow_Overflow_MaxUint32) { - for (uint32_t i = 0U; i < 7; ++i) { - uint32_t value = std::numeric_limits::max() - i; - uint32_t v1; - EXPECT_TRUE(RoundUpOverflow(value, static_cast(8), v1)); - } -} -TEST_F(MathUtilUT, RoundUpOverflow_Overflow_Inplace) { - int64_t value = std::numeric_limits::max() - 2; - EXPECT_TRUE(RoundUpOverflow(value, static_cast(4), value)); -} -TEST_F(MathUtilUT, RoundUpOverflow_NotOverflow_MaxUint32) { - uint32_t value = std::numeric_limits::max() - 7U; - uint32_t v1; - EXPECT_FALSE(RoundUpOverflow(value, static_cast(8), v1)); - EXPECT_EQ(v1, std::numeric_limits::max() - 7U); -} -TEST_F(MathUtilUT, RoundUpOverflow_NotOverflow_EvenlyDivInt8) { - int8_t value = 64; - int8_t v1; - EXPECT_FALSE(RoundUpOverflow(value, static_cast(4), v1)); - EXPECT_EQ(v1, 64); -} -TEST_F(MathUtilUT, RoundUpOverflow_NotOverflow_EvenlyDivInt32) { - int32_t value = 2048; - int32_t v1; - EXPECT_FALSE(RoundUpOverflow(value, static_cast(32), v1)); - EXPECT_EQ(v1, 2048); -} -TEST_F(MathUtilUT, RoundUpOverflow_NotOverflow_NotEvenlyDivInt32) { - for (int32_t i = 0; i < 4; ++i) { - int32_t value = 2047 - i; - int32_t v1; - EXPECT_FALSE(RoundUpOverflow(value, static_cast(32), v1)); - EXPECT_EQ(v1, 2048); - } -} -TEST_F(MathUtilUT, RoundUpOverflow_NotOverflow_Inplace) { - int32_t value = 2048; - EXPECT_FALSE(RoundUpOverflow(value, static_cast(32), value)); - EXPECT_TRUE(value == 2048); - - value = 2040; - EXPECT_FALSE(RoundUpOverflow(value, static_cast(32), value)); - EXPECT_TRUE(value == 2048); - - value = 32; - EXPECT_FALSE(RoundUpOverflow(value,value, value)); - EXPECT_TRUE(value == 32); -} -TEST_F(MathUtilUT, RoundUpOverflow_Failed_MultipleOfZero) { - int8_t value = 10; - int8_t v1; - EXPECT_TRUE(RoundUpOverflow(value, static_cast(0), v1)); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/model_serialize_unittest.cc b/tests/ut/graph/testcase/model_serialize_unittest.cc deleted file mode 100644 index 7b7a59ef1efd0fb42e69d21f8ebe8870b0350aea..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/model_serialize_unittest.cc +++ /dev/null @@ -1,1694 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "graph/ge_attr_value.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/model_serialize.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/tensor_utils.h" -#include "graph/detail/model_serialize_imp.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/attr_utils.h" -#include "graph/op_desc.h" -#include "graph/compute_graph.h" -#include "graph_builder_utils.h" -#include "graph/node.h" -#include "graph/normal_graph/node_impl.h" -#include "test_std_structs.h" -#include "external/graph/operator_factory.h" -#include "graph/utils/op_desc_utils.h" -#include "external/graph/operator_reg.h" - -#include "graph/utils/node_utils.h" -#include "proto/ge_ir.pb.h" -#include "proto/om.pb.h" -#include "attribute_group/attr_group_base.h" -#include "test_structs.h" - -using namespace ge; -using namespace std; - -using std::vector; -using std::string; -namespace { -static ComputeGraphPtr BuildSubSubComputeGraph() { - ut::GraphBuilder builder = ut::GraphBuilder("subsubgraph"); - auto data = builder.AddNode("sub_sub_Data", "sub_sub_Data", 0, 1); - auto netoutput = builder.AddNode("sub_sub_Netoutput", "sub_sub_NetOutput", 1, 0); - builder.AddDataEdge(data, 0, netoutput, 0); - auto graph = builder.GetGraph(); - return graph; -} - -static ComputeGraphPtr BuildSubComputeGraph() { - ut::GraphBuilder builder = ut::GraphBuilder("subgraph"); - auto data = builder.AddNode("sub_Data", "sub_Data", 0, 1); - auto partitioned_call = builder.AddNode("PartitionedCall", "PartitionedCall", 1, 1); - partitioned_call->GetOpDesc()->AddSubgraphName("subsubgraph"); - partitioned_call->GetOpDesc()->SetSubgraphInstanceName(0, "subsubgraph"); - auto netoutput = builder.AddNode("sub_Netoutput", "sub_NetOutput", 1, 0); - builder.AddDataEdge(data, 0, partitioned_call, 0); - builder.AddDataEdge(partitioned_call, 0, netoutput, 0); - auto subgraph = builder.GetGraph(); - partitioned_call->SetOwnerComputeGraph(subgraph); - ComputeGraphPtr subsubgraph = BuildSubSubComputeGraph(); - subsubgraph->SetParentGraph(subgraph); - subsubgraph->SetParentNode(partitioned_call); - return subgraph; -} -// construct graph which contains subgraph -static ComputeGraphPtr BuildComputeGraph() { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - transdata->GetOpDesc()->AddSubgraphName("subgraph"); - transdata->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph"); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - // add subgraph - transdata->SetOwnerComputeGraph(graph); - ComputeGraphPtr subgraph = BuildSubComputeGraph(); - subgraph->SetParentGraph(graph); - subgraph->SetParentNode(transdata); - - auto partitioned_call = subgraph->FindNode("PartitionedCall"); - auto sub_sub_graph = ge::NodeUtils::GetSubgraph(*partitioned_call, 0U); - - graph->AddSubgraph("subgraph", subgraph); - graph->AddSubgraph("partitioned_call", sub_sub_graph); - return graph; -} - -TEST(UTEST_ge_model_serialize, GetAllSubGraphsRecursivelySuccess) -{ - Model model("model_name", "custom version3.0"); - ComputeGraphPtr cgp = BuildComputeGraph(); - model.SetGraph(cgp); - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - ASSERT_NE(buffer.GetSize(), 0); -} -} - -bool LinkEdge(NodePtr srcNode, int32_t srcIndex, NodePtr dstNode, int32_t dstIndex) -{ - if (srcIndex >= 0) { - auto srcAnchor = srcNode->GetOutDataAnchor(srcIndex); - auto dstAnchor = dstNode->GetInDataAnchor(dstIndex); - srcAnchor->LinkTo(dstAnchor); - } else { - auto srcAnchor = srcNode->GetOutControlAnchor(); - auto dstAnchor = dstNode->GetInControlAnchor(); - srcAnchor->LinkTo(dstAnchor); - } - return true; -} - -NodePtr CreateNode(OpDescPtr op, ComputeGraphPtr ownerGraph) -{ - return ownerGraph->AddNode(op); -} - -void CompareShape(const vector& shape1, const vector& shape2) -{ - EXPECT_EQ(shape1.size(), shape2.size()); - if (shape1.size() == shape2.size()) { - for (size_t i = 0; i < shape1.size(); i++) { - EXPECT_EQ(shape1[i], shape2[i]); - } - } -} - -template -void CompareList(const vector& val1, const vector& val2) -{ - EXPECT_EQ(val1.size(), val2.size()); - if (val1.size() == val2.size()) { - for (size_t i = 0; i < val1.size(); i++) { - EXPECT_EQ(val1[i], val2[i]); - } - } -} - -static bool NamedAttrsSimpleCmp(const GeAttrValue& left, const GeAttrValue& right) -{ - GeAttrValue::NamedAttrs val1, val2; - left.GetValue(val1); - right.GetValue(val2); - if (val1.GetName() != val2.GetName()) { - return false; - } - auto attrs1 = val1.GetAllAttrs(); - auto attrs2 = val2.GetAllAttrs(); - if (attrs1.size() != attrs1.size()) { - return false; - } - - for (auto it: attrs1) { - auto it2 = attrs2.find(it.first); - if (it2 == attrs2.end()) { // simple check - return false; - } - if(it.second.GetValueType() != it2->second.GetValueType()){ - return false; - } - switch (it.second.GetValueType()){ - case GeAttrValue::VT_INT:{ - int64_t i1 = 0, i2 = 0; - it.second.GetValue(i1); - it2->second.GetValue(i2); - if(i1 != i2){ - return false; - } - } - case GeAttrValue::VT_FLOAT:{ - GeAttrValue::FLOAT i1 = 0, i2 = 0; - it.second.GetValue(i1); - it2->second.GetValue(i2); - if(fabs(i1 - i2) > FLT_EPSILON){ - return false; - } - } - case GeAttrValue::VT_STRING:{ - string i1, i2; - it.second.GetValue(i1); - it2->second.GetValue(i2); - if(i1 != i2){ - return false; - } - } - case GeAttrValue::VT_BOOL:{ - bool i1 = false, i2 = false; - it.second.GetValue(i1); - it2->second.GetValue(i2); - if(i1 != i2){ - return false; - } - } - default: { - continue; - } - } - } - return true; -} - -static GeAttrValue::NamedAttrs CreateNamedAttrs(const string& name, std::map map) -{ - GeAttrValue::NamedAttrs namedAttrs; - namedAttrs.SetName(name); - for(auto it :map){ - namedAttrs.SetAttr(it.first, it.second); - } - return namedAttrs; -} - -static ComputeGraphPtr CreateGraph_1_1_224_224(float *tensor_data) { - ut::GraphBuilder builder("graph1"); - auto data1 = builder.AddNode("data1", "Data", {}, {"y"}); - AttrUtils::SetInt(data1->GetOpDesc(), "index", 0); - auto const1 = builder.AddNode("const1", "Const", {}, {"y"}); - GeTensorDesc const1_td; - const1_td.SetShape(GeShape({1, 1, 224, 224})); - const1_td.SetOriginShape(GeShape({1, 1, 224, 224})); - const1_td.SetFormat(FORMAT_NCHW); - const1_td.SetOriginFormat(FORMAT_NCHW); - const1_td.SetDataType(DT_FLOAT); - const1_td.SetOriginDataType(DT_FLOAT); - GeTensor tensor(const1_td); - tensor.SetData(reinterpret_cast(tensor_data), sizeof(float) * 224 * 224); - AttrUtils::SetTensor(const1->GetOpDesc(), "value", tensor); - auto add1 = builder.AddNode("add1", "Add", {"x1", "x2"}, {"y"}); - auto netoutput1 = builder.AddNode("NetOutputNode", "NetOutput", {"x"}, {}); - ge::AttrUtils::SetListListInt(add1->GetOpDesc()->MutableOutputDesc(0), "list_list_i", {{1, 0, 0, 0}}); - ge::AttrUtils::SetListInt(add1->GetOpDesc(), "list_i", {1}); - ge::AttrUtils::SetListStr(add1->GetOpDesc(), "list_s", {"1"}); - ge::AttrUtils::SetListFloat(add1->GetOpDesc(), "list_f", {1.0}); - ge::AttrUtils::SetListBool(add1->GetOpDesc(), "list_b", {false}); - builder.AddDataEdge(data1, 0, add1, 0); - builder.AddDataEdge(const1, 0, add1, 1); - builder.AddDataEdge(add1, 0, netoutput1, 0); - - return builder.GetGraph(); -} - -TEST(UTEST_ge_model_serialize, simple) -{ - Model model("model_name", "custom version3.0"); - model.SetAttr("model_key1", GeAttrValue::CreateFrom(123)); - model.SetAttr("model_key2", GeAttrValue::CreateFrom(456.78f)); - model.SetAttr("model_key3", GeAttrValue::CreateFrom("abcd")); - model.SetAttr("model_key4", GeAttrValue::CreateFrom({123, 456})); - model.SetAttr("model_key5", GeAttrValue::CreateFrom({456.78f, 998.90f})); - model.SetAttr("model_key6", GeAttrValue::CreateFrom({"abcd", "happy"})); - model.SetAttr("model_key7", GeAttrValue::CreateFrom(false)); - model.SetAttr("model_key8", GeAttrValue::CreateFrom({true, false})); - - auto computeGraph = std::make_shared("graph_name"); - - // input - auto inputOp = std::make_shared("input", "Input"); - inputOp->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto input = CreateNode(inputOp, computeGraph); - // w1 - auto w1Op = std::make_shared("w1", "ConstOp"); - w1Op->AddOutputDesc(GeTensorDesc(GeShape({12, 2, 64, 64, 16}), FORMAT_NC1HWC0, DT_FLOAT16)); - auto w1 = CreateNode(w1Op, computeGraph); - - // node1 - auto node1Op = std::make_shared("node1", "Conv2D"); - node1Op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - node1Op->AddInputDesc(GeTensorDesc(GeShape({12, 2, 64, 64, 16}), FORMAT_NC1HWC0, DT_FLOAT16)); - node1Op->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto node1 = CreateNode(node1Op, computeGraph); - - // Attr set - node1Op->SetAttr("node_key1", GeAttrValue::CreateFrom(Buffer(10))); - node1Op->SetAttr("node_key2", GeAttrValue::CreateFrom({Buffer(20), Buffer(30)})); - auto namedAttrs1 = GeAttrValue::CreateFrom( - CreateNamedAttrs("my_name", {{"int_val", - GeAttrValue::CreateFrom( - 123)}, - {"str_val", - GeAttrValue::CreateFrom( - "abc")}, - {"float_val", - GeAttrValue::CreateFrom( - 345.345)}})); - - node1Op->SetAttr("node_key3", std::move(namedAttrs1)); - auto listNamedAttrs = GeAttrValue::CreateFrom( - {CreateNamedAttrs("my_name", - {{"int_val", - GeAttrValue::CreateFrom( - 123)}, - {"float_val", - GeAttrValue::CreateFrom( - 345.345)}}), - CreateNamedAttrs("my_name2", - {{"str_val", - GeAttrValue::CreateFrom( - "abc")}, - {"float_val", - GeAttrValue::CreateFrom( - 345.345)}})}); - node1Op->SetAttr("node_key4", std::move(listNamedAttrs)); - // tensor - auto tensorData1 = "qwertyui"; - auto tensor1 = std::make_shared(GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_INT8), (uint8_t*) tensorData1, - 8); - auto tensorData2 = "asdfqwertyui"; - auto tensor2 = std::make_shared(GeTensorDesc(GeShape({3, 2, 2}), FORMAT_ND, DT_UINT8), (uint8_t*) tensorData2, - 12); - auto tensorData3 = "ghjkasdfqwertyui"; - auto tensor3 = std::make_shared(GeTensorDesc(GeShape({4, 2, 2}), FORMAT_ND, DT_UINT16), (uint8_t*) tensorData3, - 16); - node1Op->SetAttr("node_key5", GeAttrValue::CreateFrom(tensor1)); - node1Op->SetAttr("node_key6", GeAttrValue::CreateFrom({tensor2, tensor3})); - - auto tensorDesc = GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_INT16); - TensorUtils::SetSize(tensorDesc, 100); - node1Op->SetAttr("node_key7", GeAttrValue::CreateFrom(tensorDesc)); - node1Op->SetAttr("node_key8", GeAttrValue::CreateFrom( - {GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_INT32), GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_UINT32), - GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_INT64), GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_UINT64), - GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_BOOL), - GeTensorDesc(GeShape({2, 2, 2}), FORMAT_NCHW, DT_DOUBLE)})); - - LinkEdge(input, 0, node1, 0); - LinkEdge(w1, 0, node1, 1); - - model.SetGraph(computeGraph); - - Buffer buffer; - ASSERT_EQ(model.Save(buffer), GRAPH_SUCCESS); - EXPECT_TRUE(buffer.GetData() != nullptr); - - Model model2; - ASSERT_EQ(Model::Load(buffer.GetData(), buffer.GetSize(), model2), GRAPH_SUCCESS); - EXPECT_EQ(model2.GetName(), "model_name"); - GeAttrValue::INT modelVal1; - AttrUtils::GetInt(&model2, "model_key1", modelVal1); - EXPECT_EQ(modelVal1, 123); - - GeAttrValue::FLOAT modelVal2; - AttrUtils::GetFloat(&model2, "model_key2", modelVal2); - EXPECT_EQ(modelVal2, (float) 456.78f); - - GeAttrValue::STR modelVal3; - AttrUtils::GetStr(&model2, "model_key3", modelVal3); - EXPECT_EQ(modelVal3, "abcd"); - - GeAttrValue::LIST_INT modelVal4; - AttrUtils::GetListInt(&model2, "model_key4", modelVal4); - CompareList(modelVal4, {123, 456}); - - GeAttrValue::LIST_FLOAT modelVal5; - AttrUtils::GetListFloat(&model2, "model_key5", modelVal5); - CompareList(modelVal5, {456.78f, 998.90f}); - - GeAttrValue::LIST_STR modelVal6; - AttrUtils::GetListStr(&model2, "model_key6", modelVal6); - CompareList(modelVal6, {"abcd", "happy"}); - - GeAttrValue::BOOL modelVal7; - EXPECT_EQ(AttrUtils::GetBool(&model2, "model_key7", modelVal7), true); - EXPECT_EQ(modelVal7, false); - - GeAttrValue::LIST_BOOL modelVal8; - AttrUtils::GetListBool(&model2, "model_key8", modelVal8); - CompareList(modelVal8, {true, false}); - - const auto& s_graph = model2.GetGraph(); - ASSERT_TRUE(s_graph != nullptr); - auto s_nodes = s_graph->GetDirectNode(); - ASSERT_EQ(3, s_nodes.size()); - - auto s_input = s_nodes.at(0); - auto s_w1 = s_nodes.at(1); - auto s_nod1 = s_nodes.at(2); - { - auto s_op = s_input->GetOpDesc(); - EXPECT_EQ(s_op->GetName(), "input"); - EXPECT_EQ(s_op->GetType(), "Input"); - auto s_input_descs = s_op->GetAllInputsDesc(); - ASSERT_EQ(s_input_descs.size(), 0); - auto s_output_descs = s_op->GetAllOutputsDesc(); - ASSERT_EQ(s_output_descs.size(), 1); - auto desc1 = s_output_descs.at(0); - EXPECT_EQ(desc1.GetFormat(), FORMAT_NCHW); - EXPECT_EQ(desc1.GetDataType(), DT_FLOAT); - CompareShape(desc1.GetShape().GetDims(), vector{12, 32, 64, 64}); - - auto outAnchor = s_input->GetOutDataAnchor(0); - auto peerAnchors = outAnchor->GetPeerInDataAnchors(); - ASSERT_EQ(peerAnchors.size(), 1); - auto peerAnchor = peerAnchors.at(0); - ASSERT_EQ(peerAnchor->GetIdx(), 0); - ASSERT_EQ(peerAnchor->GetOwnerNode(), s_nod1); - } - - { - auto s_op = s_w1->GetOpDesc(); - EXPECT_EQ(s_op->GetName(), "w1"); - EXPECT_EQ(s_op->GetType(), "ConstOp"); - auto s_input_descs = s_op->GetAllInputsDesc(); - ASSERT_EQ(s_input_descs.size(), 0); - auto s_output_descs = s_op->GetAllOutputsDesc(); - ASSERT_EQ(s_output_descs.size(), 1); - auto desc1 = s_output_descs.at(0); - EXPECT_EQ(desc1.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(desc1.GetDataType(), DT_FLOAT16); - CompareShape(desc1.GetShape().GetDims(), vector{12, 2, 64, 64, 16}); - - auto outAnchor = s_w1->GetOutDataAnchor(0); - auto peerAnchors = outAnchor->GetPeerInDataAnchors(); - ASSERT_EQ(peerAnchors.size(), 1); - auto peerAnchor = peerAnchors.at(0); - ASSERT_EQ(peerAnchor->GetIdx(), 1); - ASSERT_EQ(peerAnchor->GetOwnerNode(), s_nod1); - } - { - auto s_op = s_nod1->GetOpDesc(); - EXPECT_EQ(s_op->GetName(), "node1"); - EXPECT_EQ(s_op->GetType(), "Conv2D"); - auto s_input_descs = s_op->GetAllInputsDesc(); - ASSERT_EQ(s_input_descs.size(), 2); - - auto desc1 = s_input_descs.at(0); - EXPECT_EQ(desc1.GetFormat(), FORMAT_NCHW); - EXPECT_EQ(desc1.GetDataType(), DT_FLOAT); - CompareShape(desc1.GetShape().GetDims(), vector{12, 32, 64, 64}); - - auto desc2 = s_input_descs.at(1); - EXPECT_EQ(desc2.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(desc2.GetDataType(), DT_FLOAT16); - CompareShape(desc2.GetShape().GetDims(), vector{12, 2, 64, 64, 16}); - - auto s_output_descs = s_op->GetAllOutputsDesc(); - ASSERT_EQ(s_output_descs.size(), 1); - auto desc3 = s_output_descs.at(0); - EXPECT_EQ(desc3.GetFormat(), FORMAT_NCHW); - EXPECT_EQ(desc3.GetDataType(), DT_FLOAT); - CompareShape(desc3.GetShape().GetDims(), vector{12, 32, 64, 64}); - - auto outAnchor = s_nod1->GetOutDataAnchor(0); - auto peerAnchors = outAnchor->GetPeerInDataAnchors(); - ASSERT_EQ(peerAnchors.size(), 0); - - // node attrs - GeAttrValue::BYTES nodeVal1; - AttrUtils::GetBytes(s_op, "node_key1", nodeVal1); - ASSERT_EQ(nodeVal1.GetSize(), 10); - - GeAttrValue::LIST_BYTES nodeVal2; - AttrUtils::GetListBytes(s_op, "node_key2", nodeVal2); - ASSERT_EQ(nodeVal2.size(), 2); - ASSERT_EQ(nodeVal2[0].GetSize(), 20); - ASSERT_EQ(nodeVal2[1].GetSize(), 30); - - GeAttrValue s_namedAttrs; - s_op->GetAttr("node_key3", s_namedAttrs); - EXPECT_TRUE(NamedAttrsSimpleCmp(s_namedAttrs, namedAttrs1)); - - GeAttrValue s_listNamedAttrs; - s_op->GetAttr("node_key4", s_listNamedAttrs); - EXPECT_TRUE(NamedAttrsSimpleCmp(s_listNamedAttrs, listNamedAttrs)); - - ConstGeTensorPtr s_tensor; - AttrUtils::GetTensor(s_op, "node_key5", s_tensor); - ASSERT_TRUE(s_tensor == nullptr); - - GeTensorDesc s_tensorDesc; - AttrUtils::GetTensorDesc(s_op, "node_key7", s_tensorDesc); - EXPECT_EQ(s_tensorDesc.GetFormat(), FORMAT_ND); - EXPECT_EQ(s_tensorDesc.GetDataType(), DT_FLOAT); - int64_t size = -1; - TensorUtils::GetSize(s_tensorDesc, size); - EXPECT_EQ(size, 0); - } -} - -TEST(UTEST_ge_model_serialize, OpDescAsAttrValue) -{ - // node1Op - auto node1Op = std::make_shared("node1", "Conv2D"); - node1Op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - node1Op->AddInputDesc(GeTensorDesc(GeShape({12, 2, 64, 64, 16}), FORMAT_NC1HWC0, DT_FLOAT16)); - node1Op->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - - // Attr set - node1Op->SetAttr("node_key1", GeAttrValue::CreateFrom(Buffer(10))); - node1Op->SetAttr("node_key2", GeAttrValue::CreateFrom({Buffer(20), Buffer(30)})); - auto namedAttrs1 = GeAttrValue::CreateFrom( - CreateNamedAttrs("my_name", {{"int_val", - GeAttrValue::CreateFrom( - 123)}, - {"str_val", - GeAttrValue::CreateFrom( - "abc")}, - {"float_val", - GeAttrValue::CreateFrom( - 345.345)}})); - - node1Op->SetAttr("node_key3", std::move(namedAttrs1)); - auto listNamedAttrs = GeAttrValue::CreateFrom( - {CreateNamedAttrs("my_name", - {{"int_val", - GeAttrValue::CreateFrom( - 123)}, - {"float_val", - GeAttrValue::CreateFrom( - 345.345)}}), - CreateNamedAttrs("my_name2", - {{"str_val", - GeAttrValue::CreateFrom( - "abc")}, - {"float_val", - GeAttrValue::CreateFrom( - 345.345)}})}); - node1Op->SetAttr("node_key4", std::move(listNamedAttrs)); - - - Model model; - EXPECT_TRUE(AttrUtils::SetListInt(&model, "my_key2",{123})); - EXPECT_TRUE(AttrUtils::SetListBytes(&model, "my_key3",{Buffer(100)})); -} - -TEST(UTEST_ge_model_serialize, test_subGraph) -{ - Model model("model_name", "custom version3.0"); - { - auto computeGraph = std::make_shared("graph_name"); - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto input = CreateNode(inputOp, computeGraph); - model.SetGraph(computeGraph); - - auto subComputeGraph = std::make_shared("sub_graph"); - // input - auto subGraphInputOp = std::make_shared("sub_graph_test", "TestOp2"); - subGraphInputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto subGraphInput = CreateNode(subGraphInputOp, subComputeGraph); - - AttrUtils::SetGraph(inputOp, "sub_graph", subComputeGraph); - } - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - ASSERT_GE(buffer.GetSize(), 0); -} - -REG_OP(MatMul) - .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .ATTR(transpose_x1, Bool, false) - .ATTR(transpose_x2, Bool, false) - .OP_END_FACTORY_REG(MatMul) - -TEST(UTEST_ge_model_serialize, test_ir_definitions) -{ - Model model("model_name", "custom version3.0"); - auto op = ge::OperatorFactory::CreateOperator("MatMul", "MatMul"); - auto op_desc_origin = ge::OpDescUtils::GetOpDescFromOperator(op); - EXPECT_NE(op_desc_origin, nullptr); - auto computeGraph = std::make_shared("graph_name"); - CreateNode(op_desc_origin, computeGraph); - model.SetGraph(computeGraph); - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(computeGraph); - auto node = graph.GetDirectNode(); - - ModelSerialize serialize; - proto::ModelDef model_def; - ASSERT_EQ(serialize.SerializeModel(model, false, model_def), SUCCESS); //success - GraphUtils::WriteProtoToTextFile(model_def, "./ir_definitions.txt"); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - auto state = GraphUtils::LoadGEGraph("./ir_definitions.txt", *com_graph1); - ASSERT_EQ(state, true); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 1); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == "MatMul") { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - EXPECT_FALSE(op_desc->GetIrAttrNames().empty()); - EXPECT_FALSE(op_desc->GetIrInputs().empty()); - EXPECT_EQ(op_desc->GetIrAttrNames().size(), op_desc_origin->GetIrAttrNames().size()); - EXPECT_EQ(op_desc->GetIrInputs().size(), op_desc_origin->GetIrInputs().size()); - } - } - system("rm -rf ./ir_definitions.txt"); -} - -TEST(UTEST_ge_model_serialize, test_large_model) -{ - Model model("model_name/main_model", "custom version3.0"); - { - auto compute_graph = std::make_shared("graph_name/main_graph"); - // input - auto input_op = std::make_shared("test/const", CONSTANT); - input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto input = CreateNode(input_op, compute_graph); - GeTensor ge_tensor; - auto aligned_ptr = std::make_shared(4294967296U); // 4g - auto ptr = aligned_ptr->MutableGet(); - *ptr = 7; - *(ptr + 10) = 8; - *(ptr + 4294967295) = 9; - ge_tensor.SetData(aligned_ptr, 4294967296); - AttrUtils::SetTensor(input_op, ATTR_NAME_WEIGHTS, ge_tensor); - model.SetGraph(compute_graph); - - auto sub_compute_graph = std::make_shared("graph_name/sub_graph"); - auto sub_graph_input_op = std::make_shared("sub_graph_test", "TestOp2"); - sub_graph_input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto sub_graph_input = CreateNode(sub_graph_input_op, sub_compute_graph); - - std::string sub_graph = "graph_name/sub_graph"; - input_op->AddSubgraphName(sub_graph); - input_op->SetSubgraphInstanceName(0, sub_graph); - sub_compute_graph->SetParentNode(input); - sub_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(sub_compute_graph); - } - - ModelSerialize serialize; - Model model_back; - auto buffer = serialize.SerializeModel(model); - ASSERT_NE(buffer.GetSize(), 0);// failed - proto::ModelDef model_def; - ASSERT_EQ(serialize.UnserializeModel(buffer.GetData(), buffer.GetSize(), model_back), true); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 2); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == CONSTANT) { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - ASSERT_EQ(buff[4294967295], 9); // value is ok for def serialize - } - } - system("rm -rf ./air_weight"); -} - -TEST(UTEST_ge_model_serialize, test_large_model_with_16_subgraph_multi_thread) -{ - Model model("model_name/main_model", "custom version3.0"); - { - auto compute_graph = std::make_shared("main_graph"); - // input - auto input1 = std::make_shared("input1", "Data"); - auto input2 = std::make_shared("input2", "Data"); - auto input3 = std::make_shared("input3", "Data"); - input1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto input_data1 = CreateNode(input1, compute_graph); - auto input_data2 = CreateNode(input2, compute_graph); - auto input_data3 = CreateNode(input3, compute_graph); - auto case_op = std::make_shared("case", "Case"); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto case_node = CreateNode(case_op, compute_graph); - auto output_op = std::make_shared("output", "NetOutput"); - output_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto output_node = CreateNode(output_op, compute_graph); - ge::GraphUtils::AddEdge(input_data1->GetOutAnchor(0), case_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(input_data2->GetOutAnchor(0), case_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(input_data3->GetOutAnchor(0), case_node->GetInAnchor(2)); - ge::GraphUtils::AddEdge(case_node->GetOutAnchor(0), output_node->GetInAnchor(0)); - model.SetGraph(compute_graph); - for (auto i = 0UL; i < 16; i++) { - std::string subgraph_name = "subgraph" + std::to_string(i); - auto sub_compute_graph = std::make_shared(subgraph_name); - auto sub_graph_input_op = std::make_shared("data1", "DATA"); - sub_graph_input_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_graph_input_op1 = std::make_shared("data2", "DATA"); - sub_graph_input_op1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_input_node = CreateNode(sub_graph_input_op, sub_compute_graph); - auto sub_input_node1 = CreateNode(sub_graph_input_op1, sub_compute_graph); - auto sub_graph_add_op = std::make_shared("add_sub", "ADD"); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_add_node = CreateNode(sub_graph_add_op, sub_compute_graph); - auto sub_output_op = std::make_shared("sub_output", "NetOutput"); - sub_output_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_output_node = CreateNode(sub_output_op, sub_compute_graph); - ge::GraphUtils::AddEdge(sub_input_node->GetOutAnchor(0), sub_add_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(sub_input_node1->GetOutAnchor(0), sub_add_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(sub_add_node->GetOutAnchor(0), sub_output_node->GetInAnchor(0)); - case_op->AddSubgraphName(subgraph_name); - case_op->SetSubgraphInstanceName(i, subgraph_name); - sub_compute_graph->SetParentNode(case_node); - sub_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(sub_compute_graph); - } - } - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - ASSERT_NE(buffer.GetSize(), 0);// failed - Model model_back; - ASSERT_EQ(Model::LoadWithMultiThread(buffer.GetData(), buffer.GetSize(), model_back), SUCCESS); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllSubgraphs().size(), 16); - for (auto &sub_graph : com_graph1->GetAllSubgraphs()) { - ASSERT_EQ((sub_graph == nullptr), false); - EXPECT_EQ(sub_graph->GetDirectNodesSize(), 4); - for (auto &nodes : sub_graph->GetDirectNode()) { - ASSERT_EQ((nodes == nullptr), false); - } - } -} - -TEST(UTEST_ge_model_serialize, test_large_model_with_16_subgraph_single_thread) -{ - Model model("model_name/main_model", "custom version3.0"); - { - auto compute_graph = std::make_shared("main_graph"); - // input - auto input1 = std::make_shared("input1", "Data"); - auto input2 = std::make_shared("input2", "Data"); - auto input3 = std::make_shared("input3", "Data"); - input1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto input_data1 = CreateNode(input1, compute_graph); - auto input_data2 = CreateNode(input2, compute_graph); - auto input_data3 = CreateNode(input3, compute_graph); - auto case_op = std::make_shared("case", "Case"); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto case_node = CreateNode(case_op, compute_graph); - auto output_op = std::make_shared("output", "NetOutput"); - output_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto output_node = CreateNode(output_op, compute_graph); - ge::GraphUtils::AddEdge(input_data1->GetOutAnchor(0), case_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(input_data2->GetOutAnchor(0), case_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(input_data3->GetOutAnchor(0), case_node->GetInAnchor(2)); - ge::GraphUtils::AddEdge(case_node->GetOutAnchor(0), output_node->GetInAnchor(0)); - model.SetGraph(compute_graph); - for (auto i = 0UL; i < 16; i++) { - std::string subgraph_name = "subgraph" + std::to_string(i); - auto sub_compute_graph = std::make_shared(subgraph_name); - auto sub_graph_input_op = std::make_shared("data1", "DATA"); - sub_graph_input_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_graph_input_op1 = std::make_shared("data2", "DATA"); - sub_graph_input_op1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_input_node = CreateNode(sub_graph_input_op, sub_compute_graph); - auto sub_input_node1 = CreateNode(sub_graph_input_op1, sub_compute_graph); - auto sub_graph_add_op = std::make_shared("add_sub", "ADD"); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_add_node = CreateNode(sub_graph_add_op, sub_compute_graph); - auto sub_output_op = std::make_shared("sub_output", "NetOutput"); - sub_output_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_output_node = CreateNode(sub_output_op, sub_compute_graph); - ge::GraphUtils::AddEdge(sub_input_node->GetOutAnchor(0), sub_add_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(sub_input_node1->GetOutAnchor(0), sub_add_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(sub_add_node->GetOutAnchor(0), sub_output_node->GetInAnchor(0)); - case_op->AddSubgraphName(subgraph_name); - case_op->SetSubgraphInstanceName(i, subgraph_name); - sub_compute_graph->SetParentNode(case_node); - sub_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(sub_compute_graph); - } - } - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - ASSERT_NE(buffer.GetSize(), 0);// failed - Model model_back; - ASSERT_EQ(Model::Load(buffer.GetData(), buffer.GetSize(), model_back), SUCCESS); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllSubgraphs().size(), 16); - for (auto &sub_graph : com_graph1->GetAllSubgraphs()) { - ASSERT_EQ((sub_graph == nullptr), false); - EXPECT_EQ(sub_graph->GetDirectNodesSize(), 4); - for (auto &nodes : sub_graph->GetDirectNode()) { - ASSERT_EQ((nodes == nullptr), false); - } - } -} - -TEST(UTEST_ge_model_serialize, test_large_model_with_30_subgraph) -{ - Model model("model_name/main_model", "custom version3.0"); - { - auto compute_graph = std::make_shared("main_graph"); - // input - auto input1 = std::make_shared("input1", "Data"); - auto input2 = std::make_shared("input2", "Data"); - auto input3 = std::make_shared("input3", "Data"); - input1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto input_data1 = CreateNode(input1, compute_graph); - auto input_data2 = CreateNode(input2, compute_graph); - auto input_data3 = CreateNode(input3, compute_graph); - auto case_op = std::make_shared("case", "Case"); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto case_node = CreateNode(case_op, compute_graph); - auto output_op = std::make_shared("output", "NetOutput"); - output_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto output_node = CreateNode(output_op, compute_graph); - ge::GraphUtils::AddEdge(input_data1->GetOutAnchor(0), case_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(input_data2->GetOutAnchor(0), case_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(input_data3->GetOutAnchor(0), case_node->GetInAnchor(2)); - ge::GraphUtils::AddEdge(case_node->GetOutAnchor(0), output_node->GetInAnchor(0)); - model.SetGraph(compute_graph); - for (auto i = 0UL; i < 30; i++) { - std::string subgraph_name = "subgraph" + std::to_string(i); - auto sub_compute_graph = std::make_shared(subgraph_name); - auto sub_graph_input_op = std::make_shared("data1", "DATA"); - sub_graph_input_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_graph_input_op1 = std::make_shared("data2", "DATA"); - sub_graph_input_op1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_input_node = CreateNode(sub_graph_input_op, sub_compute_graph); - auto sub_input_node1 = CreateNode(sub_graph_input_op1, sub_compute_graph); - auto sub_graph_add_op = std::make_shared("add_sub", "ADD"); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_add_node = CreateNode(sub_graph_add_op, sub_compute_graph); - auto sub_output_op = std::make_shared("sub_output", "NetOutput"); - sub_output_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_output_node = CreateNode(sub_output_op, sub_compute_graph); - ge::GraphUtils::AddEdge(sub_input_node->GetOutAnchor(0), sub_add_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(sub_input_node1->GetOutAnchor(0), sub_add_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(sub_add_node->GetOutAnchor(0), sub_output_node->GetInAnchor(0)); - case_op->AddSubgraphName(subgraph_name); - case_op->SetSubgraphInstanceName(i, subgraph_name); - sub_compute_graph->SetParentNode(case_node); - sub_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(sub_compute_graph); - } - } - - ModelSerialize serialize; - Model model_back; - auto buffer = serialize.SerializeModel(model); - ASSERT_NE(buffer.GetSize(), 0);// failed - proto::ModelDef model_def; - ASSERT_EQ(serialize.UnserializeModel(buffer.GetData(), buffer.GetSize(), model_back), true); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllSubgraphs().size(), 30); - for (auto &sub_graph : com_graph1->GetAllSubgraphs()) { - ASSERT_EQ((sub_graph == nullptr), false); - EXPECT_EQ(sub_graph->GetDirectNodesSize(), 4); - for (auto &nodes : sub_graph->GetDirectNode()) { - ASSERT_EQ((nodes == nullptr), false); - } - } -} - -TEST(UTEST_ge_model_serialize, test_large_model_with_subgraph_error) -{ - Model model("model_name/main_model", "custom version3.0"); - { - auto compute_graph = std::make_shared("main_graph"); - // input - auto input1 = std::make_shared("input1", "Data"); - auto input2 = std::make_shared("input2", "Data"); - auto input3 = std::make_shared("input3", "Data"); - input1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input2->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - input3->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto input_data1 = CreateNode(input1, compute_graph); - auto input_data2 = CreateNode(input2, compute_graph); - auto input_data3 = CreateNode(input3, compute_graph); - auto case_op = std::make_shared("case", "Case"); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - case_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto case_node = CreateNode(case_op, compute_graph); - auto output_op = std::make_shared("output", "NetOutput"); - output_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto output_node = CreateNode(output_op, compute_graph); - ge::GraphUtils::AddEdge(input_data1->GetOutAnchor(0), case_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(input_data2->GetOutAnchor(0), case_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(input_data3->GetOutAnchor(0), case_node->GetInAnchor(2)); - ge::GraphUtils::AddEdge(case_node->GetOutAnchor(0), output_node->GetInAnchor(0)); - model.SetGraph(compute_graph); - for (auto i = 0UL; i < 3; i++) { - std::string subgraph_name = "subgraph" + std::to_string(i); - auto sub_compute_graph = std::make_shared(subgraph_name); - auto sub_graph_input_op = std::make_shared("data1", "DATA"); - sub_graph_input_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_graph_input_op1 = std::make_shared("subgraph/const1", CONSTANT); - sub_graph_input_op1->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_input_op1->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - GeTensor ge_tensor; - auto aligned_ptr = std::make_shared(64); // 500m - auto ptr = aligned_ptr->MutableGet(); - *ptr = 7; - *(ptr + 10) = 8; - ge_tensor.SetData(aligned_ptr, 64); - AttrUtils::SetTensor(sub_graph_input_op1, ATTR_NAME_WEIGHTS, ge_tensor); - AttrUtils::SetStr(sub_graph_input_op1, ATTR_NAME_LOCATION, "file_path"); - AttrUtils::SetInt(sub_graph_input_op1, ATTR_NAME_LENGTH, 20); - - auto sub_input_node = CreateNode(sub_graph_input_op, sub_compute_graph); - auto sub_input_node1 = CreateNode(sub_graph_input_op1, sub_compute_graph); - auto sub_graph_add_op = std::make_shared("add_sub", "ADD"); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - sub_graph_add_op->AddOutputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_add_node = CreateNode(sub_graph_add_op, sub_compute_graph); - auto sub_output_op = std::make_shared("sub_output", "NetOutput"); - sub_output_op->AddInputDesc(GeTensorDesc(GeShape({1, 3, 2, 2}), FORMAT_NCHW, DT_FLOAT)); - auto sub_output_node = CreateNode(sub_output_op, sub_compute_graph); - ge::GraphUtils::AddEdge(sub_input_node->GetOutAnchor(0), sub_add_node->GetInAnchor(0)); - ge::GraphUtils::AddEdge(sub_input_node1->GetOutAnchor(0), sub_add_node->GetInAnchor(1)); - ge::GraphUtils::AddEdge(sub_add_node->GetOutAnchor(0), sub_output_node->GetInAnchor(0)); - case_op->AddSubgraphName(subgraph_name); - case_op->SetSubgraphInstanceName(i, subgraph_name); - sub_compute_graph->SetParentNode(case_node); - sub_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(sub_compute_graph); - } - } - - ModelSerialize serialize; - Model model_back; - auto buffer = serialize.SerializeModel(model); - ASSERT_NE(buffer.GetSize(), 0);// failed - proto::ModelDef model_def; - ASSERT_EQ(serialize.UnserializeModel(buffer.GetData(), buffer.GetSize(), model_back), false); -} - -TEST(UTEST_ge_model_serialize, test_large_model_lots_const) -{ - Model model("model_name/main_model", "custom version3.0"); - { - auto compute_graph = std::make_shared("graph_name/main_graph"); - // input - for (int i = 0; i < 4; i++) { - std::string inpu_node_name = "test/const" + std::to_string(i); - auto input_op = std::make_shared(inpu_node_name, CONSTANT); - input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto input = CreateNode(input_op, compute_graph); - GeTensor ge_tensor; - auto aligned_ptr = std::make_shared(536870912U); // 500m - auto ptr = aligned_ptr->MutableGet(); - *ptr = 7; - *(ptr + 10) = 8; - *(ptr + 536870910) = 9; - ge_tensor.SetData(aligned_ptr, 536870912); - AttrUtils::SetTensor(input_op, ATTR_NAME_WEIGHTS, ge_tensor); - } - - model.SetGraph(compute_graph); - - auto sub_compute_graph = std::make_shared("sub_graph"); - auto sub_graph_input_op = std::make_shared("sub_graph_test", "TestOp2"); - sub_graph_input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto sub_graph_input = CreateNode(sub_graph_input_op, sub_compute_graph); - - auto parent_input_op = std::make_shared("parenttest", "TestOp2"); - parent_input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto parent_input = CreateNode(parent_input_op, compute_graph); - for (int i = 0; i < 4; i++) { - std::string inpu_node_name = "subgraph/const" + std::to_string(i); - auto sub_input_op = std::make_shared(inpu_node_name, CONSTANT); - sub_input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto sub_input = CreateNode(sub_input_op, sub_compute_graph); - GeTensor ge_tensor; - auto aligned_ptr = std::make_shared(536870912U); // 500m - auto ptr = aligned_ptr->MutableGet(); - *ptr = 7; - *(ptr + 10) = 8; - *(ptr + 536870910) = 9; - ge_tensor.SetData(aligned_ptr, 536870912); - AttrUtils::SetTensor(sub_input_op, ATTR_NAME_WEIGHTS, ge_tensor); - } - std::string sub_graph = "sub_graph"; - parent_input_op->AddSubgraphName(sub_graph); - parent_input_op->SetSubgraphInstanceName(0, sub_graph); - sub_compute_graph->SetParentNode(parent_input); - sub_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(sub_compute_graph); - } - ModelSerialize serialize; - Model model_back; - auto buffer = serialize.SerializeModel(model); - ASSERT_NE(buffer.GetSize(), 0);// failed - proto::ModelDef model_def; - ASSERT_EQ(serialize.UnserializeModel(buffer.GetData(), buffer.GetSize(), model_back), true); - - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 10); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == CONSTANT) { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - ASSERT_EQ(buff[536870910], 9); // value is ok for def serialize - } - } - auto sub_graph = com_graph1->GetSubgraph("sub_graph"); - ASSERT_EQ((sub_graph == nullptr), false); - ASSERT_EQ(sub_graph->GetAllNodesSize(), 5); - system("rm -rf ./air_weight"); -} - -TEST(UTEST_ge_model_serialize, test_listSubGraph) -{ - Model model("model_name", "custom version3.0"); - { - auto computeGraph = std::make_shared("graph_name"); - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto input = CreateNode(inputOp, computeGraph); - model.SetGraph(computeGraph); - - auto subComputeGraph1 = std::make_shared("sub_graph1"); - // input - auto subGraphInputOp1 = std::make_shared("sub_graph_test1", "TestOp2"); - subGraphInputOp1->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto subGraphInput1 = CreateNode(subGraphInputOp1, subComputeGraph1); - - auto subComputeGraph2 = std::make_shared("sub_graph2"); - // input - auto subGraphInputOp2 = std::make_shared("sub_graph_test2", "TestOp2"); - subGraphInputOp2->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto subGraphInput2 = CreateNode(subGraphInputOp2, subComputeGraph2); - - AttrUtils::SetListGraph(inputOp, "sub_graph", vector{subComputeGraph1, subComputeGraph2}); - } - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - ASSERT_GE(buffer.GetSize(), 0); -} - -TEST(UTEST_ge_model_serialize, test_Format) -{ - Model model("model_name", "custom version3.0"); - { - auto computeGraph = std::make_shared("graph_name"); - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NHWC, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_ND, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NC1HWC0, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_FRACTAL_Z, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NC1C0HWPAD, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NHWC1C0, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_FSR_NCHW, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_FRACTAL_DECONV, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_BN_WEIGHT, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_CHWN, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_FILTER_HWCK, DT_FLOAT)); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_FRACTAL_Z_C04, DT_FLOAT)); - auto input = CreateNode(inputOp, computeGraph); - model.SetGraph(computeGraph); - } - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - ASSERT_GE(buffer.GetSize(), 0); -} - -TEST(UTEST_ge_model_serialize, test_ControlEdge) -{ - Model model("model_name", "custom version3.0"); - { - auto computeGraph = std::make_shared("graph_name"); - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto input = CreateNode(inputOp, computeGraph); - // sink - auto sinkOp = std::make_shared("test2", "Sink"); - auto sink = CreateNode(sinkOp, computeGraph); - LinkEdge(sink, -1, input, -1); - - // sink2 - auto sinkOp2 = std::make_shared("test3", "Sink"); - auto sink2 = CreateNode(sinkOp2, computeGraph); - LinkEdge(sink2, -1, input, -1); - - // dest - auto destOp = std::make_shared("test4", "Dest"); - auto dest = CreateNode(destOp, computeGraph); - LinkEdge(input, -1, dest, -1); - - computeGraph->AddInputNode(sink); - computeGraph->AddInputNode(sink2); - computeGraph->AddOutputNode(dest); - - model.SetGraph(computeGraph); - } - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - EXPECT_GE(buffer.GetSize(), 0); -} - -TEST(UTEST_ge_model_serialize, test_invalid_Model) -{ - {// empty graph - Model model("model_name", "custom version3.0"); - auto computeGraph = std::make_shared("graph_name"); - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - EXPECT_EQ(buffer.GetSize(), 0); - } -} - -TEST(UTEST_ge_model_serialize, test_invalid_Attrs) -{ - {// valid test - Model model("model_name", "custom version3.0"); - auto computeGraph = std::make_shared("graph_name"); - - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - - GeAttrValue::NamedAttrs namedAttrs; - namedAttrs.SetAttr("key1", GeAttrValue::CreateFrom(10)); - AttrUtils::SetNamedAttrs(inputOp, "key", namedAttrs); - - auto input = CreateNode(inputOp, computeGraph); - model.SetGraph(computeGraph); - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - EXPECT_GE(buffer.GetSize(), 0); - } - {// none type - Model model("model_name", "custom version3.0"); - auto computeGraph = std::make_shared("graph_name"); - - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - - GeAttrValue::NamedAttrs namedAttrs; - EXPECT_EQ(namedAttrs.SetAttr("key1", GeAttrValue()), GRAPH_FAILED); - } - {// bytes attr len is 0 - Model model("model_name", "custom version3.0"); - auto computeGraph = std::make_shared("graph_name"); - - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - - GeAttrValue::NamedAttrs namedAttrs; - namedAttrs.SetAttr("key1", GeAttrValue::CreateFrom(GeAttrValue::BYTES(0))); - AttrUtils::SetNamedAttrs(inputOp, "key", namedAttrs); - - auto input = CreateNode(inputOp, computeGraph); - model.SetGraph(computeGraph); - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - EXPECT_GE(buffer.GetSize(), 0); - } - {// invalid list bytes attr - Model model("model_name", "custom version3.0"); - auto computeGraph = std::make_shared("graph_name"); - - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - - GeAttrValue::NamedAttrs namedAttrs; - namedAttrs.SetAttr("key1", GeAttrValue::CreateFrom({GeAttrValue::BYTES(0)})); - AttrUtils::SetNamedAttrs(inputOp, "key", namedAttrs); - - auto input = CreateNode(inputOp, computeGraph); - model.SetGraph(computeGraph); - - ModelSerialize serialize; - auto buffer = serialize.SerializeModel(model); - EXPECT_GE(buffer.GetSize(), 0); - } - {// invalid graph attr - Model model("model_name", "custom version3.0"); - auto computeGraph = std::make_shared("graph_name"); - - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - - GeAttrValue::NamedAttrs namedAttrs; - EXPECT_EQ(namedAttrs.SetAttr("key1", GeAttrValue::CreateFrom(nullptr)), GRAPH_SUCCESS); - GeAttrValue value; - EXPECT_EQ(namedAttrs.GetAttr("key1", value), GRAPH_SUCCESS); - EXPECT_FALSE(value.IsEmpty()); - } - {// invalid list graph attr - Model model("model_name", "custom version3.0"); - auto computeGraph = std::make_shared("graph_name"); - - // input - auto inputOp = std::make_shared("test", "TestOp"); - inputOp->AddOutputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - - GeAttrValue::NamedAttrs namedAttrs; - EXPECT_EQ(namedAttrs.SetAttr("key1", GeAttrValue::CreateFrom({nullptr})), GRAPH_SUCCESS); - GeAttrValue value; - EXPECT_EQ(namedAttrs.GetAttr("key1", value), GRAPH_SUCCESS); - EXPECT_FALSE(value.IsEmpty()); - } -} - -TEST(UTEST_ge_model_serialize, test_ModelSerializeImp_Invalid_Param) -{ - ModelSerializeImp imp; - EXPECT_FALSE(imp.SerializeModel(Model(), nullptr)); - EXPECT_FALSE(imp.SerializeNode(nullptr, nullptr)); - - auto graph = std::make_shared("test_graph"); - auto node = graph->AddNode(std::make_shared()); - node->GetOpDesc() = nullptr; - proto::ModelDef modelDef; - Model model; - model.SetGraph(graph); - EXPECT_TRUE(imp.SerializeModel(model, &modelDef)); -} - -TEST(UTEST_ge_model_serialize, test_parse_node_false) -{ - ModelSerializeImp imp; - string node_index = "invalid_index"; - string node_name = "name"; - int32_t index = 1; - EXPECT_EQ(imp.ParseNodeIndex(node_index, node_name, index), false); -} - -TEST(UTEST_ge_model_unserialize, test_invalid_Attr) -{ - { // invalid graph - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - auto graphAttr = attrDef->mutable_g(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - ComputeGraphPtr graphAttrNew; - EXPECT_TRUE(AttrUtils::GetGraph(nodes.at(0)->GetOpDesc(), "key1", graphAttrNew)); - ASSERT_TRUE(graphAttrNew != nullptr); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(graphAttrNew, "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } - { // invalid list graph - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - attrDef->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_GRAPH); - auto graphAttr = attrDef->mutable_list()->add_g(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - vector graphListAttr; - EXPECT_TRUE(AttrUtils::GetListGraph(nodes.at(0)->GetOpDesc(), "key1", graphListAttr)); - ASSERT_EQ(graphListAttr.size(), 1); - ASSERT_TRUE(graphListAttr[0] != nullptr); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(graphListAttr[0], "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } - { // invalid namedAttrs - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - auto graphAttr = attrDef->mutable_func(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - GeAttrValue::NAMED_ATTRS namedAttrs; - EXPECT_TRUE(AttrUtils::GetNamedAttrs(nodes.at(0)->GetOpDesc(), "key1", namedAttrs)); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(namedAttrs, "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } - { // invalid list namedAttrs - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - attrDef->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS); - auto graphAttr = attrDef->mutable_list()->add_na(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - GeAttrValue::LIST_NAMED_ATTRS namedAttrs; - EXPECT_TRUE(AttrUtils::GetListNamedAttrs(nodes.at(0)->GetOpDesc(), "key1", namedAttrs)); - ASSERT_EQ(namedAttrs.size(), 1); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(namedAttrs.at(0), "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } - { // invalid tensorDesc - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - auto graphAttr = attrDef->mutable_td(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - GeTensorDesc tensorDesc; - EXPECT_TRUE(AttrUtils::GetTensorDesc(nodes.at(0)->GetOpDesc(), "key1", tensorDesc)); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(tensorDesc, "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } - { // invalid list tensorDesc - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - attrDef->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR_DESC); - auto graphAttr = attrDef->mutable_list()->add_td(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - vector tensorDesc; - EXPECT_TRUE(AttrUtils::GetListTensorDesc(nodes.at(0)->GetOpDesc(), "key1", tensorDesc)); - ASSERT_EQ(tensorDesc.size(), 1); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(tensorDesc.at(0), "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } - { // invalid tensor - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - auto graphAttr = attrDef->mutable_t()->mutable_desc(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - ConstGeTensorPtr tensor; - EXPECT_TRUE(AttrUtils::GetTensor(nodes.at(0)->GetOpDesc(), "key1", tensor)); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(tensor->GetTensorDesc(), "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } - { // invalid list tensor - proto::ModelDef modeDeff; - auto attrs = modeDeff.add_graph()->add_op()->mutable_attr(); // node attr - - proto::AttrDef* attrDef = &(*attrs)["key1"]; - attrDef->mutable_list()->set_val_type(ge::proto::AttrDef_ListValue_ListValueType_VT_LIST_TENSOR); - auto graphAttr = attrDef->mutable_list()->add_t()->mutable_desc(); - auto attrsOfGraph = graphAttr->mutable_attr(); - auto tensorVal = (*attrsOfGraph)["key2"].mutable_td(); - tensorVal->set_dtype(proto::DT_INT8); - tensorVal->set_layout("invalidLayout"); - - ModelSerializeImp imp; - Model model; - EXPECT_TRUE(imp.UnserializeModel(model, modeDeff)); - auto graph = model.GetGraph(); - ASSERT_TRUE(graph != nullptr); - auto nodes = graph->GetAllNodes(); - ASSERT_EQ(nodes.size(), 1); - vector tensor; - EXPECT_TRUE(AttrUtils::GetListTensor(nodes.at(0)->GetOpDesc(), "key1", tensor)); - ASSERT_EQ(tensor.size(), 1); - GeTensorDesc tensorDesc1; - EXPECT_TRUE(AttrUtils::GetTensorDesc(tensor.at(0)->GetTensorDesc(), "key2", tensorDesc1)); - EXPECT_EQ(tensorDesc1.GetFormat(), FORMAT_RESERVED); - EXPECT_EQ(tensorDesc1.GetDataType(), DT_INT8); - } -} -TEST(UTEST_ge_model_unserialize, RebuildOwnershipTest) -{ - ge::ModelSerializeImp serialize_imp; - float tensor_data[224 * 224] = {1.0f}; - ComputeGraphPtr compute_graph = CreateGraph_1_1_224_224(tensor_data); - std::map subgraphs; - bool ret = serialize_imp.RebuildOwnership(compute_graph, subgraphs); - EXPECT_EQ(ret, true); -} - -TEST(UTEST_ge_model_unserialize, UnserializeModelTest) -{ - ge::ModelSerialize serialize; - ge::proto::ModelDef model_def; - Model model; - bool ret = serialize.UnserializeModel(model_def, model); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, SerializeGraphGraphIsNull) -{ - ge::ModelSerializeImp model_serialize_imp; - ConstComputeGraphPtr graph; - proto::GraphDef *graph_proto = nullptr; - bool is_dump = true; - bool ret = model_serialize_imp.SerializeGraph(graph, graph_proto, is_dump); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, SerializeAllAttrsFromAnyMapMutableAttrIsNull) -{ - ge::ModelSerializeImp model_serialize_imp; - std::map attr_map; - google::protobuf::Map *mutable_attr = nullptr; - bool ret = model_serialize_imp.SerializeAllAttrsFromAnyMap(attr_map, mutable_attr); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, DeserializeAllAttrsToAttrHolderHolderIsNull) -{ - ge::ModelSerializeImp model_serialize_imp; - google::protobuf::Map proto_attr_map; - AttrHolder *attr_holder = nullptr; - bool ret = model_serialize_imp.DeserializeAllAttrsToAttrHolder(proto_attr_map, attr_holder); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, UnserializeModelDataIsNull) -{ - ge::ModelSerialize serialize; - uint8_t *data = nullptr; - size_t len = 1; - Model model; - bool ret = serialize.UnserializeModel(data, len, model); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, HandleNodeNameEdgesSrcNodeIsNull) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test1"); - auto dst_node = builder.AddNode("dst_node", "NetOutput", 1, 0); - NodeNameNodeReq node_req("src_node", 1, -1, dst_node, 0, "dst_node"); - model_impl.node_input_node_names_.push_back(node_req); - bool ret = model_impl.HandleNodeNameRef(); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, HandleNodeNameEdgesSrcAnchorIsNull) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test1"); - auto dst_node = builder.AddNode("dst_node", "NetOutput", 1, 0); - NodeNameNodeReq node_req("src_node", 1, -1, dst_node, 0, "dst_node"); - model_impl.node_input_node_names_.push_back(node_req); - model_impl.node_map_.insert(pair("src_node", dst_node)); - bool ret = model_impl.HandleNodeNameRef(); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, HandleNodeNameControlEdgeSuccess) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test1"); - auto dst_node = builder.AddNode("dst_node", "NetOutput", 1, 0); - NodeNameNodeReq node_req("src_node", -1, -1, dst_node, 0, "dst_node"); - model_impl.node_input_node_names_.push_back(node_req); - model_impl.node_map_.insert(pair("src_node", dst_node)); - bool ret = model_impl.HandleNodeNameRef(); - EXPECT_EQ(ret, true); -} - -TEST(UTEST_ge_model_unserialize, HandleNodeNameGraphInputNodeMapIsNull) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test"); - auto graph = builder.GetGraph(); - NodeNameGraphReq graph_req("node1", 1, graph); - model_impl.graph_input_node_names_.push_back(graph_req); - bool ret = model_impl.HandleNodeNameRef(); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, HandleNodeNameGraphInputFail) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test"); - auto graph = builder.GetGraph(); - auto node1 = builder.AddNode("node1", "NetOutput", 1, 0); - NodeNameGraphReq graph_req("node1", 1, graph); - model_impl.graph_input_node_names_.push_back(graph_req); - model_impl.node_map_.insert(pair("node1", nullptr)); - bool ret = model_impl.HandleNodeNameRef(); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, HandleNodeNameGraphOutputNodeMapIsNull) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test"); - auto graph = builder.GetGraph(); - NodeNameGraphReq graph_req("node1", 1, graph); - model_impl.graph_output_node_names_.push_back(graph_req); - bool ret = model_impl.HandleNodeNameRef(); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, HandleNodeNameGraphOutputSuccess) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test"); - auto graph = builder.GetGraph(); - auto node1 = builder.AddNode("node1", "Data", 1, 0); - NodeNameGraphReq graph_req("node1", 1, graph); - model_impl.graph_output_node_names_.push_back(graph_req); - model_impl.node_map_.insert(pair("node1", node1)); - bool ret = model_impl.HandleNodeNameRef(); - EXPECT_EQ(ret, true); -} - -TEST(UTEST_ge_model_unserialize, UnserializeGraphFail) -{ - ge::ModelSerializeImp model_impl; - auto builder = ut::GraphBuilder("test"); - auto graph = builder.GetGraph(); - NodeNameGraphReq graph_req("node1", 1, graph); - model_impl.graph_input_node_names_.push_back(graph_req); - proto::GraphDef graph_proto; - bool ret = model_impl.UnserializeGraph(graph, graph_proto); - EXPECT_EQ(ret, false); -} - -TEST(UTEST_ge_model_unserialize, SerializeAttrGroupFailed) -{ - ge::ModelSerializeImp model_serialize_imp; - std::shared_ptr graph = std::make_shared("test_graph"); - graph->GetOrCreateAttrsGroup()->status = GRAPH_FAILED; - proto::GraphDef graph_proto; - bool is_dump = true; - bool ret = model_serialize_imp.SerializeGraph(graph, &graph_proto, is_dump); - EXPECT_EQ(ret, false); - Model model; - model.GetOrCreateAttrsGroup()->status = GRAPH_FAILED; - proto::ModelDef model_proto; - model_serialize_imp.SerializeModel(model, &model_proto); - EXPECT_EQ(ret, false); - - GeShape shape({64, 32, 16, 64}); - GeTensorDesc desc(shape, FORMAT_NCHW, DT_FLOAT16); - desc.GetOrCreateAttrsGroup()->status = GRAPH_FAILED; - - proto::TensorDescriptor desc_proto; - // 异常场景无返回值供校验 - GeTensorSerializeUtils::GeTensorDescAsProto(desc, &desc_proto); - - auto op_desc = std::make_shared(); - op_desc->GetOrCreateAttrsGroup()->status = GRAPH_FAILED; - proto::OpDef op_def; - EXPECT_EQ(model_serialize_imp.SerializeOpDesc(op_desc, &op_def), true); -} \ No newline at end of file diff --git a/tests/ut/graph/testcase/model_unittest.cc b/tests/ut/graph/testcase/model_unittest.cc deleted file mode 100644 index f83615868ab14e62e260092b29c650e87736cf0e..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/model_unittest.cc +++ /dev/null @@ -1,442 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "test_structs.h" -#include "func_counter.h" -#include "graph/buffer.h" -#include "graph/attr_store.h" -#include "graph/model.h" -#include "graph/node.h" -#include "graph_builder_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/utils/node_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/file_utils.h" -#include "mmpa/mmpa_api.h" - -namespace ge { -namespace { -constexpr size_t kSmallBufferSize = 32UL; -constexpr size_t kLargeBufferSize = 536870912U; -class SubModel : public Model -{ -public: - - SubModel(); - SubModel(const std::string &name, const std::string &custom_version); - - virtual ~SubModel(); - -}; - -SubModel::SubModel(){} -SubModel::SubModel(const std::string &name, const std::string &custom_version):Model(name,custom_version){} - -SubModel::~SubModel() = default; - -} - -static Model BuildModelWithConst(bool large_weight) { - Model model("model_name/main_model", "custom version3.0"); - auto compute_graph = std::make_shared("graph_name/main_graph"); - size_t buffer_size = large_weight ? (kLargeBufferSize) : kSmallBufferSize; - // input - for (int i = 0; i < 4; i++) { - std::string inpu_node_name = "test/const" + std::to_string(i); - auto input_op = std::make_shared(inpu_node_name, "Const"); - input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto input = compute_graph->AddNode(input_op); - GeTensor ge_tensor; - auto aligned_ptr = std::make_shared(buffer_size); // 500m - auto ptr = aligned_ptr->MutableGet(); - *ptr = 7; - *(ptr + 10) = 8; - *(ptr + buffer_size - 2) = 9; - ge_tensor.SetData(aligned_ptr, buffer_size); - AttrUtils::SetTensor(input_op, ATTR_NAME_WEIGHTS, ge_tensor); - } - model.SetGraph(compute_graph); - auto sub_compute_graph = std::make_shared("sub_graph"); - auto sub_graph_input_op = std::make_shared("sub_graph_test", "TestOp2"); - sub_graph_input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto sub_graph_input = sub_compute_graph->AddNode(sub_graph_input_op); - - auto parent_input_op = std::make_shared("parenttest", "TestOp2"); - parent_input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto parent_input = compute_graph->AddNode(parent_input_op); - for (int i = 0; i < 4; i++) { - std::string inpu_node_name = "subgraph/const" + std::to_string(i); - auto sub_input_op = std::make_shared(inpu_node_name, "Const"); - sub_input_op->AddInputDesc(GeTensorDesc(GeShape({12, 32, 64, 64}), FORMAT_NCHW, DT_FLOAT)); - auto sub_input = sub_compute_graph->AddNode(sub_input_op); - GeTensor ge_tensor; - auto aligned_ptr = std::make_shared(buffer_size); // 500m - auto ptr = aligned_ptr->MutableGet(); - *ptr = 7; - *(ptr + 10) = 8; - *(ptr + buffer_size - 2) = 9; - ge_tensor.SetData(aligned_ptr, buffer_size); - AttrUtils::SetTensor(sub_input_op, ATTR_NAME_WEIGHTS, ge_tensor); - } - std::string sub_graph = "sub_graph"; - parent_input_op->AddSubgraphName(sub_graph); - parent_input_op->SetSubgraphInstanceName(0, sub_graph); - sub_compute_graph->SetParentNode(parent_input); - sub_compute_graph->SetParentGraph(compute_graph); - compute_graph->AddSubgraph(sub_compute_graph); - compute_graph->TopologicalSorting(); - return model; -} - -static Model BuildModelWithLargeConst() { - return BuildModelWithConst(true); -} - -static Graph BuildGraph() { - ge::OpDescPtr add_op(new ge::OpDesc("add1", "Add")); - add_op->AddDynamicInputDesc("input", 2); - add_op->AddDynamicOutputDesc("output", 1); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - auto graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - return graph; -} - -class ModelUt : public testing::Test {}; -TEST_F(ModelUt, SetGet) { - auto md = SubModel(); - auto md2 = SubModel("md2", "test"); - EXPECT_EQ(md.GetName(),""); - md.SetName("tt"); - EXPECT_EQ(md.GetName(),"tt"); - EXPECT_EQ(md2.GetName(),"md2"); - md2.SetName("md2tt"); - EXPECT_EQ(md2.GetName(),"md2tt"); - EXPECT_EQ(md.GetVersion(),0); - EXPECT_EQ(md2.GetVersion(),0); - EXPECT_EQ(md2.GetPlatformVersion(),"test"); - - auto graph = BuildGraph(); - EXPECT_EQ(graph.IsValid(),true); - md2.SetGraph(GraphUtilsEx::GetComputeGraph(graph)); - auto g = md2.GetGraph(); - EXPECT_NE(&g, nullptr); - Buffer buf = Buffer(1024); - EXPECT_EQ(buf.GetSize(),1024); - EXPECT_EQ(md2.IsValid(),true); - ProtoAttrMap attr = AttrStore::Create(512); - AttrId id = 1; - int val = 100; - attr.Set(id, val); - const int* v = attr.Get(id); - EXPECT_EQ(*v,val); - md2.SetAttr(attr); - EXPECT_EQ(md2.Save(buf,true), GRAPH_SUCCESS); -} - -TEST_F(ModelUt, Load) { - auto md = SubModel("md2", "test"); - auto graph = BuildGraph(); - md.SetGraph(GraphUtilsEx::GetComputeGraph(graph)); - uint8_t b[5]; - memset(b,1,5); - EXPECT_EQ(md.Load((const uint8_t*)b, 5, md),GRAPH_FAILED); - - std::string msg = "package lm;\nmessage helloworld{\nrequired int32 id = 1;\nrequired string str = 2;\noptional int32 opt = 3;}"; - std::ofstream outfile; - outfile.open("./test_load.proto"); - outfile << msg; - EXPECT_EQ(md.LoadFromFile("./test_load.proto"), GRAPH_SUCCESS); - system("rm -rf ./test_load.proto"); - outfile.close(); - -} - -TEST_F(ModelUt, Save) { - EXPECT_NO_THROW( - auto md = SubModel("md2", "test"); - auto graph = BuildGraph(); - md.SetGraph(GraphUtilsEx::GetComputeGraph(graph)); - std::stringstream ss; - ss << "./test_save.proto"; - md.SaveToFile(ss.str()); - std::string cmd = "rm -rf " + ss.str(); - system(cmd.c_str()); - ); -} - -TEST_F(ModelUt, Save_Failure) { - auto md = SubModel("md2", "test"); - auto graph = BuildGraph(); - md.SetGraph(GraphUtilsEx::GetComputeGraph(graph)); - std::stringstream fn; - fn << "/tmp/"; - for (int i = 0; i < 4096; i++){ - fn << "a"; - } - fn << ".proto"; - md.SaveToFile(fn.str()); - EXPECT_EQ(md.SaveToFile("/proc/non.proto"), GRAPH_FAILED); -} - -TEST_F(ModelUt, Load_Longname) { - auto md = SubModel("md2", "test"); - std::stringstream fn; - fn << "/tmp/"; - for (int i = 0; i < 4096; i++){ - fn << "a"; - } - fn << ".proto"; - EXPECT_EQ(md.LoadFromFile(fn.str()),GRAPH_FAILED); -} - -TEST_F(ModelUt, Load_Nonfilename) { - auto md = SubModel("md2", "test"); - EXPECT_EQ(md.LoadFromFile("/tmp/non-exsit"),GRAPH_FAILED); -} - -TEST_F(ModelUt, SaveLargeModelWithoutSeparate) { - auto md = BuildModelWithLargeConst(); - Buffer buf = Buffer(1024); - EXPECT_EQ(buf.GetSize(),1024); - EXPECT_EQ(md.IsValid(),true); - EXPECT_EQ(md.SaveWithoutSeparate(buf), GRAPH_FAILED); -} - -TEST_F(ModelUt, SaveLargeModelWithRealPath) { - auto md = BuildModelWithLargeConst(); - std::string tmp_file_name = "./"; - - std::string real_path = ge::RealPath(tmp_file_name.c_str()) + "/model.air"; - std::string clear_cmd = "rm -rf " + real_path; - system(clear_cmd.c_str()); - EXPECT_EQ(md.SaveToFile(real_path), GRAPH_SUCCESS); - Model model_back(nullptr, nullptr); - EXPECT_EQ(model_back.LoadFromFile(real_path), GRAPH_SUCCESS); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 10); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == "Const") { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - ASSERT_EQ(buff[kLargeBufferSize - 2], 9); // value is ok for def serialize - } - } - auto sub_graph = com_graph1->GetSubgraph("sub_graph"); - ASSERT_EQ((sub_graph == nullptr), false); - ASSERT_EQ(sub_graph->GetAllNodesSize(), 5); - system("rm -rf /tmp/test/air_weight"); - system(clear_cmd.c_str()); -} - -TEST_F(ModelUt, SaveLargeModelWithRelatedPath) { - auto md = BuildModelWithLargeConst(); - std::string file_name = "./temp/model.air"; - std::string clear_cmd = "rm -rf " + file_name; - system(clear_cmd.c_str()); - EXPECT_EQ(md.SaveToFile(file_name), GRAPH_SUCCESS); - Model model_back; - EXPECT_EQ(model_back.LoadFromFile(file_name), GRAPH_SUCCESS); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 10); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == "Const") { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - ASSERT_EQ(buff[kLargeBufferSize - 2], 9); // value is ok for def serialize - } - } - auto sub_graph = com_graph1->GetSubgraph("sub_graph"); - ASSERT_EQ((sub_graph == nullptr), false); - ASSERT_EQ(sub_graph->GetAllNodesSize(), 5); - system("rm -rf ./temp/air_weight"); - system(clear_cmd.c_str()); -} - -TEST_F(ModelUt, SaveLargeModelSeparateWithRelatedPath) { - auto md = BuildModelWithConst(false); - auto graph = md.GetGraph(); - // const node0 reuse const node1 weight - auto const_node0 = graph->FindNode("test/const0"); - ASSERT_NE(const_node0, nullptr); - auto const_node1 = graph->FindNode("test/const1"); - ASSERT_NE(const_node1, nullptr); - ASSERT_TRUE(AttrUtils::SetBool(const_node0->GetOpDesc(), ATTR_NAME_IS_REUSE_EXTERNAL_WEIGHT, true)); - std::string op_tag = const_node1->GetType() + "_" + graph->GetName() + "_" + const_node1->GetName() + "_file"; - std::string regulated_op_tag = ge::GetRegulatedName(op_tag); - std::string reuse_offset_path = "air_weight/model_name_main_model/" + regulated_op_tag; - ASSERT_EQ(AttrUtils::SetInt(const_node0->GetOpDesc(), ge::ATTR_NAME_LENGTH, kSmallBufferSize), true); - ASSERT_EQ(AttrUtils::SetStr(const_node0->GetOpDesc(), ge::ATTR_NAME_LOCATION, reuse_offset_path), true); - std::string file_name = "./temp/model.air"; - system("rm -rf ./temp/model.air"); - EXPECT_EQ(md.SaveToFile(file_name, true), GRAPH_SUCCESS); - Model model_back; - EXPECT_EQ(model_back.LoadFromFile("./temp/model.air"), GRAPH_SUCCESS); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 10); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == "Const") { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - ASSERT_EQ(buff[kSmallBufferSize - 2], 9); // value is ok for def serialize - } - } - auto sub_graph = com_graph1->GetSubgraph("sub_graph"); - ASSERT_EQ((sub_graph == nullptr), false); - ASSERT_EQ(sub_graph->GetAllNodesSize(), 5); - system("rm -rf ./temp/air_weight"); - system("rm -rf ./temp/model.air"); -} - -TEST_F(ModelUt, SaveLargeModelWithRelatedPath2) { - auto md = BuildModelWithLargeConst(); - std::string tmp_file_name = "./"; - - std::string real_path = ge::RealPath(tmp_file_name.c_str()) + "/model.air"; - std::string clear_cmd = "rm -rf " + real_path; - system(clear_cmd.c_str()); - EXPECT_EQ(md.SaveToFile(real_path), GRAPH_SUCCESS); - Model model_back; - EXPECT_EQ(model_back.LoadFromFile(real_path), GRAPH_SUCCESS); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 10); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == "Const") { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - ASSERT_EQ(buff[kLargeBufferSize - 2], 9); // value is ok for def serialize - } - } - auto sub_graph = com_graph1->GetSubgraph("sub_graph"); - ASSERT_EQ((sub_graph == nullptr), false); - ASSERT_EQ(sub_graph->GetAllNodesSize(), 5); - system("rm -rf ./air_weight"); - system(clear_cmd.c_str()); -} - -TEST_F(ModelUt, SaveLargeModelWithRelatedPath3) { - auto md = BuildModelWithLargeConst(); - std::string file_name = "model.air"; - std::string clear_cmd = "rm -rf " + file_name; - system(clear_cmd.c_str()); - EXPECT_EQ(md.SaveToFile(file_name), GRAPH_SUCCESS); - Model model_back; - EXPECT_EQ(model_back.LoadFromFile("model.air"), GRAPH_SUCCESS); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - com_graph1 = model_back.GetGraph(); - ASSERT_EQ(com_graph1->GetAllNodesSize(), 10); - for (auto &node_ptr : com_graph1->GetAllNodes()) { - ASSERT_EQ((node_ptr == nullptr), false); - if (node_ptr->GetType() == "Const") { - auto op_desc = node_ptr->GetOpDesc(); - ASSERT_EQ((op_desc == nullptr), false); - ConstGeTensorPtr ge_tensor_ptr; - ASSERT_EQ(AttrUtils::GetTensor(op_desc, ATTR_NAME_WEIGHTS, ge_tensor_ptr), true); - ASSERT_EQ((ge_tensor_ptr == nullptr), false); - const TensorData tensor_data = ge_tensor_ptr->GetData(); - const uint8_t *buff = tensor_data.GetData(); - ASSERT_EQ((buff == nullptr), false); - ASSERT_EQ(buff[0], 7); - ASSERT_EQ(buff[10], 8); - ASSERT_EQ(buff[kLargeBufferSize - 2], 9); // value is ok for def serialize - } - } - auto sub_graph = com_graph1->GetSubgraph("sub_graph"); - ASSERT_EQ((sub_graph == nullptr), false); - ASSERT_EQ(sub_graph->GetAllNodesSize(), 5); - system("rm -rf ./air_weight"); - system(clear_cmd.c_str()); -} - -TEST_F(ModelUt, LoadLargeModelWithWrongWeight) { - auto md = BuildModelWithLargeConst(); - auto graph = md.GetGraph(); - graph->SetName("graph"); - std::string file_name = "./model.air"; - std::string clear_cmd = "rm -rf " + file_name; - system(clear_cmd.c_str()); - EXPECT_EQ(md.SaveToFile(file_name), GRAPH_SUCCESS); - std::string weight_path = "./air_weight/model_name_main_model/Const_graph_test_const0_file"; - char real_path[1280]; - EXPECT_NE(realpath(weight_path.c_str(), real_path), nullptr); - std::ofstream ofs(real_path, std::ios::out | std::ofstream::app); - const char_t *data = "a"; - if (ofs.is_open()) { - ofs << data << std::endl; - ofs.close(); - } - - Model model_back; - EXPECT_NE(model_back.LoadFromFile("./model.air"), GRAPH_SUCCESS); - system("rm -rf ./air_weight"); - system(clear_cmd.c_str()); -} - -TEST_F(ModelUt, SaveModelWithAscendWorkPath) { - ge::char_t current_path[MMPA_MAX_PATH] = {'\0'}; - getcwd(current_path, MMPA_MAX_PATH); - mmSetEnv("ASCEND_WORK_PATH", current_path, 1); - auto md = BuildModelWithLargeConst(); - std::string file_name = "model.air"; - std::string clear_cmd = "rm -rf " + file_name; - system(clear_cmd.c_str()); - EXPECT_EQ(md.SaveToFile(file_name), GRAPH_SUCCESS); - Model model_back; - std::string file_path = current_path; - file_path += "/" + file_name; - EXPECT_EQ(model_back.LoadFromFile(file_path), GRAPH_SUCCESS); - unsetenv("ASCEND_WORK_PATH"); - system(clear_cmd.c_str()); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/node_unittest.cc b/tests/ut/graph/testcase/node_unittest.cc deleted file mode 100644 index c9f32a87dcb5797d7c9b3dfab92009d493c27b74..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/node_unittest.cc +++ /dev/null @@ -1,402 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - - -#include "graph/node.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/any_value.h" -#include "graph/anchor.h" -#include "graph/op_desc.h" -#include "graph/utils/graph_utils.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph_builder_utils.h" -#include "graph/operator_factory_impl.h" -#include "graph/utils/node_utils_ex.h" - -#include - -namespace ge { -class UtestNode : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -template -std::shared_ptr MakeNullptr(){ - return nullptr; -} - -Operator CreateOp(const AscendString& name){ - return Operator(); -} - - -TEST_F(UtestNode, GetInDataAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - auto graph = builder.GetGraph(); - - auto data_node = graph->FindNode("Data"); - auto in_data_anchor0 = data_node->GetInDataAnchor(0); - EXPECT_NE(in_data_anchor0, nullptr); - - auto in_data_anchor1 = data_node->GetInDataAnchor(1); - EXPECT_EQ(in_data_anchor1, nullptr); - - auto in_data_anchor2 = data_node->GetInDataAnchor(-1); - EXPECT_EQ(in_data_anchor2, nullptr); -} -TEST_F(UtestNode, GetInAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - auto graph = builder.GetGraph(); - - auto data_node = graph->FindNode("Data"); - auto in_anchor0 = data_node->GetInAnchor(-2); - EXPECT_EQ(in_anchor0, nullptr); -} -TEST_F(UtestNode, GetOutAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - auto graph = builder.GetGraph(); - - auto data_node = graph->FindNode("Data"); - auto out_anchor0 = data_node->GetOutAnchor(-2); - EXPECT_EQ(out_anchor0, nullptr); -} - -TEST_F(UtestNode, NodeInputAndOutCheck) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - { - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto graph = builder.GetGraph(); - EXPECT_TRUE(data_node->Init() == data_node->Init()); - EXPECT_EQ(data_node->SetOwnerComputeGraph(nullptr), GRAPH_PARAM_INVALID); - EXPECT_EQ(data_node->ClearOwnerGraph(graph), GRAPH_SUCCESS); - EXPECT_EQ(data_node->GetAllInDataAnchors().size(), 1); - EXPECT_EQ(data_node->GetAllOutDataAnchors().size(), 1); - EXPECT_EQ(data_node->NodeMembersAreEqual(*data_node), true); - EXPECT_EQ(data_node->AddLinkFromForParse(attr_node), GRAPH_PARAM_INVALID); - EXPECT_EQ(data_node->AddLinkFrom(attr_node), GRAPH_PARAM_INVALID); - EXPECT_EQ(data_node->AddLinkFrom(2, attr_node), GRAPH_PARAM_INVALID); - EXPECT_EQ(data_node->AddLinkFrom("Attr", attr_node), GRAPH_PARAM_INVALID); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - OutDataAnchorPtr out_anch = std::make_shared(data_node, 222); - EXPECT_EQ(data_node->NodeAnchorIsEqual(nullptr, in_anch, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, nullptr, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, out_anch, 1), true); - auto node3 = builder.AddNode("Data3", "Data3", 3, 3); - InControlAnchorPtr inc_anch = std::make_shared(node3, 33); - EXPECT_EQ(out_anch->LinkTo(inc_anch), GRAPH_SUCCESS); - EXPECT_EQ(data_node->NodeAnchorIsEqual(out_anch, inc_anch, 1), false); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_SUCCESS); - EXPECT_EQ(attr_node->AddLinkFromForParse(data_node), GRAPH_SUCCESS); - EXPECT_EQ(attr_node->AddLinkFrom(2, data_node), GRAPH_SUCCESS); - EXPECT_EQ(attr_node->AddLinkFrom("Attr", data_node), GRAPH_SUCCESS); - EXPECT_EQ(data_node->GetOutNodes().size(), 3U); - EXPECT_EQ(data_node->GetOutNodesPtr().size(), 3U); - EXPECT_EQ(data_node->GetOutDataNodes().size(), 3U); - EXPECT_EQ(data_node->GetOutDataNodesSize(), 3U); - EXPECT_EQ(attr_node->GetInNodes().size(), 3U); - EXPECT_EQ(attr_node->GetInNodesPtr().size(), 3U); - EXPECT_EQ(attr_node->GetInNodesSize(), 3U); - EXPECT_EQ(attr_node->GetInDataNodesSize(), 3U); - EXPECT_EQ(attr_node->GetInDataNodes().size(), 3U); - EXPECT_EQ(attr_node->GetInControlNodesSize(), 0U); - EXPECT_EQ(attr_node->GetInControlNodes().size(), 0U); - builder.AddControlEdge(data_node, attr_node); - EXPECT_EQ(attr_node->GetInNodes().size(), 4U); - EXPECT_EQ(attr_node->GetInNodesPtr().size(), 4U); - EXPECT_EQ(attr_node->GetInNodesSize(), 4U); - EXPECT_EQ(attr_node->GetInDataNodesSize(), 3U); - EXPECT_EQ(attr_node->GetInDataNodes().size(), 3U); - EXPECT_EQ(attr_node->GetInControlNodesSize(), 1U); - EXPECT_EQ(attr_node->GetInControlNodes().size(), 1U); - EXPECT_EQ(data_node->GetOutNodes().size(), 4U); - EXPECT_EQ(data_node->GetOutNodesPtr().size(), 4U); - EXPECT_EQ(data_node->GetOutNodesSize(), 4U); - EXPECT_EQ(data_node->GetOutControlNodesSize(), 1U); - EXPECT_EQ(data_node->GetOutDataNodesSize(), 3U); - data_node->impl_->out_data_anchors_.push_back(nullptr); - EXPECT_EQ(data_node->GetOutNodesPtr().size(), 4U); - EXPECT_EQ(data_node->GetOutNodes().size(), 4U); - EXPECT_EQ(GraphUtils::RemoveNodeWithoutRelink(builder.GetGraph(), data_node), GRAPH_SUCCESS); - } - EXPECT_EQ(attr_node->GetInNodes().size(), 0U); - EXPECT_EQ(attr_node->GetInNodesSize(), 0U); - EXPECT_EQ(attr_node->GetInDataNodesSize(), 0U); - EXPECT_EQ(attr_node->GetInDataNodes().size(), 0U); - EXPECT_EQ(attr_node->GetInControlNodesSize(), 0U); - EXPECT_EQ(attr_node->GetInControlNodes().size(), 0U); -} - -TEST_F(UtestNode, NodeInputAndOutBarePtrCheck) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("DataNode", "Data", 1, 1); - auto attr_node = builder.AddNode("AttrNode", "Attr", 2, 2); - data_node->GetOutControlAnchor()->LinkTo(attr_node->GetInControlAnchor()); - data_node->GetOutDataAnchor(0)->LinkTo(attr_node->GetInDataAnchor(0)); - auto graph = builder.GetGraph(); - - EXPECT_EQ(data_node->GetInNodes().size(), 0U); - EXPECT_EQ(data_node->GetInNodesSize(), 0U); - EXPECT_EQ(data_node->GetAllInAnchors().size(), 1U); - EXPECT_EQ(data_node->GetAllInAnchorsPtr().size(), 2U); - EXPECT_EQ(data_node->GetOutNodes().size(), 2U); - EXPECT_EQ(data_node->GetOutNodesSize(), 2U); - EXPECT_EQ(data_node->GetAllOutAnchors().size(), 2U); - EXPECT_EQ(data_node->GetAllOutAnchorsPtr().size(), 2U); - EXPECT_EQ(data_node->GetOutDataNodes().size(), 1U); - EXPECT_EQ(data_node->GetOutDataNodesPtr().size(), 1U); - EXPECT_EQ(data_node->GetOutDataNodesSize(), 1U); - - EXPECT_EQ(attr_node->GetInNodes().size(), 2U); - EXPECT_EQ(attr_node->GetInNodesSize(), 2U); - EXPECT_EQ(attr_node->GetAllInAnchors().size(), 3U); - EXPECT_EQ(attr_node->GetAllInAnchorsPtr().size(), 3U); - EXPECT_EQ(attr_node->GetInDataNodes().size(), 1U); - EXPECT_EQ(attr_node->GetInDataNodesSize(), 1U); - EXPECT_EQ(attr_node->GetOutNodes().size(), 0U); - EXPECT_EQ(attr_node->GetOutNodesSize(), 0U); - EXPECT_EQ(attr_node->GetAllOutAnchors().size(), 2U); - EXPECT_EQ(attr_node->GetAllOutAnchorsPtr().size(), 3U); - EXPECT_EQ(attr_node->GetOutDataNodesPtr().size(), 0U); -} - -TEST_F(UtestNode, GetCase) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("DataNode", "Data", 1, 1); - auto attr_node = builder.AddNode("AttrNode", "Attr", 2, 2); - auto graph = builder.GetGraph(); - EXPECT_EQ(data_node->GetName(), "DataNode"); - EXPECT_EQ(data_node->GetType(), "Data"); - EXPECT_EQ(std::string(attr_node->GetNamePtr()), "AttrNode"); - EXPECT_EQ(std::string(attr_node->GetTypePtr()), "Attr"); - EXPECT_EQ(data_node->GetAllInAnchors().size(), 1); - EXPECT_EQ(attr_node->GetAllOutAnchors().size(), 2); - EXPECT_EQ(data_node->GetInNodes().size(), 0); - EXPECT_EQ(data_node->GetInNodesPtr().size(), 0); - EXPECT_EQ(attr_node->GetOutNodes().size(), 0); - EXPECT_EQ(attr_node->GetOutDataNodes().size(), 0); - EXPECT_EQ(NodeUtilsEx::InferShapeAndType(attr_node), GRAPH_PARAM_INVALID); - EXPECT_EQ(attr_node->GetOutDataNodesAndAnchors().size(), 0); - EXPECT_EQ(data_node->NodeInConnectsAreEqual(*attr_node), false); - EXPECT_EQ(data_node->NodeOutConnectsAreEqual(*attr_node), false); - EXPECT_EQ(attr_node->NodeInConnectsAreEqual(*data_node), false); - EXPECT_EQ(attr_node->NodeOutConnectsAreEqual(*data_node), false); - EXPECT_EQ((*data_node)==(*attr_node), false); - std::unordered_set us; - us.insert(data_node.get()); - EXPECT_EQ(attr_node->IsAllInNodesSeen(us), true); - data_node->AddSendEventId(10); - data_node->AddRecvEventId(20); - EXPECT_EQ(data_node->GetSendEventIdList().size(), 1); - EXPECT_EQ(data_node->GetRecvEventIdList().size(), 1); - kFusionDataFlowVec_t fusion_input_list; - data_node->GetFusionInputFlowList(fusion_input_list); - data_node->SetFusionInputFlowList(fusion_input_list); - EXPECT_EQ(fusion_input_list.size(), 0); - kFusionDataFlowVec_t fusion_output_list; - data_node->GetFusionOutputFlowList(fusion_output_list); - data_node->SetFusionOutputFlowList(fusion_output_list); - EXPECT_EQ(fusion_output_list.size(), 0); - EXPECT_EQ(data_node->GetHostNode(), false); - data_node->SetOrigNode(attr_node); - EXPECT_NE(data_node->GetOrigNode(), nullptr); - OpDescPtr opd = std::make_shared("Opdesc","OpdType"); - EXPECT_EQ(data_node->UpdateOpDesc(opd), GRAPH_PARAM_INVALID); -} - -TEST_F(UtestNode, IsAllInNodesSeenSuccess) { - auto builder = ut::GraphBuilder("graph"); - const auto &node1 = builder.AddNode("node1", "node1", 2, 2); - const auto &node2 = builder.AddNode("node2", "node2", 1, 1); - builder.AddDataEdge(node1, 0, node2, 0); - builder.AddControlEdge(node1, node2); - auto graph = builder.GetGraph(); - std::unordered_set us; - us.insert(node2.get()); - us.insert(node1.get()); - EXPECT_EQ(node2->IsAllInNodesSeen(us), true); -} - -TEST_F(UtestNode, NodeInConnectsAreEqual) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - EXPECT_EQ(data_node->NodeAnchorIsEqual(nullptr, in_anch, 1), false); - data_node->impl_->in_data_anchors_.push_back(in_anch); - EXPECT_EQ(data_node->GetAllInDataAnchors().size(), 2); - EXPECT_EQ(data_node->GetAllInDataAnchorsPtr().size(), 2); - EXPECT_EQ(attr_node->GetAllInDataAnchors().size(), 2); - EXPECT_EQ(data_node->NodeInConnectsAreEqual(*attr_node), true); -} - -TEST_F(UtestNode, NodeOutConnectsAreEqual) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - OutDataAnchorPtr out_anch = std::make_shared(data_node, 111); - EXPECT_EQ(data_node->NodeAnchorIsEqual(nullptr, out_anch, 1), false); - data_node->impl_->out_data_anchors_.push_back(out_anch); - EXPECT_EQ(data_node->GetAllOutDataAnchors().size(), 2); - EXPECT_EQ(data_node->GetAllOutDataAnchorsPtr().size(), 2); - EXPECT_EQ(attr_node->GetAllOutDataAnchors().size(), 2); - EXPECT_EQ(attr_node->GetAllOutDataAnchorsPtr().size(), 2); - EXPECT_EQ(data_node->NodeOutConnectsAreEqual(*attr_node), true); -} - -TEST_F(UtestNode, NodeAnchorIsEqual) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - InDataAnchorPtr in_anch1 = std::make_shared(data_node, 111); - InDataAnchorPtr in_anch2 = std::make_shared(attr_node, 222); - OutDataAnchorPtr out_anch1 = std::make_shared(data_node, 333); - EXPECT_EQ(in_anch1->LinkFrom(out_anch1), GRAPH_SUCCESS); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch1, in_anch2, 2), false); - OutDataAnchorPtr out_anch2 = std::make_shared(nullptr, 444); - EXPECT_EQ(in_anch2->LinkFrom(out_anch2), GRAPH_SUCCESS); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch1, in_anch2, 2), false); -} - -TEST_F(UtestNode, AddLink) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_SUCCESS); - data_node->impl_->op_->impl_->input_name_idx_["input_name"] = 10; - data_node->impl_->op_->impl_->outputs_desc_.push_back(MakeNullptr()); - auto odesc = data_node->GetOpDesc()->GetOutputDesc(0); - auto another_desc = data_node->GetOpDescBarePtr()->GetOutputDesc(0); - EXPECT_EQ(odesc, another_desc); - attr_node->impl_->op_->impl_->input_name_idx_["__input3"] = 20; - EXPECT_NE(attr_node->impl_->op_->impl_->input_name_idx_.find("__input3"), attr_node->impl_->op_->impl_->input_name_idx_.end()); - EXPECT_EQ(attr_node->impl_->op_->impl_->inputs_desc_.size(), 3); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_FAILED); -} - -TEST_F(UtestNode, AddLinkByIndex) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - OutDataAnchorPtr out_anch = std::make_shared(data_node, 222); - EXPECT_EQ(data_node->NodeAnchorIsEqual(nullptr, in_anch, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, nullptr, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, out_anch, 1), true); - auto node3 = builder.AddNode("Data3", "Data3", 3, 3); - InControlAnchorPtr inc_anch = std::make_shared(node3, 33); - EXPECT_EQ(out_anch->LinkTo(inc_anch), GRAPH_SUCCESS); - EXPECT_EQ(data_node->NodeAnchorIsEqual(out_anch, inc_anch, 1),false); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_SUCCESS); - data_node->impl_->op_->impl_->input_name_idx_["input_name"] = 10; - data_node->impl_->op_->impl_->outputs_desc_.push_back(MakeNullptr()); - auto odesc = data_node->GetOpDesc()->GetOutputDesc(0); - attr_node->impl_->op_->impl_->input_name_idx_["__input3"] = 20; - EXPECT_NE(attr_node->impl_->op_->impl_->input_name_idx_.find("__input3"), attr_node->impl_->op_->impl_->input_name_idx_.end()); - EXPECT_EQ(attr_node->impl_->op_->impl_->inputs_desc_.size(), 3); - EXPECT_EQ(attr_node->AddLinkFrom(11, data_node), GRAPH_FAILED); -} - -TEST_F(UtestNode, AddLinkByString) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - OutDataAnchorPtr out_anch = std::make_shared(data_node, 222); - EXPECT_EQ(data_node->NodeAnchorIsEqual(nullptr, in_anch, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, nullptr, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, out_anch, 1), true); - auto node3 = builder.AddNode("Data3", "Data3", 3, 3); - InControlAnchorPtr inc_anch = std::make_shared(node3, 33); - EXPECT_EQ(out_anch->LinkTo(inc_anch), GRAPH_SUCCESS); - EXPECT_EQ(data_node->NodeAnchorIsEqual(out_anch, inc_anch, 1),false); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_SUCCESS); - data_node->impl_->op_->impl_->input_name_idx_["input_name"] = 10; - data_node->impl_->op_->impl_->outputs_desc_.push_back(MakeNullptr()); - auto odesc = data_node->GetOpDesc()->GetOutputDesc(0); - attr_node->impl_->op_->impl_->input_name_idx_["__input3"] = 20; - EXPECT_NE(attr_node->impl_->op_->impl_->input_name_idx_.find("__input3"), attr_node->impl_->op_->impl_->input_name_idx_.end()); - EXPECT_EQ(attr_node->impl_->op_->impl_->inputs_desc_.size(), 3); - EXPECT_EQ(attr_node->AddLinkFrom("__input3", data_node), GRAPH_FAILED); - attr_node->impl_->op_->impl_->input_name_idx_["__input_succ"] = 5; - EXPECT_EQ(attr_node->impl_->op_->impl_->inputs_desc_.size(), 3); - EXPECT_NE(attr_node->impl_->op_->impl_->input_name_idx_.find("__input_succ"), attr_node->impl_->op_->impl_->input_name_idx_.end()); - EXPECT_EQ(attr_node->AddLinkFrom("__input_succ", data_node), GRAPH_FAILED); -} - -TEST_F(UtestNode, AddLinkByStringInputDescFailure) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - OutDataAnchorPtr out_anch = std::make_shared(data_node, 222); - EXPECT_EQ(data_node->NodeAnchorIsEqual(nullptr, in_anch, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, nullptr, 1), false); - EXPECT_EQ(data_node->NodeAnchorIsEqual(in_anch, out_anch, 1), true); - auto node3 = builder.AddNode("Data3", "Data3", 3, 3); - InControlAnchorPtr inc_anch = std::make_shared(node3, 33); - EXPECT_EQ(out_anch->LinkTo(inc_anch), GRAPH_SUCCESS); - EXPECT_EQ(data_node->NodeAnchorIsEqual(out_anch, inc_anch, 1),false); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_SUCCESS); - data_node->impl_->op_->impl_->input_name_idx_["input_name"] = 10; - data_node->impl_->op_->impl_->outputs_desc_.push_back(nullptr); - auto odesc = data_node->GetOpDesc()->GetOutputDesc(0); - attr_node->impl_->op_->impl_->input_name_idx_["__input5"] = -1; - auto it = attr_node->impl_->op_->impl_->input_name_idx_.find("__input5"); - EXPECT_NE(it, attr_node->impl_->op_->impl_->input_name_idx_.end()); - EXPECT_EQ(it->second, -1); - EXPECT_EQ(attr_node->impl_->op_->impl_->inputs_desc_.size(), 3); - EXPECT_EQ(attr_node->impl_->op_->impl_->AddInputDesc("__input5", odesc), GRAPH_FAILED); - EXPECT_EQ(attr_node->AddLinkFrom("__input5", data_node), GRAPH_FAILED); -} - -TEST_F(UtestNode, Verify) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - data_node->impl_->in_data_anchors_.push_back(nullptr); - EXPECT_EQ(NodeUtilsEx::Verify(data_node), GRAPH_SUCCESS); - auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", data_node->impl_->op_->GetType()); - EXPECT_NE(OperatorFactoryImpl::operator_creators_v2_, nullptr); - std::map mapv2; - mapv2 = *OperatorFactoryImpl::operator_creators_v2_; - mapv2["Data"] = CreateOp; - EXPECT_EQ(data_node->impl_->op_->GetType(), "Data"); - EXPECT_EQ(node_op.IsEmpty(), true); - auto node_op2 = ge::OperatorFactoryImpl::CreateOperator("node_op", data_node->impl_->op_->GetType()); - EXPECT_EQ(node_op2.IsEmpty(), true); -} - -TEST_F(UtestNode, InferShapeAndType_failed) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - data_node->impl_->in_data_anchors_.push_back(nullptr); - EXPECT_EQ(NodeUtilsEx::InferShapeAndType(data_node), GRAPH_PARAM_INVALID); -} - -TEST_F(UtestNode, GetOutControlNodes) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node1 = builder.AddNode("Data1", "Data", 1, 1); - auto data_node2 = builder.AddNode("Data2", "Data", 1, 1); - EXPECT_EQ(data_node1->GetOutControlAnchor()->LinkTo(data_node2->GetInControlAnchor()), GRAPH_SUCCESS); - EXPECT_EQ(data_node1->GetOutControlNodes().size(), 1); - EXPECT_EQ(data_node1->GetOutControlNodesSize(), 1); - EXPECT_EQ(data_node1->GetOutNodesSize(), 1); - EXPECT_EQ(data_node1->GetOutDataAnchor(0)->LinkTo(data_node2->GetInDataAnchor(0)), GRAPH_SUCCESS); - EXPECT_EQ(data_node1->GetOutDataNodesSize(), 1); - EXPECT_EQ(data_node1->GetOutNodesSize(), 2); -} -} diff --git a/tests/ut/graph/testcase/node_utils_unittest.cc b/tests/ut/graph/testcase/node_utils_unittest.cc deleted file mode 100644 index 2b884181ea60e14a8048e65583c03842238e1bf9..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/node_utils_unittest.cc +++ /dev/null @@ -1,1174 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/utils/node_utils.h" -#include "graph/utils/node_utils_ex.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/operator_factory_impl.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/anchor.h" -#include "graph/debug/ge_attr_define.h" - -#include "auto_mapping_util.h" -#include "graph/operator_reg.h" - -#include - -namespace ge { -class UtestNodeUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -namespace { -REG_OP(FakeData) -.INPUT(x, TensorType::NumberType()) -.OUTPUT(y, TensorType::NumberType()) -.OP_END_FACTORY_REG(FakeData); - -REG_OP(FakeOpNoOutput) -.INPUT(x, TensorType::NumberType()) -.OP_END_FACTORY_REG(FakeOpNoOutput); - -template -std::shared_ptr MakeNullptr(){ - return nullptr; -} - -/* ------------------------- -* | partitioncall_0_const1* | -* partitioncall_0--------------| | | -* | | netoutput | -* | -------------------------- -* | ------------------ ------------- -* | | data | | data | -* | | | | | | | -* partitioncall_1--------------| case -----|-------| squeeze* | -* | | | | | | -* | netoutput | | netoutput | -* ------------------ ------------- -*/ -ComputeGraphPtr BuildGraphPartitionCall() { - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 0, 1); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", PARTITIONEDCALL, 1, 1); - root_builder.AddDataEdge(partitioncall_0, 0, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - - // 1.build partitioncall_0 sub graph - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_const1 = p1_sub_builder.AddNode("partitioncall_0_const1", CONSTANT, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - - // 2.build partitioncall_1 sub graph - auto p2_sub_builder = ut::GraphBuilder("partitioncall_1_sub"); - const auto &partitioncall_1_data = p2_sub_builder.AddNode("partitioncall_1_data", DATA, 0, 1); - AttrUtils::SetInt(partitioncall_1_data->GetOpDesc(), "_parent_node_index", 0); - const auto &partitioncall_1_case = p2_sub_builder.AddNode("partitioncall_1_case", "Case", 1, 1); - const auto &partitioncall_1_netoutput = p2_sub_builder.AddNode("partitioncall_1_netoutput", NETOUTPUT, 1, 1); - p2_sub_builder.AddDataEdge(partitioncall_1_data, 0, partitioncall_1_case, 0); - p2_sub_builder.AddDataEdge(partitioncall_1_case, 0, partitioncall_1_netoutput, 0); - const auto &sub_graph2 = p2_sub_builder.GetGraph(); - sub_graph2->SetParentNode(partitioncall_1); - sub_graph2->SetParentGraph(root_graph); - partitioncall_1->GetOpDesc()->AddSubgraphName("f"); - partitioncall_1->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_1_sub"); - - // 2.1 build case sub graph - auto case_sub_builder = ut::GraphBuilder("case_sub"); - const auto &case_data = case_sub_builder.AddNode("case_data", DATA, 0, 1); - AttrUtils::SetInt(case_data->GetOpDesc(), "_parent_node_index", 0); - const auto &case_squeeze = case_sub_builder.AddNode("case_squeeze", SQUEEZE, 1, 1); - const auto &case_netoutput = case_sub_builder.AddNode("case_netoutput", NETOUTPUT, 1, 1); - case_sub_builder.AddDataEdge(case_data, 0, case_squeeze, 0); - case_sub_builder.AddDataEdge(case_squeeze, 0, case_netoutput, 0); - const auto &case_sub_graph = case_sub_builder.GetGraph(); - case_sub_graph->SetParentNode(partitioncall_1_case); - case_sub_graph->SetParentGraph(sub_graph2); - partitioncall_1_case->GetOpDesc()->AddSubgraphName("branches"); - partitioncall_1_case->GetOpDesc()->SetSubgraphInstanceName(0, "case_sub"); - - root_graph->AddSubgraph(case_sub_graph->GetName(), case_sub_graph); - root_graph->AddSubgraph(sub_graph->GetName(), sub_graph); - root_graph->AddSubgraph(sub_graph2->GetName(), sub_graph2); - return root_graph; -} -/* ------------------------- -* | data0 | -* data | | | -* | | cast | -* partitioncall_0--------------| | | -* | | netoutput | -* | -------------------------- -* | ------------------ -* | | data1 | -* | | | | -* partitioncall_1--------------| squeeze | -* | | | -* | netoutput | -* ------------------ -*/ -ComputeGraphPtr BuildGraphPartitionCall2() { - auto root_builder = ut::GraphBuilder("root"); - const auto &data = root_builder.AddNode("data", DATA, 1, 1); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 3, 3); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", PARTITIONEDCALL, 1, 1); - root_builder.AddDataEdge(data, 0, partitioncall_0, 1); - root_builder.AddDataEdge(partitioncall_0, 1, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - - // 1.build partitioncall_0 sub graph - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_data = p1_sub_builder.AddNode("partitioncall_0_data", DATA, 0, 1); - AttrUtils::SetInt(partitioncall_0_data->GetOpDesc(), "_parent_node_index", 1); - const auto &partitioncall_0_cast = p1_sub_builder.AddNode("partitioncall_0_cast", "Cast", 1, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 3, 3); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(2), "_parent_node_index", 2); - p1_sub_builder.AddDataEdge(partitioncall_0_data, 0, partitioncall_0_cast, 0); - p1_sub_builder.AddDataEdge(partitioncall_0_cast, 0, partitioncall_0_netoutput, 1); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - - // 2.build partitioncall_1 sub graph - auto p2_sub_builder = ut::GraphBuilder("partitioncall_1_sub"); - const auto &partitioncall_1_data = p2_sub_builder.AddNode("partitioncall_1_data", DATA, 0, 1); - AttrUtils::SetInt(partitioncall_1_data->GetOpDesc(), "_parent_node_index", 0); - const auto &partitioncall_1_squeeze = p2_sub_builder.AddNode("partitioncall_1_squeeze", SQUEEZE, 1, 1); - const auto &partitioncall_1_netoutput = p2_sub_builder.AddNode("partitioncall_1_netoutput", NETOUTPUT, 1, 1); - p2_sub_builder.AddDataEdge(partitioncall_1_data, 0, partitioncall_1_squeeze, 0); - p2_sub_builder.AddDataEdge(partitioncall_1_squeeze, 0, partitioncall_1_netoutput, 0); - const auto &sub_graph2 = p2_sub_builder.GetGraph(); - sub_graph2->SetParentNode(partitioncall_1); - sub_graph2->SetParentGraph(root_graph); - partitioncall_1->GetOpDesc()->AddSubgraphName("f"); - partitioncall_1->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_1_sub"); - - root_graph->AddSubgraph(sub_graph->GetName(), sub_graph); - root_graph->AddSubgraph(sub_graph2->GetName(), sub_graph2); - return root_graph; -} -/* ------------------------- ------------------- -* | partitioncall_2---------------------| Mul | -* partitioncall_0--------------| | | | | | -* | | netoutput | | netoutput | -* | -------------------------- ------------------ -* | ------------- -* | | data | -* | | | | -* partitioncall_1--------------| squeeze* | -* | | | -* | netoutput | -* ------------- -*/ -ComputeGraphPtr BuildGraphPartitionCall3() { - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 1, 1); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", PARTITIONEDCALL, 1, 1); - root_builder.AddDataEdge(partitioncall_0, 0, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - - // 1.build partitioncall_0 sub graph - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_2 = p1_sub_builder.AddNode("partitioncall_2", PARTITIONEDCALL, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_2, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("sub0"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - - // 2.build partitioncall_1 sub graph - auto p2_sub_builder = ut::GraphBuilder("partitioncall_1_sub"); - const auto &partitioncall_1_data = p2_sub_builder.AddNode("partitioncall_1_data", DATA, 0, 1); - AttrUtils::SetInt(partitioncall_1_data->GetOpDesc(), "_parent_node_index", 0); - const auto &partitioncall_1_squeeze = p2_sub_builder.AddNode("partitioncall_1_squeeze", SQUEEZE, 1, 1); - const auto &partitioncall_1_netoutput = p2_sub_builder.AddNode("partitioncall_1_netoutput", NETOUTPUT, 1, 1); - p2_sub_builder.AddDataEdge(partitioncall_1_data, 0, partitioncall_1_squeeze, 0); - p2_sub_builder.AddDataEdge(partitioncall_1_squeeze, 0, partitioncall_1_netoutput, 0); - const auto &sub_graph2 = p2_sub_builder.GetGraph(); - sub_graph2->SetParentNode(partitioncall_1); - sub_graph2->SetParentGraph(root_graph); - partitioncall_1->GetOpDesc()->AddSubgraphName("sub1"); - partitioncall_1->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_1_sub"); - - // 3 build partitioncall_2 sub graph - auto p3_sub_builder = ut::GraphBuilder("partitioncall_2_sub"); - const auto &partitioncall_2_mul = p3_sub_builder.AddNode("partitioncall_2_mul", "Mul", 0, 1); - const auto &partitioncall_2_netoutput = p3_sub_builder.AddNode("partitioncall_2_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_2_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p3_sub_builder.AddDataEdge(partitioncall_2_mul, 0, partitioncall_2_netoutput, 0); - const auto &sub_graph3 = p3_sub_builder.GetGraph(); - sub_graph3->SetParentNode(partitioncall_2); - sub_graph3->SetParentGraph(sub_graph); - partitioncall_2->GetOpDesc()->AddSubgraphName("sub2"); - partitioncall_2->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_2_sub"); - - root_graph->AddSubgraph(sub_graph->GetName(), sub_graph); - root_graph->AddSubgraph(sub_graph2->GetName(), sub_graph2); - root_graph->AddSubgraph(sub_graph3->GetName(), sub_graph3); - return root_graph; -} - -/* -* | constant | -* data partitioncall_0--------------| | | -* \ / | netoutput | -* concat -*/ -ComputeGraphPtr BuildGraphPartitionCall4() { - auto root_builder = ut::GraphBuilder("root"); - const auto &data = root_builder.AddNode("data", DATA, 1, 1); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 3, 3); - const auto &concat = root_builder.AddNode("concat", "Concat", 2, 1); - root_builder.AddDataEdge(data, 0, concat, 0); - root_builder.AddDataEdge(partitioncall_0, 0, concat, 1); - const auto &root_graph = root_builder.GetGraph(); - - // 1.build partitioncall_0 sub graph - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_const = p1_sub_builder.AddNode("partitioncall_0_constant", CONSTANTOP, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - root_graph->AddSubgraph(sub_graph->GetName(), sub_graph); - return root_graph; -} - -/** - * var---->identity--->cast--->netoutput - * \ || - * \ || - * \ \/ - * const->assgin - */ -ComputeGraphPtr BuildGraphWithUsefulIdentity() { - auto builder = ut::GraphBuilder("test"); - // id1 is useful - auto id1 = builder.AddNode("id1", IDENTITY, 1, 1); - auto var0 = builder.AddNode("var0", VARIABLE, 1, 1); - auto const0 = builder.AddNode("const0", CONSTANT, 1, 1); - auto cast = builder.AddNode("cast", "CAST", 1, 1); - auto ref_node = builder.AddNode("ref_node", ASSIGN, 2, 1); - ref_node->GetOpDesc()->UpdateInputName({{"ref", 0}, {"value", 1}}); - ref_node->GetOpDesc()->UpdateOutputName({{"ref", 0}}); - auto netoutput_node = builder.AddNode("netoutput", NETOUTPUT, 1, 1); - - builder.AddDataEdge(var0, 0, id1, 0); - builder.AddDataEdge(id1, 0, cast, 0); - builder.AddDataEdge(cast, 0, netoutput_node, 0); - builder.AddDataEdge(var0, 0, ref_node, 0); - builder.AddDataEdge(const0, 0, ref_node, 1); - builder.AddControlEdge(id1, ref_node); - return builder.GetGraph(); -} -} - -TEST_F(UtestNodeUtils, UpdateOriginShapeAndShape) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data1 = builder.AddNode("Data1", "Data", 1, 1); - auto data2 = builder.AddNode("Data2", "Data", 1, 1); - - vector dims = {1, 2}; - GeShape data_shape(dims); - ASSERT_EQ(NodeUtils::UpdateInputOriginalShapeAndShape(*data1, 0, data_shape), GRAPH_SUCCESS); - ASSERT_EQ(NodeUtils::UpdateOutputOriginalShapeAndShape(*data1, 0, data_shape), GRAPH_SUCCESS); - ASSERT_EQ(NodeUtils::UpdateInputOriginalShapeAndShape(*data2, 0, data_shape), GRAPH_SUCCESS); - ASSERT_EQ(NodeUtils::UpdateOutputOriginalShapeAndShape(*data2, 0, data_shape), GRAPH_SUCCESS); - ASSERT_EQ(data1->GetOpDesc()->GetInputDesc(0).GetShape() == data1->GetOpDesc()->GetInputDesc(0).GetShape(), true); - ASSERT_EQ(data1->GetOpDesc()->GetInputDesc(0).IsOriginShapeInitialized(), true); -} - -TEST_F(UtestNodeUtils, GetSubgraphs) { - auto root_builder = ut::GraphBuilder("root"); - const auto &case0 = root_builder.AddNode("case0", "Case", 0, 0); - const auto &root_graph = root_builder.GetGraph(); - - auto sub_builder1 = ut::GraphBuilder("sub1"); - const auto &case1 = sub_builder1.AddNode("case1", "Case", 0, 0); - const auto &sub_graph1 = sub_builder1.GetGraph(); - root_graph->AddSubGraph(sub_graph1); - sub_graph1->SetParentNode(case0); - sub_graph1->SetParentGraph(root_graph); - case0->GetOpDesc()->AddSubgraphName("branch1"); - case0->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - - std::vector subgraphs0; - ASSERT_EQ(NodeUtils::GetDirectSubgraphs(case0, subgraphs0), GRAPH_SUCCESS); - ASSERT_EQ(subgraphs0.size(), 1); - std::vector subgraphs1; - ASSERT_EQ(NodeUtils::GetDirectSubgraphs(case1, subgraphs1), GRAPH_SUCCESS); - ASSERT_TRUE(subgraphs1.empty()); -} - -TEST_F(UtestNodeUtils, GetSubgraphs_nullptr_node) { - std::vector subgraphs; - ASSERT_NE(NodeUtils::GetDirectSubgraphs(nullptr, subgraphs), GRAPH_SUCCESS); - ASSERT_TRUE(subgraphs.empty()); -} - -TEST_F(UtestNodeUtils, GetSubgraphs_nullptr_root_graph) { - auto builder = ut::GraphBuilder("graph"); - const auto &node = builder.AddNode("node", "node", 0, 0); - node->impl_->owner_graph_.reset(); - - std::vector subgraphs; - ASSERT_NE(NodeUtils::GetDirectSubgraphs(node, subgraphs), GRAPH_SUCCESS); - ASSERT_TRUE(subgraphs.empty()); -} - -TEST_F(UtestNodeUtils, GetSubgraphs_nullptr_sub_graph) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("node", "node", 0, 0); - const auto &root_graph = root_builder.GetGraph(); - - auto sub_builder = ut::GraphBuilder("sub"); - const auto &sub_graph = sub_builder.GetGraph(); - sub_graph->SetParentNode(node); - sub_graph->SetParentGraph(root_graph); - node->GetOpDesc()->AddSubgraphName("branch1"); - node->GetOpDesc()->SetSubgraphInstanceName(0, "sub"); - - std::vector subgraphs; - ASSERT_EQ(NodeUtils::GetDirectSubgraphs(node, subgraphs), GRAPH_SUCCESS); - ASSERT_TRUE(subgraphs.empty()); -} - -TEST_F(UtestNodeUtils, GetNodeUnknownShapeStatus_success) { - auto root_builder = ut::GraphBuilder("root"); - const auto &case0 = root_builder.AddNode("case0", "Case", 0, 0); - const auto &root_graph = root_builder.GetGraph(); - auto sub_builder1 = ut::GraphBuilder("sub1"); - const auto &case1 = sub_builder1.AddNode("case1", "Case", 0, 0); - const auto &sub_graph1 = sub_builder1.GetGraph(); - root_graph->AddSubGraph(sub_graph1); - sub_graph1->SetParentNode(case0); - sub_graph1->SetParentGraph(root_graph); - case0->GetOpDesc()->AddSubgraphName("branch1"); - case0->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - bool is_known = false; - ASSERT_EQ(NodeUtils::GetNodeUnknownShapeStatus(*case0, is_known), GRAPH_SUCCESS); -} - - -TEST_F(UtestNodeUtils, GetInNodeCrossPartionedCallNode_cross_one_subgraph) { - auto graph = BuildGraphPartitionCall4(); - NodePtr expect_peer_node; - NodePtr concat_node; - for (auto &node : graph->GetAllNodes()) { - if (node->GetType() == "Concat") { - concat_node = node; - } - } - auto ret = NodeUtils::GetInNodeCrossPartionedCallNode(concat_node, 1 , expect_peer_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_NE(expect_peer_node, nullptr); - ASSERT_EQ(expect_peer_node->GetName(), "partitioncall_0_constant"); -} - -TEST_F(UtestNodeUtils, GetInNodeCrossPartionedCallNode_subgraph_in_partitioncall) { - auto graph = BuildGraphPartitionCall(); - NodePtr expect_peer_node; - NodePtr squeeze_node; - for (auto &node : graph->GetAllNodes()) { - if (node->GetType() == SQUEEZE) { - squeeze_node = node; - } - } - auto ret = NodeUtils::GetInNodeCrossPartionedCallNode(squeeze_node, 0 , expect_peer_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_NE(expect_peer_node, nullptr); - ASSERT_EQ(expect_peer_node->GetName(), "partitioncall_0_const1"); -} - - /// A(PartionedCall_0)->B(PartionedCall_1) - /// PartionedCall_0's subgraph: Data->A->Netoutput - /// PartionedCall_1's subgraph: Data1->B->Netoutput - /// If it is called like GetInNodeCrossPartionCallNode(B,0,peer_node)or(Data1,0,peer_node), peer_node is A -TEST_F(UtestNodeUtils, GetInNodeCrossPartionedCallNode_paritioncall_link_partitioncall) { - auto graph = BuildGraphPartitionCall2(); - NodePtr expect_peer_node = nullptr; - NodePtr squeeze_node; - NodePtr data_in_root; - NodePtr data_in_partition0; - NodePtr data_in_partition1; - NodePtr partitioncall_0; - for (auto &node : graph->GetAllNodes()) { - if (node->GetType() == SQUEEZE) { - squeeze_node = node; - } - if (node->GetName() == "data") { - data_in_root = node; - } - if (node->GetName() == "partitioncall_0_data") { - data_in_partition0 = node; - } - if (node->GetName() == "partitioncall_1_data") { - data_in_partition1 = node; - } - if (node->GetName() == "partitioncall_0") { - partitioncall_0 = node; - } - } - ASSERT_EQ(NodeUtils::GetInNodeCrossSubgraph(data_in_partition0), data_in_root); - ASSERT_EQ(NodeUtils::GetInNodeCrossSubgraph(data_in_partition1), partitioncall_0); - - // test with src node - auto ret = NodeUtils::GetInNodeCrossPartionedCallNode(squeeze_node, 0 , expect_peer_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_NE(expect_peer_node, nullptr); - ASSERT_EQ(expect_peer_node->GetName(), "partitioncall_0_cast"); - - // test subgraph_data node - ret = NodeUtils::GetInNodeCrossPartionedCallNode(data_in_partition1, 0 , expect_peer_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_NE(expect_peer_node, nullptr); - ASSERT_EQ(expect_peer_node->GetName(), "partitioncall_0_cast"); - - // test peer_node is root_data node - ret = NodeUtils::GetInNodeCrossPartionedCallNode(partitioncall_0, 1 , expect_peer_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(expect_peer_node, nullptr); -} - -TEST_F(UtestNodeUtils, GetInNodeCrossPartionedCallNode_multi_partitioncall) { - auto graph = BuildGraphPartitionCall3(); - NodePtr expect_peer_node; - NodePtr squeeze_node; - for (auto &node : graph->GetAllNodes()) { - if (node->GetType() == SQUEEZE) { - squeeze_node = node; - } - } - auto ret = NodeUtils::GetInNodeCrossPartionedCallNode(squeeze_node, 0 , expect_peer_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_NE(expect_peer_node, nullptr); - ASSERT_EQ(expect_peer_node->GetName(), "partitioncall_2_mul"); -} - -TEST_F(UtestNodeUtils, GetInNodeCrossPartionedCallNode_temp_test_return_success_when_peer_node_null) { - auto graph = BuildGraphPartitionCall3(); - NodePtr expect_peer_node; - NodePtr partition_node; - for (auto &node : graph->GetAllNodes()) { - if (node->GetName() == "partitioncall_0") { - partition_node = node; - } - } - auto ret = NodeUtils::GetInNodeCrossPartionedCallNode(partition_node, 0 , expect_peer_node); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(expect_peer_node, nullptr); -} - -TEST_F(UtestNodeUtils, GetConstOpType_CONST) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("const1", CONSTANT, 0, 1); - std::cout << data->GetType() << std::endl; - std::string op_type; - auto ret = NodeUtils::GetConstOpType(data, op_type); - ASSERT_EQ(ret, true); - ASSERT_EQ(op_type, "Const"); -} - -TEST_F(UtestNodeUtils, GetConstOpType_DATA) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - std::cout << data->GetType() << std::endl; - std::string op_type; - auto ret = NodeUtils::GetConstOpType(data, op_type); - ASSERT_EQ(ret, false); -} - -TEST_F(UtestNodeUtils, GetNodeUnknownShapeStatus) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("add", "Add", 2, 1, FORMAT_NHWC, DT_FLOAT, {16, 228, 228, 3}); - auto graph = builder.GetGraph(); - - auto add_node = graph->FindNode("add"); - ASSERT_NE(add_node, nullptr); - bool is_unknown = false; - auto ret = NodeUtils::GetNodeUnknownShapeStatus(*add_node, is_unknown); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_FALSE(is_unknown); - - ASSERT_NE(add_node->GetOpDesc(), nullptr); - auto out_desc = add_node->GetOpDesc()->MutableOutputDesc(0); - ASSERT_NE(out_desc, nullptr); - out_desc->SetShape(GeShape({-1, 228, 228, 3})); - is_unknown = false; - (void)NodeUtils::GetNodeUnknownShapeStatus(*add_node, is_unknown); - EXPECT_EQ(is_unknown, true); - out_desc->SetShape(GeShape({-2})); - is_unknown = false; - (void)NodeUtils::GetNodeUnknownShapeStatus(*add_node, is_unknown); - EXPECT_EQ(is_unknown, true); - - auto in_desc = add_node->GetOpDesc()->MutableInputDesc(0); - ASSERT_NE(in_desc, nullptr); - in_desc->SetShape(GeShape({-1, 228, 228, 3})); - is_unknown = false; - (void)NodeUtils::GetNodeUnknownShapeStatus(*add_node, is_unknown); - EXPECT_EQ(is_unknown, true); - in_desc->SetShape(GeShape({-2})); - is_unknown = false; - (void)NodeUtils::GetNodeUnknownShapeStatus(*add_node, is_unknown); - EXPECT_EQ(is_unknown, true); -} - -TEST_F(UtestNodeUtils, ClearInDataAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - InDataAnchorPtr in_anch = std::make_shared(data, 111); - EXPECT_EQ(NodeUtils::ClearInDataAnchor(data, in_anch), GRAPH_FAILED); - - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto const2 = builder.AddNode("const2", "Const", 1, 1); - InDataAnchorPtr in_anch1 = std::make_shared(const1, 111); - InDataAnchorPtr in_anch2 = std::make_shared(const2, 111); - OutDataAnchorPtr out_anch = std::make_shared(const2, 222); - EXPECT_EQ(const1->AddLinkFrom(const2), GRAPH_SUCCESS); - EXPECT_EQ(const2->GetOutDataNodes().size(), 1); - EXPECT_EQ(const1->impl_->in_data_anchors_.size(), 2); - auto anch = const1->impl_->in_data_anchors_.at(0); - EXPECT_EQ(NodeUtils::ClearInDataAnchor(const1, in_anch2), GRAPH_FAILED); - EXPECT_EQ(NodeUtils::ClearInDataAnchor(const1, anch), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, SetStatus) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(data), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(*data), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::IsAnchorStatusSet(data), true); - EXPECT_EQ(NodeUtils::IsAnchorStatusSet(nullptr), false); - EXPECT_EQ(NodeUtils::IsAnchorStatusSet(*data), true); - data->impl_ = nullptr; - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(*data), GRAPH_FAILED); - EXPECT_EQ(NodeUtils::IsAnchorStatusSet(*data), false); -} - -TEST_F(UtestNodeUtils, MoveOutputEdges) { - EXPECT_EQ(NodeUtils::MoveOutputEdges(nullptr, nullptr), GRAPH_FAILED); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - auto dest = builder.AddNode("Dest", "Dest", 11, 22); - auto attr = builder.AddNode("Attr", "Attr", 1, 1); - auto node2 = builder.AddNode("Data2", "Data2", 2, 1); - InDataAnchorPtr peer = std::make_shared(node2, 22); - EXPECT_EQ(NodeUtils::MoveOutputEdges(data, dest), GRAPH_FAILED); - EXPECT_EQ(dest->GetAllOutDataAnchors().size(), 22); - EXPECT_EQ(NodeUtils::MoveOutputEdges(data, attr), GRAPH_SUCCESS); - - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto const2 = builder.AddNode("const2", "Const", 1, 1); - EXPECT_EQ(const1->AddLinkFrom(const2), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::MoveOutputEdges(const2, const1), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, MoveOutputEdges_Link) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data0 = builder.AddNode("Data0", DATA, 1, 1); - auto data1 = builder.AddNode("Data1", DATA, 1, 1); - auto data2 = builder.AddNode("Data2", DATA, 1, 1); - auto data3 = builder.AddNode("Data3", DATA, 1, 1); - auto data4 = builder.AddNode("Data4", DATA, 1, 1); - EXPECT_EQ(data0->GetAllOutDataAnchors().at(0)->LinkTo(data2->GetInControlAnchor()), GRAPH_SUCCESS); - EXPECT_EQ(data0->GetOutControlAnchor()->LinkTo(data3->GetInControlAnchor()), GRAPH_SUCCESS); - EXPECT_EQ(data0->GetOutControlAnchor()->LinkTo(data4->GetInDataAnchor(0)), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::MoveOutputEdges(data0, data1), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, UpdateIsInputConst_Normal) { - EXPECT_NO_THROW( - NodeUtils::UpdateIsInputConst(nullptr); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - NodeUtils::UpdateIsInputConst(data); - NodeUtils::UpdateIsInputConst(*data); - data->impl_->op_ = nullptr; - NodeUtils::UpdateIsInputConst(data); - ); -} - -TEST_F(UtestNodeUtils, UpdateIsInputConst_Nullptr) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data1 = builder.AddNode("Data1", DATA, 2, 2); - auto data2 = builder.AddNode("Data2", DATA, 1, 1); - data1->impl_->in_data_anchors_.at(1) = nullptr; - EXPECT_EQ(data1->GetInDataAnchor(0)->LinkFrom(data2->GetOutDataAnchor(0)), GRAPH_SUCCESS); - auto node = data2->GetOutDataAnchor(0)->GetOwnerNode(); - NodeUtils::UpdateIsInputConst(data1); -} - - -TEST_F(UtestNodeUtils, UpdateIsInputConst_OutDataAnchorNullptr) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - OutDataAnchorPtr out_anch = std::make_shared(data_node, 222); - EXPECT_EQ(out_anch->LinkTo(in_anch), GRAPH_SUCCESS); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_SUCCESS); - EXPECT_EQ(attr_node->GetAllInDataAnchors().size(), 3); - NodeUtils::UpdateIsInputConst(attr_node); -} - - -TEST_F(UtestNodeUtils, UnlinkAll) { - EXPECT_NO_THROW( - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - NodeUtils::UnlinkAll(*data); - ); -} - -TEST_F(UtestNodeUtils, AppendRemoveAnchor) { - EXPECT_EQ(NodeUtils::AppendInputAnchor(nullptr, 0), GRAPH_FAILED); - EXPECT_EQ(NodeUtils::RemoveInputAnchor(nullptr, 0), GRAPH_FAILED); - EXPECT_EQ(NodeUtils::AppendOutputAnchor(nullptr, 0), GRAPH_FAILED); - EXPECT_EQ(NodeUtils::RemoveOutputAnchor(nullptr, 0), GRAPH_FAILED); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - EXPECT_EQ(NodeUtils::AppendInputAnchor(data, 11), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::RemoveInputAnchor(data, 11), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::AppendOutputAnchor(data, 22), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::RemoveOutputAnchor(data, 22), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, RemoveInputAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - EXPECT_EQ(data->GetOpDesc()->GetInputsSize(), 1); - EXPECT_EQ(data->GetOpDesc()->AddInputDesc(GeTensorDesc()), GRAPH_SUCCESS); - EXPECT_EQ(data->GetOpDesc()->GetInputsSize(), 2); - EXPECT_EQ(NodeUtils::RemoveInputAnchor(data, 0), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, RemoveOutputAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - EXPECT_EQ(data->GetOpDesc()->GetOutputsSize(), 1); - EXPECT_EQ(NodeUtils::RemoveOutputAnchor(data, 0), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, SetSubgraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - auto gragh = builder.GetGraph(); - EXPECT_EQ(NodeUtils::SetSubgraph(*data, 0, nullptr), GRAPH_PARAM_INVALID); - EXPECT_EQ(NodeUtils::SetSubgraph(*data, 0, gragh), GRAPH_PARAM_INVALID); - data->impl_->op_->AddSubgraphName("g1"); - data->impl_->op_->AddSubgraphName("g2"); - - auto sub_graph1 = std::make_shared("g1"); - auto sub_graph2 = std::make_shared("g2"); - EXPECT_EQ(NodeUtils::SetSubgraph(*data, 0, sub_graph1), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::SetSubgraph(*data, 1, sub_graph2), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, IsSubgraphInput) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::IsSubgraphInput(node), false); - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 0, 1); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", PARTITIONEDCALL, 1, 1); - root_builder.AddDataEdge(partitioncall_0, 0, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_const1 = p1_sub_builder.AddNode("partitioncall_0_const1", CONSTANT, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - EXPECT_EQ(NodeUtils::IsSubgraphInput(partitioncall_0_const1), false); -} - -TEST_F(UtestNodeUtils, IsSubgraphInput_WithATTR_NAME_IS_UNKNOWN_SHAPE) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto root_builder = ut::GraphBuilder("root"); - const auto &while_node = root_builder.AddNode("while", WHILE, 1, 1); - bool is_known = true; - AttrUtils::SetBool(while_node->GetOpDesc(), ATTR_NAME_IS_UNKNOWN_SHAPE, is_known); - const auto &root_graph = root_builder.GetGraph(); - - auto p1_sub_builder = ut::GraphBuilder("sub"); - const auto &data_node = p1_sub_builder.AddNode("data", DATA, 0, 1); - AttrUtils::SetInt(data_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(while_node); - sub_graph->SetParentGraph(root_graph); - while_node->GetOpDesc()->AddSubgraphName("sub"); - while_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub"); - EXPECT_FALSE(NodeUtils::IsSubgraphInput(data_node)); - - is_known = false; - AttrUtils::SetBool(while_node->GetOpDesc(), ATTR_NAME_IS_UNKNOWN_SHAPE, is_known); - EXPECT_TRUE(NodeUtils::IsSubgraphInput(data_node)); -} - -TEST_F(UtestNodeUtils, IsSubgraphOutput_WithATTR_NAME_IS_UNKNOWN_SHAPE) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto root_builder = ut::GraphBuilder("root"); - const auto &while_node = root_builder.AddNode("while", WHILE, 1, 1); - bool is_known = true; - AttrUtils::SetBool(while_node->GetOpDesc(), ATTR_NAME_IS_UNKNOWN_SHAPE, is_known); - const auto &root_graph = root_builder.GetGraph(); - - auto p1_sub_builder = ut::GraphBuilder("sub"); - const auto &partitioncall_0_const1 = p1_sub_builder.AddNode("partitioncall_0_const1", CONSTANT, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(while_node); - sub_graph->SetParentGraph(root_graph); - while_node->GetOpDesc()->AddSubgraphName("sub"); - while_node->GetOpDesc()->SetSubgraphInstanceName(0, "sub"); - EXPECT_FALSE(NodeUtils::IsSubgraphOutput(partitioncall_0_netoutput)); - - is_known = false; - AttrUtils::SetBool(while_node->GetOpDesc(), ATTR_NAME_IS_UNKNOWN_SHAPE, is_known); - EXPECT_TRUE(NodeUtils::IsSubgraphOutput(partitioncall_0_netoutput)); -} - -TEST_F(UtestNodeUtils, IsDynamicShape) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::IsDynamicShape(*node), false); -} - -TEST_F(UtestNodeUtils, IsWhileVaryingInput) { - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 0, 1); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", PARTITIONEDCALL, 1, 1); - root_builder.AddDataEdge(partitioncall_0, 0, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_const1 = p1_sub_builder.AddNode("partitioncall_0_const1", CONSTANT, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - auto sub2_builder = ut::GraphBuilder("partitioncall_0_sub2"); - const auto &data11 = root_builder.AddNode("data11", DATA, 0, 1); - data11->SetOwnerComputeGraph(sub_graph); - EXPECT_EQ(NodeUtils::IsWhileVaryingInput(data11), false); -} - -TEST_F(UtestNodeUtils, IsWhileVaryingInput_While) { - auto root_builder = ut::GraphBuilder("root"); - const auto &while1 = root_builder.AddNode("while1", "While", 1, 1); - const auto &root_graph = root_builder.GetGraph(); - auto sub_builder = ut::GraphBuilder("sub"); - const auto &const1 = sub_builder.AddNode("const1", CONSTANT, 1, 1); - const auto &netoutput = sub_builder.AddNode("netoutput", NETOUTPUT, 1, 1); - const auto &data0 = sub_builder.AddNode("data0", DATA, 1, 1); - const auto &sub_graph = sub_builder.GetGraph(); - sub_graph->SetParentNode(while1); - sub_graph->SetParentGraph(root_graph); - EXPECT_EQ(NodeUtils::IsWhileVaryingInput(data0), false); - //EXPECT_EQ(NodeUtils::IsWhileVaryingInput(while1), false); - EXPECT_EQ(AttrUtils::SetInt(data0->GetOpDesc(), "_parent_node_index", 1), true); - EXPECT_EQ(NodeUtils::IsWhileVaryingInput(data0), true); -} - -TEST_F(UtestNodeUtils, RemoveSubgraphsOnNode) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::RemoveSubgraphsOnNode(node), GRAPH_SUCCESS); - node->impl_->op_->SetSubgraphInstanceName(0, "name"); - node->impl_->op_->SetSubgraphInstanceName(1, "name1"); - EXPECT_EQ(NodeUtils::RemoveSubgraphsOnNode(node), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, GetSubgraphDataNodesByIndex) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Const", 1, 1); - EXPECT_EQ(NodeUtils::GetSubgraphDataNodesByIndex(*node, 0).size(), 0); -} - -TEST_F(UtestNodeUtils, GetSubgraphDataAndNetoutput) { - auto root_builder = ut::GraphBuilder("root"); - const auto &data = root_builder.AddNode("data", DATA, 1, 1); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 3, 3); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", PARTITIONEDCALL, 1, 1); - root_builder.AddDataEdge(data, 0, partitioncall_0, 1); - root_builder.AddDataEdge(partitioncall_0, 1, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - - int64_t index = 0; - // 1.build partitioncall_0 sub graph - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_data = p1_sub_builder.AddNode("partitioncall_0_data", DATA, 0, 1); - AttrUtils::SetInt(partitioncall_0_data->GetOpDesc(), "_parent_node_index", index); - const auto &partitioncall_0_cast = p1_sub_builder.AddNode("partitioncall_0_cast", "Cast", 1, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 3, 3); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(2), "_parent_node_index", 2); - p1_sub_builder.AddDataEdge(partitioncall_0_data, 0, partitioncall_0_cast, 0); - p1_sub_builder.AddDataEdge(partitioncall_0_cast, 0, partitioncall_0_netoutput, 1); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - root_graph->AddSubgraph("partitioncall_0_sub", sub_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - auto compute_graph = partitioncall_0->GetOwnerComputeGraph(); - EXPECT_NE(compute_graph, nullptr); - EXPECT_EQ(NodeUtils::GetSubgraphOutputNodes(*partitioncall_0).size(), 1); - EXPECT_EQ(NodeUtils::GetSubgraphDataNodesByIndex(*partitioncall_0, index).size(), 1); - EXPECT_NE(sub_graph->GetOrUpdateNetOutputNode(), nullptr); - EXPECT_EQ(sub_graph->GetOrUpdateNetOutputNode()->GetType(), NETOUTPUT); -} - -TEST_F(UtestNodeUtils, GetSubgraphOutputNodes) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - auto op_desc = node->GetOpDesc(); - op_desc->impl_->subgraph_instance_names_.clear(); - auto subgraph_names = op_desc->GetSubgraphInstanceNames(); - EXPECT_EQ(NodeUtils::GetSubgraphOutputNodes(*node).size(), 0); - auto compute_graph = node->GetOwnerComputeGraph(); - compute_graph->impl_->parent_graph_ = MakeNullptr(); - EXPECT_EQ(NodeUtils::GetSubgraphOutputNodes(*node).size(), 0); -} - -TEST_F(UtestNodeUtils, GetOutDataNodesWithAnchorByIndex) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, 0).size(), 0); - EXPECT_EQ(NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, -1).size(), 0); -} - -TEST_F(UtestNodeUtils, GetNodeFromOperator) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - Operator op = OperatorFactoryImpl::CreateOperator("opname", "optp"); - EXPECT_EQ(NodeUtilsEx::GetNodeFromOperator(op), nullptr); -} - -TEST_F(UtestNodeUtils, GetInConstNodeTypeCrossSubgraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::GetInConstNodeTypeCrossSubgraph(node), "Node"); - EXPECT_EQ(NodeUtils::GetInConstNodeTypeCrossSubgraph(nullptr), ""); -} - -TEST_F(UtestNodeUtils, CreatNodeWithoutGraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - OpDescPtr od = std::make_shared("name", "type"); - EXPECT_NE(NodeUtils::CreatNodeWithoutGraph(od), nullptr); - EXPECT_EQ(NodeUtils::CreatNodeWithoutGraph(nullptr), nullptr); -} - -TEST_F(UtestNodeUtils, SetNodeParallelGroup) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::SetNodeParallelGroup(*node, nullptr), GRAPH_FAILED); - EXPECT_NE(NodeUtils::SetNodeParallelGroup(*node, "node_group"), GRAPH_FAILED); - auto amap = node->GetOpDesc()->GetAttrMap(); - amap.SetByName("_parallel_group", "_parallel_group_value"); - node->impl_->op_->impl_->attrs_ = amap; - EXPECT_EQ(NodeUtils::SetNodeParallelGroup(*node, "node_group"), GRAPH_FAILED); - EXPECT_EQ(NodeUtils::SetNodeParallelGroup(*node, "_parallel_group_value"), GRAPH_SUCCESS); - -} - -TEST_F(UtestNodeUtils, GetSubgraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::GetSubgraph(*node, 0), nullptr); - node->impl_->op_ = nullptr; - EXPECT_EQ(NodeUtils::GetSubgraph(*node, 0), nullptr); -} - -TEST_F(UtestNodeUtils, GetNodeType) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - node->impl_->op_->impl_->meta_data_.type_ = "FrameworkOp"; - EXPECT_EQ(NodeUtils::GetNodeType(*node), ""); -} - -TEST_F(UtestNodeUtils, UpdateInOutputOriginalShapeAndShapeFailure) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - auto desc = node->GetOpDesc(); - EXPECT_EQ(NodeUtils::UpdateInputOriginalShapeAndShape(*node, 100, GeShape()), GRAPH_PARAM_INVALID); - EXPECT_EQ(NodeUtils::UpdateOutputOriginalShapeAndShape(*node, 100, GeShape()), GRAPH_PARAM_INVALID); - node->impl_->op_ = nullptr; - EXPECT_EQ(NodeUtils::UpdateInputOriginalShapeAndShape(*node, 100, GeShape()), GRAPH_PARAM_INVALID); - EXPECT_EQ(NodeUtils::UpdateOutputOriginalShapeAndShape(*node, 100, GeShape()), GRAPH_PARAM_INVALID); -} - -TEST_F(UtestNodeUtils, GetInNodeCrossPartionedCallNode){ - auto graph = BuildGraphPartitionCall(); - NodePtr expect_peer_node; - NodePtr sq_node; - NodePtr nt_node; - - for (auto &node : graph->GetAllNodes()) { - if (node->GetType() == SQUEEZE) { - sq_node = node; - } - if (node->GetType() == NETOUTPUT) { - nt_node = node; - } - } - EXPECT_NE(sq_node, nullptr); - EXPECT_NE(sq_node, nullptr); - EXPECT_NE(sq_node->GetType(), DATA); - EXPECT_EQ(sq_node->GetType(), SQUEEZE); - EXPECT_NE(sq_node->GetType(), PARTITIONEDCALL); - EXPECT_EQ(sq_node->GetOpDesc()->GetSubgraphInstanceNames().empty(), true); - EXPECT_EQ(NodeUtils::GetInNodeCrossPartionedCallNode(sq_node, 0 , expect_peer_node), GRAPH_SUCCESS); - EXPECT_EQ(NodeUtils::GetInNodeCrossPartionedCallNode(sq_node, 1000 , expect_peer_node), GRAPH_FAILED); - auto peer = NodeUtils::GetInDataNodeByIndex(*sq_node, 0); - EXPECT_NE(NodeUtils::GetInDataNodeByIndex(*sq_node, 0), nullptr); - EXPECT_NE(nt_node->GetType(), DATA); - EXPECT_EQ(nt_node->GetType(), NETOUTPUT); - EXPECT_NE(nt_node->GetType(), PARTITIONEDCALL); - EXPECT_EQ(nt_node->GetOpDesc()->GetSubgraphInstanceNames().empty(), true); - EXPECT_EQ(NodeUtils::GetInNodeCrossPartionedCallNode(nt_node, 0 , expect_peer_node), GRAPH_SUCCESS); -} - - -TEST_F(UtestNodeUtils, GetInNodeCrossSubgraph){ - auto graph = BuildGraphPartitionCall(); - NodePtr expect_peer_node; - NodePtr dt_node; - for (auto &node : graph->GetAllNodes()) { - if (node->GetType() == DATA) { - dt_node = node; - } - } - EXPECT_NE(dt_node, nullptr); - EXPECT_NE(NodeUtils::GetInNodeCrossSubgraph(dt_node), nullptr); - auto owner_graph = dt_node->GetOwnerComputeGraph(); - owner_graph->impl_->parent_node_ = MakeNullptr(); - EXPECT_NE(NodeUtils::GetInNodeCrossSubgraph(dt_node), nullptr); -} - -TEST_F(UtestNodeUtils, GetConstOpType) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("nt", NETOUTPUT, 0, 1); - std::string op_type; - EXPECT_EQ(NodeUtils::GetConstOpType(data, op_type), false); -} - -TEST_F(UtestNodeUtils, IsWhileVaryingInputFalse) { - EXPECT_EQ(NodeUtils::IsWhileVaryingInput(nullptr), false); -} - -TEST_F(UtestNodeUtils, IsSubgraphOutput) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 11, 22); - EXPECT_EQ(NodeUtils::IsSubgraphOutput(node), false); - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", "_is_unknown_shape", 0, 1); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", "_is_unknown_shape", 1, 1); - root_builder.AddDataEdge(partitioncall_0, 0, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_const1 = p1_sub_builder.AddNode("partitioncall_0_const1", CONSTANT, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - EXPECT_EQ(NodeUtils::IsSubgraphOutput(partitioncall_0_const1), false); -} - -TEST_F(UtestNodeUtils, IsDynamicShape_Null) { - EXPECT_EQ(NodeUtils::IsDynamicShape(nullptr), false); -} - -TEST_F(UtestNodeUtils, GetInDataNodeAndAnchorByIndex_InAnchorOutOfRange) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 1, 1); - EXPECT_EQ(NodeUtils::GetInDataNodeAndAnchorByIndex(*node, 1).first, nullptr); -} -TEST_F(UtestNodeUtils, GetInDataNodeAndAnchorByIndex_NoPeerOutAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 1, 1); - EXPECT_EQ(NodeUtils::GetInDataNodeAndAnchorByIndex(*node, 0).first, nullptr); -} -TEST_F(UtestNodeUtils, GetInDataNodeAndAnchorByIndex_Success) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node1 = builder.AddNode("Node1", "Node1", 1, 1); - auto node2 = builder.AddNode("Node2", "Node2", 1, 1); - builder.AddDataEdge(node1, 0, node2, 0); - EXPECT_EQ(NodeUtils::GetInDataNodeAndAnchorByIndex(*node2, 0).first, node1); - EXPECT_EQ(NodeUtils::GetInDataNodeAndAnchorByIndex(*node2, 0).second, node1->GetOutDataAnchor(0)); -} -TEST_F(UtestNodeUtils, IsDtResourceNode_Success) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node1 = builder.AddNode("Node1", "Node1", 1, 1); - auto in_desc1 = node1->GetOpDesc()->MutableInputDesc(0); - in_desc1->SetDataType(DT_RESOURCE); - EXPECT_EQ(NodeUtils::IsDtResourceNode(node1), true); - auto node2 = builder.AddNode("Node2", "Node2", 1, 1); - auto out_desc2 = node2->GetOpDesc()->MutableOutputDesc(0); - out_desc2->SetDataType(DT_RESOURCE); - EXPECT_EQ(NodeUtils::IsDtResourceNode(node2), true); -} -TEST_F(UtestNodeUtils, IsIdentityUsefulForRWControl) { - ComputeGraphPtr graph = BuildGraphWithUsefulIdentity(); - auto node = graph->FindNode("id1"); - EXPECT_NE(node, nullptr); - // id1 is useful, not remove - EXPECT_EQ(NodeUtils::IsIdentityUsefulForRWControl(node), true); -} -TEST_F(UtestNodeUtils, FindRootGraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Node", "Node", 1, 1); - EXPECT_NE(node, nullptr); - EXPECT_EQ(NodeUtils::FindRootGraph(*node), node->GetOwnerComputeGraph()); -} -TEST_F(UtestNodeUtils, GetOutControlNodes) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &src_node = builder.AddNode("src_node", DATA, 1, 1); - const auto &ctrl_node = builder.AddNode("ctrl_node", CONSTANT, 0, 0); - const auto &ctrl_node2 = builder.AddNode("ctrl_node2", CONSTANT, 0, 0); - auto graph = builder.GetGraph(); - builder.AddControlEdge(src_node, ctrl_node); - builder.AddControlEdge(src_node, ctrl_node2); - EXPECT_EQ(NodeUtils::GetOutControlNodes(*src_node, nullptr).size(), 2U); - NodeFilter node_filter = [&](const Node &node) { return node.GetName() == ctrl_node2->GetName(); }; - EXPECT_EQ(NodeUtils::GetOutControlNodes(*src_node, node_filter).size(), 1U); - EXPECT_EQ(NodeUtils::GetOutControlNodes(*src_node, node_filter).front(), ctrl_node2); -} -TEST_F(UtestNodeUtils, GetInControlNodes) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &ctrl_node = builder.AddNode("ctrl_node", CONSTANT, 0, 0); - const auto &ctrl_node2 = builder.AddNode("ctrl_node2", CONSTANT, 0, 0); - const auto &dst_node = builder.AddNode("dst_node", NETOUTPUT, 0, 0); - - auto graph = builder.GetGraph(); - builder.AddControlEdge(ctrl_node, dst_node); - builder.AddControlEdge(ctrl_node2, dst_node); - EXPECT_EQ(NodeUtils::GetInControlNodes(*dst_node, nullptr).size(), 2U); - NodeFilter node_filter = [&](const Node &node) { return node.GetName() == ctrl_node2->GetName(); }; - EXPECT_EQ(NodeUtils::GetInControlNodes(*dst_node, node_filter).size(), 1U); - EXPECT_EQ(NodeUtils::GetInControlNodes(*dst_node, node_filter).front(), ctrl_node2); -} - -TEST_F(UtestNodeUtils, TryGetWeightByPlaceHolderNode_invalid) { - auto node = std::make_shared(); - auto ge_tensor = std::make_shared(); - EXPECT_NE(NodeUtils::TryGetWeightByPlaceHolderNode(node, ge_tensor), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, TryGetWeightByDataNode_invalid) { - auto node = std::make_shared(); - auto ge_tensor = std::make_shared(); - EXPECT_NE(NodeUtils::TryGetWeightByDataNode(node, ge_tensor), GRAPH_SUCCESS); -} - -TEST_F(UtestNodeUtils, GetParentInput_invalid) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &data_node = builder.AddNode("data", DATA, 0, 0); - auto graph = builder.GetGraph(); - auto node = graph->FindNode("data"); - AttrUtils::SetInt(node->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX,1); - EXPECT_EQ(NodeUtils::GetParentInput(node), nullptr); -} - -TEST_F(UtestNodeUtils, TryGetWeightByPlaceHolderNode_fail) { - auto builder = ut::GraphBuilder("test_graph0"); - const auto &pld = builder.AddNode("pld", PLACEHOLDER, 1, 1); - ConstGeTensorPtr ge_tensor = nullptr; - EXPECT_EQ(NodeUtils::TryGetWeightByPlaceHolderNode(pld, ge_tensor), GRAPH_SUCCESS); - EXPECT_TRUE(ge_tensor == nullptr); - const auto &parent_node = builder.AddNode("fake", "fake", 1, 1); - EXPECT_TRUE(pld->GetOpDesc()->SetExtAttr("parentNode", parent_node)); - EXPECT_EQ(NodeUtils::TryGetWeightByPlaceHolderNode(pld, ge_tensor), GRAPH_SUCCESS); - EXPECT_TRUE(ge_tensor == nullptr); -} - -TEST_F(UtestNodeUtils, Verify_update_output_name_with_default_name) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "FakeData", 0, 1); - EXPECT_EQ(data_node->GetOpDesc()->GetAllOutputName().cbegin()->first, "__output0"); - EXPECT_EQ(NodeUtilsEx::Verify(data_node), GRAPH_SUCCESS); - EXPECT_EQ(data_node->GetOpDesc()->GetAllOutputName().cbegin()->first, "y"); -} -TEST_F(UtestNodeUtils, Verify_update_output_name_with_empty_output) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "FakeData", 0, 1); - std::map output_name_idx; - data_node->GetOpDesc()->UpdateOutputName(output_name_idx); - int32_t event_level; - int32_t old_level = dlog_getlevel(GE_MODULE_NAME, &event_level); - dlog_setlevel(GE_MODULE_NAME, DLOG_DEBUG, event_level); - EXPECT_EQ(NodeUtilsEx::Verify(data_node), GRAPH_SUCCESS); - dlog_setlevel(GE_MODULE_NAME, old_level, event_level); - EXPECT_EQ(data_node->GetOpDesc()->GetAllOutputName().cbegin()->first, "y"); -} - -TEST_F(UtestNodeUtils, Verify_no_need_update_output_name_already_has) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "FakeData", 0, 1); - data_node->GetOpDesc()->UpdateOutputName({{"y", 0}}); - EXPECT_EQ(NodeUtilsEx::Verify(data_node), GRAPH_SUCCESS); - EXPECT_EQ(data_node->GetOpDesc()->GetAllOutputName().cbegin()->first, "y"); -} - -TEST_F(UtestNodeUtils, Verify_noneed_update_output_name) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "FakeData", 1, 1); - data_node->impl_->in_data_anchors_.at(0) = nullptr; - data_node->GetOpDesc()->UpdateInputName({{"xx", 0}}); - data_node->GetOpDesc()->UpdateOutputName({{"yy", 0}}); - EXPECT_EQ(data_node->GetOpDesc()->GetAllOutputName().cbegin()->first, "yy"); - - EXPECT_EQ(NodeUtilsEx::Verify(data_node), GRAPH_SUCCESS); - - EXPECT_EQ(data_node->GetOpDesc()->GetAllOutputName().cbegin()->first, "yy"); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/object_pool_unittest.cc b/tests/ut/graph/testcase/object_pool_unittest.cc deleted file mode 100644 index e32b1cc3c26c744a736b6f21a5e9bde37a943586..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/object_pool_unittest.cc +++ /dev/null @@ -1,100 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include -#include - -#include "graph/utils/object_pool.h" -#include "graph/ge_tensor.h" - -using std::vector; -namespace ge { -class UTObjectPool : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UTObjectPool, Add) { - ObjectPool object_pool_; - ASSERT_TRUE(object_pool_.IsEmpty()); - - auto ge_tensor = object_pool_.Acquire(); - GeTensorDesc tensor_desc(GeShape({10})); - ge_tensor->SetTensorDesc(tensor_desc); - - float dt[10] = {1.0f}; - auto deleter = [](const uint8_t *ptr) { - - }; - ge_tensor->SetData((uint8_t *)&dt, sizeof(dt), deleter); - object_pool_.Release(std::move(ge_tensor)); - ASSERT_EQ(object_pool_.handlers_.size(), 1); -} - -TEST_F(UTObjectPool, UniqueToShared) { - ObjectPool object_pool_; - auto ge_tensor = object_pool_.Acquire(); - GeTensorDesc tensor_desc(GeShape({10})); - ge_tensor->SetTensorDesc(tensor_desc); - - float dt[10] = {1.0f}; - auto deleter = [](const uint8_t *ptr) { - - }; - ge_tensor->SetData((uint8_t *)&dt, sizeof(dt), deleter); - - { - std::shared_ptr shared_tensor(ge_tensor.get(), [](GeTensor *){}); - } - ASSERT_NE(ge_tensor, nullptr); - object_pool_.Release(std::move(ge_tensor)); - ASSERT_EQ(object_pool_.handlers_.size(), 1); -} - -TEST_F(UTObjectPool, GetFromFull) { - ObjectPool object_pool_; - - auto ge_tensor = object_pool_.Acquire(); - GeTensorDesc tensor_desc(GeShape({10})); - ge_tensor->SetTensorDesc(tensor_desc); - float dt[10] = {1.0f}; - auto deleter = [](const uint8_t *ptr) { - }; - ge_tensor->SetData((uint8_t *)&dt, sizeof(dt), deleter); - object_pool_.Release(std::move(ge_tensor)); - - ASSERT_TRUE(object_pool_.IsFull()); - auto tmp = object_pool_.Acquire(); - ASSERT_TRUE(object_pool_.IsEmpty()); -} - - -TEST_F(UTObjectPool, AutoRelease) { - ObjectPool object_pool_; - auto ge_tensor = object_pool_.Acquire(); - GeTensorDesc tensor_desc(GeShape({10})); - ge_tensor->SetTensorDesc(tensor_desc); - - float dt[10] = {1.0f}; - auto deleter = [](const uint8_t *ptr) { - }; - ge_tensor->SetData((uint8_t *)&dt, sizeof(dt), deleter); - { - std::queue> shared_tensors_; - shared_tensors_.push(std::move(ge_tensor)); - - } - ASSERT_EQ(ge_tensor, nullptr); - object_pool_.Release(std::move(ge_tensor)); - ASSERT_TRUE(object_pool_.IsEmpty()); -} -} diff --git a/tests/ut/graph/testcase/op_desc_unittest.cc b/tests/ut/graph/testcase/op_desc_unittest.cc deleted file mode 100644 index d9865576ceccfd070b7c2720e9722468b00b6cc8..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/op_desc_unittest.cc +++ /dev/null @@ -1,1140 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/op_desc.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/ge_tensor.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/transformer_utils.h" -#include "graph/common_error_codes.h" -#include "graph/operator_factory_impl.h" -#include "register/op_tiling_registry.h" -#include "external/graph/operator_factory.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "external/graph/operator_reg.h" -#include "external/register/op_impl_registry.h" -#include "graph/debug/ge_attr_define.h" - -#include -#include "mmpa/mmpa_api.h" - -namespace ge { -class UtestOpDesc : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestOpDesc, TestCommonVerifyOnDummyShape) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({-3})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto op_desc = std::make_shared("test", "Identity"); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddOutputDesc(tensor_desc->Clone()); - - EXPECT_EQ(GRAPH_SUCCESS, op_desc->CommonVerify()); -} - -TEST_F(UtestOpDesc, TestOpDescGetSetTensorDesc) { - GeTensorDesc desc(GeShape(), FORMAT_NCHW, DT_INT32); - OpDesc op_desc("foo", "Foo"); - EXPECT_EQ(GRAPH_SUCCESS, op_desc.AddInputDesc("x", desc)); - EXPECT_EQ(GRAPH_SUCCESS, op_desc.AddOutputDesc("y", desc)); - - EXPECT_EQ(op_desc.GetInputDesc("x"), desc); - EXPECT_EQ(op_desc.GetOutputDesc("y"), desc); -} - -TEST_F(UtestOpDesc, TestNodeShapeTransUtils) { - - NodeShapeTransUtils transformer1(nullptr); - EXPECT_NE(transformer1.Init(), true); - - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1, 1, 16, 16})); - tensor_desc->SetFormat(FORMAT_FRACTAL_NZ); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetOriginFormat(FORMAT_ND); - - auto op_desc = std::make_shared("test", "Identity"); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddOutputDesc(tensor_desc->Clone()); - NodeShapeTransUtils transformer2(op_desc); - EXPECT_EQ(transformer2.Init(), true); - EXPECT_EQ(transformer2.CatchFormatAndShape(), true); - EXPECT_EQ(transformer2.UpdateFormatAndShape(), true); - - - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddOutputDesc(tensor_desc->Clone()); - - NodeShapeTransUtils transformer3(op_desc); - EXPECT_EQ(transformer3.Init(), true); - EXPECT_EQ(transformer3.CatchFormatAndShape(), true); - EXPECT_EQ(transformer3.UpdateFormatAndShape(), true); - - - EXPECT_EQ(GRAPH_SUCCESS, op_desc->CommonVerify()); -} - -TEST_F(UtestOpDesc, SetNamePtr) { - auto op_desc = std::make_shared("test", "Identity"); - op_desc->SetNamePtr("abc"); - EXPECT_EQ(op_desc->GetName(), "abc"); -} - -TEST_F(UtestOpDesc, IndexOutOfRange) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto op_desc = std::make_shared("test", "Identity"); - op_desc->AddInputDesc(tensor_desc->Clone()); - - EXPECT_NE(nullptr, op_desc->MutableInputDesc(0)); - EXPECT_EQ(nullptr, op_desc->MutableInputDesc(1)); - EXPECT_EQ(nullptr, op_desc->MutableInputDesc(999)); -} - -TEST_F(UtestOpDesc, SerializeMetadata) { - OpDescImpl impl; - impl.meta_data_.inputs_.emplace_back("input"); - impl.meta_data_.input_names_.emplace_back("names"); - impl.meta_data_.src_names_.push_back("src"); - impl.meta_data_.dst_names_.push_back("dst"); - impl.meta_data_.dst_indexes_.push_back(2); - impl.meta_data_.src_indexes_.push_back(2); - impl.meta_data_.input_offsets_.push_back(987654321); - impl.meta_data_.output_offsets_.push_back(987654321); - impl.meta_data_.workspaces.push_back(222); - impl.meta_data_.workspace_bytes_list_.push_back(111); - impl.meta_data_.is_input_consts_.push_back(false); - - proto::OpDef def; - impl.SerializeMetaDataToOpDef(&def); - EXPECT_EQ(def.input(0), "input"); - EXPECT_EQ(def.input_name(0), "names"); - EXPECT_EQ(def.src_name(0), "src"); - EXPECT_EQ(def.dst_name(0), "dst"); - EXPECT_EQ(def.dst_index(0), 2); - EXPECT_EQ(def.src_index(0), 2); - EXPECT_EQ(def.input_i(0), 987654321); - EXPECT_EQ(def.output_i(0), 987654321); - EXPECT_EQ(def.workspace(0), 222); - EXPECT_EQ(def.workspace_bytes(0), 111); - EXPECT_EQ(def.is_input_const(0), false); -} - -TEST_F(UtestOpDesc, DeSerializeMetadata) { - proto::OpDef def; - def.add_input("input"); - def.add_input_name("names"); - def.add_src_name("src"); - def.add_dst_name("dst"); - def.add_dst_index(2); - def.add_src_index(2); - def.add_input_i(987654321); - def.add_output_i(987654321); - def.add_workspace(222); - def.add_workspace_bytes(222); - def.add_is_input_const(false); - OpDescImpl impl; - impl.DeSerializeOpDefToMetaData(def); - EXPECT_EQ(impl.meta_data_.inputs_.size(), 1); - EXPECT_EQ(impl.meta_data_.inputs_[0], "input"); - EXPECT_EQ(impl.meta_data_.input_names_.size(), 1); - EXPECT_EQ(impl.meta_data_.input_names_[0], "names"); - EXPECT_EQ(impl.meta_data_.src_names_.size(), 1); - EXPECT_EQ(impl.meta_data_.src_names_[0], "src"); - EXPECT_EQ(impl.meta_data_.dst_names_.size(), 1); - EXPECT_EQ(impl.meta_data_.dst_names_[0], "dst"); - EXPECT_EQ(impl.meta_data_.dst_indexes_.size(), 1); - EXPECT_EQ(impl.meta_data_.dst_indexes_[0], 2); - EXPECT_EQ(impl.meta_data_.src_indexes_.size(), 1); - EXPECT_EQ(impl.meta_data_.src_indexes_[0], 2); - EXPECT_EQ(impl.meta_data_.input_offsets_.size(), 1); - EXPECT_EQ(impl.meta_data_.input_offsets_[0], 987654321); - EXPECT_EQ(impl.meta_data_.output_offsets_.size(), 1); - EXPECT_EQ(impl.meta_data_.output_offsets_[0], 987654321); - EXPECT_EQ(impl.meta_data_.workspaces.size(), 1); - EXPECT_EQ(impl.meta_data_.workspaces[0], 222); - EXPECT_EQ(impl.meta_data_.workspace_bytes_list_.size(), 1); - EXPECT_EQ(impl.meta_data_.workspace_bytes_list_[0], 222); - EXPECT_EQ(impl.meta_data_.is_input_consts_.size(), 1); - EXPECT_EQ(impl.meta_data_.is_input_consts_[0], false); - - OpDescImpl impl1; - impl1.DeSerializeOpDefToMetaData(def); - EXPECT_TRUE(impl1.OpDescAttrsAreEqual(impl)); -} - -TEST_F(UtestOpDesc, AddDescForward) { - GeTensorDesc desc(GeShape(), FORMAT_NCHW, DT_INT32); - OpDesc op_desc("foo", "Foo"); - EXPECT_EQ(GRAPH_SUCCESS, op_desc.AddOutputDesc("x", desc)); - EXPECT_EQ(GRAPH_SUCCESS, op_desc.AddOutputDesc("y", desc)); - EXPECT_EQ(GRAPH_SUCCESS, op_desc.AddOutputDesc("z", desc)); - EXPECT_EQ(GRAPH_SUCCESS, op_desc.AddOutputDescForward("t", 2)); - - EXPECT_EQ(5, op_desc.GetOutputsSize()); -} - -TEST_F(UtestOpDesc, AddInputDesc1_success) { - auto op_desc = std::make_shared(); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - EXPECT_EQ(op_desc->AddInputDesc(0, tensor_desc->Clone()), GRAPH_SUCCESS); - EXPECT_EQ(op_desc->AddInputDesc(0, tensor_desc->Clone()), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDesc, AddInputDesc2_success) { - auto op_desc = std::make_shared(); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - EXPECT_EQ(op_desc->AddInputDesc("input_desc1", tensor_desc->Clone()), GRAPH_SUCCESS); - EXPECT_EQ(op_desc->AddInputDesc("input_desc1", tensor_desc->Clone()), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDesc, AddInputDescMiddle_success) { - auto op_desc = std::make_shared(); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - op_desc->AddInputDesc("input_desc1", tensor_desc->Clone()); - op_desc->AddInputDesc("input_desc2", tensor_desc->Clone()); - - EXPECT_EQ(op_desc->AddInputDescMiddle("x", 2, 1), GRAPH_SUCCESS); - auto name_idx = op_desc->GetAllInputName(); - ASSERT_EQ(name_idx.size(), 4U); - EXPECT_EQ(name_idx["x0"], 1); - EXPECT_EQ(name_idx["x1"], 2); - EXPECT_EQ(name_idx["input_desc2"], 3); -} - -TEST_F(UtestOpDesc, AddOutputDescMiddle_success) { - auto op_desc = std::make_shared(); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - op_desc->AddOutputDesc("output_desc1", tensor_desc->Clone()); - op_desc->AddOutputDesc("output_desc2", tensor_desc->Clone()); - - EXPECT_EQ(op_desc->AddOutputDescMiddle("y", 2, 1), GRAPH_SUCCESS); - EXPECT_EQ(op_desc->AddOutputDescMiddle("output_desc4", 1, 5), GRAPH_FAILED); - auto name_idx = op_desc->GetAllOutputName(); - ASSERT_EQ(name_idx.size(), 4U); - EXPECT_EQ(name_idx["y0"], 1); - EXPECT_EQ(name_idx["y1"], 2); - EXPECT_EQ(name_idx["output_desc2"], 3); -} - -TEST_F(UtestOpDesc, UpdateInputDesc_success) { - auto op_desc = std::make_shared(); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - op_desc->AddInputDesc("input_desc1", tensor_desc->Clone()); - op_desc->AddInputDesc("input_desc2", tensor_desc->Clone()); - - EXPECT_EQ(op_desc->UpdateInputDesc(1, tensor_desc->Clone()), GRAPH_SUCCESS); - EXPECT_EQ(op_desc->UpdateInputDesc(4, tensor_desc->Clone()), GRAPH_FAILED); -} - -TEST_F(UtestOpDesc, UpdateInputDescForward_success) { - auto op_desc = std::make_shared(); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - op_desc->AddInputDesc("input1", tensor_desc->Clone()); - EXPECT_EQ(op_desc->AddDynamicInputDesc("x", 2, false), GRAPH_SUCCESS); - auto input_name_idx = op_desc->GetAllInputName(); - ASSERT_EQ(input_name_idx.size(), 3U); - EXPECT_EQ(input_name_idx["x0"], 0); - EXPECT_EQ(input_name_idx["x1"], 1); - EXPECT_EQ(input_name_idx["input1"], 2); -} - -TEST_F(UtestOpDesc, AddOutputDescForward_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto op_desc = std::make_shared(); - op_desc->AddOutputDesc(tensor_desc->Clone()); - EXPECT_EQ(op_desc->AddOutputDescForward("y", 2), GRAPH_SUCCESS); - - auto output_name_idx = op_desc->GetAllOutputName(); - ASSERT_EQ(output_name_idx.size(), 3U); - EXPECT_EQ(output_name_idx["y0"], 0); - EXPECT_EQ(output_name_idx["y1"], 1); - EXPECT_EQ(output_name_idx["__output0"], 2); -} - -TEST_F(UtestOpDesc, AddOptionalInputDesc_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto op_desc = std::make_shared(); - EXPECT_EQ(op_desc->AddOptionalInputDesc("test", tensor_desc->Clone()), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDesc, OpDescMembersAreEqual_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - auto op_desc1 = std::make_shared(); - op_desc1->AddInputDesc("input_desc", tensor_desc->Clone()); - op_desc1->AddOutputDesc("output_desc", tensor_desc->Clone()); - op_desc1->AddOptionalInputDesc("optional_input", tensor_desc->Clone()); - op_desc1->SetOpEngineName("DNN_VM_HOST_CPU"); - op_desc1->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - - auto op_desc2 = std::make_shared(); - op_desc1->AddInputDesc("input_desc_diff", tensor_desc->Clone()); - op_desc1->AddOutputDesc("output_desc", tensor_desc->Clone()); - op_desc1->AddOptionalInputDesc("optional_input", tensor_desc->Clone()); - op_desc1->SetOpEngineName("DNN_VM_HOST_CPU"); - op_desc1->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - - auto op_desc3 = op_desc1; - - EXPECT_EQ(op_desc1->OpDescMembersAreEqual(*(op_desc3)), true); - EXPECT_EQ(op_desc1->OpDescMembersAreEqual(*(op_desc2)), false); -} - -TEST_F(UtestOpDesc, OpDescGenTensorDescsAreEqual_success) { - auto tensor_desc1 = std::make_shared(); - tensor_desc1->SetShape(GeShape({1})); - tensor_desc1->SetFormat(FORMAT_NCHW); - tensor_desc1->SetDataType(DT_FLOAT); - - auto tensor_desc2 = std::make_shared(); - tensor_desc2->SetShape(GeShape({-1})); - tensor_desc2->SetFormat(FORMAT_NHWC); - tensor_desc2->SetDataType(DT_INT32); - - auto op_desc1 = std::make_shared(); - op_desc1->AddInputDesc(tensor_desc1->Clone()); - auto op_desc2 = std::make_shared(); - EXPECT_EQ(op_desc1->OpDescGenTensorDescsAreEqual(*(op_desc2)), false); - op_desc2->AddInputDesc(tensor_desc2->Clone()); - op_desc1->AddOutputDesc(tensor_desc1->Clone()); - EXPECT_EQ(op_desc1->OpDescGenTensorDescsAreEqual(*(op_desc2)), false); - op_desc2->AddOutputDesc(tensor_desc2->Clone()); - auto op_desc3 = std::make_shared(); - EXPECT_EQ(op_desc1->OpDescGenTensorDescsAreEqual(*(op_desc2)), false); - op_desc3->AddInputDesc(tensor_desc1->Clone()); - op_desc3->AddOutputDesc(tensor_desc2->Clone()); - EXPECT_EQ(op_desc1->OpDescGenTensorDescsAreEqual(*(op_desc3)), false); - EXPECT_EQ(op_desc1->OpDescGenTensorDescsAreEqual(*(op_desc1)), true); -} - -TEST_F(UtestOpDesc, InputIsSet_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - auto op_desc = std::make_shared(); - EXPECT_EQ(op_desc->InputIsSet("input_test"), false); - op_desc->AddInputDesc("input_test",tensor_desc->Clone()); - EXPECT_EQ(op_desc->InputIsSet("input_test"), true); -} - -TEST_F(UtestOpDesc, MutableInputDesc_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc("input_test1",tensor_desc->Clone()); - EXPECT_EQ(op_desc->MutableInputDesc("input_test"), nullptr); - EXPECT_NE(op_desc->MutableInputDesc("input_test1"), nullptr); -} - -TEST_F(UtestOpDesc, Get_SetOpKernelLibName_success) { - auto op_desc = std::make_shared(); - op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - EXPECT_EQ(op_desc->GetOpKernelLibName(), "DNN_VM_RTS_OP_STORE"); -} - -TEST_F(UtestOpDesc, Get_SetOpEngineName_success) { - auto op_desc = std::make_shared(); - op_desc->SetOpEngineName("DNN_VM_HOST_CPU"); - EXPECT_EQ(op_desc->GetOpEngineName(), "DNN_VM_HOST_CPU"); -} - -TEST_F(UtestOpDesc, GetAllOutputsDescSize_sucess) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - auto op_desc = std::make_shared(); - op_desc->AddOutputDesc(tensor_desc->Clone()); - op_desc->AddOutputDesc(tensor_desc->Clone()); - EXPECT_EQ(op_desc->GetAllOutputsDescSize(), 2); -} - -TEST_F(UtestOpDesc, AddDynamicInputDescByIndex_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc("input_test1",tensor_desc->Clone()); - op_desc->AddInputDesc("input_test2",tensor_desc->Clone()); - EXPECT_EQ(op_desc->AddDynamicInputDescByIndex("input_test2", 1, 1), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDesc, IsOptionalInput_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - auto op_desc = std::make_shared(); - op_desc->AddOptionalInputDesc("optional_test", tensor_desc->Clone()); - op_desc->AddInputDesc("input_test", tensor_desc->Clone()); - EXPECT_EQ(op_desc->IsOptionalInput("input_test"), false); - EXPECT_EQ(op_desc->IsOptionalInput("optional_test"), true); -} - -TEST_F(UtestOpDesc, GetAllOutputName_success) { - auto op_desc = std::make_shared(); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - op_desc->AddOutputDesc("output1", tensor_desc->Clone()); - op_desc->AddOutputDesc("output2", tensor_desc->Clone()); - std::map all_output; - all_output = op_desc->GetAllOutputName(); - EXPECT_EQ(all_output.size(), 2); - EXPECT_EQ(all_output["output1"], 0); - EXPECT_EQ(all_output["output2"], 1); -} - -TEST_F(UtestOpDesc, UpdateInputName_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto op_desc = std::make_shared(); - - op_desc->AddInputDesc("name1", tensor_desc->Clone()); - op_desc->AddInputDesc("name2", tensor_desc->Clone()); - - std::map input_name_idx; - input_name_idx.insert(pair("update_name1", 0)); - EXPECT_EQ(op_desc->UpdateInputName(input_name_idx), false); - input_name_idx.insert(pair("update_name2", 1)); - EXPECT_EQ(op_desc->UpdateInputName(input_name_idx), true); - auto all_input_name = op_desc->GetAllInputName(); - EXPECT_EQ(input_name_idx, all_input_name); - input_name_idx.insert(pair("update_name3", 2)); - EXPECT_EQ(op_desc->UpdateInputName(input_name_idx), true); -} - -TEST_F(UtestOpDesc, UpdateInputOutName_with_dynamic_failed) { - auto op_desc = std::make_shared(); - op_desc->AppendIrInput("query", IrInputType::kIrInputRequired); - op_desc->AppendIrInput("k", IrInputType::kIrInputDynamic); - op_desc->AppendIrInput("value", IrInputType::kIrInputDynamic); - op_desc->AppendIrInput("padding_mask", IrInputType::kIrInputOptional); - op_desc->AppendIrInput("attention_mask", IrInputType::kIrInputOptional); - op_desc->AppendIrInput("seq_lens", IrInputType::kIrInputOptional); - op_desc->AppendIrOutput("attention_out", IrOutputType::kIrOutputDynamic); - op_desc->AppendIrOutput("fake_out", IrOutputType::kIrOutputRequired); - - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - op_desc->AddInputDesc("query", tensor_desc->Clone()); - op_desc->AddDynamicInputDescByIndex("k", 1, 1); - op_desc->UpdateInputDesc(1, tensor_desc->Clone()); - op_desc->AddDynamicInputDescByIndex("value", 1, 2); - op_desc->UpdateInputDesc(2, tensor_desc->Clone()); - - std::map input_name_idx{{"query", 0}, - {"padding_mask", 1}, - {"attention_mask", 2}, - {"seq_lens", 3}}; - EXPECT_EQ(op_desc->UpdateInputName(input_name_idx), false); -} - -TEST_F(UtestOpDesc, UpdateOutputName_success) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - auto op_desc = std::make_shared(); - - op_desc->AddOutputDesc("name1", tensor_desc->Clone()); - op_desc->AddOutputDesc("name2", tensor_desc->Clone()); - - std::map output_name_idx; - output_name_idx.insert(pair("update_name1", 0)); - EXPECT_EQ(op_desc->UpdateOutputName(output_name_idx), false); - output_name_idx.insert(pair("update_name2", 1)); - EXPECT_EQ(op_desc->UpdateOutputName(output_name_idx), true); - auto all_output_name = op_desc->GetAllOutputName(); - EXPECT_EQ(output_name_idx, all_output_name); - output_name_idx.insert(pair("update_name3", 2)); - EXPECT_EQ(op_desc->UpdateOutputName(output_name_idx), true); -} - -TEST_F(UtestOpDesc, GetInferFunc_success) { - auto op_desc = std::make_shared(); - const auto add_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - op_desc->AddInferFunc(add_func); - - Operator op; - auto func = op_desc->GetInferFunc(); - EXPECT_EQ(func == nullptr, false); - EXPECT_EQ(func(op), GRAPH_SUCCESS); -} - -// infer from output -REG_OP(FixIOOp_OutputIsFix) - .INPUT(fix_input1, "T") - .INPUT(fix_input2, "T") - .OUTPUT(fix_output, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(FixIOOp_OutputIsFix); -TEST_F(UtestOpDesc, CallInferV2Func_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1,1,1,1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; - auto infer_shape_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - const ge::GeTensorDesc &input_desc = op_desc->GetInputDesc(0UL); - return op_desc->UpdateOutputDesc(0UL, input_desc); - }; - auto infer_shape_range_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - return GRAPH_SUCCESS; - }; - auto infer_data_type_func = [](const OpDescPtr &op) -> uint32_t { - return GRAPH_SUCCESS; - }; - ge::OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; - (void) ge::OperatorFactoryImpl::RegisterInferShapeV2Func(infer_shape_func); - (void) ge::OperatorFactoryImpl::RegisterInferShapeRangeFunc(infer_shape_range_func); - (void) ge::OperatorFactoryImpl::RegisterInferDataTypeFunc(infer_data_type_func); - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetDataType(), DT_FLOAT16); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 4); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDim(0), 1); - ge::OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; -} - -TEST_F(UtestOpDesc, CallInferFunc_by_shape_value_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - std::vector shape_values = {1, 2, 3}; - GeShape shape(shape_values); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; - auto infer_shape_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - const ge::GeTensorDesc &input_desc = op_desc->GetInputDesc(0UL); - return op_desc->UpdateOutputDesc(0UL, input_desc); - }; - auto infer_shape_range_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - return GRAPH_SUCCESS; - }; - auto infer_data_type_func = [](const OpDescPtr &op) -> uint32_t { - return GRAPH_SUCCESS; - }; - ge::OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; - (void) ge::OperatorFactoryImpl::RegisterInferShapeV2Func(infer_shape_func); - (void) ge::OperatorFactoryImpl::RegisterInferShapeRangeFunc(infer_shape_range_func); - (void) ge::OperatorFactoryImpl::RegisterInferDataTypeFunc(infer_data_type_func); - EXPECT_EQ(AttrUtils::SetListInt(op_desc, "_output_shapes", shape_values), true); - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetDataType(), DT_FLOAT16); - constexpr int32_t true_dim_num = 3; - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), true_dim_num); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDim(0), 1); - ge::OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetShape().GetDims(), shape_values); -} - -TEST_F(UtestOpDesc, CallInferFunc_by_shape_value_unknown_shape_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 2, -1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; - auto infer_shape_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - const ge::GeTensorDesc &input_desc = op_desc->GetInputDesc(0UL); - return op_desc->UpdateOutputDesc(0UL, input_desc); - }; - auto infer_shape_range_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - return GRAPH_SUCCESS; - }; - auto infer_data_type_func = [](const OpDescPtr &op) -> uint32_t { - return GRAPH_SUCCESS; - }; - ge::OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; - (void) ge::OperatorFactoryImpl::RegisterInferShapeV2Func(infer_shape_func); - (void) ge::OperatorFactoryImpl::RegisterInferShapeRangeFunc(infer_shape_range_func); - (void) ge::OperatorFactoryImpl::RegisterInferDataTypeFunc(infer_data_type_func); - std::vector> shape_values = {{1, 2, 3}}; - EXPECT_EQ(AttrUtils::SetListListInt(op_desc, "_preset_output_shapes", shape_values), true); - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - constexpr int32_t true_dim_num = 3; - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetDataType(), DT_FLOAT16); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), true_dim_num); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDim(0), 1); - ge::OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetShape().GetDims(), shape_values.at(0)); -} - -TEST_F(UtestOpDesc, CallInferV2Func_no_inferfunc_failed) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1,1,1,1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; // make v1 is null - - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = nullptr; // make v2 is null - op_impl_func.infer_datatype = nullptr; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("FixIOOp_OutputIsFix", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, GRAPH_FAILED); -} - -TEST_F(UtestOpDesc, CallInferV2Func_failed) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1,1,1,1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; - - auto infer_shape_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - return GRAPH_FAILED; - }; - auto infer_shape_range_func = [](const ge::Operator &op, const OpDescPtr &op_desc) -> uint32_t { - return GRAPH_SUCCESS; - }; - auto infer_data_type_func = [](const OpDescPtr &op) -> uint32_t { - return GRAPH_SUCCESS; - }; - (void) ge::OperatorFactoryImpl::RegisterInferShapeV2Func(infer_shape_func); - (void) ge::OperatorFactoryImpl::RegisterInferShapeRangeFunc(infer_shape_range_func); - (void) ge::OperatorFactoryImpl::RegisterInferDataTypeFunc(infer_data_type_func); - - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, PARAM_INVALID); - ge::OperatorFactoryImpl::operator_infer_shape_v2_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_datatype_func_ = nullptr; - ge::OperatorFactoryImpl::operator_infer_shape_range_func_ = nullptr; -} - -TEST_F(UtestOpDesc, CallInferFunc_failed) { - OpDescImpl op_desc_impl; - Operator op; - OpDescPtr op_desc; - OpDescUtilsEx::CallInferFunc(op_desc, op); - const auto func = [](Operator &op) { return GRAPH_SUCCESS; }; - op_desc_impl.infer_func_ = func; - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, PARAM_INVALID); - const auto infer_data_slice_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - - OpDescPtr odp = std::make_shared("name", "type"); - op_desc_impl.infer_func_ = infer_data_slice_func; - status = OpDescUtilsEx::CallInferFunc(odp, op); - ASSERT_NE(status, GRAPH_SUCCESS); //todo: check testcase - - const auto error_infer_shape_func = [](Operator &op) { - return GRAPH_FAILED; - }; - odp->AddInputDesc(GeTensorDesc()); - odp->AddInferFunc(error_infer_shape_func); - status = OpDescUtilsEx::CallInferFunc(odp, op); - ASSERT_EQ(status, GRAPH_FAILED); -} - -TEST_F(UtestOpDesc, InferDataSlice_success) { - auto op_desc = std::make_shared(); - EXPECT_EQ(OpDescUtilsEx::InferDataSlice(op_desc), NO_DEPENDENCE_FUNC); - const auto infer_data_slice_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - auto op = std::make_shared(); - op_desc->SetType("test"); - OperatorFactoryImpl::RegisterInferDataSliceFunc("test",infer_data_slice_func); - EXPECT_EQ(OpDescUtilsEx::InferDataSlice(op_desc), GRAPH_SUCCESS); -} - -REG_OP(MatMulUt) - .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .ATTR(transpose_x1, Bool, false) - .ATTR(transpose_x2, Bool, false) - .OP_END_FACTORY_REG(MatMulUt) - -REG_OP(AddUt) - .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OP_END_FACTORY_REG(AddUt) - -TEST_F(UtestOpDesc, SetTypeModifyIrAttrName_type_change) { - auto op = ge::OperatorFactory::CreateOperator("MatMul", "MatMulUt"); - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - EXPECT_NE(op_desc, nullptr); - EXPECT_FALSE(op_desc->GetIrAttrNames().empty()); - EXPECT_FALSE(op_desc->GetIrInputs().empty()); - op_desc->SetType("AddUt"); - - auto add_op = ge::OperatorFactory::CreateOperator("add", "AddUt"); - auto add_op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - EXPECT_TRUE(op_desc->GetIrAttrNames() == add_op_desc->GetIrAttrNames()); - EXPECT_TRUE(op_desc->GetIrInputs() == add_op_desc->GetIrInputs()); -} - -TEST_F(UtestOpDesc, SetTypeModifyIrAttrName_type_not_exist_clear) { - auto op = ge::OperatorFactory::CreateOperator("MatMul", "MatMul"); - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - EXPECT_NE(op_desc, nullptr); - EXPECT_FALSE(op_desc->GetIrAttrNames().empty()); - EXPECT_FALSE(op_desc->GetIrInputs().empty()); - - OpDescUtilsEx::SetType(op_desc, "NotExist"); - EXPECT_TRUE(op_desc->GetIrAttrNames().empty()); - EXPECT_TRUE(op_desc->GetIrInputs().empty()); -} - -TEST_F(UtestOpDesc, SetTypeModifyIrAttrName_type_not_change) { - auto op = ge::OperatorFactory::CreateOperator("MatMul", "MatMulUt"); - auto op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); - EXPECT_NE(op_desc, nullptr); - auto &check_ir_attr = op_desc->GetIrAttrNames(); - auto &check_ir_inputs = op_desc->GetIrInputs(); - EXPECT_FALSE(op_desc->GetIrAttrNames().empty()); - EXPECT_FALSE(op_desc->GetIrInputs().empty()); - - op_desc->SetType("MatMulUt"); - EXPECT_TRUE(op_desc->GetIrAttrNames() == check_ir_attr); - EXPECT_TRUE(op_desc->GetIrInputs() == check_ir_inputs); -} - -TEST_F(UtestOpDesc, InferShapeAndType_success) { - auto op_desc = std::make_shared(); - EXPECT_EQ(OpDescUtilsEx::InferShapeAndType(op_desc), GRAPH_SUCCESS); - const auto add_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - op_desc->AddInferFunc(add_func); - EXPECT_EQ(OpDescUtilsEx::InferShapeAndType(op_desc), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDesc, OpVerify_success) { - auto op_desc = std::make_shared(); - EXPECT_EQ(OpDescUtilsEx::OpVerify(op_desc), GRAPH_SUCCESS); - const auto verify_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - op_desc->AddVerifierFunc(verify_func); - EXPECT_EQ(OpDescUtilsEx::OpVerify(op_desc), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDesc, GetValidInputNameByIndex_success) { - auto op_desc = std::make_shared("verify", "Rule"); - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1})); - tensor_desc->SetFormat(FORMAT_NCHW); - tensor_desc->SetDataType(DT_FLOAT); - - op_desc->AddInputDesc("name1", tensor_desc->Clone()); - op_desc->AddInputDesc("name2", tensor_desc->Clone()); - - EXPECT_EQ(op_desc->GetValidInputNameByIndex(0), "name1"); - EXPECT_EQ(op_desc->GetValidInputNameByIndex(1), "name2"); -} - -TEST_F(UtestOpDesc, GetStreamId_success) { - auto op_desc = std::make_shared(); - op_desc->SetStreamId(1); - EXPECT_EQ(op_desc->GetStreamId(), 1); -} - -TEST_F(UtestOpDesc, AttachedStreamId) { - OpDescPtr op_desc_null = nullptr; - op_desc_null->SetAttachedStreamId(2); - - auto op_desc = std::make_shared(); - EXPECT_EQ(op_desc->GetAttachedStreamId(), -1); // default is -1 - EXPECT_FALSE(op_desc->HasValidAttachedStreamId()); - - op_desc->SetAttachedStreamId(2); - EXPECT_EQ(op_desc->GetAttachedStreamId(), 2); - EXPECT_TRUE(op_desc->HasValidAttachedStreamId()); - op_desc->SetAttachedStreamId(-1); // reset to invalid - EXPECT_FALSE(op_desc->HasValidAttachedStreamId()); -} - -TEST_F(UtestOpDesc, AttachedStreamIds) { - OpDescPtr op_desc_null = nullptr; - op_desc_null->SetAttachedStreamIds({2}); - - auto op_desc = std::make_shared(); - EXPECT_EQ(op_desc->GetAttachedStreamIds().size(), 0); // default size is 0 - EXPECT_FALSE(op_desc->HasValidAttachedStreamId()); - - op_desc->SetAttachedStreamIds({2, 3}); - EXPECT_EQ(op_desc->GetAttachedStreamIds().size(), 0); - - std::vector attached_stream_infos(3); - AttrUtils::SetListNamedAttrs(op_desc, ATTR_NAME_ATTACHED_STREAM_INFO_LIST, attached_stream_infos); - EXPECT_EQ(op_desc->GetAttachedStreamIds().size(), 3); - EXPECT_FALSE(op_desc->HasValidAttachedStreamId()); - op_desc->SetAttachedStreamIds({2, 3, 4}); - EXPECT_EQ(op_desc->GetAttachedStreamIds().size(), 3); - EXPECT_EQ(op_desc->GetAttachedStreamIds()[2], 4); - EXPECT_TRUE(op_desc->HasValidAttachedStreamId()); - - op_desc->SetAttachedStreamIds({-1}); // 设置失败,所以下一行校验会成功 - EXPECT_TRUE(op_desc->HasValidAttachedStreamId()); - - op_desc->SetAttachedStreamIds({-1, -1, -1}); - EXPECT_FALSE(op_desc->HasValidAttachedStreamId()); -} - -TEST_F(UtestOpDesc, Set_GetInputName_success) { - auto op_desc = std::make_shared(); - std::vector input_name {"name1", "name2"}; - op_desc->SetInputName(input_name); - auto get_input_name = op_desc->GetInputName(); - EXPECT_EQ(get_input_name.size(), 2); - EXPECT_EQ(get_input_name[0], "name1"); - EXPECT_EQ(get_input_name[1], "name2"); -} - -TEST_F(UtestOpDesc, GetSrcName_success) { - auto op_desc = std::make_shared(); - std::vector src_name {"src"}; - op_desc->SetSrcName(src_name); - auto get_src_name = op_desc->GetSrcName(); - EXPECT_EQ(get_src_name.size(), 1); - EXPECT_EQ(get_src_name[0], "src"); -} - -TEST_F(UtestOpDesc, GetSrcIndex_success) { - auto op_desc = std::make_shared(); - std::vector src_index{2}; - op_desc->SetSrcIndex(src_index); - auto get_src_index = op_desc->GetSrcIndex(); - EXPECT_EQ(get_src_index.size(), 1); - EXPECT_EQ(get_src_index[0], 2); -} - -TEST_F(UtestOpDesc, GetInputOffset_success) { - auto op_desc = std::make_shared(); - std::vector input_offset{987654321}; - op_desc->SetInputOffset(input_offset); - auto get_input_offset = op_desc->GetInputOffset(); - EXPECT_EQ(get_input_offset.size(), 1); - EXPECT_EQ(get_input_offset[0], 987654321); -} - -TEST_F(UtestOpDesc, GetOutputOffset_success) { - auto op_desc = std::make_shared(); - std::vector output_offset{987654321}; - op_desc->SetOutputOffset(output_offset); - auto get_output_offset = op_desc->GetOutputOffset(); - EXPECT_EQ(get_output_offset.size(), 1); - EXPECT_EQ(get_output_offset[0], 987654321); -} - -TEST_F(UtestOpDesc, GetDstName_success) { - auto op_desc = std::make_shared(); - std::vector dst_name{"dst"}; - op_desc->SetDstName(dst_name); - auto get_dst_name = op_desc->GetDstName(); - EXPECT_EQ(get_dst_name.size(), 1); - EXPECT_EQ(get_dst_name[0], "dst"); -} - -TEST_F(UtestOpDesc, Set_GetOpInferDepends_success) { - auto op_desc = std::make_shared("verify", "Rule"); - std::vector depend_names {"depend_name1", "depend_name2"}; - op_desc->SetOpInferDepends(depend_names); - auto get_depend_names = op_desc->GetOpInferDepends(); - EXPECT_EQ(get_depend_names.size(), 2); - EXPECT_EQ(get_depend_names[0], "depend_name1"); - EXPECT_EQ(get_depend_names[1], "depend_name2"); -} - -TEST_F(UtestOpDesc, GetWorkspace_success) { - auto op_desc = std::make_shared(); - std::vector workspace{222}; - op_desc->SetWorkspace(workspace); - auto get_workspace = op_desc->GetWorkspace(); - EXPECT_EQ(get_workspace.size(), 1); - EXPECT_EQ(get_workspace[0], 222); -} - -TEST_F(UtestOpDesc, GetSubgraphNameByInstanceName_success) { - auto op_desc = std::make_shared(); - op_desc->AddSubgraphName("subgraph"); - op_desc->SetSubgraphInstanceName(0, "subgraph"); - std::string subname(""); - EXPECT_EQ(op_desc->GetSubgraphNameByInstanceName("subgraph", subname), GRAPH_SUCCESS); - EXPECT_EQ(subname, "subgraph"); - - auto op_desc1 = std::make_shared(); - op_desc1->AddSubgraphName("subgraph1"); - op_desc1->SetSubgraphInstanceName(0, "sub"); - EXPECT_EQ(op_desc1->GetSubgraphNameByInstanceName("sub", subname), GRAPH_SUCCESS); - EXPECT_EQ(subname, "subgraph1"); -} - -TEST_F(UtestOpDesc, GetTilingInfo) { - auto op_desc = std::make_shared(); - EXPECT_NE(op_desc, nullptr); - EXPECT_EQ(op_desc->GetTilingFuncInfo(), nullptr); - EXPECT_EQ(op_desc->GetAtomicTilingFuncInfo(), nullptr); - - ::optiling::OpTilingFuncInfo tiling_info, atomic_tiling_info; - op_desc->SetTilingFuncInfo(&tiling_info); - op_desc->SetAtomicTilingFuncInfo(&atomic_tiling_info); - EXPECT_EQ(op_desc->GetTilingFuncInfo(), &tiling_info); - EXPECT_EQ(op_desc->GetAtomicTilingFuncInfo(), &atomic_tiling_info); -} - -TEST_F(UtestOpDesc, CopyAssignTest) { - auto op_desc = std::make_shared(); - EXPECT_NE(op_desc, nullptr); - op_desc->SetType("Test"); - OpDescImpl op_desc_impl; - op_desc_impl = *(op_desc->impl_); - EXPECT_EQ(op_desc_impl.GetType(), op_desc->GetType()); - // same object - auto fake = &op_desc_impl; - op_desc_impl = *fake; - EXPECT_EQ(op_desc_impl.GetType(), op_desc->GetType()); -} - -TEST_F(UtestOpDesc, GetDynamicInputIndexesByName_Failed) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - op_desc->impl_->input_name_idx_ = {{"query0", 0}, {"query1", 10}, {"query2", 2}, {"query3", 3}}; - - std::vector indexes; - EXPECT_EQ(op_desc->GetDynamicInputIndexesByName("query", indexes), GRAPH_FAILED); - EXPECT_EQ(indexes.size(), 1); - EXPECT_EQ(indexes[0], 0); -} - -TEST_F(UtestOpDesc, GetDynamicInputIndexesByName_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - op_desc->impl_->input_name_idx_ = {{"query0", 0}, {"query1", 1}, {"query2", 2}, {"query3", 3}}; - - std::vector indexes; - EXPECT_EQ(op_desc->GetDynamicInputIndexesByName("query", indexes), GRAPH_SUCCESS); - EXPECT_EQ(indexes.size(), 4); - EXPECT_EQ(indexes[0], 0); - EXPECT_EQ(indexes[1], 1); - EXPECT_EQ(indexes[2], 2); - EXPECT_EQ(indexes[3], 3); -} - -TEST_F(UtestOpDesc, GetDynamicOutputIndexesByName_Failed) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - op_desc->impl_->output_name_idx_ = {{"query0", 0}, {"query1", 10}, {"query2", 2}, {"query3", 3}}; - - std::vector indexes; - EXPECT_EQ(op_desc->GetDynamicOutputIndexesByName("query", indexes), GRAPH_FAILED); - EXPECT_EQ(indexes.size(), 1); - EXPECT_EQ(indexes[0], 0); -} - -TEST_F(UtestOpDesc, GetDynamicOutputIndexesByName_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - op_desc->impl_->output_name_idx_ = {{"query0", 0}, {"query1", 1}, {"query2", 2}, {"query3", 3}}; - - std::vector indexes; - EXPECT_EQ(op_desc->GetDynamicOutputIndexesByName("query", indexes), GRAPH_SUCCESS); - EXPECT_EQ(indexes.size(), 4); - EXPECT_EQ(indexes[0], 0); - EXPECT_EQ(indexes[1], 1); - EXPECT_EQ(indexes[2], 2); - EXPECT_EQ(indexes[3], 3); -} -TEST_F(UtestOpDesc, CallInferFunc_frameworkop_skip_infer) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - op_desc->SetType("FrameworkOp"); - EXPECT_EQ(OpDescUtilsEx::CallInferFunc(op_desc, op), GRAPH_PARAM_INVALID); -} - -TEST_F(UtestOpDesc, GetAllOutputIndexToName_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - op_desc->MutableAllOutputName().clear(); - op_desc->MutableAllOutputName().emplace("test1", 0); - op_desc->MutableAllOutputName().emplace("test2", 1); - op_desc->MutableAllOutputName().emplace("test3", 2); - std::map expect_map{{0, "test1"}, {1, "test2"}, {2, "test3"}}; - auto map1 = op_desc->GetAllOutputIndexToName(); - EXPECT_EQ(expect_map, map1); -} - -TEST_F(UtestOpDesc, TestNodeShapeTransUtils_UpdateFormatAndShape) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape({1, 1, 16, 16})); - tensor_desc->SetFormat(FORMAT_ND); - tensor_desc->SetDataType(DT_FLOAT); - tensor_desc->SetOriginFormat(FORMAT_NCHW); - - auto op_desc = std::make_shared("test", "Identity"); - op_desc->AddOutputDesc(tensor_desc->Clone()); - NodeShapeTransUtils transformer2(op_desc); - EXPECT_EQ(transformer2.Init(), true); - EXPECT_EQ(transformer2.CatchFormatAndShape(), true); - tensor_desc->SetFormat(FORMAT_NCHW); - op_desc->UpdateOutputDesc(0, tensor_desc->Clone()); - EXPECT_EQ(transformer2.UpdateFormatAndShape(), true); -} - -REG_OP(phony_op_with_subgraphs) - .INPUT(x, "T") - .DYNAMIC_OUTPUT(output, TensorType::ALL()) - .GRAPH(static_graph) - .DYNAMIC_GRAPH(dynamic_graph) - .OP_END_FACTORY_REG(phony_op_with_subgraphs); - -TEST_F(UtestOpDesc, TestGetOrderedSubgraphs) { - auto op = OperatorFactory::CreateOperator("test_get_ordered_subgraph_name", "phony_op_with_subgraphs"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - auto order_subgraphs = op_desc->GetOrderedSubgraphIrNames(); - - EXPECT_EQ("static_graph", order_subgraphs[0].first); - EXPECT_EQ(kStatic, order_subgraphs[0].second); - EXPECT_EQ("dynamic_graph", order_subgraphs[1].first); - EXPECT_EQ(kDynamic, order_subgraphs[1].second); - - auto subgraphs = op_desc->GetSubgraphIrNames(); - std::pair subgraph_pair[2]; - int64_t idx = 0; - for (const auto &subgraph : subgraphs) { - subgraph_pair[idx++] = subgraph; - } - EXPECT_EQ("dynamic_graph", subgraph_pair[0].first); - EXPECT_EQ(kDynamic, subgraph_pair[0].second); - EXPECT_EQ("static_graph", subgraph_pair[1].first); - EXPECT_EQ(kStatic, subgraph_pair[1].second); -} -} diff --git a/tests/ut/graph/testcase/op_desc_utils_unittest.cc b/tests/ut/graph/testcase/op_desc_utils_unittest.cc deleted file mode 100644 index e4d0ca46b1f07386360f07e5816c80378aee8624..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/op_desc_utils_unittest.cc +++ /dev/null @@ -1,1397 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/utils/op_desc_utils.h" -#include "graph_builder_utils.h" -#include "graph/utils/constant_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/debug/ge_op_types.h" -#include "graph/runtime_inference_context.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/anchor_utils.h" -#include "test_std_structs.h" -#include "external/graph/operator_reg.h" -#include "common/ge_common/debug/ge_log.h" -#include "common/util/mem_utils.h" - -namespace ge { -class UtestOpDescUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - - -namespace { -/// Data const1 -/// \ / -/// addn -/// -ComputeGraphPtr BuildGraph1() { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto addn = builder.AddNode("addn", "AddN", 2, 1); - - int32_t weight[1] = {1}; - GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); - GeTensorPtr tensor0 = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); - OpDescUtils::SetWeights(const1, {tensor0}); - - builder.AddDataEdge(data, 0, addn, 0); - builder.AddDataEdge(const1, 0, addn, 1); - return builder.GetGraph(); -} -/// (p_const)addn const1 -/// / \ / -/// cast mul -/// -ComputeGraphPtr BuildGraph2() { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto addn = builder.AddNode("addn", "AddN", 0, 2); - auto const1 = builder.AddNode("const1", "Const", 0, 1); - auto cast = builder.AddNode("cast", "Cast", 1, 1); - auto mul = builder.AddNode("mul", "Mul", 2, 1); - - int32_t weight[1] = {1}; - GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); - GeTensorPtr tensor0 = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); - AttrUtils::SetBool(addn->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, true); - AttrUtils::SetListInt(addn->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0,1}); - AttrUtils::SetListTensor(addn->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT, {tensor0, tensor0}); - OpDescUtils::SetWeights(const1, {tensor0}); - - builder.AddDataEdge(addn, 0, cast, 0); - builder.AddDataEdge(addn, 1, mul, 0); - builder.AddDataEdge(const1, 0, mul, 1); - return builder.GetGraph(); -} -/// (p_const)addn const1 -/// / \ / -/// enter mul -/// | -/// cast -ComputeGraphPtr BuildGraph3() { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto addn = builder.AddNode("addn", "AddN", 0, 2); - auto const1 = builder.AddNode("const1", "Const", 0, 1); - auto enter = builder.AddNode("enter", "Enter", 1, 1); - auto cast = builder.AddNode("cast", "Cast", 1, 1); - auto mul = builder.AddNode("mul", "Mul", 2, 1); - - int32_t weight[1] = {1}; - GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); - GeTensorPtr tensor0 = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); - AttrUtils::SetBool(addn->GetOpDesc(), ATTR_NAME_POTENTIAL_CONST, true); - AttrUtils::SetListInt(addn->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT_INDICES, {0,1}); - AttrUtils::SetListTensor(addn->GetOpDesc(), ATTR_NAME_POTENTIAL_WEIGHT, {tensor0, tensor0}); - OpDescUtils::SetWeights(const1, {tensor0}); - - AttrUtils::SetBool(enter->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, true); - - builder.AddDataEdge(addn, 0, enter, 0); - builder.AddDataEdge(addn, 1, mul, 0); - builder.AddDataEdge(const1, 0, mul, 1); - builder.AddDataEdge(enter, 0, cast, 0); - return builder.GetGraph(); -} - -/// x0 a bias b -/// \ \ / / -/// DynamicOpUt -/// -ComputeGraphPtr BuildGraph4(size_t dynamic_input_num, bool has_optional_input) { - size_t optional_input_num = has_optional_input ? 1u : 0U; - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - - auto data2 = builder.AddNode("a", "Data", 1, 1); - auto data4 = builder.AddNode("b", "Data", 1, 1); - auto dynamic_op_ut = builder.AddNode("dynamic_op_ut", "DynamicOpUt", - 2 + dynamic_input_num + optional_input_num, 1); - - size_t dst_index = 0; - // dynamic input - for (size_t i = 0U; i < dynamic_input_num; ++i) { - auto data1 = builder.AddNode("x", "Data", 1, 1); - builder.AddDataEdge(data1, 0, dynamic_op_ut, dst_index++); - } - - // required input - builder.AddDataEdge(data2, 0, dynamic_op_ut, dst_index++); - - // optional input - for (size_t i = 0U; i < optional_input_num; ++i) { - auto data3 = builder.AddNode("bias", "Data", 1, 1); - builder.AddDataEdge(data3, 0, dynamic_op_ut, dst_index++); - } - - // required input - builder.AddDataEdge(data4, 0, dynamic_op_ut, dst_index++); - - auto graph = builder.GetGraph(); - auto dynamic_op_ut_node = graph->FindNode("dynamic_op_ut"); - auto op_desc = dynamic_op_ut_node->GetOpDesc(); - op_desc->AppendIrInput("x", kIrInputDynamic); - op_desc->AppendIrInput("a", kIrInputRequired); - op_desc->AppendIrInput("bias", kIrInputOptional); - op_desc->AppendIrInput("b", kIrInputRequired); - return graph; -} -/// Data -/// | -/// | ctrl_edge -/// noop -/// -ComputeGraphPtr BuildGraph5() { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 1, 1); - auto noop = builder.AddNode("noop", "NoOp", 1, 0); - - builder.AddControlEdge(data, noop); - return builder.GetGraph(); -} -} -TEST_F(UtestOpDescUtils, SetWeight) { - auto graph = BuildGraph1(); - - auto addn_node = graph->FindNode("addn"); - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - - map weight0; - weight0[-1] = tensor; - auto ret = ge::OpDescUtils::SetWeights(*addn_node, weight0); - EXPECT_NE(ret, 0); - - map weight1; - weight1[1] = tensor; - ret = ge::OpDescUtils::SetWeights(*addn_node, weight1); - EXPECT_EQ(ret, 0); - auto const_node = graph->FindNode("const1"); - auto const_tensor = OpDescUtils::MutableWeights(const_node); - EXPECT_EQ(const_tensor[0]->MutableData().size(), 3); - auto in_nodes = addn_node->GetInAllNodes(); - EXPECT_EQ(in_nodes.size(), 2); - - map weight2; - weight2[2] = tensor; - ret = ge::OpDescUtils::SetWeights(*addn_node, weight2); - EXPECT_EQ(ret, 0); - auto in_nodes1 = addn_node->GetInAllNodes(); - EXPECT_EQ(in_nodes1.size(), 3); -} - -TEST_F(UtestOpDescUtils, GetRealConstInputNodeAndAnchor) { - auto graph = BuildGraph1(); - auto add_node = graph->FindNode("addn"); - auto nodes_2_out_anchor = OpDescUtils::GetConstInputNodeAndAnchor(*add_node); - EXPECT_EQ(nodes_2_out_anchor.size(), 1); - EXPECT_EQ(nodes_2_out_anchor[0].first->GetName(), "const1"); - EXPECT_EQ(nodes_2_out_anchor[0].second->GetIdx(), 0); -} -TEST_F(UtestOpDescUtils, GetMixConstInputNodeAndAnchor) { - auto graph = BuildGraph2(); - auto mul_node = graph->FindNode("mul"); - auto nodes_2_out_anchor = OpDescUtils::GetConstInputNodeAndAnchor(*mul_node); - EXPECT_EQ(nodes_2_out_anchor.size(), 2); - EXPECT_EQ(nodes_2_out_anchor[0].first->GetName(), "addn"); - EXPECT_EQ(nodes_2_out_anchor[0].second->GetIdx(), 1); - EXPECT_EQ(nodes_2_out_anchor[1].first->GetName(), "const1"); - EXPECT_EQ(nodes_2_out_anchor[1].second->GetIdx(), 0); -} -TEST_F(UtestOpDescUtils, GetInputDataByIndexForMixInputConst) { - auto graph = BuildGraph2(); - auto mul_node = graph->FindNode("mul"); - auto nodes_2_out_anchor = OpDescUtils::GetConstInputNodeAndAnchor(*mul_node); - EXPECT_EQ(nodes_2_out_anchor.size(), 2); - EXPECT_EQ(nodes_2_out_anchor[0].first->GetName(), "addn"); - EXPECT_EQ(nodes_2_out_anchor[0].second->GetIdx(), 1); - EXPECT_EQ(nodes_2_out_anchor[1].first->GetName(), "const1"); - EXPECT_EQ(nodes_2_out_anchor[1].second->GetIdx(), 0); - - auto weights = OpDescUtils::GetWeightsFromNodes(nodes_2_out_anchor); - EXPECT_EQ(weights.size(), 2); - EXPECT_EQ(weights[0]->GetTensorDesc().GetDataType(), DT_INT32); - EXPECT_EQ(weights[1]->GetTensorDesc().GetDataType(), DT_INT32); -} -TEST_F(UtestOpDescUtils, GetPotentailWeightByIndexAccrossEnter) { - auto graph = BuildGraph3(); - auto cast_node = graph->FindNode("cast"); - auto nodes_2_out_anchor = OpDescUtils::GetConstInputNodeAndAnchor(*cast_node); - EXPECT_EQ(nodes_2_out_anchor.size(), 1); - EXPECT_EQ(nodes_2_out_anchor[0].first->GetName(), "addn"); - EXPECT_EQ(nodes_2_out_anchor[0].second->GetIdx(), 0); - - auto weights = OpDescUtils::GetWeightsFromNodes(nodes_2_out_anchor); - EXPECT_EQ(weights.size(), 1); - EXPECT_EQ(weights[0]->GetTensorDesc().GetDataType(), DT_INT32); -} - -TEST_F(UtestOpDescUtils, GetInputConstDataByIndex_01) { - uint8_t data_buf[4096] = {0}; - data_buf[0] = 23; - data_buf[10] = 32; - auto ge_tensor = std::make_shared(); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), "value", ge_tensor); - auto case_node = builder.AddNode("Case", "Case", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(const_node, 0, case_node, 0); - builder.AddDataEdge(case_node, 0, netoutput, 0); - auto parent_graph = builder.GetGraph(); - - ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph_graph"); - auto sub_data = sub_builder.AddNode("sub_data", "Data", 0, 1); - auto sub_const = sub_builder.AddNode("sub_const", "Const", 0, 1); - AttrUtils::SetTensor(sub_const->GetOpDesc(), "value", ge_tensor); - auto add = sub_builder.AddNode("Add", "Add", 2, 1); - auto sub_netoutput = sub_builder.AddNode("sub_netoutput", "NetOutput", 1, 0); - sub_builder.AddDataEdge(sub_data, 0, add, 0); - sub_builder.AddDataEdge(sub_const, 0, add, 1); - sub_builder.AddDataEdge(add, 0, sub_netoutput, 0); - - auto subgraph = sub_builder.GetGraph(); - subgraph->SetParentNode(case_node); - subgraph->SetParentGraph(parent_graph); - parent_graph->AddSubgraph(subgraph->GetName(), subgraph); - AttrUtils::SetInt(sub_data->GetOpDesc(), "_parent_node_index", 0); - - auto op_desc = add->GetOpDesc(); - op_desc->impl_->input_name_idx_["sub_data"] = 0; - op_desc->impl_->input_name_idx_["sub_const"] = 1; - auto op = OpDescUtils::CreateOperatorFromNode(add); - RuntimeInferenceContext runtime_ctx; - // define callback - OpDescUtils::GetConstInputOnRuntimeFun func_get_input_const = - [&runtime_ctx](const ConstNodePtr &node, const size_t index, ge::GeTensorPtr &dst_tensor) { - // from runtime context - const auto in_data_anchor = node->GetInDataAnchor(static_cast(index)); - const auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - auto peer_node = out_data_anchor->GetOwnerNode(); - GeTensorPtr tensor_value = nullptr; - if (runtime_ctx.GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), tensor_value) == - GRAPH_SUCCESS) { - dst_tensor = tensor_value; - return GRAPH_SUCCESS; - } - return ge::GRAPH_SUCCESS; - }; - OpDescUtils::SetCallbackGetConstInputFuncToOperator(op, func_get_input_const); - GeTensorDesc desc; - GeTensorPtr tensor = std::make_shared(desc); - tensor->SetData(data_buf, 4096); - - int64_t node_id = 1; - int output_id = 0; - runtime_ctx.SetTensor(node_id, output_id, std::move(tensor)); - ConstGeTensorBarePtr ge_tensor_res = nullptr; - ge_tensor_res = OpDescUtils::GetInputConstData(op, 1); - - ASSERT_TRUE(ge_tensor_res != nullptr); - const TensorData tmp(ge_tensor_res->GetData()); - const uint8_t* res_buf = tmp.GetData(); - ASSERT_EQ(res_buf[0], 23); - ASSERT_EQ(res_buf[10], 32); -} - -TEST_F(UtestOpDescUtils, GetInputConstDataByIndex_02) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto enter = builder.AddNode("Enter", "Enter", 1, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data2, 0, enter, 0); - builder.AddDataEdge(data, 0, transdata, 0); - builder.AddDataEdge(enter, 0, transdata, 1); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096] = {0}; - data_buf[0] = 23; - data_buf[10] = 32; - ge_tensor->SetData(data_buf, 4096); - - auto op_desc = transdata->GetOpDesc(); - op_desc->impl_->input_name_idx_["Data"] = 0; - op_desc->impl_->input_name_idx_["Enter"] = 1; - auto tensor_desc = op_desc->MutableInputDesc(0); - AttrUtils::SetTensor(tensor_desc, "_value", ge_tensor); - - auto op = OpDescUtils::CreateOperatorFromNode(transdata); - ConstGeTensorBarePtr ge_tensor_res = nullptr; - ConstGeTensorBarePtr ge_tensor_res2 = nullptr; - ge_tensor_res = OpDescUtils::GetInputConstData(op, 0); - ge_tensor_res2 = OpDescUtils::GetInputConstData(op, 1); - ASSERT_TRUE(ge_tensor_res != nullptr); - ASSERT_TRUE(ge_tensor_res2 == nullptr); - const TensorData tmp(ge_tensor_res->GetData()); - const uint8_t* res_buf = tmp.GetData(); - ASSERT_EQ(res_buf[0], 23); - ASSERT_EQ(res_buf[10], 32); -} - -// for partiton graph get const -TEST_F(UtestOpDescUtils, GetInputConstDataByIndex_03) { - ut::GraphBuilder builder = ut::GraphBuilder("partiton_graph0"); - auto pld = builder.AddNode(PLACEHOLDER, PLACEHOLDER, 0, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 1, 1); - auto netoutput = builder.AddNode(NETOUTPUT, NETOUTPUT, 1, 0); - builder.AddDataEdge(pld, 0, transdata, 0); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto op_desc = transdata->GetOpDesc(); - - ut::GraphBuilder builder1 = ut::GraphBuilder("partiton_graph1"); - auto const_node = builder1.AddNode(CONSTANT, CONSTANT, 0, 1); - auto end = builder1.AddNode(END, END, 1, 0); - builder.AddDataEdge(const_node, 0, end, 0); - auto ge_tensor = std::make_shared(); - uint8_t data_buf[4096U] = {0}; - data_buf[0] = 23U; - data_buf[10] = 32U; - ge_tensor->SetData(data_buf, 4096U); - AttrUtils::SetTensor(const_node->GetOpDesc(), ATTR_NAME_WEIGHTS, ge_tensor); - - pld->GetOpDesc()->SetExtAttr("parentNode", const_node); - auto op = OpDescUtils::CreateOperatorFromNode(transdata); - ConstGeTensorBarePtr ge_tensor_res = nullptr; - // case 0 - ge_tensor_res = OpDescUtils::GetInputConstData(op, 0U); - ASSERT_TRUE(ge_tensor_res != nullptr); - const TensorData tmp(ge_tensor_res->GetData()); - const uint8_t *res_buf = tmp.GetData(); - ASSERT_EQ(res_buf[0], 23U); - ASSERT_EQ(res_buf[10], 32U); - - // case 1 - op_desc->impl_->input_name_idx_[PLACEHOLDER] = 0U; - Tensor tensor; - ASSERT_EQ(op.GetInputConstData(PLACEHOLDER, tensor), GRAPH_SUCCESS); - const uint8_t *buf = tensor.GetData(); - ASSERT_EQ(buf[0], 23U); - ASSERT_EQ(buf[10], 32U); -} - -TEST_F(UtestOpDescUtils, DefaultInferFormat) { - auto tensor_desc = std::make_shared(); - tensor_desc->SetShape(GeShape()); - tensor_desc->SetFormat(FORMAT_ND); - tensor_desc->SetDataType(DT_FLOAT); - auto op_desc = std::make_shared("test", "Identity"); - op_desc->AddInputDesc(tensor_desc->Clone()); - op_desc->AddOutputDesc(tensor_desc->Clone()); - - EXPECT_EQ(op_desc->DefaultInferFormat(), 0); - auto input_desc = op_desc->MutableInputDesc(0); - EXPECT_EQ(input_desc->GetFormat(), FORMAT_ND); - auto output_desc = op_desc->MutableOutputDesc(0); - EXPECT_EQ(output_desc->GetFormat(), FORMAT_ND); -} - - -TEST_F(UtestOpDescUtils, OpDescBuilder) { - OpDescBuilder builder("name", "type"); - builder.AddDynamicInput("AddDy", 1); - EXPECT_NE(&builder, nullptr); - const GeTensorDesc ten = GeTensorDesc(GeShape()); - builder.AddDynamicInput(std::string("AddDy2"), 2, ten); - EXPECT_NE(&builder, nullptr); - builder.AddDynamicOutput("AddDyOut", 3); - EXPECT_NE(&builder, nullptr); - builder.AddDynamicOutput(std::string("AddDyOut2"), 4, ten); - EXPECT_NE(&builder, nullptr); -} - -TEST_F(UtestOpDescUtils, OpDescUtils) { - OpDescPtr odp = std::make_shared("name", "type"); - EXPECT_EQ(OpDescUtils::SetSubgraphInstanceName("subgraph_name", "subgraph_instance_name", odp), GRAPH_PARAM_INVALID); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - GeTensorPtr tp = std::make_shared(); - OpDescPtr odp1 = std::make_shared("name1", "type1"); - EXPECT_EQ(OpDescUtils::MutableWeights(odp1), nullptr); - EXPECT_EQ(OpDescUtils::ClearWeights(data_node), GRAPH_SUCCESS); - NodePtr np = std::make_shared(); - EXPECT_EQ(OpDescUtils::ClearWeights(np), GRAPH_PARAM_INVALID); - EXPECT_EQ(OpDescUtils::ClearInputDesc(data_node), true); - odp->AddInputDesc(GeTensorDesc()); - EXPECT_EQ(OpDescUtils::GetWeights(data_node).size(), 0); - EXPECT_EQ(OpDescUtils::GetWeights(nullptr).size(), 0); - EXPECT_EQ(OpDescUtils::GetConstInputNode(*data_node).size(), 0); - EXPECT_EQ(OpDescUtils::SetWeights(*odp, nullptr), GRAPH_FAILED); - EXPECT_EQ(OpDescUtils::ClearInputDesc(odp, 0), true); - EXPECT_EQ(OpDescUtils::ClearInputDesc(odp, 1), false); - EXPECT_EQ(odp->impl_->inputs_desc_.size(), 0); - EXPECT_EQ(OpDescUtils::HasQuantizeFactorParams(odp), false); - EXPECT_EQ(OpDescUtils::ClearOutputDesc(data_node), true); - EXPECT_EQ(OpDescUtils::ClearOutputDesc(odp, 0), false); - EXPECT_EQ(OpDescUtils::HasQuantizeFactorParams(*odp), false); - EXPECT_EQ(OpDescUtils::IsNonConstInput(*data_node, 1), false); - EXPECT_EQ(OpDescUtils::IsNonConstInput(data_node, 1), false); -} - -TEST_F(UtestOpDescUtils, OpDescUtilsSupply) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - auto one_node = builder.AddNode("One", "One", 3, 3); - InDataAnchorPtr in_anch = std::make_shared(data_node, 111); - OutDataAnchorPtr out_anch = std::make_shared(data_node, 222); - auto node3 = builder.AddNode("Data3", "Data3", 3, 3); - InControlAnchorPtr inc_anch = std::make_shared(node3, 33); - EXPECT_EQ(attr_node->AddLinkFrom(data_node), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetConstInputNode(*attr_node).size(), 0); - std::vector node_v; - node_v.push_back(data_node); - node_v.push_back(attr_node); - EXPECT_EQ(OpDescUtils::GetInputData(node_v).size(), 0); - EXPECT_EQ(OpDescUtils::GetNonConstInputsSize(*attr_node), 1); - EXPECT_EQ(OpDescUtils::GetNonConstInputsSize(attr_node), 1); - EXPECT_EQ(OpDescUtils::GetNonConstInputTensorDesc(*attr_node, 1), GeTensorDesc()); - EXPECT_EQ(OpDescUtils::GetNonConstInputTensorDesc(attr_node, 1), GeTensorDesc()); - size_t st = 0; - EXPECT_EQ(OpDescUtils::GetNonConstInputIndex(attr_node, 1, st), false); - EXPECT_EQ(OpDescUtils::GetConstInputs(nullptr).size(), 0); - EXPECT_EQ(OpDescUtils::GetNonConstTensorDesc(attr_node).size(), 1); - Operator op("name", "type"); - op.operator_impl_ = nullptr; - EXPECT_EQ(OpDescUtils::GetInputConstData(op, 0), nullptr); -} - -TEST_F(UtestOpDescUtils, ClearInputDesc_Nullptr) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - EXPECT_EQ(data_node->GetAllInDataAnchors().size(), 1); - data_node->impl_->op_->impl_ = nullptr; - EXPECT_EQ(OpDescUtils::ClearInputDesc(data_node), false); -} - -TEST_F(UtestOpDescUtils, ClearOutputDesc_Nullptr) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data_node = builder.AddNode("Data", "Data", 1, 1); - EXPECT_EQ(data_node->GetAllInDataAnchors().size(), 1); - data_node->impl_->op_->impl_ = nullptr; - EXPECT_EQ(OpDescUtils::ClearOutputDesc(data_node), false); -} - -TEST_F(UtestOpDescUtils, ClearOutputDesc_Normal) { - OpDescPtr odp = std::make_shared("name", "type"); - EXPECT_NE(odp, nullptr); - EXPECT_NE(odp->impl_, nullptr); - EXPECT_EQ(odp->impl_->outputs_desc_.size(), 0); - odp->impl_->outputs_desc_.push_back(std::make_shared()); - EXPECT_EQ(OpDescUtils::ClearOutputDesc(odp, 0), true); -} - -TEST_F(UtestOpDescUtils, GetWeightsFromNodes) { - auto graph = BuildGraph3(); - auto cast_node = graph->FindNode("cast"); - auto enter_node = graph->FindNode("enter"); - auto in_nodes_and_anchors = cast_node->GetInDataNodesAndAnchors(); - EXPECT_EQ(in_nodes_and_anchors.size(), 1); - EXPECT_EQ(in_nodes_and_anchors.begin()->first->GetName(), "enter"); - EXPECT_EQ(in_nodes_and_anchors.begin()->second->GetIdx(), 0); - auto opdsc1 = in_nodes_and_anchors.begin()->first->GetOpDesc(); - bool is_potential_const1 = false; - auto has_attr1 = AttrUtils::GetBool(opdsc1, ATTR_NAME_POTENTIAL_CONST, is_potential_const1); - EXPECT_EQ(has_attr1, false); - - EXPECT_EQ(in_nodes_and_anchors.size(), 1); - auto nodes_2_out_anchor = OpDescUtils::GetConstInputNodeAndAnchor(*cast_node); - EXPECT_EQ(nodes_2_out_anchor.size(), 1); - EXPECT_EQ(nodes_2_out_anchor[0].first->GetName(), "addn"); - EXPECT_EQ(nodes_2_out_anchor[0].second->GetIdx(), 0); - - auto opdsc = nodes_2_out_anchor[0].first->GetOpDesc(); - bool is_potential_const = false; - auto has_attr = AttrUtils::GetBool(opdsc, ATTR_NAME_POTENTIAL_CONST, is_potential_const); - EXPECT_EQ(has_attr, true); - auto weights = OpDescUtils::GetWeightsFromNodes(nodes_2_out_anchor); -} - -TEST_F(UtestOpDescUtils, GetConstInputNode_Const_Enter_Other) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto const2 = builder.AddNode("const2", "Const", 1, 1); - EXPECT_EQ(const1->AddLinkFrom(const2), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetConstInputNode(*const1).size(), 1); - - auto enter1 = builder.AddNode("enter1", "Enter", 1, 1); - auto enter2 = builder.AddNode("enter2", "Enter", 1, 1); - EXPECT_EQ(enter1->AddLinkFrom(enter2), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetConstInputNode(*enter1).size(), 0); - - auto other1 = builder.AddNode("other1", "Enter", 1, 1); - auto other2 = builder.AddNode("other2", "other", 1, 1); - EXPECT_EQ(other1->AddLinkFrom(other2), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetConstInputNode(*other1).size(), 0); -} - -TEST_F(UtestOpDescUtils, GetInputData_Weight) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto const2 = builder.AddNode("const2", "Const", 1, 1); - EXPECT_EQ(const1->AddLinkFrom(const2), GRAPH_SUCCESS); - - int32_t weight[1] = {1}; - GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); - GeTensorPtr tensor0 = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); - OpDescUtils::SetWeights(const1, {tensor0}); - - std::vector vec; - vec.push_back(const1); - EXPECT_EQ(OpDescUtils::GetInputData(vec).size(), 1); -} - -TEST_F(UtestOpDescUtils, GetNonConstInputsSize) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto const2 = builder.AddNode("const2", "Const", 1, 1); - - EXPECT_EQ(OpDescUtils::GetNonConstInputsSize(nullptr), 0); - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(*const1), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetNonConstInputsSize(*const1), 0); - EXPECT_EQ(const1->GetAllInDataAnchors().size(), 1); - auto in_anch = const1->GetAllInDataAnchors().at(0); - EXPECT_EQ(AnchorUtils::SetStatus(in_anch, ANCHOR_DATA), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetNonConstInputsSize(*const1), 1); -} - -TEST_F(UtestOpDescUtils, AddConstOpToAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const1 = builder.AddNode("const1", "Const", 1, 1); - EXPECT_EQ(const1->GetAllInDataAnchors().size(), 1); - auto in_anch = const1->GetAllInDataAnchors().at(0); - int32_t weight[1] = {1}; - GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); - GeTensorPtr tensor0 = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); - - EXPECT_EQ(OpDescUtils::AddConstOpToAnchor(in_anch, tensor0), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDescUtils, GetNonConstInputIndex) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(*attr_node), GRAPH_SUCCESS); - size_t st = 0; - EXPECT_EQ(OpDescUtils::GetNonConstInputIndex(attr_node, 1, st), false); -} - -TEST_F(UtestOpDescUtils, GetNonConstInputTensorDesc) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - EXPECT_EQ(attr_node->GetAllInDataAnchors().size(), 2); - auto in_anch = attr_node->GetAllInDataAnchors().at(0); - EXPECT_NE(in_anch, nullptr); - - EXPECT_EQ(AnchorUtils::SetStatus(in_anch, ANCHOR_DATA), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetNonConstInputTensorDesc(attr_node, 1), GeTensorDesc()); -} - -TEST_F(UtestOpDescUtils, GetNonConstInputTensorDesc_SetStatus) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - EXPECT_EQ(attr_node->GetAllInDataAnchors().size(), 2); - auto in_anch = attr_node->GetAllInDataAnchors().at(0); - EXPECT_NE(in_anch, nullptr); - - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(attr_node), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetNonConstInputTensorDesc(attr_node, 1), GeTensorDesc()); -} - -TEST_F(UtestOpDescUtils, IsNonConstInput) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - EXPECT_EQ(OpDescUtils::IsNonConstInput(attr_node, 1), false); - - - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(*attr_node), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::IsNonConstInput(attr_node, 1), false); - - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto const2 = builder.AddNode("const2", "Const", 1, 1); - EXPECT_EQ(const1->AddLinkFrom(const2), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::IsNonConstInput(const1, 0), false); - EXPECT_EQ(OpDescUtils::IsNonConstInput(const2, 0), false); -} - -TEST_F(UtestOpDescUtils, GetNonConstTensorDesc) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto attr_node = builder.AddNode("Attr", "Attr", 2, 2); - EXPECT_EQ(OpDescUtils::GetNonConstTensorDesc(nullptr).size(), 0); - EXPECT_EQ(NodeUtils::SetAllAnchorStatus(*attr_node), GRAPH_SUCCESS); - EXPECT_EQ(OpDescUtils::GetNonConstTensorDesc(attr_node).size(), 0); -} - -TEST_F(UtestOpDescUtils, GetConstInputs_Const) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const1 = builder.AddNode("const1", "Const", 1, 1); - auto const2 = builder.AddNode("const2", "Const", 1, 1); - EXPECT_EQ(const1->AddLinkFrom(const2), GRAPH_SUCCESS); - EXPECT_EQ(const1->GetType(), "Const"); - EXPECT_EQ(OpDescUtils::GetConstInputs(*const1).size(), 1); -} - -TEST_F(UtestOpDescUtils, GetConstInputs_Switch) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto sw1 = builder.AddNode("sw1", "Switch", 1, 1); - auto mm1 = builder.AddNode("mm1", "MatMul", 1, 1); - EXPECT_EQ(sw1->AddLinkFrom(mm1), GRAPH_SUCCESS); - EXPECT_EQ(sw1->GetType(), "Switch"); - EXPECT_EQ(mm1->GetType(), "MatMul"); - EXPECT_EQ(OpDescUtils::GetConstInputs(*sw1).size(), 0); -} - -TEST_F(UtestOpDescUtils, MutableWeights) { - auto node = std::make_shared(); - node = nullptr; - EXPECT_EQ(OpDescUtils::MutableWeights(node).size(), 0); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto ph = builder.AddNode("ph", "PlaceHolder", 1, 1); - EXPECT_EQ(OpDescUtils::MutableWeights(*ph).size(), 0); -} - -TEST_F(UtestOpDescUtils, MutableWeights_Nullptr) { - OpDescPtr odp = std::make_shared(); - odp = nullptr; - EXPECT_EQ(OpDescUtils::MutableWeights(odp), nullptr); -} - -TEST_F(UtestOpDescUtils, SetWeights) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const1 = builder.AddNode("const1", "Const", 1, 1); - std::map weights_map; - weights_map[1] = std::make_shared(); - EXPECT_EQ(OpDescUtils::SetWeights(*const1, weights_map), GRAPH_SUCCESS); - - auto non1 = builder.AddNode("nonconst1", "NonConst", 1, 1); - int32_t weight[1] = {1}; - GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32); - GeTensorPtr tensor0 = std::make_shared(weight_desc, (uint8_t *)weight, sizeof(weight)); - EXPECT_EQ(OpDescUtils::SetWeights(non1, {tensor0}), GRAPH_SUCCESS); - - weights_map[2] = tensor0; - EXPECT_EQ(OpDescUtils::SetWeights(*const1, weights_map), GRAPH_PARAM_INVALID); -} - -TEST_F(UtestOpDescUtils, CopyConstructOpdesc) { - GeTensorDesc td; - td.SetShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetOriginShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetFormat(FORMAT_NCHW); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT); - td.SetOriginDataType(DT_FLOAT); - vector input_size = {12}; - AttrUtils::SetListInt(td, "input_size", input_size); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc("x1", td); - op_desc->AddInputDesc("x2", td); - op_desc->AddOptionalInputDesc("x3", td); - op_desc->AddOutputDesc("y", td); - AttrUtils::SetStr(op_desc, "padding", "SAME"); - - OpDescPtr new_desc = std::make_shared(*op_desc); - EXPECT_TRUE(new_desc->OpDescMembersAreEqual(*op_desc)); - EXPECT_TRUE(new_desc->OpDescAttrsAreEqual(*op_desc)); - EXPECT_TRUE(new_desc->OpDescGenTensorDescsAreEqual(*op_desc)); - std::string padding; - EXPECT_TRUE(AttrUtils::GetStr(new_desc, "padding", padding)); - EXPECT_EQ(padding, "SAME"); - - EXPECT_EQ(new_desc->GetInputsSize(), 3); - EXPECT_EQ(new_desc->GetOutputsSize(), 1); - - EXPECT_EQ(new_desc->GetInputDescPtr("x1"), new_desc->GetInputDescPtr(0)); - EXPECT_EQ(new_desc->GetInputDescPtr("x2"), new_desc->GetInputDescPtr(1)); - EXPECT_EQ(new_desc->MutableOutputDesc("y"), new_desc->MutableOutputDesc(0)); - - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetShape().GetDims(), std::vector({1, 1, 224, 224})); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetOriginShape().GetDims(), std::vector({1, 1, 224, 224})); - vector new_input_size; - EXPECT_TRUE(AttrUtils::GetListInt(new_desc->GetInputDescPtr(0), "input_size", new_input_size)); - EXPECT_EQ(new_input_size, std::vector({12})); - - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetShape().GetDims(), std::vector({1, 1, 224, 224})); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetOriginShape().GetDims(), std::vector({1, 1, 224, 224})); - new_input_size.clear(); - auto new_input_desc = new_desc->GetInputDescPtr(1); - EXPECT_TRUE(AttrUtils::GetListInt(new_input_desc, "input_size", new_input_size)); - EXPECT_EQ(new_input_size, std::vector({12})); - - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetShape().GetDims(), std::vector({1, 1, 224, 224})); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetOriginShape().GetDims(), std::vector({1, 1, 224, 224})); - new_input_size.clear(); - EXPECT_TRUE(AttrUtils::GetListInt(new_desc->GetInputDescPtr(0), "input_size", new_input_size)); - EXPECT_EQ(new_input_size, std::vector({12})); - op_desc->MutableInputDesc(0)->SetFormat(FORMAT_NC1HWC0_C04); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetFormat(), FORMAT_NCHW); - EXPECT_FALSE(new_desc->OpDescGenTensorDescsAreEqual(*op_desc)); -} - -TEST_F(UtestOpDescUtils, CopyOpdesc) { - GeTensorDesc td; - td.SetShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetOriginShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetFormat(FORMAT_NCHW); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT); - td.SetOriginDataType(DT_FLOAT); - vector input_size = {12}; - AttrUtils::SetListInt(td, "input_size", input_size); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc("x1", td); - op_desc->AddInputDesc("x2", td); - op_desc->AddOptionalInputDesc("x3", td); - op_desc->AddOutputDesc("y", td); - AttrUtils::SetStr(op_desc, "padding", "SAME"); - - auto new_desc = OpDescUtils::CopyOpDesc(op_desc); - - std::string padding; - EXPECT_TRUE(AttrUtils::GetStr(new_desc, "padding", padding)); - EXPECT_EQ(padding, "SAME"); - - EXPECT_EQ(new_desc->GetInputsSize(), 3); - EXPECT_EQ(new_desc->GetOutputsSize(), 1); - - EXPECT_EQ(new_desc->GetInputDescPtr("x1"), new_desc->GetInputDescPtr(0)); - EXPECT_EQ(new_desc->GetInputDescPtr("x2"), new_desc->GetInputDescPtr(1)); - EXPECT_EQ(new_desc->MutableOutputDesc("y"), new_desc->MutableOutputDesc(0)); - - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetShape().GetDims(), std::vector({1, 1, 224, 224})); - EXPECT_EQ(new_desc->GetInputDescPtr(0)->GetOriginShape().GetDims(), std::vector({1, 1, 224, 224})); - vector new_input_size; - EXPECT_TRUE(AttrUtils::GetListInt(new_desc->GetInputDescPtr(0), "input_size", new_input_size)); - EXPECT_EQ(new_input_size, std::vector({12})); - - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetShape().GetDims(), std::vector({1, 1, 224, 224})); - EXPECT_EQ(new_desc->GetInputDescPtr(1)->GetOriginShape().GetDims(), std::vector({1, 1, 224, 224})); - new_input_size.clear(); - auto new_input_desc = new_desc->GetInputDescPtr(1); - EXPECT_TRUE(AttrUtils::GetListInt(new_input_desc, "input_size", new_input_size)); - EXPECT_EQ(new_input_size, std::vector({12})); - - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetOriginDataType(), DT_FLOAT); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetOriginFormat(), FORMAT_NCHW); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetShape().GetDims(), std::vector({1, 1, 224, 224})); - EXPECT_EQ(new_desc->GetOutputDescPtr(0)->GetOriginShape().GetDims(), std::vector({1, 1, 224, 224})); - new_input_size.clear(); - EXPECT_TRUE(AttrUtils::GetListInt(new_desc->GetInputDescPtr(0), "input_size", new_input_size)); - EXPECT_EQ(new_input_size, std::vector({12})); -} - - -TEST_F(UtestOpDescUtils, CopyOpdesc2) { - GeTensorDesc td = StandardTd_5d_1_1_224_224(); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc("x1", td); - op_desc->AddInputDesc("x2", td); - op_desc->AddOutputDesc("y", td); - AttrUtils::SetStr(op_desc, "padding", "VALID"); - - auto new_desc1 = OpDescUtils::CopyOpDesc(op_desc); - - std::string padding; - EXPECT_TRUE(AttrUtils::GetStr(new_desc1, "padding", padding)); - EXPECT_EQ(padding, "VALID"); - - AttrUtils::SetStr(new_desc1, "padding", "SAME"); - padding.clear(); - EXPECT_TRUE(AttrUtils::GetStr(new_desc1, "padding", padding)); - EXPECT_EQ(padding, "SAME"); - - auto new_desc2 = OpDescUtils::CopyOpDesc(new_desc1); - padding.clear(); - EXPECT_TRUE(AttrUtils::GetStr(new_desc2, "padding", padding)); - EXPECT_EQ(padding, "SAME"); -} - -TEST_F(UtestOpDescUtils, CloneOpdesc) { - GeTensorDesc td = StandardTd_5d_1_1_224_224(); - - auto op_desc = std::make_shared(); - op_desc->AddInputDesc("x1", td); - op_desc->AddInputDesc("x2", td); - op_desc->AddOutputDesc("y", td); - AttrUtils::SetStr(op_desc, "padding", "VALID"); - - auto new_desc1 = OpDescUtils::CloneOpDesc(op_desc); - - std::string padding; - EXPECT_TRUE(AttrUtils::GetStr(new_desc1, "padding", padding)); - EXPECT_EQ(padding, "VALID"); - - AttrUtils::SetStr(new_desc1, "padding", "SAME"); - padding.clear(); - EXPECT_TRUE(AttrUtils::GetStr(new_desc1, "padding", padding)); - EXPECT_EQ(padding, "SAME"); - - auto new_desc2 = OpDescUtils::CloneOpDesc(new_desc1); - padding.clear(); - EXPECT_TRUE(AttrUtils::GetStr(new_desc2, "padding", padding)); - EXPECT_EQ(padding, "SAME"); -} - -REG_OP(DynamicOpUt) - .DYNAMIC_INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .INPUT(a, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .INPUT(b, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .ATTR(transpose_x1, Bool, false) - .ATTR(transpose_x2, Bool, false) - .OP_END_FACTORY_REG(DynamicOpUt) - -TEST_F(UtestOpDescUtils, GetInputIrIndexes2InstanceIndexesPairMap_NullOpDescFailed) { - auto ir_index_to_instance_index_pair_map = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(nullptr); - ASSERT_TRUE(ir_index_to_instance_index_pair_map.empty()); -} - -TEST_F(UtestOpDescUtils, GetOutputIrIndexes2InstanceIndexesPairMap_NullOpDescFailed) { - auto ir_index_to_instance_index_pair_map = OpDescUtils::GetOutputIrIndexes2InstanceIndexesPairMap(nullptr); - ASSERT_TRUE(ir_index_to_instance_index_pair_map.empty()); -} - -void IrIndexAndInstanceIndexCheck(size_t dynamic_input_num, bool has_optional_input) { - size_t optional_input_num = has_optional_input ? 1U : 0U; - auto graph = BuildGraph4(dynamic_input_num, has_optional_input); - auto dynamic_op_ut_node = graph->FindNode("dynamic_op_ut"); - auto op_desc = dynamic_op_ut_node->GetOpDesc(); - - size_t index = 0; - auto &name_index = op_desc->MutableAllInputName(); - name_index.clear(); - for (size_t i = 0U; i < dynamic_input_num; ++i) { - name_index["x" + std::to_string(i)] = index++; - } - name_index["a"] = index++; - if (optional_input_num == 1) { - name_index["bias"] = index++; - } - name_index["b"] = index++; - - auto ir_index_to_instance_index_pair_map = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(op_desc); - ASSERT_FALSE(ir_index_to_instance_index_pair_map.empty()); - - std::map> expect_map; - expect_map[0] = std::pair(0, dynamic_input_num); - expect_map[1] = std::pair(dynamic_input_num, 1); - expect_map[2] = std::pair(dynamic_input_num + 1, optional_input_num); - expect_map[3] = std::pair(dynamic_input_num + 1 + optional_input_num, 1); - EXPECT_EQ(ir_index_to_instance_index_pair_map, expect_map); -} - -TEST_F(UtestOpDescUtils, GetInputIrIndexes2InstanceIndexesPairMap_Success) { - IrIndexAndInstanceIndexCheck(0, true); - IrIndexAndInstanceIndexCheck(0, false); - IrIndexAndInstanceIndexCheck(1, true); - IrIndexAndInstanceIndexCheck(1, false); - IrIndexAndInstanceIndexCheck(3, true); - IrIndexAndInstanceIndexCheck(3, false); -} - -TEST_F(UtestOpDescUtils, GetInputIrIndexes2InstanceIndexesPairMap_DynamicInputNameNotMatch_Failed) { - size_t dynamic_input_num = 1; - size_t has_optional_input = true; - size_t optional_input_num = has_optional_input ? 1U : 0U; - auto graph = BuildGraph4(dynamic_input_num, has_optional_input); - auto dynamic_op_ut_node = graph->FindNode("dynamic_op_ut"); - auto op_desc = dynamic_op_ut_node->GetOpDesc(); - - size_t index = 0; - auto &name_index = op_desc->MutableAllInputName(); - name_index.clear(); - - name_index["x0"] = index++; - name_index["x2"] = index++; // error name - - name_index["a"] = index++; - if (optional_input_num == 1) { - name_index["bias"] = index++; - } - name_index["b"] = index++; - - auto ir_index_to_instance_index_pair_map = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(op_desc); - ASSERT_TRUE(ir_index_to_instance_index_pair_map.empty()); -} - -void GetIrIndexCheck(size_t dynamic_input_num, bool has_optional_input) { - size_t optional_input_num = has_optional_input ? 1U : 0U; - auto graph = BuildGraph4(dynamic_input_num, has_optional_input); - auto dynamic_op_ut_node = graph->FindNode("dynamic_op_ut"); - auto op_desc = dynamic_op_ut_node->GetOpDesc(); - - size_t index = 0; - auto &name_index = op_desc->MutableAllInputName(); - name_index.clear(); - for (size_t i = 0U; i < dynamic_input_num; ++i) { - name_index["x" + std::to_string(i)] = index++; - } - name_index["a"] = index++; - if (optional_input_num == 1) { - name_index["bias"] = index++; - } - name_index["b"] = index++; - - index = 0U; - std::map expect_instance_index_to_ir_index_map; - for (size_t i = 0U; i < dynamic_input_num; ++i) { - expect_instance_index_to_ir_index_map[index++] = 0; - } - expect_instance_index_to_ir_index_map[index++] = 1; - if (has_optional_input) { - expect_instance_index_to_ir_index_map[index++] = 2; - } - expect_instance_index_to_ir_index_map[index++] = 3; - for (auto &instance_index_to_ir_index : expect_instance_index_to_ir_index_map) { - auto input_index = instance_index_to_ir_index.first; - size_t ir_index; - auto ret = OpDescUtils::GetInputIrIndexByInstanceIndex(op_desc, input_index, ir_index); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(ir_index, instance_index_to_ir_index.second); - } -} - -TEST_F(UtestOpDescUtils, GetInputIrIndexeByInstanceIndexe_Success) { - GetIrIndexCheck(0, true); - GetIrIndexCheck(0, false); - GetIrIndexCheck(1, true); - GetIrIndexCheck(1, false); - GetIrIndexCheck(3, true); - GetIrIndexCheck(3, false); -} - -TEST_F(UtestOpDescUtils, GetInputIrIndexeByInstanceIndexe_DynamicNameNotmatch_Failed) { - size_t dynamic_input_num = 1; - size_t has_optional_input = true; - size_t optional_input_num = has_optional_input ? 1U : 0U; - auto graph = BuildGraph4(dynamic_input_num, has_optional_input); - auto dynamic_op_ut_node = graph->FindNode("dynamic_op_ut"); - auto op_desc = dynamic_op_ut_node->GetOpDesc(); - - size_t index = 0; - auto &name_index = op_desc->MutableAllInputName(); - name_index.clear(); - - name_index["x0"] = index++; - name_index["x2"] = index++; // error name - - name_index["a"] = index++; - if (optional_input_num == 1) { - name_index["bias"] = index++; - } - name_index["b"] = index++; - size_t ir_index; - auto ret = OpDescUtils::GetInputIrIndexByInstanceIndex(op_desc, 2, ir_index); - ASSERT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOpDescUtils, GetInputIrIndexeByInstanceIndexe_ActualInputsIsMoreThanIrInputsNum_Success) { - size_t dynamic_input_num = 1; - size_t has_optional_input = true; - size_t optional_input_num = has_optional_input ? 1U : 0U; - auto graph = BuildGraph4(dynamic_input_num, has_optional_input); - auto dynamic_op_ut_node = graph->FindNode("dynamic_op_ut"); - auto op_desc = dynamic_op_ut_node->GetOpDesc(); - - size_t index = 0; - auto &name_index = op_desc->MutableAllInputName(); - name_index.clear(); - - name_index["x0"] = index++; - name_index["x1"] = index++; // error name - - name_index["a"] = index++; - if (optional_input_num == 1) { - name_index["bias"] = index++; - } - name_index["b"] = index++; - name_index["assist_matrix"] = index++; - size_t ir_index; - - int32_t event_level; - int32_t old_level = dlog_getlevel(GE_MODULE_NAME, &event_level); - dlog_setlevel(GE_MODULE_NAME, DLOG_INFO, event_level); - auto ret = OpDescUtils::GetInputIrIndexByInstanceIndex(op_desc, 5, ir_index); - dlog_setlevel(GE_MODULE_NAME, old_level, event_level); - ASSERT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ir_index, std::numeric_limits::max()); -} -TEST_F(UtestOpDescUtils, GetOutputIrIndexeByInstanceIndexe_NoOutput_Success) { - auto graph = BuildGraph5(); - auto node_without_outputs = graph->FindNode("noop"); - auto op_desc = node_without_outputs->GetOpDesc(); - - auto ir_index_to_instance_index_pair_map = OpDescUtils::GetOutputIrIndexes2InstanceIndexesPairMap(op_desc); - ASSERT_TRUE(ir_index_to_instance_index_pair_map.empty()); -} - -TEST_F(UtestOpDescUtils, GetOutputIrIndexeByInstanceIndexe_UnknownOutputIrType_Failed) { - auto graph = BuildGraph4(2, false); - auto dynamic_op_ut_node = graph->FindNode("dynamic_op_ut"); - auto op_desc = dynamic_op_ut_node->GetOpDesc(); - op_desc->AppendIrOutput("y", kIrOutputTypeEnd);// invalid IrType - - auto ir_index_to_instance_index_pair_map = OpDescUtils::GetOutputIrIndexes2InstanceIndexesPairMap(op_desc); - ASSERT_TRUE(ir_index_to_instance_index_pair_map.empty()); -} - -#define CHECK_IR_RANGE(Idx, Start, Num) \ - EXPECT_EQ(ir_ranges[Idx].first, Start); \ - EXPECT_EQ(ir_ranges[Idx].second, Num) - -REG_OP(DescUtilTestDynamicFirst) - .DYNAMIC_INPUT(input0, "T") - .INPUT(input1, "T") - .INPUT(input2, "T") - .DATATYPE(T, TensorType({DT_FLOAT})) - .OP_END_FACTORY_REG(DescUtilTestDynamicFirst); -TEST_F(UtestOpDescUtils, get_input_desc_range_for_dynamic_first_ir_desc_end) { - auto op = op::DescUtilTestDynamicFirst(); - op.create_dynamic_input_input0(2); // Dynamic desc出现在尾部 - auto desc = OpDescUtils::GetOpDescFromOperator(op); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 3); - // |0 |1 |2 |3 - // [input1, input2, input0:0, input0:1] - // dynamic input1 - CHECK_IR_RANGE(0, 2, 2); - // static input2 - CHECK_IR_RANGE(1, 0, 1); - // static input2 - CHECK_IR_RANGE(2, 1, 1); -} -TEST_F(UtestOpDescUtils, get_input_desc_range_for_dynamic_first_ir_desc_begin) { - auto op = op::DescUtilTestDynamicFirst(); - op.create_dynamic_input_input0(2, false); // Dynamic desc出现在头部 - auto desc = OpDescUtils::GetOpDescFromOperator(op); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 3); - // |0 |1 |2 |3 - // [input0:0, input0:1, input1, input2] - // dynamic input1 - CHECK_IR_RANGE(0, 0, 2); - // static input2 - CHECK_IR_RANGE(1, 2, 1); - // static input2 - CHECK_IR_RANGE(2, 3, 1); -} -TEST_F(UtestOpDescUtils, get_input_desc_range_for_dynamic_first_ir_desc_middle) { - auto op = op::DescUtilTestDynamicFirst(); - op.create_dynamic_input_byindex_input0(2, 1); // Dynamic desc出现在中间 - auto desc = OpDescUtils::GetOpDescFromOperator(op); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 3); - // |0 |1 |2 |3 - // [input1, input0:0, input0:1, input2] - // dynamic input1 - CHECK_IR_RANGE(0, 1, 2); - // static input2 - CHECK_IR_RANGE(1, 0, 1); - // static input2 - CHECK_IR_RANGE(2, 3, 1); -} - -REG_OP(DescUtilTestMultiDynamic) - .DYNAMIC_INPUT(input0, "T") - .DYNAMIC_INPUT(input1, "T") - .DATATYPE(T, TensorType({DT_FLOAT})) - .OP_END_FACTORY_REG(DescUtilTestMultiDynamic); -TEST_F(UtestOpDescUtils, get_input_desc_range_for_mulit_dynamic) { - auto op = op::DescUtilTestMultiDynamic(); - op.create_dynamic_input_input0(2); - op.create_dynamic_input_input1(2); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 2); - // |0 |1 |2 |3 - // [input0:0, input0:1, input1:0, input1:1] - // dynamic input0 - CHECK_IR_RANGE(0, 0, 2); - // dynamic input1 - CHECK_IR_RANGE(1, 2, 2); -} - -TEST_F(UtestOpDescUtils, get_input_desc_range_for_mulit_dynamic_mis_order) { - auto op = op::DescUtilTestMultiDynamic(); - op.create_dynamic_input_input1(2); // 首先创建input2 - op.create_dynamic_input_input0(2); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 2); - // |0 |1 |2 |3 - // [input1:0, input1:1, input0:0, input0:1] - // dynamic input0 - CHECK_IR_RANGE(0, 2, 2); - // dynamic input1 - CHECK_IR_RANGE(1, 0, 2); -} - -REG_OP(DescUtilTestUnfedOptional) - .OPTIONAL_INPUT(input0, "T") - .OPTIONAL_INPUT(input1, "T") - .OPTIONAL_INPUT(input2, "T") - .DATATYPE(T, TensorType({DT_FLOAT})) - .OP_END_FACTORY_REG(DescUtilTestUnfedOptional); -TEST_F(UtestOpDescUtils, get_input_desc_instance_range_for_unfed_optional) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); // 全部为optional且未feed - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 0); - CHECK_IR_RANGE(1, 0, 0); - CHECK_IR_RANGE(2, 0, 0); -} -TEST_F(UtestOpDescUtils, get_input_desc_raw_range_for_unfed_optional) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); // 全部为optional且未feed - std::map> ir_ranges; - ASSERT_EQ(OpDescUtils::GetIrInputRawDescRange(desc, ir_ranges), GRAPH_SUCCESS); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 0); // Raw range会存储其desc在数据上的位置 - CHECK_IR_RANGE(1, 1, 0); - CHECK_IR_RANGE(2, 2, 0); -} -TEST_F(UtestOpDescUtils, get_input_desc_instance_range_for_unfed_optional_begin) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - desc->UpdateInputDesc("input1", GeTensorDesc()); - desc->UpdateInputDesc("input2", GeTensorDesc()); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 0); - CHECK_IR_RANGE(1, 0, 1); - CHECK_IR_RANGE(2, 1, 1); -} -TEST_F(UtestOpDescUtils, get_input_desc_raw_range_for_unfed_optional_begin) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - desc->UpdateInputDesc("input1", GeTensorDesc()); - desc->UpdateInputDesc("input2", GeTensorDesc()); - std::map> ir_ranges; - ASSERT_EQ(OpDescUtils::GetIrInputRawDescRange(desc, ir_ranges), GRAPH_SUCCESS); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 0); - CHECK_IR_RANGE(1, 1, 1); - CHECK_IR_RANGE(2, 2, 1); -} -TEST_F(UtestOpDescUtils, get_input_desc_instance_range_for_unfed_optional_middle) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - desc->UpdateInputDesc("input0", GeTensorDesc()); - desc->UpdateInputDesc("input2", GeTensorDesc()); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 1); - CHECK_IR_RANGE(1, 1, 0); - CHECK_IR_RANGE(2, 1, 1); -} -TEST_F(UtestOpDescUtils, get_input_desc_raw_range_for_unfed_optional_middle) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - desc->UpdateInputDesc("input0", GeTensorDesc()); - desc->UpdateInputDesc("input2", GeTensorDesc()); - std::map> ir_ranges; - ASSERT_EQ(OpDescUtils::GetIrInputRawDescRange(desc, ir_ranges), GRAPH_SUCCESS); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 1); - CHECK_IR_RANGE(1, 1, 0); - CHECK_IR_RANGE(2, 2, 1); -} -TEST_F(UtestOpDescUtils, get_input_desc_instance_range_for_unfed_optional_end) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - desc->UpdateInputDesc("input0", GeTensorDesc()); - desc->UpdateInputDesc("input1", GeTensorDesc()); - auto ir_ranges = OpDescUtils::GetInputIrIndexes2InstanceIndexesPairMap(desc); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 1); - CHECK_IR_RANGE(1, 1, 1); - CHECK_IR_RANGE(2, 2, 0); -} -TEST_F(UtestOpDescUtils, get_input_desc_raw_range_for_unfed_optional_end) { - auto op = op::DescUtilTestUnfedOptional(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - desc->UpdateInputDesc("input0", GeTensorDesc()); - desc->UpdateInputDesc("input1", GeTensorDesc()); - std::map> ir_ranges; - ASSERT_EQ(OpDescUtils::GetIrInputRawDescRange(desc, ir_ranges), GRAPH_SUCCESS); - - EXPECT_EQ(ir_ranges.size(), 3); - CHECK_IR_RANGE(0, 0, 1); - CHECK_IR_RANGE(1, 1, 1); - CHECK_IR_RANGE(2, 2, 0); -} - -REG_OP(OpTesGetPromoteInputList1) - .INPUT(input1, "T1") - .DYNAMIC_INPUT(input2, "T2") - .OUTPUT(output1, "T3") - .DATATYPE(T1, TensorType({DT_INT32, DT_FLOAT})) - .DATATYPE(T2, TensorType({DT_INT64, DT_FLOAT})) - .DATATYPE(T3, Promote({"T1", "T2"})) - .OP_END_FACTORY_REG(OpTesGetPromoteInputList1); - -TEST_F(UtestOpDescUtils, get_promote_input_list_one_output) { - auto op = op::OpTesGetPromoteInputList1(); - op.create_dynamic_input_input2(2); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - std::vector> ir_input_list; - OpDescUtils::GetPromoteIrInputList(desc, ir_input_list); - EXPECT_EQ(ir_input_list.size(), 1); - EXPECT_EQ(ir_input_list[0].size(), 2); - EXPECT_EQ(ir_input_list[0][0], 0); - EXPECT_EQ(ir_input_list[0][1], 1); - - std::vector> instance_input_list; - OpDescUtils::GetPromoteInstanceInputList(desc, instance_input_list); - EXPECT_EQ(instance_input_list.size(), 1); - EXPECT_EQ(instance_input_list[0].size(), 3); - EXPECT_EQ(instance_input_list[0][0], 0); - EXPECT_EQ(instance_input_list[0][1], 1); - EXPECT_EQ(instance_input_list[0][2], 2); -} - -REG_OP(OpTesGetPromoteInputList2) - .INPUT(input1, "T1") - .OPTIONAL_INPUT(input2, "T2") - .INPUT(input3, "T3") - .OPTIONAL_INPUT(input4, "T4") - .OUTPUT(output1, "T5") - .OUTPUT(output2, "T6") - .DATATYPE(T1, TensorType({DT_INT32, DT_FLOAT})) - .DATATYPE(T2, TensorType({DT_INT64, DT_FLOAT})) - .DATATYPE(T3, TensorType({DT_INT32, DT_FLOAT})) - .DATATYPE(T4, TensorType({DT_INT64, DT_FLOAT})) - .DATATYPE(T5, Promote({"T1", "T2"})) - .DATATYPE(T6, Promote({"T3", "T4"})) - .OP_END_FACTORY_REG(OpTesGetPromoteInputList2); - -TEST_F(UtestOpDescUtils, get_promote_input_list_outputs) { - auto op = op::OpTesGetPromoteInputList2(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - desc->UpdateInputDesc("input2", GeTensorDesc()); - - std::vector> ir_input_list; - OpDescUtils::GetPromoteIrInputList(desc, ir_input_list); - EXPECT_EQ(ir_input_list.size(), 2); - EXPECT_EQ(ir_input_list[0].size(), 2); - EXPECT_EQ(ir_input_list[1].size(), 2); - EXPECT_EQ(ir_input_list[0][0], 0); - EXPECT_EQ(ir_input_list[0][1], 1); - EXPECT_EQ(ir_input_list[1][0], 2); - EXPECT_EQ(ir_input_list[1][1], 3); - - std::vector> instance_input_list; - OpDescUtils::GetPromoteInstanceInputList(desc, instance_input_list); - EXPECT_EQ(instance_input_list.size(), 2); - EXPECT_EQ(instance_input_list[0].size(), 2); - EXPECT_EQ(instance_input_list[1].size(), 1); - EXPECT_EQ(instance_input_list[0][0], 0); - EXPECT_EQ(instance_input_list[0][1], 1); - EXPECT_EQ(instance_input_list[1][0], 2); -} - -REG_OP(OpTesGetPromoteInputList3) - .INPUT(input1, "T") - .INPUT(input2, "T") - .OUTPUT(output1, "T") - .DATATYPE(T, TensorType({DT_INT32})) - .OP_END_FACTORY_REG(OpTesGetPromoteInputList3); - -TEST_F(UtestOpDescUtils, get_promote_input_list_none_output) { - auto op = op::OpTesGetPromoteInputList3(); - auto desc = OpDescUtils::GetOpDescFromOperator(op); - - std::vector> ir_input_list; - OpDescUtils::GetPromoteIrInputList(desc, ir_input_list); - EXPECT_TRUE(ir_input_list.empty()); - std::vector> instance_input_list; - OpDescUtils::GetPromoteInstanceInputList(desc, instance_input_list); - EXPECT_TRUE(instance_input_list.empty()); -} - -TEST_F(UtestOpDescUtils, CreateConstOpWithOutCopy) { - ge::GeTensorDesc ge_tensor(GeShape({8,8,8}), FORMAT_ND, DT_FLOAT16); - ge_tensor.SetName("test"); - ge::GeTensorPtr const_tensor_ptr = ge::MakeShared(ge_tensor); - ge::OpDescPtr const_op_desc = ge::OpDescUtils::CreateConstOpZeroCopy(const_tensor_ptr); - ge::Operator const_op = ge::OpDescUtils::CreateOperatorFromOpDesc(const_op_desc); - (void) const_op.SetInput("test", const_op); - ConstGeTensorPtr weight; - EXPECT_TRUE(ConstantUtils::GetWeight(const_op_desc, 0UL, weight) == true); - EXPECT_TRUE(weight->GetData().GetData() == const_tensor_ptr->GetData().GetData()); -} - -TEST_F(UtestOpDescUtils, CreateConstOpWithCopy) { - ge::GeTensorDesc ge_tensor(GeShape({8,8,8}), FORMAT_ND, DT_FLOAT16); - ge_tensor.SetName("test"); - ge::GeTensorPtr const_tensor_ptr = ge::MakeShared(ge_tensor); - constexpr int32_t kAlignedSize = 256; - auto aligned_ptr = std::make_shared(kAlignedSize); - const_tensor_ptr->SetData(aligned_ptr, 256); - ge::OpDescPtr const_op_desc = ge::OpDescUtils::CreateConstOp(const_tensor_ptr); - ge::Operator const_op = ge::OpDescUtils::CreateOperatorFromOpDesc(const_op_desc); - (void) const_op.SetInput("test", const_op); - ConstGeTensorPtr weight; - EXPECT_TRUE(ConstantUtils::GetWeight(const_op_desc, 0UL, weight) == true); - EXPECT_TRUE(weight->GetData().GetData() != const_tensor_ptr->GetData().GetData()); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/op_imp_unittest.cc b/tests/ut/graph/testcase/op_imp_unittest.cc deleted file mode 100644 index 6fdbdd23c29f12e9173f71a3cdc88436ae80cb38..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/op_imp_unittest.cc +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/operator_reg.h" -#include -#include - -namespace ge { -class BroadCastInferUt : public testing::Test {}; - -TEST_F(BroadCastInferUt, Scalar1) { - std::vector ret_shape; - auto ret = BroadCastInfer( - []() { return std::vector({1, 2, 3});}, - []() { return std::vector({}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ret_shape, std::vector({1, 2, 3})); - - ret_shape.clear(); - ret = BroadCastInfer( - []() { return std::vector({});}, - []() { return std::vector({1, 2, 3}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ret_shape, std::vector({1, 2, 3})); -} - -TEST_F(BroadCastInferUt, SameShape) { - std::vector ret_shape; - auto ret = BroadCastInfer( - []() { return std::vector({1, 2, 3});}, - []() { return std::vector({1, 2, 3}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ret_shape, std::vector({1, 2, 3})); -} - -TEST_F(BroadCastInferUt, BroadCastDim1) { - std::vector ret_shape; - auto ret = BroadCastInfer( - []() { return std::vector({3, 2, 1});}, - []() { return std::vector({1, 2, 3}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ret_shape, std::vector({3, 2, 3})); -} - -TEST_F(BroadCastInferUt, BroadCastRank) { - std::vector ret_shape; - auto ret = BroadCastInfer( - []() { return std::vector({1, 2, 3, 4});}, - []() { return std::vector({3, 4}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ret_shape, std::vector({1, 2, 3, 4})); -} - -TEST_F(BroadCastInferUt, BroadCastRankAndDim1) { - std::vector ret_shape; - auto ret = BroadCastInfer( - []() { return std::vector({1, 2, 1, 4});}, - []() { return std::vector({5, 4}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ret_shape, std::vector({1, 2, 5, 4})); -} - -TEST_F(BroadCastInferUt, BroadCastFailed_DimDiff) { - std::vector ret_shape; - auto ret = BroadCastInfer( - []() { return std::vector({1, 2, 3, 4});}, - []() { return std::vector({5, 4}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(BroadCastInferUt, BroadCastRankAndDim1_1) { - std::vector ret_shape; - auto ret = BroadCastInfer( - []() { return std::vector({5, 4});}, - []() { return std::vector({1, 2, 1, 4}); }, - [&ret_shape](const std::vector &out_shape) { ret_shape = out_shape; }); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(ret_shape, std::vector({1, 2, 5, 4})); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/op_type_utils_unittest.cc b/tests/ut/graph/testcase/op_type_utils_unittest.cc deleted file mode 100644 index bebd8e15815d5bd7ca8852741944582d17edc7dd..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/op_type_utils_unittest.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/utils/op_type_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/debug/ge_util.h" -#include "graph/compute_graph.h" - -namespace ge { -class UtestOpTypeUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestOpTypeUtils, TestDataNodeType) { - std::string test_node_type = "Data"; - EXPECT_TRUE(OpTypeUtils::IsDataNode(test_node_type)); - EXPECT_FALSE(OpTypeUtils::IsVariableNode(test_node_type)); - EXPECT_FALSE(OpTypeUtils::IsVarLikeNode(test_node_type)); - - test_node_type = "AnnData"; - EXPECT_TRUE(OpTypeUtils::IsDataNode(test_node_type)); - - test_node_type = "AippData"; - EXPECT_TRUE(OpTypeUtils::IsDataNode(test_node_type)); - - test_node_type = "RefData"; - EXPECT_TRUE(OpTypeUtils::IsDataNode(test_node_type)); -} - -TEST_F(UtestOpTypeUtils, TestVariableNodeType) { - std::string test_node_type = "Variable"; - EXPECT_TRUE(OpTypeUtils::IsVariableNode(test_node_type)); - EXPECT_TRUE(OpTypeUtils::IsVarLikeNode(test_node_type)); - - test_node_type = "VariableV2"; - EXPECT_TRUE(OpTypeUtils::IsVariableNode(test_node_type)); - EXPECT_TRUE(OpTypeUtils::IsVarLikeNode(test_node_type)); -} - -TEST_F(UtestOpTypeUtils, TestVariableLikeNodeType) { - std::string test_node_type = "RefData"; - EXPECT_FALSE(OpTypeUtils::IsVariableNode(test_node_type)); - EXPECT_TRUE(OpTypeUtils::IsVarLikeNode(test_node_type)); -} - -TEST_F(UtestOpTypeUtils, TestGetOriginalTypeFailed) { - ge::OpDescPtr op_desc = std::make_shared("A", FRAMEWORKOP); - std::shared_ptr graph = std::make_shared("test1"); - ge::NodePtr node = graph->AddNode(op_desc); - - std::string original_type; - EXPECT_EQ(OpTypeUtils::GetOriginalType(node->GetOpDesc(), original_type), INTERNAL_ERROR); -} - -TEST_F(UtestOpTypeUtils, TestGetOriginalTypeSuccess) { - ge::OpDescPtr op_desc = std::make_shared("A", FRAMEWORKOP); - std::shared_ptr graph = std::make_shared("test1"); - ge::NodePtr node = graph->AddNode(op_desc); - std::string type = "GetNext"; - node->GetOpDesc()->SetType(FRAMEWORKOP); - ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); - std::string original_type; - EXPECT_EQ(OpTypeUtils::GetOriginalType(node->GetOpDesc(), original_type), GRAPH_SUCCESS); - EXPECT_EQ(original_type, type); -} - -TEST_F(UtestOpTypeUtils, TestIsInputRefData) { - ge::OpDescPtr ref_data_op_desc = std::make_shared("RefData", REFDATA); - ge::OpDescPtr data_op_desc = std::make_shared("Data", DATA); - EXPECT_EQ(OpTypeUtils::IsInputRefData(ref_data_op_desc), true); - (void) AttrUtils::SetStr(ref_data_op_desc, REF_VAR_SRC_VAR_NAME, "1"); - EXPECT_EQ(OpTypeUtils::IsInputRefData(ref_data_op_desc), false); - EXPECT_EQ(OpTypeUtils::IsInputRefData(data_op_desc), false); -} - -TEST_F(UtestOpTypeUtils, TestIsAutofuseNode) { - ge::OpDescPtr asc_bc = std::make_shared("asc_bc", ASC_BC); - ge::OpDescPtr fuse_asc_bc = std::make_shared("fuse_asc_bc", FUSE_ASC_BC); - ge::OpDescPtr empty_asc_bc = std::make_shared("empty_asc_bc", EMPTY_ASC_BC); - ge::OpDescPtr unknown = std::make_shared("unknown", "UNKNOWN"); - EXPECT_EQ(OpTypeUtils::IsAutofuseNode(asc_bc), true); - EXPECT_EQ(OpTypeUtils::IsAutofuseNode(fuse_asc_bc), true); - EXPECT_EQ(OpTypeUtils::IsAutofuseNode(empty_asc_bc), true); - EXPECT_EQ(OpTypeUtils::IsEmptyAutofuseNode(empty_asc_bc->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsAutofuseNode(unknown), false); -} -TEST_F(UtestOpTypeUtils, TestIsAutofuseNodeWithType) { - EXPECT_EQ(OpTypeUtils::IsAutofuseNode(ASC_BC), true); - EXPECT_EQ(OpTypeUtils::IsAutofuseNode(FUSE_ASC_BC), true); - EXPECT_EQ(OpTypeUtils::IsAutofuseNode(EMPTY_ASC_BC), true); - EXPECT_EQ(OpTypeUtils::IsAutofuseNode("UNKNOWN"), false); -} - -TEST_F(UtestOpTypeUtils, TestIsConstNode) { - ge::OpDescPtr constant = std::make_shared("constant", CONSTANT); - ge::OpDescPtr constant_op = std::make_shared("constant_op", CONSTANTOP); - ge::OpDescPtr unknown = std::make_shared("unknown", "UNKNOWN"); - EXPECT_EQ(OpTypeUtils::IsConstNode(constant->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsConstNode(constant_op->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsConstNode(unknown->GetType()), false); -} - -TEST_F(UtestOpTypeUtils, TestIsGraphInput) { - ge::OpDescPtr data = std::make_shared("data", DATA); - ge::OpDescPtr variable = std::make_shared("variable", VARIABLE); - ge::OpDescPtr variable_v2 = std::make_shared("variable", VARIABLEV2); - ge::OpDescPtr ref_data = std::make_shared("ref_data", REFDATA); - ge::OpDescPtr constant = std::make_shared("constant", CONSTANT); - ge::OpDescPtr constant_op = std::make_shared("constant_op", CONSTANTOP); - ge::OpDescPtr unknown = std::make_shared("unknown", "UNKNOWN"); - EXPECT_EQ(OpTypeUtils::IsGraphInputNode(data->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsGraphInputNode(variable->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsGraphInputNode(variable_v2->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsGraphInputNode(ref_data->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsGraphInputNode(constant->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsGraphInputNode(constant_op->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsGraphInputNode(unknown->GetType()), false); -} - -TEST_F(UtestOpTypeUtils, TestIsGraphOutput) { - ge::OpDescPtr data = std::make_shared("data", DATA); - ge::OpDescPtr net_output = std::make_shared("net_output", NETOUTPUT); - ge::OpDescPtr unknown = std::make_shared("unknown", "UNKNOWN"); - EXPECT_EQ(OpTypeUtils::IsGraphOutputNode(data->GetType()), false); - EXPECT_EQ(OpTypeUtils::IsGraphOutputNode(net_output->GetType()), true); - EXPECT_EQ(OpTypeUtils::IsGraphOutputNode(unknown->GetType()), false); -} - -} // namespace ge diff --git a/tests/ut/graph/testcase/op_utils_unittest.cc b/tests/ut/graph/testcase/op_utils_unittest.cc deleted file mode 100644 index 6ee3b27b8bf8f96521ccbf71d6cf36fda2f927e2..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/op_utils_unittest.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/op_desc.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/ge_tensor.h" -#include "graph/utils/ge_ir_utils.h" -#include "graph/utils/op_desc_utils_ex.h" -#include "graph/utils/transformer_utils.h" -#include "graph/common_error_codes.h" -#include "graph/operator_factory_impl.h" -#include "register/op_tiling_registry.h" -#include "external/graph/operator_factory.h" -#include "graph/utils/op_desc_utils.h" -#include "external/graph/operator_reg.h" -#include "external/register/op_impl_registry.h" -#include "register/op_impl_space_registry.h" -#include "register/op_impl_kernel_registry.h" -#include "external/graph/infer_format_context.h" - -namespace ge { -class UtestOpDescUtilsEx : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestOpDescUtilsEx, OpVerify_success) { - auto op_desc = std::make_shared(); - EXPECT_EQ(OpDescUtilsEx::OpVerify(op_desc), GRAPH_SUCCESS); - const auto verify_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - op_desc->AddVerifierFunc(verify_func); - EXPECT_EQ(OpDescUtilsEx::OpVerify(op_desc), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDescUtilsEx, InferShapeAndType_success) { - auto op_desc = std::make_shared(); - EXPECT_EQ(OpDescUtilsEx::InferShapeAndType(op_desc), GRAPH_SUCCESS); - const auto add_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - op_desc->AddInferFunc(add_func); - EXPECT_EQ(OpDescUtilsEx::InferShapeAndType(op_desc), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDescUtilsEx, InferDataSlice_success) { - auto op_desc = std::make_shared(); - EXPECT_EQ(OpDescUtilsEx::InferDataSlice(op_desc), NO_DEPENDENCE_FUNC); - const auto infer_data_slice_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - auto op = std::make_shared(); - op_desc->SetType("test"); - OperatorFactoryImpl::RegisterInferDataSliceFunc("test",infer_data_slice_func); - EXPECT_EQ(OpDescUtilsEx::InferDataSlice(op_desc), GRAPH_SUCCESS); -} - -REG_OP(FixInfer_OutputIsFix) - .INPUT(fix_input1, "T") - .INPUT(fix_input2, "T") - .OUTPUT(fix_output, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(FixInfer_OutputIsFix); -TEST_F(UtestOpDescUtilsEx, CallInferFormatFunc_success) { - auto op = OperatorFactory::CreateOperator("test", "FixInfer_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - op_desc->SetType("test"); - const auto infer_format_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - OperatorFactoryImpl::RegisterInferFormatFunc("test", infer_format_func); - EXPECT_EQ(OpDescUtilsEx::CallInferFormatFunc(op_desc, op), GRAPH_SUCCESS); -} - -TEST_F(UtestOpDescUtilsEx, SetType_success) { - auto op_desc = std::make_shared(); - string type = "tmp"; - OpDescUtilsEx::SetType(op_desc, type); - EXPECT_EQ(op_desc->GetType(), type); -} - -TEST_F(UtestOpDescUtilsEx, SetTypeAndResetFuncHandle_success) { - auto op_desc = std::make_shared(); - string type = "tmp"; - OpDescUtilsEx::SetTypeAndResetFuncHandle(op_desc, type); - EXPECT_EQ(op_desc->GetType(), type); - EXPECT_EQ(op_desc->GetInferFunc(), nullptr); - EXPECT_EQ(op_desc->GetInferFormatFunc(), nullptr); - EXPECT_EQ(op_desc->GetInferValueRangeFunc(), nullptr); - EXPECT_EQ(op_desc->GetVerifyFunc(), nullptr); - EXPECT_EQ(op_desc->GetInferDataSliceFunc(), nullptr); -} - -TEST_F(UtestOpDescUtilsEx, CallInferFormatFunc_v2_success) { - auto op = OperatorFactory::CreateOperator("test", "FixInfer_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - const auto infer_format_func = [](Operator &op) { - return GRAPH_SUCCESS; - }; - OperatorFactoryImpl::RegisterInferFormatFunc("FixInfer_OutputIsFix", infer_format_func); - - const auto infer_format_func_v2 = [](gert::InferFormatContext *context) -> UINT32 { - auto output_1 = context->GetRequiredOutputFormat(0U); - output_1->SetOriginFormat(Format::FORMAT_NCHW); - output_1->SetStorageFormat(Format::FORMAT_NCHW); - return GRAPH_SUCCESS; - }; - IMPL_OP(FixInfer_OutputIsFix).InferFormat(infer_format_func_v2); - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctionsV2 op_impl_func; - op_impl_func.infer_format_func = infer_format_func_v2; - registry_holder->AddTypesToImpl("FixInfer_OutputIsFix", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - EXPECT_EQ(OpDescUtilsEx::CallInferFormatFunc(op_desc, op), GRAPH_SUCCESS); - const auto &output_0 = op_desc->GetOutputDesc(0U); - EXPECT_EQ(output_0.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(output_0.GetFormat(), Format::FORMAT_NCHW); - -} -} \ No newline at end of file diff --git a/tests/ut/graph/testcase/operator_constuct_graph_unittest.cc b/tests/ut/graph/testcase/operator_constuct_graph_unittest.cc deleted file mode 100644 index 35318123a4598a441a488b8111a539894e8c2991..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/operator_constuct_graph_unittest.cc +++ /dev/null @@ -1,451 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/operator_reg.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/attr_value.h" -#include "external/graph/operator_factory.h" -#include "graph/operator_factory_impl.h" - -namespace ge { -REG_OP(Const) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(value, Tensor, Tensor()) - - .OP_END_FACTORY_REG(Const); - -REG_OP(OCG2) - .DYNAMIC_INPUT(x, TensorType::NumberType()) - .OUTPUT(y, TensorType::NumberType()) - .REQUIRED_ATTR(N, Int) - .OP_END_FACTORY_REG(OCG2); - -REG_OP(OCG3) - .INPUT(x, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .INPUT(shape, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(axis, Int, 0) - .ATTR(num_axes, Int, -1) - .OP_END_FACTORY_REG(OCG3); - -REG_OP(OCG4) - .INPUT(cond, TensorType::ALL()) - .DYNAMIC_INPUT(input, TensorType::ALL()) - .DYNAMIC_OUTPUT(output, TensorType::ALL()) - .GRAPH(then_branch) - .GRAPH(else_branch) - .OP_END_FACTORY_REG(OCG4); - -REG_OP(OCG5) - .INPUT(branch_index, TensorType({DT_INT32})) - .DYNAMIC_INPUT(input, TensorType::ALL()) - .DYNAMIC_OUTPUT(output, TensorType::ALL()) - .DYNAMIC_GRAPH(branches) - .OP_END_FACTORY_REG(OCG5); - -class OperatorConstructGraphUt : public testing::Test {}; - -/** - * c - * ocg2 ----> const - * | - * ocg2 - * | - * ocg3 - * / \ - * const const - */ -Graph BuildGraph1() { - auto o1_1 = op::Const("o1_1"); - auto o1_2 = op::Const("o1_2"); - auto o3 = op::OCG3("o3"); - auto o2_1 = op::OCG2("o2_1"); - auto o2_2 = op::OCG2("o2_2"); - auto o1_3 = op::Const("o1_3"); - - TensorDesc td{Shape(std::vector({8, 3, 224, 224})), FORMAT_NCHW, DT_UINT8}; - Tensor tensor(td); - tensor.SetData(std::vector(8 * 3 * 224 * 224)); - - o1_1.set_attr_value(tensor); - o1_2.set_attr_value(tensor); - o1_3.set_attr_value(tensor); - o3.set_input_x(o1_1).set_input_shape_by_name(o1_2, "y"); - o2_1.create_dynamic_input_x(1, true).set_dynamic_input_x(0, o3); - o2_2.create_dynamic_input_x(1, true).set_dynamic_input_x(0, o2_1, "y"); - o1_3.AddControlInput(o2_2); - - Graph g{"name"}; - g.SetInputs(std::vector({o1_1, o1_2})).SetOutputs(std::vector({o2_2, o1_3})); - return g; -} - -Graph BuildGraph1ByGnode() { - Graph g{"name"}; - auto o1_1 = g.AddNodeByOp(op::Const("o1_1")); - auto o1_2 = g.AddNodeByOp(op::Const("o1_2")); - auto o3 = g.AddNodeByOp(op::OCG3("o3")); - auto o2_1 = g.AddNodeByOp(op::OCG2("o2_1").create_dynamic_input_x(1, true)); - auto o2_2 = g.AddNodeByOp(op::OCG2("o2_2").create_dynamic_input_x(1, true)); - auto o1_3 = g.AddNodeByOp(op::Const("o1_3")); - - TensorDesc td{Shape(std::vector({8, 3, 224, 224})), FORMAT_NCHW, DT_UINT8}; - Tensor tensor(td); - tensor.SetData(std::vector(8 * 3 * 224 * 224)); - - o1_1.SetAttr("value", tensor); - o1_2.SetAttr("value", tensor); - o1_3.SetAttr("value", tensor); - g.AddDataEdge(o1_1, 0, o3, 0); - g.AddDataEdge(o1_2, 0, o3, 1); - g.AddDataEdge(o3,0,o2_1, 0); - g.AddDataEdge(o2_1,0,o2_2, 0); - g.AddControlEdge(o2_2, o1_3); - auto o1_4 = op::Const("o1_1"); - g.SetInputs({o1_4}); // 前面已经使用了AddNodeByOp的方式构图,此时调用设置失败 - return g; -} -Graph BuildGraph1ByIndex() { - auto o1_1 = op::Const("o1_1"); - auto o1_2 = op::Const("o1_2"); - auto o3 = op::OCG3("o3"); - auto o2_1 = op::OCG2("o2_1"); - auto o2_2 = op::OCG2("o2_2"); - auto o1_3 = op::Const("o1_3"); - - TensorDesc td{Shape(std::vector({8, 3, 224, 224})), FORMAT_NCHW, DT_UINT8}; - Tensor tensor(td); - tensor.SetData(std::vector(8 * 3 * 224 * 224)); - - o1_1.set_attr_value(tensor); - o1_2.set_attr_value(tensor); - o1_3.set_attr_value(tensor); - o3.set_input_x(o1_1, 0).set_input_shape(o1_2, 0); - o2_1.create_dynamic_input_x(1, true).set_dynamic_input_x(0, o3); - o2_2.create_dynamic_input_x(1, true).set_dynamic_input_x(0, o2_1, "y"); - o1_3.AddControlInput(o2_2); - - Graph g{"name"}; - g.SetInputs(std::vector({o1_1, o1_2})).SetOutputs(std::vector({o2_2, o1_3})); - return g; -} - -void CheckGraph1(Graph &g) { - auto cg = GraphUtilsEx::GetComputeGraph(g); - EXPECT_NE(cg, nullptr); - - EXPECT_EQ(cg->GetAllNodesSize(), 6); - auto node_o1_1 = cg->FindNode("o1_1"); - auto node_o1_2 = cg->FindNode("o1_2"); - auto node_o1_3 = cg->FindNode("o1_3"); - auto node_o2_1 = cg->FindNode("o2_1"); - auto node_o2_2 = cg->FindNode("o2_2"); - auto node_o3 = cg->FindNode("o3"); - EXPECT_NE(node_o1_1, nullptr); - EXPECT_NE(node_o1_2, nullptr); - EXPECT_NE(node_o1_3, nullptr); - EXPECT_NE(node_o2_1, nullptr); - EXPECT_NE(node_o2_2, nullptr); - EXPECT_NE(node_o3, nullptr); - EXPECT_EQ(node_o1_1->GetType(), "Const"); - EXPECT_EQ(node_o1_2->GetType(), "Const"); - EXPECT_EQ(node_o1_3->GetType(), "Const"); - EXPECT_EQ(node_o2_1->GetType(), "OCG2"); - EXPECT_EQ(node_o2_2->GetType(), "OCG2"); - EXPECT_EQ(node_o3->GetType(), "OCG3"); - EXPECT_EQ(node_o1_1->GetOpDesc()->HasAttr("value"), true); - EXPECT_EQ(node_o1_2->GetOpDesc()->HasAttr("value"), true); - EXPECT_EQ(node_o1_3->GetOpDesc()->HasAttr("value"), true); - - EXPECT_EQ(node_o1_1->GetName(), "o1_1"); - EXPECT_EQ(node_o1_2->GetName(), "o1_2"); - EXPECT_EQ(node_o1_3->GetName(), "o1_3"); - EXPECT_EQ(node_o2_1->GetName(), "o2_1"); - EXPECT_EQ(node_o2_2->GetName(), "o2_2"); - EXPECT_EQ(node_o3->GetName(), "o3"); - - EXPECT_EQ(node_o1_1->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node_o1_1->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetIdx(), 0); - EXPECT_EQ(node_o1_1->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "o3"); - - EXPECT_EQ(node_o1_2->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node_o1_2->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetIdx(), 1); - EXPECT_EQ(node_o1_2->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "o3"); - - EXPECT_EQ(node_o3->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node_o3->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetIdx(), 0); - EXPECT_EQ(node_o3->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "o2_1"); - - EXPECT_EQ(node_o2_1->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node_o2_1->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetIdx(), 0); - EXPECT_EQ(node_o2_1->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "o2_2"); - - EXPECT_EQ(node_o2_2->GetOutControlNodes().size(), 1); - EXPECT_EQ(node_o2_2->GetOutControlNodes().at(0)->GetName(), "o1_3"); -} - -TEST_F(OperatorConstructGraphUt, ConstructGraph1) { - auto g = BuildGraph1(); - CheckGraph1(g); -} - -TEST_F(OperatorConstructGraphUt, ConstructWithIndex) { - auto g = BuildGraph1ByIndex(); - CheckGraph1(g); -} -TEST_F(OperatorConstructGraphUt, ConstructGraph1ByGnode) { - auto g = BuildGraph1ByGnode(); - CheckGraph1(g); -} - -TEST_F(OperatorConstructGraphUt, GetInputConstData1) { - auto o1_1 = op::Const("o1_1"); - auto o1_2 = op::Const("o1_2"); - auto o3 = op::OCG3("o3"); - auto o2_1 = op::OCG2("o2_1"); - auto o2_2 = op::OCG2("o2_2"); - auto o1_3 = op::Const("o1_3"); - - TensorDesc td{Shape(std::vector({8, 3, 224, 224})), FORMAT_NCHW, DT_UINT8}; - Tensor tensor(td); - tensor.SetData(std::vector(8 * 3 * 224 * 224)); - - o1_1.set_attr_value(tensor); - o1_2.set_attr_value(tensor); - o1_3.set_attr_value(tensor); - o3.set_input_x(o1_1).set_input_shape_by_name(o1_2, "y"); - o2_1.create_dynamic_input_x(1, true).set_dynamic_input_x(0, o3); - o2_2.create_dynamic_input_x(1, true).set_dynamic_input_x(0, o2_1, "y"); - o1_3.AddControlInput(o2_2); - - Graph g{"name"}; - g.SetInputs(std::vector({o1_1, o1_2})).SetOutputs(std::vector({o2_2, o1_3})); - - Tensor t1; - EXPECT_NE(o2_1.GetInputConstData("x1", t1), GRAPH_SUCCESS); - EXPECT_EQ(o3.GetInputConstData("x", t1), GRAPH_SUCCESS); - EXPECT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NCHW); -} - -TEST_F(OperatorConstructGraphUt, SetGetAttrOk) { - auto op = OperatorFactory::CreateOperator("op", "OCG3"); - int64_t value = 0; - EXPECT_NE(op.GetAttr("Hello", value), GRAPH_SUCCESS); - op.SetAttr("Hello", 10); - EXPECT_EQ(op.GetAttr("Hello", value), GRAPH_SUCCESS); - EXPECT_EQ(value, 10); - op = OperatorFactory::CreateOperator(nullptr, "OCG3"); - AscendString invalid_type(""); - EXPECT_EQ(op.GetOpType(invalid_type), GRAPH_FAILED); // op无效 - EXPECT_EQ(invalid_type, ""); -} - -TEST_F(OperatorConstructGraphUt, SetGetAttrByAnyValueOk) { - auto op = OperatorFactory::CreateOperator("op", "OCG3"); - int64_t value = 0; - op.SetAttr("Foo", AttrValue::CreateFrom(10)); - EXPECT_EQ(op.GetAttr("Foo", value), GRAPH_SUCCESS); - EXPECT_EQ(value, 10); - - AttrValue attr_value; - EXPECT_NE(op.GetAttr("Bar", attr_value), GRAPH_SUCCESS); - EXPECT_EQ(op.GetAttr("Foo", attr_value), GRAPH_SUCCESS); - value = 0; - attr_value.GetValue(value); - EXPECT_EQ(value, 10); -} - -TEST_F(OperatorConstructGraphUt, GetIntputOutputSizeOk) { - auto op = OperatorFactory::CreateOperator("op", "OCG3"); - EXPECT_EQ(op.GetInputsSize(), 2); - EXPECT_EQ(op.GetOutputsSize(), 1); -} - -TEST_F(OperatorConstructGraphUt, UpdateInputOutputOk) { - TensorDesc td; - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NHWC); - td.SetShape(Shape(std::vector({8, 1, 224, 224, 16}))); - td.SetOriginShape(Shape(std::vector({8, 224, 224, 3}))); - - auto op = OperatorFactory::CreateOperator("op", "OCG3"); - EXPECT_EQ(op.UpdateInputDesc("x", td), GRAPH_SUCCESS); - EXPECT_EQ(op.UpdateInputDesc("shape", td), GRAPH_SUCCESS); - EXPECT_EQ(op.UpdateOutputDesc("y", td), GRAPH_SUCCESS); - EXPECT_NE(op.UpdateInputDesc("xx", td), GRAPH_SUCCESS); - EXPECT_NE(op.UpdateOutputDesc("yy", td), GRAPH_SUCCESS); - - EXPECT_EQ(op.GetInputDesc(0).GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(op.GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(op.GetInputDescByName("x").GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(op.GetInputDescByName("x").GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(op.GetOutputDesc(0).GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(op.GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(op.GetOutputDescByName("y").GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(op.GetOutputDescByName("y").GetOriginFormat(), FORMAT_NHWC); -} - -TEST_F(OperatorConstructGraphUt, UpdateOutputOk_AutoUpdatedToPeer) { - auto o1_1 = op::Const("o1_1"); - auto o1_2 = op::Const("o1_2"); - auto o3 = op::OCG3("o3"); - - TensorDesc td{Shape(std::vector({8, 3, 224, 224})), FORMAT_NCHW, DT_UINT8}; - Tensor tensor(td); - tensor.SetData(std::vector(8 * 3 * 224 * 224)); - - o1_1.set_attr_value(tensor); - o1_2.set_attr_value(tensor); - o3.set_input_x(o1_1).set_input_shape_by_name(o1_2, "y"); - - Graph g{"name"}; - g.SetInputs(std::vector({o1_1, o1_2})).SetOutputs(std::vector({o3})); - - TensorDesc td1; - td1.SetFormat(FORMAT_NC1HWC0); - td1.SetOriginFormat(FORMAT_NHWC); - td1.SetShape(Shape(std::vector({8, 1, 224, 224, 16}))); - td1.SetOriginShape(Shape(std::vector({8, 224, 224, 3}))); - - o1_1.UpdateOutputDesc("y", td1); - o1_2.UpdateOutputDesc("y", td1); - - EXPECT_EQ(o1_1.GetOutputDesc(0).GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(o1_1.GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(o1_2.GetOutputDesc(0).GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(o1_2.GetOutputDesc(0).GetOriginFormat(), FORMAT_NHWC); - EXPECT_EQ(o3.GetInputDesc(0).GetFormat(), FORMAT_NHWC); - EXPECT_EQ(o3.GetInputDesc(0).GetOriginFormat(), FORMAT_NHWC); -} - -TEST_F(OperatorConstructGraphUt, SubgraphIrDefOk) { - auto if_op = op::OCG4("if"); - std::vector names; - EXPECT_EQ(if_op.GetSubgraphNamesCount(), 2); - // 此接口获取的names是无序的 - EXPECT_EQ(if_op.GetSubgraphNames(names), GRAPH_SUCCESS); - std::set names_set; - for (const auto &name : names) { - names_set.insert(name.GetString()); - } - EXPECT_EQ(names_set, std::set({"then_branch", "else_branch"})); - - auto case_op = op::OCG5("case"); - EXPECT_EQ(case_op.GetSubgraphNamesCount(), 1); - names.clear(); - EXPECT_EQ(case_op.GetSubgraphNames(names), GRAPH_SUCCESS); - EXPECT_EQ(names.size(), 1); - EXPECT_EQ(strcmp(names[0].GetString(), "branches"), 0); -} - -TEST_F(OperatorConstructGraphUt, SetGetSubgraphBuilderOk1) { - auto if_op = op::OCG4("if"); - if_op.set_subgraph_builder_else_branch([]() { return Graph("FromSubgraphBuilder1"); }); - if_op.set_subgraph_builder_then_branch([]() { return Graph("FromSubgraphBuilder2"); }); - - auto else_graph = if_op.get_subgraph_builder_else_branch()(); - AscendString as; - EXPECT_EQ(else_graph.GetName(as), GRAPH_SUCCESS); - EXPECT_EQ(strcmp(as.GetString(), "FromSubgraphBuilder1"), 0); - - auto then_graph = if_op.get_subgraph_builder_then_branch()(); - EXPECT_EQ(then_graph.GetName(as), GRAPH_SUCCESS); - EXPECT_EQ(strcmp(as.GetString(), "FromSubgraphBuilder2"), 0); - EXPECT_EQ(if_op.GetSubgraphBuilder("Hello"), nullptr); -} - -TEST_F(OperatorConstructGraphUt, SetGetSubgraphBuilderOk2) { - auto case_op = op::OCG5("case"); - case_op.create_dynamic_subgraph_branches(3); - case_op.set_dynamic_subgraph_builder_branches(0, []() { return Graph("case1"); }); - case_op.set_dynamic_subgraph_builder_branches(1, []() { return Graph("case2"); }); - case_op.set_dynamic_subgraph_builder_branches(2, []() { return Graph("case3"); }); - - auto case1 = case_op.get_dynamic_subgraph_builder_branches(0)(); - AscendString as; - EXPECT_EQ(case1.GetName(as), GRAPH_SUCCESS); - EXPECT_EQ(strcmp(as.GetString(), "case1"), 0); - - auto case2 = case_op.get_dynamic_subgraph_builder_branches(1)(); - EXPECT_EQ(case2.GetName(as), GRAPH_SUCCESS); - EXPECT_EQ(strcmp(as.GetString(), "case2"), 0); - - auto case3 = case_op.get_dynamic_subgraph_builder_branches(2)(); - EXPECT_EQ(case3.GetName(as), GRAPH_SUCCESS); - EXPECT_EQ(strcmp(as.GetString(), "case3"), 0); -} - -TEST_F(OperatorConstructGraphUt, GetOpsTypeList) { - std::vector all_ops; - auto ret = OperatorFactory::GetOpsTypeList(all_ops); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_TRUE(all_ops.size() > 0); - - std::vector all_ops2; - ret = OperatorFactory::GetOpsTypeList(all_ops2); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(all_ops2.size(), all_ops.size()); -} - -static graphStatus stub_func(Operator &op) -{ - return GRAPH_SUCCESS; -} - -TEST_F(OperatorConstructGraphUt, InferFuncRegister) { - InferShapeFunc infer_shape_func = stub_func; - InferShapeFuncRegister(nullptr, infer_shape_func); - InferShapeFuncRegister("OCG3", infer_shape_func); - InferShapeFuncRegister(std::string("OCG3"), infer_shape_func); - EXPECT_NE(OperatorFactoryImpl::GetInferShapeFunc("OCG3"), nullptr); - - InferFormatFunc infer_format_func = stub_func; - InferFormatFuncRegister(nullptr, infer_format_func); - InferFormatFuncRegister("OCG3", infer_format_func); - InferFormatFuncRegister(std::string("OCG3"), infer_format_func); - EXPECT_NE(OperatorFactoryImpl::GetInferFormatFunc("OCG3"), nullptr); - - VerifyFunc verify_func = stub_func; - VerifyFuncRegister(nullptr, verify_func); - VerifyFuncRegister("OCG3", verify_func); - VerifyFuncRegister(std::string("OCG3"), verify_func); - EXPECT_NE(OperatorFactoryImpl::GetVerifyFunc("OCG3"), nullptr); -} - -TEST_F(OperatorConstructGraphUt, IsExistOp) { - Operator op; - op = OperatorFactory::CreateOperator("op", "OCG3"); - - OpCreator op_creator; - OperatorCreatorRegister( std::string("OCG3"), op_creator); - - bool ret = OperatorFactory::IsExistOp(nullptr); - EXPECT_FALSE(ret); - ret = OperatorFactory::IsExistOp("add"); - EXPECT_FALSE(ret); - - std::string op_type = "add"; - ret = OperatorFactory::IsExistOp(op_type); - EXPECT_FALSE(ret); - - ret = OperatorFactory::IsExistOp("OCG3"); - EXPECT_TRUE(ret); - - ret = OperatorFactory::IsExistOp(std::string("OCG3")); - EXPECT_TRUE(ret); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/operator_unittest.cc b/tests/ut/graph/testcase/operator_unittest.cc deleted file mode 100644 index ddc35aa5554b8aa44211bbd287cb6815ab142390..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/operator_unittest.cc +++ /dev/null @@ -1,2830 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "external/graph/operator.h" -#include "external/graph/operator_factory.h" -#include "external/graph/attr_value.h" -#include "graph/ge_attr_value.h" -#include "graph/normal_graph/operator_impl.h" -#include "external/graph/tensor.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph/type/tensor_type_impl.h" -#include "graph_builder_utils.h" -#include -#include "graph/utils/tensor_utils.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/file_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/node_utils.h" -#include "inc/external/graph/graph.h" -#include "external/graph/operator_reg.h" -#include "checker/summary_checker.h" -#include "checker/topo_checker.h" -#include "graph/any_value.h" - -namespace ge { -namespace { -REG_OP(Foo01).OUTPUT(y, TensorType::NumberType()).OP_END_FACTORY_REG(Foo01); -REG_OP(Foo11).INPUT(x, TensorType::NumberType()).OUTPUT(y, TensorType::NumberType()).OP_END_FACTORY_REG(Foo11); -REG_OP(Foo02).OUTPUT(x, TensorType::NumberType()).OUTPUT(y, TensorType::NumberType()).OP_END_FACTORY_REG(Foo02); -REG_OP(Foo22) - .INPUT(m, TensorType::NumberType()) - .INPUT(n, TensorType::NumberType()) - .OUTPUT(x, TensorType::NumberType()) - .OUTPUT(y, TensorType::NumberType()) - .OP_END_FACTORY_REG(Foo22); -REG_OP(DFoo22) - .INPUT(m, TensorType::NumberType()) - .DYNAMIC_INPUT(n, TensorType::NumberType()) - .OUTPUT(x, TensorType::NumberType()) - .OUTPUT(y, TensorType::NumberType()) - .OP_END_FACTORY_REG(DFoo22); - -/* - * ┌──────────────────────┐ - * │ cond │ - * data const │ const │ - * │ │ │ | │ - * └──┬──┘ ┌───┤ data──add──netoutput │ - * │ │ │ │ - * while ─┤ ├──────────────────────┤ - * │ │ ├──────────────────────┴─┐ - * │ └───┤ body │ - * netoutput │ data──reshape──netoutpu│ - * │ | │ - * │ const │ - * └────────────────────────┘ - */ -ComputeGraphPtr BuildWhileGraphWithConstInput() { - ut::GraphBuilder builder("main_graph"); - auto data_1 = builder.AddNode("data_1", "Data", 1, 1); - auto const_1 = builder.AddNode("const_1", "Const", 1, 1); - auto while_1 = builder.AddNode("while_1", "While", 2, 2); - auto netoutput_1 = builder.AddNode("netoutput_1", "NetOutput", 1, 1); - builder.AddDataEdge(data_1, 0, while_1, 0); - builder.AddDataEdge(const_1, 0, while_1, 1); - builder.AddDataEdge(const_1, 0, netoutput_1, 0); - auto main_graph = builder.GetGraph(); - - ut::GraphBuilder cond_builder("cond_graph"); - auto cond_data_1 = cond_builder.AddNode("cond_data_1", "Data", 1, 1); - auto cond_const_1 = cond_builder.AddNode("cond_const_1", "Const", 1, 1); - auto cond_add_1 = cond_builder.AddNode("cond_add_1", "Add", 2, 1); - auto cond_netoutput_1 = cond_builder.AddNode("cond_netoutput_1", "NetOutput", 1, 1); - cond_builder.AddDataEdge(cond_data_1, 0, cond_add_1, 0); - cond_builder.AddDataEdge(cond_const_1, 0, cond_add_1, 1); - cond_builder.AddDataEdge(cond_add_1, 0, cond_netoutput_1, 0); - auto cond_graph = cond_builder.GetGraph(); - AttrUtils::SetInt(cond_data_1->GetOpDesc(), "_parent_node_index", static_cast(0)); - cond_graph->SetParentGraph(main_graph); - cond_graph->SetParentNode(main_graph->FindNode("while_1")); - main_graph->FindNode("while_1")->GetOpDesc()->AddSubgraphName("cond_graph"); - main_graph->FindNode("while_1")->GetOpDesc()->SetSubgraphInstanceName(0, "cond_graph"); - main_graph->AddSubgraph("cond_graph", cond_graph); - - ut::GraphBuilder body_builder("body_graph"); - auto body_data_1 = body_builder.AddNode("body_data_1", "Data", 1, 1); - auto body_const_1 = body_builder.AddNode("body_const_1", "Const", 1, 1); - auto body_reshape_1 = body_builder.AddNode("body_reshape_1", "Reshape", 2, 1); - auto body_netoutput_1 = body_builder.AddNode("body_netoutput_1", "NetOutput", 1, 1); - body_builder.AddDataEdge(body_data_1, 0, body_reshape_1, 0); - body_builder.AddDataEdge(body_const_1, 0, body_reshape_1, 1); - body_builder.AddDataEdge(body_reshape_1, 0, body_netoutput_1, 0); - auto body_graph = body_builder.GetGraph(); - AttrUtils::SetInt(cond_data_1->GetOpDesc(), "_parent_node_index", static_cast(1)); - body_graph->SetParentGraph(main_graph); - body_graph->SetParentNode(main_graph->FindNode("while_1")); - main_graph->FindNode("while_1")->GetOpDesc()->AddSubgraphName("body_graph"); - main_graph->FindNode("while_1")->GetOpDesc()->SetSubgraphInstanceName(1, "body_graph"); - main_graph->AddSubgraph("body_graph", body_graph); - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - AttrUtils::SetTensor(body_const_1->GetOpDesc(), "value", tensor); - auto op_desc = body_reshape_1->GetOpDesc(); - op_desc->impl_->input_name_idx_["x"] = 0; - op_desc->impl_->input_name_idx_["shape"] = 1; - - return main_graph; -} -} // namespace -class UtestOperater : public testing::Test { - public: - /* - * Foo11 - * | - * Foo01 - */ - void CheckTopoGraph1(const Graph &graph) { - auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - ASSERT_NE(compute_graph, nullptr); - ASSERT_EQ(gert::SummaryChecker(compute_graph).StrictAllNodeTypes({{"Foo01", 1}, {"Foo11", 1}}), "success"); - auto foo11_node = compute_graph->FindNode("foo11"); - ASSERT_NE(foo11_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo11_node).StrictConnectFrom({{"Foo01"}}), "success"); - auto foo01_node = compute_graph->FindNode("foo01"); - ASSERT_NE(foo01_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo01_node).StrictConnectTo(0, {{"Foo11"}}), "success"); - } - - /* - * Foo22 - * / | - * Foo11 | - * \ | - * Foo02 - */ - void CheckTopoGraph2(const Graph &graph) { - auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - ASSERT_NE(compute_graph, nullptr); - ASSERT_EQ(gert::SummaryChecker(compute_graph).StrictAllNodeTypes({{"Foo02", 1}, {"Foo11", 1}, {"Foo22", 1}}), - "success"); - auto foo01_node = compute_graph->FindNode("foo02"); - ASSERT_NE(foo01_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo01_node).StrictConnectTo(0, {{"Foo11"}}), "success"); - ASSERT_EQ(gert::NodeTopoChecker(foo01_node).StrictConnectTo(1, {{"Foo22"}}), "success"); - auto foo11_node = compute_graph->FindNode("foo11"); - ASSERT_NE(foo11_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo11_node).StrictConnectFrom({{"Foo02"}}), "success"); - ASSERT_EQ(gert::NodeTopoChecker(foo11_node).StrictConnectTo(0, {{"Foo22"}}), "success"); - auto foo22_node = compute_graph->FindNode("foo22"); - ASSERT_NE(foo22_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo22_node).StrictConnectFrom({{"Foo11"}, {"Foo02"}}), "success"); - } - /* - * Foo22 - * / |d0 \d1 - * Foo11 | | - * \0 |0 /1 - * Foo02 - */ - void CheckTopoGraph3(const Graph &graph) { - auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - ASSERT_NE(compute_graph, nullptr); - ASSERT_EQ(gert::SummaryChecker(compute_graph).StrictAllNodeTypes({{"Foo02", 1}, {"Foo11", 1}, {"DFoo22", 1}}), - "success"); - auto foo02_node = compute_graph->FindNode("foo02"); - ASSERT_NE(foo02_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo02_node).StrictConnectTo(0, {{"Foo11"}, {"DFoo22"}}), "success"); - ASSERT_EQ(gert::NodeTopoChecker(foo02_node).StrictConnectTo(1, {{"DFoo22"}}), "success"); - auto foo11_node = compute_graph->FindNode("foo11"); - ASSERT_NE(foo11_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo11_node).StrictConnectFrom({{"Foo02"}}), "success"); - ASSERT_EQ(gert::NodeTopoChecker(foo11_node).StrictConnectTo(0, {{"DFoo22"}}), "success"); - auto foo22_node = compute_graph->FindNode("foo22"); - ASSERT_NE(foo22_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo22_node).StrictConnectFrom({{"Foo11"}, {"Foo02"}, {"Foo02"}}), "success"); - } -}; - -TEST_F(UtestOperater, GetInputConstData) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto data = builder.AddNode("Data", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - auto enter = builder.AddNode("Enter", "Enter", 1, 1); - auto transdata = builder.AddNode("Transdata", "Transdata", 2, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(data2, 0, enter, 0); - builder.AddDataEdge(data, 0, transdata, 0); - builder.AddDataEdge(enter, 0, transdata, 1); - builder.AddDataEdge(transdata, 0, netoutput, 0); - auto graph = builder.GetGraph(); - - auto ge_tensor = std::make_shared(); - auto op_desc = transdata->GetOpDesc(); - op_desc->impl_->input_name_idx_["Data"] = 0; - op_desc->impl_->input_name_idx_["Enter"] = 1; - auto tensor_desc = op_desc->MutableInputDesc(0); - AttrUtils::SetTensor(tensor_desc, "_value", ge_tensor); - - Tensor tensor; - auto op = OpDescUtils::CreateOperatorFromNode(transdata); - ASSERT_EQ(op.GetInputConstData("Data", tensor), GRAPH_SUCCESS); - ASSERT_EQ(op.GetInputConstData("Enter", tensor), GRAPH_FAILED); -} -/** -------------------------- - * const | sub_data sub_const | - * | | \ / | - * case-----------------------| Add | - * | | | | - * netoutput | sub_netoutput | - * --------------------------- - */ -TEST_F(UtestOperater, GetInputConstData_subgraph) { - auto ge_tensor = std::make_shared(); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto const_node = builder.AddNode("Const", "Const", 0, 1); - AttrUtils::SetTensor(const_node->GetOpDesc(), "value", ge_tensor); - auto case_node = builder.AddNode("Case", "Case", 1, 1); - auto netoutput = builder.AddNode("Netoutput", "NetOutput", 1, 0); - builder.AddDataEdge(const_node, 0, case_node, 0); - builder.AddDataEdge(case_node, 0, netoutput, 0); - auto parent_graph = builder.GetGraph(); - - ut::GraphBuilder sub_builder = ut::GraphBuilder("subgraph_graph"); - auto sub_data = sub_builder.AddNode("sub_data", "Data", 0, 1); - auto sub_const = sub_builder.AddNode("sub_const", "Const", 0, 1); - AttrUtils::SetTensor(sub_const->GetOpDesc(), "value", ge_tensor); - auto add = sub_builder.AddNode("Add", "Add", 2, 1); - auto sub_netoutput = sub_builder.AddNode("sub_netoutput", "NetOutput", 1, 0); - sub_builder.AddDataEdge(sub_data, 0, add, 0); - sub_builder.AddDataEdge(sub_const, 0, add, 1); - sub_builder.AddDataEdge(add, 0, sub_netoutput, 0); - - auto subgraph = sub_builder.GetGraph(); - subgraph->SetParentNode(case_node); - subgraph->SetParentGraph(parent_graph); - parent_graph->AddSubgraph(subgraph->GetName(), subgraph); - AttrUtils::SetInt(sub_data->GetOpDesc(), "_parent_node_index", 0); - - auto op_desc = add->GetOpDesc(); - op_desc->impl_->input_name_idx_["sub_data"] = 0; - op_desc->impl_->input_name_idx_["sub_const"] = 1; - - Tensor tensor; - auto op = OpDescUtils::CreateOperatorFromNode(add); - ASSERT_EQ(op.GetInputConstData("sub_const", tensor), GRAPH_SUCCESS); - ASSERT_EQ(op.GetInputConstData("sub_data", tensor), GRAPH_SUCCESS); -} - - -/* ------------------------- -* | partitioncall_0_const1* | -* partitioncall_0--------------| | | -* | | netoutput | -* | -------------------------- -* | ------------------ ------------- -* | | data | | Pld | -* | | | | | | | -* partitioncall_1--------------| FftsSub |------->| squeeze* | -* | | | | | | -* | netoutput | | netoutput | -* ------------------ ------------- -*/ -TEST_F(UtestOperater, GetInputConstData_cross_subgraph) { - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", "PartitionedCall", 0, 1); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", "PartitionedCall", 1, 1); - root_builder.AddDataEdge(partitioncall_0, 0, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - - // 1.build partitioncall_0 sub graph - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_const1 = p1_sub_builder.AddNode("partitioncall_0_const1", "Const", 0, 1); - auto ge_tensor = std::make_shared(); - ASSERT_TRUE(AttrUtils::SetTensor(partitioncall_0_const1->GetOpDesc(), "value", ge_tensor)); - - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", "NetOutput", 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - - // 2.build partitioncall_1 sub graph - auto p2_sub_builder = ut::GraphBuilder("partitioncall_1_sub"); - const auto &partitioncall_1_data = p2_sub_builder.AddNode("partitioncall_1_data", "Data", 0, 1); - AttrUtils::SetInt(partitioncall_1_data->GetOpDesc(), "_parent_node_index", 0); - const auto &partitioncall_1_ffts_sub = p2_sub_builder.AddNode("FftsSub", "PartitionedCall", 1, 1); - const auto &partitioncall_1_netoutput = p2_sub_builder.AddNode("partitioncall_1_netoutput", "NetOutput", 1, 1); - p2_sub_builder.AddDataEdge(partitioncall_1_data, 0, partitioncall_1_ffts_sub, 0); - p2_sub_builder.AddDataEdge(partitioncall_1_ffts_sub, 0, partitioncall_1_netoutput, 0); - const auto &sub_graph2 = p2_sub_builder.GetGraph(); - sub_graph2->SetParentNode(partitioncall_1); - sub_graph2->SetParentGraph(root_graph); - partitioncall_1->GetOpDesc()->AddSubgraphName("f"); - partitioncall_1->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_1_sub"); - - - // 2.1 build sgt sub graph - auto sgt_sub_builder = ut::GraphBuilder("sgt_sub"); - const auto &sgt_pld = sgt_sub_builder.AddNode("sgt_plt", "PlaceHolder", 0, 1); - const auto &sgt_squeeze = sgt_sub_builder.AddNode("sgt_squeeze", "Squeeze", 1, 1); - sgt_squeeze->GetOpDesc()->impl_->input_name_idx_["sub_data"] = 0; - const auto &sgt_netoutput = sgt_sub_builder.AddNode("sgt_netoutput", "NetOutput", 1, 1); - sgt_sub_builder.AddDataEdge(sgt_pld, 0, sgt_squeeze, 0); - sgt_sub_builder.AddDataEdge(sgt_squeeze, 0, sgt_netoutput, 0); - const auto &sgt_sub_graph = sgt_sub_builder.GetGraph(); - sgt_sub_graph->SetParentNode(partitioncall_1_ffts_sub); - sgt_sub_graph->SetParentGraph(sub_graph2); - partitioncall_1_ffts_sub->GetOpDesc()->AddSubgraphName("sgt_sub"); - partitioncall_1_ffts_sub->GetOpDesc()->SetSubgraphInstanceName(0, "sgt_sub"); - - - sgt_pld->GetOpDesc()->SetExtAttr("parentNode", partitioncall_1_data); - - - root_graph->AddSubgraph(sgt_sub_graph->GetName(), sgt_sub_graph); - root_graph->AddSubgraph(sub_graph->GetName(), sub_graph); - root_graph->AddSubgraph(sub_graph2->GetName(), sub_graph2); - - auto op = OpDescUtils::CreateOperatorFromNode(sgt_squeeze); - Tensor res; - ASSERT_EQ(op.GetInputConstData("sub_data", res), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, TestOperatorSetInputs) { - ge::Operator dst_op = ge::Operator("Mul"); - ge::Operator src_op = ge::Operator("Add"); - dst_op.InputRegister("x1"); - dst_op.InputRegister("x2"); - dst_op.OutputRegister("y"); - - src_op.InputRegister("x1"); - src_op.InputRegister("x2"); - src_op.OutputRegister("y"); - - ASSERT_EQ(src_op.GetInputsSize(), 2U); - ASSERT_EQ(dst_op.GetInputsSize(), 2U); - // src_index is illegal - (void) dst_op.SetInput(0U, src_op, 3U); - ASSERT_EQ(src_op.GetInputsSize(), 2U); - // dst_index is illegal - (void) dst_op.SetInput(3U, src_op, 0U); - ASSERT_EQ(src_op.GetInputsSize(), 2U); - - (void) dst_op.SetInput(1U, src_op, 0U); - ASSERT_EQ(src_op.GetInputsSize(), 2U); - - ge::Operator null_op; - (void) null_op.SetInput(1U, src_op, 0U); - ASSERT_EQ(null_op.GetInputsSize(), 0U); - - std::string dst_name = "x1"; - (void) dst_op.SetInput(dst_name, src_op, 0U); - ASSERT_EQ(dst_op.GetInputsSize(), 2U); -} - -TEST_F(UtestOperater, AttrRegister_Float) { - auto op = Operator("Data"); - std::string attr = "attr"; - float value = 1.0; - op.AttrRegister(attr, value); - float ret = 0; - op.GetAttr(attr.c_str(), ret); - ASSERT_FLOAT_EQ(value, ret); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_ListFloat) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector value = {1.0, 2.0}; - op.AttrRegister(attr, value); - std::vector ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_FLOAT_EQ(value[0], ret[0]); - ASSERT_FLOAT_EQ(value[1], ret[1]); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_Int) { - auto op = Operator("Data"); - std::string attr = "attr"; - int64_t value = 1; - op.AttrRegister(attr, value); - int64_t ret = 0; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value, ret); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_ListInt) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector value = {1, 2}; - op.AttrRegister(attr, value); - std::vector ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value[0], ret[0]); - ASSERT_EQ(value[1], ret[1]); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_String) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::string value = "on"; - op.AttrRegister(attr.c_str(), value.c_str()); - std::string ret; - op.GetAttr(attr, ret); - ASSERT_EQ(value, ret); - op.AttrRegister(nullptr, value.c_str()); -} - -TEST_F(UtestOperater, AttrRegister_Bool) { - auto op = Operator("Data"); - std::string attr = "attr"; - bool value = true; - op.AttrRegister(attr, value); - bool ret = false; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value, ret); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_ListBool) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector value = {false, true}; - op.AttrRegister(attr, value); - std::vector ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value[0], ret[0]); - ASSERT_EQ(value[1], ret[1]); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_Tensor) { - EXPECT_NO_THROW( - auto op = Operator("Data"); - auto value = Tensor(); - std::string attr = "attr"; - op.AttrRegister(attr, value); - op.AttrRegister(nullptr, value); - ); -} - -TEST_F(UtestOperater, AttrRegister_ListTensor) { - EXPECT_NO_THROW( - auto op = Operator("Data"); - std::vector value = {Tensor()}; - op.AttrRegister("attr", value); - op.AttrRegister(nullptr, value); - ); -} - -TEST_F(UtestOperater, AttrRegister_OpBytes) { - auto op = Operator("Data"); - std::string attr = "attr"; - auto value = OpBytes{1, 2, 3}; - op.AttrRegister(attr, value); - OpBytes ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value[0], ret[0]); - ASSERT_EQ(value[1], ret[1]); - ASSERT_EQ(value[2], ret[2]); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_ListListInt) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector> value = {{1, 2}, {3}}; - op.AttrRegister(attr, value); - std::vector> ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value[0][0], ret[0][0]); - ASSERT_EQ(value[0][1], ret[0][1]); - ASSERT_EQ(value[1][0], ret[1][0]); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_ListDataType) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector value = {DataType::DT_FLOAT, DataType::DT_INT64}; - op.AttrRegister(attr, value); - std::vector ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value[0], ret[0]); - ASSERT_EQ(value[1], ret[1]); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_DataType) { - auto op = Operator("Data"); - std::string attr = "attr"; - auto value = DataType::DT_FLOAT; - op.AttrRegister(attr, value); - DataType ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value, ret); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_NamedAttrs) { - auto op = Operator("Data"); - std::string attr = "attr"; - auto value = NamedAttrs(); - value.SetName("name"); - op.AttrRegister(attr, value); - NamedAttrs ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value.GetName(), ret.GetName()); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_ListNamedAttrs) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector value = {NamedAttrs()}; - value[0].SetName("name"); - op.AttrRegister(attr, value); - std::vector ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(value[0].GetName(), ret[0].GetName()); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_AscendString) { - auto op = Operator("Data"); - std::string attr = "attr"; - auto value = AscendString("1"); - op.AttrRegister(attr, value); - AscendString ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(std::string(value.GetString()), std::string(ret.GetString())); - op.AttrRegister(nullptr, value); -} - -TEST_F(UtestOperater, AttrRegister_AscendString2) { - auto op = Operator("Data"); - std::string attr = "attr"; - auto value = AscendString("1"); - op.AttrRegister(attr, value); - AscendString ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(std::string(value.GetString()), std::string(ret.GetString())); - op.AttrRegister(attr, AscendString("")); -} - -TEST_F(UtestOperater, AttrRegister_ListAscendString) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector value = {AscendString("1")}; - op.AttrRegister(attr, value); - std::vector ret; - op.GetAttr(attr.c_str(), ret); - ASSERT_EQ(std::string(value[0].GetString()), std::string(ret[0].GetString())); - op.AttrRegister(nullptr, value); - op.operator_impl_ = nullptr; - op.AttrRegister(attr, value); - value[0].name_ = nullptr; - op.AttrRegister(attr, value); -} - -TEST_F(UtestOperater, AttrRegister_ListString) { - auto op = Operator("Data"); - std::string attr = "attr"; - std::vector value; - op.AttrRegister(attr, value); - std::vector ret; - op.GetAttr(attr, ret); - ASSERT_EQ(ret.size(), 0); -} - -TEST_F(UtestOperater, RequiredAttrRegister_Success) { - auto op = Operator("Data"); - op.RequiredAttrRegister("x"); - op.RequiredAttrRegister(nullptr); - op.RequiredAttrRegister(std::string("y")); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetIrAttrNames(), std::vector({"x", "y"})); -} - -TEST_F(UtestOperater, RequiredAttrWithTypeRegister_Success) { - auto op = Operator("Cast"); - op.RequiredAttrWithTypeRegister("dst_type", "Int"); - op.AttrRegister("fake_ir_attr", true); - op.RequiredAttrWithTypeRegister("fake_list_type", "ListType"); - op.RequiredAttrWithTypeRegister(nullptr, nullptr); // invalid case - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - AttrUtils::SetBool(op_desc, "fake_custom_attr", false); - ASSERT_EQ(op_desc->GetIrAttrNames(), std::vector({"dst_type", "fake_ir_attr", "fake_list_type"})); - std::map ir_attr_name_types; - ASSERT_EQ(op.GetAllIrAttrNamesAndTypes(ir_attr_name_types), GRAPH_SUCCESS); - std::map ir_attr_name_types_expected - {{"dst_type", "VT_INT"}, {"fake_ir_attr", "VT_BOOL"}, {"fake_list_type", "VT_LIST_DATA_TYPE"}}; - ASSERT_EQ(ir_attr_name_types, ir_attr_name_types_expected); - ASSERT_TRUE(op_desc->HasAttr("dst_type")); - ASSERT_TRUE(op_desc->HasAttr("fake_ir_attr")); - ASSERT_TRUE(op_desc->HasAttr("fake_custom_attr")); - ASSERT_TRUE(op_desc->HasRequiredAttr("dst_type")); - ASSERT_TRUE(op_desc->HasRequiredAttr("fake_list_type")); - ASSERT_FALSE(op_desc->HasRequiredAttr("fake_ir_attr")); - ASSERT_FALSE(op_desc->HasRequiredAttr("fake_custom_attr")); - std::map setted_attr_name_types; - ASSERT_EQ(op.GetAllAttrNamesAndTypes(setted_attr_name_types), GRAPH_SUCCESS); - for (auto pair : setted_attr_name_types) { - std::cout << pair.first.GetString() << "|" << pair.second.GetString() << std::endl; - } - std::map setted_attr_name_types_expected{{"fake_ir_attr", "VT_BOOL"}, - {"fake_custom_attr", "VT_BOOL"}}; - ASSERT_EQ(setted_attr_name_types, setted_attr_name_types_expected); -} - -TEST_F(UtestOperater, SubgraphRegister) { - EXPECT_NO_THROW( - std::string name = "add"; - auto op = Operator("Add"); - bool dynamic = true; - op.SubgraphRegister(name, dynamic); - op.SubgraphRegister(nullptr, dynamic); - ); -} - -TEST_F(UtestOperater, SubgraphCountRegister) { - EXPECT_NO_THROW( - std::string name = "add"; - auto op = Operator("Add"); - uint32_t count = 1; - op.SubgraphCountRegister(name, count); - op.SubgraphCountRegister(nullptr, count); - ); -} - -TEST_F(UtestOperater, SetSubgraphBuilder) { - std::string name = "add"; - auto op = Operator("Add"); - uint32_t index = 1; - SubgraphBuilder builder = []() { return Graph(); }; - op.SetSubgraphBuilder(name, index, builder); - op.SetSubgraphBuilder(nullptr, index, builder); - - SubgraphBuilder builder2; - builder2 = op.GetSubgraphBuilder(name); - - SubgraphBuilder builder3; - builder3 = op.GetDynamicSubgraphBuilder(nullptr, 0); - builder3 = op.GetDynamicSubgraphBuilder("add", 0); - - std::vector vec_name; - vec_name = op.GetSubgraphNames(); - EXPECT_EQ(vec_name.size(), 0); - - op.GetSubgraph(nullptr); - Graph graph = op.GetSubgraph(name); - EXPECT_EQ(graph.GetName(), ""); - - graph = op.GetSubgraph("add"); - EXPECT_EQ(graph.GetName(), ""); - - op.GetDynamicSubgraph(nullptr, 0); - graph = op.GetDynamicSubgraph(name, 0); - EXPECT_EQ(graph.GetName(), ""); - - graph = op.GetDynamicSubgraph("add", 0); - EXPECT_EQ(graph.GetName(), ""); -} - -TEST_F(UtestOperater, GetSubgraphImpl) { - EXPECT_NO_THROW( - std::string name = "add"; - auto op = Operator("Add"); - op.GetSubgraphImpl(name); - op.GetSubgraphImpl(nullptr); - ); -} - -TEST_F(UtestOperater, SetInput_Handler) { - EXPECT_NO_THROW( - std::string name = "add"; - std::string type = "Add"; - auto op = Operator(type); - auto handler = OutHandler(nullptr); - op.SetInput(name.c_str(), handler); - op.SetInput(nullptr, handler); - ); -} - -TEST_F(UtestOperater, GetOutput) { - EXPECT_NO_THROW( - std::string name = "add"; - auto op = Operator("Add"); - op.GetOutput(name.c_str()); - op.GetOutput(nullptr); - ); -} - -TEST_F(UtestOperater, GetInputConstDataOut) { - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1}; - std::vector shape{1}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - const auto &data = reinterpret_cast(tensor->GetData().GetData()); - const auto size = tensor->GetData().GetSize(); - ASSERT_EQ(SaveBinToFile(data, size, "./weight.bin"), GRAPH_SUCCESS); - - std::string name = "fileconstant"; - ge::OpDescPtr op_desc = std::make_shared(name, "FileConstant"); - AttrUtils::SetDataType(op_desc, "dtype", DT_UINT8); - op_desc->AddOutputDesc(tensor->GetTensorDesc()); - AttrUtils::SetStr(op_desc, "location", "./weight.bin"); - AttrUtils::SetInt(op_desc, "length", size); - auto input_op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - auto add_op = Operator("Add"); - ge::OpIO out_handle(name, 0, input_op.GetOperatorImplPtr()); - add_op.GetOperatorImplPtr()->input_link_.insert({name, out_handle}); - add_op.GetOperatorImplPtr()->GetOpDescImpl()->impl_->input_name_idx_.insert({name, 0U}); - Tensor a = Tensor(); - ASSERT_EQ(add_op.GetInputConstDataOut(name.c_str(), a), GRAPH_SUCCESS); - ConstGeTensorBarePtr b = nullptr; - b = OpDescUtils::GetInputConstData(add_op, 0); - ASSERT_NE(b, nullptr); - system("rm -rf ./weight.bin"); -} - -TEST_F(UtestOperater, testTensorType) { - DataType dt(DT_INT16); - TensorType tt1(dt); - EXPECT_EQ(*(tt1.tensor_type_impl_->dt_set_.cbegin()), DT_INT16); - - const std::initializer_list types = {DT_INT8, DT_UINT8, DT_INT16}; - TensorType tt2(types); - EXPECT_EQ(tt2.tensor_type_impl_->dt_set_.size(), 3); -} - -TEST_F(UtestOperater, CreateOperator) { - Operator op; - OpDescPtr op_desc; - - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - EXPECT_FALSE(op.IsEmpty()); -} - -TEST_F(UtestOperater, testGetName) { - AscendString name; - Operator op("one_op", "add"); - op.GetName(name); - - const char *str = name.GetString(); - EXPECT_EQ(strcmp(str, "one_op"), 0); -} - -TEST_F(UtestOperater, GetInputConstData2) { - Operator op; - OpDescPtr op_desc; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - std::string dst_name("dst_name"); - Tensor td; - - EXPECT_NE(op.GetInputConstData(dst_name, td), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, GetNode) { - Operator op; - OpDescPtr op_desc; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - EXPECT_EQ(op.GetNode(), nullptr); -} - -TEST_F(UtestOperater, GetInputDesc) { - Operator op; - OpDescPtr op_desc; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - std::string str_name = "input_desc_name"; - TensorDesc td = op.GetInputDesc(str_name); - - EXPECT_EQ(td.GetName().length(), 0); -} - -TEST_F(UtestOperater, TryGetInputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - TensorDesc td; - auto ret = op.TryGetInputDesc("input_name_1", td); - EXPECT_EQ(ret, GRAPH_FAILED); - - string str = "input_name_2"; - ret = op.TryGetInputDesc(str, td); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, UpdateInputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - TensorDesc td; - - string str = "input_name"; - auto ret = op.UpdateInputDesc(str, td); - EXPECT_EQ(ret, GRAPH_FAILED); - - ret = op.UpdateInputDesc("input_name", td); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, GetOutputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - string str = "output_name"; - TensorDesc td = op.GetOutputDesc(str); - EXPECT_EQ(td.GetName().length(), 0); -} - -TEST_F(UtestOperater, UpdateOutputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - TensorDesc td; - string str = "output_name"; - auto ret = op.UpdateOutputDesc(str, td); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, GetDynamicInputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - string str = "input_name"; - TensorDesc td_1 = op.GetDynamicInputDesc(str, 0); - TensorDesc td_2 = op.GetDynamicInputDesc("input_name", 0); - EXPECT_EQ(td_1.GetName().length(), 0); - EXPECT_EQ(td_2.GetName().length(), 0); -} - -TEST_F(UtestOperater, UpdateDynamicInputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - TensorDesc td_1; - string str = "input_name"; - auto ret = op.UpdateDynamicInputDesc(str, 0, td_1); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = op.UpdateDynamicInputDesc("input_name", 0, td_1); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, GetDynamicOutputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - string str = "output_name"; - TensorDesc td_1 = op.GetDynamicOutputDesc(str, 0); - TensorDesc td_2 = op.GetDynamicOutputDesc("output_name", 0); - EXPECT_EQ(td_1.GetName().length(), 0); - EXPECT_EQ(td_2.GetName().length(), 0); -} - -TEST_F(UtestOperater, UpdateDynamicOutputDesc) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - TensorDesc td_1; - string str = "output_name"; - auto ret = op.UpdateDynamicOutputDesc(str, 0, td_1); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = op.UpdateDynamicOutputDesc("output_name", 0, td_1); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, InferShapeAndType) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - auto ret = op.InferShapeAndType(); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, InferShapeAndType_param_invalid) { - Operator op; - op.operator_impl_ = std::make_shared("name", "type"); - - auto ret = op.InferShapeAndType(); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, VerifyAllAttr) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - auto ret = op.VerifyAllAttr(true); - EXPECT_EQ(ret, GRAPH_FAILED); - - ret = op.VerifyAllAttr(false); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, VerifyAllAttr_success) { - Operator op; - op.operator_impl_ = std::make_shared("name", "type"); - - auto ret = op.VerifyAllAttr(true); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, GetAllAttrNamesAndTypes) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - auto ret = op.GetAllAttrNamesAndTypes(); - EXPECT_EQ(ret.size(), 0); - - std::map attr_name_types; - auto ret_2 = op.GetAllAttrNamesAndTypes(attr_name_types); - EXPECT_EQ(ret_2, GRAPH_FAILED); -} - -TEST_F(UtestOperater, GetAllAttrs) { - Operator op("name", "type"); - const std::string name = "name"; - std::string value("value"); - op.SetAttr(name, value); - std::map attr_name_types; - auto ret = op.GetAllAttrNamesAndTypes(attr_name_types); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - auto attr_types = op.GetAllAttrNamesAndTypes(); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, FuncRegister) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - std::function func; - - op.InferFuncRegister(func); - - if (op.operator_impl_->GetOpDescImpl() != nullptr) { - printf("FuncRegister GetOpDescImpl is not null!\n"); - //auto ret1 = op.operator_impl_->GetOpDescImpl()->GetInferFunc(); - //EXPECT_EQ(ret1, nullptr); - } else { - printf("FuncRegister GetOpDescImpl is null!\n"); - } - - ASSERT_NE(op.operator_impl_, nullptr); -} - -TEST_F(UtestOperater, FuncRegister2) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - std::function func; - - op.InferFormatFuncRegister(func); - op.VerifierFuncRegister(func); - - ASSERT_NE(op.operator_impl_, nullptr); -} - -TEST_F(UtestOperater, GetDynamicInputNum) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - int num1 = op.GetDynamicInputNum("input_name"); - EXPECT_EQ(num1, 0); - - int num2 = op.GetDynamicInputNum(std::string("input_name")); - EXPECT_EQ(num2, 0); -} - -TEST_F(UtestOperater, GetDynamicOutputNum) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - int num1 = op.GetDynamicOutputNum("output_name"); - EXPECT_EQ(num1, 0); - - int num2 = op.GetDynamicOutputNum(std::string("output_name")); - EXPECT_EQ(num2, 0); -} - -TEST_F(UtestOperater, VerifyAll) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - auto ret = op.VerifyAll(); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, GetOperatorImplPtr) { - Operator op; - OpDescPtr op_desc_1; - op = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - auto ret = op.GetOperatorImplPtr(); - EXPECT_NE(ret, nullptr); -} - -TEST_F(UtestOperater, AddControlInput_Exception) { - Operator op1; - Operator op2; - OpDescPtr op_desc_1; - op2 = OpDescUtils::CreateOperatorFromOpDesc(op_desc_1); - - auto ret = op1.AddControlInput(op2); - EXPECT_EQ(op1.IsEmpty(), ret.IsEmpty()); -} - -TEST_F(UtestOperater, AddMultiControlInput) { - auto o1_1_1 = op::Foo01("o1_1"); - auto o1_1_2 = op::Foo01("o1_2"); - auto o1_1_3 = op::Foo01("o1_3"); - auto o1_1_4 = op::Foo11("o1_4"); - - (void) o1_1_4.AddControlInput(o1_1_1); - (void) o1_1_4.AddControlInput(o1_1_3); - (void) o1_1_4.AddControlInput(o1_1_2); - Graph g{"g"}; - g.SetInputs(std::vector({o1_1_1, o1_1_2, o1_1_3})); - auto compute_graph = GraphUtilsEx::GetComputeGraph(g); - EXPECT_NE(compute_graph, nullptr); - - auto n_4 = compute_graph->FindNode("o1_4"); - EXPECT_NE(n_4, nullptr); - auto nodes = n_4->GetInControlNodes(); - EXPECT_EQ(nodes.size(), 3U); - // 顺序已经按照名字排序了 - EXPECT_EQ(nodes.at(0)->GetName(), "o1_1"); - EXPECT_EQ(nodes.at(1)->GetName(), "o1_2"); - EXPECT_EQ(nodes.at(2)->GetName(), "o1_3"); -} - -TEST_F(UtestOperater, SetAttr_char_array) { - Operator op1; - Operator op2; - - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - const char_t *name = "data name"; - const char_t *attr_value = "abc"; - - op2 = op1.SetAttr(name, attr_value); - std::string value1; - - op1.GetAttr(name, value1); - printf("c_str1 = %s\n", value1.c_str()); - - std::string value2; - op2.GetAttr(name, value2); - printf("c_str2 = %s\n", value2.c_str()); - EXPECT_EQ(value2, std::string("abc")); - - op1.SetAttr(nullptr, nullptr); -} - -TEST_F(UtestOperater, SetAttr_AscendString) { - Operator op1; - Operator op2; - - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - const char_t *name = "data name"; - AscendString attr_value = "abc"; - - op1.SetAttr(nullptr, attr_value); - op2 = op1.SetAttr(name, attr_value); - - std::string value2; - op2.GetAttr(name, value2); - EXPECT_EQ(value2, std::string("abc")); - - AscendString value3; - EXPECT_EQ(op2.GetAttr(nullptr, value3), GRAPH_FAILED); -} - -TEST_F(UtestOperater, SetAttr_vector_AscendString) { - Operator op1; - Operator op2; - - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - const char_t *name = "data name"; - std::vector attr_value = {AscendString("abc"), AscendString("def")}; - - op2 = op1.SetAttr(name, attr_value); - - std::vector value2; - op2.GetAttr(name, value2); - - EXPECT_TRUE(value2.size() > 1); - EXPECT_EQ(value2[1].GetString(), std::string("def")); - - op1.SetAttr(nullptr, attr_value); - EXPECT_EQ(op2.GetAttr(nullptr, value2), GRAPH_FAILED); -} - -TEST_F(UtestOperater, SetAttr_vector_AscendString2) { - Operator op1; - Operator op2; - - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - const char_t *name = "data name"; - std::vector attr_value = {AscendString("abc"), AscendString("def")}; - - op2 = op1.SetAttr(name, attr_value); - - std::vector value2; - op2.GetAttr(name, value2); - - EXPECT_EQ(value2[1].GetString(), std::string("def")); - - op2 = op1.SetAttr(nullptr, attr_value); - EXPECT_EQ(op2.GetAttr(nullptr, value2), GRAPH_FAILED); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfChar_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr("x", nullptr, "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr("x", nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfChar_tWithNullOp) { - Operator op1; - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfChar_tWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfChar_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr(0, nullptr, "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr(0, nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfChar_tWithNullOp) { - Operator op1; - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfChar_tWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfAscendString_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetInputAttr(0, nullptr, policy); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr(0, nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfAscendString_tWithNullOp) { - Operator op1; - AscendString policy("FIFO"); - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfAscendString_tWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfAscendString_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetInputAttr("x", nullptr, policy); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr("x", nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfAscendString_tWithNullOp) { - Operator op1; - AscendString policy("FIFO"); - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfAscendString_tWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - AscendString enqueue_policy; - const auto ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfChar_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr("y", nullptr, "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr("y", nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfChar_tWithNullOp) { - Operator op1; - op1.SetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfChar_tWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfChar_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr(0, nullptr, "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr(0, nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfChar_tWithNullOp) { - Operator op1; - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfChar_tWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfAscendString_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetOutputAttr(0, nullptr, policy); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr(0, nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfAscendStringWithNullOp) { - Operator op1; - AscendString policy("FIFO"); - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfAscendStringWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfAscendString_tWithNull) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetOutputAttr("y", nullptr, policy); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr("y", nullptr, enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfAscendStringWithNullOp) { - Operator op1; - AscendString policy("FIFO"); - op1.SetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfAscendStringWithNullTensor) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - AscendString policy("FIFO"); - op1.SetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - const auto ret = op1.GetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_FAILED); - EXPECT_EQ(enqueue_policy != "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR.c_str(), true); - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), static_cast(8)); - AscendString policy("FIFO"); - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - - bool has_flow_attr = false; - op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR.c_str(), has_flow_attr); - int32_t depth = 0; - AscendString enqueue_policy; - auto ret = op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), depth); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = op1.GetOutputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(has_flow_attr, true); - EXPECT_EQ(depth, 8); - EXPECT_EQ(enqueue_policy, policy); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr("y", ATTR_NAME_FLOW_ATTR.c_str(), true); - op1.SetOutputAttr("y", ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), static_cast(8)); - AscendString policy("FIFO"); - op1.SetOutputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - - bool has_flow_attr = false; - op1.GetOutputAttr("y", ATTR_NAME_FLOW_ATTR.c_str(), has_flow_attr); - int32_t depth = 0; - AscendString enqueue_policy; - auto ret = op1.GetOutputAttr("y", ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), depth); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = op1.GetOutputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(has_flow_attr, true); - EXPECT_EQ(depth, 8); - EXPECT_EQ(enqueue_policy, policy); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR.c_str(), true); - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), static_cast(8)); - AscendString policy("FIFO"); - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - - bool has_flow_attr = false; - op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR.c_str(), has_flow_attr); - int32_t depth = 0; - AscendString enqueue_policy; - auto ret = op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), depth); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(enqueue_policy, policy); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR.c_str(), true); - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), static_cast(8)); - AscendString policy("FIFO"); - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), policy); - - bool has_flow_attr = false; - auto ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR.c_str(), has_flow_attr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - int32_t depth = 0; - AscendString enqueue_policy; - ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_DEPTH.c_str(), depth); - EXPECT_EQ(ret, GRAPH_SUCCESS); - ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(enqueue_policy, policy); -} - -TEST_F(UtestOperater, SetInputAttrByNameOfChar_t) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - auto ret = op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(enqueue_policy, "FIFO"); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfChar_t) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - auto ret = op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(enqueue_policy, "FIFO"); -} - -TEST_F(UtestOperater, SetOutputAttrByNameOfChar_t) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - auto ret = op1.GetOutputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(enqueue_policy == "FIFO", true); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfChar_t) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), "FIFO"); - AscendString enqueue_policy; - auto ret = op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), enqueue_policy); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(enqueue_policy, "FIFO"); -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfListAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - std::vector attr_value_got; - EXPECT_EQ(op1.GetOutputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - ASSERT_EQ(attr_value_got.size(), attr_value.size()); - for (int32_t i = 0; i < static_cast(attr_value.size()); i++) { - EXPECT_EQ(attr_value_got[i], attr_value[i]); - } -} - -TEST_F(UtestOperater, SetOutputAttrByIndexOfListAscendString_InvalidIndex) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value_got; - Operator empty; - empty.SetOutputAttr(0, "", attr_value_got); - EXPECT_NE(empty.GetOutputAttr(1, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetOutputAttr(1, nullptr, attr_value); - op1.SetOutputAttr(1, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - EXPECT_NE(op1.GetOutputAttr(0, nullptr, attr_value_got), GRAPH_SUCCESS); - EXPECT_NE(op1.GetOutputAttr(1, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfListAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - std::vector attr_value_got; - EXPECT_EQ(op1.GetInputAttr(0, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - ASSERT_EQ(attr_value_got.size(), attr_value.size()); - for (int32_t i = 0; i < static_cast(attr_value.size()); i++) { - EXPECT_EQ(attr_value_got[i], attr_value[i]); - } -} - -TEST_F(UtestOperater, SetInputAttrByIndexOfListAscendString_InvalidIndex) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - std::vector attr_value_got; - Operator empty; - empty.SetInputAttr(0, "", attr_value_got); - EXPECT_NE(empty.GetInputAttr(1, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetInputAttr(1, nullptr, attr_value); - op1.SetInputAttr(1, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - EXPECT_NE(op1.GetInputAttr(0, nullptr, attr_value_got), GRAPH_SUCCESS); - EXPECT_NE(op1.GetInputAttr(1, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, SetOutputAttrByDstNameOfListAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetOutputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - std::vector attr_value_got; - EXPECT_EQ(op1.GetOutputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - ASSERT_EQ(attr_value_got.size(), attr_value.size()); - for (int32_t i = 0; i < static_cast(attr_value.size()); i++) { - EXPECT_EQ(attr_value_got[i], attr_value[i]); - } -} - -TEST_F(UtestOperater, SetOutputAttrByDstNameOfListAscendString_InvalidDstName) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetOutputAttr(nullptr, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - op1.SetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - std::vector attr_value_got; - EXPECT_NE(op1.GetOutputAttr(nullptr, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - EXPECT_NE(op1.GetOutputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, SetInputAttrByDstNameOfListAscendString) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - std::vector attr_value_got; - EXPECT_EQ(op1.GetInputAttr("x", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - ASSERT_EQ(attr_value_got.size(), attr_value.size()); - for (int32_t i = 0; i < static_cast(attr_value.size()); i++) { - EXPECT_EQ(attr_value_got[i], attr_value[i]); - } -} - -TEST_F(UtestOperater, SetInputAttrByDstNameOfListAscendString_InvalidDstName) { - Operator op1; - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("op1", optype_str); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - std::vector attr_value{"1", "2"}; - op1.SetInputAttr(nullptr, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - op1.SetInputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value); - std::vector attr_value_got; - EXPECT_NE(op1.GetInputAttr(nullptr, ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); - EXPECT_NE(op1.GetInputAttr("y", ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY.c_str(), attr_value_got), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, SetAttr_Tensor) { - Operator op1; - Operator op2; - - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - const char_t *name = "data name"; - TensorDesc tensor_desc; - std::vector data = {1, 2, 3}; - Tensor attr_value(tensor_desc, data); - - op2 = op1.SetAttr(name, attr_value); - - Tensor value2; - op2.GetAttr(name, value2); - - EXPECT_EQ(value2.GetSize(), attr_value.GetSize()); -} - -TEST_F(UtestOperater, SetAttr_Tensor2) { - Operator op1; - Operator op2; - - std::string optype_str = "optype"; - ge::OpDescPtr op_desc = std::make_shared("", optype_str); - op1 = OpDescUtils::CreateOperatorFromOpDesc(op_desc); - - std::string name = "data name"; - TensorDesc tensor_desc; - std::vector data = {1, 2, 3}; - Tensor attr_value(tensor_desc, data); - - op2 = op1.SetAttr(name, attr_value); - - Tensor value2; - op2.GetAttr(name, value2); - - EXPECT_EQ(value2.GetSize(), attr_value.GetSize()); -} - -TEST_F(UtestOperater, SetAttr_vector_Tensor) { - Operator op1; - Operator op2; - - op1 = Operator("Data"); - std::vector attr_value = {Tensor()}; - - std::string name = "data name"; - op2 = op1.SetAttr(name, attr_value); - - std::vector value2; - op2.GetAttr(name, value2); - - EXPECT_EQ(value2.size(), attr_value.size()); -} - -TEST_F(UtestOperater, SetAttr_vector_Tensor2) { - Operator op1; - Operator op2; - - op1 = Operator("Data"); - std::vector attr_value = {Tensor()}; - - op1.SetAttr(nullptr, attr_value); - - const char_t *name = "data name"; - op2 = op1.SetAttr(name, attr_value); - - std::vector value2; - op2.GetAttr(nullptr, value2); - op2.GetAttr(name, value2); - - EXPECT_EQ(value2.size(), attr_value.size()); -} - -TEST_F(UtestOperater, SetAttr_OpBytes) { - Operator op1; - Operator op2; - - op1 = Operator("Data"); - auto attr_value = OpBytes{1, 2, 3}; - - op1.SetAttr(nullptr, attr_value); - - const char_t *name = "data name"; - op2 = op1.SetAttr(name, attr_value); - - OpBytes value2; - op2.GetAttr(nullptr, value2); - op2.GetAttr(name, value2); - - EXPECT_EQ(value2.size(), attr_value.size()); -} - -TEST_F(UtestOperater, SetAttr_OpBytes2) { - Operator op1; - Operator op2; - - op1 = Operator("Data"); - auto attr_value = OpBytes{1, 2, 3}; - - std::string name = "data name"; - op2 = op1.SetAttr(name, attr_value); - - OpBytes value2; - op2.GetAttr(name, value2); - EXPECT_EQ(value2.size(), attr_value.size()); -} - -TEST_F(UtestOperater, SetAttr_AttrValue) { - Operator op; - op = Operator("Data"); - AttrValue attr_value; - op.SetAttr(nullptr, std::move(attr_value)); - - const char_t *name = "data name"; - op.SetAttr(name, 10); - AttrValue attr_value2; - - EXPECT_EQ(op.GetAttr(name, attr_value2), GRAPH_SUCCESS); - int64_t value = 0; - attr_value2.GetValue(value); - EXPECT_EQ(value, 10); - - const char_t *name2 = "foo"; - op.SetAttr(name2, std::move(attr_value2)); - AttrValue attr_value3; - op.GetAttr(name2, attr_value3); - attr_value3.GetValue(value); - EXPECT_EQ(value, 10); - - AttrValue attr_value4; - op.GetAttr(std::string(name2), attr_value4); - attr_value4.GetValue(value); - EXPECT_EQ(value, 10); - - AttrValue::FLOAT f_value = 0.0; - attr_value4.GetValue(f_value); - EXPECT_EQ(value, 10.0); - - AscendString str("1234"); - op.SetAttr("asc_str", str); - AttrValue attr_value5; - op.GetAttr("asc_str", attr_value5); - AscendString asc_str_value; - attr_value5.GetValue(asc_str_value); - EXPECT_EQ(asc_str_value, "1234"); - - AttrValue::STR str_value; - attr_value5.GetValue(str_value); - EXPECT_EQ(str_value, "1234"); -} - -TEST_F(UtestOperater, SetAttr_vector_DataType) { - Operator op; - op = Operator("Data"); - - const char_t *name = "data name"; - std::vector attr_value = {DT_INT8, DT_INT16, DT_INT32}; - - op.SetAttr(nullptr, attr_value); - op.SetAttr(name, attr_value); - - std::vector attr_value_out; - - op.GetAttr(nullptr, attr_value_out); - op.GetAttr(name, attr_value_out); - - EXPECT_TRUE(attr_value_out.size() > 2); - EXPECT_EQ(attr_value_out[2], DT_INT32); -} - -TEST_F(UtestOperater, SetAttr_vector_DataType2) { - Operator op; - op = Operator("Data"); - - std::string name = "data name"; - std::vector attr_value = {DT_INT8, DT_INT16, DT_INT32}; - - op.SetAttr(name, attr_value); - - std::vector attr_value_out; - - op.GetAttr(name, attr_value_out); - - EXPECT_EQ(attr_value_out[1], DT_INT16); -} - -TEST_F(UtestOperater, SetAttr_DataType) { - Operator op; - op = Operator("Data"); - - const char_t *name = "data name"; - ge::DataType attr_value = DT_INT16; - - op.SetAttr(nullptr, attr_value); - op.SetAttr(name, attr_value); - - ge::DataType attr_value_out; - - op.GetAttr(nullptr, attr_value_out); - op.GetAttr(name, attr_value_out); - - EXPECT_EQ(attr_value_out, DT_INT16); -} - -TEST_F(UtestOperater, SetAttr_DataType2) { - Operator op; - op = Operator("Data"); - - std::string name = "data name"; - ge::DataType attr_value = DT_INT16; - - op.SetAttr(name, attr_value); - - ge::DataType attr_value_out; - - op.GetAttr(name, attr_value_out); - - EXPECT_EQ(attr_value_out, DT_INT16); -} - -TEST_F(UtestOperater, CopyOperators1) { - - ge::OpDescPtr add_op(new ge::OpDesc("add_0", "add")); - std::shared_ptr compute_graph(new ge::ComputeGraph("test_graph")); - auto add_node = compute_graph->AddNode(add_op); - Graph graph = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - - ge::OpDescPtr add_op_2(new ge::OpDesc("add_2", "add")); - std::shared_ptr compute_graph_2(new ge::ComputeGraph("test_graph_2")); - auto add_node_2 = compute_graph->AddNode(add_op_2); - Graph graph2 = ge::GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph_2); - - Operator op1("op1"); - Operator op2("op2"); - Operator op3("op3"); - graph.AddOp(op1); - graph.AddOp(op2); - graph.AddOp(op3); - - auto ret = GraphUtilsEx::CopyGraph(graph, graph2); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, CopyOperators2) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto transdata = builder.AddNode("Transdata", "Transdata", 2, 1); - auto op_desc = transdata->GetOpDesc(); - op_desc->impl_->input_name_idx_["Data"] = 0; - op_desc->impl_->input_name_idx_["Enter"] = 1; - auto data = builder.AddNode("Data", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - - Operator op1 = OpDescUtils::CreateOperatorFromNode(transdata); - Operator op2 = OpDescUtils::CreateOperatorFromNode(data); - Operator op3 = OpDescUtils::CreateOperatorFromNode(data2); - - ComputeGraphPtr compt_graph = builder.GetGraph(); - Graph graph = GraphUtilsEx::CreateGraphFromComputeGraph(compt_graph); - graph.AddOp(op1); - graph.AddOp(op2); - graph.AddOp(op3); - - ut::GraphBuilder builder2 = ut::GraphBuilder("graph2"); - auto data3 = builder2.AddNode("Data3", "Data", 0, 1); - ComputeGraphPtr compt_graph2 = builder2.GetGraph(); - Graph graph2 = GraphUtilsEx::CreateGraphFromComputeGraph(compt_graph2); - - auto ret = GraphUtilsEx::CopyGraph(graph, graph2); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, CopyOperators3) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto transdata = builder.AddNode("Transdata", "Transdata", 2, 1); - auto op_desc = transdata->GetOpDesc(); - op_desc->impl_->input_name_idx_["Data"] = 0; - op_desc->impl_->input_name_idx_["Enter"] = 1; - auto data = builder.AddNode("Data", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - - Operator op1 = OpDescUtils::CreateOperatorFromNode(transdata); - Operator op2 = OpDescUtils::CreateOperatorFromNode(data); - Operator op3 = OpDescUtils::CreateOperatorFromNode(data2); - - ComputeGraphPtr compt_graph = builder.GetGraph(); - Graph src_graph = GraphUtilsEx::CreateGraphFromComputeGraph(compt_graph); - src_graph.AddOp(op1); - src_graph.AddOp(op2); - src_graph.AddOp(op3); - - ut::GraphBuilder builder2 = ut::GraphBuilder("graph2"); - auto data3 = builder2.AddNode("Data3", "Data", 0, 1); - ComputeGraphPtr dst_compute_graph = builder2.GetGraph(); - Graph dst_graph = GraphUtilsEx::CreateGraphFromComputeGraph(dst_compute_graph); - - std::map src_op_list = {{string("op1"), op1}, {string("op2"), op2}, {string("op3"), op3}}; - std::map dst_op_list; - - std::map node_old_2_new; - std::map op_desc_old_2_new; - - auto ret = OpDescUtils::CopyOperators(dst_compute_graph, node_old_2_new, op_desc_old_2_new, src_op_list, dst_op_list); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, TestCallbackToGetConstInputWithRuntimeInferenceContext) { - // new a tensor - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - - // define callback - RuntimeInferenceContext runtime_ctx; - OperatorImpl::GetConstInputOnRuntimeFun func_get_input_const = - [&runtime_ctx](const ConstNodePtr &node, const size_t index, ge::GeTensorPtr &dst_tensor) { - // from runtime context - const auto in_data_anchor = node->GetInDataAnchor(static_cast(index)); - const auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - auto peer_node = out_data_anchor->GetOwnerNode(); - GeTensorPtr tensor_value = nullptr; - if (runtime_ctx.GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), tensor_value) == - GRAPH_SUCCESS) { - dst_tensor = tensor_value; - return GRAPH_SUCCESS; - } - return ge::GRAPH_SUCCESS; - }; - - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto transdata = builder.AddNode("Transdata", "Transdata", 2, 1); - auto op_desc = transdata->GetOpDesc(); - op_desc->impl_->input_name_idx_["Data"] = 0; - op_desc->impl_->input_name_idx_["Enter"] = 1; - auto data = builder.AddNode("Data", "Data", 0, 1); - auto data2 = builder.AddNode("Data2", "Data", 0, 1); - GraphUtils::AddEdge(data->GetOutDataAnchor(0), transdata->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2->GetOutDataAnchor(0), transdata->GetInDataAnchor(1)); - Operator op1 = OpDescUtils::CreateOperatorFromNode(transdata); - - OpDescUtils::SetCallbackGetConstInputFuncToOperator(op1, func_get_input_const); - - int output_id = 0; - runtime_ctx.SetTensor(data->GetOpDesc()->GetId(), output_id, std::move(tensor)); - - Tensor test_tensor; - std::string input_name = "Data"; - EXPECT_EQ(op1.GetInputConstData(input_name.c_str(), test_tensor), GRAPH_SUCCESS); - EXPECT_EQ(test_tensor.GetSize(), value.size()); // 3 item in tensor - auto const_data = reinterpret_cast(test_tensor.GetData()); - for (size_t i = 0; i < 3; ++i) { - EXPECT_EQ(const_data[i], value[i]); - } -} -/* - * Foo11 - * | - * Foo01 - */ -TEST_F(UtestOperater, SetInput_Success_SingleIOByStrName) { - auto foo01 = op::Foo01("foo01"); - auto foo11 = op::Foo11("foo11"); - foo11.SetInput(std::string("x"), foo01); - Graph graph("graph"); - graph.SetInputs({foo01}); - CheckTopoGraph1(graph); -} -/* - * Foo11 - * | - * Foo01 - */ -TEST_F(UtestOperater, SetInput_Success_SingleIOByCharName) { - auto foo01 = op::Foo01("foo01"); - auto foo11 = op::Foo11("foo11"); - foo11.SetInput("x", foo01); - Graph graph("graph"); - graph.SetInputs({foo01}); - CheckTopoGraph1(graph); -} - -TEST_F(UtestOperater, SetInput_Failed_NullName) { - auto foo01 = op::Foo01("foo01"); - auto foo11 = op::Foo11("foo11"); - foo11.SetInput(nullptr, foo01); - Graph graph("graph"); - graph.SetInputs({foo01}); - auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - ASSERT_NE(compute_graph, nullptr); - ASSERT_EQ(gert::SummaryChecker(compute_graph).StrictAllNodeTypes({{"Foo01", 1}}), "success"); -} - -TEST_F(UtestOperater, SetInput_Opdesc_Null) { - auto invalid_op = ge::Operator("invalid"); - invalid_op.operator_impl_->op_desc_ = nullptr; - auto foo11 = op::Foo11("foo11"); - foo11.SetInput("x", invalid_op); - auto op_io = ge::OpIO("op_io", 0U, invalid_op.operator_impl_); - ASSERT_EQ(foo11.operator_impl_->GetInputImpl("x", op_io), GRAPH_FAILED); -} - -TEST_F(UtestOperater, SetInput_OutputSize_Zero) { - auto invalid_op = ge::Operator("invalid"); - ASSERT_EQ(invalid_op.GetOutputsSize(), 0U); - auto foo11 = op::Foo11("foo11"); - foo11.SetInput("x", invalid_op); - auto op_io = ge::OpIO("op_io", 0U, invalid_op.operator_impl_); - ASSERT_EQ(foo11.operator_impl_->GetInputImpl("x", op_io), GRAPH_FAILED); -} - -TEST_F(UtestOperater, GetOutput_Failed_NullName) { - auto foo01 = op::Foo01("foo01"); - auto foo11 = op::Foo11("foo11"); - foo11.SetInput("", foo01.GetOutput(nullptr)); - Graph graph("graph"); - graph.SetInputs({foo01}); - auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - ASSERT_NE(compute_graph, nullptr); - ASSERT_EQ(gert::SummaryChecker(compute_graph).StrictAllNodeTypes({{"Foo01", 1}}), "success"); -} -/* - * Foo22 - * / | - * Foo11 | - * \ | - * Foo02 - */ -TEST_F(UtestOperater, SetInput_Success_TwoByStrName) { - auto foo02 = op::Foo02("foo02"); - auto foo11 = op::Foo11("foo11").SetInput(std::string("x"), foo02, std::string("x")); - auto foo22 = op::Foo22("foo22") - .SetInput(std::string("m"), foo11, std::string("y")) - .SetInput(std::string("n"), foo02, std::string("y")); - - Graph graph("graph"); - graph.SetInputs({foo02}); - - CheckTopoGraph2(graph); -} - -TEST_F(UtestOperater, SetInput_Success_TwoByCharName) { - auto foo02 = op::Foo02("foo02"); - auto foo11 = op::Foo11("foo11").SetInput("x", foo02, "x"); - auto foo22 = op::Foo22("foo22").SetInput("m", foo11, "y").SetInput("n", foo02, "y"); - - Graph graph("graph"); - graph.SetInputs({foo02}); - - CheckTopoGraph2(graph); -} - -TEST_F(UtestOperater, SetInput_Success_TwoByIndex) { - auto foo02 = op::Foo02("foo02"); - auto foo11 = op::Foo11("foo11").SetInput("x", foo02, 0U); - auto foo22 = op::Foo22("foo22").SetInput("m", foo11, 0U).SetInput("n", foo02, 1U); - - Graph graph("graph"); - graph.SetInputs({foo02}); - - CheckTopoGraph2(graph); -} - -TEST_F(UtestOperater, SetInput_Success_TwoByStrNameIndexedHandler) { - auto foo02 = op::Foo02("foo02"); - auto foo11 = op::Foo11("foo11").SetInput(std::string("x"), foo02.GetOutput(std::string("x"))); - auto foo22 = op::Foo22("foo22") - .SetInput(std::string("m"), foo11.GetOutput(std::string("y"))) - .SetInput(std::string("n"), foo02.GetOutput(std::string("y"))); - - Graph graph("graph"); - graph.SetInputs({foo02}); - - CheckTopoGraph2(graph); -} - -TEST_F(UtestOperater, SetInput_Success_TwoByCharNameIndexedHandler) { - auto foo02 = op::Foo02("foo02"); - auto foo11 = op::Foo11("foo11").SetInput("x", foo02.GetOutput("x")); - auto foo22 = op::Foo22("foo22").SetInput("m", foo11.GetOutput("y")).SetInput("n", foo02.GetOutput("y")); - - Graph graph("graph"); - graph.SetInputs({foo02}); - - CheckTopoGraph2(graph); -} - -/* - * Foo22 - * / |d0 \d1 - * Foo11 | | - * \0 |0 /1 - * Foo02 - */ -TEST_F(UtestOperater, SetDynamicInput_Success_TwoByStrName) { - auto foo02 = op::Foo02("foo02"); - auto foo11 = op::Foo11("foo11").SetInput(std::string("x"), foo02, std::string("x")); - auto foo22 = op::DFoo22("foo22") - .create_dynamic_input_n(2) - .SetInput(std::string("m"), foo11, std::string("y")) - .SetInput(std::string("n"), 0, foo02, std::string("x")) - .SetInput(std::string("n"), 1, foo02, std::string("y")); - - Graph graph("graph"); - graph.SetInputs({foo02}); - - CheckTopoGraph3(graph); -} -TEST_F(UtestOperater, SetDynamicInput_Success_TwoByCharName) { - auto foo02 = op::Foo02("foo02"); - auto foo11 = op::Foo11("foo11").SetInput("x", foo02, "x"); - auto foo22 = op::DFoo22("foo22") - .create_dynamic_input_n(2) - .SetInput("m", foo11, "y") - .SetInput("n", 0, foo02, "x") - .SetInput("n", 1, foo02, "y"); - - Graph graph("graph"); - graph.SetInputs({foo02}); - - CheckTopoGraph3(graph); -} -/* - * DFoo22 - * / | \ - * Foo11 Foo11 Foo11 - */ -TEST_F(UtestOperater, SetDynamicInput_Success_SingleOutput) { - auto foo01_0 = op::Foo01("foo01_0"); - auto foo01_1 = op::Foo01("foo01_1"); - auto foo01_2 = op::Foo01("foo01_2"); - auto foo22 = op::DFoo22("foo22") - .create_dynamic_input_n(2) - .SetInput("m", foo01_0) - .SetInput(std::string("n"), 0, foo01_1) - .SetInput("n", 1, foo01_2); - - Graph graph("graph"); - graph.SetInputs({foo01_0, foo01_1, foo01_2}); - - auto compute_graph = GraphUtilsEx::GetComputeGraph(graph); - ASSERT_NE(compute_graph, nullptr); - ASSERT_EQ(gert::SummaryChecker(compute_graph).StrictAllNodeTypes({{"Foo01", 3}, {"DFoo22", 1}}), "success"); - auto foo22_node = compute_graph->FindNode("foo22"); - ASSERT_NE(foo22_node, nullptr); - ASSERT_EQ(gert::NodeTopoChecker(foo22_node).StrictConnectFrom({{"Foo01"}, {"Foo01"}, {"Foo01"}}), "success"); -} -TEST_F(UtestOperater, InputRegister_Success_ByString) { - Operator op("Op", "Op"); - op.InputRegister(std::string("x")); - op.InputRegister(std::string("y")); - op.OptionalInputRegister(std::string("o")); - op.DynamicInputRegister(std::string("d"), 0, true); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - std::vector> expected{{"x", kIrInputRequired}, - {"y", kIrInputRequired}, - {"o", kIrInputOptional}, - {"d", kIrInputDynamic}}; - ASSERT_EQ(op_desc->GetIrInputs(), expected); -} -TEST_F(UtestOperater, InputRegister_Success_ByChar) { - Operator op("Op", "Op"); - op.InputRegister("x"); - op.InputRegister("y"); - op.OptionalInputRegister("o"); - op.DynamicInputRegister("d", 0, true); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - std::vector> expected{{"x", kIrInputRequired}, - {"y", kIrInputRequired}, - {"o", kIrInputOptional}, - {"d", kIrInputDynamic}}; - ASSERT_EQ(op_desc->GetIrInputs(), expected); -} -TEST_F(UtestOperater, InputRegister_Failed_NullptrChar) { - Operator op("Op", "Op"); - op.InputRegister(nullptr); - op.InputRegister(nullptr); - op.OptionalInputRegister(nullptr); - op.DynamicInputRegister(nullptr, 0, true); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_TRUE(op_desc->GetIrInputs().empty()); -} -TEST_F(UtestOperater, OutputRegister_Success) { - Operator op("Op", "Op"); - op.OutputRegister(std::string("x")); - op.OutputRegister("y"); - op.OutputRegister(std::string("m")); - op.OutputRegister("n"); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllOutputsDescSize(), 4); - EXPECT_EQ(op_desc->GetOutputIndexByName("x"), 0); - EXPECT_EQ(op_desc->GetOutputIndexByName("y"), 1); - EXPECT_EQ(op_desc->GetOutputIndexByName("m"), 2); - EXPECT_EQ(op_desc->GetOutputIndexByName("n"), 3); -} -TEST_F(UtestOperater, DynamicInputRegister_Success_InsertCharDynamicInput) { - Operator op("Op", "Op"); - op.InputRegister("x"); - op.InputRegister("y"); - op.InputRegister("z"); - op.DynamicInputRegisterByIndex("d", 2, 1); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsDesc().size(), 5); - EXPECT_EQ(op_desc->GetInputIndexByName("x"), 0); - EXPECT_EQ(op_desc->GetInputIndexByName("d0"), 1); - EXPECT_EQ(op_desc->GetInputIndexByName("d1"), 2); - EXPECT_EQ(op_desc->GetInputIndexByName("y"), 3); - EXPECT_EQ(op_desc->GetInputIndexByName("z"), 4); - std::vector input_indexes; - EXPECT_EQ(op_desc->GetDynamicInputIndexesByName("d", input_indexes), GRAPH_SUCCESS); - std::vector expect_indexes{1, 2}; - EXPECT_EQ(input_indexes, expect_indexes); -} -TEST_F(UtestOperater, DynamicInputRegister_Failed_Nullptr) { - Operator op("Op", "Op"); - op.InputRegister("x"); - op.InputRegister("y"); - op.InputRegister("z"); - op.DynamicInputRegisterByIndex(nullptr, 2, 1); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsDesc().size(), 3); - EXPECT_EQ(op_desc->GetInputIndexByName("x"), 0); - EXPECT_EQ(op_desc->GetInputIndexByName("y"), 1); - EXPECT_EQ(op_desc->GetInputIndexByName("z"), 2); -} -TEST_F(UtestOperater, DynamicInputRegister_Success_InsertStrDynamicInput) { - Operator op("Op", "Op"); - op.InputRegister(std::string("x")); - op.InputRegister(std::string("y")); - op.InputRegister(std::string("z")); - op.DynamicInputRegisterByIndex(std::string("d"), 2, 1); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsDesc().size(), 5); - EXPECT_EQ(op_desc->GetInputIndexByName("x"), 0); - EXPECT_EQ(op_desc->GetInputIndexByName("d0"), 1); - EXPECT_EQ(op_desc->GetInputIndexByName("d1"), 2); - std::vector input_indexes; - EXPECT_EQ(op_desc->GetDynamicInputIndexesByName("d", input_indexes), GRAPH_SUCCESS); - std::vector expect_indexes{1, 2}; - EXPECT_EQ(input_indexes, expect_indexes); - EXPECT_EQ(op_desc->GetInputIndexByName("y"), 3); - EXPECT_EQ(op_desc->GetInputIndexByName("z"), 4); -} -TEST_F(UtestOperater, DynamicInputRegister_Success_DuplicateIrInput) { - Operator op("Op", "Op"); - op.InputRegister(std::string("x")); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetIrInputsSize(), 1); - - op.DynamicInputRegister(std::string("x"), 0, true); - op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_EQ(op_desc->GetIrInputsSize(), 1); -} -TEST_F(UtestOperater, GetDynamicInputNum_Success) { - Operator op("Op", "Op"); - op.DynamicInputRegister("x", 5); - op.DynamicInputRegister("y", 4); - EXPECT_EQ(op.GetDynamicInputNum("x"), 5); - EXPECT_EQ(op.GetDynamicInputNum("y"), 4); - EXPECT_EQ(op.GetDynamicInputNum("z"), 0); - EXPECT_EQ(op.GetDynamicInputNum(std::string("x")), 5); - EXPECT_EQ(op.GetDynamicInputNum(std::string("y")), 4); - EXPECT_EQ(op.GetDynamicInputNum(std::string("z")), 0); - EXPECT_EQ(op.GetDynamicInputNum(nullptr), 0); -} -TEST_F(UtestOperater, DynamicOutputRegister_Success) { - Operator op("Op", "Op"); - op.DynamicOutputRegister(std::string("x"), 2, true); - op.DynamicOutputRegister("y", 3, true); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllOutputsDescSize(), 5); - EXPECT_EQ(op_desc->GetOutputIndexByName("x0"), 0); - EXPECT_EQ(op_desc->GetOutputIndexByName("x1"), 1); - std::vector indexes1; - EXPECT_EQ(op_desc->GetDynamicOutputIndexesByName("x", indexes1), GRAPH_SUCCESS); - std::vector expect_indexes1{0, 1}; - EXPECT_EQ(indexes1, expect_indexes1); - EXPECT_EQ(op_desc->GetOutputIndexByName("y0"), 2); - EXPECT_EQ(op_desc->GetOutputIndexByName("y1"), 3); - EXPECT_EQ(op_desc->GetOutputIndexByName("y2"), 4); - std::vector indexes2; - EXPECT_EQ(op_desc->GetDynamicOutputIndexesByName("y", indexes2), GRAPH_SUCCESS); - std::vector expect_indexes2{2, 3, 4}; - EXPECT_EQ(indexes2, expect_indexes2); -} -TEST_F(UtestOperater, DynamicOutputRegister_duplicate_ir_output_Success) { - Operator op("Op", "Op"); - op.DynamicOutputRegister("y", 0, true); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetIrOutputs().size(), 1); - - // register duplicated - op.DynamicOutputRegister("y", 0, true); - op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetIrOutputs().size(), 1); -} -TEST_F(UtestOperater, GetDynamicOutputNum_Success) { - Operator op("Op", "Op"); - op.DynamicOutputRegister("x", 5); - op.DynamicOutputRegister("y", 4); - EXPECT_EQ(op.GetDynamicOutputNum("x"), 5); - EXPECT_EQ(op.GetDynamicOutputNum("y"), 4); - EXPECT_EQ(op.GetDynamicOutputNum("z"), 0); - EXPECT_EQ(op.GetDynamicOutputNum(std::string("x")), 5); - EXPECT_EQ(op.GetDynamicOutputNum(std::string("y")), 4); - EXPECT_EQ(op.GetDynamicOutputNum(std::string("z")), 0); - EXPECT_EQ(op.GetDynamicOutputNum(nullptr), 0); -} - -/* - * Foo11 - * | - * Foo01 - */ -TEST_F(UtestOperater, SetInput_Success_DoNotPassTensorAttrs) { - auto foo01 = op::Foo01("foo01"); - auto foo11 = op::Foo11("foo11"); - foo01.SetOutputAttr(0, "foo01_output_attr", 1); - foo11.SetInputAttr(0, "foo11_input_attr", 1); - foo11.SetInput("x", foo01); - Graph graph("graph"); - graph.SetInputs({foo01}); - CheckTopoGraph1(graph); - int64_t value = 0; - EXPECT_EQ(foo11.GetInputAttr(0, "foo01_output_attr", value), GRAPH_SUCCESS); - EXPECT_TRUE(value == 0); - EXPECT_EQ(foo11.GetInputAttr(0, "foo11_input_attr", value), GRAPH_SUCCESS); - EXPECT_TRUE(value == 1); -} - -TEST_F(UtestOperater, GetInputConstData_While_fail) { - auto graph = BuildWhileGraphWithConstInput(); - auto nodes = graph->GetAllNodes(); - NodePtr reshape = nullptr; - for (const auto &n : nodes) { - if (n->GetType() == "Reshape") { - reshape = n; - break; - } - } - ASSERT_TRUE(reshape != nullptr); - Tensor tensor; - auto op = OpDescUtils::CreateOperatorFromNode(reshape); - ASSERT_EQ(op.GetInputConstData("x", tensor), GRAPH_FAILED); - ASSERT_EQ(op.GetInputConstData("shape", tensor), GRAPH_SUCCESS); - ASSERT_EQ(tensor.GetSize(), 3); -} -TEST_F(UtestOperater, WeakLife) { - OperatorKeeper::GetInstance().ClearInvalidOp(); - auto old_size = OperatorKeeper::GetInstance().operators_.size(); - { auto op = ge::OperatorFactory::CreateOperator("test", "Const"); } - EXPECT_EQ(OperatorKeeper::GetInstance().operators_.size(), old_size + 1U); - // op的生命周期结束后,因keeper单例对其是弱引用,所以weak转shared是空被清理 - OperatorKeeper::GetInstance().ClearInvalidOp(); - EXPECT_TRUE(OperatorKeeper::GetInstance().operators_.size() == old_size); -} - -TEST_F(UtestOperater, UpdateInputOutDesc_Failed_Uninitialized) { - // not init - Operator op; - TensorDesc td; - auto ret = op.UpdateInputDesc(0U, td); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = op.UpdateOutputDesc(0U, td); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, UpdateInputOutDesc_Failed_OutOfRange) { - // index out of range - Operator op("Test"); - TensorDesc td; - AscendString type; - op.GetOpType(type); - EXPECT_EQ(std::strcmp(type.GetString(), "Test"), 0); - auto ret = op.UpdateInputDesc(0U, td); - EXPECT_EQ(ret, GRAPH_FAILED); - ret = op.UpdateOutputDesc(0U, td); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(UtestOperater, UpdateInputOutDesc_Success) { - // update ok - Operator op("Test"); - auto dims = std::vector{1, 2, 3, 4}; - TensorDesc td(Shape(dims), FORMAT_NCHW); - op.InputRegister("x0"); - op.InputRegister("x1"); - op.OutputRegister("y"); - - auto ret = op.UpdateInputDesc(1U, td); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto input_desc_1 = op.GetInputDesc(1U); - EXPECT_EQ(input_desc_1.GetShape().GetDims(), dims); - EXPECT_EQ(input_desc_1.GetFormat(), FORMAT_NCHW); - - ret = op.UpdateOutputDesc(0U, td); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto output_desc_0 = op.GetOutputDesc(0U); - EXPECT_EQ(output_desc_0.GetShape().GetDims(), dims); - EXPECT_EQ(output_desc_0.GetFormat(), FORMAT_NCHW); -} - -TEST_F(UtestOperater, AttrRegister_WithAttrValue_Success) { - Operator op("Test"); - - // 创建 AttrValue 对象 - AttrValue attr_value; - int64_t int_val = 12345; - EXPECT_EQ(attr_value.SetAttrValue(int_val), GRAPH_SUCCESS); - - // 测试正常情况 - op.AttrRegister("test_attr", attr_value); - - // 验证属性是否设置成功 - int64_t get_val = 0; - EXPECT_EQ(op.GetAttr("test_attr", get_val), GRAPH_SUCCESS); - EXPECT_EQ(get_val, int_val); -} - -TEST_F(UtestOperater, AttrRegister_WithAttrValue_NullName) { - Operator op("Test"); - - AttrValue attr_value; - int64_t int_val = 12345; - EXPECT_EQ(attr_value.SetAttrValue(int_val), GRAPH_SUCCESS); - - // 测试空指针情况 - op.AttrRegister(nullptr, attr_value); - - // 验证属性没有设置成功 - int64_t get_val = 0; - EXPECT_NE(op.GetAttr("test_attr", get_val), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, AttrRegister_WithAttrValue_NullImpl) { - Operator op; - - AttrValue attr_value; - int64_t int_val = 12345; - EXPECT_EQ(attr_value.SetAttrValue(int_val), GRAPH_SUCCESS); - - // 测试空实现情况 - op.AttrRegister("test_attr", attr_value); - - // 验证属性没有设置成功 - int64_t get_val = 0; - EXPECT_NE(op.GetAttr("test_attr", get_val), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, SetSubgraphInstanceName_Failded) { - Operator op("Test"); - - EXPECT_NE(op.SetSubgraphInstanceName(0, "subgraph_0"), GRAPH_SUCCESS); - EXPECT_NE(op.SetSubgraphInstanceName(1, "subgraph_1"), GRAPH_SUCCESS); - EXPECT_NE(op.SetSubgraphInstanceName(2, "subgraph_2"), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, SetSubgraphInstanceName_Success) { - Operator op("Test"); - op.SubgraphRegister("static", false); - op.SubgraphCountRegister("static", 1U); - EXPECT_EQ(op.SetSubgraphInstanceName(0, "subgraph_0"), GRAPH_SUCCESS); - // 测试异常情况 - EXPECT_NE(op.SetSubgraphInstanceName(1, "subgraph_1"), GRAPH_SUCCESS); -} - -TEST_F(UtestOperater, OperatorImpl_SetAttr_WithConstAnyValue_MultipleAttributes) { - auto op_impl = std::make_shared("Test", "Test"); - - // 设置多个不同类型的属性 - AnyValue int_any_value; - int64_t int_val = 12345; - int_any_value.SetValue(int_val); - EXPECT_EQ(op_impl->SetAttr("int_attr", int_any_value), GRAPH_SUCCESS); - - AnyValue str_any_value; - std::string str_val = "test_string"; - str_any_value.SetValue(str_val); - EXPECT_EQ(op_impl->SetAttr("str_attr", str_any_value), GRAPH_SUCCESS); - - AnyValue float_any_value; - float32_t float_val = 3.14159f; - float_any_value.SetValue(float_val); - EXPECT_EQ(op_impl->SetAttr("float_attr", float_any_value), GRAPH_SUCCESS); - - // 验证所有属性都设置成功 - AnyValue get_int_any_value, get_str_any_value, get_float_any_value; - EXPECT_EQ(op_impl->GetAttr("int_attr", get_int_any_value), GRAPH_SUCCESS); - EXPECT_EQ(op_impl->GetAttr("str_attr", get_str_any_value), GRAPH_SUCCESS); - EXPECT_EQ(op_impl->GetAttr("float_attr", get_float_any_value), GRAPH_SUCCESS); - - int64_t get_int_val = 0; - std::string get_str_val; - float32_t get_float_val = 0.0f; - - EXPECT_EQ(get_int_any_value.GetValue(get_int_val), GRAPH_SUCCESS); - EXPECT_EQ(get_str_any_value.GetValue(get_str_val), GRAPH_SUCCESS); - EXPECT_EQ(get_float_any_value.GetValue(get_float_val), GRAPH_SUCCESS); - - EXPECT_EQ(get_int_val, int_val); - EXPECT_EQ(get_str_val, str_val); - EXPECT_FLOAT_EQ(get_float_val, float_val); -} -// extern "C" wrapper for Operator methods to avoid C++ name mangling -extern "C" { -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_Operator_AttrRegister(void *op_ptr, const char *name, - const void *attr_value) { - if (op_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *op = static_cast(op_ptr); - auto *value = static_cast(attr_value); - op->AttrRegister(name, *value); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_Operator_SetAttr(void *op_ptr, const char *name, - const void *attr_value) { - if (op_ptr == nullptr || name == nullptr || attr_value == nullptr) { - return GRAPH_FAILED; - } - auto *op = static_cast(op_ptr); - auto *value = static_cast(attr_value); - - *op = op->SetAttr(name, *value); - return GRAPH_SUCCESS; -} - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus aclCom_Operator_SetSubgraphInstanceName(void *op_ptr, - uint32_t index, - const char *name) { - if (op_ptr == nullptr || name == nullptr) { - return GRAPH_FAILED; - } - auto *op = static_cast(op_ptr); - return op->SetSubgraphInstanceName(index, name); -} -} -TEST_F(UtestOperater, ExternC_Operator_AttrRegister_Success) { - Operator op("test_op", "TestOp"); - AttrValue attr_value; - attr_value.SetAttrValue(static_cast(12345)); - - // 测试成功情况 - EXPECT_EQ(aclCom_Operator_AttrRegister(&op, "test_attr", &attr_value), GRAPH_SUCCESS); - - // 验证注册的属性 - AttrValue get_value; - EXPECT_EQ(op.GetAttr("test_attr", get_value), GRAPH_SUCCESS); - int64_t int_value = 0; - EXPECT_EQ(get_value.GetAttrValue(int_value), GRAPH_SUCCESS); - EXPECT_EQ(int_value, 12345); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_Operator_AttrRegister(nullptr, "test_attr", &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_Operator_AttrRegister(&op, nullptr, &attr_value), GRAPH_FAILED); -} - -TEST_F(UtestOperater, ExternC_Operator_SetAttr_Success) { - Operator op("test_op", "TestOp"); - AttrValue attr_value; - attr_value.SetAttrValue(static_cast(12345)); - - // 测试成功情况 - EXPECT_EQ(aclCom_Operator_SetAttr(&op, "test_attr", &attr_value), GRAPH_SUCCESS); - - // 验证设置的属性 - AttrValue get_value; - EXPECT_EQ(op.GetAttr("test_attr", get_value), GRAPH_SUCCESS); - int64_t int_value = 0; - EXPECT_EQ(get_value.GetAttrValue(int_value), GRAPH_SUCCESS); - EXPECT_EQ(int_value, 12345); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_Operator_SetAttr(nullptr, "test_attr", &attr_value), GRAPH_FAILED); - EXPECT_EQ(aclCom_Operator_SetAttr(&op, nullptr, &attr_value), GRAPH_FAILED); -} - -TEST_F(UtestOperater, ExternC_Operator_SetSubgraphInstanceName_Success) { - Operator op("test_op", "TestOp"); - op.SubgraphRegister("subgraph", true); - op.SubgraphCountRegister("subgraph", 2); - // 测试成功情况 - EXPECT_EQ(aclCom_Operator_SetSubgraphInstanceName(&op, 0, "subgraph_0"), GRAPH_SUCCESS); - EXPECT_EQ(aclCom_Operator_SetSubgraphInstanceName(&op, 1, "subgraph_1"), GRAPH_SUCCESS); - - // 测试nullptr参数 - EXPECT_EQ(aclCom_Operator_SetSubgraphInstanceName(nullptr, 0, "subgraph_0"), GRAPH_FAILED); - EXPECT_EQ(aclCom_Operator_SetSubgraphInstanceName(&op, 0, nullptr), GRAPH_FAILED); -} - -TEST_F(UtestOperater, ExternC_Operator_ComplexAttrValue) { - Operator op("test_op", "TestOp"); - - // 测试复杂类型的AttrValue - AttrValue complex_attr; - std::vector vec_value = {1, 2, 3, 4, 5}; - complex_attr.SetAttrValue(vec_value); - - EXPECT_EQ(aclCom_Operator_AttrRegister(&op, "complex_attr", &complex_attr), GRAPH_SUCCESS); - - // 验证复杂属性 - AttrValue get_complex_attr; - EXPECT_EQ(op.GetAttr("complex_attr", get_complex_attr), GRAPH_SUCCESS); - std::vector get_vec_value; - EXPECT_EQ(get_complex_attr.GetAttrValue(get_vec_value), GRAPH_SUCCESS); - EXPECT_EQ(get_vec_value, vec_value); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/opp_package_utils_unittest.cc b/tests/ut/graph/testcase/opp_package_utils_unittest.cc deleted file mode 100644 index a101512717a75e6d0d868d1c82430baae981b6d6..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/opp_package_utils_unittest.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "base/registry/opp_package_utils.h" - -namespace gert { -class UtestOppPackageUtils : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestOppPackageUtils, OppSoDescPathsTestSuccess) { - std::vector opp_path_vector = {"/path/opp.so"}; - OppSoDesc so_desc(opp_path_vector, "pkg_name"); - EXPECT_EQ(so_desc.GetSoPaths()[0], "/path/opp.so"); - EXPECT_EQ(so_desc.GetPackageName(), "pkg_name"); - - auto so_desc2 = so_desc; - EXPECT_EQ(so_desc2.GetSoPaths()[0], "/path/opp.so"); - EXPECT_EQ(so_desc2.GetPackageName(), "pkg_name"); - - OppSoDesc so_desc3(std::move(so_desc2)); - EXPECT_EQ(so_desc3.GetSoPaths()[0], "/path/opp.so"); - EXPECT_EQ(so_desc3.GetPackageName(), "pkg_name"); - - so_desc3 = std::move(so_desc); - EXPECT_EQ(so_desc3.GetSoPaths()[0], "/path/opp.so"); - EXPECT_EQ(so_desc3.GetPackageName(), "pkg_name"); - - // 代码覆盖 - so_desc = so_desc; - so_desc = std::move(so_desc); - - OppSoDesc so_desc4(so_desc3); - EXPECT_EQ(so_desc4.GetSoPaths()[0], "/path/opp.so"); - EXPECT_EQ(so_desc4.GetPackageName(), "pkg_name"); - - so_desc = so_desc4; - EXPECT_EQ(so_desc.GetSoPaths()[0], "/path/opp.so"); - EXPECT_EQ(so_desc.GetPackageName(), "pkg_name"); -} - -} // namespace gert diff --git a/tests/ut/graph/testcase/opsproto_manager_unittest.cc b/tests/ut/graph/testcase/opsproto_manager_unittest.cc deleted file mode 100644 index c3816ada3ccea96eed0c690c943e6f6711f2b702..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/opsproto_manager_unittest.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "graph/opsproto_manager.h" -#include "tests/depends/mmpa/src/mmpa_stub.h" - -using namespace std; -using namespace testing; - -namespace ge { -namespace { -class MockMmpa : public ge::MmpaStubApi { - public: - void *DlOpen(const char *fileName, int32_t mode) override { - return (void *) 0xffffffff; - } -}; -} -class OpsprotoManagerUt : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(OpsprotoManagerUt, Instance_Initialize_Finalize) { - OpsProtoManager *opspm = nullptr; - opspm = OpsProtoManager::Instance(); - EXPECT_NE(opspm, nullptr); - - const std::map options = { - {"ge.mockLibPath", "./gtest_build-prefix/src/gtest_build-build/googlemock/"}, - {"ge.opsProtoLibPath", "./protobuf_build-prefix/src/protobuf_build-build/"}}; - - auto ret = opspm->Initialize(options); - EXPECT_TRUE(ret); - - opspm->Finalize(); - EXPECT_EQ(opspm->handles_.size(), 0); - EXPECT_EQ(opspm->is_init_, false); -} - -TEST_F(OpsprotoManagerUt, LoadOpsProtoPluginSo) { - OpsProtoManager *opspm = OpsProtoManager::Instance(); - opspm->LoadOpsProtoPluginSo(""); - opspm->LoadOpsProtoPluginSo("./protobuf_build-prefix/src/protobuf_build-build/"); - - EXPECT_EQ(opspm->handles_.size(), 0); -} - -TEST_F(OpsprotoManagerUt, LoadOpsProtoPluginSo_Exclude_rt) { - std::string path = __FILE__; - path = path.substr(0, path.rfind("/") + 1) + "opp_test/"; - system(("mkdir -p " + path + "/lib").c_str()); - system(("touch " + path + "/lib/libopsproto.so").c_str()); - system(("touch " + path + "/lib/libopsproto_rt2.0.so").c_str()); - system(("touch " + path + "/lib/libopsproto_rt.so").c_str()); - - ge::MmpaStub::GetInstance().SetImpl(std::make_shared()); - OpsProtoManager::Instance()->handles_.clear(); - OpsProtoManager::Instance()->LoadOpsProtoPluginSo(path); - ASSERT_EQ(OpsProtoManager::Instance()->handles_.size(), 1); - system(("rm -rf " + path).c_str()); - MmpaStub::GetInstance().Reset(); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/optimization_option_unittest.cc b/tests/ut/graph/testcase/optimization_option_unittest.cc deleted file mode 100644 index 3bb570620727c41bd747378ef8420a75911a25a2..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/optimization_option_unittest.cc +++ /dev/null @@ -1,273 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/option/optimization_option.h" -#include "register/optimization_option_registry.h" -#include "register/pass_option_utils.h" -#include "ge_common/ge_api_types.h" -#include "graph/debug/ge_log.h" -#include "graph/ge_local_context.h" - -namespace { -bool ThresholdCheckerFunc(const std::string &opt_value) { - std::string tmp_opt_value = opt_value; - std::stringstream ss(ge::StringUtils::Trim(tmp_opt_value)); - int64_t opt_convert; - ss >> opt_convert; - if (ss.fail() || !ss.eof()) { - return false; - } - return true; -} - -bool CustomCheckerFunc(const std::string &opt_value) { - if (opt_value.empty() || (opt_value == "disable") || (opt_value == "enable")) { - return true; - } - return false; -} -} // namespace -namespace ge { -class OptimizationOptionUT : public testing::Test { - protected: - void SetUp() override { - oopt.Initialize({}, {}); - dlog_setlevel(0, 0, 0); - } - void TearDown() override { - dlog_setlevel(0, 3, 0); - } - - OptimizationOption &oopt = GetThreadLocalContext().GetOo(); - - const std::unordered_map ®istered_opt_table = - OptionRegistry::GetInstance().GetRegisteredOptTable(); - - static void CheckOptionValue(const OptimizationOption &oo, const std::string &opt_name, const std::string &expect_value) { - std::string value; - EXPECT_EQ(oo.GetValue(opt_name, value), GRAPH_SUCCESS); - EXPECT_EQ(value, expect_value); - } - static void CheckNotConfiguredOption(const OptimizationOption &oo, const std::string &opt_name) { - std::string value; - EXPECT_NE(oo.GetValue(opt_name, value), GRAPH_SUCCESS); - } -}; - -REG_PASS_OPTION("OoUtFunctionalPass1").LEVELS(OoLevel::kO0); -REG_PASS_OPTION("OoUtFunctionalPass2").LEVELS(OoLevel::kO1); -REG_PASS_OPTION("OoUtFunctionalPass3").LEVELS(OoLevel::kO2); -REG_PASS_OPTION("OoUtFunctionalPass4").LEVELS(OoLevel::kO3); - -REG_OPTION("ge.oo.test_dead_code_elimination") - .LEVELS(OoLevel::kO1) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_test_dead_code_elimination", OoCategory::kModelTuning); -REG_PASS_OPTION("OoUtDeadCodeEliminationPass").SWITCH_OPT("ge.oo.test_dead_code_elimination"); - -REG_OPTION("ge.oo.test_constant_folding") - .LEVELS(OoLevel::kO1) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_test_constant_folding", OoCategory::kModelTuning); -REG_OPTION("ge.oo.test_constant_folding_max_expand") - .LEVELS(OoLevel::kO1) - .CHECKER(ThresholdCheckerFunc) - .DEFAULT_VALUES({{OoLevel::kO1, "400"}, {OoLevel::kO2, "600"}, {OoLevel::kO3, "800"}}) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_test_constant_folding_max_expand", OoCategory::kModelTuning); -REG_PASS_OPTION("OoUtConstantFoldingPass").SWITCH_OPT("ge.oo.test_constant_folding"); - -REG_OPTION("ge.oo.test_graph_fusion") - .LEVELS(OoLevel::kO3) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_test_graph_fusion", OoCategory::kModelTuning); -REG_OPTION("ge.oo.test_graph_fusion_add_relu") - .LEVELS(OoLevel::kO3) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_test_graph_fusion_add_relu", OoCategory::kModelTuning); -REG_PASS_OPTION("OoUtGraphFusionAddReluPass") - .SWITCH_OPT("ge.oo.test_graph_fusion") - .SWITCH_OPT("ge.oo.test_graph_fusion_add_relu", OoHierarchy::kH2); - -REG_OPTION("ge.oo.test_graph_fusion_conv_relu") - .LEVELS(OoLevel::kO3) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_test_graph_fusion_conv_relu", OoCategory::kModelTuning); -REG_PASS_OPTION("OoUtGraphFusionConvReluPass") - .SWITCH_OPT("ge.oo.test_graph_fusion") - .SWITCH_OPT("ge.oo.test_graph_fusion_conv_relu", OoHierarchy::kH2); - -REG_OPTION("ge.oo.test_other_type_switch") - .LEVELS(OoLevel::kO3) - .DEFAULT_VALUES({{OoLevel::kO3, "enable"}}) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_test_other_type_switch", OoCategory::kModelTuning) - .CHECKER(CustomCheckerFunc) - .HELP("The switch of another feature"); - -TEST_F(OptimizationOptionUT, Initialize_Failed_InvalidOoLevel) { - std::map ge_options = {{ge::OO_LEVEL, ""}}; - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table), GRAPH_PARAM_INVALID); - std::string opt_value; - EXPECT_NE(oopt.GetValue("ge.oo.test_other_type_switch", opt_value), GRAPH_SUCCESS); - - ge_options = {{ge::OO_LEVEL, "OvO"}}; - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table), GRAPH_PARAM_INVALID); -} - -TEST_F(OptimizationOptionUT, Initialize_Failed_InvalidOptionValue) { - std::map ge_options = {{ge::OO_LEVEL, "O1"}, {"ge.oo.test_graph_fusion_conv_relu", "TRUE"}}; - EXPECT_NE(oopt.Initialize(ge_options, registered_opt_table), GRAPH_SUCCESS); - - ge_options = {{"ge.oo.test_constant_folding_max_expand", "2.33"}}; - EXPECT_NE(oopt.Initialize(ge_options, registered_opt_table), GRAPH_SUCCESS); - - // threshold exceeds the maximum value of int64_t - ge_options = {{"ge.oo.test_constant_folding_max_expand", "9223372036854775807000"}}; - EXPECT_NE(oopt.Initialize(ge_options, registered_opt_table), GRAPH_SUCCESS); -} - -TEST_F(OptimizationOptionUT, Initialize_Ok_OoLevelIsNotSpecified) { - std::string opt_value; - EXPECT_EQ(oopt.Initialize({}, registered_opt_table), GRAPH_SUCCESS); - CheckOptionValue(oopt, "ge.oo.test_other_type_switch", "enable"); - - std::map ge_options = {{"ge.oo.test_constant_folding", "false"}, - {"ge.oo.test_constant_folding_max_expand", "233"}}; - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table), GRAPH_SUCCESS); - CheckOptionValue(oopt, "ge.oo.test_graph_fusion", ""); - CheckOptionValue(oopt, "ge.oo.test_graph_fusion_conv_relu", ""); - CheckOptionValue(oopt, "ge.oo.test_graph_fusion_add_relu", ""); - CheckOptionValue(oopt, "ge.oo.test_constant_folding", "false"); - CheckOptionValue(oopt, "ge.oo.test_constant_folding_max_expand", "233"); - CheckOptionValue(oopt, "ge.oo.test_dead_code_elimination", ""); - CheckOptionValue(oopt, "ge.oo.test_other_type_switch", "enable"); -} - -TEST_F(OptimizationOptionUT, Initialize_Ok_OoLevelIsSpecified) { - std::map ge_options = {{ge::OO_LEVEL, "O1"}, - {"ge.constLifecycle", "graph"}, - {"ge.oo.test_graph_fusion", "true"}, - {"ge.oo.test_graph_fusion_add_relu", "false"}}; - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table), GRAPH_SUCCESS); - CheckOptionValue(oopt, "ge.oo.test_graph_fusion", "true"); - CheckOptionValue(oopt, "ge.oo.test_constant_folding", ""); - CheckOptionValue(oopt, "ge.oo.test_constant_folding_max_expand", "400"); - CheckOptionValue(oopt, "ge.oo.test_dead_code_elimination", ""); - CheckOptionValue(oopt, "ge.oo.test_graph_fusion_add_relu", "false"); - CheckNotConfiguredOption(oopt, "ge.oo.test_graph_fusion_conv_relu"); - CheckNotConfiguredOption(oopt, "ge.oo.test_other_type_switch"); - - bool is_enabled = false; - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtFunctionalPass1", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtFunctionalPass2", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtFunctionalPass3", is_enabled), GRAPH_SUCCESS); - EXPECT_TRUE(is_enabled == false); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtFunctionalPass4", is_enabled), GRAPH_SUCCESS); - EXPECT_TRUE(is_enabled == false); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtDeadCodeEliminationPass", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtConstantFoldingPass", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtGraphFusionAddReluPass", is_enabled), GRAPH_SUCCESS); - EXPECT_TRUE(is_enabled == false); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtGraphFusionConvReluPass", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); -} - -TEST_F(OptimizationOptionUT, IsPassEnable_Failed_PassIsNotRegistered) { - REG_PASS_OPTION("NoOptionRegisteredPass").SWITCH_OPT("unknown_option"); - std::map ge_options = {{ge::OO_LEVEL, "O1"}, - {"ge.constLifecycle", "graph"}, - {"ge.oo.test_graph_fusion", "true"}, - {"ge.oo.test_graph_fusion_add_relu", "false"}}; - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table), GRAPH_SUCCESS); - bool is_enabled = false; - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("OoUtUnknownPass", is_enabled), GRAPH_FAILED); - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("NoOptionRegisteredPass", is_enabled), GRAPH_FAILED); -} - -TEST_F(OptimizationOptionUT, Initialize_With_OptimizationSwitch) { - std::map ge_options = {{ge::OPTIMIZATION_SWITCH, "pass1:on;pass2;pass3:;:on;pass5:of;pass6:on"}}; - const std::unordered_set ge_option_set = {"pass6"}; - - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table, ge_option_set), GRAPH_SUCCESS); - std::string opt_value; - EXPECT_EQ(oopt.GetValue("ge.oo.test_other_type_switch", opt_value), GRAPH_SUCCESS); - - bool is_enabled = false; - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("pass1", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass2", is_enabled), GRAPH_SUCCESS); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass3", is_enabled), GRAPH_SUCCESS); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass5", is_enabled), GRAPH_SUCCESS); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass6", is_enabled), GRAPH_SUCCESS); -} - -TEST_F(OptimizationOptionUT, Initialize_With_FusionConfigStr_Have_OptimizationSwitch) { - std::map options_map; - options_map[ge::OPTIMIZATION_SWITCH] = "pass1:on;pass2;pass3:;:on;pass5:of;pass6:on"; - GetThreadLocalContext().SetGraphOption(options_map); - - std::map ge_options = {{ge::OPTIMIZATION_SWITCH, "pass1:on;pass2;pass3:;:on;pass5:of;pass6:on"}}; - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table, {}), GRAPH_SUCCESS); - - std::string fusion_config_str = "pass6:on;pass7:on;pass8;pass9:;:on;pass11:of"; - - EXPECT_EQ(oopt.RefreshPassSwitch(fusion_config_str), GRAPH_SUCCESS); - - bool is_enabled = false; - - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("pass6", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("pass7", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass8", is_enabled), GRAPH_SUCCESS); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass9", is_enabled), GRAPH_SUCCESS); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass11", is_enabled), GRAPH_SUCCESS); - - options_map.clear(); - GetThreadLocalContext().SetGraphOption(options_map); -} - -TEST_F(OptimizationOptionUT, Initialize_With_FusionConfigStr_No_OptimizationSwitch) { - EXPECT_EQ(oopt.Initialize({}, registered_opt_table, {}), GRAPH_SUCCESS); - - std::string fusion_config_str = "pass13:on;pass14;pass15:;:on;pass17:of"; - - EXPECT_EQ(oopt.RefreshPassSwitch(fusion_config_str), GRAPH_SUCCESS); - - bool is_enabled = false; - EXPECT_EQ(PassOptionUtils::CheckIsPassEnabled("pass13", is_enabled), GRAPH_SUCCESS); - EXPECT_EQ(is_enabled, true); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass14", is_enabled), GRAPH_SUCCESS); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass15", is_enabled), GRAPH_SUCCESS); - - EXPECT_NE(PassOptionUtils::CheckIsPassEnabled("pass17", is_enabled), GRAPH_SUCCESS); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/profiler_unittest.cc b/tests/ut/graph/testcase/profiler_unittest.cc deleted file mode 100644 index f72116237219e9d45b8734cd2fa2b956a96c2002..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/profiler_unittest.cc +++ /dev/null @@ -1,306 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/utils/profiler.h" -namespace ge { -namespace profiling { -namespace { - -std::string RandomStr(const int len) { - std::string res(len,' '); - for(int i = 0; i < len; ++i) { - res[i] = 'A' + rand() % 26; - } - return res; -} - - -std::string FindNext(const std::string &s, size_t &pos) { - std::stringstream ss; - for (; pos < s.size(); ++pos) { - if (s[pos] == '\r') { - ++pos; - if (pos + 1 < s.size() && s[pos + 1] == '\n') { - ++pos; - } - return ss.str(); - } - if (s[pos] == '\n') { - ++pos; - if (pos + 1 < s.size() && s[pos + 1] == '\r') { - ++pos; - } - return ss.str(); - } - ss << s[pos]; - } - return ss.str(); -} -std::vector SplitLines(const std::string &s) { - std::vector strings; - size_t i = 0; - while (i < s.size()) { - strings.emplace_back(FindNext(s, i)); - } - return strings; -} -std::vector Split(const std::string &s, std::string spliter) { - std::vector strings; - size_t i = 0; - while (i < s.size()) { - auto pos = s.find_first_of(spliter, i); - if (pos == std::string::npos) { - strings.emplace_back(s, i); - break; - } else { - strings.emplace_back(s, i, pos - i); - i = pos + spliter.size(); - } - } - - return strings; -} -} -class ProfilerUt : public testing::Test {}; - -TEST_F(ProfilerUt, OneRecord) { - auto p = Profiler::Create(); - p->Record(0, 1, 2, EventType::kEventStart, std::chrono::system_clock::now()); - EXPECT_EQ(p->GetRecordNum(), 1); - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 3); - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "UNKNOWN(0)"); - EXPECT_EQ(elements[3], "UNKNOWN(2)"); - EXPECT_EQ(elements[4], "Start"); -} - -TEST_F(ProfilerUt, TimeStampRecord) { - auto p = Profiler::Create(); - p->Record(0, 1, 2, EventType::kEventTimestamp, std::chrono::system_clock::now()); - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 3); - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 4); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "UNKNOWN(0)"); - EXPECT_EQ(elements[3], "UNKNOWN(2)"); -} - -TEST_F(ProfilerUt, MultipleRecords) { - auto p = Profiler::Create(); - p->Record(0, 1, 2, EventType::kEventStart, std::chrono::system_clock::now()); - p->Record(0, 1, 2, EventType::kEventEnd, std::chrono::system_clock::now()); - EXPECT_EQ(p->GetRecordNum(), 2); - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 4); - - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "UNKNOWN(0)"); - EXPECT_EQ(elements[3], "UNKNOWN(2)"); - EXPECT_EQ(elements[4], "Start"); - - elements = Split(lines[2], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "UNKNOWN(0)"); - EXPECT_EQ(elements[3], "UNKNOWN(2)"); - EXPECT_EQ(elements[4], "End"); -} - -TEST_F(ProfilerUt, RecordStr) { - auto p = Profiler::Create(); - p->RegisterString(0, "Node1"); - p->RegisterString(2, "InferShape"); - p->Record(0, 1, 2, EventType::kEventStart, std::chrono::system_clock::now()); - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 3); - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "[Node1]"); - EXPECT_EQ(elements[3], "[InferShape]"); - EXPECT_EQ(elements[4], "Start"); -} - -TEST_F(ProfilerUt, RecordCurrentThread) { - auto p = Profiler::Create(); - p->RecordCurrentThread(0, 2, EventType::kEventStart, std::chrono::system_clock::now()); - p->RecordCurrentThread(0, 2, EventType::kEventEnd); - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 4); - - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[2], "UNKNOWN(0)"); - EXPECT_EQ(elements[3], "UNKNOWN(2)"); - EXPECT_EQ(elements[4], "Start"); - - elements = Split(lines[2], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[2], "UNKNOWN(0)"); - EXPECT_EQ(elements[3], "UNKNOWN(2)"); - EXPECT_EQ(elements[4], "End"); -} - -TEST_F(ProfilerUt, Reset) { - auto p = Profiler::Create(); - p->RegisterString(0, "Node1"); - p->RegisterString(2, "InferShape"); - p->Record(0, 1, 2, EventType::kEventStart, std::chrono::system_clock::now()); - p->Reset(); - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 0); -} - -TEST_F(ProfilerUt, ResetRemainsRegisteredString) { - auto p = Profiler::Create(); - p->RegisterString(0, "Node1"); - p->RegisterString(2, "InferShape"); - p->Record(0, 1, 2, EventType::kEventStart, std::chrono::system_clock::now()); - p->Reset(); - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 0); - - - p->Record(0, 1, 2, EventType::kEventStart, std::chrono::system_clock::now()); - ss = std::stringstream(); - p->Dump(ss); - lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 3); - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "[Node1]"); - EXPECT_EQ(elements[3], "[InferShape]"); - EXPECT_EQ(elements[4], "Start"); -} - -TEST_F(ProfilerUt, RegisterStringBeyondMaxSize) { - auto p = Profiler::Create(); - p->RegisterString(2, "InferShape"); - p->RegisterString(kMaxStrIndex, "[Node1]"); - p->Record(kMaxStrIndex, 1, 2, EventType::kEventStart, std::chrono::system_clock::now()); - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 3); - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "UNKNOWN(" + std::to_string(kMaxStrIndex) + ")"); - EXPECT_EQ(elements[3], "[InferShape]"); - EXPECT_EQ(elements[4], "Start"); -} - -TEST_F(ProfilerUt, EventTypeBeyondRange) { - auto p = Profiler::Create(); - p->Record(0, 1, 2, EventType::kEventTypeEnd, std::chrono::system_clock::now()); - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), 3); - auto elements = Split(lines[1], " "); - EXPECT_EQ(elements.size(), 5); - EXPECT_EQ(elements[1], "1"); - EXPECT_EQ(elements[2], "UNKNOWN(0)"); - EXPECT_EQ(elements[3], "UNKNOWN(2)"); - EXPECT_EQ(elements[4], "UNKNOWN(3)"); -} - -TEST_F(ProfilerUt, GetRecords) { - auto p = Profiler::Create(); - p->Record(0, 1, 2, EventType::kEventTypeEnd, std::chrono::system_clock::now()); - auto rec = p->GetRecords(); - EXPECT_EQ(rec->element, 0); - EXPECT_EQ(rec->thread, 1); - EXPECT_EQ(rec->event, 2); - EXPECT_EQ(rec->et, EventType::kEventTypeEnd); -} - -TEST_F(ProfilerUt, GetStringHashes) { - auto p = Profiler::Create(); - p->RegisterString(0, "Node1"); - p->RegisterString(2, "InferShape"); - auto s = p->GetStringHashes(); - EXPECT_EQ(strcmp(s[0].str, "Node1"), 0); - EXPECT_EQ(strcmp(s[2].str, "InferShape"), 0); - p->RegisterStringHash(3, 0x55, "Node2"); - p->RegisterStringHash(4, 0xaa, "Tiling"); - EXPECT_EQ(s[3].hash == 0x55, true); - EXPECT_EQ(s[4].hash == 0xaa, true); - p->UpdateHashByIndex(0, 0x5a); - p->UpdateHashByIndex(2, 0xa5); - EXPECT_EQ(s[0].hash == 0x5a, true); - EXPECT_EQ(s[2].hash == 0xa5, true); -} - -TEST_F(ProfilerUt, RegisterTooLongString) { - auto p = Profiler::Create(); - std::string input = RandomStr(300); - std::string gt_res = input.substr(0,255); - p->RegisterString(0, input); - p->RegisterString(2, "InferShape"); - auto s = p->GetStringHashes(); - EXPECT_EQ(strcmp(s[0].str, gt_res.c_str()), 0); - EXPECT_EQ(strcmp(s[2].str, "InferShape"), 0); -} - -TEST_F(ProfilerUt, ModifyStrings) { - auto p = Profiler::Create(); - p->RegisterString(0, "AbcdefghijklmnopqrstuvwxyzAbcdefghijklmnopqrstuvwxyzAbcdefghijklmnopqrstuvwxyz"); - p->RegisterString(2, "InferShape"); - auto s = p->GetStringHashes(); - strcpy(s[2].str, "Tiling"); - EXPECT_EQ(strcmp(s[0].str, "AbcdefghijklmnopqrstuvwxyzAbcdefghijklmnopqrstuvwxyzAbcdefghijklmnopqrstuvwxyz"), 0); - EXPECT_EQ(strcmp(p->GetStringHashes()[2].str, "Tiling"), 0); -} - -/* takes very long time -TEST_F(ProfilerUt, BeyondMaxRecordsNum) { - auto p = Profiler::Create(); - for (int64_t i = 0; i < profiling::kMaxRecordNum; ++i) { - p->Record(0, 1, i, kEventStart); - p->Record(0, 1, i, kEventEnd); - } - - std::stringstream ss; - p->Dump(ss); - auto lines = SplitLines(ss.str()); - EXPECT_EQ(lines.size(), profiling::kMaxRecordNum + 3); -} -*/ -} -} diff --git a/tests/ut/graph/testcase/quick_list_unittest.cc b/tests/ut/graph/testcase/quick_list_unittest.cc deleted file mode 100644 index c4fc904cd713adbd29f17a81674bd8a6f0d7a796..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/quick_list_unittest.cc +++ /dev/null @@ -1,238 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "fast_graph/quick_list.h" - -namespace ge { -class UtestQuickList : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(UtestQuickList, TestPushBack) { - QuickList test2; - int32_t test_loop = 100; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < test_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - test2.push_back(list[i], ListMode::kWorkMode); - } - ASSERT_EQ(test2.size(), test_loop); - test2.clear(); - for (int32_t i = 0; i < test_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } -} - -TEST_F(UtestQuickList, TestInsert) { - int32_t test_loop = 100; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < test_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - } - - // test insert begin() - QuickList test3; - for (int32_t i = 0; i < test_loop; i++) { - test3.insert(test3.begin(), list[i], ListMode::kWorkMode); - } - ASSERT_EQ(test3.size(), test_loop); - test3.clear(); - - // test insert end() - QuickList test5; - int insert_begin_num = 3; - for (int32_t i = 0; i < insert_begin_num; i++) { - test5.insert(test5.begin(), list[i], ListMode::kWorkMode); - } - for (int32_t i = insert_begin_num; i < test_loop; i++) { - test5.insert(test5.end(), list[i], ListMode::kWorkMode); - } - ASSERT_EQ(test5.size(), test_loop); - auto pos = test5.end(); - --pos; - // check the end node is correct. - ASSERT_EQ(*(pos), list[test_loop - 1]); - test5.clear(); - - QuickList test6; - int inner_num = 3; - for (int32_t i = 0; i < inner_num; i++) { - test6.insert(test6.begin(), list[i], ListMode::kWorkMode); - } - auto inner_pos = test6.begin(); - ++inner_pos; - for (int32_t i = inner_num; i < test_loop; i++) { - test6.insert(inner_pos, list[i], ListMode::kWorkMode); - } - ASSERT_EQ(test6.size(), test_loop); - // check the end node is correct. - auto end_pos = test6.end(); - --end_pos; - ASSERT_EQ(*(end_pos), list[0]); - test6.clear(); - - for (int32_t i = 0; i < test_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } -} - -TEST_F(UtestQuickList, TestPushFront) { - QuickList test2; - int32_t test_loop = 100; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < test_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - test2.push_front(list[i], ListMode::kWorkMode); - } - ASSERT_EQ(test2.size(), test_loop); - auto begin_pos = test2.begin(); - ASSERT_EQ(*begin_pos, list[test_loop - 1]); - auto end_pos = test2.end(); - --end_pos; - ASSERT_EQ(*end_pos, list[0]); - test2.clear(); - - for (int32_t i = 0; i < test_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } -} - -TEST_F(UtestQuickList, TestMove) { - int32_t test_loop = 100; - int dst_relative_pose = 2; - { - QuickList test1; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < test_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - test1.push_back(list[i], ListMode::kWorkMode); - } - test1.move(list[test_loop - 1], list[test_loop - dst_relative_pose], true); - auto pos = test1.end(); - --pos; - // check the end node is correct. - ASSERT_EQ(*(pos), list[test_loop - dst_relative_pose]); - test1.clear(); - for (int32_t i = 0; i < test_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } - } - - { - QuickList test1; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < test_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - test1.push_back(list[i], ListMode::kWorkMode); - } - test1.move(list[test_loop - 1], list[test_loop - dst_relative_pose], false); - auto pos = test1.end(); - --pos; - // check the end node is correct. - ASSERT_EQ(*(pos), list[test_loop - 1]); - test1.clear(); - for (int32_t i = 0; i < test_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } - } - - { - QuickList test1; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < test_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - test1.push_back(list[i], ListMode::kWorkMode); - } - test1.move(list[0], list[test_loop - 1], false); - auto pos = test1.end(); - --pos; - // check the end node is correct. - ASSERT_EQ(*(pos), list[0]); - test1.clear(); - for (int32_t i = 0; i < test_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } - } - - { - QuickList test1; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < test_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - test1.push_back(list[i], ListMode::kWorkMode); - } - test1.move(list[0], list[test_loop - 1], true); - auto pos = test1.end(); - --pos; - // check the end node is correct. - ASSERT_EQ(*(pos), list[test_loop - 1]); - test1.clear(); - for (int32_t i = 0; i < test_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } - } -} - -TEST_F(UtestQuickList, TestErase) { - QuickList test2; - int real_loop = 100; - int32_t test_loop = real_loop; - ListElement *list[test_loop] = {}; - for (int32_t i = 0; i < real_loop; i++) { - list[i] = new ListElement; - list[i]->data = 1; - test2.push_back(list[i], ListMode::kWorkMode); - } - ASSERT_EQ(test2.size(), test_loop); - test2.erase(list[test_loop - 1]); - test_loop--; - - auto iter = test2.begin(); - test2.erase(iter); - test_loop--; - ASSERT_EQ(test2.size(), test_loop); - - ASSERT_EQ(*(test2.begin()), list[1]); - auto end_iter = test2.end(); - --end_iter; - ASSERT_EQ(*(end_iter), list[test_loop]); - test2.clear(); - for (int32_t i = 0; i < real_loop; i++) { - if (list[i] != nullptr) { - delete list[i]; - } - } -} - -} // namespace ge diff --git a/tests/ut/graph/testcase/ref_relation_unittes.cc b/tests/ut/graph/testcase/ref_relation_unittes.cc deleted file mode 100644 index fe4554fd588f07fa54a84be05f9d40f6ea2c4b0c..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ref_relation_unittes.cc +++ /dev/null @@ -1,1095 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/ref_relation.h" - -#include -#include -#include "graph_builder_utils.h" -#include "graph/node.h" -#include "graph/operator_factory.h" -#include "graph/compute_graph.h" -#include "graph/operator.h" -#include "graph/operator_reg.h" - -using namespace ge; -using namespace std; -namespace ge { -class UTTEST_RefRelations : public testing::Test { - protected: - - void SetUp() { - } - - void TearDown() { - } -}; - -namespace { - -/* - * netoutput1 - * | - * add - * / \ - * data1 data2 - */ -ComputeGraphPtr BuildSubGraph(const std::string name) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto add = builder.AddNode(name + "sub", "Sub", 2, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 1, 1); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - - builder.AddDataEdge(data1, 0, add, 0); - builder.AddDataEdge(data2, 0, add, 1); - builder.AddDataEdge(add, 0, netoutput, 0); - - return builder.GetGraph(); -} -/* - * netoutput - * | - * if - * / \ - * data1 data2 - */ -ComputeGraphPtr BuildMainGraphWithIf() { - ut::GraphBuilder builder("main_graph"); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto if1 = builder.AddNode("if", "If", 2, 1); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 1, 1); - - builder.AddDataEdge(data1, 0, if1, 0); - builder.AddDataEdge(data2, 0, if1, 1); - builder.AddDataEdge(if1, 0, netoutput1, 0); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildSubGraph("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildSubGraph("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} - -/* - * netoutput - * | - * if - * / \ - * data1 data2 - */ -ComputeGraphPtr BuildMainGraphWithIfButWithNoSubgraph() { - ut::GraphBuilder builder("main_graph"); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto if1 = builder.AddNode("if", "If", 2, 1); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 1, 1); - - builder.AddDataEdge(data1, 0, if1, 0); - builder.AddDataEdge(data2, 0, if1, 1); - builder.AddDataEdge(if1, 0, netoutput1, 0); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildSubGraph("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - - auto sub2 = BuildSubGraph("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - - return main_graph; -} - -/* - * netoutput1 - * | - * add - * / \ \ - * data1 data2 data3 - */ -ComputeGraphPtr BuildSubGraph3(const std::string name) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto data3 = builder.AddNode(name + "data3", "Data", 1, 1); - - auto add = builder.AddNode(name + "sub", "Sub", 3, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 1, 1); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(data3->GetOpDesc(), "_parent_node_index", static_cast(2)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - - builder.AddDataEdge(data1, 0, add, 0); - builder.AddDataEdge(data2, 0, add, 1); - builder.AddDataEdge(data3, 0, add, 2); - builder.AddDataEdge(add, 0, netoutput, 0); - - return builder.GetGraph(); -} -/* - * netoutput - * | - * if - * / \ \ - * data1 data2 data3 - */ -ComputeGraphPtr BuildMainGraphWithIf3() { - ut::GraphBuilder builder("main_graph"); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto data3 = builder.AddNode("data3", "Data", 1, 1); - auto if1 = builder.AddNode("if", "If", 3, 1); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 1, 1); - - builder.AddDataEdge(data1, 0, if1, 0); - builder.AddDataEdge(data2, 0, if1, 1); - builder.AddDataEdge(data3, 0, if1, 2); - builder.AddDataEdge(if1, 0, netoutput1, 0); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildSubGraph3("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(if1); - if1->GetOpDesc()->AddSubgraphName("sub1"); - if1->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildSubGraph3("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(if1); - if1->GetOpDesc()->AddSubgraphName("sub2"); - if1->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} - -/* - * netoutput1 - * | \ - * sub relu - * / \ / - * data1 data2 - */ -ComputeGraphPtr BuildSubGraph2(const std::string name) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto sub = builder.AddNode(name + "sub", "Sub", 2, 1); - auto relu = builder.AddNode(name + "relu", "Relu", 1, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 2, 2); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", static_cast(1)); - - - builder.AddDataEdge(data1, 0, sub, 0); - builder.AddDataEdge(data2, 0, sub, 1); - builder.AddDataEdge(sub, 0, netoutput, 0); - builder.AddDataEdge(data2, 0, relu, 0); - builder.AddDataEdge(relu, 0, netoutput, 1); - - - return builder.GetGraph(); -} -/* - * netoutput relu - * | / - * if - * / \ - * data1 data2 - */ -ComputeGraphPtr BuildMainGraphWithIf2() { - ut::GraphBuilder builder("main_graph"); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto if1 = builder.AddNode("if", "If", 2, 2); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 2, 2); - auto relu = builder.AddNode("relu", "Relu", 1, 1); - - builder.AddDataEdge(data1, 0, if1, 0); - builder.AddDataEdge(data2, 0, if1, 1); - builder.AddDataEdge(if1, 0, netoutput1, 0); - builder.AddDataEdge(if1, 1, relu, 0); - builder.AddDataEdge(relu, 0, netoutput1, 1); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildSubGraph2("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildSubGraph2("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("if")); - main_graph->FindNode("if")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("if")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} -/* - * netoutput - * | \ - * sub relu \ - * / \ / - * data1 data2 data3 - */ -ComputeGraphPtr BuildWhileBodySubGraph(const std::string name) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto data3 = builder.AddNode(name + "data3", "Data", 1, 1); - auto sub = builder.AddNode(name + "sub", "Sub", 2, 1); - auto relu = builder.AddNode(name + "relu", "Relu", 1, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 3, 3); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(data3->GetOpDesc(), "_parent_node_index", static_cast(2)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(2), "_parent_node_index", static_cast(2)); - - - builder.AddDataEdge(data1, 0, sub, 0); - builder.AddDataEdge(data2, 0, sub, 1); - builder.AddDataEdge(sub, 0, netoutput, 0); - builder.AddDataEdge(data2, 0, relu, 0); - builder.AddDataEdge(relu, 0, netoutput, 1); - builder.AddDataEdge(data3, 0, netoutput, 2); - - - return builder.GetGraph(); -} -/* - * netoutput1 - * | - * mul - * / \ \ - * data1 data2 data3 - */ -ComputeGraphPtr BuildWhileCondSubGraph(const std::string name) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto data3 = builder.AddNode(name + "data3", "Data", 1, 1); - auto mul = builder.AddNode(name + "mul", "Mul", 3, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 1, 1); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(data3->GetOpDesc(), "_parent_node_index", static_cast(2)); - - builder.AddDataEdge(data1, 0, mul, 0); - builder.AddDataEdge(data2, 0, mul, 1); - builder.AddDataEdge(data3, 0, mul, 2); - builder.AddDataEdge(mul, 0, netoutput, 0); - - return builder.GetGraph(); -} -/* - * netoutput relu - * | / - * while - * / \ \ - * data1 data2 const - */ -ComputeGraphPtr BuildMainGraphWithWhile() { - ut::GraphBuilder builder("main_graph"); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto n = builder.AddNode("n", "Const", 1, 1); - auto while1 = builder.AddNode("while1", "While", 3, 3); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 2, 2); - auto relu = builder.AddNode("relu", "Relu", 1, 1); - - builder.AddDataEdge(data1, 0, while1, 0); - builder.AddDataEdge(data2, 0, while1, 1); - builder.AddDataEdge(n, 0, while1, 2); - builder.AddDataEdge(while1, 0, netoutput1, 0); - builder.AddDataEdge(while1, 1, relu, 0); - builder.AddDataEdge(relu, 0, netoutput1, 1); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildWhileCondSubGraph("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("while1")); - main_graph->FindNode("while1")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("while1")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildWhileBodySubGraph("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("while1")); - main_graph->FindNode("while1")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("while1")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} -/* - * netoutput - * | \ - * sub relu \ - * / \ / - * data1 data2 data3 - */ -ComputeGraphPtr BuildWhileBodySubGraph2(const std::string name) { - ut::GraphBuilder builder(name); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto data3 = builder.AddNode(name + "data3", "Data", 1, 1); - auto sub = builder.AddNode(name + "sub", "Sub", 2, 1); - auto relu = builder.AddNode(name + "relu", "Relu", 1, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 3, 3); - - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(data3->GetOpDesc(), "_parent_node_index", static_cast(2)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(2), "_parent_node_index", static_cast(2)); - - - builder.AddDataEdge(data1, 0, sub, 0); - builder.AddDataEdge(data2, 0, sub, 1); - builder.AddDataEdge(sub, 0, netoutput, 0); - builder.AddDataEdge(data2, 0, relu, 0); - builder.AddDataEdge(relu, 0, netoutput, 1); - builder.AddDataEdge(data3, 0, netoutput, 2); - - - return builder.GetGraph(); -} -/* - * netoutput relu - * | / - * while - * / \ \ - * data1 data2 const - */ -ComputeGraphPtr BuildMainGraphWithWhile2() { - ut::GraphBuilder builder("main_graph"); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto data2 = builder.AddNode("data2", "Data", 1, 1); - auto n = builder.AddNode("n", "Const", 1, 1); - auto while1 = builder.AddNode("while1", "While", 3, 3); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 2, 2); - auto relu = builder.AddNode("relu", "Relu", 1, 1); - - builder.AddDataEdge(data1, 0, while1, 0); - builder.AddDataEdge(data2, 0, while1, 1); - builder.AddDataEdge(n, 0, while1, 2); - builder.AddDataEdge(while1, 0, netoutput1, 0); - builder.AddDataEdge(while1, 1, relu, 0); - builder.AddDataEdge(relu, 0, netoutput1, 1); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildWhileCondSubGraph("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("while1")); - main_graph->FindNode("while1")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("while1")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildWhileBodySubGraph2("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("while1")); - main_graph->FindNode("while1")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("while1")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} - -/* - * netoutput - * | \ \ - * sub \ \ - * / \ \ \ - * data1 const data2 data3 - */ -ComputeGraphPtr BuildWhileBodySubGraph3(const std::string name) { - ut::GraphBuilder builder(name); - auto data0 = builder.AddNode(name + "data0", "Data", 1, 1); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto sub = builder.AddNode(name + "sub", "Sub", 2, 1); - auto const1 = builder.AddNode(name + "const1", "Const", 0, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 4, 4); - - AttrUtils::SetInt(data0->GetOpDesc(), "_parent_node_index", static_cast(2)); - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(2), "_parent_node_index", static_cast(2)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(1), "_parent_node_index", static_cast(1)); - AttrUtils::SetInt(netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", static_cast(0)); - - - builder.AddDataEdge(data0, 0, sub, 0); - builder.AddDataEdge(const1, 0, sub, 1); - builder.AddDataEdge(const1, 0, netoutput, 3); - builder.AddDataEdge(sub, 0, netoutput, 2); - builder.AddDataEdge(data1, 0, netoutput, 1); - builder.AddDataEdge(data2, 0, netoutput, 0); - - - return builder.GetGraph(); -} -/* - * netoutput1 - * | - * mul - * / \ \ - * data1 data2 data3 - */ -ComputeGraphPtr BuildWhileCondSubGraph3(const std::string name) { - ut::GraphBuilder builder(name); - auto data0 = builder.AddNode(name + "data0", "Data", 1, 1); - auto data1 = builder.AddNode(name + "data1", "Data", 1, 1); - auto data2 = builder.AddNode(name + "data2", "Data", 1, 1); - auto mul = builder.AddNode(name + "mul", "Mul", 3, 1); - auto netoutput = builder.AddNode(name + "netoutput", "NetOutput", 1, 1); - - AttrUtils::SetInt(data0->GetOpDesc(), "_parent_node_index", static_cast(2)); - AttrUtils::SetInt(data1->GetOpDesc(), "_parent_node_index", static_cast(0)); - AttrUtils::SetInt(data2->GetOpDesc(), "_parent_node_index", static_cast(1)); - - builder.AddDataEdge(data0, 0, mul, 0); - builder.AddDataEdge(data1, 0, mul, 1); - builder.AddDataEdge(data2, 0, mul, 2); - builder.AddDataEdge(mul, 0, netoutput, 0); - - return builder.GetGraph(); -} -/* - * netoutput relu - * | / - * while - * / \ \ - * data1 data2 const - */ -ComputeGraphPtr BuildMainGraphWithWhile3() { - ut::GraphBuilder builder("main_graph"); - auto data0 = builder.AddNode("data0", "Data", 1, 1); - auto data1 = builder.AddNode("data1", "Data", 1, 1); - auto n = builder.AddNode("n", "Const", 1, 1); - auto while1 = builder.AddNode("while1", "While", 3, 3); - auto netoutput1 = builder.AddNode("netoutput", "NetOutput", 2, 2); - auto relu = builder.AddNode("relu", "Relu", 1, 1); - - builder.AddDataEdge(data0, 0, while1, 0); - builder.AddDataEdge(data1, 0, while1, 1); - builder.AddDataEdge(n, 0, while1, 2); - builder.AddDataEdge(while1, 0, netoutput1, 0); - builder.AddDataEdge(while1, 1, relu, 0); - builder.AddDataEdge(relu, 0, netoutput1, 1); - - auto main_graph = builder.GetGraph(); - - auto sub1 = BuildWhileCondSubGraph3("sub1"); - sub1->SetParentGraph(main_graph); - sub1->SetParentNode(main_graph->FindNode("while1")); - main_graph->FindNode("while1")->GetOpDesc()->AddSubgraphName("sub1"); - main_graph->FindNode("while1")->GetOpDesc()->SetSubgraphInstanceName(0, "sub1"); - main_graph->AddSubgraph("sub1", sub1); - - auto sub2 = BuildWhileBodySubGraph3("sub2"); - sub2->SetParentGraph(main_graph); - sub2->SetParentNode(main_graph->FindNode("while1")); - main_graph->FindNode("while1")->GetOpDesc()->AddSubgraphName("sub2"); - main_graph->FindNode("while1")->GetOpDesc()->SetSubgraphInstanceName(1, "sub2"); - main_graph->AddSubgraph("sub2", sub2); - - return main_graph; -} -// Check result -void CheckResult(RefRelations &ref_builder, vector &keys, unordered_set &values) { - for (const auto &key : keys) { - std::unordered_set result; - auto status = ref_builder.LookUpRefRelations(key, result); - EXPECT_EQ(status, GRAPH_SUCCESS); - for (const auto &it : result) { - string res = it.node_name + std::to_string(it.in_out) + std::to_string(it.in_out_idx) + - std::to_string((unsigned long) it.node.get()); - auto iter = values.find(res); - bool is_exist = (iter == values.end()) ? false : true; - EXPECT_EQ(is_exist, true); - } - } -} -} - -TEST_F(UTTEST_RefRelations, Pass_if_1) { - auto main_graph = BuildMainGraphWithIf(); - - auto sub1 = main_graph->GetSubgraph("sub1"); - auto sub2 = main_graph->GetSubgraph("sub2"); - auto if1 = main_graph->FindNode("if"); - auto sub1data1 = sub1->FindNode("sub1data1"); - auto sub1data2 = sub1->FindNode("sub1data2"); - auto sub2data1 = sub2->FindNode("sub2data1"); - auto sub2data2 = sub2->FindNode("sub2data2"); - auto sub1netoutput = sub1->FindNode("sub1netoutput"); - auto sub2netoutput = sub2->FindNode("sub2netoutput"); - - string if1_s = std::to_string((unsigned long)if1.get()); - string sub1data1_s = std::to_string((unsigned long)sub1data1.get()); - string sub1data2_s = std::to_string((unsigned long)sub1data2.get()); - string sub2data1_s = std::to_string((unsigned long)sub2data1.get()); - string sub2data2_s = std::to_string((unsigned long)sub2data2.get()); - string sub1netoutput_s = std::to_string((unsigned long)sub1netoutput.get()); - string sub2netoutput_s = std::to_string((unsigned long)sub2netoutput.get()); - - RefRelations ref_builder; - auto status = ref_builder.BuildRefRelations(*main_graph); - EXPECT_EQ(status, GRAPH_SUCCESS); - - - vector keys_1 = { - RefCell("sub1data1",sub1data1, NODE_IN,0), - RefCell("sub1data1",sub1data1, NODE_OUT,0), - RefCell("sub2data1",sub2data1, NODE_IN,0), - RefCell("sub2data1",sub2data1, NODE_OUT,0), - RefCell("if", if1, NODE_IN, 0), - }; - - unordered_set values_1 = { - string("sub1data100") + sub1data1_s, - string("sub1data110") + sub1data1_s, - string("sub2data100") + sub2data1_s, - string("sub2data110") + sub2data1_s, - string("if00") + if1_s - }; - - vector keys_2 = { - RefCell("sub1data2", sub1data2, NODE_IN,0), - RefCell("sub1data2", sub1data2, NODE_OUT,0), - RefCell("sub2data2", sub2data2, NODE_IN,0), - RefCell("sub2data2", sub2data2, NODE_OUT,0), - RefCell("if", if1, NODE_IN, 1), - }; - unordered_set values_2 = { - string("sub1data200") + sub1data2_s, - string("sub1data210") + sub1data2_s, - string("sub2data200") + sub2data2_s, - string("sub2data210") + sub2data2_s, - string("if01") + if1_s - }; - - vector keys_3 = { - RefCell("sub1netoutput",sub1netoutput, NODE_IN,0), - RefCell("sub2netoutput",sub2netoutput, NODE_IN,0), - RefCell("if", if1, NODE_OUT, 0), - }; - - unordered_set values_3 = { - string("sub1netoutput00") + sub1netoutput_s, - string("sub2netoutput00") + sub2netoutput_s, - string("if10") + if1_s - }; - - CheckResult(ref_builder, keys_1, values_1); - CheckResult(ref_builder, keys_2, values_2); - CheckResult(ref_builder, keys_3, values_3); - -} - -TEST_F(UTTEST_RefRelations, Pass_if_2) { - auto main_graph = BuildMainGraphWithIf2(); - - auto sub1 = main_graph->GetSubgraph("sub1"); - auto sub2 = main_graph->GetSubgraph("sub2"); - auto if1 = main_graph->FindNode("if"); - auto sub1data1 = sub1->FindNode("sub1data1"); - auto sub1data2 = sub1->FindNode("sub1data2"); - auto sub2data1 = sub2->FindNode("sub2data1"); - auto sub2data2 = sub2->FindNode("sub2data2"); - auto sub1netoutput = sub1->FindNode("sub1netoutput"); - auto sub2netoutput = sub2->FindNode("sub2netoutput"); - - string if1_s = std::to_string((unsigned long)if1.get()); - string sub1data1_s = std::to_string((unsigned long)sub1data1.get()); - string sub1data2_s = std::to_string((unsigned long)sub1data2.get()); - string sub2data1_s = std::to_string((unsigned long)sub2data1.get()); - string sub2data2_s = std::to_string((unsigned long)sub2data2.get()); - string sub1netoutput_s = std::to_string((unsigned long)sub1netoutput.get()); - string sub2netoutput_s = std::to_string((unsigned long)sub2netoutput.get()); - - RefRelations ref_builder; - auto status = ref_builder.BuildRefRelations(*main_graph); - EXPECT_EQ(status, GRAPH_SUCCESS); - - vector keys_1 = { - RefCell("sub1data1", sub1data1, NODE_IN,0), - RefCell("sub1data1", sub1data1, NODE_OUT,0), - RefCell("sub2data1", sub2data1, NODE_IN,0), - RefCell("sub2data1", sub2data1, NODE_OUT,0), - RefCell("if", if1, NODE_IN, 0), - }; - unordered_set values_1 = { - string("sub1data100") + sub1data1_s, - string("sub1data110") + sub1data1_s, - string("sub2data100") + sub2data1_s, - string("sub2data110") + sub2data1_s, - string("if00") + if1_s - }; - - vector keys_2 = { - RefCell("sub1data2", sub1data2 ,NODE_IN,0), - RefCell("sub1data2", sub1data2 ,NODE_OUT,0), - RefCell("sub2data2", sub2data2 ,NODE_IN,0), - RefCell("sub2data2", sub2data2 ,NODE_OUT,0), - RefCell("if", if1, NODE_IN, 1), - }; - unordered_set values_2 = { - string("sub1data200") + sub1data2_s, - string("sub1data210") + sub1data2_s, - string("sub2data200") + sub2data2_s, - string("sub2data210") + sub2data2_s, - string("if01") + if1_s - }; - - vector keys_3 = { - RefCell("sub1netoutput", sub1netoutput,NODE_IN,0), - RefCell("sub2netoutput", sub2netoutput,NODE_IN,0), - RefCell("if", if1, NODE_OUT, 0), - }; - - unordered_set values_3 = { - string("sub1netoutput00") + sub1netoutput_s, - string("sub2netoutput00") + sub2netoutput_s, - string("if10") + if1_s - }; - - vector keys_4 = { - RefCell("sub1netoutput",sub1netoutput, NODE_IN,1), - RefCell("sub2netoutput",sub2netoutput, NODE_IN,1), - RefCell("if", if1, NODE_OUT, 1), - }; - - unordered_set values_4 = { - string("sub1netoutput01") + sub1netoutput_s, - string("sub2netoutput01") + sub2netoutput_s, - string("if11") + if1_s - }; - - CheckResult(ref_builder, keys_1, values_1); - CheckResult(ref_builder, keys_2, values_2); - CheckResult(ref_builder, keys_3, values_3); - CheckResult(ref_builder, keys_4, values_4); - -} - -TEST_F(UTTEST_RefRelations, Pass_if_3) { - auto main_graph = BuildMainGraphWithIf3(); - - auto sub1 = main_graph->GetSubgraph("sub1"); - auto sub2 = main_graph->GetSubgraph("sub2"); - auto if1 = main_graph->FindNode("if"); - auto sub1data1 = sub1->FindNode("sub1data1"); - auto sub1data2 = sub1->FindNode("sub1data2"); - auto sub1data3 = sub1->FindNode("sub1data3"); - auto sub2data1 = sub2->FindNode("sub2data1"); - auto sub2data2 = sub2->FindNode("sub2data2"); - auto sub2data3 = sub2->FindNode("sub2data3"); - auto sub1netoutput = sub1->FindNode("sub1netoutput"); - auto sub2netoutput = sub2->FindNode("sub2netoutput"); - - string if1_s = std::to_string((unsigned long)if1.get()); - string sub1data1_s = std::to_string((unsigned long)sub1data1.get()); - string sub1data2_s = std::to_string((unsigned long)sub1data2.get()); - string sub1data3_s = std::to_string((unsigned long)sub1data3.get()); - string sub2data1_s = std::to_string((unsigned long)sub2data1.get()); - string sub2data2_s = std::to_string((unsigned long)sub2data2.get()); - string sub2data3_s = std::to_string((unsigned long)sub2data3.get()); - string sub1netoutput_s = std::to_string((unsigned long)sub1netoutput.get()); - string sub2netoutput_s = std::to_string((unsigned long)sub2netoutput.get()); - - RefRelations ref_builder; - auto status = ref_builder.BuildRefRelations(*main_graph); - EXPECT_EQ(status, GRAPH_SUCCESS); - - - vector keys_1 = { - RefCell("sub1data1",sub1data1, NODE_IN,0), - RefCell("sub1data1",sub1data1, NODE_OUT,0), - RefCell("sub2data1",sub2data1, NODE_IN,0), - RefCell("sub2data1",sub2data1, NODE_OUT,0), - RefCell("if", if1, NODE_IN, 0), - }; - - unordered_set values_1 = { - string("sub1data100") + sub1data1_s, - string("sub1data110") + sub1data1_s, - string("sub2data100") + sub2data1_s, - string("sub2data110") + sub2data1_s, - string("if00") + if1_s - }; - - vector keys_2 = { - RefCell("sub1data2", sub1data2, NODE_IN,0), - RefCell("sub1data2", sub1data2, NODE_OUT,0), - RefCell("sub2data2", sub2data2, NODE_IN,0), - RefCell("sub2data2", sub2data2, NODE_OUT,0), - RefCell("if", if1, NODE_IN, 1), - }; - unordered_set values_2 = { - string("sub1data200") + sub1data2_s, - string("sub1data210") + sub1data2_s, - string("sub2data200") + sub2data2_s, - string("sub2data210") + sub2data2_s, - string("if01") + if1_s - }; - - vector keys_4 = { - RefCell("sub1data3", sub1data3, NODE_IN,0), - RefCell("sub1data3", sub1data3, NODE_OUT,0), - RefCell("sub2data3", sub2data3, NODE_IN,0), - RefCell("sub2data3", sub2data3, NODE_OUT,0), - RefCell("if", if1, NODE_IN, 1), - }; - unordered_set values_4 = { - string("sub1data300") + sub1data3_s, - string("sub1data310") + sub1data3_s, - string("sub2data300") + sub2data3_s, - string("sub2data310") + sub2data3_s, - string("if01") + if1_s - }; - - vector keys_3 = { - RefCell("sub1netoutput",sub1netoutput, NODE_IN,0), - RefCell("sub2netoutput",sub2netoutput, NODE_IN,0), - RefCell("if", if1, NODE_OUT, 0), - }; - - unordered_set values_3 = { - string("sub1netoutput00") + sub1netoutput_s, - string("sub2netoutput00") + sub2netoutput_s, - string("if10") + if1_s - }; - - CheckResult(ref_builder, keys_1, values_1); - CheckResult(ref_builder, keys_2, values_2); - CheckResult(ref_builder, keys_3, values_3); - -} - -TEST_F(UTTEST_RefRelations, Pass_while) { - auto main_graph = BuildMainGraphWithWhile(); - - auto sub1 = main_graph->GetSubgraph("sub1"); - auto sub2 = main_graph->GetSubgraph("sub2"); - auto while1 = main_graph->FindNode("while1"); - auto sub1data1 = sub1->FindNode("sub1data1"); - auto sub1data2 = sub1->FindNode("sub1data2"); - auto sub1data3 = sub1->FindNode("sub1data3"); - auto sub2data1 = sub2->FindNode("sub2data1"); - auto sub2data2 = sub2->FindNode("sub2data2"); - auto sub2data3 = sub2->FindNode("sub2data3"); - auto sub1netoutput = sub1->FindNode("sub1netoutput"); - auto sub2netoutput = sub2->FindNode("sub2netoutput"); - - string while1_s = std::to_string((unsigned long)while1.get()); - string sub1data1_s = std::to_string((unsigned long)sub1data1.get()); - string sub1data2_s = std::to_string((unsigned long)sub1data2.get()); - string sub1data3_s = std::to_string((unsigned long)sub1data3.get()); - string sub2data1_s = std::to_string((unsigned long)sub2data1.get()); - string sub2data2_s = std::to_string((unsigned long)sub2data2.get()); - string sub2data3_s = std::to_string((unsigned long)sub2data3.get()); - string sub1netoutput_s = std::to_string((unsigned long)sub1netoutput.get()); - string sub2netoutput_s = std::to_string((unsigned long)sub2netoutput.get()); - - RefRelations ref_builder; - auto status = ref_builder.BuildRefRelations(*main_graph); - EXPECT_EQ(status, GRAPH_SUCCESS); - - vector keys_1 = { - RefCell("sub1data1", sub1data1, NODE_IN,0), - RefCell("sub1data1", sub1data1, NODE_OUT,0), - RefCell("sub2data1", sub2data1, NODE_IN,0), - RefCell("sub2data1", sub2data1, NODE_OUT,0), - RefCell("sub2netoutput", sub2netoutput, NODE_IN,0), - RefCell("while1", while1, NODE_IN, 0), - RefCell("while1", while1, NODE_OUT, 0), - }; - unordered_set values_1 = { - string("sub1data100") + sub1data1_s, - string("sub1data110") + sub1data1_s, - string("sub2data100") + sub2data1_s, - string("sub2data110") + sub2data1_s, - string("sub2netoutput00") + sub2netoutput_s, - string("while100") + while1_s, - string("while110") + while1_s - }; - - vector keys_2 = { - RefCell("sub1data2", sub1data2, NODE_IN,0), - RefCell("sub1data2", sub1data2, NODE_OUT,0), - RefCell("sub2data2", sub2data2, NODE_IN,0), - RefCell("sub2data2", sub2data2, NODE_OUT,0), - RefCell("sub2netoutput", sub2netoutput, NODE_IN,1), - RefCell("while1", while1, NODE_IN, 1), - RefCell("while1", while1, NODE_OUT, 1), - }; - unordered_set values_2 = { - string("sub1data200")+ sub1data2_s, - string("sub1data210")+ sub1data2_s, - string("sub2data200")+ sub2data2_s, - string("sub2data210")+ sub2data2_s, - string("sub2netoutput01")+ sub2netoutput_s, - string("while101")+ while1_s, - string("while111")+ while1_s - }; - - vector keys_3 = { - RefCell("sub1data3", sub1data3,NODE_IN,0), - RefCell("sub1data3", sub1data3,NODE_OUT,0), - RefCell("sub2data3", sub2data3,NODE_IN,0), - RefCell("sub2data3", sub2data3,NODE_OUT,0), - RefCell("sub2netoutput", sub2netoutput,NODE_IN,2), - RefCell("while1", while1,NODE_IN, 2), - RefCell("while1", while1,NODE_OUT, 2), - }; - unordered_set values_3 = { - string("sub1data300")+ sub1data3_s, - string("sub1data310")+ sub1data3_s, - string("sub2data300")+ sub2data3_s, - string("sub2data310")+ sub2data3_s, - string("sub2netoutput02")+ sub2netoutput_s, - string("while102")+ while1_s, - string("while112")+ while1_s - }; - CheckResult(ref_builder, keys_1, values_1); - CheckResult(ref_builder, keys_2, values_2); - CheckResult(ref_builder, keys_3, values_3); -} - -TEST_F(UTTEST_RefRelations, Pass_while_2) { - auto main_graph = BuildMainGraphWithWhile2(); - - auto sub1 = main_graph->GetSubgraph("sub1"); - auto sub2 = main_graph->GetSubgraph("sub2"); - auto while1 = main_graph->FindNode("while1"); - auto sub1data1 = sub1->FindNode("sub1data1"); - auto sub1data2 = sub1->FindNode("sub1data2"); - auto sub1data3 = sub1->FindNode("sub1data3"); - auto sub2data1 = sub2->FindNode("sub2data1"); - auto sub2data2 = sub2->FindNode("sub2data2"); - auto sub2data3 = sub2->FindNode("sub2data3"); - auto sub1netoutput = sub1->FindNode("sub1netoutput"); - auto sub2netoutput = sub2->FindNode("sub2netoutput"); - - string while1_s = std::to_string((unsigned long)while1.get()); - string sub1data1_s = std::to_string((unsigned long)sub1data1.get()); - string sub1data2_s = std::to_string((unsigned long)sub1data2.get()); - string sub1data3_s = std::to_string((unsigned long)sub1data3.get()); - string sub2data1_s = std::to_string((unsigned long)sub2data1.get()); - string sub2data2_s = std::to_string((unsigned long)sub2data2.get()); - string sub2data3_s = std::to_string((unsigned long)sub2data3.get()); - string sub1netoutput_s = std::to_string((unsigned long)sub1netoutput.get()); - string sub2netoutput_s = std::to_string((unsigned long)sub2netoutput.get()); - - RefRelations ref_builder; - ref_builder.Clear(); - auto status = ref_builder.BuildRefRelations(*main_graph); - EXPECT_EQ(status, GRAPH_SUCCESS); - - vector keys_1 = { - RefCell("sub1data1",sub1data1, NODE_IN,0), - RefCell("sub1data1",sub1data1 ,NODE_OUT,0), - RefCell("sub2data1",sub2data1 ,NODE_IN,0), - RefCell("sub2data1",sub2data1 ,NODE_OUT,0), - RefCell("sub2netoutput",sub2netoutput ,NODE_IN,1), - RefCell("while1",while1 ,NODE_IN, 0), - RefCell("while1",while1 ,NODE_OUT, 0), - }; - unordered_set values_1 = { - string("sub1data100") + sub1data1_s, - string("sub1data110") + sub1data1_s, - string("sub2data100") + sub2data1_s, - string("sub2data110") + sub2data1_s, - string("sub2netoutput01") + sub2netoutput_s, - string("while100") + while1_s, - string("while110") + while1_s - }; - - // vector keys_2 = { - // RefCell("sub1data2",sub1data2 ,NODE_IN,0), - // RefCell("sub1data2",sub1data2 ,NODE_OUT,0), - // RefCell("sub2data2",sub2data2 ,NODE_IN,0), - // RefCell("sub2data2",sub2data2 ,NODE_OUT,0), - // RefCell("sub2netoutput",sub2netoutput,NODE_IN,0), - // RefCell("sub2netoutput",sub2netoutput ,NODE_OUT,0), - // RefCell("while1",while1 ,NODE_IN, 1), - // RefCell("while1",while1 ,NODE_OUT, 1), - // }; - // unordered_set values_2 = { - // string("sub1data200")+ sub1data2_s, - // string("sub1data210")+ sub1data2_s, - // string("sub2data200")+ sub2data2_s, - // string("sub2data210")+ sub2data2_s, - // string("sub2netoutput00")+ sub2netoutput_s, - // string("sub2netoutput10")+ sub2netoutput_s, - // string("while101")+ while1_s, - // string("while111")+ while1_s - // }; - - - // vector keys_3 = { - // RefCell("sub1data3",sub1data3 ,NODE_IN,0), - // RefCell("sub1data3",sub1data3 ,NODE_OUT,0), - // RefCell("sub2data3",sub2data3 ,NODE_IN,0), - // RefCell("sub2data3",sub2data3 ,NODE_OUT,0), - // RefCell("sub2netoutput",sub2netoutput ,NODE_IN,2), - // RefCell("sub2netoutput",sub2netoutput ,NODE_OUT,2), - // RefCell("while1",while1 ,NODE_IN, 2), - // RefCell("while1",while1 ,NODE_OUT, 2), - // }; - // unordered_set values_3 = { - // string("sub1data300")+ sub1data3_s, - // string("sub1data310")+ sub1data3_s, - // string("sub2data300")+ sub2data3_s, - // string("sub2data310")+ sub2data3_s, - // string("sub2netoutput02")+ sub2netoutput_s, - // string("sub2netoutput12")+ sub2netoutput_s, - // string("while102")+ while1_s, - // string("while112")+ while1_s - // }; - - CheckResult(ref_builder, keys_1, values_1); - // CheckResult(ref_builder, keys_2, values_2); - // CheckResult(ref_builder, keys_3, values_3); -} - -TEST_F(UTTEST_RefRelations, Pass_while3) { - auto main_graph = BuildMainGraphWithWhile3(); - - auto sub1 = main_graph->GetSubgraph("sub1"); - auto sub2 = main_graph->GetSubgraph("sub2"); - auto while1 = main_graph->FindNode("while1"); - auto sub1data0 = sub1->FindNode("sub1data0"); - auto sub1data1 = sub1->FindNode("sub1data1"); - auto sub1data2 = sub1->FindNode("sub1data2"); - auto sub2data0 = sub2->FindNode("sub2data0"); - auto sub2data1 = sub2->FindNode("sub2data1"); - auto sub2data2 = sub2->FindNode("sub2data2"); - auto sub1netoutput = sub1->FindNode("sub1netoutput"); - auto sub2netoutput = sub2->FindNode("sub2netoutput"); - - string while1_s = std::to_string((unsigned long)while1.get()); - string sub1data0_s = std::to_string((unsigned long)sub1data0.get()); - string sub1data1_s = std::to_string((unsigned long)sub1data1.get()); - string sub1data2_s = std::to_string((unsigned long)sub1data2.get()); - string sub2data0_s = std::to_string((unsigned long)sub2data0.get()); - string sub2data1_s = std::to_string((unsigned long)sub2data1.get()); - string sub2data2_s = std::to_string((unsigned long)sub2data2.get()); - string sub1netoutput_s = std::to_string((unsigned long)sub1netoutput.get()); - string sub2netoutput_s = std::to_string((unsigned long)sub2netoutput.get()); - - RefRelations ref_builder; - auto status = ref_builder.BuildRefRelations(*main_graph); - EXPECT_NE(status, GRAPH_SUCCESS); - - vector keys_0 = { - RefCell("sub1data0", sub1data0, NODE_IN,0), - RefCell("sub1data0", sub1data0, NODE_OUT,0), - RefCell("sub2data0", sub2data0, NODE_IN,0), - RefCell("sub2data0", sub2data0, NODE_OUT,0), - RefCell("sub2netoutput", sub2netoutput, NODE_IN,2), - RefCell("while1", while1, NODE_IN, 2), - RefCell("while1", while1, NODE_OUT, 2), - }; - unordered_set values_0 = { - string("sub1data000") + sub1data0_s, - string("sub1data010") + sub1data0_s, - string("sub2data000") + sub2data0_s, - string("sub2data010") + sub2data0_s, - string("sub2netoutput02") + sub2netoutput_s, - string("while102") + while1_s, - string("while112") + while1_s - }; - - vector keys_1 = { - RefCell("sub1data1", sub1data1, NODE_IN,0), - RefCell("sub1data1", sub1data1, NODE_OUT,0), - RefCell("sub2data1", sub2data1, NODE_IN,0), - RefCell("sub2data1", sub2data1, NODE_OUT,0), - RefCell("sub1data2", sub1data2, NODE_IN,0), - RefCell("sub1data2", sub1data2, NODE_OUT,0), - RefCell("sub2data2", sub2data2, NODE_IN,0), - RefCell("sub2data2", sub2data2, NODE_OUT,0), - RefCell("sub2netoutput", sub2netoutput, NODE_IN,0), - RefCell("sub2netoutput", sub2netoutput, NODE_IN,1), - RefCell("while1", while1, NODE_IN, 0), - RefCell("while1", while1, NODE_OUT, 0), - RefCell("while1", while1, NODE_IN, 1), - RefCell("while1", while1, NODE_OUT, 1), - }; - unordered_set values_1 = { - string("sub1data100")+ sub1data1_s, - string("sub1data110")+ sub1data1_s, - string("sub2data100")+ sub2data1_s, - string("sub2data110")+ sub2data1_s, - string("sub1data200")+ sub1data2_s, - string("sub1data210")+ sub1data2_s, - string("sub2data200")+ sub2data2_s, - string("sub2data210")+ sub2data2_s, - string("sub2netoutput01")+ sub2netoutput_s, - string("sub2netoutput00")+ sub2netoutput_s, - string("while100")+ while1_s, - string("while110")+ while1_s, - string("while101")+ while1_s, - string("while111")+ while1_s - }; - - // CheckResult(ref_builder, keys_0, values_0); - CheckResult(ref_builder, keys_1, values_1); -} -TEST_F(UTTEST_RefRelations, Failed_if_1) { - auto main_graph = BuildMainGraphWithIfButWithNoSubgraph(); - RefRelations ref_builder; - auto status = ref_builder.BuildRefRelations(*main_graph); - EXPECT_EQ(status, GRAPH_SUCCESS); -} -} diff --git a/tests/ut/graph/testcase/ref_relation_unittest.cc b/tests/ut/graph/testcase/ref_relation_unittest.cc deleted file mode 100644 index bae3373c97ac13acec4d61a1a4a122246a3be9b9..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/ref_relation_unittest.cc +++ /dev/null @@ -1,141 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/ref_relation.h" -#include "graph/compute_graph.h" -#include "common/util/mem_utils.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/tensor_utils.h" - -namespace ge { -class UtestRefRelation : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -static ge::OpDescPtr CreateOpDesc(string name = "", string type = "", int in_num = 0, int out_num = 0) { - auto op_desc = std::make_shared(name, type); - op_desc->SetStreamId(0); - static int32_t index = 0; - op_desc->SetId(index++); - - GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); - TensorUtils::SetSize(tensor, 64); - vector input_offset; - for (int i = 0; i < in_num; ++i) { - op_desc->AddInputDesc(tensor); - input_offset.emplace_back(index * 64 + i * 64); - } - op_desc->SetInputOffset(input_offset); - - vector output_offset; - for (int i = 0; i < out_num; ++i) { - op_desc->AddOutputDesc(tensor); - output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); - } - op_desc->SetOutputOffset(output_offset); - - op_desc->SetWorkspace({}); - op_desc->SetWorkspaceBytes({}); - - ge::AttrUtils::SetStr(op_desc, ge::TVM_ATTR_NAME_MAGIC, "RT_DEV_BINARY_MAGIC_ELF_AIVEC"); - bool support_dynamic = true; - ge::AttrUtils::GetBool(op_desc, "support_dynamicshape", support_dynamic); - return op_desc; -} - -TEST_F(UtestRefRelation, build_ref_relations_fail) { - ComputeGraphPtr root_graph = std::make_shared("root_graph"); - auto partitioned_call_op_desc = CreateOpDesc("partitioned_call", PARTITIONEDCALL, 3, 1); - auto partitioned_call_node = root_graph->AddNode(partitioned_call_op_desc); - partitioned_call_op_desc->AddSubgraphName("f"); - partitioned_call_op_desc->SetSubgraphInstanceName(0, "sub_graph"); - - ComputeGraphPtr sub_sub_graph1 = std::make_shared("while_cond"); - { - OpDescPtr sub_sub_graph_while_cond_data_op_desc = CreateOpDesc("cond_data", DATA); - NodePtr sub_sub_graph_while_cond_data_node = sub_sub_graph1->AddNode(sub_sub_graph_while_cond_data_op_desc); - sub_sub_graph1->SetParentGraph(root_graph); - root_graph->AddSubGraph(sub_sub_graph1); - } - - ComputeGraphPtr sub_sub_graph2 = std::make_shared("while_body"); - { - OpDescPtr sub_sub_graph_while_body_data_op_desc = CreateOpDesc("body_data", DATA); - NodePtr sub_sub_graph_while_body_data_node = sub_sub_graph2->AddNode(sub_sub_graph_while_body_data_op_desc); - sub_sub_graph2->SetGraphUnknownFlag(true); - sub_sub_graph2->SetParentGraph(root_graph); - root_graph->AddSubGraph(sub_sub_graph2); - } - - // Will unfold to merged_graph. - ComputeGraphPtr sub_graph = std::make_shared("sub_graph"); - { - OpDescPtr sub_graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1); - OpDescPtr sub_graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1); - OpDescPtr sub_graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1); - NodePtr sub_graph_data1_node = sub_graph->AddNode(sub_graph_data1_op_desc); - NodePtr sub_graph_data2_node = sub_graph->AddNode(sub_graph_data2_op_desc); - NodePtr sub_graph_data3_node = sub_graph->AddNode(sub_graph_data3_op_desc); - - AttrUtils::SetInt(sub_graph_data1_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(sub_graph_data2_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 1); - AttrUtils::SetInt(sub_graph_data3_op_desc, ATTR_NAME_PARENT_NODE_INDEX, 2); - - OpDescPtr sub_graph_while_op_desc = CreateOpDesc("while", WHILE, 2, 2); - NodePtr sub_graph_while_node = sub_graph->AddNode(sub_graph_while_op_desc); - sub_sub_graph1->SetParentNode(sub_graph_while_node); - sub_sub_graph2->SetParentNode(sub_graph_while_node); - sub_graph_while_op_desc->AddSubgraphName("while_cond"); - sub_graph_while_op_desc->SetSubgraphInstanceName(0, "while_cond"); - sub_graph_while_op_desc->AddSubgraphName("while_body"); - sub_graph_while_op_desc->SetSubgraphInstanceName(1, "while_body"); - - OpDescPtr sub_graph_matmul_op_desc = CreateOpDesc("matmul", MATMUL, 2, 1); - NodePtr sub_graph_matmul_node = sub_graph->AddNode(sub_graph_matmul_op_desc); - - OpDescPtr sub_graph_output_op_desc = CreateOpDesc("output", NETOUTPUT, 1, 1); - NodePtr sub_graph_output_node = sub_graph->AddNode(sub_graph_output_op_desc); - - GraphUtils::AddEdge(sub_graph_data1_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(sub_graph_data2_node->GetOutDataAnchor(0), sub_graph_while_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(sub_graph_data3_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(sub_graph_while_node->GetOutDataAnchor(0), sub_graph_matmul_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(sub_graph_matmul_node->GetOutDataAnchor(0), sub_graph_output_node->GetInDataAnchor(0)); - - sub_graph->SetGraphUnknownFlag(true); - sub_graph->SetParentNode(partitioned_call_node); - sub_graph->SetParentGraph(root_graph); - root_graph->AddSubGraph(sub_graph); - } - - OpDescPtr graph_data1_op_desc = CreateOpDesc("data1", DATA, 1, 1); - OpDescPtr graph_data2_op_desc = CreateOpDesc("data2", DATA, 1, 1); - OpDescPtr graph_data3_op_desc = CreateOpDesc("data3", DATA, 1, 1); - NodePtr graph_data1_node = root_graph->AddNode(graph_data1_op_desc); - NodePtr graph_data2_node = root_graph->AddNode(graph_data2_op_desc); - NodePtr graph_data3_node = root_graph->AddNode(graph_data3_op_desc); - AttrUtils::SetInt(graph_data1_op_desc, ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(graph_data2_op_desc, ATTR_NAME_INDEX, 1); - AttrUtils::SetInt(graph_data3_op_desc, ATTR_NAME_INDEX, 2); - GraphUtils::AddEdge(graph_data1_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(graph_data2_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(graph_data3_node->GetOutDataAnchor(0), partitioned_call_node->GetInDataAnchor(2)); - - RefRelations ref; - EXPECT_NE(ref.BuildRefRelations(*root_graph.get()), ge::SUCCESS); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/resource_context_mgr_unittest.cc b/tests/ut/graph/testcase/resource_context_mgr_unittest.cc deleted file mode 100644 index ea59fbc71175d37e276c5f7f2cb2a37fd457a435..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/resource_context_mgr_unittest.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/resource_context_mgr.h" -#include "graph_builder_utils.h" - -namespace ge { -namespace { - struct TestResourceContext : ResourceContext { - std::vector shapes; - std::string resource_type; - }; -} - -class ResourceInferenceContextMgrTest : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(ResourceInferenceContextMgrTest, TestSetAndGetResourceContext) { - // prepare resource_context - string resource_key = "123"; - std::vector resource_shapes = {GeShape({1,1,2,3})}; - TestResourceContext *resource_context = new TestResourceContext(); - resource_context->shapes = resource_shapes; - resource_context->resource_type = "normal"; - - // set resource_context to mgr - ResourceContextMgr resource_context_mgr; - resource_context_mgr.SetResourceContext(resource_key, resource_context); - - TestResourceContext *test_resource_context = - dynamic_cast(resource_context_mgr.GetResourceContext(resource_key)); - // check result - auto ret_shape = test_resource_context->shapes.at(0); - auto ret_type = test_resource_context->resource_type; - ASSERT_EQ(ret_shape.GetDims(), resource_context->shapes.at(0).GetDims()); - ASSERT_EQ(ret_type, resource_context->resource_type); -} - -TEST_F(ResourceInferenceContextMgrTest, TestRegsiterNodesReliedOnResource) { - string resource_key = "123"; - auto builder = ut::GraphBuilder("g"); - - auto read_node_1 = builder.AddNode("stackpop", "stackPop", 1, 1); - auto read_node_2 = builder.AddNode("tensorAarrayRead", "TensorArrayRead", 1, 1); - ResourceContextMgr resource_context_mgr; - // register one node - resource_context_mgr.RegisterNodeReliedOnResource(resource_key, read_node_1); - auto read_nodes = resource_context_mgr.MutableNodesReliedOnResource(resource_key); - ASSERT_EQ(read_nodes.size(), 1); - - // register second node - resource_context_mgr.RegisterNodeReliedOnResource(resource_key, read_node_2); - read_nodes = resource_context_mgr.MutableNodesReliedOnResource(resource_key); - ASSERT_EQ(read_nodes.size(), 2); - vector expect_read_nodes = {read_node_1, read_node_2}; - for (const auto &expect_node : expect_read_nodes) { - ASSERT_TRUE(read_nodes.count(expect_node) > 0); - } -} - -TEST_F(ResourceInferenceContextMgrTest, TestRegsiterDuplicateNodeReliedOnResource) { - string resource_key = "123"; - auto builder = ut::GraphBuilder("g"); - - auto read_node = builder.AddNode("stack", "stack", 1, 1); - ResourceContextMgr resource_context_mgr; - resource_context_mgr.RegisterNodeReliedOnResource(resource_key, read_node); - // check add same node to context - resource_context_mgr.RegisterNodeReliedOnResource(resource_key, read_node); - auto read_nodes = resource_context_mgr.MutableNodesReliedOnResource(resource_key); - ASSERT_EQ(read_nodes.size(), 1); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/runtime_inference_context_unittest.cc b/tests/ut/graph/testcase/runtime_inference_context_unittest.cc deleted file mode 100644 index 00d6320aaf2acf1809d9782da1b68fa5a3250ad5..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/runtime_inference_context_unittest.cc +++ /dev/null @@ -1,37 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/runtime_inference_context.h" - -namespace ge { -class RuntimeInferenceContextTest : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - - -TEST_F(RuntimeInferenceContextTest, TestSetGetTensor) { - RuntimeInferenceContext ctx; - GeTensorDesc desc; - GeTensorPtr ge_tensor = std::make_shared(desc); - ASSERT_EQ(ctx.SetTensor(1, 3, ge_tensor), GRAPH_SUCCESS); - GeTensorPtr new_tensor; - ASSERT_EQ(ctx.GetTensor(1, 3, new_tensor), GRAPH_SUCCESS); - ASSERT_NE(ctx.GetTensor(2, 0, new_tensor), GRAPH_SUCCESS); - ASSERT_NE(ctx.GetTensor(2, -1, new_tensor), GRAPH_SUCCESS); - ASSERT_NE(ctx.GetTensor(1, 4, new_tensor), GRAPH_SUCCESS); - ASSERT_NE(ctx.GetTensor(1, 0, new_tensor), GRAPH_SUCCESS); - ctx.Release(); - ASSERT_NE(ctx.GetTensor(1, 3, new_tensor), GRAPH_SUCCESS); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/screen_printer_unittest.cc b/tests/ut/graph/testcase/screen_printer_unittest.cc deleted file mode 100644 index e7a7439b47d4cea62a53d4252efd5969373b09f7..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/screen_printer_unittest.cc +++ /dev/null @@ -1,208 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "common/screen_printer.h" - -#include -#include -#include "graph/ge_context.h" -#include "graph/ge_local_context.h" -#include "mmpa/mmpa_api.h" -#include "tests/depends/mmpa/src/mmpa_stub.h" - -namespace ge { -namespace { -constexpr const char *kFormatTime = "[2023-08-08-20:08:00.001.001]"; - -int32_t system_time_ret = 0; -int32_t time_of_day_ret = 0; -class MockMmpa : public ge::MmpaStubApi { - public: - INT32 mmGetSystemTime(mmSystemTime_t *sysTime) override{ - if (system_time_ret == -1) { - return EN_ERR; - } - sysTime->wYear = 2023; - sysTime->wMonth = 8; - sysTime->wDay = 8; - sysTime->wHour = 20; - sysTime->wMinute = 8; - sysTime->wSecond = 0; - return EN_OK; - } - INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone) override { - if (time_of_day_ret == -1) { - return EN_ERR; - } - timeVal->tv_usec = 1001; - timeVal->tv_sec = 1001; - return EN_OK; - } -}; -} -class UtestScreenPrinter : public testing::Test { - protected: - void SetUp() { - MmpaStub::GetInstance().SetImpl(std::make_shared()); - } - - void TearDown() { - MmpaStub::GetInstance().Reset(); - } -}; - -TEST_F(UtestScreenPrinter, log_ok) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - std::string option = "input_shape_range"; - SCREEN_LOG("Option %s is deprecated", option.c_str()); - - std::cout.rdbuf(coutbuf); - std::string out_log = ss.str(); - std::string expect_log = kFormatTime + std::to_string(mmGetTid()) + " Option input_shape_range is deprecated" + "\n"; - EXPECT_EQ(out_log, expect_log); -} - -TEST_F(UtestScreenPrinter, multi_thread_log_ok) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - auto func = [](char c) -> void { - std::string option(100, c); - SCREEN_LOG("%s", option.c_str()); - }; - - char a = 'a'; - char b = 'b'; - char c = 'c'; - std::thread t1(func, a); - std::thread t2(func, b); - std::thread t3(func, c); - t1.join(); - t2.join(); - t3.join(); - std::cout.rdbuf(coutbuf); - std::string out_log = ss.str(); - - std::string expect_a(100, a); - std::string expect_b(100, b); - std::string expect_c(100, c); - std::unordered_set expect_set; - std::string tmp; - while(getline(ss, tmp)) { - if (tmp.find(expect_a) != std::string::npos) { - expect_set.emplace(expect_a); - } else if (tmp.find(expect_b) != std::string::npos) { - expect_set.emplace(expect_b); - } else if (tmp.find(expect_c) != std::string::npos) { - expect_set.emplace(expect_c); - } - } - EXPECT_EQ(expect_set.size(), 3); -} - -TEST_F(UtestScreenPrinter, log_len_ok) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - std::string option(1024, 'a'); - SCREEN_LOG("%s", option.c_str()); - std::cout.rdbuf(coutbuf); - std::string expect_log = kFormatTime + std::to_string(mmGetTid()) + " " + option + "\n"; - EXPECT_EQ(ss.str(), expect_log); -} - -TEST_F(UtestScreenPrinter, log_len_over) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - std::string option(1025, 'a'); - SCREEN_LOG("%s", option.c_str()); - std::cout.rdbuf(coutbuf); - std::string expect_log = ""; - EXPECT_EQ(ss.str(), expect_log); -} - -TEST_F(UtestScreenPrinter, fmt_nullptr) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - SCREEN_LOG(nullptr); - - std::cout.rdbuf(coutbuf); - std::string out_log = ss.str(); - std::string expect_log = ""; - EXPECT_EQ(out_log, expect_log); -} - -TEST_F(UtestScreenPrinter, log_time_err) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - system_time_ret = -1; - std::string option = "input_shape_range"; - SCREEN_LOG("Option %s is deprecated", option.c_str()); - system_time_ret = 0; - - time_of_day_ret = -1; - SCREEN_LOG("Option %s is deprecated", option.c_str()); - time_of_day_ret = 0; - - std::cout.rdbuf(coutbuf); - std::string out_log = ss.str(); - std::string expect_log = ""; - expect_log += expect_log; - EXPECT_EQ(out_log, expect_log); -} - -TEST_F(UtestScreenPrinter, log_disable) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - std::map options; - options.emplace("ge.screen_print_mode", "disable"); - ScreenPrinter::GetInstance().Init(options["ge.screen_print_mode"]); - - std::string option = "input_shape_range"; - SCREEN_LOG("Option %s is deprecated", option.c_str()); - - std::cout.rdbuf(coutbuf); - std::string out_log = ss.str(); - std::string expect_log = ""; - EXPECT_EQ(out_log, expect_log); - GetThreadLocalContext().SetGlobalOption(std::map{}); -} - -TEST_F(UtestScreenPrinter, log_ensable) { - std::stringstream ss; - std::streambuf *coutbuf = std::cout.rdbuf(); - std::cout.rdbuf(ss.rdbuf()); - - std::map options; - options.emplace("ge.screen_print_mode", "enable"); - ScreenPrinter::GetInstance().Init(options["ge.screen_print_mode"]); - - std::string option = "input_shape_range"; - SCREEN_LOG("Option %s is deprecated", option.c_str()); - - std::cout.rdbuf(coutbuf); - std::string out_log = ss.str(); - std::string expect_log = kFormatTime + std::to_string(mmGetTid()) + " Option input_shape_range is deprecated" + "\n"; - EXPECT_EQ(out_log, expect_log); - GetThreadLocalContext().SetGlobalOption(std::map{}); -} -} diff --git a/tests/ut/graph/testcase/serialization_util_unittest.cc b/tests/ut/graph/testcase/serialization_util_unittest.cc deleted file mode 100644 index 5f3f5b57f49145621461a9e457862a0951ef2a44..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/serialization_util_unittest.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/serialization/utils/serialization_util.h" - -namespace ge { -class SerializationUtilUTest : public testing::Test { - public: - proto::DataType proto_data_type_; - DataType ge_data_type_; - - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(SerializationUtilUTest, GetComplex32ProtoDataType) { - SerializationUtil::GeDataTypeToProto(DT_COMPLEX32, proto_data_type_); - EXPECT_EQ(proto_data_type_, proto::DT_COMPLEX32); - SerializationUtil::ProtoDataTypeToGe(proto::DT_COMPLEX32, ge_data_type_); - EXPECT_EQ(ge_data_type_, DT_COMPLEX32); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/shape_refiner_unittest.cc b/tests/ut/graph/testcase/shape_refiner_unittest.cc deleted file mode 100644 index 2dfd7d7af3ab3348fbf161dd3e053b5bebb80b5d..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/shape_refiner_unittest.cc +++ /dev/null @@ -1,477 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/shape_refiner.h" -#include "graph/operator_factory_impl.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/debug/ge_attr_define.h" - -#include -#include -#include - -namespace ge { -namespace { -static NodePtr CreateNode(const ComputeGraphPtr &graph, const string &name, const string &type, int in_num, int out_num) { - OpDescPtr op_desc = std::make_shared(name, type); - op_desc->SetStreamId(0); - static int32_t index = 0; - op_desc->SetId(index++); - - GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT); - tensor.SetOriginFormat(FORMAT_NCHW); - tensor.SetOriginDataType(DT_FLOAT); - TensorUtils::SetSize(tensor, 512); - vector input_offset; - for (int i = 0; i < in_num; i++) { - op_desc->AddInputDesc(tensor); - input_offset.emplace_back(1024); - } - op_desc->SetInputOffset(input_offset); - - vector output_offset; - for (int i = 0; i < out_num; i++) { - op_desc->AddOutputDesc(tensor); - output_offset.emplace_back(1024); - } - op_desc->SetOutputOffset(output_offset); - - op_desc->SetWorkspace({}); - op_desc->SetWorkspaceBytes({}); - op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); - - const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; - op_desc->AddInferFunc(stub_func); - op_desc->AddInferFormatFunc(stub_func); - op_desc->AddVerifierFunc(stub_func); - - return graph->AddNode(op_desc); -} - -/* - * Data1 - * sub_data1 | sub_data2 sub_data3 - * | PartitionedCall2 ===> | | - * relu1 | PartitionedCall3 ===> relu2 - * | <=== PartitionedCall1 | | - * sub_output1 | sub_output2 sub_output3 - * netoutput - */ -ComputeGraphPtr CreateGraphWithMultiSubgraph() { - ut::GraphBuilder builder = ut::GraphBuilder("root_graph"); - auto data = builder.AddNode("Data1", "Data", 1, 1); - auto partcall1 = builder.AddNode("partcall1", "PartitionedCall", 1, 1); - auto partcall2 = builder.AddNode("partcall2", "PartitionedCall", 1, 1); - auto netoutput = builder.AddNode("netoutput", "NetOutput", 1, 0); - - builder.AddDataEdge(data, 0, partcall2, 0); - builder.AddDataEdge(partcall2, 0, partcall1, 0); - builder.AddDataEdge(partcall1, 0, netoutput, 0); - auto root_graph = builder.GetGraph(); - - ut::GraphBuilder sub_builder1 = ut::GraphBuilder("sub_graph1"); - auto sub_data1 = sub_builder1.AddNode("sub_data1", "Data", 1, 1); - auto data1_desc = sub_data1->GetOpDesc(); - AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); - auto sub_relu1 = sub_builder1.AddNode("sub_relu1", "Relu", 1, 1); - auto sub_output1 = sub_builder1.AddNode("sub_output1", "NetOutput", 1, 0); - sub_builder1.AddDataEdge(sub_data1, 0, sub_relu1, 0); - sub_builder1.AddDataEdge(sub_relu1, 0, sub_output1, 0); - auto subgraph1 = sub_builder1.GetGraph(); - - ut::GraphBuilder sub_builder2 = ut::GraphBuilder("sub_graph2"); - auto sub_data2 = sub_builder2.AddNode("sub_data2", "Data", 1, 1); - auto partcall3 = sub_builder2.AddNode("partcall3", "PartitionedCall", 1, 1); - auto sub_output2 = sub_builder2.AddNode("sub_output2", "NetOutput", 1, 0); - auto output2_desc = sub_output2->GetOpDesc(); - auto output2_desc_in = output2_desc->MutableInputDesc(0); - AttrUtils::SetInt(output2_desc_in, "_parent_node_index", 0); - sub_builder2.AddDataEdge(sub_data2, 0, partcall3, 0); - sub_builder2.AddDataEdge(partcall3, 0, sub_output2, 0); - auto subgraph2 = sub_builder2.GetGraph(); - - ut::GraphBuilder sub_builder3 = ut::GraphBuilder("sub_graph3"); - auto sub_data3 = sub_builder3.AddNode("sub_data3", "Data", 1, 1); - auto sub_relu2 = sub_builder3.AddNode("sub_relu2", "Relu", 1, 1); - auto sub_output3 = sub_builder3.AddNode("sub_output3", "NetOutput", 1, 0); - auto output3_desc = sub_output3->GetOpDesc(); - auto output3_desc_in = output3_desc->MutableInputDesc(0); - AttrUtils::SetInt(output3_desc_in, "_parent_node_index", 0); - sub_builder3.AddDataEdge(sub_data3, 0, sub_relu2, 0); - sub_builder3.AddDataEdge(sub_relu2, 0, sub_output3, 0); - auto subgraph3 = sub_builder3.GetGraph(); - - auto part_node1 = root_graph->FindNode("partcall1"); - auto part_desc1 = part_node1->GetOpDesc(); - part_desc1->AddSubgraphName("sub_graph1"); - part_desc1->SetSubgraphInstanceName(0, "sub_graph1"); - - subgraph1->SetParentNode(part_node1); - subgraph1->SetParentGraph(root_graph); - root_graph->AddSubgraph("sub_graph1", subgraph1); - - auto part_node2 = root_graph->FindNode("partcall2"); - auto part_desc2 = part_node2->GetOpDesc(); - part_desc2->AddSubgraphName("sub_graph2"); - part_desc2->SetSubgraphInstanceName(0, "sub_graph2"); - - subgraph2->SetParentNode(part_node2); - subgraph2->SetParentGraph(root_graph); - root_graph->AddSubgraph("sub_graph2", subgraph2); - - auto part_node3 = subgraph2->FindNode("partcall3"); - auto part_desc3 = part_node3->GetOpDesc(); - part_desc3->AddSubgraphName("sub_graph3"); - part_desc3->SetSubgraphInstanceName(0, "sub_graph3"); - - subgraph3->SetParentNode(part_node3); - subgraph3->SetParentGraph(subgraph2); - root_graph->AddSubgraph(subgraph3); - - return root_graph; -} - -/* - * Data1 - * | - * relu1 sub_data0 - * | | - * PartitionedCall0 ===> sub_output0 - * | sub_data1 - * PartitionedCall1 ===> | - * | sub_output1 - * relu2 - * | - * netoutput - */ -ComputeGraphPtr CreateGraphWithSubgraphDataToNetoutput() { - ut::GraphBuilder builder = ut::GraphBuilder("root_graph"); - auto data = builder.AddNode("Data1", "Data", 1, 1); - auto relu1 = builder.AddNode("relu1", "Relu", 1, 1); - auto partcall0 = builder.AddNode("partcall0", "PartitionedCall", 1, 1); - auto partcall1 = builder.AddNode("partcall1", "PartitionedCall", 1, 1); - auto relu2 = builder.AddNode("relu2", "Relu", 1, 1); - auto netoutput = builder.AddNode("netoutput", "NetOutput", 1, 0); - - builder.AddDataEdge(data, 0, relu1, 0); - builder.AddDataEdge(relu1, 0, partcall0, 0); - builder.AddDataEdge(partcall0, 0, partcall1, 0); - builder.AddDataEdge(partcall1, 0, relu2, 0); - builder.AddDataEdge(relu2, 0, netoutput, 0); - auto root_graph = builder.GetGraph(); - - ut::GraphBuilder sub_builder1 = ut::GraphBuilder("sub_graph1"); - auto sub_data1 = sub_builder1.AddNode("sub_data1", "Data", 1, 1); - auto data1_desc = sub_data1->GetOpDesc(); - AttrUtils::SetInt(data1_desc, "_parent_node_index", 0); - auto sub_output1 = sub_builder1.AddNode("sub_output1", "NetOutput", 1, 0); - auto output1_desc = sub_output1->GetOpDesc(); - auto output1_desc_in = output1_desc->MutableInputDesc(0); - AttrUtils::SetInt(output1_desc_in, "_parent_node_index", 0); - sub_builder1.AddDataEdge(sub_data1, 0, sub_output1, 0); - auto subgraph1 = sub_builder1.GetGraph(); - - auto part_node1 = root_graph->FindNode("partcall1"); - auto part_desc1 = part_node1->GetOpDesc(); - part_desc1->AddSubgraphName("sub_graph1"); - part_desc1->SetSubgraphInstanceName(0, "sub_graph1"); - - subgraph1->SetParentNode(part_node1); - subgraph1->SetParentGraph(root_graph); - root_graph->AddSubgraph("sub_graph1", subgraph1); - - ut::GraphBuilder sub_builder0 = ut::GraphBuilder("sub_graph0"); - auto sub_data0 = sub_builder0.AddNode("sub_data0", "Data", 1, 1); - auto data0_desc = sub_data0->GetOpDesc(); - AttrUtils::SetInt(data0_desc, "_parent_node_index", 0); - auto sub_output0 = sub_builder0.AddNode("sub_output0", "NetOutput", 1, 0); - auto output0_desc = sub_output0->GetOpDesc(); - auto output0_desc_in = output0_desc->MutableInputDesc(0); - AttrUtils::SetInt(output0_desc_in, "_parent_node_index", 0); - sub_builder0.AddDataEdge(sub_data0, 0, sub_output0, 0); - auto subgraph0 = sub_builder0.GetGraph(); - - auto part_node0 = root_graph->FindNode("partcall0"); - auto part_desc0 = part_node0->GetOpDesc(); - part_desc0->AddSubgraphName("sub_graph0"); - part_desc0->SetSubgraphInstanceName(0, "sub_graph0"); - - subgraph0->SetParentNode(part_node0); - subgraph0->SetParentGraph(root_graph); - root_graph->AddSubgraph("sub_graph0", subgraph0); - return root_graph; -} - -/* cond_graph and body_graph share the same input tensor - * +-------------+ +-------------+ - * | Cond Graph | | Body Graph | - * | NetOutput | | NetOutput | - * | | | | | | - * NetOutput | LessThan_5 | | Add_1 | - * | | | | | | | - * while -----+ input | + input | - * | +-------------+ +-------------+ - * input - */ - -ComputeGraphPtr BuildSimpleWhileGraph2() { - const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; }; - const std::vector shape{-1, -1, 224, 224}; - // build main graph - ut::GraphBuilder main_builder("main_graph"); - auto data_1 = main_builder.AddNode("data_1", "Data", 1, 1); - auto data_2 = main_builder.AddNode("data_2", "Data", 1, 1); - auto while_1 = main_builder.AddNode("while_1", "While", 1, 1); - auto output_1 = main_builder.AddNode("output_1", "NetOutput", 1, 1); - main_builder.AddDataEdge(data_1, 0, while_1, 0); - main_builder.AddDataEdge(while_1, 0, output_1, 0); - while_1->GetOpDesc()->AddInferFunc(stub_func); - auto main_graph = main_builder.GetGraph(); - AttrUtils::SetInt(data_1->GetOpDesc(), ATTR_NAME_INDEX, 0); - output_1->GetOpDesc()->SetSrcName({"while_1"}); - output_1->GetOpDesc()->SetSrcIndex({0, 1}); - - // build condition graph - ut::GraphBuilder cond_builder("cond_graph"); - auto cond_data_1 = cond_builder.AddNode("cond_data_1", "Data", 1, 1); - auto cond_less_1 = cond_builder.AddNode("foo", "LessThan_5", 1, 1); - auto cond_output_1 = cond_builder.AddNode("cond_output_1", "NetOutput", 1, 1); - cond_builder.AddDataEdge(cond_data_1, 0, cond_less_1, 0); - cond_builder.AddDataEdge(cond_less_1, 0, cond_output_1, 0); - auto cond_graph = cond_builder.GetGraph(); - cond_output_1->GetOpDesc()->SetSrcName({"foo"}); - cond_output_1->GetOpDesc()->SetSrcIndex({0}); - cond_output_1->GetOpDesc()->UpdateInputDesc(0, GeTensorDesc(GeShape(), FORMAT_ND, DT_BOOL)); - AttrUtils::SetInt(cond_data_1->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(cond_data_1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - - // build body graph - ut::GraphBuilder body_builder("body_graph"); - auto body_data_1 = body_builder.AddNode("body_data_1", "Data", 1, 1); - auto body_add_1 = body_builder.AddNode("bar", "Add_1", 1, 1); - // out_shape contains unknown dims (-1) - auto body_output_1 = body_builder.AddNode("body_output_1", "NetOutput", 1, 1, FORMAT_NCHW, DT_FLOAT, shape); - body_builder.AddDataEdge(body_data_1, 0, body_add_1, 0); - body_builder.AddDataEdge(body_add_1, 0, body_output_1, 0); - auto body_graph = body_builder.GetGraph(); - AttrUtils::SetInt(body_data_1->GetOpDesc(), ATTR_NAME_INDEX, 0); - AttrUtils::SetInt(body_data_1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0); - AttrUtils::SetInt(body_output_1->GetOpDesc()->MutableInputDesc(0), ATTR_NAME_PARENT_NODE_INDEX, 0); - - // setup parent graph and sub-graphs - cond_graph->SetParentGraph(main_graph); - cond_graph->SetParentNode(main_graph->FindNode("while_1")); - body_graph->SetParentGraph(main_graph); - body_graph->SetParentNode(main_graph->FindNode("while_1")); - main_graph->FindNode("while_1")->GetOpDesc()->AddSubgraphName("cond"); - main_graph->FindNode("while_1")->GetOpDesc()->SetSubgraphInstanceName(0, cond_graph->GetName()); - main_graph->AddSubgraph("cond_graph", cond_graph); - main_graph->FindNode("while_1")->GetOpDesc()->AddSubgraphName("body"); - main_graph->FindNode("while_1")->GetOpDesc()->SetSubgraphInstanceName(1, body_graph->GetName()); - main_graph->AddSubgraph("body_graph", body_graph); - - main_graph->SetGraphUnknownFlag(true); - for (auto &subgraph : main_graph->GetAllSubgraphs()) { - subgraph->SetGraphUnknownFlag(true); - } - - return main_graph; -} -} // namespace - -class UtestShapeRefiner : public testing::Test { - protected: - void SetUp() { - dlog_setlevel(GE_MODULE_NAME, 0, 1); - operator_infershape_funcs_bak_ = OperatorFactoryImpl::operator_infershape_funcs_; - OperatorFactoryImpl::operator_infershape_funcs_.reset(new (std::nothrow) std::map()); - } - void TearDown() { - OperatorFactoryImpl::operator_infershape_funcs_ = operator_infershape_funcs_bak_; - } -private: - shared_ptr> operator_infershape_funcs_bak_; -}; - -TEST_F(UtestShapeRefiner, InferShapeAndTypeForRunning_Success) { - OperatorFactoryImpl::operator_infershape_funcs_->emplace("Merge", [](Operator &op) { return GRAPH_SUCCESS; }); - OperatorFactoryImpl::operator_infershape_funcs_->emplace("Enter", [](Operator &op) { return GRAPH_SUCCESS; }); - - const auto graph = std::make_shared("test_infer_shape"); - auto enter1 = CreateNode(graph, "enter", "Enter", 1, 1); - auto op_enter = OpDescUtils::CreateOperatorFromNode(enter1); - EXPECT_EQ(ShapeRefiner::InferShapeAndTypeForRunning(enter1, op_enter, true), GRAPH_SUCCESS); - - auto merge1 = CreateNode(graph, "merge1", "StreamMerge", 2, 2); - auto op = OpDescUtils::CreateOperatorFromNode(merge1); - merge1->GetOpDesc()->AddInferFunc(nullptr); - EXPECT_EQ(ShapeRefiner::InferShapeAndTypeForRunning(merge1, op, true), GRAPH_SUCCESS); -} - -TEST_F(UtestShapeRefiner, InferShapeAndTypeForRunning_Failure_NullInferFunc) { - const auto graph = std::make_shared("test_infer_shape"); - - OperatorFactoryImpl::operator_infershape_funcs_.reset(new (std::nothrow) std::map()); - auto merge1 = CreateNode(graph, "merge1", "StreamMerge", 2, 2); - auto op = OpDescUtils::CreateOperatorFromNode(merge1); - merge1->GetOpDesc()->AddInferFunc(nullptr); - EXPECT_EQ(ShapeRefiner::InferShapeAndTypeForRunning(merge1, op, true), GRAPH_FAILED); -} - -TEST_F(UtestShapeRefiner, CreateInferenceContext_Success_CrossSubgraph) { - OperatorFactoryImpl::operator_infershape_funcs_->emplace("Relu", [](Operator &op) { return GRAPH_SUCCESS; }); - auto graph = CreateGraphWithMultiSubgraph(); - graph->SetGraphUnknownFlag(false); - auto subgraph = graph->GetSubgraph("sub_graph1"); - auto relu = subgraph->FindNode("sub_relu1"); - - EXPECT_EQ(ShapeRefiner::InferShapeAndType(relu, false), GRAPH_SUCCESS); - auto in_data_node = relu->GetInDataNodes().at(0); - int32_t out_idx = 0; - std::map nodes_idx; - auto ret = ShapeRefiner::GetRealInNodesAndIndex(in_data_node, out_idx, nodes_idx); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(nodes_idx.size(), 1); - for (const auto &node_idx : nodes_idx) { - EXPECT_EQ(node_idx.first->GetName(), "sub_relu2"); - } -} - -TEST_F(UtestShapeRefiner, CreateInferenceContext_Success_CrossSubgraphDataToNetoutput) { - OperatorFactoryImpl::operator_infershape_funcs_->emplace("Relu", [](Operator &op) { return GRAPH_SUCCESS; }); - auto graph = CreateGraphWithSubgraphDataToNetoutput(); - auto relu = graph->FindNode("relu2"); - - EXPECT_EQ(ShapeRefiner::InferShapeAndType(relu, false), GRAPH_SUCCESS); - auto in_data_node = relu->GetInDataNodes().at(0); - int32_t out_idx = 0; - std::map nodes_idx; - auto ret = ShapeRefiner::GetRealInNodesAndIndex(in_data_node, out_idx, nodes_idx); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(nodes_idx.size(), 1); - for (const auto &node_idx : nodes_idx) { - EXPECT_EQ(node_idx.first->GetName(), "relu1"); - } -} - -TEST_F(UtestShapeRefiner, InferShapeAndType_Failure_InvalidNode) { - const auto graph = std::make_shared("test_infer_shape"); - auto enter1 = CreateNode(graph, "enter", "Enter", 1, 1); - - EXPECT_EQ(ShapeRefiner::InferShapeAndType(enter1, true), GRAPH_FAILED); -} - -// 看起来是无效ut -TEST_F(UtestShapeRefiner, UpdateOutputForMultiBatch) { - auto graph = CreateGraphWithMultiSubgraph(); - graph->SetGraphUnknownFlag(false); - auto subgraph = graph->GetSubgraph("sub_graph1"); - auto relu = subgraph->FindNode("sub_relu1"); - - auto op = OpDescUtils::CreateOperatorFromNode(relu); - auto ret = ShapeRefiner::InferShapeAndType(relu, op, false); - EXPECT_EQ(ret, GRAPH_PARAM_INVALID); -} - -TEST_F(UtestShapeRefiner, InferShapeAndType_Success_WithMultiSubgraphs) { - OperatorFactoryImpl::operator_infershape_funcs_->emplace("Relu", [](Operator &op) { return GRAPH_SUCCESS; }); - auto graph = CreateGraphWithMultiSubgraph(); - graph->SetGraphUnknownFlag(false); - auto subgraph = graph->GetSubgraph("sub_graph1"); - auto relu = subgraph->FindNode("sub_relu1"); - - ShapeRefiner::ClearContextMap(); - - auto subgraph3 = graph->GetSubgraph("sub_graph3"); - auto relu2 = subgraph3->FindNode("sub_relu2"); - - InferenceContextPtr inference_context; - ShapeRefiner::CreateInferenceContext(relu2, inference_context); - ShapeRefiner::PushToContextMap(relu2, inference_context); - - auto ret = ShapeRefiner::InferShapeAndType(relu); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestShapeRefiner, InferShapeAndType_Success_SingleNode) { - auto graph = std::make_shared("test_infer_shape"); - auto node = CreateNode(graph, "enter", "Enter", 1, 1); - auto op = OpDescUtils::CreateOperatorFromNode(node); - bool before_subgraph = false; - - auto ret = ShapeRefiner::InferShapeAndType(node, op, before_subgraph); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestShapeRefiner, InferShapeAndType_Success_WithEmptySubgraph) { - auto root_graph = std::make_shared("test_infer_shape"); - auto root_node = CreateNode(root_graph, "enter", "Enter", 1, 1); - auto op_desc = root_node->GetOpDesc(); - op_desc->AddSubgraphName("sub_graph"); - op_desc->SetSubgraphInstanceName(0, "sub_graph"); - - auto subgraph = std::make_shared("sub_graph"); - subgraph->SetParentNode(root_node); - subgraph->SetParentGraph(root_graph); - root_graph->AddSubgraph("sub_graph", subgraph); - - Operator op = OpDescUtils::CreateOperatorFromNode(root_node); - - auto ret = ShapeRefiner::InferShapeAndType(root_node, op, false); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestShapeRefiner, InferShapeAndType_Success_WithSubgraph) { - auto root_graph = std::make_shared("test_infer_shape"); - NodePtr root_node = CreateNode(root_graph, "enter", "Enter", 1, 1); - auto op_desc = root_node->GetOpDesc(); - op_desc->AddSubgraphName("sub_graph"); - op_desc->SetSubgraphInstanceName(0, "sub_graph"); - - auto subgraph = std::make_shared("sub_graph"); - NodePtr sub_node = CreateNode(subgraph, "netoutput", "Netoutput", 1, 1); - auto sub_op_desc = sub_node->GetOpDesc(); - sub_op_desc->SetType(NETOUTPUT); - - subgraph->SetParentNode(root_node); - subgraph->SetParentGraph(root_graph); - root_graph->AddSubgraph("sub_graph", subgraph); - - Operator op = OpDescUtils::CreateOperatorFromNode(root_node); - - auto ret = ShapeRefiner::InferShapeAndType(root_node, op, false); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(UtestShapeRefiner, InferShapeAndType_Success_CheckRangeForWhile) { - auto root_graph = BuildSimpleWhileGraph2(); - auto while_node = root_graph->FindNode("while_1"); - auto op = OpDescUtils::CreateOperatorFromNode(while_node); - EXPECT_EQ(ShapeRefiner::InferShapeAndType(while_node, op, false), GRAPH_SUCCESS); - // verify the shape ranges of While's output tensor - std::vector> x_range; - std::vector> expected_range = {{0, -1}, {0, -1}, {224, 224}, {224, 224}}; - while_node->GetOpDesc()->MutableOutputDesc(0)->GetShapeRange(x_range); - EXPECT_EQ(x_range, expected_range); -} - -TEST_F(UtestShapeRefiner, InferShapeAndType_UpdateSubGraphDataNodes) { - auto graph = CreateGraphWithMultiSubgraph(); - auto p1_node = graph->FindNode("partcall1"); - EXPECT_NE(p1_node, nullptr); - (void)AttrUtils::SetBool(p1_node->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, true); - EXPECT_EQ(ShapeRefiner::InferShapeAndType(p1_node, true), GRAPH_SUCCESS); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/small_vector_ut.cc b/tests/ut/graph/testcase/small_vector_ut.cc deleted file mode 100644 index 2381e576d32efde2c4aee30f5f1dfd6c05a764a1..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/small_vector_ut.cc +++ /dev/null @@ -1,1355 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/small_vector.h" -#include "test_structs.h" -#include "func_counter.h" -namespace ge { -class SmallVectorUt : public testing::Test {}; - -TEST_F(SmallVectorUt, ConstructAndFree) { - auto test = []() { - std::vector v1; - std::vector v2(std::move(v1)); - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - SmallVector vec6({InlineStructB(), InlineStructB()}); - }; - EXPECT_NO_THROW(test();); -} - -TEST_F(SmallVectorUt, Construct_CallCorrectConstructor1) { - FuncCounter::Clear(); - SmallVector vec1; - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter fc; - FuncCounter::Clear(); - SmallVector vec2(10, fc); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 10); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - SmallVector vec3(vec2); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 10); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - SmallVector vec4(std::move(vec2)); // 直接挪指针 - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, Construct_CallCorrectConstructor2) { - FuncCounter fc; - SmallVector vec2(3, fc); - - FuncCounter::Clear(); - SmallVector vec3(std::move(vec2)); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 3); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 3); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, CopyConstructorAndFree) { - auto test = []() { - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - - SmallVector vec2_1(vec1); - SmallVector vec2_2(vec2); - SmallVector vec2_3(vec3); - SmallVector vec2_4(vec4); - SmallVector vec2_5(vec5); - }; - EXPECT_NO_THROW(test();); -} - -TEST_F(SmallVectorUt, ConstructorCorrectSizeAndCap) { - std::vector veccccc; - SmallVector vec1; - EXPECT_EQ(vec1.size(), 0); - EXPECT_EQ(vec1.capacity(), 5); - - SmallVector vec2(2); - EXPECT_EQ(vec2.size(), 2); - EXPECT_EQ(vec2.capacity(), 5); - - SmallVector vec2_10(2, 10); - EXPECT_EQ(vec2.size(), 2); - EXPECT_EQ(vec2.capacity(), 5); - int32_t expect_vec2_10[2] = {10, 10}; - EXPECT_EQ(memcmp(vec2_10.data(), expect_vec2_10, sizeof(expect_vec2_10)), 0); - - SmallVector vec4(6, 10); - EXPECT_EQ(vec4.size(), 6); - EXPECT_GE(vec4.capacity(), 6); - int32_t expect_vec4_10[6] = {10, 10, 10, 10, 10, 10}; - EXPECT_EQ(memcmp(vec4.data(), expect_vec4_10, sizeof(expect_vec4_10)), 0); - - SmallVector vec6({1,2,3}); - EXPECT_EQ(vec6.size(), 3); - EXPECT_EQ(vec6.capacity(), 5); - int32_t expect_vec6_10[] = {1,2,3}; - EXPECT_EQ(memcmp(vec6.data(), expect_vec6_10, sizeof(expect_vec6_10)), 0); - - SmallVector vec7({1,2,3,4,5,6}); - EXPECT_EQ(vec7.size(), 6); - EXPECT_GE(vec7.capacity(), 6); - int32_t expect_vec7_10[] = {1,2,3,4,5,6}; - EXPECT_EQ(memcmp(vec7.data(), expect_vec7_10, sizeof(expect_vec7_10)), 0); -} - -TEST_F(SmallVectorUt, MoveConstructor) { - auto test = []() { - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - - SmallVector vec2_1(std::move(vec1)); - SmallVector vec2_2(std::move(vec2)); - SmallVector vec2_3(std::move(vec3)); - SmallVector vec2_4(std::move(vec4)); - SmallVector vec2_5(std::move(vec5)); - }; - EXPECT_NO_THROW(test();); -} - -TEST_F(SmallVectorUt, MoveAssign) { - auto test = []() { - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - - SmallVector vec2_1; - SmallVector vec2_6; - - vec2_1 = std::move(vec1); - vec2_1 = std::move(vec3); - vec2_1 = std::move(vec2); - - vec2_6 = SmallVector(20); - vec2_6 = std::move(vec4); - vec2_6 = std::move(vec5); - vec2_6 = SmallVector(3); - }; - EXPECT_NO_THROW(test();); -} - -TEST_F(SmallVectorUt, CopyAssignInlineCap) { - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - - SmallVector vec6; - - // 只要dst size不大于5,不论size变大还是变小,那么vec6的cap不会变化 - vec6 = vec1; - EXPECT_EQ(vec6.capacity(), 5); - EXPECT_EQ(vec6.size(), 0); - vec6 = vec3; - EXPECT_EQ(vec6.capacity(), 5); - EXPECT_EQ(vec6.size(), 5); - vec6 = vec2; - EXPECT_EQ(vec6.capacity(), 5); - EXPECT_EQ(vec6.size(), 2); -} - -TEST_F(SmallVectorUt, CopyAssignInlineToAlloc) { - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - - SmallVector vec6(3); - vec6 = vec5; - EXPECT_EQ(vec6.capacity(), 10); - EXPECT_EQ(vec6.size(), 10); -} - -TEST_F(SmallVectorUt, CopyAssignAllocToInline) { - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - - SmallVector vec6(10); - - // 为了减少内存申请和释放的次数,即使vec6使用allocated_storage,size降回到N以下时,也不会用回inline_storage了 - vec6 = vec2; - EXPECT_EQ(vec6.capacity(), 10); - EXPECT_EQ(vec6.size(), 2); - - vec6 = vec4; - EXPECT_EQ(vec6.capacity(), 10); - EXPECT_EQ(vec6.size(), 6); -} - -TEST_F(SmallVectorUt, CopyAssignAllocExpand) { - SmallVector vec3(8); - SmallVector vec4(10); - SmallVector vec5(20); - - SmallVector vec6(9); - - vec6 = vec3; - EXPECT_EQ(vec6.capacity(), 9); - EXPECT_EQ(vec6.size(), 8); - - vec6 = vec5; - EXPECT_EQ(vec6.capacity(), 20); - EXPECT_EQ(vec6.size(), 20); - - vec6 = vec4; - EXPECT_EQ(vec6.capacity(), 20); - EXPECT_EQ(vec6.size(), 10); -} - -TEST_F(SmallVectorUt, CopyAssignOk1) { - SmallVector, 100> sv1; - sv1.emplace_back(10, 100); - sv1.emplace_back(10, 200); - - SmallVector, 100> sv2 = sv1; - sv1[0].push_back(100); - EXPECT_EQ(sv2.size(), sv1.size()); - EXPECT_NE(sv2[0], sv1[0]); - EXPECT_EQ(sv2[1], sv1[1]); -} - -TEST_F(SmallVectorUt, Assign_CallCorrectConstructor) { - SmallVector vec1(5); - SmallVector vec2; - SmallVector vec3; - - FuncCounter::Clear(); - vec2 = vec1; - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 5); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - vec3 = std::move(vec1); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 5); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 5); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - vec2 = vec3; - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 5); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 5); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - vec2 = std::move(vec3); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 10); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 5); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, Clear_CallCorrectConstructor) { - SmallVector vec1(5); - SmallVector vec2; - - FuncCounter::Clear(); - vec1.clear(); - EXPECT_EQ(FuncCounter::destruct_times, 5); - FuncCounter::Clear(); - vec2.clear(); - EXPECT_EQ(FuncCounter::destruct_times, 0); -} - -TEST_F(SmallVectorUt, Insert_CallCorrectConstructor) { - SmallVector vec1(5); - - FuncCounter fc; - FuncCounter::Clear(); - vec1.insert(vec1.end(), fc); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 1); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - vec1.insert(vec1.begin(), fc); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 1); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 6); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 6); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, InsertExpand_CallCorrectConstructor) { - SmallVector vec1(5); - - FuncCounter fc; - FuncCounter::Clear(); - vec1.insert(vec1.end(), fc); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 5); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 5); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 1); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, MoveInsert_CallCorrectConstructor) { - SmallVector vec1(5); - - FuncCounter fc1; - FuncCounter::Clear(); - vec1.insert(vec1.end(), std::move(fc1)); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 1); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter fc2; - FuncCounter::Clear(); - vec1.insert(vec1.begin(), std::move(fc2)); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 7); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 6); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, InsertMultiple_CallCorrectConstructor) { - SmallVector vec1(5); - - FuncCounter fc; - FuncCounter::Clear(); - vec1.insert(vec1.end(), 3, fc); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 3); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - vec1.insert(vec1.begin(), 3, fc); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 3); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 8); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 8); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} -TEST_F(SmallVectorUt, ClearOk) { - SmallVector vec1; - SmallVector vec2(2); - SmallVector vec3(5); - SmallVector vec4(6); - SmallVector vec5(10); - - vec1.clear(); - EXPECT_EQ(vec1.size(), 0); - EXPECT_EQ(vec1.capacity(), 5); - - vec2.clear(); - EXPECT_EQ(vec2.size(), 0); - EXPECT_EQ(vec2.capacity(), 5); - - vec3.clear(); - EXPECT_EQ(vec3.size(), 0); - EXPECT_EQ(vec3.capacity(), 5); - - vec4.clear(); - EXPECT_EQ(vec4.size(), 0); - EXPECT_EQ(vec4.capacity(), 5); - - vec5.clear(); - EXPECT_EQ(vec5.size(), 0); - EXPECT_EQ(vec5.capacity(), 5); -} - -TEST_F(SmallVectorUt, At) { - SmallVector vec2(2); - SmallVector vec3(3); - - for (int32_t i = 0; i < 10; ++i) { - vec2.at(0).Set(i, i); - vec2.at(1).Set(i, i * 10 + 1); - - vec3.at(0).Set(i, i * 100 + 10); - vec3.at(1).Set(i, i * 100 + 11); - vec3.at(2).Set(i, i * 100 + 12); - } - - const SmallVector &read_vec2 = vec2; - const SmallVector &read_vec3 = vec3; - - for (int32_t i = 0; i < 10; ++i) { - EXPECT_EQ(read_vec2.at(0).Get(i), i); - EXPECT_EQ(read_vec2.at(1).Get(i), i * 10 + 1); - - EXPECT_EQ(read_vec3.at(0).Get(i), i * 100 + 10); - EXPECT_EQ(read_vec3.at(1).Get(i), i * 100 + 11); - EXPECT_EQ(read_vec3.at(2).Get(i), i * 100 + 12); - } -} - -TEST_F(SmallVectorUt, BeginAndEnd) { - SmallVector vec0; - SmallVector vec1(1); - SmallVector vec2(2); - SmallVector vec3(3); - - size_t iter_count = 0; - for (auto iter = vec3.begin(); iter != vec3.end(); ++iter) { - EXPECT_EQ(&*iter, &vec3.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 3); - - iter_count = 0; - for (auto iter = vec2.begin(); iter != vec2.end(); ++iter) { - EXPECT_EQ(&*iter, &vec2.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 2); - - iter_count = 0; - for (auto iter = vec1.begin(); iter != vec1.end(); ++iter) { - EXPECT_EQ(&*iter, &vec1.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 1); - - iter_count = 0; - for (auto iter = vec0.begin(); iter != vec0.end(); ++iter) { - iter_count++; - } - EXPECT_TRUE(iter_count == 0); -} - -TEST_F(SmallVectorUt, CBeginAndCEnd) { - SmallVector vec0; - SmallVector vec1(1); - SmallVector vec2(2); - SmallVector vec3(3); - - size_t iter_count = 0; - for (auto iter = vec3.cbegin(); iter != vec3.cend(); ++iter) { - EXPECT_EQ(&*iter, &vec3.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 3); - - iter_count = 0; - for (auto iter = vec2.cbegin(); iter != vec2.cend(); ++iter) { - EXPECT_EQ(&*iter, &vec2.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 2); - - iter_count = 0; - for (auto iter = vec1.cbegin(); iter != vec1.cend(); ++iter) { - EXPECT_EQ(&*iter, &vec1.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 1); - - iter_count = 0; - for (auto iter = vec0.cbegin(); iter != vec0.cend(); ++iter) { - iter_count++; - } - EXPECT_TRUE(iter_count == 0); -} - -TEST_F(SmallVectorUt, BeginAndEnd_Const) { - const SmallVector vec0; - const SmallVector vec1(1); - const SmallVector vec2(2); - const SmallVector vec3(3); - - size_t iter_count = 0; - for (auto iter = vec3.begin(); iter != vec3.end(); ++iter) { - EXPECT_EQ(&*iter, &vec3.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 3); - - iter_count = 0; - for (auto iter = vec2.begin(); iter != vec2.end(); ++iter) { - EXPECT_EQ(&*iter, &vec2.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 2); - - iter_count = 0; - for (auto iter = vec1.begin(); iter != vec1.end(); ++iter) { - EXPECT_EQ(&*iter, &vec1.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 1); - - iter_count = 0; - for (auto iter = vec0.begin(); iter != vec0.end(); ++iter) { - iter_count++; - } - EXPECT_TRUE(iter_count == 0); -} - -TEST_F(SmallVectorUt, CBeginAndCEnd_Const) { - const SmallVector vec0; - const SmallVector vec1(1); - const SmallVector vec2(2); - const SmallVector vec3(3); - - size_t iter_count = 0; - for (auto iter = vec3.cbegin(); iter != vec3.cend(); ++iter) { - EXPECT_EQ(&*iter, &vec3.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 3); - - iter_count = 0; - for (auto iter = vec2.cbegin(); iter != vec2.cend(); ++iter) { - EXPECT_EQ(&*iter, &vec2.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 2); - - iter_count = 0; - for (auto iter = vec1.cbegin(); iter != vec1.cend(); ++iter) { - EXPECT_EQ(&*iter, &vec1.at(iter_count)); - iter_count++; - } - EXPECT_EQ(iter_count, 1); - - iter_count = 0; - for (auto iter = vec0.cbegin(); iter != vec0.cend(); ++iter) { - iter_count++; - } - EXPECT_TRUE(iter_count == 0); -} - -TEST_F(SmallVectorUt, Front1) { - const SmallVector vec1(1); - const SmallVector vec2(4); - const SmallVector vec3(8); - EXPECT_EQ(&vec1.front(), &vec1.at(0)); - EXPECT_EQ(&vec2.front(), &vec2.at(0)); - EXPECT_EQ(&vec3.front(), &vec3.at(0)); - - SmallVector vec1_1(1); - SmallVector vec1_2(4); - SmallVector vec1_3(8); - EXPECT_EQ(&vec1_1.front(), &vec1_1.at(0)); - EXPECT_EQ(&vec1_2.front(), &vec1_2.at(0)); - EXPECT_EQ(&vec1_3.front(), &vec1_3.at(0)); -} - -TEST_F(SmallVectorUt, Back1) { - const SmallVector vec1(1); - const SmallVector vec2(4); - const SmallVector vec3(8); - EXPECT_EQ(&vec1.back(), &vec1.at(0)); - EXPECT_EQ(&vec2.back(), &vec2.at(3)); - EXPECT_EQ(&vec3.back(), &vec3.at(7)); - - SmallVector vec1_1(1); - SmallVector vec1_2(4); - SmallVector vec1_3(8); - EXPECT_EQ(&vec1_1.back(), &vec1_1.at(0)); - EXPECT_EQ(&vec1_2.back(), &vec1_2.at(3)); - EXPECT_EQ(&vec1_3.back(), &vec1_3.at(7)); -} - -TEST_F(SmallVectorUt, FrontAndBack) { - const SmallVector vec1(1); - const SmallVector vec2(4); - const SmallVector vec3(8); - EXPECT_EQ(&vec1.front(), &vec1.back()); - EXPECT_NE(&vec2.front(), &vec2.back()); - EXPECT_NE(&vec3.front(), &vec3.back()); - - SmallVector vec1_1(1); - SmallVector vec1_2(4); - SmallVector vec1_3(8); - EXPECT_EQ(&vec1_1.front(), &vec1_1.back()); - EXPECT_NE(&vec1_2.front(), &vec1_2.back()); - EXPECT_NE(&vec1_3.front(), &vec1_3.back()); -} - -TEST_F(SmallVectorUt, RCIter_Const) { - const SmallVector vec0; - const SmallVector vec1(1); - const SmallVector vec2(2); - const SmallVector vec3(3); - - size_t iter_count = 3; - for (auto iter = vec3.crbegin(); iter != vec3.crend(); ++iter) { - EXPECT_EQ(&*iter, &vec3.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 2; - for (auto iter = vec2.crbegin(); iter != vec2.crend(); ++iter) { - EXPECT_EQ(&*iter, &vec2.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 1; - for (auto iter = vec1.crbegin(); iter != vec1.crend(); ++iter) { - EXPECT_EQ(&*iter, &vec1.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 0; - for (auto iter = vec0.crbegin(); iter != vec0.crend(); ++iter) { - iter_count++; - } - EXPECT_TRUE(iter_count == 0); -} - -TEST_F(SmallVectorUt, RCIter_Const1) { - const SmallVector vec0; - const SmallVector vec1(1); - const SmallVector vec2(2); - const SmallVector vec3(3); - - size_t iter_count = 3; - for (auto iter = vec3.rbegin(); iter != vec3.rend(); ++iter) { - EXPECT_EQ(&*iter, &vec3.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 2; - for (auto iter = vec2.rbegin(); iter != vec2.rend(); ++iter) { - EXPECT_EQ(&*iter, &vec2.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 1; - for (auto iter = vec1.rbegin(); iter != vec1.rend(); ++iter) { - EXPECT_EQ(&*iter, &vec1.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 0; - for (auto iter = vec0.rbegin(); iter != vec0.rend(); ++iter) { - iter_count++; - } - EXPECT_TRUE(iter_count == 0); -} - -TEST_F(SmallVectorUt, RCIter1) { - SmallVector vec0; - SmallVector vec1(1); - SmallVector vec2(2); - SmallVector vec3(3); - - size_t iter_count = 3; - for (auto iter = vec3.rbegin(); iter != vec3.rend(); ++iter) { - EXPECT_EQ(&*iter, &vec3.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 2; - for (auto iter = vec2.rbegin(); iter != vec2.rend(); ++iter) { - EXPECT_EQ(&*iter, &vec2.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 1; - for (auto iter = vec1.rbegin(); iter != vec1.rend(); ++iter) { - EXPECT_EQ(&*iter, &vec1.at(--iter_count)); - } - EXPECT_EQ(iter_count, 0); - - iter_count = 0; - for (auto iter = vec0.rbegin(); iter != vec0.rend(); ++iter) { - iter_count++; - } - EXPECT_TRUE(iter_count == 0); -} - -TEST_F(SmallVectorUt, EmptyOk) { - const SmallVector vec0; - const SmallVector vec1(1); - const SmallVector vec2(2); - const SmallVector vec3(3); - - EXPECT_TRUE(vec0.empty()); - EXPECT_FALSE(vec1.empty()); - EXPECT_FALSE(vec2.empty()); - EXPECT_FALSE(vec3.empty()); -} - -TEST_F(SmallVectorUt, SizeOk) { - const SmallVector vec0; - const SmallVector vec1(1); - const SmallVector vec2(2); - const SmallVector vec3(3); - - EXPECT_EQ(vec0.size(), 0); - EXPECT_EQ(vec1.size(), 1); - EXPECT_EQ(vec2.size(), 2); - EXPECT_EQ(vec3.size(), 3); -} - -void RandomB(InlineStructB &b) { - for (int32_t i = 0; i < 10; ++i) { - b.Set(i, rand()); - } -} - -TEST_F(SmallVectorUt, InsertFront) { - SmallVector vec0; - SmallVector vec1(1); - SmallVector vec2(2); - SmallVector vec3(3); - - InlineStructB b; - for (int32_t i = 0; i < 10; ++i) { - b.Set(i, i * 10); - } - InlineStructB b1; - RandomB(b1); - InlineStructB b2; - RandomB(b2); - - auto iter = vec0.insert(vec0.cbegin(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec0.size(), 1); - - vec1[0] = b1; - iter = vec1.insert(vec1.cbegin(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec1[0], b); - EXPECT_EQ(vec1[1], b1); - EXPECT_EQ(vec1.size(), 2); - - vec2[0] = b1; - vec2[1] = b2; - iter = vec2.insert(vec2.cbegin(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec2[0], b); - EXPECT_EQ(vec2[1], b1); - EXPECT_EQ(vec2[2], b2); - EXPECT_EQ(vec2.size(), 3); - - vec3[1] = b1; - vec3[2] = b2; - iter = vec3.insert(vec3.cbegin(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec3[0], b); - EXPECT_EQ(vec3[2], b1); - EXPECT_EQ(vec3[3], b2); - EXPECT_EQ(vec3.size(), 4); -} - -TEST_F(SmallVectorUt, InsertEnd) { - SmallVector vec0; - SmallVector vec1(1); - SmallVector vec2(2); - SmallVector vec3(3); - - InlineStructB b; - for (int32_t i = 0; i < 10; ++i) { - b.Set(i, i * 10); - } - InlineStructB b1; - RandomB(b1); - InlineStructB b2; - RandomB(b2); - - auto iter = vec0.insert(vec0.end(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec0.size(), 1); - - vec1[0] = b1; - iter = vec1.insert(vec1.end(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec1[1], b); - EXPECT_EQ(vec1[0], b1); - EXPECT_EQ(vec1.size(), 2); - - vec2[0] = b1; - vec2[1] = b2; - iter = vec2.insert(vec2.end(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec2[2], b); - EXPECT_EQ(vec2[0], b1); - EXPECT_EQ(vec2[1], b2); - EXPECT_EQ(vec2.size(), 3); - - vec3[1] = b1; - vec3[2] = b2; - iter = vec3.insert(vec3.end(), b); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec3[3], b); - EXPECT_EQ(vec3[1], b1); - EXPECT_EQ(vec3[2], b2); - EXPECT_EQ(vec3.size(), 4); -} - -TEST_F(SmallVectorUt, InsertMid) { - SmallVector vec2(2); - SmallVector vec3(3); - - InlineStructB b; - for (int32_t i = 0; i < 10; ++i) { - b.Set(i, i * 10); - } - InlineStructB b1; - RandomB(b1); - InlineStructB b2; - RandomB(b2); - - vec2[0] = b1; - vec2[1] = b2; - auto iter = vec2.insert(vec2.begin() + 1, b); // b1, b, b2 - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec2[1], b); - EXPECT_EQ(vec2[0], b1); - EXPECT_EQ(vec2[2], b2); - EXPECT_EQ(vec2.size(), 3); - - vec3[1] = b1; - vec3[2] = b2; - iter = vec3.insert(vec3.begin() + 1, b); // xx, b, b1, b2 - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec3[1], b); - EXPECT_EQ(vec3[2], b1); - EXPECT_EQ(vec3[3], b2); - EXPECT_EQ(vec3.size(), 4); -} - -TEST_F(SmallVectorUt, InsertMid_Move) { - SmallVector vec3(3); - - InlineStructB b; - for (int32_t i = 0; i < 10; ++i) { - b.Set(i, i * 10); - } - InlineStructB b1; - RandomB(b1); - InlineStructB b2; - RandomB(b2); - InlineStructB b_back = b; - auto b_back_p = b.GetP(); - - vec3[1] = b1; - vec3[2] = b2; - auto iter = vec3.insert(vec3.begin() + 1, std::move(b)); // xx, b, b1, b2 - EXPECT_EQ(vec3.size(), 4); - EXPECT_EQ(*iter, b_back); - EXPECT_EQ(iter->GetP(), b_back_p); - EXPECT_EQ(vec3[1], b_back); - EXPECT_EQ(vec3[2], b1); - EXPECT_EQ(vec3[3], b2); -} - -TEST_F(SmallVectorUt, InsertMid_Multiple) { - SmallVector, 3> vec2(2); - SmallVector, 3> vec3(3); - SmallVector, 3> vec4(4); - - std::vector b{1, 2, 3, 4, 5}; - std::vector b1{6, 7, 8, 9, 10}; - std::vector b2{11, 12, 13, 14, 15}; - - vec2[0] = b1; - vec2[1] = b2; - auto iter = vec2.insert(vec2.begin() + 1, 3, b); // b1, b, b, b, b2 - EXPECT_EQ(vec2.size(), 5); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec2[0], b1); - EXPECT_EQ(vec2[1], b); - EXPECT_EQ(vec2[2], b); - EXPECT_EQ(vec2[3], b); - EXPECT_EQ(vec2[4], b2); - - vec3[1] = b1; - vec3[2] = b2; - iter = vec3.insert(vec3.begin() + 1, 3, b); // xx, b, b, b, b1, b2 - EXPECT_EQ(vec3.size(), 6); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec3[1], b); - EXPECT_EQ(vec3[2], b); - EXPECT_EQ(vec3[3], b); - EXPECT_EQ(vec3[4], b1); - EXPECT_EQ(vec3[5], b2); - - vec4[0] = b1; - vec4[1] = b2; - vec4[2] = b1; - vec4[3] = b2; - iter = vec4.insert(vec4.begin() + 2, 1, b); // b1,b2,b,b1,b2 - EXPECT_EQ(vec4.size(), 5); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec4[0], b1); - EXPECT_EQ(vec4[1], b2); - EXPECT_EQ(vec4[2], b); - EXPECT_EQ(vec4[3], b1); - EXPECT_EQ(vec4[4], b2); -} - -TEST_F(SmallVectorUt, InsertMid_Multiple1) { - SmallVector vec2(2); - SmallVector vec3(3); - SmallVector vec4(4); - - InlineStructB b; - for (int32_t i = 0; i < 10; ++i) { - b.Set(i, i * 10); - } - InlineStructB b1; - RandomB(b1); - InlineStructB b2; - RandomB(b2); - - vec2[0] = b1; - vec2[1] = b2; - auto iter = vec2.insert(vec2.begin() + 1, 3, b); // b1, b, b, b, b2 - EXPECT_EQ(vec2.size(), 5); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec2[0], b1); - EXPECT_EQ(vec2[1], b); - EXPECT_EQ(vec2[2], b); - EXPECT_EQ(vec2[3], b); - EXPECT_EQ(vec2[4], b2); - - vec3[1] = b1; - vec3[2] = b2; - iter = vec3.insert(vec3.begin() + 1, 3, b); // xx, b, b, b, b1, b2 - EXPECT_EQ(vec3.size(), 6); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec3[1], b); - EXPECT_EQ(vec3[2], b); - EXPECT_EQ(vec3[3], b); - EXPECT_EQ(vec3[4], b1); - EXPECT_EQ(vec3[5], b2); - - vec4[0] = b1; - vec4[1] = b2; - vec4[2] = b1; - vec4[3] = b2; - iter = vec4.insert(vec4.begin() + 2, 1, b); // b1,b2,b,b1,b2 - EXPECT_EQ(vec4.size(), 5); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec4[0], b1); - EXPECT_EQ(vec4[1], b2); - EXPECT_EQ(vec4[2], b); - EXPECT_EQ(vec4[3], b1); - EXPECT_EQ(vec4[4], b2); -} - -TEST_F(SmallVectorUt, InsertMid_Multiple2) { - SmallVector, 3> vec2(2); - - std::vector b{1, 2, 3, 4, 5}; - std::vector b1{6, 7, 8, 9, 10}; - std::vector b2{11, 12, 13, 14, 15}; - - vec2[0] = b1; - vec2[1] = b2; - auto iter = vec2.insert(vec2.begin() + 1, 8, b); // b1, b[8], b2 - EXPECT_EQ(vec2.size(), 10); - EXPECT_GE(vec2.capacity(), 10); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec2[0], b1); - for (auto i = 0; i < 8; ++i) { - EXPECT_EQ(vec2[i + 1], b); - } - EXPECT_EQ(vec2[9], b2); -} - -TEST_F(SmallVectorUt, InsertMid_Multiple_NotExpandCap) { - SmallVector, 8> vec2(4); - - std::vector b{1, 2, 3, 4, 5}; - std::vector b1{6, 7, 8, 9, 10}; - std::vector b2{11, 12, 13, 14, 15}; - std::vector b3{16, 17, 18, 19, 20}; - std::vector b4{21, 22, 13, 14, 15}; - vec2[0] = b1; - vec2[1] = b2; - vec2[2] = b3; - vec2[3] = b4; - auto iter = vec2.insert(vec2.begin() + 1, 4, b); // b1, b[4], b2 - EXPECT_EQ(vec2.size(), 8); - EXPECT_EQ(vec2.capacity(), 8); - EXPECT_EQ(*iter, b); - EXPECT_EQ(vec2[0], b1); - EXPECT_EQ(vec2[1], b); - EXPECT_EQ(vec2[2], b); - EXPECT_EQ(vec2[3], b); - EXPECT_EQ(vec2[4], b); - EXPECT_EQ(vec2[5], b2); - EXPECT_EQ(vec2[6], b3); - EXPECT_EQ(vec2[7], b4); -} - -TEST_F(SmallVectorUt, InsertListOk) { - SmallVector vec1; - vec1.insert(vec1.end(), {1, 2, 3, 4, 5}); - int64_t expect1[] = {1,2,3,4,5}; - EXPECT_EQ(vec1.size(), 5); - EXPECT_EQ(memcmp(vec1.data(), expect1, sizeof(expect1)), 0); - - vec1.insert(vec1.begin(), {10, 20, 30}); - int64_t expect2[] = {10,20,30,1,2,3,4,5}; - EXPECT_EQ(vec1.size(), 8); - EXPECT_EQ(memcmp(vec1.data(), expect2, sizeof(expect2)), 0); - - // expand - vec1.insert(vec1.begin() + 3, {100, 20, 30}); - int64_t expect3[] = {10,20,30,100,20,30,1,2,3,4,5}; - EXPECT_EQ(vec1.size(), 11); - EXPECT_GE(vec1.capacity(), 11); - EXPECT_EQ(memcmp(vec1.data(), expect3, sizeof(expect3)), 0); -} - -TEST_F(SmallVectorUt, EmplaceOk) { - SmallVector, 8> vec2(4); - - std::vector b_1{1, 1, 1, 1, 1}; - std::vector b1{6, 7, 8, 9, 10}; - std::vector b2{11, 12, 13, 14, 15}; - std::vector b3{16, 17, 18, 19, 20}; - std::vector b4{21, 22, 13, 14, 15}; - vec2[0] = b1; - vec2[1] = b2; - vec2[2] = b3; - vec2[3] = b4; - auto iter = vec2.emplace(vec2.begin() + 1, 5, 1); // b1, b[4], b2 - EXPECT_EQ(vec2.size(), 5); - EXPECT_EQ(vec2.capacity(), 8); - EXPECT_EQ(*iter, b_1); - EXPECT_EQ(vec2[0], b1); - EXPECT_EQ(vec2[1], b_1); - EXPECT_EQ(vec2[2], b2); - EXPECT_EQ(vec2[3], b3); - EXPECT_EQ(vec2[4], b4); -} - -TEST_F(SmallVectorUt, EraseOk) { - SmallVector vec1{1,2,3,4,5,6,7}; - - vec1.erase(vec1.begin()); - EXPECT_EQ(vec1.size(), 6); - int64_t vec1_expect_1[] = {2,3,4,5,6,7}; - EXPECT_EQ(memcmp(vec1.data(), vec1_expect_1, sizeof(vec1_expect_1)), 0); - - vec1.erase(vec1.begin() + 2); - EXPECT_EQ(vec1.size(), 5); - int64_t vec1_expect_2[] = {2,3,5,6,7}; - EXPECT_EQ(memcmp(vec1.data(), vec1_expect_2, sizeof(vec1_expect_2)), 0); -} - -TEST_F(SmallVectorUt, EraseAllOk) { - SmallVector vec1{1,2,3,4,5,6,7}; - - vec1.erase(vec1.begin(), vec1.end()); - EXPECT_EQ(vec1.size(), 0); -} - -TEST_F(SmallVectorUt, EraseEmptyOk) { - SmallVector vec1{1,2,3,4,5,6,7}; - - vec1.erase(vec1.begin(), vec1.begin()); - EXPECT_EQ(vec1.size(), 7); - int64_t vec1_expect_1[] = {1,2,3,4,5,6,7}; - EXPECT_EQ(memcmp(vec1.data(), vec1_expect_1, sizeof(vec1_expect_1)), 0); -} - -TEST_F(SmallVectorUt, Erase_CallCorrectConstructor) { - SmallVector vec1(11); - - FuncCounter::Clear(); - vec1.erase(vec1.begin()); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 11); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 10); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, PopBackOk) { - SmallVector vec1{1,2,3,4,5,6,7}; - - vec1.pop_back(); - EXPECT_EQ(vec1.size(), 6); - int64_t expect[] = {1,2,3,4,5,6}; - EXPECT_EQ(memcmp(vec1.data(), expect, sizeof(expect)), 0); -} - -TEST_F(SmallVectorUt, PopBack_CallDestructor) { - SmallVector vec1(7); - - FuncCounter::Clear(); - vec1.pop_back(); - EXPECT_EQ(vec1.size(), 6); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 1); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, ResizeOk1) { - SmallVector vec1{1,2,3,4,5,6,7}; - - vec1.resize(9); - int64_t expect_1[] = {1,2,3,4,5,6,7,0,0}; - EXPECT_EQ(vec1.size(), 9); - EXPECT_EQ(memcmp(vec1.data(), expect_1, sizeof(expect_1)), 0); - - // expand - vec1.resize(11); - int64_t expect_2[] = {1,2,3,4,5,6,7,0,0,0,0}; - EXPECT_EQ(vec1.size(), 11); - EXPECT_EQ(memcmp(vec1.data(), expect_2, sizeof(expect_2)), 0); - - // expand again - auto next_size = vec1.capacity() + 1; - vec1.resize(next_size); - auto expect_3 = std::unique_ptr(new int64_t[next_size]()); - for (int64_t i = 0; i < 7; ++i) { - expect_3[i] = i + 1; - } - EXPECT_EQ(vec1.size(), next_size); - EXPECT_EQ(memcmp(vec1.data(), expect_3.get(), sizeof(int64_t) * next_size), 0); -} - -TEST_F(SmallVectorUt, ResizeOk2) { - SmallVector vec1{1,2,3,4,5,6,7}; - - vec1.resize(5); - int64_t expect_1[] = {1,2,3,4,5}; - EXPECT_EQ(vec1.size(), 5); - EXPECT_EQ(memcmp(vec1.data(), expect_1, sizeof(expect_1)), 0); - - vec1.resize(7); - int64_t expect_2[] = {1,2,3,4,5,0,0}; - EXPECT_EQ(vec1.size(), 7); - EXPECT_EQ(memcmp(vec1.data(), expect_2, sizeof(expect_2)), 0); -} - -TEST_F(SmallVectorUt, ResizeOk1_CallCorrectConstructor) { - SmallVector vec1(7); - - FuncCounter::Clear(); - vec1.resize(9); - EXPECT_EQ(FuncCounter::GetClearConstructTimes(), 2); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - // expand - vec1.resize(11); - EXPECT_EQ(FuncCounter::GetClearConstructTimes(), 2); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 9); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 9); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - // expand again - auto next_size = vec1.capacity() + 1; - vec1.resize(next_size); - auto expect_3 = std::unique_ptr(new int64_t[next_size]()); - for (int64_t i = 0; i < 7; ++i) { - expect_3[i] = i + 1; - } - EXPECT_EQ(FuncCounter::GetClearConstructTimes(), next_size - 11); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 11); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 11); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, PushBackOk1) { - SmallVector vec1({1,2,3}); - vec1.push_back(10); - EXPECT_EQ(vec1.size(), 4); - int64_t expect_1[] = {1,2,3,10}; - EXPECT_EQ(memcmp(vec1.data(), expect_1, sizeof(expect_1)), 0); -} - -TEST_F(SmallVectorUt, PushBackExpandOk2) { - SmallVector vec1({1,2,3,4,5}); - vec1.push_back(10); - EXPECT_EQ(vec1.size(), 6); - int64_t expect_1[] = {1,2,3,4,5,10}; - EXPECT_EQ(memcmp(vec1.data(), expect_1, sizeof(expect_1)), 0); -} - -TEST_F(SmallVectorUt, PushBackExpandOk3) { - SmallVector vec1({1,2,3,4,5,6}); - vec1.push_back(10); - EXPECT_EQ(vec1.size(), 7); - int64_t expect_1[] = {1,2,3,4,5,6,10}; - EXPECT_EQ(memcmp(vec1.data(), expect_1, sizeof(expect_1)), 0); -} - -TEST_F(SmallVectorUt, PushBack_CallCorrectConstructor) { - SmallVector vec1(4); - FuncCounter fc; - - FuncCounter::Clear(); - vec1.push_back(fc); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 1); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - vec1.push_back(fc); - EXPECT_EQ(FuncCounter::GetClearCopyConstructTimes(), 1); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 5); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 5); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - FuncCounter::Clear(); - vec1.push_back(std::move(fc)); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 1); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, SwapOk1) { - SmallVector vec1{1,2,3,4,5000}; - SmallVector vec2{6,7,8,9000}; - - vec1.swap(vec2); - - EXPECT_EQ(vec1.size(), 4); - int32_t expect_2[] = {6,7,8,9000}; - EXPECT_EQ(memcmp(vec1.data(), expect_2, sizeof(expect_2)), 0); - - EXPECT_EQ(vec2.size(), 5); - int32_t expect_1[] = {1,2,3,4,5000}; - EXPECT_EQ(memcmp(vec2.data(), expect_1, sizeof(expect_1)), 0); -} - -TEST_F(SmallVectorUt, SwapOk2) { - SmallVector vec1{1,2,3,4,5000}; - SmallVector vec2{6,7,8,9000}; - - vec1.swap(vec2); - - EXPECT_EQ(vec1.size(), 4); - int32_t expect_2[] = {6,7,8,9000}; - EXPECT_EQ(memcmp(vec1.data(), expect_2, sizeof(expect_2)), 0); - - EXPECT_EQ(vec2.size(), 5); - int32_t expect_1[] = {1,2,3,4,5000}; - EXPECT_EQ(memcmp(vec2.data(), expect_1, sizeof(expect_1)), 0); -} - -TEST_F(SmallVectorUt, SwapOk3) { - SmallVector vec1{1,2,3,4,5000}; - SmallVector vec2{6,7,8,9000}; - - vec1.swap(vec2); - - EXPECT_EQ(vec1.size(), 4); - int32_t expect_2[] = {6,7,8,9000}; - EXPECT_EQ(memcmp(vec1.data(), expect_2, sizeof(expect_2)), 0); - - EXPECT_EQ(vec2.size(), 5); - int32_t expect_1[] = {1,2,3,4,5000}; - EXPECT_EQ(memcmp(vec2.data(), expect_1, sizeof(expect_1)), 0); -} - -TEST_F(SmallVectorUt, SwapOk3_ConstructTimes) { - SmallVector vec1(5); - SmallVector vec2(4); - - FuncCounter::Clear(); - vec1.swap(vec2); - // vec1的指针被转移到vec2,vec1的对象没有操作 - // vec2的对象移动到vec1,vec2的原对象析构 - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 4); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 4); - EXPECT_TRUE(FuncCounter::AllTimesZero()); - - std::swap(vec1, vec2); - EXPECT_EQ(FuncCounter::GetClearMoveConstructTimes(), 4); - EXPECT_EQ(FuncCounter::GetClearDestructTimes(), 4); - EXPECT_TRUE(FuncCounter::AllTimesZero()); -} - -TEST_F(SmallVectorUt, CompareOperator) { - SmallVector vec1({1,2,3,4}); - SmallVector vec2({1,2,3,4}); - SmallVector vec3({1,2,3,4,5}); - SmallVector vec4({1,2,3,5}); - SmallVector vec5({1,2,3,5}); - SmallVector vec6({1,2,3,4,5}); - - EXPECT_TRUE(vec1 == vec2); - EXPECT_FALSE(vec1 == vec3); - EXPECT_FALSE(vec1 == vec4); - EXPECT_FALSE(vec1 == vec5); - EXPECT_FALSE(vec1 == vec6); - EXPECT_TRUE(vec1 <= vec2); - EXPECT_TRUE(vec1 >= vec2); - EXPECT_FALSE(vec1 != vec2); - EXPECT_TRUE(vec1 != vec3); - EXPECT_FALSE(vec1 == vec3); - - EXPECT_TRUE(vec1 < vec3); - EXPECT_TRUE(vec1 <= vec3); - EXPECT_TRUE(vec3 < vec4); - EXPECT_TRUE(vec3 <= vec4); -} - -TEST_F(SmallVectorUt, ReserveOk1) { - SmallVector vec1; - EXPECT_EQ(vec1.size(), 0); - EXPECT_EQ(vec1.capacity(), 4); - - vec1.reserve(3); - EXPECT_EQ(vec1.size(), 0); - EXPECT_EQ(vec1.capacity(), 4); - - vec1.reserve(5); - EXPECT_EQ(vec1.size(), 0); - EXPECT_GE(vec1.capacity(), 5); -} - -TEST_F(SmallVectorUt, ReserveOk2) { - SmallVector vec1{1,2,3}; - EXPECT_EQ(vec1.size(), 3); - EXPECT_EQ(vec1.capacity(), 4); - - vec1.reserve(3); - EXPECT_EQ(vec1.size(), 3); - EXPECT_EQ(vec1.capacity(), 4); - - vec1.reserve(5); - EXPECT_EQ(vec1.size(), 3); - EXPECT_GE(vec1.capacity(), 5); - EXPECT_EQ(vec1[0], 1); - EXPECT_EQ(vec1[1], 2); - EXPECT_EQ(vec1[2], 3); -} -} // namespace ge - -namespace test_open { - using ge::SmallVector; - class OpenSmallVectorUt : public testing::Test {}; - - TEST_F(OpenSmallVectorUt, CompareOperator) { - SmallVector vec1({1,2,3,4}); - SmallVector vec2({1,2,3,4}); - SmallVector vec3({1,2,3,4,5}); - SmallVector vec4({1,2,3,5}); - SmallVector vec5({1,2,3,5}); - SmallVector vec6({1,2,3,4,5}); - - EXPECT_TRUE(vec1 == vec2); - EXPECT_FALSE(vec1 == vec3); - EXPECT_FALSE(vec1 == vec4); - EXPECT_FALSE(vec1 == vec5); - EXPECT_FALSE(vec1 == vec6); - EXPECT_TRUE(vec1 <= vec2); - EXPECT_TRUE(vec1 >= vec2); - EXPECT_FALSE(vec1 != vec2); - EXPECT_TRUE(vec1 != vec3); - EXPECT_FALSE(vec1 == vec3); - - EXPECT_TRUE(vec1 < vec3); - EXPECT_TRUE(vec1 <= vec3); - EXPECT_TRUE(vec3 < vec4); - EXPECT_TRUE(vec3 <= vec4); -} -} // namespace test_open diff --git a/tests/ut/graph/testcase/tensor_parallel_attrs_unittest.cc b/tests/ut/graph/testcase/tensor_parallel_attrs_unittest.cc deleted file mode 100644 index e4399914c8b8ba2f6ae5705773472cdf2be11718..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/tensor_parallel_attrs_unittest.cc +++ /dev/null @@ -1,953 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "nlohmann/json.hpp" - -#include "graph/parallelism/tensor_parallel_attrs.h" -#include "external/ge_common/ge_api_error_codes.h" -#include "common/ge_common/ge_inner_error_codes.h" - -using namespace testing; -namespace ge { -namespace tp { -using Json = nlohmann::json; - -class TensorParallelAttrsTest : public testing::Test { - protected: - static void TestToAndFromJson(const CommTask &comm_task, CommTask &out_comm_task) { - ReshardAttr reshard_attr; - OutputReshardRes output_reshard_res; - CommStep comm_step; - comm_step.id = 0; - comm_step.comm_task = comm_task; - output_reshard_res.comm_steps.emplace_back(comm_step); - reshard_attr.reshard_infos.emplace_back(std::vector{output_reshard_res}); - const auto &json_str = TensorParallelAttrs::ToJson(reshard_attr); - ASSERT_TRUE(!json_str.empty()); - std::cout << json_str << std::endl; - ReshardAttr reshard_attr_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, reshard_attr_from_json), SUCCESS); - out_comm_task = reshard_attr_from_json.reshard_infos[0][0].comm_steps[0].comm_task; - } -}; - -TEST_F(TensorParallelAttrsTest, ParseFailed_InvalidJsonStr) { - std::string json_str = "invalid"; - DeviceIndex device_index; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, device_index), PARAM_INVALID); -} - -TEST_F(TensorParallelAttrsTest, ParseFailed_FieldMismatches) { - std::string json_str = R"( - {"engine_type": "NPU", "index": 0} -)"; - DeviceIndex device_index; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, device_index), PARAM_INVALID); -} - -TEST_F(TensorParallelAttrsTest, DeviceIndex_ToAndFromJsonStr) { - DeviceIndex device_index; - device_index.indices = {0, 1, 2}; - device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(device_index); - DeviceIndex another_device_index; - TensorParallelAttrs::FromJson(json_str, another_device_index); - EXPECT_EQ(device_index, another_device_index); -} - -TEST_F(TensorParallelAttrsTest, ModelIndex_ToAndFromJsonStr) { - ModelIndex model_index; - model_index.stage_id = 0; - model_index.virtual_stage_id = 0; - model_index.device_index.indices = {0, 1, 2}; - model_index.device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(model_index); - ModelIndex another_model_index; - TensorParallelAttrs::FromJson(json_str, another_model_index); - EXPECT_TRUE(model_index.DebugString() == "MyEngine[0, 1, 2][S0, V0]"); - EXPECT_EQ(model_index, another_model_index); -} - -TEST_F(TensorParallelAttrsTest, ModelIndex_NotEqual) { - ModelIndex model_index; - model_index.stage_id = 0; - model_index.virtual_stage_id = 0; - model_index.device_index.indices = {0, 1, 2}; - model_index.device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(model_index); - ModelIndex another_model_index(model_index); - another_model_index.virtual_stage_id = 1; - EXPECT_TRUE(model_index != another_model_index); -} - -TEST_F(TensorParallelAttrsTest, ModelIndex_LessBigger) { - ModelIndex model_index; - model_index.stage_id = 0; - model_index.virtual_stage_id = 1; - model_index.device_index.indices = {0, 1, 2}; - model_index.device_index.engine_type = "MyEngine"; - std::string json_str = TensorParallelAttrs::ToJson(model_index); - ModelIndex another_model_index(model_index); - another_model_index.virtual_stage_id = 2; - EXPECT_TRUE(model_index < another_model_index); - EXPECT_FALSE(another_model_index < model_index); - another_model_index.virtual_stage_id = 1; - EXPECT_FALSE(another_model_index < model_index); -} - -TEST_F(TensorParallelAttrsTest, PipelineConfig_ToAndFromJsonStr) { - PipelineConfig pipeline_config; - pipeline_config.micro_batch = 1; - pipeline_config.stage_id = 0; - pipeline_config.virtual_stage_id = {0, 1}; - std::string json_str = TensorParallelAttrs::ToJson(pipeline_config); - PipelineConfig another_pipeline_config; - TensorParallelAttrs::FromJson(json_str, another_pipeline_config); - EXPECT_EQ(pipeline_config.micro_batch, another_pipeline_config.micro_batch); - EXPECT_EQ(pipeline_config.stage_id, another_pipeline_config.stage_id); - EXPECT_EQ(pipeline_config.virtual_stage_id, another_pipeline_config.virtual_stage_id); -} - -TEST_F(TensorParallelAttrsTest, DeviceIndex_operators) { - std::map device_index_to_value; - DeviceIndex device_index; - device_index.indices = {0, 1, 2}; - device_index.engine_type = "MyEngine"; - - DeviceIndex another_device_index; - device_index.indices = {0, 1, 2}; - device_index.engine_type = "CPU"; - ASSERT_TRUE(device_index_to_value.emplace(device_index, 0).second); - ASSERT_FALSE(device_index_to_value.emplace(device_index, 0).second); - ASSERT_FALSE(device_index.DebugString().empty()); - ASSERT_EQ(device_index_to_value.count(another_device_index), 0U); - ASSERT_NE(device_index, another_device_index); - ASSERT_EQ(device_index, device_index); -} - -TEST_F(TensorParallelAttrsTest, NodeDeployment_ToAndFromJsonStr) { - DeviceIndex device_index_0; - device_index_0.indices = {0, 0, 1}; - device_index_0.engine_type = "CPU"; - - DeviceIndex device_index_1; - device_index_1.indices = {0, 0, 2}; - device_index_1.engine_type = "NPU"; - - NodeDeployment node_deployment; - node_deployment.devices = {device_index_0, device_index_1}; - std::string json_str = TensorParallelAttrs::ToJson(node_deployment); - NodeDeployment another_node_deployment; - TensorParallelAttrs::FromJson(json_str, another_node_deployment); - EXPECT_EQ(node_deployment.devices, another_node_deployment.devices); -} - -TEST_F(TensorParallelAttrsTest, ParseSendRecvTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "SendReceive", - "comm_pairs": [ - { - "src_device_index": {"engine_type": "NPU", "index": [0, 0, 1]}, - "src_virtual_stage_id": 0, - "dst_device_index": {"engine_type": "NPU", "index": [0, 0, 2]}, - "dst_virtual_stage_id": 0 - } - ], - "comm_type": "Queue", - "flow_attr": { - "depth":128, - "enqueue_policy":"FIFO" - } -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.send_recv_reshard_task != nullptr); - ASSERT_EQ(comm_task.send_recv_reshard_task->comm_pairs.size(), 1U); - EXPECT_EQ(comm_task.send_recv_reshard_task->comm_pairs[0].src_device_index.indices, (std::vector{0, 0, 1})); - EXPECT_EQ(comm_task.send_recv_reshard_task->comm_pairs[0].dst_device_index.indices, (std::vector{0, 0, 2})); - EXPECT_EQ(comm_task.send_recv_reshard_task->comm_type, kSendRecvCommTypeQueue); - EXPECT_EQ(comm_task.send_recv_reshard_task->flow_attr.depth, 128); - EXPECT_EQ(comm_task.send_recv_reshard_task->flow_attr.enqueue_policy, kFlowAttrEnqueuePolicyFifo); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); -} - -TEST_F(TensorParallelAttrsTest, ParseAllGatherCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllGather", - "parallel_group": "-1", - "output_allocator": "BufferPool", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ], - "axis": 0 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.all_gather_reshard_task != nullptr); - EXPECT_EQ(out_comm_task.all_gather_reshard_task->comm_groups.size(), 4); - EXPECT_EQ(out_comm_task.all_gather_reshard_task->parallel_group, "-1"); - EXPECT_EQ(out_comm_task.all_gather_reshard_task->output_allocator, "BufferPool"); -} - -TEST_F(TensorParallelAttrsTest, ParseAllReduceCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllReduce", - "reduction": "sum", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.all_reduce_reshard_task != nullptr); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - EXPECT_EQ(out_comm_task.all_reduce_reshard_task->reduction, "sum"); - EXPECT_EQ(out_comm_task.all_reduce_reshard_task->comm_groups.size(), 4); -} - -TEST_F(TensorParallelAttrsTest, ParseAllReduceMeanCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllReduceMean", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ], - "axis": 0, - "value": 2 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.all_reduce_mean_reshard_task != nullptr); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - EXPECT_EQ(out_comm_task.all_reduce_mean_reshard_task->comm_groups.size(), 4); -} - -TEST_F(TensorParallelAttrsTest, ParseReduceScatterCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomReduceScatter", - "reduction": "sum", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - ASSERT_TRUE(comm_task.reduce_scatter_reshard_task != nullptr); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - EXPECT_EQ(out_comm_task.reduce_scatter_reshard_task->reduction, "sum"); - EXPECT_EQ(out_comm_task.reduce_scatter_reshard_task->comm_groups.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, ParseAllToAllCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomAllToAll", - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.all_to_all_reshard_task != nullptr); - EXPECT_EQ(out_comm_task.all_to_all_reshard_task->comm_groups.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, ParseSliceCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "Slice", - "offsets": [2, 4], - "size": [4, 8], - "device_index":{"engine_type": "NPU", "index": [0]} -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.slice_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.slice_reshard_task->offsets, (std::vector{2, 4})); - ASSERT_EQ(out_comm_task.slice_reshard_task->sizes, (std::vector{4, 8})); - ASSERT_EQ(out_comm_task.slice_reshard_task->device_index.engine_type, "NPU"); - ASSERT_EQ(out_comm_task.slice_reshard_task->device_index.indices, (std::vector{0})); -} - -TEST_F(TensorParallelAttrsTest, ParseSliceByAxisCommTask) { - CommTask comm_task; - comm_task.task_type = "SliceByAxis"; -// ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - comm_task.slice_by_axis_reshard_task = std::make_shared(); - auto &axis_to_slice_deployments = comm_task.slice_by_axis_reshard_task->axis_to_slice_deployments; - std::vector dim_0_slice_0_deployments{DeviceIndex{"NPU", {0, 0, 0}}, DeviceIndex{"NPU", {0, 0, 1}}}; - std::vector dim_0_slice_1_deployments{DeviceIndex{"NPU", {0, 0, 2}}, DeviceIndex{"NPU", {0, 0, 3}}}; - std::vector dim_1_slice_0_deployments{DeviceIndex{"NPU", {0, 1, 0}}, DeviceIndex{"NPU", {0, 1, 1}}}; - std::vector dim_1_slice_1_deployments{DeviceIndex{"NPU", {0, 1, 2}}, DeviceIndex{"NPU", {0, 1, 3}}}; - - axis_to_slice_deployments[0].emplace_back(dim_0_slice_0_deployments); - axis_to_slice_deployments[0].emplace_back(dim_0_slice_1_deployments); - axis_to_slice_deployments[1].emplace_back(dim_1_slice_0_deployments); - axis_to_slice_deployments[2].emplace_back(dim_1_slice_1_deployments); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.slice_by_axis_reshard_task != nullptr); -} - -TEST_F(TensorParallelAttrsTest, ParseSplitCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "Split", - "split_dim": 1, - "num_split": 2 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.split_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.split_reshard_task->split_dim, 1); - ASSERT_EQ(out_comm_task.split_reshard_task->num_split, 2); -} - -TEST_F(TensorParallelAttrsTest, ParseConcatCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "Concat", - "concat_dim": 1 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.concat_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.concat_reshard_task->concat_dim, 1); -} - -TEST_F(TensorParallelAttrsTest, ParseUniqueConcatCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "UniqueConcat", - "unique_id": "0:1", - "concat_dim": 1, - "src_device_indices": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - "dst_device_index": {"engine_type": "HOST_CPU", "index": [0, 0, 1]} -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.unique_concat_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.unique_concat_reshard_task->concat_dim, 1); - ASSERT_EQ(out_comm_task.unique_concat_reshard_task->src_device_indices.size(), 4); - DeviceIndex device_index{"HOST_CPU", {0, 0, 1}}; - ASSERT_EQ(out_comm_task.unique_concat_reshard_task->dst_device_index, device_index); -} - -TEST_F(TensorParallelAttrsTest, ParseTransposeTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "Transpose", - "perm": [1, 0, 2, 3] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.transpose_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.transpose_reshard_task->perm, (std::vector{1, 0, 2, 3})); -} - -TEST_F(TensorParallelAttrsTest, ParseReshapeTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "Reshape", - "shape": [1, 1, 2, 3] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.reshape_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.reshape_reshard_task->shape, (std::vector{1, 1, 2, 3})); -} - -TEST_F(TensorParallelAttrsTest, ParseCastTaskInfo) { - const std::string &json_str = - R"( -{ - "task_type": "Cast", - "dst_type": 1 -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.cast_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.cast_reshard_task->dst_type, DT_FLOAT16); -} - -TEST_F(TensorParallelAttrsTest, ParseModifyValueCommTask) { - const std::string &json_str = R"( -{ - "task_type": "ModifyValue", - "op_type": "Mul", - "value": [1, 2] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(out_comm_task.modify_value_reshard_task != nullptr); - ASSERT_EQ(out_comm_task.modify_value_reshard_task->op_type, "Mul"); - ASSERT_EQ(out_comm_task.modify_value_reshard_task->value, (std::vector{1, 2})); -} - -TEST_F(TensorParallelAttrsTest, ParseBroadcastCommTask) { - const std::string &json_str = - R"( -{ - "task_type": "HcomBroadcast", - "roots": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 6]} - ], - "comm_groups": [ - [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 4]}, - {"engine_type": "NPU", "index": [0, 0, 5]} - ], - [ - {"engine_type": "NPU", "index": [0, 0, 6]}, - {"engine_type": "NPU", "index": [0, 0, 7]} - ] - ] -} -)"; - CommTask comm_task; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_task), SUCCESS); - - CommTask out_comm_task; - TestToAndFromJson(comm_task, out_comm_task); - ASSERT_TRUE(comm_task.broadcast_reshard_task != nullptr); - std::vector root_device_index; - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 0}}); - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 2}}); - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 4}}); - root_device_index.emplace_back(DeviceIndex{"NPU", {0, 0, 6}}); - EXPECT_EQ(out_comm_task.broadcast_reshard_task->root_device_indices, root_device_index); - EXPECT_EQ(out_comm_task.broadcast_reshard_task->comm_groups.size(), 4); -} - -TEST_F(TensorParallelAttrsTest, ParseParseCommStep) { - const std::string &json_str = - R"( -{ - "id": 2, - "input_ids": [[0, 0], [1, 0]], - "comm_task": { - "task_type": "Split", - "num_split": 2, - "split_dim": 1 - } -} - )"; - CommStep comm_step; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, comm_step), SUCCESS); - EXPECT_TRUE(comm_step.comm_task.split_reshard_task != nullptr); - EXPECT_EQ(comm_step.id, 2); -} - -TEST_F(TensorParallelAttrsTest, ParseTensorReshardInfo) { - const std::string &json_str = - R"( -{ - "output_index": 1, - "device_list": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - "comm_steps": [ - { - "id": 1, - "input_ids": [], - "comm_task": { - "task_type": "SplitVD", - "size_splits": [2, 4], - "split_dim": 1 - } - }, - { - "id": 2, - "input_ids": [[1, 0]], - "comm_task": { - "task_type": "SplitVD", - "size_splits": [2, 4], - "split_dim": 1 - } - } - ], - "peer_inputs": [ - {"step_id": 1, "node_name": "dst_node", "input_index": 0, "stage_id": 0, "virtual_stage_id": 0}, - {"step_id": 1, "node_name": "dst_node_1", "input_index": 1, "stage_id": 0, "virtual_stage_id": 0} - ], - "stage_id":1, - "virtual_stage_id":2 -} -)"; - OutputReshardRes tensor_reshard_info; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, tensor_reshard_info), SUCCESS); - EXPECT_EQ(tensor_reshard_info.comm_steps.size(), 2); - ASSERT_EQ(tensor_reshard_info.peer_inputs.size(), 2); - EXPECT_EQ(tensor_reshard_info.peer_inputs[0].step_id, 1); - EXPECT_EQ(tensor_reshard_info.peer_inputs[0].node_name, "dst_node"); - EXPECT_EQ(tensor_reshard_info.peer_inputs[0].input_index, 0); - EXPECT_EQ(tensor_reshard_info.peer_inputs[1].node_name, "dst_node_1"); - EXPECT_EQ(tensor_reshard_info.peer_inputs[1].step_id, 1); - EXPECT_EQ(tensor_reshard_info.peer_inputs[1].input_index, 1); - EXPECT_EQ(tensor_reshard_info.stage_id, 1); - EXPECT_EQ(tensor_reshard_info.virtual_stage_id, 2); -} - -TEST_F(TensorParallelAttrsTest, ReshardAttrToAndFromJson) { - const std::string &json_str = - R"( -[ - [ - { - "device_list": [ - {"engine_type": "NPU", "index": [0, 0, 0]}, - {"engine_type": "NPU", "index": [0, 0, 2]}, - {"engine_type": "NPU", "index": [0, 0, 3]}, - {"engine_type": "NPU", "index": [0, 0, 1]} - ], - "comm_steps": [ - { - "id": 1, - "input_ids": [], - "comm_task": { - "task_type": "Split", - "num_split": 3, - "split_dim": 1 - } - }, - { - "id": 2, - "input_ids": [[1, 0]], - "comm_task": { - "task_type": "Split", - "num_split": 4, - "split_dim": 1 - } - } - ], - "peer_inputs": [ - {"step_id": 1, "node_name": "dst_node", "input_index": 0, "stage_id": 0, "virtual_stage_id": 0}, - {"step_id": 1, "node_name": "dst_node_1", "input_index": 1, "stage_id": 0, "virtual_stage_id": 0} - ] - } - ] -] -)"; - ReshardAttr reshard_attr; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, reshard_attr), SUCCESS); - ASSERT_EQ(reshard_attr.reshard_infos.size(), 1); - auto str = TensorParallelAttrs::ToJson(reshard_attr); - ASSERT_EQ(TensorParallelAttrs::FromJson(str, reshard_attr), SUCCESS); -} - -TEST_F(TensorParallelAttrsTest, TensorDeploymentToAndFromJson) { - const std::string &json_str = - R"( -{ - "shard_deployment": { - "device_indices_each_slice": [ - [{"engine_type": "NPU", "index": [0, 0, 0]}], - [{"engine_type": "NPU", "index": [0, 0, 1]}], - [{"engine_type": "NPU", "index": [0, 0, 2]}], - [{"engine_type": "NPU", "index": [0, 0, 3]}] - ], - "axis_slices": [ - [[0, 2], [2, 4]], - [[0, 4], [4, 8]] - ] - }, - "verbose" : "verbose_val" -} -)"; - TensorDeployment tensor_deployment; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, tensor_deployment), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(tensor_deployment); - - TensorDeployment tensor_deployment_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, tensor_deployment_from_json), SUCCESS); - const auto &tensor_slice_deployment = tensor_deployment_from_json.shard_deployment; - EXPECT_EQ(tensor_slice_deployment.device_indices_each_slice.size(), 4); - EXPECT_EQ(tensor_slice_deployment.axis_slices.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, TensorDeploymentsToAndFromJson) { - const std::string &json_str = - R"( -{ - "deployments": [ - [1, { - "shard_deployment": { - "axis_slices": [ - [[0, 2], [2, 4]], - [[0, 4], [4, 8]] - ], - "device_indices_each_slice": [ - [{"engine_type": "NPU", "index": [0, 0, 0]}], - [{"engine_type": "NPU", "index": [0, 0, 1]}], - [{"engine_type": "NPU", "index": [0, 0, 2]}], - [{"engine_type": "NPU", "index": [0, 0, 3]}] - ] - }, - "verbose": "verbose_val" - }], - [2, { - "shard_deployment": { - "axis_slices": [ - [[0, 2], [2, 4]], - [[0, 4], [4, 8]] - ], - "device_indices_each_slice": [ - [{"engine_type": "NPU", "index": [0, 1, 0]}], - [{"engine_type": "NPU", "index": [0, 1, 1]}], - [{"engine_type": "NPU", "index": [0, 1, 2]}], - [{"engine_type": "NPU", "index": [0, 1, 3]}] - ] - }, - "verbose": "verbose_val" - }] - ] -} -)"; - TensorDeployments tensor_deployments; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, tensor_deployments), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(tensor_deployments); - ASSERT_EQ(tensor_deployments.deployments.size(), 2); - - TensorDeployments tensor_deployments_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, tensor_deployments_from_json), SUCCESS); - const auto &tensor_slice_deployment = tensor_deployments_from_json.deployments[1].shard_deployment; - EXPECT_EQ(tensor_slice_deployment.device_indices_each_slice.size(), 4); - EXPECT_EQ(tensor_slice_deployment.axis_slices.size(), 2); -} - -TEST_F(TensorParallelAttrsTest, NodeDeploymentsToAndFromJson) { - const std::string &json_str = - R"( -{ - "deployments": [ - [1, { - "devices": [{ - "engine_type": "CPU", - "index": [0, 0, 1] - }, { - "engine_type": "NPU", - "index": [0, 0, 3] - }], - "pipeline_config": { - "micro_batch": 1, - "stage_id": 0, - "virtual_stage_id": [] - } - }], - [2, { - "devices": [{ - "engine_type": "", - "index": [] - }], - "pipeline_config": { - "micro_batch": 1, - "stage_id": 0, - "virtual_stage_id": [] - } - }] - ] -} -)"; - NodeDeployments node_deployments; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, node_deployments), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(node_deployments); - ASSERT_EQ(node_deployments.deployments.size(), 2); - - NodeDeployments node_deployments_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, node_deployments_from_json), SUCCESS); - const auto &devices = node_deployments_from_json.deployments[1].devices; - EXPECT_EQ(devices.size(), 2); - EXPECT_TRUE(devices[0].engine_type == "CPU"); - EXPECT_TRUE(devices[1].engine_type == "NPU"); -} - -TEST_F(TensorParallelAttrsTest, ShardGraphExtAttrsToAndFromJson) { - const std::string &json_str = - R"( -{ - "dev_index_to_logic_dev_id": [ - [{ - "engine_type": "NPU", - "index": [0, 0, 0] - }, - [1, 0, 0] - ], - [{ - "engine_type": "NPU", - "index": [0, 0, 1] - }, - [1, 0, 1] - ] - ], - "graph_name_to_endpoints": { - "test_graph1": { - "endpoint1": ["SerializedString1"], - "endpoint2": ["SerializedString2"] - }, - "test_graph2": { - "endpoint1": ["SerializedString3"], - "endpoint2": ["SerializedString4"] - } - }, - "group_name_to_dev_ids": { - "group1": ["0:0:0:0", "0:0:1:0"], - "group2": ["0:0:0:1", "0:0:1:1"] - } -} -)"; - ShardGraphExtAttrs shard_graph_ext_attrs; - ASSERT_EQ(TensorParallelAttrs::FromJson(json_str, shard_graph_ext_attrs), SUCCESS); - const auto str = TensorParallelAttrs::ToJson(shard_graph_ext_attrs); - ShardGraphExtAttrs shard_graph_ext_attrs_from_json; - ASSERT_EQ(TensorParallelAttrs::FromJson(str, shard_graph_ext_attrs_from_json), SUCCESS); - EXPECT_EQ(shard_graph_ext_attrs_from_json.graph_name_to_endpoints, shard_graph_ext_attrs.graph_name_to_endpoints); - EXPECT_EQ(shard_graph_ext_attrs_from_json.dev_index_to_logic_dev_id, - shard_graph_ext_attrs.dev_index_to_logic_dev_id); - EXPECT_EQ(shard_graph_ext_attrs_from_json.group_name_to_dev_ids, shard_graph_ext_attrs.group_name_to_dev_ids); -} - -TEST_F(TensorParallelAttrsTest, StructCmp) { - SrcNodeInfo src_node_info; - src_node_info.inserted_node_id = 0; - src_node_info.output_index = 0; - SrcNodeInfo src_node_info1; - src_node_info1.inserted_node_id = 0; - src_node_info1.output_index = 0; - EXPECT_EQ(src_node_info == src_node_info1, true); - SrcNodeInfo src_node_info2; - src_node_info2.inserted_node_id = 0; - src_node_info2.output_index = 1; - EXPECT_EQ(src_node_info < src_node_info2, true); - SrcNodeInfo src_node_info3; - src_node_info3.inserted_node_id = 1; - src_node_info3.output_index = 1; - EXPECT_EQ(src_node_info < src_node_info3, true); - EXPECT_EQ(src_node_info3 < src_node_info1, false); - - OrigNodeInfo orig_node_info; - orig_node_info.node_name = "node"; - orig_node_info.sliced_id = 0; - DstNodeInfo dst_node_info; - dst_node_info.orig_node_info = orig_node_info; - dst_node_info.input_indexes = {0}; - InsertedNodeInput inserted_node_input; - inserted_node_input.orig_node_info = orig_node_info; - inserted_node_input.input_info = src_node_info; - PeerOutNodeInfo peer_out_node_info; - peer_out_node_info.input_info = src_node_info; - peer_out_node_info.node_info = dst_node_info; - - OrigNodeInfo orig_node_info1; - orig_node_info1.node_name = "node"; - orig_node_info1.sliced_id = 0; - DstNodeInfo dst_node_info1; - dst_node_info1.orig_node_info = orig_node_info1; - dst_node_info1.input_indexes = {0}; - InsertedNodeInput inserted_node_input1; - inserted_node_input1.orig_node_info = orig_node_info1; - inserted_node_input1.input_info = src_node_info1; - PeerOutNodeInfo peer_out_node_info1; - peer_out_node_info1.input_info = src_node_info1; - peer_out_node_info1.node_info = dst_node_info1; - - EXPECT_EQ(orig_node_info1 == orig_node_info, true); - EXPECT_EQ(dst_node_info == dst_node_info1, true); - EXPECT_EQ(inserted_node_input == inserted_node_input1, true); - EXPECT_EQ(peer_out_node_info == peer_out_node_info1, true); - - OrigNodeInfo orig_node_info2; - orig_node_info2.node_name = "node1"; - orig_node_info2.sliced_id = 0; - DstNodeInfo dst_node_info2; - dst_node_info2.orig_node_info = orig_node_info2; - dst_node_info2.input_indexes = {0}; - InsertedNodeInput inserted_node_input2; - inserted_node_input2.orig_node_info = orig_node_info2; - inserted_node_input2.input_info = src_node_info2; - PeerOutNodeInfo peer_out_node_info2; - peer_out_node_info2.input_info = src_node_info2; - peer_out_node_info2.node_info = dst_node_info2; - - EXPECT_EQ(orig_node_info < orig_node_info2, true); - EXPECT_EQ(orig_node_info2 < orig_node_info, false); - EXPECT_EQ(dst_node_info < dst_node_info2, true); - EXPECT_EQ(dst_node_info2 < dst_node_info, false); - EXPECT_EQ(inserted_node_input < inserted_node_input2, true); - EXPECT_EQ(inserted_node_input2 < inserted_node_input, false); - EXPECT_EQ(peer_out_node_info < peer_out_node_info2, true); - EXPECT_EQ(peer_out_node_info2 < peer_out_node_info, false); - - OrigNodeInfo orig_node_info3; - orig_node_info3.node_name = "node"; - orig_node_info3.sliced_id = 1; - DstNodeInfo dst_node_info3; - dst_node_info3.orig_node_info = orig_node_info3; - dst_node_info3.input_indexes = {0, 1}; - InsertedNodeInput inserted_node_input3; - inserted_node_input3.orig_node_info = orig_node_info3; - inserted_node_input3.input_info = src_node_info3; - PeerOutNodeInfo peer_out_node_info3; - peer_out_node_info3.input_info = src_node_info3; - peer_out_node_info3.node_info = dst_node_info3; - - EXPECT_EQ(orig_node_info < orig_node_info3, true); - EXPECT_EQ(dst_node_info < dst_node_info3, true); - EXPECT_EQ(inserted_node_input < inserted_node_input3, true); - EXPECT_EQ(peer_out_node_info < peer_out_node_info3, true); -} -} // namespace tp -} // namespace ge diff --git a/tests/ut/graph/testcase/tensor_unittest.cc b/tests/ut/graph/testcase/tensor_unittest.cc deleted file mode 100644 index b9870ca5187be6fd81b68279f8c7dddaa363d080..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/tensor_unittest.cc +++ /dev/null @@ -1,441 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "graph/ge_tensor.h" -#include "ge_ir.pb.h" -#include "graph/utils/tensor_utils.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "external/graph/tensor.h" -#include -#include "graph/utils/tensor_adapter.h" - -namespace ge { -class TensorUtilsUT : public testing::Test { - protected: - void SetUp() { - } - void TearDown() {} -}; - -TEST_F(TensorUtilsUT, CopyConstruct1_NullTensorDef) { - GeTensor t1; - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - std::cout << "test1" << std::endl; - t1.SetData(vec); - GeTensor t2 = TensorUtils::CreateShareTensor(t1); - t1.impl_->tensor_def_.GetProtoOwner(); -// The copy construct share tensor_data_, do not share tensor_desc - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t1.impl_->tensor_data_.impl_->tensor_descriptor_, t1.impl_->desc_.impl_); - ASSERT_EQ(t2.impl_->tensor_data_.impl_->tensor_descriptor_, t2.impl_->desc_.impl_); - ASSERT_EQ(t1.impl_->tensor_data_.GetData(), t2.impl_->tensor_data_.GetData()); - - t1.MutableTensorDesc().SetFormat(FORMAT_NCHW); - t2.MutableTensorDesc().SetFormat(FORMAT_NHWC); - ASSERT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NCHW); - ASSERT_EQ(t2.GetTensorDesc().GetFormat(), FORMAT_NHWC); - - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); -} - -TEST_F(TensorUtilsUT, CopyConstruct2_WithTensorDef) { - GeIrProtoHelper helper; - helper.InitDefault(); - helper.GetProtoMsg()->mutable_data()->resize(100); - GeTensor t1(helper.GetProtoOwner(), helper.GetProtoMsg()); - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2 = TensorUtils::CreateShareTensor(t1); - - // The copy construct share tensor_data_ and tensor_desc - ASSERT_NE(t1.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_NE(t1.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t1.impl_->tensor_data_.impl_->tensor_descriptor_, t1.impl_->desc_.impl_); - ASSERT_EQ(t2.impl_->tensor_data_.impl_->tensor_descriptor_, t2.impl_->desc_.impl_); - ASSERT_EQ(t1.impl_->tensor_data_.GetData(), t2.impl_->tensor_data_.GetData()); - - t1.MutableTensorDesc().SetFormat(FORMAT_NCHW); - ASSERT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NCHW); - ASSERT_EQ(t2.GetTensorDesc().GetFormat(), FORMAT_NCHW); - t2.MutableTensorDesc().SetFormat(FORMAT_NHWC); - ASSERT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NHWC); - ASSERT_EQ(t2.GetTensorDesc().GetFormat(), FORMAT_NHWC); - - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); -} - -TEST_F(TensorUtilsUT, SetData_CreateShareTensorWithTensorDef) { - GeIrProtoHelper helper; - helper.InitDefault(); - helper.GetProtoMsg()->mutable_data()->resize(100); - GeTensor t1(helper.GetProtoOwner(), helper.GetProtoMsg()); - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2 = TensorUtils::CreateShareTensor(t1); - - std::vector vec2; - for (uint8_t i = 0; i < 100; ++i) { - vec2.push_back(i); - } - t2.SetData(vec2); - ASSERT_EQ(memcmp(t2.GetData().GetData(), vec2.data(), vec2.size()), 0); - // todo 这里存在bug,但是从目前来看,并没有被触发,因此暂时不修复了,重构后一起修复。 - // 触发bug的场景为:如果tensor1是通过tensor_def_持有TensorData,然后通过拷贝构造、拷贝赋值的方式,从tensor1构造了tensor2。 - // 那么通过tensor2.SetData后,会导致tensor1的GetData接口失效(获取到野指针) - // 触发的表现就是,如下两条ASSERT_EQ并不成立 - // ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); - // ASSERT_EQ(memcmp(t1.GetData().GetData(), vec2.data(), vec2.size()), 0); -} - -TEST_F(TensorUtilsUT, SetData_CreateShareTensorWithoutTensorDef) { - GeTensor t1; - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2 = TensorUtils::CreateShareTensor(t1); - - std::vector vec3; - for (uint8_t i = 0; i < 100; ++i) { - vec3.push_back(i); - } - t2.SetData(vec3); - ASSERT_EQ(t2.GetData().size(), vec3.size()); - ASSERT_EQ(memcmp(t2.GetData().GetData(), vec3.data(), vec3.size()), 0); - ASSERT_EQ(t1.GetData().size(), vec3.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec3.data(), vec3.size()), 0); - ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); - - std::vector vec2; - for (uint8_t i = 0; i < 105; ++i) { - vec2.push_back(i); - } - t2.SetData(vec2); - ASSERT_EQ(t2.GetData().size(), vec2.size()); - ASSERT_EQ(memcmp(t2.GetData().GetData(), vec2.data(), vec2.size()), 0); - // after modify the data of t2 using a different size buffer, the t1 will not be modified - ASSERT_EQ(t1.GetData().size(), vec3.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec3.data(), vec3.size()), 0); - ASSERT_NE(t1.GetData().GetData(), t2.GetData().GetData()); -} - -TEST_F(TensorUtilsUT, CreateShareTensorFromSharedPtr) { - auto ap = std::make_shared(100); - for (uint8_t i = 0; i < 100; ++i) { - ap->MutableGet()[i] = i; - } - GeTensorDesc td; - GeTensor t1 = TensorUtils::CreateShareTensor(td, ap, 100); - ASSERT_EQ(t1.GetData().GetData(), ap->MutableGet()); - ASSERT_EQ(t1.GetData().size(), 100); - - GeTensor t2(td, ap, 100); - ASSERT_EQ(t2.GetData().GetData(), ap->MutableGet()); - ASSERT_EQ(t2.GetData().size(), 100); -} - -TEST_F(TensorUtilsUT, ShareTensorData) { - auto ap = std::make_shared(100); - for (uint8_t i = 0; i < 100; ++i) { - ap->MutableGet()[i] = i; - } - GeTensorDesc td; - - GeTensor t1(td); - t1.SetData(ap, 100); - ASSERT_EQ(t1.GetData().GetData(), ap->MutableGet()); - ASSERT_EQ(t1.GetData().size(), 100); - - GeTensor t2(td); - TensorUtils::ShareAlignedPtr(ap, 100, t2); - ASSERT_EQ(t2.GetData().GetData(), ap->MutableGet()); - ASSERT_EQ(t2.GetData().size(), 100); -} - -TEST_F(TensorUtilsUT, CopyAssign_NullTensorDef) { - GeTensor t1; - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2; - TensorUtils::ShareTensor(t1, t2); - - // The copy construct share tensor_data_, do not share tensor_desc - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t1.impl_->tensor_data_.impl_->tensor_descriptor_, t1.impl_->desc_.impl_); - ASSERT_EQ(t2.impl_->tensor_data_.impl_->tensor_descriptor_, t2.impl_->desc_.impl_); - ASSERT_EQ(t1.impl_->tensor_data_.GetData(), t2.impl_->tensor_data_.GetData()); - - t1.MutableTensorDesc().SetFormat(FORMAT_NCHW); - t2.MutableTensorDesc().SetFormat(FORMAT_NHWC); - ASSERT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NCHW); - ASSERT_EQ(t2.GetTensorDesc().GetFormat(), FORMAT_NHWC); - - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); -} - -TEST_F(TensorUtilsUT, CopyConstruct3_TensorData) { - std::vector vec; - for (uint8_t i = 0; i < 200; ++i) { - vec.push_back(i); - } - TensorData td1; - td1.SetData(vec); - - TensorData td2(td1); - ASSERT_EQ(td1.GetData(), td2.GetData()); - ASSERT_EQ(td1.GetSize(), td2.GetSize()); - ASSERT_EQ(td1.GetSize(), 200); - - TensorData td3 = TensorUtils::CreateShareTensorData(td1); - ASSERT_EQ(td1.GetData(), td3.GetData()); - ASSERT_EQ(td1.GetSize(), td3.GetSize()); - ASSERT_EQ(td1.GetSize(), 200); -} - -TEST_F(TensorUtilsUT, CopyAssign_TensorData) { - std::vector vec; - for (uint8_t i = 0; i < 200; ++i) { - vec.push_back(i); - } - TensorData td1; - td1.SetData(vec); - - TensorData td2 = td1; - ASSERT_EQ(td1.GetData(), td2.GetData()); - ASSERT_EQ(td1.GetSize(), td2.GetSize()); - ASSERT_EQ(td1.GetSize(), 200); - - TensorData td3; - TensorUtils::ShareTensorData(td1, td3); - ASSERT_EQ(td1.GetData(), td3.GetData()); - ASSERT_EQ(td1.GetSize(), td3.GetSize()); - ASSERT_EQ(td1.GetSize(), 200); -} - -TEST_F(TensorUtilsUT, SetData_ShareAlignedPtr_TensorData) { - std::vector vec; - for (uint8_t i = 0; i < 200; ++i) { - vec.push_back(i); - } - auto ap = std::make_shared(vec.size()); - memcpy_s(ap->MutableGet(), vec.size(), vec.data(), vec.size()); - - TensorData td1; - td1.SetData(ap, vec.size()); - ASSERT_EQ(td1.GetData(), ap->MutableGet()); - ASSERT_EQ(td1.GetSize(), 200); - - TensorData td2; - TensorUtils::ShareAlignedPtr(ap, vec.size(), td2); - ASSERT_EQ(td2.GetData(), ap->MutableGet()); - ASSERT_EQ(td2.GetSize(), 200); -} - -TEST_F(TensorUtilsUT, ShareTheSame) { - std::vector vec; - for (uint8_t i = 0; i < 200; ++i) { - vec.push_back(i); - } - TensorData td1; - td1.SetData(vec); - TensorUtils::ShareTensorData(td1, td1); - ASSERT_EQ(memcmp(td1.GetData(), vec.data(), vec.size()), 0); - ASSERT_EQ(td1.GetSize(), 200); - - GeTensorDesc tensor_desc; - GeTensor t1(tensor_desc, vec); - TensorUtils::ShareTensor(t1, t1); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); -} - -TEST_F(TensorUtilsUT, ConstData) { - std::unique_ptr const_data = std::unique_ptr(new (std::nothrow) uint8_t[10]); - ASSERT_NE(const_data, nullptr); - TensorDesc tensor_desc; - tensor_desc.SetFormat(FORMAT_NCHW); - tensor_desc.SetConstData(std::move(const_data), sizeof(int)); - uint8_t *ret = nullptr; - size_t len = 0; - tensor_desc.GetConstData(&ret, len); - printf("GetConstData1====%p\n", ret); - ASSERT_NE(ret, nullptr); - ASSERT_EQ(sizeof(int), len); - - // operator= - tensor_desc = tensor_desc; - TensorDesc tensor_desc1; - tensor_desc1 = tensor_desc; - uint8_t *ret1 = nullptr; - size_t len1 = 0; - tensor_desc1.GetConstData(&ret1, len1); - printf("GetConstData2 ====%p\n", ret1); - ASSERT_NE(ret1, nullptr); - ASSERT_NE(ret1, ret); - ASSERT_EQ(sizeof(int), len1); - ASSERT_EQ(tensor_desc1.GetFormat(), FORMAT_NCHW); - - // copy - std::size_t big_size = SECUREC_MEM_MAX_LEN+1; - std::unique_ptr big_data = std::unique_ptr(new (std::nothrow) uint8_t[big_size]); - TensorDesc tensor_desc2; - tensor_desc2.SetFormat(FORMAT_NCHW); - tensor_desc2.SetConstData(std::move(big_data), big_size); - TensorDesc tensor_desc3(tensor_desc2); - uint8_t *ret2 = nullptr; - size_t len2 = 0; - tensor_desc3.GetConstData(&ret2, len2); - printf("GetConstData3 ====%p\n", ret2); - ASSERT_NE(ret2, nullptr); - ASSERT_NE(ret2, ret); - ASSERT_EQ(big_size, len2); - ASSERT_EQ(tensor_desc3.GetFormat(), FORMAT_NCHW); -} -TEST_F(TensorUtilsUT, GetShapeSize_Ok_VectorMax) { - Shape shape({std::numeric_limits::max()}); - EXPECT_EQ(shape.GetShapeSize(), std::numeric_limits::max()); -} -TEST_F(TensorUtilsUT, GetShapeSize_ReturnZero_Overflow) { - Shape shape({2, std::numeric_limits::max() - 1}); - EXPECT_EQ(shape.GetShapeSize(), 0); -} -TEST_F(TensorUtilsUT, TensorConstruct_IsValid_Overflow) { - Shape shape({std::numeric_limits::max()}); - TensorDesc td; - td.SetDataType(DT_FLOAT); - td.SetShape(shape); - td.SetOriginShape(shape); - Tensor tensor(td, {}); - - // todo 这个行为挺奇怪的,即使发生了overflow,仍然返回success,不过历史实现一直是这样,不敢修改这个行为 - ASSERT_EQ(tensor.IsValid(), ge::GRAPH_SUCCESS); -} - -TEST_F(TensorUtilsUT, TensorSetAndGetMetaInfoGeneral) { - Tensor tensor; - tensor.SetOriginFormat(ge::FORMAT_NCHW); - EXPECT_EQ(tensor.GetOriginFormat(), ge::FORMAT_NCHW); - - tensor.SetFormat(ge::FORMAT_NC1HWC0); - EXPECT_EQ(tensor.GetFormat(), ge::FORMAT_NC1HWC0); - - tensor.SetDataType(ge::DT_BF16); - EXPECT_EQ(tensor.GetDataType(), ge::DT_BF16); - - tensor.SetOriginShapeDimNum(4); - EXPECT_EQ(tensor.GetOriginShapeDimNum(), 4); - for (size_t i = 0U; i < tensor.GetOriginShapeDimNum(); ++i) { - tensor.SetOriginShapeDim(i, i); - } - for (size_t i = 0U; i < tensor.GetOriginShapeDimNum(); ++i) { - EXPECT_EQ(tensor.GetOriginShapeDim(i), i); - } - tensor.SetShapeDimNum(5); - EXPECT_EQ(tensor.GetShapeDimNum(), 5); - for (size_t i = 0U; i < tensor.GetShapeDimNum(); ++i) { - tensor.SetShapeDim(i, i); - } - for (size_t i = 0U; i < tensor.GetShapeDimNum(); ++i) { - EXPECT_EQ(tensor.GetShapeDim(i), i); - } - - EXPECT_EQ(tensor.SetPlacement(Placement::kPlacementDevice), ge::GRAPH_SUCCESS); - EXPECT_EQ(tensor.GetPlacement(), Placement::kPlacementDevice); - - EXPECT_EQ(tensor.SetExpandDimsRule("0011"), ge::GRAPH_SUCCESS); - AscendString str; - EXPECT_EQ(tensor.GetExpandDimsRule(str), ge::GRAPH_SUCCESS); - EXPECT_EQ(str, "0011"); -} - -TEST_F(TensorUtilsUT, TensorSetAndResetData) { - Tensor tensor; - EXPECT_EQ(tensor.ResetData(nullptr, 0UL, [](uint8_t *ptr) {delete[] ptr;}), ge::GRAPH_SUCCESS); - - uint8_t *data_ptr = new uint8_t[10]; - EXPECT_EQ(tensor.ResetData(data_ptr, 10,[](uint8_t *ptr) {delete[] ptr;}), ge::GRAPH_SUCCESS); - EXPECT_EQ(tensor.GetData(), data_ptr); - - uint8_t *data_ptr2 = new uint8_t[20]; - EXPECT_EQ(tensor.ResetData(data_ptr2, 20, [](uint8_t *ptr) { delete[] ptr; }), ge::GRAPH_SUCCESS); - EXPECT_EQ(tensor.GetData(), data_ptr2); - tensor.ResetData().reset(); -} - -TEST_F(TensorUtilsUT, TensorSetAndGetMetaInfoAbnormal) { - Tensor tensor; - tensor.impl = nullptr; - - EXPECT_EQ(tensor.SetOriginFormat(ge::FORMAT_NCHW), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetOriginFormat(), ge::FORMAT_RESERVED); - - EXPECT_EQ(tensor.SetFormat(ge::FORMAT_NC1HWC0), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetFormat(), ge::FORMAT_RESERVED); - - EXPECT_EQ(tensor.SetDataType(ge::DT_BF16), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetDataType(), ge::DT_UNDEFINED); - - EXPECT_EQ(tensor.SetOriginShapeDimNum(4), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetOriginShapeDimNum(), 0); - EXPECT_EQ(tensor.SetOriginShapeDim(0, 0), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetOriginShapeDim(0), 0); - EXPECT_EQ(tensor.SetShapeDimNum(5), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetShapeDimNum(), 0); - EXPECT_EQ(tensor.SetShapeDim(0, 0), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetShapeDim(0), 0); - - EXPECT_EQ(tensor.SetPlacement(Placement::kPlacementDevice), ge::GRAPH_FAILED); - EXPECT_EQ(tensor.GetPlacement(), Placement::kPlacementEnd); - - EXPECT_EQ(tensor.SetExpandDimsRule("0011"), ge::GRAPH_FAILED); - AscendString str; - EXPECT_EQ(tensor.GetExpandDimsRule(str), ge::GRAPH_FAILED); - EXPECT_NE(str, "0011"); - - uint8_t *data_ptr = new uint8_t[20]; - EXPECT_NE(tensor.ResetData(data_ptr, 20, [](uint8_t *ptr) { delete[] ptr; }), ge::GRAPH_SUCCESS); - delete[] data_ptr; -} - -TEST_F(TensorUtilsUT, SetReuseInputIndex) { - TensorDesc tensor_desc; - tensor_desc.SetReuseInputIndex(1); - auto ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); - bool reuse_flag = false; - uint32_t reuse_index = 0; - TensorUtils::GetReuseInput(ge_tensor_desc, reuse_flag); - TensorUtils::GetReuseInputIndex(ge_tensor_desc, reuse_index); - EXPECT_EQ(reuse_flag, true); - EXPECT_EQ(reuse_index, 1); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/tensor_ut.cc b/tests/ut/graph/testcase/tensor_ut.cc deleted file mode 100644 index 148502b602644f4041d24a9e9d92ca354c79a051..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/tensor_ut.cc +++ /dev/null @@ -1,794 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include -#include "graph/ge_tensor.h" -#include "ge_ir.pb.h" -#include "graph/debug/ge_util.h" -#include "graph/normal_graph/ge_tensor_impl.h" -#include "graph/utils/tensor_adapter.h" -#include "graph/utils/tensor_utils.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" - -namespace ge { -class TensorUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(TensorUT, SetData1NoShare) { - GeTensor t1; - std::vector vec; - for (uint8_t i = 0; i < 150; ++i) { - vec.push_back(i); - } - ASSERT_EQ(t1.SetData(vec), GRAPH_SUCCESS); - ASSERT_EQ(t1.GetData().GetSize(), vec.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - t1.MutableData().GetData()[10] = 250; - ASSERT_NE(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - - std::vector vec2; - for (uint8_t i = 0; i < 105; ++i) { - vec2.push_back(i * 2); - } - vec = vec2; - ASSERT_EQ(t1.SetData(std::move(vec2)), GRAPH_SUCCESS); - ASSERT_EQ(t1.GetData().GetSize(), vec.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - - vec.clear(); - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(100 - i); - } - Buffer buffer = Buffer::CopyFrom(vec.data(), vec.size()); - ASSERT_EQ(t1.SetData(buffer), GRAPH_SUCCESS); - ASSERT_EQ(t1.GetData().GetSize(), vec.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - - vec.clear(); - for (uint8_t i = 0; i < 150; ++i) { - vec.push_back(i); - } - ASSERT_EQ(t1.SetData(vec.data(), vec.size()), GRAPH_SUCCESS); - ASSERT_EQ(t1.GetData().GetSize(), vec.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - - vec.clear(); - for (uint8_t i = 0; i < 200; ++i) { - vec.push_back(200 - i); - } - TensorData td; - td.SetData(vec); - ASSERT_EQ(memcmp(td.GetData(), vec.data(), vec.size()), 0); - ASSERT_EQ(t1.SetData(td), GRAPH_SUCCESS); - ASSERT_EQ(t1.GetData().GetSize(), vec.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); -} - -TEST_F(TensorUT, Construct1_General) { - GeTensor t1; - ASSERT_EQ(t1.impl_->desc_.impl_, t1.GetData().impl_->tensor_descriptor_); - - GeTensorDesc td; - - GeIrProtoHelper helper; - helper.InitDefault(); - helper.GetProtoMsg()->mutable_data()->resize(200); - GeTensor t2(helper.GetProtoOwner(), helper.GetProtoMsg()); - ASSERT_NE(t2.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_NE(t2.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t2.impl_->tensor_data_.impl_->tensor_descriptor_, t2.impl_->desc_.impl_); - ASSERT_EQ(reinterpret_cast(t2.impl_->tensor_data_.GetData()), - t2.impl_->tensor_def_.GetProtoMsg()->data().data()); -} -TEST_F(TensorUT, Construct2_CopyDesc) { - EXPECT_NO_THROW( - GeTensorDesc desc; - GeTensor t1(desc); - ); -} -TEST_F(TensorUT, Construct3_ExceptionalScenes) { - GeIrProtoHelper helper; - helper.InitDefault(); - GeTensor t1(nullptr, helper.GetProtoMsg()); - GeTensor t2(helper.GetProtoOwner(), nullptr); - GeTensor t3(nullptr, nullptr); - - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoMsg(), helper.GetProtoMsg()); - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_EQ(t1.impl_->tensor_data_.impl_->tensor_descriptor_, t1.impl_->desc_.impl_); - - ASSERT_EQ(t2.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t2.impl_->tensor_def_.GetProtoOwner(), helper.GetProtoOwner()); - ASSERT_EQ(t2.impl_->tensor_data_.impl_->tensor_descriptor_, t2.impl_->desc_.impl_); - - ASSERT_EQ(t3.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t3.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_EQ(t3.impl_->tensor_data_.impl_->tensor_descriptor_, t3.impl_->desc_.impl_); -} -TEST_F(TensorUT, CopyConstruct1_NullTensorDef) { - GeTensor t1; - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2(t1); - - // The copy construct share tensor_data_, do not share tensor_desc - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_EQ(t1.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t1.impl_->tensor_data_.impl_->tensor_descriptor_, t1.impl_->desc_.impl_); - ASSERT_EQ(t2.impl_->tensor_data_.impl_->tensor_descriptor_, t2.impl_->desc_.impl_); - ASSERT_EQ(t1.impl_->tensor_data_.GetData(), t2.impl_->tensor_data_.GetData()); - - t1.MutableTensorDesc().SetFormat(FORMAT_NCHW); - t2.MutableTensorDesc().SetFormat(FORMAT_NHWC); - ASSERT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NCHW); - ASSERT_EQ(t2.GetTensorDesc().GetFormat(), FORMAT_NHWC); - - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); -} - -TEST_F(TensorUT, CopyConstruct2_WithTensorDef) { - GeIrProtoHelper helper; - helper.InitDefault(); - helper.GetProtoMsg()->mutable_data()->resize(100); - GeTensor t1(helper.GetProtoOwner(), helper.GetProtoMsg()); - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2(t1); - - // Copy construct should share tensordata only - ASSERT_NE(t1.impl_->tensor_def_.GetProtoOwner(), nullptr); - ASSERT_NE(t1.impl_->tensor_def_.GetProtoMsg(), nullptr); - ASSERT_EQ(t1.impl_->tensor_data_.impl_->tensor_descriptor_, t1.impl_->desc_.impl_); - ASSERT_EQ(t2.impl_->tensor_data_.impl_->tensor_descriptor_, t2.impl_->desc_.impl_); - ASSERT_EQ(t1.impl_->tensor_data_.GetData(), t2.impl_->tensor_data_.GetData()); - - t1.MutableTensorDesc().SetFormat(FORMAT_NCHW); - ASSERT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NCHW); - t2.MutableTensorDesc().SetFormat(FORMAT_NHWC); - ASSERT_EQ(t1.GetTensorDesc().GetFormat(), FORMAT_NCHW); - ASSERT_EQ(t2.GetTensorDesc().GetFormat(), FORMAT_NHWC); - - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec.data(), vec.size()), 0); - ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); -} - -TEST_F(TensorUT, SetData_SharedWithTensorDef) { - GeIrProtoHelper helper; - helper.InitDefault(); - helper.GetProtoMsg()->mutable_data()->resize(100); - GeTensor t1(helper.GetProtoOwner(), helper.GetProtoMsg()); - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2(t1); - - std::vector vec2; - for (uint8_t i = 0; i < 100; ++i) { - vec2.push_back(i); - } - t2.SetData(vec2); - ASSERT_EQ(memcmp(t2.GetData().GetData(), vec2.data(), vec2.size()), 0); - // todo 这里存在bug,但是从目前来看,并没有被触发,因此暂时不修复了,重构后一起修复。 - // 触发bug的场景为:如果tensor1是通过tensor_def_持有TensorData,然后通过拷贝构造、拷贝赋值的方式,从tensor1构造了tensor2。 - // 那么通过tensor2.SetData后,会导致tensor1的GetData接口失效(获取到野指针) - // 触发的表现就是,如下两条ASSERT_EQ并不成立 - // ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); - // ASSERT_EQ(memcmp(t1.GetData().GetData(), vec2.data(), vec2.size()), 0); -} - -TEST_F(TensorUT, SetData_SharedWithoutTensorDef) { - GeTensor t1; - - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - t1.SetData(vec); - GeTensor t2(t1); - - std::vector vec3; - for (uint8_t i = 0; i < 100; ++i) { - vec3.push_back(i); - } - t2.SetData(vec3); - ASSERT_EQ(t2.GetData().size(), vec3.size()); - ASSERT_EQ(memcmp(t2.GetData().GetData(), vec3.data(), vec3.size()), 0); - ASSERT_EQ(t1.GetData().size(), vec3.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec3.data(), vec3.size()), 0); - ASSERT_EQ(t1.GetData().GetData(), t2.GetData().GetData()); - - std::vector vec2; - for (uint8_t i = 0; i < 105; ++i) { - vec2.push_back(i); - } - t2.SetData(vec2); - ASSERT_EQ(t2.GetData().size(), vec2.size()); - ASSERT_EQ(memcmp(t2.GetData().GetData(), vec2.data(), vec2.size()), 0); - // after modify the data of t2 using a different size buffer, the t1 will not be modified - ASSERT_EQ(t1.GetData().size(), vec3.size()); - ASSERT_EQ(memcmp(t1.GetData().GetData(), vec3.data(), vec3.size()), 0); - ASSERT_NE(t1.GetData().GetData(), t2.GetData().GetData()); -} - -TEST_F(TensorUT, SetDataDelete_success) { - auto deleter = [](uint8_t *ptr) { - delete[] ptr; - ptr = nullptr; - }; - uint8_t *data_ptr = new uint8_t[10]; - GeTensor ge_tensor; - ge_tensor.SetData(data_ptr, 10, deleter); - auto length = ge_tensor.GetData().GetSize(); - ASSERT_EQ(length, 10); -} - -TEST_F(TensorUT, TensorSetDataDelete_success) { - auto deleter = [](uint8_t *ptr) { - delete[] ptr; - ptr = nullptr; - }; - uint8_t *data_ptr = new uint8_t[10]; - Tensor tensor; - EXPECT_EQ(tensor.SetData(data_ptr, 10, deleter), GRAPH_SUCCESS); - EXPECT_EQ(tensor.GetSize(), 10); -} - -TEST_F(TensorUT, TransTensorDescWithoutOriginShape2GeTensorDesc) { - TensorDesc desc(Shape({1, 2, 3, 4}), FORMAT_NCHW); - GeTensorDesc ge_desc = TensorAdapter::TensorDesc2GeTensorDesc(desc); - ASSERT_EQ(desc.GetFormat(), ge_desc.GetFormat()); - ASSERT_EQ(desc.GetShape().GetDims().size(), ge_desc.GetShape().GetDims().size()); - for (size_t i = 0; i < desc.GetShape().GetDims().size(); i++) { - ASSERT_EQ(desc.GetShape().GetDim(i), ge_desc.GetShape().GetDim(i)); - } - bool origin_format_is_set = false; - EXPECT_FALSE(AttrUtils::GetBool(ge_desc, ATTR_NAME_ORIGIN_FORMAT_IS_SET, origin_format_is_set)); -} - -TEST_F(TensorUT, TransTensorDescWithOriginShape2GeTensorDesc) { - TensorDesc desc(Shape({1, 2, 3, 4}), FORMAT_NCHW); - desc.SetOriginFormat(FORMAT_NHWC); - desc.SetOriginShape(Shape({1, 3, 4, 2})); - GeTensorDesc ge_desc = TensorAdapter::TensorDesc2GeTensorDesc(desc); - - ASSERT_EQ(desc.GetFormat(), ge_desc.GetFormat()); - ASSERT_EQ(desc.GetShape().GetDims().size(), ge_desc.GetShape().GetDims().size()); - for (size_t i = 0; i < desc.GetShape().GetDims().size(); i++) { - ASSERT_EQ(desc.GetShape().GetDim(i), ge_desc.GetShape().GetDim(i)); - } - - ASSERT_EQ(desc.GetOriginFormat(), ge_desc.GetOriginFormat()); - ASSERT_EQ(desc.GetOriginShape().GetDims().size(), ge_desc.GetOriginShape().GetDims().size()); - for (size_t i = 0; i < desc.GetOriginShape().GetDims().size(); i++) { - ASSERT_EQ(desc.GetOriginShape().GetDim(i), ge_desc.GetOriginShape().GetDim(i)); - } - bool origin_format_is_set = false; - EXPECT_TRUE(AttrUtils::GetBool(ge_desc, ATTR_NAME_ORIGIN_FORMAT_IS_SET, origin_format_is_set)); - EXPECT_TRUE(origin_format_is_set); -} - -TEST_F(TensorUT, NormalizeGeTensorWithOriginShape) { - TensorDesc desc(Shape({1, 2, 3, 4}), FORMAT_NCHW); - desc.SetOriginFormat(FORMAT_NHWC); - desc.SetOriginShape(Shape({1, 3, 4, 2})); - Tensor tensor(desc); - auto ge_tensor = TensorAdapter::AsGeTensor(tensor); - auto &ge_desc = ge_tensor.MutableTensorDesc(); - - bool origin_format_is_set = false; - EXPECT_TRUE(AttrUtils::GetBool(ge_desc, ATTR_NAME_ORIGIN_FORMAT_IS_SET, origin_format_is_set)); - EXPECT_TRUE(origin_format_is_set); - - auto normalized_ge_tensor = TensorAdapter::NormalizeGeTensor(ge_tensor); - auto &normalized_ge_desc = normalized_ge_tensor.MutableTensorDesc(); - - EXPECT_TRUE(AttrUtils::GetBool(normalized_ge_desc, ATTR_NAME_ORIGIN_FORMAT_IS_SET, origin_format_is_set)); - EXPECT_FALSE(origin_format_is_set); - - auto storage_format = static_cast(FORMAT_MAX); - EXPECT_TRUE(AttrUtils::GetInt(normalized_ge_desc, ATTR_NAME_STORAGE_FORMAT, storage_format)); - EXPECT_EQ(storage_format, static_cast(ge_desc.GetFormat())); - - std::vector storage_dims; - EXPECT_TRUE(AttrUtils::GetListInt(normalized_ge_desc, ATTR_NAME_STORAGE_SHAPE, storage_dims)); - EXPECT_EQ(storage_dims.size(), ge_desc.GetShape().GetDims().size()); - for (size_t i = 0; i < storage_dims.size(); i++) { - ASSERT_EQ(ge_desc.GetShape().GetDim(i), storage_dims[i]); - } - - EXPECT_EQ(ge_desc.GetOriginFormat(), normalized_ge_desc.GetFormat()); - ASSERT_EQ(ge_desc.GetOriginShape().GetDims().size(), normalized_ge_desc.GetShape().GetDims().size()); - for (size_t i = 0; i < ge_desc.GetOriginShape().GetDims().size(); i++) { - ASSERT_EQ(ge_desc.GetOriginShape().GetDim(i), normalized_ge_desc.GetShape().GetDim(i)); - } -} - -TEST_F(TensorUT, GeShapeSetDimNum) { - ge::GeShape shape; - EXPECT_EQ(shape.GetDimNum(), 0); - shape.SetDimNum(2); // Normal dim nums - EXPECT_EQ(shape.GetDimNum(), 2); - EXPECT_EQ(shape.GetDim(0), ge::UNKNOWN_DIM); - EXPECT_EQ(shape.GetDim(1), ge::UNKNOWN_DIM); - shape.SetDimNum(0); // Scalar dim nums - EXPECT_EQ(shape.GetDimNum(), 0); - shape.SetDimNum(20); // Big dim nums - EXPECT_EQ(shape.GetDimNum(), 20); - for (int i = 0; i < 20; i++) { - EXPECT_EQ(shape.GetDim(i), ge::UNKNOWN_DIM); - } -} - -TEST_F(TensorUT, GeShapeIsUnknownDimNum) { - ge::GeShape shape; - EXPECT_FALSE(shape.IsUnknownDimNum()); - shape.SetDimNum(2); - EXPECT_FALSE(shape.IsUnknownDimNum()); - shape.SetIsUnknownDimNum(); - EXPECT_TRUE(shape.IsUnknownDimNum()); - shape.SetDimNum(2); - EXPECT_FALSE(shape.IsUnknownDimNum()); -} - -TEST_F(TensorUT, GeShapeAppendDim) { - ge::GeShape shape; - EXPECT_EQ(shape.GetDimNum(), 0); - shape.AppendDim(1); - EXPECT_EQ(shape.GetDimNum(), 1); - EXPECT_EQ(shape.GetDim(0), 1); - shape.AppendDim(2); - EXPECT_EQ(shape.GetDimNum(), 2); - EXPECT_EQ(shape.GetDim(0), 1); - EXPECT_EQ(shape.GetDim(1), 2); - shape.SetIsUnknownDimNum(); - EXPECT_TRUE(shape.IsUnknownDimNum()); - shape.AppendDim(1); - EXPECT_FALSE(shape.IsUnknownDimNum()); -} - -TEST_F(TensorUT, GeTensorDescGetShape) { - ge::GeTensorDesc desc(ge::GeShape(std::vector({1, 2}))); - auto &shape = desc.GetShape(); - EXPECT_EQ(shape.GetDim(0), 1); - EXPECT_EQ(shape.GetDim(1), 2); - const_cast(&shape)->SetDim(0, 10); - const_cast(&shape)->SetDim(1, 20); - auto &shape2 = desc.GetShape(); - EXPECT_EQ(shape2.GetDim(0), 10); - EXPECT_EQ(shape2.GetDim(1), 20); -} - -TEST_F(TensorUT, GeTensorSerializeUtils_GeShape) { - GeShape shape({1, 2, 3, 4}); - proto::ShapeDef shape_proto; - GeTensorSerializeUtils::GeShapeAsProto(shape, &shape_proto); - GeShape shape_from_proto; - GeTensorSerializeUtils::AssembleGeShapeFromProto(&shape_proto, shape_from_proto); - EXPECT_EQ(shape, shape_from_proto); -} - -TEST_F(TensorUT, GeTensorSerializeUtils_GeTensorDesc) { - GeShape shape({1, 2, 3, 4}); - GeTensorDesc desc(shape, FORMAT_NC1HWC0, DT_FLOAT16); - desc.SetOriginDataType(DT_INT32); - desc.SetOriginFormat(FORMAT_NHWC1C0); - desc.SetOriginShape(GeShape({4, 3, 2, 1})); - proto::TensorDescriptor desc_proto; - GeTensorSerializeUtils::GeTensorDescAsProto(desc, &desc_proto); - GeTensorDesc desc_from_proto; - GeTensorSerializeUtils::AssembleGeTensorDescFromProto(&desc_proto, desc_from_proto); - bool res = false; - EXPECT_TRUE(AttrUtils::GetBool(desc_from_proto, "origin_shape_initialized", res)); - EXPECT_TRUE(res); - EXPECT_EQ(desc, desc_from_proto); -} - -TEST_F(TensorUT, GeTensorSerializeUtils_Dtype) { - proto::TensorDescriptor desc_proto; - ge::proto::AttrDef custom_dtype; - custom_dtype.set_i(13); - (void)desc_proto.mutable_attr()->insert({"__tensor_desc_data_type__", custom_dtype}); - ge::DataType dtype; - GeTensorSerializeUtils::GetDtypeFromDescProto(&desc_proto, dtype); - EXPECT_EQ(dtype, ge::DT_DUAL); -} - - -TEST_F(TensorUT, GeTensorSerializeUtils_GeTensor) { - GeShape shape({1, 2, 3, 4}); - GeTensorDesc desc(shape, FORMAT_NC1HWC0, DT_FLOAT16); - desc.SetOriginDataType(DT_INT32); - desc.SetOriginFormat(FORMAT_NHWC1C0); - desc.SetOriginShape(GeShape({4, 3, 2, 1})); - GeTensor tensor(desc); - proto::TensorDef tensor_proto; - GeTensorSerializeUtils::GeTensorAsProto(tensor, &tensor_proto); - GeTensor tensor_from_proto; - GeTensorSerializeUtils::AssembleGeTensorFromProto(&tensor_proto, tensor_from_proto); - EXPECT_EQ(tensor.GetTensorDesc(), desc); - EXPECT_EQ(tensor.GetTensorDesc(), tensor_from_proto.GetTensorDesc()); -} - -TEST_F(TensorUT, GeShape_ModifyDimNum) { - GeShape shape({1, 2, 3, 4}); - EXPECT_EQ(shape.GetShapeSize(), 24); - EXPECT_EQ(shape.GetDimNum(), 4); - shape.SetDimNum(2); - EXPECT_EQ(shape.GetDimNum(), 2); - EXPECT_FALSE(shape.IsUnknownDimNum()); - shape.SetIsUnknownDimNum(); - EXPECT_TRUE(shape.IsUnknownDimNum()); - EXPECT_EQ(shape.GetShapeSize(), -1); - shape.SetDimNum(2); - EXPECT_EQ(shape.GetDimNum(), 2); - EXPECT_FALSE(shape.IsUnknownDimNum()); - shape.SetDim(0, 2); - shape.SetDim(1, 2); - EXPECT_EQ(shape.GetShapeSize(), 4); - shape.SetDim(0, INT64_MAX); - shape.SetDim(1, 2); - EXPECT_EQ(shape.GetShapeSize(), -1); -} - -TEST_F(TensorUT, GeShape_Unknown) { - GeShape shape({-2}); - EXPECT_TRUE(shape.IsUnknownShape()); - EXPECT_TRUE(shape.IsUnknownDimNum()); - EXPECT_FALSE(shape.IsScalar()); - EXPECT_EQ(shape.GetDimNum(), 0U); - EXPECT_EQ(shape.GetDims().size(), 1U); -} - -TEST_F(TensorUT, Shape_Unknown) { - Shape shape({-2}); - EXPECT_EQ(shape.GetDimNum(), 0U); - EXPECT_EQ(shape.GetDims().size(), 1U); -} - -TEST_F(TensorUT, GeTensorDesc_Update) { - GeShape shape({1, 2, 3, 4}); - GeTensorDesc desc(shape, FORMAT_NC1HWC0, DT_FLOAT16); - EXPECT_EQ(desc.GetShape(), shape); - EXPECT_EQ(desc.GetFormat(), FORMAT_NC1HWC0); - EXPECT_EQ(desc.GetDataType(), DT_FLOAT16); - GeShape shape2({4, 3, 2, 1}); - desc.Update(shape2, FORMAT_NHWC, DT_INT32); - EXPECT_EQ(desc.GetShape(), shape2); - EXPECT_EQ(desc.GetFormat(), FORMAT_NHWC); - EXPECT_EQ(desc.GetDataType(), DT_INT32); -} - -TEST_F(TensorUT, AttrUtils_SetGeTensorDesc) { - GeShape shape({1, 2, 3, 4}); - GeTensorDesc desc(shape, FORMAT_NC1HWC0, DT_FLOAT16); - GeTensorDesc obj; - ge::AttrUtils::SetTensorDesc(obj, "attr_tensor", desc); - GeTensorDesc desc_from_attr; - ge::AttrUtils::GetTensorDesc(obj, "attr_tensor", desc_from_attr); - EXPECT_EQ(desc, desc_from_attr); -} - -TEST_F(TensorUT, AttrUtils_SetListGeTensorDesc) { - GeShape shape({1, 2, 3, 4}); - std::vector descs; - descs.emplace_back(GeTensorDesc(GeShape({1, 2, 3, 4}), FORMAT_NC1HWC0, DT_FLOAT16)); - descs.emplace_back(GeTensorDesc(GeShape({4, 3, 2, 1}), FORMAT_NCHW, DT_INT32)); - GeTensorDesc obj; - ge::AttrUtils::SetListTensorDesc(obj, "attr_tensors", descs); - std::vector descs_from_attr; - ge::AttrUtils::GetListTensorDesc(obj, "attr_tensors", descs_from_attr); - EXPECT_EQ(descs.size(), descs_from_attr.size()); - for (size_t i = 0; i < descs.size(); i++) { - EXPECT_EQ(descs[i], descs_from_attr[i]); - } -} - -class AscendStringUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(AscendStringUT, Hash) { - ge::AscendString ascend_string("ABC"); - EXPECT_EQ(std::hash()(ascend_string), ascend_string.Hash()); - EXPECT_EQ(std::hash()(ascend_string.GetString()), ascend_string.Hash()); - EXPECT_EQ(std::hash()("ABC"), ascend_string.Hash()); - - ge::AscendString empty_ascend_string; - EXPECT_EQ(std::hash()(empty_ascend_string), empty_ascend_string.Hash()); - EXPECT_EQ(std::hash()(""), empty_ascend_string.Hash()); -} - -TEST_F(AscendStringUT, EmptyValueCompare) { - ge::AscendString ascend_string; - EXPECT_NE(ascend_string.GetString(), ""); - EXPECT_EQ(ascend_string.GetString(), std::string("")); - EXPECT_TRUE(std::string(ascend_string.GetString()).empty()); -} - -TEST_F(TensorUT, TensorUtils_GetSteExtMeta) { - GeTensorDesc desc; - -#define TEST_EXT_META_INNER(NAME, TYPE, V, V1) \ - do { \ - TYPE v = V; \ - TYPE v1 = V1; \ - TensorUtils::Set##NAME(desc, v); \ - TensorUtils::Get##NAME(desc, v1); \ - EXPECT_EQ(v, v1); \ - } while (false) - -#define TEST_EXT_META_INT64(NAME) TEST_EXT_META_INNER(NAME, int64_t, 0, -1); -#define TEST_EXT_META_BOOL(NAME) TEST_EXT_META_INNER(NAME, bool, true, false); -#define TEST_EXT_META_UINT32(NAME) TEST_EXT_META_INNER(NAME, uint32_t, 0, 1); - - TEST_EXT_META_INT64(Size); - TEST_EXT_META_INT64(DataOffset); - - TEST_EXT_META_UINT32(RealDimCnt); - TEST_EXT_META_UINT32(ReuseInputIndex); - - TEST_EXT_META_BOOL(InputTensor); - TEST_EXT_META_BOOL(OutputTensor); - TEST_EXT_META_BOOL(ReuseInput); - - desc.SetName("foo"); - EXPECT_EQ(desc.GetName(), "foo"); - - TensorUtils::SetWeightSize(desc, 2021); - EXPECT_EQ(TensorUtils::GetWeightSize(desc), 2021); -} - -TEST_F(TensorUT, Tensor_Construct3) { - std::vector shape{4}; - uint8_t *data = new uint8_t[4]{1, 2, 3, 4}; - size_t size = 4; - TensorDesc tensor_desc(Shape(shape), FORMAT_ND, DT_UINT8); - Tensor tensor(tensor_desc, data, size); - EXPECT_EQ(tensor.GetSize(), 4); - delete[] data; -} - -TEST_F(TensorUT, Tensor_Construct4) { - std::vector value{1, 2, 3}; - std::vector shape{3}; - TensorDesc tensor_desc(Shape(shape), FORMAT_ND, DT_UINT8); - Tensor tensor(std::move(tensor_desc), std::move(value)); - EXPECT_EQ(tensor.GetSize(), 3); -} - -TEST_F(TensorUT, Tensor_SetData) { - Tensor t1; - std::vector vec; - for (uint8_t i = 0; i < 10; ++i) { - vec.push_back(i); - } - EXPECT_EQ(t1.SetData(vec), GRAPH_SUCCESS); - - Tensor t2; - std::string str1 = "abc"; - EXPECT_EQ(t2.SetData(str1), GRAPH_SUCCESS); - - Tensor t3; - std::vector vec_str; - EXPECT_EQ(t3.SetData(vec_str), GRAPH_FAILED); - for (uint8_t i = 0; i < 10; ++i) { - vec_str.push_back(std::to_string(i)); - } - EXPECT_EQ(t3.SetData(vec_str), GRAPH_SUCCESS); - - Tensor t4; - const char *str2 = "def"; - EXPECT_EQ(t4.SetData(str2), GRAPH_SUCCESS); - - Tensor t5; - const char * str3[3] = {"123", "456", "789"}; - std::vector vec_asc_str; - for (uint8_t i = 0; i < 3; ++i) { - vec_asc_str.push_back(AscendString(str3[i])); - } - EXPECT_EQ(t5.SetData(vec_asc_str), GRAPH_SUCCESS); -} - -TEST_F(TensorUT, Shape_SetDim) { - size_t idx = 1; - int64_t value = 2; - - Shape shape1; - EXPECT_EQ(shape1.SetDim(idx, value), GRAPH_FAILED); - - std::vector dims; - for(int64_t i = 0; i < 3; i++) { - dims.push_back(i); - } - - Shape shape2(dims); - EXPECT_EQ(shape2.SetDim(idx, value), GRAPH_SUCCESS); -} - -TEST_F(TensorUT, TensorDesc_Construct1) { - std::vector shape{3}; - TensorDesc tensor_desc1(Shape(shape), FORMAT_ND, DT_UINT8); - TensorDesc tensor_desc2(std::move(tensor_desc1)); - - TensorDesc tensor_desc3(Shape(shape), FORMAT_ND, DT_UINT8); - TensorDesc tensor_desc4 = std::move(tensor_desc3); - - tensor_desc4.Update(Shape(shape), FORMAT_ND, DT_UINT16); - EXPECT_EQ(tensor_desc4.GetDataType(), DT_UINT16); - - TensorDesc tensor_desc5; - EXPECT_EQ(tensor_desc5.GetShape().GetShapeSize(), 0); -} - -TEST_F(TensorUT, TensorDesc_GetSetShape) { - std::vector> range; - TensorDesc tensor_desc1; - tensor_desc1.GetShape(); - tensor_desc1.GetOriginShape(); - - EXPECT_EQ(tensor_desc1.GetShapeRange(range), GRAPH_SUCCESS); - EXPECT_EQ(tensor_desc1.SetShapeRange(range), GRAPH_SUCCESS); - - EXPECT_EQ(tensor_desc1.SetUnknownDimNumShape(), GRAPH_SUCCESS); - - std::vector shape{3}; - TensorDesc tensor_desc2(Shape(shape), FORMAT_ND, DT_UINT8); - EXPECT_EQ(tensor_desc2.SetUnknownDimNumShape(), GRAPH_SUCCESS); -} - -TEST_F(TensorUT, TensorDesc_SetDataType) { - EXPECT_NO_THROW( - std::vector shape{3}; - TensorDesc tensor_desc1(Shape(shape), FORMAT_ND, DT_UINT8); - tensor_desc1.SetDataType(DT_UINT16); - ); -} - -TEST_F(TensorUT, TensorDesc_GetSetName) { - std::vector shape{3}; - TensorDesc tensor_desc1(Shape(shape), FORMAT_ND, DT_UINT8); - tensor_desc1.SetName("abc"); - - AscendString name; - tensor_desc1.GetName(name); - EXPECT_EQ(name, AscendString("abc")); -} - - -TEST_F(TensorUT, TensorDesc_get_set_expand_dims_rule) { - TensorDesc td; - // init status - AscendString expand_dims_rule; - td.GetExpandDimsRule(expand_dims_rule); - EXPECT_EQ(expand_dims_rule.GetLength(), 0); - - // test set and get - expand_dims_rule = AscendString("1100"); - td.SetExpandDimsRule(expand_dims_rule); - td.GetExpandDimsRule(expand_dims_rule); - EXPECT_STREQ(expand_dims_rule.GetString(), "1100"); -} - -TEST_F(TensorUT, Tensor_SetTensorDesc_GetData) { - std::vector shape{3}; - TensorDesc tensor_desc1(Shape(shape), FORMAT_ND, DT_UINT8); - - Tensor t1; - auto ret = t1.SetTensorDesc(tensor_desc1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - uint8_t *data1 = NULL; - data1 = t1.GetData(); - EXPECT_NE(data1, nullptr); - - const uint8_t *data2 = NULL; - data2 = t1.GetData(); - EXPECT_NE(data2, nullptr); -} - -TEST_F(TensorUT, Tensor_AsGeTensorImpl) { - std::vector shape{3}; - TensorDesc tensor_desc1(Shape(shape), FORMAT_ND, DT_UINT8); - - Tensor t1; - auto ret = t1.SetTensorDesc(tensor_desc1); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - const GeTensor* gt_impl = TensorAdapter::AsBareGeTensorPtr(t1); - EXPECT_NE(gt_impl, nullptr); - - t1.impl = nullptr; - const GeTensor* gt_impl_2 = TensorAdapter::AsBareGeTensorPtr(t1); - EXPECT_EQ(gt_impl_2, nullptr); -} - -TEST_F(TensorUT, unique_ptr_Tensor_ResetData) { - Tensor t1; - std::unique_ptr pt; - EXPECT_NO_THROW(pt = t1.ResetData()); -} - -TEST_F(TensorUT, Tensor_IsValid_Clone) { - Tensor t1; - Tensor t2; - - std::vector shape{3}; - TensorDesc tensor_desc1(Shape(shape), FORMAT_ND, DT_UINT8); - t1.SetTensorDesc(tensor_desc1); - - EXPECT_EQ(t1.IsValid(), GRAPH_FAILED); - - t2 = t1.Clone(); -} - -TEST_F(TensorUT, TensorAdapter_GetGeTensorFromTensor) { - Tensor t1; - GeTensor gt = TensorAdapter::AsGeTensorShared(t1); - ConstGeTensorPtr cgtptr = TensorAdapter::AsGeTensorPtr(t1); - EXPECT_NE(cgtptr, nullptr); -} - -TEST_F(TensorUT, TensorAdapter_AsTensor) { - GeTensor gt1; - std::vector vec; - for (uint8_t i = 0; i < 100; ++i) { - vec.push_back(i * 2); - } - gt1.SetData(vec); - - Tensor t1; - t1 = TensorAdapter::AsTensor(gt1); - EXPECT_EQ(t1.GetSize(), gt1.GetData().GetSize()); - const GeTensor gt2; - const Tensor t2 = TensorAdapter::AsTensor(gt2); - EXPECT_EQ(t2.GetSize(), gt2.GetData().GetSize()); -} - -TEST_F(TensorUT, TensorDesc2GeTensorDesc_expand_dims_rule) { - TensorDesc td; - // test set and get - td.SetExpandDimsRule(AscendString("0011")); - - auto ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(td); - EXPECT_STREQ(ge_tensor_desc.GetExpandDimsRule().c_str(), "0011"); -} - -TEST_F(TensorUT, GeTensorDesc2TensorDesc_expand_dims_rule) { - GeTensorDesc ge_tensor_desc; - // test set and get - ge_tensor_desc.SetExpandDimsRule("0011"); - - auto tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(ge_tensor_desc); - AscendString expand_dims_rule; - tensor_desc.GetExpandDimsRule(expand_dims_rule); - EXPECT_STREQ(expand_dims_rule.GetString(), "0011"); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/tensor_utils_unittest.cc b/tests/ut/graph/testcase/tensor_utils_unittest.cc deleted file mode 100644 index e8fc120d88df91b8587f8dfe4fabcc2ff32a9f16..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/tensor_utils_unittest.cc +++ /dev/null @@ -1,667 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/tensor_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_log.h" -#include "graph/yuv_subformat.h" -#include -using namespace std; - -namespace ge { - -class ge_test_tensor_utils : public testing::Test { - protected: - void SetUp() { - } - - void TearDown() { - } -}; - -TEST_F(ge_test_tensor_utils, shape) { - GeTensorDesc tensorDesc; - - int64_t s1 = 1; - int64_t s2 = 0; - TensorUtils::SetSize(tensorDesc, s1); - EXPECT_EQ(TensorUtils::GetSize(tensorDesc, s2), GRAPH_SUCCESS); - EXPECT_EQ(s2, 1); - TensorUtils::SetSize(tensorDesc, 2); - EXPECT_EQ(TensorUtils::GetSize(tensorDesc, s2), GRAPH_SUCCESS); - EXPECT_EQ(s2, 2); - - bool f1(true); - bool f2(false); - TensorUtils::SetReuseInput(tensorDesc, f1); - EXPECT_EQ(TensorUtils::GetReuseInput(tensorDesc, f2), GRAPH_SUCCESS); - EXPECT_EQ(f2, true); - - f1 = true; - f2 = false; - TensorUtils::SetOutputTensor(tensorDesc, f1); - EXPECT_EQ(TensorUtils::GetOutputTensor(tensorDesc, f2), GRAPH_SUCCESS); - EXPECT_EQ(f2, true); - - DeviceType d1(DeviceType::CPU); - DeviceType d2(DeviceType::NPU); - TensorUtils::SetDeviceType(tensorDesc, d1); - EXPECT_EQ(TensorUtils::GetDeviceType(tensorDesc, d2), GRAPH_SUCCESS); - EXPECT_EQ(d2, true); - - f1 = true; - f2 = false; - TensorUtils::SetInputTensor(tensorDesc, f1); - EXPECT_EQ(TensorUtils::GetInputTensor(tensorDesc, f2), GRAPH_SUCCESS); - EXPECT_EQ(f2, true); - - uint32_t s5 = 1; - uint32_t s6 = 0; - TensorUtils::SetRealDimCnt(tensorDesc, s5); - EXPECT_EQ(TensorUtils::GetRealDimCnt(tensorDesc, s6), GRAPH_SUCCESS); - EXPECT_EQ(s6, 1); - - s5 = 1; - s6 = 0; - TensorUtils::SetReuseInputIndex(tensorDesc, s5); - EXPECT_EQ(TensorUtils::GetReuseInputIndex(tensorDesc, s6), GRAPH_SUCCESS); - EXPECT_EQ(s6, 1); - - int64_t s3(1); - int64_t s4(0); - TensorUtils::SetDataOffset(tensorDesc, s3); - EXPECT_EQ(TensorUtils::GetDataOffset(tensorDesc, s4), GRAPH_SUCCESS); - EXPECT_EQ(s4, 1); - - s5 = 1; - s6 = 0; - TensorUtils::SetRC(tensorDesc, s5); - EXPECT_EQ(TensorUtils::GetRC(tensorDesc, s6), GRAPH_SUCCESS); - EXPECT_EQ(s6, 1); - TensorUtils::SetRC(tensorDesc, 2); - EXPECT_EQ(TensorUtils::GetRC(tensorDesc, s6), GRAPH_SUCCESS); - EXPECT_EQ(s6, 2); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_failed_datatype_notsupport) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_NCHW; - DataType data_type = DT_UNDEFINED; - - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_failed_format_notsupport) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_RESERVED; - DataType data_type = DT_FLOAT16; - - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -// not 4 calc by nd -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_NCHW_shape_not_4) { - vector dims({2, 3, 4, 5, 6}); - GeShape ge_shape(dims); - Format format = FORMAT_NCHW; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = sizeof(uint16_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_NCHW_SUCCESS) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_NCHW; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = sizeof(uint16_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -// not 4 calc by nd -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_NHWC_shape_not_4) { - vector dims({2, 3, 4}); - GeShape ge_shape(dims); - Format format = FORMAT_NHWC; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = sizeof(uint16_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_NHWC_SUCCESS) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_NHWC; - DataType data_type = DT_FLOAT; - int64_t expect_mem_size = sizeof(float); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_ND_FAILED_overflow_with_type) { - vector dims({1024 * 1024, 1024 * 1024, 1024 * 1024}); - GeShape ge_shape(dims); - Format format = FORMAT_ND; - DataType data_type = DT_UINT64; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_ND_FAILED_overflow) { - vector dims({1024 * 1024, 1024 * 1024, 1024 * 1024, 1024 * 1024}); - GeShape ge_shape(dims); - Format format = FORMAT_ND; - DataType data_type = DT_UINT64; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_DT_STRING_FAILED_overflow) { - vector dims({4, 1024 * 1024, 1024 * 1024, 1024 * 1024}); - GeShape ge_shape(dims); - Format format = FORMAT_ND; - DataType data_type = DT_STRING; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_NE(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_ND_SUCCESS) { - vector dims({10, 2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_ND; - DataType data_type = DT_UINT64; - int64_t expect_mem_size = sizeof(uint64_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_MD_SUCCESS) { - vector dims({10, 20, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_MD; - DataType data_type = DT_UINT32; - int64_t expect_mem_size = sizeof(uint32_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_NC1HWC0_SUCCESS_NONEEDPAD) { - vector dims({10, 32, 3, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_NC1HWC0; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = sizeof(uint16_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_NC1HWC0_SUCCESS_5D) { - vector dims({10, 2, 3, 5, 16}); - GeShape ge_shape(dims); - Format format = FORMAT_NC1HWC0; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = sizeof(uint16_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_ND_SUCCESS_size_0) { - vector dims({10, 0, 3, 5, 16}); - GeShape ge_shape(dims); - Format format = FORMAT_NC1HWC0; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = 0; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_TRUE(mem_size == expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_ND_SUCCESS_size_unknown) { - vector dims({10, -1, 3, 5, 16}); - GeShape ge_shape(dims); - Format format = FORMAT_NC1HWC0; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = -1; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_C1HWNCoC0_SUCCESS) { - vector dims({10, 2, 3, 5, 8, 16}); - GeShape ge_shape(dims); - Format format = FORMAT_C1HWNCoC0; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = sizeof(uint16_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(mem_size, expect_mem_size); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_CCE_FractalZ_shape_error) { - setenv("PARSER_PRIORITY", "cce", 0); - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_FRACTAL_Z; - DataType data_type = DT_UINT8; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - unsetenv("PARSER_PRIORITY"); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_CCE_FractalZ_SUCCESS) { - setenv("PARSER_PRIORITY", "cce", 0); - vector dims({16, 16, 6, 7}); - GeShape ge_shape(dims); - Format format = FORMAT_FRACTAL_Z; - DataType data_type = DT_FLOAT16; - int64_t expect_mem_size = sizeof(uint16_t); - for (int64_t dim:dims) { - expect_mem_size *= dim; - } - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - unsetenv("PARSER_PRIORITY"); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSize_TBE_FractalZ_OverFlow_FAIL) { - vector dims({1024 * 1024, 1024 * 1024, 1024 * 1024, 1024 * 1024}); - GeShape ge_shape(dims); - Format format = FORMAT_FRACTAL_Z; - DataType data_type = DT_FLOAT16; - - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - //cout<<"mem_size:"< dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_RESERVED; - DataType data_type = DT_MAX; - GeTensorDesc tensorDesc(ge_shape, format, data_type); - int64_t size; -// MOCKER(TensorUtils::GetTensorSizeInBytes).stubs().will(returnValue(GRAPH_FAILED)); - graphStatus ret = TensorUtils::GetTensorMemorySizeInBytes(tensorDesc, size); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(ge_test_tensor_utils, GetTensorSizeInBytes_SUCCESS) { - GeTensorDesc tensorDesc; - int64_t size; -// MOCKER(TensorUtils::CalcTensorMemSize).stubs().with(any(),any(),any(),outBound(memSize)).will(returnValue(GRAPH_SUCCESS)); - graphStatus ret = TensorUtils::GetTensorSizeInBytes(tensorDesc, size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, GetTensorSizeInBytes_FAILED) { - GeTensorDesc tensorDesc; - int64_t size; -// MOCKER(TensorUtils::CalcTensorMemSize).stubs().will(returnValue(GRAPH_FAILED)); - graphStatus ret = TensorUtils::GetTensorSizeInBytes(tensorDesc, size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, GetTensorSizeInBytes_NoTiling_SUCCESS) { - GeTensorDesc tensorDesc(GeShape({1, -1})); - tensorDesc.SetShapeRange({{1, 1}, {1, 10}}); - int64_t size; - (void)AttrUtils::SetBool(&tensorDesc, ATTR_NAME_TENSOR_NO_TILING_MEM_TYPE, true); - graphStatus ret = TensorUtils::GetTensorSizeInBytes(tensorDesc, size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeFilterHwckTest) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_FILTER_HWCK; - DataType data_type = DT_STRING; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeFractalZnRnn) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_FRACTAL_ZN_RNN; - DataType data_type = DT_STRING; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeFractalZWino) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_FRACTAL_Z_WINO; - DataType data_type = DT_STRING; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CheckShapeByShapeRangeShapeRangeIsNull) { - vector dims({2, 3, 4, 5, 6, 7, 8}); - GeShape ge_shape(dims); - std::vector> shape_range; - graphStatus ret = - TensorUtils::CheckShapeByShapeRange(ge_shape, shape_range); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CheckShapeByShapeRangeFailTest) { - vector dims({2, 3, 4, 5, 6, 7, 8}); - GeShape ge_shape(dims); - std::vector> shape_range; - shape_range.push_back(std::make_pair(1, 1)); - graphStatus ret = - TensorUtils::CheckShapeByShapeRange(ge_shape, shape_range); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(ge_test_tensor_utils, CheckShapeByShapeRangeLeftRangeLessThan0) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - std::vector> shape_range; - shape_range.push_back(std::make_pair(-1, 1)); - shape_range.push_back(std::make_pair(2, 2)); - shape_range.push_back(std::make_pair(3, 3)); - shape_range.push_back(std::make_pair(4, 4)); - graphStatus ret = - TensorUtils::CheckShapeByShapeRange(ge_shape, shape_range); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(ge_test_tensor_utils, CheckShapeByShapeRangeCurDimIsUnknownDim) { - vector dims({-1, -1}); - GeShape ge_shape(dims); - std::vector> shape_range; - shape_range.push_back(std::make_pair(1, 1)); - shape_range.push_back(std::make_pair(2, 2)); - graphStatus ret = - TensorUtils::CheckShapeByShapeRange(ge_shape, shape_range); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CheckShapeByShapeRangeCurDimLessThanLeftRange) { - vector dims({1, 2}); - GeShape ge_shape(dims); - std::vector> shape_range; - shape_range.push_back(std::make_pair(3, 3)); - shape_range.push_back(std::make_pair(4, 4)); - graphStatus ret = - TensorUtils::CheckShapeByShapeRange(ge_shape, shape_range); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(ge_test_tensor_utils, CheckShapeByShapeRangeRightRangeLessThan0) { - vector dims({3, 4}); - GeShape ge_shape(dims); - std::vector> shape_range; - shape_range.push_back(std::make_pair(3, -3)); - shape_range.push_back(std::make_pair(4, -4)); - graphStatus ret = - TensorUtils::CheckShapeByShapeRange(ge_shape, shape_range); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(ge_test_tensor_utils, CheckShapeByShapeRangeCurDimGreaterThanRightRange) { - vector dims({5, 6}); - GeShape ge_shape(dims); - std::vector> shape_range; - shape_range.push_back(std::make_pair(3, 3)); - shape_range.push_back(std::make_pair(4, 4)); - graphStatus ret = - TensorUtils::CheckShapeByShapeRange(ge_shape, shape_range); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeForNoTilingSuccess) { - GeTensorDesc tensor; - Format format = FORMAT_NCHW; - DataType data_type = DT_STRING; - int64_t mem_size = 0; - graphStatus ret = - TensorUtils::CalcTensorMemSizeForNoTiling(tensor, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeForNoTilingDimsSizeIs0) { - vector dims({}); - GeShape ge_shape(dims); - Format format = FORMAT_FRACTAL_Z; - DataType data_type = DT_FLOAT; - int64_t mem_size = 0; - GeTensorDesc tensor(ge_shape, format, data_type); - graphStatus ret = - TensorUtils::CalcTensorMemSizeForNoTiling(tensor, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeForNoTilingFail) { - vector dims({0, -1}); - GeShape ge_shape(dims); - Format format = FORMAT_MAX; - DataType data_type = DT_MAX; - int64_t mem_size = 0; - GeTensorDesc tensor(ge_shape); - graphStatus ret = - TensorUtils::CalcTensorMemSizeForNoTiling(tensor, format, data_type, mem_size); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(ge_test_tensor_utils, GetMaxShapeDimsFromNoTilingTensorFail) { - vector dims({0, -1}); - GeShape ge_shape(dims); - Format format = FORMAT_MAX; - DataType data_type = DT_MAX; - int64_t mem_size = 0; - GeTensorDesc tensor(ge_shape); - std::vector max_shape_list; - max_shape_list.push_back(1); - AttrUtils::SetListInt(tensor, ATTR_NAME_TENSOR_MAX_SHAPE, max_shape_list); - graphStatus ret = - TensorUtils::CalcTensorMemSizeForNoTiling(tensor, format, data_type, mem_size); - EXPECT_EQ(ret, PARAM_INVALID); -} - -TEST_F(ge_test_tensor_utils, GetMaxShapeDimsFromNoTilingTensorSuccess) { - vector dims({0, -1}); - GeShape ge_shape(dims); - Format format = FORMAT_ND; - DataType data_type = DT_FLOAT; - int64_t mem_size = 0; - GeTensorDesc tensor(ge_shape); - std::vector max_shape_list; - max_shape_list.push_back(1); - max_shape_list.push_back(2); - AttrUtils::SetListInt(tensor, ATTR_NAME_TENSOR_MAX_SHAPE, max_shape_list); - graphStatus ret = - TensorUtils::CalcTensorMemSizeForNoTiling(tensor, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, GetMaxShapeDimsFromNoTilingTensorGetShapeRangeFail) { - vector dims({0, -1}); - GeShape ge_shape(dims); - Format format = FORMAT_MAX; - DataType data_type = DT_MAX; - int64_t mem_size = 0; - GeTensorDesc tensor(ge_shape); - - std::vector> range; - range.push_back(std::vector(1)); - AttrUtils::SetListListInt(tensor, "shape_range", range); - graphStatus ret = - TensorUtils::CalcTensorMemSizeForNoTiling(tensor, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(ge_test_tensor_utils, GetTensorSizeInBytesOutputMemSizeLessThan0) { - GeTensorDesc tensorDesc(GeShape({1, -1})); - int64_t size; - (void)AttrUtils::SetBool(&tensorDesc, ATTR_NAME_TENSOR_NO_TILING_MEM_TYPE, false); - graphStatus ret = TensorUtils::GetTensorSizeInBytes(tensorDesc, size); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeDataTypeIsDtStringRef) { - vector dims({0, 0}); - GeShape ge_shape(dims); - Format format = FORMAT_MAX; - DataType data_type = DT_STRING_REF; - int64_t mem_size; - graphStatus ret = - TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeNYUV) { - vector dims({2, 3, 4, 5}); - GeShape ge_shape(dims); - Format format = FORMAT_NYUV; - DataType data_type = DT_FLOAT; - int64_t mem_size = 0; - graphStatus ret = TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - int format_sub = (format | YVU420_SP << 8); - ret = TensorUtils::CalcTensorMemSize(ge_shape, static_cast(format_sub), data_type, mem_size); - EXPECT_EQ(mem_size, 2 * 3 * 4 * 5 * 4); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - format_sub = (format | YUV422_SP << 8); - ret = TensorUtils::CalcTensorMemSize(ge_shape, static_cast(format_sub), data_type, mem_size); - EXPECT_EQ(mem_size, 2 * 3 * 4 * 5 * 4); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - format_sub = (format | YUV400 << 8); - ret = TensorUtils::CalcTensorMemSize(ge_shape, static_cast(format_sub), data_type, mem_size); - EXPECT_EQ(mem_size, 2 * 3 * 4 * 5 * 4); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeNCL) { - vector dims({2, 3, 4}); - GeShape ge_shape(dims); - Format format = FORMAT_NCL; - DataType data_type = DT_FLOAT; - int64_t mem_size = 0; - graphStatus ret = TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(mem_size, 2 * 3 * 4 * 4); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(ge_test_tensor_utils, CalcTensorMemSizeC1HWC0) { - vector dims({1, 2, 3, 4}); - GeShape ge_shape(dims); - Format format = FORMAT_C1HWC0; - DataType data_type = DT_INT8; - int64_t mem_size = 0; - graphStatus ret = TensorUtils::CalcTensorMemSize(ge_shape, format, data_type, mem_size); - EXPECT_EQ(mem_size, 1 * 2 * 3 * 4 * 1); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} -} diff --git a/tests/ut/graph/testcase/test_std_structs.cc b/tests/ut/graph/testcase/test_std_structs.cc deleted file mode 100644 index f18e2438153b71df333854030d6b77038ecc9595..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/test_std_structs.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "test_std_structs.h" - -#include - -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/utils/attr_utils.h" - -namespace ge { - -GeTensorDesc StandardTd_5d_1_1_224_224() { - GeTensorDesc td; - td.SetShape(GeShape(std::vector({1, 1, 224, 224, 16}))); - td.SetOriginShape(GeShape(std::vector({1, 1, 224, 224}))); - td.SetFormat(FORMAT_NC1HWC0); - td.SetOriginFormat(FORMAT_NCHW); - td.SetDataType(DT_FLOAT16); - td.SetOriginDataType(DT_FLOAT); - - vector input_size = {12}; - AttrUtils::SetListInt(td, "input_size", input_size); - - return td; -} - -void ExpectStandardTdProto_5d_1_1_224_224(const proto::TensorDescriptor &input_td) { - // shape - EXPECT_EQ(input_td.shape().dim_size(), 5); - EXPECT_EQ(input_td.shape().dim(0), 1); - EXPECT_EQ(input_td.shape().dim(1), 1); - EXPECT_EQ(input_td.shape().dim(2), 224); - EXPECT_EQ(input_td.shape().dim(3), 224); - EXPECT_EQ(input_td.shape().dim(4), 16); - - // origin shape, origin shape is set - EXPECT_EQ(input_td.attr().count("origin_shape"), 1); - EXPECT_EQ(input_td.attr().at("origin_shape").value_case(), proto::AttrDef::ValueCase::kList); - EXPECT_EQ(input_td.attr().at("origin_shape").list().val_type(), proto::AttrDef_ListValue_ListValueType_VT_LIST_INT); - EXPECT_EQ(input_td.attr().at("origin_shape").list().i_size(), 4); - EXPECT_EQ(input_td.attr().at("origin_shape").list().i(0), 1); - EXPECT_EQ(input_td.attr().at("origin_shape").list().i(1), 1); - EXPECT_EQ(input_td.attr().at("origin_shape").list().i(2), 224); - EXPECT_EQ(input_td.attr().at("origin_shape").list().i(3), 224); - EXPECT_EQ(input_td.attr().count("origin_shape_initialized"), 1); - EXPECT_EQ(input_td.attr().at("origin_shape_initialized").value_case(), proto::AttrDef::ValueCase::kB); - EXPECT_EQ(input_td.attr().at("origin_shape_initialized").b(), true); - - // format, origin format - EXPECT_EQ(input_td.attr().count("origin_format"), 1); - EXPECT_EQ(input_td.attr().at("origin_format").s(), "NCHW"); - EXPECT_EQ(input_td.layout(), "NC1HWC0"); - - // data_tpye, origin data_type - EXPECT_EQ(input_td.dtype(), proto::DT_FLOAT16); - EXPECT_EQ(input_td.attr().count("origin_data_type"), 1); - EXPECT_EQ(input_td.attr().at("origin_data_type").s(), "DT_FLOAT"); - - EXPECT_EQ(input_td.attr().count("input_size"), 1); - EXPECT_EQ(input_td.attr().at("input_size").value_case(), proto::AttrDef::ValueCase::kList); - EXPECT_EQ(input_td.attr().at("input_size").list().val_type(), proto::AttrDef_ListValue_ListValueType_VT_LIST_INT); - EXPECT_EQ(input_td.attr().at("input_size").list().i_size(), 1); - EXPECT_EQ(input_td.attr().at("input_size").list().i(0), 12); -} -} diff --git a/tests/ut/graph/testcase/test_std_structs.h b/tests/ut/graph/testcase/test_std_structs.h deleted file mode 100644 index 52fe2f9949b0d99984f04d3ff69bf3643005f15e..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/test_std_structs.h +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef METADEF_CXX_TEST_STD_STRUCTS_H -#define METADEF_CXX_TEST_STD_STRUCTS_H -#include "proto/ge_ir.pb.h" -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/utils/attr_utils.h" - -namespace ge { -GeTensorDesc StandardTd_5d_1_1_224_224(); -void ExpectStandardTdProto_5d_1_1_224_224(const proto::TensorDescriptor &input_td); -} -#endif //METADEF_CXX_TEST_STD_STRUCTS_H diff --git a/tests/ut/graph/testcase/trace_manager_unittest.cc b/tests/ut/graph/testcase/trace_manager_unittest.cc deleted file mode 100644 index 2f2d4629d76aa7a0ff21bc1c62022a9ac6bdc53d..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/trace_manager_unittest.cc +++ /dev/null @@ -1,78 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include "mmpa/mmpa_api.h" -#include "inc/common/util/trace_manager/trace_manager.h" - -using namespace std; - -namespace { -size_t GetFileLinesNum(const std::string fn) { - std::ifstream f; - f.open(fn, std::ios::in); - size_t num = 0U; - if (f.fail()) { - return 0U; - } else { - std::string s; - while (std::getline(f, s)) { - num++; - } - f.close(); - } - return num; -} -} // namespace - -namespace ge { -class UtestTraceManager : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestTraceManager, add_trace_basic_0) { - auto &instance = TraceManager::GetInstance(); - instance.ClearTraceOwner(); - instance.SetTraceOwner("a", "b", "c"); - instance.trace_index_ = 0; - EXPECT_EQ(instance.Initialize("."), SUCCESS); - instance.enabled_ = true; - instance.ClearTraceOwner(); - instance.SetTraceOwner("a", "b", "c"); - EXPECT_EQ(instance.trace_header_, "a:b"); - EXPECT_EQ(instance.graph_name_, "c"); - instance.AddTrace("0"); - EXPECT_EQ(instance.trace_index_, 1); - for (int i = 0; i < 10000; i++) { - instance.AddTrace(std::to_string(i + 1)); - } - instance.Finalize(); - EXPECT_EQ(instance.current_file_saved_nums_, 10001); - EXPECT_EQ(GetFileLinesNum(instance.current_saving_file_name_), 10001); - std::string pre_file_name = instance.current_saving_file_name_; - instance.stopped_ = false; - instance.current_file_saved_nums_ = 2000000U + 1U; - EXPECT_EQ(instance.Initialize("."), SUCCESS); - - for (int i = 0; i < 100; i++) { - instance.AddTrace(std::to_string(i + 1)); - } - instance.Finalize(); - - EXPECT_NE(pre_file_name, instance.current_saving_file_name_); - remove(instance.current_saving_file_name_.c_str()); - remove(pre_file_name.c_str()); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/transformer_expand_dims_ut.cc b/tests/ut/graph/testcase/transformer_expand_dims_ut.cc deleted file mode 100644 index 07ee555a431fa71b0668ded93b10123967f418ce..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/transformer_expand_dims_ut.cc +++ /dev/null @@ -1,644 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "axis_util.h" -#include "expand_dimension.h" - -namespace transformer { -class TransformerExpandDimsUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} - - void EXPECT_RunExpandDimsCase(const ge::Format &origin_format, const ge::Format &format, const string &reshape_type, - const vector &dims, const vector &expect_dims) { - std::cout << "EXPECT_RunExpandDimsCase: origin_format=" << origin_format << ", format=" << format - << ", reahpe type=" << reshape_type << ", dim size=" << dims.size() << std::endl; - string op_type = "Relu"; - uint32_t tensor_index = 0; - ge::GeShape shape(dims); - bool ret = ExpandDimension(op_type, origin_format, format, tensor_index, reshape_type, shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(shape.GetDims(), expect_dims); - - ge::GeShape new_shape(dims); - int64_t int_reshape_type = ExpandDimension::GenerateReshapeType(origin_format, format, new_shape.GetDimNum(), - reshape_type); - if (int_reshape_type != 0) { - size_t full_size = static_cast(int_reshape_type >> 56); - size_t expect_full_size = 0; - ExpandDimension::GetFormatFullSize(origin_format, expect_full_size); - EXPECT_EQ(full_size, expect_full_size); - } - - ExpandDimension::ExpandDims(int_reshape_type, new_shape); - EXPECT_EQ(new_shape.GetDims(), expect_dims); - - ge::GeShape shape_1(dims); - ge::GeShape shape_2(dims); - ExpandDimension::ExpandDims(int_reshape_type, shape_1, shape_2); - EXPECT_EQ(shape_2.GetDims(), expect_dims); - } - - void EXPECT_RunNewExpandDimsCase(const ge::Format &origin_format, const ge::Format &format, const string &reshape_type, - const vector &dims, const vector &expect_dims) { - std::cout << "origin_format=" << origin_format << ", format=" << format - << ", reahpe type=" << reshape_type << ", dim size=" << dims.size() << std::endl; - ge::GeShape new_shape(dims); - int64_t int_reshape_type = ExpandDimension::GenerateReshapeType(origin_format, format, new_shape.GetDimNum(), - reshape_type); - if (int_reshape_type != 0) { - size_t full_size = static_cast(int_reshape_type >> 56); - size_t expect_full_size = 0; - ExpandDimension::GetFormatFullSize(origin_format, expect_full_size); - EXPECT_EQ(full_size, expect_full_size); - } - ExpandDimension::ExpandDims(int_reshape_type, new_shape); - EXPECT_EQ(new_shape.GetDims(), expect_dims); - } - - void RunReshapeTypeCase(const ge::Format &format, const size_t &dims_size, const std::string &reshape_type) { - std::cout << "RunReshapeTypeCase: origin format=" << format << ", dims size=" << dims_size << ", reshape_type=" << reshape_type << std::endl; - int64_t reshape_mask = transformer::ExpandDimension::GenerateReshapeType(format, ge::FORMAT_NC1HWC0, dims_size, reshape_type); - std::string ret_shape_type; - std::string fail_reason; - bool ret = transformer::ExpandDimension::GenerateReshapeTypeByMask(format, dims_size, reshape_mask, ret_shape_type, fail_reason); - EXPECT_EQ(ret, true); - EXPECT_EQ(fail_reason.empty(), true); - EXPECT_EQ(reshape_type, ret_shape_type); - } -}; - -TEST_F(TransformerExpandDimsUT, all_expand_dims_cases_1) { - int64_t max_reshape_type = 0xff; - vector full_size_vec = {4, 5}; - vector> dim_vecs = {{}, {5}, {5, 6}, {5, 6, 7}, {5, 6, 7, 8}, {5, 6, 7, 8, 9}}; - for (const size_t &full_size : full_size_vec) { - for (const vector &dims : dim_vecs) { - for (int64_t i = 0; i <= max_reshape_type; i++) { - ge::GeShape shape(dims); - int64_t reshape_type = i | (full_size << 56); - std::cout << "reshape_type = " << std::bitset<8>(reshape_type) << ", shape = " << shape.ToString(); - EXPECT_NO_THROW(ExpandDimension::ExpandDims(reshape_type, shape)); - std::cout << ", after expand dims shape = " << shape.ToString() << std::endl; - } - } - } -} - -TEST_F(TransformerExpandDimsUT, all_expand_dims_cases_2) { - int64_t max_reshape_type = 0xff; - vector full_size_vec = {4, 5}; - vector> dim_vecs = {{}, {5}, {5, 6}, {5, 6, 7}, {5, 6, 7, 8}, {5, 6, 7, 8, 9}}; - for (const size_t &full_size : full_size_vec) { - for (const vector &dims : dim_vecs) { - for (int64_t i = 0; i <= max_reshape_type; i++) { - ge::GeShape shape_1(dims); - ge::GeShape shape_2(dims); - int64_t reshape_type = i | (full_size << 56); - std::cout << "reshape_type = " << std::bitset<8>(reshape_type) << ", shape = " << shape_1.ToString(); - EXPECT_NO_THROW(ExpandDimension::ExpandDims(reshape_type, shape_1, shape_2)); - std::cout << ", after expand dims shape = " << shape_2.ToString() << std::endl; - } - } - } -} - -TEST_F(TransformerExpandDimsUT, not_expand_cases) { - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "FORBIDDEN", {8, 9}, {8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_NZ, "HW", {8, 9}, {8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "", {6, 7, 8, 9}, {6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "", {4, 5, 6, 7, 8, 9}, {4, 5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, "CN", {8, 9}, {8, 9}); - - EXPECT_RunNewExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_ND, "HW", {8, 9}, {8, 9}); - EXPECT_RunNewExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_ND_RNN_BIAS, "HW", {8, 9}, {8, 9}); - EXPECT_RunNewExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_ZN_RNN, "HW", {8, 9}, {8, 9}); -} - -TEST_F(TransformerExpandDimsUT, default_reshape_type_cases) { - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "WN", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "CN", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "NH", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "NC", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NC1HWC0, "CN", {}, {1, 1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "WN", {}, {1, 1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NC1HWC0, "ND", {}, {1, 1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NC1HWC0, "CD", {}, {1, 1, 1, 1, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "", {5}, {1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "", {5}, {1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NC1HWC0, "", {5}, {1, 1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "", {5}, {1, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NC1HWC0, "", {5}, {1, 1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NC1HWC0, "", {5}, {1, 1, 1, 1, 5}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "WN", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "CWN", {5}, {1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "NH", {5}, {1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "NCHW", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NC1HWC0, "CN", {5}, {1, 1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "WNCD", {5}, {1, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NC1HWC0, "ND", {5}, {1, 1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NC1HWC0, "CD", {5}, {1, 1, 1, 1, 5}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "WN", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "CN", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "NH", {5, 6}, {1, 1, 5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "NC", {5, 6}, {1, 1, 5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NC1HWC0, "CN", {5, 6}, {1, 1, 1, 5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "WN", {5, 6}, {1, 1, 1, 5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NC1HWC0, "ND", {5, 6}, {1, 1, 1, 5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NC1HWC0, "CD", {5, 6}, {1, 1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "WHN", {5, 6, 7}, {1, 5, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "CWN", {5, 6, 7}, {1, 5, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "NHW", {5, 6, 7}, {1, 5, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CNW", {5, 6, 7}, {1, 5, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NC1HWC0, "CND", {5, 6, 7}, {1, 1, 5, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "WDN", {5, 6, 7}, {1, 1, 5, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NC1HWC0, "WCND", {5, 6, 7}, {1, 1, 5, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NC1HWC0, "NCD", {5, 6, 7}, {1, 1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NC1HWC0, "CNWD", {5, 6, 7, 8}, {1, 5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NC1HWC0, "NDHWC", {5, 6, 7, 8}, {1, 5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NC1HWC0, "NCHW", {5, 6, 7, 8}, {1, 5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NC1HWC0, "NCDH", {5, 6, 7, 8}, {1, 5, 6, 7, 8}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_DHWNC, "N", {5}, {5, 1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NCDHW, "D", {5}, {1, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDHWC, "H", {5}, {1, 1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_DHWCN, "W", {5}, {1, 1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDHWC, "ND", {6, 5}, {6, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NCDHW, "HW", {6, 5}, {1, 1, 6, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_DHWCN, "NC", {6, 5}, {6, 1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_DHWNC, "DC", {6, 5}, {1, 6, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDHWC, "NDH", {7, 6, 5}, {7, 6, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NCDHW, "HWC", {7, 6, 5}, {1, 1, 7, 6, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDHWC, "NDHW", {8, 7, 6, 5}, {8, 7, 6, 5, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_DHWNC, "N", {5}, {5, 1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NCDHW, "C", {5}, {1, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDHWC, "D", {5}, {1, 1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_DHWCN, "H", {5}, {1, 1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDHWC, "NC", {6, 5}, {6, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NCDHW, "DH", {6, 5}, {1, 1, 6, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_DHWCN, "NW", {6, 5}, {6, 1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_DHWNC, "CW", {6, 5}, {1, 6, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDHWC, "NCD", {7, 6, 5}, {7, 6, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NCDHW, "DHW", {7, 6, 5}, {1, 1, 7, 6, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDHWC, "NCDH", {8, 7, 6, 5}, {8, 7, 6, 5, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_DHWNC, "D", {5}, {5, 1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NCDHW, "H", {5}, {1, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDHWC, "W", {5}, {1, 1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_DHWCN, "C", {5}, {1, 1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDHWC, "DH", {6, 5}, {6, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NCDHW, "WC", {6, 5}, {1, 1, 6, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_DHWCN, "DN", {6, 5}, {6, 1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_DHWNC, "HN", {6, 5}, {1, 6, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDHWC, "DHW", {7, 6, 5}, {7, 6, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NCDHW, "WCN", {7, 6, 5}, {1, 1, 7, 6, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDHWC, "DHWC", {8, 7, 6, 5}, {8, 7, 6, 5, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_DHWNC, "D", {5}, {5, 1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NCDHW, "H", {5}, {1, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NDHWC, "W", {5}, {1, 1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_DHWCN, "N", {5}, {1, 1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NDHWC, "DH", {6, 5}, {6, 5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NCDHW, "WN", {6, 5}, {1, 1, 6, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_DHWCN, "DC", {6, 5}, {6, 1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_DHWNC, "HC", {6, 5}, {1, 6, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NDHWC, "DHW", {7, 6, 5}, {7, 6, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NCDHW, "WNC", {7, 6, 5}, {1, 1, 7, 6, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWNC, ge::FORMAT_NDHWC, "DHWN", {8, 7, 6, 5}, {8, 7, 6, 5, 1}); -} - -TEST_F(TransformerExpandDimsUT, nchw_reshape_type) { - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "N", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "HW", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "HCW", {}, {1, 1, 1, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "N", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "C", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "H", {5}, {1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "W", {5}, {1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCHW", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "CHW", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "HW", {5}, {1, 1, 5, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "N", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "C", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "H", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "W", {5, 6}, {5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NC", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCH", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCHW", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCW", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NH", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NHW", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NW", {5, 6}, {5, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "CH", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "CHW", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "CW", {5, 6}, {1, 5, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "HW", {5, 6}, {1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCH", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCHW", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCW", {5, 6, 7}, {5, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NHW", {5, 6, 7}, {5, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "CHW", {5, 6, 7}, {1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCHW", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NC", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "NCHW", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, "HW", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); -} - -TEST_F(TransformerExpandDimsUT, nhwc_reshape_type) { - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "N", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NH", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "HWC", {}, {1, 1, 1, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "N", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "H", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "W", {5}, {1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "C", {5}, {1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHWC", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "HWC", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "WC", {5}, {1, 1, 5, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "N", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "H", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "W", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "C", {5, 6}, {5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NH", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHW", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHWC", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHC", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NW", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NWC", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NC", {5, 6}, {5, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "HW", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "HWC", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "HC", {5, 6}, {1, 5, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "WC", {5, 6}, {1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHW", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHWC", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHC", {5, 6, 7}, {5, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NWC", {5, 6, 7}, {5, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "HWC", {5, 6, 7}, {1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHWC", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NC", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "NHWC", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NHWC, ge::FORMAT_NC1HWC0, "HW", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); -} - -TEST_F(TransformerExpandDimsUT, hwcn_reshape_type) { - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "N", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HW", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWN", {}, {1, 1, 1, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "H", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "W", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "C", {5}, {1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "N", {5}, {1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWCN", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "WCN", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "CN", {5}, {1, 1, 5, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "H", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "W", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "C", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "N", {5, 6}, {5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HW", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWC", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWCN", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWN", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HC", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HCN", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HN", {5, 6}, {5, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "WC", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "WCN", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "WN", {5, 6}, {1, 5, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "CN", {5, 6}, {1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWC", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWCN", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWN", {5, 6, 7}, {5, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HCN", {5, 6, 7}, {5, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "WCN", {5, 6, 7}, {1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWCN", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HW", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "HWCN", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, "CN", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); -} - -TEST_F(TransformerExpandDimsUT, chwn_reshape_type) { - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "C", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HW", {}, {1, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HWN", {}, {1, 1, 1, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "C", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "H", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "W", {5}, {1, 1, 5, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "N", {5}, {1, 1, 1, 5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHWN", {5}, {5, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HWN", {5}, {1, 5, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "WN", {5}, {1, 1, 5, 1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "C", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "H", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "W", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "N", {5, 6}, {5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CH", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHW", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHWN", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHN", {5, 6}, {5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CW", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CWN", {5, 6}, {5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CN", {5, 6}, {5, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HW", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HWN", {5, 6}, {1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HN", {5, 6}, {1, 5, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "WN", {5, 6}, {1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHW", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHWN", {5, 6, 7}, {5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHN", {5, 6, 7}, {5, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CWN", {5, 6, 7}, {5, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HWN", {5, 6, 7}, {1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHWN", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "HW", {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CHWN", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_CHWN, ge::FORMAT_NC1HWC0, "CN", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); -} - -TEST_F(TransformerExpandDimsUT, ndhwc_reshape_type) { - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "C", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "HW", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NHW", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDWC", {}, {1, 1, 1, 1 ,1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "N", {5}, {5, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "D", {5}, {1, 5, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "H", {5}, {1, 1, 5, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "W", {5}, {1, 1, 1, 5 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "C", {5}, {1, 1, 1, 1 ,5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHWC", {5}, {5, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DHWC", {5}, {1, 5, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "HWC", {5}, {1, 1, 5, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "WC", {5}, {1, 1, 1, 5 ,1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "N", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "D", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "H", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "W", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "C", {5, 6}, {5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "ND", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDH", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDW", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDC", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHW", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHC", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHWC", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NH", {5, 6}, {5, 1, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NW", {5, 6}, {5, 1, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NC", {5, 6}, {5, 1, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DH", {5, 6}, {1, 5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DW", {5, 6}, {1, 5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DC", {5, 6}, {1, 5, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "HW", {5, 6}, {1, 1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "HC", {5, 6}, {1, 1, 5, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "WC", {5, 6}, {1, 1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDH", {5, 6, 7}, {5, 6, 7, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDW", {5, 6, 7}, {5, 6, 1, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDC", {5, 6, 7}, {5, 6, 1, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NHW", {5, 6, 7}, {5, 1, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NHC", {5, 6, 7}, {5, 1, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NWC", {5, 6, 7}, {5, 1, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DHW", {5, 6, 7}, {1, 5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DHC", {5, 6, 7}, {1, 5, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DWC", {5, 6, 7}, {1, 5, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "HWC", {5, 6, 7}, {1, 1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHW", {5, 6, 7, 8}, {5, 6, 7, 8, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHWC", {5, 6, 7, 8}, {5, 6, 7, 8, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHC", {5, 6, 7, 8}, {5, 6, 7, 1, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDWC", {5, 6, 7, 8}, {5, 6, 1, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NHWC", {5, 6, 7, 8}, {5, 1, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "DHWC", {5, 6, 7, 8}, {1, 5, 6, 7, 8}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NHWC", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHWC", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDWC", {5, 6, 7, 8, 9, 7}, {5, 6, 7, 8, 9, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NDHWC, ge::FORMAT_NDC1HWC0, "NDHWC", {5, 6, 7, 8, 9, 7}, {5, 6, 7, 8, 9, 7}); -} - -TEST_F(TransformerExpandDimsUT, ncdhw_reshape_type) { - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "C", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "HW", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NHW", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NDHW", {}, {1, 1, 1, 1 ,1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "N", {5}, {5, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "C", {5}, {1, 5, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "D", {5}, {1, 1, 5, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "H", {5}, {1, 1, 1, 5 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "W", {5}, {1, 1, 1, 1 ,5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDHW", {5}, {5, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CDHW", {5}, {1, 5, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "DHW", {5}, {1, 1, 5, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "HW", {5}, {1, 1, 1, 5 ,1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "N", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "C", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "D", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "H", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "W", {5, 6}, {5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NC", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCD", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCH", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCW", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDH", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDW", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCHW", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDHW", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "ND", {5, 6}, {5, 1, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NH", {5, 6}, {5, 1, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NW", {5, 6}, {5, 1, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CD", {5, 6}, {1, 5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CH", {5, 6}, {1, 5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CW", {5, 6}, {1, 5, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "DH", {5, 6}, {1, 1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "DW", {5, 6}, {1, 1, 5, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "HW", {5, 6}, {1, 1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCD", {5, 6, 7}, {5, 6, 7, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCH", {5, 6, 7}, {5, 6, 1, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCW", {5, 6, 7}, {5, 6, 1, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NDH", {5, 6, 7}, {5, 1, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NDW", {5, 6, 7}, {5, 1, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NHW", {5, 6, 7}, {5, 1, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CDH", {5, 6, 7}, {1, 5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CDW", {5, 6, 7}, {1, 5, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CHW", {5, 6, 7}, {1, 5, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "DHW", {5, 6, 7}, {1, 1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDH", {5, 6, 7, 8}, {5, 6, 7, 8, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDHW", {5, 6, 7, 8}, {5, 6, 7, 8, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDW", {5, 6, 7, 8}, {5, 6, 7, 1, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCHW", {5, 6, 7, 8}, {5, 6, 1, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NDHW", {5, 6, 7, 8}, {5, 1, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CDHW", {5, 6, 7, 8}, {1, 5, 6, 7, 8}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCHW", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDHW", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "CDHW", {5, 6, 7, 8, 9, 7}, {5, 6, 7, 8, 9, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, "NCDHW", {5, 6, 7, 8, 9, 7}, {5, 6, 7, 8, 9, 7}); -} - -TEST_F(TransformerExpandDimsUT, dhwcn_reshape_type) { - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "C", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HW", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HWC", {}, {1, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWN", {}, {1, 1, 1, 1 ,1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "D", {5}, {5, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "H", {5}, {1, 5, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "W", {5}, {1, 1, 5, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "C", {5}, {1, 1, 1, 5 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "N", {5}, {1, 1, 1, 1 ,5}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWCN", {5}, {5, 1, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HWCN", {5}, {1, 5, 1, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "WCN", {5}, {1, 1, 5, 1 ,1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "CN", {5}, {1, 1, 1, 5 ,1}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "D", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "H", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "W", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "C", {5, 6}, {5, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "N", {5, 6}, {5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DH", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHW", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHC", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHN", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWC", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWN", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHCN", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWCN", {5, 6}, {5, 6, 1, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DW", {5, 6}, {5, 1, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DC", {5, 6}, {5, 1, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DN", {5, 6}, {5, 1, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HW", {5, 6}, {1, 5, 6, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HC", {5, 6}, {1, 5, 1, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HN", {5, 6}, {1, 5, 1, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "WC", {5, 6}, {1, 1, 5, 6, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "WN", {5, 6}, {1, 1, 5, 1, 6}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "CN", {5, 6}, {1, 1, 1, 5, 6}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHW", {5, 6, 7}, {5, 6, 7, 1, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHC", {5, 6, 7}, {5, 6, 1, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHN", {5, 6, 7}, {5, 6, 1, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DWC", {5, 6, 7}, {5, 1, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DWN", {5, 6, 7}, {5, 1, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DCN", {5, 6, 7}, {5, 1, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HWC", {5, 6, 7}, {1, 5, 6, 7, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HWN", {5, 6, 7}, {1, 5, 6, 1, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HCN", {5, 6, 7}, {1, 5, 1, 6, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "WCN", {5, 6, 7}, {1, 1, 5, 6, 7}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWC", {5, 6, 7, 8}, {5, 6, 7, 8, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWCN", {5, 6, 7, 8}, {5, 6, 7, 8, 1}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWN", {5, 6, 7, 8}, {5, 6, 7, 1, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHCN", {5, 6, 7, 8}, {5, 6, 1, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DWCN", {5, 6, 7, 8}, {5, 1, 6, 7, 8}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HWCN", {5, 6, 7, 8}, {1, 5, 6, 7, 8}); - - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWC", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWCN", {5, 6, 7, 8, 9}, {5, 6, 7, 8, 9}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "HWCN", {5, 6, 7, 8, 9, 7}, {5, 6, 7, 8, 9, 7}); - EXPECT_RunExpandDimsCase(ge::FORMAT_DHWCN, ge::FORMAT_NDC1HWC0, "DHWCN", {5, 6, 7, 8, 9, 7}, {5, 6, 7, 8, 9, 7}); -} - -TEST_F(TransformerExpandDimsUT, reshape_type_case1) { - RunReshapeTypeCase(ge::FORMAT_NCHW, 1, "N"); - RunReshapeTypeCase(ge::FORMAT_NCHW, 2, "NC"); - RunReshapeTypeCase(ge::FORMAT_NCHW, 2, "HW"); - RunReshapeTypeCase(ge::FORMAT_NCHW, 3, "NHW"); - RunReshapeTypeCase(ge::FORMAT_NCHW, 3, "CHW"); - RunReshapeTypeCase(ge::FORMAT_NCHW, 4, "NCHW"); - - RunReshapeTypeCase(ge::FORMAT_NDHWC, 1, "D"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 2, "NC"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 2, "DC"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 2, "HW"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 3, "NHC"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 3, "HWC"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 3, "NHW"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 4, "DHWC"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 4, "NHWC"); - RunReshapeTypeCase(ge::FORMAT_NDHWC, 5, "NDHWC"); -} - -TEST_F(TransformerExpandDimsUT, reshape_type_case2) { - std::string reshape_type; - std::string failed_reason; - bool ret = transformer::ExpandDimension::GenerateReshapeTypeByMask(ge::FORMAT_ND, 2, 0, reshape_type, failed_reason); - EXPECT_EQ(ret, true); - EXPECT_EQ(reshape_type.empty(), true); - - ret = transformer::ExpandDimension::GenerateReshapeTypeByMask(ge::FORMAT_ND, 2, 1, reshape_type, failed_reason); - EXPECT_EQ(ret, false); - EXPECT_EQ(reshape_type.empty(), true); - - ret = transformer::ExpandDimension::GenerateReshapeTypeByMask(ge::FORMAT_NC1HWC0, 2, 1, reshape_type, failed_reason); - EXPECT_EQ(ret, false); - EXPECT_EQ(reshape_type.empty(), true); - - int64_t reshape_mask = 3; - ret = transformer::ExpandDimension::GenerateReshapeTypeByMask(ge::FORMAT_NHWC, 2, reshape_mask, reshape_type, failed_reason); - EXPECT_EQ(ret, false); - EXPECT_EQ(reshape_type.empty(), true); - - reshape_mask = 3; - ret = transformer::ExpandDimension::GenerateReshapeTypeByMask(ge::FORMAT_NHWC, 3, reshape_mask, reshape_type, failed_reason); - EXPECT_EQ(ret, false); - EXPECT_EQ(reshape_type.empty(), true); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/transformer_transfer_shape_ut.cc b/tests/ut/graph/testcase/transformer_transfer_shape_ut.cc deleted file mode 100644 index a2502b5512adf135e5cf52a2c955888a9f59a10c..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/transformer_transfer_shape_ut.cc +++ /dev/null @@ -1,1076 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "axis_util.h" -#include "expand_dimension.h" -#include "transfer_shape_according_to_format.h" -#include "transfer_range_according_to_format.h" -#include "platform/platform_info.h" -#include "transfer_def.h" -#include "transfer_shape_utils.h" - -using namespace ge; - -namespace transformer { -class TransformerTransferShapeUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} - - void EXPECT_RunTransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &dtype, - const bool &expect_ret, const vector &dims, const vector &expect_dim, - bool only_test_first_interface = false) { - std::cout << "EXPECT_RunTransferShape: origin_format=" << origin_format << ", format=" << format << ", dtype=" << dtype - << ", dim size=" << dims.size() << std::endl; - ge::GeShape shape(dims); - ShapeAndFormat shape_and_format_info {shape, origin_format, format, dtype}; - ShapeTransferAccordingToFormat shape_transfer; - bool ret = shape_transfer.GetShapeAccordingToFormat(shape_and_format_info); - EXPECT_EQ(ret, expect_ret); - if (ret) { - EXPECT_EQ(shape.GetDims(), expect_dim); - } - if (only_test_first_interface) { - return; - } - gert::Shape current_shape; - for (const int64_t &d : dims) { - current_shape.AppendDim(d); - } - gert::Shape ret_shape; - ret = shape_transfer.TransferShape(origin_format, format, dtype, current_shape, ret_shape); - if (ret && dims != expect_dim) { - vector new_dim; - for (size_t i = 0; i < ret_shape.GetDimNum(); ++i) { - new_dim.push_back(ret_shape.GetDim(i)); - } - EXPECT_EQ(new_dim, expect_dim); - } - - ret = shape_transfer.TransferShape(origin_format, format, dtype, current_shape); - EXPECT_EQ(ret, expect_ret); - if (ret) { - vector new_dim; - for (size_t i = 0; i < current_shape.GetDimNum(); ++i) { - new_dim.push_back(current_shape.GetDim(i)); - } - EXPECT_EQ(new_dim, expect_dim); - } - ExtAxisValue ext_axis; - shape_transfer.InitExtAxisValue(nullptr, ext_axis); - ge::GeShape src_shape(dims); - ge::GeShape dst_shape; - ret = shape_transfer.TransferShape(origin_format, format, dtype, ext_axis, src_shape, dst_shape); - EXPECT_EQ(ret, expect_ret); - if (ret && dims != expect_dim) { - EXPECT_EQ(dst_shape.GetDims(), expect_dim); - } - - ret = shape_transfer.TransferShape(origin_format, format, dtype, ext_axis, src_shape); - EXPECT_EQ(ret, expect_ret); - if (ret) { - EXPECT_EQ(src_shape.GetDims(), expect_dim); - } - } - - void EXPECT_RunTransferShape(const ge::OpDescPtr &op_desc, const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &dtype, const bool &expect_ret, const vector &dims, - const vector &expect_dim, bool only_test_first_interface = false, - const int64_t &m0_val = 16) { - std::cout << "EXPECT_RunTransferShape: origin_format=" << origin_format << ", format=" << format << ", dtype=" << dtype - << ", dim size=" << dims.size() << std::endl; - ge::GeShape shape(dims); - ShapeAndFormat shape_and_format_info {shape, origin_format, format, dtype}; - ShapeTransferAccordingToFormat shape_transfer; - bool ret = shape_transfer.GetShapeAccordingToFormat(op_desc, shape_and_format_info); - EXPECT_EQ(ret, expect_ret); - if (ret) { - EXPECT_EQ(shape.GetDims(), expect_dim); - } - if (only_test_first_interface) { - return; - } - - gert::Shape current_shape; - for (const int64_t &d : dims) { - current_shape.AppendDim(d); - } - - gert::Shape ret_shape; - ret = shape_transfer.TransferShape(origin_format, format, dtype, current_shape, ret_shape, op_desc); - if (ret && dims != expect_dim) { - vector new_dim; - for (size_t i = 0; i < ret_shape.GetDimNum(); ++i) { - new_dim.push_back(ret_shape.GetDim(i)); - } - EXPECT_EQ(new_dim, expect_dim); - } - - ret = shape_transfer.TransferShape(origin_format, format, dtype, current_shape, op_desc); - EXPECT_EQ(ret, expect_ret); - if (ret) { - vector new_dim; - for (size_t i = 0; i < current_shape.GetDimNum(); ++i) { - new_dim.push_back(current_shape.GetDim(i)); - } - EXPECT_EQ(new_dim, expect_dim); - } - ExtAxisValue ext_axis; - shape_transfer.InitExtAxisValue(op_desc, ext_axis); - ext_axis[3] = m0_val; - ge::GeShape src_shape(dims); - ge::GeShape dst_shape; - ret = shape_transfer.TransferShape(origin_format, format, dtype, ext_axis, src_shape, dst_shape); - EXPECT_EQ(ret, expect_ret); - if (ret && dims != expect_dim) { - EXPECT_EQ(dst_shape.GetDims(), expect_dim); - } - - ret = shape_transfer.TransferShape(origin_format, format, dtype, ext_axis, src_shape); - EXPECT_EQ(ret, expect_ret); - if (ret) { - EXPECT_EQ(src_shape.GetDims(), expect_dim); - } - } - - void RunTransferShapeWithExtAxis(const ge::OpDescPtr &op_desc, const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &dtype, const bool &expect_ret, const vector &dims, - const vector &expect_dim, const int64_t &m0_val = 16) { - std::cout << "EXPECT_RunTransferShape: origin_format=" << origin_format << ", format=" << format << ", dtype=" << dtype - << ", dim size=" << dims.size() << ", m0 value=" << m0_val << std::endl; - - ShapeTransferAccordingToFormat shape_transfer; - ExtAxisValue ext_axis; - shape_transfer.InitExtAxisValue(op_desc, ext_axis); - ext_axis[3] = m0_val; - ge::GeShape src_shape(dims); - ge::GeShape dst_shape; - bool ret = shape_transfer.TransferShape(origin_format, format, dtype, ext_axis, src_shape, dst_shape); - EXPECT_EQ(ret, expect_ret); - if (ret && dims != expect_dim) { - EXPECT_EQ(dst_shape.GetDims(), expect_dim); - } - - ret = shape_transfer.TransferShape(origin_format, format, dtype, ext_axis, src_shape); - EXPECT_EQ(ret, expect_ret); - if (ret) { - EXPECT_EQ(src_shape.GetDims(), expect_dim); - } - } -}; - -TEST_F(TransformerTransferShapeUT, transfer_shape_verify_param) { - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NHWC, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_ND, DT_INT8, true, {3, 4, 5, 6}, {3, 4, 5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_NHWC, DT_FLOAT, true, {3, 4, 5, 6}, {3, 4, 5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_NCDHW, DT_INT32, true, {3, 4, 5, 6}, {3, 4, 5, 6}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NHWC, DT_UNDEFINED, false, {3, 4, 5, 6}, {3, 4, 5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NHWC, DT_MAX, false, {3, 4, 5, 6}, {3, 4, 5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_RESERVED, ge::FORMAT_NHWC, DT_FLOAT16, false, {3, 4, 5, 6}, {3, 4, 5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_END, ge::FORMAT_NHWC, DT_FLOAT16, false, {3, 4, 5, 6}, {3, 4, 5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_RESERVED, DT_FLOAT16, false, {3, 4, 5, 6}, {3, 4, 5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_END, DT_FLOAT16, false, {3, 4, 5, 6}, {3, 4, 5, 6}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nchw) { - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NCHW, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NHWC, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NHWC, DT_FLOAT, true, {5}, {5}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NCHW, DT_INT64, true, {5, 6}, {5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NCHW, DT_UINT8, true, {5, 6, 7}, {5, 6, 7}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NCHW, DT_UINT8, true, {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NHWC, DT_INT8, true, {5, 6, 7, 8}, {5, 7, 8, 6}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_HWCN, DT_UINT16, true, {5, 6, 7, 8}, {7, 8, 6, 5}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_CHWN, DT_INT16, true, {5, 6, 7, 8}, {6, 7, 8, 5}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NCHW, DT_UINT32, true, {5, 6, 7, 8}, {5, 6, 7, 8}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NHWC, DT_INT32, true, {5, 6, 7, 8}, {5, 7, 8, 6}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_HWCN, DT_FLOAT, true, {5, 6, 7, 8}, {7, 8, 6, 5}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_CHWN, DT_FLOAT16, true, {5, 6, 7, 8}, {6, 7, 8, 5}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_UINT8, true, {8, 512, 5, 5}, {8, 16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT8, true, {8, 512, 5, 5}, {8, 16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_UINT16, true, {8, 512, 5, 5}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT16, true, {8, 512, 5, 5}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_UINT32, true, {8, 512, 5, 5}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT32, true, {8, 512, 5, 5}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_FLOAT, true, {8, 512, 5, 5}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_FLOAT16, true, {8, 512, 5, 5}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_UINT1, true, {8, 512, 5, 5}, {8, 2, 5, 5, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_UINT2, true, {8, 512, 5, 5}, {8, 4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT2, true, {8, 512, 5, 5}, {8, 4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT4, true, {8, 512, 5, 5}, {8, 8, 5, 5, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_UINT8, true, {512, 1, 5, 5}, {16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_INT8, true, {512, 1, 5, 5}, {16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_UINT16, true, {512, 1, 5, 5}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_INT16, true, {512, 1, 5, 5}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_UINT32, true, {512, 1, 5, 5}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_INT32, true, {512, 1, 5, 5}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_FLOAT, true, {512, 1, 5, 5}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_FLOAT16, true, {512, 1, 5, 5}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_UINT1, true, {512, 1, 5, 5}, {2, 5, 5, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_UINT2, true, {512, 1, 5, 5}, {4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_INT2, true, {512, 1, 5, 5}, {4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWC0, DT_INT4, true, {512, 1, 5, 5}, {8, 5, 5, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_UINT8, true, {18, 512, 5, 5}, {16, 5, 5, 2, 32, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_INT8, true, {18, 512, 5, 5}, {16, 5, 5, 2, 32, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_UINT16, true, {18, 512, 5, 5}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_INT16, true, {18, 512, 5, 5}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_UINT32, true, {18, 512, 5, 5}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_INT32, true, {18, 512, 5, 5}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_FLOAT, true, {18, 512, 5, 5}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_FLOAT16, true, {18, 512, 5, 5}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_UINT1, true, {18, 512, 5, 5}, {2, 5, 5, 2, 256, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_UINT2, true, {18, 512, 5, 5}, {4, 5, 5, 2, 128, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_INT2, true, {18, 512, 5, 5}, {4, 5, 5, 2, 128, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_C1HWNCoC0, DT_INT4, true, {18, 512, 5, 5}, {8, 5, 5, 2, 64, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_UINT8, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT8, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_UINT16, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT16, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_UINT32, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT32, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_FLOAT, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_FLOAT16, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_UINT1, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_UINT2, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT2, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT4, true, {8, 512, 5, 5}, {8, 128, 5, 5, 4}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_UINT8, true, {48, 512, 5, 5}, {400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_INT8, true, {48, 512, 5, 5}, {400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_UINT16, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_INT16, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_UINT32, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_INT32, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_UINT1, true, {48, 512, 5, 5}, {50, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_UINT2, true, {48, 512, 5, 5}, {100, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_INT2, true, {48, 512, 5, 5}, {100, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_INT4, true, {48, 512, 5, 5}, {200, 3, 16, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_UINT8, true, {48, 3, 5, 5}, {4, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_INT8, true, {48, 3, 5, 5}, {4, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_UINT16, true, {48, 3, 5, 5}, {7, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_INT16, true, {48, 3, 5, 5}, {7, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_UINT32, true, {48, 3, 5, 5}, {7, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_INT32, true, {48, 3, 5, 5}, {7, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT, true, {48, 3, 5, 5}, {7, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, 3, 5, 5}, {7, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_UINT1, true, {48, 3, 5, 5}, {1, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_UINT2, true, {48, 3, 5, 5}, {1, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_INT2, true, {48, 3, 5, 5}, {1, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_INT4, true, {48, 3, 5, 5}, {2, 3, 16, 64}); - - int32_t group = 16; - ge::Format target_format = static_cast(GetFormatFromSub(static_cast(ge::FORMAT_FRACTAL_Z), group)); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_UINT8, true, {48, 512, 5, 5}, {6400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_INT8, true, {48, 512, 5, 5}, {6400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_UINT16, true, {48, 512, 5, 5}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_INT16, true, {48, 512, 5, 5}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_UINT32, true, {48, 512, 5, 5}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_INT32, true, {48, 512, 5, 5}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_FLOAT, true, {48, 512, 5, 5}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_FLOAT16, true, {48, 512, 5, 5}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_UINT1, true, {48, 512, 5, 5}, {800, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_UINT2, true, {48, 512, 5, 5}, {1600, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_INT2, true, {48, 512, 5, 5}, {1600, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_INT4, true, {48, 512, 5, 5}, {3200, 3, 16, 64}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nchw_unknow_shape) { - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT16, true, {-1, 512, 5, 5}, {-1, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT16, true, {-1, 512, -1, 5}, {-1, 32, -1, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0, DT_INT16, true, {8, -1, 5, 5}, {8, -1, 5, 5, 16}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT16, true, {-1, 33, 5, 5}, {-1, 9, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT16, true, {-1, 33, -1, 5}, {-1, 9, -1, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NC1HWC0_C04, DT_INT16, true, {8, -1, 5, 5}, {8, -1, 5, 5, 4}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {-1, 33, 5, 5}, {75, -1, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, -1, 5, 5}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, -1, -1, 5}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, -1, 5, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, 512, -1, 5}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, 512, 5, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, 512, -1, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, -1, -1, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {-1, -1, -1, -1}, {-1, -1, 16, 16}); - - int32_t group = 16; - ge::Format target_format = static_cast(GetFormatFromSub(static_cast(ge::FORMAT_FRACTAL_Z), group)); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_FLOAT16, true, {48, 512, -1, 5}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_FLOAT16, true, {-1, 512, 5, 5}, {800, -1, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_FLOAT16, true, {48, -1, 5, 5}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, target_format, DT_FLOAT16, true, {-1, -1, 5, 5}, {-1, -1, 16, 16}); - - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {-1, 3, 5, 5}, {7, -1, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, -1, 5, 5}, {7, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, -1, -1, 5}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, -1, 5, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, 3, -1, 5}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, 3, 5, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, 3, -1, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {48, -1, -1, -1}, {-1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_Z_C04, DT_FLOAT16, true, {-1, -1, -1, -1}, {-1, -1, 16, 16}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_hwcn) { - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NCHW, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_HWCN, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NHWC, DT_FLOAT, true, {5}, {5}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NCHW, DT_INT64, true, {5, 6}, {5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NCHW, DT_UINT8, true, {5, 6, 7}, {5, 6, 7}); - - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NCHW, DT_UINT8, true, {7, 8, 6, 5}, {5, 6, 7, 8}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NHWC, DT_INT8, true, {7, 8, 6, 5}, {5, 7, 8, 6}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_HWCN, DT_UINT16, true, {7, 8, 6, 5}, {7, 8, 6, 5}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_CHWN, DT_INT16, true, {7, 8, 6, 5}, {6, 7, 8, 5}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NCHW, DT_UINT32, true, {7, 8, 6, 5}, {5, 6, 7, 8}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NHWC, DT_INT32, true, {7, 8, 6, 5}, {5, 7, 8, 6}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_HWCN, DT_FLOAT, true, {7, 8, 6, 5}, {7, 8, 6, 5}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_CHWN, DT_FLOAT16, true, {7, 8, 6, 5}, {6, 7, 8, 5}); - - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_UINT8, true, {5, 5, 512, 8}, {8, 16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_INT8, true, {5, 5, 512, 8}, {8, 16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_UINT16, true, {5, 5, 512, 8}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_INT16, true, {5, 5, 512, 8}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_UINT32, true, {5, 5, 512, 8}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_INT32, true, {5, 5, 512, 8}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_FLOAT, true, {5, 5, 512, 8}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_FLOAT16, true, {5, 5, 512, 8}, {8, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_UINT1, true, {5, 5, 512, 8}, {8, 2, 5, 5, 256}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_UINT2, true, {5, 5, 512, 8}, {8, 4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_INT2, true, {5, 5, 512, 8}, {8, 4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0, DT_INT4, true, {5, 5, 512, 8}, {8, 8, 5, 5, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_UINT8, true, {5, 5, 1, 512}, {16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_INT8, true, {5, 5, 1, 512}, {16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_UINT16, true, {5, 5, 1, 512}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_INT16, true, {5, 5, 1, 512}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_UINT32, true, {5, 5, 1, 512}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_INT32, true, {5, 5, 1, 512}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_FLOAT, true, {5, 5, 1, 512}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_FLOAT16, true, {5, 5, 1, 512}, {32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_UINT1, true, {5, 5, 1, 512}, {2, 5, 5, 256}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_UINT2, true, {5, 5, 1, 512}, {4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_INT2, true, {5, 5, 1, 512}, {4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWC0, DT_INT4, true, {5, 5, 1, 512}, {8, 5, 5, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_UINT8, true, {5, 5, 512, 18}, {16, 5, 5, 2, 32, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_INT8, true, {5, 5, 512, 18}, {16, 5, 5, 2, 32, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_UINT16, true, {5, 5, 512, 18}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_INT16, true, {5, 5, 512, 18}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_UINT32, true, {5, 5, 512, 18}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_INT32, true, {5, 5, 512, 18}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_FLOAT, true, {5, 5, 512, 18}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_FLOAT16, true, {5, 5, 512, 18}, {32, 5, 5, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_UINT1, true, {5, 5, 512, 18}, {2, 5, 5, 2, 256, 256}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_UINT2, true, {5, 5, 512, 18}, {4, 5, 5, 2, 128, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_INT2, true, {5, 5, 512, 18}, {4, 5, 5, 2, 128, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_C1HWNCoC0, DT_INT4, true, {5, 5, 512, 18}, {8, 5, 5, 2, 64, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_UINT8, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_INT8, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_UINT16, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_INT16, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_UINT32, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_INT32, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_FLOAT, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_FLOAT16, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_UINT1, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_UINT2, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_INT2, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NC1HWC0_C04, DT_INT4, true, {5, 5, 512, 8}, {8, 128, 5, 5, 4}); - - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_UINT8, true, {5, 5, 512, 48}, {400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_INT8, true, {5, 5, 512, 48}, {400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_UINT16, true, {5, 5, 512, 48}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_INT16, true, {5, 5, 512, 48}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_UINT32, true, {5, 5, 512, 48}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_INT32, true, {5, 5, 512, 48}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_FLOAT, true, {5, 5, 512, 48}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {5, 5, 512, 48}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_UINT1, true, {5, 5, 512, 48}, {50, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_UINT2, true, {5, 5, 512, 48}, {100, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_INT2, true, {5, 5, 512, 48}, {100, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_FRACTAL_Z, DT_INT4, true, {5, 5, 512, 48}, {200, 3, 16, 64}); - - int32_t group = 16; - ge::Format target_format = static_cast(GetFormatFromSub(static_cast(ge::FORMAT_FRACTAL_Z), group)); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_UINT8, true, {5, 5, 512, 48}, {6400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_INT8, true, {5, 5, 512, 48}, {6400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_UINT16, true, {5, 5, 512, 48}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_INT16, true, {5, 5, 512, 48}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_UINT32, true, {5, 5, 512, 48}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_INT32, true, {5, 5, 512, 48}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_FLOAT, true, {5, 5, 512, 48}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_FLOAT16, true, {5, 5, 512, 48}, {12800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_UINT1, true, {5, 5, 512, 48}, {800, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_UINT2, true, {5, 5, 512, 48}, {1600, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_INT2, true, {5, 5, 512, 48}, {1600, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, target_format, DT_INT4, true, {5, 5, 512, 48}, {3200, 3, 16, 64}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_ncdhw) { - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_FLOAT, true, {5}, {5}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_INT64, true, {5, 6}, {5, 6}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_UINT8, true, {5, 6, 7}, {5, 6, 7}); - - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_UINT8, true, {8, 512, 9, 5, 5}, {8, 9, 16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_INT8, true, {8, 512, 9, 5, 5}, {8, 9, 16, 5, 5, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_UINT16, true, {8, 512, 9, 5, 5}, {8, 9, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_INT16, true, {8, 512, 9, 5, 5}, {8, 9, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_UINT32, true, {8, 512, 9, 5, 5}, {8, 9, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_INT32, true, {8, 512, 9, 5, 5}, {8, 9, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_FLOAT, true, {8, 512, 9, 5, 5}, {8, 9, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_FLOAT16, true, {8, 512, 9, 5, 5}, {8, 9, 32, 5, 5, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_UINT1, true, {8, 512, 9, 5, 5}, {8, 9, 2, 5, 5, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_UINT2, true, {8, 512, 9, 5, 5}, {8, 9, 4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_INT2, true, {8, 512, 9, 5, 5}, {8, 9, 4, 5, 5, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDC1HWC0, DT_INT4, true, {8, 512, 9, 5, 5}, {8, 9, 8, 5, 5, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_NDHWC, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {8, 9, 5, 15, 512}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_DHWCN, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {9, 5, 15, 512, 8}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_DHWNC, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {9, 5, 15, 8, 512}); - - EXPECT_RunTransferShape(ge::FORMAT_NDHWC, ge::FORMAT_NCDHW, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {8, 15, 512, 9, 5}); - EXPECT_RunTransferShape(ge::FORMAT_NDHWC, ge::FORMAT_DHWCN, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {512, 9, 5, 15, 8}); - EXPECT_RunTransferShape(ge::FORMAT_NDHWC, ge::FORMAT_DHWNC, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {512, 9, 5, 8, 15}); - - EXPECT_RunTransferShape(ge::FORMAT_DHWCN, ge::FORMAT_NCDHW, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {15, 5, 8, 512, 9}); - EXPECT_RunTransferShape(ge::FORMAT_DHWCN, ge::FORMAT_NDHWC, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {15, 8, 512, 9, 5}); - EXPECT_RunTransferShape(ge::FORMAT_DHWCN, ge::FORMAT_DHWNC, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {8, 512, 9, 15, 5}); - - EXPECT_RunTransferShape(ge::FORMAT_DHWNC, ge::FORMAT_NCDHW, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {5, 15, 8, 512, 9}); - EXPECT_RunTransferShape(ge::FORMAT_DHWNC, ge::FORMAT_NDHWC, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {5, 8, 512, 9, 15}); - EXPECT_RunTransferShape(ge::FORMAT_DHWNC, ge::FORMAT_DHWCN, DT_FLOAT16, true, {8, 512, 9, 5, 15}, {8, 512, 9, 15, 5}); - - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_UINT8, true, {48, 512, 3, 5, 5}, {1200, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_INT8, true, {48, 512, 3, 5, 5}, {1200, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_UINT16, true, {48, 512, 3, 5, 5}, {2400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_INT16, true, {48, 512, 3, 5, 5}, {2400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_UINT32, true, {48, 512, 3, 5, 5}, {2400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_INT32, true, {48, 512, 3, 5, 5}, {2400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_FLOAT, true, {48, 512, 3, 5, 5}, {2400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_FLOAT16, true, {48, 512, 3, 5, 5}, {2400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_UINT1, true, {48, 512, 3, 5, 5}, {150, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_UINT2, true, {48, 512, 3, 5, 5}, {300, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_INT2, true, {48, 512, 3, 5, 5}, {300, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D, DT_INT4, true, {48, 512, 3, 5, 5}, {600, 3, 16, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_UINT8, true, {90, 512, 3, 5, 5}, {450, 16, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_INT8, true, {90, 512, 3, 5, 5}, {450, 16, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_UINT16, true, {90, 512, 3, 5, 5}, {450, 32, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_INT16, true, {90, 512, 3, 5, 5}, {450, 32, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_UINT32, true, {90, 512, 3, 5, 5}, {450, 32, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_INT32, true, {90, 512, 3, 5, 5}, {450, 32, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_FLOAT, true, {90, 512, 3, 5, 5}, {450, 32, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_FLOAT16, true, {90, 512, 3, 5, 5}, {450, 32, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_UINT1, true, {90, 512, 3, 5, 5}, {450, 2, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_UINT2, true, {90, 512, 3, 5, 5}, {450, 4, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_INT2, true, {90, 512, 3, 5, 5}, {450, 4, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE, DT_INT4, true, {90, 512, 3, 5, 5}, {450, 8, 16, 64}); - - int32_t group = 16; - ge::Format target_format = static_cast(GetFormatFromSub(static_cast(ge::FORMAT_FRACTAL_Z_3D), group)); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_UINT8, true, {48, 512, 3, 5, 5}, {19200, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_INT8, true, {48, 512, 3, 5, 5}, {19200, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_UINT16, true, {48, 512, 3, 5, 5}, {38400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_INT16, true, {48, 512, 3, 5, 5}, {38400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_UINT32, true, {48, 512, 3, 5, 5}, {38400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_INT32, true, {48, 512, 3, 5, 5}, {38400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_FLOAT, true, {48, 512, 3, 5, 5}, {38400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_FLOAT16, true, {48, 512, 3, 5, 5}, {38400, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_UINT1, true, {48, 512, 3, 5, 5}, {2400, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_UINT2, true, {48, 512, 3, 5, 5}, {4800, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_INT2, true, {48, 512, 3, 5, 5}, {4800, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_NCDHW, target_format, DT_INT4, true, {48, 512, 3, 5, 5}, {9600, 3, 16, 64}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_4d_to_6hd) { - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_NDC1HWC0, DT_FLOAT16, true, {4, 33, 7, 7}, {4, 1, 3, 7, 7, 16}); - EXPECT_RunTransferShape(ge::FORMAT_NHWC, ge::FORMAT_NDC1HWC0, DT_FLOAT16, true, {4, 7, 7, 33}, {4, 1, 3, 7, 7, 16}); - EXPECT_RunTransferShape(ge::FORMAT_HWCN, ge::FORMAT_NDC1HWC0, DT_FLOAT16, true, {7, 7, 33, 4}, {4, 1, 3, 7, 7, 16}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nd_to_nz) { - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {34}, {1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {34, 1}, {1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT, true, {18, 34}, {3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT8, true, {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT8, true, {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT16, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT16, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT32, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT32, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT1, true, {1, 18, 134}, {1, 1, 2, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT2, true, {1, 18, 134}, {1, 2, 2, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT2, true, {1, 18, 134}, {1, 2, 2, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT4, true, {1, 18, 134}, {1, 3, 2, 16, 64}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {-2}, {-2}, true); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {8, 1000}, {63, 1, 16, 16}); - - transformer::TransferShapeUtils::m0_list_.fill(1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {34}, {1, 34, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {34, 1}, {1, 34, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT, true, {18, 34}, {3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT8, true, {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT8, true, {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT16, true, {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT16, true, {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT32, true, {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT32, true, {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT, true, {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT1, true, {1, 18, 134}, {1, 1, 18, 1, 256}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_UINT2, true, {1, 18, 134}, {1, 2, 18, 1, 128}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT2, true, {1, 18, 134}, {1, 2, 18, 1, 128}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, DT_INT4, true, {1, 18, 134}, {1, 3, 18, 1, 64}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_NZ, DT_FLOAT16, true, {8, 1000}, {63, 8, 1, 16}, 1); - transformer::TransferShapeUtils::m0_list_.fill(16); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nd_to_nz_C0_16) { - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, {34}, {1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, {34, 1}, {1, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT, true, {18, 34}, {3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT8, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT8, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT16, true, - {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT16, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT32, true, - {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT32, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT, true, {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, - {1, 18, 34}, {1, 3, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT1, true, - {1, 18, 134}, {1, 9, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT2, true, - {1, 18, 134}, {1, 9, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT2, true, - {1, 18, 134}, {1, 9, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT4, true, {1, 18, 134}, {1, 9, 2, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, {-2}, {-2}, true); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, {8, 1000}, {63, 1, 16, 16}); - - transformer::TransferShapeUtils::m0_list_.fill(1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, - {34}, {1, 34, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, - {34, 1}, {1, 34, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT, true, - {18, 34}, {3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT8, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT8, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT16, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT16, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT32, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT32, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, - {1, 18, 34}, {1, 3, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT1, true, - {1, 18, 134}, {1, 9, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_UINT2, true, - {1, 18, 134}, {1, 9, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT2, true, - {1, 18, 134}, {1, 9, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_16, DT_INT4, true, - {1, 18, 134}, {1, 9, 18, 1, 16}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_NZ_C0_16, DT_FLOAT16, true, - {8, 1000}, {63, 8, 1, 16}, 1); - transformer::TransferShapeUtils::m0_list_.fill(16); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nd_to_nz_C0_32) { - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, {34}, {1, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, {34, 1}, {1, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT, true, {18, 34}, {2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT8, true, {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT8, true, {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT16, true, - {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT16, true, {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT32, true, - {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT32, true, {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT, true, {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, - {1, 18, 34}, {1, 2, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT1, true, - {1, 18, 134}, {1, 5, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT2, true, - {1, 18, 134}, {1, 5, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT2, true, - {1, 18, 134}, {1, 5, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT4, true, {1, 18, 134}, {1, 5, 2, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, {-2}, {-2}, true); - EXPECT_RunTransferShape(ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, {8, 1000}, {32, 1, 16, 32}); - - transformer::TransferShapeUtils::m0_list_.fill(1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, - {34}, {1, 34, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, - {34, 1}, {1, 34, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT, true, - {18, 34}, {2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT8, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT8, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT16, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT16, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT32, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT32, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, - {1, 18, 34}, {1, 2, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT1, true, - {1, 18, 134}, {1, 5, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_UINT2, true, - {1, 18, 134}, {1, 5, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT2, true, - {1, 18, 134}, {1, 5, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ_C0_32, DT_INT4, true, - {1, 18, 134}, {1, 5, 18, 1, 32}, 1); - RunTransferShapeWithExtAxis(nullptr, ge::FORMAT_NCHW, ge::FORMAT_FRACTAL_NZ_C0_32, DT_FLOAT16, true, - {8, 1000}, {32, 8, 1, 32}, 1); - transformer::TransferShapeUtils::m0_list_.fill(16); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nd_to_fz) { - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT8, true, {18, 34}, {1, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT8, true, {18, 34}, {1, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT16, true, {18, 34}, {2, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT16, true, {18, 34}, {2, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT32, true, {18, 34}, {2, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT32, true, {18, 34}, {2, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_FLOAT, true, {18, 34}, {2, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {18, 34}, {2, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT1, true, {188, 23}, {1, 2, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT2, true, {188, 23}, {2, 2, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT2, true, {188, 23}, {2, 2, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT4, true, {188, 23}, {3, 2, 16, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT8, true, {48, 512, 5, 5}, {400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT8, true, {48, 512, 5, 5}, {400, 3, 16, 32}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT16, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT16, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT32, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT32, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_FLOAT, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_FLOAT16, true, {48, 512, 5, 5}, {800, 3, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT1, true, {48, 512, 5, 5}, {50, 3, 16, 256}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_UINT2, true, {48, 512, 5, 5}, {100, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT2, true, {48, 512, 5, 5}, {100, 3, 16, 128}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_Z, DT_INT4, true, {48, 512, 5, 5}, {200, 3, 16, 64}); - - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_UINT8, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_INT8, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_UINT16, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_INT16, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_UINT32, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_INT32, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_FLOAT, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_FLOAT16, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_UINT1, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_UINT2, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_INT2, true, {48, 80, 5, 5}, {6, 4, 16, 16}); - EXPECT_RunTransferShape(ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_LSTM, DT_INT4, true, {48, 80, 5, 5}, {6, 4, 16, 16}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nd_to_zn_rnn) { - ge::OpDescPtr op_desc = std::make_shared("test", "test"); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {128}, {128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {65, 128}, {65, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT1, true, {65, 128}, {65, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT2, true, {65, 128}, {65, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT4, true, {65, 128}, {65, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT8, true, {65, 128}, {65, 128}); - - (void)ge::AttrUtils::SetInt(op_desc, "input_size", 30); - (void)ge::AttrUtils::SetInt(op_desc, "hidden_size", 40); - (void)ge::AttrUtils::SetInt(op_desc, "state_size", -1); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {128}, {128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {70, 128}, {5, 9, 16, 16}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT1, true, {70, 128}, {5, 3, 16, 256}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT2, true, {70, 128}, {5, 3, 16, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT4, true, {70, 128}, {5, 3, 16, 64}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT8, true, {70, 128}, {5, 6, 16, 32}); - - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {9, 40, 128}, - {9, 3, 9, 16, 16}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT1, true, {9, 40, 128}, {9, 3, 3, 16, 256}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT2, true, {9, 40, 128}, {9, 3, 3, 16, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT4, true, {9, 40, 128}, {9, 3, 3, 16, 64}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT8, true, {9, 40, 128}, {9, 3, 6, 16, 32}); - - (void)ge::AttrUtils::SetInt(op_desc, "state_size", 70); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {128}, {128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {70, 128}, {5, 9, 16, 16}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT1, true, {70, 128}, {5, 3, 16, 256}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT2, true, {70, 128}, {5, 3, 16, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT4, true, {70, 128}, {5, 3, 16, 64}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT8, true, {70, 128}, {5, 6, 16, 32}); - - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_FLOAT16, true, {9, 100, 128}, - {9, 7, 9, 16, 16}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT1, true, {9, 100, 128}, - {9, 7, 3, 16, 256}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_UINT2, true, {9, 100, 128}, - {9, 7, 3, 16, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT4, true, {9, 100, 128}, {9, 7, 3, 16, 64}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_FRACTAL_ZN_RNN, DT_INT8, true, {9, 100, 128}, {9, 7, 6, 16, 32}); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nd_to_nd_rnn_bias) { - ge::OpDescPtr op_desc = std::make_shared("test", "test"); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {150}, {2400}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {18, 80}, {18, 1280}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_UINT1, true, {18, 80}, {18, 20480}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_UINT2, true, {18, 80}, {18, 10240}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_INT4, true, {18, 80}, {18, 5120}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_INT8, true, {18, 80}, {18, 2560}); - - (void)ge::AttrUtils::SetInt(op_desc, "hidden_size", 64); - (void)ge::AttrUtils::SetInt(op_desc, "input_size", 1); - (void)ge::AttrUtils::SetInt(op_desc, "state_size", 1); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {}, {}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {150}, {128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {18, 80}, {18, 64}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_UINT1, true, {18, 80}, {18, 256}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_UINT2, true, {18, 80}, {18, 128}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_INT4, true, {18, 80}, {18, 64}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_INT8, true, {18, 80}, {18, 64}); - - (void)ge::AttrUtils::SetInt(op_desc, "hidden_size", 0); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {150}, {150}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_INT8, true, {18, 80}, {18, 80}); - EXPECT_RunTransferShape(op_desc, ge::FORMAT_ND, ge::FORMAT_ND_RNN_BIAS, DT_FLOAT16, true, {-2}, {-2}, true); -} - -TEST_F(TransformerTransferShapeUT, transfer_shape_from_nyuva) { - ShapeTransferAccordingToFormat shape_transfer; - gert::Shape current_shape; - vector dims = {42, 63, 3}; - vector expect_dim = {48, 64, 3}; - for (const int64_t &d : dims) { - current_shape.AppendDim(d); - } - bool ret = shape_transfer.TransferShape(ge::FORMAT_NYUV, ge::FORMAT_NYUV_A, DT_INT8, current_shape); - EXPECT_EQ(ret, true); - if (ret) { - vector new_dim; - for (size_t i = 0; i < current_shape.GetDimNum(); ++i) { - new_dim.push_back(current_shape.GetDim(i)); - } - EXPECT_EQ(new_dim, expect_dim); - } -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_1) { - ge::Format src_format = ge::FORMAT_ND; - gert::Shape src_shape({1}); - ge::Format dst_format = ge::FORMAT_ND; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{0}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp1); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_2) { - ge::Format src_format = ge::FORMAT_ND; - gert::Shape src_shape({1, 16}); - ge::Format dst_format = ge::FORMAT_FRACTAL_NZ; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({16, 16})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{1}, {0}}; - std::vector> tmp2 = {{1}, {0}, {-1}, {-1}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_3) { - ge::Format src_format = ge::FORMAT_ND; - gert::Shape src_shape({1, 16, 64}); - ge::Format dst_format = ge::FORMAT_FRACTAL_NZ; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 16, 16})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{0}, {2}, {1}}; - std::vector> tmp2 = {{0}, {2}, {1}, {-1}, {-1}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_4) { - ge::Format src_format = ge::FORMAT_ND; - gert::Shape src_shape({1, 16, 64}); - ge::Format dst_format = ge::FORMAT_FRACTAL_Z; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, false); - EXPECT_EQ(aligned_shape, gert::Shape({})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, false); - std::vector> tmp1 = {}; - std::vector> tmp2 = {}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_5) { - ge::Format src_format = ge::FORMAT_NCHW; - gert::Shape src_shape({1, 16, 64}); - ge::Format dst_format = ge::FORMAT_NCHW; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 1, 1})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{0}, {1}, {2}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp1); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_6) { - ge::Format src_format = ge::FORMAT_NCHW; - gert::Shape src_shape({1, 16, 64, 128}); - ge::Format dst_format = ge::FORMAT_HWCN; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 1, 1, 1})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{3}, {2}, {0}, {1}}; - std::vector> tmp2 = {{2}, {3}, {1}, {0}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_7) { - ge::Format src_format = ge::FORMAT_NCHW; - gert::Shape src_shape({1, 16, 64, 128}); - ge::Format dst_format = ge::FORMAT_NC1HWC0; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 16, 1, 1})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{0}, {1, 4}, {2}, {3}}; - std::vector> tmp2 = {{0}, {1}, {2}, {3}, {1}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_8) { - ge::Format src_format = ge::FORMAT_NCHW; - gert::Shape src_shape({1, 16}); - ge::Format dst_format = ge::FORMAT_NC1HWC0; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 9; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({16, 1})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{1, 4}, {2}}; - std::vector> tmp2 = {{-1}, {0}, {1}, {-1}, {0}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_9) { - ge::Format src_format = ge::FORMAT_NCHW; - gert::Shape src_shape({1, 16}); - ge::Format dst_format = ge::FORMAT_NC1HWC0; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 3; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 1})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{2}, {3}}; - std::vector> tmp2 = {{-1}, {-1}, {0}, {1}, {-1}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_10) { - ge::Format src_format = ge::FORMAT_NDHWC; - gert::Shape src_shape({1, 16, 32, 64, 128}); - ge::Format dst_format = ge::FORMAT_NCDHW; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 1, 1, 1, 1})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{0}, {2}, {3}, {4}, {1}}; - std::vector> tmp2 = {{0}, {4}, {1}, {2}, {3}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_11) { - ge::Format src_format = ge::FORMAT_NDHWC; - gert::Shape src_shape({1, 16}); - ge::Format dst_format = ge::FORMAT_NDC1HWC0; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 7; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 16})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{4}, {2, 5}}; - std::vector> tmp2 = {{-1}, {-1}, {1}, {-1}, {0}, {1}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_12) { - ge::Format src_format = ge::FORMAT_HWCN; - gert::Shape src_shape({1, 2, 3, 16}); - ge::Format dst_format = ge::FORMAT_FRACTAL_Z; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 0; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 1, 16, 16})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{0}, {0}, {0, 3}, {1, 2}}; - std::vector> tmp2 = {{2, 0, 1}, {3}, {3}, {2}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} - -TEST_F(TransformerTransferShapeUT, get_aligned_shape_and_transfer_dim_13) { - ge::Format src_format = ge::FORMAT_HWCN; - gert::Shape src_shape({1, 16}); - ge::Format dst_format = ge::FORMAT_FRACTAL_Z; - ge::DataType data_type = ge::DT_FLOAT16; - int64_t reshape_type_mask = 6; - gert::Shape aligned_shape; - transformer::AlignShapeInfo align_shape_info = {src_format, dst_format, src_shape, data_type, reshape_type_mask}; - bool ret = transformer::TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); - EXPECT_EQ(ret, true); - EXPECT_EQ(aligned_shape, gert::Shape({1, 16})); - transformer::TransferDimsInfo transfor_dims_info = {src_format, dst_format, src_shape, reshape_type_mask}; - transformer::AxisIndexMapping axis_index_mapping; - ret = transformer::TransferShapeUtils::TransferDims(transfor_dims_info, axis_index_mapping); - EXPECT_EQ(ret, true); - std::vector> tmp1 = {{0},{1, 2}}; - std::vector> tmp2 = {{-1, 0, -1}, {1}, {1}, {-1}}; - EXPECT_EQ(axis_index_mapping.src_to_dst_transfer_dims, tmp1); - EXPECT_EQ(axis_index_mapping.dst_to_src_transfer_dims, tmp2); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/tuning_utils_unittest.cc b/tests/ut/graph/testcase/tuning_utils_unittest.cc deleted file mode 100644 index 41393c5293198cf808e61205eb41ed40fbe7751b..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/tuning_utils_unittest.cc +++ /dev/null @@ -1,633 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/tuning_utils.h" -#include "graph/utils/node_utils.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/operator_factory_impl.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "graph/anchor.h" - -using namespace std; -using namespace testing; - -namespace ge { - -class UtestTuningUtils : public testing::Test { - protected: - void SetUp() { - } - void TearDown() { - TuningUtils::netoutput_nodes_.clear(); - TuningUtils::data_node_2_end_node_.clear(); - } -}; - -/* ------------------------- -* | partitioncall_0_const1* | -* partitioncall_0--------------| | | -* | | netoutput | -* | -------------------------- -* | ------------------ ------------- -* | | data | | data | -* | | | | | | | -* partitioncall_1--------------| case -----|-------| squeeze* | -* | | | | | | -* | netoutput | | netoutput | -* ------------------ ------------- -*/ -ComputeGraphPtr BuildGraphPartitionCall1() { - auto root_builder = ut::GraphBuilder("root"); - const auto &partitioncall_0 = root_builder.AddNode("partitioncall_0", PARTITIONEDCALL, 0, 1); - const auto &partitioncall_1 = root_builder.AddNode("partitioncall_1", PARTITIONEDCALL, 1, 1); - root_builder.AddDataEdge(partitioncall_0, 0, partitioncall_1, 0); - const auto &root_graph = root_builder.GetGraph(); - - // 1.build partitioncall_0 sub graph - auto p1_sub_builder = ut::GraphBuilder("partitioncall_0_sub"); - const auto &partitioncall_0_const1 = p1_sub_builder.AddNode("partitioncall_0_const1", CONSTANT, 0, 1); - const auto &partitioncall_0_netoutput = p1_sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - p1_sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = p1_sub_builder.GetGraph(); - sub_graph->SetParentNode(partitioncall_0); - sub_graph->SetParentGraph(root_graph); - partitioncall_0->GetOpDesc()->AddSubgraphName("f"); - partitioncall_0->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_0_sub"); - - // 2.build partitioncall_1 sub graph - auto p2_sub_builder = ut::GraphBuilder("partitioncall_1_sub"); - const auto &partitioncall_1_data = p2_sub_builder.AddNode("partitioncall_1_data", DATA, 0, 1); - AttrUtils::SetInt(partitioncall_1_data->GetOpDesc(), "_parent_node_index", 0); - const auto &partitioncall_1_case = p2_sub_builder.AddNode("partitioncall_1_case", "Case", 1, 1); - const auto &partitioncall_1_netoutput = p2_sub_builder.AddNode("partitioncall_1_netoutput", NETOUTPUT, 1, 1); - p2_sub_builder.AddDataEdge(partitioncall_1_data, 0, partitioncall_1_case, 0); - p2_sub_builder.AddDataEdge(partitioncall_1_case, 0, partitioncall_1_netoutput, 0); - const auto &sub_graph2 = p2_sub_builder.GetGraph(); - sub_graph2->SetParentNode(partitioncall_1); - sub_graph2->SetParentGraph(root_graph); - partitioncall_1->GetOpDesc()->AddSubgraphName("f"); - partitioncall_1->GetOpDesc()->SetSubgraphInstanceName(0, "partitioncall_1_sub"); - - // 2.1 build case sub graph - auto case_sub_builder = ut::GraphBuilder("case_sub"); - const auto &case_data = case_sub_builder.AddNode("case_data", DATA, 0, 1); - AttrUtils::SetInt(case_data->GetOpDesc(), "_parent_node_index", 0); - const auto &case_squeeze = case_sub_builder.AddNode("case_squeeze", SQUEEZE, 1, 1); - const auto &case_netoutput = case_sub_builder.AddNode("case_netoutput", NETOUTPUT, 1, 1); - case_sub_builder.AddDataEdge(case_data, 0, case_squeeze, 0); - case_sub_builder.AddDataEdge(case_squeeze, 0, case_netoutput, 0); - const auto &case_sub_graph = case_sub_builder.GetGraph(); - case_sub_graph->SetParentNode(partitioncall_1_case); - case_sub_graph->SetParentGraph(sub_graph2); - partitioncall_1_case->GetOpDesc()->AddSubgraphName("branches"); - partitioncall_1_case->GetOpDesc()->SetSubgraphInstanceName(0, "case_sub"); - - root_graph->AddSubgraph(case_sub_graph->GetName(), case_sub_graph); - root_graph->AddSubgraph(sub_graph->GetName(), sub_graph); - root_graph->AddSubgraph(sub_graph2->GetName(), sub_graph2); - return root_graph; -} - -TEST_F(UtestTuningUtils, ConvertGraphToFile) { - std::vector tuning_subgraphs; - std::vector non_tuning_subgraphs; - auto builder = ut::GraphBuilder("non_tun"); - const auto data0 = builder.AddNode("data_0", DATA, 0, 1); - const auto data1 = builder.AddNode("data_1", DATA, 1, 1); - auto nongraph = builder.GetGraph(); - tuning_subgraphs.push_back(BuildGraphPartitionCall1()); - non_tuning_subgraphs.push_back(nongraph); - EXPECT_EQ(TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs), GRAPH_SUCCESS); - auto nonnodes = non_tuning_subgraphs.at(0)->GetDirectNode(); - auto nonfirst = nonnodes.at(0); - nonfirst->impl_->op_ = nullptr; - EXPECT_EQ(TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs), GRAPH_FAILED); - auto nodes = tuning_subgraphs.at(0)->GetDirectNode(); - auto first = nodes.at(0); - first->impl_->op_ = nullptr; - EXPECT_EQ(TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs), GRAPH_FAILED); -} - -TEST_F(UtestTuningUtils, ConvertGraphToFile_HelpInfo) { - std::vector tuning_subgraphs; - std::vector non_tuning_subgraphs; - auto builder = ut::GraphBuilder("non_tun"); - const auto data0 = builder.AddNode("data_0", DATA, 0, 1); - const auto data1 = builder.AddNode("data_1", DATA, 1, 1); - auto nongraph = builder.GetGraph(); - tuning_subgraphs.push_back(BuildGraphPartitionCall1()); - non_tuning_subgraphs.push_back(nongraph); - EXPECT_EQ(TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs, true, "path", "user_path"), GRAPH_SUCCESS); -} - -TEST_F(UtestTuningUtils, ConvertGraphToFile_Placehodler) { - std::vector tuning_subgraphs; - std::vector non_tuning_subgraphs; - auto builder = ut::GraphBuilder("non_tun"); - const auto plhd0 = builder.AddNode("placeholder_0", PLACEHOLDER, 1, 1); - auto nongraph = builder.GetGraph(); - tuning_subgraphs.push_back(BuildGraphPartitionCall1()); - non_tuning_subgraphs.push_back(nongraph); - EXPECT_EQ(TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs), GRAPH_FAILED); -} - -TEST_F(UtestTuningUtils, ConvertGraphToFile_End) { - std::vector tuning_subgraphs; - std::vector non_tuning_subgraphs; - auto builder = ut::GraphBuilder("non_tun"); - const auto end0 = builder.AddNode("end_0", END, 1, 1); - auto nongraph = builder.GetGraph(); - tuning_subgraphs.push_back(BuildGraphPartitionCall1()); - non_tuning_subgraphs.push_back(nongraph); - EXPECT_EQ(TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs), GRAPH_FAILED); -} - -TEST_F(UtestTuningUtils, PrintCheckLog) { - EXPECT_NE(TuningUtils::PrintCheckLog(), ""); - TuningUtils::data_2_end_["data"] = "end"; - EXPECT_NE(TuningUtils::PrintCheckLog(), ""); - auto builder = ut::GraphBuilder("graph"); - const auto data0 = builder.AddNode("data_0", DATA, 0, 1); - const auto data1 = builder.AddNode("data_1", DATA, 1, 1); - TuningUtils::netoutput_nodes_.push_back(data0); - TuningUtils::netoutput_nodes_.push_back(data1); - EXPECT_NE(TuningUtils::PrintCheckLog(), ""); -} - -TEST_F(UtestTuningUtils, GetNodeNameByAnchor) { - EXPECT_EQ(TuningUtils::GetNodeNameByAnchor(nullptr), "Null"); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", "Data", 1, 1); - InDataAnchorPtr in_anch = std::make_shared(node, 111); - EXPECT_EQ(TuningUtils::GetNodeNameByAnchor(in_anch.get()), "Data"); -} - -TEST_F(UtestTuningUtils, CreateDataNode) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", "Data", 1, 1); - NodePtr data_node; - EXPECT_EQ(TuningUtils::CreateDataNode(node, "", data_node), GRAPH_SUCCESS); - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - map weight1; - weight1[1] = tensor; - EXPECT_EQ(ge::OpDescUtils::SetWeights(*node, weight1), 0); - auto node_tensor = OpDescUtils::MutableWeights(node); - EXPECT_EQ(TuningUtils::CreateDataNode(node, "", data_node), GRAPH_SUCCESS); - - - auto sub_builder = ut::GraphBuilder("sub"); - const auto &partitioncall_0_const1 = sub_builder.AddNode("partitioncall_0_const1", CONSTANT, 0, 1); - const auto &partitioncall_0_netoutput = sub_builder.AddNode("partitioncall_0_netoutput", NETOUTPUT, 1, 1); - AttrUtils::SetInt(partitioncall_0_netoutput->GetOpDesc()->MutableInputDesc(0), "_parent_node_index", 0); - sub_builder.AddDataEdge(partitioncall_0_const1, 0, partitioncall_0_netoutput, 0); - const auto &sub_graph = sub_builder.GetGraph(); - sub_graph->SetParentNode(node); - EXPECT_EQ(TuningUtils::CreateDataNode(node, "", data_node), GRAPH_SUCCESS); -} - -TEST_F(UtestTuningUtils, CreateDataNode_Weight) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node = builder.AddNode("Data", DATA, 1, 1); - auto node1 = builder.AddNode("Data1", "Data", 1, 1); - NodePtr data_node; - node->GetOpDesc()->SetExtAttr("parentNode", node1); - EXPECT_EQ(TuningUtils::CreateDataNode(node, "", data_node), GRAPH_SUCCESS); - - auto pld = builder.AddNode("pld", PLACEHOLDER, 0, 1); - uint8_t val = 1; - auto const_tensor = std::make_shared(GeTensorDesc(), &val, sizeof(val)); - ASSERT_NE(pld->GetOpDesc(), nullptr); - EXPECT_EQ(ge::AttrUtils::SetTensor(pld->GetOpDesc(), "value", const_tensor), true); - EXPECT_EQ(ge::AttrUtils::SetStr(pld->GetOpDesc(), "_parentNodeName", "src_const"), true); - EXPECT_EQ(TuningUtils::CreateDataNode(pld, "", data_node), GRAPH_SUCCESS); - std::string parent_node_name; - EXPECT_EQ(ge::AttrUtils::GetStr(data_node->GetOpDesc(), ATTR_NAME_SRC_CONST_NAME, parent_node_name), true); - EXPECT_EQ(parent_node_name, "src_const"); -} - -TEST_F(UtestTuningUtils, AddAttrToDataNodeForMergeGraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - auto node1 = builder.AddNode("Data1", "Data", 1, 1); - EXPECT_EQ(TuningUtils::AddAttrToDataNodeForMergeGraph(node0, node1), FAILED); - AttrUtils::SetStr(node0->GetOpDesc(), "parentOpType", "Hello world"); - EXPECT_EQ(TuningUtils::AddAttrToDataNodeForMergeGraph(node0, node1), FAILED); - AttrUtils::SetStr(node0->GetOpDesc(), "_parentNodeName", "Hello world0"); - EXPECT_EQ(TuningUtils::AddAttrToDataNodeForMergeGraph(node0, node1), FAILED); - AttrUtils::SetInt(node0->GetOpDesc(), "anchorIndex", 1); - EXPECT_EQ(TuningUtils::AddAttrToDataNodeForMergeGraph(node0, node1), FAILED); - AttrUtils::SetStr(node0->GetOpDesc(), "_peerNodeName", "Hello world0"); - EXPECT_EQ(TuningUtils::AddAttrToDataNodeForMergeGraph(node0, node1), SUCCESS); -} - -TEST_F(UtestTuningUtils, ChangePld2Data) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - auto node1 = builder.AddNode("Data1", "Data", 1, 1); - EXPECT_EQ(TuningUtils::ChangePld2Data(node0, node1), FAILED); - auto node2 = builder.AddNode("placeholder2", PLACEHOLDER, 1, 1); - auto node3 = builder.AddNode("data3", DATA, 1, 1); - EXPECT_EQ(TuningUtils::ChangePld2Data(node2, node3), SUCCESS); - node3->impl_->out_data_anchors_.push_back(nullptr); - EXPECT_EQ(TuningUtils::ChangePld2Data(node2, node3), FAILED); -} - -TEST_F(UtestTuningUtils, HandlePld) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - EXPECT_EQ(TuningUtils::HandlePld(node0, ""), FAILED); - AttrUtils::SetStr(node0->GetOpDesc(), "parentOpType", "Hello world"); - AttrUtils::SetStr(node0->GetOpDesc(), "_parentNodeName", "Hello world0"); - AttrUtils::SetInt(node0->GetOpDesc(), "anchorIndex", 1); - AttrUtils::SetStr(node0->GetOpDesc(), "_peerNodeName", "Hello world0"); - EXPECT_EQ(TuningUtils::HandlePld(node0, ""), FAILED); - auto node2 = builder.AddNode("placeholder2", PLACEHOLDER, 1, 1); - auto node3 = builder.AddNode("data3", DATA, 1, 1); - AttrUtils::SetStr(node2->GetOpDesc(), "parentOpType", "Hello world"); - AttrUtils::SetStr(node2->GetOpDesc(), "_parentNodeName", "Hello world0"); - AttrUtils::SetInt(node2->GetOpDesc(), "anchorIndex", 1); - AttrUtils::SetStr(node2->GetOpDesc(), "_peerNodeName", "Hello world0"); - EXPECT_EQ(TuningUtils::HandlePld(node2, ""), SUCCESS); -} - -TEST_F(UtestTuningUtils, CreateNetOutput) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - NodePtr node1; - auto graph = builder.GetGraph(); - EXPECT_EQ(TuningUtils::CreateNetOutput(node0, node1), FAILED); - TuningUtils::create_output_[graph] = node0; - EXPECT_EQ(TuningUtils::CreateNetOutput(node0, node1), SUCCESS); - TuningUtils::create_output_[graph] = nullptr; - EXPECT_EQ(TuningUtils::CreateNetOutput(node0, node1), SUCCESS); -} - -TEST_F(UtestTuningUtils, AddAttrToNetOutputForMergeGraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - auto node1 = builder.AddNode("Data1", "Data", 1, 1); - EXPECT_EQ(TuningUtils::AddAttrToNetOutputForMergeGraph(node0, node1, 0), SUCCESS); -} - -TEST_F(UtestTuningUtils, LinkEnd2NetOutput) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - auto node1 = builder.AddNode("Data1", "Data", 1, 1); - EXPECT_EQ(TuningUtils::LinkEnd2NetOutput(node0, node1), PARAM_INVALID); - auto node2 = builder.AddNode("Data2", "Data", 0, 1); - auto node3 = builder.AddNode("Data3", "Data", 1, 1); - EXPECT_EQ(node2->AddLinkFrom(node3), GRAPH_SUCCESS); - EXPECT_EQ(TuningUtils::LinkEnd2NetOutput(node2, node3), SUCCESS); -} - -TEST_F(UtestTuningUtils, LinkEnd2NetOutput_OutControlAnchor) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node2 = builder.AddNode("Data2", "Data", 1, 1); - auto node3 = builder.AddNode("Data3", "Data", 1, 1); - auto node4 = builder.AddNode("Data4", "Data", 1, 1); - EXPECT_EQ(node2->GetAllInDataAnchors().size(), 1); - EXPECT_EQ(node2->GetInDataAnchor(0)->GetFirstPeerAnchor(), nullptr); - EXPECT_EQ(node2->GetInControlAnchor()->LinkFrom(node4->GetOutControlAnchor()), GRAPH_SUCCESS); - EXPECT_EQ(node2->AddLinkFrom(node3), GRAPH_SUCCESS); - EXPECT_EQ(TuningUtils::LinkEnd2NetOutput(node2, node3), SUCCESS); -} - - -TEST_F(UtestTuningUtils, ChangeEnd2NetOutput) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - auto node1 = builder.AddNode("Data1", "Data", 1, 1); - EXPECT_EQ(TuningUtils::ChangeEnd2NetOutput(node0, node1), FAILED); - auto node2 = builder.AddNode("Node2", END, 1, 1); - auto node3 = builder.AddNode("Node3", NETOUTPUT, 1, 1); - EXPECT_EQ(TuningUtils::ChangeEnd2NetOutput(node2, node3), FAILED); - auto node4 = builder.AddNode("Node4", END, 0, 1); - auto node5 = builder.AddNode("Node5", NETOUTPUT, 1, 1); - EXPECT_EQ(node4->AddLinkFrom(node5), GRAPH_SUCCESS); - EXPECT_EQ(TuningUtils::ChangeEnd2NetOutput(node4, node5), SUCCESS); - auto graph = node4->GetOwnerComputeGraph(); - graph->impl_ = nullptr; - EXPECT_EQ(TuningUtils::ChangeEnd2NetOutput(node4, node5), FAILED); -} - -TEST_F(UtestTuningUtils, HandleEnd) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", DATA, 0, 1); - auto graph = builder.GetGraph(); - EXPECT_EQ(TuningUtils::HandleEnd(node0), FAILED); - TuningUtils::create_output_[graph] = node0; - EXPECT_EQ(TuningUtils::HandleEnd(node0), FAILED); - auto node4 = builder.AddNode("Node4", END, 0, 1); - auto node5 = builder.AddNode("Node5", NETOUTPUT, 1, 1); - EXPECT_EQ(node4->AddLinkFrom(node5), GRAPH_SUCCESS); - TuningUtils::create_output_[graph] = node5; - EXPECT_EQ(TuningUtils::HandleEnd(node4), SUCCESS); -} - -TEST_F(UtestTuningUtils, ConvertFileToGraph) { - // build root graph - auto root_graph_builder = ut::GraphBuilder("root_graph"); - const auto &data_0 = root_graph_builder.AddNode("data_0", DATA, 0, 1); - AttrUtils::SetInt(data_0->GetOpDesc(), "_parent_node_index", 0); - const auto &case_0 = root_graph_builder.AddNode("case_0", "Case", 1, 1); - const auto &netoutput_0 = root_graph_builder.AddNode("netoutput_0", NETOUTPUT, 1, 1); - root_graph_builder.AddDataEdge(data_0, 0, case_0, 0); - root_graph_builder.AddDataEdge(case_0, 0, netoutput_0, 0); - case_0->GetOpDesc()->AddSubgraphName("branches"); - case_0->GetOpDesc()->SetSubgraphInstanceName(0, "case_sub"); - const auto &root_graph = root_graph_builder.GetGraph(); - EXPECT_EQ(AttrUtils::SetBool(root_graph, ATTR_NAME_IS_ROOT_GRAPH, true), true); - EXPECT_EQ(AttrUtils::SetStr(root_graph, ATTR_NAME_PARENT_GRAPH_NAME, root_graph->GetName()), true); - auto ret = GraphUtils::DumpGEGraphByPath(root_graph, "./subgraph_0.txt", ge::DumpLevel::NO_DUMP); - ASSERT_EQ(ret, 0); - - // build case sub graph - auto case_sub_builder = ut::GraphBuilder("case_sub"); - const auto &case_data = case_sub_builder.AddNode("case_data", DATA, 0, 1); - AttrUtils::SetInt(case_data->GetOpDesc(), "_parent_node_index", 0); - const auto &case_squeeze = case_sub_builder.AddNode("case_squeeze", SQUEEZE, 1, 1); - const auto &case_netoutput = case_sub_builder.AddNode("case_netoutput", NETOUTPUT, 1, 1); - case_sub_builder.AddDataEdge(case_data, 0, case_squeeze, 0); - case_sub_builder.AddDataEdge(case_squeeze, 0, case_netoutput, 0); - const auto &case_sub_graph = case_sub_builder.GetGraph(); - case_sub_graph->SetParentNode(case_0); - case_sub_graph->SetParentGraph(root_graph); - EXPECT_EQ(AttrUtils::SetStr(case_sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, case_sub_graph->GetName()), true); - ret = GraphUtils::DumpGEGraphByPath(case_sub_graph, "./subgraph_1.txt", ge::DumpLevel::NO_DUMP); - ASSERT_EQ(ret, 0); - - ComputeGraphPtr com_graph0 = std::make_shared("TestGraph0"); - ComputeGraphPtr com_graph1 = std::make_shared("TestGraph1"); - ASSERT_EQ(GraphUtils::LoadGEGraph("./subgraph_0.txt", *com_graph0), true); - ASSERT_EQ(GraphUtils::LoadGEGraph("./subgraph_1.txt", *com_graph1), true); - - std::map options; - options.emplace(0, "./subgraph_0.txt"); - options.emplace(1, "./subgraph_1.txt"); - Graph g; - EXPECT_EQ(TuningUtils::ConvertFileToGraph(options, g), SUCCESS); - - options.clear(); - EXPECT_EQ(TuningUtils::ConvertFileToGraph(options, g), FAILED); -} - -TEST_F(UtestTuningUtils, MergeSubGraph) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - auto graph = builder.GetGraph(); - EXPECT_EQ(TuningUtils::MergeSubGraph(graph), SUCCESS); -} - -TEST_F(UtestTuningUtils, MergeSubGraph_End) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("end", END, 1, 1); - auto graph = builder.GetGraph(); - EXPECT_EQ(TuningUtils::MergeSubGraph(graph), FAILED); -} - -TEST_F(UtestTuningUtils, MergeSubGraph_Valid) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("data", DATA, 1, 1); - auto graph = builder.GetGraph(); - AttrUtils::SetStr(node0->GetOpDesc(), "_peerNodeName", "Hello world"); - EXPECT_EQ(TuningUtils::MergeSubGraph(graph), SUCCESS); -} - -TEST_F(UtestTuningUtils, MergeSubGraph_Netoutput) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("net", NETOUTPUT, 1, 1); - auto graph = builder.GetGraph(); - std::vector val; - val.push_back("Hello world"); - AttrUtils::SetListStr(node0->GetOpDesc(), "_aliasName", val); - EXPECT_EQ(TuningUtils::MergeSubGraph(graph), SUCCESS); -} - -TEST_F(UtestTuningUtils, FindNode) { - int64_t in_index; - EXPECT_EQ(TuningUtils::FindNode("Data0", in_index), nullptr); - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node0 = builder.AddNode("Data0", "Data", 1, 1); - auto graph = builder.GetGraph(); - TuningUtils::netoutput_nodes_.push_back(nullptr); - TuningUtils::netoutput_nodes_.push_back(node0); - EXPECT_EQ(TuningUtils::FindNode("Data0", in_index), nullptr); - AttrUtils::SetListStr(node0->GetOpDesc(), "_aliasName", {"Data0", "str2", "str3"}); - AttrUtils::SetListInt(node0->GetOpDesc(), "_aliasIndexes", {0, 1, 2}); - EXPECT_NE(TuningUtils::FindNode("Data0", in_index), nullptr); -} - -TEST_F(UtestTuningUtils, ConvertConstToWeightAttr) { - auto builder = ut::GraphBuilder("root"); - const auto &placeholder_0 = builder.AddNode("placeholder_0", PLACEHOLDER, 0, 1); - const auto &placeholder_1 = builder.AddNode("placeholder_1", PLACEHOLDER, 1, 1); - std::map tmap; - tmap[10] = std::make_shared(); - OpDescUtils::SetWeights(*placeholder_0, tmap); - builder.AddDataEdge(placeholder_0, 0, placeholder_1, 0); - const auto &graph = builder.GetGraph(); - EXPECT_EQ(TuningUtils::ConvertConstToWeightAttr(graph), SUCCESS); - EXPECT_EQ(OpDescUtils::SetWeights(placeholder_0->GetOpDesc(), std::make_shared()), GRAPH_SUCCESS); - auto weight = OpDescUtils::MutableWeights(placeholder_0); - EXPECT_EQ(weight.empty(), false); - EXPECT_EQ(TuningUtils::ConvertConstToWeightAttr(graph), SUCCESS); -} - -TEST_F(UtestTuningUtils, DumpGraphToPath) { - EXPECT_NO_THROW( - auto builder = ut::GraphBuilder("root"); - const auto &placeholder_0 = builder.AddNode("placeholder_0", PLACEHOLDER, 0, 1); - const auto &placeholder_1 = builder.AddNode("placeholder_1", PLACEHOLDER, 1, 1); - builder.AddDataEdge(placeholder_0, 0, placeholder_1, 0); - const auto &graph = builder.GetGraph(); - TuningUtils::DumpGraphToPath(graph, 1, true, "path"); - TuningUtils::DumpGraphToPath(graph, 1, false, "path"); - ); -} - -TEST_F(UtestTuningUtils, RemoveDataNetoutputEdge) { - auto builder = ut::GraphBuilder("root"); - const auto &placeholder_0 = builder.AddNode("placeholder_0", PLACEHOLDER, 0, 1); - const auto &placeholder_1 = builder.AddNode("placeholder_1", PLACEHOLDER, 1, 1); - const auto netoutput_1 = builder.AddNode("netoutput_1", NETOUTPUT, 1, 1); - const auto noopnode = builder.AddNode("netoutput_1NoOp", NOOP, 1, 1); - const auto netoutput_2 = builder.AddNode("netoutput_2", NETOUTPUT, 1, 1); - builder.AddDataEdge(placeholder_0, 0, placeholder_1, 0); - auto graph = builder.GetGraph(); - TuningUtils::data_node_2_end_node_[placeholder_0] = "placeholder_0"; - TuningUtils::data_node_2_end_node_[placeholder_1] = "placeholder_1"; - EXPECT_EQ(TuningUtils::RemoveDataNetoutputEdge(graph), PARAM_INVALID); - TuningUtils::data_node_2_end_node_.clear(); - TuningUtils::data_node_2_end_node_[netoutput_1] = "netoutput_1"; - std::vector out_alias_name; - out_alias_name.push_back("netoutput_1"); - AttrUtils::SetListStr(netoutput_1->GetOpDesc(), "_aliasName", out_alias_name); - std::vector alias_indexes; - alias_indexes.push_back(-1); - AttrUtils::SetListInt(netoutput_1->GetOpDesc(), "_aliasIndexes", alias_indexes); - TuningUtils::netoutput_nodes_.push_back(netoutput_1); - int64_t index = 0; - auto n = TuningUtils::FindNode("netoutput_1", index); - EXPECT_EQ(index, -1); - EXPECT_EQ(netoutput_1->GetInControlAnchor()->LinkFrom(noopnode->GetOutControlAnchor()), GRAPH_SUCCESS); - EXPECT_EQ(noopnode->GetInControlAnchor()->LinkFrom(netoutput_2->GetOutControlAnchor()), GRAPH_SUCCESS); - EXPECT_EQ(TuningUtils::RemoveDataNetoutputEdge(graph), GRAPH_SUCCESS); -} - -TEST_F(UtestTuningUtils, RemoveDataNetoutputEdge_FindNode) { - auto builder = ut::GraphBuilder("root"); - const auto &placeholder_0 = builder.AddNode("placeholder_0", PLACEHOLDER, 0, 1); - const auto &placeholder_1 = builder.AddNode("placeholder_1", PLACEHOLDER, 1, 1); - builder.AddDataEdge(placeholder_0, 0, placeholder_1, 0); - TuningUtils::data_node_2_end_node_[placeholder_0] = "placeholder_0"; - TuningUtils::netoutput_nodes_.push_back(placeholder_0); - AttrUtils::SetListStr(placeholder_0->GetOpDesc(), "_aliasName", {"placeholder_0"}); - AttrUtils::SetListInt(placeholder_0->GetOpDesc(), "_aliasIndexes", {0}); - int64_t in_index0; - EXPECT_NE(TuningUtils::FindNode("placeholder_0", in_index0), nullptr); - EXPECT_EQ(placeholder_0->AddLinkFrom(placeholder_1), GRAPH_SUCCESS); - auto graph = builder.GetGraph(); - EXPECT_EQ(TuningUtils::RemoveDataNetoutputEdge(graph), SUCCESS); -} - -TEST_F(UtestTuningUtils, MergeAllSubGraph) { - auto builder0 = ut::GraphBuilder("sub0"); - const auto &placeholder_0 = builder0.AddNode("placeholder_0", PLACEHOLDER, 0, 1); - const auto &placeholder_1 = builder0.AddNode("placeholder_1", PLACEHOLDER, 1, 1); - auto graph0 = builder0.GetGraph(); - auto builder1 = ut::GraphBuilder("sub1"); - const auto &placeholder_2 = builder1.AddNode("placeholder_2", PLACEHOLDER, 0, 1); - const auto &placeholder_3 = builder1.AddNode("placeholder_3", PLACEHOLDER, 1, 1); - auto graph1 = builder1.GetGraph(); - std::vector vec; - vec.push_back(graph0); - vec.push_back(graph1); - auto output_builder = ut::GraphBuilder("output"); - auto output_graph = output_builder.GetGraph(); - TuningUtils::merged_graph_nodes_.push_back(placeholder_0); - TuningUtils::merged_graph_nodes_.push_back(placeholder_1); - TuningUtils::merged_graph_nodes_.push_back(placeholder_2); - TuningUtils::merged_graph_nodes_.push_back(placeholder_3); - EXPECT_EQ(TuningUtils::MergeAllSubGraph(vec, output_graph), GRAPH_FAILED); - vec.clear(); - EXPECT_EQ(TuningUtils::MergeAllSubGraph(vec, output_graph), SUCCESS); - std::vector vals; - vals.push_back("1"); - vals.push_back("2"); - AttrUtils::SetListStr(placeholder_0->GetOpDesc(), ATTR_NAME_NEED_RECOVER_ATTR, vals); - EXPECT_EQ(TuningUtils::MergeAllSubGraph(vec, output_graph), SUCCESS); -} - -TEST_F(UtestTuningUtils, WeightExternalizationAndRecover) { - std::vector tuning_subgraphs; - auto builder_tune = ut::GraphBuilder("tune_graph"); - const auto pld0 = builder_tune.AddNode("pld0", PLACEHOLDER, 0, 1); - AttrUtils::SetStr(pld0->GetOpDesc(), "_parentNodeName", "Const0"); - AttrUtils::SetStr(pld0->GetOpDesc(), "_peerNodeName", "Const0"); - AttrUtils::SetStr(pld0->GetOpDesc(), "parentOpType", CONSTANT); - AttrUtils::SetInt(pld0->GetOpDesc(), "anchorIndex", 0); - const auto pld1 = builder_tune.AddNode("pld1", PLACEHOLDER, 0, 1); - AttrUtils::SetStr(pld1->GetOpDesc(), "_parentNodeName", "Const0"); - AttrUtils::SetStr(pld1->GetOpDesc(), "_peerNodeName", "Const0"); - AttrUtils::SetStr(pld1->GetOpDesc(), "parentOpType", CONSTANT); - AttrUtils::SetInt(pld1->GetOpDesc(), "anchorIndex", 0); - const auto pld2 = builder_tune.AddNode("pld2", PLACEHOLDER, 0, 1); - AttrUtils::SetStr(pld2->GetOpDesc(), "_parentNodeName", "Const1"); - AttrUtils::SetStr(pld2->GetOpDesc(), "_peerNodeName", "Const1"); - AttrUtils::SetStr(pld2->GetOpDesc(), "parentOpType", CONSTANT); - AttrUtils::SetInt(pld2->GetOpDesc(), "anchorIndex", 0); - const auto pld3 = builder_tune.AddNode("pld3", PLACEHOLDER, 0, 1); - AttrUtils::SetStr(pld3->GetOpDesc(), "_parentNodeName", "Const2"); - AttrUtils::SetStr(pld3->GetOpDesc(), "_peerNodeName", "Const2"); - AttrUtils::SetStr(pld3->GetOpDesc(), "parentOpType", CONSTANT); - AttrUtils::SetInt(pld3->GetOpDesc(), "anchorIndex", 0); - const auto addn = builder_tune.AddNode("addn", "AddN", 4, 1); - const auto end0 = builder_tune.AddNode("end0", END, 1, 0); - builder_tune.AddDataEdge(pld0, 0, addn, 0); - builder_tune.AddDataEdge(pld1, 0, addn, 1); - builder_tune.AddDataEdge(pld2, 0, addn, 2); - builder_tune.AddDataEdge(pld3, 0, addn, 3); - builder_tune.AddDataEdge(addn, 0, end0, 0); - auto tune_graph = builder_tune.GetGraph(); - tuning_subgraphs.push_back(tune_graph); - - std::vector non_tuning_subgraphs; - auto builder_const = ut::GraphBuilder("const_graph"); - const auto const0 = builder_const.AddNode("Const0", CONSTANT, 0, 1); - const auto const1 = builder_const.AddNode("Const1", CONSTANT, 0, 1); - const auto const2 = builder_const.AddNode("Const2", CONSTANT, 0, 1); - ge::GeTensorPtr tensor = std::make_shared(); - std::vector value{1, 2, 3}; - std::vector shape{3}; - tensor->MutableTensorDesc().SetShape(GeShape(shape)); - tensor->SetData(value); - tensor->MutableTensorDesc().SetDataType(DT_UINT8); - EXPECT_EQ(ge::OpDescUtils::SetWeights(const0, {tensor}), 0); - EXPECT_EQ(ge::OpDescUtils::SetWeights(const1, {tensor}), 0); - - ge::GeTensorPtr empty_tensor = std::make_shared(); - std::vector empty_value; - std::vector empty_shape{0}; - empty_tensor->MutableTensorDesc().SetShape(GeShape(empty_shape)); - empty_tensor->SetData(empty_value); - empty_tensor->MutableTensorDesc().SetDataType(DT_UINT8); - EXPECT_EQ(ge::OpDescUtils::SetWeights(const2, {empty_tensor}), 0); - - const auto netoutput = builder_const.AddNode("netoutput0", NETOUTPUT, 3, 0); - builder_const.AddDataEdge(const0, 0, netoutput, 0); - builder_const.AddDataEdge(const1, 0, netoutput, 1); - builder_const.AddDataEdge(const2, 0, netoutput, 2); - auto nongraph = builder_const.GetGraph(); - non_tuning_subgraphs.push_back(nongraph); - pld0->GetOpDesc()->SetExtAttr("parentNode", const0); - pld1->GetOpDesc()->SetExtAttr("parentNode", const0); - pld2->GetOpDesc()->SetExtAttr("parentNode", const1); - pld3->GetOpDesc()->SetExtAttr("parentNode", const2); - - EXPECT_EQ(TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs, true, "./"), GRAPH_SUCCESS); - EXPECT_EQ(TuningUtils::reusable_weight_files_.size(), 1U); - EXPECT_EQ(TuningUtils::hash_to_files_.size(), 1U); - const auto file_const1 = tune_graph->FindFirstNodeMatchType(FILECONSTANT); - EXPECT_NE(file_const1, nullptr); - const auto real_const1 = tune_graph->FindFirstNodeMatchType(CONSTANT); - EXPECT_NE(real_const1, nullptr); - const auto file_const2 = nongraph->FindFirstNodeMatchType(FILECONSTANT); - EXPECT_NE(file_const2, nullptr); - const auto real_const2 = nongraph->FindFirstNodeMatchType(CONSTANT); - EXPECT_NE(real_const2, nullptr); - - ComputeGraphPtr load_aicore = std::make_shared("TestGraph0"); - ComputeGraphPtr load_const = std::make_shared("TestGraph1"); - EXPECT_EQ(GraphUtils::LoadGEGraph("./aicore_subgraph_0.txt", load_aicore), true); - EXPECT_EQ(GraphUtils::LoadGEGraph("./subgraph_0.txt", load_const), true); - const auto recover_const1 = load_aicore->FindFirstNodeMatchType(CONSTANT); - EXPECT_NE(recover_const1, nullptr); - const auto file_const3 = load_aicore->FindFirstNodeMatchType(FILECONSTANT); - EXPECT_EQ(file_const3, nullptr); - const auto recover_const2 = load_const->FindFirstNodeMatchType(CONSTANT); - EXPECT_NE(recover_const2, nullptr); - const auto file_const4 = load_const->FindFirstNodeMatchType(FILECONSTANT); - EXPECT_EQ(file_const4, nullptr); - - system("rm -rf ./tmp_weight_*"); - system("rm -rf ./aicore_subgraph_*"); - system("rm -rf ./subgraph_*"); -} -} // namespace ge diff --git a/tests/ut/graph/testcase/type_utils_inner_unittest.cc b/tests/ut/graph/testcase/type_utils_inner_unittest.cc deleted file mode 100644 index 73a964e7a07c3459a22c5ffee6b5af2948d32a1a..0000000000000000000000000000000000000000 --- a/tests/ut/graph/testcase/type_utils_inner_unittest.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/utils/type_utils_inner.h" -#include -#include -#include "graph/debug/ge_util.h" - -namespace ge { -class UtestTypeUtilsInner : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestTypeUtilsInner, IsInternalFormat) { - ASSERT_TRUE(TypeUtilsInner::IsInternalFormat(FORMAT_FRACTAL_Z)); - ASSERT_FALSE(TypeUtilsInner::IsInternalFormat(FORMAT_RESERVED)); -} - -TEST_F(UtestTypeUtilsInner, ImplyTypeToSSerialString) { - ASSERT_EQ(TypeUtilsInner::ImplyTypeToSerialString(domi::ImplyType::BUILDIN), "buildin"); - ASSERT_EQ(TypeUtilsInner::ImplyTypeToSerialString(static_cast(30)), "UNDEFINED"); -} - -TEST_F(UtestTypeUtilsInner, DomiFormatToFormat) { - ASSERT_EQ(TypeUtilsInner::DomiFormatToFormat(domi::domiTensorFormat_t::DOMI_TENSOR_NDHWC), FORMAT_NDHWC); - ASSERT_EQ(TypeUtilsInner::DomiFormatToFormat(static_cast(30)), FORMAT_RESERVED); -} - -TEST_F(UtestTypeUtilsInner, FmkTypeToSerialString) { - ASSERT_EQ(TypeUtilsInner::FmkTypeToSerialString(domi::FrameworkType::CAFFE), "caffe"); -} - -TEST_F(UtestTypeUtilsInner, ImplyTypeToSerialString) { - ASSERT_EQ(TypeUtilsInner::ImplyTypeToSerialString(domi::ImplyType::BUILDIN), "buildin"); -} - -TEST_F(UtestTypeUtilsInner, DomiFormatToFormat2) { - ASSERT_EQ(TypeUtilsInner::DomiFormatToFormat(domi::DOMI_TENSOR_NCHW), FORMAT_NCHW); - ASSERT_EQ(TypeUtilsInner::DomiFormatToFormat(domi::DOMI_TENSOR_RESERVED), FORMAT_RESERVED); -} - -TEST_F(UtestTypeUtilsInner, FmkTypeToSerialString2) { - ASSERT_EQ(TypeUtilsInner::FmkTypeToSerialString(domi::CAFFE), "caffe"); - ASSERT_EQ(TypeUtilsInner::FmkTypeToSerialString(static_cast(domi::FRAMEWORK_RESERVED + 1)), ""); -} -} diff --git a/tests/ut/register/CMakeLists.txt b/tests/ut/register/CMakeLists.txt index 97364944d379fe4740ddb062c28361077375e34b..8b488a8fca19292400b014b5fb9f64a839a0e5f9 100644 --- a/tests/ut/register/CMakeLists.txt +++ b/tests/ut/register/CMakeLists.txt @@ -16,10 +16,9 @@ include_directories(${METADEF_DIR}) include_directories(${METADEF_DIR}/register) file(GLOB_RECURSE REGISTER_UT_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/register/*.cc" ) -file(GLOB_RECURSE UTILS_FILES CONFIGURE_DEPENDS "${METADEF_DIR}/tests/ut/graph/common/*.cc" ) add_executable(ut_register - ${REGISTER_UT_FILES} ${UTILS_FILES} + ${REGISTER_UT_FILES} ) add_compile_definitions(CMAKE_BINARY_DIR=\"${CMAKE_BINARY_DIR}\") target_compile_options(ut_register PRIVATE @@ -40,7 +39,7 @@ target_link_libraries(ut_register runtime_headers slog_headers msprof_headers - exe_graph lowering register graph graph_base mmpa + metadef exe_graph mmpa opp_registry GTest::gtest GTest::gtest_main @@ -60,6 +59,5 @@ target_link_libraries(ut_register ) target_include_directories(ut_register PRIVATE - ${METADEF_DIR}/tests/ut/graph/common -) - + ${METADEF_DIR}/base/registry + ) diff --git a/tests/ut/register/testcase/abi_compatibility_for_register_unittest.cc b/tests/ut/register/testcase/abi_compatibility_for_register_unittest.cc deleted file mode 100644 index b94c23b2313be05d52222608edf56e2bb87960d1..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/abi_compatibility_for_register_unittest.cc +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/any_value.h" -#include "register/op_impl_registry.h" -#include "register/op_impl_registry_base.h" -#include "register/kernel_registry_impl.h" - -namespace gert { -namespace { -constexpr const size_t kUint8Size = 1U; -constexpr const size_t kPointerSize = 8U; -constexpr const size_t kVectorSize = 24U; -constexpr const size_t kUnorderedSetSize = 56U; -constexpr const size_t kMapSize = 48U; -constexpr const size_t kVirtualTableSize = 8U; -constexpr const size_t kReservedFieldSize = 8U; -constexpr const size_t kReservedFieldSize2 = 40U; -constexpr const size_t kOpImplFunctionSize = 216U; -constexpr const size_t kOpImplReservedFieldSize = 502UL * sizeof(void*); -constexpr const size_t kOpImplSize = 4304UL; - -constexpr const size_t kOpImplFunctionsSize = 200U; -constexpr const size_t kOpImplRegistrySize = 88U + kVirtualTableSize; -constexpr const size_t kOpImplRegisterSize = 56U; -} // namespace - -constexpr size_t OpImplKernelRegistry::OpImplFunctions::kByteBitCount; -class AbiCompatibilityForRegisterUT : public testing::Test {}; -TEST_F(AbiCompatibilityForRegisterUT, OpImplRegistry_CheckMemLayoutNotChanged) { - OpImplRegistry r; - ASSERT_EQ(sizeof(r), kOpImplRegistrySize); - ASSERT_EQ(reinterpret_cast(&r.types_to_impl_) - reinterpret_cast(&r), kVirtualTableSize); - - EXPECT_EQ(reinterpret_cast(&r.reserved_) - reinterpret_cast(&r.types_to_impl_), - kMapSize); - EXPECT_EQ(sizeof(r.reserved_), kReservedFieldSize2); -} - -TEST_F(AbiCompatibilityForRegisterUT, OpImplFunctionsV2_CheckMemLayoutNotChanged) { - gert::OpImplKernelRegistry::OpImplFunctionsV2 r; - ASSERT_EQ(sizeof(r), kOpImplSize); - EXPECT_EQ(sizeof(r.reserved_), kOpImplReservedFieldSize); -} -} // namespace gert diff --git a/tests/ut/register/testcase/auto_mapping_util_unittest.cc b/tests/ut/register/testcase/auto_mapping_util_unittest.cc deleted file mode 100644 index fac1ca0d15e83a9d9a46caeeba864c3615d3cbff..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/auto_mapping_util_unittest.cc +++ /dev/null @@ -1,389 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph_builder_utils.h" -#include "external/register/register.h" -#include -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_op_types.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/type_utils.h" -#include "register/op_registry.h" -#include "graph/graph.h" -#include "graph/utils/attr_utils.h" -#include "proto/tensorflow/node_def.pb.h" -#include "register/auto_mapping_util.h" -#include "external/register/scope/scope_fusion_pass_register.h" -#include "register/scope/scope_graph_impl.h" - -using namespace ge; -using namespace domi; -class AutoMappingUtils : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -void CreateTFGraphDef(domi::tensorflow::GraphDef &graph_def) { - // 1. add node - auto placeholder0 = graph_def.add_node(); - auto placeholder1 = graph_def.add_node(); - auto add0 = graph_def.add_node(); - auto add1 = graph_def.add_node(); - auto mul0 = graph_def.add_node(); - auto mul1 = graph_def.add_node(); - auto add2 = graph_def.add_node(); - auto retval0 = graph_def.add_node(); - auto retval1 = graph_def.add_node(); - - // 2. set info - placeholder0->set_name("placeholder0"); - placeholder0->set_op("PlaceHolder"); - placeholder1->set_name("placeholder1"); - placeholder1->set_op("PlaceHolder"); - - add0->set_name("add0"); - add0->set_op("Add"); - add1->set_name("add1"); - add1->set_op("Add"); - add2->set_name("add2"); - add2->set_op("Add"); - - mul0->set_name("mul0"); - mul0->set_op("Mul"); - mul1->set_name("mul1"); - mul1->set_op("Mul"); - - retval0->set_name("retval0"); - retval0->set_op("_RetVal"); - retval1->set_name("retval1"); - retval1->set_op("_RetVal"); - - // 3. add edges - add0->add_input("placeholder0"); - add0->add_input("placeholder1"); - - mul0->add_input("placeholder0"); - mul0->add_input("placeholder1"); - - mul1->add_input("placeholder0"); - mul1->add_input("add0"); - mul1->add_input("^mul0"); - - add1->add_input("mul0"); - add1->add_input("placeholder1"); - - add2->add_input("mul1"); - add2->add_input("mul0"); - - retval0->add_input("add2:0"); - retval1->add_input("add1:0"); -} - -TEST_F(AutoMappingUtils, FindAttrValueFalse) { - domi::tensorflow::GraphDef graph_def; - domi::tensorflow::AttrValue attr_num; - CreateTFGraphDef(graph_def); - bool ret; - domi::tensorflow::NodeDef *node0 = nullptr; - ret = ge::AutoMappingUtil::FindAttrValue(node0, string(""), attr_num); - EXPECT_FALSE(ret); - - domi::tensorflow::NodeDef node1; - ret = ge::AutoMappingUtil::FindAttrValue(&node1, string(""), attr_num); - EXPECT_FALSE(ret); - - const domi::tensorflow::NodeDef *node2 = graph_def.mutable_node(0); - ret = ge::AutoMappingUtil::FindAttrValue(node2, node2->name(), attr_num); - EXPECT_FALSE(ret); -} - -TEST_F(AutoMappingUtils, ConvertShape) { - domi::tensorflow::TensorShapeProto shape; - vector shape_dims; - - shape.set_unknown_rank(true); - ge::AutoMappingUtil::ConvertShape(shape, shape_dims); - EXPECT_EQ(shape_dims, ge::UNKNOWN_SHAPE); - - shape.set_unknown_rank(false); - shape.add_dim(); - ge::AutoMappingUtil::ConvertShape(shape, shape_dims); - EXPECT_NE(shape_dims, ge::UNKNOWN_SHAPE); -} - -TEST_F(AutoMappingUtils, ConvertTensor) { - ge::graphStatus ret; - domi::tensorflow::TensorProto tensor; - ge::GeTensorPtr weight; - - tensor.set_dtype(domi::tensorflow::DataType_INT_MAX_SENTINEL_DO_NOT_USE_); - ret = ge::AutoMappingUtil::ConvertTensor(tensor, weight); - EXPECT_EQ(ret, GRAPH_FAILED); - - tensor.set_dtype(domi::tensorflow::DT_UINT16_REF); - ret = ge::AutoMappingUtil::ConvertTensor(tensor, weight); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - tensor.set_dtype(domi::tensorflow::DT_UINT8); - ret = ge::AutoMappingUtil::ConvertTensor(tensor, weight); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(AutoMappingUtils, ConvertTensorList) { - domi::tensorflow::AttrValue_ListValue list; - std::vector vec; - - list.add_tensor(); - ge::AutoMappingUtil::ConvertTensorList(list, vec); - EXPECT_EQ(vec.empty(), true); -} - -TEST_F(AutoMappingUtils, ConvertFunc) { - EXPECT_NO_THROW( - domi::tensorflow::NameAttrList tf_func; - ge::NamedAttrs ge_func; - const int32_t kInvalidFuncRecursiveDepth = 31; - - tf_func.set_name("test_fun"); - ge::AutoMappingUtil::ConvertFunc(tf_func, ge_func); - ge::AutoMappingUtil::ConvertFunc(tf_func, ge_func, kInvalidFuncRecursiveDepth); - ); -} - -TEST_F(AutoMappingUtils, ConvertDataTypeList) { - domi::tensorflow::AttrValue_ListValue list; - std::vector vec; - - list.add_type(domi::tensorflow::DT_INT16); - ge::AutoMappingUtil::ConvertDataTypeList(list, vec); - EXPECT_EQ(vec.empty(), false); -} - -TEST_F(AutoMappingUtils, ConvertShapeList) { - domi::tensorflow::AttrValue_ListValue list; - std::vector> vec; - - list.add_shape(); - ge::AutoMappingUtil::ConvertShapeList(list, vec); - EXPECT_EQ(vec.empty(), false); -} - -TEST_F(AutoMappingUtils, ConvertFuncList) { - domi::tensorflow::AttrValue_ListValue list; - std::vector vec; - const int32_t kInvalidFuncRecursiveDepth = 31; - - list.add_func(); - ge::AutoMappingUtil::ConvertFuncList(list, vec, kInvalidFuncRecursiveDepth); - EXPECT_EQ(vec.empty(), true); - - ge::AutoMappingUtil::ConvertFuncList(list, vec); - EXPECT_EQ(vec.empty(), false); -} - -const float FLOAT_TEST_NUM = 3.14; -const double DOUBLE_TEST_NUM = 3.1415; -const int INT_TEST_NUM = 0; -const unsigned int UNSIGNED_INT_TEST_NUM = 0; - -TEST_F(AutoMappingUtils, CopyAttrValueInputTest) { - ut::GraphBuilder builder("graph"); - NodePtr node_src = builder.AddNode("ParseSingleNode", "ParseSingleType", 3, 0, FORMAT_ALL); - - ge::Operator op_src = OpDescUtils::CreateOperatorFromNode(node_src); - ge::Operator op_dst = ge::Operator("ParseSingleExample"); - std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op_dst); - std::vector value; - EXPECT_EQ(op_dst.GetInputsSize(), 0); - - const char_t *testName1 = "type_str"; - const char_t *testPort1 = "port_str"; - AttrUtils::SetStr(node_src->GetOpDesc(), testName1, "str_shapes"); - op_desc_dst->AddRegisterInputName(testPort1); - domi::DynamicInputOutputInfo input1(kInput, testPort1, strlen(testPort1), testName1, strlen(testName1)); - value.push_back(input1); - - const char_t *testName2 = "type_int"; - const char_t *testPort2 = "port_int"; - AttrUtils::SetInt(node_src->GetOpDesc(), testName2, INT_TEST_NUM); - op_desc_dst->AddRegisterInputName(testPort2); - domi::DynamicInputOutputInfo input2(kInput, testPort2, strlen(testPort2), testName2, strlen(testName2)); - value.push_back(input2); - - const char_t *testName3 = "type_float"; - const char_t *testPort3 = "port_float"; - AttrUtils::SetFloat(node_src->GetOpDesc(), testName3, FLOAT_TEST_NUM); - op_desc_dst->AddRegisterInputName(testPort3); - domi::DynamicInputOutputInfo input3(kInput, testPort3, strlen(testPort3), testName3, strlen(testName3)); - value.push_back(input3); - - const char_t *listName = "Name_inputlist1"; - const char_t *listPort = "port_inputlist1"; - vector InListDataType = {DT_STRING, DT_INT32, DT_FLOAT}; - AttrUtils::SetListDataType(node_src->GetOpDesc(), listName, InListDataType); - op_desc_dst->AddRegisterInputName(listPort); - domi::DynamicInputOutputInfo input4(kInput, listPort, strlen(listPort), listName, strlen(listName)); - value.push_back(input4); - - auto ret = domi::AutoMappingByOpFnDynamic(op_src, op_dst, value); - EXPECT_EQ(ret, domi::SUCCESS); - EXPECT_EQ(op_dst.GetInputsSize(), 3); -} - -TEST_F(AutoMappingUtils, CopyAttrValueInputListTest) { - ut::GraphBuilder builder("graph"); - NodePtr node_src = builder.AddNode("ParseSingleNode", "ParseSingleType", 6, 0, FORMAT_ALL); - - ge::Operator op_src = OpDescUtils::CreateOperatorFromNode(node_src); - ge::Operator op_dst = ge::Operator("ParseSingleExample"); - std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op_dst); - std::vector value; - EXPECT_EQ(op_dst.GetInputsSize(), 0); - - const char_t *testlistName1 = "listName_str"; - const char_t *testlistPort1 = "listport_str"; - vector attrStrList = {"image/class/lable","image/encode", "image/format"}; - AttrUtils::SetListStr(node_src->GetOpDesc(), testlistName1, attrStrList); - op_desc_dst->AddRegisterInputName(testlistPort1); - domi::DynamicInputOutputInfo input1(kInput, testlistPort1, strlen(testlistPort1), testlistName1, strlen(testlistName1)); - value.push_back(input1); - - const char_t *testlistName2 = "listName_Int"; - const char_t *testlistPort2 = "listport_Int"; - vector attrIntList = {0, 1, 2, 3}; - AttrUtils::SetListInt(node_src->GetOpDesc(), testlistName2, attrIntList); - op_desc_dst->AddRegisterInputName(testlistPort2); - domi::DynamicInputOutputInfo input2(kInput, testlistPort2, strlen(testlistPort2), testlistName2, strlen(testlistName2)); - value.push_back(input2); - - const char_t *testlistName3 = "listName_Float"; - const char_t *testlistPort3 = "listport_Float"; - vector attrFloatList = {0.0, 0.1, 0.2, 0.3}; - AttrUtils::SetListFloat(node_src->GetOpDesc(), testlistName3, attrFloatList); - op_desc_dst->AddRegisterInputName(testlistPort3); - domi::DynamicInputOutputInfo input3(kInput, testlistPort3, strlen(testlistPort3), testlistName3, strlen(testlistName3)); - value.push_back(input3); - - const char_t *testlistName4 = "listName_Bool"; - const char_t *testlistPort4 = "listport_Bool"; - vector attrBoolList = {true, false, false, true}; - AttrUtils::SetListBool(node_src->GetOpDesc(), testlistName4, attrBoolList); - op_desc_dst->AddRegisterInputName(testlistPort4); - domi::DynamicInputOutputInfo input4(kInput, testlistPort4, strlen(testlistPort4), testlistName4, strlen(testlistName4)); - value.push_back(input4); - - const char_t *testlistName5 = "listName_NamedAttrs"; - const char_t *testlistPort5 = "listport_NamedAttrs"; - NamedAttrs name1; NamedAttrs name2; - vector attrNamedAttrsList = {name1, name2}; - AttrUtils::SetListNamedAttrs(node_src->GetOpDesc(), testlistName5, attrNamedAttrsList); - op_desc_dst->AddRegisterInputName(testlistPort5); - domi::DynamicInputOutputInfo input5(kInput, testlistPort5, strlen(testlistPort5), testlistName5, strlen(testlistName5)); - value.push_back(input5); - - const char_t *testlistName6 = "listName_Int"; - const char_t *testlistPort6 = "listport_Int"; - vector> attrIntListList = {attrIntList, attrIntList}; - AttrUtils::SetListListInt(node_src->GetOpDesc(), testlistName6, attrIntListList); - op_desc_dst->AddRegisterInputName(testlistPort6); - domi::DynamicInputOutputInfo input6(kInput, testlistPort6, strlen(testlistPort6), testlistName6, strlen(testlistName6)); - value.push_back(input6); - - const char_t *testlist_ListName = "listName_ListData"; - const char_t *testlist_ListPort = "listport_ListData"; - vector InListDataType = {DT_STRING, DT_INT32, DT_FLOAT, DT_BOOL, DT_UNDEFINED, DT_UNDEFINED}; - AttrUtils::SetListDataType(node_src->GetOpDesc(), testlist_ListName, InListDataType); - op_desc_dst->AddRegisterInputName(testlist_ListPort); - domi::DynamicInputOutputInfo input7(kInput, testlist_ListPort, strlen(testlist_ListPort), testlist_ListName, strlen(testlist_ListName)); - value.push_back(input7); - - auto ret = domi::AutoMappingByOpFnDynamic(op_src, op_dst, value); - EXPECT_EQ(ret, domi::SUCCESS); - EXPECT_EQ(op_dst.GetInputsSize(), 6); -} - -TEST_F(AutoMappingUtils, CopyAttrValueOutputTest) { - ut::GraphBuilder builder("graph"); - NodePtr node_src = builder.AddNode("ParseSingleNode", "ParseSingleType", 0, 4, FORMAT_ALL); - - ge::Operator op_src = OpDescUtils::CreateOperatorFromNode(node_src); - ge::Operator op_dst = ge::Operator("ParseSingleExample"); - std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op_dst); - std::vector value; - EXPECT_EQ(op_dst.GetInputsSize(), 0); - - const char_t *testName1 = "Name_attrBool"; - const char_t *testPort1 = "port_attrBool"; - AttrUtils::SetBool(node_src->GetOpDesc(), testName1, true); - op_desc_dst->AddRegisterOutputName(testPort1); - domi::DynamicInputOutputInfo output1(kOutput, testPort1, strlen(testPort1), testName1, strlen(testName1)); - value.push_back(output1); - - const char_t *testName2 = "Name_attrName"; - const char_t *testPort2 = "port_attrName"; - NamedAttrs NamedAttr; NamedAttr.SetName("NamedAttr"); - AttrUtils::SetNamedAttrs(node_src->GetOpDesc(), testName2, NamedAttr); - op_desc_dst->AddRegisterOutputName(testPort2); - domi::DynamicInputOutputInfo output2(kOutput, testPort2, strlen(testPort2), testName2, strlen(testName2)); - value.push_back(output2); - - const char_t *testName3 = "Name_attrDataType"; - const char_t *testPort3 = "port_attrDataType"; - AttrUtils::SetDataType(node_src->GetOpDesc(), testName3, DT_INT16); - op_desc_dst->AddRegisterOutputName(testPort3); - domi::DynamicInputOutputInfo output3(domi::kOutput, testPort3, strlen(testPort3), testName3, strlen(testName3)); - value.push_back(output3); - - const char_t *testName4 = "Name_attrGraph"; - const char_t *testPort4 = "port_attrGraph"; - ComputeGraphPtr graph = builder.GetGraph(); - AttrUtils::SetGraph(node_src->GetOpDesc(), testName4, graph); - op_desc_dst->AddRegisterOutputName(testPort4); - domi::DynamicInputOutputInfo output4(kOutput, testPort4, strlen(testPort4), testName4, strlen(testName4)); - value.push_back(output4); - - const char_t *testName5 = "Name_attrDataTypeList"; - const char_t *testPort5 = "port_attrDataTypeList"; - vector OutListDataType = {DT_BOOL, DT_STRING, DT_INT16, DT_RESOURCE}; - AttrUtils::SetListDataType(node_src->GetOpDesc(), testName5, OutListDataType); - op_desc_dst->AddRegisterOutputName(testPort5); - domi::DynamicInputOutputInfo output5(kOutput, testPort5, strlen(testPort5), testName5, strlen(testName5)); - value.push_back(output5); - - auto ret = domi::AutoMappingByOpFnDynamic(op_src, op_dst, value); - EXPECT_EQ(ret, domi::SUCCESS); - EXPECT_EQ(op_dst.GetOutputsSize(), 4); -} - -TEST_F(AutoMappingUtils, ConvertValueTest) { - ge::NamedAttrs ge_func; - std::string convertName = "convertName"; - domi::tensorflow::AttrValue value; - - value.set_s(std::string("valueString")); - value.set_has_s(); - auto op_desc = std::make_shared(); - ge::AutoMappingUtil::ConvertValue(convertName, value, op_desc, 0); - std::string valueStr; - ge::AttrUtils::GetStr(op_desc, convertName, valueStr); - EXPECT_EQ(valueStr=="valueString", true); - - ge::AutoMappingUtil::ConvertValue(convertName, value, ge_func, 0); - ge::AttrUtils::GetStr(ge_func, convertName, valueStr); - EXPECT_EQ(valueStr=="valueString", true); -} - diff --git a/tests/ut/register/testcase/custom_pass/register_custom_pass_unittest.cc b/tests/ut/register/testcase/custom_pass/register_custom_pass_unittest.cc deleted file mode 100644 index c38bc3e20d09ec289ef6d1aea90d97b045f072b5..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/custom_pass/register_custom_pass_unittest.cc +++ /dev/null @@ -1,303 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "inc/external/register/register_custom_pass.h" -#include "register/custom_pass_context_impl.h" -#include "graph/debug/ge_log.h" -#include "register/custom_pass_helper.h" -#include "tests/depends/mmpa/src/mmpa_stub.h" - -namespace ge { -namespace { - const char *const kEnvName = "ASCEND_OPP_PATH"; -} -class UtestRegisterPass : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} - - void CreateSharedLibrary(const std::string &path) { - std::ofstream ofs(path + ".cpp"); - ofs << R"( - #include - extern "C" void hello() { - std::cout << "Hello, world!" << std::endl; - } - )"; - ofs.close(); - std::string cmd = "g++ -shared -fPIC -o " + path + ".so " + path + ".cpp"; - system(cmd.c_str()); - std::remove((path + ".cpp").c_str()); - } - - static Status MyCustomPass(ge::GraphPtr &graph, CustomPassContext &context) { - if (graph->GetName() == "test") { - context.SetErrorMessage("graph name is invalid"); - return FAILED; - } - return SUCCESS; - } - - static Status FooConstGraphCustomPass(const ConstGraphPtr &graph, CustomPassContext &context) { - if (graph->GetName() == "error_graph") { - context.SetErrorMessage("graph name is invalid"); - return FAILED; - } - return SUCCESS; - } -}; - -TEST_F(UtestRegisterPass, GetPassNameTest) { - ge::PassRegistrationData pass_data("registry"); - std::string name = pass_data.GetPassName(); - EXPECT_EQ(name, "registry"); - - pass_data.impl_ = nullptr; - name = pass_data.GetPassName(); - EXPECT_EQ(name, ""); -} - -TEST_F(UtestRegisterPass, CustomPassFnTest) { - CustomPassFunc custom_pass_fn = nullptr; - ge::PassRegistrationData pass_data("registry"); - pass_data.CustomPassFn(custom_pass_fn); - auto ret = pass_data.GetCustomPassFn(); - EXPECT_EQ(ret, nullptr); - - custom_pass_fn = std::function(); - pass_data.impl_ = nullptr; - pass_data.CustomPassFn(custom_pass_fn); - ret = pass_data.GetCustomPassFn(); - EXPECT_EQ(ret, nullptr); -} - -TEST_F(UtestRegisterPass, CustomPassHelperRunTest) { - PassRegistrationData pass_data("registry"); - ge::PassReceiver pass_receiver(pass_data); - CustomPassHelper cust_helper; - auto graph = std::make_shared("test"); - CustomPassContext custom_pass_context; - bool ret = cust_helper.Run(graph, custom_pass_context); - EXPECT_EQ(ret, SUCCESS); - - // not register pass func - PassRegistrationData pass_data2("registry2"); - cust_helper.registration_datas_.emplace_back(pass_data2); - auto graph2 = std::make_shared("test2"); - ret = cust_helper.Run(graph2, custom_pass_context); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UtestRegisterPass, CustomPassHelperRunTest_Failed) { - CustomPassHelper cust_helper; - CustomPassContext custom_pass_context; - PassRegistrationData pass_data2("registry2"); - pass_data2.CustomPassFn(MyCustomPass); - cust_helper.registration_datas_.emplace_back(pass_data2); - auto graph = std::make_shared("test"); - auto ret = cust_helper.Run(graph, custom_pass_context); - EXPECT_NE(ret, SUCCESS); -} - -TEST_F(UtestRegisterPass, CustomPassHelperRunTest_Success) { - CustomPassHelper cust_helper; - CustomPassContext custom_pass_context; - PassRegistrationData pass_data2("registry2"); - pass_data2.CustomPassFn(MyCustomPass); - cust_helper.registration_datas_.emplace_back(pass_data2); - auto graph = std::make_shared("test2"); - auto ret = cust_helper.Run(graph, custom_pass_context); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(UtestRegisterPass, LoadCustomPassLibsTest_Failed) { - CustomPassHelper cust_helper; - ge::Status status = cust_helper.Load(); - EXPECT_EQ(status, ge::SUCCESS); - status = cust_helper.Unload(); - EXPECT_EQ(status, ge::SUCCESS); -} - -TEST_F(UtestRegisterPass, LoadCustomPassLibsTest_Failed_Invalid_Lib) { - std::string path = __FILE__; - path = path.substr(0, path.rfind("/") + 1) + "opp"; - mmSetEnv(kEnvName, path.c_str(), 1); - system(("mkdir -p " + path).c_str()); - - std::string custom_path = path + "/vendors/1/custom_fusion_passes"; - system(("mkdir -p " + custom_path).c_str()); - system(("touch " + custom_path + "/concat_pass.so").c_str()); - system(("touch " + custom_path + "/tile_pass.so").c_str()); - system(("touch " + custom_path + "/add_pass.so").c_str()); - - CustomPassHelper cust_helper; - ge::Status status = cust_helper.Load(); - EXPECT_EQ(status, ge::FAILED); - status = cust_helper.Unload(); - EXPECT_EQ(status, ge::SUCCESS); - - system(("rm -rf " + path).c_str()); -} - -TEST_F(UtestRegisterPass, LoadCustomPassLibsTest_MissingDependencies) { - std::string path = __FILE__; - path = path.substr(0, path.rfind("/") + 1) + "opp"; - mmSetEnv(kEnvName, path.c_str(), 1); - system(("mkdir -p " + path).c_str()); - - std::string custom_path = path + "/vendors/1/custom_fusion_passes"; - system(("mkdir -p " + custom_path).c_str()); - - // Create a shared library that depends on a dummy library - std::ofstream dummy_lib(custom_path + "/libdummy.cpp"); - dummy_lib << R"( - #include - extern "C" void dummy() { - std::cout << "Dummy function" << std::endl; - } - )"; - dummy_lib.close(); - std::string dummy_cmd = "g++ -shared -fPIC -o " + custom_path + "/libdummy.so " + custom_path + "/libdummy.cpp"; - system(dummy_cmd.c_str()); - std::remove((custom_path + "/libdummy.cpp").c_str()); - - // Create the main shared library that depends on the dummy library - std::ofstream main_lib(custom_path + "/libcustom_pass.cpp"); - main_lib << R"( - #include - extern void dummy(); - extern "C" void hello() { - dummy(); - std::cout << "Hello, world!" << std::endl; - } - )"; - main_lib.close(); - std::string main_cmd = "g++ -shared -fPIC -o " + custom_path + "/libcustom_pass.so " + custom_path + "/libcustom_pass.cpp -L" + custom_path + " -ldummy"; - system(main_cmd.c_str()); - std::remove((custom_path + "/libcustom_pass.cpp").c_str()); - - // Ensure the shared library is created - struct stat buffer; - ASSERT_EQ(stat((custom_path + "/libcustom_pass.so").c_str(), &buffer), 0); - - // Remove the dummy library to simulate missing dependency - system(("rm " + custom_path + "/libdummy.so").c_str()); - - // Call the function under test - CustomPassHelper cust_helper; - ge::Status status = cust_helper.Load(); - EXPECT_EQ(status, ge::FAILED); - - system(("rm -rf " + path).c_str()); -} - -TEST_F(UtestRegisterPass, LoadCustomPassLibsTest_Success) { - std::string path = __FILE__; - path = path.substr(0, path.rfind("/") + 1) + "opp"; - mmSetEnv(kEnvName, path.c_str(), 1); - system(("mkdir -p " + path).c_str()); - - std::string custom_path = path + "/vendors/1/custom_fusion_passes/add"; - system(("mkdir -p " + custom_path).c_str()); - - CreateSharedLibrary(custom_path); - - // Call the function under test - CustomPassHelper cust_helper; - ge::Status status = cust_helper.Load(); - EXPECT_EQ(status, ge::SUCCESS); - status = cust_helper.Unload(); - EXPECT_EQ(status, ge::SUCCESS); - - system(("rm -rf " + path).c_str()); -} - -TEST_F(UtestRegisterPass, CustomPassStage_Success) { - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.Stage(CustomPassStage::kAfterInferShape); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterInferShape); -} - -TEST_F(UtestRegisterPass, CustomPassStage_AndRun_Success) { - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomPassFn(MyCustomPass).Stage(CustomPassStage::kAfterInferShape); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - auto graph = std::make_shared("test2"); - CustomPassContext custom_pass_context; - EXPECT_EQ(CustomPassHelper::Instance().Run(graph, custom_pass_context), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterInferShape); -} - -TEST_F(UtestRegisterPass, CustomPassStage_Failed) { - PassRegistrationData pass_reg_data; - pass_reg_data.Stage(CustomPassStage::kAfterInferShape); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kInvalid); -} - -TEST_F(UtestRegisterPass, ConstGraphCustomPass_AndRun_SUCCESS) { - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomPassFn(FooConstGraphCustomPass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - auto graph = std::make_shared("test2"); - CustomPassContext custom_pass_context; - EXPECT_NE(CustomPassHelper::Instance().Run(graph, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterAssignLogicStream); -} - -TEST_F(UtestRegisterPass, ConstGraphCustomPass_AndRun_Failed_RegisterWrongFunc) { - PassRegistrationData pass_reg_data("custom_pass"); - // wrong func in kAfterAssignLogicStream stage - pass_reg_data.CustomPassFn(MyCustomPass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - auto graph = std::make_shared("test2"); - CustomPassContext custom_pass_context; - EXPECT_NE(CustomPassHelper::Instance().Run(graph, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterAssignLogicStream); -} - -TEST_F(UtestRegisterPass, ConstGraphCustomPass_AndRun_Failed_FuncReturnError) { - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomPassFn(FooConstGraphCustomPass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - auto graph = std::make_shared("error_graph"); - CustomPassContext custom_pass_context; - EXPECT_NE(CustomPassHelper::Instance().Run(graph, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterAssignLogicStream); -} - -TEST_F(UtestRegisterPass, ConstGraph_AfterBuiltinFusionCustomPass_AndRun_SUCCESS) { - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomPassFn(FooConstGraphCustomPass).Stage(CustomPassStage::kAfterBuiltinFusionPass); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - auto graph = std::make_shared("test2"); - CustomPassContext custom_pass_context; - EXPECT_EQ(CustomPassHelper::Instance().Run(graph, custom_pass_context, CustomPassStage::kAfterBuiltinFusionPass), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterBuiltinFusionPass); -} - -TEST_F(UtestRegisterPass, ConstGraph_AfterBuiltinFusionCustomPass_AndRun_Failed_FuncReturnError) { - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomPassFn(FooConstGraphCustomPass).Stage(CustomPassStage::kAfterBuiltinFusionPass); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - auto graph = std::make_shared("error_graph"); - CustomPassContext custom_pass_context; - EXPECT_NE(CustomPassHelper::Instance().Run(graph, custom_pass_context, CustomPassStage::kAfterBuiltinFusionPass), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterBuiltinFusionPass); -} -} // namespace ge diff --git a/tests/ut/register/testcase/custom_pass/register_custom_stream_pass_unittest.cc b/tests/ut/register/testcase/custom_pass/register_custom_stream_pass_unittest.cc deleted file mode 100644 index 88c80f2d6ef208d5d13d6b2fe1e287229e03aa00..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/custom_pass/register_custom_stream_pass_unittest.cc +++ /dev/null @@ -1,258 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include -#include - -#include "inc/external/register/register_custom_pass.h" -#include "register/custom_pass_context_impl.h" -#include "graph/debug/ge_log.h" -#include "register/custom_pass_helper.h" -#include "tests/depends/mmpa/src/mmpa_stub.h" -#include "tests/ut/graph/common/share_graph.h" - -namespace ge { -class UtestRegisterStreamPass : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} - - static Status AssignStreamIdByTopoIdPass(const ConstGraphPtr &graph, StreamPassContext &context) { - if (graph->GetName() == "error_graph") { - context.SetErrorMessage("graph name is invalid"); - return FAILED; - } - - for (const auto &node : graph->GetAllNodes()) { - GE_ASSERT_SUCCESS(context.SetStreamId(node, context.AllocateNextStreamId())); - } - return SUCCESS; - } - - static Status AssignNewStreamIdPass(const ConstGraphPtr &graph, StreamPassContext &context) { - for (const auto &node : graph->GetAllNodes()) { - GE_ASSERT_SUCCESS(context.SetStreamId(node, context.AllocateNextStreamId())); - } - return SUCCESS; - } - static Status AssignStreamIdOutOfRangePass(const ConstGraphPtr &graph, StreamPassContext &context) { - for (const auto &node : graph->GetAllNodes()) { - if (context.SetStreamId(node, context.GetCurrMaxStreamId() + 1) != SUCCESS) { - AscendString name; - node.GetName(name); - auto error_msg = AscendString("Failed to set stream id for node"); - context.SetErrorMessage(error_msg); - return FAILED; - } - } - return SUCCESS; - } - - static Status AssignStreamIdOutOfRange2Pass(const ConstGraphPtr &graph, StreamPassContext &context) { - for (const auto &node : graph->GetAllNodes()) { - if (context.SetStreamId(node, -1) != SUCCESS) { - AscendString name; - node.GetName(name); - auto error_msg = AscendString("Failed to set stream id for node"); - context.SetErrorMessage(error_msg); - return FAILED; - } - } - return SUCCESS; - } - - static Status FooNonConstGraphCustomPass(GraphPtr &graph, CustomPassContext &context) { - return SUCCESS; - } -}; - -TEST_F(UtestRegisterStreamPass, AsssignStreamIdByTopoIdPass_SUCCESS) { - int64_t default_stream_id = 0u; - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomAllocateStreamPassFn(AssignStreamIdByTopoIdPass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - - // prepare graph - auto compute_graph = SharedGraph::BuildGraphWithControlEdge(); - compute_graph->TopologicalSorting(); - // init stream id to 0 - for (const auto &node : compute_graph->GetAllNodes()) { - node->GetOpDesc()->SetStreamId(default_stream_id); - } - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - auto const_graph = std::make_shared(graph); - auto graph_ptr = std::make_shared(graph); - // init stream pass context - StreamPassContext custom_pass_context(-1); - - EXPECT_EQ(CustomPassHelper::Instance().Run(graph_ptr, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterAssignLogicStream); - - // check stream id is equal with topo id - for (const auto &node : const_graph->GetAllNodes()) { - int64_t stream_id = custom_pass_context.GetStreamId(node); - auto topo_id = NodeAdapter::GNode2Node(node)->GetOpDescBarePtr()->GetId(); - EXPECT_EQ(stream_id, topo_id); - } -} - -TEST_F(UtestRegisterStreamPass, AssignNewStreamIdPass_SUCCESS) { - int64_t default_stream_id = 0u; - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomAllocateStreamPassFn(AssignNewStreamIdPass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - - // prepare graph - auto compute_graph = SharedGraph::BuildGraphWithControlEdge(); - compute_graph->TopologicalSorting(); - // init stream id to 0 - for (const auto &node : compute_graph->GetAllNodes()) { - node->GetOpDesc()->SetStreamId(default_stream_id); - } - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - auto const_graph = std::make_shared(graph); - auto graph_ptr = std::make_shared(graph); - // init stream pass context - StreamPassContext custom_pass_context(0); - - EXPECT_EQ(CustomPassHelper::Instance().Run(graph_ptr, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - EXPECT_EQ(pass_reg_data.GetStage(), CustomPassStage::kAfterAssignLogicStream); - - size_t expect_stream_id = 1; - for (const auto &node : const_graph->GetAllNodes()) { - EXPECT_EQ(custom_pass_context.GetStreamId(node), expect_stream_id++); - } - EXPECT_EQ(custom_pass_context.GetCurrMaxStreamId(), 5); -} - -TEST_F(UtestRegisterStreamPass, AsssignStreamIdByTopoIdPass_PassRunFailed_Failed) { - int64_t default_stream_id = 0u; - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomAllocateStreamPassFn(AssignStreamIdByTopoIdPass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - - // prepare graph - auto compute_graph = SharedGraph::BuildGraphWithControlEdge(); - compute_graph->TopologicalSorting(); - compute_graph->SetName("error_graph"); - // init stream id to 0 - for (const auto &node : compute_graph->GetAllNodes()) { - node->GetOpDesc()->SetStreamId(default_stream_id); - } - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - auto const_graph = std::make_shared(graph); - auto graph_ptr = std::make_shared(graph); - // init stream pass context - StreamPassContext custom_pass_context(0); - for (const auto &node : const_graph->GetAllNodes()) { - custom_pass_context.SetStreamId(node, 0u); - } - - EXPECT_NE(CustomPassHelper::Instance().Run(graph_ptr, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - auto error_msg = custom_pass_context.GetErrorMessage().GetString(); - EXPECT_STREQ(error_msg, "graph name is invalid"); -} - -TEST_F(UtestRegisterStreamPass, AssignStreamIdOutOfRangePass_PassRunFailed_Failed) { - int64_t default_stream_id = 0u; - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomAllocateStreamPassFn(AssignStreamIdOutOfRangePass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - - // prepare graph - auto compute_graph = SharedGraph::BuildGraphWithControlEdge(); - compute_graph->TopologicalSorting(); - // init stream id to 0 - for (const auto &node : compute_graph->GetAllNodes()) { - node->GetOpDesc()->SetStreamId(default_stream_id); - } - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - auto const_graph = std::make_shared(graph); - auto graph_ptr = std::make_shared(graph); - // init stream pass context - StreamPassContext custom_pass_context(0); - for (const auto &node : const_graph->GetAllNodes()) { - custom_pass_context.SetStreamId(node, 0u); - } - - EXPECT_NE(CustomPassHelper::Instance().Run(graph_ptr, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - auto error_msg = custom_pass_context.GetErrorMessage().GetString(); - EXPECT_STREQ(error_msg, "Failed to set stream id for node"); -} -TEST_F(UtestRegisterStreamPass, AssignStreamIdOutOfRange2Pass_PassRunFailed_Failed) { - int64_t default_stream_id = 0u; - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomAllocateStreamPassFn(AssignStreamIdOutOfRange2Pass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - - // prepare graph - auto compute_graph = SharedGraph::BuildGraphWithControlEdge(); - compute_graph->TopologicalSorting(); - // init stream id to 0 - for (const auto &node : compute_graph->GetAllNodes()) { - node->GetOpDesc()->SetStreamId(default_stream_id); - } - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - auto const_graph = std::make_shared(graph); - auto graph_ptr = std::make_shared(graph); - // init stream pass context - StreamPassContext custom_pass_context(0); - for (const auto &node : const_graph->GetAllNodes()) { - custom_pass_context.SetStreamId(node, 0u); - } - - EXPECT_NE(CustomPassHelper::Instance().Run(graph_ptr, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); - auto error_msg = custom_pass_context.GetErrorMessage().GetString(); - EXPECT_STREQ(error_msg, "Failed to set stream id for node"); -} -TEST_F(UtestRegisterStreamPass, RegNormalGraphPass_RegFailed_Failed) { - PassRegistrationData pass_reg_data("custom_pass"); - pass_reg_data.CustomPassFn(FooNonConstGraphCustomPass).Stage(CustomPassStage::kAfterAssignLogicStream); - CustomPassHelper::Instance().Unload(); - CustomPassHelper::Instance().Insert(pass_reg_data); - - // prepare graph - auto compute_graph = SharedGraph::BuildGraphWithControlEdge(); - compute_graph->TopologicalSorting(); - - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - auto const_graph = std::make_shared(graph); - auto graph_ptr = std::make_shared(graph); - // init stream pass context - StreamPassContext custom_pass_context(0); - for (const auto &node : const_graph->GetAllNodes()) { - custom_pass_context.SetStreamId(node, 0u); - } - - EXPECT_NE(CustomPassHelper::Instance().Run(graph_ptr, custom_pass_context, CustomPassStage::kAfterAssignLogicStream), SUCCESS); -} -TEST_F(UtestRegisterStreamPass, StreamPassContext_ImplNull_GetStreamId_failed) { - auto compute_graph = SharedGraph::BuildGraphWithControlEdge(); - auto graph = GraphUtilsEx::CreateGraphFromComputeGraph(compute_graph); - auto const_graph = std::make_shared(graph); - auto graph_ptr = std::make_shared(graph); - - // init stream pass context - StreamPassContext custom_pass_context(0); - custom_pass_context.impl_ = nullptr; - - for (const auto &node : const_graph->GetAllNodes()) { - EXPECT_EQ(custom_pass_context.GetStreamId(node), INVALID_STREAM_ID); - } -} -} // namespace ge diff --git a/tests/ut/register/testcase/ffts_node_registry_unittest.cc b/tests/ut/register/testcase/ffts_node_registry_unittest.cc deleted file mode 100644 index a02e20526d8946f07c372f93b757bba0afec989e..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/ffts_node_registry_unittest.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/ffts_node_calculater_registry.h" -#include "register/ffts_node_converter_registry.h" -#include "register/op_ext_gentask_registry.h" -#include "proto/task.pb.h" -#include - -class FFTSNodeRegistryUnittest : public testing::Test {}; - -namespace TestFFTSNodeRegistry { -gert::LowerResult TestFftsLowFunc(const ge::NodePtr &node, const gert::FFTSLowerInput &lower_input) { - return {}; -} - -ge::graphStatus TestFftsCalcFunc(const ge::NodePtr &node, const gert::LoweringGlobalData *global_data, - size_t &total_size, size_t &pre_data_size, std::unique_ptr &pre_data_ptr) { - return ge::GRAPH_SUCCESS; -} - -TEST_F(FFTSNodeRegistryUnittest, ConverterRegisterSuccess_Test) { - EXPECT_EQ(gert::FFTSNodeConverterRegistry::GetInstance().FindNodeConverter("RegisterSuccess1"), nullptr); - FFTS_REGISTER_NODE_CONVERTER("RegisterSuccess1", TestFftsLowFunc); - EXPECT_EQ(gert::FFTSNodeConverterRegistry::GetInstance().FindNodeConverter("RegisterSuccess1"), TestFftsLowFunc); -} - -TEST_F(FFTSNodeRegistryUnittest, CalculaterRegisterSuccess_Test) { - EXPECT_EQ(gert::FFTSNodeCalculaterRegistry::GetInstance().FindNodeCalculater("RegisterSuccess2"), nullptr); - FFTS_REGISTER_NODE_CALCULATER("RegisterSuccess2", TestFftsCalcFunc); - EXPECT_EQ(gert::FFTSNodeCalculaterRegistry::GetInstance().FindNodeCalculater("RegisterSuccess2"), TestFftsCalcFunc); -} - -TEST_F(FFTSNodeRegistryUnittest, SkipCtxRecord_Test) { - gert::SkipCtxRecord skip_record; - uint32_t ctx_id = 0; - uint32_t ctx_type = 1; - EXPECT_EQ(skip_record.SetSkipCtx(ctx_id, ctx_type), false); - skip_record.Init(); - skip_record.SetSkipCtx(1, 2); - skip_record.SetSkipCtx(2, 3); - EXPECT_EQ(skip_record.GetCtxNum(), 2); - skip_record.GetSkipCtx(1, ctx_id, ctx_type); - EXPECT_EQ(ctx_id, 2); - EXPECT_EQ(ctx_type, 3); - skip_record.ClearRecord(); - EXPECT_EQ(skip_record.GetCtxNum(), 0); -} - -ge::Status TestOpExtGenTask(const ge::Node &node, ge::RunContext &context, std::vector &tasks) { - return ge::SUCCESS; -} - -TEST_F(FFTSNodeRegistryUnittest, OpExtGenTask_test1) { - EXPECT_EQ(fe::OpExtGenTaskRegistry::GetInstance().FindRegisterFunc("Conv2D"), nullptr); - REGISTER_NODE_EXT_GENTASK("Conv2D", TestOpExtGenTask); - auto func = fe::OpExtGenTaskRegistry::GetInstance().FindRegisterFunc("Conv2D"); - EXPECT_EQ(func, TestOpExtGenTask); -} - -ge::Status TestSKExtGenTaskFunc(const ge::Node &node, std::vector> &subTasks, - const std::vector &sub_nodes, std::vector &tasks) { - return ge::SUCCESS; -} - -TEST_F(FFTSNodeRegistryUnittest, OpExtGenTask_test2) { - EXPECT_EQ(fe::OpExtGenTaskRegistry::GetInstance().FindSKRegisterFunc("Conv2D"), nullptr); - REGISTER_SK_EXT_GENTASK("Conv2D", TestSKExtGenTaskFunc); - auto func = fe::OpExtGenTaskRegistry::GetInstance().FindSKRegisterFunc("Conv2D"); - EXPECT_EQ(func, TestSKExtGenTaskFunc); -} - -TEST_F(FFTSNodeRegistryUnittest, ExtTaskTypeReg_test) { - REGISTER_EXT_TASK_TYPE(MoeDistributeCombine, fe::ExtTaskType::kAicoreTask); - fe::ExtTaskType taskType = fe::OpExtGenTaskRegistry::GetInstance().GetExtTaskType("MoeDistributeCombine"); - EXPECT_EQ(taskType, fe::ExtTaskType::kAicoreTask); -} -} diff --git a/tests/ut/register/testcase/ffts_plus_task_update_unittest.cc b/tests/ut/register/testcase/ffts_plus_task_update_unittest.cc deleted file mode 100644 index e000b03f939b020a6f95295b5fe253c0bd972354..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/ffts_plus_task_update_unittest.cc +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "gtest/gtest.h" - -#include "register/ffts_plus_update_manager.h" -#include "common/plugin/plugin_manager.h" - -namespace ge { -class FFTSPlusTaskUpdateStub : public FFTSPlusTaskUpdate { - public: - Status GetAutoThreadParam(const NodePtr &node, const std::vector &op_run_info, - AutoThreadParam &auto_thread_param) override { - return SUCCESS; - } - - Status UpdateSubTaskAndCache(const NodePtr &node, const AutoThreadSubTaskFlush &sub_task_flush, - rtFftsPlusTaskInfo_t &ffts_plus_task_info) override { - return SUCCESS; - } - - Status UpdateCommonCtx(const ComputeGraphPtr &sgt_graph, rtFftsPlusTaskInfo_t &task_info) override { - return SUCCESS; - } -}; - -class UtestFftsPlusUpdate : public testing::Test { - protected: - void SetUp() { - const std::string kCoreTypeTest = "FFTS_TEST"; // FftsPlusUpdateManager::FftsPlusUpdateRegistrar - REGISTER_FFTS_PLUS_CTX_UPDATER(kCoreTypeTest, FFTSPlusTaskUpdateStub); - } - - void TearDown() { - FftsPlusUpdateManager::Instance().creators_.clear(); - FftsPlusUpdateManager::Instance().plugin_manager_.reset(); - } -}; - -TEST_F(UtestFftsPlusUpdate, GetUpdater) { - EXPECT_EQ(FftsPlusUpdateManager::Instance().GetUpdater("AIC_AIV"), nullptr); - EXPECT_NE(FftsPlusUpdateManager::Instance().GetUpdater("FFTS_TEST"), nullptr); -} -} diff --git a/tests/ut/register/testcase/fusion_quant_utils_unittest.cc b/tests/ut/register/testcase/fusion_quant_utils_unittest.cc deleted file mode 100644 index acd68cda24c6b52724f3145b9b2eacf0b3a69d93..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/fusion_quant_utils_unittest.cc +++ /dev/null @@ -1,320 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/op_desc.h" -#include "graph/compute_graph.h" -#include "graph_optimizer/fusion_common/graph_pass_util.h" -#include "register/graph_optimizer/graph_fusion/fusion_quant_util.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -using namespace std; -using namespace ge; - -namespace fe { - -class FusionQuantUtilUT : public testing::Test { -protected: - void SetUp() {} - - void TearDown() {} - - static ge::ComputeGraphPtr CreateTestGraphWithOffset() { - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr x = std::make_shared("x", "Data"); - OpDescPtr weight = std::make_shared("weight", "Const"); - OpDescPtr atquant_scale = std::make_shared("atquant_scale", "Const"); - OpDescPtr quant_scale = std::make_shared("quant_scale", "Const"); - OpDescPtr quant_offset = std::make_shared("quant_offset", "Const"); - OpDescPtr mm = std::make_shared("mm", "WeightQuantBatchMatmulV2"); - OpDescPtr y = std::make_shared("y", "NetOutput"); - - // add descriptor - ge::GeShape shape1({2,4,9,16}); - GeTensorDesc tensor_desc1(shape1, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tensor_desc1.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc1.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - - GeTensorDesc tensor_desc2(shape1, ge::FORMAT_NCHW, ge::DT_INT8); - tensor_desc2.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc2.SetOriginDataType(ge::DT_INT8); - tensor_desc2.SetOriginShape(shape1); - - ge::GeShape shape2({1, 16}); - GeTensorDesc tensor_desc3(shape2, ge::FORMAT_ND, ge::DT_FLOAT); - tensor_desc3.SetOriginFormat(ge::FORMAT_ND); - tensor_desc3.SetOriginDataType(ge::DT_FLOAT); - tensor_desc3.SetOriginShape(shape2); - - x->AddOutputDesc(tensor_desc1); - weight->AddOutputDesc(tensor_desc2); - atquant_scale->AddOutputDesc(tensor_desc1); - quant_scale->AddOutputDesc(tensor_desc3); - quant_offset->AddOutputDesc(tensor_desc3); - - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc2); - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc3); - mm->AddInputDesc(tensor_desc3); - mm->AddOutputDesc(tensor_desc2); - y->AddInputDesc(tensor_desc2); - - // create nodes - NodePtr x_node = graph->AddNode(x); - NodePtr weight_node = graph->AddNode(weight); - NodePtr atquant_scale_node = graph->AddNode(atquant_scale); - NodePtr quant_scale_node = graph->AddNode(quant_scale); - NodePtr quant_offset_node = graph->AddNode(quant_offset); - NodePtr mm_node = graph->AddNode(mm); - NodePtr y_node = graph->AddNode(y); - - // link edge - ge::GraphUtils::AddEdge(x_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(weight_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(atquant_scale_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(quant_scale_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(3)); - ge::GraphUtils::AddEdge(quant_offset_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(4)); - ge::GraphUtils::AddEdge(mm_node->GetOutDataAnchor(0), - y_node->GetInDataAnchor(0)); - return graph; - } - - static ge::ComputeGraphPtr CreateTestGraphWithOffset2() { - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr x = std::make_shared("x", "Data"); - OpDescPtr weight = std::make_shared("weight", "Const"); - OpDescPtr atquant_scale = std::make_shared("atquant_scale", "Const"); - OpDescPtr quant_scale = std::make_shared("quant_scale", "Const"); - OpDescPtr quant_offset = std::make_shared("quant_offset", "Const"); - OpDescPtr mm = std::make_shared("mm", "WeightQuantBatchMatmulV2"); - OpDescPtr y = std::make_shared("y", "NetOutput"); - - // add descriptor - ge::GeShape shape1({2,4,9,16}); - GeTensorDesc tensor_desc1(shape1, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tensor_desc1.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc1.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - - GeTensorDesc tensor_desc2(shape1, ge::FORMAT_NCHW, ge::DT_INT8); - tensor_desc2.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc2.SetOriginDataType(ge::DT_INT8); - tensor_desc2.SetOriginShape(shape1); - - ge::GeShape shape2({16}); - GeTensorDesc tensor_desc3(shape2, ge::FORMAT_ND, ge::DT_FLOAT); - tensor_desc3.SetOriginFormat(ge::FORMAT_ND); - tensor_desc3.SetOriginDataType(ge::DT_FLOAT); - tensor_desc3.SetOriginShape(shape2); - - x->AddOutputDesc(tensor_desc1); - weight->AddOutputDesc(tensor_desc2); - atquant_scale->AddOutputDesc(tensor_desc1); - quant_scale->AddOutputDesc(tensor_desc3); - quant_offset->AddOutputDesc(tensor_desc3); - - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc2); - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc3); - mm->AddInputDesc(tensor_desc3); - mm->AddOutputDesc(tensor_desc2); - y->AddInputDesc(tensor_desc2); - - // create nodes - NodePtr x_node = graph->AddNode(x); - NodePtr weight_node = graph->AddNode(weight); - NodePtr atquant_scale_node = graph->AddNode(atquant_scale); - NodePtr quant_scale_node = graph->AddNode(quant_scale); - NodePtr quant_offset_node = graph->AddNode(quant_offset); - NodePtr mm_node = graph->AddNode(mm); - NodePtr y_node = graph->AddNode(y); - - // link edge - ge::GraphUtils::AddEdge(x_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(weight_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(atquant_scale_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(quant_scale_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(3)); - ge::GraphUtils::AddEdge(quant_offset_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(4)); - ge::GraphUtils::AddEdge(mm_node->GetOutDataAnchor(0), - y_node->GetInDataAnchor(0)); - return graph; - } - - static ge::ComputeGraphPtr CreateTestGraphWithoutOffset() { - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr x = std::make_shared("x", "Data"); - OpDescPtr weight = std::make_shared("weight", "Const"); - OpDescPtr atquant_scale = std::make_shared("atquant_scale", "Const"); - OpDescPtr quant_scale = std::make_shared("quant_scale", "Const"); - OpDescPtr mm = std::make_shared("mm", "WeightQuantBatchMatmulV2"); - OpDescPtr y = std::make_shared("y", "NetOutput"); - - // add descriptor - ge::GeShape shape1({2,4,9,16}); - GeTensorDesc tensor_desc1(shape1, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tensor_desc1.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc1.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - - GeTensorDesc tensor_desc2(shape1, ge::FORMAT_NCHW, ge::DT_INT8); - tensor_desc2.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc2.SetOriginDataType(ge::DT_INT8); - tensor_desc2.SetOriginShape(shape1); - - ge::GeShape shape2({1, 16}); - GeTensorDesc tensor_desc3(shape2, ge::FORMAT_ND, ge::DT_FLOAT); - tensor_desc3.SetOriginFormat(ge::FORMAT_ND); - tensor_desc3.SetOriginDataType(ge::DT_FLOAT); - tensor_desc3.SetOriginShape(shape2); - - x->AddOutputDesc(tensor_desc1); - weight->AddOutputDesc(tensor_desc2); - atquant_scale->AddOutputDesc(tensor_desc1); - quant_scale->AddOutputDesc(tensor_desc3); - - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc2); - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc3); - mm->AddOutputDesc(tensor_desc2); - y->AddInputDesc(tensor_desc2); - - // create nodes - NodePtr x_node = graph->AddNode(x); - NodePtr weight_node = graph->AddNode(weight); - NodePtr atquant_scale_node = graph->AddNode(atquant_scale); - NodePtr quant_scale_node = graph->AddNode(quant_scale); - NodePtr mm_node = graph->AddNode(mm); - NodePtr y_node = graph->AddNode(y); - - // link edge - ge::GraphUtils::AddEdge(x_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(weight_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(atquant_scale_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(quant_scale_node->GetOutDataAnchor(0), - mm_node->GetInDataAnchor(3)); - ge::GraphUtils::AddEdge(mm_node->GetOutDataAnchor(0), - y_node->GetInDataAnchor(0)); - return graph; - } - - static void FillWeightValue2(const ge::ComputeGraphPtr &graph) { - for (const ge::NodePtr &node : graph->GetDirectNode()) { - if (node == nullptr) { - continue; - } - if (node->GetType() != "Const") { - continue; - } - std::vector weights = ge::OpDescUtils::GetWeights(node); - if (weights.empty()) { - ge::ConstGeTensorDescPtr out_tensor = node->GetOpDesc()->GetOutputDescPtr(0); - int64_t shape_size = out_tensor->GetShape().GetShapeSize(); - if (shape_size <= 0) { - continue; - } - ge::GeTensorPtr weight = std::make_shared(*out_tensor); - if (node->GetName() == "quant_scale") { - vector data_vec = - {-2.7065194, -4.7495637, 2.5856478, 2.533566 , -2.7307642, 0.08650689, 1.2195834, -4.520703, - -4.902806, -4.9793777 , -3.8038466 , 4.6814585, -0.8230759, 1.4473673, 4.71265, 2.3249402}; - weight->SetData(reinterpret_cast(data_vec.data()), shape_size * sizeof(float)); - ge::OpDescUtils::SetWeights(node->GetOpDesc(), weight); - continue; - } - if (node->GetName() == "quant_offset") { - std::cout << "mmm quant_offset" << std::endl; - vector data_vec = - {1.7815902, -0.83771265, 3.8743427, -1.129952, 3.348905, 4.898297, 2.8627427, -4.685532, - -1.0928544, 0.0128879, 3.988301, -4.4012594, -0.15809901, 1.5274582, 3.3731332, -0.75769955}; - weight->SetData(reinterpret_cast(data_vec.data()), shape_size * sizeof(float)); - ge::OpDescUtils::SetWeights(node->GetOpDesc(), weight); - continue; - } - if (out_tensor->GetDataType() == ge::DT_UINT32 || out_tensor->GetDataType() == ge::DT_INT32 || - out_tensor->GetDataType() == ge::DT_FLOAT) { - vector data_vec(shape_size, 1); - weight->SetData(reinterpret_cast(data_vec.data()), shape_size * sizeof(int32_t)); - } - if (out_tensor->GetDataType() == ge::DT_UINT64 || out_tensor->GetDataType() == ge::DT_INT64 || - out_tensor->GetDataType() == ge::DT_DOUBLE) { - vector data_vec(shape_size, 1); - weight->SetData(reinterpret_cast(data_vec.data()), shape_size * sizeof(int64_t)); - } - if (out_tensor->GetDataType() == ge::DT_UINT16 || out_tensor->GetDataType() == ge::DT_INT16 || - out_tensor->GetDataType() == ge::DT_FLOAT16) { - vector data_vec(shape_size, 1); - weight->SetData(reinterpret_cast(data_vec.data()), shape_size * sizeof(int16_t)); - } - if (out_tensor->GetDataType() == ge::DT_UINT8 || out_tensor->GetDataType() == ge::DT_INT8) { - vector data_vec(shape_size, 1); - weight->SetData(reinterpret_cast(data_vec.data()), shape_size * sizeof(int8_t)); - } - ge::OpDescUtils::SetWeights(node->GetOpDesc(), weight); - } - } - } -}; - -TEST_F(FusionQuantUtilUT, insert_quant_op_succ) { - ComputeGraphPtr graph = CreateTestGraphWithOffset(); - FillWeightValue2(graph); - ge::NodePtr mm_node = graph->FindNode("mm"); - InDataAnchorPtr cuba_bias = mm_node->GetInDataAnchor(1); - InDataAnchorPtr quant_scale = mm_node->GetInDataAnchor(3); - InDataAnchorPtr quant_offset = mm_node->GetInDataAnchor(4); - std::vector fusion_nodes; - Status ret = QuantUtil::InsertQuantScaleConvert(quant_scale, quant_offset, fusion_nodes); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(FusionQuantUtilUT, insert_quant_op_succ2) { - ComputeGraphPtr graph = CreateTestGraphWithoutOffset(); - FillWeightValue2(graph); - ge::NodePtr mm_node = graph->FindNode("mm"); - InDataAnchorPtr cuba_bias = mm_node->GetInDataAnchor(1); - InDataAnchorPtr quant_scale = mm_node->GetInDataAnchor(3); - InDataAnchorPtr quant_offset = mm_node->GetInDataAnchor(4); - std::vector fusion_nodes; - Status ret = QuantUtil::InsertQuantScaleConvert(quant_scale, quant_offset, fusion_nodes); - EXPECT_EQ(ret, SUCCESS); -} - -TEST_F(FusionQuantUtilUT, insert_requant_op_succ) { - ComputeGraphPtr graph = CreateTestGraphWithOffset2(); - FillWeightValue2(graph); - ge::NodePtr mm_node = graph->FindNode("mm"); - InDataAnchorPtr cuba_bias = mm_node->GetInDataAnchor(1); - InDataAnchorPtr quant_scale = mm_node->GetInDataAnchor(3); - InDataAnchorPtr quant_offset = mm_node->GetInDataAnchor(4); - std::vector fusion_nodes; - Status ret = QuantUtil::InsertRequantScaleConvert(quant_scale, quant_offset, cuba_bias, fusion_nodes); - EXPECT_EQ(ret, SUCCESS); -} -} diff --git a/tests/ut/register/testcase/fusion_statistics.cc b/tests/ut/register/testcase/fusion_statistics.cc deleted file mode 100644 index 90db447bce812efa46a0e102e4026af49cbdd672..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/fusion_statistics.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/graph.h" -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "inc/graph/operator_factory_impl.h" -#include "graph/utils/op_desc_utils.h" -#include "graph_builder_utils.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" -#include "register/graph_optimizer/fusion_common/fusion_statistic_recorder.h" - -using namespace ge; -class UtestFusionStatistics : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - - - -TEST_F(UtestFusionStatistics, test_01) { - auto &fs_instance = fe::FusionStatisticRecorder::Instance(); - fe::FusionInfo fusion_info(0, "", "test_pass"); - - fs_instance.UpdateGraphFusionMatchTimes(fusion_info); - - fs_instance.UpdateGraphFusionEffectTimes(fusion_info); - - fs_instance.UpdateBufferFusionMatchTimes(fusion_info); - - string session_graph_id = "0_1"; - std::map graph_fusion_info_map; - std::map buffer_fusion_info_map; - fs_instance.GetAndClearFusionInfo(session_graph_id, graph_fusion_info_map, - buffer_fusion_info_map); - - fs_instance.GetFusionInfo(session_graph_id, graph_fusion_info_map, - buffer_fusion_info_map); - - fs_instance.ClearFusionInfo(session_graph_id); - - std::vector session_graph_id_vec = {session_graph_id}; - EXPECT_NO_THROW(fs_instance.GetAllSessionAndGraphIdList(session_graph_id_vec)); -} - -TEST_F(UtestFusionStatistics, test_02) { - auto &fs_instance = fe::FusionStatisticRecorder::Instance(); - fe::FusionInfo fusion_info(0, "", "test_pass"); - fusion_info.SetEffectTimes(2); - fusion_info.SetMatchTimes(2); - fusion_info.AddEffectTimes(1); - fusion_info.AddMatchTimes(1); - fusion_info.GetEffectTimes(); - fusion_info.GetMatchTimes(); - fusion_info.GetGraphId(); - fusion_info.GetPassName(); - fusion_info.GetSessionId(); - fusion_info.SetRepoHitTimes(5); - fusion_info.GetRepoHitTimes(); - - fs_instance.UpdateGraphFusionMatchTimes(fusion_info); - - fs_instance.UpdateGraphFusionEffectTimes(fusion_info); - - fs_instance.UpdateBufferFusionMatchTimes(fusion_info); - - fs_instance.UpdateBufferFusionEffectTimes(fusion_info); - string session_graph_id = "0_1"; - std::map graph_fusion_info_map; - std::map buffer_fusion_info_map; - fs_instance.GetAndClearFusionInfo(session_graph_id, graph_fusion_info_map, - buffer_fusion_info_map); - - fs_instance.GetFusionInfo(session_graph_id, graph_fusion_info_map, - buffer_fusion_info_map); - - fs_instance.ClearFusionInfo(session_graph_id); - - std::vector session_graph_id_vec = {session_graph_id}; - EXPECT_NO_THROW(fs_instance.GetAllSessionAndGraphIdList(session_graph_id_vec)); -} - - - - diff --git a/tests/ut/register/testcase/graph_fusion_clycle_detection_unittest.cc b/tests/ut/register/testcase/graph_fusion_clycle_detection_unittest.cc deleted file mode 100644 index de211b8a167e57f2003bd216044a5599e90dfb97..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/graph_fusion_clycle_detection_unittest.cc +++ /dev/null @@ -1,454 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/graph.h" -#include "graph/compute_graph.h" -#include "graph/utils/graph_utils.h" -#include "inc/graph/operator_factory_impl.h" -#include "graph/utils/op_desc_utils.h" -#include "graph_builder_utils.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" -#include "graph/utils/connection_matrix.h" -#include "graph/utils/connection_matrix_impl.h" - -using namespace ge; -class UtestCycleDetection : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -class FusionTestPass : public fe::PatternFusionBasePass { - public: - FusionTestPass() {}; - ~FusionTestPass() override {}; - vector DefinePatterns() override {return {nullptr};}; - fe::Status Fusion(ComputeGraph &graph, Mapping &mapping, - vector &new_nodes) override { - - vector> fusion_nodes; - bool ret = CycleDetection(graph, fusion_nodes); - EXPECT_EQ(ret, false); - - vector scope_nodes; - for (auto &node : graph.GetDirectNode()) { - if (std::find(new_nodes.begin(), new_nodes.end(), node) != new_nodes.end()) { - scope_nodes.emplace_back(node); - } - } - fusion_nodes.emplace_back(scope_nodes); - - ret = CycleDetection(graph, fusion_nodes); - if (ret) { - return fe::NOT_CHANGED; - } else { - return fe::SUCCESS; - } - } -}; - - -/* A - * / \ - * B \ - * / \ - * D------->C - * | | - * After fusion A/B/C, the graph looks like: - * <--- - * / \ - * ABC--->D */ -static ComputeGraphPtr BuildFusionGraph01(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 2, 1); - auto d = builder.AddNode("D", "D", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 2, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, d, 0); - builder.AddDataEdge(d, 0, c, 1); - - builder.AddDataEdge(a, 0, c, 0); - builder.AddDataEdge(c, 0, netoutput, 0); - builder.AddDataEdge(d, 0, netoutput, 1); - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c}; - return graph; -} -using Mapping = std::map, std::vector, fe::CmpKey>; -TEST_F(UtestCycleDetection, cycle_detection_01) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph01(fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::NOT_CHANGED); -} - - -/* A - * / \ - * B \ - * / \ - * D C - * \ / - * Netoutput - * After fusion A/B/C, the graph looks like: - * - * ABC--->D - * \ / - * Netoutput - * No cycle will be generated if fusing. */ -static ComputeGraphPtr BuildFusionGraph02(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 2, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, d, 0); - - builder.AddDataEdge(a, 0, c, 0); - builder.AddDataEdge(c, 0, netoutput, 0); - builder.AddDataEdge(d, 0, netoutput, 1); - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c}; - return graph; -} - -TEST_F(UtestCycleDetection, cycle_detection_02) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph02(fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::SUCCESS); -} - -/* A--->B---->C---->D - * \-----E-------/ - * - * A, B, C, D will be fused. - * Cycle will be generated if fusing. - */ -static ComputeGraphPtr BuildFusionGraph03(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 2, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, c, 0); - - builder.AddDataEdge(c, 0, d, 0); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(e, 0, d, 1); - builder.AddDataEdge(d, 0, netoutput, 0); - - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c, d}; - return graph; -} - -TEST_F(UtestCycleDetection, cycle_detection_03) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph03(fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::NOT_CHANGED); -} - -/* A--->B---->C------->D - * \-----E---F------/ - * - * A, B, C, D will be fused. - * Cycle will be generated if fusing. - */ -static ComputeGraphPtr BuildFusionGraph04(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 2, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto f = builder.AddNode("F", "F", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, c, 0); - builder.AddDataEdge(c, 0, d, 0); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(e, 0, f, 0); - builder.AddDataEdge(f, 0, d, 1); - - builder.AddDataEdge(d, 0, netoutput, 0); - auto graph = builder.GetGraph(); - fusion_nodes = {a, b, c, d}; - return graph; -} - -TEST_F(UtestCycleDetection, cycle_detection_04) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph04(fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::NOT_CHANGED); -} - -/* A--->B---->C------->D - * \-----E---F------/ - * - * B/C will be fused. - * No Cycle will be generated if fusing. - */ -static ComputeGraphPtr BuildFusionGraph05(std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 1); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 2, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto f = builder.AddNode("F", "F", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 1, 0); - - builder.AddDataEdge(a, 0, b, 0); - builder.AddDataEdge(b, 0, c, 0); - builder.AddDataEdge(c, 0, d, 0); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(e, 0, f, 0); - builder.AddDataEdge(f, 0, d, 0); - - builder.AddDataEdge(d, 0, netoutput, 0); - auto graph = builder.GetGraph(); - fusion_nodes = {b, c}; - return graph; -} - -TEST_F(UtestCycleDetection, cycle_detection_05) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph05(fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::SUCCESS); -} - -const int kContainCycle = 0; -const int kNoCycleCase1 = 1; -const int kNoCycleCase2 = 2; -const int kNoCycleCase3 = 3; -/* - * /-----H----------------\ - * /------G---------\ \ - * / /------I------\ \ - * A--->B---->C------->D---NetOutput - * \------E---F------------/ - * - * B/C will be fused. - * No Cycle will be generated if fusing. - */ -ComputeGraphPtr CreateGraph06(int case_num, std::vector &fusion_nodes) { - ut::GraphBuilder builder = ut::GraphBuilder("fusion_graph"); - auto a = builder.AddNode("A", "A", 0, 4); - auto b = builder.AddNode("B", "B", 1, 1); - auto c = builder.AddNode("C", "C", 1, 1); - auto d = builder.AddNode("D", "D", 3, 1); - auto e = builder.AddNode("E", "E", 1, 1); - auto f = builder.AddNode("F", "F", 1, 1); - auto g = builder.AddNode("G", "G", 1, 1); - auto h = builder.AddNode("H", "H", 1, 1); - auto i = builder.AddNode("I", "I", 1, 1); - auto netoutput = builder.AddNode("NetOutput", "NetOutput", 3, 0); - - builder.AddControlEdge(a, b); - builder.AddDataEdge(a, 0, e, 0); - builder.AddDataEdge(a, 1, g, 0); - builder.AddDataEdge(a, 2, h, 0); - builder.AddDataEdge(h, 0, netoutput, 0); - - builder.AddDataEdge(b, 0, c, 0); - builder.AddDataEdge(b, 0, i, 0); - builder.AddDataEdge(i, 0, d, 0); - builder.AddDataEdge(c, 0, d, 1); - builder.AddDataEdge(d, 0, netoutput, 1); - - builder.AddDataEdge(g, 0, d, 2); - - builder.AddDataEdge(e, 0, f, 0); - builder.AddDataEdge(f, 0, netoutput, 3); - - auto graph = builder.GetGraph(); - if (case_num == kNoCycleCase1) { - fusion_nodes = {a, b, e, g, h}; - } else if (case_num == kContainCycle) { - fusion_nodes = {b, c, d}; - } else if (case_num == kNoCycleCase2) { - fusion_nodes = {b, c, i}; - } else if (case_num == kNoCycleCase3) { - fusion_nodes = {b, c, d, i}; - } - return graph; -} - - -static ComputeGraphPtr BuildFusionGraph06(int case_num, - std::vector &fusion_nodes) { - auto graph = CreateGraph06(case_num, fusion_nodes); - return graph; -} - -TEST_F(UtestCycleDetection, cycle_detection_06) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase1, fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::SUCCESS); -} - -TEST_F(UtestCycleDetection, cycle_detection_07) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kContainCycle, fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::NOT_CHANGED); -} - -TEST_F(UtestCycleDetection, cycle_detection_08) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase2, fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::SUCCESS); -} - -TEST_F(UtestCycleDetection, cycle_detection_09) { - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase2, fusion_nodes); - FusionTestPass pass; - Mapping mapping; - - fe::Status ret = pass.Fusion(*graph, mapping, fusion_nodes); - EXPECT_EQ(ret, fe::SUCCESS); -} - -TEST_F(UtestCycleDetection, Coverage_01) { - ge::LargeBitmap a(5); - a.SetValues(1); - - ge::LargeBitmap b(5); - b.SetValues(2); - - a.Or(b); - - ge::LargeBitmap c(5); - c.SetValues(3); - EXPECT_EQ(a == c, true); -} - -TEST_F(UtestCycleDetection, Coverage_02) { - ge::LargeBitmap a(5); - a.SetValues(1); - - ge::LargeBitmap b(5); - b.SetValues(2); - - a.And(b); - - ge::LargeBitmap c(5); - c.SetValues(0); - EXPECT_EQ(a == c, true); - - EXPECT_EQ(a != c, false); -} - -TEST_F(UtestCycleDetection, Coverage_03) { - ge::LargeBitmap a(6); - a.SetValues(1); - a.SetBit(10000); - a.GetBit(10000); - - ge::LargeBitmap b(5); - b.SetValues(2); - - a.And(b); - a.Or(b); - - ge::LargeBitmap c(5); - c.SetValues(0); - EXPECT_EQ(a == c, false); - - EXPECT_EQ(a != c, true); -} - -TEST_F(UtestCycleDetection, Coverage_04) { - EXPECT_NO_THROW( - std::vector fusion_nodes; - auto graph = BuildFusionGraph06(kNoCycleCase2, fusion_nodes); - auto connectivity = std::shared_ptr(new(std::nothrow) fe::ConnectionMatrix(*graph)); - connectivity->Generate(*graph); - connectivity->Update(*graph, fusion_nodes); - ); -} - -TEST_F(UtestCycleDetection, Coverage_05) { - EXPECT_NO_THROW( - FusionTestPass pass; - std::unique_ptr connection_matrix; - pass.GetConnectionMatrix(connection_matrix); - pass.SetConnectionMatrix(connection_matrix); - ); -} - -TEST_F(UtestCycleDetection, Coverage_06) { - EXPECT_NO_THROW( - auto graph = std::make_shared("test"); - auto connection_matrix = std::shared_ptr(new(std::nothrow) fe::ConnectionMatrix(*graph)); - - std::vector fusion_nodes; - BuildFusionGraph06(kNoCycleCase2, fusion_nodes); - connection_matrix->GetIndex(fusion_nodes[0]); - ); -} - -TEST_F(UtestCycleDetection, Coverage_07) { - EXPECT_NO_THROW( - auto graph = std::make_shared("test"); - auto connection_matrix = std::shared_ptr(new(std::nothrow) - ge::ConnectionMatrixImpl(graph)); - - std::vector fusion_nodes; - BuildFusionGraph06(kNoCycleCase2, fusion_nodes); - connection_matrix->GetIndex(fusion_nodes[0]); - ); -} - diff --git a/tests/ut/register/testcase/graph_fusion_turbo_unittest.cc b/tests/ut/register/testcase/graph_fusion_turbo_unittest.cc deleted file mode 100644 index a59d08b8d04d84e0406bbb11ecb4ddf0d204a5d5..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/graph_fusion_turbo_unittest.cc +++ /dev/null @@ -1,1940 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/node_utils.h" - -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/graph_optimizer/fusion_common/fusion_turbo.h" - -#include "external/graph/operator_factory.h" -#include "external/graph/operator_reg.h" -#include "graph/operator_factory_impl.h" -#include "graph/debug/ge_log.h" - -using namespace testing; -using namespace ge; -using namespace fe; - -namespace fe { -REG_OP(Data) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .ATTR(index, Int, 0) - .OP_END_FACTORY_REG(Data) - - -REG_OP(Const) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(value, Tensor, Tensor()) - .OP_END_FACTORY_REG(Const); - -REG_OP(Transpose) - .INPUT(x, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .INPUT(shape, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(axis, Int, 0) - .ATTR(num_axes, Int, -1) - .OP_END_FACTORY_REG(Transpose); - -REG_OP(Add) - .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OP_END_FACTORY_REG(Add) - -REG_OP(Relu) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, - DT_INT8, DT_INT32, DT_INT16, DT_INT64, - DT_UINT8, DT_UINT16, DT_QINT8})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, - DT_INT8, DT_INT32, DT_INT16, DT_INT64, - DT_UINT8, DT_UINT16, DT_QINT8})) - .OP_END_FACTORY_REG(Relu) - -REG_OP(Split) - .INPUT(split_dim, TensorType({DT_INT32})) - .INPUT(x, TensorType::BasicType()) - .DYNAMIC_OUTPUT(y, TensorType::BasicType()) - .REQUIRED_ATTR(num_split, Int) - .OP_END_FACTORY_REG(Split) - -REG_OP(Concat) - .DYNAMIC_INPUT(x, TensorType::BasicType()) - .INPUT(concat_dim, TensorType::IndexNumberType()) - .OUTPUT(y, TensorType::BasicType()) - .ATTR(N, Int, 1) - .OP_END_FACTORY_REG(Concat) - -REG_OP(RaggedTensorFromVariant) - .INPUT(encoded_ragged, TensorType({DT_VARIANT})) - .DYNAMIC_OUTPUT(output_nested_splits, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(output_dense_values, TensorType::BasicType()) - .REQUIRED_ATTR(input_ragged_rank, Int) - .REQUIRED_ATTR(output_ragged_rank, Int) - .REQUIRED_ATTR(Tvalues, Type) - .ATTR(Tsplits, Type, DT_INT64) - .OP_END_FACTORY_REG(RaggedTensorFromVariant) - -class UTestFusionTurbo : public testing::Test { - public: - - protected: - - - void SetUp() { - } - - void TearDown() { - } - - ge::NodePtr GetNode(ComputeGraphPtr &graph, const string &name) { - for (auto &node : graph->GetDirectNode()) { - if (node->GetName() == name) { - return node; - } - } - return nullptr; - } - - ComputeGraphPtr CreateGraphSingleInAndOut() { - ComputeGraphPtr graph = std::make_shared("test1"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_relu = std::make_shared("relu", "Relu"); - OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - //vector dim_d; - GeShape shape_d(dim_a); - GeTensorDesc tensor_desc_d(shape_d); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT16); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_relu->AddInputDesc(tensor_desc_b); - op_desc_relu->AddOutputDesc(tensor_desc_c); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_output->AddInputDesc(tensor_desc_d); - - ge::AttrUtils::SetStr(op_desc_relu, "_op_compile_strategy", "{}"); - ge::AttrUtils::SetInt(op_desc_relu, "_keep_dtype", 1); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_relu = graph->AddNode(op_desc_relu); - NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_other = graph->AddNode(op_desc_input); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - return graph; - } - - ComputeGraphPtr CreateGraphParentAndSub() { - - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_add1 = std::make_shared("add1", "Add"); - OpDescPtr op_desc_partcall = std::make_shared("partioncall", "PartionCall"); - OpDescPtr op_desc_partout = std::make_shared("partout", "PartionOut"); - OpDescPtr op_desc_add2 = std::make_shared("add2", "Add"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_output1 = std::make_shared("output1", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - OpDescPtr op_desc_input1 = std::make_shared("other1", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - //vector dim_d; - GeShape shape_d(dim_a); - GeTensorDesc tensor_desc_d(shape_d); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT16); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - GeShape shape_e(dim_a); - GeTensorDesc tensor_desc_e(shape_e); - tensor_desc_e.SetFormat(FORMAT_NCHW); - tensor_desc_e.SetOriginFormat(FORMAT_NCHW); - tensor_desc_e.SetDataType(DT_FLOAT16); - tensor_desc_e.SetOriginDataType(DT_FLOAT); - - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_add1->AddInputDesc(tensor_desc_b); - op_desc_add1->AddInputDesc(tensor_desc_b); - op_desc_add1->AddOutputDesc(tensor_desc_c); - - op_desc_partcall->AddInputDesc(tensor_desc_c); - op_desc_partcall->AddOutputDesc(tensor_desc_d); - op_desc_partcall->AddOutputDesc(tensor_desc_d); - op_desc_partout->AddInputDesc(tensor_desc_d); - - op_desc_add2->AddInputDesc(tensor_desc_d); - op_desc_add2->AddInputDesc(tensor_desc_d); - op_desc_add2->AddOutputDesc(tensor_desc_e); - - op_desc_input1->AddOutputDesc(tensor_desc_d); - - op_desc_output->AddInputDesc(tensor_desc_e); - op_desc_output1->AddInputDesc(tensor_desc_e); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_add1 = graph->AddNode(op_desc_add1); - NodePtr node_add2 = graph->AddNode(op_desc_add2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_netoutput1 = graph->AddNode(op_desc_output1); - NodePtr node_other = graph->AddNode(op_desc_input); - NodePtr node_partcall = graph->AddNode(op_desc_partcall); - NodePtr node_partout = graph->AddNode(op_desc_partout); - NodePtr node_other1 = graph->AddNode(op_desc_input1); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_add1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_add1->GetInDataAnchor(1)); - GraphUtils::AddEdge(node_add1->GetOutDataAnchor(0), node_partcall->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_partcall->GetOutDataAnchor(0), node_partout->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_partcall->GetOutDataAnchor(1), node_add2->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_other1->GetOutDataAnchor(0), node_add2->GetInDataAnchor(1)); - GraphUtils::AddEdge(node_add2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_add2->GetOutDataAnchor(0), node_netoutput1->GetInDataAnchor(0)); - - // subgraph - ComputeGraphPtr subgraph = std::make_shared("subgraph"); - OpDescPtr op_desc_sub_data1 = std::make_shared("data1", "Data"); - OpDescPtr op_desc_sub_input = std::make_shared("other", "Other"); - OpDescPtr op_desc_sub_add = std::make_shared("sub_add", "Add"); - OpDescPtr op_desc_net_in = std::make_shared("net_in", "NetOutInput"); - OpDescPtr op_desc_sub_output = std::make_shared("output", "NetOutput"); - - op_desc_sub_data1->AddInputDesc(tensor_desc_c); - op_desc_sub_data1->AddOutputDesc(tensor_desc_c); - op_desc_sub_input->AddOutputDesc(tensor_desc_c); - - op_desc_sub_add->AddInputDesc(tensor_desc_c); - op_desc_sub_add->AddInputDesc(tensor_desc_c); - op_desc_sub_add->AddOutputDesc(tensor_desc_d); - - op_desc_net_in->AddOutputDesc(tensor_desc_d); - op_desc_sub_output->AddInputDesc(tensor_desc_d); - op_desc_sub_output->AddInputDesc(tensor_desc_d); - - NodePtr sub_data_node1 = subgraph->AddNode(op_desc_sub_data1); - NodePtr sub_input_node = subgraph->AddNode(op_desc_sub_input); - NodePtr sub_add_node = subgraph->AddNode(op_desc_sub_add); - NodePtr sub_netin_node = subgraph->AddNode(op_desc_net_in); - NodePtr sub_sub_output_node = subgraph->AddNode(op_desc_sub_output); - - GraphUtils::AddEdge(sub_data_node1->GetOutDataAnchor(0), sub_add_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(sub_input_node->GetOutDataAnchor(0), sub_add_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(sub_add_node->GetOutDataAnchor(0), sub_sub_output_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(sub_netin_node->GetOutDataAnchor(0), sub_sub_output_node->GetInDataAnchor(1)); - ge::AttrUtils::SetInt(sub_data_node1->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, 0); - ge::AttrUtils::SetInt(sub_sub_output_node->GetOpDesc()->MutableInputDesc(0), ge::ATTR_NAME_PARENT_NODE_INDEX, 0); - ge::AttrUtils::SetInt(sub_sub_output_node->GetOpDesc()->MutableInputDesc(1), ge::ATTR_NAME_PARENT_NODE_INDEX, 1); - node_partcall->GetOpDesc()->AddSubgraphName("subgraph1"); - ge::NodeUtils::SetSubgraph(*node_partcall, 0, subgraph); - subgraph->SetParentNode(node_partcall); - return graph; - } - - static void DumpGraph(const ge::ComputeGraphPtr graph, string graph_name) { - printf("start to dump graph %s...\n", graph_name.c_str()); - for (ge::NodePtr node : graph->GetAllNodes()) { - printf("node name = %s.\n", node->GetName().c_str()); - for (ge::OutDataAnchorPtr anchor : node->GetAllOutDataAnchors()) { - for (ge::InDataAnchorPtr peer_in_anchor : anchor->GetPeerInDataAnchors()) { - printf(" node name = %s[%d], out data node name = %s[%d].\n", - node->GetName().c_str(), - anchor->GetIdx(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_anchor->GetIdx()); - } - } - if (node->GetOutControlAnchor() != nullptr) { - for (ge::InControlAnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { - printf(" node name = %s, out control node name = %s.\n", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str()); - } - } - } - } - -}; - -TEST_F(UTestFusionTurbo, test_case_01) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - - auto cast2 = GetNode(graph, "cast2"); - auto node = acc.InsertNodeBefore(name, type, cast2, 0); - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - - -TEST_F(UTestFusionTurbo, test_case_01_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - - auto cast2 = GetNode(graph, "cast2"); - acc.BreakInput(cast2, {0}); - - auto node = acc.InsertNodeBefore(name, type, cast2, 0); - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims_null = {}; - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims_null); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor(), nullptr); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - -TEST_F(UTestFusionTurbo, test_case_02) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - - auto relu = GetNode(graph, "relu"); - auto node = acc.InsertNodeAfter(name, type, relu, 0); - - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - - -TEST_F(UTestFusionTurbo, test_case_02_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - - auto relu = GetNode(graph, "relu"); - acc.BreakAllOutput(relu); - auto node = acc.InsertNodeAfter(name, type, relu, 0); - - ASSERT_NE(node, nullptr); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - vector dims_null = {}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims_null); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 0); -} - -TEST_F(UTestFusionTurbo, test_concat_and_split_node) { - auto graph = std::make_shared("test1"); - FusionTurbo acc(graph); - string name1 = "split"; - string type1 = "Split"; - auto split = acc.AddNodeOnly(name1, type1, 32); - - string name2 = "concat"; - string type2 = "Concat"; - auto concat = acc.AddNodeOnly(name2, type2, 32); - ASSERT_NE(split, nullptr); - ASSERT_NE(concat, nullptr); - - auto split_input_size = split->GetOpDesc()->GetAllInputsSize(); - EXPECT_EQ(split_input_size, 2); - - auto split_output_size = split->GetOpDesc()->GetOutputsSize(); - EXPECT_EQ(split_output_size, 32); - - auto concat_input_size = concat->GetOpDesc()->GetAllInputsSize(); - EXPECT_EQ(concat_input_size, 33); - - auto concat_output_size = concat->GetOpDesc()->GetOutputsSize(); - EXPECT_EQ(concat_output_size, 1); - - Relations output_relation; - for (size_t i = 0; i < 32; ++i) { - output_relation.Add(i, {concat, static_cast(i)}); - } - acc.LinkOutput(output_relation, split); - auto out_data_nodes = split->GetOutDataNodes(); - ASSERT_EQ(out_data_nodes.size(), 32); - for (size_t i = 0; i < 32; ++i) { - EXPECT_EQ(out_data_nodes.at(i)->GetName(), "concat"); - } - - auto split_input_0_name = split->GetOpDesc()->GetInputNameByIndex(0); - EXPECT_EQ(split_input_0_name, "split_dim"); - - auto split_input_1_name = split->GetOpDesc()->GetInputNameByIndex(1); - EXPECT_EQ(split_input_1_name, "x"); - - for (size_t i = 0; i < 32; ++i) { - auto name = concat->GetOpDesc()->GetInputNameByIndex(i); - EXPECT_EQ(name, "x" + std::to_string(i)); - } - - for (size_t i = 0; i < 32; ++i) { - auto name = split->GetOpDesc()->GetOutputNameByIndex(i); - EXPECT_EQ(name, "y" + std::to_string(i)); - } -} - -TEST_F(UTestFusionTurbo, test_concat_and_ragged_node) { - auto graph = std::make_shared("test1"); - FusionTurbo acc(graph); - string name1 = "ragged"; - string type1 = "RaggedTensorFromVariant"; - auto ragged = acc.AddNodeOnly(name1, type1, 32); - - string name2 = "concat"; - string type2 = "Concat"; - auto concat = acc.AddNodeOnly(name2, type2, 32); - ASSERT_NE(ragged, nullptr); - ASSERT_NE(concat, nullptr); - - auto split_input_size = ragged->GetOpDesc()->GetAllInputsSize(); - EXPECT_EQ(split_input_size, 1); - - auto split_output_size = ragged->GetOpDesc()->GetOutputsSize(); - EXPECT_EQ(split_output_size, 33); - - auto concat_input_size = concat->GetOpDesc()->GetAllInputsSize(); - EXPECT_EQ(concat_input_size, 33); - - auto concat_output_size = concat->GetOpDesc()->GetOutputsSize(); - EXPECT_EQ(concat_output_size, 1); - - Relations output_relation; - for (size_t i = 1; i < 33; ++i) { - output_relation.Add(i, {concat, static_cast(i)}); - } - acc.LinkOutput(output_relation, ragged); - auto out_data_nodes = ragged->GetOutDataNodes(); - ASSERT_EQ(out_data_nodes.size(), 32); - for (size_t i = 0; i < 32; ++i) { - EXPECT_EQ(out_data_nodes.at(i)->GetName(), "concat"); - } - - for (size_t i = 0; i < 32; ++i) { - auto name = concat->GetOpDesc()->GetInputNameByIndex(i); - EXPECT_EQ(name, "x" + std::to_string(i)); - } - - for (size_t i = 0; i < 32; ++i) { - auto name = ragged->GetOpDesc()->GetOutputNameByIndex(i); - EXPECT_EQ(name, "output_nested_splits" + std::to_string(i)); - } - auto output_32_name = ragged->GetOpDesc()->GetOutputNameByIndex(32); - EXPECT_EQ(output_32_name, "output_dense_values"); -} - -TEST_F(UTestFusionTurbo, test_case_03) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node1 = acc.AddNodeOnly(name, type); - auto node2 = acc.AddNodeOnly(name, type); - ASSERT_NE(node1, nullptr); - ASSERT_NE(node2, nullptr); -} - -/* cast2 already has input so Transpose will not have peer out. */ -TEST_F(UTestFusionTurbo, test_case_04) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); -} - -TEST_F(UTestFusionTurbo, test_case_04_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list; - src_list.Add(0, {relu, 0}); - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - dst_list.Add(0, {cast2, 0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); -} - -TEST_F(UTestFusionTurbo, test_case_04_2) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list; - src_list.Add(0, {{relu, 0}}); - acc.LinkInput(src_list, node); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list; - dst_list.Add(0, {{cast2, 0}}); - Status ret = acc.LinkOutput(dst_list, node); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); -} - - -TEST_F(UTestFusionTurbo, test_case_04_3) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node); - - auto cast2 = GetNode(graph, "cast2"); - auto output_node = GetNode(graph, "output"); - Relations dst_list = {{cast2, 0}, - {output_node, 0}}; - Status ret = FusionTurbo::LinkOutput(dst_list, node); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); -} - -TEST_F(UTestFusionTurbo, test_case_04_4) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - auto input = node->GetOpDesc()->MutableInputDesc(0); - input->SetFormat(ge::FORMAT_ND); - input->SetDataType(ge::DT_UNDEFINED); - auto output = node->GetOpDesc()->MutableOutputDesc(0); - output->SetFormat(ge::FORMAT_ND); - output->SetDataType(ge::DT_UNDEFINED); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_PEER); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - Status ret = acc.LinkOutput(dst_list, node, UPDATE_PEER); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - vector dims_null = {}; - - EXPECT_EQ(input->GetShape().GetDims(), dims_null); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(input->GetDataType(), ge::DT_UNDEFINED); - - EXPECT_EQ(output->GetShape().GetDims(), dims_null); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(output->GetDataType(), ge::DT_UNDEFINED); - - auto relu_input = relu->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(relu_input->GetShape().GetDims(), dims); - EXPECT_EQ(relu_input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(relu_input->GetDataType(), ge::DT_FLOAT); - - auto relu_output = relu->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(relu_output->GetShape().GetDims(), dims_null); - EXPECT_EQ(relu_output->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(relu_output->GetDataType(), ge::DT_UNDEFINED); - - auto cast_input = cast2->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(cast_input->GetShape().GetDims(), dims_null); - EXPECT_EQ(cast_input->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(cast_input->GetDataType(), ge::DT_UNDEFINED); - - vector dims_cast = {8, 4, 16, 16}; - auto cast_output = cast2->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(cast_output->GetShape().GetDims(), dims_cast); - EXPECT_EQ(cast_output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(cast_output->GetDataType(), ge::DT_FLOAT16); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); -} - -TEST_F(UTestFusionTurbo, test_case_05) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - ASSERT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - -TEST_F(UTestFusionTurbo, test_case_05_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list; - src_list.Add(0, {relu, 0}); - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list; - dst_list.Add(0, {cast2, 0}); - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - ASSERT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - -TEST_F(UTestFusionTurbo, test_case_05_2) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - auto op_desc = node->GetOpDesc(); - auto input0 = op_desc->GetInputDescPtr(0); - auto input1 = op_desc->GetInputDescPtr(1); - EXPECT_EQ(input0->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(input0->GetShape().GetDimNum(), 0); - EXPECT_EQ(input0->GetFormat(), ge::FORMAT_ND); - - auto relu = GetNode(graph, "relu"); - Relations src_list; - src_list.Add(0, {relu, 0}); - src_list.Add(0, {{relu, 0}, {relu, 0}}); - Relations src_list_1 = src_list; - - Relations dst_list; - dst_list.Add(0, {relu, 0, PEER}); - dst_list.Add(0, {{relu, 0, PEER}, {relu, 0, PEER}}); - - acc.LinkInput(src_list_1, node, UPDATE_THIS); - auto cast2 = GetNode(graph, "cast2"); - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - ASSERT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - -TEST_F(UTestFusionTurbo, test_case_06) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_NONE); - - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); -} - -TEST_F(UTestFusionTurbo, test_case_07) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - acc.BreakAllInput(cast2); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_NONE); - EXPECT_EQ(ret, SUCCESS); - - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - - vector dims = {}; - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - -TEST_F(UTestFusionTurbo, test_case_08) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - - auto cast2 = GetNode(graph, "cast2"); - acc.BreakOutput(cast2, {0}); - acc.BreakOutput(cast2, {1}); - EXPECT_EQ(acc.RemoveNodeWithRelink(cast2, {0}), SUCCESS); -} - -TEST_F(UTestFusionTurbo, test_case_09) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - - auto cast2 = GetNode(graph, "cast2"); - auto cast1 = GetNode(graph, "cast1"); - EXPECT_EQ(acc.RemoveNodeOnly(cast2), SUCCESS); - EXPECT_EQ(acc.RemoveNodeWithRelink(cast1, {0}), SUCCESS); -} - -TEST_F(UTestFusionTurbo, test_case_10) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - - auto cast2 = GetNode(graph, "cast2"); - cast2->GetOpDesc()->MutableOutputDesc(0)->SetShape(ge::GeShape({-1})); - cast2->GetOpDesc()->MutableInputDesc(0)->SetOriginShape(ge::GeShape({-1})); - EXPECT_EQ(false, acc.IsUnknownShape(cast2, 0)); - EXPECT_EQ(false, acc.IsUnknownShape(cast2, 0, true)); - EXPECT_EQ(true, acc.IsUnknownShape(cast2, 0, false)); - - EXPECT_EQ(true, acc.IsUnknownOriShape(cast2, 0)); - EXPECT_EQ(true, acc.IsUnknownOriShape(cast2, 0, true)); - EXPECT_EQ(false, acc.IsUnknownOriShape(cast2, 0, false)); -} - -TEST_F(UTestFusionTurbo, test_case_11) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list; - src_list.Add(0, {relu, 0}); - src_list.Add(0, {}); - EXPECT_EQ(SUCCESS, acc.LinkInput(src_list, node, UPDATE_NONE)); - - auto cast2 = GetNode(graph, "cast2"); - - Relations dst_list; - dst_list.Add(0, {cast2, 0}); - dst_list.Add(0, {}); - EXPECT_EQ(SUCCESS, acc.LinkOutput(dst_list, node, UPDATE_NONE)); - auto input = node->GetOpDesc()->MutableInputDesc(0); - - // Update input desc - vector dims = {}; - vector dims_new = {1, 4, 64, 64}; - EXPECT_EQ(SUCCESS, acc.UpdateInputByPeer(node, 0, relu, 0)); - EXPECT_EQ(input->GetShape().GetDims(), dims_new); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(input->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetOriginDataType(), ge::DT_FLOAT); - - input->SetShape(ge::GeShape(dims)); - input->SetFormat(ge::FORMAT_ND); - - acc.UpdateInputByPeer(node, 0, relu, 0); - EXPECT_EQ(input->GetShape().GetDims(), dims_new); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(input->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetOriginDataType(), ge::DT_FLOAT); - - input->SetShape(ge::GeShape(dims)); - input->SetFormat(ge::FORMAT_ND); - - acc.UpdateInputByPeer(node, 0 , relu, 0); - EXPECT_EQ(input->GetShape().GetDims(), dims_new); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(input->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetOriginDataType(), ge::DT_FLOAT); - - input->SetShape(ge::GeShape(dims)); - input->SetFormat(ge::FORMAT_ND); -} - -TEST_F(UTestFusionTurbo, test_case_12) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list; - NodeIndex pi = {relu, 0}; - src_list.Add(0, pi); - src_list.Add(0, pi); - acc.LinkInput(src_list, node, UPDATE_NONE); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list; - dst_list.Add(0, {{cast2, 0}}); - acc.LinkOutput(dst_list, node, UPDATE_NONE); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - auto cast1 = GetNode(graph, "cast1"); - // Update output desc - vector dims = {}; - vector dims_new = {8, 4, 16, 16}; - EXPECT_EQ(SUCCESS, acc.UpdateOutputByPeer(node, 0, cast1, 0)); - EXPECT_EQ(output->GetShape().GetDims(), dims_new); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(output->GetOriginDataType(), ge::DT_FLOAT); - output->SetShape(ge::GeShape(dims)); - output->SetFormat(ge::FORMAT_ND); - - acc.UpdateOutputByPeer(node, 0, cast1, 0); - EXPECT_EQ(output->GetShape().GetDims(), dims_new); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(output->GetOriginDataType(), ge::DT_FLOAT); - - output->SetShape(ge::GeShape(dims)); - output->SetFormat(ge::FORMAT_ND); - - acc.UpdateOutputByPeer(node, 0, cast1, 0); - EXPECT_EQ(output->GetShape().GetDims(), dims_new); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetOriginFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(output->GetOriginDataType(), ge::DT_FLOAT); - - output->SetShape(ge::GeShape(dims)); - output->SetFormat(ge::FORMAT_ND); -} - -TEST_F(UTestFusionTurbo, test_case_13) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list; - src_list.Add(0, {relu, 0}); - src_list.Add(0, {relu, 0}); - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list; - NodeIndex pi = {cast2, 0}; - dst_list.Add(0, pi); - dst_list.Add(0, pi); - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - WeightInfo w = {ge::GeShape({1, 2, 3, 4}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - /* coverage code */ - auto shape = ge::GeShape({1, 2, 3, 4}); - WeightInfo w1 = {shape, ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - WeightInfo w2 = {ge::GeShape({1, 2, 3, 4}), ge::GeShape({1, 2, 3, 4}), - ge::DT_INT32, ge::DT_INT32, ge::FORMAT_NCHW, ge::FORMAT_NCHW, - value.get()}; - - ASSERT_NE(nullptr, acc.AddWeight(node, w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 3); - EXPECT_EQ(node->GetInDataAnchor(2)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto new_weight = FusionTurbo::MutableWeight(node, 2); - ASSERT_NE(nullptr, new_weight); - - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t i = 0; i < 96; i++) { - EXPECT_EQ(data[i], i); - } -} - - -TEST_F(UTestFusionTurbo, test_case_13_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - WeightInfo w = {ge::GeShape({1, 2, 3, 4}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - - ASSERT_NE(nullptr, acc.AddWeight(node, 3, w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 3); - EXPECT_EQ(node->GetInDataAnchor(2)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto new_weight = FusionTurbo::MutableWeight(node, 2); - ASSERT_NE(nullptr, new_weight); - - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t i = 0; i < 96; i++) { - EXPECT_EQ(data[i], i); - } -} - -TEST_F(UTestFusionTurbo, test_case_14) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - - auto input1 = node->GetOpDesc()->MutableInputDesc(1); - input1->SetDataType(ge::DT_INT32); - input1->SetOriginDataType(ge::DT_INT16); - input1->SetFormat(ge::FORMAT_NCHW); - input1->SetOriginFormat(ge::FORMAT_NHWC); - input1->SetShape(ge::GeShape({1, 2, 3, 4})); - input1->SetOriginShape(ge::GeShape({1, 3, 4, 2})); - - WeightInfo w(node, 1, value.get()); - - ASSERT_NE(nullptr, acc.AddWeight(node, 1, w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 2); - EXPECT_EQ(node->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto new_weight = FusionTurbo::MutableWeight(node, 1); - ASSERT_NE(nullptr, new_weight); - - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginShape(), ge::GeShape({1, 3, 4, 2})); - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT16); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NHWC); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t i = 0; i < 96; i++) { - EXPECT_EQ(data[i], i); - } - - w.shape = ge::GeShape({5, 2, 7, 1}); - w.ori_shape = ge::GeShape({5, 1, 2, 7}); - w.datatype = ge::DT_FLOAT16; - w.ori_datatype = ge::DT_FLOAT; - w.format = ge::FORMAT_NHWC; - w.ori_format = ge::FORMAT_NCHW; - - w.total_data_size = 140; - unique_ptr value1(new(std::nothrow) int32_t[140]); - auto data_ptr1 = (uint8_t *) (value1.get()); - for (size_t i = 0; i < 140; i++) { - data_ptr1[i] = i + 1; - } - w.data = (uint8_t *) value1.get(); - /* Update const value and tensor when const node and weight both exist. */ - ASSERT_NE(nullptr, acc.AddWeight(node, 1, w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 2); - EXPECT_EQ(node->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - new_weight = FusionTurbo::MutableWeight(node, 1); - ASSERT_NE(nullptr, new_weight); - - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({5, 2, 7, 1})); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginShape(), ge::GeShape({5, 1, 2, 7})); - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_FLOAT); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - data = new_weight->GetData().GetData(); - data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 140); - for (size_t i = 0; i < 140; i++) { - EXPECT_EQ(data[i], i + 1); - } -} - -TEST_F(UTestFusionTurbo, test_case_14_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - - auto input1 = node->GetOpDesc()->MutableInputDesc(1); - input1->SetDataType(ge::DT_INT32); - input1->SetOriginDataType(ge::DT_INT16); - input1->SetFormat(ge::FORMAT_NCHW); - input1->SetOriginFormat(ge::FORMAT_NHWC); - input1->SetShape(ge::GeShape({1, 2, 3, 4})); - input1->SetOriginShape(ge::GeShape({1, 3, 4, 2})); - - WeightInfo w(node, 1, value.get()); - - ASSERT_NE(nullptr, acc.AddWeight(node, "shape", w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 2); - EXPECT_EQ(node->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto new_weight = FusionTurbo::MutableWeight(node, 1); - ASSERT_NE(nullptr, new_weight); - - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginShape(), ge::GeShape({1, 3, 4, 2})); - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT16); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NHWC); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t i = 0; i < 96; i++) { - EXPECT_EQ(data[i], i); - } - - w.shape = ge::GeShape({5, 2, 7, 1}); - w.ori_shape = ge::GeShape({5, 1, 2, 7}); - w.datatype = ge::DT_FLOAT16; - w.ori_datatype = ge::DT_FLOAT; - w.format = ge::FORMAT_NHWC; - w.ori_format = ge::FORMAT_NCHW; - - w.total_data_size = 140; - unique_ptr value1(new(std::nothrow) int32_t[140]); - auto data_ptr1 = (uint8_t *) (value1.get()); - for (size_t i = 0; i < 140; i++) { - data_ptr1[i] = i + 1; - } - w.data = (uint8_t *) value1.get(); - /* Update const value and tensor when const node and weight both exist. */ - ASSERT_NE(nullptr, acc.AddWeight(node, "shape", w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 2); - EXPECT_EQ(node->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - new_weight = FusionTurbo::MutableWeight(node, 1); - ASSERT_NE(nullptr, new_weight); - - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({5, 2, 7, 1})); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginShape(), ge::GeShape({5, 1, 2, 7})); - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_FLOAT); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NHWC); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - data = new_weight->GetData().GetData(); - data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 140); - for (size_t i = 0; i < 140; i++) { - EXPECT_EQ(data[i], i + 1); - } -} - - -TEST_F(UTestFusionTurbo, test_case_14_1_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - - auto input1 = node->GetOpDesc()->MutableInputDesc(1); - input1->SetDataType(ge::DT_INT32); - input1->SetOriginDataType(ge::DT_INT16); - input1->SetFormat(ge::FORMAT_NCHW); - input1->SetOriginFormat(ge::FORMAT_NHWC); - input1->SetShape(ge::GeShape({1, 2, 3, 4})); - input1->SetOriginShape(ge::GeShape({1, 3, 4, 2})); - - WeightInfo w(node, 1, value.get()); - - ASSERT_NE(nullptr, acc.AddWeight(node, "shape", w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 2); - EXPECT_EQ(node->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto new_weight = FusionTurbo::MutableWeight(node, 1); - ASSERT_NE(nullptr, new_weight); - - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginShape(), ge::GeShape({1, 3, 4, 2})); - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT16); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NHWC); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t i = 0; i < 96; i++) { - EXPECT_EQ(data[i], i); - } - - w.shape = ge::GeShape({5, 2, 7, 1}); - w.ori_shape = ge::GeShape({5, 1, 2, 7}); - w.datatype = ge::DT_FLOAT16; - w.ori_datatype = ge::DT_FLOAT; - w.format = ge::FORMAT_NHWC; - w.ori_format = ge::FORMAT_NCHW; - - w.total_data_size = 140; - unique_ptr value1(new(std::nothrow) int32_t[140]); - auto data_ptr1 = (uint8_t *) (value1.get()); - for (size_t i = 0; i < 140; i++) { - data_ptr1[i] = i + 1; - } - w.data = (uint8_t *) value1.get(); - /* Update const value and tensor when const node and weight both exist. */ - ASSERT_EQ(nullptr, acc.AddWeight(node, "xxxx", w)); -} - -TEST_F(UTestFusionTurbo, test_case_14_2) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - auto const_node = acc.InsertNodeBefore("const_1", "Const", node, 1); - ASSERT_NE(nullptr, const_node); - auto const_out_desc = const_node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(const_out_desc->GetShape(), ge::GeShape()); - EXPECT_EQ(const_out_desc->GetOriginShape(), ge::GeShape()); - EXPECT_EQ(const_out_desc->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(const_out_desc->GetOriginDataType(), ge::DT_UNDEFINED); - EXPECT_EQ(const_out_desc->GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(const_out_desc->GetOriginFormat(), ge::FORMAT_ND); - - auto weight = FusionTurbo::MutableWeight(node, 1); - ASSERT_NE(nullptr, weight); - auto &weight_tensor = weight->GetTensorDesc(); - EXPECT_EQ(weight_tensor.GetShape(), ge::GeShape()); - EXPECT_EQ(weight_tensor.GetOriginShape(), ge::GeShape()); - EXPECT_EQ(weight_tensor.GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(weight_tensor.GetOriginDataType(), ge::DT_UNDEFINED); - EXPECT_EQ(weight_tensor.GetFormat(), ge::FORMAT_ND); - EXPECT_EQ(weight_tensor.GetOriginFormat(), ge::FORMAT_ND); - const uint8_t *data = weight->GetData().GetData(); - auto data_size = weight->GetData().size(); - EXPECT_EQ(data_size, 0); - EXPECT_NE(data, nullptr); - - - auto input1 = node->GetOpDesc()->MutableInputDesc(1); - input1->SetDataType(ge::DT_INT32); - input1->SetOriginDataType(ge::DT_INT16); - input1->SetFormat(ge::FORMAT_NCHW); - input1->SetOriginFormat(ge::FORMAT_NHWC); - input1->SetShape(ge::GeShape({1, 2, 3, 4})); - input1->SetOriginShape(ge::GeShape({1, 3, 4, 2})); - - WeightInfo w(node, 1, value.get()); - ASSERT_NE(nullptr, acc.AddWeight(node, 1, w)); - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 2); - EXPECT_EQ(node->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto new_weight = FusionTurbo::MutableWeight(node, 1); - ASSERT_NE(nullptr, new_weight); - - auto &new_weight_tensor = new_weight->GetTensorDesc(); - EXPECT_EQ(new_weight_tensor.GetShape(), ge::GeShape({1, 2, 3, 4})); - EXPECT_EQ(new_weight_tensor.GetOriginShape(), ge::GeShape({1, 3, 4, 2})); - EXPECT_EQ(new_weight_tensor.GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight_tensor.GetOriginDataType(), ge::DT_INT16); - EXPECT_EQ(new_weight_tensor.GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight_tensor.GetOriginFormat(), ge::FORMAT_NHWC); - data = new_weight->GetData().GetData(); - data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t i = 0; i < 96; i++) { - EXPECT_EQ(data[i], i); - } -} - -TEST_F(UTestFusionTurbo, test_case_15) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - WeightInfo w = {ge::GeShape({1, 2, 3, 4}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - WeightInfo w2 = {ge::GeShape({1, 3, 2, 4}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - WeightInfo w3 = {ge::GeShape({4, 1, 3, 2}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - std::vector weight_all = {std::move(w), std::move(w2), std::move(w3)}; - - auto const_nodes = acc.AddWeights(node, weight_all); - ASSERT_EQ(const_nodes.size(), 3); - - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 5); - EXPECT_EQ(node->GetInDataAnchor(2)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - EXPECT_EQ(node->GetInDataAnchor(3)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - EXPECT_EQ(node->GetInDataAnchor(4)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - - for (size_t i = 2; i < 5; i++) { - auto new_weight = FusionTurbo::MutableWeight(node, i); - ASSERT_NE(nullptr, new_weight); - if (i == 2) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - } else if (i == 3) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 3, 2, 4})); - } else { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({4, 1, 3, 2})); - } - - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t j = 0; j < 96; j++) { - EXPECT_EQ(data[j], j); - } - } -} - - -TEST_F(UTestFusionTurbo, test_case_15_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - WeightInfo w = {ge::GeShape({1, 2, 3, 4}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - WeightInfo w2 = {ge::GeShape({1, 3, 2, 4}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - WeightInfo w3 = {ge::GeShape({4, 1, 3, 2}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - acc.AddWeight(node, 1, w); - std::vector weight_all = {std::move(w), std::move(w2), std::move(w3)}; - - auto const_nodes = acc.AddWeights(node, weight_all); - ASSERT_EQ(const_nodes.size(), 3); - - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 5); - EXPECT_EQ(node->GetInDataAnchor(2)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - EXPECT_EQ(node->GetInDataAnchor(3)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - EXPECT_EQ(node->GetInDataAnchor(4)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto all_weights = ge::OpDescUtils::MutableWeights(node); - ASSERT_EQ(all_weights.size(), 4); - - for (size_t i = 1; i < 5; i++) { - auto new_weight = FusionTurbo::MutableWeight(node, i); - ASSERT_NE(nullptr, new_weight); - if (i == 1 || i == 2) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - } else if (i == 3) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 3, 2, 4})); - } else { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({4, 1, 3, 2})); - } - - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t j = 0; j < 96; j++) { - EXPECT_EQ(data[j], j); - } - } -} - -TEST_F(UTestFusionTurbo, test_case_15_3) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - - auto node_input_1 = node->GetOpDesc()->MutableInputDesc(1); - auto weight_shape = ge::GeShape({1, 2, 3, 4}); - node_input_1->SetOriginShape(weight_shape); - node_input_1->SetShape(weight_shape); - node_input_1->SetDataType(ge::DT_INT32); - node_input_1->SetOriginDataType(ge::DT_INT32); - node_input_1->SetFormat(ge::FORMAT_NCHW); - node_input_1->SetOriginFormat(ge::FORMAT_NCHW); - - WeightInfo w = {*node_input_1, value.get()}; - WeightInfo w2 = {ge::GeShape({1, 3, 2, 4}), ge::GeShape({1, 3, 2, 4}), - ge::DT_INT32, ge::DT_INT32, ge::FORMAT_NCHW, ge::FORMAT_NCHW, value.get()}; - WeightInfo w3 = {ge::GeShape({4, 1, 3, 2}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - acc.AddWeight(node, 1, w); - std::vector weight_all = {std::move(w), std::move(w2), std::move(w3)}; - - auto const_nodes = acc.AddWeights(node, weight_all); - ASSERT_EQ(const_nodes.size(), 3); - - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 5); - EXPECT_EQ(node->GetInDataAnchor(2)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - EXPECT_EQ(node->GetInDataAnchor(3)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - EXPECT_EQ(node->GetInDataAnchor(4)->GetPeerOutAnchor()->GetOwnerNode()->GetType(), "Const"); - auto all_weights = ge::OpDescUtils::MutableWeights(node); - ASSERT_EQ(all_weights.size(), 4); - - for (size_t i = 1; i < 5; i++) { - auto new_weight = FusionTurbo::MutableWeight(node, i); - ASSERT_NE(nullptr, new_weight); - if (i == 1 || i == 2) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - } else if (i == 3) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 3, 2, 4})); - } else { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({4, 1, 3, 2})); - } - - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t j = 0; j < 96; j++) { - EXPECT_EQ(data[j], j); - } - } -} - -TEST_F(UTestFusionTurbo, test_case_15_4) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - Relations src_list = {{relu, 0}}; - acc.LinkInput(src_list, node, UPDATE_THIS); - - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{cast2, 0}}; - acc.BreakInput(cast2, {0}); - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - - unique_ptr value(new(std::nothrow) int32_t[24]); - auto data_ptr = (uint8_t *) (value.get()); - for (size_t i = 0; i < 96; i++) { - data_ptr[i] = i; - } - - auto node_input_1 = node->GetOpDesc()->MutableInputDesc(1); - auto weight_shape = ge::GeShape({1, 2, 3, 4}); - node_input_1->SetOriginShape(weight_shape); - node_input_1->SetShape(weight_shape); - node_input_1->SetDataType(ge::DT_INT32); - node_input_1->SetOriginDataType(ge::DT_INT32); - node_input_1->SetFormat(ge::FORMAT_NCHW); - node_input_1->SetOriginFormat(ge::FORMAT_NCHW); - - WeightInfo w = {*node_input_1, value.get()}; - WeightInfo w2 = {ge::GeShape({1, 3, 2, 4}), ge::GeShape({1, 3, 2, 4}), - ge::DT_INT32, ge::DT_INT32, ge::FORMAT_NCHW, ge::FORMAT_NCHW, value.get()}; - WeightInfo w3 = {ge::GeShape({4, 1, 3, 2}), ge::DT_INT32, ge::FORMAT_NCHW, value.get()}; - acc.AddWeight(node, 1, w); - std::vector weight_all = {std::move(w), std::move(w2), std::move(w3)}; - - auto const_nodes = acc.AddWeights(node, weight_all); - ASSERT_EQ(const_nodes.size(), 3); - - ASSERT_EQ(node->GetAllInDataAnchorsSize(), 5); - auto const_2 = FusionTurbo::GetPeerOutNode(node, 2); - auto const_2_peer_in = FusionTurbo::GetPeerInNodes(const_2, 0); - ASSERT_EQ(const_2_peer_in.size(), 1); - auto node_temp = const_2_peer_in.at(0); - - EXPECT_EQ(node_temp, node); - EXPECT_EQ(FusionTurbo::CheckConnected(const_2, node), true); - EXPECT_EQ(FusionTurbo::CheckConnected(const_2, node, 0), true); - EXPECT_EQ(const_2->GetType(), "Const"); - EXPECT_EQ(FusionTurbo::GetPeerOutNode(node, 3)->GetType(), "Const"); - EXPECT_EQ(FusionTurbo::GetPeerOutNode(node, 4)->GetType(), "Const"); - auto all_weights = ge::OpDescUtils::MutableWeights(node); - ASSERT_EQ(all_weights.size(), 4); - - for (size_t i = 1; i < 5; i++) { - auto new_weight = FusionTurbo::MutableWeight(node, i); - ASSERT_NE(nullptr, new_weight); - if (i == 1 || i == 2) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 2, 3, 4})); - } else if (i == 3) { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({1, 3, 2, 4})); - } else { - EXPECT_EQ(new_weight->GetTensorDesc().GetShape(), ge::GeShape({4, 1, 3, 2})); - } - - EXPECT_EQ(new_weight->GetTensorDesc().GetDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginDataType(), ge::DT_INT32); - EXPECT_EQ(new_weight->GetTensorDesc().GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(new_weight->GetTensorDesc().GetOriginFormat(), ge::FORMAT_NCHW); - const uint8_t *data = new_weight->GetData().GetData(); - auto data_size = new_weight->GetData().size(); - EXPECT_EQ(data_size, 96); - for (size_t j = 0; j < 96; j++) { - EXPECT_EQ(data[j], j); - } - } -} - -TEST_F(UTestFusionTurbo, test_case_16_1) { - /* - * input data input - * | \ \ / - * cast \ add input1 - * | / | / - * add1 netoutput - * | - * partioncall input2 - * / \ / - * partionout add2 - * / \ - * netout netout1 - */ - auto graph = CreateGraphParentAndSub(); - FusionTurbo acc(graph); - auto movnode = graph->FindNode("add2"); - acc.GraphNodeUpMigration(movnode, 0); - /* - * input data input input1 data1 - * | \ \ / \ / - * cast \ add add2 - * | / \ / - * add1 input2 \ / - * | / \ / - * partioncall netouput - * / \ \ - * partionout netout netout1 - */ - auto aftermovnode = graph->FindNode("add2"); - EXPECT_EQ(aftermovnode, nullptr); - - auto partioncall_node = graph->FindNode("partioncall"); - EXPECT_NE(partioncall_node, nullptr); - auto sub_graph = ge::NodeUtils::GetSubgraph(*partioncall_node, 0); - - auto add2 = sub_graph->FindNode("add2"); - EXPECT_NE(add2, nullptr); - - auto in_nodes = add2->GetInDataNodes(); - ASSERT_EQ(in_nodes.size(), 2); - EXPECT_EQ(in_nodes.at(0)->GetType(), "NetOutInput"); - EXPECT_EQ(in_nodes.at(1)->GetType(), "Data"); - - auto out_nodes = add2->GetOutDataNodes(); - EXPECT_EQ(out_nodes.at(0)->GetType(), "NetOutput"); - - auto data1 = in_nodes.at(1); - int64_t index; - ge::AttrUtils::GetInt(data1->GetOpDesc(), ge::ATTR_NAME_PARENT_NODE_INDEX, index); - EXPECT_EQ(index, 1); - EXPECT_EQ(partioncall_node->GetInDataAnchor(1)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "other1"); -} - -TEST_F(UTestFusionTurbo, test_case_16_2) { - /* - * input data input - * | \ \ / - * cast \ add input1 - * | / | / - * add1 netoutput - * | - * partioncall input2 - * / \ / - * partionout add2 - * / \ - * netout netout1 - */ - auto graph = CreateGraphParentAndSub(); - FusionTurbo acc(graph); - auto movnode = graph->FindNode("add1"); - acc.GraphNodeDownMigration(movnode, 0); - /* - * input data data1 - * | \ \ / - * cast \ add1 input - * | \ \ / - * | / add input1 - * | / | / - * partioncall input2 netout - * / \ / - * partionout add2 - * / \ - * netout netout1 - */ - auto aftermovnode = graph->FindNode("add1"); - EXPECT_EQ(aftermovnode, nullptr); - auto partioncall_node = graph->FindNode("partioncall"); - EXPECT_NE(partioncall_node, nullptr); - auto sub_graph = ge::NodeUtils::GetSubgraph(*partioncall_node, 0); - - auto add1 = sub_graph->FindNode("add1"); - EXPECT_NE(add1, nullptr); - - auto in_nodes = add1->GetInDataNodes(); - EXPECT_EQ(in_nodes.at(0)->GetType(), "Data"); - EXPECT_EQ(in_nodes.at(1)->GetType(), "Data"); - - auto out_nodes = add1->GetOutDataNodes(); - EXPECT_EQ(out_nodes.at(0)->GetType(), "Add"); -} - -TEST_F(UTestFusionTurbo, test_case_17) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{relu, 0, PEER}}; - Relations src_list = {{relu, 0}}; - - acc.LinkInput(src_list, node); - - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} - -TEST_F(UTestFusionTurbo, test_case_17_1) { - auto graph = CreateGraphSingleInAndOut(); - FusionTurbo acc(graph); - string name = "transpose"; - string type = "Transpose"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu = GetNode(graph, "relu"); - auto cast2 = GetNode(graph, "cast2"); - Relations dst_list = {{relu, 0, PEER}}; - Relations src_list = {{relu, 0}}; - - Status ret = acc.LinkOutput(dst_list, node, UPDATE_THIS); - - acc.LinkInput(src_list, node); - - - EXPECT_EQ(ret, SUCCESS); - EXPECT_EQ(node->GetName(), name); - EXPECT_EQ(node->GetType(), type); - vector dims = {1, 4, 64, 64}; - auto input = node->GetOpDesc()->MutableInputDesc(0); - EXPECT_EQ(input->GetShape().GetDims(), dims); - EXPECT_EQ(input->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(input->GetDataType(), ge::DT_FLOAT); - - auto output = node->GetOpDesc()->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), dims); - EXPECT_EQ(output->GetFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(output->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode()->GetName(), "relu"); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size(), 1); - EXPECT_EQ(node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode()->GetName(), "cast2"); -} -} diff --git a/tests/ut/register/testcase/graph_fusion_turbo_unittest2.cc b/tests/ut/register/testcase/graph_fusion_turbo_unittest2.cc deleted file mode 100644 index 7d26b3b3c4bf0d9b95b3c76e217bd0d5a2e1e555..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/graph_fusion_turbo_unittest2.cc +++ /dev/null @@ -1,1009 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" - -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/graph_optimizer/fusion_common/fusion_turbo.h" - -#include "external/graph/operator_factory.h" -#include "external/graph/operator_reg.h" -#include "graph/operator_factory_impl.h" -#include "graph/debug/ge_log.h" - -using namespace testing; -using namespace ge; -using namespace fe; - -namespace fe { -REG_OP(Const) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(value, Tensor, Tensor()) - .OP_END_FACTORY_REG(Const); - -REG_OP(Transpose) - .INPUT(x, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .INPUT(shape, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(axis, Int, 0) - .ATTR(num_axes, Int, -1) - .OP_END_FACTORY_REG(Transpose); - -REG_OP(Add) - .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OP_END_FACTORY_REG(Add) - -REG_OP(MultiAdd) - .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x3, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x4, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OP_END_FACTORY_REG(MultiAdd) - -REG_OP(Relu) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, - DT_INT8, DT_INT32, DT_INT16, DT_INT64, - DT_UINT8, DT_UINT16, DT_QINT8})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, - DT_INT8, DT_INT32, DT_INT16, DT_INT64, - DT_UINT8, DT_UINT16, DT_QINT8})) - .OP_END_FACTORY_REG(Relu) - -REG_OP(End) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .ATTR(peerIndex, Int, 0) - .ATTR(parentOpType, String, "") - .OP_END_FACTORY_REG(End) - -REG_OP(LarsV2Update) - .INPUT(w, TensorType(DT_FLOAT)) - .INPUT(g, TensorType(DT_FLOAT)) - .INPUT(w_square_sum, TensorType(DT_FLOAT)) - .INPUT(g_square_sum, TensorType(DT_FLOAT)) - .INPUT(weight_decay, TensorType(DT_FLOAT)) - .INPUT(learning_rate, TensorType(DT_FLOAT)) - .OUTPUT(g_new, TensorType(DT_FLOAT)) - .ATTR(hyperpara, Float, 0.001) - .ATTR(epsilon, Float, 0.00001) - .ATTR(use_clip, Bool, false) - .OP_END_FACTORY_REG(LarsV2Update) - -REG_OP(SquareSumAll) - .INPUT(x1, TensorType({DT_FLOAT})) - .INPUT(x2, TensorType({DT_FLOAT})) - .OUTPUT(y1, TensorType({DT_FLOAT})) - .OUTPUT(y2, TensorType({DT_FLOAT})) - .OP_END_FACTORY_REG(SquareSumAll) - -REG_OP(LarsV2) - .INPUT(w, TensorType(DT_FLOAT)) - .INPUT(g, TensorType(DT_FLOAT)) - .INPUT(weight_decay, TensorType(DT_FLOAT)) - .INPUT(learning_rate, TensorType(DT_FLOAT)) - .OUTPUT(g_new, TensorType(DT_FLOAT)) - .ATTR(hyperpara, Float, 0.001) - .ATTR(epsilon, Float, 0.00001) - .ATTR(use_clip, Bool, false) - .OP_END_FACTORY_REG(LarsV2) - -class UTestFusionTurbo2 : public testing::Test { - public: - - protected: - - - void SetUp() { - } - - void TearDown() { - } - - ge::NodePtr GetNode(ComputeGraphPtr &graph, const string &name) { - for (auto &node : graph->GetDirectNode()) { - if (node->GetName() == name) { - return node; - } - } - return nullptr; - } - - ComputeGraphPtr CreateComplexGraph() { - ComputeGraphPtr graph = std::make_shared("test1"); - - OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_b); - - op_desc_relu2->AddInputDesc(tensor_desc_a); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_output->AddInputDesc(tensor_desc_b); - op_desc_output->AddInputDesc(tensor_desc_b); - - NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - - GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - - FusionTurbo acc(graph); - auto node_add = acc.InsertNodeAfter("add", "Add", node_relu2, 0, 1); - EXPECT_NE(node_add, nullptr); - Relations rl(0, {node_relu1, 0}); - acc.LinkInput(rl, node_add); - - unique_ptr data(new(std::nothrow) int32_t[4096]); - WeightInfo w(tensor_desc_a, data.get()); - acc.AddWeight(node_relu1, 0, w); - acc.AddWeight(node_relu2, 0, w); - return graph; - } - - ComputeGraphPtr CreateComplexGraph2() { - ComputeGraphPtr graph = std::make_shared("test2"); - - OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_b); - - op_desc_relu2->AddInputDesc(tensor_desc_a); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_output->AddInputDesc(tensor_desc_b); - op_desc_output->AddInputDesc(tensor_desc_b); - - NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - - GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - - FusionTurbo acc(graph); - auto node_add = acc.InsertNodeAfter("add", "Add", node_relu2, 0, 0); - EXPECT_NE(node_add, nullptr); - Relations rl(1, {node_relu1, 0}); - acc.LinkInput(rl, node_add); - - auto relu1_front = acc.InsertNodeBefore("relu1_front", "Relu", node_relu1, 0); - - auto relu2_front = acc.InsertNodeBefore("relu2_front", "Relu", node_relu2, 0); - - unique_ptr data(new(std::nothrow) int32_t[4096]); - WeightInfo w(tensor_desc_a, data.get()); - acc.AddWeight(relu1_front, 0, w); - acc.AddWeight(relu2_front, 0, w); - return graph; - } - - static void DumpGraph(const ge::ComputeGraphPtr graph, string graph_name) { - printf("start to dump graph %s...\n", graph_name.c_str()); - for (ge::NodePtr node : graph->GetAllNodes()) { - printf("node name = %s.\n", node->GetName().c_str()); - for (ge::OutDataAnchorPtr anchor : node->GetAllOutDataAnchors()) { - for (ge::InDataAnchorPtr peer_in_anchor : anchor->GetPeerInDataAnchors()) { - printf(" node name = %s[%d], out data node name = %s[%d].\n", - node->GetName().c_str(), - anchor->GetIdx(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_anchor->GetIdx()); - } - } - if (node->GetOutControlAnchor() != nullptr) { - for (ge::InControlAnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { - printf(" node name = %s, out control node name = %s.\n", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str()); - } - } - } - } - -}; - -TEST_F(UTestFusionTurbo2, test_case_01) { - auto graph = CreateComplexGraph(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu2_input, 0}}}; - Relations output_relations = {0, {out, 0}}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, {relu1, relu2, add}, true); - EXPECT_EQ(ret, SUCCESS); - - auto relu1_out_nodes = relu1_input->GetOutDataNodes(); - auto relu2_out_nodes = relu2_input->GetOutDataNodes(); - ASSERT_EQ(relu1_out_nodes.size(), 1); - ASSERT_EQ(relu2_out_nodes.size(), 1); - EXPECT_EQ(relu1_out_nodes.at(0)->GetName(), "add_new"); - - EXPECT_EQ(relu2_out_nodes.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add_new"); - EXPECT_EQ(graph->GetDirectNodesSize(), 4); -} - -TEST_F(UTestFusionTurbo2, test_case_01_1) { - auto graph = CreateComplexGraph(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu2_input, 0}}}; - Relations output_relations = {0, {out, 0}}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, {}, false); - EXPECT_EQ(ret, SUCCESS); - - auto relu1_out_nodes = relu1_input->GetOutDataNodes(); - auto relu2_out_nodes = relu2_input->GetOutDataNodes(); - ASSERT_EQ(relu1_out_nodes.size(), 2); - ASSERT_EQ(relu2_out_nodes.size(), 2); - EXPECT_EQ(relu1_out_nodes.at(0)->GetName(), "relu1"); - EXPECT_EQ(relu1_out_nodes.at(1)->GetName(), "add_new"); - - EXPECT_EQ(relu2_out_nodes.at(0)->GetName(), "relu2"); - EXPECT_EQ(relu2_out_nodes.at(1)->GetName(), "add_new"); - - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 2); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_in_nodes.at(1)->GetName(), "add"); - - auto add_in_nodes = add->GetInDataNodes(); - EXPECT_EQ(add_in_nodes.size(), 2); - EXPECT_EQ(add_in_nodes.at(0)->GetName(), "relu1"); - EXPECT_EQ(add_in_nodes.at(1)->GetName(), "relu2"); -} - -TEST_F(UTestFusionTurbo2, test_case_01_2) { - auto graph = CreateComplexGraph(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {0, {relu1_input, 0}}, - {1, {relu2_input, 0}}}; - Relations output_relations = {0, {out, 0}}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, {}, false); - EXPECT_EQ(ret, SUCCESS); - - auto relu1_out_nodes = relu1_input->GetOutDataNodes(); - auto relu2_out_nodes = relu2_input->GetOutDataNodes(); - ASSERT_EQ(relu1_out_nodes.size(), 2); - ASSERT_EQ(relu2_out_nodes.size(), 2); - EXPECT_EQ(relu1_out_nodes.at(0)->GetName(), "relu1"); - EXPECT_EQ(relu1_out_nodes.at(1)->GetName(), "add_new"); - - EXPECT_EQ(relu2_out_nodes.at(0)->GetName(), "relu2"); - EXPECT_EQ(relu2_out_nodes.at(1)->GetName(), "add_new"); - - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 2); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_in_nodes.at(1)->GetName(), "add"); - - auto add_in_nodes = add->GetInDataNodes(); - EXPECT_EQ(add_in_nodes.size(), 2); - EXPECT_EQ(add_in_nodes.at(0)->GetName(), "relu1"); - EXPECT_EQ(add_in_nodes.at(1)->GetName(), "relu2"); -} - -TEST_F(UTestFusionTurbo2, test_case_01_3) { - auto graph = CreateComplexGraph(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {2, {relu1_input, 0}}, - {1, {relu2_input, 0}}}; - Relations output_relations = {0, {out, 0}}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, {}, false); - EXPECT_EQ(ret, FAILED); - - auto relu1_out_nodes = relu1_input->GetOutDataNodes(); - auto relu2_out_nodes = relu2_input->GetOutDataNodes(); - ASSERT_EQ(relu1_out_nodes.size(), 1); - ASSERT_EQ(relu2_out_nodes.size(), 1); - EXPECT_EQ(relu1_out_nodes.at(0)->GetName(), "relu1"); - - EXPECT_EQ(relu2_out_nodes.at(0)->GetName(), "relu2"); - - auto out_in_nodes = out->GetInDataNodes(); - ASSERT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_in_nodes = add->GetInDataNodes(); - ASSERT_EQ(add_in_nodes.size(), 2); - EXPECT_EQ(add_in_nodes.at(0)->GetName(), "relu1"); - EXPECT_EQ(add_in_nodes.at(1)->GetName(), "relu2"); -} - - -TEST_F(UTestFusionTurbo2, test_case_01_4) { - auto graph = CreateComplexGraph(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu2_input, 0}}}; - Relations output_relations = {{0, {out, 0}}, - {1, {out, 1}}}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, {}, false); - EXPECT_EQ(ret, FAILED); - - auto relu1_out_nodes = relu1_input->GetOutDataNodes(); - auto relu2_out_nodes = relu2_input->GetOutDataNodes(); - ASSERT_EQ(relu1_out_nodes.size(), 1); - ASSERT_EQ(relu2_out_nodes.size(), 1); - EXPECT_EQ(relu1_out_nodes.at(0)->GetName(), "relu1"); - - EXPECT_EQ(relu2_out_nodes.at(0)->GetName(), "relu2"); - - auto out_in_nodes = out->GetInDataNodes(); - ASSERT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_in_nodes = add->GetInDataNodes(); - ASSERT_EQ(add_in_nodes.size(), 2); - EXPECT_EQ(add_in_nodes.at(0)->GetName(), "relu1"); - EXPECT_EQ(add_in_nodes.at(1)->GetName(), "relu2"); - EXPECT_EQ(graph->GetDirectNodesSize(), 7); -} - - -TEST_F(UTestFusionTurbo2, test_case_01_5) { - auto graph = CreateComplexGraph(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu2_input, 0}}}; - Relations output_relations = {0, {out, 0}}; - ge::GraphUtils::AddEdge(relu1->GetOutControlAnchor(), add->GetInControlAnchor()); - ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), relu1->GetInControlAnchor()); - // This is a very special case! node is in old_nodes!! - Status ret = acc.MultiInOne(node, input_relations, output_relations, {relu1, relu2, add, node}, false); - EXPECT_EQ(ret, SUCCESS); - - auto relu1_out_nodes = relu1_input->GetOutDataNodes(); - auto relu2_out_nodes = relu2_input->GetOutDataNodes(); - ASSERT_EQ(relu1_out_nodes.size(), 2); - ASSERT_EQ(relu2_out_nodes.size(), 2); - EXPECT_EQ(relu1_out_nodes.at(0)->GetName(), "relu1"); - EXPECT_EQ(relu1_out_nodes.at(1)->GetName(), "add_new"); - - EXPECT_EQ(relu1_out_nodes.at(0)->GetOutControlNodes().size(), 3); - EXPECT_EQ(relu1_out_nodes.at(1)->GetOutControlNodes().size(), 4); - - EXPECT_EQ(relu2_out_nodes.at(0)->GetName(), "relu2"); - EXPECT_EQ(relu2_out_nodes.at(1)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 2); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add_new"); - EXPECT_EQ(graph->GetDirectNodesSize(), 7); -} - -TEST_F(UTestFusionTurbo2, test_case_2) { - auto graph = CreateComplexGraph2(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto relu2_front = GetNode(graph, "relu2_front"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_front_input = relu2_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu2_front_input, 0}}}; - Relations output_relations = {{0, {add, 0}}, - {0, {relu2, 0}}}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, {relu1, relu2_front}, true); - EXPECT_EQ(ret, SUCCESS); - - auto out_nodes1 = relu1_input->GetOutDataNodes(); - auto out_nodes2 = relu2_front_input->GetOutDataNodes(); - ASSERT_EQ(out_nodes1.size(), 1); - ASSERT_EQ(out_nodes2.size(), 1); - - EXPECT_EQ(out_nodes1.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_nodes2.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_new_out_nodes = node->GetOutDataNodes(); - EXPECT_EQ(add_new_out_nodes.size(), 2); - EXPECT_EQ(add_new_out_nodes.at(0)->GetName(), "add"); - EXPECT_EQ(add_new_out_nodes.at(1)->GetName(), "relu2"); -} - - -TEST_F(UTestFusionTurbo2, test_case_3) { - auto graph = CreateComplexGraph2(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "MultiAdd"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu1_front = GetNode(graph, "relu1_front"); - auto relu2 = GetNode(graph, "relu2"); - auto relu2_front = GetNode(graph, "relu2_front"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu1_front_input = relu1_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_front_input = relu2_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu1_front_input, 0}}, - {2, {relu2_input, 0}}, - {3, {relu2_front_input, 0}}}; - Relations output_relations = {{0, {add, 0}}, - {0, {add, 1}}}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, - {relu1, relu1_front, relu2, relu2_front}, true); - EXPECT_EQ(ret, SUCCESS); - - auto out_nodes1 = relu1_front_input->GetOutDataNodes(); - auto out_nodes2 = relu2_front_input->GetOutDataNodes(); - ASSERT_EQ(out_nodes1.size(), 1); - ASSERT_EQ(out_nodes2.size(), 1); - - EXPECT_EQ(out_nodes1.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_nodes2.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_new_out_nodes = node->GetOutDataNodes(); - EXPECT_EQ(add_new_out_nodes.size(), 2); - EXPECT_EQ(add_new_out_nodes.at(0)->GetName(), "add"); - EXPECT_EQ(add_new_out_nodes.at(1)->GetName(), "add"); - EXPECT_EQ(graph->GetDirectNodesSize(), 5); -} - -TEST_F(UTestFusionTurbo2, test_case_4) { - auto graph = CreateComplexGraph2(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "MultiAdd"; - - auto relu1 = GetNode(graph, "relu1"); - auto relu1_front = GetNode(graph, "relu1_front"); - auto relu2 = GetNode(graph, "relu2"); - auto relu2_front = GetNode(graph, "relu2_front"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu1_front_input = relu1_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_front_input = relu2_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu1_front_input, 0}}, - {2, {relu2_input, 0}}, - {3, {relu2_front_input, 0}}}; - Relations output_relations = {{0, {add, 0}}, - {0, {add, 1}}}; - auto node = acc.MultiInOne(name, type, input_relations, output_relations, - {relu1, relu1_front, relu2, relu2_front}, true); - EXPECT_NE(node, nullptr); - - auto out_nodes1 = relu1_front_input->GetOutDataNodes(); - auto out_nodes2 = relu2_front_input->GetOutDataNodes(); - ASSERT_EQ(out_nodes1.size(), 1); - ASSERT_EQ(out_nodes2.size(), 1); - - EXPECT_EQ(out_nodes1.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_nodes2.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_new_out_nodes = node->GetOutDataNodes(); - EXPECT_EQ(add_new_out_nodes.size(), 2); - EXPECT_EQ(add_new_out_nodes.at(0)->GetName(), "add"); - EXPECT_EQ(add_new_out_nodes.at(1)->GetName(), "add"); - EXPECT_EQ(graph->GetDirectNodesSize(), 5); -} - -TEST_F(UTestFusionTurbo2, test_case_4_1) { - auto graph = CreateComplexGraph2(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "MultiAdd"; - - auto relu1 = GetNode(graph, "relu1"); - auto relu1_front = GetNode(graph, "relu1_front"); - auto relu2 = GetNode(graph, "relu2"); - auto relu2_front = GetNode(graph, "relu2_front"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_front_input = relu1_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_front_input = relu2_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1, 0, PEER}}, - {1, {relu1_front, 0, PEER}}, - {2, {relu2, 0, PEER}}, - {3, {relu2_front, 0, PEER}}}; - Relations output_relations = {{0, {add, 1}}, - {0, {relu2, 0, PEER}}}; - auto node = acc.MultiInOne(name, type, input_relations, output_relations, - {relu1, relu1_front, relu2, relu2_front}, true); - EXPECT_NE(node, nullptr); - - auto out_nodes1 = relu1_front_input->GetOutDataNodes(); - auto out_nodes2 = relu2_front_input->GetOutDataNodes(); - ASSERT_EQ(out_nodes1.size(), 1); - ASSERT_EQ(out_nodes2.size(), 1); - - EXPECT_EQ(out_nodes1.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_nodes2.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_new_out_nodes = node->GetOutDataNodes(); - EXPECT_EQ(add_new_out_nodes.size(), 2); - EXPECT_EQ(add_new_out_nodes.at(0)->GetName(), "add"); - EXPECT_EQ(add_new_out_nodes.at(1)->GetName(), "add"); - EXPECT_EQ(graph->GetDirectNodesSize(), 5); -} - - -TEST_F(UTestFusionTurbo2, test_case_4_2) { - auto graph = CreateComplexGraph2(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "MultiAdd"; - - auto relu1 = GetNode(graph, "relu1"); - auto relu1_front = GetNode(graph, "relu1_front"); - auto relu2 = GetNode(graph, "relu2"); - auto relu2_front = GetNode(graph, "relu2_front"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_front_input = relu1_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_front_input = relu2_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - Relations input_relations = {{0, {relu1, 0, PEER_SINGLE}}, - {1, {relu1_front, 0, PEER_SINGLE}}, - {2, {relu2, 0, PEER_SINGLE}}, - {3, {relu2_front, 0, PEER_SINGLE}}}; - Relations output_relations = {{0, {add, 1, PEER}}, - {0, {relu2, 0, PEER}}}; - auto node = acc.MultiInOne(name, type, input_relations, output_relations, - {relu1, relu1_front, relu2, relu2_front}, true); - EXPECT_NE(node, nullptr); - - auto input_nodes = node->GetInDataNodes(); - ASSERT_EQ(input_nodes.size(), 2); - - auto out_nodes1 = relu1_front_input->GetOutDataNodes(); - auto out_nodes2 = relu2_front_input->GetOutDataNodes(); - ASSERT_EQ(out_nodes1.size(), 1); - ASSERT_EQ(out_nodes2.size(), 1); - - EXPECT_EQ(out_nodes1.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_nodes2.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_new_out_nodes = node->GetOutDataNodes(); - EXPECT_EQ(add_new_out_nodes.size(), 1); - EXPECT_EQ(add_new_out_nodes.at(0)->GetName(), "add"); - EXPECT_EQ(graph->GetDirectNodesSize(), 5); -} - - -TEST_F(UTestFusionTurbo2, test_case_4_3) { - auto graph = CreateComplexGraph2(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "MultiAdd"; - - auto relu1 = GetNode(graph, "relu1"); - auto relu1_front = GetNode(graph, "relu1_front"); - auto relu2 = GetNode(graph, "relu2"); - auto relu2_front = GetNode(graph, "relu2_front"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_front_input = relu1_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_front_input = relu2_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - Relations input_relations = {{0, {relu1, 0, PEER_SINGLE}}, - {1, {relu1_front, 0, PEER_SINGLE}}, - {2, {relu2, 0, PEER_SINGLE}}, - {3, {relu2_front, 0, PEER_SINGLE}}}; - Relations output_relations = {{0, {add, 1, PEER_SINGLE}}, - {0, {relu2, 0, PEER_SINGLE}}}; - auto node = acc.MultiInOne(name, type, input_relations, output_relations, - {relu1, relu1_front, relu2, relu2_front}, true); - EXPECT_NE(node, nullptr); - - auto out_nodes1 = relu1_front_input->GetOutDataNodes(); - auto out_nodes2 = relu2_front_input->GetOutDataNodes(); - ASSERT_EQ(out_nodes1.size(), 1); - ASSERT_EQ(out_nodes2.size(), 1); - - EXPECT_EQ(out_nodes1.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_nodes2.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_new_out_nodes = node->GetOutDataNodes(); - EXPECT_EQ(add_new_out_nodes.size(), 1); - EXPECT_EQ(add_new_out_nodes.at(0)->GetName(), "add"); - EXPECT_EQ(graph->GetDirectNodesSize(), 5); -} - -/* Test RemoveMultiNodesOnly. */ -TEST_F(UTestFusionTurbo2, test_case_4_4) { - auto graph = CreateComplexGraph2(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "MultiAdd"; - - auto relu1 = GetNode(graph, "relu1"); - auto relu1_front = GetNode(graph, "relu1_front"); - auto relu2 = GetNode(graph, "relu2"); - auto relu2_front = GetNode(graph, "relu2_front"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - FusionTurbo::GetPeerInFirstPair(add, 0); - FusionTurbo::GetPeerOutPair(add, 0); - auto relu1_front_input = relu1_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_front_input = relu2_front->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - Relations input_relations = {{0, {relu1, 0, PEER_SINGLE}}, - {1, {relu1_front, 0, PEER_SINGLE}}, - {2, {relu2, 0, PEER_SINGLE}}, - {3, {relu2_front, 0, PEER_SINGLE}}}; - Relations output_relations = {{0, {add, 1, PEER_SINGLE}}, - {0, {relu2, 0, PEER_SINGLE}}}; - auto node = acc.MultiInOne(name, type, input_relations, output_relations); - EXPECT_NE(node, nullptr); - acc.RemoveMultiNodesOnly({nullptr}); - acc.RemoveMultiNodesOnly({relu1, relu1_front, relu2, relu2_front}); - auto out_nodes1 = relu1_front_input->GetOutDataNodes(); - auto out_nodes2 = relu2_front_input->GetOutDataNodes(); - ASSERT_EQ(out_nodes1.size(), 1); - ASSERT_EQ(out_nodes2.size(), 1); - - EXPECT_EQ(out_nodes1.at(0)->GetName(), "add_new"); - EXPECT_EQ(out_nodes2.at(0)->GetName(), "add_new"); - auto out_in_nodes = out->GetInDataNodes(); - EXPECT_EQ(out_in_nodes.size(), 1); - EXPECT_EQ(out_in_nodes.at(0)->GetName(), "add"); - - auto add_new_out_nodes = node->GetOutDataNodes(); - EXPECT_EQ(add_new_out_nodes.size(), 1); - EXPECT_EQ(add_new_out_nodes.at(0)->GetName(), "add"); - EXPECT_EQ(graph->GetDirectNodesSize(), 5); -} - -void LarsV2UpdateFusion(const ComputeGraphPtr &graph) { - FusionTurbo ft(graph); - NodePtr fused_node = graph->FindFirstNodeMatchType("LarsV2"); - - OpDescPtr fused_desc = fused_node->GetOpDesc(); - - // new the square_sum_all node - NodePtr square_sum_all_node = ft.AddNodeOnly(fused_desc->GetName() + "/SquareSumAll", "SquareSumAll"); - NodePtr lars_v2_update_node = ft.AddNodeOnly(fused_desc->GetName() + "/LarsUpdate", "LarsV2Update"); - - auto square_sum_all_op_desc = square_sum_all_node->GetOpDesc(); - - Relations square_sum_input_relation; - square_sum_input_relation.Add(0, {fused_node, 0, PEER}) - .Add(1, {fused_node, 1, PEER}); - - Relations lars_v2_input_relation = {{2, {square_sum_all_node, 0}}, {3, {square_sum_all_node, 1}}, - {0, {fused_node, 0, PEER}}, {1, {fused_node, 1, PEER}}, - {4, {fused_node, 2, PEER}}, {5, {fused_node, 3, PEER}}}; - - Relations lars_v2_output_relation = {0, {fused_node, 0, PEER}}; - FusionTurbo::LinkInput(square_sum_input_relation, square_sum_all_node); - FusionTurbo::LinkOutput(lars_v2_output_relation, lars_v2_update_node); - FusionTurbo::LinkInput(lars_v2_input_relation, lars_v2_update_node); - FusionTurbo::TransferInCtrlEdges({fused_node}, lars_v2_update_node); - FusionTurbo::TransferOutCtrlEdges({fused_node}, lars_v2_update_node); - - ft.RemoveNodeOnly(fused_node); -} - -TEST_F(UTestFusionTurbo2, test_case_4_5) { - ComputeGraphPtr graph = std::make_shared("test1"); - fe::FusionTurbo ft(graph); - auto data0 = ft.AddNodeOnly("data0", "Data"); - auto data0_output = data0->GetOpDesc()->MutableOutputDesc(0); - vector data_0_output_shape = {1, 2, 3, 4}; - data0_output->SetShape(ge::GeShape(data_0_output_shape)); - - auto data1 = ft.AddNodeOnly("data1", "Data"); - auto data1_output = data1->GetOpDesc()->MutableOutputDesc(0); - vector data_1_output_shape = {2, 4, 6, 8}; - data1_output->SetShape(ge::GeShape(data_1_output_shape)); - - auto data2 = ft.AddNodeOnly("data2", "Data"); - auto data2_output = data2->GetOpDesc()->MutableOutputDesc(0); - vector data_2_output_shape = {3, 6, 9, 12}; - data2_output->SetShape(ge::GeShape(data_2_output_shape)); - - auto data3 = ft.AddNodeOnly("data3", "Data"); - auto data3_output = data3->GetOpDesc()->MutableOutputDesc(0); - vector data_3_output_shape = {4, 8, 12, 16}; - data3_output->SetShape(ge::GeShape(data_3_output_shape)); - - auto end = ft.AddNodeOnly("end", "End"); - auto end_input = end->GetOpDesc()->MutableInputDesc(0); - vector end_input_shape = {100}; - end_input->SetShape(ge::GeShape(end_input_shape)); - - auto lars_v2 = ft.AddNodeOnly("lars_v2", "LarsV2"); - fe::Relations input_relation({{0, {data0, 0}}, {1, {data1, 0}}, - {2, {data2, 0}}, {3, {data3, 0}}}); - fe::Relations output_relation; - output_relation.Add(0, {end, 0}); - fe::FusionTurbo::LinkInput(input_relation, lars_v2); - fe::FusionTurbo::LinkOutput(output_relation, lars_v2); - - LarsV2UpdateFusion(graph); - - EXPECT_EQ(graph->GetDirectNodesSize(), 7); - size_t expected_op = 0; - ge::NodePtr square_sum_all; - for (const auto &node: graph->GetDirectNode()) { - if (node->GetType() == "SquareSumAll") { - expected_op++; - auto op_desc = node->GetOpDesc(); - square_sum_all = node; - ASSERT_EQ(op_desc->GetInputsSize(), 2); - auto input0 = op_desc->MutableInputDesc(0); - EXPECT_EQ(input0->GetShape().GetDims(), data_0_output_shape); - - auto input1 = op_desc->MutableInputDesc(1); - EXPECT_EQ(input1->GetShape().GetDims(), data_1_output_shape); - - ASSERT_EQ(op_desc->GetOutputsSize(), 2); - auto output0 = op_desc->MutableOutputDesc(0); - EXPECT_EQ(output0->GetShape(), ge::GeShape()); - - auto output1 = op_desc->MutableOutputDesc(1); - EXPECT_EQ(output1->GetShape(), ge::GeShape()); - - auto input_nodes = node->GetInDataNodes(); - EXPECT_EQ(input_nodes.size(), 2); - EXPECT_EQ(input_nodes.at(0), data0); - EXPECT_EQ(input_nodes.at(1), data1); - } - } - - for (const auto &node: graph->GetDirectNode()) { - if (node->GetType() == "LarsV2Update") { - expected_op++; - auto op_desc = node->GetOpDesc(); - ASSERT_EQ(op_desc->GetInputsSize(), 6); - auto input0 = op_desc->MutableInputDesc(0); - EXPECT_EQ(input0->GetShape().GetDims(), data_0_output_shape); - - auto input1 = op_desc->MutableInputDesc(1); - EXPECT_EQ(input1->GetShape().GetDims(), data_1_output_shape); - - auto input2 = op_desc->MutableInputDesc(2); - EXPECT_EQ(input2->GetShape(), ge::GeShape()); - - auto input3 = op_desc->MutableInputDesc(3); - EXPECT_EQ(input3->GetShape(), ge::GeShape()); - - auto input4 = op_desc->MutableInputDesc(4); - EXPECT_EQ(input4->GetShape().GetDims(), data_2_output_shape); - - auto input5 = op_desc->MutableInputDesc(5); - EXPECT_EQ(input5->GetShape().GetDims(), data_3_output_shape); - - ASSERT_EQ(op_desc->GetOutputsSize(), 1); - auto output = op_desc->MutableOutputDesc(0); - EXPECT_EQ(output->GetShape().GetDims(), end_input_shape); - - auto input_nodes = node->GetInDataNodes(); - EXPECT_EQ(input_nodes.size(), 6); - EXPECT_EQ(input_nodes.at(0), data0); - EXPECT_EQ(input_nodes.at(1), data1); - - EXPECT_EQ(input_nodes.at(2), square_sum_all); - EXPECT_EQ(input_nodes.at(3), square_sum_all); - - EXPECT_EQ(input_nodes.at(4), data2); - EXPECT_EQ(input_nodes.at(5), data3); - - auto output_nodes = node->GetOutDataNodes(); - EXPECT_EQ(output_nodes.size(), 1); - EXPECT_EQ(output_nodes.at(0), end); - } - } - EXPECT_EQ(expected_op, 2); -} - -TEST_F(UTestFusionTurbo2, test_case_multiinone_out_relaitons_empty) { - auto graph = CreateComplexGraph(); - FusionTurbo acc(graph); - string name = "add_new"; - string type = "Add"; - auto node = acc.AddNodeOnly(name, type); - ASSERT_NE(node, nullptr); - - auto relu1 = GetNode(graph, "relu1"); - auto relu2 = GetNode(graph, "relu2"); - auto add = GetNode(graph, "add"); - auto out = GetNode(graph, "output"); - - auto relu1_input = relu1->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - auto relu2_input = relu2->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - Relations input_relations = {{0, {relu1_input, 0}}, - {1, {relu2_input, 0}}}; - Relations output_relations = {}; - Status ret = acc.MultiInOne(node, input_relations, output_relations, {}, false); - ASSERT_EQ(ret, SUCCESS); -} -} diff --git a/tests/ut/register/testcase/graph_fusion_turbo_unittest3.cc b/tests/ut/register/testcase/graph_fusion_turbo_unittest3.cc deleted file mode 100644 index 0475477bd9ef8f19b87a97479d784a8cf0fc33fc..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/graph_fusion_turbo_unittest3.cc +++ /dev/null @@ -1,321 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" - -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/graph_optimizer/fusion_common/fusion_turbo.h" - -#include "external/graph/operator_factory.h" -#include "external/graph/operator_reg.h" -#include "graph/operator_factory_impl.h" -#include "graph/debug/ge_log.h" - -using namespace testing; -using namespace ge; -using namespace fe; - -namespace fe { -REG_OP(Const) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(value, Tensor, Tensor()) - .OP_END_FACTORY_REG(Const); - -REG_OP(Transpose) - .INPUT(x, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .INPUT(shape, TensorType({DT_INT32, DT_INT64})) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(axis, Int, 0) - .ATTR(num_axes, Int, -1) - .OP_END_FACTORY_REG(Transpose); - -REG_OP(Add) - .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OP_END_FACTORY_REG(Add) - -REG_OP(MultiAdd) - .INPUT(x1, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x2, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x3, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .INPUT(x4, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_INT32, DT_INT64, DT_FLOAT16, DT_INT16, - DT_INT8, DT_UINT8, DT_DOUBLE, DT_COMPLEX128, - DT_COMPLEX64, DT_STRING})) - .OP_END_FACTORY_REG(MultiAdd) - -REG_OP(Relu) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, - DT_INT8, DT_INT32, DT_INT16, DT_INT64, - DT_UINT8, DT_UINT16, DT_QINT8})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_DOUBLE, - DT_INT8, DT_INT32, DT_INT16, DT_INT64, - DT_UINT8, DT_UINT16, DT_QINT8})) - .OP_END_FACTORY_REG(Relu) - -REG_OP(End) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .ATTR(peerIndex, Int, 0) - .ATTR(parentOpType, String, "") - .OP_END_FACTORY_REG(End) - -REG_OP(LarsV2Update) - .INPUT(w, TensorType(DT_FLOAT)) - .INPUT(g, TensorType(DT_FLOAT)) - .INPUT(w_square_sum, TensorType(DT_FLOAT)) - .INPUT(g_square_sum, TensorType(DT_FLOAT)) - .INPUT(weight_decay, TensorType(DT_FLOAT)) - .INPUT(learning_rate, TensorType(DT_FLOAT)) - .OUTPUT(g_new, TensorType(DT_FLOAT)) - .ATTR(hyperpara, Float, 0.001) - .ATTR(epsilon, Float, 0.00001) - .ATTR(use_clip, Bool, false) - .OP_END_FACTORY_REG(LarsV2Update) - -REG_OP(SquareSumAll) - .INPUT(x1, TensorType({DT_FLOAT})) - .INPUT(x2, TensorType({DT_FLOAT})) - .OUTPUT(y1, TensorType({DT_FLOAT})) - .OUTPUT(y2, TensorType({DT_FLOAT})) - .OP_END_FACTORY_REG(SquareSumAll) - -REG_OP(LarsV2) - .INPUT(w, TensorType(DT_FLOAT)) - .INPUT(g, TensorType(DT_FLOAT)) - .INPUT(weight_decay, TensorType(DT_FLOAT)) - .INPUT(learning_rate, TensorType(DT_FLOAT)) - .OUTPUT(g_new, TensorType(DT_FLOAT)) - .ATTR(hyperpara, Float, 0.001) - .ATTR(epsilon, Float, 0.00001) - .ATTR(use_clip, Bool, false) - .OP_END_FACTORY_REG(LarsV2) - -class UTestFusionTurbo3 : public testing::Test { - public: - - protected: - - - void SetUp() { - } - - void TearDown() { - } - - ge::NodePtr GetNode(ComputeGraphPtr &graph, const string &name) { - for (auto &node : graph->GetDirectNode()) { - if (node->GetName() == name) { - return node; - } - } - return nullptr; - } - - ComputeGraphPtr CreateComplexGraph() { - ComputeGraphPtr graph = std::make_shared("test1"); - - OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_b); - - op_desc_relu2->AddInputDesc(tensor_desc_a); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_output->AddInputDesc(tensor_desc_b); - op_desc_output->AddInputDesc(tensor_desc_b); - - NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - - GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - - FusionTurbo acc(graph); - auto node_add = acc.InsertNodeAfter("add", "Add", node_relu2, 0, 1); - EXPECT_NE(node_add, nullptr); - Relations rl(0, {node_relu1, 0}); - acc.LinkInput(rl, node_add); - - unique_ptr data(new(std::nothrow) int32_t[4096]); - WeightInfo w(tensor_desc_a, data.get()); - acc.AddWeight(node_relu1, 0, w); - acc.AddWeight(node_relu2, 0, w); - return graph; - } - - ComputeGraphPtr CreateComplexGraph2() { - ComputeGraphPtr graph = std::make_shared("test2"); - - OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_b); - - op_desc_relu2->AddInputDesc(tensor_desc_a); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_output->AddInputDesc(tensor_desc_b); - op_desc_output->AddInputDesc(tensor_desc_b); - - NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - - GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - - FusionTurbo acc(graph); - auto node_add = acc.InsertNodeAfter("add", "Add", node_relu2, 0, 0); - EXPECT_NE(node_add, nullptr); - Relations rl(1, {node_relu1, 0}); - acc.LinkInput(rl, node_add); - - auto relu1_front = acc.InsertNodeBefore("relu1_front", "Relu", node_relu1, 0); - - auto relu2_front = acc.InsertNodeBefore("relu2_front", "Relu", node_relu2, 0); - - auto relu_top = acc.AddNodeOnly("relu_top", "Relu"); - Relations output_relation(0, {{relu1_front, 0}, - {relu2_front, 0}}); - acc.LinkOutput(output_relation, relu_top); - return graph; - } -}; - -TEST_F(UTestFusionTurbo3, test_case_01) { - auto graph = CreateComplexGraph(); - auto relu_node = graph->FindFirstNodeMatchType("Relu"); - bool has_data_out = FusionTurbo::HasOutData(relu_node); - EXPECT_EQ(has_data_out, true); - has_data_out = FusionTurbo::HasOutData(nullptr); - EXPECT_EQ(has_data_out, false); - auto net_out_node = graph->FindFirstNodeMatchType("NetOutput"); - ASSERT_NE(net_out_node, nullptr); - has_data_out = FusionTurbo::HasOutData(net_out_node); - EXPECT_EQ(has_data_out, false); -} - -TEST_F(UTestFusionTurbo3, test_case_02) { - auto graph = CreateComplexGraph(); - FusionTurbo ft(graph); - auto relu_node = graph->FindFirstNodeMatchType("Relu"); - Status ret = ft.RemoveDanglingNode(relu_node); - EXPECT_EQ(ret, FAILED); - - auto net_out_node = graph->FindFirstNodeMatchType("NetOutput"); - ASSERT_NE(net_out_node, nullptr); - ge::GraphUtils::AddEdge(net_out_node->GetOutControlAnchor(), relu_node->GetInControlAnchor()); - ret = ft.RemoveDanglingNode(net_out_node); - EXPECT_EQ(ret, FAILED); - auto remain_net_out_node = graph->FindFirstNodeMatchType("NetOutput"); - EXPECT_TRUE(remain_net_out_node == net_out_node); - - ret = ft.RemoveDanglingNode(net_out_node, true); - EXPECT_EQ(ret, SUCCESS); - remain_net_out_node = graph->FindFirstNodeMatchType("NetOutput"); - EXPECT_EQ(remain_net_out_node, nullptr); -} - -TEST_F(UTestFusionTurbo3, test_case_03) { - auto graph = CreateComplexGraph2(); - ge::NodePtr relu_top = nullptr; - for (const auto &node : graph->GetDirectNode()) { - if (node->GetName() == "relu_top") { - auto out_data_nodes = node->GetOutNodes(); - ASSERT_EQ(out_data_nodes.size(), 2); - EXPECT_EQ(out_data_nodes.at(0)->GetName(), "relu1_front"); - EXPECT_EQ(out_data_nodes.at(1)->GetName(), "relu2_front"); - relu_top = node; - } - } - EXPECT_NE(relu_top, nullptr); - - FusionTurbo ft(graph); - - auto &tensor_desc = relu_top->GetOpDesc()->GetOutputDesc(0); - unique_ptr data(new(std::nothrow) int32_t[4096]); - WeightInfo w(tensor_desc, data.get()); - auto const_node = ft.AddWeightAfter(relu_top, 0, w); - ASSERT_NE(const_node, nullptr); - auto const_out_data_nodes = const_node->GetOutNodes(); - ASSERT_EQ(const_out_data_nodes.size(), 2); - EXPECT_EQ(const_out_data_nodes.at(0)->GetName(), "relu1_front"); - EXPECT_EQ(const_out_data_nodes.at(1)->GetName(), "relu2_front"); - - auto relu_top_out_data_nodes = relu_top->GetOutNodes(); - ASSERT_EQ(relu_top_out_data_nodes.size(), 0); -} -} diff --git a/tests/ut/register/testcase/graph_pass_util_ut.cc b/tests/ut/register/testcase/graph_pass_util_ut.cc deleted file mode 100644 index 48e088a021f5c2e5ddb89bbbd3d7030ed5ce8ae4..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/graph_pass_util_ut.cc +++ /dev/null @@ -1,789 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/op_desc.h" -#include "graph/compute_graph.h" -#include "graph_optimizer/fusion_common/graph_pass_util.h" - -using namespace std; -using namespace ge; - -namespace fe { -class GraphPassUtilUT : public testing::Test { -protected: - void SetUp() {} - - void TearDown() {} -}; - -bool CheckOriginAttr(const std::vector &nodes, std::string pass_name, - GraphPassUtil::OriginOpAttrsVec origin_attrs) { - for (auto node : nodes) { - auto op_desc = node->GetOpDesc(); - std::shared_ptr op_attrs_maps_tmp = - std::make_shared(); - op_attrs_maps_tmp = op_desc->TryGetExtAttr(ge::ATTR_NAME_ORIGIN_OP_ATTRS_MAP, op_attrs_maps_tmp); - if (op_attrs_maps_tmp->find(pass_name) == op_attrs_maps_tmp->cend()) { - return false; - } - auto attrs_in_vec = (*op_attrs_maps_tmp)[pass_name]; - if (attrs_in_vec.size() != origin_attrs.size()) { - return false; - } - for (const auto &origin_attr : origin_attrs) { - bool is_in_vec = false; - for (const auto &attr_in_vec : attrs_in_vec) { - if (origin_attr == attr_in_vec) { - is_in_vec = true; - break; - } - } - if (is_in_vec == false) { - return false; - } - } - } - return true; -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case1) { - EXPECT_NO_THROW( - NodePtr origin_node = nullptr; - NodePtr fusion_node = nullptr; - GraphPassUtil::SetOutputDescAttr(0, 0, origin_node, fusion_node); - ); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case2) { - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr relu2 = std::make_shared("relu2", "Relu"); - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc(shape, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16); - tenosr_desc.SetOriginFormat(FORMAT_NCHW); - tenosr_desc.SetOriginDataType(DT_FLOAT); - relu1->AddInputDesc(tenosr_desc); - relu1->AddOutputDesc(tenosr_desc); - relu2->AddInputDesc(tenosr_desc); - relu2->AddOutputDesc(tenosr_desc); - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr relu2_node = graph->AddNode(relu2); - GraphPassUtil::SetOutputDescAttr(1, 0, relu1_node, relu2_node); - EXPECT_EQ(relu2_node->GetOpDesc()->GetOutputDescPtr(0)->HasAttr(ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME), false); - GraphPassUtil::SetOutputDescAttr(0, 1, relu1_node, relu2_node); - EXPECT_EQ(relu2_node->GetOpDesc()->GetOutputDescPtr(0)->HasAttr(ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME), false); - GraphPassUtil::SetOutputDescAttr(0, 0, relu1_node, relu2_node); - EXPECT_EQ(relu2_node->GetOpDesc()->GetOutputDescPtr(0)->HasAttr(ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME), true); - string origin_name; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - EXPECT_EQ(origin_name, "relu1"); - string origin_dtype; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_dtype); - EXPECT_EQ(origin_dtype, "DT_FLOAT"); - string origin_format; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format); - EXPECT_EQ(origin_format, "NCHW"); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case3) { - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr relu2 = std::make_shared("relu2", "Relu"); - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc(shape, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16); - tenosr_desc.SetOriginFormat(FORMAT_NCHW); - tenosr_desc.SetOriginDataType(DT_FLOAT); - relu1->AddInputDesc(tenosr_desc); - relu1->AddOutputDesc(tenosr_desc); - AttrUtils::SetStr(relu1->MutableOutputDesc(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, "origin_relu1"); - AttrUtils::SetStr(relu1->MutableOutputDesc(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, "DT_DOUBLE"); - AttrUtils::SetStr(relu1->MutableOutputDesc(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, "ND"); - relu2->AddInputDesc(tenosr_desc); - relu2->AddOutputDesc(tenosr_desc); - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr relu2_node = graph->AddNode(relu2); - - GraphPassUtil::SetOutputDescAttr(0, 0, relu1_node, relu2_node); - EXPECT_EQ(relu2_node->GetOpDesc()->GetOutputDescPtr(0)->HasAttr(ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME), true); - string origin_name; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - EXPECT_EQ(origin_name, "origin_relu1"); - string origin_dtype; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_dtype); - EXPECT_EQ(origin_dtype, "DT_DOUBLE"); - string origin_format; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format); - EXPECT_EQ(origin_format, "ND"); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case4) { - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr relu2 = std::make_shared("relu2", "Relu"); - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc(shape, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16); - tenosr_desc.SetOriginFormat(FORMAT_NCHW); - tenosr_desc.SetOriginDataType(DT_FLOAT); - relu1->AddInputDesc(tenosr_desc); - relu1->AddOutputDesc(tenosr_desc); - vector names = {"ori_rule1"}; - AttrUtils::SetListStr(relu1, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names); - AttrUtils::SetStr(relu1->MutableOutputDesc(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, "RESERVED"); - AttrUtils::SetStr(relu1->MutableOutputDesc(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, "RESERVED"); - relu2->AddInputDesc(tenosr_desc); - relu2->AddOutputDesc(tenosr_desc); - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr relu2_node = graph->AddNode(relu2); - - GraphPassUtil::SetOutputDescAttr(0, 0, relu1_node, relu2_node); - EXPECT_EQ(relu2_node->GetOpDesc()->GetOutputDescPtr(0)->HasAttr(ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME), true); - string origin_name; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - EXPECT_EQ(origin_name, "ori_rule1"); - string origin_dtype; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, origin_dtype); - EXPECT_EQ(origin_dtype, "DT_FLOAT"); - string origin_format; - AttrUtils::GetStr(relu2->GetOutputDescPtr(0), ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, origin_format); - EXPECT_EQ(origin_format, "NCHW"); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case5) { - vector dims = {1,2,3,4}; - std::string origin_data_type_str = "RESERVED"; - GeShape shape(dims); - GeTensorDescPtr tensor_desc_ptr = std::make_shared(shape, FORMAT_NCHW, DT_FLOAT); - tensor_desc_ptr->SetDataType((ge::DataType)24); - (void)ge::AttrUtils::SetStr(tensor_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE, "DT_DOUBLE"); - ge::DataType origin_dtype; - origin_dtype = GraphPassUtil::GetDataDumpOriginDataType(tensor_desc_ptr); - EXPECT_EQ(origin_dtype, (ge::DataType)11); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case6) { - ge::ComputeGraphPtr graph = std::make_shared("test"); - ge::OpDescPtr op = std::make_shared("test_op", "TestOp"); - auto node = graph->AddNode(op); - - std::map inner_map; - inner_map["test"] = node; - std::unordered_map> node_map; - node_map["test"] = inner_map; - - NodeTypeMapPtr node_type_map = std::make_shared(node_map); - EXPECT_NO_THROW(GraphPassUtil::AddNodeToNodeTypeMap(node_type_map, "test", node)); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case7) { - ge::ComputeGraphPtr graph = std::make_shared("test"); - ge::OpDescPtr op = std::make_shared("test_op", "TestOp"); - auto node = graph->AddNode(op); - - std::map inner_map; - inner_map["test"] = node; - std::unordered_map> node_map; - node_map["test"] = inner_map; - - NodeTypeMapPtr node_type_map = std::make_shared(node_map); - EXPECT_NO_THROW(GraphPassUtil::RemoveNodeFromNodeTypeMap(node_type_map, "test", node)); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case8) { - ge::ComputeGraphPtr graph = std::make_shared("test"); - ge::OpDescPtr op = std::make_shared("test_op", "TestOp"); - auto node = graph->AddNode(op); - - std::map inner_map; - inner_map["test"] = node; - std::unordered_map> node_map; - node_map["test"] = inner_map; - - NodeTypeMapPtr node_type_map = std::make_shared(node_map); - vector nodes; - EXPECT_NO_THROW(GraphPassUtil::GetNodesFromNodeTypeMap(node_type_map, "test", nodes)); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case9) { - vector dims = {1,2,3,4}; - std::string origin_data_type_str = "RESERVED"; - GeShape shape(dims); - GeTensorDescPtr tensor_desc_ptr = std::make_shared(shape, FORMAT_NCHW, DT_FLOAT); - tensor_desc_ptr->SetDataType((ge::DataType)24); - (void)ge::AttrUtils::SetStr(tensor_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT, "NCHW"); - ge::Format origin_format; - origin_format = GraphPassUtil::GetDataDumpOriginFormat(tensor_desc_ptr); - EXPECT_EQ(origin_format, (ge::Format)0); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case10) { - putenv(const_cast("DUMP_GE_GRAPH=2")); - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr relu2 = std::make_shared("relu2", "Relu"); - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc(shape, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16); - tenosr_desc.SetOriginFormat(FORMAT_NCHW); - tenosr_desc.SetOriginDataType(DT_FLOAT); - relu1->AddInputDesc(tenosr_desc); - relu1->AddOutputDesc(tenosr_desc); - vector names = {"ori_rule1"}; - - relu2->AddInputDesc(tenosr_desc); - relu2->AddOutputDesc(tenosr_desc); - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr relu2_node = graph->AddNode(relu2); - std::vector original_nodes = {relu1_node}; - std::vector fus_nodes = {relu2_node}; - - GraphPassUtil::RecordPassnameAndOriginalAttrs(original_nodes, fus_nodes, "passA"); - GraphPassUtil::OriginOpAttrsVec origin_op_attrs_to_check = {{"relu1", "Relu"}}; - bool oringin_attr_check = false; - oringin_attr_check = CheckOriginAttr(fus_nodes, "passA", origin_op_attrs_to_check); - EXPECT_EQ(oringin_attr_check, true); -} - -TEST_F(GraphPassUtilUT, set_original_op_names_and_types) { - putenv(const_cast("DUMP_GE_GRAPH=2")); - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr relu2 = std::make_shared("relu2", "Relu"); - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc(shape, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16); - tenosr_desc.SetOriginFormat(FORMAT_NCHW); - tenosr_desc.SetOriginDataType(DT_FLOAT); - relu1->AddInputDesc(tenosr_desc); - relu1->AddOutputDesc(tenosr_desc); - vector names = {"ori_rule1"}; - - relu2->AddInputDesc(tenosr_desc); - relu2->AddOutputDesc(tenosr_desc); - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr relu2_node = graph->AddNode(relu2); - vector names_tmp = {"A", "B"}; - vector types_tmp = {"typeA", "typeB"}; - const ge::OpDescPtr relu1_node_op_desc_ptr = relu1_node->GetOpDesc(); - const ge::OpDescPtr relu2_node_op_desc_ptr = relu2_node->GetOpDesc(); - (void)ge::AttrUtils::SetListStr(relu1_node_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, names_tmp); - (void)ge::AttrUtils::SetListStr(relu1_node_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES, types_tmp); - - std::vector original_nodes = {relu1_node}; - - GraphPassUtil::RecordOriginalNames(original_nodes, relu2_node); - vector original_names; - vector original_types; - ge::AttrUtils::GetListStr(relu2_node_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_names); - ge::AttrUtils::GetListStr(relu2_node_op_desc_ptr, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES, original_types); - - vector original_names_check = {"A", "B"}; - vector original_types_check = {"typeA", "typeB"}; - EXPECT_EQ(original_names_check, original_names); - EXPECT_EQ(original_types_check, original_types); -} - -TEST_F(GraphPassUtilUT, set_output_desc_attr_case11) { - putenv(const_cast("DUMP_GE_GRAPH=2")); - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr relu2 = std::make_shared("relu2", "Relu"); - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc(shape, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16); - tenosr_desc.SetOriginFormat(FORMAT_NCHW); - tenosr_desc.SetOriginDataType(DT_FLOAT); - relu1->AddInputDesc(tenosr_desc); - relu1->AddOutputDesc(tenosr_desc); - vector names = {"ori_rule1"}; - std::shared_ptr op_attrs_maps_tmp = - std::make_shared(); - GraphPassUtil::OriginOpAttrsVec origin_op_attrs_vec = {{"nodeA", "typeA"}, {"nodeB", "typeB"}}; - op_attrs_maps_tmp->insert(std::pair("pass_test", origin_op_attrs_vec)); - (void)relu1->SetExtAttr(ge::ATTR_NAME_ORIGIN_OP_ATTRS_MAP, op_attrs_maps_tmp); - vector pass_names = {"pass_test"}; - (void)AttrUtils::SetListStr(relu1, "pass_name", pass_names); - - relu2->AddInputDesc(tenosr_desc); - relu2->AddOutputDesc(tenosr_desc); - ComputeGraphPtr graph = std::make_shared("test"); - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr relu2_node = graph->AddNode(relu2); - std::vector original_nodes = {relu1_node}; - std::vector fus_nodes = {relu2_node}; - - GraphPassUtil::RecordPassnameAndOriginalAttrs(original_nodes, fus_nodes, "passA"); - GraphPassUtil::OriginOpAttrsVec origin_op_attrs_to_check = {{"nodeA", "typeA"}, {"nodeB", "typeB"}}; - bool oringin_attr_check = false; - oringin_attr_check = CheckOriginAttr(fus_nodes, "passA", origin_op_attrs_to_check); - EXPECT_EQ(oringin_attr_check, true); -} - - -void CreateGraph(ComputeGraphPtr &graph, std::vector &original_nodes, - std::vector &fus_nodes) { - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr relu2 = std::make_shared("relu2", "Relu"); - OpDescPtr relu3 = std::make_shared("relu3", "Relu"); - OpDescPtr fusion_op = std::make_shared("fusion", "Fusion"); - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc(shape, ge::FORMAT_NC1HWC0, ge::DT_FLOAT16); - tenosr_desc.SetOriginFormat(FORMAT_NCHW); - tenosr_desc.SetOriginDataType(DT_FLOAT); - relu1->AddInputDesc(tenosr_desc); - relu1->AddOutputDesc(tenosr_desc); - - - relu2->AddInputDesc(tenosr_desc); - relu2->AddOutputDesc(tenosr_desc); - - relu3->AddInputDesc(tenosr_desc); - relu3->AddOutputDesc(tenosr_desc); - - fusion_op->AddInputDesc(tenosr_desc); - fusion_op->AddOutputDesc(tenosr_desc); - - - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr relu2_node = graph->AddNode(relu2); - NodePtr relu3_node = graph->AddNode(relu3); - NodePtr fusion_node = graph->AddNode(fusion_op); - ge::GraphUtils::AddEdge(relu1_node->GetOutDataAnchor(0), relu2_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(relu2_node->GetOutDataAnchor(0), relu3_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(relu3_node->GetOutDataAnchor(0), fusion_node->GetInDataAnchor(0)); - - original_nodes = {relu2_node, relu3_node}; - fus_nodes = {fusion_node}; -} - -TEST_F(GraphPassUtilUT, test_get_back_ward_attr_01) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - bool backward = false; - GraphPassUtil::GetBackWardAttr(original_nodes, backward, BackWardInheritMode::kInheritTrue); - EXPECT_EQ(backward, true); -} - -TEST_F(GraphPassUtilUT, test_get_back_ward_attr_02) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - auto ori_node0 = original_nodes[0]; - auto ori_node1 = original_nodes[1]; - ge::AttrUtils::SetBool(ori_node0->GetOpDesc(), "_backward", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_backward", true); - bool backward = false; - GraphPassUtil::GetBackWardAttr(original_nodes, backward, BackWardInheritMode::kFusedNode); - EXPECT_EQ(backward, true); -} - -TEST_F(GraphPassUtilUT, test_get_back_ward_attr_03) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - auto ori_node0 = original_nodes[0]; - auto ori_node1 = original_nodes[1]; - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_backward", true); - bool backward = false; - GraphPassUtil::GetBackWardAttr(original_nodes, backward, BackWardInheritMode::kInsertNode); - EXPECT_EQ(backward, true); -} - -TEST_F(GraphPassUtilUT, test_get_back_ward_attr_03_1) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - auto ori_node0 = original_nodes[0]; - auto ori_node1 = original_nodes[1]; - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_backward", true); - bool backward = false; - GraphPassUtil::GetBackWardAttr(original_nodes, backward, BackWardInheritMode::kFusedNode); - EXPECT_FALSE(backward); -} - -TEST_F(GraphPassUtilUT, test_get_back_ward_attr_04) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - auto ori_node0 = original_nodes[0]; - auto ori_node1 = original_nodes[1]; - ge::AttrUtils::SetBool(ori_node0->GetOpDesc(), "_backward", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_backward", true); - bool backward = false; - GraphPassUtil::GetBackWardAttr(original_nodes, backward, BackWardInheritMode::kDoNotInherit); - EXPECT_FALSE(backward); -} - -TEST_F(GraphPassUtilUT, test_inherit_attrs_01) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - auto ori_node0 = original_nodes[0]; - auto ori_node1 = original_nodes[1]; - ge::AttrUtils::SetBool(ori_node0->GetOpDesc(), "_backward", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_backward", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_recompute", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_optimizer", true); - ge::AttrUtils::SetInt(ori_node1->GetOpDesc(), ge::ATTR_NAME_KEEP_DTYPE, 1); - ge::AttrUtils::SetStr(ori_node1->GetOpDesc(), ge::ATTR_NAME_OP_COMPILE_STRATEGY, "test"); - - GraphPassUtil::InheritAttrFromOriNodes(original_nodes, fus_nodes, BackWardInheritMode::kFusedNode); - auto fus_op = fus_nodes.at(0)->GetOpDesc(); - - bool backward = false; - ge::AttrUtils::GetBool(fus_op, "_backward", backward); - EXPECT_EQ(backward, true); - - bool recompute = 0; - ge::AttrUtils::GetBool(fus_op, "_recompute", recompute); - EXPECT_EQ(recompute, true); - - bool optimizer = 0; - ge::AttrUtils::GetBool(fus_op, "_optimizer", optimizer); - EXPECT_EQ(optimizer, true); - - int64_t keep_dtype = 0; - ge::AttrUtils::GetInt(fus_op, ge::ATTR_NAME_KEEP_DTYPE, keep_dtype); - EXPECT_EQ(keep_dtype, 1); - - string strategy = ""; - ge::AttrUtils::GetStr(fus_op, ge::ATTR_NAME_OP_COMPILE_STRATEGY, strategy); - EXPECT_EQ(strategy, "test"); -} - - -TEST_F(GraphPassUtilUT, test_inherit_attrs_02) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - auto ori_node0 = original_nodes[0]; - auto ori_node1 = original_nodes[1]; - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_backward", true); - - GraphPassUtil::InheritAttrFromOriNodes(original_nodes, fus_nodes, BackWardInheritMode::kFusedNode); - auto fus_op = fus_nodes.at(0)->GetOpDesc(); - - bool backward = false; - ge::AttrUtils::GetBool(fus_op, "_backward", backward); - EXPECT_TRUE(backward == false); - - - EXPECT_EQ(fus_op->HasAttr("_recompute"), false); - EXPECT_EQ(fus_op->HasAttr("_optimizer"), false); - EXPECT_EQ(fus_op->HasAttr(ge::ATTR_NAME_KEEP_DTYPE), false); - EXPECT_EQ(fus_op->HasAttr(ge::ATTR_NAME_OP_COMPILE_STRATEGY), false); -} - - -TEST_F(GraphPassUtilUT, test_inherit_attrs_03) { - ComputeGraphPtr graph = std::make_shared("test"); - std::vector original_nodes; - std::vector fus_nodes; - CreateGraph(graph, original_nodes, fus_nodes); - auto ori_node0 = original_nodes[0]; - auto ori_node1 = original_nodes[1]; - ge::AttrUtils::SetBool(ori_node0->GetOpDesc(), "_backward", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_backward", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_recompute", true); - ge::AttrUtils::SetBool(ori_node1->GetOpDesc(), "_optimizer", true); - - auto original_nodes_reverse_order = {ori_node1, ori_node0}; - GraphPassUtil::InheritAttrFromOriNodes(original_nodes_reverse_order, fus_nodes, BackWardInheritMode::kFusedNode); - auto fus_op = fus_nodes.at(0)->GetOpDesc(); - - bool backward = false; - ge::AttrUtils::GetBool(fus_op, "_backward", backward); - EXPECT_EQ(backward, true); - EXPECT_EQ(fus_op->HasAttr("_recompute"), true); - EXPECT_EQ(fus_op->HasAttr("_optimizer"), true); - - bool recompute = false; - ge::AttrUtils::GetBool(fus_op, "_recompute", recompute); - EXPECT_EQ(recompute, true); - - bool optimizer = false; - ge::AttrUtils::GetBool(fus_op, "_optimizer", optimizer); - EXPECT_EQ(optimizer, true); - - EXPECT_EQ(fus_op->HasAttr(ge::ATTR_NAME_KEEP_DTYPE), false); - EXPECT_EQ(fus_op->HasAttr(ge::ATTR_NAME_OP_COMPILE_STRATEGY), false); -} - - // N -> 1 -TEST_F(GraphPassUtilUT, test_inherit_attrs_04) { - auto graph = std::make_shared("test"); - ge::OpDescPtr ori_op_desc1 = std::make_shared("node1", "Relu"); - ge::OpDescPtr ori_op_desc2 = std::make_shared("node2", "Relu"); - ge::OpDescPtr fus_op_desc = std::make_shared("fusion", "Fusion"); - - ge::NodePtr ori_node1 = graph->AddNode(ori_op_desc1); - ge::NodePtr ori_node2 = graph->AddNode(ori_op_desc2); - ge::NodePtr fus_node = graph->AddNode(fus_op_desc); - - std::vector ori_nodes = {ori_node1, ori_node2}; - std::vector fus_nodes = {fus_node}; - - ge::AttrUtils::SetInt(ori_op_desc1, "_op_custom_impl_mode_enum", 0x20); - ge::AttrUtils::SetInt(ori_op_desc2, "_op_custom_impl_mode_enum", 0x40); - - GraphPassUtil::InheritAttrFromOriNodes(ori_nodes, fus_nodes, BackWardInheritMode::kFusedNode); - - int64_t op_impl_mode = -1; - ge::AttrUtils::GetInt(fus_op_desc, "_op_custom_impl_mode_enum", op_impl_mode); - EXPECT_EQ(op_impl_mode, 0x40); -} - - -// 1 -> N -TEST_F(GraphPassUtilUT, test_inherit_attrs_05) { - auto graph = std::make_shared("test"); - ge::OpDescPtr ori_op_desc = std::make_shared("node", "Relu"); - ge::OpDescPtr fus_op_desc1 = std::make_shared("node2", "Relu"); - ge::OpDescPtr fus_op_desc2 = std::make_shared("node1", "Relu"); - - ge::NodePtr ori_node = graph->AddNode(ori_op_desc); - ge::NodePtr fus_node1 = graph->AddNode(fus_op_desc1); - ge::NodePtr fus_node2 = graph->AddNode(fus_op_desc2); - - std::vector ori_nodes = {ori_node}; - std::vector fus_nodes = {fus_node1, fus_node2}; - - ge::AttrUtils::SetInt(ori_op_desc, "_op_custom_impl_mode_enum", 0x10); - - GraphPassUtil::InheritAttrFromOriNodes(ori_nodes, fus_nodes, BackWardInheritMode::kFusedNode); - - int64_t op_impl_mode = -1; - ge::AttrUtils::GetInt(fus_op_desc1, "_op_custom_impl_mode_enum", op_impl_mode); - EXPECT_EQ(op_impl_mode, 0x10); - ge::AttrUtils::GetInt(fus_op_desc2, "_op_custom_impl_mode_enum", op_impl_mode); - EXPECT_EQ(op_impl_mode, 0x10); -} - -// N -> N -TEST_F(GraphPassUtilUT, test_inherit_attrs_06) { - auto graph = std::make_shared("test"); - ge::OpDescPtr ori_op_desc1 = std::make_shared("node1", "Relu"); - ge::OpDescPtr ori_op_desc2 = std::make_shared("node2", "Relu"); - ge::OpDescPtr fus_op_desc1 = std::make_shared("node2", "Fusion"); - ge::OpDescPtr fus_op_desc2 = std::make_shared("node1", "Fusion"); - - ge::NodePtr ori_node1 = graph->AddNode(ori_op_desc1); - ge::NodePtr ori_node2 = graph->AddNode(ori_op_desc2); - ge::NodePtr fus_node1 = graph->AddNode(fus_op_desc1); - ge::NodePtr fus_node2 = graph->AddNode(fus_op_desc2); - - std::vector ori_nodes = {ori_node1, ori_node2}; - std::vector fus_nodes = {fus_node1, fus_node2}; - - ge::AttrUtils::SetInt(ori_op_desc1, "_op_custom_impl_mode_enum", 0x4); - ge::AttrUtils::SetInt(ori_op_desc2, "_op_custom_impl_mode_enum", 0x2); - - GraphPassUtil::InheritAttrFromOriNodes(ori_nodes, fus_nodes, BackWardInheritMode::kFusedNode); - - int64_t op_impl_mode = -1; - ge::AttrUtils::GetInt(fus_op_desc1, "_op_custom_impl_mode_enum", op_impl_mode); - EXPECT_EQ(op_impl_mode, 0x2); - ge::AttrUtils::GetInt(fus_op_desc2, "_op_custom_impl_mode_enum", op_impl_mode); - EXPECT_EQ(op_impl_mode, 0x4); -} - -// N1 -> N2 -TEST_F(GraphPassUtilUT, test_inherit_attrs_07) { - auto graph = std::make_shared("test"); - ge::OpDescPtr ori_op_desc1 = std::make_shared("node1", "Relu"); - ge::OpDescPtr ori_op_desc2 = std::make_shared("node2", "Relu"); - ge::OpDescPtr ori_op_desc3 = std::make_shared("node3", "Relu"); - ge::OpDescPtr fus_op_desc1 = std::make_shared("node1", "Relu"); - ge::OpDescPtr fus_op_desc2 = std::make_shared("Fusion", "Fusion"); - - ge::NodePtr ori_node1 = graph->AddNode(ori_op_desc1); - ge::NodePtr ori_node2 = graph->AddNode(ori_op_desc2); - ge::NodePtr ori_node3 = graph->AddNode(ori_op_desc3); - ge::NodePtr fus_node1 = graph->AddNode(fus_op_desc1); - ge::NodePtr fus_node2 = graph->AddNode(fus_op_desc2); - - std::vector ori_nodes = {ori_node1, ori_node2, ori_node3}; - std::vector fus_nodes = {fus_node1, fus_node2}; - - ge::AttrUtils::SetInt(ori_op_desc1, "_op_custom_impl_mode_enum", 0x8); - ge::AttrUtils::SetInt(ori_op_desc2, "_op_custom_impl_mode_enum", 0x4); - ge::AttrUtils::SetInt(ori_op_desc3, "_op_custom_impl_mode_enum", 0x2); - - GraphPassUtil::InheritAttrFromOriNodes(ori_nodes, fus_nodes, BackWardInheritMode::kFusedNode); - - int64_t op_impl_mode = -1; - ge::AttrUtils::GetInt(fus_op_desc1, "_op_custom_impl_mode_enum", op_impl_mode); - EXPECT_EQ(op_impl_mode, 0x8); - ge::AttrUtils::GetInt(fus_op_desc2, "_op_custom_impl_mode_enum", op_impl_mode); - EXPECT_EQ(op_impl_mode, 0x4); -} - -// N -> 1 -TEST_F(GraphPassUtilUT, test_inherit_attrs_08) { - auto graph = std::make_shared("test"); - ge::OpDescPtr ori_op_desc1 = std::make_shared("node1", "Relu"); - ge::OpDescPtr ori_op_desc2 = std::make_shared("node2", "Relu"); - ge::OpDescPtr fus_op_desc = std::make_shared("fusion", "Fusion"); - - ge::NodePtr ori_node1 = graph->AddNode(ori_op_desc1); - ge::NodePtr ori_node2 = graph->AddNode(ori_op_desc2); - ge::NodePtr fus_node = graph->AddNode(fus_op_desc); - - std::vector ori_nodes = {ori_node1, ori_node2}; - std::vector fus_nodes = {fus_node}; - - ge::AttrUtils::SetInt(ori_op_desc1, "_op_impl_mode_enum", 0x20); - ge::AttrUtils::SetInt(ori_op_desc2, "_op_impl_mode_enum", 0x40); - - GraphPassUtil::InheritAttrFromOriNodes(ori_nodes, fus_nodes, BackWardInheritMode::kFusedNode); - - int64_t op_impl_mode = -1; - ge::AttrUtils::GetInt(fus_op_desc, "_op_impl_mode_enum", op_impl_mode); - EXPECT_TRUE(op_impl_mode == -1); -} - -TEST_F(GraphPassUtilUT, test_set_pair_tensor_attr) { - auto graph = std::make_shared("test"); - ge::OpDescPtr op_desc1 = std::make_shared("node1", "Relu"); - ge::OpDescPtr op_desc2 = std::make_shared("node2", "Relu"); - ge::OpDescPtr op_desc3 = std::make_shared("node3", "Relu"); - ge::OpDescPtr op_desc4 = std::make_shared("node4", "Relu"); - - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc0(shape, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tenosr_desc0.SetOriginFormat(FORMAT_NCHW); - tenosr_desc0.SetOriginDataType(DT_FLOAT); - - vector dim2 = {16, 1, 1, 2}; - GeShape shape2(dim2); - GeTensorDesc tenosr_desc1(shape2, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tenosr_desc1.SetOriginFormat(FORMAT_NCHW); - tenosr_desc1.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tenosr_desc2(tenosr_desc1); - - op_desc1->AddInputDesc(tenosr_desc2); - op_desc1->AddOutputDesc(tenosr_desc1); - - op_desc2->AddInputDesc(tenosr_desc0); - op_desc2->AddOutputDesc(tenosr_desc2); - - op_desc3->AddInputDesc(tenosr_desc2); - op_desc3->AddOutputDesc(tenosr_desc1); - - op_desc4->AddInputDesc(tenosr_desc2); - op_desc4->AddOutputDesc(tenosr_desc1); - - ge::NodePtr node1 = graph->AddNode(op_desc1); - ge::NodePtr node2 = graph->AddNode(op_desc2); - ge::NodePtr node3 = graph->AddNode(op_desc3); - ge::NodePtr node4 = graph->AddNode(op_desc4); - - ge::GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node4->GetInDataAnchor(0)); - std::map attr_val; - attr_val["_tensor_memory_scope"] = 2; - GraphPassUtil::SetPairTensorAttr(node2, 0, attr_val); - GraphPassUtil::SetPairTensorAttr(node2, 0, attr_val, false); - int64_t scope_0 = -1; - int64_t scope_1 = -1; - int64_t scope_2 = -1; - int64_t scope_3 = -1; - int64_t scope_4 = -1; - (void)ge::AttrUtils::GetInt(op_desc1->MutableOutputDesc(0), "_tensor_memory_scope", scope_0); - (void)ge::AttrUtils::GetInt(op_desc2->MutableInputDesc(0), "_tensor_memory_scope", scope_1); - (void)ge::AttrUtils::GetInt(op_desc2->MutableOutputDesc(0), "_tensor_memory_scope", scope_2); - (void)ge::AttrUtils::GetInt(op_desc3->MutableInputDesc(0), "_tensor_memory_scope", scope_3); - (void)ge::AttrUtils::GetInt(op_desc4->MutableInputDesc(0), "_tensor_memory_scope", scope_4); - EXPECT_EQ(scope_0, 2); - EXPECT_EQ(scope_1, 2); - EXPECT_EQ(scope_2, 2); - EXPECT_EQ(scope_3, 2); - EXPECT_EQ(scope_4, 2); -} - -TEST_F(GraphPassUtilUT, test_set_pair_tensor_attr_with_ge_local) { - auto graph = std::make_shared("test"); - ge::OpDescPtr op_desc1 = std::make_shared("node1", "Relu"); - ge::OpDescPtr op_desc2 = std::make_shared("node2", "Reshape"); - ge::OpDescPtr op_desc3 = std::make_shared("node3", "Relu"); - ge::OpDescPtr op_desc4 = std::make_shared("node4", "Relu"); - - vector dim = {4, 4, 1, 4}; - GeShape shape(dim); - GeTensorDesc tenosr_desc0(shape, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tenosr_desc0.SetOriginFormat(FORMAT_NCHW); - tenosr_desc0.SetOriginDataType(DT_FLOAT); - - vector dim2 = {16, 1, 1, 2}; - GeShape shape2(dim2); - GeTensorDesc tenosr_desc1(shape2, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tenosr_desc1.SetOriginFormat(FORMAT_NCHW); - tenosr_desc1.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tenosr_desc2(tenosr_desc1); - - op_desc1->AddInputDesc(tenosr_desc2); - op_desc1->AddOutputDesc(tenosr_desc1); - - op_desc2->AddInputDesc(tenosr_desc0); - op_desc2->AddOutputDesc(tenosr_desc2); - - op_desc3->AddInputDesc(tenosr_desc2); - op_desc3->AddOutputDesc(tenosr_desc1); - - op_desc4->AddInputDesc(tenosr_desc2); - op_desc4->AddOutputDesc(tenosr_desc1); - - ge::NodePtr node1 = graph->AddNode(op_desc1); - ge::NodePtr node2 = graph->AddNode(op_desc2); - ge::NodePtr node3 = graph->AddNode(op_desc3); - ge::NodePtr node4 = graph->AddNode(op_desc4); - - ge::GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node4->GetInDataAnchor(0)); - std::map attr_val; - attr_val["_tensor_memory_scope"] = 2; - GraphPassUtil::SetPairTensorAttr(node3, 0, attr_val); - int64_t scope_0 = -1; - int64_t scope_1 = -1; - int64_t scope_2 = -1; - int64_t scope_3 = -1; - int64_t scope_4 = -1; - (void)ge::AttrUtils::GetInt(op_desc1->MutableOutputDesc(0), "_tensor_memory_scope", scope_0); - (void)ge::AttrUtils::GetInt(op_desc2->MutableInputDesc(0), "_tensor_memory_scope", scope_1); - (void)ge::AttrUtils::GetInt(op_desc2->MutableOutputDesc(0), "_tensor_memory_scope", scope_2); - (void)ge::AttrUtils::GetInt(op_desc3->MutableInputDesc(0), "_tensor_memory_scope", scope_3); - (void)ge::AttrUtils::GetInt(op_desc4->MutableInputDesc(0), "_tensor_memory_scope", scope_4); - EXPECT_EQ(scope_0, 2); - EXPECT_TRUE(scope_1 == -1); - EXPECT_TRUE(scope_2 == -1); - EXPECT_EQ(scope_3, 2); - EXPECT_EQ(scope_4, 2); -} -} diff --git a/tests/ut/register/testcase/hidden_input_registry_unittest.cc b/tests/ut/register/testcase/hidden_input_registry_unittest.cc deleted file mode 100644 index f61f45788aa5575cb5cb910966b1cb5dc165f2f6..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/hidden_input_registry_unittest.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/op_desc.h" -#include "register/hidden_input_func_registry.h" -#include -#include -namespace ge { -ge::graphStatus HcomHiddenInputFunc(const ge::OpDescPtr &op_desc, void *&addr) { - addr = reinterpret_cast(0xf1); - return ge::GRAPH_SUCCESS; -} -class HiddenInputFuncRegistryUnittest : public testing::Test {}; - -TEST_F(HiddenInputFuncRegistryUnittest, HcomHiddenFuncRegisterSuccess_Test) { - EXPECT_EQ(ge::HiddenInputFuncRegistry::GetInstance().FindHiddenInputFunc(ge::HiddenInputType::HCOM), nullptr); - REG_HIDDEN_INPUT_FUNC(ge::HiddenInputType::HCOM, HcomHiddenInputFunc); - ge::GetHiddenAddr func = nullptr; - func = ge::HiddenInputFuncRegistry::GetInstance().FindHiddenInputFunc(ge::HiddenInputType::HCOM); - EXPECT_EQ(func, HcomHiddenInputFunc); - const ge::OpDescPtr op_desc = std::make_shared(); - void *res = nullptr; - EXPECT_EQ(func(op_desc, res), ge::GRAPH_SUCCESS); - EXPECT_EQ(reinterpret_cast(res), 0xf1); -} -} // namespace ge diff --git a/tests/ut/register/testcase/hidden_inputs_registry_unittest.cc b/tests/ut/register/testcase/hidden_inputs_registry_unittest.cc deleted file mode 100644 index 0fabcef420b10d4b931773d494b7faa8017483c4..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/hidden_inputs_registry_unittest.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/op_desc.h" -#include "register/hidden_inputs_func_registry.h" -#include -#include -namespace ge { -ge::graphStatus HcomHiddenInputsFunc(const ge::OpDescPtr &op_desc, std::vector &addr) { - addr.push_back(reinterpret_cast(0xf1)); - return ge::GRAPH_SUCCESS; -} -class HiddenInputsFuncRegistryUnittest : public testing::Test {}; - -TEST_F(HiddenInputsFuncRegistryUnittest, HcomHiddenFuncRegisterSuccess_Test) { - EXPECT_EQ(ge::HiddenInputsFuncRegistry::GetInstance().FindHiddenInputsFunc(ge::HiddenInputsType::HCOM), nullptr); - REG_HIDDEN_INPUTS_FUNC(ge::HiddenInputsType::HCOM, HcomHiddenInputsFunc); - ge::GetHiddenAddrs func = nullptr; - func = ge::HiddenInputsFuncRegistry::GetInstance().FindHiddenInputsFunc(ge::HiddenInputsType::HCOM); - EXPECT_EQ(func, HcomHiddenInputsFunc); - const ge::OpDescPtr op_desc = std::make_shared(); - std::vector res; - EXPECT_EQ(func(op_desc, res), ge::GRAPH_SUCCESS); - EXPECT_EQ(reinterpret_cast(res[0U]), 0xf1); -} -} // namespace ge diff --git a/tests/ut/register/testcase/infer_axis_slice_registry_unittest.cc b/tests/ut/register/testcase/infer_axis_slice_registry_unittest.cc deleted file mode 100644 index ad213ab5f1c284c49be71a20d3dc81dd6cfe3933..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/infer_axis_slice_registry_unittest.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "inc/register/infer_axis_slice_registry.h" -#include "graph/operator_factory_impl.h" -#include "inc/external/graph/operator_reg.h" - -namespace ge { -REG_OP(test) - .OP_END_FACTORY_REG(test) -} //namespace ge - -class UtestInferAxisSliceRegister : public testing::Test { -protected: - void SetUp() {} - void TearDown() {} -}; - -IMPLEMT_COMMON_INFER_AXIS_SLICE(InferAxisSliceFunc) { - return ge::GRAPH_SUCCESS; -} -IMPLEMT_COMMON_INFER_AXIS_TYPE_INFO(InferAxisTypeInfoFunc) { - return ge::GRAPH_SUCCESS; -} - -TEST_F(UtestInferAxisSliceRegister, InferAxisSliceFuncRegister_success) { - INFER_AXIS_TYPE_INFO_REG(test, InferAxisTypeInfoFunc); - EXPECT_NE(ge::OperatorFactoryImpl::operator_infer_axis_type_info_funcs_->find("test"), - ge::OperatorFactoryImpl::operator_infer_axis_type_info_funcs_->end()); - - INFER_AXIS_SLICE_FUNC_REG(test, InferAxisSliceFunc); - EXPECT_NE(ge::OperatorFactoryImpl::operator_infer_axis_slice_funcs_->find("test"), - ge::OperatorFactoryImpl::operator_infer_axis_slice_funcs_->end()); - - ge::InferAxisTypeInfoFunc infer_axis_type_info_func = ge::OperatorFactoryImpl::GetInferAxisTypeInfoFunc("test"); - EXPECT_NE(infer_axis_type_info_func, nullptr); - ge::InferAxisSliceFunc infer_axis_slice_func = ge::OperatorFactoryImpl::GetInferAxisSliceFunc("test"); - EXPECT_NE(infer_axis_slice_func, nullptr); -} - -TEST_F(UtestInferAxisSliceRegister, AxisTypeInfo_success) { - ge::AxisTypeInfo axis_type_info; - ge::CutInfo output_cut_info({0U, {0}}); - axis_type_info.AddOutputCutInfo(output_cut_info); - ge::CutInfo input_cut_info({0U, {0}}); - axis_type_info.AddInputCutInfo(input_cut_info); - - ge::CutInfo output_cut_dim; - ge::CutInfo input_cut_dim; - EXPECT_EQ(axis_type_info.GetInputCutInfo(0U, input_cut_dim), ge::GRAPH_SUCCESS); - EXPECT_EQ(axis_type_info.GetOutputCutInfo(0U, output_cut_dim), ge::GRAPH_SUCCESS); - EXPECT_EQ(input_cut_dim.first, input_cut_info.first); - EXPECT_EQ(output_cut_dim.first, output_cut_info.first); -} diff --git a/tests/ut/register/testcase/infer_data_slice_registry_unittest.cc b/tests/ut/register/testcase/infer_data_slice_registry_unittest.cc deleted file mode 100644 index d72d8f07e99a2e12ab1f0234d6a14508dbd39b10..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/infer_data_slice_registry_unittest.cc +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "inc/register/infer_data_slice_registry.h" -#include "graph/operator_factory_impl.h" -namespace ge { - class UtestInferDataSliceRegister : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} - }; - -TEST_F(UtestInferDataSliceRegister, InferDataSliceFuncRegister_success) { - InferDataSliceFunc infer_data_slice_func; - ge::InferDataSliceFuncRegister("test", infer_data_slice_func); - EXPECT_NE(OperatorFactoryImpl::operator_infer_data_slice_funcs_->find("test"), OperatorFactoryImpl::operator_infer_data_slice_funcs_->end()); -} -} diff --git a/tests/ut/register/testcase/inference_rule_unittest.cc b/tests/ut/register/testcase/inference_rule_unittest.cc deleted file mode 100644 index d583947ec375c11b1299f82cad0a50e7a826713f..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/inference_rule_unittest.cc +++ /dev/null @@ -1,308 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/operator_reg.h" -#include "graph/debug/ge_log.h" -#include "register/shape_inference.h" -#include "register/op_impl_space_registry.h" -#include "register/op_impl_registry_base.h" - -using Json = nlohmann::json; -using namespace gert; - -namespace ge { -REG_OP(RuleInferOp) - .DYNAMIC_INPUT(x, TensorType::ALL()) - .DYNAMIC_OUTPUT(y, TensorType::ALL()) - .OP_END_FACTORY_REG(RuleInferOp); -} // namespace ge - -namespace { -class CtxMaker { - public: - CtxMaker() { - json["shape"]["inputs"] = Json::array(); - json["shape"]["outputs"] = Json::array(); - json["dtype"] = Json::array(); - } - - CtxMaker &Input(const Json::array_t &input, const std::initializer_list runtime_input) { - json["shape"]["inputs"].push_back(input); - compile_inputs.emplace_back(NewShape()); - runtime_inputs.emplace_back(NewShape(runtime_input)); - auto &compile_input = compile_inputs.back()->MutableOriginShape(); - compile_input.SetDimNum(runtime_input.size()); - for (size_t i = 0; i < runtime_input.size(); ++i) { - const auto &dim = input[i]; - if (dim.is_string()) { - compile_input.SetDim(i, -1); - } else if (dim.is_number_integer()) { - const int64_t dim_value = dim.get(); - compile_input.SetDim(i, dim_value); - } else { - compile_input.SetDim(i, -3); - } - } - return *this; - } - - CtxMaker &ValueInput(const Json::array_t &input, const std::initializer_list runtime_input, - ge::DataType dtype) { - json["shape"]["inputs"].push_back(input); - compile_inputs.emplace_back(NewTensor(runtime_input, dtype)); - runtime_inputs.emplace_back(NewTensor(runtime_input, dtype)); - return *this; - } - - CtxMaker &NullInput() { - json["shape"]["inputs"].push_back(nullptr); - compile_inputs.emplace_back(nullptr); - runtime_inputs.emplace_back(nullptr); - return *this; - } - - CtxMaker &Output(const Json::array_t &output) { - json["shape"]["outputs"].push_back(output); - compile_outputs.emplace_back(NewShape()); - runtime_outputs.emplace_back(NewShape()); - return *this; - } - - CtxMaker &Dtypes(const Json::array_t &dtypes) { - json["dtype"] = dtypes; - output_dtypes.resize(dtypes.size(), ge::DataType::DT_UNDEFINED); - for (auto &output_dtype : output_dtypes) { - ctx_dtypes.emplace_back(&output_dtype); - } - return *this; - } - - std::string Str() const { - return json.dump(); - } - - void Build(bool with_rule = true) { - const auto rule_op = std::make_shared("op"); - rule_op->create_dynamic_input_x(compile_inputs.size()); - rule_op->create_dynamic_output_y(compile_outputs.size()); - for (size_t i = 0; i < compile_inputs.size(); ++i) { - if (compile_inputs[i] == nullptr) { - rule_op->UpdateDynamicInputDesc("x", i, ge::TensorDesc()); - continue; - } - auto &storage_shape = compile_inputs[i]->MutableOriginShape(); - std::vector dims; - dims.reserve(storage_shape.GetDimNum()); - for (size_t j = 0; j < storage_shape.GetDimNum(); ++j) { - dims.push_back(storage_shape.GetDim(j)); - } - rule_op->UpdateDynamicInputDesc("x", i, ge::TensorDesc(ge::Shape(dims), ge::FORMAT_ND, ge::DT_FLOAT16)); - } - desc = ge::OpDescUtils::GetOpDescFromOperator(*rule_op); - if (with_rule) { - ge::AttrUtils::SetStr(desc, "_inference_rule", Str()); - } - op = rule_op; - - std::vector inputs; - std::vector outputs; - inputs.reserve(compile_inputs.size()); - for (auto &input : compile_inputs) { - inputs.emplace_back(input); - } - outputs.reserve(compile_outputs.size()); - for (auto &output : compile_outputs) { - outputs.emplace_back(output); - } - - std::vector rt_inputs; - std::vector rt_outputs; - rt_inputs.reserve(runtime_inputs.size()); - for (auto &input : runtime_inputs) { - rt_inputs.emplace_back(input); - } - rt_outputs.reserve(runtime_outputs.size()); - for (auto &output : runtime_outputs) { - rt_outputs.emplace_back(output); - } - } - - ge::OpDescPtr OpDesc() const { - return desc; - } - - ge::Operator &Operator() const { - return *op; - } - - StorageShape *NewShape() { - holders.emplace_back(std::make_shared()); - return holders.back().get(); - } - - StorageShape *NewTensor(const std::initializer_list &runtime_input, ge::DataType dtype) { - values.emplace_back(std::shared_ptr(malloc(sizeof(int64_t) * runtime_input.size()), std::free)); - auto shape = StorageShape({static_cast(runtime_input.size())}, {static_cast(runtime_input.size())}); - tensor_holders.emplace_back(std::make_shared(shape, StorageFormat(), kOnHost, dtype, values.back().get())); - if (dtype == ge::DT_INT32) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = static_cast(dim); - } - } else if (dtype == ge::DT_INT64) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = dim; - } - } else if (dtype == ge::DT_UINT32) { - const auto data = tensor_holders.back()->GetData(); - size_t i = 0; - for (const auto dim : runtime_input) { - data[i++] = static_cast(dim); - } - } - return reinterpret_cast(tensor_holders.back().get()); - } - - StorageShape *NewShape(const std::initializer_list &runtime_input) { - holders.emplace_back(std::make_shared(runtime_input, runtime_input)); - return holders.back().get(); - } - - Json json; - std::vector compile_inputs; - std::vector runtime_inputs; - std::vector compile_outputs; - std::vector runtime_outputs; - - std::vector> holders; - - std::vector> values; - std::vector> tensor_holders; - - std::vector ctx_dtypes; - std::vector output_dtypes; - - std::shared_ptr op = nullptr; - ge::OpDescPtr desc = nullptr; -}; -} // namespace - -class InferenceRuleUtest : public testing::Test { - protected: - void SetUp() override { - // construct op impl registry - const auto space_registry = std::make_shared(); - const auto registry_holder = std::make_shared(); - const auto funcs = gert::OpImplRegistry::GetInstance().CreateOrGetOpImpl("RuleInferOp"); - registry_holder->AddTypesToImpl("RuleInferOp", funcs); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - } - - void TearDown() override {} - - static std::string ShapeEqual(Shape *shape, std::initializer_list dims) { - std::stringstream ss; - if (shape == nullptr) { - return "shape == nullptr"; - } - if (shape->GetDimNum() != dims.size()) { - ss << "dim num not equal, expect " << dims.size() << ", got " << shape->GetDimNum(); - return ss.str(); - } - for (size_t i = 0; i < dims.size(); ++i) { - if (shape->GetDim(i) != *(dims.begin() + i)) { - ss << "dim[" << i << "] not equal, expect " << *(dims.begin() + i) << ", got " << shape->GetDim(i); - return ss.str(); - } - } - return ""; - } - - static std::string ShapeEqual(const ge::GeShape &shape, std::initializer_list dims) { - std::stringstream ss; - if (shape.GetDimNum() != dims.size()) { - ss << "dim num not equal, expect " << dims.size() << ", got " << shape.GetDimNum(); - return ss.str(); - } - for (size_t i = 0; i < dims.size(); ++i) { - if (shape.GetDim(i) != *(dims.begin() + i)) { - ss << "dim[" << i << "] not equal, expect " << *(dims.begin() + i) << ", got " << shape.GetDim(i); - return ss.str(); - } - } - return ""; - } -}; - -TEST_F(InferenceRuleUtest, CalledByInferShapeOnCompile) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_EQ(InferShapeOnCompile(ctx_maker.Operator(), desc), ge::GRAPH_SUCCESS); - ASSERT_EQ(ShapeEqual(desc->GetOutputDesc(0).GetShape(), {-1}), ""); -} - -TEST_F(InferenceRuleUtest, CalledByInferShapeOnCompileNoRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Build(false); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferShapeOnCompile(ctx_maker.Operator(), desc), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CalledByInferShapeOnCompileInvalidRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0+s4"}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferShapeOnCompile(ctx_maker.Operator(), desc), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CalledByInferDtypeOnCompile) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Dtypes({ge::DT_FLOAT16}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_EQ(InferDataTypeOnCompile(desc), ge::GRAPH_SUCCESS); - ASSERT_EQ(desc->GetOutputDesc(0).GetDataType(), ge::DT_FLOAT16); -} - -TEST_F(InferenceRuleUtest, CalledByInferDtypeOnCompileNoRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Dtypes({ge::DT_FLOAT16}).Build(false); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferDataTypeOnCompile(desc), ge::GRAPH_SUCCESS); -} - -TEST_F(InferenceRuleUtest, CalledByInferDtypeOnCompileInvalidRule) { - CtxMaker ctx_maker; - ctx_maker.Input({"s0"}, {32}).Output({"s0"}).Dtypes({ge::DT_UNDEFINED}).Build(); - - const auto desc = ctx_maker.OpDesc(); - ASSERT_NE(InferDataTypeOnCompile(desc), ge::GRAPH_SUCCESS); -} diff --git a/tests/ut/register/testcase/kernel_registry_impl_unittest.cc b/tests/ut/register/testcase/kernel_registry_impl_unittest.cc deleted file mode 100644 index 7b0135937e1e56a5bcf31cdc489421e94c3b2385..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/kernel_registry_impl_unittest.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/kernel_registry_impl.h" -#include -namespace gert { -namespace { -ge::graphStatus TestOutputCreator(const ge::FastNode *, KernelContext *) { - return ge::GRAPH_SUCCESS; -} -KernelStatus TestFunc(KernelContext *) { - return 0; -} -std::vector TestTraceFunc(const gert::KernelContext *) { - return {""}; -} -} -class KernelRegistryImplUT : public testing::Test {}; -TEST_F(KernelRegistryImplUT, RegisterAndFind_Ok_AllFuncRegistered) { - KernelRegistryImpl registry; - registry.RegisterKernel("Foo", {{TestFunc, TestOutputCreator, TestTraceFunc}, ""}); - ASSERT_NE(registry.FindKernelFuncs("Foo"), nullptr); - ASSERT_EQ(registry.FindKernelFuncs("Foo")->run_func, &TestFunc); - ASSERT_EQ(registry.FindKernelFuncs("Foo")->outputs_creator, &TestOutputCreator); - ASSERT_EQ(registry.FindKernelFuncs("Foo")->trace_printer, &TestTraceFunc); -} -TEST_F(KernelRegistryImplUT, RegisterAndFind_Ok_OnlyRegisterRunFunc) { - KernelRegistryImpl registry; - registry.RegisterKernel("Foo", {{TestFunc, nullptr, nullptr}, ""}); - ASSERT_NE(registry.FindKernelFuncs("Foo"), nullptr); - ASSERT_EQ(registry.FindKernelFuncs("Foo")->run_func, &TestFunc); - ASSERT_EQ(registry.FindKernelFuncs("Foo")->outputs_creator, nullptr); - ASSERT_EQ(registry.FindKernelFuncs("Foo")->trace_printer, nullptr); -} -TEST_F(KernelRegistryImplUT, FailedToFindWhenNotRegister) { - KernelRegistryImpl registry; - ASSERT_EQ(registry.FindKernelFuncs("Foo"), nullptr); - ASSERT_EQ(registry.FindKernelInfo("Foo"), nullptr); -} -TEST_F(KernelRegistryImplUT, GetAll_Ok) { - KernelRegistryImpl registry; - registry.RegisterKernel("Foo", {{TestFunc, nullptr, nullptr}, "memory"}); - std::unordered_map expect_kernel_infos = { - {"Foo", {{TestFunc, nullptr, nullptr}, "memory"}}, - {"Bar", {{TestFunc, TestOutputCreator, TestTraceFunc}, "memory"}} - }; - registry.RegisterKernel("Foo", {{TestFunc, nullptr, nullptr}, "memory"}); - registry.RegisterKernel("Bar", {{TestFunc, TestOutputCreator, TestTraceFunc}, "memory"}); - ASSERT_EQ(registry.GetAll().size(), expect_kernel_infos.size()); - for (const auto &key_to_infos : registry.GetAll()) { - ASSERT_TRUE(expect_kernel_infos.count(key_to_infos.first) > 0); - ASSERT_EQ(key_to_infos.second.func.run_func, expect_kernel_infos[key_to_infos.first].func.run_func); - ASSERT_EQ(key_to_infos.second.critical_section, expect_kernel_infos[key_to_infos.first].critical_section); - } -} -} // namespace gert diff --git a/tests/ut/register/testcase/kernel_registry_unittest.cc b/tests/ut/register/testcase/kernel_registry_unittest.cc deleted file mode 100644 index 8b55dcba7182e747b165fca8a7fa0d7d708e9ff5..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/kernel_registry_unittest.cc +++ /dev/null @@ -1,195 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/kernel_registry.h" -#include "register/kernel_registry_impl.h" -#include - -namespace test_gert { -namespace { -ge::graphStatus TestOutputsCreator(const ge::FastNode *, gert::KernelContext *) { - return ge::GRAPH_SUCCESS; -} -std::vector TestTraceFunc(const gert::KernelContext *) { - return {""}; -} -KernelStatus TestFunc1(gert::KernelContext *context) { - return 0; -} -ge::graphStatus TestOutputsCreator2(const ge::FastNode *, gert::KernelContext *) { - return ge::GRAPH_SUCCESS; -} -std::vector TestTraceFunc2(const gert::KernelContext *) { - return {""}; -} -KernelStatus TestFunc2(gert::KernelContext *) { - return 0; -} - -ge::graphStatus ProfilingInfoFillerTest(const gert::KernelContext *, gert::ProfilingInfoWrapper &) { return 0; } - -ge::graphStatus DataDumpInfoFillerTest(const gert::KernelContext *, gert::DataDumpInfoWrapper &) { return 0; } - -ge::graphStatus ExceptionDumpInfoFillerTest(const gert::KernelContext *, gert::ExceptionDumpInfoWrapper &) { return 0; } - -class FakeRegistry : public gert::KernelRegistry { -public: - const KernelFuncs *FindKernelFuncs(const std::string &kernel_type) const override { - static KernelFuncs funcs{nullptr, nullptr, nullptr}; - return &funcs; - } -}; -} // namespace -class KernelRegistryTest : public testing::Test { -protected: - void SetUp() override { - Test::SetUp(); - gert::KernelRegistry::ReplaceKernelRegistry(std::make_shared()); - } - void TearDown() override { - gert::KernelRegistry::ReplaceKernelRegistry(nullptr); - } -}; - -TEST_F(KernelRegistryTest, RegisterKernel_RegisterSuccess_OnlyRegisterRunFunc) { - REGISTER_KERNEL(KernelRegistryTest1).RunFunc(TestFunc1); - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest1"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->run_func, &TestFunc1); -} - -TEST_F(KernelRegistryTest, RegisterKernel_DefaultFuncOk_OnlyRegisterRunFunc) { - REGISTER_KERNEL(KernelRegistryTest1).RunFunc(TestFunc1); - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest1"); - ASSERT_NE(funcs, nullptr); - - // output creator 默认函数是啥都不干,直接返回成功 - ASSERT_NE(funcs->outputs_creator, nullptr); - EXPECT_EQ(funcs->outputs_creator(nullptr, nullptr), ge::GRAPH_SUCCESS); - - // trace printer默认值是nullptr - EXPECT_EQ(funcs->trace_printer, nullptr); -} -TEST_F(KernelRegistryTest, RegisterKernel_Success_OutputCreator) { - REGISTER_KERNEL(KernelRegistryTest2) - .OutputsCreator(TestOutputsCreator); - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest2"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->outputs_creator, &TestOutputsCreator); -} -TEST_F(KernelRegistryTest, RegisterKernel_Success_TraceFunc) { - REGISTER_KERNEL(KernelRegistryTest1).TracePrinter(TestTraceFunc); - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest1"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->trace_printer, &TestTraceFunc); -} -TEST_F(KernelRegistryTest, RegisterKernel_Success_Register_Multiple) { - REGISTER_KERNEL(KernelRegistryTest1) - .RunFunc(TestFunc1) - .OutputsCreator(TestOutputsCreator) - .TracePrinter(TestTraceFunc); - - REGISTER_KERNEL(KernelRegistryTest2) - .RunFunc(TestFunc2) - .OutputsCreator(TestOutputsCreator2) - .TracePrinter(TestTraceFunc2) - .ProfilingInfoFiller(ProfilingInfoFillerTest) - .DataDumpInfoFiller(DataDumpInfoFillerTest) - .ExceptionDumpInfoFiller(ExceptionDumpInfoFillerTest); - - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest1"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->run_func, &TestFunc1); - EXPECT_EQ(funcs->outputs_creator, &TestOutputsCreator); - EXPECT_EQ(funcs->trace_printer, &TestTraceFunc); - EXPECT_EQ(funcs->profiling_info_filler, nullptr); - EXPECT_EQ(funcs->data_dump_info_filler, nullptr); - EXPECT_EQ(funcs->exception_dump_info_filler, nullptr); - - funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest2"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->run_func, &TestFunc2); - EXPECT_EQ(funcs->outputs_creator, &TestOutputsCreator2); - EXPECT_EQ(funcs->trace_printer, &TestTraceFunc2); - EXPECT_EQ(funcs->profiling_info_filler, &ProfilingInfoFillerTest); - EXPECT_EQ(funcs->data_dump_info_filler, &DataDumpInfoFillerTest); - EXPECT_EQ(funcs->exception_dump_info_filler, &ExceptionDumpInfoFillerTest); -} -TEST_F(KernelRegistryTest, RegisterKernel_RegisterOk_SelfDefinedRegistry) { - // SetUp 中已经是SelfDefinedRegistry了 - REGISTER_KERNEL(KernelRegistryTest1) - .RunFunc(TestFunc1) - .OutputsCreator(TestOutputsCreator) - .TracePrinter(TestTraceFunc); - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest1"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->run_func, &TestFunc1); - EXPECT_EQ(funcs->outputs_creator, &TestOutputsCreator); - EXPECT_EQ(funcs->trace_printer, &TestTraceFunc); -} -TEST_F(KernelRegistryTest, SelfDefinedRegistry_RecoveryOk) { - // 还原为原始的registry - gert::KernelRegistry::ReplaceKernelRegistry(nullptr); - - // 向原始registry注册 - REGISTER_KERNEL(KernelRegistryTest123) - .RunFunc(TestFunc1) - .OutputsCreator(TestOutputsCreator) - .TracePrinter(TestTraceFunc); - - // replace为自己的实现 - gert::KernelRegistry::ReplaceKernelRegistry(std::make_shared()); - REGISTER_KERNEL(KernelRegistryTest123) - .RunFunc(TestFunc2) - .OutputsCreator(TestOutputsCreator2) - .TracePrinter(TestTraceFunc2); - - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest123"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->run_func, &TestFunc2); - EXPECT_EQ(funcs->outputs_creator, &TestOutputsCreator2); - EXPECT_EQ(funcs->trace_printer, &TestTraceFunc2); - - // 还原为原始的registry - gert::KernelRegistry::ReplaceKernelRegistry(nullptr); - - // 原始注册的func还原成功 - funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest123"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->run_func, &TestFunc1); - EXPECT_EQ(funcs->outputs_creator, &TestOutputsCreator); - EXPECT_EQ(funcs->trace_printer, &TestTraceFunc); -} -TEST_F(KernelRegistryTest, RegisterKernel_NoEffect_RegDeprectedFunc) { - // SetUp 中已经是SelfDefinedRegistry了 - REGISTER_KERNEL(KernelRegistryTest1) - .OutputsCreator(TestOutputsCreator); - auto funcs = gert::KernelRegistry::GetInstance().FindKernelFuncs("KernelRegistryTest1"); - ASSERT_NE(funcs, nullptr); - ASSERT_NE(funcs->outputs_creator, nullptr); - EXPECT_EQ(funcs->outputs_creator(nullptr, nullptr), ge::GRAPH_SUCCESS); -} -TEST_F(KernelRegistryTest, RegisterKernel_RegisterSuccess_OnlyRegisterCriticalSection) { - REGISTER_KERNEL(KernelRegistryTest1).ConcurrentCriticalSectionKey("memory"); - auto kernel_info = gert::KernelRegistry::GetInstance().FindKernelInfo("KernelRegistryTest1"); - ASSERT_NE(kernel_info, nullptr); - std::string critical_section = kernel_info->critical_section; - EXPECT_EQ(critical_section, "memory"); -} -TEST_F(KernelRegistryTest, RegisterKernel_RegisterSuccess_NotRegisterCriticalSection) { - REGISTER_KERNEL(KernelRegistryTest1).RunFunc(TestFunc1); - auto kernel_info =gert::KernelRegistry::GetInstance().FindKernelInfo("KernelRegistryTest1"); - ASSERT_NE(kernel_info, nullptr); - std::string critical_section = kernel_info->critical_section; - EXPECT_EQ(critical_section, ""); -} -TEST_F(KernelRegistryTest, RegisterKernel_NotRegister_NotFindKernelInfos) { - EXPECT_EQ(gert::KernelRegistry::GetInstance().FindKernelInfo("KernelRegistryTest1"), nullptr); -} -} // namespace test_gert diff --git a/tests/ut/register/testcase/node_converter_registry_unittest.cc b/tests/ut/register/testcase/node_converter_registry_unittest.cc deleted file mode 100644 index 06d697d7ac3e572c4e846e2dc11281b9df709065..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/node_converter_registry_unittest.cc +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/node_converter_registry.h" -#include - -class NodeConverterRegistryUnittest : public testing::Test {}; - -namespace TestNodeConverterRegistry { -gert::LowerResult TestFunc(const ge::NodePtr &node, const gert::LowerInput &lower_input) { - return {}; -} -gert::LowerResult TestFunc2(const ge::NodePtr &node, const gert::LowerInput &lower_input) { - return {}; -} - -TEST_F(NodeConverterRegistryUnittest, RegisterSuccess_DefaultPlacement) { - EXPECT_EQ(gert::NodeConverterRegistry::GetInstance().FindNodeConverter("RegisterSuccess1"), nullptr); - REGISTER_NODE_CONVERTER("RegisterSuccess1", TestFunc); - EXPECT_EQ(gert::NodeConverterRegistry::GetInstance().FindNodeConverter("RegisterSuccess1"), TestFunc); - auto reg_data1 = gert::NodeConverterRegistry::GetInstance().FindRegisterData("RegisterSuccess1"); - ASSERT_NE(reg_data1, nullptr); - EXPECT_EQ(reg_data1->converter, TestFunc); - EXPECT_EQ(reg_data1->require_placement, -1); -} - -TEST_F(NodeConverterRegistryUnittest, RegisterSuccess_WithPlacement) { - EXPECT_EQ(gert::NodeConverterRegistry::GetInstance().FindNodeConverter("RegisterSuccess2"), nullptr); - REGISTER_NODE_CONVERTER_PLACEMENT("RegisterSuccess2", 10, TestFunc2); - EXPECT_EQ(gert::NodeConverterRegistry::GetInstance().FindNodeConverter("RegisterSuccess2"), TestFunc2); - auto reg_data1 = gert::NodeConverterRegistry::GetInstance().FindRegisterData("RegisterSuccess2"); - ASSERT_NE(reg_data1, nullptr); - EXPECT_EQ(reg_data1->converter, TestFunc2); - EXPECT_EQ(reg_data1->require_placement, 10); -} -} diff --git a/tests/ut/register/testcase/op_binary_resource_manager_unittest.cc b/tests/ut/register/testcase/op_binary_resource_manager_unittest.cc deleted file mode 100644 index b6e8fb339d569540cb3815bcb02d94202a229276..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_binary_resource_manager_unittest.cc +++ /dev/null @@ -1,199 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "register/op_binary_resource_manager.h" - -class OpBinaryResourceManagerUT : public testing::Test { -protected: - void SetUp() {} - void TearDown() {} - - nnopbase::OpBinaryResourceManager &manager = nnopbase::OpBinaryResourceManager::GetInstance(); -}; - -TEST_F(OpBinaryResourceManagerUT, SaveFunc) { - int i; - manager.AddOpFuncHandle("AddTik2", {(void *)&i}); - EXPECT_EQ(manager.resourceHandle_.size(), 1); - EXPECT_EQ(manager.resourceHandle_["AddTik2"][0], &i); -} - -std::string AddTik2Json = "{" - " \"binList\": [" - " {" - " \"simplifiedKey\": [" - " \"AddTik2/d=0,p=0/1,2/1,2/1,2\"," - " \"AddTik2/d=1,p=0/1,2/1,2/1,2\"" - " ]," - " \"binInfo\": {" - " \"jsonFilePath\": \"ascend910/add_tik2/Add_Tik2_01.json\"" - " }" - " }," - " {" - " \"simplifiedKey\": [" - " \"AddTik2/d=0,p=0/1,2/0,2/0,2\"," - " \"AddTik2/d=1,p=0/1,2/0,2/0,2\"" - " ]," - " \"binInfo\": {" - " \"jsonFilePath\": \"ascend910/add_tik2/Add_Tik2_02.json\"" - " }" - " }" - " ]" - "}"; -std::string AddTik201Json = "{" - " \"filePath\": \"ascend910/add_tik2/Add_Tik2_01.json\"," - " \"supportInfo\": {" - " \"simplifiedKey\": [" - " \"AddTik2/d=0,p=0/1,2/1,2/1,2\"," - " \"AddTik2/d=1,p=0/1,2/1,2/1,2\"" - " ]" - " }" - "}"; -std::string AddTik201Bin = "01"; -std::string AddTik202Json = "{" - " \"filePath\": \"ascend910/add_tik2/Add_Tik2_02.json\"," - " \"supportInfo\": {" - " \"simplifiedKey\": [" - " \"AddTik2/d=0,p=0/1,2/0,2/0,2\"," - " \"AddTik2/d=1,p=0/1,2/0,2/0,2\"" - " ]" - " }" - "}"; -std::string AddTik202Bin = "02"; - -std::vector> addTik2OpBinary( - {{(const uint8_t*)AddTik2Json.c_str(), (const uint8_t*)AddTik2Json.c_str() + AddTik2Json.size()}, - {(const uint8_t*)AddTik201Json.c_str(), (const uint8_t*)AddTik201Json.c_str() + AddTik201Json.size()}, - {(const uint8_t*)AddTik201Bin.c_str(), (const uint8_t*)AddTik201Bin.c_str() + AddTik201Bin.size()}, - {(const uint8_t*)AddTik202Json.c_str(), (const uint8_t*)AddTik202Json.c_str() + AddTik202Json.size()}, - {(const uint8_t*)AddTik202Bin.c_str(), (const uint8_t*)AddTik202Bin.c_str() + AddTik202Bin.size()}}); - -TEST_F(OpBinaryResourceManagerUT, SaveBinary) { - EXPECT_EQ(manager.AddBinary("AddTik2", addTik2OpBinary), ge::GRAPH_SUCCESS); - EXPECT_EQ(manager.opBinaryDesc_.size(), 1); - auto it = manager.opBinaryDesc_.find("AddTik2"); - ASSERT_NE(it, manager.opBinaryDesc_.end()); - auto list = it->second; - EXPECT_EQ(list.size(), 1); - - auto binIter = manager.pathToBinary_.find("ascend910/add_tik2/Add_Tik2_01.json"); - ASSERT_NE(binIter, manager.pathToBinary_.end()); - auto binJson = std::get<0U>(binIter->second); - auto bin = std::get<1U>(binIter->second); - auto filePath = binJson["filePath"].get(); - EXPECT_EQ(filePath, "ascend910/add_tik2/Add_Tik2_01.json"); - EXPECT_EQ(bin.content, (const uint8_t*)AddTik201Bin.c_str()); - EXPECT_EQ(bin.len, AddTik201Bin.size()); - - EXPECT_EQ(manager.keyToPath_["AddTik2/d=0,p=0/1,2/1,2/1,2"], "ascend910/add_tik2/Add_Tik2_01.json"); - EXPECT_EQ(manager.keyToPath_["AddTik2/d=1,p=0/1,2/1,2/1,2"], "ascend910/add_tik2/Add_Tik2_01.json"); - EXPECT_EQ(manager.keyToPath_["AddTik2/d=0,p=0/1,2/0,2/0,2"], "ascend910/add_tik2/Add_Tik2_02.json"); - EXPECT_EQ(manager.keyToPath_["AddTik2/d=1,p=0/1,2/0,2/0,2"], "ascend910/add_tik2/Add_Tik2_02.json"); -} - -TEST_F(OpBinaryResourceManagerUT, BinaryForJson) { - nlohmann::json binDesc; - EXPECT_EQ(manager.GetOpBinaryDesc("AddTik2", binDesc), ge::GRAPH_SUCCESS); - auto keys = binDesc["binList"][0]["simplifiedKey"].get>(); - EXPECT_EQ(keys[0], "AddTik2/d=0,p=0/1,2/1,2/1,2"); - EXPECT_EQ(keys[1], "AddTik2/d=1,p=0/1,2/1,2/1,2"); - auto jsonFilePath = binDesc["binList"][0]["binInfo"]["jsonFilePath"].get(); - EXPECT_EQ(jsonFilePath, "ascend910/add_tik2/Add_Tik2_01.json"); - - keys = binDesc["binList"][1]["simplifiedKey"].get>(); - EXPECT_EQ(keys[0], "AddTik2/d=0,p=0/1,2/0,2/0,2"); - EXPECT_EQ(keys[1], "AddTik2/d=1,p=0/1,2/0,2/0,2"); - jsonFilePath = binDesc["binList"][1]["binInfo"]["jsonFilePath"].get(); - EXPECT_EQ(jsonFilePath, "ascend910/add_tik2/Add_Tik2_02.json"); -} - -TEST_F(OpBinaryResourceManagerUT, KeyToBinary) { - std::tuple binInfo; - EXPECT_EQ(manager.GetOpBinaryDescByKey("AddTik2/d=1,p=0/1,2/1,2/1,2", binInfo), ge::GRAPH_SUCCESS); - auto binJson = std::get<0U>(binInfo); - auto bin = std::get<1U>(binInfo); - auto filePath = binJson["filePath"].get(); - EXPECT_EQ(filePath, "ascend910/add_tik2/Add_Tik2_01.json"); - EXPECT_EQ(bin.content, (const uint8_t*)AddTik201Bin.c_str()); - EXPECT_EQ(bin.len, AddTik201Bin.size()); - - EXPECT_EQ(manager.GetOpBinaryDescByKey("AddTik2/d=1,p=0/1,2/0,2/0,2", binInfo), ge::GRAPH_SUCCESS); - binJson = std::get<0U>(binInfo); - bin = std::get<1U>(binInfo); - filePath = binJson["filePath"].get(); - EXPECT_EQ(filePath, "ascend910/add_tik2/Add_Tik2_02.json"); - EXPECT_EQ(bin.content, (const uint8_t*)AddTik202Bin.c_str()); - EXPECT_EQ(bin.len, AddTik202Bin.size()); -} - -TEST_F(OpBinaryResourceManagerUT, PathToBinary) { - std::tuple binInfo; - EXPECT_EQ(manager.GetOpBinaryDescByPath("ascend910/add_tik2/Add_Tik2_01.json", binInfo), ge::GRAPH_SUCCESS); - auto binJson = std::get<0U>(binInfo); - auto bin = std::get<1U>(binInfo); - auto filePath = binJson["filePath"].get(); - EXPECT_EQ(filePath, "ascend910/add_tik2/Add_Tik2_01.json"); - EXPECT_EQ(bin.content, (const uint8_t*)AddTik201Bin.c_str()); - EXPECT_EQ(bin.len, AddTik201Bin.size()); - - EXPECT_EQ(manager.GetOpBinaryDescByPath("ascend910/add_tik2/Add_Tik2_02.json", binInfo), ge::GRAPH_SUCCESS); - binJson = std::get<0U>(binInfo); - bin = std::get<1U>(binInfo); - filePath = binJson["filePath"].get(); - EXPECT_EQ(filePath, "ascend910/add_tik2/Add_Tik2_02.json"); - EXPECT_EQ(bin.content, (const uint8_t*)AddTik202Bin.c_str()); - EXPECT_EQ(bin.len, AddTik202Bin.size()); -} - -TEST_F(OpBinaryResourceManagerUT, BinaryAllDesc) { - auto &map = manager.GetAllOpBinaryDesc(); - EXPECT_EQ(map.size(), 1); - auto it = map.find("AddTik2"); - ASSERT_NE(it, map.end()); - - auto keys = (it->second)["binList"][0]["simplifiedKey"].get>(); - EXPECT_EQ(keys[0], "AddTik2/d=0,p=0/1,2/1,2/1,2"); - EXPECT_EQ(keys[1], "AddTik2/d=1,p=0/1,2/1,2/1,2"); - auto jsonFilePath = (it->second)["binList"][0]["binInfo"]["jsonFilePath"].get(); - EXPECT_EQ(jsonFilePath, "ascend910/add_tik2/Add_Tik2_01.json"); - - keys = (it->second)["binList"][1]["simplifiedKey"].get>(); - EXPECT_EQ(keys[0], "AddTik2/d=0,p=0/1,2/0,2/0,2"); - EXPECT_EQ(keys[1], "AddTik2/d=1,p=0/1,2/0,2/0,2"); - jsonFilePath = (it->second)["binList"][1]["binInfo"]["jsonFilePath"].get(); - EXPECT_EQ(jsonFilePath, "ascend910/add_tik2/Add_Tik2_02.json"); -} - -std::string AddTik2KbRuntime = "1234"; -std::vector> addTik2RuntimeKb( - {{(const uint8_t*)AddTik2KbRuntime.c_str(), (const uint8_t*)AddTik2KbRuntime.c_str() + AddTik2KbRuntime.size()}}); - -TEST_F(OpBinaryResourceManagerUT, RuntimeKB) { - EXPECT_EQ(manager.AddRuntimeKB("AddTik2", addTik2RuntimeKb), ge::GRAPH_SUCCESS); - // 重复添加正常 - EXPECT_EQ(manager.AddRuntimeKB("AddTik2", addTik2RuntimeKb), ge::GRAPH_SUCCESS); - std::vector kbList; - EXPECT_EQ(manager.GetOpRuntimeKB("AddTik2", kbList), ge::GRAPH_SUCCESS); - EXPECT_EQ(kbList.size(), 1); - EXPECT_EQ(AddTik2KbRuntime, kbList[0].GetString()); -} - -TEST_F(OpBinaryResourceManagerUT, Error) { - nlohmann::json binDesc; - EXPECT_EQ(manager.GetOpBinaryDesc("AddTik2Invalid",binDesc), ge::GRAPH_PARAM_INVALID); - - std::tuple binInfo; - EXPECT_EQ(manager.GetOpBinaryDescByPath("AddTik2Invalid", binInfo), ge::GRAPH_PARAM_INVALID); - EXPECT_EQ(manager.GetOpBinaryDescByKey("AddTik2Invalid", binInfo), ge::GRAPH_PARAM_INVALID); - - std::vector kbList; - EXPECT_EQ(manager.GetOpRuntimeKB("AddTik2Invalid", kbList), ge::GRAPH_PARAM_INVALID); -} diff --git a/tests/ut/register/testcase/op_calc_param_unittest.cc b/tests/ut/register/testcase/op_calc_param_unittest.cc deleted file mode 100644 index 56eb88bcc87f756094de9d1942aaec0d192e5fad..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_calc_param_unittest.cc +++ /dev/null @@ -1,26 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_ext_calc_param_registry.h" -#include "proto/task.pb.h" -#include - -class OpExtCalcParamRegistryUnittest : public testing::Test {}; - -namespace TestOpExtCalcParamRegistry { -ge::Status OpExtCalcParam(const ge::Node &node) { - return 0; -} - -TEST_F(OpExtCalcParamRegistryUnittest, OpExtCalcParamRegisterSuccess_Test) { - EXPECT_EQ(fe::OpExtCalcParamRegistry::GetInstance().FindRegisterFunc("test"), nullptr); - REGISTER_NODE_EXT_CALC_PARAM("test", OpExtCalcParam); - EXPECT_EQ(fe::OpExtCalcParamRegistry::GetInstance().FindRegisterFunc("test"), OpExtCalcParam); -} -} diff --git a/tests/ut/register/testcase/op_check_unittest.cc b/tests/ut/register/testcase/op_check_unittest.cc deleted file mode 100644 index a305d98020c0241f71ef4c6011ac2ff5402bf6fd..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_check_unittest.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_check_register.h" - -bool testFunc(const ge::Operator &op, ge::AscendString &result) { - return true; -} - -namespace { - -class OpCheckAPIUT : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(OpCheckAPIUT, APITest) { - ge::AscendString example("test"); - optiling::GEN_SIMPLIFIEDKEY_FUNC pFunc; - pFunc = testFunc; - optiling::OpCheckFuncRegistry::RegisterGenSimplifiedKeyFunc(example, pFunc); - - ge::AscendString errName("notExisted"); - auto nullFunc = optiling::OpCheckFuncRegistry::GetGenSimplifiedKeyFun(errName); - EXPECT_EQ(nullFunc, nullptr); - - auto func = optiling::OpCheckFuncRegistry::GetGenSimplifiedKeyFun(example); - EXPECT_NE(func, nullptr); -} - -} // namespace diff --git a/tests/ut/register/testcase/op_ct_impl_registry_unittest.cc b/tests/ut/register/testcase/op_ct_impl_registry_unittest.cc deleted file mode 100644 index 8b54fc3405c7b7e9c72451249d168d2712134dcb..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_ct_impl_registry_unittest.cc +++ /dev/null @@ -1,140 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_ct_impl_registry.h" -#include "register/op_impl_registry_base.h" -#include -#include "exe_graph/runtime/kernel_context.h" -#include "graph/any_value.h" -#include "register/op_ct_impl_registry_api.h" -#include "exe_graph/runtime/exe_res_generation_context.h" -#include "register/op_ct_impl_kernel_registry.h" - -namespace gert_test { -namespace { -ge::graphStatus CalcParamKernelFunc(gert::ExeResGenerationContext *context) { - return ge::GRAPH_SUCCESS; -} -ge::graphStatus GenTaskKernelFunc(const gert::ExeResGenerationContext *context, std::vector> &tasks) { - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus CheckSupportFunc(const gert::OpCheckContext *context, ge::AscendString &result) { - return ge::GRAPH_SUCCESS; -} -ge::graphStatus OpSelectFormatFunc(const gert::OpCheckContext *context, ge::AscendString &result) { - return ge::GRAPH_SUCCESS; -} -ge::graphStatus GetOpSpecificInfoFunc(const gert::OpCheckContext *context, ge::AscendString &result) { - return ge::GRAPH_SUCCESS; -} -} // namespace -class OpCtImplRegistryUT : public testing::Test { - - protected: - virtual void TearDown() { - gert::OpCtImplRegistry::GetInstance().GetAllTypesToImpl().clear(); - } -}; - -TEST_F(OpCtImplRegistryUT, Register_Success_RegisterAll) { - auto funcs = gert::OpCtImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_EQ(funcs, nullptr); - - IMPL_OP_CT(TestFoo) - .CalcOpParam(CalcParamKernelFunc) - .GenerateTask(GenTaskKernelFunc); - auto impl_num = GetRegisteredOpCtNum(); - EXPECT_EQ(impl_num, 1); - funcs = gert::OpCtImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->calc_op_param, &CalcParamKernelFunc); - EXPECT_EQ(funcs->gen_task, &GenTaskKernelFunc); - - IMPL_OP_CT(TestConv2D).CalcOpParam(CalcParamKernelFunc).GenerateTask(GenTaskKernelFunc); - impl_num = GetRegisteredOpCtNum(); - auto impl_funcs = std::unique_ptr(new(std::nothrow) TypesToCtImpl[impl_num]); - EXPECT_EQ(impl_num, 2); - auto ret = GetOpCtImplFunctions(reinterpret_cast(impl_funcs.get()), impl_num); - EXPECT_NE(ret, ge::GRAPH_FAILED); - for (size_t i = 0; i < impl_num; ++i) { - EXPECT_EQ(impl_funcs[i].funcs.calc_op_param, &CalcParamKernelFunc); - EXPECT_EQ(impl_funcs[i].funcs.gen_task, &GenTaskKernelFunc); - } -} - -TEST_F(OpCtImplRegistryUT, Register_Ct_version_test) { - auto funcs = gert::OpCtImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_EQ(funcs, nullptr); - - IMPL_OP_CT(TestFoo) - .CalcOpParam(CalcParamKernelFunc) - .GenerateTask(GenTaskKernelFunc); - auto impl_num = GetRegisteredOpCtNum(); - EXPECT_EQ(impl_num, 1); - funcs = gert::OpCtImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->calc_op_param, &CalcParamKernelFunc); - EXPECT_EQ(funcs->gen_task, &GenTaskKernelFunc); - - IMPL_OP_CT(TestConv2D).CalcOpParam(CalcParamKernelFunc).GenerateTask(GenTaskKernelFunc); - impl_num = GetRegisteredOpCtNum(); - size_t real_size = sizeof(gert::OpCtImplKernelRegistry::OpCtImplFunctions) + 8; - size_t offset = real_size + sizeof(char*); - auto mem_ptr = std::unique_ptr(new(std::nothrow) uint8_t [impl_num * offset]); - EXPECT_EQ(impl_num, 2); - for (size_t i = 0; i < impl_num; ++i) { - auto tmp_impl = reinterpret_cast(mem_ptr.get() + offset * i); - tmp_impl->funcs.version = 2; - tmp_impl->funcs.st_size = real_size; - } - auto ret = GetOpCtImplFunctions(reinterpret_cast(mem_ptr.get()), impl_num); - EXPECT_NE(ret, ge::GRAPH_FAILED); - for (size_t i = 0; i < impl_num; ++i) { - auto tmp_impl = reinterpret_cast(mem_ptr.get() + offset * i); - EXPECT_EQ(tmp_impl->funcs.version, 1); - EXPECT_EQ(tmp_impl->funcs.st_size, sizeof(gert::OpCtImplKernelRegistry::OpCtImplFunctions)); - EXPECT_EQ(tmp_impl->funcs.calc_op_param, &CalcParamKernelFunc); - EXPECT_EQ(tmp_impl->funcs.gen_task, &GenTaskKernelFunc); - } -} - -TEST_F(OpCtImplRegistryUT, register_all_success) { - auto funcs = gert::OpCtImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_EQ(funcs, nullptr); - - IMPL_OP_CT(TestFoo) - .CheckSupport(CheckSupportFunc) - .OpSelectFormat(OpSelectFormatFunc) - .GetOpSpecificInfo(GetOpSpecificInfoFunc); - auto impl_num = GetRegisteredOpCtNum(); - EXPECT_EQ(impl_num, 1); - funcs = gert::OpCtImplRegistry::GetInstance().GetOpImpl("TestFoo"); - ASSERT_NE(funcs, nullptr); - EXPECT_EQ(funcs->check_support, &CheckSupportFunc); - EXPECT_EQ(funcs->op_select_format, &OpSelectFormatFunc); - EXPECT_EQ(funcs->get_op_specific_info, &GetOpSpecificInfoFunc); - - IMPL_OP_CT(TestConv2D) - .CheckSupport(CheckSupportFunc) - .OpSelectFormat(OpSelectFormatFunc) - .GetOpSpecificInfo(GetOpSpecificInfoFunc); - impl_num = GetRegisteredOpCtNum(); - auto impl_funcs = std::unique_ptr(new(std::nothrow) TypesToCtImpl[impl_num]); - EXPECT_EQ(impl_num, 2); - auto ret = GetOpCtImplFunctions(reinterpret_cast(impl_funcs.get()), impl_num); - EXPECT_NE(ret, ge::GRAPH_FAILED); - for (size_t i = 0; i < impl_num; ++i) { - EXPECT_EQ(impl_funcs[i].funcs.check_support, &CheckSupportFunc); - EXPECT_EQ(impl_funcs[i].funcs.op_select_format, &OpSelectFormatFunc); - EXPECT_EQ(impl_funcs[i].funcs.get_op_specific_info, &GetOpSpecificInfoFunc); - } -} - -} // namespace gert_test diff --git a/tests/ut/register/testcase/op_def_aicore_unittest.cc b/tests/ut/register/testcase/op_def_aicore_unittest.cc deleted file mode 100644 index 143c8077adc50579338ec4049a98aa2dd9f7f271..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_def_aicore_unittest.cc +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_def_registry.h" - -namespace ops { - -namespace { - -class OpDefAICoreUT : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(OpDefAICoreUT, AICoreTest) { - OpAICoreDef aicoreDef; - OpAICoreConfig config; - config.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(true) - .PrecisionReduceFlag(true); - std::map cfgs = config.GetCfgInfo(); - EXPECT_EQ(cfgs["dynamicCompileStatic.flag"], "true"); - EXPECT_EQ(cfgs["dynamicFormat.flag"], "true"); - EXPECT_EQ(cfgs["dynamicRankSupport.flag"], "true"); - EXPECT_EQ(cfgs["dynamicShapeSupport.flag"], "true"); - EXPECT_EQ(cfgs["needCheckSupport.flag"], "true"); - EXPECT_EQ(cfgs["precision_reduce.flag"], "true"); - config.DynamicCompileStaticFlag(false) - .DynamicFormatFlag(false) - .DynamicRankSupportFlag(false) - .DynamicShapeSupportFlag(false) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(false); - cfgs = config.GetCfgInfo(); - EXPECT_EQ(cfgs["dynamicCompileStatic.flag"], "false"); - EXPECT_EQ(cfgs["dynamicFormat.flag"], "false"); - EXPECT_EQ(cfgs["dynamicRankSupport.flag"], "false"); - EXPECT_EQ(cfgs["dynamicShapeSupport.flag"], "false"); - EXPECT_EQ(cfgs["needCheckSupport.flag"], "false"); - EXPECT_EQ(cfgs["precision_reduce.flag"], "false"); - aicoreDef.AddConfig("ascend310p", config); - aicoreDef.AddConfig("ascend910", config); - aicoreDef.AddConfig("ascend310p", config); - std::map aicfgs = aicoreDef.GetAICoreConfigs(); - EXPECT_TRUE(aicfgs.find("ascend310p") != aicfgs.end()); - EXPECT_EQ(aicfgs.size(), 2); - aicoreDef.AddConfig("ascend310p"); - aicfgs = aicoreDef.GetAICoreConfigs(); - config = aicfgs["ascend310p"]; - cfgs = config.GetCfgInfo(); - EXPECT_EQ(cfgs["dynamicCompileStatic.flag"], "true"); - EXPECT_EQ(cfgs["dynamicFormat.flag"], "true"); - EXPECT_EQ(cfgs["dynamicRankSupport.flag"], "true"); - EXPECT_EQ(cfgs["dynamicShapeSupport.flag"], "true"); - EXPECT_EQ(cfgs["needCheckSupport.flag"], "false"); - EXPECT_EQ(cfgs["precision_reduce.flag"], "true"); - - // test default aicore config - OpAICoreConfig configDefault("ascend310p"); - cfgs = configDefault.GetCfgInfo(); - EXPECT_EQ(cfgs["dynamicCompileStatic.flag"], "true"); - EXPECT_EQ(cfgs["dynamicFormat.flag"], "true"); - EXPECT_EQ(cfgs["dynamicRankSupport.flag"], "true"); - EXPECT_EQ(cfgs["dynamicShapeSupport.flag"], "true"); - EXPECT_EQ(cfgs["needCheckSupport.flag"], "false"); - EXPECT_EQ(cfgs["precision_reduce.flag"], "true"); -} - -} // namespace -} // namespace ops diff --git a/tests/ut/register/testcase/op_def_api_unittest.cc b/tests/ut/register/testcase/op_def_api_unittest.cc deleted file mode 100644 index 4f9a1eeaf80d61442f8b2836c675d50b0e7f985c..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_def_api_unittest.cc +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_def_registry.h" - -namespace ge { - -static ge::graphStatus InferShape4AddAscendC(gert::InferShapeContext *context) { - return GRAPH_SUCCESS; -} - -static ge::graphStatus InferShapeRange4AddAscendC(gert::InferShapeRangeContext *context) { - return GRAPH_SUCCESS; -} - -static ge::graphStatus InferDataType4AddAscendC(gert::InferDataTypeContext *context) { - return GRAPH_SUCCESS; -} - -} // namespace ge - -namespace optiling { - -static ge::graphStatus TilingAscendCAdd(gert::TilingContext *context) { - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus check_op_support(const ge::Operator &op, ge::AscendString &result) { - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus get_op_support(const ge::Operator &op, ge::AscendString &result) { - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus op_select_format(const ge::Operator &op, ge::AscendString &result) { - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus get_op_specific_info(const ge::Operator &op, ge::AscendString &result) { - return ge::GRAPH_SUCCESS; -} - -static ge::graphStatus generalize_config(const ge::Operator &op, const ge::AscendString &generalize_config, - ge::AscendString &generalize_para) { - return ge::GRAPH_SUCCESS; -} - -} // namespace optiling - -namespace ops { - -namespace { - -class OpDefAPIUT : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(OpDefAPIUT, APITest) { - OpDef opDef("Test"); - opDef.SetInferShape(ge::InferShape4AddAscendC); - opDef.SetInferShapeRange(ge::InferShapeRange4AddAscendC); - opDef.SetInferDataType(ge::InferDataType4AddAscendC); - opDef.AICore().SetTiling(optiling::TilingAscendCAdd); - opDef.AICore() - .SetCheckSupport(optiling::check_op_support) - .SetOpSelectFormat(optiling::op_select_format) - .SetOpSupportInfo(optiling::get_op_support) - .SetOpSpecInfo(optiling::get_op_specific_info) - .SetParamGeneralize(optiling::generalize_config); - EXPECT_EQ(opDef.GetInferShape(), ge::InferShape4AddAscendC); - EXPECT_EQ(opDef.GetInferShapeRange(), ge::InferShapeRange4AddAscendC); - EXPECT_EQ(opDef.GetInferDataType(), ge::InferDataType4AddAscendC); - EXPECT_EQ(opDef.AICore().GetTiling(), optiling::TilingAscendCAdd); - EXPECT_EQ(opDef.AICore().GetCheckSupport(), optiling::check_op_support); - EXPECT_EQ(opDef.AICore().GetOpSelectFormat(), optiling::op_select_format); - EXPECT_EQ(opDef.AICore().GetOpSupportInfo(), optiling::get_op_support); - EXPECT_EQ(opDef.AICore().GetOpSpecInfo(), optiling::get_op_specific_info); - EXPECT_EQ(opDef.AICore().GetParamGeneralize(), optiling::generalize_config); -} - -} // namespace -} // namespace ops diff --git a/tests/ut/register/testcase/op_def_attr_unittest.cc b/tests/ut/register/testcase/op_def_attr_unittest.cc deleted file mode 100644 index 20daffad477900d4d3a0f24d75aae872333e2f17..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_def_attr_unittest.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_def_registry.h" - -namespace ops { - -namespace { - -class OpAttrDefUT : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(OpAttrDefUT, AttrTest) { - OpDef opDef("Test"); - OpAttrDef attr("Test"); - OpAttrDef attr2("Test"); - OpAttrDef attr3("Test1"); - EXPECT_EQ(attr == attr2, true); - EXPECT_EQ(attr == attr3, false); - attr.Bool(); - EXPECT_EQ(attr.GetCfgDataType(), "bool"); - EXPECT_EQ(attr.GetProtoDataType(), "Bool"); - opDef.Attr("Test"); - EXPECT_EQ(opDef.GetAttrs().size(), 1); - opDef.Attr("Test"); - EXPECT_EQ(opDef.GetAttrs().size(), 1); - opDef.Attr("Test1"); - EXPECT_EQ(opDef.GetAttrs().size(), 2); - attr.AttrType(Option::OPTIONAL).Bool(true); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "true"); - attr.AttrType(Option::OPTIONAL).Int(10); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "10"); - attr.AttrType(Option::OPTIONAL).String("test"); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "test"); - attr.AttrType(Option::OPTIONAL).Float(0.1); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "0.1"); - attr.AttrType(Option::OPTIONAL).ListBool({true, false}); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "[true,false]"); - attr.AttrType(Option::OPTIONAL).ListFloat({0.1, 0.1}); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "[0.1,0.1]"); - attr.AttrType(Option::OPTIONAL).ListInt({1, 2}); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "[1,2]"); - attr.AttrType(Option::OPTIONAL).ListListInt({{1, 2}, {3, 4}}); - EXPECT_EQ(attr.GetAttrDefaultVal("[]"), "[[1,2],[3,4]]"); - attr.Version(1); - EXPECT_EQ(attr.GetVersion(), 1); -} -TEST_F(OpAttrDefUT, CommentSingleTest) { - OpAttrDef attr("Test"); - attr.Comment("") - .Comment("comment of Attr Test"); - EXPECT_EQ(attr.GetComment(), "comment of Attr Test"); -} -TEST_F(OpAttrDefUT, CommentCombineTest) { - OpDef opDef("Test"); - opDef.Attr("Test") - .Comment("") - .Comment("comment of Attr Test"); - EXPECT_EQ(opDef.GetAttrs().size(), 1); - EXPECT_EQ(opDef.GetAttrs().at(0).GetComment(), "comment of Attr Test"); -} -} // namespace -} // namespace ops diff --git a/tests/ut/register/testcase/op_def_factory_unittest.cc b/tests/ut/register/testcase/op_def_factory_unittest.cc deleted file mode 100644 index 43bc26488fe3d666b1431ef53d870e003555f347..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_def_factory_unittest.cc +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_def_registry.h" -#include "register/op_config_registry.h" -#include "register/device_op_impl_registry.h" -#include "register/opdef/op_config_registry_impl.h" - -namespace ops { -namespace { -class OpDefFactoryUT : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -class AddAscendC : public OpDef { - public: - AddAscendC(const char *name) : OpDef(name) {} -}; - -OP_ADD(AddAscendC, None); - -class AddCustomRegMacro : public OpDef { - public: - AddCustomRegMacro(const char *name) : OpDef(name) {} -}; - -REGISTER_OP_AICORE_CONFIG(AddCustomRegMacro, ascendxxxy, []() { - ops::OpAICoreConfig config("ascendxxxy"); - return config; -}); - -OP_ADD(AddCustomRegMacro); - -class AddCustomAddConfigWithRegMacro : public OpDef { - public: - AddCustomAddConfigWithRegMacro(const char *name) : OpDef(name) { - OpAICoreConfig aicoreConfig; - aicoreConfig.DynamicCompileStaticFlag(false) - .DynamicFormatFlag(false) - .DynamicRankSupportFlag(false) - .DynamicShapeSupportFlag(false) - .NeedCheckSupportFlag(false) - .PrecisionReduceFlag(false); - this->AICore().AddConfig("ascend111y", aicoreConfig); - } -}; - -REGISTER_OP_AICORE_CONFIG(AddCustomAddConfigWithRegMacro, ascend111y, []() { - ops::OpAICoreConfig config; - config.DynamicCompileStaticFlag(true) - .DynamicFormatFlag(true) - .DynamicRankSupportFlag(true) - .DynamicShapeSupportFlag(true) - .NeedCheckSupportFlag(true) - .PrecisionReduceFlag(true); - return config; -}); - -OP_ADD(AddCustomAddConfigWithRegMacro); - - -TEST_F(OpDefFactoryUT, OpDefFactoryTest) { - auto &ops = OpDefFactory::GetAllOp(); - EXPECT_EQ(ops.size(), 3); - EXPECT_EQ(std::string(ops[0].GetString()), "AddAscendC"); - EXPECT_EQ(std::string(ops[1].GetString()), "AddCustomRegMacro"); - EXPECT_EQ(std::string(ops[2].GetString()), "AddCustomAddConfigWithRegMacro"); - if (std::string(ops[0].GetString()) == "AddAscendC") { - OpDef opDef = OpDefFactory::OpDefCreate(ops[0].GetString()); - EXPECT_EQ(opDef.GetOpType(), "AddAscendC"); - } -} - -TEST_F(OpDefFactoryUT, DeviceOpImplRegisterUt) { - auto op_device_register_tmp = optiling::DeviceOpImplRegister("AddAscendC"); - optiling::DeviceOpImplRegister op_device_register_tmp2 = optiling::DeviceOpImplRegister("AddAscendC"); - optiling::DeviceOpImplRegister op_device_register_move(std::move(op_device_register_tmp)); - optiling::DeviceOpImplRegister op_device_register = op_device_register_move; - op_device_register.Tiling((optiling::SinkTilingFunc)nullptr); - EXPECT_EQ(ops::OpDefFactory::OpIsTilingSink("AddAscendC"), true); -} - -TEST_F(OpDefFactoryUT, RegisterOpAICoreConfigTest) { - auto regConfigs = GetOpAllAICoreConfig("AddCustomRegMacro"); - EXPECT_EQ(regConfigs.size(), 1); - auto it = regConfigs.cbegin(); - EXPECT_EQ(std::string((it->first).GetString()), "ascendxxxy"); - EXPECT_NE(it->second, nullptr); - - OpDef opDef = OpDefFactory::OpDefCreate("AddCustomRegMacro"); - std::map aicfgs = opDef.AICore().GetAICoreConfigs(); - EXPECT_TRUE(aicfgs.find("ascendxxxy") != aicfgs.end()); - EXPECT_EQ(aicfgs.size(), 1); - OpAICoreConfig config = aicfgs["ascendxxxy"]; - std::map cfgs = config.GetCfgInfo(); - // shoud be default config - EXPECT_EQ(cfgs["dynamicCompileStatic.flag"], "true"); - EXPECT_EQ(cfgs["dynamicFormat.flag"], "true"); - EXPECT_EQ(cfgs["dynamicRankSupport.flag"], "true"); - EXPECT_EQ(cfgs["dynamicShapeSupport.flag"], "true"); - EXPECT_EQ(cfgs["needCheckSupport.flag"], "false"); - EXPECT_EQ(cfgs["precision_reduce.flag"], "true"); - - ops::OpConfigRegistry configRegistry; - configRegistry.RegisterOpAICoreConfig(nullptr, nullptr, nullptr); - configRegistry.RegisterOpAICoreConfig("AddCustomNullptr", nullptr, nullptr); - - OpConfigRegistryImpl::GetInstance().AddAICoreConfig(nullptr, nullptr, nullptr); - OpConfigRegistryImpl::GetInstance().AddAICoreConfig("AddCustomNullptr", nullptr, nullptr); - OpConfigRegistryImpl::GetInstance().GetOpAllAICoreConfig(nullptr); -} - -TEST_F(OpDefFactoryUT, AddConfigWithRegMacroTest) { - auto regConfigs = GetOpAllAICoreConfig("AddCustomAddConfigWithRegMacro"); - EXPECT_EQ(regConfigs.size(), 1); - auto it = regConfigs.cbegin(); - EXPECT_EQ(std::string((it->first).GetString()), "ascend111y"); - EXPECT_NE(it->second, nullptr); - - OpDef opDef = OpDefFactory::OpDefCreate("AddCustomAddConfigWithRegMacro"); - std::map aicfgs = opDef.AICore().GetAICoreConfigs(); - EXPECT_TRUE(aicfgs.find("ascend111y") != aicfgs.end()); - EXPECT_EQ(aicfgs.size(), 1); - OpAICoreConfig config = aicfgs["ascend111y"]; - std::map cfgs = config.GetCfgInfo(); - // custom config set by opdef addconfig should overwrite the config set by REGISTER_OP_AICORE_CONFIG - EXPECT_EQ(cfgs["dynamicCompileStatic.flag"], "false"); - EXPECT_EQ(cfgs["dynamicFormat.flag"], "false"); - EXPECT_EQ(cfgs["dynamicRankSupport.flag"], "false"); - EXPECT_EQ(cfgs["dynamicShapeSupport.flag"], "false"); - EXPECT_EQ(cfgs["needCheckSupport.flag"], "false"); - EXPECT_EQ(cfgs["precision_reduce.flag"], "false"); -} -} // namespace -} // namespace ops diff --git a/tests/ut/register/testcase/op_def_param_unittest.cc b/tests/ut/register/testcase/op_def_param_unittest.cc deleted file mode 100644 index 858dee353cca968a3eb543e3bef738f63189906c..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_def_param_unittest.cc +++ /dev/null @@ -1,261 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_def_registry.h" -#include "register/opdef/op_def_impl.h" - -namespace ops { - -namespace { - -class OpDefParamUT : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(OpDefParamUT, ParamTest) { - OpParamDef param("test"); - OpParamDef param2("test"); - OpParamDef param3("test3"); - EXPECT_EQ(param == param2, true); - EXPECT_EQ(param == param3, false); - OpParamTrunk desc; - desc.Input("x1") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_NCHW}) - .ValueDepend(Option::REQUIRED) - .IgnoreContiguous() - .AutoContiguous(); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .ValueDepend(Option::REQUIRED); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .ValueDepend(Option::OPTIONAL); - desc.Input("x3") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .Scalar(); - desc.Input("x4") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .ScalarList(); - desc.Output("y") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .ValueDepend(Option::REQUIRED) - .IgnoreContiguous() - .AutoContiguous() - .OutputShapeDependOnCompute(); - EXPECT_EQ(desc.Input("x1").GetParamName(), "x1"); - EXPECT_EQ(desc.Input("x1").GetParamType(), Option::OPTIONAL); - EXPECT_EQ(desc.Input("x1").GetDataTypes().size(), 1); - EXPECT_EQ(desc.Input("x1").GetFormats().size(), 1); - EXPECT_EQ(desc.Input("x1").GetUnknownShapeFormats().size(), 1); - EXPECT_EQ(desc.Input("x1").GetUnknownShapeFormats()[0], ge::FORMAT_NCHW); - EXPECT_EQ(desc.Input("x1").GetValueDepend(), "required"); - EXPECT_EQ(desc.Input("x1").GetIgnoreContiguous(), true); - EXPECT_EQ(desc.Input("x1").GetAutoContiguous(), true); - EXPECT_EQ(desc.Input("x2").GetValueDepend(), "optional"); - EXPECT_EQ(desc.Input("x2").GetIgnoreContiguous(), false); - EXPECT_EQ(desc.Input("x2").GetAutoContiguous(), false); - EXPECT_EQ(desc.Output("y").GetIgnoreContiguous(), true); - EXPECT_EQ(desc.Output("y").GetAutoContiguous(), true); - EXPECT_EQ(desc.GetInputs().size(), 4); - EXPECT_EQ(desc.GetOutputs().size(), 1); - EXPECT_EQ(desc.Input("x1").IsScalar(), false); - EXPECT_EQ(desc.Input("x1").IsScalarList(), false); - EXPECT_EQ(desc.Input("x3").IsScalar(), true); - EXPECT_EQ(desc.Input("x3").IsScalarList(), false); - EXPECT_EQ(desc.Input("x4").IsScalar(), false); - EXPECT_EQ(desc.Input("x4").IsScalarList(), true); - EXPECT_EQ(desc.Output("y").IsOutputShapeDependOnCompute(), true); -} - -TEST_F(OpDefParamUT, DependParamTest) { - OpParamTrunk desc; - desc.Input("x1") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_NCHW}) - .ValueDepend(Option::REQUIRED) - .IgnoreContiguous() - .AutoContiguous(); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .ValueDepend(Option::REQUIRED); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .ValueDepend(Option::VIRTUAL, DependScope::TILING) - .ValueDepend(Option::OPTIONAL, (DependScope)5) - .ValueDepend(Option::OPTIONAL, DependScope::TILING); - desc.Output("y") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .IgnoreContiguous() - .AutoContiguous() - .OutputShapeDependOnCompute(); - EXPECT_EQ(desc.Input("x1").GetParamName(), "x1"); - EXPECT_EQ(desc.Input("x1").GetParamType(), Option::OPTIONAL); - EXPECT_EQ(desc.Input("x1").GetDataTypes().size(), 1); - EXPECT_EQ(desc.Input("x1").GetFormats().size(), 1); - EXPECT_EQ(desc.Input("x1").GetUnknownShapeFormats().size(), 1); - EXPECT_EQ(desc.Input("x1").GetUnknownShapeFormats()[0], ge::FORMAT_NCHW); - EXPECT_EQ(desc.Input("x1").GetValueDepend(), "required"); - EXPECT_EQ(desc.Input("x1").GetDependScope(), DependScope::ALL); - EXPECT_EQ(desc.Input("x1").GetIgnoreContiguous(), true); - EXPECT_EQ(desc.Input("x1").GetAutoContiguous(), true); - EXPECT_EQ(desc.Input("x2").GetValueDepend(), "optional"); - EXPECT_EQ(desc.Input("x2").GetDependScope(), DependScope::TILING); - EXPECT_EQ(desc.Input("x2").GetIgnoreContiguous(), false); - EXPECT_EQ(desc.Input("x2").GetAutoContiguous(), false); - EXPECT_EQ(desc.Output("y").GetIgnoreContiguous(), true); - EXPECT_EQ(desc.Output("y").GetAutoContiguous(), true); - EXPECT_EQ(desc.GetInputs().size(), 2); - EXPECT_EQ(desc.GetOutputs().size(), 1); - EXPECT_EQ(desc.Input("x1").IsScalar(), false); - EXPECT_EQ(desc.Input("x1").IsScalarList(), false); - EXPECT_EQ(desc.Output("y").IsOutputShapeDependOnCompute(), true); -} -TEST_F(OpDefParamUT, FollowParamTest) { - OpParamTrunk desc; - desc.Input("x1") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_NCHW}) - .ValueDepend(Option::REQUIRED) - .IgnoreContiguous() - .AutoContiguous(); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .Follow("x1") - .ValueDepend(Option::REQUIRED); - desc.Output("y") - .ParamType(Option::OPTIONAL) - .Follow("x2") - .IgnoreContiguous() - .AutoContiguous() - .OutputShapeDependOnCompute(); - desc.FollowDataImpl(); - EXPECT_EQ(desc.Output("y").GetIgnoreContiguous(), true); - EXPECT_EQ(desc.Output("y").GetAutoContiguous(), true); - EXPECT_EQ(desc.GetInputs().size(), 2); - EXPECT_EQ(desc.GetOutputs().size(), 1); - EXPECT_EQ(desc.Output("y").IsOutputShapeDependOnCompute(), true); - EXPECT_EQ(desc.Output("y").GetFollowName(), "x1"); - EXPECT_EQ(desc.Output("y").GetFollowType(), FollowType::ALL); -} -TEST_F(OpDefParamUT, FollowListParamTest) { - OpParamTrunk desc; - desc.Input("x1") - .ParamType(Option::OPTIONAL) - .DataTypeList({ge::DT_FLOAT16}) - .FormatList({ge::FORMAT_ND}); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .Follow("x1", FollowType::DTYPE) - .FormatList({ge::FORMAT_ND}) - .ValueDepend(Option::REQUIRED); - desc.Output("y") - .ParamType(Option::OPTIONAL) - .Follow("x2", (FollowType)5) - .Follow("x2", FollowType::DTYPE); - desc.Output("x1") - .ParamType(Option::OPTIONAL) - .Follow("x1"); - - desc.FollowDataImpl(); - auto flwMap = desc.GetFollowMap(); - auto shpMap = desc.GetShapeMap(); - auto dtpMap = desc.GetDtypeMap(); - EXPECT_EQ(desc.Output("y").GetFollowType(), FollowType::DTYPE); - EXPECT_EQ(desc.GetParamDef("y", OpDef::PortStat::OUT).GetFollowType(), FollowType::DTYPE); - EXPECT_EQ(desc.Output("x1").GetFollowType(), FollowType::ALL); -} -TEST_F(OpDefParamUT, CommentTest) { - OpParamTrunk desc; - desc.Input("x1") - .ParamType(Option::OPTIONAL) - .DataTypeList({ge::DT_FLOAT16}) - .Comment("comment of param x1") - .FormatList({ge::FORMAT_ND}); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .DataTypeList({ge::DT_FLOAT16}) - .Comment("") - .Comment("comment of param x2") - .FormatList({ge::FORMAT_ND}); - desc.Output("y") - .ParamType(Option::OPTIONAL) - .DataTypeList({ge::DT_FLOAT16}) - .Comment("comment of param y") - .FormatList({ge::FORMAT_ND}); - - EXPECT_EQ(desc.Input("x1").GetComment(), "comment of param x1"); - EXPECT_EQ(desc.Input("x2").GetComment(), "comment of param x2"); - EXPECT_EQ(desc.Output("y").GetComment(), "comment of param y"); -} -TEST_F(OpDefParamUT, ForBinQueryTest) { - OpParamTrunk desc; - desc.Input("x1") - .ParamType(Option::OPTIONAL) - .DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}) - .DataTypeForBinQuery({}) - .DataTypeForBinQuery({ge::DT_FLOAT}) - .DataTypeForBinQuery({ge::DT_FLOAT, ge::DT_FLOAT}) - .FormatList({ge::FORMAT_NC, ge::FORMAT_ND}) - .FormatForBinQuery({}) - .FormatForBinQuery({ge::FORMAT_ND}) - .FormatForBinQuery({ge::FORMAT_ND, ge::FORMAT_ND}); - desc.Input("x2") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT}) - .DataTypeForBinQuery({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_NC, ge::FORMAT_ND}) - .FormatForBinQuery({ge::FORMAT_ND, ge::FORMAT_ND}); - desc.Output("y") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT}) - .DataTypeForBinQuery({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_NC, ge::FORMAT_ND}) - .FormatForBinQuery({ge::FORMAT_ND, ge::FORMAT_ND}); - EXPECT_EQ(desc.Input("x1").GetDataTypesForBin()[0], ge::DT_FLOAT); - EXPECT_EQ(desc.Input("x2").GetFormatsForBin()[0], ge::FORMAT_ND); - EXPECT_EQ(desc.Output("y").GetFormatsForBin().size(), 2); -} -} // namespace -} // namespace ops diff --git a/tests/ut/register/testcase/op_def_unittest.cc b/tests/ut/register/testcase/op_def_unittest.cc deleted file mode 100644 index f0b792d59ba0ab86f2a0669ff88f1270031015ed..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_def_unittest.cc +++ /dev/null @@ -1,305 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "register/op_def_registry.h" - -namespace ge { - -static ge::graphStatus InferShape4AddAscendC(gert::InferShapeContext *context) { - return GRAPH_SUCCESS; -} - -static ge::graphStatus InferShapeRange4AddAscendC(gert::InferShapeRangeContext *context) { - return GRAPH_SUCCESS; -} - -static ge::graphStatus InferDataType4AddAscendC(gert::InferDataTypeContext *context) { - return GRAPH_SUCCESS; -} - -} // namespace ge - -namespace ops { - -namespace { - -class OpDefUT : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(OpDefUT, Construct) { - OpDef opDef("Test"); - opDef.Input("x1").DataType({ge::DT_FLOAT16}).InitValue({ScalarType::UINT64, 1u}) - .InitValue({{ScalarType::FLOAT16, 1.1}}); - opDef.Input("x2").DataType({ge::DT_FLOAT16}).Scalar().To("x3"); - opDef.Input("x3").DataType({ge::DT_FLOAT}).Version(1).InitValue({{ScalarType::FLOAT32, 1.1}}) - .InitValue({ScalarType::UINT32, 1u}); - opDef.Input("x4").DataType({ge::DT_FLOAT}).ScalarList().To(ge::DT_INT32); - opDef.Output("y").DataType({ge::DT_FLOAT16}).InitValue({ScalarType::INT64, 1}); - opDef.SetInferShape(ge::InferShape4AddAscendC); - opDef.SetInferShapeRange(ge::InferShapeRange4AddAscendC); - opDef.SetInferDataType(ge::InferDataType4AddAscendC); - OpAICoreConfig aicConfig; - aicConfig.Input("x1") - .ParamType(Option::OPTIONAL) - .DataType({ge::DT_FLOAT}) - .Format({ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND}) - .ValueDepend(Option::REQUIRED) - .InitValue(0); - opDef.AICore().AddConfig("ascend310p", aicConfig); - aicConfig.ExtendCfgInfo("rangeLimit.value", "limited"); - EXPECT_EQ(ge::AscendString("Test"), opDef.GetOpType()); - std::vector inputs = opDef.GetMergeInputs(aicConfig); - EXPECT_EQ(inputs.size(), 4); - OpParamDef param = inputs[0]; - auto initValueList = param.GetInitValueList(); - auto originTypes = param.GetOriginDataTypes(); - EXPECT_EQ(param.GetParamName(), ge::AscendString("x1")); - EXPECT_EQ(param.GetParamType(), Option::OPTIONAL); - EXPECT_EQ(param.GetDataTypes()[0], ge::DT_FLOAT); - EXPECT_EQ(param.GetFormats()[0], ge::FORMAT_ND); - EXPECT_EQ(param.GetUnknownShapeFormats()[0], ge::FORMAT_ND); - EXPECT_EQ(param.GetValueDepend(), ge::AscendString("required")); - EXPECT_EQ(param.GetInitValue().value_u64, 0); - EXPECT_EQ(param.GetInitValueType(), InitValueType::INIT_VALUE_UINT64_T); - EXPECT_EQ(inputs[1].IsScalar(), true); - EXPECT_EQ(inputs[3].IsScalarList(), true); - EXPECT_EQ(inputs[1].GetDataTypes()[0], ge::DT_FLOAT); - EXPECT_EQ(inputs[3].GetDataTypes()[0], ge::DT_INT32); - EXPECT_EQ(inputs[3].GetScalarType(), ge::DT_INT32); - std::vector outputs = opDef.GetMergeOutputs(aicConfig); - EXPECT_EQ(outputs.size(), 1); - OpParamDef paramOut = outputs[0]; - EXPECT_EQ(paramOut.GetParamType(), Option::REQUIRED); - EXPECT_EQ(paramOut.GetDataTypes()[0], ge::DT_FLOAT16); - EXPECT_EQ(paramOut.GetFormats()[0], ge::FORMAT_ND); - aicConfig.Input("x1") - .DataType({ge::DT_FLOAT}) - .Format({ge::FORMAT_NCHW}); - inputs = opDef.GetMergeInputs(aicConfig); - EXPECT_EQ(inputs.size(), 4); - param = inputs[0]; - EXPECT_EQ(param.GetDataTypes().size(), 1); - EXPECT_EQ(param.GetFormats().size(), 1); - EXPECT_EQ(inputs[2].GetVersion(), 1); -} - -TEST_F(OpDefUT, ListParamTest) { - OpDef opDef("Test"); - opDef.Input("x1").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}).FormatList({ge::FORMAT_ND}); - opDef.Input("x2").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}).FormatList({ge::FORMAT_ND}); - opDef.Input("x3").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}).Version(1).FormatList({ge::FORMAT_ND}); - opDef.Input("x4").DataType({ge::DT_FLOAT16, ge::DT_FLOAT}).Scalar().FormatList({ge::FORMAT_ND}).To(ge::DT_FLOAT); - opDef.Output("y").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}).FormatList({ge::FORMAT_ND}); - opDef.SetInferShape(ge::InferShape4AddAscendC); - opDef.SetInferShapeRange(ge::InferShapeRange4AddAscendC); - opDef.SetInferDataType(ge::InferDataType4AddAscendC); - - OpAICoreConfig aicConfig; - aicConfig.Input("x1") - .ParamType(Option::OPTIONAL) - .DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}) - .FormatList({ge::FORMAT_ND}) - .ValueDepend(Option::REQUIRED) - .InitValue(0); - opDef.AICore().AddConfig("ascend310p", aicConfig); - aicConfig.ExtendCfgInfo("rangeLimit.value", "limited"); - EXPECT_EQ(ge::AscendString("Test"), opDef.GetOpType()); - std::vector inputs = opDef.GetMergeInputs(aicConfig); - for (size_t i = 0; i < inputs.size(); ++i) { - EXPECT_EQ(inputs[i].GetDataTypes().size(), 16); - EXPECT_EQ(inputs[i].GetFormats().size(), 16); - for (size_t j = 0; j < inputs[i].GetFormats().size(); ++j) { - EXPECT_EQ(inputs[i].GetFormats()[j], ge::FORMAT_ND); - } - } - for (uint32_t i = 0; i < 8; ++i) { - EXPECT_EQ(inputs[0].GetDataTypes()[i], ge::DT_FLOAT16); - } - for (uint32_t i = 8; i < 16; ++i) { - EXPECT_EQ(inputs[0].GetDataTypes()[i], ge::DT_FLOAT); - } - std::vector outputs = opDef.GetMergeOutputs(aicConfig); - for (size_t i = 0; i < outputs.size(); ++i) { - EXPECT_EQ(outputs[i].GetDataTypes().size(), 16); - EXPECT_EQ(outputs[i].GetFormats().size(), 16); - for (size_t j = 0; j < outputs[i].GetFormats().size(); ++j) { - EXPECT_EQ(outputs[i].GetFormats()[j], ge::FORMAT_ND); - } - } -} - -TEST_F(OpDefUT, ListParamTest1) { - OpDef opDef("Test"); - opDef.Input("x1").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}); - opDef.Input("x2").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}).ScalarList().To("x1"); - opDef.Input("x3").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}).Version(1); - opDef.Input("x4").DataType({ge::DT_FLOAT16, ge::DT_FLOAT}).Scalar().To(ge::DT_FLOAT); - opDef.Output("y").DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}); - opDef.SetInferShape(ge::InferShape4AddAscendC); - opDef.SetInferShapeRange(ge::InferShapeRange4AddAscendC); - opDef.SetInferDataType(ge::InferDataType4AddAscendC); - - OpAICoreConfig aicConfig; - aicConfig.Input("x1") - .ParamType(Option::OPTIONAL) - .DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}) - .ValueDepend(Option::REQUIRED) - .InitValue(0); - opDef.AICore().AddConfig("ascend310p", aicConfig); - aicConfig.ExtendCfgInfo("rangeLimit.value", "limited"); - EXPECT_EQ(ge::AscendString("Test"), opDef.GetOpType()); - std::vector inputs = opDef.GetMergeInputs(aicConfig); - for (size_t i = 0; i < inputs.size(); ++i) { - EXPECT_EQ(inputs[i].GetDataTypes().size(), 8); - EXPECT_EQ(inputs[i].GetFormats().size(), 8); - for (size_t j = 0; j < inputs[i].GetFormats().size(); ++j) { - EXPECT_EQ(inputs[i].GetFormats()[j], ge::FORMAT_ND); - } - } - for (uint32_t i = 0; i < 4; ++i) { - EXPECT_EQ(inputs[0].GetDataTypes()[i], ge::DT_FLOAT16); - } - for (uint32_t i = 8; i < 8; ++i) { - EXPECT_EQ(inputs[0].GetDataTypes()[i], ge::DT_FLOAT); - } - std::vector outputs = opDef.GetMergeOutputs(aicConfig); - for (size_t i = 0; i < outputs.size(); ++i) { - EXPECT_EQ(outputs[i].GetDataTypes().size(), 8); - EXPECT_EQ(outputs[i].GetFormats().size(), 8); - for (size_t j = 0; j < outputs[i].GetFormats().size(); ++j) { - EXPECT_EQ(outputs[i].GetFormats()[j], ge::FORMAT_ND); - } - } -} - -TEST_F(OpDefUT, MC2Test) { - OpDef opDef("Test"); - opDef.Input("x1").DataType({ge::DT_FLOAT16}); - opDef.Output("y").DataType({ge::DT_FLOAT16}); - opDef.Attr("group1").AttrType(REQUIRED).String(); - opDef.Attr("group1").AttrType(REQUIRED).String(); - opDef.MC2().HcclGroup("group1"); - std::vector groups = opDef.MC2().GetHcclGroups(); - EXPECT_EQ(groups.size(), 1); - opDef.MC2().HcclGroup({"group1", "group2"}); - groups = opDef.MC2().GetHcclGroups(); - EXPECT_EQ(groups.size(), 2); - opDef.MC2().HcclGroup("group2"); - groups = opDef.MC2().GetHcclGroups(); - EXPECT_EQ(groups.size(), 2); - - EXPECT_EQ(opDef.MC2().GetHcclServerType(), HcclServerType::MAX); - opDef.MC2().HcclServerType(HcclServerType::AICPU, "ascend910b"); - EXPECT_EQ(opDef.MC2().GetHcclServerType("ascend910c"), HcclServerType::MAX); - EXPECT_EQ(opDef.MC2().GetHcclServerType("ascend910b"), HcclServerType::AICPU); - EXPECT_EQ(opDef.MC2().GetHcclServerType(), HcclServerType::AICPU); - opDef.MC2().HcclServerType(HcclServerType::AICORE); - EXPECT_EQ(opDef.MC2().GetHcclServerType("ascend910c"), HcclServerType::AICORE); - EXPECT_EQ(opDef.MC2().GetHcclServerType("ascend910b"), HcclServerType::AICPU); -} - -TEST_F(OpDefUT, CommentTest) { - OpDef opDef("Test"); - opDef.Comment(CommentSection::BRIEF, "") - .Comment(CommentSection::CATEGORY, "ca tgg") - .Comment(CommentSection::CATEGORY, "catg") - .Comment(CommentSection::BRIEF, "Brie\nf cmt") - .Comment(CommentSection::BRIEF, "Brief cmt") - .Comment(CommentSection::CONSTRAINTS, "Constr\naints cmt 1") - .Comment(CommentSection::CONSTRAINTS, "Constraints cmt 2") - .Comment(CommentSection::RESTRICTIONS, "Restrictions cmt") - .Comment(CommentSection::RESTRICTIONS, "Restriction\ns cmt") - .Comment(CommentSection::THIRDPARTYFWKCOMPAT, "ThirdParn\nyFwkCopat cmt") - .Comment(CommentSection::THIRDPARTYFWKCOMPAT, "ThirdPartyFwkCopat cmt") - .Comment(CommentSection::SEE, "See cmt") - .Comment(CommentSection::SEE, "Seen\n cmt") - .Comment(CommentSection::SECTION_MAX, "Seen\n cmt"); - EXPECT_EQ(opDef.GetCateGory(), "catg"); - EXPECT_EQ(opDef.GetBrief().size(), 2); - EXPECT_EQ(opDef.GetConstraints().size(), 2); - EXPECT_EQ(opDef.GetRestrictions().size(), 2); - EXPECT_EQ(opDef.GetSee().size(), 2); - EXPECT_EQ(opDef.GetThirdPartyFwkCopat().size(), 2); - EXPECT_EQ(opDef.GetConstraints().at(1), "Constraints cmt 2"); -} - -TEST_F(OpDefUT, ForBinQueryTest) { - OpDef opDef("Test"); - opDef.Input("x") - .ParamType(Option::REQUIRED) - .DataTypeList({ge::DT_FLOAT16, ge::DT_FLOAT}) - .DataTypeForBinQuery({ge::DT_FLOAT, ge::DT_FLOAT}) - .FormatList({ge::FORMAT_NC, ge::FORMAT_ND}) - .FormatForBinQuery({ge::FORMAT_ND, ge::FORMAT_ND}); - opDef.Input("y") - .ParamType(Option::REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT}) - .DataTypeForBinQuery({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_NC, ge::FORMAT_ND}) - .FormatForBinQuery({ge::FORMAT_ND, ge::FORMAT_ND}); - opDef.Output("z") - .ParamType(Option::REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT}) - .DataTypeForBinQuery({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_NC, ge::FORMAT_ND}) - .FormatForBinQuery({ge::FORMAT_ND, ge::FORMAT_ND}) - .Follow("x"); - opDef.AICore() - .AddConfig("ascend910"); - - auto aicoreMap = opDef.AICore().GetAICoreConfigs(); - auto aicore = aicoreMap["ascend910"]; - std::vector inputs = opDef.GetMergeInputs(aicore); - std::vector outputs = opDef.GetMergeOutputs(aicore); - EXPECT_EQ(inputs[0].GetDataTypesForBin().size(), 8); - EXPECT_EQ(inputs[0].GetFormatsForBin().size(), 8); -} - -TEST_F(OpDefUT, ParamUnalignTest) { - OpDef opDef("Test"); - opDef.Input("x") - .ParamType(Option::REQUIRED) - .DataType({ge::DT_FLOAT16, ge::DT_FLOAT}) - .Format({ge::FORMAT_NC}); - opDef.Output("z") - .ParamType(Option::REQUIRED) - .Follow("x"); - opDef.AICore() - .AddConfig("ascend910"); - - auto aicoreMap = opDef.AICore().GetAICoreConfigs(); - auto aicore = aicoreMap["ascend910"]; - std::vector inputs = opDef.GetMergeInputs(aicore); - std::vector outputs = opDef.GetMergeOutputs(aicore); - EXPECT_EQ(inputs.size(), 1); - EXPECT_EQ(inputs[0].GetDataTypes().size(), 0); -} - -TEST_F(OpDefUT, TestFormatCheckAndEnableCallBack) { - OpDef opDef("Test"); - opDef.Input("x").DataType({ge::DT_FLOAT16}); - opDef.Output("y").DataType({ge::DT_FLOAT16}); - opDef.FormatMatchMode(ops::FormatCheckOption::DEFAULT); - EXPECT_EQ(opDef.GetFormatMatchMode(), ops::FormatCheckOption::DEFAULT); - opDef.FormatMatchMode(ops::FormatCheckOption::STRICT); - EXPECT_EQ(opDef.GetFormatMatchMode(), ops::FormatCheckOption::STRICT); - EXPECT_EQ(opDef.IsEnableFallBack(), false); - opDef.EnableFallBack(); - EXPECT_EQ(opDef.IsEnableFallBack(), true); -} - -} // namespace -} // namespace ops diff --git a/tests/ut/register/testcase/op_exe_res_unittest.cc b/tests/ut/register/testcase/op_exe_res_unittest.cc deleted file mode 100644 index 72ea32b4bf3992a89cbe470ef747f3f3da1d6c8d..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_exe_res_unittest.cc +++ /dev/null @@ -1,375 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/op_desc.h" -#include "graph/compute_graph.h" -#include "graph_optimizer/fusion_common/graph_pass_util.h" -#include "register/graph_optimizer/graph_fusion/fusion_quant_util.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "graph/debug/ge_log.h" -#define protected public -#include "inc/external/exe_graph/runtime/exe_res_generation_context.h" -#include "inc/exe_graph/lowering/exe_res_generation_ctx_builder.h" -#undef protected - -using namespace std; -using namespace ge; -using namespace fe; -using namespace gert; -namespace { -class OpExeResTest : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -bool IsShapeEqual(gert::Shape shape1, ge::GeShape shape2) { - if (shape1.GetDimNum() != shape2.GetDimNum()) { - GELOGD("Dim[%zu] vs [%zu].", shape1.GetDimNum(), shape2.GetDimNum()); - return false; - } - for (size_t i = 0; i < shape1.GetDimNum(); ++i) { - GELOGD("Dim val [%ld] vs [%ld].", shape1.GetDim(i), shape2.GetDim(i)); - if (shape1.GetDim(i) != shape2.GetDim(i)) { - return false; - } - } - return true; -} - -TEST_F(OpExeResTest, OpResAPITest) { - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr x = std::make_shared("x", "Data"); - OpDescPtr weight = std::make_shared("weight", "Const"); - OpDescPtr atquant_scale = std::make_shared("atquant_scale", "Const"); - OpDescPtr quant_scale = std::make_shared("quant_scale", "Const"); - OpDescPtr quant_offset = std::make_shared("quant_offset", "Const"); - OpDescPtr mm = std::make_shared("mm", "WeightQuantBatchMatmulV2"); - OpDescPtr y = std::make_shared("y", "NetOutput"); - - // add descriptor - ge::GeShape shape1({2,4,9,16}); - GeTensorDesc tensor_desc1(shape1, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tensor_desc1.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc1.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - - GeTensorDesc tensor_desc2(shape1, ge::FORMAT_NCHW, ge::DT_INT8); - tensor_desc2.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc2.SetOriginDataType(ge::DT_INT8); - tensor_desc2.SetOriginShape(shape1); - - ge::GeShape shape2({1, 16}); - GeTensorDesc tensor_desc3(shape2, ge::FORMAT_ND, ge::DT_FLOAT); - tensor_desc3.SetOriginFormat(ge::FORMAT_ND); - tensor_desc3.SetOriginDataType(ge::DT_FLOAT); - tensor_desc3.SetOriginShape(shape2); - - x->AddOutputDesc(tensor_desc1); - weight->AddOutputDesc(tensor_desc2); - atquant_scale->AddOutputDesc(tensor_desc1); - quant_scale->AddOutputDesc(tensor_desc3); - quant_offset->AddOutputDesc(tensor_desc3); - - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc2); - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc3); - mm->AddInputDesc(tensor_desc3); - mm->AddOutputDesc(tensor_desc2); - y->AddInputDesc(tensor_desc2); - - // create nodes - NodePtr x_node = graph->AddNode(x); - NodePtr weight_node = graph->AddNode(weight); - NodePtr atquant_scale_node = graph->AddNode(atquant_scale); - NodePtr quant_scale_node = graph->AddNode(quant_scale); - NodePtr quant_offset_node = graph->AddNode(quant_offset); - NodePtr mm_node = graph->AddNode(mm); - NodePtr y_node = graph->AddNode(y); - - ge::GraphUtils::AddEdge(x_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(weight_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(atquant_scale_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(quant_scale_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(3)); - - ge::GraphUtils::AddEdge(mm_node->GetOutDataAnchor(0), y_node->GetInDataAnchor(0)); - - mm->SetStreamId(12); - mm->SetId(34); - std::vector ori_work_sizes{22,33,44}; - mm->SetWorkspaceBytes(ori_work_sizes); - - ExeResGenerationCtxBuilder exe_ctx_builder; - auto res_ptr_holder = exe_ctx_builder.CreateOpExeContext(*mm_node); - EXPECT_NE(res_ptr_holder, nullptr); - auto op_exe_res_ctx = reinterpret_cast(res_ptr_holder->context_); - auto node_ptr = op_exe_res_ctx->MutableInputPointer(0); - EXPECT_NE(node_ptr, nullptr); - auto stream_id = op_exe_res_ctx->GetStreamId(); - EXPECT_EQ(stream_id, 12); - auto op_id = op_exe_res_ctx->GetOpId(); - EXPECT_EQ(op_id, 34); - - auto work_sizes = op_exe_res_ctx->GetWorkspaceBytes(); - EXPECT_EQ(work_sizes, ori_work_sizes); - std::vector n_work_sizes{02,03,04}; - op_exe_res_ctx->SetWorkspaceBytes(n_work_sizes); - work_sizes = op_exe_res_ctx->GetWorkspaceBytes(); - EXPECT_EQ(work_sizes, n_work_sizes); - - // test shape - auto in_3_shape = op_exe_res_ctx->GetInputShape(3); - EXPECT_NE(in_3_shape, nullptr); - EXPECT_EQ(IsShapeEqual(in_3_shape->GetOriginShape(), shape2), true); - EXPECT_EQ(IsShapeEqual(in_3_shape->GetStorageShape(), shape2), true); - - auto in_5_shape = op_exe_res_ctx->GetInputShape(4); - EXPECT_EQ(in_5_shape, nullptr); - - auto out_0_shape = op_exe_res_ctx->GetOutputShape(0); - EXPECT_NE(out_0_shape, nullptr); - EXPECT_EQ(IsShapeEqual(out_0_shape->GetOriginShape(), shape1), true); - EXPECT_EQ(IsShapeEqual(out_0_shape->GetStorageShape(), shape1), true); - - graph->SetGraphUnknownFlag(true); - auto mode = op_exe_res_ctx->GetExecuteMode(); - EXPECT_EQ(mode, ExecuteMode::kDynamicExecute); - - ge::GraphUtils::AddEdge(weight_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(0)); - ge::AscendString ir_name = "__input0"; - auto is_const = op_exe_res_ctx->IsConstInput(ir_name); - EXPECT_EQ(is_const, false); - - ge::AscendString ir_name1 = "__input_invalid"; - is_const = op_exe_res_ctx->IsConstInput(ir_name1); - EXPECT_EQ(is_const, false); - - std::vector stream_info_vec; - StreamInfo si_1; - si_1.name = "tiling"; - si_1.reuse_key = "tiling_key"; - si_1.depend_value_input_indices = {1, 2}; - stream_info_vec.emplace_back(si_1); - si_1.name = "tiling1"; - si_1.reuse_key = "tiling_key1"; - si_1.depend_value_input_indices = {0}; - stream_info_vec.emplace_back(si_1); - auto ret = op_exe_res_ctx->SetAttachedStreamInfos(stream_info_vec); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - std::vector stream_info_attrs; - (void)ge::AttrUtils::GetListNamedAttrs(mm_node->GetOpDesc(), ge::ATTR_NAME_ATTACHED_STREAM_INFO_LIST, - stream_info_attrs); - EXPECT_EQ(stream_info_attrs.size(), stream_info_vec.size()); - stream_info_vec.clear(); - stream_info_vec = op_exe_res_ctx->GetAttachedStreamInfos(); - EXPECT_EQ(stream_info_vec.size(), stream_info_attrs.size()); - - std::vector sync_info_vec; - SyncResInfo sync_info; - sync_info.type = SyncResType::SYNC_RES_EVENT; - sync_info.name = "tiling"; - sync_info.reuse_key = "tiling_key"; - sync_info_vec.emplace_back(sync_info); - ret = op_exe_res_ctx->SetSyncResInfos(sync_info_vec); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - std::vector sync_info_attrs; - (void)ge::AttrUtils::GetListNamedAttrs(node_ptr->GetOpDesc(), ge::ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST, - sync_info_attrs); - sync_info_vec.clear(); - sync_info_vec = op_exe_res_ctx->GetSyncResInfos(); - EXPECT_EQ(sync_info_vec.size(), sync_info_attrs.size()); - - std::string test_key = "test_key"; - std::vector test_list = {"test_key"}; - ret = op_exe_res_ctx->SetListStr(test_key, test_list); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); -} - -TEST_F(OpExeResTest, OpResAPITest_1) { - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr x = std::make_shared("x", "Data"); - OpDescPtr weight = std::make_shared("weight", "Const"); - OpDescPtr atquant_scale = std::make_shared("atquant_scale", "Const"); - OpDescPtr quant_scale = std::make_shared("quant_scale", "Const"); - OpDescPtr quant_offset = std::make_shared("quant_offset", "Const"); - OpDescPtr mm = std::make_shared("mm", "WeightQuantBatchMatmulV2"); - OpDescPtr y = std::make_shared("y", "NetOutput"); - - // add descriptor - ge::GeShape shape0({2,4,9,16}); - ge::GeShape ori_shape0({1,16}); - GeTensorDesc tensor_desc1(shape0, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tensor_desc1.SetOriginFormat(ge::FORMAT_ND); - tensor_desc1.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc1.SetOriginShape(ori_shape0); - - ge::GeShape shape1({2,4,9,16}); - GeTensorDesc tensor_desc2(shape1, ge::FORMAT_NCHW, ge::DT_INT8); - tensor_desc2.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc2.SetOriginDataType(ge::DT_INT8); - tensor_desc2.SetOriginShape(shape1); - - ge::GeShape shape2({1, 16}); - GeTensorDesc tensor_desc3(shape2, ge::FORMAT_ND, ge::DT_FLOAT); - tensor_desc3.SetOriginFormat(ge::FORMAT_ND); - tensor_desc3.SetOriginDataType(ge::DT_FLOAT); - tensor_desc3.SetOriginShape(shape2); - - x->AddOutputDesc(tensor_desc1); - weight->AddOutputDesc(tensor_desc2); - atquant_scale->AddOutputDesc(tensor_desc1); - quant_scale->AddOutputDesc(tensor_desc3); - quant_offset->AddOutputDesc(tensor_desc3); - - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc2); - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc3); - mm->AddOutputDesc(tensor_desc2); - y->AddInputDesc(tensor_desc2); - - // create nodes - NodePtr x_node = graph->AddNode(x); - NodePtr weight_node = graph->AddNode(weight); - NodePtr atquant_scale_node = graph->AddNode(atquant_scale); - NodePtr quant_scale_node = graph->AddNode(quant_scale); - NodePtr quant_offset_node = graph->AddNode(quant_offset); - NodePtr mm_node = graph->AddNode(mm); - NodePtr y_node = graph->AddNode(y); - - ge::GraphUtils::AddEdge(x_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(weight_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(atquant_scale_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(2)); - ge::GraphUtils::AddEdge(quant_scale_node->GetOutDataAnchor(0), mm_node->GetInDataAnchor(3)); - ge::GraphUtils::AddEdge(mm_node->GetOutDataAnchor(0), y_node->GetInDataAnchor(0)); - - ExeResGenerationCtxBuilder exe_ctx_builder; - auto res_ptr_holder = exe_ctx_builder.CreateOpCheckContext(*mm_node); - EXPECT_NE(res_ptr_holder, nullptr); - auto op_check_ctx = reinterpret_cast(res_ptr_holder->context_); - auto node_ptr = op_check_ctx->MutableInputPointer(0); - EXPECT_NE(node_ptr, nullptr); - - // test dtype&format - auto in_0_desc = op_check_ctx->GetInputDesc(0); - EXPECT_EQ(in_0_desc->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(in_0_desc->GetStorageFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(in_0_desc->GetOriginFormat(), ge::FORMAT_ND); - - auto in_1_desc = op_check_ctx->GetInputDesc(1); - EXPECT_EQ(in_1_desc->GetDataType(), ge::DT_INT8); - EXPECT_EQ(in_1_desc->GetStorageFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(in_1_desc->GetOriginFormat(), ge::FORMAT_NCHW); - - auto in_2_desc = op_check_ctx->GetInputDesc(2); - EXPECT_EQ(in_2_desc->GetDataType(), ge::DT_FLOAT16); - EXPECT_EQ(in_2_desc->GetStorageFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(in_2_desc->GetOriginFormat(), ge::FORMAT_ND); - - auto in_3_desc = op_check_ctx->GetInputDesc(3); - EXPECT_EQ(in_3_desc->GetDataType(), ge::DT_FLOAT); - EXPECT_EQ(in_3_desc->GetStorageFormat(), ge::FORMAT_ND); - EXPECT_EQ(in_3_desc->GetOriginFormat(), ge::FORMAT_ND); - - auto out_0_desc = op_check_ctx->GetOutputDesc(0); - EXPECT_EQ(out_0_desc->GetDataType(), ge::DT_INT8); - EXPECT_EQ(out_0_desc->GetStorageFormat(), ge::FORMAT_NCHW); - EXPECT_EQ(out_0_desc->GetOriginFormat(), ge::FORMAT_NCHW); - - // test shape - auto in_0_shape = op_check_ctx->GetInputShape(0); - EXPECT_NE(in_0_shape, nullptr); - EXPECT_EQ(IsShapeEqual(in_0_shape->GetOriginShape(), ori_shape0), true); - EXPECT_EQ(IsShapeEqual(in_0_shape->GetStorageShape(), shape0), true); - - auto in_1_shape = op_check_ctx->GetInputShape(1); - EXPECT_NE(in_1_shape, nullptr); - EXPECT_EQ(IsShapeEqual(in_1_shape->GetOriginShape(), shape1), true); - EXPECT_EQ(IsShapeEqual(in_1_shape->GetStorageShape(), shape1), true); - - auto in_2_shape = op_check_ctx->GetInputShape(2); - EXPECT_NE(in_2_shape, nullptr); - EXPECT_EQ(IsShapeEqual(in_2_shape->GetOriginShape(), ori_shape0), true); - EXPECT_EQ(IsShapeEqual(in_2_shape->GetStorageShape(), shape0), true); - - auto in_3_shape = op_check_ctx->GetInputShape(3); - EXPECT_NE(in_3_shape, nullptr); - EXPECT_EQ(IsShapeEqual(in_3_shape->GetOriginShape(), shape2), true); - EXPECT_EQ(IsShapeEqual(in_3_shape->GetStorageShape(), shape2), true); - - auto out_0_shape = op_check_ctx->GetOutputShape(0); - EXPECT_NE(out_0_shape, nullptr); - EXPECT_EQ(IsShapeEqual(out_0_shape->GetOriginShape(), shape1), true); - EXPECT_EQ(IsShapeEqual(out_0_shape->GetStorageShape(), shape1), true); -} - -TEST_F(OpExeResTest, SK_Test) { - ComputeGraphPtr graph = std::make_shared("test"); - OpDescPtr mm = std::make_shared("mm", "WeightQuantBatchMatmulV2"); - // add descriptor - ge::GeShape shape0({2,4,9,16}); - ge::GeShape ori_shape0({1,16}); - GeTensorDesc tensor_desc1(shape0, ge::FORMAT_NCHW, ge::DT_FLOAT16); - tensor_desc1.SetOriginFormat(ge::FORMAT_ND); - tensor_desc1.SetOriginDataType(ge::DT_FLOAT16); - tensor_desc1.SetOriginShape(ori_shape0); - - ge::GeShape shape1({2,4,9,16}); - GeTensorDesc tensor_desc2(shape1, ge::FORMAT_NCHW, ge::DT_INT8); - tensor_desc2.SetOriginFormat(ge::FORMAT_NCHW); - tensor_desc2.SetOriginDataType(ge::DT_INT8); - tensor_desc2.SetOriginShape(shape1); - - ge::GeShape shape2({1, 16}); - GeTensorDesc tensor_desc3(shape2, ge::FORMAT_ND, ge::DT_FLOAT); - tensor_desc3.SetOriginFormat(ge::FORMAT_ND); - tensor_desc3.SetOriginDataType(ge::DT_FLOAT); - tensor_desc3.SetOriginShape(shape2); - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc2); - mm->AddInputDesc(tensor_desc1); - mm->AddInputDesc(tensor_desc3); - mm->AddOutputDesc(tensor_desc2); - NodePtr mm_node = graph->AddNode(mm); - std::string sk_scope = "_ascendc_super_kernel_scope"; - std::string scope_val = "mla"; - ge::AttrUtils::SetStr(mm, sk_scope, scope_val); - ExeResGenerationCtxBuilder exe_ctx_builder; - auto res_ptr_holder = exe_ctx_builder.CreateOpCheckContext(*mm_node); - EXPECT_NE(res_ptr_holder, nullptr); - auto op_exe_res_ctx = reinterpret_cast(res_ptr_holder->context_); - EXPECT_NE(op_exe_res_ctx, nullptr); - ge::AscendString str_ret; - op_exe_res_ctx->GetStrAttrVal(sk_scope.c_str(), str_ret); - EXPECT_EQ(*(str_ret.GetString()), *(scope_val.c_str())); - std::string scope_val2 = "mlp"; - op_exe_res_ctx->SetStrAttrVal(sk_scope.c_str(), scope_val2.c_str()); - std::string str_tmp; - ge::AttrUtils::GetStr(mm, sk_scope, str_tmp); - EXPECT_EQ(str_tmp, scope_val2); - std::string sk_id = "sub_id"; - ge::AttrUtils::SetInt(mm, sk_id, 1); - int64_t int_val = 0; - op_exe_res_ctx->GetIntAttrVal(sk_id.c_str(), int_val); - EXPECT_EQ(int_val, 1); - op_exe_res_ctx->SetIntAttrVal(sk_id.c_str(), 3); - int64_t int_ret = 0; - ge::AttrUtils::GetInt(mm, sk_id, int_ret); - EXPECT_EQ(int_ret, 3); -} -} // namespace diff --git a/tests/ut/register/testcase/op_impl_registry_holder_manager_unittest.cc b/tests/ut/register/testcase/op_impl_registry_holder_manager_unittest.cc index 985d829376ea6b2004be0463bf9200b2efd3d762..c56ca200646b78eb8aadd7e262c57a2b2589a2fb 100644 --- a/tests/ut/register/testcase/op_impl_registry_holder_manager_unittest.cc +++ b/tests/ut/register/testcase/op_impl_registry_holder_manager_unittest.cc @@ -16,6 +16,11 @@ #include "tests/depends/mmpa/src/mmpa_stub.h" #include #include +#include "graph/operator_factory_impl.h" + +namespace ge { +void ge::OperatorFactoryImpl::ReleaseRegInfo() {} +} namespace gert_test { namespace { diff --git a/tests/ut/register/testcase/op_impl_registry_unittest.cc b/tests/ut/register/testcase/op_impl_registry_unittest.cc index cba264da190c86bb429ac1cedd947ff4b4f206f3..f87fb946c5f02228b9de621050af118a354aebdc 100644 --- a/tests/ut/register/testcase/op_impl_registry_unittest.cc +++ b/tests/ut/register/testcase/op_impl_registry_unittest.cc @@ -13,7 +13,6 @@ #include "graph/any_value.h" #include "register/op_impl_registry_api.h" #include "base/registry/op_impl_register_v2_impl.h" -#include "exe_graph/runtime/infer_symbol_shape_context.h" namespace gert_test { namespace { diff --git a/tests/ut/register/testcase/op_kernel_registry_unittest.cc b/tests/ut/register/testcase/op_kernel_registry_unittest.cc deleted file mode 100644 index ecb423508db15d5250c5bfb6dd4f8631849110df..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_kernel_registry_unittest.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "register/op_kernel_registry.h" -#include "register/host_cpu_context.h" -#include "graph/debug/ge_log.h" - -namespace ge { -class UtestOpKernelRegistry : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestOpKernelRegistry, IsRegisteredTest) { - OpKernelRegistry &op_registry = OpKernelRegistry::GetInstance(); - std::string op_type = "registry"; - bool ret = op_registry.IsRegistered(op_type); - EXPECT_EQ(ret, false); -} - -TEST_F(UtestOpKernelRegistry, HostCpuOpTest) { - OpKernelRegistry op_registry; - std::string op_type = "registry"; - OpKernelRegistry::CreateFn fn = nullptr; - op_registry.RegisterHostCpuOp(op_type, fn); - std::unique_ptr host_cpu = op_registry.CreateHostCpuOp(op_type); - EXPECT_EQ(host_cpu, nullptr); -} - -TEST_F(UtestOpKernelRegistry, HostCpuOpRegistrarTest) { - OpKernelRegistry op_registry; - HostCpuOpRegistrar host_strar(nullptr, []()->::ge::HostCpuOp* {return nullptr;}); - std::string op_type = "registry"; - std::unique_ptr host_cpu = op_registry.CreateHostCpuOp(op_type); - EXPECT_EQ(host_cpu, nullptr); -} - -} // namespace ge diff --git a/tests/ut/register/testcase/op_lib_register_unittest.cc b/tests/ut/register/testcase/op_lib_register_unittest.cc deleted file mode 100644 index f911a9fb0fb729d44feb27872316bdd148b8d5fa..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_lib_register_unittest.cc +++ /dev/null @@ -1,180 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "register/op_lib_register_impl.h" -#include "external/register/register_base.h" - -#include -#include "graph/debug/ge_log.h" -#include "tests/depends/mmpa/src/mmpa_stub.h" - -namespace ge { -namespace { -size_t g_lib_register_cnt = 1; -std::vector init_func_vec; -const std::string custom_op_name = "libcust_opapi.so"; -const std::string tmp_test_lib_dir = "./test_op_lib_register/"; - -uint32_t FakeFunc(ge::AscendString& path) { - init_func_vec.emplace_back(g_lib_register_cnt); - path = AscendString(to_string(g_lib_register_cnt).c_str()); - ++g_lib_register_cnt; - return 0; -} - -void ClearCache() { - g_lib_register_cnt = 1; - init_func_vec.clear(); - OpLibRegistry::GetInstance().vendor_funcs_.clear(); - OpLibRegistry::GetInstance().vendor_names_set_.clear(); - OpLibRegistry::GetInstance().op_lib_paths_ = ""; - OpLibRegistry::GetInstance().ClearHandles(); - OpLibRegistry::GetInstance().is_processed_ = false; -} - -void CreateVendorSoPath(const std::string &vendor_dir) { - system(("mkdir -p " + vendor_dir).c_str()); - system(("touch " + vendor_dir + custom_op_name).c_str()); -} - -void CreateVendorOldRunbagDir(const std::string &vendor_dir) { - system(("mkdir -p " + vendor_dir + "/op_proto/").c_str()); -} - -void DelVendorSoDir(const std::string &vendor_dir) { - system(("rm -rf " + vendor_dir).c_str()); -} - -class MockMmpaForOpLib : public ge::MmpaStubApi { - public: - void *DlOpen(const char *fileName, int32_t mode) override { - auto tmp_register = ge::OpLibRegister(fileName).RegOpLibInit(FakeFunc); - return (void *) fileName; - } - int32_t DlClose(void *handle) override { - return 0L; - } -}; -} - -class OpLibRegisterUT : public testing::Test { - protected: - void SetUp() { - system("pwd"); - system(("mkdir -p " + tmp_test_lib_dir).c_str()); - } - - void TearDown() { - system(("rm -rf " + tmp_test_lib_dir).c_str()); - unsetenv("ASCEND_CUSTOM_OPP_PATH"); - ClearCache(); - ge::MmpaStub::GetInstance().Reset(); - } -}; - -TEST_F(OpLibRegisterUT, register_construct) { - OpLibRegister tmp1("vendor1"); - auto tmp2 = tmp1; - auto tmp3 = OpLibRegister(std::move(tmp1)); - EXPECT_NE(tmp1.impl_.get(), nullptr); - EXPECT_EQ(tmp2.impl_.get(), nullptr); - EXPECT_EQ(tmp3.impl_.get(), nullptr); -} - -TEST_F(OpLibRegisterUT, register_same_vendor) { - ClearCache(); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_funcs_.size(), 0); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_names_set_.size(), 0); - - REGISTER_OP_LIB(vendor_1).RegOpLibInit(FakeFunc); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_funcs_.size(), 1); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_names_set_.size(), 1); - - REGISTER_OP_LIB(vendor_1).RegOpLibInit(FakeFunc); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_funcs_.size(), 1); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_names_set_.size(), 1); -} - -TEST_F(OpLibRegisterUT, register_direct_link) { - ClearCache(); - REGISTER_OP_LIB(vendor_1).RegOpLibInit(FakeFunc); - REGISTER_OP_LIB(vendor_2).RegOpLibInit(FakeFunc); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_funcs_.size(), 2); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_names_set_.size(), 2); - unsetenv("ASCEND_CUSTOM_OPP_PATH"); - auto ret = OpLibRegistry::GetInstance().PreProcessForCustomOp(); - EXPECT_EQ(ret, GRAPH_SUCCESS); - std::vector expect_vec{1, 2}; - EXPECT_EQ(init_func_vec, expect_vec); - std::string custom_path = aclGetCustomOpLibPath(); - EXPECT_EQ("1:2", custom_path); -} - -TEST_F(OpLibRegisterUT, register_direct_link_and_env_priority) { - ClearCache(); - REGISTER_OP_LIB(vendor_1).RegOpLibInit(FakeFunc); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_funcs_.size(), 1); - - ge::MmpaStub::GetInstance().SetImpl(std::make_shared()); - std::string vendor_2_dir = tmp_test_lib_dir + "/vendor_2/"; - CreateVendorSoPath(vendor_2_dir); - std::string vendor_3_dir = tmp_test_lib_dir + "/vendor_3/"; - CreateVendorSoPath(vendor_3_dir); - std::string env_val = vendor_2_dir + ":" + vendor_3_dir; - mmSetEnv("ASCEND_CUSTOM_OPP_PATH", env_val.c_str(), 1); - - auto ret = OpLibRegistry::GetInstance().PreProcessForCustomOp(); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_funcs_.size(), 3); - std::vector expect_vec{1, 2, 3}; - EXPECT_EQ(init_func_vec, expect_vec); - EXPECT_EQ(OpLibRegistry::GetInstance().handles_.size(), 2); - EXPECT_EQ(OpLibRegistry::GetInstance().is_processed_, true); - ret = OpLibRegistry::GetInstance().PreProcessForCustomOp(); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - DelVendorSoDir(vendor_2_dir); - DelVendorSoDir(vendor_3_dir); -} - -TEST_F(OpLibRegisterUT, register_coexistence_direct_link) { - ClearCache(); - REGISTER_OP_LIB(vendor_1).RegOpLibInit(FakeFunc); - EXPECT_EQ(OpLibRegistry::GetInstance().vendor_funcs_.size(), 1); - std::string old_dir = tmp_test_lib_dir + "/vendor_2/"; - mmSetEnv("ASCEND_CUSTOM_OPP_PATH", old_dir.c_str(), 1); - CreateVendorOldRunbagDir(old_dir); - auto ret = OpLibRegistry::GetInstance().PreProcessForCustomOp(); - EXPECT_EQ(ret, SUCCESS); - std::string custom_path = aclGetCustomOpLibPath(); - std::string expect_path = "1:" + old_dir; - EXPECT_EQ(expect_path, custom_path); - DelVendorSoDir(old_dir); -} - -TEST_F(OpLibRegisterUT, register_coexistence_env) { - ClearCache(); - std::string old_dir = tmp_test_lib_dir + "/vendor_old/"; - CreateVendorOldRunbagDir(old_dir); - - ge::MmpaStub::GetInstance().SetImpl(std::make_shared()); - std::string vendor_2_dir = tmp_test_lib_dir + "/vendor_new/"; - CreateVendorSoPath(vendor_2_dir); - std::string env_val = old_dir + ":" + vendor_2_dir + ":"; - mmSetEnv("ASCEND_CUSTOM_OPP_PATH", env_val.c_str(), 1); - - auto ret = OpLibRegistry::GetInstance().PreProcessForCustomOp(); - EXPECT_EQ(ret, SUCCESS); - std::string custom_path = aclGetCustomOpLibPath(); - std::string expect_path = "1:" + env_val; - EXPECT_EQ(expect_path, custom_path); - DelVendorSoDir(old_dir); - DelVendorSoDir(vendor_2_dir); -} -} // namespace ge \ No newline at end of file diff --git a/tests/ut/register/testcase/op_tiling_attr_utils_test.cc b/tests/ut/register/testcase/op_tiling_attr_utils_test.cc deleted file mode 100644 index e5a184ef21354761881d7049998208f732b47441..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/op_tiling_attr_utils_test.cc +++ /dev/null @@ -1,401 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "register/op_tiling_attr_utils.h" - -using namespace std; -using namespace ge; - -namespace optiling { - -const uint16_t kFp16ExpBias = 15; -const uint32_t kFp32ExpBias = 127; -const uint16_t kFp16ManLen = 10; -const uint32_t kFp32ManLen = 23; -const uint32_t kFp32SignIndex = 31; -const uint16_t kFp16ManMask = 0x03FF; -const uint16_t kFp16ManHideBit = 0x0400; -const uint16_t kFp16MaxExp = 0x001F; -const uint32_t kFp32MaxMan = 0x7FFFFF; - -float Uint16ToFloat(const uint16_t &intVal) { - float ret; - - uint16_t hfSign = (intVal >> 15) & 1; - int16_t hfExp = (intVal >> kFp16ManLen) & kFp16MaxExp; - uint16_t hfMan = ((intVal >> 0) & 0x3FF) | ((((intVal >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400); - if (hfExp == 0) { - hfExp = 1; - } - - while (hfMan && !(hfMan & kFp16ManHideBit)) { - hfMan <<= 1; - hfExp--; - } - - uint32_t sRet, eRet, mRet, fVal; - - sRet = hfSign; - if (!hfMan) { - eRet = 0; - mRet = 0; - } else { - eRet = hfExp - kFp16ExpBias + kFp32ExpBias; - mRet = hfMan & kFp16ManMask; - mRet = mRet << (kFp32ManLen - kFp16ManLen); - } - fVal = ((sRet) << kFp32SignIndex) | ((eRet) << kFp32ManLen) | ((mRet) & kFp32MaxMan); - ret = *(reinterpret_cast(&fVal)); - - return ret; -} - -class OpTilingAttrUtilsTest : public testing::Test { -protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(OpTilingAttrUtilsTest, get_bool_attr_success) { - Operator op("relu", "Relu"); - op.SetAttr("attr_bool", true); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_bool", "bool", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 1); -} - -TEST_F(OpTilingAttrUtilsTest, get_bool_attr_fail_1) { - Operator op("relu", "Relu"); - op.SetAttr("attr_bool", true); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, nullptr, "bool", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_bool_attr_fail_2) { - Operator op("relu", "Relu"); - op.SetAttr("attr_bool", true); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_bool", "booooooool", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_bool_attr_fail_3) { - Operator op("relu", "Relu"); - op.SetAttr("attr_bool", true); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_bool", "booooooool", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_str_attr_success_1) { - Operator op("relu", "Relu"); - op.SetAttr("attr_str", "12345"); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_str", "str", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 5); -} - -TEST_F(OpTilingAttrUtilsTest, get_str_attr_success_2) { - Operator op("relu", "Relu"); - op.SetAttr("attr_str", "12345"); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_str", "str", attr_data_ptr, "string"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 5); -} - -TEST_F(OpTilingAttrUtilsTest, get_str_attr_fail_1) { - Operator op("relu", "Relu"); - op.SetAttr("attr_str", "12345"); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_str", "str", attr_data_ptr, "stttttr"); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_str_attr_fail_2) { - Operator op("relu", "Relu"); - op.SetAttr("attr_str", "12345"); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_str", "str", attr_data_ptr, "int"); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_str_attr_fail_3) { - Operator op("relu", "Relu"); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_str", "str", attr_data_ptr, "int"); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_int_attr_success_1) { - Operator op("relu", "Relu"); - int32_t attr = -123; - op.SetAttr("attr_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_int", "int", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, attr); -} - -TEST_F(OpTilingAttrUtilsTest, get_int_attr_success_2) { - Operator op("relu", "Relu"); - int32_t attr = -123; - op.SetAttr("attr_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_int", "int32", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, attr); -} - -TEST_F(OpTilingAttrUtilsTest, get_int_attr_to_uint32_success_1) { - Operator op("relu", "Relu"); - int32_t attr = 123; - op.SetAttr("attr_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_int", "int", attr_data_ptr, "uint32"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const uint32_t *data = (const uint32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, attr); -} - -TEST_F(OpTilingAttrUtilsTest, get_int_attr_to_uint32_success_2) { - Operator op("relu", "Relu"); - int32_t attr = 123; - op.SetAttr("attr_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_int", "int32", attr_data_ptr, "uint"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const uint32_t *data = (const uint32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, attr); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_int_attr_success_1) { - Operator op("relu", "Relu"); - vector attr = {-123, 456}; - op.SetAttr("attr_list_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_int", "list_int", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 8); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, -123); - EXPECT_EQ(*(data+1), 456); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_int_attr_success_2) { - Operator op("relu", "Relu"); - vector attr = {-123, 456}; - op.SetAttr("attr_list_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_int", "list_int32", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 8); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, -123); - EXPECT_EQ(*(data+1), 456); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_int_attr_to_list_uint32_success_1) { - Operator op("relu", "Relu"); - vector attr = {-123, 456}; - op.SetAttr("attr_list_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_int", "list_int", attr_data_ptr, "list_uint32"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 8); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, -123); - EXPECT_EQ(*(data+1), 456); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_int_attr_to_list_uint32_success_2) { - Operator op("relu", "Relu"); - vector attr = {-123, 456}; - op.SetAttr("attr_list_int", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_int", "list_int32", attr_data_ptr, "list_uint"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 8); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, -123); - EXPECT_EQ(*(data+1), 456); -} - -TEST_F(OpTilingAttrUtilsTest, get_float_attr_success_1) { - Operator op("relu", "Relu"); - float attr = 1.23; - op.SetAttr("attr_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_float", "float", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const float *data = (const float *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, attr); -} - -TEST_F(OpTilingAttrUtilsTest, get_float_attr_success_2) { - Operator op("relu", "Relu"); - float attr = 1.23; - op.SetAttr("attr_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_float", "float32", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const float *data = (const float *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, attr); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_float_attr_success_1) { - Operator op("relu", "Relu"); - vector attr = {1.23, 2.34}; - op.SetAttr("attr_list_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_float", "list_float", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 8); - const float *data = (const float *)attr_data_ptr->GetData(); - cout << *data << endl; - cout << *(data+1) << endl; -} - -TEST_F(OpTilingAttrUtilsTest, get_list_float_attr_success_2) { - Operator op("relu", "Relu"); - vector attr = {1.23, 2.34}; - op.SetAttr("attr_list_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_float", "list_float32", attr_data_ptr); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 8); - const float *data = (const float *)attr_data_ptr->GetData(); - cout << *data << endl; - cout << *(data+1) << endl; -} - -TEST_F(OpTilingAttrUtilsTest, get_float_attr_to_float16_success_1) { - Operator op("relu", "Relu"); - float attr = 1.23; - op.SetAttr("attr_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_float", "float", attr_data_ptr, "float16"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 2); - const uint16_t *data = (const uint16_t *)attr_data_ptr->GetData(); - cout << Uint16ToFloat(*data) << endl; -} - -TEST_F(OpTilingAttrUtilsTest, get_float_attr_to_float16_success_2) { - Operator op("relu", "Relu"); - float attr = 1.23; - op.SetAttr("attr_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_float", "float32", attr_data_ptr, "float16"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 2); - const uint16_t *data = (const uint16_t *)attr_data_ptr->GetData(); - cout << Uint16ToFloat(*data) << endl; -} - -TEST_F(OpTilingAttrUtilsTest, get_list_float_attr_to_float16_success_1) { - Operator op("relu", "Relu"); - vector attr = {1.23, -2.34}; - op.SetAttr("attr_list_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_float", "list_float", attr_data_ptr, "list_float16"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const uint16_t *data = (const uint16_t *)attr_data_ptr->GetData(); - cout << Uint16ToFloat(*data) << endl; - cout << Uint16ToFloat(*(data+1)) << endl; -} - -TEST_F(OpTilingAttrUtilsTest, get_list_float_attr_to_float16_success_2) { - Operator op("relu", "Relu"); - vector attr = {1.23, -2.34}; - op.SetAttr("attr_list_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_float", "list_float32", attr_data_ptr, "list_float16"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const uint16_t *data = (const uint16_t *)attr_data_ptr->GetData(); - cout << Uint16ToFloat(*data) << endl; - cout << Uint16ToFloat(*(data+1)) << endl; -} - -TEST_F(OpTilingAttrUtilsTest, get_float_attr_to_int32_success_1) { - Operator op("relu", "Relu"); - float attr = 3.56; - op.SetAttr("attr_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_float", "float", attr_data_ptr, "int32"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, 3); -} - -TEST_F(OpTilingAttrUtilsTest, get_float_attr_to_int32_success_2) { - Operator op("relu", "Relu"); - float attr = -3.56; - op.SetAttr("attr_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_float", "float", attr_data_ptr, "int"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 4); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, -3); -} - -TEST_F(OpTilingAttrUtilsTest, get_float_attr_to_int32_fail_1) { - Operator op("relu", "Relu"); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_float", "float", attr_data_ptr, "int"); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_float_attr_to_int32_success_1) { - Operator op("relu", "Relu"); - vector attr = {1.63, -2.34}; - op.SetAttr("attr_list_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_float", "list_float32", attr_data_ptr, "list_int"); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(attr_data_ptr->GetSize(), 8); - const int32_t *data = (const int32_t *)attr_data_ptr->GetData(); - EXPECT_EQ(*data, 1); - EXPECT_EQ(*(data+1), -2); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_float_attr_to_int32_fail_1) { - Operator op("relu", "Relu"); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_float", "list_float32", attr_data_ptr, "list_int"); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(OpTilingAttrUtilsTest, get_list_float_attr_to_int32_fail_2) { - Operator op("relu", "Relu"); - vector attr; - op.SetAttr("attr_list_float", attr); - AttrDataPtr attr_data_ptr = nullptr; - graphStatus ret = GetOperatorAttrValue(op, "attr_list_float", "list_float32", attr_data_ptr, "list_int"); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -} diff --git a/tests/ut/register/testcase/ops_kernel_builder_registry_unittest.cc b/tests/ut/register/testcase/ops_kernel_builder_registry_unittest.cc deleted file mode 100644 index 47f913f415bc9fbb5b8dbf38371319a800579543..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/ops_kernel_builder_registry_unittest.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "register/ops_kernel_builder_registry.h" -#include "graph/debug/ge_log.h" - -namespace ge { -class UtestOpsKernelBuilderRegistry : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestOpsKernelBuilderRegistry, GetAllKernelBuildersTest) { - ge::OpsKernelBuilderRegistry ops_registry; - OpsKernelBuilderPtr opsptr = std::shared_ptr(); - ops_registry.kernel_builders_.insert(pair("ops1", opsptr)); - std::map ops_map; - ops_map = ops_registry.GetAll(); - EXPECT_EQ(ops_map.size(), 1); - EXPECT_EQ(ops_map["ops1"], opsptr); -} - -TEST_F(UtestOpsKernelBuilderRegistry, RegisterTest) { - ge::OpsKernelBuilderRegistry ops_registry; - std::string name = "register1"; - OpsKernelBuilderPtr instance = std::shared_ptr(); - ops_registry.Register(name, instance); - std::map kernel_builders_; - kernel_builders_ = ops_registry.GetAll(); - EXPECT_EQ(kernel_builders_.size(), 1); - EXPECT_EQ(kernel_builders_["register1"], instance); -} - -TEST_F(UtestOpsKernelBuilderRegistry, UnregisterTest) { - ge::OpsKernelBuilderRegistry ops_registry; - std::string name1 = "register1"; - OpsKernelBuilderPtr opsPtr1 = std::shared_ptr(); - ops_registry.Register(name1, opsPtr1); - - std::string name2 = "register2"; - OpsKernelBuilderPtr opsPtr2 = std::shared_ptr(); - ops_registry.Register(name2, opsPtr2); - - std::map ops_map; - ops_map = ops_registry.GetAll(); - EXPECT_EQ(ops_map.size(), 2); - EXPECT_EQ(ops_map["register1"], opsPtr1); - EXPECT_EQ(ops_map["register2"], opsPtr2); - - ops_registry.Unregister("register1"); - ops_map = ops_registry.GetAll(); - EXPECT_EQ(ops_map.size(), 1); - EXPECT_EQ(ops_map.count("register1"), 0); -} - -TEST_F(UtestOpsKernelBuilderRegistry, UnregisterAllTest) { - ge::OpsKernelBuilderRegistry ops_registry; - std::string name1 = "register1"; - OpsKernelBuilderPtr opsPtr1 = std::shared_ptr(); - ops_registry.Register(name1, opsPtr1); - - std::string name2 = "register2"; - OpsKernelBuilderPtr opsPtr2 = std::shared_ptr(); - ops_registry.Register(name2, opsPtr2); - - std::map ops_map; - ops_map = ops_registry.GetAll(); - EXPECT_EQ(ops_map.size(), 2); - EXPECT_EQ(ops_map["register1"], opsPtr1); - EXPECT_EQ(ops_map["register2"], opsPtr2); - - ops_registry.UnregisterAll(); - ops_map = ops_registry.GetAll(); - EXPECT_EQ(ops_map.size(), 0); - EXPECT_EQ(ops_map.count("register1"), 0); - EXPECT_EQ(ops_map.count("register2"), 0); -} - -TEST_F(UtestOpsKernelBuilderRegistry, OpsKernelBuilderRegistrarTest) { - std::string name = "register"; - ge::OpsKernelBuilderRegistrar::CreateFn fn = nullptr; - OpsKernelBuilderRegistrar ops_rar(name, fn); - std::map ops_map; - ops_map = OpsKernelBuilderRegistry::GetInstance().GetAll(); - EXPECT_EQ(ops_map.size(), 1); -} - -} // namespace ge diff --git a/tests/ut/register/testcase/optimization_option_registry_unittest.cc b/tests/ut/register/testcase/optimization_option_registry_unittest.cc deleted file mode 100644 index 7bbb153dc32be3e5522f9ef665ac6cd9b98d6b7d..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/optimization_option_registry_unittest.cc +++ /dev/null @@ -1,217 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "register/optimization_option_registry.h" -#include "ge_common/ge_api_types.h" -#include "graph/debug/ge_log.h" -#include "common/ge_common/string_util.h" - -namespace { -bool ThresholdCheckerFunc(const std::string &opt_value) { - std::string tmp_opt_value = opt_value; - std::stringstream ss(ge::StringUtils::Trim(tmp_opt_value)); - int64_t opt_convert; - ss >> opt_convert; - if (ss.fail() || !ss.eof()) { - return false; - } - return true; -} -} // namespace -namespace ge { -class OptimizationOptRegistryUT : public testing::Test { - protected: - void SetUp() override { - dlog_setlevel(0, 0, 0); - } - void TearDown() override { - dlog_setlevel(0, 3, 0); - } - OptionRegistry &opt_registry = OptionRegistry::GetInstance(); - const PassOptionRegistry &pass_opt_registry = PassOptionRegistry::GetInstance(); - const char_t *const kInvalidOptionName = "ge.oo.invalid_option_name"; - const char_t *const kUtOptionName = "ge.ut_test_option"; - - std::unordered_map GetRegisteredOptionsByLevel(OoLevel level) { - std::unordered_map options; - for (const auto &opt_info : opt_registry.GetRegisteredOptTable()) { - if (OoInfoUtils::IsBitSet(opt_info.second.levels, static_cast(level))) { - const auto value_str = OoInfoUtils::GetDefaultValue(opt_info.second, level); - options.emplace(opt_info.first, opt_info.second); - } - } - return options; - } -}; - -TEST_F(OptimizationOptRegistryUT, RegisterPassOption_Fail_InvalidParams) { - std::vector option_names; - // pass name is empty - REG_PASS_OPTION("").LEVELS(OoLevel::kO1); - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("", option_names), GRAPH_FAILED); - // option levels is invalid - REG_PASS_OPTION("InvalidTestPass").LEVELS(OoLevel::kEnd); - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("InvalidTestPass", option_names), GRAPH_FAILED); - // option name is invalid - REG_PASS_OPTION("InvalidTestPass").LEVELS(OoLevel::kO1).SWITCH_OPT(""); - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("InvalidTestPass", option_names), GRAPH_FAILED); - // hierarchy is invalid - REG_PASS_OPTION("InvalidTestPass") - .LEVELS(OoLevel::kO1) - .SWITCH_OPT("test_option_name1", OoHierarchy::kH1) - .SWITCH_OPT("test_option_name", OoHierarchy::kEnd); - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("InvalidTestPass", option_names), GRAPH_FAILED); - // invalid switch options - REG_PASS_OPTION("UtFakePass").SWITCH_OPT("ut.fake_option_secondary", OoHierarchy::kH2); - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("InvalidTestPass", option_names), GRAPH_FAILED); - REG_PASS_OPTION("UtFakePass") - .SWITCH_OPT("ut.fake_option", OoHierarchy::kH2) - .SWITCH_OPT("ut.fake_option_secondary", OoHierarchy::kH2); - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("UtFakePass", option_names), GRAPH_FAILED); - REG_PASS_OPTION("UtFakePass") - .SWITCH_OPT("ut.fake_option", OoHierarchy::kH1) - .SWITCH_OPT("ut.fake_option_secondary", OoHierarchy::kH1); - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("UtFakePass", option_names), GRAPH_SUCCESS); - EXPECT_EQ(option_names.front(), "ut.fake_option"); -} - -TEST_F(OptimizationOptRegistryUT, RegisterOption_Fail_InvalidParams) { - // option name is empty - REG_OPTION("").LEVELS(OoLevel::kO1); - EXPECT_EQ(opt_registry.FindOptInfo(""), nullptr); - // option level is invalid - REG_OPTION(kUtOptionName).LEVELS(OoLevel::kEnd); - EXPECT_EQ(opt_registry.FindOptInfo("ut.test_option"), nullptr); - // option level is not set - REG_OPTION(kUtOptionName).DEFAULT_VALUES({OoLevel::kO1, "true"}); - EXPECT_EQ(opt_registry.FindOptInfo(kUtOptionName), nullptr); - // option hierarchy is invalid - REG_OPTION(kUtOptionName, OoHierarchy::kEnd).LEVELS(OoLevel::kO1); - EXPECT_EQ(opt_registry.FindOptInfo(kUtOptionName), nullptr); -} - -TEST_F(OptimizationOptRegistryUT, RegisterOption_Ok_RegisterOptionMoreThanOnce) { - // Set OoLevel and more than once will not overwrite - REG_OPTION("ge.oo.fake_more_than_once1") - .LEVELS(OoLevel::kO3) - .LEVELS(OoLevel::kO2) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .DEFAULT_VALUES({{OoLevel::kO1, "true"}}); - auto opt1 = opt_registry.FindOptInfo("ge.oo.fake_more_than_once1"); - EXPECT_NE(opt1, nullptr); - EXPECT_EQ(OoInfoUtils::IsBitSet(opt1->levels, static_cast(OoLevel::kO2)), false); - - EXPECT_EQ(GetRegisteredOptionsByLevel(OoLevel::kO2).count("ge.oo.fake_more_than_once1"), 0UL); - EXPECT_EQ(GetRegisteredOptionsByLevel(OoLevel::kO3).count("ge.oo.fake_more_than_once1"), 1UL); - - // cannot register same options more than once - REG_OPTION("ge.oo.fake_more_than_once1").LEVELS(OoLevel::kO2); - REG_PASS_OPTION("UtMoreThanOncePass").SWITCH_OPT("ge.oo.fake_more_than_once1"); - const auto option_ptr = opt_registry.FindOptInfo("ge.oo.fake_more_than_once2"); - ASSERT_EQ(option_ptr, nullptr); - EXPECT_EQ(GetRegisteredOptionsByLevel(OoLevel::kO2).count("ge.oo.fake_more_than_once1"), 0UL); - EXPECT_EQ(GetRegisteredOptionsByLevel(OoLevel::kO3).count("ge.oo.fake_more_than_once1"), 1UL); - std::vector opt_names; - EXPECT_EQ(pass_opt_registry.FindOptionNamesByPassName("UtMoreThanOncePass", opt_names), GRAPH_SUCCESS); -} - -TEST_F(OptimizationOptRegistryUT, RegisterOption_Ok_FunctionalPassWithoutVisibleOption) { - REG_PASS_OPTION("FakeFunctionalPass0").LEVELS(OoLevel::kO0); - REG_PASS_OPTION("FakeFunctionalPass1").LEVELS(OoLevel::kO1); - REG_PASS_OPTION("FakeFunctionalPass2").LEVELS(OoLevel::kO2); - REG_PASS_OPTION("FakeFunctionalPass3").LEVELS(OoLevel::kO3); - // repeated registration will not take effect - REG_PASS_OPTION("FakeFunctionalPass3").LEVELS(OoLevel::kO3); - // check pass name ot option names - auto check_pass_option = [this](const std::string &pass_name) { - std::vector opt_names; - const auto ret = pass_opt_registry.FindOptionNamesByPassName(pass_name, opt_names); - ASSERT_EQ(ret, GRAPH_SUCCESS); - ASSERT_EQ(opt_names.size(), 1UL); - EXPECT_EQ(opt_names[static_cast(OoHierarchy::kH1)], pass_name); - ASSERT_NE(opt_registry.FindOptInfo(pass_name), nullptr); - EXPECT_EQ(opt_registry.FindOptInfo(pass_name)->visibility, 0UL); - }; - check_pass_option("FakeFunctionalPass0"); - check_pass_option("FakeFunctionalPass1"); - check_pass_option("FakeFunctionalPass2"); - check_pass_option("FakeFunctionalPass3"); -} - -TEST_F(OptimizationOptRegistryUT, RegisterOption_Ok_MultiHierarchicalOptions) { - // multiple switch options - REG_OPTION("ge.oo.fake_graph_fusion") - .LEVELS(OoLevel::kO3) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_fake_graph_fusion", OoCategory::kModelTuning) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .HELP("The switch of fake graph fusion"); - REG_OPTION("ge.oo.fake_graph_fusion_add_relu", OoHierarchy::kH2) - .LEVELS(OoLevel::kO3) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_fake_graph_fusion_add_relu", OoCategory::kModelTuning) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .HELP("The secondary switch of fake graph fusion for add-relu"); - REG_PASS_OPTION("GraphFusionAddReluPass") - .SWITCH_OPT("ge.oo.fake_graph_fusion") - .SWITCH_OPT("ge.oo.fake_graph_fusion_add_relu", OoHierarchy::kH2); - const auto option_ptr = opt_registry.FindOptInfo("ge.oo.fake_graph_fusion"); - const auto option_ptr2 = opt_registry.FindOptInfo("ge.oo.fake_graph_fusion_add_relu"); - EXPECT_NE(option_ptr, nullptr); - EXPECT_NE(option_ptr2, nullptr); - EXPECT_EQ(option_ptr->show_infos.at(OoEntryPoint::kAtc).show_name, "oo_fake_graph_fusion"); - EXPECT_EQ(option_ptr2->show_infos.at(OoEntryPoint::kAtc).show_name, "oo_fake_graph_fusion_add_relu"); - // bind option to another pass - REG_OPTION("ge.oo.fake_conv_relu_thres") - .LEVELS(OoLevel::kO3) - .DEFAULT_VALUES({{OoLevel::kO3, "800"}}) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_fake_conv_relu_thres", OoCategory::kModelTuning) - .CHECKER(ThresholdCheckerFunc) - .HELP("The threshold of conv relu"); - REG_PASS_OPTION("FakeConvReluFusionPass").SWITCH_OPT("ge.oo.fake_graph_fusion"); - const auto option_ptr3 = opt_registry.FindOptInfo("ge.oo.fake_graph_fusion"); - const auto option_ptr4 = opt_registry.FindOptInfo("ge.oo.fake_conv_relu_thres"); - EXPECT_NE(option_ptr3, nullptr); - EXPECT_NE(option_ptr4, nullptr); - EXPECT_EQ(option_ptr3->checker("TRUE"), false); - EXPECT_EQ(option_ptr4->checker("233"), true); - const auto option_infos = GetRegisteredOptionsByLevel(OoLevel::kO3); - ASSERT_EQ(option_infos.count("ge.oo.fake_graph_fusion"), 1UL); - ASSERT_EQ(option_infos.count("ge.oo.fake_conv_relu_thres"), 1UL); - EXPECT_EQ(OoInfoUtils::GetDefaultValue(option_infos.at("ge.oo.fake_conv_relu_thres"), OoLevel::kO3), "800"); -} - -TEST_F(OptimizationOptRegistryUT, GetCommandLineOptions_Ok) { - REG_OPTION("ge.oo.fake_option") - .LEVELS(OoLevel::kO3) - .DEFAULT_VALUES({{OoLevel::kO3, "800"}}) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_fake_option", OoCategory::kModelTuning) - .CHECKER(ThresholdCheckerFunc) - .HELP("fake option"); - REG_OPTION("ge.oo.fake_option_a", OoHierarchy::kH2) - .LEVELS(OoLevel::kO3) - .VISIBILITY(OoEntryPoint::kAtc, OoEntryPoint::kSession, OoEntryPoint::kIrBuild) - .SHOW_NAME(OoEntryPoint::kAtc, "oo_fake_option_a", OoCategory::kModelTuning) - .CHECKER(OoInfoUtils::IsSwitchOptValueValid) - .HELP("The secondary switch of fake graph fusion for add-relu"); - REG_PASS_OPTION("FakePass0").LEVELS(OoLevel::kO0); - REG_PASS_OPTION("FakePass1").LEVELS(OoLevel::kO1); - REG_PASS_OPTION("FakePass2").LEVELS(OoLevel::kO2); - REG_PASS_OPTION("FakePass3").LEVELS(OoLevel::kO3); - const auto cmd_options = opt_registry.GetVisibleOptions(OoEntryPoint::kAtc); - for (const auto &cmd : cmd_options) { - EXPECT_EQ(cmd.first.empty(), false); - std::cout << cmd.first << " " << cmd.second.help_text << std::endl; - } -} -} // namespace ge \ No newline at end of file diff --git a/tests/ut/register/testcase/reg_op_unittest.cc b/tests/ut/register/testcase/reg_op_unittest.cc deleted file mode 100644 index 70cff3c296e2782d391f511fde2913abe17d5cd2..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/reg_op_unittest.cc +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/operator_reg.h" -#include - -#include "graph/utils/op_desc_utils.h" -namespace { -const std::string kStr1 = "abc"; -const std::string kStr2 = "abcd"; -const std::vector kStrs = {"a", "bc"}; -const std::vector kAscendStrs = {ge::AscendString("a"), ge::AscendString("b")}; -} - -namespace ge { -class RegisterOpUnittest : public testing::Test {}; - -REG_OP(AttrIrNameRegSuccess1) - .ATTR(AttrInt, Int, 0) - .ATTR(AttrFloat, Float, 0.0) - .ATTR(AttrBool, Bool, true) - .ATTR(AttrTensor, Tensor, Tensor()) - .ATTR(AttrType, Type, DT_INT32) - .ATTR(AttrString, String, "") - .ATTR(AttrString1, String, kStr1) - .ATTR(AttrString2, String, "ab") - .ATTR(AttrAscendString, AscendString, "") - .ATTR(AttrAscendString1, AscendString, AscendString("abc")) - .ATTR(AttrListInt, ListInt, {}) - .ATTR(AttrListFloat, ListFloat, {}) - .ATTR(AttrListBool, ListBool, {}) - .ATTR(AttrListTensor, ListTensor, {}) - .ATTR(AttrListType, ListType, {}) - .ATTR(AttrListString, ListString, {}) - .ATTR(AttrListString1, ListString, {"", ""}) - .ATTR(AttrListString2, ListString, {kStr1, kStr2}) - .ATTR(AttrListString3, ListString, kStrs) - .ATTR(AttrListAscendString, ListAscendString, {}) - .ATTR(AttrListAscendString1, ListAscendString, kAscendStrs) - .ATTR(AttrBytes, Bytes, {}) - .ATTR(AttrListListInt, ListListInt, {}) - - .REQUIRED_ATTR(ReqAttrInt, Int) - .REQUIRED_ATTR(ReqAttrFloat, Float) - .REQUIRED_ATTR(ReqAttrBool, Bool) - .REQUIRED_ATTR(ReqAttrTensor, Tensor) - .REQUIRED_ATTR(ReqAttrType, Type) - .REQUIRED_ATTR(ReqAttrString, String) - .REQUIRED_ATTR(ReqAttrAscendString, AscendString) - - .REQUIRED_ATTR(ReqAttrListInt, ListInt) - .REQUIRED_ATTR(ReqAttrListFloat, ListFloat) - .REQUIRED_ATTR(ReqAttrListBool, ListBool) - .REQUIRED_ATTR(ReqAttrListTensor, ListTensor) - .REQUIRED_ATTR(ReqAttrListType, ListType) - .REQUIRED_ATTR(ReqAttrListString, ListString) - .REQUIRED_ATTR(ReqAttrListAscendString, ListAscendString) - .REQUIRED_ATTR(ReqAttrBytes, Bytes) - .REQUIRED_ATTR(ReqAttrListListInt, ListListInt) - - //.ATTR(AttrNamedAttrs, NamedAttrs, NamedAttrs()) - //.ATTR(AttrListNamedAttrs, ListNamedAttrs, {}) - .OP_END_FACTORY_REG(AttrIrNameRegSuccess1); - -REG_OP(InputIrNameRegSuccess1) - .INPUT(fix_input1, TensorType({DT_INT32, DT_INT64})) - .INPUT(fix_input2, TensorType({DT_INT32, DT_INT64})) - .OPTIONAL_INPUT(opi1, TensorType({DT_INT32, DT_INT64})) - .OPTIONAL_INPUT(opi2, TensorType({DT_INT32, DT_INT64})) - .DYNAMIC_INPUT(dyi1, TensorType({DT_INT32, DT_INT64})) - .DYNAMIC_INPUT(dyi2, TensorType({DT_INT32, DT_INT64})) - .OP_END_FACTORY_REG(InputIrNameRegSuccess1); - -TEST_F(RegisterOpUnittest, AttrIrNameRegSuccess) { - auto op = OperatorFactory::CreateOperator("AttrIrNameRegSuccess1Op", "AttrIrNameRegSuccess1"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - const auto &ir_names = op_desc->GetIrAttrNames(); - EXPECT_EQ(ir_names, - std::vector({"AttrInt", - "AttrFloat", - "AttrBool", - "AttrTensor", - "AttrType", - "AttrString", - "AttrString1", - "AttrString2", - "AttrAscendString", - "AttrAscendString1", - "AttrListInt", - "AttrListFloat", - "AttrListBool", - "AttrListTensor", - "AttrListType", - "AttrListString", - "AttrListString1", - "AttrListString2", - "AttrListString3", - "AttrListAscendString", - "AttrListAscendString1", - "AttrBytes", - "AttrListListInt", - "ReqAttrInt", - "ReqAttrFloat", - "ReqAttrBool", - "ReqAttrTensor", - "ReqAttrType", - "ReqAttrString", - "ReqAttrAscendString", - "ReqAttrListInt", - "ReqAttrListFloat", - "ReqAttrListBool", - "ReqAttrListTensor", - "ReqAttrListType", - "ReqAttrListString", - "ReqAttrListAscendString", - "ReqAttrBytes", - "ReqAttrListListInt"})); -} - -TEST_F(RegisterOpUnittest, InputIrNameRegSuccess) { - auto op = OperatorFactory::CreateOperator("InputIrNameRegSuccess1", "InputIrNameRegSuccess1"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - const auto &inputs = op_desc->GetIrInputs(); - std::vector> expect_inputs({{"fix_input1", kIrInputRequired}, - {"fix_input2", kIrInputRequired}, - {"opi1", kIrInputOptional}, - {"opi2", kIrInputOptional}, - {"dyi1", kIrInputDynamic}, - {"dyi2", kIrInputDynamic}}); - EXPECT_EQ(inputs, expect_inputs); - - EXPECT_EQ(op_desc->GetValidInputNameByIndex(0), "fix_input1"); - EXPECT_EQ(op_desc->GetValidInputNameByIndex(1), "fix_input2"); - EXPECT_EQ(op_desc->GetValidInputNameByIndex(2), ""); - EXPECT_EQ(op_desc->GetValidInputNameByIndex(3), ""); - EXPECT_EQ(op_desc->GetValidInputNameByIndex(4), ""); - EXPECT_EQ(op_desc->GetValidInputNameByIndex(5), ""); -} - -} // namespace ge diff --git a/tests/ut/register/testcase/register_bank_key_unittest.cc b/tests/ut/register/testcase/register_bank_key_unittest.cc deleted file mode 100644 index 11ec733c9fd6e8417c9039cd3a7bf2c13211016a..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_bank_key_unittest.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "nlohmann/json.hpp" -#include "graph/ascend_string.h" -#include "register/tuning_bank_key_registry.h" - -namespace tuningtiling { -struct DynamicRnnInputArgsV2 { - int64_t batch; - int32_t dims; -}; -bool ConvertTilingContext(const gert::TilingContext* context, - std::shared_ptr &input_args, size_t &size) { - if (context == nullptr) { - auto rnn = std::make_shared(); - rnn->batch = 0; - rnn->dims = 1; - size = sizeof(DynamicRnnInputArgsV2); - input_args = rnn; - return false; - } - return true; -} - -// v1 -DECLARE_STRUCT_RELATE_WITH_OP(DynamicRNN, DynamicRnnInputArgsV2, batch, dims); -REGISTER_OP_BANK_KEY_CONVERT_FUN(DynamicRNN, ConvertTilingContext); - -// new api test -// DECLARE_STRUCT_RELATE_WITH_OP_V2(DynamicRNN, DynamicRnnInputArgsV2, -// batch, dims); -// REGISTER_OP_BANK_KEY_CONVERT_FUN_V2(DynamicRNN, ConvertTilingContext); -class RegisterOPBankKeyUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -extern "C" void _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann10basic_jsonISt3mapSt6vectorSsblmdSaNSA_14adl_serializerESD_IhSaIhEEEEEERKS5_IFbRS7_RmRKSH_EE(); - -extern "C" void _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann16json_abi_v3_11_210basic_jsonISt3mapSt6vectorSsblmdSaNSB_14adl_serializerESE_IhSaIhEEEEEERKS5_IFbRS7_RmRKSI_EE(); - -TEST_F(RegisterOPBankKeyUT, convert_tiling_context) { - _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann10basic_jsonISt3mapSt6vectorSsblmdSaNSA_14adl_serializerESD_IhSaIhEEEEEERKS5_IFbRS7_RmRKSH_EE(); - _ZN12tuningtiling21OpBankKeyFuncRegistryC1ERKN2ge12AscendStringERKSt8functionIFbRKSt10shared_ptrIvEmRN15ascend_nlohmann16json_abi_v3_11_210basic_jsonISt3mapSt6vectorSsblmdSaNSB_14adl_serializerESE_IhSaIhEEEEEERKS5_IFbRS7_RmRKSI_EE(); - auto& func = OpBankKeyFuncRegistry::RegisteredOpFuncInfo(); - auto iter = func.find("DynamicRNN"); - nlohmann::json test; - test["batch"] = 12; - test["dims"] = 2; - std::string dump_str; - dump_str = test.dump(); - ge::AscendString test_str; - - const OpBankLoadFun& load_func = iter->second.GetBankKeyLoadFunc(); - std::shared_ptr ld = nullptr; - size_t len = 0; - load_func(ld, len, test_str); - - const auto &parse_func = iter->second.GetBankKeyParseFunc(); - ge::AscendString test2_str; - parse_func(ld, len, test2_str); - - const auto &convert_func = iter->second.GetBankKeyConvertFunc(); - std::shared_ptr op_key = nullptr; - size_t s = 0U; - convert_func(nullptr, op_key, s); - auto rnn_ky = std::static_pointer_cast(op_key); - -} - -// TEST_F(RegisterOPBankKeyUT, convert_tiling_contextV2) { -// auto& func = OpBankKeyFuncRegistryV2::RegisteredOpFuncInfoV2(); -// auto iter = func.find("DynamicRNN"); -// nlohmann::json test; -// test["batch"] = 12; -// test["dims"] = 2; -// std::string dump_str; -// dump_str = test.dump(); -// ge::AscendString test_str; -// test_str = ge::AscendString(dump_str.c_str()); -// ASSERT_TRUE(iter != func.cend()); - -// const OpBankLoadFunV2& load_funcV2 = iter->second.GetBankKeyLoadFuncV2(); -// std::shared_ptr ld = nullptr; -// size_t len = 0; -// EXPECT_TRUE(load_funcV2(ld, len, test_str)); -// EXPECT_TRUE(ld != nullptr); - -// const auto &parse_funcV2 = iter->second.GetBankKeyParseFuncV2(); -// ge::AscendString test2; -// EXPECT_TRUE(parse_funcV2(ld, len, test2)); -// EXPECT_EQ(test_str, test2); - -// const auto &convert_funcV2 = iter->second.GetBankKeyConvertFuncV2(); -// std::shared_ptr op_key = nullptr; -// size_t s = 0U; -// EXPECT_FALSE(convert_funcV2(nullptr, op_key, s)); -// EXPECT_TRUE(s !=0); -// EXPECT_TRUE(op_key != nullptr); -// auto rnn_ky = std::static_pointer_cast(op_key); -// EXPECT_EQ(rnn_ky->batch, 0); - -// } -} // namespace tuningtiling diff --git a/tests/ut/register/testcase/register_buffer_fusion.cc b/tests/ut/register/testcase/register_buffer_fusion.cc deleted file mode 100644 index 230b2ba5ad222193a9bfbe3815826e5fffd3fc48..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_buffer_fusion.cc +++ /dev/null @@ -1,2210 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "gtest/gtest.h" - -#include "graph/compute_graph.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/op_kernel_bin.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h" -#include "graph/debug/ge_log.h" -#include "register/graph_optimizer/graph_fusion/connection_matrix.h" -#include "register/graph_optimizer/fusion_common/op_slice_info.h" -#include "runtime/kernel.h" - -using namespace std; -using namespace domi; -using namespace fe; -using namespace ge; - -static const string STREAM_LABEL = "_stream_label"; -const std::string FE_IMPLY_TYPE = "_fe_imply_type"; -namespace fe{ -static const uint32_t L2_MAXDATANUM = 8; -using L2FusionData_t = struct tag_l2_fusion_data { - uint32_t l2Index; - uint64_t l2Addr; - uint64_t l2PageNum; -}; -using L2FusionDataMap_t = std::map; - -using fe_sm_desc_t = struct tag_fe_sm_desc { - rtL2Ctrl_t l2ctrl; - std::string node_name[L2_MAXDATANUM]; - uint8_t output_index[L2_MAXDATANUM]; -}; - -using TaskL2FusionInfo_t = struct TagTaskL2FusionInfo { - std::string node_name; - fe_sm_desc_t l2_info; - L2FusionDataMap_t input; - L2FusionDataMap_t output; - uint32_t is_used; -}; -using L2FusionInfoPtr = std::shared_ptr; -} - -class TbeCommonRules2FusionPass : public BufferFusionPassBase { - public: - explicit TbeCommonRules2FusionPass() = default; - - ~TbeCommonRules2FusionPass() override = default; - - protected: - /* - * @brief: define a common ub fusion pattern: - * (StrideRead) -> Convolution -> (Dequant) -> Elewise*N -> Quant -> (StrideWrite) - * - * pattern limits: - * 1. StrideRead, StrideWrite, Dequant are optional, Conv2D and Quant are required. - * 2. Elewise supports LeakyRelu, Vadd, Relu, Relu6, Prelu, Add, Mul. The number of Elewise can be 0 to 5. - * 3. There are two outputs from Dequant or Elewise, one is int8 or int4, the other is fp16. - * - * - * fusion node: (StrideRead), Convolution, (AscendDequant), Elewise, AscendQuant, - * - * @return BufferFusionPattern: return all valid patterns. - */ - vector DefinePatterns() override; - - /* - * @brief: parse nodes matched in mapping and call DoFusion - * @param [in] graph: original graph - * @param [out] mapping: nodes matched by pattern - * @return bool: fusion status ok or not. - */ - Status GetFusionNodes(const BufferFusionMapping &mapping, vector &fusion_nodes) override; - - private: - static int CountOtherOutput(vector dequant_nodes, vector elem_wise_nodes); - - static bool JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes); -}; - -namespace { -const string PATTERN_STRIDEREAD = "strideRead"; // NOLINT -const string PATTERN_CONVOLUTION = "convolution"; // NOLINT -const string PATTERN_DEPTHWISECONV = "depthwiseconv"; // NOLINT -const string PATTERN_DEQUANT = "dequant"; // NOLINT -const string PATTERN_ELEMWISE = "elemWise"; // NOLINT -const string PATTERN_QUANT = "quant"; // NOLINT -const string PATTERN_STRIDEWRITE = "strideWrite"; // NOLINT -const string PATTERN_OTHER_INPUT = "otherInput"; // NOLINT -const string PATTERN_OUTPUT = "output"; // NOLINT - -const vector ELEM_WISE_WHITE_LIST = {"Eltwise", "LeakyRelu", "Vadd", "Relu", - "Relu6", "Relu6D", "PRelu", - "Add", "Mul", "Softplus", "Sigmoid", "Mish", - "Minimum", "Tanh", "Swish"}; // NOLINT - -const int MAX_OP_COUNT = 20; -const int MAX_ELEMWISE_COUNT = 5; -const int INPUT_MAX_SIZE = 2; -const int kConvOutputMaxSize = 2; -} - -#define UT_CHECK(cond, log_func, return_expr) \ - do { \ - if (cond) { \ - log_func; \ - return_expr; \ - } \ - } while (0) - -#define UT_CHECK_NOTNULL(val) \ - do { \ - if ((val) == nullptr) { \ - GE_LOGE("Parameter[%s] must not be null.", #val); \ - return fe::PARAM_INVALID; \ - } \ - } while (0)3 - -/* -* @brief: define a common ub fusion pattern: -* (StrideRead) -> Convolution -> (Dequant) -> Elewise*N -> Quant -> (StrideWrite) -* -* pattern limits: -* 1. StrideRead, StrideWrite, Dequant are optional, Conv2D and Quant are required. -* 2. Elewise supports LeakyRelu, Vadd, Relu, Relu6, Prelu, Add, Mul. The number of Elewise can be 0 to 5. -* 3. There are two outputs from Dequant or Elewise, one is int8 or int4, the other is fp16. -* -* -* fusion node: (StrideRead), Convolution, (AscendDequant), Elewise, AscendQuant, -* -* @return BufferFusionPattern: return all valid patterns. -*/ -vector TbeCommonRules2FusionPass::DefinePatterns() { - vector patterns; - string pass_name = "TbeCommonRules2FusionPass"; - auto *pattern = new (std::nothrow) BufferFusionPattern(pass_name, MAX_OP_COUNT); - UT_CHECK((pattern == nullptr), - GE_LOGE("[SubGraphOpt][CommonRules2Fus][DefPtn] New an object failed."), return patterns); - GELOGD("Start to define %s pass pattern.", pass_name.c_str()); - pattern->AddOpDesc(PATTERN_STRIDEREAD, {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_CONVOLUTION, {OP_PATTERN_CONV}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEPTHWISECONV, {OP_PATTERN_DEPTHWISE_CONV}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEQUANT, {OP_PATTERN_DEQUANT}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_OTHER_INPUT, {TBE_PATTERN_INPUT_NODE}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_ELEMWISE, {OP_PATTERN_ELEMWISE, OP_PATTERN_BROAD_CAST}, TBE_PATTERN_NUM_NONE, - MAX_ELEMWISE_COUNT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_QUANT, {OP_PATTERN_QUANT}, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_STRIDEWRITE, {OP_PATTERN_STRIDED_WRITE}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .SetHead({PATTERN_STRIDEREAD, PATTERN_CONVOLUTION, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_STRIDEREAD, {PATTERN_CONVOLUTION, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_CONVOLUTION, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_DEPTHWISECONV, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_DEQUANT, {PATTERN_ELEMWISE}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_OTHER_INPUT, {PATTERN_DEQUANT}) - .SetOutputs(PATTERN_ELEMWISE, {PATTERN_QUANT}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_QUANT, {PATTERN_STRIDEWRITE}, TBE_OUTPUT_BRANCH_SINGLE, false, true); - patterns.push_back(pattern); - GELOGD("End to define %s pass pattern.", pass_name.c_str()); - - return patterns; -} - -int TbeCommonRules2FusionPass::CountOtherOutput(vector dequant_nodes, - vector elem_wise_nodes) { - int other_out_count = 0; - // count EltWise op other output - for (const auto &elem_wise_node : elem_wise_nodes) { - if (elem_wise_node->GetOutDataNodes().empty()) { - continue; - } - int other_elt_wise_out = (int)(elem_wise_node->GetOutDataNodes().size() - 1); - other_out_count += other_elt_wise_out; - } - - // count Dequant op other output - if (!dequant_nodes.empty()) { - int other_dequant_out = 0; - if (dequant_nodes[0]->GetOutDataNodes().empty()) { - other_dequant_out = 0; - } else { - other_dequant_out = static_cast(dequant_nodes[0]->GetOutDataNodes().size() - 1); - } - other_out_count += other_dequant_out; - } - return other_out_count; -} - -bool TbeCommonRules2FusionPass::JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes) { - if (pre_elemwise_nodes.empty()) { - return false; - } - ge::NodePtr cur_node = pre_elemwise_nodes[0]; - for (auto &elemwise_node: elemwise_nodes) { - ge::NodePtr pre_node = cur_node; - cur_node = elemwise_node; - if (cur_node->GetOpDesc()->GetInputsSize() != INPUT_MAX_SIZE) { - continue; - } - - if ((cur_node->GetInDataAnchor(0)->GetPeerOutAnchor() == nullptr) || - (cur_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode() == nullptr)) { - return false; - } - auto cur_node_input0 = cur_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - vector in_scope_dims; - vector out_scope_dims; - if (cur_node_input0->GetName() == pre_node->GetOpDesc()->GetName()) { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - } else { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - } - if (in_scope_dims.size() != out_scope_dims.size()) { - GELOGD("Elem_wise[node: %s] : the number of input's dims is not equal. in_scope_dims: %zu, out_scope_dims: %zu", - cur_node->GetName().c_str(), in_scope_dims.size(), out_scope_dims.size()); - return false; - } else { - for (size_t i = 0; i < in_scope_dims.size(); i++) { - if (in_scope_dims[i] < out_scope_dims[i]) { - GELOGD("Elem_wise[node: %s] dims[%zu]: the value of in_scope is less than out_scope. in_scope : %ld," - " out_scope : %ld", cur_node->GetName().c_str(), i, in_scope_dims[i], out_scope_dims[i]); - return true; - } - } - } - } - return false; -} - -/* -* @brief: parse nodes matched in mapping and call DoFusion -* @param [in] graph: original graph -* @param [out] mapping: nodes matched by pattern -* @return bool: fusion status ok or not. -*/ -Status TbeCommonRules2FusionPass::GetFusionNodes(const BufferFusionMapping &mapping, - vector &fusion_nodes) { - fusion_nodes = GetMatchedNodes(mapping); - vector output_nodes = GetMatchedNodesByDescName(TBE_PATTERN_OUTPUT_NODE, mapping); - vector conv_nodes = GetMatchedNodesByDescName(PATTERN_CONVOLUTION, mapping); - vector depthwise_nodes = GetMatchedNodesByDescName(PATTERN_DEPTHWISECONV, mapping); - vector elem_wise_nodes = GetMatchedNodesByDescName(PATTERN_ELEMWISE, mapping); - vector dequant_nodes = GetMatchedNodesByDescName(PATTERN_DEQUANT, mapping); - vector quant_nodes = GetMatchedNodesByDescName(PATTERN_QUANT, mapping); - vector stride_write_nodes = GetMatchedNodesByDescName(PATTERN_STRIDEWRITE, mapping); - - bool conv_depth_size = conv_nodes.size() == 1 || depthwise_nodes.size() == 1; - if (!conv_depth_size) { - GELOGD("There is no conv and depthwise in TbeCommonRules2FusionPass"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - vector conv_depthwise_nodes = conv_nodes.size() == 1 ? conv_nodes : depthwise_nodes; - - size_t conv_output_size = conv_depthwise_nodes[0]->GetOutDataNodes().size(); - // conv outputs size is more than 2, skip fused - if (conv_output_size > kConvOutputMaxSize) { - GELOGD("node: %s, outputs is more than 2, size is: %zu.", - conv_depthwise_nodes[0]->GetName().c_str(), conv_output_size); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - - // the output_data can't be fused - for (const auto &outputnode : output_nodes) { - auto node_ptr = find(fusion_nodes.begin(), fusion_nodes.end(), outputnode); - if (node_ptr != fusion_nodes.end()) { - fusion_nodes.erase(node_ptr); - } - } - - // this pattern only support one other output from dequant node or elem_wise node - int other_out_count = CountOtherOutput(dequant_nodes, elem_wise_nodes); - bool cond_other_out_count = (conv_output_size == 1 && other_out_count != 1) || - (conv_output_size == kConvOutputMaxSize && other_out_count != 0); - if (cond_other_out_count) { - GELOGD("The number of other output from EltWise or Dequant is %d, skip fusion.", other_out_count); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - - // if elewise has 2 input and inscope's shape less than outscope's shape, skip fusion - bool dequant_flag = !dequant_nodes.empty() && - JudgeElemShapeInScopeLessThanOutScope(dequant_nodes, elem_wise_nodes); - if (dequant_flag) { - GELOGD("dequant_nodes exist, Elemwise node has 2 inputs and in scope shape is less than outscope"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - bool no_dequant_flag = dequant_nodes.empty() && - JudgeElemShapeInScopeLessThanOutScope(conv_depthwise_nodes, elem_wise_nodes); - if (no_dequant_flag) { - GELOGD("no dequant_nodes, Elemwise node has 2 inputs and in scope shape is less than outscope"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - - // check whether the EltWise op is in the whitelist or inputsizes less then 3(only support single or double in) - for (const auto &elem_wise_node : elem_wise_nodes) { - bool support_flag = find(ELEM_WISE_WHITE_LIST.begin(), ELEM_WISE_WHITE_LIST.end(), elem_wise_node->GetType()) == - ELEM_WISE_WHITE_LIST.end() || - elem_wise_node->GetOpDesc()->GetInputsSize() > INPUT_MAX_SIZE; - if (support_flag) { - fusion_nodes.clear(); - GELOGD("Eltwise op[%s] type[%s] is not supported for this ub fusion pass, skip fusion.", - elem_wise_node->GetName().c_str(), elem_wise_node->GetType().c_str()); - return ge::GRAPH_SUCCESS; - } - } - - // if stride_write is the last node, check whether quant node has multi outputs - bool quant_node_flag = quant_nodes[0]->GetOutDataNodes().size() > 1 && !stride_write_nodes.empty(); - if (quant_node_flag) { - auto node_ptr = find(fusion_nodes.begin(), fusion_nodes.end(), stride_write_nodes[0]); - if (node_ptr != fusion_nodes.end()) { - fusion_nodes.erase(node_ptr); - } - GELOGD("Quant is not the last node of the matched pattern, \ - but has multi outpts, erase last node stride_write."); - } - return ge::GRAPH_SUCCESS; -} - -static const char PATTERN_STRIDED_READ[] = "stridedread"; -static const char PATTERN_CONV[] = "convolution"; -static const char PATTERN_STRIDED_WRITE[] = "stridedwrite"; -static const int FUSION_OP_NUM_MAX = 10; - -class ConveragePass : public BufferFusionPassBase { - public: - explicit ConveragePass() {} - - ~ConveragePass() override {} - - protected: - - /* - * @brief: define common rules0 ops fusion pattern - * - * (StrideRead) + conv2_d + (dequant) + ele-wise*N + (quant) + (StrideWrite) - * restriction: 1.each node must be single output and single reference - * 2.the range of N is 0 to 5 - * 3.allow multiple input, but only one input can be fusion - * - * @return BufferFusionPattern: return all valid patterns. - */ - vector DefinePatterns() override { - vector patterns; - string pass_name = "ConveragePass"; - BufferFusionPattern *pattern = new (std::nothrow) BufferFusionPattern(pass_name, 10); - UT_CHECK((pattern == nullptr), - GE_LOGE("[SubGraphOpt][CommonRules0Fus][DefPtn] New an object failed."), - return patterns); - GELOGD("Start to define %s pass pattern.", pass_name.c_str()); - // define pattern rules - pattern->AddOpDesc("", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->AddOpDesc("test", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_NUM_NONE, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->AddOpDesc("test", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc("test", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc("head1", {OP_PATTERN_STRIDED_READ}, 2, - 3, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc("head2", {OP_PATTERN_STRIDED_READ}, 1, - 1, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->AddOpDesc(PATTERN_CONV, {OP_PATTERN_CONV}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEPTHWISECONV, {OP_PATTERN_DEPTHWISE_CONV}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc(PATTERN_STRIDED_READ, {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->SetOutputs("", {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - pattern->SetOutputs("1", {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {"1", PATTERN_DEPTHWISECONV}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {PATTERN_STRIDED_READ}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - - - vector heads; - pattern->SetHead(heads); - - heads = {""}; - pattern->SetHead(heads); - - heads = {"head1"}; - pattern->SetHead(heads); - - heads = {PATTERN_CONV}; - pattern->SetHead(heads); - - heads = {PATTERN_CONV, "head2"}; - pattern->SetHead(heads); - - auto conv_desc = pattern->GetOpDesc(PATTERN_CONV); - pattern->UpdateSkipStatus(conv_desc); - - pattern->GetOpDescs(); - patterns.push_back(pattern); - return patterns; //todo need to check - } -}; - -class TbeCommonRules0FusionPass : public BufferFusionPassBase { - public: - explicit TbeCommonRules0FusionPass() {} - - ~TbeCommonRules0FusionPass() override {} - - protected: - /* - * @brief: define common rules0 ops fusion pattern - * - * (StrideRead) + conv2_d + (dequant) + ele-wise*N + (quant) + (StrideWrite) - * restriction: 1.each node must be single output and single reference - * 2.the range of N is 0 to 5 - * 3.allow multiple input, but only one input can be fusion - * - * @return BufferFusionPattern: return all valid patterns. - */ - vector DefinePatterns() override; - - /* - * @brief: parse nodes matched in mapping and call DoFusion - * @param [in] graph: original graph - * @param [out] mapping: nodes matched by pattern - * @return bool: fusion status ok or not. - */ - Status GetFusionNodes(const BufferFusionMapping &mapping, vector &fusion_nodes) override; - - private: - static bool DealWithSameInAndOutScopeDimSize(const vector &in_scope_dims, - const vector &out_scope_dims, - const vector &elemwise_nodes, - const ge::NodePtr &cur_node, const size_t &i, - vector &fusion_node); - - static bool JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes, - vector &fusion_nodes); - static bool IsInBlackListOfOpPatternElemwise(vector &elemwise_nodes, ge::NodePtr &node_ptr); -}; - -namespace { - -// white list of OP_PATTERN_ELEMWISE -static const vector WHITELIST_OF_OP_PATTERN_ELEMWISE = { - "Eltwise", "LeakyRelu", "Vadd", "Relu", "Relu6", "Relu6D", - "PRelu", "Add", "Mul", "Softplus", "Sigmoid", "Mish","Minimum", - "Tanh", "Swish"}; -// black list of OP_PATTERN_ELEMWISE -static const vector BLACKLIST_OF_OP_PATTERN_ELEMWISE = { - "ReluGradV2"}; -} - -/* - * @brief: define common rules0 ops fusion pattern - * - * (StrideRead) + conv2_d + (dequant) + ele-wise*N + (quant) + (StrideWrite) - * restriction: 1.each node must be single output and single reference - * 2.the range of N is 0 to 5 - * 3.allow multiple input, but only one input can be fusion - * - * @return BufferFusionPattern: return all valid patterns. - */ -vector TbeCommonRules0FusionPass::DefinePatterns() { - vector patterns; - string pass_name = "TbeCommonRules0FusionPass"; - BufferFusionPattern *pattern = new (std::nothrow) BufferFusionPattern(pass_name, FUSION_OP_NUM_MAX); - UT_CHECK((pattern == nullptr), - GE_LOGE("[SubGraphOpt][CommonRules0Fus][DefPtn] New an object failed."), - return patterns); - GELOGD("Start to define %s pass pattern.", pass_name.c_str()); - // define pattern rules - pattern->AddOpDesc(PATTERN_STRIDED_READ, {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_CONV, {OP_PATTERN_CONV}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEPTHWISECONV, {OP_PATTERN_DEPTHWISE_CONV}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEQUANT, {OP_PATTERN_DEQUANT}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDescTypeRules(PATTERN_ELEMWISE, {OP_PATTERN_ELEMWISE, OP_PATTERN_BROAD_CAST}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_MAX, - TBE_PATTERN_GROUPID_INVALID, {IGNORE_SHAPE_TYPE, ONLY_SUPPORT_STATIC}) - .AddOpDesc(PATTERN_QUANT, {OP_PATTERN_QUANT}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_STRIDED_WRITE, {OP_PATTERN_STRIDED_WRITE}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_OTHER_INPUT, {TBE_PATTERN_INPUT_NODE}, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .SetHead({PATTERN_STRIDED_READ, PATTERN_CONV, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_STRIDED_READ, {PATTERN_CONV, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_CONV, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_DEPTHWISECONV, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_OTHER_INPUT, {PATTERN_DEQUANT}) - .SetOutputs(PATTERN_DEQUANT, {PATTERN_ELEMWISE}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_ELEMWISE, {PATTERN_QUANT}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_QUANT, {PATTERN_STRIDED_WRITE}); - - patterns.push_back(pattern); - - GELOGD("End to define %s pass pattern.", pass_name.c_str()); - return patterns; -} - -static void DelNotMatchNodesFromFusionNodes(ge::NodePtr node_ptr, vector &fusion_nodes) { - auto node = find(fusion_nodes.begin(), fusion_nodes.end(), node_ptr); - if (node != fusion_nodes.end()) { - fusion_nodes.erase(node); - } else { - return; - } - - auto curr_nodes = node_ptr->GetOutDataNodes(); - if (curr_nodes.size() != 1) { - return; - } else { - DelNotMatchNodesFromFusionNodes(curr_nodes.at(0), fusion_nodes); - } - return; -} - -static bool IsInWhiteListOfOpPatternElemwise(vector &elemwise_nodes, ge::NodePtr &node_ptr) { - for (auto &elemwise_node : elemwise_nodes) { - string elemwise_type = elemwise_node->GetType(); - auto op_type = - find(WHITELIST_OF_OP_PATTERN_ELEMWISE.begin(), WHITELIST_OF_OP_PATTERN_ELEMWISE.end(), elemwise_type); - if (op_type == WHITELIST_OF_OP_PATTERN_ELEMWISE.end()) { - GELOGD("node:%s[type:%s] not in elemwise white_list.", - elemwise_node->GetName().c_str(), elemwise_type.c_str()); - node_ptr = elemwise_node; - return false; - } - } - return true; -} - -bool TbeCommonRules0FusionPass::IsInBlackListOfOpPatternElemwise(vector &elemwise_nodes, - ge::NodePtr &node_ptr) { - for (auto &elemwise_node : elemwise_nodes) { - string elemwise_type = elemwise_node->GetType(); - auto op_type = - find(BLACKLIST_OF_OP_PATTERN_ELEMWISE.begin(), BLACKLIST_OF_OP_PATTERN_ELEMWISE.end(), elemwise_type); - if (op_type != BLACKLIST_OF_OP_PATTERN_ELEMWISE.end()) { - GELOGD("node:%s[type:%s] in elemwise black_list.", elemwise_node->GetName().c_str(), elemwise_type.c_str()); - node_ptr = elemwise_node; - return true; - } - } - return false; -} - -static void CheckElewiseInputSize(vector &elemwise_nodes, vector &fusion_nodes) { - for (auto elemwise_node : elemwise_nodes) { - if (elemwise_node->GetOpDesc()->GetInputsSize() > INPUT_MAX_SIZE) { - DelNotMatchNodesFromFusionNodes(elemwise_node, fusion_nodes); - return; - } - } -} - -bool TbeCommonRules0FusionPass::DealWithSameInAndOutScopeDimSize(const vector &in_scope_dims, - const vector &out_scope_dims, - const vector &elemwise_nodes, - const ge::NodePtr &cur_node, const size_t &i, - vector &fusion_nodes) { - for (size_t j = 0; j < in_scope_dims.size(); j++) { - if (in_scope_dims[j] < out_scope_dims[j]) { - GELOGD("Elem_wise[node: %s] dims[%zu] : the value of in_scope is less than out_scope. in_scope : %ld," - " out_scope : %ld", cur_node->GetName().c_str(), j, in_scope_dims[j], out_scope_dims[j]); - vector new_elemwise_nodes; - for (size_t z = i; z < elemwise_nodes.size(); z++) { - new_elemwise_nodes.push_back(elemwise_nodes[z]); - } - for (auto new_elemwise_node : new_elemwise_nodes) { - DelNotMatchNodesFromFusionNodes(new_elemwise_node, fusion_nodes); - } - return true; - } - } - return false; -} - -bool TbeCommonRules0FusionPass::JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes, - vector &fusion_nodes) { - if (pre_elemwise_nodes.empty()) { - return false; - } - ge::NodePtr cur_node = pre_elemwise_nodes[0]; - for (size_t i = 0; i < elemwise_nodes.size(); i++) { - ge::NodePtr elemwise_node = elemwise_nodes[i]; - ge::NodePtr pre_node = cur_node; - cur_node = elemwise_node; - if (cur_node->GetOpDesc()->GetInputsSize() != INPUT_MAX_SIZE) { - continue; - } - auto peerOutAnchor = cur_node->GetInDataAnchor(0)->GetPeerOutAnchor(); - if (peerOutAnchor == nullptr) { - GELOGD("node[%s]'s first peer in anchor is null", cur_node->GetName().c_str()); - continue; - } - auto cur_node_input0 = peerOutAnchor->GetOwnerNode(); - vector in_scope_dims; - vector out_scope_dims; - if (cur_node_input0->GetName() == pre_node->GetOpDesc()->GetName()) { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - } else { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - } - if (in_scope_dims.size() != out_scope_dims.size()) { - GELOGD("Elem_wise[node: %s] : the number of input's dims is not equal. in_scope : %zu, out_scope : %zu", - cur_node->GetName().c_str(), in_scope_dims.size(), out_scope_dims.size()); - return false; - } else { - if (DealWithSameInAndOutScopeDimSize(in_scope_dims, out_scope_dims, elemwise_nodes, cur_node, i, fusion_nodes)) { - return true; - } - } - } - return false; -} - -static void DelNotMatchNodes(vector& elemwise_nodes, vector &fusion_nodes) { - if (!elemwise_nodes.empty()) { - ge::NodePtr node = nullptr; - if (!IsInWhiteListOfOpPatternElemwise(elemwise_nodes, node)) { - DelNotMatchNodesFromFusionNodes(node, fusion_nodes); - } - } -} -/* - * @brief: parse nodes matched in mapping and call DoFusion - * @param [in] graph: original graph - * @param [out] mapping: nodes matched by pattern - * @return bool: fusion status ok or not. - */ -Status TbeCommonRules0FusionPass::GetFusionNodes(const BufferFusionMapping &mapping, - vector &fusion_nodes) { - GELOGD("Begin to do TbeCommonRules0FusionPass!"); - fusion_nodes = GetMatchedNodes(mapping); - - vector elemwise_nodes = GetMatchedNodesByDescName(PATTERN_ELEMWISE, mapping); - // elewise only support single in or double in - if (!elemwise_nodes.empty()) { - CheckElewiseInputSize(elemwise_nodes, fusion_nodes); - } - - vector conv_nodes = GetMatchedNodesByDescName(PATTERN_CONV, mapping); - vector depthwise_nodes = GetMatchedNodesByDescName(PATTERN_DEPTHWISECONV, mapping); - bool conv_depth_size = conv_nodes.size() == 1 || depthwise_nodes.size() == 1; - if (!conv_depth_size) { - GELOGD("There is no conv and depthwise in TbeCommonRules0FusionPass"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - vector conv_depthwise_nodes = conv_nodes.size() == 1 ? conv_nodes : depthwise_nodes; - vector dequant_nodes = GetMatchedNodesByDescName(PATTERN_DEQUANT, mapping); - - // if elewise has 2 input and inscope's shape less than outscope's shape, skip fusion - if (!dequant_nodes.empty()) { - if (JudgeElemShapeInScopeLessThanOutScope(dequant_nodes, elemwise_nodes, fusion_nodes)) { - GELOGD("dequant_nodes exist, Elemwise node has 2 inputs and in scope shape is less than outscope, try to fuse" - " before elemwise nodes"); - return ge::GRAPH_SUCCESS; - } - } else { - if (JudgeElemShapeInScopeLessThanOutScope(conv_depthwise_nodes, elemwise_nodes, fusion_nodes)) { - GELOGD("no dequant_nodes, Elemwise node has 2 inputs and in scope shape is less than outscope, try to fuse" - " before elemwise nodes"); - return ge::GRAPH_SUCCESS; - } - } - // elewise is in the blacklist, skip fusion - if (!elemwise_nodes.empty()) { - ge::NodePtr node = nullptr; - if (IsInBlackListOfOpPatternElemwise(elemwise_nodes, node)) { - GELOGD("node is in elemwise black_list, skip ub fusion!"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - } - - // in conv2_d+elewise(1~3) pattern, elewise has no restrictions, - // if nums of elewise more then 3 and either one is not in the whitelist, skip fusion - bool ret = (fusion_nodes.size() == (elemwise_nodes.size() + conv_depthwise_nodes.size())) && - (conv_depthwise_nodes.size() == 1) && !elemwise_nodes.empty(); - if (ret) { - if (elemwise_nodes.size() <= 3) { - return ge::GRAPH_SUCCESS; - } else { - ge::NodePtr node = nullptr; - if (!IsInWhiteListOfOpPatternElemwise(elemwise_nodes, node)) { - fusion_nodes.clear(); - } - return ge::GRAPH_SUCCESS; - } - } - - DelNotMatchNodes(elemwise_nodes, fusion_nodes); - - if (fusion_nodes.size() == 1) { - fusion_nodes.clear(); - } - GELOGD("End to do TbeCommonRules0FusionPass!"); - return ge::GRAPH_SUCCESS; -} - -BufferFusionPassType type = BUFFER_FUSION_PASS_TYPE_RESERVED; -REGISTER_BUFFER_FUSION_PASS("MetadefBufferFusionPassTest", type, TbeCommonRules0FusionPass); - -REGISTER_BUFFER_FUSION_PASS("", BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, - TbeCommonRules0FusionPass); - -REGISTER_BUFFER_FUSION_PASS("MetadefBufferFusionPassTest", BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, - TbeCommonRules0FusionPass); -REGISTER_BUFFER_FUSION_PASS("MetadefBufferFusionPassTest", BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, - TbeCommonRules0FusionPass); - -Status Run(ge::ComputeGraph &graph, std::shared_ptr &pass) { - // 1. get pattern info - auto patterns = pass->DefinePatterns(); - - fe::ConnectionMatrix a; - a.Generate(graph); - - // 2. for all patterns - for (BufferFusionPattern *pattern : patterns) { - if (pattern == nullptr) { - continue; - } - string pattern_name = pattern->GetName(); - pattern->GetErrorCnt(); - pattern->GetHead(); - pattern->GetOpDesc("test"); - auto conv = pattern->GetOpDesc(PATTERN_CONVOLUTION); - pattern->GetOpMaxCount(); - BufferFusionOpDesc *op_desc = nullptr; - std::vector outputs; - pattern->GetOutputs(op_desc, outputs); - - pattern->GetOutputs(conv, outputs); - pattern->GetName(); - pattern->UpdateSkipStatus(conv); - BufferFusionMapping mapping; - auto dequant = pattern->GetOpDesc(PATTERN_DEQUANT); - auto elmw = pattern->GetOpDesc(PATTERN_ELEMWISE); - auto quant = pattern->GetOpDesc(PATTERN_QUANT); - - std::vector buffer_fusion_op_vec = - {conv, dequant, elmw, quant}; - int i = 0; - for (auto node : graph.GetDirectNode()) { - if (i < 4) { - std::vector node_vec = {node}; - mapping.emplace(std::make_pair(buffer_fusion_op_vec[i], node_vec)); - } - } - pass->GetName(); - - std::vector fusion_nodes; - EXPECT_EQ(fe::SUCCESS, pass->GetFusionNodes(mapping, fusion_nodes)); - EXPECT_EQ(fe::NOT_CHANGED, pass->GetMixl2FusionNodes(mapping, fusion_nodes)); - -// OpCalcInfo op_slice_info; -// pass->CalcFusionOpSliceInfo(fusion_nodes, op_slice_info); - - pass->GetMatchedHeadNode(fusion_nodes); - - pass->SetName("test"); - - pass->GetMatchedNodes(mapping); - - pass->GetMatchedNodesByDescName(PATTERN_ELEMWISE, mapping); - } - - for (const auto &pattern : patterns) { - delete pattern; - } - - - return ge::GRAPH_SUCCESS; -} - -Status RunPass(ge::ComputeGraph &graph) { - std::shared_ptr common0 = std::make_shared(); - Run(graph, common0); - return ge::GRAPH_SUCCESS; -} - -class UB_FUSION_UT_CONV_ELT_RELU : public testing::Test { - - protected: - static void SetUpTestCase() { std::cout << "UB fusion SetUp" << std::endl; } - static void TearDownTestCase() { - std::cout << "UB fusion TearDown" << std::endl; - } - - virtual void SetUp() { - } - - virtual void TearDown() {} - void SetPattern(ge::OpDescPtr opdef, string optype) { - auto key_pattern = opdef->GetName() + "_pattern"; - ge::AttrUtils::SetStr(opdef, key_pattern, optype); - } - void SetTvmType(ge::OpDescPtr opdef) { - ge::AttrUtils::SetInt(opdef, ge::ATTR_NAME_IMPLY_TYPE,static_cast(domi::ImplyType::TVM)); - } - void BuildGraph(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr relu1 = std::make_shared("relu1", "ReLU"); - - SetPattern(conv, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv, ge::ATTR_NAME_SESSION_GRAPH_ID, "_0_1_2_3"); - std::vector params = {0, 0, 0, 0, 0, 1, 0, 1}; - AttrUtils::SetListInt(conv, "ub_atomic_params", params); - AttrUtils::SetBool(conv, "Aipp_Conv_Flag", true); - conv->SetWorkspaceBytes({0}); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - } - - void BuildGraphForL2Fusion(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "ReLU"); - OpDescPtr relu1 = std::make_shared("relu1", "ReLU"); - SetPattern(conv, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - std::vector params = {0, 0, 0, 0, 0, 1, 0, 1}; - AttrUtils::SetListInt(conv, "ub_atomic_params", params); - AttrUtils::SetBool(conv, "Aipp_Conv_Flag", true); - conv->SetWorkspaceBytes({0}); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - - //elementwise l2 Info - L2FusionInfoPtr elementwise_l2_info_ptr = std::make_shared(); - //data - uint64_t L2_mirror_addr=0; // preload or swap source address - uint32_t L2_data_section_size=123; // every data size - uint8_t L2_preload=0; // 1 - preload from mirror_addr, 0 - no preload - uint8_t modified=1; // 1 - data will be modified by kernel, 0 - no modified - uint8_t priority=1; // data priority - int8_t prev_L2_page_offset_base=-1; // remap source section offset - uint8_t L2_page_offset_base=0; // remap destination section offset - uint8_t L2_load_to_ddr=0; // 1 - need load out, 0 - no need - rtSmData_t tmp_data={L2_mirror_addr,L2_data_section_size,L2_preload,modified,priority,prev_L2_page_offset_base,L2_page_offset_base,L2_page_offset_base,L2_load_to_ddr}; - tmp_data.reserved[2]={0}; - - elementwise_l2_info_ptr->l2_info.l2ctrl.data[0]=tmp_data; - elementwise_l2_info_ptr->l2_info.l2ctrl.size=60; - elementwise_l2_info_ptr->node_name="elem"; - L2FusionData_t elem_output={0,123,2}; - elementwise_l2_info_ptr->output[0]=elem_output; - - (void)ge::AttrUtils::SetBool(conv_node->GetOpDesc(), "need_re_precompile", true); - elemwise_node->GetOpDesc()->SetExtAttr( - "task_l2_fusion_info_extend_content", elementwise_l2_info_ptr); - - //relu l2 Info - L2FusionInfoPtr relu_l2_info_ptr = std::make_shared(); - //data - tmp_data.reserved[2]={0}; - - relu_l2_info_ptr->l2_info.l2ctrl.data[0]=tmp_data; - relu_l2_info_ptr->l2_info.l2ctrl.size=60; - relu_l2_info_ptr->node_name="relu"; - L2FusionData_t relu_output={1,234,2}; - relu_l2_info_ptr->output[0]=relu_output; - - (void)ge::AttrUtils::SetBool(relu_node->GetOpDesc(), "need_re_precompile", true); - relu_node->GetOpDesc()->SetExtAttr( - "task_l2_fusion_info_extend_content", relu_l2_info_ptr); - - - //relu2 Info - L2FusionInfoPtr relu2_l2_info_ptr = std::make_shared(); - //data - tmp_data.reserved[2]={0}; - - relu2_l2_info_ptr->l2_info.l2ctrl.data[0]=tmp_data; - relu2_l2_info_ptr->l2_info.l2ctrl.size=60; - relu2_l2_info_ptr->node_name="relu1"; - L2FusionData_t conv_output={2,567,2}; - relu2_l2_info_ptr->output[0]=conv_output; - - (void)ge::AttrUtils::SetBool(relu1_node->GetOpDesc(), "need_re_precompile", true); - relu1_node->GetOpDesc()->SetExtAttr( - "task_l2_fusion_info_extend_content", relu2_l2_info_ptr); - - } - void BuildGraphForL2Fusion1(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr elemwise1 = std::make_shared("elem1", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "ReLU"); - - SetPattern(conv, "Convolution"); - SetPattern(elemwise1, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetPattern(relu, "ElemWise"); - SetTvmType(conv); - SetTvmType(elemwise); - SetTvmType(elemwise1); - SetTvmType(relu); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - elemwise1->AddInputDesc(out_desc); - elemwise1->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - std::vector params = {0, 0, 0, 0, 0, 1, 0, 1}; - AttrUtils::SetListInt(conv, "ub_atomic_params", params); - AttrUtils::SetBool(conv, "Aipp_Conv_Flag", true); - conv->SetWorkspaceBytes({0}); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(elemwise1, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr elemwise1_node = graph->AddNode(elemwise1); - NodePtr relu_node = graph->AddNode(relu); - - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - elemwise1_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(elemwise1_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - - //elementwise l2 Info - L2FusionInfoPtr elementwise_l2_info_ptr = std::make_shared(); - //data - uint64_t L2_mirror_addr=0; // preload or swap source address - uint32_t L2_data_section_size=123; // every data size - uint8_t L2_preload=0; // 1 - preload from mirror_addr, 0 - no preload - uint8_t modified=1; // 1 - data will be modified by kernel, 0 - no modified - uint8_t priority=1; // data priority - int8_t prev_L2_page_offset_base=-1; // remap source section offset - uint8_t L2_page_offset_base=0; // remap destination section offset - uint8_t L2_load_to_ddr=0; // 1 - need load out, 0 - no need - rtSmData_t tmp_data={L2_mirror_addr,L2_data_section_size,L2_preload,modified,priority,prev_L2_page_offset_base,L2_page_offset_base,L2_page_offset_base,L2_load_to_ddr}; - tmp_data.reserved[2]={0}; - - elementwise_l2_info_ptr->l2_info.l2ctrl.data[0]=tmp_data; - elementwise_l2_info_ptr->l2_info.l2ctrl.size=60; - elementwise_l2_info_ptr->node_name="elem"; - L2FusionData_t elem_output={0,123,2}; - elementwise_l2_info_ptr->output[0]=elem_output; - - (void)ge::AttrUtils::SetBool(elemwise1_node->GetOpDesc(), "need_re_precompile", true); - elemwise1_node->GetOpDesc()->SetExtAttr( - "task_l2_fusion_info_extend_content", elementwise_l2_info_ptr); - - } - void BuildGraph2(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - OpDescPtr conv1 = std::make_shared("conv1", "Convolution"); - OpDescPtr netout_op = std::make_shared("netoutput", "NetOutput"); - - SetPattern(conv, "Convolution"); - SetPattern(conv1, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(relu1, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(conv1); - SetTvmType(elemwise); - SetTvmType(relu); - SetTvmType(relu1); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - netout_op->AddInputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(conv1, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr conv1_node = graph->AddNode(conv1); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - NodePtr netout_node = graph->AddNode(netout_op); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv1_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - netout_node->GetInDataAnchor(0)); - } - - void BuildGraph3(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr relu1 = std::make_shared("relu1", "ReLU"); - OpDescPtr conv1 = std::make_shared("conv1", "Convolution"); - - SetPattern(conv, "Convolution"); - SetPattern(conv1, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(conv1); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(conv1, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv, STREAM_LABEL, "stream1"); - ge::AttrUtils::SetStr(conv1, STREAM_LABEL, "stream1"); - ge::AttrUtils::SetStr(elemwise, STREAM_LABEL, "stream1"); - ge::AttrUtils::SetStr(relu, STREAM_LABEL, "stream1"); - ge::AttrUtils::SetStr(relu1, STREAM_LABEL, "stream1"); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr conv1_node = graph->AddNode(conv1); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv1_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - } - - void BuildGraph4(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "ReLU"); - OpDescPtr relu1 = std::make_shared("relu1", "ReLU"); - OpDescPtr conv1 = std::make_shared("conv1", "Convolution"); - - SetPattern(conv, "Convolution"); - SetPattern(conv1, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(conv1); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(conv1, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv, STREAM_LABEL, "stream1"); - ge::AttrUtils::SetStr(conv1, STREAM_LABEL, "stream1"); - ge::AttrUtils::SetStr(elemwise, STREAM_LABEL, "stream2"); - ge::AttrUtils::SetStr(relu, STREAM_LABEL, "stream1"); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr conv1_node = graph->AddNode(conv1); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv1_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - } - - void BuildGraph5(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "ReLU"); - OpDescPtr relu1 = std::make_shared("relu1", "ReLU"); - OpDescPtr conv1 = std::make_shared("conv1", "Convolution"); - - SetPattern(conv, "Convolution"); - SetPattern(conv1, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(conv1); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(conv1, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(relu, STREAM_LABEL, "stream1"); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr conv1_node = graph->AddNode(conv1); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv1_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - } - - void BuildGraph6(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "ReLU"); - OpDescPtr relu1 = std::make_shared("relu1", "ReLU"); - OpDescPtr conv1 = std::make_shared("conv1", "Convolution"); - - SetPattern(conv, "Convolution"); - SetPattern(conv1, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(conv1); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(conv1, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv, STREAM_LABEL, "stream1"); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr conv1_node = graph->AddNode(conv1); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv1_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - } - - void BuildGraphConvReluQuant(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr quant = std::make_shared("quant", "Quant"); - - SetPattern(conv, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(quant, "quant"); - SetTvmType(conv); - SetTvmType(relu); - SetTvmType(quant); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - quant->AddInputDesc(out_desc); - quant->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(quant, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr relu_node = graph->AddNode(relu); - NodePtr quant_node = graph->AddNode(quant); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(2)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - quant_node->GetInDataAnchor(0)); - } - void BuildGraphConvLeakyReluQuant1(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr relu = std::make_shared("relu", "LeakyRelu"); - OpDescPtr quant = std::make_shared("quant", "Quant"); - - SetPattern(conv, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(quant, "quant"); - SetTvmType(conv); - SetTvmType(relu); - SetTvmType(quant); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - quant->AddInputDesc(out_desc); - quant->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - AttrUtils::SetFloat(relu, "negative_slope", 0); - AttrUtils::SetInt(quant, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr relu_node = graph->AddNode(relu); - NodePtr quant_node = graph->AddNode(quant); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(2)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - quant_node->GetInDataAnchor(0)); - } - void BuildGraphConvLeakyReluQuant2(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr relu = std::make_shared("relu", "LeakyRelu"); - OpDescPtr quant = std::make_shared("quant", "Quant"); - - SetPattern(conv, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(quant, "quant"); - SetTvmType(conv); - SetTvmType(relu); - SetTvmType(quant); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - quant->AddInputDesc(out_desc); - quant->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - AttrUtils::SetFloat(relu, "negative_slope", 0.1); - AttrUtils::SetInt(quant, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr relu_node = graph->AddNode(relu); - NodePtr quant_node = graph->AddNode(quant); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(2)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - quant_node->GetInDataAnchor(0)); - } - void BuildGraphConvEltReluQuant1(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr eltwise = std::make_shared("eltwise", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr quant = std::make_shared("quant", "Quant"); - - SetPattern(conv, "Convolution"); - SetPattern(eltwise, "ElemWise"); - SetPattern(relu, "ElemWise"); - SetPattern(quant, "quant"); - SetTvmType(conv); - SetTvmType(eltwise); - SetTvmType(relu); - SetTvmType(quant); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - eltwise->AddInputDesc(out_desc); - eltwise->AddInputDesc(out_desc); - eltwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - quant->AddInputDesc(out_desc); - quant->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(eltwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(quant, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr eltwise_node = graph->AddNode(eltwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr quant_node = graph->AddNode(quant); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - eltwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - eltwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(eltwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - quant_node->GetInDataAnchor(0)); - } - void BuildGraphConvEltReluQuant2(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Convolution"); - OpDescPtr eltwise = std::make_shared("eltwise", "EltwiseNoFusion"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr quant = std::make_shared("quant", "Quant"); - - SetPattern(conv, "Convolution"); - SetPattern(eltwise, "ElemWise"); - SetPattern(relu, "ElemWise"); - SetPattern(quant, "quant"); - SetTvmType(conv); - SetTvmType(eltwise); - SetTvmType(relu); - SetTvmType(quant); - // add descriptor - vector dim(4, 4); - GeShape shape(dim); - GeTensorDesc out_desc(shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - eltwise->AddInputDesc(out_desc); - eltwise->AddInputDesc(out_desc); - eltwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - quant->AddInputDesc(out_desc); - quant->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(eltwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(quant, FE_IMPLY_TYPE, 6); - AttrUtils::SetBool(quant, "_is_op_dynamic_impl", true); - AttrUtils::SetBool(eltwise, "_is_op_dynamic_impl", true); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr eltwise_node = graph->AddNode(eltwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr quant_node = graph->AddNode(quant); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - eltwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - eltwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(eltwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - quant_node->GetInDataAnchor(0)); - } - - void BuildGraphdoubleConvEltElt(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Conv2D"); - OpDescPtr conv1 = std::make_shared("conv1", "Conv2D"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - SetPattern(conv, "Convolution"); - SetPattern(conv1, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(conv1); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim = {4, 4, 4, 4}; - GeShape shape(dim); - GeTensorDesc out_desc(shape); - out_desc.SetOriginFormat(ge::FORMAT_NCHW); - vector dim1 = {8, 8, 8, 8}; - GeShape shape1(dim1); - GeTensorDesc out_desc1(shape1); - out_desc1.SetOriginFormat(ge::FORMAT_NCHW); - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - conv1->AddInputDesc(out_desc1); - conv1->AddInputDesc(out_desc1); - conv1->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv, ge::ATTR_NAME_SESSION_GRAPH_ID, "_0_1_2_3"); - std::vector params = {0, 0, 0, 0, 0, 1, 0, 1}; - AttrUtils::SetListInt(conv, "ub_atomic_params", params); - AttrUtils::SetBool(conv, "Aipp_Conv_Flag", true); - conv->SetWorkspaceBytes({0}); - AttrUtils::SetInt(conv1, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv1, ge::ATTR_NAME_SESSION_GRAPH_ID, "_0_1_2_3"); - AttrUtils::SetListInt(conv1, "ub_atomic_params", params); - AttrUtils::SetBool(conv1, "Aipp_Conv_Flag", true); - conv1->SetWorkspaceBytes({0}); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr conv_node1 = graph->AddNode(conv1); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - ge::OpKernelBinPtr tbe_kernel_ptr1 = std::make_shared( - conv_node1->GetName(), std::move(buffer)); - conv_node1->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr1); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node1->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node1->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv_node1->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - } - - void BuildGraphdoubleConvEltElt_1(ComputeGraphPtr graph, int32_t reluflag) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Conv2D"); - OpDescPtr conv1 = std::make_shared("conv1", "Conv2D"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - OpDescPtr relu = std::make_shared("relu", "Relu"); - OpDescPtr relu1 = std::make_shared("relu1", "Relu"); - SetPattern(conv, "Convolution"); - SetPattern(conv1, "Convolution"); - SetPattern(relu, "ElemWise"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(conv1); - SetTvmType(elemwise); - SetTvmType(relu); - // add descriptor - vector dim = {4, 4, 4, 4}; - GeShape shape(dim); - GeTensorDesc out_desc(shape); - out_desc.SetOriginFormat(ge::FORMAT_NHWC); - vector dim1 = {8, 8, 8, 8}; - GeShape shape1(dim1); - GeTensorDesc out_desc1(shape1); - out_desc1.SetOriginFormat(ge::FORMAT_HWCN); - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc1); - conv->AddInputDesc(out_desc1); - conv->AddOutputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddInputDesc(out_desc); - conv1->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - relu->AddInputDesc(out_desc); - relu->AddOutputDesc(out_desc); - relu1->AddInputDesc(out_desc); - relu1->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv, ge::ATTR_NAME_SESSION_GRAPH_ID, "_0_1_2_3"); - std::vector params = {0, 0, 0, 0, 0, 1, 0, 1}; - AttrUtils::SetListInt(conv, "ub_atomic_params", params); - AttrUtils::SetBool(conv, "Aipp_Conv_Flag", true); - conv->SetWorkspaceBytes({0}); - AttrUtils::SetInt(conv1, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv1, ge::ATTR_NAME_SESSION_GRAPH_ID, "_0_1_2_3"); - AttrUtils::SetListInt(conv1, "ub_atomic_params", params); - AttrUtils::SetBool(conv1, "Aipp_Conv_Flag", true); - conv1->SetWorkspaceBytes({0}); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - AttrUtils::SetInt(relu, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr conv_node1 = graph->AddNode(conv1); - NodePtr elemwise_node = graph->AddNode(elemwise); - NodePtr relu_node = graph->AddNode(relu); - NodePtr relu1_node = graph->AddNode(relu1); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - ge::OpKernelBinPtr tbe_kernel_ptr1 = std::make_shared( - conv_node1->GetName(), std::move(buffer)); - conv_node1->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr1); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node1->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node1->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(conv_node1->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(elemwise_node->GetOutDataAnchor(0), - relu_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(relu_node->GetOutDataAnchor(0), - relu1_node->GetInDataAnchor(0)); - } - - void BuildGraphConvElt(ComputeGraphPtr graph) { - OpDescPtr data = std::make_shared("DATA0", "Data"); - OpDescPtr data1 = std::make_shared("DATA1", "Data"); - OpDescPtr data2 = std::make_shared("DATA2", "Data"); - OpDescPtr conv = std::make_shared("conv", "Conv2D"); - OpDescPtr elemwise = std::make_shared("elem", "Eltwise"); - - SetPattern(conv, "Convolution"); - SetPattern(elemwise, "ElemWise"); - SetTvmType(conv); - SetTvmType(elemwise); - // add descriptor - vector dim(4, 4); - GeShape in_shape(dim); - GeTensorDesc out_desc(in_shape); - - data->AddOutputDesc(out_desc); - data1->AddOutputDesc(out_desc); - data2->AddOutputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddInputDesc(out_desc); - conv->AddOutputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddInputDesc(out_desc); - elemwise->AddOutputDesc(out_desc); - AttrUtils::SetInt(conv, FE_IMPLY_TYPE, 6); - ge::AttrUtils::SetStr(conv, ge::ATTR_NAME_SESSION_GRAPH_ID, "_0_1_2_3"); - std::vector params = {0, 0, 0, 0, 0, 1, 0, 1}; - AttrUtils::SetListInt(conv, "ub_atomic_params", params); - // AttrUtils::SetBool(conv, "Aipp_Conv_Flag", true); - conv->SetWorkspaceBytes({0}); - AttrUtils::SetInt(elemwise, FE_IMPLY_TYPE, 6); - - NodePtr data_node = graph->AddNode(data); - NodePtr data1_node = graph->AddNode(data1); - NodePtr data2_node = graph->AddNode(data2); - NodePtr conv_node = graph->AddNode(conv); - NodePtr elemwise_node = graph->AddNode(elemwise); - const char tbe_bin[] = "tbe_bin"; - vector buffer(tbe_bin, tbe_bin + strlen(tbe_bin)); - ge::OpKernelBinPtr tbe_kernel_ptr = std::make_shared( - conv_node->GetName(), std::move(buffer)); - conv_node->GetOpDesc()->SetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel_ptr); - - GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data1_node->GetOutDataAnchor(0), - conv_node->GetInDataAnchor(1)); - GraphUtils::AddEdge(conv_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(0)); - GraphUtils::AddEdge(data2_node->GetOutDataAnchor(0), - elemwise_node->GetInDataAnchor(1)); - } -}; - -/************************ - * - * op conv - * | | - * eltiwse - * | - * op - * - ************************* - *conv eltw ubfusion - *************************/ -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_data_eltwise_relu) { - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraph(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_data_eltwise_relu_l2_fusion) { - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphForL2Fusion1(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_conv_eltwise_relu) { - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraph2(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_conv_eltwise_relu3) { - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraph3(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_conv_eltwise_relu4) { - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraph4(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_conv_eltwise_relu5) { - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraph5(graph, 1); - graph->TopologicalSorting(); - - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_conv_eltwise_relu6) { - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraph6(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_relu_quant_fusion_pass) { - - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphConvReluQuant(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_leakyrelu_quant_fusion_pass) { - - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphConvLeakyReluQuant1(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_leakyrelu_quant_fusion_pass_no_fusion) { - - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphConvLeakyReluQuant2(graph, 1); - graph->TopologicalSorting(); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_eltwise_relu_quant_fusion_pass) { - - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphConvEltReluQuant1(graph, 1); - - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, coverage_01) { - - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphConvEltReluQuant1(graph, 1); - - std::shared_ptr pass = std::make_shared(); - std::vector fusion_nodes; - BufferFusionMapping mapping; - EXPECT_EQ(fe::SUCCESS, pass->GetFusionNodes(mapping, fusion_nodes)); - - std::vector matched_nodes; - for (auto node : graph->GetDirectNode()) { - matched_nodes.emplace_back(node); - } - - pass->GetMatchedHeadNode(matched_nodes); - - std::vector patterns= pass->DefinePatterns(); - for (auto &pattern : patterns) { - delete pattern; - } -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, conv_eltwise_relu_quant_fusion_pass_no_fusion) { - std::map create_fns = - BufferFusionPassRegistry::GetInstance().GetCreateFnByType(BUILT_IN_AI_CORE_BUFFER_FUSION_PASS); - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphConvEltReluQuant2(graph, 1); - RunPass(*graph); -} - -TEST_F(UB_FUSION_UT_CONV_ELT_RELU, dyn_static_check) { - std::map create_fns = - BufferFusionPassRegistry::GetInstance().GetCreateFnByType(BUILT_IN_AI_CORE_BUFFER_FUSION_PASS); - ComputeGraphPtr graph = std::make_shared("test"); - BuildGraphConvEltReluQuant2(graph, 1); - std::shared_ptr common0 = std::make_shared(); - std::vector patterns = common0->DefinePatterns(); - BufferFusionMapping mapping; - auto pattern = patterns[0]; - pattern->SetGraphModType(1); - EXPECT_EQ(1, pattern->GetGraphModType()); - auto conv = pattern->GetOpDesc(PATTERN_CONV); - auto eltwise = pattern->GetOpDesc(PATTERN_ELEMWISE); - auto quant = pattern->GetOpDesc(PATTERN_QUANT); - auto conv_vec = {graph->FindNode("conv")}; - auto eltwise_vec = {graph->FindNode("eltwise")}; - auto quant_vec = {graph->FindNode("quant")}; - mapping.emplace(std::make_pair(conv, conv_vec)); - mapping.emplace(std::make_pair(eltwise, eltwise_vec)); - mapping.emplace(std::make_pair(quant, quant_vec)); - EXPECT_FALSE(BufferFusionPassBase::CheckNodesImplConsistent(mapping)); - EXPECT_FALSE(BufferFusionPassBase::CheckNodesIncDynamicShape(mapping)); - EXPECT_TRUE(BufferFusionPassBase::CheckNodeIsDynamicImpl(graph->FindNode("eltwise"))); - EXPECT_EQ(common0->PostFusion(nullptr), fe::SUCCESS); - for (auto &pattern : patterns) { - delete pattern; - } -} diff --git a/tests/ut/register/testcase/register_buffer_fusion_v2_utest.cc b/tests/ut/register/testcase/register_buffer_fusion_v2_utest.cc deleted file mode 100644 index 08965dcd19382fac8903c8aeab5b7285a805abac..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_buffer_fusion_v2_utest.cc +++ /dev/null @@ -1,894 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include "gtest/gtest.h" - -#include "graph/compute_graph.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "graph/op_kernel_bin.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h" -#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h" -#include "graph/debug/ge_log.h" -#include "register/graph_optimizer/graph_fusion/connection_matrix.h" -#include "register/graph_optimizer/fusion_common/op_slice_info.h" -#include "runtime/kernel.h" - -using namespace std; -using namespace domi; -using namespace fe; -using namespace ge; - - -namespace fe { -namespace buffer_fusion_reg_v2 { -static const string STREAM_LABEL = "_stream_label"; -const std::string FE_IMPLY_TYPE = "_fe_imply_type"; -static const uint32_t L2_MAXDATANUM = 8; -using L2FusionData_t = struct tag_l2_fusion_data { - uint32_t l2Index; - uint64_t l2Addr; - uint64_t l2PageNum; -}; -using L2FusionDataMap_t = std::map; - -using fe_sm_desc_t = struct tag_fe_sm_desc { - rtL2Ctrl_t l2ctrl; - std::string node_name[L2_MAXDATANUM]; - uint8_t output_index[L2_MAXDATANUM]; -}; - -using TaskL2FusionInfo_t = struct TagTaskL2FusionInfo { - std::string node_name; - fe_sm_desc_t l2_info; - L2FusionDataMap_t input; - L2FusionDataMap_t output; - uint32_t is_used; -}; -using L2FusionInfoPtr = std::shared_ptr; -} - -class TbeCommonRules2FusionPass : public BufferFusionPassBase { - public: - explicit TbeCommonRules2FusionPass() = default; - - ~TbeCommonRules2FusionPass() override = default; - - protected: - - /* - * @brief: define a common ub fusion pattern: - * (StrideRead) -> Convolution -> (Dequant) -> Elewise*N -> Quant -> (StrideWrite) - * - * pattern limits: - * 1. StrideRead, StrideWrite, Dequant are optional, Conv2D and Quant are required. - * 2. Elewise supports LeakyRelu, Vadd, Relu, Relu6, Prelu, Add, Mul. The number of Elewise can be 0 to 5. - * 3. There are two outputs from Dequant or Elewise, one is int8 or int4, the other is fp16. - * - * - * fusion node: (StrideRead), Convolution, (AscendDequant), Elewise, AscendQuant, - * - * @return BufferFusionPattern: return all valid patterns. - */ - vector DefinePatterns() override; - - /* - * @brief: parse nodes matched in mapping and call DoFusion - * @param [in] graph: original graph - * @param [out] mapping: nodes matched by pattern - * @return bool: fusion status ok or not. - */ - Status GetFusionNodes(const BufferFusionMapping &mapping, vector &fusion_nodes) override; - - - Status PostFusion(const ge::NodePtr &fused_node) override { - return FAILED; - } - - private: - - static int CountOtherOutput(vector dequant_nodes, vector elem_wise_nodes); - - static bool JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes); -}; - -namespace { -const string PATTERN_STRIDEREAD = "strideRead"; // NOLINT -const string PATTERN_CONVOLUTION = "convolution"; // NOLINT -const string PATTERN_DEPTHWISECONV = "depthwiseconv"; // NOLINT -const string PATTERN_DEQUANT = "dequant"; // NOLINT -const string PATTERN_ELEMWISE = "elemWise"; // NOLINT -const string PATTERN_QUANT = "quant"; // NOLINT -const string PATTERN_STRIDEWRITE = "strideWrite"; // NOLINT -const string PATTERN_OTHER_INPUT = "otherInput"; // NOLINT -const string PATTERN_OUTPUT = "output"; // NOLINT - -const vector ELEM_WISE_WHITE_LIST = {"Eltwise", "LeakyRelu", "Vadd", "Relu", - "Relu6", "Relu6D", "PRelu", - "Add", "Mul", "Softplus", "Sigmoid", "Mish", - "Minimum", "Tanh", "Swish"}; // NOLINT - -const int MAX_OP_COUNT = 20; -const int MAX_ELEMWISE_COUNT = 5; -const int INPUT_MAX_SIZE = 2; -const int kConvOutputMaxSize = 2; -} - -#define UT_CHECK(cond, log_func, return_expr) \ - do { \ - if (cond) { \ - log_func; \ - return_expr; \ - } \ - } while (0) - -#define UT_CHECK_NOTNULL(val) \ - do { \ - if ((val) == nullptr) { \ - GE_LOGE("Parameter[%s] must not be null.", #val); \ - return fe::PARAM_INVALID; \ - } \ - } while (0)3 - -/* -* @brief: define a common ub fusion pattern: -* (StrideRead) -> Convolution -> (Dequant) -> Elewise*N -> Quant -> (StrideWrite) -* -* pattern limits: -* 1. StrideRead, StrideWrite, Dequant are optional, Conv2D and Quant are required. -* 2. Elewise supports LeakyRelu, Vadd, Relu, Relu6, Prelu, Add, Mul. The number of Elewise can be 0 to 5. -* 3. There are two outputs from Dequant or Elewise, one is int8 or int4, the other is fp16. -* -* -* fusion node: (StrideRead), Convolution, (AscendDequant), Elewise, AscendQuant, -* -* @return BufferFusionPattern: return all valid patterns. -*/ -vector TbeCommonRules2FusionPass::DefinePatterns() { - vector patterns; - string pass_name = "TbeCommonRules2FusionPass"; - auto *pattern = new(std::nothrow) BufferFusionPattern(pass_name, MAX_OP_COUNT); - UT_CHECK((pattern == nullptr), - GE_LOGE("[SubGraphOpt][CommonRules2Fus][DefPtn] New an object failed."), return patterns); - GELOGD("Start to define %s pass pattern.", pass_name.c_str()); - pattern->AddOpDesc(PATTERN_STRIDEREAD, {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_CONVOLUTION, {OP_PATTERN_CONV}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEPTHWISECONV, {OP_PATTERN_DEPTHWISE_CONV}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEQUANT, {OP_PATTERN_DEQUANT}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_OTHER_INPUT, {TBE_PATTERN_INPUT_NODE}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_ELEMWISE, {OP_PATTERN_ELEMWISE}, TBE_PATTERN_NUM_NONE, - MAX_ELEMWISE_COUNT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_QUANT, {OP_PATTERN_QUANT}, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_STRIDEWRITE, {OP_PATTERN_STRIDED_WRITE}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .SetHead({PATTERN_STRIDEREAD, PATTERN_CONVOLUTION, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_STRIDEREAD, {PATTERN_CONVOLUTION, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_CONVOLUTION, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_DEPTHWISECONV, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_DEQUANT, {PATTERN_ELEMWISE}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_OTHER_INPUT, {PATTERN_DEQUANT}) - .SetOutputs(PATTERN_ELEMWISE, {PATTERN_QUANT}, TBE_OUTPUT_BRANCH_SINGLE, true, true) - .SetOutputs(PATTERN_QUANT, {PATTERN_STRIDEWRITE}, TBE_OUTPUT_BRANCH_SINGLE, false, true); - patterns.push_back(pattern); - GELOGD("End to define %s pass pattern.", pass_name.c_str()); - - return patterns; -} - -int TbeCommonRules2FusionPass::CountOtherOutput(vector dequant_nodes, - vector elem_wise_nodes) { - int other_out_count = 0; - // count EltWise op other output - for (const auto &elem_wise_node : elem_wise_nodes) { - if (elem_wise_node->GetOutDataNodes().empty()) { - continue; - } - int other_elt_wise_out = (int) (elem_wise_node->GetOutDataNodes().size() - 1); - other_out_count += other_elt_wise_out; - } - - // count Dequant op other output - if (!dequant_nodes.empty()) { - int other_dequant_out = 0; - if (dequant_nodes[0]->GetOutDataNodes().empty()) { - other_dequant_out = 0; - } else { - other_dequant_out = static_cast(dequant_nodes[0]->GetOutDataNodes().size() - 1); - } - other_out_count += other_dequant_out; - } - return other_out_count; -} - -bool TbeCommonRules2FusionPass::JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes) { - if (pre_elemwise_nodes.empty()) { - return false; - } - ge::NodePtr cur_node = pre_elemwise_nodes[0]; - for (auto &elemwise_node: elemwise_nodes) { - ge::NodePtr pre_node = cur_node; - cur_node = elemwise_node; - if (cur_node->GetOpDesc()->GetInputsSize() != INPUT_MAX_SIZE) { - continue; - } - - if ((cur_node->GetInDataAnchor(0)->GetPeerOutAnchor() == nullptr) || - (cur_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode() == nullptr)) { - return false; - } - auto cur_node_input0 = cur_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); - - vector in_scope_dims; - vector out_scope_dims; - if (cur_node_input0->GetName() == pre_node->GetOpDesc()->GetName()) { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - } else { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - } - if (in_scope_dims.size() != out_scope_dims.size()) { - GELOGD("Elem_wise[node: %s] : the number of input's dims is not equal. in_scope_dims: %zu, out_scope_dims: %zu", - cur_node->GetName().c_str(), in_scope_dims.size(), out_scope_dims.size()); - return false; - } else { - for (size_t i = 0; i < in_scope_dims.size(); i++) { - if (in_scope_dims[i] < out_scope_dims[i]) { - GELOGD("Elem_wise[node: %s] dims[%zu]: the value of in_scope is less than out_scope. in_scope : %ld," - " out_scope : %ld", cur_node->GetName().c_str(), i, in_scope_dims[i], out_scope_dims[i]); - return true; - } - } - } - } - return false; -} - -/* -* @brief: parse nodes matched in mapping and call DoFusion -* @param [in] graph: original graph -* @param [out] mapping: nodes matched by pattern -* @return bool: fusion status ok or not. -*/ -Status TbeCommonRules2FusionPass::GetFusionNodes(const BufferFusionMapping &mapping, - vector &fusion_nodes) { - fusion_nodes = GetMatchedNodes(mapping); - vector output_nodes = GetMatchedNodesByDescName(TBE_PATTERN_OUTPUT_NODE, mapping); - vector conv_nodes = GetMatchedNodesByDescName(PATTERN_CONVOLUTION, mapping); - vector depthwise_nodes = GetMatchedNodesByDescName(PATTERN_DEPTHWISECONV, mapping); - vector elem_wise_nodes = GetMatchedNodesByDescName(PATTERN_ELEMWISE, mapping); - vector dequant_nodes = GetMatchedNodesByDescName(PATTERN_DEQUANT, mapping); - vector quant_nodes = GetMatchedNodesByDescName(PATTERN_QUANT, mapping); - vector stride_write_nodes = GetMatchedNodesByDescName(PATTERN_STRIDEWRITE, mapping); - - bool conv_depth_size = conv_nodes.size() == 1 || depthwise_nodes.size() == 1; - if (!conv_depth_size) { - GELOGD("There is no conv and depthwise in TbeCommonRules2FusionPass"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - vector conv_depthwise_nodes = conv_nodes.size() == 1 ? conv_nodes : depthwise_nodes; - - size_t conv_output_size = conv_depthwise_nodes[0]->GetOutDataNodes().size(); - // conv outputs size is more than 2, skip fused - if (conv_output_size > kConvOutputMaxSize) { - GELOGD("node: %s, outputs is more than 2, size is: %zu.", - conv_depthwise_nodes[0]->GetName().c_str(), conv_output_size); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - - // the output_data can't be fused - for (const auto &outputnode : output_nodes) { - auto node_ptr = find(fusion_nodes.begin(), fusion_nodes.end(), outputnode); - if (node_ptr != fusion_nodes.end()) { - fusion_nodes.erase(node_ptr); - } - } - - // this pattern only support one other output from dequant node or elem_wise node - int other_out_count = CountOtherOutput(dequant_nodes, elem_wise_nodes); - bool cond_other_out_count = (conv_output_size == 1 && other_out_count != 1) || - (conv_output_size == kConvOutputMaxSize && other_out_count != 0); - if (cond_other_out_count) { - GELOGD("The number of other output from EltWise or Dequant is %d, skip fusion.", other_out_count); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - - // if elewise has 2 input and inscope's shape less than outscope's shape, skip fusion - bool dequant_flag = !dequant_nodes.empty() && - JudgeElemShapeInScopeLessThanOutScope(dequant_nodes, elem_wise_nodes); - if (dequant_flag) { - GELOGD("dequant_nodes exist, Elemwise node has 2 inputs and in scope shape is less than outscope"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - bool no_dequant_flag = dequant_nodes.empty() && - JudgeElemShapeInScopeLessThanOutScope(conv_depthwise_nodes, elem_wise_nodes); - if (no_dequant_flag) { - GELOGD("no dequant_nodes, Elemwise node has 2 inputs and in scope shape is less than outscope"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - - // check whether the EltWise op is in the whitelist or inputsizes less then 3(only support single or double in) - for (const auto &elem_wise_node : elem_wise_nodes) { - bool support_flag = find(ELEM_WISE_WHITE_LIST.begin(), ELEM_WISE_WHITE_LIST.end(), elem_wise_node->GetType()) == - ELEM_WISE_WHITE_LIST.end() || - elem_wise_node->GetOpDesc()->GetInputsSize() > INPUT_MAX_SIZE; - if (support_flag) { - fusion_nodes.clear(); - GELOGD("Eltwise op[%s] type[%s] is not supported for this ub fusion pass, skip fusion.", - elem_wise_node->GetName().c_str(), elem_wise_node->GetType().c_str()); - return ge::GRAPH_SUCCESS; - } - } - - // if stride_write is the last node, check whether quant node has multi outputs - bool quant_node_flag = quant_nodes[0]->GetOutDataNodes().size() > 1 && !stride_write_nodes.empty(); - if (quant_node_flag) { - auto node_ptr = find(fusion_nodes.begin(), fusion_nodes.end(), stride_write_nodes[0]); - if (node_ptr != fusion_nodes.end()) { - fusion_nodes.erase(node_ptr); - } - GELOGD("Quant is not the last node of the matched pattern, \ - but has multi outpts, erase last node stride_write."); - } - return ge::GRAPH_SUCCESS; -} - -static const char PATTERN_STRIDED_READ[] = "stridedread"; -static const char PATTERN_CONV[] = "convolution"; -static const char PATTERN_STRIDED_WRITE[] = "stridedwrite"; -static const int FUSION_OP_NUM_MAX = 10; - -class ConveragePass : public BufferFusionPassBase { - public: - explicit ConveragePass() {} - - ~ConveragePass() override {} - - protected: - - /* - * @brief: define common rules0 ops fusion pattern - * - * (StrideRead) + conv2_d + (dequant) + ele-wise*N + (quant) + (StrideWrite) - * restriction: 1.each node must be single output and single reference - * 2.the range of N is 0 to 5 - * 3.allow multiple input, but only one input can be fusion - * - * @return BufferFusionPattern: return all valid patterns. - */ - vector DefinePatterns() override { - vector patterns; - string pass_name = "ConveragePass"; - BufferFusionPattern *pattern = new(std::nothrow) BufferFusionPattern(pass_name, 10); - UT_CHECK((pattern == nullptr), - GE_LOGE("[SubGraphOpt][CommonRules0Fus][DefPtn] New an object failed."), - return patterns); - GELOGD("Start to define %s pass pattern.", pass_name.c_str()); - // define pattern rules - pattern->AddOpDesc("", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->AddOpDesc("test", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_NUM_NONE, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->AddOpDesc("test", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc("test", {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc("head1", {OP_PATTERN_STRIDED_READ}, 2, - 3, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc("head2", {OP_PATTERN_STRIDED_READ}, 1, - 1, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->AddOpDesc(PATTERN_CONV, {OP_PATTERN_CONV}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEPTHWISECONV, {OP_PATTERN_DEPTHWISE_CONV}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - pattern->AddOpDesc(PATTERN_STRIDED_READ, {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE); - - pattern->SetOutputs("", {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - pattern->SetOutputs("1", {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {"1", PATTERN_DEPTHWISECONV}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {PATTERN_STRIDED_READ}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - pattern->SetOutputs(PATTERN_STRIDED_READ, {PATTERN_CONV, PATTERN_DEPTHWISECONV}); - - - vector heads; - pattern->SetHead(heads); - - heads = {""}; - pattern->SetHead(heads); - - heads = {"head1"}; - pattern->SetHead(heads); - - heads = {PATTERN_CONV}; - pattern->SetHead(heads); - - heads = {PATTERN_CONV, "head2"}; - pattern->SetHead(heads); - - auto conv_desc = pattern->GetOpDesc(PATTERN_CONV); - pattern->UpdateSkipStatus(conv_desc); - - pattern->GetOpDescs(); - } -}; - -class TbeCommonRules0FusionPass : public BufferFusionPassBase { - public: - explicit TbeCommonRules0FusionPass() {} - - ~TbeCommonRules0FusionPass() override {} - - protected: - - /* - * @brief: define common rules0 ops fusion pattern - * - * (StrideRead) + conv2_d + (dequant) + ele-wise*N + (quant) + (StrideWrite) - * restriction: 1.each node must be single output and single reference - * 2.the range of N is 0 to 5 - * 3.allow multiple input, but only one input can be fusion - * - * @return BufferFusionPattern: return all valid patterns. - */ - vector DefinePatterns() override; - - /* - * @brief: parse nodes matched in mapping and call DoFusion - * @param [in] graph: original graph - * @param [out] mapping: nodes matched by pattern - * @return bool: fusion status ok or not. - */ - Status GetFusionNodes(const BufferFusionMapping &mapping, vector &fusion_nodes) override; - - private: - - static bool DealWithSameInAndOutScopeDimSize(const vector &in_scope_dims, - const vector &out_scope_dims, - const vector &elemwise_nodes, - const ge::NodePtr &cur_node, const size_t &i, - vector &fusion_node); - - static bool JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes, - vector &fusion_nodes); - - static bool IsInBlackListOfOpPatternElemwise(vector &elemwise_nodes, ge::NodePtr &node_ptr); -}; - -namespace { - -// white list of OP_PATTERN_ELEMWISE -static const vector WHITELIST_OF_OP_PATTERN_ELEMWISE = { - "Eltwise", "LeakyRelu", "Vadd", "Relu", "Relu6", "Relu6D", - "PRelu", "Add", "Mul", "Softplus", "Sigmoid", "Mish", "Minimum", - "Tanh", "Swish"}; -// black list of OP_PATTERN_ELEMWISE -static const vector BLACKLIST_OF_OP_PATTERN_ELEMWISE = { - "ReluGradV2"}; -} - -/* - * @brief: define common rules0 ops fusion pattern - * - * (StrideRead) + conv2_d + (dequant) + ele-wise*N + (quant) + (StrideWrite) - * restriction: 1.each node must be single output and single reference - * 2.the range of N is 0 to 5 - * 3.allow multiple input, but only one input can be fusion - * - * @return BufferFusionPattern: return all valid patterns. - */ -vector TbeCommonRules0FusionPass::DefinePatterns() { - vector patterns; - string pass_name = "TbeCommonRules0FusionPass"; - BufferFusionPattern *pattern = new(std::nothrow) BufferFusionPattern(pass_name, FUSION_OP_NUM_MAX); - UT_CHECK((pattern == nullptr), - GE_LOGE("[SubGraphOpt][CommonRules0Fus][DefPtn] New an object failed."), - return patterns); - GELOGD("Start to define %s pass pattern.", pass_name.c_str()); - // define pattern rules - pattern->AddOpDesc(PATTERN_STRIDED_READ, {OP_PATTERN_STRIDED_READ}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_CONV, {OP_PATTERN_CONV}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEPTHWISECONV, {OP_PATTERN_DEPTHWISE_CONV}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_DEQUANT, {OP_PATTERN_DEQUANT}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_ELEMWISE, {OP_PATTERN_ELEMWISE}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_MAX, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_QUANT, {OP_PATTERN_QUANT}, TBE_PATTERN_NUM_NONE, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_STRIDED_WRITE, {OP_PATTERN_STRIDED_WRITE}, TBE_PATTERN_NUM_NONE, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .AddOpDesc(PATTERN_OTHER_INPUT, {TBE_PATTERN_INPUT_NODE}, TBE_PATTERN_NUM_DEFAULT, - TBE_PATTERN_NUM_DEFAULT, TBE_PATTERN_GROUPID_INVALID, IGNORE_SHAPE_TYPE) - .SetHead({PATTERN_STRIDED_READ, PATTERN_CONV, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_STRIDED_READ, {PATTERN_CONV, PATTERN_DEPTHWISECONV}) - .SetOutputs(PATTERN_CONV, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_DEPTHWISECONV, {PATTERN_DEQUANT}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_OTHER_INPUT, {PATTERN_DEQUANT}) - .SetOutputs(PATTERN_DEQUANT, {PATTERN_ELEMWISE}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_ELEMWISE, {PATTERN_QUANT}, TBE_OUTPUT_BRANCH_SINGLE, true) - .SetOutputs(PATTERN_QUANT, {PATTERN_STRIDED_WRITE}); - - patterns.push_back(pattern); - - GELOGD("End to define %s pass pattern.", pass_name.c_str()); - return patterns; -} - -static void DelNotMatchNodesFromFusionNodes(ge::NodePtr node_ptr, vector &fusion_nodes) { - auto node = find(fusion_nodes.begin(), fusion_nodes.end(), node_ptr); - if (node != fusion_nodes.end()) { - fusion_nodes.erase(node); - } else { - return; - } - - auto curr_nodes = node_ptr->GetOutDataNodes(); - if (curr_nodes.size() != 1) { - return; - } else { - DelNotMatchNodesFromFusionNodes(curr_nodes.at(0), fusion_nodes); - } - return; -} - -static bool IsInWhiteListOfOpPatternElemwise(vector &elemwise_nodes, ge::NodePtr &node_ptr) { - for (auto &elemwise_node : elemwise_nodes) { - string elemwise_type = elemwise_node->GetType(); - auto op_type = - find(WHITELIST_OF_OP_PATTERN_ELEMWISE.begin(), WHITELIST_OF_OP_PATTERN_ELEMWISE.end(), elemwise_type); - if (op_type == WHITELIST_OF_OP_PATTERN_ELEMWISE.end()) { - GELOGD("node:%s[type:%s] not in elemwise white_list.", - elemwise_node->GetName().c_str(), elemwise_type.c_str()); - node_ptr = elemwise_node; - return false; - } - } - return true; -} - -bool TbeCommonRules0FusionPass::IsInBlackListOfOpPatternElemwise(vector &elemwise_nodes, - ge::NodePtr &node_ptr) { - for (auto &elemwise_node : elemwise_nodes) { - string elemwise_type = elemwise_node->GetType(); - auto op_type = - find(BLACKLIST_OF_OP_PATTERN_ELEMWISE.begin(), BLACKLIST_OF_OP_PATTERN_ELEMWISE.end(), elemwise_type); - if (op_type != BLACKLIST_OF_OP_PATTERN_ELEMWISE.end()) { - GELOGD("node:%s[type:%s] in elemwise black_list.", elemwise_node->GetName().c_str(), elemwise_type.c_str()); - node_ptr = elemwise_node; - return true; - } - } - return false; -} - -static void CheckElewiseInputSize(vector &elemwise_nodes, vector &fusion_nodes) { - for (auto elemwise_node : elemwise_nodes) { - if (elemwise_node->GetOpDesc()->GetInputsSize() > INPUT_MAX_SIZE) { - DelNotMatchNodesFromFusionNodes(elemwise_node, fusion_nodes); - return; - } - } -} - -bool TbeCommonRules0FusionPass::DealWithSameInAndOutScopeDimSize(const vector &in_scope_dims, - const vector &out_scope_dims, - const vector &elemwise_nodes, - const ge::NodePtr &cur_node, const size_t &i, - vector &fusion_nodes) { - for (size_t j = 0; j < in_scope_dims.size(); j++) { - if (in_scope_dims[j] < out_scope_dims[j]) { - GELOGD("Elem_wise[node: %s] dims[%zu] : the value of in_scope is less than out_scope. in_scope : %ld," - " out_scope : %ld", cur_node->GetName().c_str(), j, in_scope_dims[j], out_scope_dims[j]); - vector new_elemwise_nodes; - for (size_t z = i; z < elemwise_nodes.size(); z++) { - new_elemwise_nodes.push_back(elemwise_nodes[z]); - } - for (auto new_elemwise_node : new_elemwise_nodes) { - DelNotMatchNodesFromFusionNodes(new_elemwise_node, fusion_nodes); - } - return true; - } - } - return false; -} - -bool TbeCommonRules0FusionPass::JudgeElemShapeInScopeLessThanOutScope(const vector &pre_elemwise_nodes, - const vector &elemwise_nodes, - vector &fusion_nodes) { - if (pre_elemwise_nodes.empty()) { - return false; - } - ge::NodePtr cur_node = pre_elemwise_nodes[0]; - for (size_t i = 0; i < elemwise_nodes.size(); i++) { - ge::NodePtr elemwise_node = elemwise_nodes[i]; - ge::NodePtr pre_node = cur_node; - cur_node = elemwise_node; - if (cur_node->GetOpDesc()->GetInputsSize() != INPUT_MAX_SIZE) { - continue; - } - auto peerOutAnchor = cur_node->GetInDataAnchor(0)->GetPeerOutAnchor(); - if (peerOutAnchor == nullptr) { - GELOGD("node[%s]'s first peer in anchor is null", cur_node->GetName().c_str()); - continue; - } - auto cur_node_input0 = peerOutAnchor->GetOwnerNode(); - vector in_scope_dims; - vector out_scope_dims; - if (cur_node_input0->GetName() == pre_node->GetOpDesc()->GetName()) { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - } else { - in_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(1)->MutableShape().GetDims(); - out_scope_dims = cur_node->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(); - } - if (in_scope_dims.size() != out_scope_dims.size()) { - GELOGD("Elem_wise[node: %s] : the number of input's dims is not equal. in_scope : %zu, out_scope : %zu", - cur_node->GetName().c_str(), in_scope_dims.size(), out_scope_dims.size()); - return false; - } else { - if (DealWithSameInAndOutScopeDimSize(in_scope_dims, out_scope_dims, elemwise_nodes, cur_node, i, fusion_nodes)) { - return true; - } - } - } - return false; -} - -static void DelNotMatchNodes(vector &elemwise_nodes, vector &fusion_nodes) { - if (!elemwise_nodes.empty()) { - ge::NodePtr node = nullptr; - if (!IsInWhiteListOfOpPatternElemwise(elemwise_nodes, node)) { - DelNotMatchNodesFromFusionNodes(node, fusion_nodes); - } - } -} - -/* - * @brief: parse nodes matched in mapping and call DoFusion - * @param [in] graph: original graph - * @param [out] mapping: nodes matched by pattern - * @return bool: fusion status ok or not. - */ -Status TbeCommonRules0FusionPass::GetFusionNodes(const BufferFusionMapping &mapping, - vector &fusion_nodes) { - GELOGD("Begin to do TbeCommonRules0FusionPass!"); - fusion_nodes = GetMatchedNodes(mapping); - - vector elemwise_nodes = GetMatchedNodesByDescName(PATTERN_ELEMWISE, mapping); - // elewise only support single in or double in - if (!elemwise_nodes.empty()) { - CheckElewiseInputSize(elemwise_nodes, fusion_nodes); - } - - vector conv_nodes = GetMatchedNodesByDescName(PATTERN_CONV, mapping); - vector depthwise_nodes = GetMatchedNodesByDescName(PATTERN_DEPTHWISECONV, mapping); - bool conv_depth_size = conv_nodes.size() == 1 || depthwise_nodes.size() == 1; - if (!conv_depth_size) { - GELOGD("There is no conv and depthwise in TbeCommonRules0FusionPass"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - vector conv_depthwise_nodes = conv_nodes.size() == 1 ? conv_nodes : depthwise_nodes; - vector dequant_nodes = GetMatchedNodesByDescName(PATTERN_DEQUANT, mapping); - - // if elewise has 2 input and inscope's shape less than outscope's shape, skip fusion - if (!dequant_nodes.empty()) { - if (JudgeElemShapeInScopeLessThanOutScope(dequant_nodes, elemwise_nodes, fusion_nodes)) { - GELOGD("dequant_nodes exist, Elemwise node has 2 inputs and in scope shape is less than outscope, try to fuse" - " before elemwise nodes"); - return ge::GRAPH_SUCCESS; - } - } else { - if (JudgeElemShapeInScopeLessThanOutScope(conv_depthwise_nodes, elemwise_nodes, fusion_nodes)) { - GELOGD("no dequant_nodes, Elemwise node has 2 inputs and in scope shape is less than outscope, try to fuse" - " before elemwise nodes"); - return ge::GRAPH_SUCCESS; - } - } - // elewise is in the blacklist, skip fusion - if (!elemwise_nodes.empty()) { - ge::NodePtr node = nullptr; - if (IsInBlackListOfOpPatternElemwise(elemwise_nodes, node)) { - GELOGD("node is in elemwise black_list, skip ub fusion!"); - fusion_nodes.clear(); - return ge::GRAPH_SUCCESS; - } - } - - // in conv2_d+elewise(1~3) pattern, elewise has no restrictions, - // if nums of elewise more then 3 and either one is not in the whitelist, skip fusion - bool ret = (fusion_nodes.size() == (elemwise_nodes.size() + conv_depthwise_nodes.size())) && - (conv_depthwise_nodes.size() == 1) && !elemwise_nodes.empty(); - if (ret) { - if (elemwise_nodes.size() <= 3) { - return ge::GRAPH_SUCCESS; - } else { - ge::NodePtr node = nullptr; - if (!IsInWhiteListOfOpPatternElemwise(elemwise_nodes, node)) { - fusion_nodes.clear(); - } - return ge::GRAPH_SUCCESS; - } - } - - DelNotMatchNodes(elemwise_nodes, fusion_nodes); - - if (fusion_nodes.size() == 1) { - fusion_nodes.clear(); - } - GELOGD("End to do TbeCommonRules0FusionPass!"); - return ge::GRAPH_SUCCESS; -} - -class UTestBufferFusionPassReg : public testing::Test { - public: - - protected: - - void SetUp() { - - } - - void TearDown() { - } -}; - -using BufferFusionFn = BufferFusionPassBase *(*)(); - -class TestBufferFusionPass : public BufferFusionPassBase { - protected: - - vector DefinePatterns() override { - return {}; - }; -}; - -fe::BufferFusionPassBase *BufferFunc1() { - auto ret = new(std::nothrow) TestBufferFusionPass(); - ret->SetName("1"); - return ret; -} - -fe::BufferFusionPassBase *BufferFunc2() { - auto ret = new(std::nothrow) TestBufferFusionPass(); - ret->SetName("2"); - return ret; -} - -fe::BufferFusionPassBase *BufferFunc3() { - auto ret = new(std::nothrow) TestBufferFusionPass(); - ret->SetName("3"); - return ret; -} - -TEST_F(UTestBufferFusionPassReg, test_case_01) { - auto pass_desc = BufferFusionPassRegistry::GetInstance().GetPassDesc(CUSTOM_AI_CORE_BUFFER_FUSION_PASS); - auto init_size = pass_desc.size(); - BufferFusionPassRegistry::GetInstance().RegisterPass( - CUSTOM_AI_CORE_BUFFER_FUSION_PASS, "CUSTOM_PASS1", BufferFunc1, 0); - BufferFusionPassRegistry::GetInstance().RegisterPass( - CUSTOM_AI_CORE_BUFFER_FUSION_PASS, "CUSTOM_PASS2", BufferFunc2, 1); - BufferFusionPassRegistry::GetInstance().RegisterPass( - CUSTOM_AI_CORE_BUFFER_FUSION_PASS, "CUSTOM_PASS3", BufferFunc3, 0xffffffff); - - pass_desc = BufferFusionPassRegistry::GetInstance().GetPassDesc(CUSTOM_AI_CORE_BUFFER_FUSION_PASS); - EXPECT_EQ(pass_desc.size(), init_size + 3); - EXPECT_EQ(pass_desc["CUSTOM_PASS1"].attr, 0); - EXPECT_EQ(pass_desc["CUSTOM_PASS2"].attr, 1); - EXPECT_EQ(pass_desc["CUSTOM_PASS3"].attr, 0xffffffff); - - auto create_fn1 = pass_desc["CUSTOM_PASS1"].create_fn; - auto buffer_fusion_pass_base_ptr1 = std::unique_ptr(create_fn1()); - EXPECT_EQ(buffer_fusion_pass_base_ptr1->GetName(), "1"); - - auto create_fn2 = pass_desc["CUSTOM_PASS2"].create_fn; - auto buffer_fusion_pass_base_ptr2 = std::unique_ptr(create_fn2()); - EXPECT_EQ(buffer_fusion_pass_base_ptr2->GetName(), "2"); - - auto create_fn3 = pass_desc["CUSTOM_PASS3"].create_fn; - auto buffer_fusion_pass_base_ptr3 = std::unique_ptr(create_fn3()); - EXPECT_EQ(buffer_fusion_pass_base_ptr3->GetName(), "3"); - - pass_desc = BufferFusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_AI_CORE_BUFFER_FUSION_PASS); - init_size = pass_desc.size(); - BufferFusionPassRegistry::GetInstance().RegisterPass( - BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, "BUILT_IN_PASS1", BufferFunc1, 0); - BufferFusionPassRegistry::GetInstance().RegisterPass( - BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, "BUILT_IN_PASS2", BufferFunc2, 1); - BufferFusionPassRegistry::GetInstance().RegisterPass( - BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, "BUILT_IN_PASS3", BufferFunc3, 0xffffffff); - - auto pass_desc2 = BufferFusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_AI_CORE_BUFFER_FUSION_PASS); - EXPECT_EQ(pass_desc2.size(), init_size + 3); - EXPECT_EQ(pass_desc2["BUILT_IN_PASS1"].attr, 0); - EXPECT_EQ(pass_desc2["BUILT_IN_PASS2"].attr, 1); - EXPECT_EQ(pass_desc2["BUILT_IN_PASS3"].attr, 0xffffffff); - - create_fn1 = pass_desc2["BUILT_IN_PASS1"].create_fn; - auto buffer_fusion_pass_base_ptr4 = std::unique_ptr(create_fn1()); - EXPECT_EQ(buffer_fusion_pass_base_ptr4->GetName(), "1"); - - create_fn2 = pass_desc2["BUILT_IN_PASS2"].create_fn; - auto buffer_fusion_pass_base_ptr5 = std::unique_ptr(create_fn2()); - EXPECT_EQ(buffer_fusion_pass_base_ptr5->GetName(), "2"); - - create_fn3 = pass_desc2["BUILT_IN_PASS3"].create_fn; - auto buffer_fusion_pass_base_ptr6 = std::unique_ptr(create_fn3()); - EXPECT_EQ(buffer_fusion_pass_base_ptr6->GetName(), "3"); -} - - -TEST_F(UTestBufferFusionPassReg, test_case_02) { - REG_BUFFER_FUSION_PASS("", BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, - TbeCommonRules0FusionPass, 0); - - REG_BUFFER_FUSION_PASS("MetadefBufferFusionPassTest", BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, - TbeCommonRules0FusionPass, 1); - REG_BUFFER_FUSION_PASS("MetadefBufferFusionPassTest", BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, - TbeCommonRules0FusionPass, 2); - - REG_BUFFER_FUSION_PASS("MetadefBufferFusionPassTest1", BUILT_IN_VECTOR_CORE_BUFFER_FUSION_PASS, - TbeCommonRules0FusionPass, 0xffff); - - auto pass_desc = BufferFusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_AI_CORE_BUFFER_FUSION_PASS); - EXPECT_EQ(pass_desc["MetadefBufferFusionPassTest"].attr, 2); - auto create_fn = pass_desc["MetadefBufferFusionPassTest"].create_fn; - auto buffer_fusion_pass_base_ptr = std::unique_ptr(create_fn()); - auto def_patterns = buffer_fusion_pass_base_ptr->DefinePatterns(); - EXPECT_EQ(def_patterns.size(), 1); - for (const auto &def_pattern : def_patterns) { - delete def_pattern; - } - - pass_desc = BufferFusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_VECTOR_CORE_BUFFER_FUSION_PASS); - EXPECT_EQ(pass_desc["MetadefBufferFusionPassTest1"].attr, 0xffff); - create_fn = pass_desc["MetadefBufferFusionPassTest1"].create_fn; - buffer_fusion_pass_base_ptr = std::unique_ptr(create_fn()); - def_patterns = buffer_fusion_pass_base_ptr->DefinePatterns(); - EXPECT_EQ(def_patterns.size(), 1); - for (const auto &def_pattern : def_patterns) { - delete def_pattern; - } -} - -TEST_F(UTestBufferFusionPassReg, test_post_fusion) { - std::shared_ptr common2 = std::make_shared(); - EXPECT_EQ(common2->PostFusion(nullptr), FAILED); -} -} diff --git a/tests/ut/register/testcase/register_graph_fusion.cc b/tests/ut/register/testcase/register_graph_fusion.cc deleted file mode 100644 index 0163548fe4a5fdc9c6729ab93d03fb7732eb55b5..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_graph_fusion.cc +++ /dev/null @@ -1,785 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" - -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" - -#include "graph/debug/ge_log.h" - -using namespace testing; -using namespace ge; -using namespace fe; - -namespace fe{ - - -class UTESTGraphFusionPass : public testing::Test { - public: - - protected: - void SetUp() { - - } - - void TearDown() { - } - - static ComputeGraphPtr CreateCastReluCastGraph1() { - ComputeGraphPtr graph = std::make_shared("test1"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_relu = std::make_shared("relu", "Relu"); - OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - //vector dim_d; - GeShape shape_d(dim_a); - GeTensorDesc tensor_desc_d(shape_d); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT16); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_relu->AddInputDesc(tensor_desc_b); - op_desc_relu->AddOutputDesc(tensor_desc_c); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_output->AddInputDesc(tensor_desc_d); - - ge::AttrUtils::SetStr(op_desc_relu, "_op_compile_strategy", "{}"); - ge::AttrUtils::SetInt(op_desc_relu, "_keep_dtype", 1); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_relu = graph->AddNode(op_desc_relu); - NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_other = graph->AddNode(op_desc_input); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - return graph; - } - static ComputeGraphPtr CreateCastReluCastGraph2() { - ComputeGraphPtr graph = std::make_shared("test1"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_relu = std::make_shared("relu", "Relu"); - OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - //vector dim_d; - GeShape shape_d(dim_a); - GeTensorDesc tensor_desc_d(shape_d); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_relu->AddInputDesc(tensor_desc_b); - op_desc_relu->AddOutputDesc(tensor_desc_c); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_output->AddInputDesc(tensor_desc_d); - - ge::AttrUtils::SetStr(op_desc_relu, "_op_compile_strategy", "{}"); - ge::AttrUtils::SetInt(op_desc_relu, "_keep_dtype", 1); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_relu = graph->AddNode(op_desc_relu); - NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_other = graph->AddNode(op_desc_input); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - return graph; - } - static ComputeGraphPtr CreateCastReluCastGraph3() { - ComputeGraphPtr graph = std::make_shared("test1"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_relu = std::make_shared("relu", "Relu"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT16); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT16); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_relu->AddInputDesc(tensor_desc_b); - op_desc_relu->AddOutputDesc(tensor_desc_c); - - op_desc_output->AddInputDesc(tensor_desc_c); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_relu = graph->AddNode(op_desc_relu); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_other = graph->AddNode(op_desc_input); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - return graph; - } - static ComputeGraphPtr CreateCastReluCastGraph4() { - ComputeGraphPtr graph = std::make_shared("test1"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_cast3 = std::make_shared("cast3", "Cast"); - OpDescPtr op_desc_relu = std::make_shared("relu", "Relu"); - OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - //vector dim_d; - GeShape shape_d(dim_a); - GeTensorDesc tensor_desc_d(shape_d); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT16); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_cast3->AddInputDesc(tensor_desc_b); - op_desc_cast3->AddOutputDesc(tensor_desc_d); - - op_desc_relu->AddInputDesc(tensor_desc_b); - op_desc_relu->AddOutputDesc(tensor_desc_c); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_output->AddInputDesc(tensor_desc_d); - op_desc_output->AddInputDesc(tensor_desc_d); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_cast3 = graph->AddNode(op_desc_cast3); - NodePtr node_relu = graph->AddNode(op_desc_relu); - NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_other = graph->AddNode(op_desc_input); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_cast3->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast3->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - - return graph; - } - static ComputeGraphPtr CreateCastReluCastGraph5() { - ComputeGraphPtr graph = std::make_shared("test1"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_cast3 = std::make_shared("cast3", "Cast"); - OpDescPtr op_desc_cast4 = std::make_shared("cast4", "Cast"); - OpDescPtr op_desc_relu = std::make_shared("relu", "Relu"); - OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - //vector dim_d; - GeShape shape_d(dim_a); - GeTensorDesc tensor_desc_d(shape_d); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT16); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_relu->AddInputDesc(tensor_desc_b); - op_desc_relu->AddOutputDesc(tensor_desc_c); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_cast3->AddInputDesc(tensor_desc_c); - op_desc_cast3->AddOutputDesc(tensor_desc_d); - - op_desc_cast4->AddInputDesc(tensor_desc_c); - op_desc_cast4->AddOutputDesc(tensor_desc_d); - - op_desc_output->AddInputDesc(tensor_desc_d); - op_desc_output->AddInputDesc(tensor_desc_d); - op_desc_output->AddInputDesc(tensor_desc_d); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_cast3 = graph->AddNode(op_desc_cast3); - NodePtr node_cast4 = graph->AddNode(op_desc_cast4); - NodePtr node_relu = graph->AddNode(op_desc_relu); - NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_other = graph->AddNode(op_desc_input); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast3->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast4->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast3->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - GraphUtils::AddEdge(node_cast4->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(2)); - - return graph; - } - static ComputeGraphPtr CreateCastReluCastGraph6() { - ComputeGraphPtr graph = std::make_shared("test1"); - OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - OpDescPtr op_desc_cast3 = std::make_shared("cast3", "Cast"); - OpDescPtr op_desc_cast4 = std::make_shared("cast4", "Cast"); - OpDescPtr op_desc_relu = std::make_shared("relu", "Relu"); - OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - OpDescPtr op_desc_output = std::make_shared("output", "NetOutput"); - OpDescPtr op_desc_input = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {1, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - //vector dim_d; - GeShape shape_d(dim_a); - GeTensorDesc tensor_desc_d(shape_d); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT16); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - op_desc_input->AddOutputDesc(tensor_desc_a); - - op_desc_cast1->AddInputDesc(tensor_desc_a); - op_desc_cast1->AddOutputDesc(tensor_desc_b); - - op_desc_cast3->AddInputDesc(tensor_desc_c); - op_desc_cast3->AddOutputDesc(tensor_desc_d); - - op_desc_cast4->AddInputDesc(tensor_desc_c); - op_desc_cast4->AddOutputDesc(tensor_desc_c); - - op_desc_relu->AddInputDesc(tensor_desc_b); - op_desc_relu->AddOutputDesc(tensor_desc_c); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_output->AddInputDesc(tensor_desc_d); - op_desc_output->AddInputDesc(tensor_desc_d); - op_desc_output->AddInputDesc(tensor_desc_c); - - NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - NodePtr node_cast3 = graph->AddNode(op_desc_cast3); - NodePtr node_cast4 = graph->AddNode(op_desc_cast4); - NodePtr node_relu = graph->AddNode(op_desc_relu); - NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - NodePtr node_netoutput = graph->AddNode(op_desc_output); - NodePtr node_other = graph->AddNode(op_desc_input); - - GraphUtils::AddEdge(node_other->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast3->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_cast4->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - GraphUtils::AddEdge(node_cast3->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - GraphUtils::AddEdge(node_cast4->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(2)); - - return graph; - } - static void DumpGraph(const ge::ComputeGraphPtr graph, string graph_name) { - printf("start to dump graph %s...\n", graph_name.c_str()); - for (ge::NodePtr node : graph->GetAllNodes()) { - printf("node name = %s.\n", node->GetName().c_str()); - for (ge::OutDataAnchorPtr anchor : node->GetAllOutDataAnchors()) { - for (ge::InDataAnchorPtr peer_in_anchor : anchor->GetPeerInDataAnchors()) { - printf(" node name = %s[%d], out data node name = %s[%d].\n", - node->GetName().c_str(), - anchor->GetIdx(), - peer_in_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_anchor->GetIdx()); - } - } - if (node->GetOutControlAnchor() != nullptr) { - for (ge::InControlAnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { - printf(" node name = %s, out control node name = %s.\n", node->GetName().c_str(), - peer_in_anchor->GetOwnerNode()->GetName().c_str()); - } - } - } - - return; - } - -}; - - -const char *kOpTypeCast = "Cast"; -const char *kOpTypeRelu = "Relu"; - -const char *kPatternCast0 = "Cast0"; -const char *kPatternCast1 = "Cast1"; -const char *kPatternRelu = "Relu"; -#define UT_CHECK(cond, log_func, return_expr) \ - do { \ - if (cond) { \ - log_func; \ - return_expr; \ - } \ - } while (0) - -#define UT_CHECK_NOTNULL(val) \ - do { \ - if ((val) == nullptr) { \ - GELOGD("Parameter[%s] must not be null.", #val); \ - return fe::PARAM_INVALID; \ - } \ - } while (0) - -string pass_name_test = "CastCastFusionPass"; -class TestPass : public fe::PatternFusionBasePass { - using Mapping = std::map, std::vector, fe::CmpKey>; - protected: - vector DefinePatterns() override { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("CastCastFusionPass"); - UT_CHECK(pattern == nullptr, GELOGD(" Fail to create a new pattern object."), - return patterns); - pattern->AddOpDesc(kPatternCast0, {kOpTypeCast}) - .AddOpDesc(kPatternRelu, {kOpTypeRelu}) - .AddOpDesc(kPatternCast1, {kOpTypeCast}) - .SetInputs(kPatternRelu, {kPatternCast0}) - .SetInputs(kPatternCast1, {kPatternRelu}) - .SetOutput(kPatternCast1); - - patterns.push_back(pattern); - - return patterns; - } - - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override { - FusionPattern pattern("CastCastFusionPass"); - DumpMapping(pattern, mapping); - - CheckGraphCycle(graph); - ge::NodePtr cast_Node0 = GetNodeFromMapping(kPatternCast0, mapping); - CheckOpSupported(cast_Node0); - CheckOpSupported(cast_Node0->GetOpDesc()); - CheckAccuracySupported(cast_Node0); - - UT_CHECK(cast_Node0 == nullptr, GELOGD("cast_Node0 is null,fusion failed."), - return NOT_CHANGED); - ge::OpDescPtr cast_desc0 = cast_Node0->GetOpDesc(); - UT_CHECK(cast_desc0 == nullptr, GELOGD("cast_Node0's Desc is null,fusion failed."), - return NOT_CHANGED); - - ge::NodePtr relu_Node = GetNodeFromMapping(kPatternRelu, mapping); - UT_CHECK(relu_Node == nullptr, GELOGD("relu_Node is null,fusion failed."), - return NOT_CHANGED); - ge::OpDescPtr relu_desc = relu_Node->GetOpDesc(); - UT_CHECK(cast_desc0 == nullptr, GELOGD("relu_Node's Desc is null,fusion failed."), - return NOT_CHANGED); - - auto relu_input = relu_desc->MutableInputDesc(0); - UT_CHECK_NOTNULL(relu_input); - auto relu_input_desc_dtype = relu_input->GetDataType(); - - auto relu_output = relu_desc->MutableOutputDesc(0); - UT_CHECK_NOTNULL(relu_output); - auto relu_output_desc_dtype = relu_output->GetDataType(); - if (relu_input_desc_dtype != DT_FLOAT || relu_output_desc_dtype != DT_FLOAT) { - GELOGD("Relu node [%s]'s input dtype or output dtype is unsuitable", relu_desc->GetName().c_str()); - return NOT_CHANGED; - } - - ge::NodePtr cast_Node1 = GetNodeFromMapping(kPatternCast1, mapping); - UT_CHECK(cast_Node1 == nullptr, GELOGD("cast_Node1 is null,fusion failed."), - return NOT_CHANGED); - ge::OpDescPtr cast_desc1 = cast_Node1->GetOpDesc(); - UT_CHECK(cast_desc0 == nullptr, GELOGD("cast_Node1's Desc is null,fusion failed."), - return NOT_CHANGED); - - auto cast0_input = cast_desc0->MutableInputDesc(0); - UT_CHECK_NOTNULL(cast0_input); - DataType cast0_in_d_type = cast0_input->GetDataType(); - auto cast1_output = cast_desc1->MutableOutputDesc(0); - UT_CHECK_NOTNULL(cast1_output); - DataType cast1_out_d_type = cast1_output->GetDataType(); - if (cast0_in_d_type != cast1_out_d_type) { - GELOGD("Cast Node0 [%s] input data type is not equal to Cast Node1 [%s] output data type ", - cast_Node0->GetName().c_str(), cast_Node1->GetName().c_str()); - return NOT_CHANGED; - } - - auto cast0_out_data_anchor = cast_Node0->GetOutDataAnchor(0); - UT_CHECK_NOTNULL(cast0_out_data_anchor); - if (cast0_out_data_anchor->GetPeerInDataAnchors().size() > 1) { - GELOGD("The first output anchor of Cast node[%s] has more than one peer in anchor.", - cast_Node0->GetName().c_str()); - return NOT_CHANGED; - } - - auto relu_out_data_anchor = relu_Node->GetOutDataAnchor(0); - UT_CHECK_NOTNULL(relu_out_data_anchor); - if (relu_out_data_anchor->GetPeerInDataAnchors().size() > 1) { - for (auto node : relu_Node->GetOutAllNodes()) { - if (node->GetType() != "Cast") { - GELOGD("The output anchor of Relu node has not Cast node,name is [%s] Type is [%s].", - node->GetName().c_str(), node->GetType().c_str()); - return NOT_CHANGED; - } - auto node_desc = node->GetOpDesc(); - UT_CHECK_NOTNULL(node_desc); - auto in_dtype = node_desc->MutableInputDesc(0)->GetDataType(); - auto out_dtype = node_desc->MutableOutputDesc(0)->GetDataType(); - if (in_dtype != DT_FLOAT || out_dtype != DT_FLOAT16) { - GELOGD("The Cast node [%s]'s indatatype is not equal to DT_FLOAT or outdatatype is not equal to DT_FLOAT16.", - node->GetName().c_str()); - return NOT_CHANGED; - } - } - } - - ge::ComputeGraphPtr graphPtr = relu_Node->GetOwnerComputeGraph(); - UT_CHECK_NOTNULL(graphPtr); - if (GraphUtils::IsolateNode(cast_Node0, {0}) != GRAPH_SUCCESS) { - GELOGD("Isolate op:%s(%s) failed", cast_Node0->GetName().c_str(), cast_Node0->GetType().c_str()); - return FAILED; - } - if (GraphUtils::RemoveNodeWithoutRelink(graphPtr, cast_Node0) != GRAPH_SUCCESS) { - GELOGD("[Remove][Node] %s, type:%s without relink in graph:%s failed", - cast_Node0->GetName().c_str(), cast_Node0->GetType().c_str(), graph.GetName().c_str()); - return FAILED; - } - for (auto inAnchor : relu_out_data_anchor->GetPeerInDataAnchors()) { - auto node = inAnchor->GetOwnerNode(); - UT_CHECK_NOTNULL(node); - if (GraphUtils::IsolateNode(node, {0}) != GRAPH_SUCCESS) { - GELOGD("Isolate op:%s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); - return FAILED; - } - if (GraphUtils::RemoveNodeWithoutRelink(graphPtr, node) != GRAPH_SUCCESS) { - GELOGD("[Remove][Node] %s, type:%s without relink in graph:%s failed", - node->GetName().c_str(), node->GetType().c_str(), graph.GetName().c_str()); - return FAILED; - } - } - relu_desc->MutableInputDesc(0)->SetDataType(cast0_in_d_type); - relu_desc->MutableOutputDesc(0)->SetDataType(cast1_out_d_type); - new_nodes.push_back(relu_Node); - return SUCCESS; - } -}; - -TEST_F(UTESTGraphFusionPass, cast_relu_cast_01) -{ - ComputeGraphPtr graph = CreateCastReluCastGraph1(); - TestPass pass; - DumpGraph(graph, "test1"); - fe::Status status = pass.Run(*graph, nullptr); - EXPECT_EQ(fe::SUCCESS, status); - DumpGraph(graph, "test1"); - - vector dim_a = {8, 4, 16, 16}; - vector dim_b = {1, 4, 64, 64}; - - for(auto node : graph->GetDirectNode()) { - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Relu") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDims(), dim_b); - - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDims(), dim_b); - NodePtr node_out=node->GetOutDataNodes().at(0); - EXPECT_EQ(node_out->GetOpDesc()->GetType(),"NetOutput"); - EXPECT_EQ(node_out->GetOpDesc()->MutableInputDesc(0)->GetDataType(),DT_FLOAT16); - EXPECT_EQ(node_out->GetOpDesc()->MutableInputDesc(0)->MutableShape().GetDims(),dim_a); - NodePtr node0=node->GetInDataNodes().at(0); - EXPECT_EQ(node0->GetOpDesc()->GetType(),"Other"); - EXPECT_EQ(node0->GetOpDesc()->MutableOutputDesc(0)->GetDataType(),DT_FLOAT16); - EXPECT_EQ(node0->GetOpDesc()->MutableOutputDesc(0)->MutableShape().GetDims(), dim_a); - - } - } -} - -class UtOpsKernel : public OpsKernelInfoStore { - // initialize opsKernelInfoStore - Status Initialize(const std::map &options) override { - return SUCCESS; - } - - // close opsKernelInfoStore - Status Finalize() override { - return SUCCESS; - } - - // get all opsKernelInfo - void GetAllOpsKernelInfo(std::map &infos) const override {} - - // whether the opsKernelInfoStore is supported based on the operator attribute - bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const override { - return true; - } -}; - -TEST_F(UTESTGraphFusionPass, cast_relu_cast_02) -{ - ComputeGraphPtr graph = CreateCastReluCastGraph2(); - TestPass pass; - std::shared_ptr ops_kernel = std::make_shared(); - - std::shared_ptr base = ops_kernel; - fe::Status status = pass.Run(*graph, base); - EXPECT_EQ(fe::NOT_CHANGED, status); - -} -TEST_F(UTESTGraphFusionPass, cast_relu_cast_03) -{ - ComputeGraphPtr graph = CreateCastReluCastGraph3(); - - TestPass pass; - std::shared_ptr ops_kernel = std::make_shared(); - - std::shared_ptr base = ops_kernel; - fe::Status status = pass.Run(*graph, base); - EXPECT_EQ(fe::NOT_CHANGED, status); - -} -TEST_F(UTESTGraphFusionPass, cast_relu_cast_04) -{ - ComputeGraphPtr graph = CreateCastReluCastGraph4(); - TestPass pass; - std::shared_ptr ops_kernel = std::make_shared(); - - std::shared_ptr base = ops_kernel; - fe::Status status = pass.Run(*graph, base); - EXPECT_EQ(fe::NOT_CHANGED, status); - -} -TEST_F(UTESTGraphFusionPass, cast_relu_cast_05) -{ - ComputeGraphPtr graph = CreateCastReluCastGraph5(); - TestPass pass; - DumpGraph(graph, "test1"); - std::shared_ptr ops_kernel = std::make_shared(); - - std::shared_ptr base = ops_kernel; - fe::Status status = pass.Run(*graph, base); - EXPECT_EQ(fe::SUCCESS, status); - DumpGraph(graph, "test1"); - -} -TEST_F(UTESTGraphFusionPass, cast_relu_cast_06) -{ - - ComputeGraphPtr graph = CreateCastReluCastGraph6(); - TestPass pass; - std::shared_ptr ops_kernel = std::make_shared(); - - std::shared_ptr base = ops_kernel; - fe::Status status = pass.Run(*graph, base); - EXPECT_EQ(fe::NOT_CHANGED, status); - -} - -TEST_F(UTESTGraphFusionPass, coverage_01) { - REGISTER_PASS(pass_name_test, GRAPH_FUSION_PASS_TYPE_RESERVED, TestPass); - REGISTER_PASS("", BUILT_IN_GRAPH_PASS, TestPass); - REGISTER_PASS(pass_name_test, BUILT_IN_GRAPH_PASS, TestPass); - REGISTER_PASS(pass_name_test, BUILT_IN_GRAPH_PASS, TestPass); - std::map create_fns = - FusionPassRegistry::GetInstance().GetCreateFnByType(SECOND_ROUND_BUILT_IN_GRAPH_PASS); - EXPECT_NO_THROW( - create_fns = - FusionPassRegistry::GetInstance().GetCreateFnByType(BUILT_IN_GRAPH_PASS); - ); -} -} diff --git a/tests/ut/register/testcase/register_graph_fusion_2.cc b/tests/ut/register/testcase/register_graph_fusion_2.cc deleted file mode 100644 index 4816aac5d0fa338215dc13a0f6124e2adfc943b8..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_graph_fusion_2.cc +++ /dev/null @@ -1,862 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "graph/debug/ge_log.h" -/** - * Input Input - * | | - * switch Switch - * / \ / \ - * A B A B - * \ / -> | | - * Merge Cast Cast - * | \ / - * Cast Merge - * | | - * NetOutput NetOutPut - */ -namespace fe { -using Mapping = std::map, std::vector, fe::CmpKey>; -class SwapMergeCastFusionTestPass : public PatternFusionBasePass { - protected: - vector DefinePatterns() override; - - vector DefineInnerPatterns() override; - - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; - - private: - Status VerifyNodes(const ge::NodePtr &merge_node, ge::NodePtr &cast_node, ge::NodePtr &netout_node) const; - - Status RelinkMergeNode(const ge::NodePtr &merge_node, const ge::NodePtr &cast_node, const ge::NodePtr &netout_node); - - Status AddCastNodeBeforeMergeNode(const ge::NodePtr &merge_node, ge::OpDescPtr &cast_op_desc, - ge::ComputeGraph &graph); -}; - -static const string SWAPMERGECAST_PASS_NAME = "SwapMergeCastFusionPass"; -static const string PATTERN_MERGE = "Pattern_Merge"; -static const string PATTERN_CAST = "Pattern_Cast"; -static const string PATTERN_RELU = "Pattern_Relu"; -static const string OP_TYPE_MERGE = "Merge"; -static const string OP_TYPE_CAST = "Cast"; -static const string OP_TYPE_RELU = "Relu"; -static const string OP_TYPE_NETOUTPUT = "NetOutput"; - -vector SwapMergeCastFusionTestPass::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("SwapMergeCastFusionPattern"); - - pattern->AddOpDesc(PATTERN_MERGE, {OP_TYPE_MERGE}) - .AddOpDesc(PATTERN_CAST, {OP_TYPE_CAST}) - .SetInputs(PATTERN_CAST, {PATTERN_MERGE}) - .SetOutput(PATTERN_CAST); - - patterns.push_back(pattern); - - return patterns; -} - -vector SwapMergeCastFusionTestPass::DefineInnerPatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("SwapMergeCastFusionInnerPattern"); - - pattern->AddOpDesc(PATTERN_RELU, {OP_TYPE_RELU}) - .AddOpDesc(PATTERN_CAST, {OP_TYPE_CAST}) - .AddOpDesc(PATTERN_MERGE, {OP_TYPE_MERGE}) - .SetInputs(PATTERN_CAST, {PATTERN_RELU}) - .SetInputs(PATTERN_MERGE, {PATTERN_CAST}) - .SetOutput(PATTERN_MERGE); - - patterns.push_back(pattern); - return patterns; -} - -Status SwapMergeCastFusionTestPass::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - ge::NodePtr merge_node = GetNodeFromMapping(PATTERN_MERGE, mapping); - ge::NodePtr cast_node = GetNodeFromMapping(PATTERN_CAST, mapping); - CheckOpSupported(merge_node); - ge::NodePtr netout_node = nullptr; - Status verify_status = VerifyNodes(merge_node, cast_node, netout_node); - if (verify_status != SUCCESS) { - return verify_status; - } - - // unlink cast node and link merge node to netoutput node - Status status = RelinkMergeNode(merge_node, cast_node, netout_node); - if (status != SUCCESS) { - return status; - } - - // add cast node for each input data anchor of merge node - ge::OpDescPtr cast_op_desc = cast_node->GetOpDesc(); - status = AddCastNodeBeforeMergeNode(merge_node, cast_op_desc, graph); - if (status != SUCCESS) { - return status; - } - - if (graph.RemoveNode(cast_node) != ge::GRAPH_SUCCESS) { - return FAILED; - } - - return SUCCESS; -} - -#define UT_CHECK(cond, log_func, return_expr) \ - do { \ - if (cond) { \ - log_func; \ - return_expr; \ - } \ - } while (0) - -#define UT_CHECK_NOTNULL(val) \ - do { \ - if ((val) == nullptr) { \ - GE_LOGE("Parameter[%s] must not be null.", #val); \ - return fe::PARAM_INVALID; \ - } \ - } while (0) - - -Status -SwapMergeCastFusionTestPass::AddCastNodeBeforeMergeNode(const ge::NodePtr &merge_node, - ge::OpDescPtr &cast_op_desc, - ge::ComputeGraph &graph) { - ge::OpDescPtr merge_op_desc = merge_node->GetOpDesc(); - ge::DataType cast_out_d_type = cast_op_desc->MutableOutputDesc(0)->GetDataType(); - merge_op_desc->MutableOutputDesc(0)->SetDataType(cast_out_d_type); - - size_t input_size = merge_op_desc->GetAllInputsSize(); - for (size_t i = 0; i < input_size; i++) { - ge::InDataAnchorPtr in_data_anchor = merge_node->GetInDataAnchor(i); - if (in_data_anchor == nullptr || in_data_anchor->GetPeerOutAnchor() == nullptr) { - GELOGD("InData Anchor[%zu] of merge node[%s] is not linked.", i, merge_node->GetName().c_str()); - continue; - } - - // update data Type of each input tensor desc of merge node - ge::GeTensorDescPtr in_data_desc = merge_op_desc->MutableInputDesc(i); - if (in_data_desc == nullptr) { - GELOGD("In data desc[%zu] is null.", i); - continue; - } - in_data_desc->SetDataType(cast_out_d_type); - - // copy cast op desc and update the shape of input and output - ge::OpDescPtr new_cast_op_desc = ge::OpDescUtils::CopyOpDesc(cast_op_desc); - UT_CHECK(new_cast_op_desc == nullptr, - GE_LOGE("[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to copy op desc for cast node[%s].", - cast_op_desc->GetName().c_str()), - return FAILED); - - new_cast_op_desc->SetName(cast_op_desc->GetName() + std::to_string(i)); - new_cast_op_desc->MutableInputDesc(0)->SetShape(in_data_desc->GetShape()); - new_cast_op_desc->MutableOutputDesc(0)->SetShape(in_data_desc->GetShape()); - - ge::NodePtr new_cast_node = graph.AddNode(new_cast_op_desc); - UT_CHECK(new_cast_node == nullptr, - GE_LOGE("[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to add cast node[%s] to graph.", - new_cast_op_desc->GetName().c_str()), - return FAILED); - - ge::OutDataAnchorPtr out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - UT_CHECK_NOTNULL(out_data_anchor); - // unlink the indata anchor of merge node - in_data_anchor->UnlinkAll(); - if (ge::GraphUtils::AddEdge(out_data_anchor, new_cast_node->GetInDataAnchor(0)) != ge::GRAPH_SUCCESS) { - GE_LOGE("[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to link in_data_anchor of cast node[%s].", - new_cast_node->GetName().c_str()); - return FAILED; - } - if (ge::GraphUtils::AddEdge(new_cast_node->GetOutDataAnchor(0), in_data_anchor) != ge::GRAPH_SUCCESS) { - GE_LOGE( - "[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to link in_data_anchor[%zu] of merge node[%s]" - " with cast node.", - i, merge_node->GetName().c_str()); - return FAILED; - } - } - - return SUCCESS; -} - -Status SwapMergeCastFusionTestPass::RelinkMergeNode(const ge::NodePtr &merge_node, const ge::NodePtr &cast_node, - const ge::NodePtr &netout_node) { - ge::InDataAnchorPtr netout_in_data_anchor = cast_node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0); - cast_node->GetInDataAnchor(0)->UnlinkAll(); - cast_node->GetOutDataAnchor(0)->UnlinkAll(); - - // if cast node has in control anchors, then link them to netoutputnode - if (cast_node->GetInControlAnchor() != nullptr) { - if (cast_node->GetInControlAnchor()->GetPeerOutControlAnchors().size() > 0 && - netout_node->GetInControlAnchor() != nullptr) { - for (ge::OutControlAnchorPtr out_control_anchor : cast_node->GetInControlAnchor()->GetPeerOutControlAnchors()) { - if (ge::GraphUtils::AddEdge(out_control_anchor, netout_node->GetInControlAnchor()) != ge::GRAPH_SUCCESS) { - GE_LOGE( - "[GraphOpt][SwapMrgCastFus][RelkMrgNd] Fail to link control edge between netoutput node[%s]" - " and peer out control anchor of cast node[%s].", - netout_node->GetName().c_str(), cast_node->GetName().c_str()); - return FAILED; - } - } - } - cast_node->GetInControlAnchor()->UnlinkAll(); - } - - // usually cast node do not have any output control anchor - // if cast node has output control anchors, unlink them - if (cast_node->GetOutControlAnchor() != nullptr) { - cast_node->GetOutControlAnchor()->UnlinkAll(); - } - - if (ge::GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), netout_in_data_anchor) != ge::GRAPH_SUCCESS) { - GE_LOGE("[GraphOpt][SwapMrgCastFus][RelkMrgNd] Fail to link the output data anchor of merge node[%s].", - merge_node->GetName().c_str()); - return FAILED; - } - - return SUCCESS; -} - -Status SwapMergeCastFusionTestPass::VerifyNodes(const ge::NodePtr &merge_node, - ge::NodePtr &cast_node, ge::NodePtr &netout_node) const { - UT_CHECK(merge_node == nullptr, GE_LOGE("[GraphOpt][SwapMrgCastFus][VerifyNd] Merge node is nullptr."), - return PARAM_INVALID); - - UT_CHECK(cast_node == nullptr, GE_LOGE("[GraphOpt][SwapMrgCastFus][VerifyNd] Cast node is nullptr."), - return PARAM_INVALID); - - // merge node has two outputs, first output must has only one peer in anchor - if (merge_node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size() > 1) { - GELOGD( - "The first output anchor of Merge node[%s]" - " has more than one peer in anchor.", - merge_node->GetName().c_str()); - return NOT_CHANGED; - } - - // cast node must have only one output node - if (cast_node->GetOutDataNodesSize() != 1) { - GELOGD("Cast node[%s] has more than one out data nodes.", cast_node->GetName().c_str()); - return NOT_CHANGED; - } - - netout_node = cast_node->GetOutDataNodes().at(0); - UT_CHECK_NOTNULL(netout_node); - if (netout_node->GetType() != OP_TYPE_NETOUTPUT) { - GELOGD("The next node of cast node[%s] is not NetOutput.", cast_node->GetName().c_str()); - return NOT_CHANGED; - } - - return SUCCESS; -} - -using namespace ge; -using namespace fe; - -class UTESTGraphFusionPass2 : public testing::Test { - protected: - void SetUp() { - } - - void TearDown() { - - } - - protected: - static ComputeGraphPtr CreateSwapMergeCastGraph1() { - ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_switch = std::make_shared("switch", "Switch"); - ge::OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - ge::OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - ge::OpDescPtr op_desc_merge = std::make_shared("merge", "Merge"); - ge::OpDescPtr op_desc_cast = std::make_shared("cast", "Cast"); - ge::OpDescPtr op_desc_netoutput = std::make_shared("netoutput", "NetOutput"); - ge::OpDescPtr op_desc_other = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT16); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {8, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT16); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tensor_desc_d(shape_c); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - vector dim_e; - GeShape shape_e(dim_e); - GeTensorDesc tensor_desc_e(shape_e); - tensor_desc_e.SetFormat(FORMAT_ND); - tensor_desc_e.SetOriginFormat(FORMAT_ND); - tensor_desc_e.SetDataType(DT_INT32); - tensor_desc_e.SetOriginDataType(DT_INT32); - - op_desc_switch->AddOutputDesc(tensor_desc_a); - op_desc_switch->AddOutputDesc(tensor_desc_b); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_a); - - op_desc_relu2->AddInputDesc(tensor_desc_b); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_merge->AddInputDesc(tensor_desc_a); - op_desc_merge->AddInputDesc(tensor_desc_b); - op_desc_merge->AddOutputDesc(tensor_desc_c); - op_desc_merge->AddOutputDesc(tensor_desc_e); - - op_desc_other->AddInputDesc(tensor_desc_e); - - op_desc_cast->AddInputDesc(tensor_desc_c); - op_desc_cast->AddOutputDesc(tensor_desc_d); - - op_desc_netoutput->AddInputDesc(tensor_desc_d); - - ge::NodePtr node_switch = graph->AddNode(op_desc_switch); - ge::NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - ge::NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - ge::NodePtr node_merge = graph->AddNode(op_desc_merge); - ge::NodePtr node_cast = graph->AddNode(op_desc_cast); - ge::NodePtr node_netoutput = graph->AddNode(op_desc_netoutput); - ge::NodePtr node_other = graph->AddNode(op_desc_other); - - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(0), node_relu1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(1), node_relu2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu1->GetOutDataAnchor(0), node_merge->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_merge->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(0), node_cast->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(1), node_other->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_cast->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph2() { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - ge::OpDescPtr op_desc_some = std::make_shared("some_node", "Some"); - vector dim = {8, 4, 64, 64}; - GeShape shape(dim); - GeTensorDesc tensor_desc(shape); - tensor_desc.SetFormat(FORMAT_NCHW); - tensor_desc.SetOriginFormat(FORMAT_NCHW); - tensor_desc.SetDataType(DT_FLOAT16); - tensor_desc.SetOriginDataType(DT_FLOAT); - op_desc_some->AddInputDesc(tensor_desc); - op_desc_some->AddOutputDesc(tensor_desc); - - ge::NodePtr node_some = graph->AddNode(op_desc_some); - - for (ge::NodePtr node : graph->GetDirectNode()) { - if (node->GetType() == "Merge") { - ge::GraphUtils::AddEdge(node->GetOutDataAnchor(0), node_some->GetInDataAnchor(0)); - } - } - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph3() { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - ge::OpDescPtr op_desc_some = std::make_shared("some_node", "Some"); - vector dim = {8, 4, 64, 64}; - GeShape shape(dim); - GeTensorDesc tensor_desc(shape); - tensor_desc.SetFormat(FORMAT_NCHW); - tensor_desc.SetOriginFormat(FORMAT_NCHW); - tensor_desc.SetDataType(DT_FLOAT); - tensor_desc.SetOriginDataType(DT_FLOAT); - op_desc_some->AddInputDesc(tensor_desc); - op_desc_some->AddOutputDesc(tensor_desc); - - ge::NodePtr node_some = graph->AddNode(op_desc_some); - - for (ge::NodePtr node : graph->GetDirectNode()) { - if (node->GetType() == "Cast") { - ge::GraphUtils::AddEdge(node->GetOutDataAnchor(0), node_some->GetInDataAnchor(0)); - } - } - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph4() { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - - for (ge::NodePtr node : graph->GetDirectNode()) { - if (node->GetType() == "NetOutput") { - node->GetOpDesc()->SetType("NetOut"); - } - } - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph5() { - ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_switch = std::make_shared("switch", "Switch"); - ge::OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - ge::OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - ge::OpDescPtr op_desc_merge = std::make_shared("merge", "Merge"); - ge::OpDescPtr op_desc_cast = std::make_shared("cast", "Cast"); - ge::OpDescPtr op_desc_netoutput = std::make_shared("netoutput", "NetOutput"); - ge::OpDescPtr op_desc_other = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT16); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {8, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT16); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tensor_desc_d(shape_c); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - vector dim_e; - GeShape shape_e(dim_e); - GeTensorDesc tensor_desc_e(shape_e); - tensor_desc_e.SetFormat(FORMAT_ND); - tensor_desc_e.SetOriginFormat(FORMAT_ND); - tensor_desc_e.SetDataType(DT_INT32); - tensor_desc_e.SetOriginDataType(DT_INT32); - - op_desc_switch->AddOutputDesc(tensor_desc_a); - op_desc_switch->AddOutputDesc(tensor_desc_b); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_a); - - op_desc_relu2->AddInputDesc(tensor_desc_b); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_merge->AddInputDesc(tensor_desc_a); - op_desc_merge->AddInputDesc(tensor_desc_b); - op_desc_merge->AddOutputDesc(tensor_desc_c); - op_desc_merge->AddOutputDesc(tensor_desc_e); - - op_desc_other->AddInputDesc(tensor_desc_e); - - op_desc_cast->AddInputDesc(tensor_desc_c); - op_desc_cast->AddOutputDesc(tensor_desc_d); - - op_desc_netoutput->AddInputDesc(tensor_desc_d); - - ge::NodePtr node_switch = graph->AddNode(op_desc_switch); - ge::NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - ge::NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - ge::NodePtr node_merge = graph->AddNode(op_desc_merge); - ge::NodePtr node_cast = graph->AddNode(op_desc_cast); - ge::NodePtr node_netoutput = graph->AddNode(op_desc_netoutput); - ge::NodePtr node_other = graph->AddNode(op_desc_other); - - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(0), node_relu1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(1), node_relu2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu1->GetOutDataAnchor(0), node_merge->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_merge->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(0), node_cast->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(1), node_other->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_cast->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - ge::GraphUtils::AddEdge(node_relu1->GetOutControlAnchor(), node_cast->GetInControlAnchor()); - ge::GraphUtils::AddEdge(node_relu2->GetOutControlAnchor(), node_cast->GetInControlAnchor()); - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph6() { - ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_switch1 = std::make_shared("switch1", "Switch"); - ge::OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - ge::OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - ge::OpDescPtr op_desc_merge1 = std::make_shared("merge1", "Merge"); - ge::OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - ge::OpDescPtr op_desc_other1 = std::make_shared("other1", "Other"); - - ge::OpDescPtr op_desc_switch2 = std::make_shared("switch2", "Switch"); - ge::OpDescPtr op_desc_relu3 = std::make_shared("relu3", "Relu"); - ge::OpDescPtr op_desc_relu4 = std::make_shared("relu4", "Relu"); - ge::OpDescPtr op_desc_merge2 = std::make_shared("merge2", "Merge"); - ge::OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - ge::OpDescPtr op_desc_other2 = std::make_shared("other2", "Other"); - - ge::OpDescPtr op_desc_netoutput = std::make_shared("netoutput", "NetOutput"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT16); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {8, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT16); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tensor_desc_d(shape_c); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - vector dim_e; - GeShape shape_e(dim_e); - GeTensorDesc tensor_desc_e(shape_e); - tensor_desc_e.SetFormat(FORMAT_ND); - tensor_desc_e.SetOriginFormat(FORMAT_ND); - tensor_desc_e.SetDataType(DT_INT32); - tensor_desc_e.SetOriginDataType(DT_INT32); - - op_desc_switch1->AddOutputDesc(tensor_desc_a); - op_desc_switch1->AddOutputDesc(tensor_desc_b); - - op_desc_switch2->AddOutputDesc(tensor_desc_a); - op_desc_switch2->AddOutputDesc(tensor_desc_b); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_a); - op_desc_relu2->AddInputDesc(tensor_desc_b); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_relu3->AddInputDesc(tensor_desc_a); - op_desc_relu3->AddOutputDesc(tensor_desc_a); - op_desc_relu4->AddInputDesc(tensor_desc_b); - op_desc_relu4->AddOutputDesc(tensor_desc_b); - - op_desc_merge1->AddInputDesc(tensor_desc_a); - op_desc_merge1->AddInputDesc(tensor_desc_b); - op_desc_merge1->AddOutputDesc(tensor_desc_c); - op_desc_merge1->AddOutputDesc(tensor_desc_e); - - op_desc_merge2->AddInputDesc(tensor_desc_a); - op_desc_merge2->AddInputDesc(tensor_desc_b); - op_desc_merge2->AddOutputDesc(tensor_desc_c); - op_desc_merge2->AddOutputDesc(tensor_desc_e); - - op_desc_other1->AddInputDesc(tensor_desc_e); - - op_desc_other2->AddInputDesc(tensor_desc_e); - - op_desc_cast1->AddInputDesc(tensor_desc_c); - op_desc_cast1->AddOutputDesc(tensor_desc_d); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_netoutput->AddInputDesc(tensor_desc_d); - op_desc_netoutput->AddInputDesc(tensor_desc_d); - - ge::NodePtr node_switch1 = graph->AddNode(op_desc_switch1); - ge::NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - ge::NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - ge::NodePtr node_merge1 = graph->AddNode(op_desc_merge1); - ge::NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - ge::NodePtr node_other1 = graph->AddNode(op_desc_other1); - - ge::NodePtr node_switch2 = graph->AddNode(op_desc_switch2); - ge::NodePtr node_relu3 = graph->AddNode(op_desc_relu3); - ge::NodePtr node_relu4 = graph->AddNode(op_desc_relu4); - ge::NodePtr node_merge2 = graph->AddNode(op_desc_merge2); - ge::NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - ge::NodePtr node_other2 = graph->AddNode(op_desc_other2); - - ge::NodePtr node_netoutput = graph->AddNode(op_desc_netoutput); - - ge::GraphUtils::AddEdge(node_switch1->GetOutDataAnchor(0), node_relu1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch1->GetOutDataAnchor(1), node_relu2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu1->GetOutDataAnchor(0), node_merge1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_merge1->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge1->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge1->GetOutDataAnchor(1), node_other1->GetInDataAnchor(0)); - - ge::GraphUtils::AddEdge(node_switch2->GetOutDataAnchor(0), node_relu3->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch2->GetOutDataAnchor(1), node_relu4->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu3->GetOutDataAnchor(0), node_merge2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu4->GetOutDataAnchor(0), node_merge2->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge2->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge2->GetOutDataAnchor(1), node_other2->GetInDataAnchor(0)); - - ge::GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - - return graph; - } -}; - -TEST_F(UTESTGraphFusionPass2, UTESTGraphFusionPass2_1) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - - SwapMergeCastFusionTestPass pass; - Status status = fe::FAILED; - - - pass.SetName("test"); - status = pass.Run(*graph); - - EXPECT_EQ(fe::SUCCESS, status); - vector dim_a = {8, 4, 16, 16}; - vector dim_b = {1, 4, 64, 64}; - vector dim_c = {8, 4, 64, 64}; - for (ge::NodePtr node : graph->GetDirectNode()) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Merge") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - - EXPECT_EQ(op_desc->MutableInputDesc(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDims(), dim_b); - - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDims(), dim_c); - } - if (op_desc->GetType() == "Cast") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - } - if (op_desc->GetName() == "relu1") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_a); - } - if (op_desc->GetName() == "relu2") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_b); - } - } -} - -TEST_F(UTESTGraphFusionPass2, UTESTGraphFusionPass2_2) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph2(); - - Status status = fe::FAILED; - SwapMergeCastFusionTestPass pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - - EXPECT_EQ(fe::NOT_CHANGED, status); -} - -TEST_F(UTESTGraphFusionPass2, UTESTGraphFusionPass2_3) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph3(); - - Status status = fe::FAILED; - - SwapMergeCastFusionTestPass pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - EXPECT_EQ(fe::NOT_CHANGED, status); -} - -TEST_F(UTESTGraphFusionPass2, UTESTGraphFusionPass2_4) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph4(); - - Status status = fe::FAILED; - SwapMergeCastFusionTestPass pass; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(fe::NOT_CHANGED, status); -} - -TEST_F(UTESTGraphFusionPass2, UTESTGraphFusionPass2_5) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph5(); - Status status = fe::FAILED; - - SwapMergeCastFusionTestPass pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - - EXPECT_EQ(fe::SUCCESS, status); - vector dim_a = {8, 4, 16, 16}; - vector dim_b = {1, 4, 64, 64}; - vector dim_c = {8, 4, 64, 64}; - for (ge::NodePtr node : graph->GetDirectNode()) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Merge") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - - EXPECT_EQ(op_desc->MutableInputDesc(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDims(), dim_b); - - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDims(), dim_c); - } - if (op_desc->GetType() == "Cast") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - } - if (op_desc->GetName() == "relu1") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(node->GetOutControlAnchor()->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetType(), "NetOutput"); - } - if (op_desc->GetName() == "relu2") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(node->GetOutControlAnchor()->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetType(), "NetOutput"); - } - } -} - - -TEST_F(UTESTGraphFusionPass2, UTESTGraphFusionPass2_6) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph6(); - - Status status = fe::FAILED; - - SwapMergeCastFusionTestPass pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - - EXPECT_EQ(fe::SUCCESS, status); - vector dim_a = {8, 4, 16, 16}; - vector dim_b = {1, 4, 64, 64}; - vector dim_c = {8, 4, 64, 64}; - for (ge::NodePtr node : graph->GetDirectNode()) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Merge") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - - EXPECT_EQ(op_desc->MutableInputDesc(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDims(), dim_b); - - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDims(), dim_c); - - EXPECT_EQ(node->GetOutDataNodes().at(0)->GetType(), "NetOutput"); - } - if (op_desc->GetType() == "Cast") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - } - if (op_desc->GetName() == "relu1" || op_desc->GetName() == "relu3") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_a); - } - if (op_desc->GetName() == "relu2" || op_desc->GetName() == "relu4") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_b); - } - } -} - -TEST_F(UTESTGraphFusionPass2, get_and_match_inner_pattern) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph6(); - Status status = fe::FAILED; - SwapMergeCastFusionTestPass pass; - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(fe::SUCCESS, status); - - const auto &inner_patterns = pass.GetInnerPatterns(); - EXPECT_EQ(inner_patterns.size(), 1); - - bool inner_pattern_check_flag = false; - for (ge::NodePtr node : graph->GetDirectNode()) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Merge") { - std::shared_ptr output_op_desc = inner_patterns.at(0)->GetOutput(); - ASSERT_NE(output_op_desc, nullptr); - fe::PatternFusionBasePass::Mapping mapping; - inner_pattern_check_flag = pass.MatchFromOutput(node, output_op_desc, mapping); - EXPECT_EQ(inner_pattern_check_flag, true); - } - } -} -} diff --git a/tests/ut/register/testcase/register_graph_fusion_3.cc b/tests/ut/register/testcase/register_graph_fusion_3.cc deleted file mode 100644 index 49604145e84257a5cfbeb28ac3c75999d5703901..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_graph_fusion_3.cc +++ /dev/null @@ -1,838 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/graph_optimizer/fusion_common/graph_pass_util.h" -#include "graph/debug/ge_log.h" -/** - * Input Input - * | | - * switch Switch - * / \ / \ - * A B A B - * \ / -> | | - * Merge Cast Cast - * | \ / - * Cast Merge - * | | - * NetOutput NetOutPut - */ -namespace fe { -using Mapping = std::map, std::vector, fe::CmpKey>; -class SwapMergeCastFusionTestPass3 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; - - private: - Status VerifyNodes(const ge::NodePtr &merge_node, ge::NodePtr &cast_node, ge::NodePtr &netout_node) const; - - Status RelinkMergeNode(const ge::NodePtr &merge_node, const ge::NodePtr &cast_node, const ge::NodePtr &netout_node); - - Status AddCastNodeBeforeMergeNode(const ge::NodePtr &merge_node, ge::OpDescPtr &cast_op_desc, - ge::ComputeGraph &graph); -}; - -static const string SWAPMERGECAST_PASS_NAME = "SwapMergeCastFusionPass"; -static const string PATTERN_MERGE = "Pattern_Merge"; -static const string PATTERN_CAST = "Pattern_Cast"; -static const string OP_TYPE_MERGE = "Merge"; -static const string OP_TYPE_CAST = "Cast"; -static const string OP_TYPE_NETOUTPUT = "NetOutput"; - -vector SwapMergeCastFusionTestPass3::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("SwapMergeCastFusionPattern"); - - pattern->AddOpDesc(PATTERN_MERGE, {OP_TYPE_MERGE}) - .AddOpDesc(PATTERN_CAST, {OP_TYPE_CAST}) - .SetInputs(PATTERN_CAST, {PATTERN_MERGE}) - .SetOutput(PATTERN_CAST); - - patterns.push_back(pattern); - - return patterns; -} - -Status SwapMergeCastFusionTestPass3::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - ge::NodePtr merge_node = GetNodeFromMapping(PATTERN_MERGE, mapping); - ge::NodePtr cast_node = GetNodeFromMapping(PATTERN_CAST, mapping); - ge::NodePtr netout_node = nullptr; - Status verify_status = VerifyNodes(merge_node, cast_node, netout_node); - if (verify_status != SUCCESS) { - return verify_status; - } - - // unlink cast node and link merge node to netoutput node - Status status = RelinkMergeNode(merge_node, cast_node, netout_node); - if (status != SUCCESS) { - return status; - } - - // add cast node for each input data anchor of merge node - ge::OpDescPtr cast_op_desc = cast_node->GetOpDesc(); - status = AddCastNodeBeforeMergeNode(merge_node, cast_op_desc, graph); - if (status != SUCCESS) { - return status; - } - - if (graph.RemoveNode(cast_node) != ge::GRAPH_SUCCESS) { - return FAILED; - } - - return SUCCESS; -} - -#define UT_CHECK(cond, log_func, return_expr) \ - do { \ - if (cond) { \ - log_func; \ - return_expr; \ - } \ - } while (0) - -#define UT_CHECK_NOTNULL(val) \ - do { \ - if ((val) == nullptr) { \ - GE_LOGE("Parameter[%s] must not be null.", #val); \ - return fe::PARAM_INVALID; \ - } \ - } while (0) - - -Status -SwapMergeCastFusionTestPass3::AddCastNodeBeforeMergeNode(const ge::NodePtr &merge_node, - ge::OpDescPtr &cast_op_desc, - ge::ComputeGraph &graph) { - ge::OpDescPtr merge_op_desc = merge_node->GetOpDesc(); - ge::DataType cast_out_d_type = cast_op_desc->MutableOutputDesc(0)->GetDataType(); - merge_op_desc->MutableOutputDesc(0)->SetDataType(cast_out_d_type); - - size_t input_size = merge_op_desc->GetAllInputsSize(); - for (size_t i = 0; i < input_size; i++) { - ge::InDataAnchorPtr in_data_anchor = merge_node->GetInDataAnchor(i); - if (in_data_anchor == nullptr || in_data_anchor->GetPeerOutAnchor() == nullptr) { - GELOGD("InData Anchor[%zu] of merge node[%s] is not linked.", i, merge_node->GetName().c_str()); - continue; - } - - // update data Type of each input tensor desc of merge node - ge::GeTensorDescPtr in_data_desc = merge_op_desc->MutableInputDesc(i); - if (in_data_desc == nullptr) { - GELOGD("In data desc[%zu] is null.", i); - continue; - } - in_data_desc->SetDataType(cast_out_d_type); - - // copy cast op desc and update the shape of input and output - ge::OpDescPtr new_cast_op_desc = ge::OpDescUtils::CopyOpDesc(cast_op_desc); - UT_CHECK(new_cast_op_desc == nullptr, - GE_LOGE("[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to copy op desc for cast node[%s].", - cast_op_desc->GetName().c_str()), - return FAILED); - - new_cast_op_desc->SetName(cast_op_desc->GetName() + std::to_string(i)); - new_cast_op_desc->MutableInputDesc(0)->SetShape(in_data_desc->GetShape()); - new_cast_op_desc->MutableOutputDesc(0)->SetShape(in_data_desc->GetShape()); - - ge::NodePtr new_cast_node = graph.AddNode(new_cast_op_desc); - UT_CHECK(new_cast_node == nullptr, - GE_LOGE("[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to add cast node[%s] to graph.", - new_cast_op_desc->GetName().c_str()), - return FAILED); - - ge::OutDataAnchorPtr out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - UT_CHECK_NOTNULL(out_data_anchor); - // unlink the indata anchor of merge node - in_data_anchor->UnlinkAll(); - if (ge::GraphUtils::AddEdge(out_data_anchor, new_cast_node->GetInDataAnchor(0)) != ge::GRAPH_SUCCESS) { - GE_LOGE("[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to link in_data_anchor of cast node[%s].", - new_cast_node->GetName().c_str()); - return FAILED; - } - if (ge::GraphUtils::AddEdge(new_cast_node->GetOutDataAnchor(0), in_data_anchor) != ge::GRAPH_SUCCESS) { - GE_LOGE( - "[GraphOpt][SwapMrgCastFus][AddCastNd] Fail to link in_data_anchor[%zu] of merge node[%s]" - " with cast node.", - i, merge_node->GetName().c_str()); - return FAILED; - } - } - - return SUCCESS; -} - -Status SwapMergeCastFusionTestPass3::RelinkMergeNode(const ge::NodePtr &merge_node, const ge::NodePtr &cast_node, - const ge::NodePtr &netout_node) { - ge::InDataAnchorPtr netout_in_data_anchor = cast_node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0); - cast_node->GetInDataAnchor(0)->UnlinkAll(); - cast_node->GetOutDataAnchor(0)->UnlinkAll(); - - // if cast node has in control anchors, then link them to netoutputnode - if (cast_node->GetInControlAnchor() != nullptr) { - if (cast_node->GetInControlAnchor()->GetPeerOutControlAnchors().size() > 0 && - netout_node->GetInControlAnchor() != nullptr) { - for (ge::OutControlAnchorPtr out_control_anchor : cast_node->GetInControlAnchor()->GetPeerOutControlAnchors()) { - if (ge::GraphUtils::AddEdge(out_control_anchor, netout_node->GetInControlAnchor()) != ge::GRAPH_SUCCESS) { - GE_LOGE( - "[GraphOpt][SwapMrgCastFus][RelkMrgNd] Fail to link control edge between netoutput node[%s]" - " and peer out control anchor of cast node[%s].", - netout_node->GetName().c_str(), cast_node->GetName().c_str()); - return FAILED; - } - } - } - cast_node->GetInControlAnchor()->UnlinkAll(); - } - - // usually cast node do not have any output control anchor - // if cast node has output control anchors, unlink them - if (cast_node->GetOutControlAnchor() != nullptr) { - cast_node->GetOutControlAnchor()->UnlinkAll(); - } - - if (ge::GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), netout_in_data_anchor) != ge::GRAPH_SUCCESS) { - GE_LOGE("[GraphOpt][SwapMrgCastFus][RelkMrgNd] Fail to link the output data anchor of merge node[%s].", - merge_node->GetName().c_str()); - return FAILED; - } - - return SUCCESS; -} - -Status SwapMergeCastFusionTestPass3::VerifyNodes(const ge::NodePtr &merge_node, - ge::NodePtr &cast_node, ge::NodePtr &netout_node) const { - UT_CHECK(merge_node == nullptr, GE_LOGE("[GraphOpt][SwapMrgCastFus][VerifyNd] Merge node is nullptr."), - return PARAM_INVALID); - - UT_CHECK(cast_node == nullptr, GE_LOGE("[GraphOpt][SwapMrgCastFus][VerifyNd] Cast node is nullptr."), - return PARAM_INVALID); - - // merge node has two outputs, first output must has only one peer in anchor - if (merge_node->GetOutDataAnchor(0)->GetPeerInDataAnchors().size() > 1) { - GELOGD( - "The first output anchor of Merge node[%s]" - " has more than one peer in anchor.", - merge_node->GetName().c_str()); - return NOT_CHANGED; - } - - // cast node must have only one output node - if (cast_node->GetOutDataNodesSize() != 1) { - GELOGD("Cast node[%s] has more than one out data nodes.", cast_node->GetName().c_str()); - return NOT_CHANGED; - } - - netout_node = cast_node->GetOutDataNodes().at(0); - UT_CHECK_NOTNULL(netout_node); - if (netout_node->GetType() != OP_TYPE_NETOUTPUT) { - GELOGD("The next node of cast node[%s] is not NetOutput.", cast_node->GetName().c_str()); - return NOT_CHANGED; - } - - return SUCCESS; -} - -using namespace ge; -using namespace fe; - -class UTESTGraphFusionPass3 : public testing::Test { - protected: - void SetUp() { - } - - void TearDown() { - - } - - protected: - static ComputeGraphPtr CreateSwapMergeCastGraph1() { - ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_switch = std::make_shared("switch", "Switch"); - ge::OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - ge::OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - ge::OpDescPtr op_desc_merge = std::make_shared("merge", "Merge"); - ge::OpDescPtr op_desc_cast = std::make_shared("cast", "Cast"); - ge::OpDescPtr op_desc_netoutput = std::make_shared("netoutput", "NetOutput"); - ge::OpDescPtr op_desc_other = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT16); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {8, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT16); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tensor_desc_d(shape_c); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - vector dim_e; - GeShape shape_e(dim_e); - GeTensorDesc tensor_desc_e(shape_e); - tensor_desc_e.SetFormat(FORMAT_ND); - tensor_desc_e.SetOriginFormat(FORMAT_ND); - tensor_desc_e.SetDataType(DT_INT32); - tensor_desc_e.SetOriginDataType(DT_INT32); - - op_desc_switch->AddOutputDesc(tensor_desc_a); - op_desc_switch->AddOutputDesc(tensor_desc_b); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_a); - - op_desc_relu2->AddInputDesc(tensor_desc_b); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_merge->AddInputDesc(tensor_desc_a); - op_desc_merge->AddInputDesc(tensor_desc_b); - op_desc_merge->AddOutputDesc(tensor_desc_c); - op_desc_merge->AddOutputDesc(tensor_desc_e); - - op_desc_other->AddInputDesc(tensor_desc_e); - - op_desc_cast->AddInputDesc(tensor_desc_c); - op_desc_cast->AddOutputDesc(tensor_desc_d); - - op_desc_netoutput->AddInputDesc(tensor_desc_d); - - ge::NodePtr node_switch = graph->AddNode(op_desc_switch); - ge::NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - ge::NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - ge::NodePtr node_merge = graph->AddNode(op_desc_merge); - ge::NodePtr node_cast = graph->AddNode(op_desc_cast); - ge::NodePtr node_netoutput = graph->AddNode(op_desc_netoutput); - ge::NodePtr node_other = graph->AddNode(op_desc_other); - - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(0), node_relu1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(1), node_relu2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu1->GetOutDataAnchor(0), node_merge->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_merge->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(0), node_cast->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(1), node_other->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_cast->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph2() { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - ge::OpDescPtr op_desc_some = std::make_shared("some_node", "Some"); - vector dim = {8, 4, 64, 64}; - GeShape shape(dim); - GeTensorDesc tensor_desc(shape); - tensor_desc.SetFormat(FORMAT_NCHW); - tensor_desc.SetOriginFormat(FORMAT_NCHW); - tensor_desc.SetDataType(DT_FLOAT16); - tensor_desc.SetOriginDataType(DT_FLOAT); - op_desc_some->AddInputDesc(tensor_desc); - op_desc_some->AddOutputDesc(tensor_desc); - - ge::NodePtr node_some = graph->AddNode(op_desc_some); - - for (ge::NodePtr node : graph->GetDirectNode()) { - if (node->GetType() == "Merge") { - ge::GraphUtils::AddEdge(node->GetOutDataAnchor(0), node_some->GetInDataAnchor(0)); - } - } - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph3() { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - ge::OpDescPtr op_desc_some = std::make_shared("some_node", "Some"); - vector dim = {8, 4, 64, 64}; - GeShape shape(dim); - GeTensorDesc tensor_desc(shape); - tensor_desc.SetFormat(FORMAT_NCHW); - tensor_desc.SetOriginFormat(FORMAT_NCHW); - tensor_desc.SetDataType(DT_FLOAT); - tensor_desc.SetOriginDataType(DT_FLOAT); - op_desc_some->AddInputDesc(tensor_desc); - op_desc_some->AddOutputDesc(tensor_desc); - - ge::NodePtr node_some = graph->AddNode(op_desc_some); - - for (ge::NodePtr node : graph->GetDirectNode()) { - if (node->GetType() == "Cast") { - ge::GraphUtils::AddEdge(node->GetOutDataAnchor(0), node_some->GetInDataAnchor(0)); - } - } - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph4() { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - - for (ge::NodePtr node : graph->GetDirectNode()) { - if (node->GetType() == "NetOutput") { - node->GetOpDesc()->SetType("NetOut"); - } - } - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph5() { - ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_switch = std::make_shared("switch", "Switch"); - ge::OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - ge::OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - ge::OpDescPtr op_desc_merge = std::make_shared("merge", "Merge"); - ge::OpDescPtr op_desc_cast = std::make_shared("cast", "Cast"); - ge::OpDescPtr op_desc_netoutput = std::make_shared("netoutput", "NetOutput"); - ge::OpDescPtr op_desc_other = std::make_shared("other", "Other"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT16); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {8, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT16); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tensor_desc_d(shape_c); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - vector dim_e; - GeShape shape_e(dim_e); - GeTensorDesc tensor_desc_e(shape_e); - tensor_desc_e.SetFormat(FORMAT_ND); - tensor_desc_e.SetOriginFormat(FORMAT_ND); - tensor_desc_e.SetDataType(DT_INT32); - tensor_desc_e.SetOriginDataType(DT_INT32); - - op_desc_switch->AddOutputDesc(tensor_desc_a); - op_desc_switch->AddOutputDesc(tensor_desc_b); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_a); - - op_desc_relu2->AddInputDesc(tensor_desc_b); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_merge->AddInputDesc(tensor_desc_a); - op_desc_merge->AddInputDesc(tensor_desc_b); - op_desc_merge->AddOutputDesc(tensor_desc_c); - op_desc_merge->AddOutputDesc(tensor_desc_e); - - op_desc_other->AddInputDesc(tensor_desc_e); - - op_desc_cast->AddInputDesc(tensor_desc_c); - op_desc_cast->AddOutputDesc(tensor_desc_d); - - op_desc_netoutput->AddInputDesc(tensor_desc_d); - - ge::NodePtr node_switch = graph->AddNode(op_desc_switch); - ge::NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - ge::NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - ge::NodePtr node_merge = graph->AddNode(op_desc_merge); - ge::NodePtr node_cast = graph->AddNode(op_desc_cast); - ge::NodePtr node_netoutput = graph->AddNode(op_desc_netoutput); - ge::NodePtr node_other = graph->AddNode(op_desc_other); - - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(0), node_relu1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch->GetOutDataAnchor(1), node_relu2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu1->GetOutDataAnchor(0), node_merge->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_merge->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(0), node_cast->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge->GetOutDataAnchor(1), node_other->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_cast->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - - ge::GraphUtils::AddEdge(node_relu1->GetOutControlAnchor(), node_cast->GetInControlAnchor()); - ge::GraphUtils::AddEdge(node_relu2->GetOutControlAnchor(), node_cast->GetInControlAnchor()); - return graph; - } - - static ComputeGraphPtr CreateSwapMergeCastGraph6() { - ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_switch1 = std::make_shared("switch1", "Switch"); - ge::OpDescPtr op_desc_relu1 = std::make_shared("relu1", "Relu"); - ge::OpDescPtr op_desc_relu2 = std::make_shared("relu2", "Relu"); - ge::OpDescPtr op_desc_merge1 = std::make_shared("merge1", "Merge"); - ge::OpDescPtr op_desc_cast1 = std::make_shared("cast1", "Cast"); - ge::OpDescPtr op_desc_other1 = std::make_shared("other1", "Other"); - - ge::OpDescPtr op_desc_switch2 = std::make_shared("switch2", "Switch"); - ge::OpDescPtr op_desc_relu3 = std::make_shared("relu3", "Relu"); - ge::OpDescPtr op_desc_relu4 = std::make_shared("relu4", "Relu"); - ge::OpDescPtr op_desc_merge2 = std::make_shared("merge2", "Merge"); - ge::OpDescPtr op_desc_cast2 = std::make_shared("cast2", "Cast"); - ge::OpDescPtr op_desc_other2 = std::make_shared("other2", "Other"); - - ge::OpDescPtr op_desc_netoutput = std::make_shared("netoutput", "NetOutput"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - vector dim_b = {1, 4, 64, 64}; - GeShape shape_b(dim_b); - GeTensorDesc tensor_desc_b(shape_b); - tensor_desc_b.SetFormat(FORMAT_NCHW); - tensor_desc_b.SetOriginFormat(FORMAT_NCHW); - tensor_desc_b.SetDataType(DT_FLOAT16); - tensor_desc_b.SetOriginDataType(DT_FLOAT); - - vector dim_c = {8, 4, 64, 64}; - GeShape shape_c(dim_c); - GeTensorDesc tensor_desc_c(shape_c); - tensor_desc_c.SetFormat(FORMAT_NCHW); - tensor_desc_c.SetOriginFormat(FORMAT_NCHW); - tensor_desc_c.SetDataType(DT_FLOAT16); - tensor_desc_c.SetOriginDataType(DT_FLOAT); - - GeTensorDesc tensor_desc_d(shape_c); - tensor_desc_d.SetFormat(FORMAT_NCHW); - tensor_desc_d.SetOriginFormat(FORMAT_NCHW); - tensor_desc_d.SetDataType(DT_FLOAT); - tensor_desc_d.SetOriginDataType(DT_FLOAT); - - vector dim_e; - GeShape shape_e(dim_e); - GeTensorDesc tensor_desc_e(shape_e); - tensor_desc_e.SetFormat(FORMAT_ND); - tensor_desc_e.SetOriginFormat(FORMAT_ND); - tensor_desc_e.SetDataType(DT_INT32); - tensor_desc_e.SetOriginDataType(DT_INT32); - - op_desc_switch1->AddOutputDesc(tensor_desc_a); - op_desc_switch1->AddOutputDesc(tensor_desc_b); - - op_desc_switch2->AddOutputDesc(tensor_desc_a); - op_desc_switch2->AddOutputDesc(tensor_desc_b); - - op_desc_relu1->AddInputDesc(tensor_desc_a); - op_desc_relu1->AddOutputDesc(tensor_desc_a); - op_desc_relu2->AddInputDesc(tensor_desc_b); - op_desc_relu2->AddOutputDesc(tensor_desc_b); - - op_desc_relu3->AddInputDesc(tensor_desc_a); - op_desc_relu3->AddOutputDesc(tensor_desc_a); - op_desc_relu4->AddInputDesc(tensor_desc_b); - op_desc_relu4->AddOutputDesc(tensor_desc_b); - - op_desc_merge1->AddInputDesc(tensor_desc_a); - op_desc_merge1->AddInputDesc(tensor_desc_b); - op_desc_merge1->AddOutputDesc(tensor_desc_c); - op_desc_merge1->AddOutputDesc(tensor_desc_e); - - op_desc_merge2->AddInputDesc(tensor_desc_a); - op_desc_merge2->AddInputDesc(tensor_desc_b); - op_desc_merge2->AddOutputDesc(tensor_desc_c); - op_desc_merge2->AddOutputDesc(tensor_desc_e); - - op_desc_other1->AddInputDesc(tensor_desc_e); - - op_desc_other2->AddInputDesc(tensor_desc_e); - - op_desc_cast1->AddInputDesc(tensor_desc_c); - op_desc_cast1->AddOutputDesc(tensor_desc_d); - - op_desc_cast2->AddInputDesc(tensor_desc_c); - op_desc_cast2->AddOutputDesc(tensor_desc_d); - - op_desc_netoutput->AddInputDesc(tensor_desc_d); - op_desc_netoutput->AddInputDesc(tensor_desc_d); - - ge::NodePtr node_switch1 = graph->AddNode(op_desc_switch1); - ge::NodePtr node_relu1 = graph->AddNode(op_desc_relu1); - ge::NodePtr node_relu2 = graph->AddNode(op_desc_relu2); - ge::NodePtr node_merge1 = graph->AddNode(op_desc_merge1); - ge::NodePtr node_cast1 = graph->AddNode(op_desc_cast1); - ge::NodePtr node_other1 = graph->AddNode(op_desc_other1); - - ge::NodePtr node_switch2 = graph->AddNode(op_desc_switch2); - ge::NodePtr node_relu3 = graph->AddNode(op_desc_relu3); - ge::NodePtr node_relu4 = graph->AddNode(op_desc_relu4); - ge::NodePtr node_merge2 = graph->AddNode(op_desc_merge2); - ge::NodePtr node_cast2 = graph->AddNode(op_desc_cast2); - ge::NodePtr node_other2 = graph->AddNode(op_desc_other2); - - ge::NodePtr node_netoutput = graph->AddNode(op_desc_netoutput); - - std::shared_ptr node_map_info = - std::make_shared(); - node_map_info->node_type_map = std::make_shared(); - - std::map map_relu; - map_relu.emplace(std::make_pair("relu1", node_relu1)); - - std::map map_cast; - map_cast.emplace(std::make_pair("cast1", node_cast1)); - map_cast.emplace(std::make_pair("cast2", node_cast2)); - - std::map map_merge; - map_merge.emplace(std::make_pair("merge1", node_merge1)); - map_merge.emplace(std::make_pair("merge2", node_merge2)); - - node_map_info->node_type_map->emplace("Relu", map_relu); - node_map_info->node_type_map->emplace("Cast", map_cast); - node_map_info->node_type_map->emplace("Merge", map_merge); - - graph->SetExtAttr("NodeMapInfo", node_map_info); - - ge::GraphUtils::AddEdge(node_switch1->GetOutDataAnchor(0), node_relu1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch1->GetOutDataAnchor(1), node_relu2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu1->GetOutDataAnchor(0), node_merge1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_merge1->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge1->GetOutDataAnchor(0), node_cast1->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge1->GetOutDataAnchor(1), node_other1->GetInDataAnchor(0)); - - ge::GraphUtils::AddEdge(node_switch2->GetOutDataAnchor(0), node_relu3->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_switch2->GetOutDataAnchor(1), node_relu4->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu3->GetOutDataAnchor(0), node_merge2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_relu4->GetOutDataAnchor(0), node_merge2->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_merge2->GetOutDataAnchor(0), node_cast2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_merge2->GetOutDataAnchor(1), node_other2->GetInDataAnchor(0)); - - ge::GraphUtils::AddEdge(node_cast1->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_cast2->GetOutDataAnchor(0), node_netoutput->GetInDataAnchor(1)); - - return graph; - } -}; - -TEST_F(UTESTGraphFusionPass3, UTESTGraphFusionPass3_1) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph1(); - - SwapMergeCastFusionTestPass3 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - - EXPECT_EQ(fe::SUCCESS, status); - vector dim_a = {8, 4, 16, 16}; - vector dim_b = {1, 4, 64, 64}; - vector dim_c = {8, 4, 64, 64}; - for (ge::NodePtr node : graph->GetDirectNode()) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Merge") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - - EXPECT_EQ(op_desc->MutableInputDesc(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDims(), dim_b); - - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDims(), dim_c); - } - if (op_desc->GetType() == "Cast") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - } - if (op_desc->GetName() == "relu1") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_a); - } - if (op_desc->GetName() == "relu2") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_b); - } - } -} - -TEST_F(UTESTGraphFusionPass3, UTESTGraphFusionPass3_2) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph2(); - - Status status = fe::FAILED; - SwapMergeCastFusionTestPass3 pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - - EXPECT_EQ(fe::NOT_CHANGED, status); -} - -TEST_F(UTESTGraphFusionPass3, UTESTGraphFusionPass3_3) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph3(); - - Status status = fe::FAILED; - - SwapMergeCastFusionTestPass3 pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - EXPECT_EQ(fe::NOT_CHANGED, status); -} - -TEST_F(UTESTGraphFusionPass3, UTESTGraphFusionPass3_4) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph4(); - - Status status = fe::FAILED; - SwapMergeCastFusionTestPass3 pass; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(fe::NOT_CHANGED, status); -} - -TEST_F(UTESTGraphFusionPass3, UTESTGraphFusionPass3_5) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph5(); - Status status = fe::FAILED; - - SwapMergeCastFusionTestPass3 pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - - EXPECT_EQ(fe::SUCCESS, status); - vector dim_a = {8, 4, 16, 16}; - vector dim_b = {1, 4, 64, 64}; - vector dim_c = {8, 4, 64, 64}; - for (ge::NodePtr node : graph->GetDirectNode()) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Merge") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - - EXPECT_EQ(op_desc->MutableInputDesc(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDims(), dim_b); - - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDims(), dim_c); - } - if (op_desc->GetType() == "Cast") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - } - if (op_desc->GetName() == "relu1") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(node->GetOutControlAnchor()->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetType(), "NetOutput"); - } - if (op_desc->GetName() == "relu2") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(node->GetOutControlAnchor()->GetPeerInControlAnchors().at(0)->GetOwnerNode()->GetType(), "NetOutput"); - } - } -} - - -TEST_F(UTESTGraphFusionPass3, UTESTGraphFusionPass3_6) { - ComputeGraphPtr graph = CreateSwapMergeCastGraph6(); - - Status status = fe::FAILED; - - SwapMergeCastFusionTestPass3 pass; - - pass.SetName("test"); - status = pass.Run(*graph); - - - EXPECT_EQ(fe::SUCCESS, status); - vector dim_a = {8, 4, 16, 16}; - vector dim_b = {1, 4, 64, 64}; - vector dim_c = {8, 4, 64, 64}; - for (ge::NodePtr node : graph->GetDirectNode()) { - ge::OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc->GetType() == "Merge") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - - EXPECT_EQ(op_desc->MutableInputDesc(1)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDims(), dim_b); - - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDims(), dim_c); - - EXPECT_EQ(node->GetOutDataNodes().at(0)->GetType(), "NetOutput"); - } - if (op_desc->GetType() == "Cast") { - EXPECT_EQ(op_desc->MutableInputDesc(0)->GetDataType(), DT_FLOAT16); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->GetDataType(), DT_FLOAT); - } - if (op_desc->GetName() == "relu1" || op_desc->GetName() == "relu3") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_a); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_a); - } - if (op_desc->GetName() == "relu2" || op_desc->GetName() == "relu4") { - ge::NodePtr node_cast = node->GetOutDataAnchor(0)->GetPeerInDataAnchors().at(0)->GetOwnerNode(); - EXPECT_EQ(node_cast->GetType(), "Cast"); - ge::OpDescPtr op_desc_cast = node_cast->GetOpDesc(); - EXPECT_EQ(op_desc_cast->MutableInputDesc(0)->MutableShape().GetDims(), dim_b); - EXPECT_EQ(op_desc_cast->MutableOutputDesc(0)->MutableShape().GetDims(), dim_b); - } - } -} -} diff --git a/tests/ut/register/testcase/register_graph_fusion_4.cc b/tests/ut/register/testcase/register_graph_fusion_4.cc deleted file mode 100644 index 4a31737cbc152c355515f507b2a95f151c213b1b..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_graph_fusion_4.cc +++ /dev/null @@ -1,972 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/graph_optimizer/fusion_common/graph_pass_util.h" -#include "graph/debug/ge_log.h" -/** - * Input - * | - * A - * | - * B - * | \ \ - * C D E - * | - * NetOutput - */ -namespace fe { -using Mapping = std::map, std::vector, fe::CmpKey>; -using namespace ge; -class TestSetOutputsPass1 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -/** - * Input - * | - * A - * | - * B - * | \ - * C D - * | - * NetOutput - */ - -class TestSetOutputsPass2 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -class TestSetOutputsPass3 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -class TestSetOutputsPassFuzzy : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -class TestSetOutputsPassFuzzy2 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -class TestSetOutputsPassFuzzy3 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -class TestSetOutputsPassFuzzy5 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -class TestSetOutputsPassFuzzy6 : public GraphFusionPassBase { - protected: - vector DefinePatterns() override; - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override; -}; - -static const string TEST_SETOUPT_PASS_NAME = "TestSetOutputsFusionPass"; -static const string OP_A = "A"; -static const string OP_B = "B"; -static const string OP_C = "C"; -static const string OP_D = "D"; -static const string OP_E = "E"; -static const string OP_F = "F"; -static const string OP_G = "G"; -static const string TYPE_A = "TypeA"; -static const string TYPE_B = "TypeB"; -static const string TYPE_C = "TypeC"; -static const string TYPE_D = "TypeD"; -static const string TYPE_E = "TypeE"; -static const string TYPE_F = "TypeF"; -static const string TYPE_G = "TypeG"; - -vector TestSetOutputsPass1::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {TYPE_B}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_D, {TYPE_D}) - .AddOpDesc(OP_E, {TYPE_E}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{0, {OP_C.c_str()}}, {1, {OP_D.c_str(), OP_E.c_str()}}}) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPass1::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - -vector TestSetOutputsPass2::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {TYPE_B}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_D, {TYPE_D}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{0, OP_C.c_str()}, {1, OP_D.c_str()}}, false) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPass2::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - -vector TestSetOutputsPass3::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {TYPE_B}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_D, {TYPE_D}) - .AddOpDesc(OP_F, {TYPE_F}) - .SetOutputs(OP_A, {{0, OP_B.c_str()}, {1, OP_F.c_str()}}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{0, OP_C.c_str()}, {1, OP_D.c_str()}}, false) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPass3::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - -vector TestSetOutputsPassFuzzy::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {TYPE_B}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_D, {TYPE_D}) - .AddOpDesc(OP_E, {TYPE_E}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{kFuzzyOutIndex, {OP_C.c_str()}}, {kFuzzyOutIndex, {OP_D.c_str(), OP_E.c_str()}}}) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPassFuzzy::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - -vector TestSetOutputsPassFuzzy2::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {TYPE_B}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_D, {TYPE_D}) - .AddOpDesc(OP_E, {TYPE_E}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{kFuzzyOutIndex, {OP_C.c_str()}}, {kFuzzyOutIndex, {OP_D.c_str(), OP_E.c_str()}}, - {0, {OP_C.c_str()}}, - {2, {OP_E.c_str()}}}, true) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPassFuzzy2::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - -vector TestSetOutputsPassFuzzy3::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {TYPE_B}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_F, {TYPE_F}) - .AddOpDesc(OP_D, {TYPE_D}) - .AddOpDesc(OP_E, {TYPE_E}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{kFuzzyOutIndex, {OP_C.c_str()}}, {kFuzzyOutIndex, {OP_D.c_str(), OP_E.c_str()}}, - {2, {OP_F.c_str()}}}) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPassFuzzy3::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - - -vector TestSetOutputsPassFuzzy5::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_F, {TYPE_F}) - .AddOpDesc(OP_D, {TYPE_D}) - .AddOpDesc(OP_E, {TYPE_E}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{kFuzzyOutIndex, {OP_D.c_str(), OP_E.c_str()}}}, false) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPassFuzzy5::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - -vector TestSetOutputsPassFuzzy6::DefinePatterns() { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("TestSetOutputsFusionPattern"); - - pattern->AddOpDesc(OP_A, {TYPE_A}) - .AddOpDesc(OP_B, {}) - .AddOpDesc(OP_C, {TYPE_C}) - .AddOpDesc(OP_F, {TYPE_F}) - .AddOpDesc(OP_D, {TYPE_D}) - .AddOpDesc(OP_E, {TYPE_E}) - .AddOpDesc(OP_G, {TYPE_G}) - .SetInputs(OP_B, {OP_A}) - .SetOutputs(OP_B, {{kFuzzyOutIndex, {OP_D.c_str(), OP_E.c_str()}}}, false) - .SetOutputs(OP_D, {{kFuzzyOutIndex, OP_F.c_str()}}, false) - .SetOutputs(OP_F, {{kFuzzyOutIndex, OP_G.c_str()}}, false) - .SetInputs(OP_C, {OP_B}) - .SetOutput(OP_C); - - patterns.push_back(pattern); - - return patterns; -} - -Status TestSetOutputsPassFuzzy6::Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) { - return fe::SUCCESS; -} - -class UTESTGraphFusionPass4 : public testing::Test { - protected: - void SetUp() { - } - - void TearDown() { - - } - - protected: - static ge::ComputeGraphPtr CreateTestOutputGraph1() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph2() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_e->AddInputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph2_1() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph2_2() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_e2 = std::make_shared("E2", TYPE_E); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_e->AddInputDesc(tensor_desc_a); - op_desc_e2->AddInputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_e2 = graph->AddNode(op_desc_e2); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(2), node_e2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph2_3() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_e->AddInputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph2_4() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_c2 = std::make_shared("C2", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_e2 = std::make_shared("E2", TYPE_E); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - op_desc_c2->AddInputDesc(tensor_desc_a); - op_desc_c2->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_e->AddInputDesc(tensor_desc_a); - op_desc_e2->AddInputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_c2 = graph->AddNode(op_desc_c2); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_e2 = graph->AddNode(op_desc_e2); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_c2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(2), node_e2->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - - static ge::ComputeGraphPtr CreateTestOutputGraph2_5() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_f = std::make_shared("F", TYPE_F); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_e->AddInputDesc(tensor_desc_a); - op_desc_f->AddInputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_f = graph->AddNode(op_desc_f); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(2), node_f->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph2_6() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_f = std::make_shared("F", TYPE_F); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_e->AddInputDesc(tensor_desc_a); - op_desc_f->AddInputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_f = graph->AddNode(op_desc_f); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_f->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(2), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph2_7(bool success) { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_f = std::make_shared("F", TYPE_F); - ge::OpDescPtr op_desc_g = std::make_shared("G", TYPE_G); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_d->AddOutputDesc(tensor_desc_a); - - op_desc_e->AddInputDesc(tensor_desc_a); - op_desc_e->AddOutputDesc(tensor_desc_a); - - op_desc_f->AddInputDesc(tensor_desc_a); - op_desc_f->AddOutputDesc(tensor_desc_a); - - op_desc_g->AddInputDesc(tensor_desc_a); - op_desc_g->AddOutputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_f = graph->AddNode(op_desc_f); - ge::NodePtr node_g = graph->AddNode(op_desc_g); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(1)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_d->GetOutDataAnchor(0), node_f->GetInDataAnchor(0)); - if (success) { - ge::GraphUtils::AddEdge(node_f->GetOutDataAnchor(0), node_g->GetInDataAnchor(0)); - } - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(2), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - static ge::ComputeGraphPtr CreateTestOutputGraph3() { - ge::ComputeGraphPtr graph = std::make_shared("test1"); - ge::OpDescPtr op_desc_a = std::make_shared("A", TYPE_A); - ge::OpDescPtr op_desc_b = std::make_shared("B", TYPE_B); - ge::OpDescPtr op_desc_c = std::make_shared("C", TYPE_C); - ge::OpDescPtr op_desc_d = std::make_shared("D", TYPE_D); - ge::OpDescPtr op_desc_e = std::make_shared("E", TYPE_E); - ge::OpDescPtr op_desc_f = std::make_shared("F", TYPE_F); - ge::OpDescPtr op_desc_out = std::make_shared("NetOut", "NetOut"); - - //add descriptor - vector dim_a = {8, 4, 16, 16}; - GeShape shape_a(dim_a); - GeTensorDesc tensor_desc_a(shape_a); - tensor_desc_a.SetFormat(FORMAT_NCHW); - tensor_desc_a.SetOriginFormat(FORMAT_NCHW); - tensor_desc_a.SetDataType(DT_FLOAT16); - tensor_desc_a.SetOriginDataType(DT_FLOAT); - - op_desc_a->AddOutputDesc(tensor_desc_a); - op_desc_a->AddOutputDesc(tensor_desc_a); - - op_desc_f->AddInputDesc(tensor_desc_a); - - op_desc_b->AddInputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - op_desc_b->AddOutputDesc(tensor_desc_a); - - op_desc_c->AddInputDesc(tensor_desc_a); - op_desc_c->AddOutputDesc(tensor_desc_a); - - op_desc_d->AddInputDesc(tensor_desc_a); - op_desc_e->AddInputDesc(tensor_desc_a); - - op_desc_out->AddInputDesc(tensor_desc_a); - - ge::NodePtr node_a = graph->AddNode(op_desc_a); - ge::NodePtr node_b = graph->AddNode(op_desc_b); - ge::NodePtr node_c = graph->AddNode(op_desc_c); - ge::NodePtr node_d = graph->AddNode(op_desc_d); - ge::NodePtr node_e = graph->AddNode(op_desc_e); - ge::NodePtr node_f = graph->AddNode(op_desc_f); - ge::NodePtr node_out = graph->AddNode(op_desc_out); - - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), node_b->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(1), node_f->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(0), node_c->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_d->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_b->GetOutDataAnchor(1), node_e->GetInDataAnchor(0)); - ge::GraphUtils::AddEdge(node_c->GetOutDataAnchor(0), node_out->GetInDataAnchor(0)); - - return graph; - } - - }; - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPass4_1) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph1(); - TestSetOutputsPass1 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::NOT_CHANGED); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPass4_2) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2(); - TestSetOutputsPass1 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPass4_3) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2(); - TestSetOutputsPass2 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPass4_4) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph3(); - TestSetOutputsPass3 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2(); - TestSetOutputsPassFuzzy pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_1) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_1(); - TestSetOutputsPassFuzzy pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::NOT_CHANGED); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_2) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_2(); - TestSetOutputsPassFuzzy2 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::NOT_CHANGED); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_3) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_3(); - TestSetOutputsPassFuzzy2 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::NOT_CHANGED); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_4) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_4(); - TestSetOutputsPassFuzzy2 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} - - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_5) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_5(); - TestSetOutputsPassFuzzy3 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_6) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_6(); - TestSetOutputsPassFuzzy5 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_7) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_7(false); - TestSetOutputsPassFuzzy6 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::NOT_CHANGED); -} - -TEST_F(UTESTGraphFusionPass4, UTESTGraphFusionPassFuzzy_8) { - ge::ComputeGraphPtr graph = CreateTestOutputGraph2_7(true); - TestSetOutputsPassFuzzy6 pass; - Status status = fe::FAILED; - - pass.SetName("test"); - status = pass.Run(*graph); - EXPECT_EQ(status, fe::SUCCESS); -} -} diff --git a/tests/ut/register/testcase/register_graph_fusion_v2_utest.cc b/tests/ut/register/testcase/register_graph_fusion_v2_utest.cc deleted file mode 100644 index 9ecf36f949548af713407fc6ea1f4c528ac402a0..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_graph_fusion_v2_utest.cc +++ /dev/null @@ -1,271 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "graph/compute_graph.h" -#include "graph/graph.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" - -#include "register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h" -#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" -#include "register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h" -#include "register/pass_option_utils.h" -#include "graph/debug/ge_log.h" -#include "graph/ge_local_context.h" -#include "register/optimization_option_registry.h" - -using namespace testing; -using namespace ge; -using namespace fe; - -namespace fe { -namespace graph_fusion_reg_v2 { - -class UTestGraphFusionPassReg : public testing::Test { - public: - - protected: - - void SetUp() { - oopt.Initialize({}, {}); - } - - void TearDown() { - } - - ge::OptimizationOption &oopt = ge::GetThreadLocalContext().GetOo(); - - const std::unordered_map ®istered_opt_table = - ge::OptionRegistry::GetInstance().GetRegisteredOptTable(); -}; - - -const char *kOpTypeCast = "Cast"; -const char *kOpTypeRelu = "Relu"; - -const char *kPatternCast0 = "Cast0"; -const char *kPatternCast1 = "Cast1"; -const char *kPatternRelu = "Relu"; -#define UT_CHECK(cond, log_func, return_expr) \ - do { \ - if (cond) { \ - log_func; \ - return_expr; \ - } \ - } while (0) - -#define UT_CHECK_NOTNULL(val) \ - do { \ - if ((val) == nullptr) { \ - GELOGD("Parameter[%s] must not be null.", #val); \ - return fe::PARAM_INVALID; \ - } \ - } while (0) - -string pass_name_test = "CastCastFusionPass"; -string pass_name_test1 = "CastCastFusionPass1"; -class TestPass : public fe::PatternFusionBasePass { - using Mapping = std::map, std::vector, fe::CmpKey>; - protected: - - vector DefinePatterns() override { - vector patterns; - - FusionPattern *pattern = new(std::nothrow) FusionPattern("CastCastFusionPass"); - UT_CHECK(pattern == nullptr, GELOGD(" Fail to create a new pattern object."), - return patterns); - pattern->AddOpDesc(kPatternCast0, {kOpTypeCast}) - .AddOpDesc(kPatternRelu, {kOpTypeRelu}) - .AddOpDesc(kPatternCast1, {kOpTypeCast}) - .SetInputs(kPatternRelu, {kPatternCast0}) - .SetInputs(kPatternCast1, {kPatternRelu}) - .SetOutput(kPatternCast1); - - patterns.push_back(pattern); - - return patterns; - } - - Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, vector &new_nodes) override { - FusionPattern pattern("CastCastFusionPass"); - DumpMapping(pattern, mapping); - - CheckGraphCycle(graph); - ge::NodePtr cast_Node0 = GetNodeFromMapping(kPatternCast0, mapping); - CheckOpSupported(cast_Node0); - CheckOpSupported(cast_Node0->GetOpDesc()); - CheckAccuracySupported(cast_Node0); - - UT_CHECK(cast_Node0 == nullptr, GELOGD("cast_Node0 is null,fusion failed."), - return NOT_CHANGED); - ge::OpDescPtr cast_desc0 = cast_Node0->GetOpDesc(); - UT_CHECK(cast_desc0 == nullptr, GELOGD("cast_Node0's Desc is null,fusion failed."), - return NOT_CHANGED); - - ge::NodePtr relu_Node = GetNodeFromMapping(kPatternRelu, mapping); - UT_CHECK(relu_Node == nullptr, GELOGD("relu_Node is null,fusion failed."), - return NOT_CHANGED); - ge::OpDescPtr relu_desc = relu_Node->GetOpDesc(); - UT_CHECK(cast_desc0 == nullptr, GELOGD("relu_Node's Desc is null,fusion failed."), - return NOT_CHANGED); - - auto relu_input = relu_desc->MutableInputDesc(0); - UT_CHECK_NOTNULL(relu_input); - auto relu_input_desc_dtype = relu_input->GetDataType(); - - auto relu_output = relu_desc->MutableOutputDesc(0); - UT_CHECK_NOTNULL(relu_output); - auto relu_output_desc_dtype = relu_output->GetDataType(); - if (relu_input_desc_dtype != DT_FLOAT || relu_output_desc_dtype != DT_FLOAT) { - GELOGD("Relu node [%s]'s input dtype or output dtype is unsuitable", relu_desc->GetName().c_str()); - return NOT_CHANGED; - } - - ge::NodePtr cast_Node1 = GetNodeFromMapping(kPatternCast1, mapping); - UT_CHECK(cast_Node1 == nullptr, GELOGD("cast_Node1 is null,fusion failed."), - return NOT_CHANGED); - ge::OpDescPtr cast_desc1 = cast_Node1->GetOpDesc(); - UT_CHECK(cast_desc0 == nullptr, GELOGD("cast_Node1's Desc is null,fusion failed."), - return NOT_CHANGED); - - auto cast0_input = cast_desc0->MutableInputDesc(0); - UT_CHECK_NOTNULL(cast0_input); - DataType cast0_in_d_type = cast0_input->GetDataType(); - auto cast1_output = cast_desc1->MutableOutputDesc(0); - UT_CHECK_NOTNULL(cast1_output); - DataType cast1_out_d_type = cast1_output->GetDataType(); - if (cast0_in_d_type != cast1_out_d_type) { - GELOGD("Cast Node0 [%s] input data type is not equal to Cast Node1 [%s] output data type ", - cast_Node0->GetName().c_str(), cast_Node1->GetName().c_str()); - return NOT_CHANGED; - } - - auto cast0_out_data_anchor = cast_Node0->GetOutDataAnchor(0); - UT_CHECK_NOTNULL(cast0_out_data_anchor); - if (cast0_out_data_anchor->GetPeerInDataAnchors().size() > 1) { - GELOGD("The first output anchor of Cast node[%s] has more than one peer in anchor.", - cast_Node0->GetName().c_str()); - return NOT_CHANGED; - } - - auto relu_out_data_anchor = relu_Node->GetOutDataAnchor(0); - UT_CHECK_NOTNULL(relu_out_data_anchor); - if (relu_out_data_anchor->GetPeerInDataAnchors().size() > 1) { - for (auto node : relu_Node->GetOutAllNodes()) { - if (node->GetType() != "Cast") { - GELOGD("The output anchor of Relu node has not Cast node,name is [%s] Type is [%s].", - node->GetName().c_str(), node->GetType().c_str()); - return NOT_CHANGED; - } - auto node_desc = node->GetOpDesc(); - UT_CHECK_NOTNULL(node_desc); - auto in_dtype = node_desc->MutableInputDesc(0)->GetDataType(); - auto out_dtype = node_desc->MutableOutputDesc(0)->GetDataType(); - if (in_dtype != DT_FLOAT || out_dtype != DT_FLOAT16) { - GELOGD("The Cast node [%s]'s indatatype is not equal to DT_FLOAT or outdatatype is not equal to DT_FLOAT16.", - node->GetName().c_str()); - return NOT_CHANGED; - } - } - } - - ge::ComputeGraphPtr graphPtr = relu_Node->GetOwnerComputeGraph(); - UT_CHECK_NOTNULL(graphPtr); - if (GraphUtils::IsolateNode(cast_Node0, {0}) != GRAPH_SUCCESS) { - GELOGD("Isolate op:%s(%s) failed", cast_Node0->GetName().c_str(), cast_Node0->GetType().c_str()); - return FAILED; - } - if (GraphUtils::RemoveNodeWithoutRelink(graphPtr, cast_Node0) != GRAPH_SUCCESS) { - GELOGD("[Remove][Node] %s, type:%s without relink in graph:%s failed", - cast_Node0->GetName().c_str(), cast_Node0->GetType().c_str(), graph.GetName().c_str()); - return FAILED; - } - for (auto inAnchor : relu_out_data_anchor->GetPeerInDataAnchors()) { - auto node = inAnchor->GetOwnerNode(); - UT_CHECK_NOTNULL(node); - if (GraphUtils::IsolateNode(node, {0}) != GRAPH_SUCCESS) { - GELOGD("Isolate op:%s(%s) failed", node->GetName().c_str(), node->GetType().c_str()); - return FAILED; - } - if (GraphUtils::RemoveNodeWithoutRelink(graphPtr, node) != GRAPH_SUCCESS) { - GELOGD("[Remove][Node] %s, type:%s without relink in graph:%s failed", - node->GetName().c_str(), node->GetType().c_str(), graph.GetName().c_str()); - return FAILED; - } - } - relu_desc->MutableInputDesc(0)->SetDataType(cast0_in_d_type); - relu_desc->MutableOutputDesc(0)->SetDataType(cast1_out_d_type); - new_nodes.push_back(relu_Node); - return SUCCESS; - } -}; - -TEST_F(UTestGraphFusionPassReg, test_case_01) { - REG_PASS(pass_name_test, GRAPH_FUSION_PASS_TYPE_RESERVED, TestPass, 0); - REG_PASS("", BUILT_IN_GRAPH_PASS, TestPass, 1); - REG_PASS(pass_name_test, BUILT_IN_GRAPH_PASS, TestPass, 2); - REG_PASS(pass_name_test, BUILT_IN_GRAPH_PASS, TestPass, 3); - std::map pass_desc = - FusionPassRegistry::GetInstance().GetPassDesc(SECOND_ROUND_BUILT_IN_GRAPH_PASS); - EXPECT_EQ(pass_desc.size(), 0); - pass_desc = - FusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_GRAPH_PASS); - EXPECT_EQ(pass_desc.size(), 1); - - EXPECT_EQ(pass_desc[pass_name_test].attr, 3); - - auto create_fn1 = pass_desc[pass_name_test].create_fn; - auto pattern_fusion_base_pass_ptr = std::unique_ptr( - dynamic_cast(create_fn1())); - auto patterns = pattern_fusion_base_pass_ptr->DefinePatterns(); - EXPECT_EQ(patterns.size(), 1); - for (auto pattern : patterns) { - if (pattern != nullptr) { - delete pattern; - pattern = nullptr; - } - } -} - -TEST_F(UTestGraphFusionPassReg, test_compile_level_case_01) { - auto pass_desc = FusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_GRAPH_PASS); - EXPECT_EQ(pass_desc.size(), 1); - - bool enable_flag = true; - auto ret = ge::PassOptionUtils::CheckIsPassEnabled(pass_name_test, enable_flag); - EXPECT_NE(ret, ge::GRAPH_SUCCESS); - - REG_PASS(pass_name_test, BUILT_IN_GRAPH_PASS, TestPass, COMPILE_LEVEL_O0 | COMPILE_LEVEL_O1 | COMPILE_LEVEL_O2); - REG_PASS(pass_name_test1, BUILT_IN_GRAPH_PASS, TestPass, COMPILE_LEVEL_O3); - - std::map ge_options = {{ge::OO_LEVEL, "O1"}, - {"ge.constLifecycle", "graph"}, - {"ge.oo.test_graph_fusion", "true"}, - {"ge.oo.test_graph_fusion_add_relu", "false"}}; - EXPECT_EQ(oopt.Initialize(ge_options, registered_opt_table), GRAPH_SUCCESS); - - pass_desc = FusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_GRAPH_PASS); - EXPECT_EQ(pass_desc.size(), 2); - enable_flag = true; - ret = ge::PassOptionUtils::CheckIsPassEnabled(pass_name_test, enable_flag); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_TRUE(enable_flag == true); - - pass_desc = FusionPassRegistry::GetInstance().GetPassDesc(BUILT_IN_GRAPH_PASS); - EXPECT_EQ(pass_desc.size(), 2); - enable_flag = true; - ret = ge::PassOptionUtils::CheckIsPassEnabled(pass_name_test1, enable_flag); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); - EXPECT_EQ(enable_flag, false); -} - -} -} diff --git a/tests/ut/register/testcase/register_op_tiling_v1_ut.cc b/tests/ut/register/testcase/register_op_tiling_v1_ut.cc deleted file mode 100644 index 06d2ed8a516fe0cd2addf037557a647e7cf0b512..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_op_tiling_v1_ut.cc +++ /dev/null @@ -1,305 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/compute_graph.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "register/op_tiling_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "op_tiling.h" - -using namespace std; -using namespace ge; - -namespace optiling { -class RegisterOpTilingV1UT : public testing::Test { -protected: - void SetUp() {} - - void TearDown() {} -}; -bool op_tiling_stub_v1(const TeOpParas &op_paras, const OpCompileInfo &compile_info, OpRunInfo &run_info) { - return true; -} - -static string parse_int(const std::stringstream& tiling_data) { - auto data = tiling_data.str(); - string result; - int32_t tmp = 0; - for (size_t i = 0; i < data.length(); i += sizeof(int32_t)) { - memcpy(&tmp, data.c_str() + i, sizeof(tmp)); - result += std::to_string(tmp); - result += " "; - } - - return result; -} - -static string parse_int(void *const addr_base, const uint64_t size) { - char *data = reinterpret_cast(addr_base); - string result; - int32_t tmp = 0; - for (size_t i = 0; i < size; i += sizeof(int32_t)) { - memcpy(&tmp, data + i, sizeof(tmp)); - result += std::to_string(tmp); - result += " "; - } - - return result; -} - -REGISTER_OP_TILING(ReluV1, op_tiling_stub_v1); -//REGISTER_OP_TILING(DynamicAtomicAddrClean, op_tiling_stub_v1); - -TEST_F(RegisterOpTilingV1UT, op_para_calculate_v1_1) { - OpDescPtr op_desc = make_shared("relu", "ReluV1"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - EXPECT_EQ(op_desc->GetTilingFuncInfo(), nullptr); - auto op = OpDescUtils::CreateOperatorFromNode(node); - utils::OpRunInfo run_info; - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto &op_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - auto iter = op_func_map.find("ReluV1"); - EXPECT_NE(iter, op_func_map.end()); - EXPECT_NE(&(iter->second), nullptr); - EXPECT_EQ(op_desc->GetTilingFuncInfo(), reinterpret_cast(&(iter->second))); - EXPECT_EQ(OpParaCalculateV2(op, run_info), GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV1UT, op_para_calculate_v1_2) { - OpDescPtr op_desc = make_shared("relu", "ReluVVV"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - - OpTilingFuncInfo op_func_info(OP_TYPE_AUTO_TILING); - op_func_info.tiling_func_ = op_tiling_stub_v1; - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - tiling_func_map.emplace(OP_TYPE_AUTO_TILING, op_func_info); - EXPECT_EQ(op_desc->GetTilingFuncInfo(), nullptr); - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto &op_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - auto iter = op_func_map.find(OP_TYPE_AUTO_TILING); - EXPECT_NE(iter, op_func_map.end()); - EXPECT_NE(&(iter->second), nullptr); - EXPECT_EQ(op_desc->GetTilingFuncInfo(), reinterpret_cast(&(iter->second))); - tiling_func_map.erase(OP_TYPE_AUTO_TILING); -} - -TEST_F(RegisterOpTilingV1UT, op_para_calculate_v1_3) { - OpDescPtr op_desc = make_shared("relu", "ReluV1"); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV1UT, op_para_calculate_v1_4) { - OpDescPtr op_desc = make_shared("relu", "ReluV1"); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - vector depend_names = {"x"}; - AttrUtils::SetListStr(op_desc, "_op_infer_depends", depend_names); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV1UT, op_atomic_calculate_v1_1) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - tensor_desc.SetOriginShape(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {0}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_ = op_tiling_stub_v1; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - utils::OpRunInfo run_info; - EXPECT_EQ(op_desc->GetAtomicTilingFuncInfo(), nullptr); - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - auto &op_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - auto iter = op_func_map.find(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - EXPECT_NE(iter, op_func_map.end()); - EXPECT_NE(&(iter->second), nullptr); - EXPECT_EQ(op_desc->GetAtomicTilingFuncInfo(), reinterpret_cast(&(iter->second))); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV1UT, op_atomic_calculate_v1_2) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_ = op_tiling_stub_v1; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV1UT, op_atomic_calculate_v1_3) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {1}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_ = op_tiling_stub_v1; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV1UT, add_tiling_data) { - utils::OpRunInfo run_info; - - int32_t data1 = 123; - int32_t data2 = 456; - run_info.AddTilingData(data1); - run_info.AddTilingData(data2); - run_info.SetClearAtomic(false); - run_info.SetTilingKey(100); - run_info.SetTilingCond(-1); - std::vector workspace{1, 2, 3}; - run_info.SetWorkspaces(workspace); - std::string get_data = parse_int(run_info.GetAllTilingData()); - EXPECT_EQ(get_data, "123 456 "); - EXPECT_EQ(run_info.GetClearAtomic(), false); - EXPECT_EQ(run_info.GetTilingKey(), 100); - EXPECT_EQ(run_info.GetWorkspaceNum(), 3); - EXPECT_EQ(run_info.GetAllWorkspaces(), workspace); - EXPECT_EQ(run_info.GetTilingCond(), -1); - - run_info.GetAllTilingData().str(""); - run_info.AddTilingData(reinterpret_cast(&data2), sizeof(int)); - get_data = parse_int(run_info.GetAllTilingData()); - EXPECT_EQ(get_data, "456 "); - - char arg_base[20] = {0}; - run_info.ResetWorkspace(); - run_info.ResetAddrBase(arg_base, sizeof(arg_base)); - data1 = 78; - data2 = 90; - run_info.AddTilingData(data1); - run_info.AddTilingData(data2); - run_info.AddWorkspace(4); - run_info.AddWorkspace(5); - get_data = parse_int(arg_base, sizeof(int) * 2); - EXPECT_EQ(get_data, "78 90 "); - workspace = std::vector{4, 5}; - EXPECT_EQ(run_info.GetAllWorkspaces(), workspace); - - run_info.ResetWorkspace(); - run_info.ResetAddrBase(arg_base, sizeof(arg_base)); - run_info.AddTilingData(reinterpret_cast(&data2), sizeof(int)); - run_info.AddWorkspace(6); - get_data = parse_int(arg_base, sizeof(int)); - EXPECT_EQ(get_data, "90 "); - workspace = std::vector{6}; - EXPECT_EQ(run_info.GetAllWorkspaces(), workspace); -} -} diff --git a/tests/ut/register/testcase/register_op_tiling_v2_ut.cc b/tests/ut/register/testcase/register_op_tiling_v2_ut.cc deleted file mode 100644 index 3abc1df691238c2cf75329c8683fd4eb6263edb8..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_op_tiling_v2_ut.cc +++ /dev/null @@ -1,260 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/compute_graph.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "register/op_tiling_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "op_tiling.h" - -using namespace std; -using namespace ge; - -namespace optiling { -class RegisterOpTilingV2UT : public testing::Test { -protected: - void SetUp() {} - - void TearDown() {} -}; - -bool op_tiling_stub_v2(const Operator &op, const utils::OpCompileInfo &compile_info, utils::OpRunInfo &run_info) { - return true; -} - -REGISTER_OP_TILING_V2(ReluV2, op_tiling_stub_v2); - -TEST_F(RegisterOpTilingV2UT, replace_and_recovery_tensor_1) { - OpDescPtr op_desc = make_shared("relu", "ReluV2"); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - - std::vector indexes; - ReplaceEmptyShapeOfTensorDesc(op_desc, indexes); - std::vector expect_indexes = {0, 1, -1}; - EXPECT_EQ(indexes, expect_indexes); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDimNum(), 1); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDimNum(), 1); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDimNum(), 1); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDim(0), 1); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDim(0), 1); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDim(0), 1); - - RecoveryEmptyShapeOfTensorDesc(op_desc, indexes); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDimNum(), 0); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDimNum(), 0); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDimNum(), 0); -} - -TEST_F(RegisterOpTilingV2UT, replace_and_recovery_tensor_2) { - OpDescPtr op_desc = make_shared("relu", "ReluV2"); - GeShape shape({4,3,16,16}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - - std::vector indexes; - ReplaceEmptyShapeOfTensorDesc(op_desc, indexes); - EXPECT_EQ(indexes.size(), 0); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDimNum(), 4); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDimNum(), 4); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDimNum(), 4); - - RecoveryEmptyShapeOfTensorDesc(op_desc, indexes); - EXPECT_EQ(op_desc->MutableInputDesc(0)->MutableShape().GetDimNum(), 4); - EXPECT_EQ(op_desc->MutableInputDesc(1)->MutableShape().GetDimNum(), 4); - EXPECT_EQ(op_desc->MutableOutputDesc(0)->MutableShape().GetDimNum(), 4); -} - -TEST_F(RegisterOpTilingV2UT, op_para_calculate_v2_1) { - OpDescPtr op_desc = make_shared("relu", "ReluV2"); - GeShape shape({4,3,16,16}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV2UT, op_para_calculate_v2_2) { - OpDescPtr op_desc = make_shared("relu", "ReluVV"); - GeShape shape({4,3,16,16}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - - OpTilingFuncInfo op_func_info(OP_TYPE_AUTO_TILING); - op_func_info.tiling_func_v2_ = op_tiling_stub_v2; - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - tiling_func_map.emplace(OP_TYPE_AUTO_TILING, op_func_info); - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_AUTO_TILING); -} - -TEST_F(RegisterOpTilingV2UT, op_para_calculate_v2_3) { - OpDescPtr op_desc = make_shared("relu", "ReluV2"); - GeShape shape({4,3,16,16}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - vector depend_names = {"x"}; - AttrUtils::SetListStr(op_desc, "_op_infer_depends", depend_names); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV2UT, op_para_calculate_v2_4) { - OpDescPtr op_desc = make_shared("relu", "ReluV2"); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - vector depend_names = {"x"}; - AttrUtils::SetListStr(op_desc, "_op_infer_depends", depend_names); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV2UT, op_atomic_calculate_v2_1) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {0}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v2_ = op_tiling_stub_v2; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV2UT, op_atomic_calculate_v2_2) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v2_ = op_tiling_stub_v2; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV2UT, op_atomic_calculate_v2_3) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {1}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v2_ = op_tiling_stub_v2; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} -} diff --git a/tests/ut/register/testcase/register_op_tiling_v3_ut.cc b/tests/ut/register/testcase/register_op_tiling_v3_ut.cc deleted file mode 100644 index efa1c255b54781be5747d397c5729b06e170e885..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_op_tiling_v3_ut.cc +++ /dev/null @@ -1,332 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/compute_graph.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "register/op_tiling_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "op_tiling/op_compile_info_manager.h" -#include "op_tiling.h" - -using namespace std; -using namespace ge; - -namespace optiling { -class RegisterOpTilingV3UT : public testing::Test { -protected: - void SetUp() {} - - void TearDown() {} -}; -bool op_tiling_stub_v3(const Operator &op, const void* value, OpRunInfoV2 &run_info) { - return true; -} - -void* op_parse_stub_v3(const Operator &op, const ge::AscendString &compile_info_json) { -// static void *p = new int(3); - static int x = 1024; - void *p = &x; - return p; -} - -void* op_parse_stub_null_v3(const Operator &op, const ge::AscendString &compile_info_json) { - return nullptr; -} - -void* op_parse_stub_dt1_v3(const Operator &op, const ge::AscendString &compile_info_json) { -// static void *p = new int(3); - static int x = 1024; - void *p = &x; - return p; -} - -void* op_parse_stub_dt2_v3(const Operator &op, const ge::AscendString &compile_info_json) { -// static void *p = new int(3); - static int x = 1024; - void *p = &x; - return p; -} - -void* op_parse_stub_dt3_v3(const Operator &op, const ge::AscendString &compile_info_json) { -// static void *p = new int(3); - static int x = 1024; - void *p = &x; - return p; -} - -void* op_parse_stub_dt4_v3(const Operator &op, const ge::AscendString &compile_info_json) { -// static void *p = new int(3); - static int x = 1024; - void *p = &x; - return p; -} - -REGISTER_OP_TILING_V3(ReluV3, op_tiling_stub_v3, op_parse_stub_v3); -REGISTER_OP_TILING_V3(ReluNullV3, op_tiling_stub_v3, op_parse_stub_null_v3); - -TEST_F(RegisterOpTilingV3UT, op_para_calculate_v3_1) { - OpDescPtr op_desc = make_shared("relu", "ReluV3"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - bool flag = CompileInfoCache::Instance().HasCompileInfo("compile_info_key"); - EXPECT_EQ(flag, true); - - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV3UT, op_para_calculate_v3_2) { - OpDescPtr op_desc = make_shared("relu", "ReluVV3"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - - OpTilingFuncInfo op_func_info(OP_TYPE_AUTO_TILING); - op_func_info.tiling_func_v3_ = op_tiling_stub_v3; - op_func_info.parse_func_v3_ = op_parse_stub_v3; - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - tiling_func_map.emplace(OP_TYPE_AUTO_TILING, op_func_info); - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_AUTO_TILING); -} - -TEST_F(RegisterOpTilingV3UT, op_para_calculate_v3_3) { - OpDescPtr op_desc = make_shared("relu", "ReluV3"); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - vector depend_names = {"x"}; - AttrUtils::SetListStr(op_desc, "_op_infer_depends", depend_names); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV3UT, op_para_calculate_v3_4) { - OpDescPtr op_desc = make_shared("relu", "ReluNullV3"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key_null"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(RegisterOpTilingV3UT, op_atomic_calculate_v3_1) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {0}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v3_ = op_tiling_stub_v3; - op_func_info.parse_func_v3_ = op_parse_stub_dt1_v3; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV3UT, op_atomic_calculate_v3_2) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v3_ = op_tiling_stub_v3; - op_func_info.parse_func_v3_ = op_parse_stub_dt2_v3; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV3UT, op_atomic_calculate_v3_3) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {1}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v3_ = op_tiling_stub_v3; - op_func_info.parse_func_v3_ = op_parse_stub_dt3_v3; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV3UT, op_atomic_calculate_v3_4) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::map temp_map = {{1,1}, {2,2}, {3,3}}; - std::map> atomic_workspace_info = {{"qwer", temp_map}}; - op_desc->SetExtAttr(ge::EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_info); - std::vector workspace_bytes = {1,2,3,4}; - op_desc->SetWorkspaceBytes(workspace_bytes); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v3_ = op_tiling_stub_v3; - op_func_info.parse_func_v3_ = op_parse_stub_dt4_v3; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV3UT, op_atomic_calculate_v3_5) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key_null"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {0}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v3_ = op_tiling_stub_v3; - op_func_info.parse_func_v3_ = op_parse_stub_null_v3; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -} diff --git a/tests/ut/register/testcase/register_op_tiling_v4_ut.cc b/tests/ut/register/testcase/register_op_tiling_v4_ut.cc deleted file mode 100644 index 318b3a20f867c4cbcddf669b68b90d07ec622c93..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_op_tiling_v4_ut.cc +++ /dev/null @@ -1,311 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include "graph/ge_tensor.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/op_desc.h" -#include "graph/operator.h" -#include "graph/compute_graph.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "register/op_tiling_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "op_tiling/op_compile_info_manager.h" -#include "op_tiling.h" - -using namespace std; -using namespace ge; - -namespace optiling { -class CompileInfoJson : public CompileInfoBase { -public: - CompileInfoJson(const std::string &json) : json_str_(json) {} - ~CompileInfoJson() {} -private: - std::string json_str_; -}; - -class RegisterOpTilingV4UT : public testing::Test { -protected: - void SetUp() {} - - void TearDown() {} -}; -bool op_tiling_stub_v4(const Operator &op, const CompileInfoPtr value, OpRunInfoV2 &run_info) { - return true; -} - -CompileInfoPtr op_parse_stub_v4(const Operator &op, const ge::AscendString &compile_info_json) { -// static void *p = new int(3); - CompileInfoPtr info = std::make_shared("qwer"); - return info; -} - -CompileInfoPtr op_parse_stub_null_v4(const Operator &op, const ge::AscendString &compile_info_json) { - return nullptr; -} - -REGISTER_OP_TILING_V4(ReluV4, op_tiling_stub_v4, op_parse_stub_v4); -REGISTER_OP_TILING_V4(ReluNullV4, op_tiling_stub_v4, op_parse_stub_null_v4); - -TEST_F(RegisterOpTilingV4UT, op_para_calculate_v4_1) { - OpDescPtr op_desc = make_shared("relu", "ReluV4"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - bool flag = CompileInfoManager::Instance().HasCompileInfo("compile_info_key"); - EXPECT_EQ(flag, true); - - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV4UT, op_para_calculate_v4_2) { - OpDescPtr op_desc = make_shared("relu", "ReluVV3"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - - OpTilingFuncInfo op_func_info(OP_TYPE_AUTO_TILING); - op_func_info.tiling_func_v4_ = op_tiling_stub_v4; - op_func_info.parse_func_v4_ = op_parse_stub_v4; - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - tiling_func_map.emplace(OP_TYPE_AUTO_TILING, op_func_info); - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_AUTO_TILING); -} - -TEST_F(RegisterOpTilingV4UT, op_para_calculate_v4_3) { - OpDescPtr op_desc = make_shared("relu", "ReluV4"); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - vector depend_names = {"x"}; - AttrUtils::SetListStr(op_desc, "_op_infer_depends", depend_names); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingV4UT, op_para_calculate_v4_4) { - OpDescPtr op_desc = make_shared("relu", "ReluNullV4"); - GeShape shape({4,3,14,14}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key_null"; - string compile_info_json = "compile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_jsoncompile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - utils::OpRunInfo run_info; - auto op = OpDescUtils::CreateOperatorFromNode(node); - graphStatus ret = OpParaCalculateV2(op, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); -} - -TEST_F(RegisterOpTilingV4UT, op_atomic_calculate_v4_1) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {0}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v4_ = op_tiling_stub_v4; - op_func_info.parse_func_v4_ = op_parse_stub_v4; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - - ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV4UT, op_atomic_calculate_v4_2) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v4_ = op_tiling_stub_v4; - op_func_info.parse_func_v4_ = op_parse_stub_v4; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV4UT, op_atomic_calculate_v4_3) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {1}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v4_ = op_tiling_stub_v4; - op_func_info.parse_func_v4_ = op_parse_stub_v4; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV4UT, op_atomic_calculate_v4_4) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::map temp_map = {{1,1}, {2,2}, {3,3}}; - std::map> atomic_workspace_info = {{"qwer", temp_map}}; - op_desc->SetExtAttr(ge::EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_info); - std::vector workspace_bytes = {1,2,3,4}; - op_desc->SetWorkspaceBytes(workspace_bytes); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v4_ = op_tiling_stub_v4; - op_func_info.parse_func_v4_ = op_parse_stub_v4; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_SUCCESS); - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -TEST_F(RegisterOpTilingV4UT, op_atomic_calculate_v4_5) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape; - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - string compile_info_key = "compile_info_key_null"; - string compile_info_json = "{\"_workspace_size_list\":[]}"; - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, ATOMIC_COMPILE_INFO_JSON, compile_info_json); - std::vector atomic_output_indices = {0}; - (void) ge::AttrUtils::SetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); - - ComputeGraphPtr graph = make_shared("test"); - NodePtr node = graph->AddNode(op_desc); - - std::unordered_map &tiling_func_map = OpTilingFuncRegistry::RegisteredOpFuncInfo(); - OpTilingFuncInfo op_func_info(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - op_func_info.tiling_func_v4_ = op_tiling_stub_v4; - op_func_info.parse_func_v4_ = op_parse_stub_null_v4; - tiling_func_map.emplace(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN, op_func_info); - - utils::OpRunInfo run_info; - graphStatus ret = OpAtomicCalculateV2(*node, run_info); - EXPECT_EQ(ret, GRAPH_FAILED); - - tiling_func_map.erase(OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); -} - -} diff --git a/tests/ut/register/testcase/register_optiling_unittest.cc b/tests/ut/register/testcase/register_optiling_unittest.cc deleted file mode 100644 index 80504a954fb89873035d49c189949a961079ebf6..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_optiling_unittest.cc +++ /dev/null @@ -1,352 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "register/op_tiling_registry.h" -#include "op_tiling/op_tiling.cc" -#include "common/sgt_slice_type.h" -#include "graph_builder_utils.h" -using namespace std; -using namespace ge; -using namespace ffts; -namespace optiling { -using ByteBuffer = std::stringstream; -class RegisterOpTilingUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(RegisterOpTilingUT, byte_buffer_test) { - EXPECT_NO_THROW( - ByteBuffer stream; - char *dest = nullptr; - size_t size = ByteBufferGetAll(stream, dest, 2); - cout << size << endl; - ); -} - -TEST_F(RegisterOpTilingUT, op_run_info_test) { - std::shared_ptr run_info = make_shared(8, true, 64); - int64_t work_space; - graphStatus ret = run_info->GetWorkspace(0, work_space); - EXPECT_EQ(ret, GRAPH_FAILED); - vector work_space_vec = {10, 20, 30, 40}; - run_info->SetWorkspaces(work_space_vec); - ret = run_info->GetWorkspace(1, work_space); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(work_space, 20); - EXPECT_EQ(run_info->GetWorkspaceNum(), 4); - string str = "test"; - run_info->AddTilingData(str); - - - std::shared_ptr run_info_2 = make_shared(*run_info); - ret = run_info_2->GetWorkspace(2, work_space); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(work_space, 30); - - utils::OpRunInfo run_info_3 = *run_info; - ret = run_info_3.GetWorkspace(3, work_space); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(work_space, 40); - - utils::OpRunInfo &run_info_4 = *run_info; - ret = run_info_4.GetWorkspace(0, work_space); - EXPECT_EQ(ret, GRAPH_SUCCESS); - EXPECT_EQ(work_space, 10); -} - -TEST_F(RegisterOpTilingUT, op_compile_info_test) { - std::shared_ptr compile_info = make_shared(); - string str_key = "key"; - string str_value = "value"; - AscendString key(str_key.c_str()); - AscendString value(str_value.c_str()); - compile_info->SetKey(key); - compile_info->SetValue(value); - - std::shared_ptr compile_info_2 = make_shared(key, value); - EXPECT_EQ(compile_info_2->GetKey() == key, true); - EXPECT_EQ(compile_info_2->GetValue() == value, true); - - std::shared_ptr compile_info_3 = make_shared(str_key, str_value); - EXPECT_EQ(compile_info_3->GetKey() == key, true); - EXPECT_EQ(compile_info_3->GetValue() == value, true); - - std::shared_ptr compile_info_4 = make_shared(*compile_info); - EXPECT_EQ(compile_info_4->GetKey() == key, true); - EXPECT_EQ(compile_info_4->GetValue() == value, true); - - utils::OpCompileInfo compile_info_5 = *compile_info; - EXPECT_EQ(compile_info_5.GetKey() == key, true); - EXPECT_EQ(compile_info_5.GetValue() == value, true); - - utils::OpCompileInfo &compile_info_6 = *compile_info; - EXPECT_EQ(compile_info_6.GetKey() == key, true); - EXPECT_EQ(compile_info_6.GetValue() == value, true); -} - -TEST_F(RegisterOpTilingUT, te_op_paras_test) { - OpDescPtr op_desc = make_shared("relu", OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN); - GeShape shape({1,4,1,1}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - int32_t attr_value = 1024; - AttrUtils::SetInt(op_desc, "some_int_attr", attr_value); - vector attr_vec = {11, 22, 33, 44}; - AttrUtils::SetListInt(op_desc, "some_int_vec", attr_vec); - TeOpParas op_param; - op_param.op_type = op_desc->GetType(); - VarAttrHelper::InitTeOpVarAttr(op_desc, op_param.var_attrs); - size_t size = 0; - EXPECT_NO_THROW( - op_param.var_attrs.GetData("some_int_attr", "xxx", size); - op_param.var_attrs.GetData("some_int_attr", "Int32", size); - op_param.var_attrs.GetData("some_int_vec", "ListInt32", size); - ); -} - -bool op_tiling_stub(const Operator &op, const utils::OpCompileInfo &compile_info, utils::OpRunInfo &run_info) { - return true; -} - -REGISTER_OP_TILING_V2(ReluV2, op_tiling_stub); - -TEST_F(RegisterOpTilingUT, OpFftsPlusCalculate_1) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("relu", "ReluV2", 1, 1); - const auto &op_desc = node->GetOpDesc(); - const Operator op = OpDescUtils::CreateOperatorFromNode(node); - - ThreadSliceMapDyPtr slice_info_ptr = std::make_shared(); - vector vec_1; - vec_1.push_back(1); - vector> vec_2; - vec_2.push_back(vec_1); - vec_2.push_back(vec_1); - slice_info_ptr->parallel_window_size = 2; - slice_info_ptr->slice_instance_num = 2; - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->output_tensor_slice.push_back(vec_2); - slice_info_ptr->output_tensor_slice.push_back(vec_2); - - (void)op_desc->SetExtAttr(ffts::kAttrSgtStructInfoDy, slice_info_ptr); - GeShape shape({4,1,3,4,16}); - GeTensorDesc tensor_desc(shape, ge::FORMAT_NCHW, ge::DT_FLOAT); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - std::vector op_run_info; - EXPECT_EQ(OpFftsPlusCalculate(op, op_run_info), ge::GRAPH_FAILED); - - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - auto dstAnchor = node->GetInDataAnchor(0); - ge::AnchorUtils::SetStatus(dstAnchor, ge::ANCHOR_DATA); - EXPECT_EQ(OpFftsPlusCalculate(op, op_run_info), ge::GRAPH_SUCCESS); -} - -// slice instance over -TEST_F(RegisterOpTilingUT, OpFftsPlusCalculate_2) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("relu", "ReluV2", 1, 1); - const auto &op_desc = node->GetOpDesc(); - const Operator op = OpDescUtils::CreateOperatorFromNode(node); - - ThreadSliceMapDyPtr slice_info_ptr = std::make_shared(); - vector vec_1; - vec_1.push_back(1); - vector> vec_2; - vec_2.push_back(vec_1); - vec_2.push_back(vec_1); - slice_info_ptr->parallel_window_size = 2; - slice_info_ptr->slice_instance_num = 4; - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->output_tensor_slice.push_back(vec_2); - slice_info_ptr->output_tensor_slice.push_back(vec_2); - slice_info_ptr->input_tensor_indexes.push_back(0); - slice_info_ptr->output_tensor_indexes.push_back(0); - (void)op_desc->SetExtAttr(ffts::kAttrSgtStructInfoDy, slice_info_ptr); - GeShape shape({4,1,3,4,16}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - string compile_info_key = "compile_info_key"; - string compile_info_json = "compile_info_json"; - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_KEY, compile_info_key); - (void)ge::AttrUtils::SetStr(op_desc, COMPILE_INFO_JSON, compile_info_json); - std::vector op_run_info; - EXPECT_EQ(OpFftsPlusCalculate(op, op_run_info), ge::GRAPH_FAILED); -} - -TEST_F(RegisterOpTilingUT, PostProcCalculateV2_SUCCESS) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("relu", "ReluV2", 1, 1); - Operator op = OpDescUtils::CreateOperatorFromNode(node); - OpDescPtr op_desc = node->GetOpDesc(); - (void)ge::AttrUtils::SetStr(op_desc, "_alias_engine_name", "TEST"); - std::vector workspaces = { 1, 2, 3}; - OpRunInfoV2 run_info; - run_info.SetWorkspaces(workspaces); - workspaces.emplace_back(5); - op_desc->SetWorkspaceBytes(workspaces); - ge::graphStatus ret = PostProcCalculateV2(op, run_info); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingUT, PostProcMemoryCheck1) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("relu", "ReluV2", 2, 1); - GeShape shape({3,4,2,1}); - GeTensorDesc tensor_desc(shape); - OpDescPtr op_desc = node->GetOpDesc(); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddInputDesc("y", tensor_desc); - op_desc->AddOutputDesc("z", tensor_desc); - Operator op = OpDescUtils::CreateOperatorFromNode(node); - std::vector workspaces = { 1, 2, 3}; - OpRunInfoV2 run_info; - run_info.SetWorkspaces(workspaces); - (void)ge::AttrUtils::SetBool(op_desc, kMemoryCheck, false); - (void)PostProcMemoryCheck(op, run_info); - ByteBuffer &data = run_info.GetAllTilingData(); - cout << "TEST" << data.str() << endl; - EXPECT_EQ(data.str().empty(), true); - (void)ge::AttrUtils::SetBool(op_desc, kMemoryCheck, true); - (void)ge::AttrUtils::SetInt(op_desc, kOriOpParaSize, 64); - (void)PostProcMemoryCheck(op, run_info); - ByteBuffer &data1 = run_info.GetAllTilingData(); - cout << "TEST1" << data1.str().c_str() << endl; - EXPECT_EQ(data1.str().empty(), true); - run_info.ResetAddrBase(nullptr, 1024); - (void)PostProcMemoryCheck(op, run_info); - ByteBuffer &data2 = run_info.GetAllTilingData(); - cout << "TEST2" << data2.str().c_str() << endl; - EXPECT_EQ(data2.str().empty(), false); -} - -TEST_F(RegisterOpTilingUT, UpDateNodeShapeBySliceInfo1) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("relu", "ReluV2", 1, 1); - OpDescPtr op_desc = node->GetOpDesc(); - ThreadSliceMapDyPtr slice_info_ptr; - slice_info_ptr = std::make_shared(); - vector vec_1; - vec_1.push_back(1); - vector> vec_2; - vector> vec_3; - vec_2.push_back(vec_1); - vec_2.push_back(vec_1); - vec_3.push_back(vec_1); - slice_info_ptr->parallel_window_size = 2; - slice_info_ptr->slice_instance_num = 2; - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->output_tensor_slice.push_back(vec_3); - slice_info_ptr->input_tensor_indexes.push_back(0); - slice_info_ptr->output_tensor_indexes.push_back(0); - (void)node->GetOpDesc()->SetExtAttr(ffts::kAttrSgtStructInfo, slice_info_ptr); - GeShape shape({4,1,3,4,16}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - vector ori_shape; - bool same_shape = false; - auto ret = UpDateNodeShapeBySliceInfo(slice_info_ptr, op_desc, 2, ori_shape, same_shape); - EXPECT_EQ(ret, ge::GRAPH_FAILED); - op_desc->AddOutputDesc("y", tensor_desc); - ret = UpDateNodeShapeBySliceInfo(slice_info_ptr, op_desc, 0, ori_shape, same_shape); - EXPECT_EQ(ret, ge::GRAPH_SUCCESS); -} - -TEST_F(RegisterOpTilingUT, UpDateNodeShapeBySliceInfo2) { - auto root_builder = ut::GraphBuilder("root"); - const auto &node = root_builder.AddNode("relu", "ReluV2", 1, 1); - OpDescPtr op_desc = node->GetOpDesc(); - ThreadSliceMapDyPtr slice_info_ptr; - slice_info_ptr = std::make_shared(); - vector vec_1; - vec_1.push_back(1); - vector> vec_2; - vec_2.push_back(vec_1); - vec_2.push_back(vec_1); - slice_info_ptr->parallel_window_size = 2; - slice_info_ptr->slice_instance_num = 2; - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->input_tensor_slice.push_back(vec_2); - slice_info_ptr->output_tensor_slice.push_back(vec_2); - slice_info_ptr->output_tensor_slice.push_back(vec_2); - slice_info_ptr->input_tensor_indexes.push_back(0); - slice_info_ptr->input_tensor_indexes.push_back(1); - slice_info_ptr->input_tensor_indexes.push_back(2); - slice_info_ptr->output_tensor_indexes.push_back(0); - slice_info_ptr->output_tensor_indexes.push_back(2); - GeShape shape({4,1,3,4,16}); - GeTensorDesc tensor_desc(shape); - op_desc->AddInputDesc("x", tensor_desc); - op_desc->AddOutputDesc("y", tensor_desc); - vector ori_shape; - bool same_shape = false; - auto ret = UpDateNodeShapeBySliceInfo(slice_info_ptr, op_desc, 0, ori_shape, same_shape); - EXPECT_EQ(ret, ge::PARAM_INVALID); - slice_info_ptr->input_tensor_indexes.push_back(0); - slice_info_ptr->input_tensor_indexes.push_back(1); - slice_info_ptr->input_tensor_indexes.push_back(2); - slice_info_ptr->output_tensor_indexes.push_back(0); - slice_info_ptr->output_tensor_indexes.push_back(2); - ret = UpDateNodeShapeBySliceInfo(slice_info_ptr, op_desc, 1, ori_shape, same_shape); - EXPECT_EQ(ret, ge::PARAM_INVALID); - ret = UpDateNodeShapeBack(op_desc, slice_info_ptr, ori_shape); - EXPECT_EQ(ret, ge::GRAPH_FAILED); -} - -TEST_F(RegisterOpTilingUT, op_run_info_test_new_tiling_interface1) { - utils::OpRunInfo run_info; - uint64_t max_size = 0; - void * base = run_info.GetAddrBase(max_size); - run_info.SetAddrBaseOffset(10); - EXPECT_TRUE(base == NULL); -} - -TEST_F(RegisterOpTilingUT, op_run_info_test_new_tiling_interface2) { - EXPECT_NO_THROW( - utils::OpRunInfo run_info; - int v1 = 1; - int64_t v2 = 2; - run_info << v1; - run_info << v2; - ); -} - -TEST_F(RegisterOpTilingUT, op_run_info_test_local_memory_size) { - utils::OpRunInfo run_info; - uint32_t local_memory_size = run_info.GetLocalMemorySize(); - EXPECT_EQ(local_memory_size, 0U); // default value - - const uint32_t test_val = 100U; - run_info.SetLocalMemorySize(test_val); - local_memory_size = run_info.GetLocalMemorySize(); - EXPECT_EQ(local_memory_size, test_val); // set value - - utils::OpRunInfo run_info2 = run_info; // copy constructor - local_memory_size = run_info2.GetLocalMemorySize(); - EXPECT_EQ(local_memory_size, test_val); - - utils::OpRunInfo run_info3(1, 2, 3); - local_memory_size = run_info3.GetLocalMemorySize(); - EXPECT_EQ(local_memory_size, 0U); // default value -} -} // namespace ge diff --git a/tests/ut/register/testcase/register_prototype_unittest.cc b/tests/ut/register/testcase/register_prototype_unittest.cc deleted file mode 100644 index 25858fc7b2720cfd8c5efa56cb3ce37694387d75..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_prototype_unittest.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include - -#include "register/prototype_pass_registry.h" -namespace ge { -class UtestProtoTypeRegister : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -class RegisterPass : public ProtoTypeBasePass { - public: - Status Run(google::protobuf::Message *message) { return SUCCESS; } -}; - -class RegisterFail : public ProtoTypeBasePass { - public: - Status Run(google::protobuf::Message *message) { return FAILED; } -}; - -REGISTER_PROTOTYPE_PASS("RegisterPass", RegisterPass, domi::CAFFE); -REGISTER_PROTOTYPE_PASS("RegisterPass", RegisterPass, domi::CAFFE); - -TEST_F(UtestProtoTypeRegister, register_test) { - auto pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(domi::CAFFE); - EXPECT_EQ(pass_vec.size(), 1); -} - -TEST_F(UtestProtoTypeRegister, register_test_fail) { - REGISTER_PROTOTYPE_PASS(nullptr, RegisterPass, domi::CAFFE); - REGISTER_PROTOTYPE_PASS("RegisterFail", RegisterFail, domi::CAFFE); - - ProtoTypePassRegistry::GetInstance().RegisterProtoTypePass(nullptr, nullptr, domi::CAFFE); - auto pass_vec = ProtoTypePassRegistry::GetInstance().GetCreateFnByType(domi::CAFFE); - EXPECT_NE(pass_vec.size(), 1); -} -} // namespace ge diff --git a/tests/ut/register/testcase/register_scope_graph_unittest.cc b/tests/ut/register/testcase/register_scope_graph_unittest.cc deleted file mode 100644 index e330e8d87e7c2b387a8efd46aa25993891264a4f..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_scope_graph_unittest.cc +++ /dev/null @@ -1,682 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph/debug/ge_attr_define.h" -#include "external/register/scope/scope_fusion_pass_register.h" -#include "register/scope/scope_graph_impl.h" -#include "register/scope/scope_pass_impl.h" -#include "register/scope/scope_pass_registry_impl.h" - -using namespace ge; -class UtestScopeGraph : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -/* --- sub0, sub1 only for UT test --- -* placeholder0 placeholder1 -* | /\ /\ | -* | / \/ \ | -* | / /\ \ | -* | | / \ | | -* | add0 mul0 | -* | /\ c/|\ | -* | / sub0 / | \ | -* mul1 ---- / | add1 -* \ | |\ -* \ ---- add2 | sub1 -* | | -* retval0 retval1 -*/ - -void CreateGraphDef(domi::tensorflow::GraphDef &graph_def) { - // 1. add node - auto placeholder0 = graph_def.add_node(); - auto placeholder1 = graph_def.add_node(); - auto add0 = graph_def.add_node(); - auto add1 = graph_def.add_node(); - auto sub0 = graph_def.add_node(); - auto mul0 = graph_def.add_node(); - auto mul1 = graph_def.add_node(); - auto add2 = graph_def.add_node(); - auto retval0 = graph_def.add_node(); - auto retval1 = graph_def.add_node(); - - // 2. set info - placeholder0->set_name("placeholder0"); - placeholder0->set_op("PlaceHolder"); - placeholder1->set_name("placeholder1"); - placeholder1->set_op("PlaceHolder"); - - add0->set_name("add0"); - add0->set_op("Add"); - add1->set_name("add1"); - add1->set_op("Add"); - add2->set_name("add2"); - add2->set_op("Add"); - sub0->set_name("add0/sub0"); - sub0->set_op("Sub"); - - mul0->set_name("mul0"); - mul0->set_op("Mul"); - mul1->set_name("mul1"); - mul1->set_op("Mul"); - - retval0->set_name("retval0"); - retval0->set_op("_RetVal"); - retval1->set_name("retval1"); - retval1->set_op("_RetVal"); - - // 3. add edges - add0->add_input("placeholder0"); - add0->add_input("placeholder1"); - sub0->add_input("placeholder0"); - sub0->add_input("placeholder1"); - - mul0->add_input("placeholder0"); - mul0->add_input("placeholder1"); - - mul1->add_input("placeholder0"); - mul1->add_input("add0"); - mul1->add_input("^mul0"); - - add1->add_input("mul0"); - add1->add_input("placeholder1"); - - add2->add_input("mul1"); - add2->add_input("mul0"); - - retval0->add_input("add2:0"); - retval1->add_input("add1:0"); -} - -TEST_F(UtestScopeGraph, test_build_scope_graph_succ) { - domi::tensorflow::GraphDef graph_def; - - CreateGraphDef(graph_def); - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - Status ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - auto nodes_map = impl->GetNodesMap(); - EXPECT_EQ(nodes_map.size(), 10); - - // checkpoint 1 - auto mul0_iter = nodes_map.find("mul0"); - ASSERT_NE(mul0_iter, nodes_map.end()); - std::vector mul0_inputs; - std::vector mul0_outputs; - mul0_iter->second->GetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_INPUTS, mul0_inputs); - mul0_iter->second->GetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_OUTPUTS, mul0_outputs); - ASSERT_EQ(mul0_inputs.size(), 2); - EXPECT_EQ(mul0_inputs.at(0), "0:placeholder0:0"); - EXPECT_EQ(mul0_inputs.at(1), "1:placeholder1:0"); - ASSERT_EQ(mul0_outputs.size(), 3); - EXPECT_EQ(mul0_outputs.at(0), "-1:mul1:-1"); - EXPECT_EQ(mul0_outputs.at(1), "0:add1:0"); - EXPECT_EQ(mul0_outputs.at(2), "0:add2:1"); - - // checkpoint 2 - auto mul1_iter = nodes_map.find("mul1"); - ASSERT_NE(mul1_iter, nodes_map.end()); - std::vector mul1_inputs; - std::vector mul1_outputs; - mul1_iter->second->GetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_INPUTS, mul1_inputs); - mul1_iter->second->GetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_OUTPUTS, mul1_outputs); - ASSERT_EQ(mul1_inputs.size(), 3); - EXPECT_EQ(mul1_inputs.at(0), "-1:mul0:-1"); - EXPECT_EQ(mul1_inputs.at(1), "0:placeholder0:0"); - EXPECT_EQ(mul1_inputs.at(2), "1:add0:0"); - ASSERT_EQ(mul1_outputs.size(), 1); - EXPECT_EQ(mul1_outputs.at(0), "0:add2:0"); -} - -TEST_F(UtestScopeGraph, test_build_scope_graph_node_without_inout) { - domi::tensorflow::GraphDef graph_def; - auto no_op = graph_def.add_node(); - no_op->set_name("test_scope/no_op"); - no_op->set_op("NoOp"); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - Status ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - auto nodes_map = impl->GetNodesMap(); - EXPECT_EQ(nodes_map.size(), 1); - auto iter = nodes_map.find("test_scope/no_op"); - ASSERT_NE(iter, nodes_map.end()); - std::vector inputs; - std::vector outputs; - graphStatus get_input_attr = iter->second->GetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_INPUTS, inputs); - graphStatus get_output_attr = iter->second->GetAttr(ATTR_NAME_ORIGIN_GRAPH_NODE_OUTPUTS, outputs); - ASSERT_EQ(get_input_attr, GRAPH_SUCCESS); - ASSERT_EQ(get_output_attr, GRAPH_SUCCESS); - EXPECT_EQ(inputs.size(), 0); - EXPECT_EQ(outputs.size(), 0); -} - -TEST_F(UtestScopeGraph, test_build_scope_graph_failed) { - domi::tensorflow::GraphDef graph_def; - auto placeholder0 = graph_def.add_node(); - auto placeholder1 = graph_def.add_node(); - auto add0 = graph_def.add_node(); - - placeholder0->set_name("placeholder0"); - placeholder0->set_op("PlaceHolder"); - placeholder1->set_name("placeholder1"); - placeholder1->set_op("PlaceHolder"); - - add0->set_name("add0"); - add0->set_op("Add"); - add0->add_input("placeholder0"); - add0->add_input("placeholder1"); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - Status ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - - // 1. input name is invalied - add0->set_input(0, "placeholder0:invalid:input"); - impl->BuildScopeGraph(&graph_def); - auto nodes_map = impl->GetNodesMap(); - EXPECT_EQ(nodes_map.size(), 0); - - // 2. index is invalid - add0->set_input(0, "placeholder0:s1"); - impl->BuildScopeGraph(&graph_def); - nodes_map = impl->GetNodesMap(); - EXPECT_EQ(nodes_map.size(), 0); - - // 3. index is out of range - add0->set_input(0, "placeholder0:12356890666666"); - impl->BuildScopeGraph(&graph_def); - nodes_map = impl->GetNodesMap(); - EXPECT_EQ(nodes_map.size(), 0); - - // index is negative - add0->set_input(0, "placeholder0:-1"); - impl->BuildScopeGraph(&graph_def); - nodes_map = impl->GetNodesMap(); - EXPECT_EQ(nodes_map.size(), 0); -} - -TEST_F(UtestScopeGraph, IsFusionOpTest) { - bool retBool; - Status stat; - FusionScopesResult *retFusnScopRst; - - domi::tensorflow::GraphDef graph_def; - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - Status ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - // AddFusionScopesResult - FusionScopesResult *fusionResult = new (std::nothrow) FusionScopesResult(); - ASSERT_NE(fusionResult, nullptr); - stat = fusionResult->Init(); - EXPECT_EQ(stat, SUCCESS); - impl->AddFusionScopesResult(nullptr); - impl->AddFusionScopesResult(fusionResult); - - // add nodes for check FusionOp - std::vector fusionRstNodes; - OperatorPtr op1(new (std::nothrow) ge::Operator("addTest", "Add")); - fusionRstNodes.push_back(op1); - OperatorPtr op2(new (std::nothrow) ge::Operator("Sub", "Sub")); - fusionRstNodes.push_back(op2); - OperatorPtr op3(new (std::nothrow) ge::Operator("Mul", "Mul")); - fusionRstNodes.push_back(op3); - fusionResult->impl_->AddNodes(fusionRstNodes); - - // IsFusionOp - domi::tensorflow::NodeDef *emptyNode = nullptr; - retBool = impl->IsFusionOp(emptyNode); - EXPECT_EQ(retBool, false); - retFusnScopRst = impl->GetFusionScopesResults(emptyNode); - EXPECT_EQ(retFusnScopRst, nullptr); - - domi::tensorflow::NodeDef *tmpNode = graph_def.add_node(); - tmpNode->set_name("div"); - tmpNode->set_op("Div"); - tmpNode->add_input("placeholder0"); - tmpNode->add_input("placeholder1"); - retBool = impl->IsFusionOp(tmpNode); - EXPECT_EQ(retBool, false); - retFusnScopRst = impl->GetFusionScopesResults(tmpNode); - EXPECT_EQ(retFusnScopRst, nullptr); - retFusnScopRst = impl->GetFusionScopesResults(std::string("div")); - EXPECT_EQ(retFusnScopRst, nullptr); - - // IsFusionOpChild - std::vector info_list; - retBool = impl->IsFusionOpChild(std::string("nodeName"), info_list); - EXPECT_EQ(retBool, false); - retBool = impl->IsFusionOpChild(std::string("addTest"), info_list); - EXPECT_EQ(retBool, true); - retBool = impl->IsFusionOpChild(std::string("Sub"), info_list); - EXPECT_EQ(retBool, true); - retBool = impl->IsFusionOpChild(std::string("Mul"), info_list); - EXPECT_EQ(retBool, true); - - // FusionOpChildIgnore - retBool = impl->FusionOpChildIgnore(info_list.front()); - EXPECT_EQ(retBool, true); - - std::vector index_map = {1, 2}; - fusionResult->InsertInputs("Sub", index_map); - fusionResult->InsertOutputs("Mul", index_map); - retBool = impl->FusionOpChildIgnore(info_list.back()); - EXPECT_EQ(retBool, false); - - // GetInputOrOutputIndex - int32_t old_index = -1; - int32_t new_index = 2; - stat = impl->GetInputOrOutputIndex(info_list.front(), old_index, true, new_index); - EXPECT_EQ(new_index, -1); - EXPECT_EQ(stat, SUCCESS); - - old_index = 666; - stat = impl->GetInputOrOutputIndex(info_list.front(), old_index, true, new_index); - EXPECT_EQ(new_index, kFusionDisableIndex); - EXPECT_EQ(stat, SUCCESS); - - old_index = 1; - stat = impl->GetInputOrOutputIndex(info_list.back(), old_index, true, new_index); - EXPECT_EQ(new_index, kFusionDisableIndex); - EXPECT_EQ(stat, SUCCESS); -} - -TEST_F(UtestScopeGraph, IsFusionOpTest_Fail) { - // AddFusionScopesResult - FusionScopesResult *fusionResult = new (std::nothrow) FusionScopesResult(); - ASSERT_NE(fusionResult, nullptr); - std::string name = "name"; - fusionResult->SetName(name); - fusionResult->SetName(name.c_str()); - fusionResult->SetType(name); - fusionResult->SetType(name.c_str()); - fusionResult->SetDescription(name); - fusionResult->SetDescription(name.c_str()); - EXPECT_EQ(fusionResult->Name(), ""); - AscendString as; - EXPECT_EQ(fusionResult->Name(as), ge::GRAPH_PARAM_INVALID); - EXPECT_EQ(fusionResult->Nodes().empty(), true); - fusionResult->InsertInputs(name, {}); - fusionResult->InsertInputs(name.c_str(), {}); - fusionResult->InsertOutputs(name, {}); - delete fusionResult; -} - -TEST_F(UtestScopeGraph, ScopeImplAddNodesTest) { - Status ret; - domi::tensorflow::GraphDef graph_def; - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - const ScopeTree *scopeTree = scope_graph->GetScopeTree(); - - OperatorPtr nodeDef1 = nullptr; - scopeTree->GetAllScopes().front()->impl_->AddNode(nodeDef1); - scopeTree->impl_->AddNodeToScope(nodeDef1); - - OperatorPtr nodeDef2(new (std::nothrow) ge::Operator("add0/sub0", "Add")); - scopeTree->GetAllScopes().front()->impl_->AddNode(nodeDef2); - scopeTree->impl_->AddNodeToScope(nodeDef2); - EXPECT_EQ(scopeTree->GetAllScopes().empty(), false); - - std::unordered_map node_map; - scopeTree->GetAllScopes().front()->AllNodesMap(); - ret = scopeTree->GetAllScopes().front()->AllNodesMap(node_map); - EXPECT_EQ(ret, ge::SUCCESS); - - std::vector scopes = scopeTree->GetAllScopes().front()->impl_->GetAllSubScopes(); - EXPECT_EQ(scopes.empty(), false); -} - -TEST_F(UtestScopeGraph, GetScopeLastName) { - Status ret; - domi::tensorflow::GraphDef graph_def; - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - const ScopeTree *scopeTree = scope_graph->GetScopeTree(); - scopeTree->GetAllScopes().front()->LastName(); - AscendString name; - ret = scopeTree->GetAllScopes().front()->LastName(name); - EXPECT_EQ(ret, ge::SUCCESS); -} - -TEST_F(UtestScopeGraph, TrimScopeIndex) { - domi::tensorflow::GraphDef graph_def; - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - Status ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - const ScopeTree *scopeTree = scope_graph->GetScopeTree(); - - std::string scope_str = "scope_str_2"; - std::string retStr1 = scopeTree->GetAllScopes().front()->impl_->TrimScopeIndex(scope_str); - EXPECT_EQ(retStr1 == scope_str, false); - - scope_str = "scope_str_9223372036854775807"; - std::string retStr2 = scopeTree->GetAllScopes().front()->impl_->TrimScopeIndex(scope_str); - EXPECT_EQ(retStr2 == scope_str, true); - - scope_str = "scope_str_"; - std::string retStr3 = scopeTree->GetAllScopes().front()->impl_->TrimScopeIndex(scope_str); - EXPECT_EQ(retStr3 == scope_str, true); - - scope_str = "scope_66666"; - std::string retStr4 = scopeTree->GetAllScopes().front()->impl_->TrimScopeIndex(scope_str); - EXPECT_EQ(retStr4 == scope_str, false); -} - -TEST_F(UtestScopeGraph, ScopeImplOpTypeTest) { - int retInt; - domi::tensorflow::GraphDef graph_def; - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - Status ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - const ScopeTree *scopeTree = scope_graph->GetScopeTree(); - - const std::string op_type1 = "Add"; - const std::string op_type2 = "Mul666"; - retInt = scopeTree->GetAllScopes().front()->impl_->GetOpTypeNum(op_type1); - EXPECT_EQ(retInt, -1); - retInt = scopeTree->GetAllScopes().front()->impl_->GetOpTypeNum(std::string("type1")); - EXPECT_EQ(retInt, -1); - - scopeTree->GetAllScopes().front()->impl_->OpsNumInc(op_type1); - scopeTree->GetAllScopes().front()->impl_->OpsNumInc(op_type1); - scopeTree->GetAllScopes().front()->impl_->OpsNumInc(op_type2); - retInt = scopeTree->GetAllScopes().front()->impl_->GetOpTypeNum(op_type1); - EXPECT_EQ(retInt, 2); - retInt = scopeTree->GetAllScopes().front()->impl_->GetOpTypeNum(op_type2); - EXPECT_EQ(retInt, 1); - - scopeTree->GetAllScopes().front()->impl_->ClearTypeAndSubType(); - const std::vector &sub_scopes = scopeTree->GetAllScopes().front()->impl_->GetAllSubScopes(); - for (auto &sub_scope : sub_scopes) { - std::string type = sub_scope->SubType(); - EXPECT_EQ(type == "", true); - } -} - -TEST_F(UtestScopeGraph, scopeGraphInit) { - Status ret; - domi::tensorflow::GraphDef graph_def; - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - ret = scope_graph->Init(); - ASSERT_EQ(ret, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - const ScopeTree *scopeTree = scope_graph->GetScopeTree(); - - // init - const char_t *name = "init_name"; - const char_t *sub_type = "sub_type"; - scopeTree->GetAllScopes().front()->Init(name, sub_type, nullptr); - - // Name - AscendString a_name; - ret = scopeTree->GetAllScopes().front()->Name(a_name); - EXPECT_EQ(ret, ge::SUCCESS); - - // SubType - scopeTree->GetAllScopes().front()->SubType(); - AscendString type; - ret = scopeTree->GetAllScopes().front()->SubType(type); - EXPECT_EQ(ret, ge::SUCCESS); - - // GetScope - const Scope *scope1 = scopeTree->GetAllScopes().front()->GetSubScope(std::string("Add")); - EXPECT_EQ(scope1, nullptr); - // Used to test function overloading - const Scope *scope2 = scopeTree->GetAllScopes().front()->GetSubScope("Add"); - EXPECT_EQ(scope2, nullptr); - - const Scope *scope3 = scopeTree->GetAllScopes().front()->GetFatherScope(); - EXPECT_EQ(scope3, nullptr); - - std::vector scopes = scopeTree->GetAllScopes().front()->GetAllSubScopes(); - EXPECT_EQ(scopes.empty(), true); - - // GetNodesMap - std::unordered_map nodes_map; - scope_graph->GetNodesMap(); - ret = scope_graph->GetNodesMap(nodes_map); - EXPECT_EQ(ret, ge::SUCCESS); -} - -class UtestFusionScope : public testing::Test { - public: - domi::tensorflow::GraphDef graph_def; - std::shared_ptr scope_graph; - FusionScopesResult *fusion_rlt0; - FusionScopesResult *fusion_rlt; - - protected: - void SetUp() { - Status ret; - CreateGraphDef(graph_def); - - scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - ret = scope_graph->Init(); - ASSERT_EQ(ret, ge::SUCCESS); - - fusion_rlt0 = new (std::nothrow) FusionScopesResult(); - fusion_rlt = new (std::nothrow) FusionScopesResult(); - ASSERT_NE(fusion_rlt, nullptr); - ret = fusion_rlt->Init(); - ASSERT_EQ(ret, ge::SUCCESS); - } - void TearDown() { - delete fusion_rlt0; - delete fusion_rlt; - } -}; - -TEST_F(UtestFusionScope, FusionScopesResultSetInfo) { - fusion_rlt->SetName(std::string("fusionRstName1")); - EXPECT_EQ(fusion_rlt->Name() == "fusionRstName1", true); - - fusion_rlt->SetName("fusionRstName2"); - const std::string fsnName1 = fusion_rlt->Name(); - AscendString fsnName2; - fusion_rlt->Name(fsnName2); - EXPECT_EQ(strncmp(fsnName2.GetString(), "fusionRstName2", strlen("fusionRstName2")), false); - - fusion_rlt->SetType(std::string("fusionRstype1")); - fusion_rlt->SetType("fusionRstype2"); - - fusion_rlt->SetDescription(std::string("fusionRstDesc1")); - fusion_rlt->SetDescription("fusionRstDesc2"); - EXPECT_EQ(fusion_rlt->impl_->Description() == "fusionRstDesc2", true); -} - -TEST_F(UtestFusionScope, FusionScopesResultInnerNodeInfo) { - fusion_rlt->SetName("fusionRstName"); - fusion_rlt->SetType("fusionRstype"); - fusion_rlt->SetDescription("fusionRstDesc"); - - std::vector index_map(6); - index_map.push_back(1); - index_map.push_back(2); - index_map.push_back(5); - index_map.push_back(6); - - fusion_rlt->InsertInputs(std::string("innerIOpName1"), index_map); - fusion_rlt->InsertInputs("innerIOpName2", index_map); - - fusion_rlt->InsertOutputs(std::string("innerOOpName1"), index_map); - fusion_rlt->InsertOutputs("innerOOpName2", index_map); - - FusionScopesResult::InnerNodeInfo *InnerNode1; - fusion_rlt0->AddInnerNode(std::string("InnerNodeName1"), std::string("InnerNodeType1")); - InnerNode1 = fusion_rlt->AddInnerNode("InnerNodeName1", "InnerNodeType1"); - (void)InnerNode1; - FusionScopesResult::InnerNodeInfo *InnerNode2; - fusion_rlt0->AddInnerNode(std::string("InnerNodeName2"), std::string("InnerNodeType2")); - InnerNode2 = fusion_rlt->AddInnerNode("InnerNodeName2", "InnerNodeType2"); - InnerNode2->InsertOutput(kOutputToFusionScope, 0); - FusionScopesResult::InnerNodeInfo *retInnerNodeInfo; - retInnerNodeInfo = fusion_rlt0->MutableRecentInnerNode(); - EXPECT_EQ(retInnerNodeInfo, nullptr); - retInnerNodeInfo = fusion_rlt->MutableRecentInnerNode(); - EXPECT_NE(retInnerNodeInfo, nullptr); - - retInnerNodeInfo = fusion_rlt0->MutableInnerNode(1); - EXPECT_EQ(retInnerNodeInfo, nullptr); - retInnerNodeInfo = fusion_rlt->MutableInnerNode(1); - EXPECT_NE(retInnerNodeInfo, nullptr); - ge::graphStatus retGraphStat; - retGraphStat = fusion_rlt0->CheckInnerNodesInfo(); - EXPECT_EQ(retGraphStat, ge::GRAPH_PARAM_INVALID); - retGraphStat = fusion_rlt->CheckInnerNodesInfo(); - EXPECT_EQ(retGraphStat, ge::GRAPH_PARAM_INVALID); - - FusionInnerNodesInfo nodes_info = fusion_rlt->impl_->GetInnerNodesInfo(); - EXPECT_EQ(nodes_info.empty(), false); -} - -TEST_F(UtestFusionScope, FusionScopesResultCheckInnerNodesInfo) { - fusion_rlt->SetName("CheckInnerNodesInfo"); - fusion_rlt->SetType("CheckInnerNodesInfo"); - fusion_rlt->SetDescription("CheckInnerNodesInfo"); - - std::vector index_map{0}; - const std::string input_name("inputop"); - const std::string output_name("outputop"); - fusion_rlt->InsertInputs(input_name, index_map); - fusion_rlt->InsertOutputs(output_name, index_map); - auto InnerNode = fusion_rlt->AddInnerNode("InnerNodeName", "InnerNodeType"); - InnerNode->InsertInput(kInputFromFusionScope, 0); - InnerNode->InsertOutput(kOutputToFusionScope, 0); - // check - EXPECT_EQ(fusion_rlt->CheckInnerNodesInfo(), ge::GRAPH_SUCCESS); - - FusionInnerNodesInfo nodes_info = fusion_rlt->impl_->GetInnerNodesInfo(); - //check - EXPECT_EQ(nodes_info.size(), 1U); - - const auto name = std::get<0>(nodes_info[0U]); - const auto type = std::get<1>(nodes_info[0U]); - const auto inputs = std::get<2>(nodes_info[0U]); - const auto outputs = std::get<3>(nodes_info[0U]); - const auto op = std::get<4>(nodes_info[0U]); - - EXPECT_EQ(name, "CheckInnerNodesInfo/InnerNodeName"); - EXPECT_EQ(type, "InnerNodeType"); - EXPECT_EQ(inputs.size(), 1U); - EXPECT_EQ(inputs[0U].first, kInputFromFusionScope); - EXPECT_EQ(inputs[0U].second, 0); - EXPECT_EQ(outputs.size(), 1U); - EXPECT_EQ(outputs[0U].first, kOutputToFusionScope); - EXPECT_EQ(outputs[0U].second, 0); - EXPECT_NE(op, nullptr); -} - -TEST_F(UtestFusionScope, InnerNodeInit) { - FusionScopesResult::InnerNodeInfo InnerNode1(std::string("FusionNode1")); - InnerNode1.SetName(std::string("InnerAdd")); - InnerNode1.SetType(std::string("Add")); - InnerNode1.InsertInput(std::string("Input1"), 1); - InnerNode1.InsertOutput(std::string("Output1"), 11); - - FusionScopesResult::InnerNodeInfo InnerNode2("FusionNode2"); - InnerNode2.SetName("InnerSub"); - InnerNode2.SetType("Sub"); - InnerNode2.InsertInput("Input2", 2); - InnerNode2.InsertOutput("Output2", 22); - - FusionScopesResult::InnerNodeInfo InnerNode3("FusionNodeName3", "NodeName3", "NodeType3"); - - EXPECT_NE(InnerNode1.BuildInnerNode(), ge::GRAPH_PARAM_INVALID); - EXPECT_NE(InnerNode1.MutableOperator(), nullptr); - - std::string InnerNodeInfoStr; - AscendString InnerNodeInfoAscendStr; - //name - InnerNode1.GetName(); - graphStatus status = InnerNode1.GetName(InnerNodeInfoAscendStr); - ASSERT_EQ(status, GRAPH_SUCCESS); - // type - InnerNode1.GetType(); - status = InnerNode1.GetType(InnerNodeInfoAscendStr); - ASSERT_EQ(status, GRAPH_SUCCESS); - - std::vector> pairList; - std::vector> pairAscendList; - pairList = InnerNode1.GetInputs(); - status = InnerNode1.GetInputs(pairAscendList); - ASSERT_EQ(status, GRAPH_SUCCESS); - - pairList = InnerNode1.GetOutputs(); - status = InnerNode1.GetOutputs(pairAscendList); - ASSERT_EQ(status, GRAPH_SUCCESS); -} - -TEST_F(UtestFusionScope, InnerNodeSetIOFormat) { - graphStatus retGraphStat; - - FusionScopesResult::InnerNodeInfo InnerNode(std::string("FusionNode")); - InnerNode.SetName(std::string("InnerAdd")); - InnerNode.SetType(std::string("Add")); - InnerNode.InsertInput(std::string("Input"), 1); - EXPECT_NE(InnerNode.BuildInnerNode(), ge::GRAPH_PARAM_INVALID); - EXPECT_NE(InnerNode.MutableOperator(), nullptr); - // Used to test function overloading - retGraphStat = InnerNode.SetInputFormat(std::string("InputName1"), std::string("InputFormat1")); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); - retGraphStat = InnerNode.SetInputFormat("InputName2", "InputFormat2"); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); - - retGraphStat = InnerNode.SetOutputFormat(std::string("OutputName1"), std::string("OutputFormat1")); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); - retGraphStat = InnerNode.SetOutputFormat("OutputName2", "OutputFormat2"); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); - - retGraphStat = - InnerNode.SetDynamicInputFormat(std::string("DynamicInputName1"), 0, std::string("DynamicInputFormat1")); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); - retGraphStat = InnerNode.SetDynamicInputFormat("DynamicInputName2", 1, "DynamicInputFormat2"); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); - - retGraphStat = - InnerNode.SetDynamicOutputFormat(std::string("DynamicOutputName1"), 0, std::string("DynamicOutputFormat1")); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); - retGraphStat = InnerNode.SetDynamicOutputFormat("DynamicOutputName2", 1, "DynamicOutputFormat2"); - EXPECT_NE(retGraphStat, ge::GRAPH_PARAM_INVALID); -} diff --git a/tests/ut/register/testcase/register_scope_pass_registry_unittest.cc b/tests/ut/register/testcase/register_scope_pass_registry_unittest.cc deleted file mode 100644 index d606decec1348de12410d524d55a476f1b92e4ca..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_scope_pass_registry_unittest.cc +++ /dev/null @@ -1,116 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include - -#include "external/register/scope/scope_fusion_pass_register.h" -#include "register/scope/scope_graph_impl.h" -#include "register/scope/scope_pass_impl.h" -#include "register/scope/scope_pass_registry_impl.h" - -using namespace ge; -namespace { -class ScopeBasePassChild : public ScopeBasePass { -protected: - std::vector DefinePatterns() { - std::vector ret; - return ret; - } - std::string PassName() { - return std::string("passName"); - } - Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) { - return SUCCESS; - } - void GenerateFusionResult(const std::vector &scopes, - FusionScopesResult *fusion_rlt) { - return; - } -}; -} - -class UtestScopePassRegistry : public testing::Test { -public: - std::unique_ptr scoPassPtr1; - std::unique_ptr scoPassPtr2; - -protected: - void SetUp() { - const char *regName1 = "regScoFusnPass1"; - REGISTER_SCOPE_FUSION_PASS(regName1, ScopeBasePassChild, true); - scoPassPtr1 = - ScopeFusionPassRegistry::GetInstance().impl_->CreateScopeFusionPass(std::string(regName1)); - - const char *regName2 = "regScoFusnPass2"; - REGISTER_SCOPE_FUSION_PASS(regName2, ScopeBasePassChild, true); - scoPassPtr2 = - ScopeFusionPassRegistry::GetInstance().impl_->CreateScopeFusionPass(std::string(regName2)); - } - void TearDown() {} -}; - -TEST_F(UtestScopePassRegistry, ScopeFusionPassRegistryRecreat) { - std::shared_ptr scope_graph = std::make_shared(); - scope_graph->Init(); - std::vector results; - EXPECT_EQ(scoPassPtr2->LastMatchScopesAndOPs(scope_graph, results), SUCCESS); - - REGISTER_SCOPE_FUSION_PASS("regScoFusnPass2", ScopeBasePassChild, true); - EXPECT_EQ(scoPassPtr2->LastMatchScopesAndOPs(scope_graph, results), SUCCESS); -} - -TEST_F(UtestScopePassRegistry, ScopeFusionPassRegistryRegister) { - std::shared_ptr scope_graph = std::make_shared(); - scope_graph->Init(); - std::vector results; - EXPECT_EQ(scoPassPtr1->LastMatchScopesAndOPs(scope_graph, results), SUCCESS); - EXPECT_EQ(scoPassPtr2->LastMatchScopesAndOPs(scope_graph, results), SUCCESS); - - std::vector nameList = ScopeFusionPassRegistry::GetInstance().impl_->GetAllRegisteredPasses(); - EXPECT_EQ(nameList[0]=="regScoFusnPass1", true); - EXPECT_EQ(nameList[1]=="regScoFusnPass2", true); -} - -TEST_F(UtestScopePassRegistry, SetPassEnableFlag) { - bool retBool; - std::unique_ptr scoPassPtr; - const char *regName = "regScoFusnPass"; - REGISTER_SCOPE_FUSION_PASS(nullptr, ScopeBasePassChild, true); // test for fail - REGISTER_SCOPE_FUSION_PASS(regName, ScopeBasePassChild, true); - scoPassPtr = ScopeFusionPassRegistry::GetInstance().impl_->CreateScopeFusionPass(std::string(regName)); - - retBool = ScopeFusionPassRegistry::GetInstance().impl_->SetPassEnableFlag(std::string(regName), false); - EXPECT_EQ(retBool, true); - retBool = ScopeFusionPassRegistry::GetInstance().impl_->SetPassEnableFlag(std::string("test"), false); - EXPECT_EQ(retBool, false); -} - -TEST_F(UtestScopePassRegistry, GetCreateFnWithDisableFlag) { - bool retBool; - const char *regName = "regScoFusnPass1"; - - std::shared_ptr scope_graph = std::make_shared(); - scope_graph->Init(); - std::vector results; - - std::unique_ptr scoPassPtr1 = - ScopeFusionPassRegistry::GetInstance().impl_->CreateScopeFusionPass(std::string(regName)); - EXPECT_EQ(scoPassPtr1->LastMatchScopesAndOPs(scope_graph, results), SUCCESS); - - retBool = ScopeFusionPassRegistry::GetInstance().impl_->SetPassEnableFlag(std::string(regName), false); - ASSERT_EQ(retBool, true); - - std::unique_ptr scoPassPtr2 = - ScopeFusionPassRegistry::GetInstance().impl_->CreateScopeFusionPass(std::string(regName)); - EXPECT_EQ(scoPassPtr2, nullptr); -} diff --git a/tests/ut/register/testcase/register_scope_pass_unittest.cc b/tests/ut/register/testcase/register_scope_pass_unittest.cc deleted file mode 100644 index 51e45eb31fc53dda52948c24fe651e1f418e2ac5..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_scope_pass_unittest.cc +++ /dev/null @@ -1,496 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include -#include -#include "graph/debug/ge_util.h" - -#include "external/register/scope/scope_fusion_pass_register.h" -#include "register/scope/scope_graph_impl.h" -#include "register/scope/scope_pass_impl.h" -#include "register/scope/scope_pass_registry_impl.h" - -using namespace ge; -class UtestScopePass : public testing::Test { -protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(UtestScopePass, ScopesResultImplFail) { - ScopesResult scopeRstOri; - scopeRstOri.impl_.reset(); - - std::vector nodes; - scopeRstOri.SetNodes(nodes); - - std::vector scopes; - scopeRstOri.SetScopes(scopes); - - ScopesResult scopeRst1(scopeRstOri); - EXPECT_EQ(scopeRst1.impl_->GetScopes().empty(), true); - - ScopesResult scopeRst2; - scopeRst2 = scopeRst2; - EXPECT_EQ(scopeRst2.impl_->GetScopes().empty(), true); - - ScopesResult scopeRst3; - scopeRst3 = scopeRstOri; - EXPECT_EQ(scopeRst3.impl_->GetScopes().empty(), true); -} - -TEST_F(UtestScopePass, ScopesResultRegister) { - ScopesResult scopeRstOri; - std::vector scopes; - std::vector nodes; - // test for initialized - std::vector ScopeList = scopeRstOri.impl_->GetScopes(); - EXPECT_EQ(ScopeList.empty(), true); - std::vector NodeList = scopeRstOri.impl_->GetNodes(); - EXPECT_EQ(NodeList.empty(), true); - - // add scope - Scope scope1; - scope1.Init("scope1", "type1"); - scopes.push_back(&scope1); - Scope scope2; - scope2.Init("scope2", "type2"); - scopes.push_back(&scope2); - scopeRstOri.SetScopes(scopes); - // add node - OperatorPtr node1(new (std::nothrow) ge::Operator("add", "Add")); - nodes.push_back(node1); - OperatorPtr node2(new (std::nothrow) ge::Operator("sub", "Sub")); - nodes.push_back(node2); - OperatorPtr node3(new (std::nothrow) ge::Operator("mul", "Mul")); - nodes.push_back(node3); - scopeRstOri.SetNodes(nodes); - - ScopesResult scopeRst1(scopeRstOri); - ScopeList = scopeRst1.impl_->GetScopes(); - EXPECT_EQ(ScopeList.empty(), false); - NodeList = scopeRst1.impl_->GetNodes(); - EXPECT_EQ(NodeList.empty(), false); - - ScopesResult scopeRst2; - scopeRst2 = scopeRst1; - ScopeList = scopeRst2.impl_->GetScopes(); - EXPECT_EQ(ScopeList.empty(), false); - NodeList = scopeRst2.impl_->GetNodes(); - EXPECT_EQ(NodeList.empty(), false); -} - -namespace { -class ScopePass1 : public ScopeBasePass { -public: - ScopePattern *scoPattern1; - ScopePattern *scoPattern2; - -protected: - std::vector DefinePatterns() { - std::vector>> scoPattern; - std::vector> scoPatternSub; - std::vector scoPatternSubSub1; - std::vector scoPatternSubSub2; - - scoPattern1 = new ScopePattern(); - scoPattern2 = new ScopePattern(); - scoPatternSubSub1.push_back(scoPattern1); - scoPatternSubSub2.push_back(scoPattern2); - - scoPatternSub.push_back(scoPatternSubSub1); - scoPatternSub.push_back(scoPatternSubSub2); - - scoPattern.push_back(scoPatternSub); - return scoPattern; - } - std::string PassName() { - return std::string("passName1"); - } - Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) { - return SUCCESS; - } - void GenerateFusionResult(const std::vector &scopes, - FusionScopesResult *fusion_rlt) { - return; - } -}; - -class ScopePass2 : public ScopeBasePass { -protected: - std::vector DefinePatterns() { - std::vector>> scoPattern; - return scoPattern; - } - std::string PassName() { - return std::string("passName2"); - } - Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) { - return FAILED; - } - void GenerateFusionResult(const std::vector &scopes, - FusionScopesResult *fusion_rlt) { - return; - } -}; - -class ScopePass3 : public ScopeBasePass { -public: - ScopePattern *scoPattern1; - ScopePattern *scoPattern2; - -protected: - std::vector DefinePatterns() { - std::vector>> scoPattern; - std::vector> scoPatternSub; - std::vector scoPatternSubSub1; - std::vector scoPatternSubSub2; - - scoPattern1 = new ScopePattern(); - scoPattern2 = new ScopePattern(); - scoPatternSubSub1.push_back(scoPattern1); - scoPatternSubSub2.push_back(scoPattern2); - - scoPatternSub.push_back(scoPatternSubSub1); - scoPatternSub.push_back(scoPatternSubSub2); - - scoPattern.push_back(scoPatternSub); - return scoPattern; - } - std::string PassName() { - return std::string("passName1"); - } - Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) { - return FAILED; - } - void GenerateFusionResult(const std::vector &scopes, - FusionScopesResult *fusion_rlt) { - return; - } -}; - -class ScopePass4 : public ScopeBasePass { -public: - ScopePattern *scoPattern1; - ScopePattern *scoPattern2; - -protected: - std::vector DefinePatterns() { - std::vector>> scoPattern; - std::vector> scoPatternSub; - std::vector scoPatternSubSub1; - std::vector scoPatternSubSub2; - - scoPattern1 = new ScopePattern(); - scoPattern2 = new ScopePattern(); - scoPatternSubSub1.push_back(scoPattern1); - scoPatternSubSub2.push_back(scoPattern2); - - scoPatternSub.push_back(scoPatternSubSub1); - scoPatternSub.push_back(scoPatternSubSub2); - - scoPattern.push_back(scoPatternSub); - return scoPattern; - } - std::string PassName() { - return std::string("passName1"); - } - Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, - std::vector &results) { - ScopesResult Scope1; - results.push_back(Scope1); - return SUCCESS; - } - void GenerateFusionResult(const std::vector &scopes, - FusionScopesResult *fusion_rlt) { - fusion_rlt->impl_->SetType(kScopeInvalidType); - return; - } -}; - - -void CreateGraph(domi::tensorflow::GraphDef &graph_def) { - // 1. add node - auto placeholder0 = graph_def.add_node(); - auto placeholder1 = graph_def.add_node(); - auto add0 = graph_def.add_node(); - auto add1 = graph_def.add_node(); - auto mul0 = graph_def.add_node(); - auto mul1 = graph_def.add_node(); - auto mul2 = graph_def.add_node(); - auto add2 = graph_def.add_node(); - auto retval0 = graph_def.add_node(); - auto retval1 = graph_def.add_node(); - auto retval2 = graph_def.add_node(); - - // 2. set info - placeholder0->set_name("placeholder0"); - placeholder0->set_op("PlaceHolder"); - placeholder1->set_name("placeholder1"); - placeholder1->set_op("PlaceHolder"); - - add0->set_name("add0"); - add0->set_op("Add"); - add1->set_name("add1"); - add1->set_op("Add"); - add2->set_name("add2"); - add2->set_op("Add"); - - mul0->set_name("mul0"); - mul0->set_op("Mul"); - mul1->set_name("mul1"); - mul1->set_op("Mul"); - mul2->set_name("mul1/mul2"); - mul2->set_op("Mul"); - - retval0->set_name("retval0"); - retval0->set_op("_RetVal"); - retval1->set_name("retval1"); - retval1->set_op("_RetVal"); - retval2->set_name("retval2"); - retval2->set_op("_RetVal"); - - // 3. add edges - add0->add_input("placeholder0"); - add0->add_input("placeholder1"); - - mul0->add_input("placeholder0"); - mul0->add_input("placeholder1"); - - mul1->add_input("placeholder0"); - mul1->add_input("add0"); - mul1->add_input("^mul0"); - - mul2->add_input("mul0"); - mul2->add_input("add0"); - - add1->add_input("mul0"); - add1->add_input("placeholder1"); - - add2->add_input("mul1"); - add2->add_input("mul0"); - - retval0->add_input("add2:0"); - retval1->add_input("add1:0"); - retval2->add_input("mul2:0"); -} -} - -TEST_F(UtestScopePass, ScopePassRun1) { - // no scope match - Status retStatus; - domi::tensorflow::GraphDef graph_def; - CreateGraph(graph_def); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - retStatus = scope_graph->Init(); - ASSERT_EQ(retStatus, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - ScopePass1 scoBasePass; - retStatus = scoBasePass.impl_->Run(scope_graph); - EXPECT_EQ(retStatus, SUCCESS); -} - -TEST_F(UtestScopePass, ScopePassRun4) { - // no scope match - Status retStatus; - domi::tensorflow::GraphDef graph_def; - CreateGraph(graph_def); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - retStatus = scope_graph->Init(); - ASSERT_EQ(retStatus, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - ScopePass4 scoBasePass; - retStatus = scoBasePass.impl_->Run(scope_graph); - EXPECT_EQ(retStatus, domi::SCOPE_NOT_CHANGED); -} - -TEST_F(UtestScopePass, ScopePassRun2) { - // MatchAllBatches failed - Status retStatus; - domi::tensorflow::GraphDef graph_def; - CreateGraph(graph_def); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - retStatus = scope_graph->Init(); - ASSERT_EQ(retStatus, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - ScopePass2 scoBasePass; - retStatus = scoBasePass.impl_->Run(scope_graph); - EXPECT_EQ(retStatus, domi::SCOPE_NOT_CHANGED); -} - -TEST_F(UtestScopePass, ScopePassRun3) { - // LastMatchScopesAndOPs failed - Status retStatus; - domi::tensorflow::GraphDef graph_def; - CreateGraph(graph_def); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - retStatus = scope_graph->Init(); - ASSERT_EQ(retStatus, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - ScopePass3 scoBasePass; - retStatus = scoBasePass.impl_->Run(scope_graph); - EXPECT_EQ(retStatus, domi::SCOPE_NOT_CHANGED); -} - -TEST_F(UtestScopePass, AddFusionScopesResultToScopeGraph1) { - Status retStatus; - std::vector scope_results; - domi::tensorflow::GraphDef graph_def; - CreateGraph(graph_def); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - retStatus = scope_graph->Init(); - ASSERT_EQ(retStatus, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - ScopePass1 scoBasePass; - ScopesResult scopeRst; - std::vector scopes; - std::vector nodes; - // add scope - Scope scope1; - scope1.Init("scope1", "type1"); - scopes.push_back(&scope1); - Scope scope2; - scope2.Init("scope2", "type2"); - scopes.push_back(&scope2); - scopeRst.SetScopes(scopes); - // add node - OperatorPtr node1(new (std::nothrow) ge::Operator("add", "Add")); - nodes.push_back(node1); - OperatorPtr node2(new (std::nothrow) ge::Operator("sub", "Sub")); - nodes.push_back(node2); - OperatorPtr node3(new (std::nothrow) ge::Operator("mul", "Mul")); - nodes.push_back(node3); - scopeRst.SetNodes(nodes); - // add scope result - scope_results.push_back(scopeRst); - retStatus = scoBasePass.impl_->AddFusionScopesResultToScopeGraph(scope_graph, scope_results); - EXPECT_EQ(retStatus, SUCCESS); -} - -TEST_F(UtestScopePass, ScopePassWithWrongInput) { - ScopePass1 scoBasePass; - const std::vector patternlist; - std::vector results; - bool retBool; - Status retStatus; - - retBool = scoBasePass.impl_->MatchOneBatch(nullptr, patternlist, results); - EXPECT_EQ(retBool, false); - - retBool = scoBasePass.impl_->MatchOneScope(nullptr, nullptr, results); - EXPECT_EQ(retBool, false); - - retBool = scoBasePass.impl_->MatchAllBatches(nullptr, results); - EXPECT_EQ(retBool, false); - - std::shared_ptr scope_graph; - scope_graph.reset(); - retStatus = scoBasePass.impl_->PrintFusionScopeInfo(scope_graph); - EXPECT_EQ(retStatus, PARAM_INVALID); -} - -TEST_F(UtestScopePass, MatchOneScope) { - bool retBool; - Status retStatus; - domi::tensorflow::GraphDef graph_def; - CreateGraph(graph_def); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - retStatus = scope_graph->Init(); - ASSERT_EQ(retStatus, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - const ScopeTree *scopeTree = scope_graph->GetScopeTree(); - std::vector scopes = scopeTree->impl_->scopes_; - - ScopePattern scoPattern; - NodeOpTypeFeature feature("nodeType", 0); - scoPattern.AddNodeOpTypeFeature(feature); - - std::vector results; - ScopePass1 scoBasePass; - for (auto scope : scopes) - { - retBool = scoBasePass.impl_->MatchOneScope(&scoPattern, scope, results); - EXPECT_EQ(retBool, false); - } -} - -TEST_F(UtestScopePass, PrintFusionScopeInfo) { - Status retStatus; - domi::tensorflow::GraphDef graph_def; - CreateGraph(graph_def); - - std::shared_ptr scope_graph = std::make_shared(); - ASSERT_NE(scope_graph, nullptr); - retStatus = scope_graph->Init(); - ASSERT_EQ(retStatus, SUCCESS); - auto &impl = scope_graph->impl_; - impl->BuildScopeGraph(&graph_def); - - FusionScopesResult *fusionResult = new (std::nothrow) FusionScopesResult(); - ASSERT_NE(fusionResult, nullptr); - retStatus = fusionResult->Init(); - EXPECT_EQ(retStatus, SUCCESS); - // init - fusionResult->SetName("fusionRstName"); - fusionResult->SetType("fusionRstype"); - fusionResult->SetDescription("fusionRstDesc"); - // add nodes for check fusionOp - std::vector fusionRstNodes; - OperatorPtr op1(new (std::nothrow) ge::Operator("Sub", "Sub")); - fusionRstNodes.push_back(op1); - OperatorPtr op2(new (std::nothrow) ge::Operator("Mul", "Mul")); - fusionRstNodes.push_back(op2); - fusionResult->impl_->AddNodes(fusionRstNodes); - // insert inputs outputs - std::vector index_map = {1, 2}; - fusionResult->InsertInputs("Sub", index_map); - fusionResult->InsertOutputs("Mul", index_map); - // add scopes - Scope scope0; - scope0.Init("scope0", "type0"); - Scope scope1; - scope1.Init("scope1", "type1"); - std::vector scopes = {&scope0, &scope1}; - fusionResult->impl_->AddScopes(scopes); - - impl->AddFusionScopesResult(fusionResult); - - ScopePass1 scoBasePass; - retStatus = scoBasePass.impl_->PrintFusionScopeInfo(scope_graph); - EXPECT_EQ(retStatus, SUCCESS); -} diff --git a/tests/ut/register/testcase/register_tuningtiling_unittest.cc b/tests/ut/register/testcase/register_tuningtiling_unittest.cc deleted file mode 100644 index 86f4a9cb4787c1943938c7365a8b8787c4a798bc..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_tuningtiling_unittest.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "nlohmann/json.hpp" -#include "register/tuning_tiling_registry.h" - -namespace tuningtiling { -BEGIN_TUNING_TILING_DEF(TestMatmul) -TUNING_TILING_DATA_FIELD_DEF(uint32_t, batchdim); -END_TUNING_TILING_DEF - -DECLARE_SCHEMA(TestMatmul, FIELD(TestMatmul, batchdim)); - -BEGIN_TUNING_TILING_DEF(TestDynamic) -TUNING_TILING_DATA_FIELD_DEF(uint32_t, scheduleId); -TUNING_TILING_DATA_FIELD_DEF(TestMatmul, mmtiling); -END_TUNING_TILING_DEF - -DECLARE_SCHEMA(TestDynamic, FIELD(TestDynamic, scheduleId), FIELD(TestDynamic, mmtiling)); - -REGISTER_TUNING_TILING_CLASS(DynamicRnn, TestDynamic); - -class RegisterTuningTilingUT : public testing::Test { - protected: - void SetUp() {} - - void TearDown() {} -}; - -TEST_F(RegisterTuningTilingUT, from_json_ut) { - TestDynamic testdyn; - TestMatmul mm; - mm.batchdim = 1; - testdyn.scheduleId = 10; - testdyn.mmtiling = mm; - nlohmann::json jsonval; - testdyn.ToJson(jsonval); - std::cout << "ori json:" << jsonval.dump() << std::endl; - TuningTilingDefPtr tuingdef = TuningTilingClassFactory::CreateTilingDataInstance(ge::AscendString("unknow")); - EXPECT_EQ(tuingdef == nullptr, true); - tuingdef = TuningTilingClassFactory::CreateTilingDataInstance(ge::AscendString("DynamicRnn")); - EXPECT_EQ(tuingdef != nullptr, true); - auto struct_name = tuingdef->GetClassName(); - EXPECT_EQ(strcmp(struct_name.GetString(), "TestDynamic"), 0); - auto fields = tuingdef->GetItemInfo(); - EXPECT_EQ(fields.size(), 2); - std::cout << struct_name.GetString() << std::endl; - tuingdef->FromJson(jsonval); - nlohmann::json res; - tuingdef->ToJson(res); - EXPECT_EQ(res, jsonval); - std::cout << "expected json:" << res.dump() << std::endl; -} -} // namespace tuningtiling diff --git a/tests/ut/register/testcase/register_unittest.cc b/tests/ut/register/testcase/register_unittest.cc deleted file mode 100644 index 96bc052ef04d9f41086e2966becbb78375b850b3..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/register_unittest.cc +++ /dev/null @@ -1,2587 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "external/register/scope/scope_fusion_pass_register.h" -#include "register/scope/scope_graph_impl.h" -#include "external/graph/operator.h" -#include "graph/utils/node_utils.h" -#include "graph/utils/graph_utils_ex.h" -#include "graph/normal_graph/node_impl.h" -#include "graph/normal_graph/op_desc_impl.h" -#include "graph_builder_utils.h" -#include "graph/debug/ge_op_types.h" -#include "graph/normal_graph/compute_graph_impl.h" -#include "external/register/register.h" -#include "base/registry/op_impl_space_registry_v2.h" -#include "base/registry/op_impl_space_registry_v2_impl.h" -#include "register/op_impl_space_registry.h" -#include "common/util/error_manager/error_manager.h" - -#include -#include -#include -#include - -#include "common/ge_common/debug/ge_log.h" -#include "register/op_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "register/op_tiling_registry.h" -#include "op_tiling/op_tiling_utils.h" -#include "op_tiling/op_tiling_constants.h" -#include "register/op_compile_info_base.h" -#include "op_tiling.h" -#include "base/registry/op_impl_space_registry_v2_impl.h" -#include "register/op_check_register.h" -#include "external/register/op_check.h" -#include "register/tilingdata_base.h" - -#include "graph/graph.h" -#include "graph/utils/attr_utils.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/graph_utils.h" -#include "graph/utils/type_utils.h" -#include "graph/attr_value.h" - -#include "graph/debug/ge_util.h" -#include "graph/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" - -#include "proto/tensorflow/attr_value.pb.h" -#include "proto/tensorflow/node_def.pb.h" -#include "exe_graph/lowering/kernel_run_context_builder.h" -#include "external/register/op_impl_registry.h" -#include "exe_graph/runtime/continuous_vector.h" -#include "common/util/tiling_utils.h" -#include "inc/external/hcom/hcom_topo_info.h" -#include - -using namespace domi; -using namespace ge; -using namespace optiling; -namespace ge { -void to_json(nlohmann::json &j, const HcomTopoInfo::TopoLevelDesc &desc); -void from_json(const nlohmann::json &j, HcomTopoInfo::TopoLevelDesc &desc); -void to_json(nlohmann::json &j, const HcomTopoInfo::TopoInfo &info); -void from_json(const nlohmann::json &j, HcomTopoInfo::TopoInfo &info); -} -namespace { -REG_OP(AddUt) - .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) - .OP_END_FACTORY_REG(AddUt); -// infer from output -REG_OP(FixIOOp_OutputIsFix) - .INPUT(fix_input1, "T") - .INPUT(fix_input2, "T") - .OUTPUT(fix_output, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(FixIOOp_OutputIsFix); -} -class CompileInfoJson : public CompileInfoBase { - public: - CompileInfoJson(const std::string &json) : json_str_(json) {} - ~CompileInfoJson() {} - - private: - std::string json_str_; -}; - -namespace { -struct StubCompileInfo : public CompileInfoBase { - int64_t stub_ = 2; -}; - -void *CreateCompileInfo() { - return new StubCompileInfo(); -} - -void DeleteCompileInfo(void *compile_info) { - delete reinterpret_cast(compile_info); -} - -UINT32 OpTilingStubNew(gert::TilingContext *kernel_context) { - auto tensor_without_data = kernel_context->GetInputTensor(1); - EXPECT_EQ(tensor_without_data->GetAddr(), nullptr); - EXPECT_EQ(tensor_without_data->GetStorageShape(), gert::Shape({5, 5, 5, 5})); - EXPECT_EQ(tensor_without_data->GetOriginShape(), gert::Shape({5, 5, 5, 5})); - auto tensor = kernel_context->GetInputTensor(0); - EXPECT_EQ(tensor->GetShape().GetStorageShape().GetDimNum(), 4); - gert::Shape expect_shape({4, 4, 4, 4}); - EXPECT_EQ(tensor->GetShape().GetStorageShape(), expect_shape); - EXPECT_EQ(tensor->GetDataType(), DT_INT8); - EXPECT_EQ((tensor->GetData())[3], 4); - EXPECT_EQ((tensor->GetData())[2], 3); - EXPECT_EQ((tensor->GetData())[1], 2); - EXPECT_EQ((tensor->GetData())[0], 1); - EXPECT_EQ(tensor->GetFormat().GetStorageFormat(), FORMAT_ND); - gert::Shape expect_shape2({9, 9, 9, 9}); - EXPECT_TRUE(kernel_context->GetOutputShape(0)->GetStorageShape() == expect_shape2); - auto shape = kernel_context->GetInputShape(1); - EXPECT_TRUE(*shape == gert::StorageShape({5, 5, 5, 5}, {5, 5, 5, 5})); - auto ci = kernel_context->GetCompileInfo(); - EXPECT_EQ(reinterpret_cast(ci)->stub_, 1); - - EXPECT_EQ(kernel_context->GetAttrs()->GetAttrNum(), 4); - std::vector expect_attr = {1, 2, 3, 4}; - for (size_t i = 0UL; i < 4UL; ++i) { - EXPECT_EQ(reinterpret_cast( - kernel_context->GetAttrs()->GetAttrPointer(0)->GetData())[i], - expect_attr[i]); - } - EXPECT_EQ(*kernel_context->GetAttrs()->GetAttrPointer(1), 99); - kernel_context->SetBlockDim(2); - kernel_context->SetAicpuBlockDim(4); - kernel_context->SetNeedAtomic(true); - kernel_context->SetTilingKey(78); - *kernel_context->GetWorkspaceSizes(1) = 12; - kernel_context->GetRawTilingData()->Append(6); - kernel_context->GetRawTilingData()->Append(7); - kernel_context->GetRawTilingData()->Append(8); - kernel_context->GetRawTilingData()->Append(9); - kernel_context->GetRawTilingData()->Append(10); - return ge::GRAPH_SUCCESS; -} - -UINT32 OpTilingParseStubNew(gert::KernelContext *kernel_context) { - auto ci = kernel_context->GetOutputPointer(0); - ci->stub_ = 1; - return ge::GRAPH_SUCCESS; -} - -UINT32 OpTilingStubNewWithName(gert::TilingContext *kernel_context) { - std::string node_name = kernel_context->GetNodeName(); - EXPECT_EQ(node_name, "test"); - return ge::GRAPH_SUCCESS; -} - -UINT32 OpTilingStubV5(gert::TilingContext *kernel_context) { - auto tensor = kernel_context->GetInputTensor(0); - std::vector real_data = {1.1, 2.1, 3.1, 4.1}; - for (size_t i = 0UL; i < 4UL; ++i) { - EXPECT_EQ((tensor->GetData())[i], optiling::Float32ToFloat16(real_data[i])); - } - return ge::GRAPH_SUCCESS; -} - -UINT32 OpTilingStubBf16(gert::TilingContext *kernel_context) { - auto tensor = kernel_context->GetInputTensor(0); - std::vector real_data = {1.1, 2.1, 3.1, 4.1}; - for (size_t i = 0UL; i < 4UL; ++i) { - EXPECT_EQ((tensor->GetData())[i], optiling::Float32ToBfloat16(real_data[i])); - } - (void)kernel_context->SetNeedAtomic(true); - return ge::GRAPH_SUCCESS; -} - -UINT32 OpTilingStubNewWithDynamicInput(gert::TilingContext *kernel_context) { - auto shape = kernel_context->GetDynamicInputShape(0, 0); - EXPECT_EQ(*shape, gert::StorageShape( {4, 256, 200, 336}, {4, 16, 200, 336, 16})); - auto shape0_1 = kernel_context->GetDynamicInputShape(0, 1); - EXPECT_EQ(*shape0_1, gert::StorageShape( {4, 256, 100, 168}, {4, 16, 100, 168, 16})); - auto shape_optional_1 = kernel_context->GetOptionalInputShape(1); - EXPECT_EQ(shape_optional_1, nullptr); - auto shape_optional_2 = kernel_context->GetOptionalInputShape(2); - EXPECT_EQ(shape_optional_2, nullptr); - auto shape_2 = kernel_context->GetDynamicInputShape(3, 0); - EXPECT_EQ(*shape_2, gert::StorageShape({100, 5}, {100, 5})); - auto input_1 = kernel_context->GetInputShape(4); - EXPECT_EQ(*input_1, gert::StorageShape({100, 5}, {100, 5})); - auto output = kernel_context->GetOutputShape(0); - EXPECT_EQ(*output, gert::StorageShape({9, 9, 9, 9}, {9, 9, 9, 9})); - EXPECT_EQ(kernel_context->GetComputeNodeInputNum(), 5); - EXPECT_EQ(kernel_context->GetComputeNodeOutputNum(), 1); - return ge::GRAPH_SUCCESS; -} - -UINT32 DefaultOptilingStub(gert::TilingContext *kernel_context) { - (void)kernel_context->SetNeedAtomic(true); - return ge::GRAPH_SUCCESS; -} - -UINT32 OpTilingParseStubV5(gert::KernelContext *kernel_context) { - auto av = kernel_context->GetOutput(0); - av->Set(CreateCompileInfo(), DeleteCompileInfo); - return ge::GRAPH_SUCCESS; -} - -UINT32 OpTilingStubV6(gert::TilingContext *kernel_context) { - auto input_desc = kernel_context->GetComputeNodeInfo()->GetInputTdInfo(0); - EXPECT_EQ(input_desc->GetFormat().GetStorageFormat(), - ge::GetFormatFromSub(static_cast(Format::FORMAT_FRACTAL_Z), 32)); - auto tensor = kernel_context->GetInputTensor(0); - std::vector real_data = {1.1, 2.1, 3.1, 4.1}; - for (size_t i = 0UL; i < 4UL; ++i) { - EXPECT_EQ((tensor->GetData())[i], optiling::Float32ToFloat16(real_data[i])); - } - return ge::GRAPH_SUCCESS; -} - -bool op_tiling_stub_failed(const Operator &op, const utils::OpCompileInfo &compile_info, utils::OpRunInfo &run_info) { - EXPECT_EQ(true, false); - return true; -} - -void StubError() { - auto &instance = ErrorManager::GetInstance(); - EXPECT_EQ(instance.GetErrorMessage(), ""); - std::vector vec; - vec.push_back(ErrorManager::ErrorItem()); - vec[0].error_id = "E19999"; - vec[0].error_message = "Get peer anchor failed"; - vec.push_back(ErrorManager::ErrorItem()); - vec[1].error_id = "E80001"; - std::map args_map; - args_map.emplace("config_name", "os"); - args_map.emplace("file_name", "scene.info"); - vec[1].args_map = args_map; - instance.error_message_per_work_id_[0] = vec; - instance.error_context_.work_stream_id = 0; -} - -bool CheckErrorRetFormat(const std::string &ret_json_str) { - EXPECT_NE(ret_json_str, ""); - nlohmann::json ret_json; - try { - ret_json = nlohmann::json::parse(ret_json_str); - EXPECT_TRUE(ret_json.contains("ret_code")); - EXPECT_EQ(ret_json.at("ret_code"), 1); - EXPECT_TRUE(ret_json.contains("error_messages")); - EXPECT_EQ(ret_json.at("error_messages").size(), 2); - EXPECT_EQ(ret_json.at("error_messages")[0].at("type"), 2); - EXPECT_EQ(ret_json.at("error_messages")[0].at("errorcode"), "E19999"); - EXPECT_EQ(ret_json.at("error_messages")[0].at("errormsg"), "Get peer anchor failed"); - EXPECT_EQ(ret_json.at("error_messages")[1].at("type"), 1); - EXPECT_EQ(ret_json.at("error_messages")[1].at("errorcode"), "E80001"); - std::map args_map; - args_map.emplace("config_name", "os"); - args_map.emplace("file_name", "scene.info"); - EXPECT_EQ(ret_json.at("error_messages")[1].at("errormsg"), args_map); - } catch (const nlohmann::json::exception &e) { - std::cout << "parse failed, reason: " << e.what() << std::endl; - return false; - } - return true; -} - -void ReInitErrorManager() { - ErrorManager::GetInstance().is_init_ = false; - ErrorManager::GetInstance().compile_failed_msg_map_.clear(); - ErrorManager::GetInstance().compile_failed_msg_map_.clear(); - ErrorManager::GetInstance().error_message_per_work_id_.clear(); - ErrorManager::GetInstance().warning_messages_per_work_id_.clear(); -} -} // namespace - -class UtestRegister : public testing::Test { - protected: - void SetUp() {} - - void TearDown() { - ErrorManager::GetInstance().is_init_ = false; - ErrorManager::GetInstance().compile_failed_msg_map_.clear(); - ErrorManager::GetInstance().compile_failed_msg_map_.clear(); - ErrorManager::GetInstance().error_message_per_work_id_.clear(); - ErrorManager::GetInstance().warning_messages_per_work_id_.clear(); - } -}; - -extern "C" int TbeOpTilingPyInterfaceEx2(const char *optype, const char *compile_info, const char *inputs, - const char *outputs, char *run_info_json, size_t run_info_len, - const char *compile_info_hash, uint64_t *elapse); - -extern "C" int TbeOpTilingPyInterface(const char *optype, const char *compile_info, const char *compile_info_hash, - const char *inputs, const char *outputs, const char *attrs, char *run_info_json, - size_t run_info_len, uint64_t *elapse); - -extern "C" const char *DoOpTilingForCompile(const char *optype, const char *compile_info, const char *compile_info_hash, - const char *inputs, - const char *outputs, - const char *attrs, - char *run_info_json, - size_t run_info_len, - uint64_t *elapse, - const char *extra_info); -bool op_tiling_stub_v2(const Operator &op, const utils::OpCompileInfo &compile_info, utils::OpRunInfo &run_info) { - return true; -} - -bool op_tiling_stub_v3(const Operator &op, const void *value, OpRunInfoV2 &run_info) { - return true; -} - -void *op_parse_stub_v3(const Operator &op, const ge::AscendString &compile_info_json) { - // static void *p = new int(3); - static int x = 1024; - void *p = &x; - return p; -} - -bool op_tiling_stub_v4(const Operator &op, const CompileInfoPtr value, OpRunInfoV2 &run_info) { - return true; -} - -CompileInfoPtr op_parse_stub_v4(const Operator &op, const ge::AscendString &compile_info_json) { - // static void *p = new int(3); - CompileInfoPtr info = std::make_shared("qwer"); - return info; -} - - -UINT32 OpTilingStubNewWithNullDesc(gert::TilingContext *kernel_context) { - auto tensor_without_data = kernel_context->GetInputTensor(1); - EXPECT_EQ(tensor_without_data->GetAddr(), nullptr); - EXPECT_EQ(tensor_without_data->GetStorageShape(), gert::Shape({5, 5, 5, 5})); - EXPECT_EQ(tensor_without_data->GetOriginShape(), gert::Shape({5, 5, 5, 5})); - auto tensor = kernel_context->GetInputTensor(0); - EXPECT_EQ(tensor->GetShape().GetStorageShape().GetDimNum(), 4); - gert::Shape expect_shape({4, 4, 4, 4}); - EXPECT_EQ(tensor->GetShape().GetStorageShape(), expect_shape); - EXPECT_EQ(tensor->GetDataType(), DT_FLOAT); - EXPECT_EQ((tensor->GetData())[3], -std::numeric_limits::infinity()); - EXPECT_EQ(std::isnan((tensor->GetData())[2]), true); - EXPECT_EQ((tensor->GetData())[1], 2.0); - EXPECT_EQ((tensor->GetData())[0], std::numeric_limits::infinity()); - EXPECT_EQ(tensor->GetFormat().GetStorageFormat(), FORMAT_ND); - gert::Shape expect_shape2({9, 9, 9, 9}); - EXPECT_TRUE(kernel_context->GetOutputShape(0)->GetStorageShape() == expect_shape2); - auto shape = kernel_context->GetInputShape(1); - EXPECT_TRUE(*shape == gert::StorageShape({5, 5, 5, 5}, {5, 5, 5, 5})); - auto ci = kernel_context->GetCompileInfo(); - EXPECT_EQ(reinterpret_cast(ci)->stub_, 1); - - EXPECT_EQ(kernel_context->GetAttrs()->GetAttrNum(), 4); - std::vector expect_attr = {std::numeric_limits::infinity(), 2.0, - std::numeric_limits::quiet_NaN(), -std::numeric_limits::infinity()}; - for (size_t i = 0UL; i < 4UL; ++i) { - if (i == 2U) { - EXPECT_EQ(std::isnan(reinterpret_cast( - kernel_context->GetAttrs()->GetAttrPointer(0)->GetData())[i]), true); - continue; - } - EXPECT_EQ(reinterpret_cast( - kernel_context->GetAttrs()->GetAttrPointer(0)->GetData())[i], - expect_attr[i]); - } - EXPECT_EQ(*kernel_context->GetAttrs()->GetAttrPointer(1), std::numeric_limits::infinity()); - kernel_context->SetBlockDim(2); - kernel_context->SetAicpuBlockDim(4); - kernel_context->SetNeedAtomic(true); - kernel_context->SetTilingKey(78); - *kernel_context->GetWorkspaceSizes(1) = 12; - kernel_context->GetRawTilingData()->Append(6); - kernel_context->GetRawTilingData()->Append(7); - kernel_context->GetRawTilingData()->Append(8); - kernel_context->GetRawTilingData()->Append(9); - kernel_context->GetRawTilingData()->Append(10); - return ge::GRAPH_SUCCESS; -} - -void SupportInfNanWithNullDescInvalidTestCase(const nlohmann::json &input, const nlohmann::json &output, - const nlohmann::json &attrs) { - std::string input_str = input.dump(); - std::string output_str = output.dump(); - std::string attrs_str = attrs.dump(); - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - std::string runinfo(100, 'a'); - size_t size = 100; - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubNewWithNullDesc; - op_impl_func.tiling_parse = OpTilingParseStubNew; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(runinfo.c_str()), size, elapse), - 0); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -REGISTER_OP_TILING_V2(ReluV2, op_tiling_stub_v2); -REGISTER_OP_TILING_V3(ReluV3, op_tiling_stub_v3, op_parse_stub_v3); -REGISTER_OP_TILING_V4(ReluV4, op_tiling_stub_v4, op_parse_stub_v4); - -TEST_F(UtestRegister, test_register_dynamic_outputs_op_only_has_partial_output) { - ut::GraphBuilder builder = ut::GraphBuilder("graph"); - auto node_src = builder.AddNode("ParseSingleExample", "ParseSingleExample", - {"serialized", "dense_defaults_0", "dense_defaults_1", "dense_defaults_2"}, - {"dense_values_0", "dense_values_1", "dense_values_2"}); - // build op_src attrs - vector dense_keys = {"image/class/lable", "image/encode", "image/format"}; - vector t_dense = {DT_INT64, DT_STRING, DT_STRING}; - AttrUtils::SetListStr(node_src->GetOpDesc(), "dense_keys", dense_keys); - AttrUtils::SetListStr(node_src->GetOpDesc(), "dense_shapes", {}); - AttrUtils::SetInt(node_src->GetOpDesc(), "num_sparse", 0); - AttrUtils::SetListStr(node_src->GetOpDesc(), "sparse_keys", {}); - AttrUtils::SetListStr(node_src->GetOpDesc(), "sparse_types", {}); - AttrUtils::SetListDataType(node_src->GetOpDesc(), "Tdense", t_dense); - auto graph = builder.GetGraph(); - - // get op_src - ge::Operator op_src = OpDescUtils::CreateOperatorFromNode(node_src); - ge::Operator op_dst = ge::Operator("ParseSingleExample"); - std::shared_ptr op_desc_dst = ge::OpDescUtils::GetOpDescFromOperator(op_dst); - op_desc_dst->AddRegisterInputName("dense_defaults"); - op_desc_dst->AddRegisterOutputName("sparse_indices"); - op_desc_dst->AddRegisterOutputName("sparse_values"); - op_desc_dst->AddRegisterOutputName("sparse_shapes"); - op_desc_dst->AddRegisterOutputName("dense_values"); - - // simulate parse_single_example plugin - std::vector value; - DynamicInputOutputInfo input(kInput, "dense_defaults", 14, "Tdense", 6); - value.push_back(input); - DynamicInputOutputInfo output(kOutput, "sparse_indices", 14, "num_sparse", 10); - value.push_back(output); - DynamicInputOutputInfo output1(kOutput, "sparse_values", 13, "sparse_types", 12); - value.push_back(output1); - DynamicInputOutputInfo output2(kOutput, "sparse_shapes", 13, "num_sparse", 10); - value.push_back(output2); - DynamicInputOutputInfo output3(kOutput, "dense_values", 12, "Tdense", 6); - value.push_back(output3); - DynamicInputOutputInfo invalidput(kInvalid, "Invalid", 7, "Invalid", 7); - value.push_back(invalidput); - - // pre_check - EXPECT_EQ(op_dst.GetOutputsSize(), 0); - auto ret = AutoMappingByOpFnDynamic(op_src, op_dst, value); - - // check add 3 output to op_dst - EXPECT_EQ(ret, domi::SUCCESS); - EXPECT_EQ(op_dst.GetOutputsSize(), 3); - - // for AutoMappingByOpFnDynamic failed test - ge::Operator op_src_fail(nullptr); - ret = AutoMappingByOpFnDynamic(op_src_fail, op_dst, value); - EXPECT_EQ(ret, domi::FAILED); - - std::vector value_fail; - ret = AutoMappingByOpFnDynamic(op_src, op_dst, value_fail); - DynamicInputOutputInfo input_fail(kInput, "", 0, "", 0); - value_fail.push_back(input_fail); - ret = AutoMappingByOpFnDynamic(op_src, op_dst, value_fail); - EXPECT_EQ(ret, domi::FAILED); -} - -void GraphInit(domi::tensorflow::GraphDef &graph_def) { - // add node, set info - domi::tensorflow::NodeDef *placeholder0 = graph_def.add_node(); - placeholder0->set_name("placeholder0"); - placeholder0->set_op("PlaceHolder"); - - // add node, set info, add edges - domi::tensorflow::NodeDef *add0 = graph_def.add_node(); - add0->set_name("add0"); - add0->set_op("Add"); - add0->add_input("placeholder0"); - add0->add_input("placeholder1"); - - // 1. add node - auto placeholder1 = graph_def.add_node(); - auto add1 = graph_def.add_node(); - auto mul0 = graph_def.add_node(); - auto mul1 = graph_def.add_node(); - auto add2 = graph_def.add_node(); - auto retval0 = graph_def.add_node(); - auto retval1 = graph_def.add_node(); - - // 2. set info - placeholder1->set_name("placeholder1"); - placeholder1->set_op("PlaceHolder"); - add1->set_name("add1"); - add1->set_op("Add"); - add2->set_name("add2"); - add2->set_op("Add"); - mul0->set_name("mul0"); - mul0->set_op("Mul"); - mul1->set_name("mul1"); - mul1->set_op("Mul"); - retval0->set_name("retval0"); - retval0->set_op("_RetVal"); - retval1->set_name("retval1"); - retval1->set_op("_RetVal"); - - // 3. add edges - mul0->add_input("placeholder0"); - mul0->add_input("placeholder1"); - mul1->add_input("placeholder0"); - mul1->add_input("add0"); - mul1->add_input("^mul0"); - add1->add_input("mul0"); - add1->add_input("placeholder1"); - add2->add_input("mul1"); - add2->add_input("mul0"); - retval0->add_input("add2:0"); - retval1->add_input("add1:0"); -} - -int32_t AutoMappingSubgraphIndexInput(int32_t data_index) { - return 0; -} -int32_t AutoMappingSubgraphIndexOutput(int32_t netoutput_index) { - return 0; -} -Status AutoMappingSubgraphIndexInput2(int32_t data_index, int32_t &parent_input_index) { - return domi::SUCCESS; -} -Status AutoMappingSubgraphIndexOutput2(int32_t netoutput_index, int32_t &parent_output_index) { - parent_output_index++; - return domi::SUCCESS; -} -Status AutoMappingSubgraphIndexOutput2Failed(int32_t netoutput_index, int32_t &parent_output_index) { - return domi::FAILED; -} - -TEST_F(UtestRegister, AutoMappingSubgraphIndex) { - Status stat; - auto builder = ut::GraphBuilder("root"); - auto output = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - auto input = builder.AddNode("data", DATA, 1, 1); - input->impl_->op_->impl_->meta_data_.type_ = "Data"; - auto func_node = builder.AddNode("func_node", FRAMEWORKOP, 1, 1); - func_node->impl_->op_->impl_->meta_data_.type_ = "FrameworkOp"; - builder.AddDataEdge(input, 0, func_node, 0); - builder.AddDataEdge(func_node, 0, output, 0); - - auto computeGraph = builder.GetGraph(); - Graph graph = GraphUtilsEx::CreateGraphFromComputeGraph(computeGraph); - stat = AutoMappingSubgraphIndex(graph, AutoMappingSubgraphIndexInput, AutoMappingSubgraphIndexOutput); - EXPECT_EQ(stat, domi::FAILED); -} - -TEST_F(UtestRegister, AutoMappingSubgraphIndexByDataNode) { - Status stat; - auto builder = ut::GraphBuilder("root"); - auto output = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - auto func_node = builder.AddNode("func_node", PARTITIONEDCALL, 1, 1); - builder.AddDataEdge(func_node, 0, output, 0); - - auto computeGraph = builder.GetGraph(); - Graph graph = GraphUtilsEx::CreateGraphFromComputeGraph(computeGraph); - stat = AutoMappingSubgraphIndex(graph, AutoMappingSubgraphIndexInput2, AutoMappingSubgraphIndexOutput2); - EXPECT_EQ(stat, domi::SUCCESS); - - auto input = builder.AddNode("Retval", DATA, 1, 1); - input->impl_->op_->impl_->meta_data_.type_ = "_Retval"; - AttrUtils::SetInt(input->GetOpDesc(), "retval_index", 0); - builder.AddDataEdge(input, 0, func_node, 0); - stat = AutoMappingSubgraphIndex(graph, AutoMappingSubgraphIndexInput2, AutoMappingSubgraphIndexOutput2); - EXPECT_EQ(stat, domi::SUCCESS); -} - -TEST_F(UtestRegister, AutoMappingSubgraphIndexByDataNode2) { - Status stat; - auto builder = ut::GraphBuilder("root"); - auto input = builder.AddNode("index", DATA, 1, 1); - input->impl_->op_->impl_->meta_data_.type_ = "Data"; - AttrUtils::SetInt(input->GetOpDesc(), "index", 0); - auto output = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - auto func_node = builder.AddNode("func_node", PARTITIONEDCALL, 1, 1); - builder.AddDataEdge(input, 0, func_node, 0); - builder.AddDataEdge(func_node, 0, output, 0); - - auto computeGraph = builder.GetGraph(); - Graph graph = GraphUtilsEx::CreateGraphFromComputeGraph(computeGraph); - stat = AutoMappingSubgraphIndex(graph, AutoMappingSubgraphIndexInput2, AutoMappingSubgraphIndexOutput2); - EXPECT_EQ(stat, domi::SUCCESS); -} - -TEST_F(UtestRegister, AutoMappingSubgraphOutputFail) { - Status stat; - auto builder = ut::GraphBuilder("root"); - auto output = builder.AddNode("netoutput", NETOUTPUT, 1, 0); - auto input = builder.AddNode("data", DATA, 1, 1); - input->impl_->op_->impl_->meta_data_.type_ = "Data"; - auto func_node = builder.AddNode("func_node", FRAMEWORKOP, 1, 1); - func_node->impl_->op_->impl_->meta_data_.type_ = "FrameworkOp"; - builder.AddDataEdge(input, 0, func_node, 0); - builder.AddDataEdge(func_node, 0, output, 0); - - auto computeGraph = builder.GetGraph(); - Graph graph = GraphUtilsEx::CreateGraphFromComputeGraph(computeGraph); - - stat = AutoMappingSubgraphIndex(graph, AutoMappingSubgraphIndexInput2, AutoMappingSubgraphIndexOutput2Failed); - EXPECT_EQ(stat, domi::FAILED); -} - -TEST_F(UtestRegister, AutoMappingFnDynamicInputTest) { - Status stat; - domi::tensorflow::GraphDef graph_def; - GraphInit(graph_def); - map> name_attr_value; - name_attr_value.insert(make_pair(std::string("in"), make_pair(std::string("dynamicName1"), std::string("int")))); - name_attr_value.insert(make_pair(std::string("out"), make_pair(std::string("dynamicName2"), std::string("float")))); - - ge::Operator op_dst = ge::Operator("Add", "int"); - const domi::tensorflow::NodeDef *node; - - int32_t node_size = graph_def.node_size(); - for (int i = 0; i < node_size; i++) { - node = graph_def.mutable_node(i); - stat = AutoMappingFnDynamic(node, op_dst, name_attr_value, 1, 1); - EXPECT_EQ(stat, domi::SUCCESS); - } -} - -TEST_F(UtestRegister, AutoMappingFnDynamicInput) { - Status retStat; - domi::tensorflow::GraphDef graph_def; - GraphInit(graph_def); - - ge::Operator op_dst = ge::Operator("Add", "int"); - domi::tensorflow::NodeDef *node = graph_def.mutable_node(0); - - // test for add attrs - map> name_attrs; - domi::tensorflow::AttrValue inValue; - inValue.set_s(std::string("stringValue")); - inValue.set_i(66); - node->mutable_attr()->insert({"inVal", inValue}); - name_attrs.insert(make_pair(std::string("in"), make_pair(std::string("inName1"), std::string("inVal")))); - retStat = AutoMappingFnDynamic(node, op_dst, name_attrs, 1, 1); - EXPECT_EQ(retStat, domi::SUCCESS); -} - -TEST_F(UtestRegister, AutoMappingFnDynamicOutput) { - Status retStat; - domi::tensorflow::GraphDef graph_def; - GraphInit(graph_def); - - ge::Operator op_dst = ge::Operator("Add", "int"); - domi::tensorflow::NodeDef *node = graph_def.mutable_node(0); - - // test for add attrs - map> name_attrs; - domi::tensorflow::AttrValue outValue; - outValue.set_b(true); - outValue.set_i(88); - node->mutable_attr()->insert({"outVal", outValue}); - name_attrs.insert(make_pair(std::string("out"), make_pair(std::string("outName1"), std::string("outVal")))); - retStat = AutoMappingFnDynamic(node, op_dst, name_attrs, 1, 1); - EXPECT_EQ(retStat, domi::SUCCESS); -} - -TEST_F(UtestRegister, AutoMappingFunctionkFunc) { - Status retStat; - domi::tensorflow::GraphDef graph_def; - GraphInit(graph_def); - - ge::Operator op_dst = ge::Operator("Add", "int"); - op_dst.SubgraphRegister("subVal", true); - op_dst.SubgraphCountRegister("subVal", 6); - - // test for add attrs - domi::tensorflow::NodeDef *node = graph_def.mutable_node(0); - map> name_attrs; - domi::tensorflow::AttrValue attrValue; - attrValue.set_i(88); - domi::tensorflow::NameAttrList *nameAttrList = new domi::tensorflow::NameAttrList(); - nameAttrList->set_name("nameAttrList"); - attrValue.unsafe_arena_set_allocated_func(nameAttrList); - node->mutable_attr()->insert({"subVal", attrValue}); - name_attrs.insert(make_pair(std::string("out"), make_pair(std::string("outName1"), std::string("subVal")))); - retStat = AutoMappingFnDynamic(node, op_dst, name_attrs, 1, 1); - EXPECT_EQ(retStat, domi::FAILED); -} - -TEST_F(UtestRegister, AutoMappingFunctionkList) { - Status retStat; - domi::tensorflow::GraphDef graph_def; - GraphInit(graph_def); - - ge::Operator op_dst = ge::Operator("Add", "int"); - op_dst.SubgraphRegister("subVal", true); - - // test for add attrs - domi::tensorflow::NodeDef *node = graph_def.mutable_node(0); - map> name_attrs; - domi::tensorflow::AttrValue attrValue; - attrValue.set_i(88); - domi::tensorflow::AttrValue_ListValue *attrValListVal = new domi::tensorflow::AttrValue_ListValue(); - attrValListVal->add_s("list0"); - attrValListVal->add_s("list1"); - attrValue.unsafe_arena_set_allocated_list(attrValListVal); - // list.func - domi::tensorflow::NameAttrList *nameAttrList = new domi::tensorflow::NameAttrList(); - nameAttrList->set_name("nameAttrList"); - attrValListVal->add_func(); - - node->mutable_attr()->insert({"subVal", attrValue}); - name_attrs.insert(make_pair(std::string("out"), make_pair(std::string("outName1"), std::string("subVal")))); - retStat = AutoMappingFnDynamic(node, op_dst, name_attrs, 1, 1); - EXPECT_EQ(retStat, domi::SUCCESS); - delete nameAttrList; -} - -domi::Status inputFunc(int32_t data_index, int32_t &parent_input_index) { - parent_input_index++; - return (parent_input_index < 0) ? domi::FAILED : domi::SUCCESS; -} - -domi::Status outputFunc(int32_t netoutput_index, int32_t &parent_output_index) { - parent_output_index++; - return (parent_output_index < 2) ? domi::FAILED : domi::SUCCESS; -} - -domi::Status AutoMappingSubgraphIOIndexFuncCB( - const ge::Graph &graph, const std::function &input, - const std::function &output) { - static int test_idx = -2; - - switch (test_idx) { - case -2: - return input(0, test_idx); - case -1: - return input(0, test_idx); - case 0: - return output(0, test_idx); - case 1: - return output(0, test_idx); - } - return domi::SUCCESS; -} - -TEST_F(UtestRegister, FrameworkRegistryTest) { - auto TENSORFLOW = domi::TENSORFLOW; - REGISTER_AUTOMAPPING_SUBGRAPH_IO_INDEX_FUNC(TENSORFLOW, AutoMappingSubgraphIOIndexFuncCB); - - FrameworkRegistry &cur = FrameworkRegistry::Instance(); - cur.AddAutoMappingSubgraphIOIndexFunc(domi::CAFFE, AutoMappingSubgraphIOIndexFuncCB); - - const ge::Graph graph("graph_test"); - AutoMappingSubgraphIOIndexFunc func = cur.GetAutoMappingSubgraphIOIndexFunc(domi::CAFFE); - EXPECT_EQ(func(graph, inputFunc, outputFunc), domi::FAILED); - EXPECT_EQ(func(graph, inputFunc, outputFunc), domi::SUCCESS); - EXPECT_EQ(func(graph, inputFunc, outputFunc), domi::FAILED); - EXPECT_EQ(func(graph, inputFunc, outputFunc), domi::SUCCESS); -} - -TEST_F(UtestRegister, OpRegistrationDataWithNoImpl) { - OpRegistrationData opRegData(std::string("OmOptype")); - opRegData.impl_.reset(); - - EXPECT_EQ(opRegData.GetOmOptype() == "", true); - EXPECT_EQ(opRegData.GetFrameworkType(), domi::FRAMEWORK_RESERVED); - EXPECT_EQ(opRegData.GetOriginOpTypeSet().empty(), true); - EXPECT_EQ(opRegData.GetParseParamFn(), nullptr); - EXPECT_EQ(opRegData.GetParseParamByOperatorFn(), nullptr); - EXPECT_EQ(opRegData.GetFusionParseParamFn(), nullptr); - EXPECT_EQ(opRegData.GetFusionParseParamByOpFn(), nullptr); - EXPECT_EQ(opRegData.GetImplyType(), domi::ImplyType::BUILDIN); - EXPECT_EQ(opRegData.GetParseSubgraphPostFn(), nullptr); - EXPECT_EQ(opRegData.GetParseOpToGraphFn(), nullptr); - ParseSubgraphFuncV2 func; - EXPECT_EQ(opRegData.GetParseSubgraphPostFn(func), domi::FAILED); -} - -TEST_F(UtestRegister, OmOptypeTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - OpReceiver oprcver(opRegData); - opRegData.GetOmOptype(); - - AscendString OmOptype; - Status stat = opRegData.GetOmOptype(OmOptype); - EXPECT_EQ(stat, domi::SUCCESS); -} - -TEST_F(UtestRegister, FrameworkTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - opRegData.FrameworkType(domi::MINDSPORE); - EXPECT_EQ(opRegData.GetFrameworkType(), domi::MINDSPORE); -} - -TEST_F(UtestRegister, OriOpTypeTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - OpRegistrationData opRegData2("OmOptype2"); - - std::initializer_list OptypeList1{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList1); - std::vector OptypeList2 = {AscendString("Div"), AscendString("Mul")}; - opRegData.OriginOpType(OptypeList2); - - opRegData2.OriginOpType(std::string("Add")); - opRegData2.OriginOpType("Sub"); - - opRegData.GetOriginOpTypeSet(); - std::set opTypeSet; - Status stat = opRegData.GetOriginOpTypeSet(opTypeSet); - EXPECT_EQ(stat, domi::SUCCESS); -} - -TEST_F(UtestRegister, OpRegistryImplyTypeTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - std::vector OptypeList2 = {AscendString("Div"), AscendString("Mul")}; - opRegData.OriginOpType(OptypeList2); - - // set ImplyType - opRegData.ImplyType(domi::ImplyType::CUSTOM); - EXPECT_EQ(opRegData.GetImplyType(), domi::ImplyType::CUSTOM); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - - domi::ImplyType implType = opReg->GetImplyTypeByOriOpType(std::string("Add")); - EXPECT_EQ(implType, domi::ImplyType::CUSTOM); - - implType = opReg->GetImplyType(std::string("OmOptype")); - EXPECT_EQ(implType, domi::ImplyType::CUSTOM); - implType = opReg->GetImplyType(std::string("strOmOptype")); - EXPECT_EQ(implType, domi::ImplyType::BUILDIN); - - vector vecOpType; - vecOpType.clear(); - opReg->GetOpTypeByImplyType(vecOpType, domi::ImplyType::CUSTOM); - EXPECT_EQ(vecOpType.empty(), false); - vecOpType.clear(); - opReg->GetOpTypeByImplyType(vecOpType, domi::ImplyType::AI_CPU); - EXPECT_EQ(vecOpType.empty(), true); -} - -TEST_F(UtestRegister, DelInputWithTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - - opRegData.ParseParamsFn(domi::AutoMappingFn); - EXPECT_NE(opRegData.GetParseParamFn(), nullptr); - - // insert input into vector - const vector input_order{0, 1, 3, 2}; - opRegData.InputReorderVector(input_order); - - opRegData.DelInputWithCond(1, std::string("attrName_1"), true); - opRegData.DelInputWithCond(2, "attrName_2", false); - - opRegData.DelInputWithOriginalType(3, std::string("Add")); - opRegData.DelInputWithOriginalType(4, "Sub"); - - OpRegistry *opReg = OpRegistry::Instance(); - ASSERT_NE(opReg, nullptr); - bool retBool = opReg->Register(opRegData); - ASSERT_EQ(retBool, true); - - std::vector rmConfigVec; - rmConfigVec = opReg->GetRemoveInputConfigure(std::string("Add")); - EXPECT_EQ(rmConfigVec.empty(), true); - rmConfigVec = opReg->GetRemoveInputConfigure(std::string("Mul")); - EXPECT_EQ(rmConfigVec.empty(), true); - rmConfigVec = opReg->GetRemoveInputConfigure(std::string("Mul666")); - EXPECT_EQ(rmConfigVec.empty(), true); -} - -TEST_F(UtestRegister, GetOmTypeByOriOpTypeTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - std::string om_type; - EXPECT_EQ(opReg->GetOmTypeByOriOpType(std::string("Sub"), om_type), true); - EXPECT_EQ(opReg->GetOmTypeByOriOpType(std::string("Sub1"), om_type), false); -} - -domi::Status FusionParseParamsFnCB(const std::vector Msg, ge::Operator &Op) { - return domi::SUCCESS; -} -domi::Status FusionParseParamsFnCB2(const std::vector &VecOp, ge::Operator &Op) { - return domi::FAILED; -} -domi::Status ParseSubgraphPostFnCB(const std::string &subgraph_name, const ge::Graph &graph) { - return domi::SUCCESS; -} -domi::Status ParseSubgraphPostFnCB2(const ge::AscendString &subgraph_name, const ge::Graph &graph) { - return domi::SUCCESS; -} -domi::Status ParseOpToGraphFnCB(const ge::Operator &Op, ge::Graph &Graph) { - return domi::SUCCESS; -} - -TEST_F(UtestRegister, ParseParamFuncTest) { - const std::string strOmOptype = "OmOptype"; - OpRegistrationData opRegData(strOmOptype); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - std::vector OptypeListAStr = {AscendString("Div"), AscendString("Mul")}; - opRegData.OriginOpType(OptypeListAStr); - - opRegData.ParseParamsFn(domi::AutoMappingFn); - EXPECT_NE(opRegData.GetParseParamFn(), nullptr); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - - EXPECT_EQ(opReg->GetParseParamFunc(std::string("OmOptype1"), std::string("Sub")), nullptr); - EXPECT_EQ(opReg->GetParseParamFunc(std::string("OmOptype"), std::string("Sub")), nullptr); -} - -TEST_F(UtestRegister, FusionParseParamFuncTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - - opRegData.FusionParseParamsFn(FusionParseParamsFnCB); - EXPECT_NE(opRegData.GetFusionParseParamFn(), nullptr); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - - EXPECT_EQ(opReg->GetFusionParseParamFunc(std::string("OmOptype"), std::string("Sub")), nullptr); - EXPECT_EQ(opReg->GetFusionParseParamFunc(std::string("OmOptype1"), std::string("Sub")), nullptr); -} - -TEST_F(UtestRegister, GetParseOpToGraphFuncTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - - opRegData.ParseOpToGraphFn(ParseOpToGraphFnCB); - EXPECT_NE(opRegData.GetParseOpToGraphFn(), nullptr); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - std::string om_type; - - EXPECT_EQ(opReg->GetParseOpToGraphFunc(std::string("OmOptype"), std::string("Add")), nullptr); - EXPECT_EQ(opReg->GetParseOpToGraphFunc(std::string("OmOptype"), std::string("Mul")), nullptr); -} - -TEST_F(UtestRegister, ParseParamByOperatorFuncTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - - opRegData.ParseParamsByOperatorFn(domi::AutoMappingByOpFn); - EXPECT_NE(opRegData.GetParseParamByOperatorFn(), nullptr); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - - EXPECT_EQ(opReg->GetParseParamByOperatorFunc(std::string("int")), nullptr); - EXPECT_EQ(opReg->GetParseParamByOperatorFunc(std::string("Add")), nullptr); -} - -TEST_F(UtestRegister, FusionParseParamByOpFuncTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - - opRegData.FusionParseParamsFn(FusionParseParamsFnCB); - EXPECT_NE(opRegData.GetFusionParseParamFn(), nullptr); - - opRegData.FusionParseParamsFn(FusionParseParamsFnCB2); - EXPECT_NE(opRegData.GetFusionParseParamByOpFn(), nullptr); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - - EXPECT_EQ(opReg->GetFusionParseParamByOpFunc(std::string("strOmOptype"), std::string("Add")), nullptr); - EXPECT_EQ(opReg->GetFusionParseParamByOpFunc(std::string("OmOptype"), std::string("Add")), nullptr); -} - -TEST_F(UtestRegister, ParseSubgraphPostFnTest) { - OpRegistrationData opRegData(std::string("OmOptype")); - - std::initializer_list OptypeList{std::string("Add"), std::string("Sub")}; - opRegData.OriginOpType(OptypeList); - - opRegData.ParseSubgraphPostFn(ParseSubgraphPostFnCB); - EXPECT_NE(opRegData.GetParseSubgraphPostFn(), nullptr); - - opRegData.ParseSubgraphPostFn(ParseSubgraphPostFnCB2); - EXPECT_NE(opRegData.GetParseSubgraphPostFn(), nullptr); - - ParseSubgraphFuncV2 Getfunc; - opRegData.GetParseSubgraphPostFn(Getfunc); - - OpRegistry *opReg = OpRegistry::Instance(); - opReg->Register(opRegData); - - EXPECT_EQ(opReg->GetParseSubgraphPostFunc(std::string("strOmOptype")), nullptr); - EXPECT_EQ(opReg->GetParseSubgraphPostFunc(std::string("OmOptype")), nullptr); - - domi::ParseSubgraphFuncV2 parse_subgraph_func; - EXPECT_EQ(opReg->GetParseSubgraphPostFunc(std::string("OmOptype"), parse_subgraph_func), domi::SUCCESS); - EXPECT_EQ(opReg->GetParseSubgraphPostFunc(std::string("strOmOptype"), parse_subgraph_func), domi::FAILED); -} - -TEST_F(UtestRegister, optiling_py_interface) { - EXPECT_NO_THROW( - const nlohmann::json j = R"([ - { - "name": "test_0", - "dtype": "int8", - "value": 1, - "const_value": [ - 1, - 1, - 1, - 1 - ], - "shape": [ - 4, - 4, - 4, - 4 - ], - "format": "ND" - }, - { - "name": "test_1", - "dtype": "list_int", - "value": [ - 1, - 1, - 1, - 1 - ] - }, - { - "name": "test_2" - }, - { - "name": "test_2", - "dtype": "list_list_int", - "value": [ - [1, 2], - [1, 2], - [1, 2], - [1, 2] - ] - }, - { - "name": "test_0", - "dtype": "list_list_int64", - "value": [ - [1, 2], - [1, 2], - [1, 2], - [1, 2] - ] - }, - { - "name": "test_3", - "dtype": "test", - "value": "1" - } - ])"_json; - - std::string json_str = j.dump(); - ge::Operator op("NULL"); - const char *optype = "ReluV2"; - const char *optype_v3 = "ReluV3"; - const char *optype_v4 = "ReluV4"; - const char *cmp_info = "{\"_common_info\":[0,16,48,1,1,0,0],\"_is_ori_last_transpose\":0,\"_pattern\":\"Transdata\"," - "\"_permute\":[0,2,1,3],\"_sgt_cube_vector_core_type\":\"VectorCore\",\"_src_fuse\":[0,1,3]," - "\"_src_pad_mode\":[0,0,2],\"_src_pad_var\":[1,1,16],\"_ub_info" - "\":[[48512,24192],[-1],[-1],[-1]],\"device_id\":\"0\"}"; - char *runinfo = const_cast(""); - size_t size = 3; - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const char *attrs = json_str.c_str(); - TbeOpTilingPyInterface(optype, cmp_info, cmp_info_hash, attrs, attrs, attrs, runinfo, size, elapse); - TbeOpTilingPyInterface(optype_v3, cmp_info, cmp_info_hash, attrs, attrs, attrs, runinfo, size, elapse); - TbeOpTilingPyInterface(optype_v4, cmp_info, cmp_info_hash, attrs, attrs, attrs, runinfo, size, elapse); - TbeOpTilingPyInterfaceEx2(optype, cmp_info, attrs, attrs, runinfo, size, cmp_info_hash, elapse); - TbeOpTilingPyInterfaceEx2(optype_v3, cmp_info, attrs, attrs, runinfo, size, cmp_info_hash, elapse); - TbeOpTilingPyInterfaceEx2(optype_v4, cmp_info, attrs, attrs, runinfo, size, cmp_info_hash, elapse); - ); -} - -TEST_F(UtestRegister, new_optiling_py_interface_ok) { - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "int32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int32","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_int64","value": [1,2, 3, 4]}, -{ "name": "attr_1","dtype": "int","value": 99}, -{ "name": "attr_2","dtype": "list_int32","value": [1, 2, 3, 4]}, -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - std::string result = - R"({"aicpu_block_dim":4,"block_dim":2,"clear_atomic":true,"local_memory_size":0,"schedule_mode":0,"tiling_cond":0,"tiling_data":"060708090A","tiling_key":78,"workspaces":[12]})"; - size_t size = result.length(); - std::string runinfo(size, 'a'); - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubNew; - op_impl_func.tiling_parse = OpTilingParseStubNew; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(runinfo.c_str()), size + 1U, elapse), - 1); - EXPECT_EQ(result, runinfo); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_fail_with_invalid_const_value) { - // int9999 is invalid data type - const nlohmann::json input = R"([ - {"name": "test_0","dtype": "int9999", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}])"_json; - std::string input_str = input.dump(); - std::string output_str = " "; - std::string attrs_str = " "; - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - size_t size = 150; - std::string runinfo(size, 'a'); - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubNew; - op_impl_func.tiling_parse = OpTilingParseStubNew; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(runinfo.c_str()), size, elapse), - 0); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_fail_with_invalid_attr) { - std::string input_str = " "; - std::string output_str = " "; - // int999 is invalid dtype - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_int64","value": [1,2, 3, 4]}, -{ "name": "attr_1","dtype": "int9999","value": 99}])"_json; - std::string attrs_str = attrs.dump(); - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - size_t size = 150; - std::string runinfo(size, 'a'); - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubNew; - op_impl_func.tiling_parse = OpTilingParseStubNew; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(runinfo.c_str()), size, elapse), - 0); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_fail_without_params) { - EXPECT_EQ(TbeOpTilingPyInterface(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, nullptr), 0); -} - -TEST_F(UtestRegister, new_optiling_py_interface_ok_with_float_data) { - const nlohmann::json input = R"([ -{"name": "t0", "dtype": "float16","const_value": [1.1,2.1,3.1,4.1] ,"shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "ND"}, -{"dtype": "int8", "shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "ND"} -])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - - size_t size = 160; - std::string runinfo(size, 'a'); - - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}, -{ "name": "group", "dtype": "str", "value": "empty"}])"_json; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubV5; - op_impl_func.tiling_parse = OpTilingParseStubV5; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse), - 1); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_ok_with_sub_format) { - const nlohmann::json input = R"([ -{"name": "t0", "dtype": "float16","const_value": [1.1,2.1,3.1,4.1] ,"shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "FRACTAL_Z", "sub_format" :32}, -{"dtype": "int8", "shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "ND"} -])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - std::string runinfo(100, 'a'); - size_t size = 100; - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}, { "name": "test_name", "dtype": "list_int", "value": [50, 51]}])"_json; - const nlohmann::json extra_infos = R"([ -{ "op_name": "matmul_all_reduce", "rank_size": 1}])"_json; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubV6; - op_impl_func.tiling_parse = OpTilingParseStubV5; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - EXPECT_EQ(std::string(DoOpTilingForCompile(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse, - extra_infos.dump().c_str())), - "{\"ret_code\":0}"); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_ok_with_extra_info_case2) { - const nlohmann::json input = R"([ -{"name": "t0", "dtype": "float16","const_value": [1.1,2.1,3.1,4.1] ,"shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "FRACTAL_Z", "sub_format" :32}, -{"dtype": "int8", "shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "ND"} -])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const char *op_type = "MatmulAllreduce"; - const char *cmp_info = ""; - std::string runinfo(100, 'a'); - size_t size = 100; - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}, { "name": "group", "dtype": "str", "value": "g0"}])"_json; - ge::HcomTopoInfo::TopoInfo original; - original.rank_size = 64; - original.notify_handle = reinterpret_cast(0x1234); // 指针不会被序列化 - - // 初始化topo_level_descs - original.topo_level_descs[static_cast(ge::HcomTopoInfo::TopoLevel::L0)] = {8, 16}; - original.topo_level_descs[static_cast(ge::HcomTopoInfo::TopoLevel::L1)] = {4, 32}; - - // 序列化 - nlohmann::json j = original; - nlohmann::json j_wrapped; - j_wrapped["hcom_topo_info"] = j; - std::string extra_infos = j_wrapped.dump(2); - std::cout << "Serialized JSON:\n" << extra_infos << "\n\n"; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubV6; - op_impl_func.tiling_parse = OpTilingParseStubV5; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - EXPECT_EQ(std::string(DoOpTilingForCompile(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse, - extra_infos.c_str())), - "{\"ret_code\":0}"); - int64_t rank_size_get = -1; - EXPECT_EQ(ge::HcomTopoInfo::Instance().GetGroupRankSize("g0", rank_size_get), ge::GRAPH_SUCCESS); - EXPECT_EQ(original.rank_size, rank_size_get); - auto topo_desc = ge::HcomTopoInfo::Instance().GetGroupTopoDesc("g0"); - EXPECT_NE(topo_desc, nullptr); - EXPECT_EQ(((*topo_desc)[0]).comm_sets, 8); - EXPECT_EQ(((*topo_desc)[0]).rank_size, 16); - EXPECT_EQ(((*topo_desc)[1]).comm_sets, 4); - EXPECT_EQ(((*topo_desc)[1]).rank_size, 32); - void *notify_handle = nullptr; - EXPECT_EQ(ge::HcomTopoInfo::Instance().GetGroupNotifyHandle("g0", notify_handle), ge::GRAPH_SUCCESS); - EXPECT_TRUE(notify_handle == nullptr); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_ok_with_extra_info_case_invalid) { - const nlohmann::json input = R"([ -{"name": "t0", "dtype": "float16","const_value": [1.1,2.1,3.1,4.1] ,"shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "FRACTAL_Z", "sub_format" :32}, -{"dtype": "int8", "shape": [4,4,4,4], "ori_shape":[4,4,4,4],"format": "ND"} -])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const char *op_type = "MatmulAllreduce"; - const char *cmp_info = ""; - std::string runinfo(100, 'a'); - size_t size = 100; - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}, { "name": "group", "dtype": "str", "value": "g1"}])"_json; - const nlohmann::json invalid_extra_infos = R"([ -{ - "hcom_topo_info": { - "rank_size": 64, - "topo_level_descs": [ - { - "comm_sets": 8, - "rank_size": 16 - }, - { - "comm_sets": 4, - "rank_size": 32 - }, - { - "comm_sets": 4, - "rank_size": 32 - } - ] - } -}])"_json; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubV6; - op_impl_func.tiling_parse = OpTilingParseStubV5; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - EXPECT_EQ(std::string(DoOpTilingForCompile(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse, - invalid_extra_infos.dump().c_str())), - "{\"ret_code\":0}"); - int64_t rank_size_get = -1; - EXPECT_NE(ge::HcomTopoInfo::Instance().GetGroupRankSize("g1", rank_size_get), ge::GRAPH_SUCCESS); - EXPECT_EQ(rank_size_get, -1); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_ok_auto_tiling) { - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = DefaultOptilingStub; - op_impl_func.tiling_parse = OpTilingParseStubV5; - registry_holder->AddTypesToImpl("DefaultImpl", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - // expect rt1 tiling not to work - REGISTER_OP_TILING_V2(AutoTiling, op_tiling_stub_failed); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8","shape": [4,4,4,4],"format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const char *op_type = "AutoTiling"; - const char *cmp_info = ""; - size_t size = 160; - std::string runinfo(size, 'a'); - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse), - 1); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, NewOptilingInterface_Ok_WithEmptyTensor) { - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = DefaultOptilingStub; - op_impl_func.tiling_parse = OpTilingParseStubV5; - registry_holder->AddTypesToImpl("DefaultImpl", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - // expect rt1 tiling not to work - REGISTER_OP_TILING_V2(AutoTiling, op_tiling_stub_failed); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8","shape": [0],"format": "ND", "const value": ""}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const char *op_type = "AutoTiling"; - const char *cmp_info = ""; - size_t size = 160; - std::string runinfo(size, 'a'); - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse), - 1); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, NewOptilingInterface_Ok_WithNodeName) { - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubNewWithName; - op_impl_func.tiling_parse = OpTilingParseStubV5; - registry_holder->AddTypesToImpl("DefaultImpl", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - // expect rt1 tiling not to work - REGISTER_OP_TILING_V2(AutoTiling, op_tiling_stub_failed); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8","shape": [0],"format": "ND", "const value": ""}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const nlohmann::json extra_info = R"({"op_name": "test"})"_json; - std::string extra_info_str = extra_info.dump(); - const char *op_type = "AutoTiling"; - const char *cmp_info = ""; - std::string runinfo(100, 'a'); - size_t size = 100; - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - EXPECT_EQ(std::string(DoOpTilingForCompile(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse, - extra_info_str.c_str())), - "{\"ret_code\":0}"); -// 以下三种是异常场景,因当前ut里面错误码上报功能不完备,所以返回值正常 - { - const nlohmann::json invalid_attrs = R"([ -{ "name": "op_para_size", "value": 50}])"_json; - EXPECT_EQ(std::string(DoOpTilingForCompile(op_type, - cmp_info, - cmp_info_hash, - input_str.c_str(), - output_str.c_str(), - invalid_attrs.dump().c_str(), - const_cast(runinfo.c_str()), - size, - elapse, - extra_info_str.c_str())), - "{\"ret_code\":0}"); - EXPECT_EQ(std::string(DoOpTilingForCompile(nullptr, - cmp_info, - cmp_info_hash, - input_str.c_str(), - output_str.c_str(), - invalid_attrs.dump().c_str(), - const_cast(runinfo.c_str()), - size, - elapse, - extra_info_str.c_str())), - "{\"ret_code\":0}"); - EXPECT_EQ(std::string(DoOpTilingForCompile("TestReluV2", - cmp_info, - cmp_info_hash, - input_str.c_str(), - output_str.c_str(), - invalid_attrs.dump().c_str(), - const_cast(runinfo.c_str()), - size, - elapse, - extra_info_str.c_str())), - "{\"ret_code\":0}"); - // 模拟内部有错误的时候, 校验返回的json对象 - StubError(); - std::string ret_when_error = DoOpTilingForCompile("TestReluV2", - cmp_info, - cmp_info_hash, - input_str.c_str(), - output_str.c_str(), - invalid_attrs.dump().c_str(), - const_cast(runinfo.c_str()), - size, - elapse, - extra_info_str.c_str()); - EXPECT_TRUE(CheckErrorRetFormat(ret_when_error)); - ReInitErrorManager(); - } - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, NewOptilingInterface_Ok_WithDynamicInput) { - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubNewWithDynamicInput; - op_impl_func.tiling_parse = OpTilingParseStubV5; - registry_holder->AddTypesToImpl("DefaultImpl", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - const nlohmann::json input = R"([ - [{ - "shape": [4, 16, 200, 336, 16], - "ori_shape": [4, 256, 200, 336], - "format": "NC1HWC0", - "sub_format": 0, - "ori_format": "NCHW", - "dtype": "float16", - "addr_type": 0, - "total_shape": [4, 16, 200, 336, 16], - "slice_offset": [], - "L1_addr_offset": 0, - "L1_fusion_type": -1, - "L1_workspace_size": -1, - "valid_shape": [], - "split_index": 0, - "atomic_type": "", - "input_c_values": 256, - "range": [ [4, 4], [16, 16], [200, 200], [336, 336], [16, 16] ], - "param_name": "feats", - "name": "feats_gm_0" - }, - { - "shape": [4, 16, 100, 168, 16], - "ori_shape": [4, 256, 100, 168], - "format": "NC1HWC0", - "sub_format": 0, - "ori_format": "NCHW", - "dtype": "float16", - "addr_type": 0, - "total_shape": [4, 16, 100, 168, 16], - "slice_offset": [], - "L1_addr_offset": 0, - "L1_fusion_type": -1, - "L1_workspace_size": -1, - "valid_shape": [], - "split_index": 0, - "atomic_type": "", - "input_c_values": 256, - "range": [ [4, 4], [16, 16], [100, 100], [168, 168], [16, 16] ], - "param_name": "feats", - "name": "feats_gm_1" - }, - { - "shape": [4, 16, 50, 84, 16], - "ori_shape": [4, 256, 50, 84], - "format": "NC1HWC0", - "sub_format": 0, - "ori_format": "NCHW", - "dtype": "float16", - "addr_type": 0, - "total_shape": [4, 16, 50, 84, 16], - "slice_offset": [], - "L1_addr_offset": 0, - "L1_fusion_type": -1, - "L1_workspace_size": -1, - "valid_shape": [], - "split_index": 0, - "atomic_type": "", - "input_c_values": 256, - "range": [ [4, 4], [16, 16], [50, 50], [84, 84], [16, 16] ], - "param_name": "feats" - }, - { - "shape": [4, 16, 25, 42, 16], - "ori_shape": [4, 256, 25, 42], - "format": "NC1HWC0", - "sub_format": 0, - "ori_format": "NCHW", - "dtype": "float16", - "addr_type": 0, - "total_shape": [4, 16, 25, 42, 16], - "slice_offset": [], - "L1_addr_offset": 0, - "L1_fusion_type": -1, - "L1_workspace_size": -1, - "valid_shape": [], - "split_index": 0, - "atomic_type": "", - "input_c_values": 256, - "range": [ [4, 4], [16, 16], [25, 25], [42, 42], [16, 16] ], - "param_name": "feats" - } - ], - null, - null, - { - "shape": [100, 5], - "ori_shape": [100, 5], - "format": "NCHW", - "sub_format": 0, - "ori_format": "NCHW", - "dtype": "float16", - "addr_type": 0, - "total_shape": [100, 5], - "slice_offset": [], - "L1_addr_offset": 0, - "L1_fusion_type": -1, - "L1_workspace_size": -1, - "valid_shape": [], - "split_index": 0, - "atomic_type": "", - "input_c_values": 5, - "range": [ [100, 100], [5, 5] ], - "param_name": "rois" - }])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - const nlohmann::json extra_info = R"({"op_name": "test", "rank_size": 1})"_json; - std::string extra_info_str = extra_info.dump(); - const char *op_type = "AutoTiling"; - const char *cmp_info = ""; - std::string runinfo(100, 'a'); - size_t size = 100; - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([ -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - EXPECT_EQ(std::string(DoOpTilingForCompile(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse, - extra_info_str.c_str())), - "{\"ret_code\":0}"); -} - -extern "C" int AscendCPyInterfaceCheckOp(const char *check_type, const char *optype, const char *inputs, - const char *outputs, const char *attrs, char *result_info, - size_t result_info_len); - -extern "C" int AscendCPyInterfaceGeneralized(const char *optype, const char *inputs, const char *outputs, - const char *attrs, const char *generalize_config, char *result_info, - size_t result_info_len); - -extern "C" int AscendCPyInterfaceGetTilingDefInfo(const char *optype, char *result_info, size_t result_info_len); - -ge::graphStatus check_supported_stub(const ge::Operator &op, ge::AscendString &result) { - const nlohmann::json res_json = R"( -{"ret_code": "1","reason": "check_supported_stub"})"_json; - std::string res_json_str = res_json.dump(); - result = AscendString(res_json_str.c_str()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus op_select_format_stub(const ge::Operator &op, ge::AscendString &result) { - const nlohmann::json res_json = R"({"op_info": "op_select_format_stub"})"_json; - std::string res_json_str = res_json.dump(); - result = AscendString(res_json_str.c_str()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus get_op_support_info_stub(const ge::Operator &op, ge::AscendString &result) { - const nlohmann::json res_json = R"({"op_info": "get_op_support_info_stub"})"_json; - std::string res_json_str = res_json.dump(); - result = AscendString(res_json_str.c_str()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus get_op_specific_info_stub(const ge::Operator &op, ge::AscendString &result) { - const nlohmann::json res_json = R"({"op_info": "get_op_specific_info_stub"})"_json; - std::string res_json_str = res_json.dump(); - result = AscendString(res_json_str.c_str()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus check_supported_stub_throw(const ge::Operator &op, ge::AscendString &result) { - throw "bad callback"; - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus check_supported_stub_fail(const ge::Operator &op, ge::AscendString &result) { - return ge::GRAPH_FAILED; -} - -TEST_F(UtestRegister, ascendC_py_interface_check_cap_ok) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "int32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int32","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_int64","value": [1,2, 3, 4]}, -{ "name": "attr_1","dtype": "int","value": 99}, -{ "name": "attr_2","dtype": "list_int32","value": [1, 2, 3, 4]}, -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "ascendC_py_interface_check_cap_ok"; - std::string res_info(100, 'a'); - size_t size = 100; - // check_supported - REG_CHECK_SUPPORT(ascendC_py_interface_check_cap_ok, check_supported_stub); - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_CHECK_SUPPORTED, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 1); - std::string check_supported_result = "{\"reason\":\"check_supported_stub\",\"ret_code\":\"1\"}"; - EXPECT_EQ(check_supported_result, res_info.substr(0, check_supported_result.size())); - - // op_select_format - REG_OP_SELECT_FORMAT(ascendC_py_interface_check_cap_ok, op_select_format_stub); - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_OP_SELECT_FORMAT, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 1); - std::string op_select_format_result = "{\"op_info\":\"op_select_format_stub\"}"; - EXPECT_EQ(op_select_format_result, res_info.substr(0, op_select_format_result.size())); - - // get_op_support_info - REG_OP_SUPPORT_INFO(ascendC_py_interface_check_cap_ok, get_op_support_info_stub); - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_GET_OP_SUPPORT_INFO, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 1); - std::string get_op_support_info_result = "{\"op_info\":\"get_op_support_info_stub\"}"; - EXPECT_EQ(get_op_support_info_result, res_info.substr(0, get_op_support_info_result.size())); - - // get_op_specific_info - REG_OP_SPEC_INFO(ascendC_py_interface_check_cap_ok, get_op_specific_info_stub); - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_GET_SPECIFIC_INFO, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 1); - std::string get_op_specific_info_result = "{\"op_info\":\"get_op_specific_info_stub\"}"; - EXPECT_EQ(get_op_specific_info_result, res_info.substr(0, get_op_specific_info_result.size())); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_check_cap_fail_without_callback) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "int32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int32","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_int64","value": [1,2, 3, 4]}, -{ "name": "attr_1","dtype": "int","value": 99}, -{ "name": "attr_2","dtype": "list_int32","value": [1, 2, 3, 4]}, -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "ascendC_py_interface_check_cap_fail_without_callback"; - std::string res_info(100, 'a'); - size_t size = 100; - // check_supported - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_CHECK_SUPPORTED, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 0); - - // op_select_format - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_OP_SELECT_FORMAT, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 0); - - // get_op_support_info - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_GET_OP_SUPPORT_INFO, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 0); - - // get_op_specific_info - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_GET_SPECIFIC_INFO, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_check_cap_fail_throw) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "int32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int32","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_int64","value": [1,2, 3, 4]}, -{ "name": "attr_1","dtype": "int","value": 99}, -{ "name": "attr_2","dtype": "list_int32","value": [1, 2, 3, 4]}, -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "ascendC_py_interface_check_cap_fail_throw"; - std::string res_info(100, 'a'); - size_t size = 100; - // check_supported - REG_CHECK_SUPPORT(ascendC_py_interface_check_cap_fail_throw, check_supported_stub_throw); - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_CHECK_SUPPORTED, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_check_cap_fail_by_callback) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "int32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int32","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_int64","value": [1,2, 3, 4]}, -{ "name": "attr_1","dtype": "int","value": 99}, -{ "name": "attr_2","dtype": "list_int32","value": [1, 2, 3, 4]}, -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "ascendC_py_interface_check_cap_fail_throw"; - std::string res_info(100, 'a'); - size_t size = 100; - // check_supported - REG_CHECK_SUPPORT(ascendC_py_interface_check_cap_fail_throw, check_supported_stub_fail); - EXPECT_EQ(AscendCPyInterfaceCheckOp(FUNC_CHECK_SUPPORTED, op_type.c_str(), input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(res_info.c_str()), size), - 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_check_cap_fail_without_params) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - EXPECT_EQ(AscendCPyInterfaceCheckOp(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0), 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -ge::graphStatus generalize_stub(const ge::Operator &op, const ge::AscendString &generalize_config, ge::AscendString &result) { - const nlohmann::json res_json = R"({"op_info": "generalize_stub"})"_json; - std::string res_json_str = res_json.dump(); - result = AscendString(res_json_str.c_str()); - return ge::GRAPH_SUCCESS; -} - -ge::graphStatus generalize_stub_fail(const ge::Operator &op, const ge::AscendString &generalize_config, ge::AscendString &result) { - return ge::GRAPH_FAILED; -} - -ge::graphStatus generalize_stub_throw(const ge::Operator &op, const ge::AscendString &generalize_config, ge::AscendString &result) { - const nlohmann::json res_json = R"({"op_info": "generalize_stub"})"_json; - std::string res_json_str = res_json.dump(); - result = AscendString(res_json_str.c_str()); - throw "bad callback"; - return ge::GRAPH_SUCCESS; -} - -TEST_F(UtestRegister, ascendC_py_interface_generalize_ok) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "float16", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "float32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int64","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "uint32","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_list_int64","value": [[1, 2], [3, 4]]}, -{ "name": "attr_1","dtype": "uint32","value": 99}, -{ "name": "attr_2","dtype": "list_list_int32","value": [[1, 2], [3, 4]]}, -{ "name": "op_para_size", "dtype": "uint16", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "ascendC_py_interface_generalize_ok"; - std::string generalize_config = "keep_rank"; - std::string res_info(100, 'a'); - size_t size = 100; - // shape generalize - REG_OP_PARAM_GENERALIZE(ascendC_py_interface_generalize_ok, generalize_stub); - EXPECT_EQ(AscendCPyInterfaceGeneralized(op_type.c_str(), input_str.c_str(), output_str.c_str(), attrs_str.c_str(), - generalize_config.c_str(), const_cast(res_info.c_str()), size), - 1); - std::string result = "{\"op_info\":\"generalize_stub\"}"; - EXPECT_EQ(result, res_info.substr(0, result.size())); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_generalize_fail_by_callback) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "float16", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "float32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int64","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "uint32","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_list_int64","value": [[1, 2], [3, 4]]}, -{ "name": "attr_1","dtype": "uint32","value": 99}, -{ "name": "attr_2","dtype": "list_list_int32","value": [[1, 2], [3, 4]]}, -{ "name": "op_para_size", "dtype": "uint16", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "ascendC_py_interface_generalize_ok"; - std::string generalize_config = "keep_rank"; - std::string res_info(100, 'a'); - size_t size = 100; - // shape generalize - REG_OP_PARAM_GENERALIZE(ascendC_py_interface_generalize_ok, generalize_stub_fail); - EXPECT_EQ(AscendCPyInterfaceGeneralized(op_type.c_str(), input_str.c_str(), output_str.c_str(), attrs_str.c_str(), - generalize_config.c_str(), const_cast(res_info.c_str()), size), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - - -TEST_F(UtestRegister, ascendC_py_interface_generalize_fail_without_callback) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int8", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "int32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int32","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_int64","value": [1,2, 3, 4]}, -{ "name": "attr_1","dtype": "int","value": 99}, -{ "name": "attr_2","dtype": "list_int32","value": [1, 2, 3, 4]}, -{ "name": "op_para_size", "dtype": "int", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "TestReluV2"; - std::string generalize_config = "keep_rank"; - std::string res_info(100, 'a'); - size_t size = 100; - // shape generalize - EXPECT_EQ(AscendCPyInterfaceGeneralized(op_type.c_str(), input_str.c_str(), output_str.c_str(), attrs_str.c_str(), - generalize_config.c_str(), const_cast(res_info.c_str()), size), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_generalize_fail_throw) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "float16", "const_value": [1,2,3,4],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "float32","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}, -{"name": "test_2","dtype": "int64","shape": [6,6,6,6],"ori_shape": [6,6,6,6],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "uint32","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - std::string output_str = output.dump(); - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_list_int64","value": [[1, 2], [3, 4]]}, -{ "name": "attr_1","dtype": "uint32","value": 99}, -{ "name": "attr_2","dtype": "list_list_int32","value": [[1, 2], [3, 4]]}, -{ "name": "op_para_size", "dtype": "uint16", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - std::string op_type = "ascendC_py_interface_generalize_fail_throw"; - std::string generalize_config = "keep_rank"; - std::string res_info(100, 'a'); - size_t size = 100; - // shape generalize - REG_OP_PARAM_GENERALIZE(ascendC_py_interface_generalize_fail_throw, generalize_stub_throw); - EXPECT_EQ(AscendCPyInterfaceGeneralized(op_type.c_str(), input_str.c_str(), output_str.c_str(), attrs_str.c_str(), - generalize_config.c_str(), const_cast(res_info.c_str()), size), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_generalize_fail_without_params) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - EXPECT_EQ(AscendCPyInterfaceGeneralized(nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0), 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -BEGIN_TILING_DATA_DEF(TestMaxPoolTilingData) -// format: TILING_DATA_FIELD_DEF(data_type, field_name); -TILING_DATA_FIELD_DEF(int8_t, dim_0); -TILING_DATA_FIELD_DEF(int16_t, dim_1); -TILING_DATA_FIELD_DEF(int32_t, dim_2); -TILING_DATA_FIELD_DEF(int64_t, dim_3); -TILING_DATA_FIELD_DEF(uint8_t, dim_4); -TILING_DATA_FIELD_DEF(uint16_t, dim_5); -TILING_DATA_FIELD_DEF(uint32_t, dim_6); -TILING_DATA_FIELD_DEF(uint64_t, dim_7); -TILING_DATA_FIELD_DEF(int32_t, act_core_num); -END_TILING_DATA_DEF - -// register class -REGISTER_TILING_DATA_CLASS(TestMaxPool, TestMaxPoolTilingData) - -BEGIN_TILING_DATA_DEF(TestMaxPoolTilingDataStruct) -// format: TILING_DATA_FIELD_DEF(data_type, field_name); -TILING_DATA_FIELD_DEF_ARR(int8_t, 8, dim_0); -TILING_DATA_FIELD_DEF_STRUCT(TestMaxPoolTilingData, dim_1); -END_TILING_DATA_DEF - -// register class -REGISTER_TILING_DATA_CLASS(TestMaxPoolStruct, TestMaxPoolTilingDataStruct) - -TEST_F(UtestRegister, ascendC_py_interface_get_tiling_def_ok) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "TestMaxPool"; - std::string res_info(1024, 'a'); - size_t size = 1024; - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 1); - const nlohmann::json result = - R"({"class_name":"TestMaxPoolTilingData","data_size":40,"fields":[{"classType":"0","dtype":"int8_t","name":"dim_0"},{"arrSize":1,"classType":"1","dtype":"uint8_t","name":"dim_1PH"},{"classType":"0","dtype":"int16_t","name":"dim_1"},{"classType":"0","dtype":"int32_t","name":"dim_2"},{"classType":"0","dtype":"int64_t","name":"dim_3"},{"classType":"0","dtype":"uint8_t","name":"dim_4"},{"arrSize":1,"classType":"1","dtype":"uint8_t","name":"dim_5PH"},{"classType":"0","dtype":"uint16_t","name":"dim_5"},{"classType":"0","dtype":"uint32_t","name":"dim_6"},{"classType":"0","dtype":"uint64_t","name":"dim_7"},{"classType":"0","dtype":"int32_t","name":"act_core_num"},{"arrSize":4,"classType":"1","dtype":"uint8_t","name":"TestMaxPoolTilingDataPH"}]})"_json; - std::string result_str = result.dump(); - EXPECT_EQ(result_str, res_info.substr(0, result_str.size())); - op_type = "TestMaxPoolStruct"; - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 1); - unsetenv("ENABLE_RUNTIME_V2"); -} - - -namespace test1 { -BEGIN_TILING_DATA_DEF(TestMaxPoolTilingStruct) -TILING_DATA_FIELD_DEF_ARR(int8_t, 5, dim_0); -TILING_DATA_FIELD_DEF_STRUCT(TestMaxPoolTilingData, dim_1); -END_TILING_DATA_DEF -} - -namespace test2 { -BEGIN_TILING_DATA_DEF(TestMaxPoolTilingStruct) -TILING_DATA_FIELD_DEF_ARR(int8_t, 5, dim_1); -TILING_DATA_FIELD_DEF_STRUCT(TestMaxPoolTilingData, dim_2); -END_TILING_DATA_DEF -} //name - -namespace test3 { -BEGIN_TILING_DATA_DEF(TestMaxPoolTilingStruct) -TILING_DATA_FIELD_DEF(uint64_t, dim_1); -TILING_DATA_FIELD_DEF_STRUCT(TestMaxPoolTilingData, dim_2); -END_TILING_DATA_DEF -} //infosize - -namespace test4 { -BEGIN_TILING_DATA_DEF(TestMaxPoolTilingStruct) -TILING_DATA_FIELD_DEF_ARR(int8_t, 4, dim_0); -TILING_DATA_FIELD_DEF_STRUCT(TestMaxPoolTilingData, dim_1); -END_TILING_DATA_DEF -} //arrsize -namespace test5 { -BEGIN_TILING_DATA_DEF(TestMaxPoolTilingStruct) -TILING_DATA_FIELD_DEF_ARR(int8_t, 50, dim_0); -END_TILING_DATA_DEF -} //datasize - -std::shared_ptr Test_api1() { - return std::make_shared(); -} - -std::shared_ptr Test_api2() { - return std::make_shared(); -} - -std::shared_ptr Test_api3() { - return std::make_shared(); -} - -std::shared_ptr Test_api4() { - return std::make_shared(); -} - -std::shared_ptr Test_api5() { - return std::make_shared(); -} - -TEST_F(UtestRegister, test_register_tiling_data) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "Test_MaxPool"; - std::string res_info(1024, 'a'); - size_t size = 1024; - CTilingDataClassFactory::GetInstance().RegisterTilingData("Test_MaxPool", Test_api1); - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 1); - const nlohmann::json result1 = - R"({"class_name":"TestMaxPoolTilingStruct","data_size":48,"fields":[{"arrSize":5,"classType":"1","dtype":"int8_t","name":"dim_0"},{"arrSize":3,"classType":"1","dtype":"uint8_t","name":"dim_1PH"},{"classType":"2","dtype":"struct","name":"dim_1","structSize":40,"structType":"TestMaxPoolTilingData"}]})"_json; - std::string result_str1 = result1.dump(); - EXPECT_EQ(result_str1, res_info.substr(0, result_str1.size())); - - CTilingDataClassFactory::GetInstance().RegisterTilingData("Test_MaxPool", Test_api2); - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 1); - const nlohmann::json result2 = - R"({"class_name":"TestMaxPoolTilingStruct","data_size":48,"fields":[{"arrSize":5,"classType":"1","dtype":"int8_t","name":"dim_0"},{"arrSize":3,"classType":"1","dtype":"uint8_t","name":"dim_1PH"},{"classType":"2","dtype":"struct","name":"dim_1","structSize":40,"structType":"TestMaxPoolTilingData"}]})"_json; - std::string result_str2 = result2.dump(); - EXPECT_EQ(result_str2, res_info.substr(0, result_str2.size())); - - CTilingDataClassFactory::GetInstance().RegisterTilingData("Test_MaxPool", Test_api3); - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 1); - const nlohmann::json result3 = - R"({"class_name":"TestMaxPoolTilingStruct","data_size":48,"fields":[{"arrSize":5,"classType":"1","dtype":"int8_t","name":"dim_0"},{"arrSize":3,"classType":"1","dtype":"uint8_t","name":"dim_1PH"},{"classType":"2","dtype":"struct","name":"dim_1","structSize":40,"structType":"TestMaxPoolTilingData"}]})"_json; - std::string result_str3 = result3.dump(); - EXPECT_EQ(result_str3, res_info.substr(0, result_str3.size())); - - CTilingDataClassFactory::GetInstance().RegisterTilingData("Test_MaxPool", Test_api4); - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 1); - const nlohmann::json result4 = - R"({"class_name":"TestMaxPoolTilingStruct","data_size":48,"fields":[{"arrSize":5,"classType":"1","dtype":"int8_t","name":"dim_0"},{"arrSize":3,"classType":"1","dtype":"uint8_t","name":"dim_1PH"},{"classType":"2","dtype":"struct","name":"dim_1","structSize":40,"structType":"TestMaxPoolTilingData"}]})"_json; - std::string result_str4 = result4.dump(); - EXPECT_EQ(result_str4, res_info.substr(0, result_str4.size())); - - CTilingDataClassFactory::GetInstance().RegisterTilingData("Test_MaxPool", Test_api5); - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 1); - const nlohmann::json result5 = - R"({"class_name":"TestMaxPoolTilingStruct","data_size":48,"fields":[{"arrSize":5,"classType":"1","dtype":"int8_t","name":"dim_0"},{"arrSize":3,"classType":"1","dtype":"uint8_t","name":"dim_1PH"},{"classType":"2","dtype":"struct","name":"dim_1","structSize":40,"structType":"TestMaxPoolTilingData"}]})"_json; - std::string result_str5 = result5.dump(); - EXPECT_EQ(result_str5, res_info.substr(0, result_str5.size())); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_get_tiling_def_without_callback) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "TestMaxPoolNotExist"; - std::string res_info(1024, 'a'); - size_t size = 1024; - // check_supported - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(op_type.c_str(), const_cast(res_info.c_str()), size), 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_get_tiling_def_fail_without_params) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - EXPECT_EQ(AscendCPyInterfaceGetTilingDefInfo(nullptr, nullptr, 0), 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_register_tilingdata_record_tiling_struct) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - uint32_t ret = TilingDataStructBase::GetInstance().RecordTilingStruct("TestTiling1", "test.cpp", 1); - ret = TilingDataStructBase::GetInstance().RecordTilingStruct("TestTiling2", "test.h", 1); - ret = TilingDataStructBase::GetInstance().RecordTilingStruct("TestTiling1", "test.cpp", 2); - EXPECT_EQ(ret, 0); - ret = TilingDataStructBase::GetInstance().RecordTilingStruct("TestTiling1", "test.h", 1); - EXPECT_EQ(ret, 0); - ret = TilingDataStructBase::GetInstance().RecordTilingStruct("TestTiling2", "test.h", 1); - EXPECT_EQ(ret, 0); - ret = TilingDataStructBase::GetInstance().RecordTilingStruct("TestTiling2", "test.h", 2); - EXPECT_EQ(ret, 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_register_tilingdata_base_ok) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - auto params = TestMaxPoolTilingData(); - params.set_dim_0(0); - params.set_dim_1(10); - params.set_dim_2(20); - params.set_dim_3(30); - params.set_dim_4(40); - params.set_dim_5(50); - params.set_dim_6(60); - params.set_dim_7(70); - params.set_act_core_num(8); - uint8_t res_data[1024]; - int offset = 0; - params.SaveToBuffer((void *) (&res_data), params.GetDataSize()); - EXPECT_EQ(*((int8_t *) (res_data + offset)), params.get_dim_0()); - offset += sizeof(int16_t); - EXPECT_EQ(*((int16_t *) (res_data + offset)), params.get_dim_1()); - offset += sizeof(int16_t); - EXPECT_EQ(*((int32_t *) (res_data + offset)), params.get_dim_2()); - offset += sizeof(int32_t); - EXPECT_EQ(*((int64_t *) (res_data + offset)), params.get_dim_3()); - offset += sizeof(int64_t); - EXPECT_EQ(*((uint8_t *) (res_data + offset)), params.get_dim_4()); - offset += sizeof(uint16_t); - EXPECT_EQ(*((uint16_t *) (res_data + offset)), params.get_dim_5()); - offset += sizeof(uint16_t); - EXPECT_EQ(*((uint32_t *) (res_data + offset)), params.get_dim_6()); - offset += sizeof(uint32_t); - EXPECT_EQ(*((uint64_t *) (res_data + offset)), params.get_dim_7()); - offset += sizeof(uint64_t); - EXPECT_EQ(*((int32_t *) (res_data + offset)), params.get_act_core_num()); - offset += sizeof(int32_t); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_register_tilingdata_base_failed) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - auto paramStruct = TestMaxPoolTilingDataStruct(); - int8_t arr[] = {0, 1, 2, 3, 4, 5, 6, 7}; - uint8_t res_data[1024]; - int offset = 0; - paramStruct.set_dim_0(arr); - paramStruct.dim_1.set_dim_0(0); - paramStruct.dim_1.set_dim_1(10); - paramStruct.dim_1.set_dim_2(20); - paramStruct.dim_1.set_dim_3(30); - paramStruct.dim_1.set_dim_4(40); - paramStruct.dim_1.set_dim_5(50); - paramStruct.dim_1.set_dim_6(60); - paramStruct.dim_1.set_dim_7(70); - paramStruct.dim_1.set_act_core_num(8); - paramStruct.SaveToBuffer((void *) (&res_data), 1024); - - auto params = TestMaxPoolTilingData((void *) (&res_data)); - params.set_dim_0(0); - params.set_dim_1(10); - params.set_dim_2(20); - params.set_dim_3(30); - params.set_dim_4(40); - params.set_dim_5(50); - params.set_dim_6(60); - params.set_dim_7(70); - params.set_act_core_num(8); - params.SaveToBuffer((void *) (&res_data), 1024); - EXPECT_EQ(*((int8_t *) (res_data + offset)), params.get_dim_0()); - offset += sizeof(int16_t); - EXPECT_EQ(*((int16_t *) (res_data + offset)), params.get_dim_1()); - params.SetDataPtr(res_data); - unsetenv("ENABLE_RUNTIME_V2"); -} - -extern "C" int AscendCPyInterfaceOpReplay(const char *optype, const char *soc_version, int block_dim, - const char *tiling_data, const char *kernel_name, const char *entry_file, - const char *output_kernel_file, int core_type, int task_ration); - -int replay_stub(ReplayFuncParam ¶m, const int core_typ) { - return 1; -} - -int replay_stub_throw(ReplayFuncParam ¶m, const int core_typ) { - throw "bad callback"; - return 1; -} - -int replay_stub_invalid_ret(ReplayFuncParam ¶m, const int core_typ) { - return 0; -} - -TEST_F(UtestRegister, ascendC_py_interface_op_replay_ok) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "ascendC_py_interface_op_replay_ok"; - std::string soc_version = "ascend710"; - int blkdim = 32; - std::string tilingdata = "\x00\x14\x00\x00\x00\n\x00(\x1e\x00\x00\x00\x00\x00\x00\x00"; - std::string kernel_name = "ascendC_py_interface_op_replay_ok"; - std::string entry_file = "ascendC_py_interface_op_replay_ok_entry_file.h"; - std::string output_kernel_file = "ascendC_py_interface_op_replay_ok_kernel_file.cce"; - int core_type = 0; - int task_ration = 1; - REG_REPLAY_FUNC(ascendC_py_interface_op_replay_ok, ascend710, replay_stub); - EXPECT_EQ(AscendCPyInterfaceOpReplay(op_type.c_str(), soc_version.c_str(), blkdim, tilingdata.c_str(), - kernel_name.c_str(), entry_file.c_str(), output_kernel_file.c_str(), core_type, - task_ration), - 1); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_op_replay_fail_without_callback) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "ascendC_py_interface_op_replay_fail_without_callback"; - std::string soc_version = "ascend710"; - int blkdim = 32; - std::string tilingdata = "\x00\x14\x00\x00\x00\n\x00(\x1e\x00\x00\x00\x00\x00\x00\x00"; - std::string kernel_name = "ascendC_py_interface_op_replay_fail_without_callback"; - std::string entry_file = "ascendC_py_interface_op_replay_fail_without_callback_entry_file.h"; - std::string output_kernel_file = "ascendC_py_interface_op_replay_fail_without_callback_kernel_file.cce"; - int core_type = 0; - int task_ration = 1; - EXPECT_EQ(AscendCPyInterfaceOpReplay(op_type.c_str(), soc_version.c_str(), blkdim, tilingdata.c_str(), - kernel_name.c_str(), entry_file.c_str(), output_kernel_file.c_str(), core_type, - task_ration), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_op_replay_fail_throw) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "ascendC_py_interface_op_replay_fail_throw"; - std::string soc_version = "ascend710"; - int blkdim = 32; - std::string tilingdata = "\x00\x14\x00\x00\x00\n\x00(\x1e\x00\x00\x00\x00\x00\x00\x00"; - std::string kernel_name = "ascendC_py_interface_op_replay_fail_throw"; - std::string entry_file = "ascendC_py_interface_op_replay_fail_throw_entry_file.h"; - std::string output_kernel_file = "ascendC_py_interface_op_replay_fail_throw_kernel_file.cce"; - int core_type = 0; - int task_ration = 1; - REG_REPLAY_FUNC(ascendC_py_interface_op_replay_fail_throw, ascend710, replay_stub_throw); - EXPECT_EQ(AscendCPyInterfaceOpReplay(op_type.c_str(), soc_version.c_str(), blkdim, tilingdata.c_str(), - kernel_name.c_str(), entry_file.c_str(), output_kernel_file.c_str(), core_type, - task_ration), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_op_replay_fail_without_params) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - EXPECT_EQ(AscendCPyInterfaceOpReplay(nullptr, nullptr, 0, nullptr, nullptr, nullptr, nullptr, 0, 1), 0); - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_op_replay_invalid_core_type) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "ascendC_py_interface_op_replay_invalid_core_type"; - std::string soc_version = "ascend710"; - int blkdim = 32; - std::string tilingdata = "\x00\x14\x00\x00\x00\n\x00(\x1e\x00\x00\x00\x00\x00\x00\x00"; - std::string kernel_name = "ascendC_py_interface_op_replay_invalid_core_type"; - std::string entry_file = "ascendC_py_interface_op_replay_invalid_core_type_entry_file.h"; - std::string output_kernel_file = "ascendC_py_interface_op_replay_invalid_core_type_kernel_file.cce"; - int core_type = 4; - int task_ration = 1; - REG_REPLAY_FUNC(ascendC_py_interface_op_replay_invalid_core_type, ascend710, replay_stub); - EXPECT_EQ(AscendCPyInterfaceOpReplay(op_type.c_str(), soc_version.c_str(), blkdim, tilingdata.c_str(), - kernel_name.c_str(), entry_file.c_str(), output_kernel_file.c_str(), core_type, - task_ration), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_op_replay_invalid_task_ration) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "ascendC_py_interface_op_replay_invalid_task_ration"; - std::string soc_version = "ascend710"; - int blkdim = 32; - std::string tilingdata = "\x00\x14\x00\x00\x00\n\x00(\x1e\x00\x00\x00\x00\x00\x00\x00"; - std::string kernel_name = "ascendC_py_interface_op_replay_invalid_task_ration"; - std::string entry_file = "ascendC_py_interface_op_replay_invalid_task_ration_entry_file.h"; - std::string output_kernel_file = "ascendC_py_interface_op_replay_invalid_task_ration_kernel_file.cce"; - int core_type = 0; - int task_ration = -1; - REG_REPLAY_FUNC(ascendC_py_interface_op_replay_invalid_task_ration, ascend710, replay_stub); - EXPECT_EQ(AscendCPyInterfaceOpReplay(op_type.c_str(), soc_version.c_str(), blkdim, tilingdata.c_str(), - kernel_name.c_str(), entry_file.c_str(), output_kernel_file.c_str(), core_type, - task_ration), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, ascendC_py_interface_op_replay_invalid_ret) { - setenv("ENABLE_RUNTIME_V2", "1", 0); - std::string op_type = "ascendC_py_interface_op_replay_invalid_ret"; - std::string soc_version = "ascend710"; - int blkdim = 32; - std::string tilingdata = "\x00\x14\x00\x00\x00\n\x00(\x1e\x00\x00\x00\x00\x00\x00\x00"; - std::string kernel_name = "ascendC_py_interface_op_replay_invalid_ret"; - std::string entry_file = "ascendC_py_interface_op_replay_invalid_ret_entry_file.h"; - std::string output_kernel_file = "ascendC_py_interface_op_replay_invalid_ret_kernel_file.cce"; - int core_type = 1; - int task_ration = 2; - REG_REPLAY_FUNC(ascendC_py_interface_op_replay_invalid_ret, ascend710, replay_stub_invalid_ret); - EXPECT_EQ(AscendCPyInterfaceOpReplay(op_type.c_str(), soc_version.c_str(), blkdim, tilingdata.c_str(), - kernel_name.c_str(), entry_file.c_str(), output_kernel_file.c_str(), core_type, - task_ration), - 0); - - unsetenv("ENABLE_RUNTIME_V2"); -} - -TEST_F(UtestRegister, new_optiling_py_interface_with_null_desc_ok) { - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "float", "const_value": [null,2.0,null,null], "const_value_null_desc": ["inf", null, "nan", "-inf"],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "float","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - std::string output_str = output.dump(); - - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_float","value": [null,2.0,null,null],"value_null_desc": ["inf", null, "nan", "-inf"]}, -{ "name": "attr_1","dtype": "float","value": null, "value_null_desc": "inf"}, -{ "name": "attr_2","dtype": "list_float","value": [1, 2, 3, 4]}, -{ "name": "op_para_size", "dtype": "float", "value": 50}])"_json; - std::string attrs_str = attrs.dump(); - - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - - std::string result = - R"({"aicpu_block_dim":4,"block_dim":2,"clear_atomic":true,"local_memory_size":0,"schedule_mode":0,"tiling_cond":0,"tiling_data":"060708090A","tiling_key":78,"workspaces":[12]})"; - size_t size = result.length(); - std::string runinfo(size, 'a'); - - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubNewWithNullDesc; - op_impl_func.tiling_parse = OpTilingParseStubNew; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = 50; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs_str.c_str(), const_cast(runinfo.c_str()), size + 1U, elapse), - 1); - - EXPECT_EQ(result, runinfo); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_const_value_fail_with_different_size) { - // const_value size is different from const_value_null_desc - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "float", "const_value": [null,2.0,null,null], "const_value_null_desc": ["inf", null, "nan"],"shape": [4,4,4,4],"format": "ND"}])"_json; - SupportInfNanWithNullDescInvalidTestCase(input, nullptr, nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_const_value_fail_with_type_not_support_inf_nan) { - // dtype doesn't support inf and nan - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "int32", "const_value": [null,2.0,null,null], "const_value_null_desc": ["inf", null, "nan", "-inf"],"shape": [4,4,4,4],"format": "ND"}])"_json; - SupportInfNanWithNullDescInvalidTestCase(input, nullptr, nullptr); -} - -TEST_F(UtestRegister, new_optiling_py_interface_attr_fail_with_invalid_null_desc_param) { - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "float", "const_value": [null,2.0,null,null], "const_value_null_desc": ["inf", null, "nan", "-inf"],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "float","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}])"_json; - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - // value_null_desc has invalid param "abc" - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "list_float","value": [null,2.0,null,null],"value_null_desc": ["inf", null, "nan", "abc"]}])"_json; - SupportInfNanWithNullDescInvalidTestCase(input, output, attrs); -} - -TEST_F(UtestRegister, new_optiling_py_interface_attr_fail_with_value_not_null_but_has_null_desc) { - const nlohmann::json input = R"([ -{"name": "test_0","dtype": "float", "const_value": [null,2.0,null,null], "const_value_null_desc": ["inf", null, "nan", "-inf"],"shape": [4,4,4,4],"format": "ND"}, -{"name": "test_1","dtype": "float","shape": [5,5,5,5],"ori_shape": [5,5,5,5],"format": "ND","ori_format": "ND"}])"_json; - const nlohmann::json output = R"([ -{"name": "y_0","dtype": "int8","shape": [9,9,9,9],"ori_shape" :[9,9,9,9],"format": "ND","ori_format":"ND"}])"_json; - - // when attr dtype is float, value is not null, but exist value_null_desc - const nlohmann::json attrs = R"([ -{ "name": "attr_0","dtype": "float","value": 2.0, "value_null_desc": null}])"_json; - SupportInfNanWithNullDescInvalidTestCase(input, output, attrs); -} - -TEST_F(UtestRegister, new_optiling_py_interface_ok_with_bf16_data) { - const nlohmann::json input = R"([ - { - "name": "t0", - "dtype": "bfloat16", - "const_value": [1.1, 2.1, 3.1, 4.1], - "shape": [4, 4, 4, 4], - "ori_shape": [4, 4, 4, 4], - "format": "ND" - }, - { - "dtype": "int8", - "shape": [4, 4, 4, 4], - "ori_shape": [4, 4, 4, 4], - "format": "ND" - }])"_json; - std::string input_str = input.dump(); - const nlohmann::json output = R"([ - { - "name": "y_0", - "dtype": "int8", - "shape": [9, 9, 9, 9], - "ori_shape": [9, 9, 9, 9], - "format": "ND", - "ori_format": "ND" - }])"_json; - std::string output_str = output.dump(); - const char *op_type = "TestReluV2"; - const char *cmp_info = ""; - size_t size = 160U; - std::string runinfo(size, 'a'); - const char *cmp_info_hash = ""; - uint64_t *elapse = nullptr; - const nlohmann::json attrs = R"([{"name": "op_para_size", "dtype": "int", "value": 50}])"_json; - - const size_t max_tiling_size = 50U; - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.tiling = OpTilingStubBf16; - op_impl_func.tiling_parse = OpTilingParseStubV5; - op_impl_func.compile_info_creator = CreateCompileInfo; - op_impl_func.compile_info_deleter = DeleteCompileInfo; - op_impl_func.max_tiling_data_size = max_tiling_size; - registry_holder->AddTypesToImpl(op_type, op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - EXPECT_EQ(TbeOpTilingPyInterface(op_type, cmp_info, cmp_info_hash, input_str.c_str(), output_str.c_str(), - attrs.dump().c_str(), const_cast(runinfo.c_str()), size, elapse), - 1); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); -} diff --git a/tests/ut/register/testcase/scope_pattern_unittest.cc b/tests/ut/register/testcase/scope_pattern_unittest.cc deleted file mode 100644 index a79460271bbfd6b8e0a5d2ed8003144fb5844b6b..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/scope_pattern_unittest.cc +++ /dev/null @@ -1,316 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "register/scope/scope_pattern_impl.h" -#include "register/scope/scope_graph_impl.h" -#include "common/ge_common/debug/ge_log.h" -#include "graph/types.h" -#include "inc/external/register/scope/scope_fusion_pass_register.h" - -using namespace std; -using namespace testing; - -namespace ge { - -class ScopePatternUt : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(ScopePatternUt, ScopeAttrValue1) { - ScopeAttrValue scope_attr_value; - - float32_t value = 0.2; - scope_attr_value.SetFloatValue(value); - EXPECT_EQ(scope_attr_value.impl_->GetFloatValue(), static_cast(0.2)); - - int64_t value2 = 2; - scope_attr_value.SetIntValue(value2); - EXPECT_EQ(scope_attr_value.impl_->GetIntValue(), 2); - - scope_attr_value.SetStringValue("abc"); - EXPECT_EQ(scope_attr_value.impl_->GetStrValue(), string("abc")); - - scope_attr_value.SetStringValue(string("def")); - EXPECT_EQ(scope_attr_value.impl_->GetStrValue(), string("def")); - - scope_attr_value.SetBoolValue(true); - EXPECT_TRUE(scope_attr_value.impl_->GetBoolValue()); - - ScopeAttrValue scope_attr_value2(scope_attr_value); - EXPECT_EQ(scope_attr_value2.impl_->GetFloatValue(), static_cast(0.2)); - EXPECT_EQ(scope_attr_value2.impl_->GetIntValue(), 2); - EXPECT_EQ(scope_attr_value2.impl_->GetStrValue(), string("def")); - EXPECT_TRUE(scope_attr_value2.impl_->GetBoolValue()); - - ScopeAttrValue scope_attr_value3; - scope_attr_value3 = scope_attr_value; - EXPECT_EQ(scope_attr_value3.impl_->GetFloatValue(), static_cast(0.2)); - EXPECT_EQ(scope_attr_value3.impl_->GetIntValue(), 2); - EXPECT_EQ(scope_attr_value3.impl_->GetStrValue(), string("def")); - EXPECT_TRUE(scope_attr_value3.impl_->GetBoolValue()); -} - -TEST_F(ScopePatternUt, ScopeAttrValue2) { - ScopeAttrValue scope_attr_value; - scope_attr_value.impl_ = nullptr; - - float32_t value = 0.2; - scope_attr_value.SetFloatValue(value); - - int64_t value2 = 2; - scope_attr_value.SetIntValue(value2); - scope_attr_value.SetStringValue("abc"); - scope_attr_value.SetStringValue(string("def")); - scope_attr_value.SetBoolValue(true); - - EXPECT_TRUE(scope_attr_value.impl_ == nullptr); -} - -TEST_F(ScopePatternUt, NodeOpTypeFeature) { - // construct - string nodeType = string("add"); - int32_t num = 1; - int32_t step = 100; - NodeOpTypeFeature notf(nodeType, num, step); - EXPECT_EQ(notf.impl_->step_, step); - NodeOpTypeFeature notf2("edf", num, 0); - EXPECT_EQ(notf2.impl_->node_type_, string("edf")); - - // match - Scope scope; - scope.Init("name", "sub_type", nullptr); - EXPECT_FALSE(notf.Match(nullptr)); - EXPECT_FALSE(notf.Match(&scope)); - EXPECT_FALSE(notf2.Match(&scope)); - - // copy - NodeOpTypeFeature notf3(notf); - EXPECT_EQ(notf3.impl_->node_type_, string("add")); - notf3 = notf3; - notf3 = notf2; - EXPECT_EQ(notf3.impl_->node_type_, string("edf")); - - notf3.impl_.reset(); - EXPECT_FALSE(notf3.Match(nullptr)); - EXPECT_EQ(notf3.impl_, nullptr); -} - -TEST_F(ScopePatternUt, NodeAttrFeature) { - // construct - ScopeAttrValue scope_attr_value; - scope_attr_value.SetStringValue("abc"); - NodeAttrFeature naf("node_type", "attr_name", DT_INT8, scope_attr_value); - NodeAttrFeature naf2(string("node_type_2"), string("attr_name_2"), DT_INT8, scope_attr_value); - EXPECT_EQ(naf.impl_->attr_value_.impl_->GetStrValue(), string("abc")); - - // copy - NodeAttrFeature naf3(naf2); - EXPECT_EQ(naf3.impl_->node_type_, string("node_type_2")); - naf3 = naf3; - naf3 = naf; - EXPECT_EQ(naf3.impl_->attr_name_, string("attr_name")); - - // match - Scope scope; - scope.Init("name", "sub_type", nullptr); - EXPECT_FALSE(naf3.impl_->Match(nullptr)); - EXPECT_FALSE(naf3.impl_->Match(&scope)); -} - -TEST_F(ScopePatternUt, CheckNodeAttrFeatureData) { - ScopeAttrValue scope_attr_value; - scope_attr_value.SetStringValue("abc"); - NodeAttrFeature naf("node_type", "attr_name", DT_INT8, scope_attr_value); - - bool init_value = true; - ge::OpDescPtr op_desc(new ge::OpDesc("add1", "Add")); - Scope scope; - scope.Init("name", "sub_type", nullptr); - - auto ret = naf.impl_->CheckNodeAttrFeatureData(init_value, op_desc, &scope); - EXPECT_EQ(ret, PARAM_INVALID); - - string init_value2 = "init_value"; - ret = naf.impl_->CheckNodeAttrFeatureData(init_value2, op_desc, &scope); - EXPECT_EQ(ret, PARAM_INVALID); - - int64_t init_value3 = 1; - ret = naf.impl_->CheckNodeAttrFeatureData(init_value3, op_desc, &scope); - EXPECT_EQ(ret, PARAM_INVALID); - - float32_t init_value4 = 0.2; - ret = naf.impl_->CheckNodeAttrFeatureData(init_value4, op_desc, &scope); - EXPECT_EQ(ret, PARAM_INVALID); - - // match - EXPECT_FALSE(naf.Match(nullptr)); - EXPECT_FALSE(naf.Match(&scope)); -} - -TEST_F(ScopePatternUt, CheckNodeAttrFeatureDataSuccess) { - { - ScopeAttrValue scope_attr_value; - bool init_value = true; - scope_attr_value.SetBoolValue(init_value); - string attr_name("attr_name"); - NodeAttrFeature naf("node_type", attr_name, DT_INT8, scope_attr_value); - - ge::OpDescPtr op_desc(new ge::OpDesc("add1", "Add")); - ge::AttrUtils::SetBool(op_desc, attr_name, init_value); - Scope scope; - scope.Init("name", "sub_type", nullptr); - - auto ret = naf.impl_->CheckNodeAttrFeatureData(init_value, op_desc, &scope); - EXPECT_EQ(ret, SUCCESS); - } - { - ScopeAttrValue scope_attr_value; - string init_value = "true"; - scope_attr_value.SetStringValue(init_value.c_str()); - string attr_name("attr_name"); - NodeAttrFeature naf("node_type", attr_name, DT_INT8, scope_attr_value); - - ge::OpDescPtr op_desc(new ge::OpDesc("add1", "Add")); - ge::AttrUtils::SetStr(op_desc, attr_name, init_value); - Scope scope; - scope.Init("name", "sub_type", nullptr); - - auto ret = naf.impl_->CheckNodeAttrFeatureData(init_value, op_desc, &scope); - EXPECT_EQ(ret, SUCCESS); - } - { - ScopeAttrValue scope_attr_value; - float32_t init_value = 0.0f; - scope_attr_value.SetFloatValue(init_value); - string attr_name("attr_name"); - NodeAttrFeature naf("node_type", attr_name, DT_INT8, scope_attr_value); - - ge::OpDescPtr op_desc(new ge::OpDesc("add1", "Add")); - ge::AttrUtils::SetFloat(op_desc, attr_name, init_value); - Scope scope; - scope.Init("name", "sub_type", nullptr); - - auto ret = naf.impl_->CheckNodeAttrFeatureData(init_value, op_desc, &scope); - EXPECT_EQ(ret, SUCCESS); - } - { - ScopeAttrValue scope_attr_value; - int64_t init_value = 0; - scope_attr_value.SetIntValue(init_value); - string attr_name("attr_name"); - NodeAttrFeature naf("node_type", attr_name, DT_INT8, scope_attr_value); - - ge::OpDescPtr op_desc(new ge::OpDesc("add1", "Add")); - ge::AttrUtils::SetInt(op_desc, attr_name, init_value); - Scope scope; - scope.Init("name", "sub_type", nullptr); - - auto ret = naf.impl_->CheckNodeAttrFeatureData(init_value, op_desc, &scope); - EXPECT_EQ(ret, SUCCESS); - } -} - -TEST_F(ScopePatternUt, ScopeFeature) { - // construct - string sub_type = "sub_type"; - int32_t num = 3; - string suffix = "suffix"; - string sub_scope_mask = "sub_scope_mask"; - int32_t step = 0; - - ScopeFeature sf(sub_type, num, suffix, sub_scope_mask, step); - EXPECT_EQ(sf.impl_->sub_type_, sub_type); - - ScopeFeature sf2("sub_type_2", num, "suffix_2", "sub_scope_mask_2", step); - EXPECT_EQ(sf2.impl_->sub_type_, string("sub_type_2")); - - // copy - ScopeFeature sf3(sf2); - EXPECT_EQ(sf3.impl_->sub_type_, string("sub_type_2")); - - sf2 = sf2; - sf2 = sf; - EXPECT_EQ(sf2.impl_->sub_type_, sub_type); - - // match - Scope scope; - scope.Init("name", "sub_type", nullptr); - EXPECT_FALSE(sf.Match(&scope)); -} - -TEST_F(ScopePatternUt, ScopeFeature_Match) { - std::vector scopes; - Scope scope; - scope.Init("name", "sub_type", nullptr); - scopes.emplace_back(&scope); - Scope scope2; - scope2.Init("name_2", "sub_type_2", nullptr); - scopes.emplace_back(&scope2); - - ScopeFeature sf2("sub_type_2", 1, "suffix_2", "sub_scope_mask_2", 1); - auto ret = sf2.impl_->SubScopesMatch(scopes); - EXPECT_FALSE(ret); -} - -TEST_F(ScopePatternUt, ScopePattern) { - ScopePattern scope_pat; - EXPECT_NE(scope_pat.impl_, nullptr); - - scope_pat.SetSubType("sub_type"); - scope_pat.SetSubType(string("sub_type_2")); - EXPECT_EQ(scope_pat.impl_->sub_type_, string("sub_type_2")); - - scope_pat.impl_.reset(); - scope_pat.SetSubType("sub_type"); - scope_pat.SetSubType(string("sub_type_2")); - EXPECT_EQ(scope_pat.impl_, nullptr); -} - -TEST_F(ScopePatternUt, AddFeature) { - ScopePattern scope_pat; - - NodeOpTypeFeature notf("abc", 1, 0); - scope_pat.AddNodeOpTypeFeature(notf); - EXPECT_TRUE(scope_pat.impl_->node_optype_features_.size() > 0); - - ScopeAttrValue scope_attr_value; - scope_attr_value.SetStringValue("abc"); - NodeAttrFeature naf("node_type", "attr_name", DT_INT8, scope_attr_value); - scope_pat.AddNodeAttrFeature(naf); - EXPECT_TRUE(scope_pat.impl_->node_attr_features_.size() > 0); - - ScopeFeature sf("sub_type", 1, "suffix", "sub_scope_mask", 1); - scope_pat.AddScopeFeature(sf); - EXPECT_TRUE(scope_pat.impl_->scopes_features_.size() > 0); -} - -TEST_F(ScopePatternUt, AddFeature_Null) { - ScopePattern scope_pat; - scope_pat.impl_.reset(); - - NodeOpTypeFeature notf("abc", 1, 0); - scope_pat.AddNodeOpTypeFeature(notf); - - ScopeAttrValue scope_attr_value; - scope_attr_value.SetStringValue("abc"); - NodeAttrFeature naf("node_type", "attr_name", DT_INT8, scope_attr_value); - scope_pat.AddNodeAttrFeature(naf); - - ScopeFeature sf("sub_type", 1, "suffix", "sub_scope_mask", 1); - scope_pat.AddScopeFeature(sf); - - EXPECT_EQ(scope_pat.impl_, nullptr); -} -} // namespace ge diff --git a/tests/ut/register/testcase/scope_util_unittest.cc b/tests/ut/register/testcase/scope_util_unittest.cc deleted file mode 100644 index 24e4465d03237f7d2710b40d0d54733684fa251f..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/scope_util_unittest.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include - -#include "inc/external/register/scope/scope_fusion_pass_register.h" - -using namespace std; -using namespace testing; - -namespace ge { - -class ScopeUtilUt : public testing::Test { - protected: - void SetUp() {} - void TearDown() {} -}; - -TEST_F(ScopeUtilUt, StringReplaceAll) { - AscendString str = ScopeUtil::StringReplaceAll("123456", "123", "abc"); - EXPECT_EQ(str.GetString(), string("abc456")); - - string str2 = ScopeUtil::StringReplaceAll(string("abc456"), string("456"), string("def")); - EXPECT_EQ(str2, string("abcdef")); -} - -TEST_F(ScopeUtilUt, FreeScopePatterns) { - std::vector> scoPatternSub; - std::vector scoPatternSubSub1; - std::vector scoPatternSubSub2; - ScopePattern *scoPattern1; - ScopePattern *scoPattern2; - - scoPattern1 = new ScopePattern(); - scoPattern2 = new ScopePattern(); - scoPatternSubSub1.push_back(scoPattern1); - scoPatternSubSub2.push_back(scoPattern2); - - scoPatternSub.push_back(scoPatternSubSub1); - scoPatternSub.push_back(scoPatternSubSub2); - - ScopeUtil::FreeScopePatterns(scoPatternSub); - EXPECT_EQ(scoPatternSub.size(), 0); -} - -} // namespace ge diff --git a/tests/ut/register/testcase/shape_inference_unittest.cc b/tests/ut/register/testcase/shape_inference_unittest.cc deleted file mode 100644 index 5395fc11771b33a5a98120726a75a47ea0b30161..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/shape_inference_unittest.cc +++ /dev/null @@ -1,1352 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "graph/operator_factory_impl.h" -#include "graph/operator_reg.h" -#include "register/op_impl_registry.h" -#include "register/shape_inference.h" -#include "utils/op_desc_utils.h" -#include -#include "graph/utils/graph_utils.h" -#include "graph/attr_value.h" -#include "external/graph/operator_factory.h" -#include "register/op_impl_space_registry.h" -#include "register/op_impl_registry_holder_manager.h" -#include "common/ge_common/ge_inner_error_codes.h" - -#include -#include "graph/op_desc.h" -#include "graph/normal_graph/op_desc_impl.h" - -#include -#include -#include "mmpa/mmpa_api.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/inference_context.h" -#include "graph/ct_infer_shape_context.h" -#include "graph/ct_infer_shape_range_context.h" -#include "graph/utils/tensor_adapter.h" - - -namespace ge{ -REG_OP(Const) - .OUTPUT(y, - TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, - DT_UINT64, DT_BOOL, DT_DOUBLE})) - .ATTR(value, Tensor, Tensor()) - .OP_END_FACTORY_REG(Const); -} -namespace gert { -using namespace ge; -class ShapeInferenceUT : public testing::Test {}; -// infer from output -REG_OP(FixIOOp_OutputIsFix) - .INPUT(fix_input1, "T") - .INPUT(fix_input2, "T") - .OUTPUT(fix_output, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(FixIOOp_OutputIsFix); -// 无可选输入,无动态输入,正常流程,infer shape & infer data type -TEST_F(ShapeInferenceUT, CallInferV2Func_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 1, 1, 1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - const auto infer_shape_func = [](gert::InferShapeContext *context) -> graphStatus { - const auto input_shape = context->GetInputShape(0U); - auto output = context->GetOutputShape(0); - for (size_t dim = 0UL; dim < input_shape->GetDimNum(); dim++) { - output->AppendDim(input_shape->GetDim(dim)); - } - output->SetDimNum(input_shape->GetDimNum()); - return GRAPH_SUCCESS; - }; - const auto infer_data_type_func = [](gert::InferDataTypeContext *context) -> graphStatus { - const auto date_type = context->GetInputDataType(0U); - EXPECT_EQ(context->SetOutputDataType(0, date_type), SUCCESS); - return GRAPH_SUCCESS; - }; - const auto infer_shape_range_func = [](gert::InferShapeRangeContext *context) -> graphStatus { - auto input_shape_range = context->GetInputShapeRange(0U); - auto output_shape_range = context->GetOutputShapeRange(0U); - output_shape_range->SetMin(const_cast(input_shape_range->GetMin())); - output_shape_range->SetMax(const_cast(input_shape_range->GetMax())); - return GRAPH_SUCCESS; - }; - IMPL_OP(FixIOOp_OutputIsFix).InferShape(infer_shape_func) - .InferDataType(infer_data_type_func) - .InferShapeRange(infer_shape_range_func) - .OutputShapeDependOnCompute({0}); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = infer_shape_func; - op_impl_func.infer_datatype = infer_data_type_func; - op_impl_func.infer_shape_range = infer_shape_range_func; - op_impl_func.output_shape_depend_compute = 1UL; - registry_holder->AddTypesToImpl("FixIOOp_OutputIsFix", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_data_type = OperatorFactoryImpl::GetInferDataTypeFunc(); - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - ASSERT_NE(call_infer_data_type, nullptr); - ASSERT_NE(call_infer_shape_v2, nullptr); - ASSERT_NE(call_infer_shape_range, nullptr); - auto status = call_infer_data_type(op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - status = call_infer_shape_v2(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - uint64_t unknown_shape_type; - ASSERT_TRUE(ge::AttrUtils::GetInt(op_desc, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type)); - ASSERT_EQ(unknown_shape_type, 3U); - status = call_infer_shape_range(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetDataType(), DT_FLOAT16); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 4); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDim(0), 1); -} - -REG_OP(OptionalInput3Input3Output) - .INPUT(input1, "T") - .OPTIONAL_INPUT(input2, "T") - .INPUT(input3, "T") - .OUTPUT(output1, "T2") - .OUTPUT(output2, "T2") - .OUTPUT(output3, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(OptionalInput3Input3Output); -// 未实例化的optional input测试 -TEST_F(ShapeInferenceUT, CallInferV2Func_OptionalInputWithOutInstance) { - auto op = OperatorFactory::CreateOperator("test2", "OptionalInput3Input3Output"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - // input1 - GeShape shape1({1, 2, 3, 4}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc2); - const auto infer_shape_func = [](gert::InferShapeContext *context) -> graphStatus { - const auto option_input_shape = context->GetOptionalInputShape(1U); - if (option_input_shape != nullptr) { - return GRAPH_FAILED; - } - auto output = context->GetOutputShape(0); - const auto input_shape = context->GetInputShape(1U); - for (size_t dim = 0UL; dim < input_shape->GetDimNum(); dim++) { - output->AppendDim(input_shape->GetDim(dim)); - } - output->SetDimNum(input_shape->GetDimNum()); - return GRAPH_SUCCESS; - }; - IMPL_OP(OptionalInput3Input3Output).InferShape(infer_shape_func) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = infer_shape_func; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("OptionalInput3Input3Output", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); -} - -// 实例化的optional input测试 -TEST_F(ShapeInferenceUT, CallInferV2Func_OptionalInputWithInstance) { - auto op = OperatorFactory::CreateOperator("test3", "OptionalInput3Input3Output"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - // input1 - GeShape shape1({1, 2, 3, 4}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input2 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(1, tensor_desc2); - // input3 - GeShape shape3({4, 3}); - GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - const auto infer_shape_func = [](gert::InferShapeContext *context) -> graphStatus { - // update option input to output0 - const auto input_shape = context->GetOptionalInputShape(1U); - auto output = context->GetOutputShape(0); - for (size_t dim = 0UL; dim < input_shape->GetDimNum(); dim++) { - output->AppendDim(input_shape->GetDim(dim)); - } - output->SetDimNum(input_shape->GetDimNum()); - // update input3 to output2 - const auto input_shape2 = context->GetInputShape(2U); - auto output2 = context->GetOutputShape(1); - for (size_t dim = 0UL; dim < input_shape2->GetDimNum(); dim++) { - output2->AppendDim(input_shape2->GetDim(dim)); - } - output2->SetDimNum(input_shape2->GetDimNum()); - return GRAPH_SUCCESS; - }; - IMPL_OP(OptionalInput3Input3Output).InferShape(infer_shape_func) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = infer_shape_func; - registry_holder->AddTypesToImpl("OptionalInput3Input3Output", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - const auto status = call_infer_shape_v2(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - ASSERT_EQ(op_desc->GetOutputDesc(1U).GetShape().GetDimNum(), 2); -} - -REG_OP(TwoOptionalInputsOp) - .INPUT(input0, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(input1, TensorType({DT_FLOAT16, DT_FLOAT})) - .OPTIONAL_INPUT(input2, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(output1, TensorType({DT_FLOAT16, DT_INT8)) - .OP_END_FACTORY_REG(TwoOptionalInputsOp); -// 只实例化第二个可选输入,预期第一个可选输入dtype为undefined,第二个可选输入为算子上的dtype -TEST_F(ShapeInferenceUT, CallInferV2Func_Rt2ConetxtGetRightDtype_JustInstantializeOptionalInputs) { - auto op = OperatorFactory::CreateOperator("test2", "TwoOptionalInputsOp"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - // input0 - GeShape shape0({1, 2, 3, 4}); - GeTensorDesc tensor_desc0(shape0, Format::FORMAT_NCHW, DT_FLOAT); - tensor_desc0.SetOriginShape(shape0); - tensor_desc0.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc0); - // optional input 1 - GeShape shape1({4, 3, 2}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_RESERVED, DT_UNDEFINED); - tensor_desc1.SetOriginDataType(DT_UNDEFINED); - op_desc->UpdateInputDesc(1, tensor_desc1); - // optional input 2 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc2); - - const auto infer_datatype_func = [](gert::InferDataTypeContext *context) -> graphStatus { - const auto option_input_dtype_1 = context->GetOptionalInputDataType(1U); - if (option_input_dtype_1 != ge::DT_UNDEFINED) { - return GRAPH_FAILED; - } - const auto option_input_dtype_2 = context->GetOptionalInputDataType(2U); - if (option_input_dtype_2 != ge::DT_FLOAT16) { - return GRAPH_FAILED; - } - auto ret = context->SetOutputDataType(0U, option_input_dtype_2 == ge::DT_FLOAT16 ? ge::DT_INT8 : DT_FLOAT); - return ret; - }; - IMPL_OP(TwoOptionalInputsOp).InferShape(nullptr).InferDataType(infer_datatype_func).InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_datatype = infer_datatype_func; - registry_holder->AddTypesToImpl("TwoOptionalInputsOp", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_dtype_v2 = OperatorFactoryImpl::GetInferDataTypeFunc(); - ASSERT_NE(call_infer_dtype_v2, nullptr); - auto status = call_infer_dtype_v2(op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - EXPECT_EQ(op_desc->GetOutputDesc(0U).GetDataType(), ge::DT_INT8); -} - -// 动态输入的input测试 -REG_OP(DynamicInput3Input3Output3) - .INPUT(input1, "T") - .DYNAMIC_INPUT(dyn_input, "D") - .INPUT(input3, "T") - .OUTPUT(output1, "T2") - .OUTPUT(output2, "T2") - .OUTPUT(output3, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(DynamicInput3Input3Output3); -const auto INFER_SHAPE_FUNC = [](gert::InferShapeContext *context) -> graphStatus { - // update input3 input to output0 - const auto input_shape = context->GetInputShape(1U); - auto output = context->GetOutputShape(0); - for (size_t dim = 0UL; dim < input_shape->GetDimNum(); dim++) { - output->AppendDim(input_shape->GetDim(dim)); - } - output->SetDimNum(input_shape->GetDimNum()); - // update dyn_input_0 to output1, dyn_input_1 to output2 - const auto input_shape2 = context->GetInputShape(2U); - auto output2 = context->GetOutputShape(1); - for (size_t dim = 0UL; dim < input_shape2->GetDimNum(); dim++) { - output2->AppendDim(input_shape2->GetDim(dim)); - } - output2->SetDimNum(input_shape2->GetDimNum()); - - const auto input_shape3 = context->GetInputShape(3U); - auto output3 = context->GetOutputShape(2); - for (size_t dim = 0UL; dim < input_shape3->GetDimNum(); dim++) { - output3->AppendDim(input_shape3->GetDim(dim)); - } - output3->SetDimNum(input_shape3->GetDimNum()); - return GRAPH_SUCCESS; -}; -TEST_F(ShapeInferenceUT, CallInferV2Func_DynamicInput) { - auto operator_dynamic = op::DynamicInput3Input3Output3("test4"); - operator_dynamic.create_dynamic_input_byindex_dyn_input(2, true); - auto op_desc = OpDescUtils::GetOpDescFromOperator(operator_dynamic); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 4); - // input1 - GeShape shape1({1, 2, 3, 4}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(1, tensor_desc2); - // dynamic input - GeShape shape3({4, 3}); - GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - op_desc->UpdateInputDesc(3, tensor_desc1); - IMPL_OP(DynamicInput3Input3Output3).InferShape(INFER_SHAPE_FUNC) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = INFER_SHAPE_FUNC; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("DynamicInput3Input3Output3", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - ASSERT_EQ(op_desc->GetOutputDesc(1U).GetShape().GetDimNum(), 2); - ASSERT_EQ(op_desc->GetOutputDesc(2U).GetShape().GetDimNum(), 4); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); -} - -// 动态输入的input测试 动态轴-2 -TEST_F(ShapeInferenceUT, CallInferV2Func_DynamicInput_unknow_2) { - auto operator_dynamic = op::DynamicInput3Input3Output3("test4"); - operator_dynamic.create_dynamic_input_byindex_dyn_input(2, true); - auto op_desc = OpDescUtils::GetOpDescFromOperator(operator_dynamic); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 4); - // input1 - GeShape shape1({-2}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(1, tensor_desc2); - // dynamic input - GeShape shape3({4, 3}); - GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - op_desc->UpdateInputDesc(3, tensor_desc1); - IMPL_OP(DynamicInput3Input3Output3).InferShape(INFER_SHAPE_FUNC) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = INFER_SHAPE_FUNC; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("DynamicInput3Input3Output3", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - ASSERT_EQ(op_desc->GetOutputDesc(1U).GetShape().GetDimNum(), 2); - ASSERT_EQ(op_desc->GetOutputDesc(2U).GetShape().GetDimNum(), 0); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); -} - -// 动态输入的input测试 动态轴-1, shape range 不设值 -TEST_F(ShapeInferenceUT, CallInferV2Func_DynamicInput_unknow_no_shaperange) { - auto operator_dynamic = op::DynamicInput3Input3Output3("test4"); - operator_dynamic.create_dynamic_input_byindex_dyn_input(2, true); - auto op_desc = OpDescUtils::GetOpDescFromOperator(operator_dynamic); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 4); - // input1 - GeShape shape1({1, 2, 3, -1}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(1, tensor_desc2); - // dynamic input - GeShape shape3({4, 3}); - GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - op_desc->UpdateInputDesc(3, tensor_desc1); - IMPL_OP(DynamicInput3Input3Output3).InferShape(INFER_SHAPE_FUNC) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = INFER_SHAPE_FUNC; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("DynamicInput3Input3Output3", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - ASSERT_EQ(op_desc->GetOutputDesc(1U).GetShape().GetDimNum(), 2); - ASSERT_EQ(op_desc->GetOutputDesc(2U).GetShape().GetDimNum(), 4); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); -} - -// 动态输入的input测试 动态轴-1, shape range 设值 -TEST_F(ShapeInferenceUT, CallInferV2Func_DynamicInput_unknow_shaperange) { - auto operator_dynamic = op::DynamicInput3Input3Output3("test4"); - operator_dynamic.create_dynamic_input_byindex_dyn_input(2, true); - auto op_desc = OpDescUtils::GetOpDescFromOperator(operator_dynamic); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 4); - // input1 - GeShape shape1({1, 2, 3, -1}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{1, 1}, {2, 2}, {3, 3}, {22, 999}}; - tensor_desc1.SetShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(1, tensor_desc2); - // dynamic input - GeShape shape3({4, 3}); - GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - op_desc->UpdateInputDesc(3, tensor_desc1); - IMPL_OP(DynamicInput3Input3Output3).InferShape(INFER_SHAPE_FUNC) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = INFER_SHAPE_FUNC; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("DynamicInput3Input3Output3", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - ASSERT_EQ(op_desc->GetOutputDesc(1U).GetShape().GetDimNum(), 2); - ASSERT_EQ(op_desc->GetOutputDesc(2U).GetShape().GetDimNum(), 4); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - std::vector> shape_range; - (void)op_desc->GetOutputDesc(2U).GetShapeRange(shape_range); - ASSERT_EQ(shape_range.size(), 4U); - for (size_t i = 0UL; i < shape_range.size(); ++i) { - ASSERT_EQ(shape_range[i].first, range[i].first); - ASSERT_EQ(shape_range[i].second, range[i].second); - } -} - -// 动态输入的input测试 动态轴-1, shape range 设值,min大于max异常场景 -TEST_F(ShapeInferenceUT, CallInferV2Func_DynamicInput_unknow_shaperange_min_bigger_max) { - auto operator_dynamic = op::DynamicInput3Input3Output3("test4"); - operator_dynamic.create_dynamic_input_byindex_dyn_input(2, true); - auto op_desc = OpDescUtils::GetOpDescFromOperator(operator_dynamic); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 4); - // input1 - GeShape shape1({1, 2, 3, -1}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{1, 1}, {2, 2}, {3, 3}, {999, 22}}; - tensor_desc1.SetShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(1, tensor_desc2); - // dynamic input - GeShape shape3({4, 3}); - GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - op_desc->UpdateInputDesc(3, tensor_desc1); - IMPL_OP(DynamicInput3Input3Output3).InferShape(INFER_SHAPE_FUNC) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = INFER_SHAPE_FUNC; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("DynamicInput3Input3Output3", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - ASSERT_EQ(op_desc->GetOutputDesc(1U).GetShape().GetDimNum(), 2); - ASSERT_EQ(op_desc->GetOutputDesc(2U).GetShape().GetDimNum(), 4); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(operator_dynamic, op_desc); - ASSERT_EQ(status, ge::PARAM_INVALID); - std::vector> shape_range; - (void)op_desc->GetOutputDesc(2U).GetShapeRange(shape_range); - ASSERT_EQ(shape_range.size(), 0U); -} - -// 动态输入的input测试 动态轴-1, shape range 设值, min大于max, max为-1的正常场景 -TEST_F(ShapeInferenceUT, CallInferV2Func_DynamicInput_unknow_shaperange_min_bigger_max_success) { - auto operator_dynamic = op::DynamicInput3Input3Output3("test4"); - operator_dynamic.create_dynamic_input_byindex_dyn_input(2, true); - auto op_desc = OpDescUtils::GetOpDescFromOperator(operator_dynamic); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 4); - // input1 - GeShape shape1({1, 2, 3, -1}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{1, -1}, {2, -1}, {3, -1}, {999, -1}}; - tensor_desc1.SetShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - GeShape shape2({4, 3, 2}); - GeTensorDesc tensor_desc2(shape2, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc2.SetOriginShape(shape2); - tensor_desc2.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(1, tensor_desc2); - // dynamic input - GeShape shape3({4, 3}); - GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - op_desc->UpdateInputDesc(3, tensor_desc1); - IMPL_OP(DynamicInput3Input3Output3).InferShape(INFER_SHAPE_FUNC) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = INFER_SHAPE_FUNC; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("DynamicInput3Input3Output3", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(operator_dynamic, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - ASSERT_EQ(op_desc->GetOutputDesc(0U).GetShape().GetDimNum(), 3); - ASSERT_EQ(op_desc->GetOutputDesc(1U).GetShape().GetDimNum(), 2); - ASSERT_EQ(op_desc->GetOutputDesc(2U).GetShape().GetDimNum(), 4); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(operator_dynamic, op_desc); - ASSERT_EQ(status, ge::GRAPH_SUCCESS); -} - -// 二类算子值依赖测试 -REG_OP(Type2_1Input_1Output) - .INPUT(input1, "T") - .OPTIONAL_INPUT(input2, "T") - .INPUT(input3, "T") - .OUTPUT(output1, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(Type2_1Input_1Output); -TEST_F(ShapeInferenceUT, CallInferV2Func_Type2ValueDepend) { - // construct const input - auto const_input = ge::op::Const("const_input"); - ge::TensorDesc td{ge::Shape(std::vector({1, 2, 3, 4})), FORMAT_NCHW, DT_UINT8}; - ge::Tensor tensor(td); - std::vector val{0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58}; - tensor.SetData(val); - const_input.set_attr_value(tensor); - // const input link to op - auto op = op::Type2_1Input_1Output("test5"); - op.set_input_input1(const_input); - op.set_input_input3(const_input); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 3); - // input1 - ge::GeShape shape1({1, 2, 3, 5}); - ge::GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - op_desc->UpdateInputDesc(2, tensor_desc1); - const auto infer_shape_func = [](gert::InferShapeContext *context) -> graphStatus { - // update input3(因为option输入未实例化,所以是第二个) value to output0 - const auto data = context->GetInputTensor(1U)->GetData(); - std::vector dims = {data[0], data[1], data[2], data[3]}; - ge::Shape input_shape(dims); - auto output = context->GetOutputShape(0); - for (size_t dim = 0UL; dim < input_shape.GetDimNum(); dim++) { - output->AppendDim(input_shape.GetDim(dim)); - } - output->SetDimNum(input_shape.GetDimNum()); - return GRAPH_SUCCESS; - }; - IMPL_OP(Type2_1Input_1Output).InferShape(infer_shape_func).InputsDataDependency({2}) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = infer_shape_func; - op_impl_func.SetInputDataDependency(2); - registry_holder->AddTypesToImpl("Type2_1Input_1Output", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - const auto status = call_infer_shape_v2(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - const auto &shape = op_desc->GetOutputDesc(0U).GetShape(); - ASSERT_EQ(shape.GetDimNum(), 4); - ASSERT_EQ(shape.GetDim(0U), 85); - ASSERT_EQ(shape.GetDim(1U), 86); - ASSERT_EQ(shape.GetDim(2U), 87); - ASSERT_EQ(shape.GetDim(3U), 88); -} - -// 二类算子值依赖测试,带shape range -REG_OP(Type2_3Input_2Output) - .INPUT(input1, "T") - .OPTIONAL_INPUT(input2, "T") - .INPUT(input3, "T") - .OUTPUT(output1, "T2") - .OUTPUT(output2, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(Type2_3Input_2Output); -TEST_F(ShapeInferenceUT, CallInferV2Func_Type2ValueDepend_unknow_shaperange) { - // construct const input - auto const_input = ge::op::Const("const_input"); - ge::TensorDesc td{ge::Shape(std::vector({1, 2, 3, 4})), FORMAT_NCHW, DT_UINT8}; - ge::Tensor tensor(td); - std::vector val{0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58}; - tensor.SetData(val); - const_input.set_attr_value(tensor); - // const input link to op - auto op = op::Type2_3Input_2Output("test5"); - op.set_input_input1(const_input); - op.set_input_input3(const_input); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 3); - // input1 - GeShape shape1({1, 2, 3, -1}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{1, 1}, {2, 2}, {3, 3}, {22, 999}}; - tensor_desc1.SetShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - ge::GeShape shape3({1, 2, 3, 5}); - ge::GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - const auto infer_shape_func = [](gert::InferShapeContext *context) -> graphStatus { - // update input3(因为option输入未实例化,所以是第二个) value to output0 - const auto data = context->GetInputTensor(1U)->GetData(); - std::vector dims = {data[0], data[1], data[2], data[3]}; - ge::Shape input_shape(dims); - auto output = context->GetOutputShape(0); - for (size_t dim = 0UL; dim < input_shape.GetDimNum(); dim++) { - output->AppendDim(input_shape.GetDim(dim)); - } - output->SetDimNum(input_shape.GetDimNum()); - - const auto input_shape1 = context->GetInputShape(0U); - auto output1 = context->GetOutputShape(1); - for (size_t dim = 0UL; dim < input_shape1->GetDimNum(); dim++) { - output1->AppendDim(input_shape1->GetDim(dim)); - } - output1->SetDimNum(input_shape1->GetDimNum()); - return GRAPH_SUCCESS; - }; - const auto infer_shape_range_func = [](gert::InferShapeRangeContext *context) -> graphStatus { - auto input_shape_range = context->GetInputShapeRange(0U); - auto output_shape_range = context->GetOutputShapeRange(0U); - output_shape_range->SetMin(const_cast(input_shape_range->GetMin())); - output_shape_range->SetMax(const_cast(input_shape_range->GetMax())); - return GRAPH_SUCCESS; - }; - IMPL_OP(Type2_3Input_2Output).InferShape(infer_shape_func).InputsDataDependency({2}) - .InferDataType(nullptr) - .InferShapeRange(infer_shape_range_func); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = infer_shape_func; - op_impl_func.SetInputDataDependency(2); - op_impl_func.infer_shape_range = infer_shape_range_func; - registry_holder->AddTypesToImpl("Type2_3Input_2Output", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - auto status = call_infer_shape_v2(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - const auto &shape = op_desc->GetOutputDesc(0U).GetShape(); - ASSERT_EQ(shape.GetDimNum(), 4); - ASSERT_EQ(shape.GetDim(0U), 85); - ASSERT_EQ(shape.GetDim(1U), 86); - ASSERT_EQ(shape.GetDim(2U), 87); - ASSERT_EQ(shape.GetDim(3U), 88); - const auto call_infer_shape_range = OperatorFactoryImpl::GetInferShapeRangeFunc(); - status = call_infer_shape_range(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - std::vector> shape_range; - (void)op_desc->GetOutputDesc(0U).GetShapeRange(shape_range); - ASSERT_EQ(shape_range.size(), 4U); - for (size_t i = 0UL; i < shape_range.size(); ++i) { - ASSERT_EQ(shape_range[i].first, range[i].first); - ASSERT_EQ(shape_range[i].second, range[i].second); - } -} -TEST_F(ShapeInferenceUT, CallInferV2Func_skip_shaperange_infer_when_input_without_range) { - // construct const input - auto const_input = ge::op::Const("const_input"); - ge::TensorDesc td{ge::Shape(std::vector({1, 2, 3, 4})), FORMAT_NCHW, DT_UINT8}; - ge::Tensor tensor(td); - std::vector val{0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58}; - tensor.SetData(val); - const_input.set_attr_value(tensor); - // const input link to op - auto op = op::Type2_3Input_2Output("test5"); - op.set_input_input1(const_input); - op.set_input_input3(const_input); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 3); - // input1 - GeShape shape1({1, 2, 3, -1}); - GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - // input3 - ge::GeShape shape3({1, 2, 3, 5}); - ge::GeTensorDesc tensor_desc3(shape3, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc3.SetOriginShape(shape3); - tensor_desc3.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(2, tensor_desc3); - - const auto infer_shape_range_func = [](gert::InferShapeRangeContext *context) -> graphStatus { return GRAPH_FAILED; }; - IMPL_OP(Type2_3Input_2Output).InputsDataDependency({2}) - .InferDataType(nullptr) - .InferShapeRange(infer_shape_range_func); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.SetInputDataDependency(2); - op_impl_func.infer_shape_range = infer_shape_range_func; - registry_holder->AddTypesToImpl("Type2_3Input_2Output", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - auto status = InferShapeRangeOnCompile(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); // success means not called infer_shape_range_func -} - -// 资源类算子测试 -REG_OP(RegisterAndGetReiledOnResource) - .INPUT(input1, "T") - .OUTPUT(output1, "T2") - .DATATYPE(T2, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(RegisterAndGetReiledOnResource); -TEST_F(ShapeInferenceUT, CallInferV2Func_RegisterAndGetReiledOnResource) { - auto op = OperatorFactory::CreateOperator("test6", "RegisterAndGetReiledOnResource"); - const char_t *resource_key = "224"; - auto read_inference_context = std::shared_ptr(InferenceContext::Create()); - read_inference_context->RegisterReliedOnResourceKey(AscendString(resource_key)); - op.SetInferenceContext(read_inference_context); - - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); // simulate read_op register relied resource - // input1 - ge::GeShape shape1({1, 2, 3, 5}); - ge::GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - const auto infer_shape_func = [](gert::InferShapeContext *context) -> graphStatus { - auto ct_context = reinterpret_cast(context); - const auto &read_inference_context = ct_context->GetInferenceContext(); - const auto &reiled_keys = read_inference_context->GetReliedOnResourceKeys(); - const char_t *resource_key_ = "224"; - // check result - EXPECT_EQ(reiled_keys.empty(), false); - EXPECT_EQ(*reiled_keys.begin(), resource_key_); - if (reiled_keys.empty() || - (*reiled_keys.begin() != resource_key_)) { - return GRAPH_FAILED; - } - auto out_shape = context->GetOutputShape(0UL); - out_shape->SetDimNum(1UL); - out_shape->SetDim(0UL, std::strtol(resource_key_, nullptr, 10)); - return GRAPH_SUCCESS; - }; - const auto infer_shape_range_func = [](gert::InferShapeRangeContext *context) -> graphStatus { - auto ct_context = reinterpret_cast(context); - const auto &read_inference_context = ct_context->GetInferenceContext(); - const auto &reiled_keys = read_inference_context->GetReliedOnResourceKeys(); - const char_t *resource_key_ = "224"; - // check result - EXPECT_EQ(reiled_keys.empty(), false); - EXPECT_EQ(*reiled_keys.begin(), resource_key_); - return GRAPH_SUCCESS; - }; - IMPL_OP(RegisterAndGetReiledOnResource) - .InferShape(infer_shape_func) - .InferDataType(nullptr) - .InferShapeRange(infer_shape_range_func); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = infer_shape_func; - registry_holder->AddTypesToImpl("RegisterAndGetReiledOnResource", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_shape_v2 = OperatorFactoryImpl::GetInferShapeV2Func(); - ASSERT_NE(call_infer_shape_v2, nullptr); - const auto status = call_infer_shape_v2(op, op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - const auto &shape = op_desc->GetOutputDesc(0U).GetShape(); - ASSERT_EQ(shape.GetDim(0), std::strtol(resource_key, nullptr, 10)); -} - -// 默认infer datatype测试 -REG_OP(TestDefaultInferDataType) - .INPUT(input1, "T") - .OUTPUT(output1, "T") - .DATATYPE(T, TensorType({DT_BOOL})) - .OP_END_FACTORY_REG(TestDefaultInferDataType); -TEST_F(ShapeInferenceUT, CallInferV2Func_TestDefaultInferShape) { - auto op = OperatorFactory::CreateOperator("test7", "TestDefaultInferDataType"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); // simulate read_op register relied resource - // input1 - ge::GeShape shape1({1, 2, 3, 5}); - ge::GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_BOOL); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_BOOL); - op_desc->UpdateInputDesc(0, tensor_desc1); - const auto infer_shape_func = [](gert::InferShapeContext *context) -> graphStatus { - return GRAPH_SUCCESS; - }; - IMPL_OP(TestDefaultInferDataType) - .InferShape(infer_shape_func) - .InferDataType(nullptr) - .InferShapeRange(nullptr); - const auto call_infer_data_type = OperatorFactoryImpl::GetInferDataTypeFunc(); - const auto status = call_infer_data_type(op_desc); - ASSERT_EQ(status, GRAPH_SUCCESS); - const auto &data_type = op_desc->GetOutputDesc(0U).GetDataType(); - ASSERT_EQ(data_type, DT_BOOL); -} -TEST_F(ShapeInferenceUT, AdaptFuncRegisterOk) { - ASSERT_NE(OperatorFactoryImpl::GetInferShapeV2Func(), nullptr); - ASSERT_NE(OperatorFactoryImpl::GetInferShapeRangeFunc(), nullptr); - ASSERT_NE(OperatorFactoryImpl::GetInferDataTypeFunc(), nullptr); -} -TEST_F(ShapeInferenceUT, CallInferV2Func_no_inferfunc_failed_FAST_IGNORE_INFER_ERRIR_is_empty) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 1, 1, 1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; // make v1 is null - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = nullptr; // make v2 is null - op_impl_func.infer_datatype = nullptr; - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("FixIOOp_OutputIsFix", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - mmSetEnv("IGNORE_INFER_ERROR", "111", 1); - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, GRAPH_PARAM_INVALID); - unsetenv("IGNORE_INFER_ERROR"); -} - -TEST_F(ShapeInferenceUT, CallInferV2Func_NoSpaceRegistry_NotSupportSymbolicInfer_failed) { - auto op = OperatorFactory::CreateOperator("test1", "AddUt"); // not support symbolic infer dtype - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 1, 1, 1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; // make v1 is null - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, ge::GRAPH_PARAM_INVALID); -} - -TEST_F(ShapeInferenceUT, CallInferV2Func_NoSpaceRegistry_SupportSymbolicInfer_success) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); // support symbolic infer dtype - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - - ASSERT_NE(op_desc, nullptr); - GeShape shape({1, 1, 1, 1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; // make v1 is null - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(nullptr); - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, ge::GRAPH_PARAM_INVALID); -} - -TEST_F(ShapeInferenceUT, CallInferV2Func_NoCustomInferdtype_NotSupportSymbolicInfer_failed) { - auto op = OperatorFactory::CreateOperator("test1", "AddUt"); // not support symbolic infer dtype - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - - ASSERT_NE(op_desc, nullptr); - GeShape shape({1,1,1,1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; // make v1 is null - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = nullptr; // make v2 is null - op_impl_func.infer_datatype = nullptr; // make v2 infer dtype is null - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("AddUt", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - auto status = OpDescUtilsEx::CallInferFunc(op_desc, op); - ASSERT_EQ(status, ge::GRAPH_PARAM_INVALID); -} - -TEST_F(ShapeInferenceUT, CallInferV2Func_NoInferShape_failed) { - auto op = OperatorFactory::CreateOperator("test1", "FixIOOp_OutputIsFix"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - - ASSERT_NE(op_desc, nullptr); - GeShape shape({1,1,1,1}); - GeTensorDesc tensor_desc(shape, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc.SetOriginShape(shape); - tensor_desc.SetOriginDataType(DT_FLOAT16); - std::vector> range = {{0, 10000}}; - tensor_desc.SetOriginShapeRange(range); - op_desc->UpdateInputDesc(0, tensor_desc); - op_desc->UpdateInputDesc(1, tensor_desc); - op_desc->impl_->infer_func_ = nullptr; // make v1 is null - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctions op_impl_func; - op_impl_func.infer_shape = nullptr; // make v2 is null - op_impl_func.infer_datatype = nullptr; // make v2 infer dtype is null - op_impl_func.infer_shape_range = nullptr; - registry_holder->AddTypesToImpl("AddUt", op_impl_func); - space_registry->AddRegistry(registry_holder); - gert::DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - auto status = InferShapeOnCompile(op, op_desc); - ASSERT_NE(status, ge::GRAPH_SUCCESS); -} - -TEST_F(ShapeInferenceUT, CallInferFormatFunc_OptionalInput) { - auto op = OperatorFactory::CreateOperator("OptionalTest", "OptionalInput3Input3Output"); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - EXPECT_NE(op_desc, nullptr); - // option_input - GeShape shape({1, 1, 1, 1}); - const auto format = Format::FORMAT_NCHW; - const auto ori_format = Format::FORMAT_ND; - GeTensorDesc tensor_desc(shape, format); - tensor_desc.SetOriginFormat(ori_format); - tensor_desc.SetOriginShape(shape); - op_desc->UpdateInputDesc(1U, tensor_desc); - - auto infer_format_func = [](gert::InferFormatContext *context) -> UINT32 { - const auto option_input_format = context->GetOptionalInputFormat(1U); - EXPECT_NE(option_input_format, nullptr); - EXPECT_EQ(option_input_format->GetOriginFormat(), Format::FORMAT_ND); - EXPECT_EQ(option_input_format->GetStorageFormat(), Format::FORMAT_NCHW); - - const auto option_input_shape = context->GetOptionalInputShape(1U); - EXPECT_NE(option_input_shape, nullptr); - EXPECT_EQ(option_input_shape->GetDimNum(), 4UL); - EXPECT_EQ(option_input_shape->GetDim(0U), 1L); - - auto input_0 = context->GetRequiredInputFormat(0U); - input_0->SetOriginFormat(Format::FORMAT_NCHW); - input_0->SetStorageFormat(Format::FORMAT_NCHW); - - auto output_0 = context->GetOutputFormat(0U); - output_0->SetOriginFormat(Format::FORMAT_NCHW); - output_0->SetStorageFormat(Format::FORMAT_NCHW); - return GRAPH_SUCCESS; - }; - IMPL_OP(OptionalInput3Input3Output).InferFormat(infer_format_func); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctionsV2 op_impl_func; - op_impl_func.infer_format_func = infer_format_func; - registry_holder->AddTypesToImpl("OptionalInput3Input3Output", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_format_v2 = OperatorFactoryImpl::GetInferFormatV2Func(); - EXPECT_NE(call_infer_format_v2, nullptr); - ASSERT_EQ(call_infer_format_v2(op, op_desc), GRAPH_SUCCESS); - - const auto &input_0 = op_desc->GetInputDesc(0U); - EXPECT_EQ(input_0.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(input_0.GetFormat(), Format::FORMAT_NCHW); - - const auto &output_0 = op_desc->GetOutputDesc(0U); - EXPECT_EQ(output_0.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(output_0.GetFormat(), Format::FORMAT_NCHW); -} - -TEST_F(ShapeInferenceUT, CallInferFormatFunc_DynamicInput) { - auto op = op::DynamicInput3Input3Output3("DynamicTest"); - op.create_dynamic_input_byindex_dyn_input(2, 1); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 4); - - // dynamic_input index=1 - GeShape shape({1, 1, 1, 1}); - const auto format = Format::FORMAT_NCHW; - const auto ori_format = Format::FORMAT_ND; - GeTensorDesc tensor_desc(shape, format); - tensor_desc.SetOriginFormat(ori_format); - tensor_desc.SetOriginShape(shape); - op_desc->UpdateInputDesc(1U, tensor_desc); - - auto infer_format_func = [](gert::InferFormatContext *context) -> UINT32 { - const auto dynamic_input_format = context->GetDynamicInputFormat(1U, 0U); - EXPECT_NE(dynamic_input_format, nullptr); - EXPECT_EQ(dynamic_input_format->GetOriginFormat(), Format::FORMAT_ND); - EXPECT_EQ(dynamic_input_format->GetStorageFormat(), Format::FORMAT_NCHW); - - const auto dynamic_input_shape = context->GetDynamicInputShape(1U, 0U); - EXPECT_NE(dynamic_input_shape, nullptr); - EXPECT_EQ(dynamic_input_shape->GetDimNum(), 4UL); - EXPECT_EQ(dynamic_input_shape->GetDim(0U), 1L); - - auto input_0 = context->GetRequiredInputFormat(2U); - input_0->SetOriginFormat(Format::FORMAT_NCHW); - input_0->SetStorageFormat(Format::FORMAT_NCHW); - - auto output_0 = context->GetOutputFormat(0U); - output_0->SetOriginFormat(Format::FORMAT_NCHW); - output_0->SetStorageFormat(Format::FORMAT_NCHW); - return GRAPH_SUCCESS; - }; - IMPL_OP(DynamicInput3Input3Output3).InferFormat(infer_format_func); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctionsV2 op_impl_func; - op_impl_func.infer_format_func = infer_format_func; - registry_holder->AddTypesToImpl("DynamicInput3Input3Output3", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_format_v2 = OperatorFactoryImpl::GetInferFormatV2Func(); - EXPECT_NE(call_infer_format_v2, nullptr); - ASSERT_EQ(call_infer_format_v2(op, op_desc), GRAPH_SUCCESS); - - const auto &input_0 = op_desc->GetInputDesc(3U); - EXPECT_EQ(input_0.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(input_0.GetFormat(), Format::FORMAT_NCHW); - - const auto &output_0 = op_desc->GetOutputDesc(0U); - EXPECT_EQ(output_0.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(output_0.GetFormat(), Format::FORMAT_NCHW); -} - -TEST_F(ShapeInferenceUT, CallInferFormatFunc_ValueDependInput) { - // construct const input - auto const_input = ge::op::Const("const_input"); - ge::TensorDesc td{ge::Shape(std::vector({1, 2, 3, 4})), FORMAT_NCHW, DT_UINT8}; - ge::Tensor tensor(td); - std::vector val{0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58, - 0x55, 0x56, 0x57, 0x58, 0x58, 0x58}; - tensor.SetData(val); - const_input.set_attr_value(tensor); - // const input link to op - auto op = op::Type2_1Input_1Output("test5"); - op.set_input_input1(const_input); - op.set_input_input3(const_input); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - ASSERT_EQ(op_desc->GetAllInputsSize(), 3); - // input1 - ge::GeShape shape1({1, 2, 3, 4}); - ge::GeTensorDesc tensor_desc1(shape1, Format::FORMAT_NCHW, DT_FLOAT16); - tensor_desc1.SetOriginShape(shape1); - tensor_desc1.SetOriginDataType(DT_FLOAT16); - op_desc->UpdateInputDesc(0, tensor_desc1); - op_desc->UpdateInputDesc(2, tensor_desc1); - - - auto infer_format_func = [](gert::InferFormatContext *context) -> UINT32 { - const auto data = context->GetInputTensor(1U)->GetData(); - EXPECT_EQ(data[0], 85); - - auto output_0 = context->GetOutputFormat(0U); - output_0->SetOriginFormat(Format::FORMAT_NCHW); - output_0->SetStorageFormat(Format::FORMAT_NCHW); - return GRAPH_SUCCESS; - }; - IMPL_OP(Type2_1Input_1Output).InferFormat(infer_format_func).InputsDataDependency({2}); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctionsV2 op_impl_func; - op_impl_func.infer_format_func = infer_format_func; - op_impl_func.SetInputDataDependency(2); - registry_holder->AddTypesToImpl("Type2_1Input_1Output", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_format_v2 = OperatorFactoryImpl::GetInferFormatV2Func(); - EXPECT_NE(call_infer_format_v2, nullptr); - ASSERT_EQ(call_infer_format_v2(op, op_desc), GRAPH_SUCCESS); - - const auto &output_0 = op_desc->GetOutputDesc(0U); - EXPECT_EQ(output_0.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(output_0.GetFormat(), Format::FORMAT_NCHW); -} - -REG_OP(DynamicInputDynamicOutput) - .DYNAMIC_INPUT(dyn_input, "D") - .DYNAMIC_OUTPUT(dyn_output, "D") - .OUTPUT(out, "T") - .OP_END_FACTORY_REG(DynamicInputDynamicOutput); -TEST_F(ShapeInferenceUT, CallInferFormatFunc_DynamicOutput) { - auto op = op::DynamicInputDynamicOutput("DynamicTest"); - op.create_dynamic_input_byindex_dyn_input(1, 0); - op.create_dynamic_output_dyn_output(1); - auto op_desc = OpDescUtils::GetOpDescFromOperator(op); - ASSERT_NE(op_desc, nullptr); - - // dynamic_input - GeShape shape({1, 1, 1, 1}); - const auto format = Format::FORMAT_NCHW; - const auto ori_format = Format::FORMAT_ND; - GeTensorDesc tensor_desc(shape, format); - tensor_desc.SetOriginFormat(ori_format); - tensor_desc.SetOriginShape(shape); - op_desc->UpdateInputDesc(0U, tensor_desc); - op_desc->UpdateOutputDesc(0U, tensor_desc); - op_desc->UpdateOutputDesc(1U, tensor_desc); - - auto infer_format_func = [](gert::InferFormatContext *context) -> UINT32 { - const auto dynamic_input_format = context->GetDynamicInputFormat(0U, 0U); - EXPECT_NE(dynamic_input_format, nullptr); - EXPECT_EQ(dynamic_input_format->GetOriginFormat(), Format::FORMAT_ND); - EXPECT_EQ(dynamic_input_format->GetStorageFormat(), Format::FORMAT_NCHW); - - const auto dynamic_input_shape = context->GetDynamicInputShape(0U, 0U); - EXPECT_NE(dynamic_input_shape, nullptr); - EXPECT_EQ(dynamic_input_shape->GetDimNum(), 4UL); - EXPECT_EQ(dynamic_input_shape->GetDim(0U), 1L); - - - auto output_0 = context->GetDynamicOutputFormat(0U, 0U); - output_0->SetOriginFormat(Format::FORMAT_NCHW); - output_0->SetStorageFormat(Format::FORMAT_NCHW); - - auto output_1 = context->GetRequiredOutputFormat(0U); - output_1->SetOriginFormat(Format::FORMAT_NCHW); - output_1->SetStorageFormat(Format::FORMAT_NCHW); - return GRAPH_SUCCESS; - }; - IMPL_OP(DynamicInputDynamicOutput).InferFormat(infer_format_func); - - auto space_registry = std::make_shared(); - auto registry_holder = std::make_shared(); - gert::OpImplKernelRegistry::OpImplFunctionsV2 op_impl_func; - op_impl_func.infer_format_func = infer_format_func; - registry_holder->AddTypesToImpl("DynamicInputDynamicOutput", op_impl_func); - space_registry->AddRegistry(registry_holder); - DefaultOpImplSpaceRegistry::GetInstance().SetDefaultSpaceRegistry(space_registry); - - const auto call_infer_format_v2 = OperatorFactoryImpl::GetInferFormatV2Func(); - EXPECT_NE(call_infer_format_v2, nullptr); - ASSERT_EQ(call_infer_format_v2(op, op_desc), GRAPH_SUCCESS); - - const auto &output_0 = op_desc->GetOutputDesc(0U); - EXPECT_EQ(output_0.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(output_0.GetFormat(), Format::FORMAT_NCHW); - - const auto &output_1 = op_desc->GetOutputDesc(1U); - EXPECT_EQ(output_1.GetOriginFormat(), Format::FORMAT_NCHW); - EXPECT_EQ(output_1.GetFormat(), Format::FORMAT_NCHW); -} -} // namespace gert diff --git a/tests/ut/register/testcase/stream_manage_func_registry_unittest.cc b/tests/ut/register/testcase/stream_manage_func_registry_unittest.cc deleted file mode 100644 index 0ba4f8904ca83f3ee8de7e30982d3df9a0a6da87..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/stream_manage_func_registry_unittest.cc +++ /dev/null @@ -1,60 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include - -#include "register/stream_manage_func_registry.h" - -namespace ge { -uint32_t FakeCallbackFunc(MngActionType action_type, MngResourceHandle handle) { - (void) action_type; - (void) handle; - return SUCCESS; -} - -uint32_t FakeCallbackFuncFailed(MngActionType action_type, MngResourceHandle handle) { - (void) action_type; - (void) handle; - return FAILED; -} - -class StreamMngFuncRegistryUT : public testing::Test {}; - -TEST_F(StreamMngFuncRegistryUT, TryCallStreamMngFunc_Ok) { - const auto func_type = StreamMngFuncType::ACLNN_STREAM_CALLBACK; - const MngActionType action_type = MngActionType::DESTROY_STREAM; - const MngResourceHandle handle = {.stream = (void *) 0x11}; - // case: no callback function is registered - EXPECT_EQ(StreamMngFuncRegistry::GetInstance().TryCallStreamMngFunc(func_type, action_type, handle), SUCCESS); - - // case: run callback function successfully - REG_STREAM_MNG_FUNC(func_type, FakeCallbackFunc); - EXPECT_EQ(StreamMngFuncRegistry::GetInstance().TryCallStreamMngFunc(func_type, action_type, handle), SUCCESS); -} - -TEST_F(StreamMngFuncRegistryUT, TryCallStreamMngFunc_Ok_CallFuncReturnNonzero) { - const auto func_type = StreamMngFuncType::ACLNN_STREAM_CALLBACK; - const MngActionType action_type = MngActionType::RESET_DEVICE; - const MngResourceHandle handle = {.device_id = 0}; - // case: run callback function failed - REG_STREAM_MNG_FUNC(func_type, FakeCallbackFuncFailed); - EXPECT_EQ(StreamMngFuncRegistry::GetInstance().TryCallStreamMngFunc(func_type, action_type, handle), SUCCESS); -} - -TEST_F(StreamMngFuncRegistryUT, TryCallStreamMngFunc_Ok_MultipleRegister) { - const auto func_type = StreamMngFuncType::ACLNN_STREAM_CALLBACK; - // register for the first time - REG_STREAM_MNG_FUNC(func_type, FakeCallbackFuncFailed); - EXPECT_EQ(StreamMngFuncRegistry::GetInstance().LookUpStreamMngFunc(func_type), FakeCallbackFuncFailed); - // register for the second time - REG_STREAM_MNG_FUNC(func_type, FakeCallbackFunc); - EXPECT_EQ(StreamMngFuncRegistry::GetInstance().LookUpStreamMngFunc(func_type), FakeCallbackFunc); -} - -} // namespace ge diff --git a/tests/ut/register/testcase/tensor_assign_unittest.cc b/tests/ut/register/testcase/tensor_assign_unittest.cc deleted file mode 100644 index 3fcd24aee4b76ba842961fd69580fe5f3062ef4b..0000000000000000000000000000000000000000 --- a/tests/ut/register/testcase/tensor_assign_unittest.cc +++ /dev/null @@ -1,302 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include -#include -#include "graph_builder_utils.h" -#include "external/register/register.h" -#include -#include "proto/tensorflow/node_def.pb.h" -#include "register/op_registry.h" -#include "graph/graph.h" -#include "graph/utils/attr_utils.h" -#include "graph/debug/ge_util.h" - -#include "register/auto_mapping_util.h" -#include "external/register/scope/scope_fusion_pass_register.h" -#include "register/scope/scope_graph_impl.h" - -using namespace ge; -using namespace domi; - -class ConvertTensorUtest : public testing::Test { -public: - domi::tensorflow::TensorProto tensor_; - ge::graphStatus ret_; - ge::GeTensorPtr weight_; - -protected: - void SetUp() { - tensor_.set_tensor_content("tensor_context_for_test"); - } - void TearDown() {} -}; - -const float FLOAT_TEST_NUM = 3.14; -const double DOUBLE_TEST_NUM = 3.1415; -const int INT_TEST_NUM = 66; -const unsigned int UNSIGNED_INT_TEST_NUM = 88; - -TEST_F(ConvertTensorUtest, ConvertTensorNoType) { - GeTensorPtr weight; - weight.reset(); - TensorAssign::SetWeightData(domi::tensorflow::DataType_INT_MAX_SENTINEL_DO_NOT_USE_, 0, std::string("content"), weight); - tensor_.set_dtype(domi::tensorflow::DataType_INT_MAX_SENTINEL_DO_NOT_USE_); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, GRAPH_FAILED); - tensor_.clear_dtype(); -} -TEST_F(ConvertTensorUtest, ConvertTensorFloat) { - tensor_.add_float_val(FLOAT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_FLOAT); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_float_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorDouble) { - tensor_.add_double_val(DOUBLE_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_DOUBLE); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_double_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorComplex32_ValSize0_Success) { - tensor_.set_dtype(domi::tensorflow::DT_COMPLEX32); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_icomplex_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorComplex32_ValSize2_DimSize2_Success) { - domi::tensorflow::TensorShapeProto *tensor_shape = new domi::tensorflow::TensorShapeProto(); - tensor_shape->add_dim()->set_size(2); - tensor_.set_allocated_tensor_shape(tensor_shape); - tensor_.add_icomplex_val(INT_TEST_NUM); - tensor_.add_icomplex_val(INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_COMPLEX32); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_icomplex_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorComplex32_ValSize4_DimSize3_Success) { - domi::tensorflow::TensorShapeProto *tensor_shape = new domi::tensorflow::TensorShapeProto(); - tensor_shape->add_dim()->set_size(3); - tensor_.set_allocated_tensor_shape(tensor_shape); - tensor_.add_icomplex_val(INT_TEST_NUM); - tensor_.add_icomplex_val(INT_TEST_NUM); - tensor_.add_icomplex_val(INT_TEST_NUM); - tensor_.add_icomplex_val(INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_COMPLEX32); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_icomplex_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorSComplex) { - tensor_.add_scomplex_val(FLOAT_TEST_NUM); - tensor_.add_scomplex_val(FLOAT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_COMPLEX64); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_scomplex_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorDComplex) { - tensor_.add_dcomplex_val(DOUBLE_TEST_NUM); - tensor_.add_dcomplex_val(DOUBLE_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_COMPLEX128); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_dcomplex_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorInt) { - tensor_.add_int_val(INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_INT32); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_int_val(); - - tensor_.add_int64_val(INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_INT64); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_int64_val(); - - tensor_.add_int_val(INT_TEST_NUM); - tensor_.add_int_val(INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_INT16); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_int_val(); - - tensor_.add_int_val(INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_UINT8); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_int_val(); - -} -TEST_F(ConvertTensorUtest, ConvertTensorUnsignedInt) { - tensor_.add_uint32_val(UNSIGNED_INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_UINT32); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_int_val(); - - tensor_.add_uint64_val(UNSIGNED_INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_UINT64); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_int64_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorBool) { - tensor_.add_bool_val(true); - tensor_.set_dtype(domi::tensorflow::DT_BOOL); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_bool_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorString) { - domi::tensorflow::TensorShapeProto *tensor_shape = new domi::tensorflow::TensorShapeProto(); - tensor_.set_allocated_tensor_shape(tensor_shape); - tensor_.add_string_val("1"); - tensor_.set_dtype(domi::tensorflow::DT_STRING); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - - tensor_.add_string_val("str_test2"); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_string_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorHalf) { - tensor_.add_half_val(INT_TEST_NUM); - tensor_.set_dtype(domi::tensorflow::DT_HALF); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_half_val(); -} -TEST_F(ConvertTensorUtest, ConvertTensorHalfVariant) { - tensor_.add_variant_val(); - tensor_.set_dtype(domi::tensorflow::DT_VARIANT); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - tensor_.clear_variant_val(); -} - -TEST_F(ConvertTensorUtest, SetWeightFloat) { - tensor_.set_dtype(domi::tensorflow::DT_FLOAT); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightDouble) { - tensor_.set_dtype(domi::tensorflow::DT_DOUBLE); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightSComplex) { - tensor_.set_dtype(domi::tensorflow::DT_COMPLEX64); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightDComplex) { - tensor_.set_dtype(domi::tensorflow::DT_COMPLEX128); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightInt) { - tensor_.set_dtype(domi::tensorflow::DT_INT16); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - - tensor_.set_dtype(domi::tensorflow::DT_INT32); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - - tensor_.set_dtype(domi::tensorflow::DT_INT64); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightUnsignedInt) { - tensor_.set_dtype(domi::tensorflow::DT_UINT8); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - - tensor_.set_dtype(domi::tensorflow::DT_UINT32); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); - - tensor_.set_dtype(domi::tensorflow::DT_UINT64); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightBool) { - tensor_.set_dtype(domi::tensorflow::DT_BOOL); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightString) { - tensor_.set_dtype(domi::tensorflow::DT_STRING); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightHalf) { - tensor_.set_dtype(domi::tensorflow::DT_HALF); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightVarient) { - tensor_.set_dtype(domi::tensorflow::DT_VARIANT); - ret_ = ge::AutoMappingUtil::ConvertTensor(tensor_, weight_); - EXPECT_EQ(ret_, domi::SUCCESS); -} -TEST_F(ConvertTensorUtest, SetWeightDataString) { - GeTensorPtr weight = std::make_shared(); - const int64_t count = 2; - std::string s[count] = {"0.941042", "0.508840"}; - std::string tensor_content; - for (int i = 0; i < count; ++i) { - tensor_content.push_back(static_cast(s[i].size())); - } - for (int i = 0; i < count; ++i) { - tensor_content += s[i]; - } - TensorAssign::SetWeightData(domi::tensorflow::DT_STRING, count, tensor_content, weight); - uint64_t tensor_data_size = 0; - for (int i = 0; i < count; ++i) { - tensor_data_size += sizeof(StringHead) + s[i].size() + 1; - } - EXPECT_EQ(weight->GetData().GetSize(), tensor_data_size); - std::string tensor_data[count]; - int start_index = count * sizeof(StringHead); - for (int i = 0; i < count; ++i) { - for (size_t j = 0; j < s[i].size(); ++j) { - tensor_data[i] += weight->GetData().data()[start_index + j]; - } - start_index += s[i].size() + 1; - EXPECT_EQ(tensor_data[i], s[i]); - } -} -TEST_F(ConvertTensorUtest, SetWeightDataStringOverflow) { - GeTensorPtr weight = std::make_shared(); - std::string tensor_content = "\b\b0.9410420.508840"; - const int64_t count = std::numeric_limits::max() / sizeof(StringHead) + 1; - TensorAssign::SetWeightData(domi::tensorflow::DT_STRING, count, tensor_content, weight); - EXPECT_EQ(weight->GetData().size(), 0); -} - -TEST_F(ConvertTensorUtest, GetStringVal) { - Status retStatus; - ge::GeTensorPtr weight = ComGraphMakeShared(); - google::protobuf::RepeatedPtrField vector; - - vector.Add()->assign("vector"); - EXPECT_FALSE(vector.empty()); - EXPECT_EQ(vector.size(), 1); - - retStatus = TensorAssign::GetStringVal(1, vector, 2, weight); - EXPECT_EQ(retStatus, domi::SUCCESS); -} diff --git a/tests/ut/sc_check/testcase/sc_check_unittest.cc b/tests/ut/sc_check/testcase/sc_check_unittest.cc index 191565200d8b5709b545e17d4f4b5fe35b26a89e..eb92cd799b876e99e63dc91074f0334174743227 100644 --- a/tests/ut/sc_check/testcase/sc_check_unittest.cc +++ b/tests/ut/sc_check/testcase/sc_check_unittest.cc @@ -36,22 +36,10 @@ namespace SC{ * 因此添加这个ut来看护代码仓目录 */ TEST(FileCount, CheckFileCount) { - fs::path dirPath0(std::string(TOP_DIR).append("/graph")); - EXPECT_NO_THROW(countFilesAndDirs(dirPath0)); - fs::path dirPath1(std::string(TOP_DIR).append("/inc")); EXPECT_NO_THROW(countFilesAndDirs(dirPath1)); - fs::path dirPath2(std::string(TOP_DIR).append("/exe_graph")); - EXPECT_NO_THROW(countFilesAndDirs(dirPath2)); - - fs::path dirPath3(std::string(TOP_DIR).append("/ops")); - EXPECT_NO_THROW(countFilesAndDirs(dirPath3)); - - fs::path dirPath4(std::string(TOP_DIR).append("/register")); - EXPECT_NO_THROW(countFilesAndDirs(dirPath4)); - fs::path dirPath5(std::string(TOP_DIR).append("/base")); - EXPECT_NO_THROW(countFilesAndDirs(dirPath4)); + EXPECT_NO_THROW(countFilesAndDirs(dirPath5)); } } diff --git a/third_party/transformer/inc/axis_util.h b/third_party/transformer/inc/axis_util.h deleted file mode 100644 index 9b1cb298207296d3d5ea2979d4ba290baba70709..0000000000000000000000000000000000000000 --- a/third_party/transformer/inc/axis_util.h +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_UTILS_TRANSFORMER_INC_AXIS_UTIL_H_ -#define COMMON_UTILS_TRANSFORMER_INC_AXIS_UTIL_H_ - -#include -#include -#include "external/graph/types.h" -#include "graph/utils/math_util.h" -#include "exe_graph/runtime/shape.h" - -namespace transformer { -#define CHECK(cond, log_func, return_expr) \ - do { \ - if (cond) { \ - log_func; \ - return_expr; \ - } \ - } while (0) - -#define INT64_ZEROCHECK(a) \ - if (a == 0) { \ - return false; \ - } - -#define MUL_OVERFLOW(x, y, z) \ - if (ge::MulOverflow((x), (y), (z))) { \ - return false; \ - } \ - -enum AxisValueType { - AXIS_N = 0, - AXIS_C = 1, - AXIS_H = 2, - AXIS_W = 3, - AXIS_C1 = 4, - AXIS_C0 = 5, - AXIS_Co = 6, - AXIS_D = 7, - AXIS_G = 8, - AXIS_M0 = 9, - AXIS_INPUT_SIZE = 10, - AXIS_HIDDEN_SIZE = 11, - AXIS_STATE_SIZE = 12, - AXIS_BOTTOM = 13 -}; - -using AxisValue = std::array(AXIS_BOTTOM)>; - -inline int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { - if (divisor == 0) { - return 0; - } else if (dividend < 0) { - return -1; - } else { - return (dividend + divisor - 1) / divisor; - } -} - -class AxisUtil { - public: - AxisUtil() {}; - ~AxisUtil() {}; - static bool GetAxisValueByOriginFormat(const ge::Format &format, const gert::Shape &shape, AxisValue &axis_value); - static int32_t GetAxisIndexByFormat(const ge::Format &format, const string &axis); - static int32_t GetAxisIndexByFormat(const ge::Format &format, const string &axis, - const std::map &valid_axis_map); - static std::vector GetAxisVecByFormat(const ge::Format &format); - static std::vector GetReshapeTypeAxisVec(const ge::Format &format, const int64_t &reshape_type_mask); - static std::map GetReshapeTypeAxisMap(const ge::Format &format, - const int64_t &reshape_type_mask); - static std::vector GetSplitOrConcatAxisByFormat(const ge::Format &format, const std::string &axis); - private: - static bool GetAxisValueByNCHW(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByNHWC(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByHWCN(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByND(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByNDHWC(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByNCDHW(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByDHWCN(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByDHWNC(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByNC1HWC0(const gert::Shape &shape, AxisValue &axis_value); - - static bool GetAxisValueByC1HWNCoC0(const gert::Shape &shape, AxisValue &axis_value); -}; -} // namespace transformer -#endif // COMMON_UTILS_TRANSFORMER_INC_AXIS_UTIL_H_ - \ No newline at end of file diff --git a/third_party/transformer/inc/expand_dimension.h b/third_party/transformer/inc/expand_dimension.h deleted file mode 100644 index 699d18331c0466bfddb9fe9fdcd3a757b598f53d..0000000000000000000000000000000000000000 --- a/third_party/transformer/inc/expand_dimension.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ -#define COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ - -#include -#include -#include "graph/types.h" -#include "graph/ge_tensor.h" -#include "exe_graph/runtime/shape.h" -#include "transfer_def.h" - -namespace transformer { - /* Pad dimension according to reshape type */ -bool ExpandDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format, - const uint32_t &tensor_index, const std::string &reshape_type, ge::GeShape &shape); - -bool ExpandRangeDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format, - const uint32_t &tensor_index, const std::string &reshape_type, - std::vector> &ranges); - -class ExpandDimension { - public: - ExpandDimension(); - ~ExpandDimension(); - - static int64_t GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, - const size_t &origin_dim_size, const std::string &reshape_type); - static bool GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, - const size_t &origin_dim_size, const std::string &reshape_type, - int64_t &reshape_type_mask); - static bool GenerateReshapeTypeByMask(const ge::Format &origin_format, const size_t &origin_dim_size, - const int64_t &reshape_type_mask, std::string &reshape_type, - std::string &failed_reason); - static void ExpandDims(const int64_t &reshape_type, ge::GeShape &shape); - static void ExpandDims(const int64_t &reshape_type, const ge::GeShape &origin_shape, ge::GeShape &shape); - static void ExpandDims(const int64_t &reshape_type, gert::Shape &shape); - static void ExpandDims(const int64_t &reshape_type, const gert::Shape &origin_shape, gert::Shape &shape); - static bool GetDefaultReshapeType(const ge::Format &origin_format, const size_t &origin_dim_size, - std::string &reshape_type); - static int32_t GetAxisIndexByName(char ch, const ge::Format &format); - static int64_t GetReshapeAxicValue(const int64_t &reshape_type_mask, - const ge::GeShape &shape, int32_t axis_index); - static int64_t GetReshapeAxicValueByName(const int64_t &reshape_type_mask, char ch, - const ge::GeShape &shape, const ge::Format &format); - static bool GetFormatFullSize(const ge::Format &format, size_t &full_size); - private: - static bool IsNeedExpand(const ge::Format &origin_format, const ge::Format &format, - const size_t &origin_dim_size, const size_t &full_size, const std::string &reshape_type); - static bool IsReshapeTypeValid(const ge::Format &origin_format, const size_t &origin_dim_size, - const std::string &reshape_type); -}; -} // namespace transformer -#endif // COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ - \ No newline at end of file diff --git a/third_party/transformer/inc/transfer_def.h b/third_party/transformer/inc/transfer_def.h deleted file mode 100644 index 508ecbb25ca94724a8525c8e5f455413f787901e..0000000000000000000000000000000000000000 --- a/third_party/transformer/inc/transfer_def.h +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef TRANSFORMER_INC_TRANSFER_DEF_H_ -#define TRANSFORMER_INC_TRANSFER_DEF_H_ - -#include -#include -#include "graph/utils/math_util.h" -#include "exe_graph/runtime/shape.h" - -namespace transformer { - -const size_t ORIGIN_FORMAT_DIM_SIZE = 5; -const size_t EXT_AXIS_SIZE = 4; -const size_t EXT_AXIS_OP_SIZE = 3; - - -struct AlignShapeInfo { - ge::Format src_format; - ge::Format dst_format; - gert::Shape src_shape; - ge::DataType data_type; - int64_t reshape_type_mask; -}; - -struct TransferDimsInfo { - ge::Format src_format; - ge::Format dst_format; - gert::Shape src_shape; - int64_t reshape_type_mask; -}; - -struct AxisIndexMapping { - std::vector> src_to_dst_transfer_dims; - std::vector> dst_to_src_transfer_dims; -}; - -using GetAlignedShapeFunc = std::function; -using TransferDimsFunc = std::function; - -using FormatIndex = std::array; -using ExtAxisValue = std::array; -using ExtAxisOpValue = std::array; - -} // namespace transformer -#endif // TRANSFORMER_INC_TRANSFER_DEF_H_ diff --git a/third_party/transformer/inc/transfer_range_according_to_format.h b/third_party/transformer/inc/transfer_range_according_to_format.h deleted file mode 100644 index 27d4857a3a9dcb9decfff80dfba34e9ff730fc02..0000000000000000000000000000000000000000 --- a/third_party/transformer/inc/transfer_range_according_to_format.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_UTILS_TRANSFORMER_INC_TRANSFER_RANGE_ACCORDING_TO_FORMAT_H_ -#define COMMON_UTILS_TRANSFORMER_INC_TRANSFER_RANGE_ACCORDING_TO_FORMAT_H_ - -#include -#include "transfer_shape_according_to_format.h" - -namespace transformer { -struct RangeAndFormatInfo { - ge::GeShape old_shape; - std::vector> old_range; - std::vector> &new_range; - ge::Format old_format; - ge::Format new_format; - ge::DataType current_data_type; - CalcShapeExtraAttr extra_attr; - RangeAndFormatInfo(ge::GeShape old_shape, std::vector> old_range, - std::vector> &new_range, ge::Format old_format, - ge::Format new_format, ge::DataType current_data_type) : - old_shape(old_shape), old_range(old_range), new_range(new_range), old_format(old_format), - new_format(new_format), current_data_type(current_data_type), extra_attr(CalcShapeExtraAttr()) {} - RangeAndFormatInfo(ge::GeShape old_shape, std::vector> old_range, - std::vector> &new_range, ge::Format old_format, - ge::Format new_format, ge::DataType current_data_type, CalcShapeExtraAttr extra_attr) : - old_shape(old_shape), old_range(old_range), new_range(new_range), old_format(old_format), - new_format(new_format), current_data_type(current_data_type), extra_attr(extra_attr) {} -}; - -using RangeAndFormat = struct RangeAndFormatInfo; - -class RangeTransferAccordingToFormat { - public: - RangeTransferAccordingToFormat() = default; - - ~RangeTransferAccordingToFormat() = default; - - RangeTransferAccordingToFormat(const RangeTransferAccordingToFormat &) = delete; - - RangeTransferAccordingToFormat &operator=(const RangeTransferAccordingToFormat &) = delete; - - static bool GetRangeAccordingToFormat(RangeAndFormat &range_and_format_info); - - // deprecated ATTRIBUTED_DEPRECATED(static bool GetRangeAccordingToFormat(const ExtAxisOpValue &, RangeAndFormat &)) - static bool GetRangeAccordingToFormat(const ge::OpDescPtr &op_desc, RangeAndFormat &range_and_format_info); - - static bool GetRangeAccordingToFormat(const ExtAxisOpValue &op_value, RangeAndFormat &range_and_format_info); -}; -} // namespace fe - -#endif // FUSION_ENGINE_OPTIMIZER_GRAPH_OPTIMIZER_RANGE_FORMAT_TRANSFER_TRANSFER_RANGE_ACCORDING_TO_FORMAT_H_ diff --git a/third_party/transformer/inc/transfer_shape_according_to_format.h b/third_party/transformer/inc/transfer_shape_according_to_format.h deleted file mode 100644 index dbb4616807e6f3fc74a3355cfb4e119a2cde79d0..0000000000000000000000000000000000000000 --- a/third_party/transformer/inc/transfer_shape_according_to_format.h +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_UTILS_TRANSFORMER_INC_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ -#define COMMON_UTILS_TRANSFORMER_INC_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ - -#include -#include "graph/types.h" -#include "graph/ge_tensor.h" -#include "graph/op_desc.h" -#include "platform/platform_info.h" -#include "transfer_def.h" -#include "transfer_shape_utils.h" - -namespace transformer { -struct CalcShapeExtraAttr { - int64_t hidden_size; - int64_t input_size; - int64_t state_size; -}; - -struct ShapeAndFormatInfo { - ge::GeShape &oldShape; - const ge::Format &oldFormat; - const ge::Format &newFormat; - const ge::DataType ¤tDataType; - CalcShapeExtraAttr extra_attr; - ShapeAndFormatInfo(ge::GeShape &old_shape, const ge::Format &old_format, const ge::Format &new_format, - const ge::DataType &data_type) - : oldShape(old_shape), oldFormat(old_format), newFormat(new_format), currentDataType(data_type), - extra_attr({1, 1, -1}) {} -}; - -using ShapeAndFormat = struct ShapeAndFormatInfo; - -class ShapeTransferAccordingToFormat { - public: - ShapeTransferAccordingToFormat(); - - ~ShapeTransferAccordingToFormat() {}; - - ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete; - - ShapeTransferAccordingToFormat &operator=(const ShapeTransferAccordingToFormat&) = delete; - - static bool GetShapeAccordingToFormat(ShapeAndFormat &shapeAndFormatInfo); - - // deprecated ATTRIBUTED_DEPRECATED(static bool GetShapeAccordingToFormat(const ExtAxisOpValue &, ShapeAndFormat &)) - static bool GetShapeAccordingToFormat(const ge::OpDescPtr &op_desc, ShapeAndFormat &shapeAndFormatInfo); - - static bool GetShapeAccordingToFormat(const ExtAxisOpValue &op_value, ShapeAndFormat &shapeAndFormatInfo); - - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - const ExtAxisValue &ext_axis, ge::GeShape &shape); - - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - const ExtAxisValue &ext_axis, const ge::GeShape &origin_shape, ge::GeShape &shape); - - /* deprecated ATTRIBUTED_DEPRECATED(static bool TransferShape(const ge::Format &, const ge::Format &, const ge::DataType &, - gert::Shape &, const ExtAxisOpValue &, - const fe::PlatFormInfos *)) */ - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - gert::Shape &shape, const ge::OpDescPtr op_desc = nullptr, - const fe::PlatFormInfos *platform_infos_ptr = nullptr); - - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - gert::Shape &shape, const ExtAxisOpValue &op_value, - const fe::PlatFormInfos *platform_infos_ptr = nullptr); - - /* deprecated ATTRIBUTED_DEPRECATED(static bool TransferShape(const ge::Format &, const ge::Format &, const ge::DataType &, - const gert::Shape &, gert::Shape &, const ExtAxisOpValue &)) */ - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - const gert::Shape &origin_shape, gert::Shape &shape, - const ge::OpDescPtr op_desc = nullptr); - - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - const gert::Shape &origin_shape, gert::Shape &shape, const ExtAxisOpValue &op_value); - - // deprecated ATTRIBUTED_DEPRECATED(static void InitExtAxisValue(const ExtAxisOpValue &, ExtAxisValue &)) - static void InitExtAxisValue(const ge::OpDescPtr &op_desc, ExtAxisValue &ext_axis); - - static void InitExtAxisValue(const ExtAxisOpValue &op_value, ExtAxisValue &ext_axis); - - static bool InitPlatformInfo(); - static int64_t GetC0ByDtype(const ge::DataType &data_type); - static int64_t GetM0ByDtype(const ge::DataType &data_type); - static int64_t GetN0ByDtype(const ge::DataType &data_type); - static bool GetAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); - static bool TransferDims(const TransferDimsInfo &transfer_dims_info, AxisIndexMapping &axis_index_mapping); -}; -} // namespace transformer -#endif // COMMON_UTILS_TRANSFORMER_INC_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ diff --git a/third_party/transformer/inc/transfer_shape_utils.h b/third_party/transformer/inc/transfer_shape_utils.h deleted file mode 100644 index 14cb6b871ae32a0fc1a13057c1de0e5784d9e0d6..0000000000000000000000000000000000000000 --- a/third_party/transformer/inc/transfer_shape_utils.h +++ /dev/null @@ -1,195 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef TRANSFORMER_INC_TRANSFER_SHAPE_UTILS_H_ -#define TRANSFORMER_INC_TRANSFER_SHAPE_UTILS_H_ - -#include -#include "platform/platform_info.h" -#include "axis_util.h" -#include "transfer_def.h" - -namespace transformer { -enum class TransferShapeType { - ND_TO_ND = 0, - ND_TO_NZ, - FULL_SIZE, - NOT_FULL_SIZE, - INVALID -}; - -class TransferShapeUtils { - public: - TransferShapeUtils() {} - ~TransferShapeUtils() {} - static bool InitPlatformInfo(const fe::PlatFormInfos *platform_infos_ptr = nullptr); - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - const ExtAxisValue &ext_axis, gert::Shape &shape, - const fe::PlatFormInfos *platforminfos = nullptr); - static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, - const ExtAxisValue &ext_axis, const gert::Shape &origin_shape, gert::Shape &shape, - const fe::PlatFormInfos *platforminfos = nullptr); - static int64_t GetC0ByDtype(const ge::DataType &data_type); - static int64_t GetM0ByDtype(const ge::DataType &data_type); - static int64_t GetN0ByDtype(const ge::DataType &data_type); - static bool GetAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); - static bool TransferDims(const TransferDimsInfo &transfer_dims_info, AxisIndexMapping &axis_index_mapping); - - private: - static bool InitM0K0CO(const fe::PlatFormInfos *platform_infos); - static bool TransferShapeByFormat(const ge::Format &primary_format, const AxisValue &axis_value, - gert::Shape &shape); - static bool TransferShapeByAxisValue(const ge::Format &primary_format, const AxisValue &axis_value, - gert::Shape &shape); - static bool TransferShapeByOriginShape(const ge::Format &primary_format, const int64_t &c0, const int64_t &m0, - const ExtAxisValue &ext_axis, const gert::Shape &origin_shape, - gert::Shape &shape); - static bool TransferShapeByFormatIndex(const ge::Format &origin_format, const ge::Format &format, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape); - static bool IsNeedTransferShape(const ge::Format &origin_format, const ge::Format &format, const gert::Shape &shape); - static bool CheckInputParam(const ge::Format &origin_format, const ge::Format &primary_format, - const ge::DataType &data_type); - static bool IsNeedAxisValue(const ge::Format &format, const size_t &origin_dim_size); - static int64_t GetC0Value(const ge::DataType &data_type, const ge::Format &format); - - /* ----------Below is the function of getting new shape by axis value---------------------- */ - static bool GetNCHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetNHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetCHWNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetNDHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetNCDHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetDHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetDHWNCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetNC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetNDC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetC1HWNCoC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetNzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetFzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetFz3DShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetFz3DTransposeShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetFzLstmShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetFzC04ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetFznRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetNDRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - /* ----------Below is the function of getting new shape by origin shape---------------------- */ - static bool GetNCHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetNHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetCHWNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetNDHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetNCDHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetDHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetDHWNCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetNC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetNDC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetC1HWNCoC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetFractalNzShape(const int64_t &c0, const int64_t &m0, - const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetFractalZShape(const int64_t &c0, const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetFractalZShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group, - const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetFractalZ3DShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group, - const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetFractalZ3DTransposeShape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape); - - static bool GetFractalZLstmShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetFractalZC04Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetFractalZnRnnShape(const ExtAxisValue &ext_axis, const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetNdRnnBiasShape(const ExtAxisValue &ext_axis, const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape); - - static bool GetNYUVShape(gert::Shape &shape); - - static bool GetFzWinoShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); - - static bool GetFractalZWinoShape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape); - - static TransferShapeType GetTransferShapeType(const ge::Format &src_format, const ge::Format &dst_format, - const gert::Shape &src_shape); - - static bool GetNdToNdAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); - - static bool GetNdToNzAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); - - static bool GetFullSizeAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); - - static bool GetNotFullSizeAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); - - static bool GetNdToNdAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping); - - static bool GetNdToNzAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping); - - static bool GetFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping); - - static bool GetNotFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping); - - static std::array(ge::DataType::DT_MAX)> m0_list_; - static std::array(ge::DataType::DT_MAX)> k0_list_; - static std::array(ge::DataType::DT_MAX)> n0_list_; - static const std::map get_aligned_shape_func_map; - static const std::map transfer_dims_func_map; -}; -} -#endif // TRANSFORMER_INC_TRANSFER_SHAPE_UTILS_H_ - \ No newline at end of file diff --git a/third_party/transformer/src/axis_constants.h b/third_party/transformer/src/axis_constants.h deleted file mode 100644 index 08be86e0815f7ae76cb4c8fe6fe23c2ee0159f87..0000000000000000000000000000000000000000 --- a/third_party/transformer/src/axis_constants.h +++ /dev/null @@ -1,90 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#ifndef COMMON_UTILS_TRANSFORMER_INC_AXIS_CONSTANTS_H_ -#define COMMON_UTILS_TRANSFORMER_INC_AXIS_CONSTANTS_H_ - -#include -#include -#include -#include "graph/types.h" - -namespace transformer { -extern const size_t DIM_SIZE_TWO; -extern const size_t DIM_SIZE_FOUR; -extern const size_t DIM_SIZE_FIVE; -extern const size_t DIM_SIZE_SIX; - -extern const size_t EXT_INDEX_INPUT_SIZE; -extern const size_t EXT_INDEX_HIDDEN_SIZE; -extern const size_t EXT_INDEX_STATE_SIZE; -extern const size_t EXT_INDEX_M0_VAL; - -extern const int32_t AXIS_NCHW_DIM_N; -extern const int32_t AXIS_NCHW_DIM_C; -extern const int32_t AXIS_NCHW_DIM_H; -extern const int32_t AXIS_NCHW_DIM_W; - -extern const int32_t AXIS_NHWC_DIM_N; -extern const int32_t AXIS_NHWC_DIM_H; -extern const int32_t AXIS_NHWC_DIM_W; -extern const int32_t AXIS_NHWC_DIM_C; - -extern const int32_t AXIS_HWCN_DIM_H; -extern const int32_t AXIS_HWCN_DIM_W; -extern const int32_t AXIS_HWCN_DIM_C; -extern const int32_t AXIS_HWCN_DIM_N; - -extern const int32_t AXIS_CHWN_DIM_C; -extern const int32_t AXIS_CHWN_DIM_H; -extern const int32_t AXIS_CHWN_DIM_W; -extern const int32_t AXIS_CHWN_DIM_N; - -extern const int32_t NDHWC_DIM_N; -extern const int32_t NDHWC_DIM_D; -extern const int32_t NDHWC_DIM_H; -extern const int32_t NDHWC_DIM_W; -extern const int32_t NDHWC_DIM_C; - -extern const int32_t NCDHW_DIM_N; -extern const int32_t NCDHW_DIM_C; -extern const int32_t NCDHW_DIM_D; -extern const int32_t NCDHW_DIM_H; -extern const int32_t NCDHW_DIM_W; - -extern const int32_t DHWCN_DIM_D; -extern const int32_t DHWCN_DIM_H; -extern const int32_t DHWCN_DIM_W; -extern const int32_t DHWCN_DIM_C; -extern const int32_t DHWCN_DIM_N; - -extern const int32_t DHWNC_DIM_D; -extern const int32_t DHWNC_DIM_H; -extern const int32_t DHWNC_DIM_W; -extern const int32_t DHWNC_DIM_N; -extern const int32_t DHWNC_DIM_C; - -extern const int32_t AXIS_NC1HWC0_DIM_N; -extern const int32_t AXIS_NC1HWC0_DIM_C1; -extern const int32_t AXIS_NC1HWC0_DIM_H; -extern const int32_t AXIS_NC1HWC0_DIM_W; -extern const int32_t AXIS_NC1HWC0_DIM_C0; - -extern const int32_t AXIS_C1HWNCoC0_DIM_C1; -extern const int32_t AXIS_C1HWNCoC0_DIM_H; -extern const int32_t AXIS_C1HWNCoC0_DIM_W; -extern const int32_t AXIS_C1HWNCoC0_DIM_N; -extern const int32_t AXIS_C1HWNCoC0_DIM_Co; - -const std::set kFormatNZSet = {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ_C0_16, - ge::FORMAT_FRACTAL_NZ_C0_32}; - -} // namespace transformer - -#endif // COMMON_UTILS_TRANSFORMER_INC_AXIS_CONSTANTS_H_ diff --git a/third_party/transformer/src/axis_util.cc b/third_party/transformer/src/axis_util.cc deleted file mode 100644 index f4297c6c01540e472cdb944e3a4f166ad7957a5c..0000000000000000000000000000000000000000 --- a/third_party/transformer/src/axis_util.cc +++ /dev/null @@ -1,385 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "axis_util.h" -#include "axis_constants.h" -#include "common/ge_common/debug/ge_log.h" -#include "expand_dimension.h" -#include "graph/utils/type_utils.h" - -namespace transformer { -const size_t DIM_SIZE_TWO = 2; -const size_t DIM_SIZE_FOUR = 4; -const size_t DIM_SIZE_FIVE = 5; -const size_t DIM_SIZE_SIX = 6; - -const size_t EXT_INDEX_INPUT_SIZE = 0; -const size_t EXT_INDEX_HIDDEN_SIZE = 1; -const size_t EXT_INDEX_STATE_SIZE = 2; -const size_t EXT_INDEX_M0_VAL = 3; - -const int32_t AXIS_NCHW_DIM_N = 0; -const int32_t AXIS_NCHW_DIM_C = 1; -const int32_t AXIS_NCHW_DIM_H = 2; -const int32_t AXIS_NCHW_DIM_W = 3; - -const int32_t AXIS_NHWC_DIM_N = 0; -const int32_t AXIS_NHWC_DIM_H = 1; -const int32_t AXIS_NHWC_DIM_W = 2; -const int32_t AXIS_NHWC_DIM_C = 3; - -const int32_t AXIS_HWCN_DIM_H = 0; -const int32_t AXIS_HWCN_DIM_W = 1; -const int32_t AXIS_HWCN_DIM_C = 2; -const int32_t AXIS_HWCN_DIM_N = 3; - -const int32_t AXIS_CHWN_DIM_C = 0; -const int32_t AXIS_CHWN_DIM_H = 1; -const int32_t AXIS_CHWN_DIM_W = 2; -const int32_t AXIS_CHWN_DIM_N = 3; - -const int32_t NDHWC_DIM_N = 0; -const int32_t NDHWC_DIM_D = 1; -const int32_t NDHWC_DIM_H = 2; -const int32_t NDHWC_DIM_W = 3; -const int32_t NDHWC_DIM_C = 4; - -const int32_t NCDHW_DIM_N = 0; -const int32_t NCDHW_DIM_C = 1; -const int32_t NCDHW_DIM_D = 2; -const int32_t NCDHW_DIM_H = 3; -const int32_t NCDHW_DIM_W = 4; - -const int32_t DHWCN_DIM_D = 0; -const int32_t DHWCN_DIM_H = 1; -const int32_t DHWCN_DIM_W = 2; -const int32_t DHWCN_DIM_C = 3; -const int32_t DHWCN_DIM_N = 4; - -const int32_t DHWNC_DIM_D = 0; -const int32_t DHWNC_DIM_H = 1; -const int32_t DHWNC_DIM_W = 2; -const int32_t DHWNC_DIM_N = 3; -const int32_t DHWNC_DIM_C = 4; - -const int32_t AXIS_NC1HWC0_DIM_N = 0; -const int32_t AXIS_NC1HWC0_DIM_C1 = 1; -const int32_t AXIS_NC1HWC0_DIM_H = 2; -const int32_t AXIS_NC1HWC0_DIM_W = 3; -const int32_t AXIS_NC1HWC0_DIM_C0 = 4; - -const int32_t AXIS_NDC1HWC0_DIM_N = 0; -const int32_t AXIS_NDC1HWC0_DIM_D = 1; -const int32_t AXIS_NDC1HWC0_DIM_C1 = 2; -const int32_t AXIS_NDC1HWC0_DIM_H = 3; -const int32_t AXIS_NDC1HWC0_DIM_W = 4; -const int32_t AXIS_NDC1HWC0_DIM_C0 = 5; - -const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0; -const int32_t AXIS_C1HWNCoC0_DIM_H = 1; -const int32_t AXIS_C1HWNCoC0_DIM_W = 2; -const int32_t AXIS_C1HWNCoC0_DIM_N = 3; -const int32_t AXIS_C1HWNCoC0_DIM_Co = 4; - -const int32_t AXIS_FZ_DIM_C1HW = 0; -const int32_t AXIS_FZ_DIM_N1 = 1; -const int32_t AXIS_FZ_DIM_N0 = 2; -const int32_t AXIS_FZ_DIM_C0 = 3; - -const std::map> kFormatAxisIndexMap = { - {ge::Format::FORMAT_NCHW, {{"N", AXIS_NCHW_DIM_N}, {"C", AXIS_NCHW_DIM_C}, - {"H", AXIS_NCHW_DIM_H}, {"W", AXIS_NCHW_DIM_W}}}, - {ge::Format::FORMAT_HWCN, {{"N", AXIS_HWCN_DIM_N}, {"C", AXIS_HWCN_DIM_C}, - {"H", AXIS_HWCN_DIM_H}, {"W", AXIS_HWCN_DIM_W}}}, - {ge::Format::FORMAT_NHWC, {{"N", AXIS_NHWC_DIM_N}, {"C", AXIS_NHWC_DIM_C}, - {"H", AXIS_NHWC_DIM_H}, {"W", AXIS_NHWC_DIM_W}}}, - {ge::Format::FORMAT_CHWN, {{"N", AXIS_CHWN_DIM_N}, {"C", AXIS_CHWN_DIM_C}, - {"H", AXIS_CHWN_DIM_H}, {"W", AXIS_CHWN_DIM_W}}}, - {ge::Format::FORMAT_NDHWC, {{"N", NDHWC_DIM_N}, {"C", NDHWC_DIM_C}, - {"H", NDHWC_DIM_H}, {"W", NDHWC_DIM_W}, {"D", NDHWC_DIM_D}}}, - {ge::Format::FORMAT_NCDHW, {{"N", NCDHW_DIM_N}, {"C", NCDHW_DIM_C}, - {"H", NCDHW_DIM_H}, {"W", NCDHW_DIM_W}, {"D", NCDHW_DIM_D}}}, - {ge::Format::FORMAT_DHWCN, {{"N", DHWCN_DIM_N}, {"C", DHWCN_DIM_C}, - {"H", DHWCN_DIM_H}, {"W", DHWCN_DIM_W}, {"D", DHWCN_DIM_D}}}, - {ge::Format::FORMAT_DHWNC, {{"N", DHWNC_DIM_N}, {"C", DHWNC_DIM_C}, - {"H", DHWNC_DIM_H}, {"W", DHWNC_DIM_W}, {"D", DHWNC_DIM_D}}}, - {ge::Format::FORMAT_NC1HWC0, {{"N", AXIS_NC1HWC0_DIM_N}, {"C1", AXIS_NC1HWC0_DIM_C1}, - {"H", AXIS_NC1HWC0_DIM_H}, {"W", AXIS_NC1HWC0_DIM_W}, {"C0", AXIS_NC1HWC0_DIM_C0}}}, - {ge::Format::FORMAT_NDC1HWC0, {{"N", AXIS_NDC1HWC0_DIM_N}, {"D", AXIS_NDC1HWC0_DIM_D}, - {"C1", AXIS_NDC1HWC0_DIM_C1}, {"H", AXIS_NDC1HWC0_DIM_H}, - {"W", AXIS_NDC1HWC0_DIM_W}, {"C0", AXIS_NDC1HWC0_DIM_C0}}}, - {ge::Format::FORMAT_FRACTAL_Z, {{"C1HW", AXIS_FZ_DIM_C1HW}, {"N1", AXIS_FZ_DIM_N1}, - {"N0", AXIS_FZ_DIM_N0}, {"C0", AXIS_FZ_DIM_C0}}}}; - -const std::map> kFormatAxisVec = { - {ge::Format::FORMAT_NCHW, {"N", "C", "H", "W"}}, - {ge::Format::FORMAT_HWCN, {"H", "W", "C", "N"}}, - {ge::Format::FORMAT_NHWC, {"N", "H", "W", "C"}}, - {ge::Format::FORMAT_CHWN, {"C", "H", "W", "N"}}, - {ge::Format::FORMAT_NDHWC, {"N", "D", "H", "W", "C"}}, - {ge::Format::FORMAT_NCDHW, {"N", "C", "D", "H", "W"}}, - {ge::Format::FORMAT_DHWCN, {"D", "H", "W", "C", "N"}}, - {ge::Format::FORMAT_DHWNC, {"D", "H", "W", "N", "C"}}, - {ge::Format::FORMAT_NC1HWC0, {"N", "C1", "H", "W", "C0"}}, - {ge::Format::FORMAT_NDC1HWC0, {"N", "D", "C1", "H", "W", "C0"}}, - {ge::Format::FORMAT_FRACTAL_Z, {"C1HW", "N1", "N0", "C0"}}}; - -const std::map>> kFormatSplitOrConcatAxisMap { - {ge::Format::FORMAT_NC1HWC0, {{"C", {"C1", "C0"}}, {"C1", {"C"}}, {"C0", {"C"}}}}, - {ge::Format::FORMAT_NDC1HWC0, {{"C", {"C1", "C0"}}, {"C1", {"C"}}, {"C0", {"C"}}}}, - {ge::Format::FORMAT_FRACTAL_Z, {{"N", {"N1", "N0"}}, {"C", {"C1HW", "C0"}}, {"H", {"C1HW"}}, {"W", {"C1HW"}}, - {"C1HW", {"C", "H", "W"}}, {"N1", {"N"}}, {"N0", {"N"}}, {"C0", {"C"}}}}}; - -bool AxisUtil::GetAxisValueByOriginFormat(const ge::Format &format, const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.IsScalar(), GELOGI("Original dim vector is empty!"), return true); - switch (format) { - case ge::FORMAT_NCHW: - return GetAxisValueByNCHW(shape, axis_value); - case ge::FORMAT_NHWC: - return GetAxisValueByNHWC(shape, axis_value); - case ge::FORMAT_HWCN: - return GetAxisValueByHWCN(shape, axis_value); - case ge::FORMAT_ND: - return GetAxisValueByND(shape, axis_value); - case ge::FORMAT_NCDHW: - return GetAxisValueByNCDHW(shape, axis_value); - case ge::FORMAT_NDHWC: - return GetAxisValueByNDHWC(shape, axis_value); - case ge::FORMAT_DHWCN: - return GetAxisValueByDHWCN(shape, axis_value); - case ge::FORMAT_DHWNC: - return GetAxisValueByDHWNC(shape, axis_value); - case ge::FORMAT_NC1HWC0: - return GetAxisValueByNC1HWC0(shape, axis_value); - case ge::FORMAT_C1HWNCoC0: - return GetAxisValueByC1HWNCoC0(shape, axis_value); - default: - GELOGI("Could not retrieve axis value for old format %d.", format); - return false; - } -} - -bool AxisUtil::GetAxisValueByND(const gert::Shape &shape, AxisValue &axis_value) { - /* To differentiate the input datatype of int8 and others */ - if (shape.GetDimNum() == DIM_SIZE_FOUR) { - axis_value[AXIS_N] = shape.GetDim(AXIS_NCHW_DIM_N); - axis_value[AXIS_C] = shape.GetDim(AXIS_NCHW_DIM_C); - axis_value[AXIS_H] = shape.GetDim(AXIS_NCHW_DIM_H); - axis_value[AXIS_W] = shape.GetDim(AXIS_NCHW_DIM_W); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - } - return true; -} - -bool AxisUtil::GetAxisValueByNCHW(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FOUR, GELOGI("Dim size is less than 4."), return false); - /* C0 Must be set for case ND or 2D-NCHW to NZ */ - axis_value[AXIS_N] = shape.GetDim(AXIS_NCHW_DIM_N); - axis_value[AXIS_C] = shape.GetDim(AXIS_NCHW_DIM_C); - axis_value[AXIS_H] = shape.GetDim(AXIS_NCHW_DIM_H); - axis_value[AXIS_W] = shape.GetDim(AXIS_NCHW_DIM_W); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - return true; -} - -bool AxisUtil::GetAxisValueByNHWC(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FOUR, GELOGI("Dim size is less than 4."), return false); - /* C0 Must be set for case ND or 2D-NHWC to NZ */ - axis_value[AXIS_N] = shape.GetDim(AXIS_NHWC_DIM_N); - axis_value[AXIS_C] = shape.GetDim(AXIS_NHWC_DIM_C); - axis_value[AXIS_H] = shape.GetDim(AXIS_NHWC_DIM_H); - axis_value[AXIS_W] = shape.GetDim(AXIS_NHWC_DIM_W); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - return true; -} - -bool AxisUtil::GetAxisValueByNC1HWC0(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FOUR, GELOGI("Dim size is less than 4."), return false); - if (shape.GetDimNum() == DIM_SIZE_FIVE) { - axis_value[AXIS_C0] = shape.GetDim(AXIS_NC1HWC0_DIM_C0); - axis_value[AXIS_C1] = shape.GetDim(AXIS_NC1HWC0_DIM_C1); - axis_value[AXIS_C] = axis_value[AXIS_C1] * axis_value[AXIS_C0]; - } else { - axis_value[AXIS_C] = shape.GetDim(AXIS_NCHW_DIM_C); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - } - - axis_value[AXIS_N] = shape.GetDim(AXIS_NC1HWC0_DIM_N); - axis_value[AXIS_H] = shape.GetDim(AXIS_NC1HWC0_DIM_H); - axis_value[AXIS_W] = shape.GetDim(AXIS_NC1HWC0_DIM_W); - return true; -} - -bool AxisUtil::GetAxisValueByHWCN(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FOUR, GELOGI("Dim size is less than 4."), return false); - /* C0 Must be set for case ND or 2D-NHWC to NZ */ - axis_value[AXIS_N] = shape.GetDim(AXIS_HWCN_DIM_N); - axis_value[AXIS_C] = shape.GetDim(AXIS_HWCN_DIM_C); - axis_value[AXIS_H] = shape.GetDim(AXIS_HWCN_DIM_H); - axis_value[AXIS_W] = shape.GetDim(AXIS_HWCN_DIM_W); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - return true; -} - -bool AxisUtil::GetAxisValueByC1HWNCoC0(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_SIX, GELOGI("Dim size is less than 6."), return false); - /* C0 Must be set for case ND or 2D-NHWC to NZ */ - axis_value[AXIS_N] = shape.GetDim(AXIS_C1HWNCoC0_DIM_N); - axis_value[AXIS_C] = shape.GetDim(AXIS_C1HWNCoC0_DIM_C1) * axis_value[AXIS_C0]; - axis_value[AXIS_H] = shape.GetDim(AXIS_C1HWNCoC0_DIM_H); - axis_value[AXIS_W] = shape.GetDim(AXIS_C1HWNCoC0_DIM_W); - axis_value[AXIS_C1] = shape.GetDim(AXIS_C1HWNCoC0_DIM_C1); - axis_value[AXIS_Co] = shape.GetDim(AXIS_C1HWNCoC0_DIM_Co); - return true; -} - -bool AxisUtil::GetAxisValueByNDHWC(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FIVE, GELOGI("Dim size is less than 5."), return false); - - axis_value[AXIS_N] = shape.GetDim(NDHWC_DIM_N); - axis_value[AXIS_C] = shape.GetDim(NDHWC_DIM_C); - axis_value[AXIS_H] = shape.GetDim(NDHWC_DIM_H); - axis_value[AXIS_W] = shape.GetDim(NDHWC_DIM_W); - axis_value[AXIS_D] = shape.GetDim(NDHWC_DIM_D); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - return true; -} - -bool AxisUtil::GetAxisValueByNCDHW(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FIVE, GELOGI("Dim size is less than 5."), return false); - - axis_value[AXIS_N] = shape.GetDim(NCDHW_DIM_N); - axis_value[AXIS_C] = shape.GetDim(NCDHW_DIM_C); - axis_value[AXIS_H] = shape.GetDim(NCDHW_DIM_H); - axis_value[AXIS_W] = shape.GetDim(NCDHW_DIM_W); - axis_value[AXIS_D] = shape.GetDim(NCDHW_DIM_D); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - return true; -} - -bool AxisUtil::GetAxisValueByDHWCN(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FIVE, GELOGI("Dim size is less than 5."), return false); - - axis_value[AXIS_N] = shape.GetDim(DHWCN_DIM_N); - axis_value[AXIS_C] = shape.GetDim(DHWCN_DIM_C); - axis_value[AXIS_H] = shape.GetDim(DHWCN_DIM_H); - axis_value[AXIS_W] = shape.GetDim(DHWCN_DIM_W); - axis_value[AXIS_D] = shape.GetDim(DHWCN_DIM_D); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - return true; -} - -bool AxisUtil::GetAxisValueByDHWNC(const gert::Shape &shape, AxisValue &axis_value) { - CHECK(shape.GetDimNum() < DIM_SIZE_FIVE, GELOGI("Dim size is less than 5."), return false); - axis_value[AXIS_N] = shape.GetDim(DHWNC_DIM_N); - axis_value[AXIS_C] = shape.GetDim(DHWNC_DIM_C); - axis_value[AXIS_H] = shape.GetDim(DHWNC_DIM_H); - axis_value[AXIS_W] = shape.GetDim(DHWNC_DIM_W); - axis_value[AXIS_D] = shape.GetDim(DHWNC_DIM_D); - axis_value[AXIS_C1] = DivisionCeiling(axis_value[AXIS_C], axis_value[AXIS_C0]); - axis_value[AXIS_Co] = axis_value[AXIS_C0]; - - return true; -} - -int32_t AxisUtil::GetAxisIndexByFormat(const ge::Format &format, const string &axis) { - auto iter = kFormatAxisIndexMap.find(static_cast(GetPrimaryFormat(format))); - if (iter == kFormatAxisIndexMap.end()) { - GELOGW("Does not support this format: %s.", ge::TypeUtils::FormatToSerialString(format).c_str()); - return -1; - } - auto iter2 = iter->second.find(axis); - if (iter2 == iter->second.end()) { - GELOGW("Format %s does not have this axis %s.", ge::TypeUtils::FormatToSerialString(format).c_str(), - axis.c_str()); - return -1; - } - return iter2->second; -} - -int32_t AxisUtil::GetAxisIndexByFormat(const ge::Format &format, const string &axis, - const std::map &valid_axis_map) { - int32_t axis_index = GetAxisIndexByFormat(format, axis); - if (axis_index == -1) { - return -1; - } - auto iter = valid_axis_map.find(axis); - if (iter == valid_axis_map.end()) { - GELOGW("The axis %s is invalid.", axis.c_str()); - return -1; - } - return axis_index - iter->second; -} - -std::vector AxisUtil::GetAxisVecByFormat(const ge::Format &format) { - auto iter = kFormatAxisVec.find(static_cast(GetPrimaryFormat(format))); - if (iter == kFormatAxisVec.end()) { - GELOGW("Does not support this format: %s", ge::TypeUtils::FormatToSerialString(format).c_str()); - return {}; - } - return iter->second; -} - -std::vector AxisUtil::GetReshapeTypeAxisVec(const ge::Format &format, const int64_t &reshape_type_mask) { - std::vector format_axis_vec = GetAxisVecByFormat(static_cast(GetPrimaryFormat(format))); - if (format_axis_vec.empty()) { - GELOGW("Does not support this format: %s", ge::TypeUtils::FormatToSerialString(format).c_str()); - return {}; - } - std::vector axis_vec; - for (size_t i = 0; i < format_axis_vec.size(); ++i) { - int64_t bit_value = (reshape_type_mask >> i) & 1; - if (bit_value == 0) { - axis_vec.emplace_back(format_axis_vec.at(i)); - } - } - return axis_vec; -} - -std::map AxisUtil::GetReshapeTypeAxisMap(const ge::Format &format, - const int64_t &reshape_type_mask) { - std::vector format_axis_vec = GetAxisVecByFormat(static_cast(GetPrimaryFormat(format))); - if (format_axis_vec.empty()) { - GELOGW("Does not support this format: %s", ge::TypeUtils::FormatToSerialString(format).c_str()); - return {}; - } - std::map axis_map; - int32_t expand_dims_cnt = 0; - for (size_t i = 0; i < format_axis_vec.size(); ++i) { - int64_t bit_value = (reshape_type_mask >> i) & 1; - if (bit_value == 0) { - axis_map[format_axis_vec.at(i)] = expand_dims_cnt; - } else { - ++expand_dims_cnt; - } - } - return axis_map; -} - -std::vector AxisUtil::GetSplitOrConcatAxisByFormat(const ge::Format &format, const std::string &axis) { - auto iter = kFormatSplitOrConcatAxisMap.find(static_cast(GetPrimaryFormat(format))); - if (iter == kFormatSplitOrConcatAxisMap.end()) { - GELOGW("Does not support this format: %s.", ge::TypeUtils::FormatToSerialString(format).c_str()); - return {}; - } - auto iter2 = iter->second.find(axis); - if (iter2 == iter->second.end()) { - GELOGW("There is no need to split or concatenate this axis: %s.", axis.c_str()); - return {}; - } - return iter2->second; -} -} // namespace transformer diff --git a/third_party/transformer/src/expand_dimension.cc b/third_party/transformer/src/expand_dimension.cc deleted file mode 100644 index 57f45531a5eb9be63df45680f21c34bf1eaa8164..0000000000000000000000000000000000000000 --- a/third_party/transformer/src/expand_dimension.cc +++ /dev/null @@ -1,591 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "expand_dimension.h" -#include -#include -#include -#include -#include -#include "axis_constants.h" -#include "exe_graph/runtime/expand_dims_type.h" -#include "external/graph/types.h" -#include "common/ge_common/debug/ge_log.h" -#include "external/graph/ge_error_codes.h" -#include "graph/utils/type_utils.h" - -namespace transformer { -namespace { - const std::string RESHAPE_TYPE_FORBIDDEN = "FORBIDDEN"; - const uint32_t kBitsOfByte = 8; - const uint32_t kBitSetDisplaySize = 8; - const uint32_t kMaxReshapeTypeSize = 56; - - const std::set kSupportedTransFormat = {ge::FORMAT_ND, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ_C0_16, ge::FORMAT_FRACTAL_NZ_C0_32, - ge::FORMAT_ND_RNN_BIAS, ge::FORMAT_FRACTAL_ZN_RNN}; - - const std::map FULL_SIZE_OF_FORMAT { - {ge::FORMAT_NCHW, DIM_SIZE_FOUR}, - {ge::FORMAT_NHWC, DIM_SIZE_FOUR}, - {ge::FORMAT_HWCN, DIM_SIZE_FOUR}, - {ge::FORMAT_CHWN, DIM_SIZE_FOUR}, - {ge::FORMAT_NDHWC, DIM_SIZE_FIVE}, - {ge::FORMAT_NCDHW, DIM_SIZE_FIVE}, - {ge::FORMAT_DHWCN, DIM_SIZE_FIVE}, - {ge::FORMAT_DHWNC, DIM_SIZE_FIVE}, - {ge::FORMAT_ND, DIM_SIZE_FOUR} - }; - - inline uint32_t GenerateFormatKey(ge::Format format) { - return ((static_cast(format) & 0xff) << kBitsOfByte); - } - - inline uint32_t GenerateReshapeTypeKey(ge::Format format, size_t size) { - return ((static_cast(format) & 0xff) << kBitsOfByte) | (static_cast(size) & 0xff); - } - - inline uint32_t GenerateAxisIndexKey(ge::Format format, char ch) { - return ((static_cast(format) & 0xff) << kBitsOfByte) | (static_cast(ch) & 0xff); - } - - const std::unordered_map DEFAULT_RESHAPE_TYPE { - {GenerateReshapeTypeKey(ge::FORMAT_NCHW, 0), ""}, - {GenerateReshapeTypeKey(ge::FORMAT_NHWC, 0), ""}, - {GenerateReshapeTypeKey(ge::FORMAT_HWCN, 0), ""}, - {GenerateReshapeTypeKey(ge::FORMAT_CHWN, 0), ""}, - {GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 0), ""}, - {GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 0), ""}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 0), ""}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 0), ""}, - - {GenerateReshapeTypeKey(ge::FORMAT_NCHW, 1), "C"}, - {GenerateReshapeTypeKey(ge::FORMAT_NHWC, 1), "C"}, - {GenerateReshapeTypeKey(ge::FORMAT_HWCN, 1), "C"}, - {GenerateReshapeTypeKey(ge::FORMAT_CHWN, 1), "C"}, - {GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 1), "C"}, - {GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 1), "C"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 1), "C"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 1), "C"}, - - {GenerateReshapeTypeKey(ge::FORMAT_NCHW, 2), "CH"}, - {GenerateReshapeTypeKey(ge::FORMAT_NHWC, 2), "HW"}, - {GenerateReshapeTypeKey(ge::FORMAT_HWCN, 2), "CN"}, - {GenerateReshapeTypeKey(ge::FORMAT_CHWN, 2), "WN"}, - {GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 2), "WC"}, - {GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 2), "HW"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 2), "CN"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 2), "NC"}, - - {GenerateReshapeTypeKey(ge::FORMAT_NCHW, 3), "CHW"}, - {GenerateReshapeTypeKey(ge::FORMAT_NHWC, 3), "HWC"}, - {GenerateReshapeTypeKey(ge::FORMAT_HWCN, 3), "WCN"}, - {GenerateReshapeTypeKey(ge::FORMAT_CHWN, 3), "HWN"}, - {GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 3), "HWC"}, - {GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 3), "DHW"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 3), "WCN"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 3), "WNC"}, - - {GenerateReshapeTypeKey(ge::FORMAT_NDHWC, 4), "DHWC"}, - {GenerateReshapeTypeKey(ge::FORMAT_NCDHW, 4), "CDHW"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWCN, 4), "HWCN"}, - {GenerateReshapeTypeKey(ge::FORMAT_DHWNC, 4), "HWNC"} - }; - - const std::unordered_map AXIS_INDEX_OF_FORMAT { - {GenerateAxisIndexKey(ge::FORMAT_NCHW, 'N'), AXIS_NCHW_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_NCHW, 'C'), AXIS_NCHW_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_NCHW, 'H'), AXIS_NCHW_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_NCHW, 'W'), AXIS_NCHW_DIM_W}, - - {GenerateAxisIndexKey(ge::FORMAT_HWCN, 'N'), AXIS_HWCN_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_HWCN, 'C'), AXIS_HWCN_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_HWCN, 'H'), AXIS_HWCN_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_HWCN, 'W'), AXIS_HWCN_DIM_W}, - - {GenerateAxisIndexKey(ge::FORMAT_NHWC, 'N'), AXIS_NHWC_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_NHWC, 'C'), AXIS_NHWC_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_NHWC, 'H'), AXIS_NHWC_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_NHWC, 'W'), AXIS_NHWC_DIM_W}, - - {GenerateAxisIndexKey(ge::FORMAT_CHWN, 'N'), AXIS_CHWN_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_CHWN, 'C'), AXIS_CHWN_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_CHWN, 'H'), AXIS_CHWN_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_CHWN, 'W'), AXIS_CHWN_DIM_W}, - - {GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'N'), NDHWC_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'C'), NDHWC_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'H'), NDHWC_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'W'), NDHWC_DIM_W}, - {GenerateAxisIndexKey(ge::FORMAT_NDHWC, 'D'), NDHWC_DIM_D}, - - {GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'N'), NCDHW_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'C'), NCDHW_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'H'), NCDHW_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'W'), NCDHW_DIM_W}, - {GenerateAxisIndexKey(ge::FORMAT_NCDHW, 'D'), NCDHW_DIM_D}, - - {GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'N'), DHWCN_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'C'), DHWCN_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'H'), DHWCN_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'W'), DHWCN_DIM_W}, - {GenerateAxisIndexKey(ge::FORMAT_DHWCN, 'D'), DHWCN_DIM_D}, - - {GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'N'), DHWNC_DIM_N}, - {GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'C'), DHWNC_DIM_C}, - {GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'H'), DHWNC_DIM_H}, - {GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'W'), DHWNC_DIM_W}, - {GenerateAxisIndexKey(ge::FORMAT_DHWNC, 'D'), DHWNC_DIM_D} - }; - - void GeShapeToRtShape(const ge::GeShape &ge_shape, gert::Shape &rt_shape) { - rt_shape.SetDimNum(0); - for (size_t i = 0; i < ge_shape.GetDimNum(); ++i) { - rt_shape.AppendDim(ge_shape.GetDim(i)); - } - } - - void RtShapeToGeShape(const gert::Shape &rt_shape, ge::GeShape &ge_shape) { - ge_shape.SetDimNum(0); - for (size_t i = 0; i < rt_shape.GetDimNum(); ++i) { - ge_shape.AppendDim(rt_shape.GetDim(i)); - } - } -} - -bool GetDefaultReshapeType(const ge::Format &original_format, const size_t &old_dims_size, std::string &reshape_type) { - int32_t default_key = GenerateReshapeTypeKey(original_format, old_dims_size); - auto iter = DEFAULT_RESHAPE_TYPE.find(default_key); - if (iter == DEFAULT_RESHAPE_TYPE.end()) { - GELOGW("dim size %zu is invalid.", old_dims_size); - return false; - } - - reshape_type = iter->second; - return true; -} - -bool IsExpandNecessary(const size_t &old_dims_size, const ge::Format &original_format, const ge::Format &final_format, - const std::string &reshape_type, size_t &full_size) { - /* 1. Check whether the old dim size is full. Full size is not necessary for expand. */ - auto iter_full_size = FULL_SIZE_OF_FORMAT.find(original_format); - if (iter_full_size == FULL_SIZE_OF_FORMAT.end()) { - GELOGW("Original Format %u is invalid.", original_format); - return false; - } else { - if (old_dims_size >= iter_full_size->second) { - return false; - } - } - /* 2. Check whether the final format does not need expanding demension. */ - bool no_need_reshape_flag = reshape_type == RESHAPE_TYPE_FORBIDDEN || kFormatNZSet.count(final_format) > 0 || - (original_format == ge::FORMAT_ND && final_format == ge::FORMAT_FRACTAL_Z); - if (no_need_reshape_flag) { - return false; - } - full_size = iter_full_size->second; - return true; -} - -bool IsReshapeTypeValid(const ge::Format &original_format, const size_t &old_dims_size, - const std::string &reshape_type) { - if (reshape_type.empty()) { - return old_dims_size == 0; - } - int32_t pos = -1; - uint32_t format_key = GenerateFormatKey(original_format); - uint32_t axis_key = 0; - for (const char &dim : reshape_type) { - axis_key = format_key | (static_cast(dim) & 0xff); - auto iter = AXIS_INDEX_OF_FORMAT.find(axis_key); - if (iter == AXIS_INDEX_OF_FORMAT.end()) { - return false; - } - if (iter->second > pos) { - pos = iter->second; - } else { - return false; - } - } - - return true; -} - -void ExpandByReshapeType(ge::GeShape &shape, const ge::Format &original_format, - const size_t &old_dims_size, const size_t &full_size, const std::string &reshape_type) { - GELOGD("Expand tensor through reshape of type %s.", reshape_type.c_str()); - /* Build a array with all 1 of full size. Then we will substitute some of the 1 with the original axis value. */ - for (size_t i = old_dims_size; i < full_size; i++) { - shape.AppendDim(1); - } - if (reshape_type.empty() || old_dims_size == 0) { - return; - } - - uint32_t format_key = GenerateFormatKey(original_format); - uint32_t axis_key = 0; - for (int32_t i = static_cast(old_dims_size) - 1; i >= 0; i--) { - axis_key = format_key | (static_cast(reshape_type.at(i)) & 0xff); - auto iter_axis_index = AXIS_INDEX_OF_FORMAT.find(axis_key); - if (iter_axis_index == AXIS_INDEX_OF_FORMAT.end()) { - continue; - } - if (iter_axis_index->second == i) { - continue; - } - shape.SetDim(iter_axis_index->second, shape.GetDim(i)); - shape.SetDim(i, 1); - } -} - -bool ExpandDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format, - const uint32_t &tensor_index, const std::string &reshape_type, ge::GeShape &shape) { - /* 1. Check expanding necessary. */ - size_t full_size = 0; - size_t old_dims_size = shape.GetDimNum(); - auto primary_format = static_cast(ge::GetPrimaryFormat(final_format)); - if (!IsExpandNecessary(old_dims_size, original_format, primary_format, reshape_type, full_size)) { - return true; - } - - /* 2. Check whether the reshape type is consistent with the original format. - * If not consistent, just return and report a warning. */ - std::string valid_reshape_type = reshape_type; - if (!IsReshapeTypeValid(original_format, old_dims_size, reshape_type)) { - if (!GetDefaultReshapeType(original_format, old_dims_size, valid_reshape_type)) { - return true; - } - } - - /* 3. Check whether the dimension of original shape is less than or equal to - * the length of reshape type. If the dimension of original shape if larger, - * we cannot find suitable posotion for all axis in original shape and we just return. */ - if (old_dims_size > valid_reshape_type.length()) { - GELOGW("Dimension %zu of tensor %u in %s exceeds the length of the reshape type, which is %zu.", - old_dims_size, tensor_index, op_type.c_str(), valid_reshape_type.length()); - return true; - } - - /* 4. Expand dimension. */ - ExpandByReshapeType(shape, original_format, old_dims_size, full_size, valid_reshape_type); - return true; -} - -bool ExpandRangeDimension(const std::string &op_type, const ge::Format &original_format, - const ge::Format &final_format, const uint32_t &tensor_index, const std::string &reshape_type, - std::vector> &ranges) { - std::vector range_upper; - std::vector range_low; - for (auto &i : ranges) { - range_low.emplace_back(i.first); - range_upper.emplace_back(i.second); - } - - ge::GeShape shape_low(range_low); - ge::GeShape shape_upper(range_upper); - auto primary_format = static_cast(ge::GetPrimaryFormat(final_format)); - bool res = ExpandDimension(op_type, original_format, primary_format, tensor_index, reshape_type, shape_low) && - ExpandDimension(op_type, original_format, primary_format, tensor_index, reshape_type, shape_upper); - if (!res || (shape_low.GetDimNum() != shape_upper.GetDimNum())) { - return false; - } - ranges.clear(); - for (size_t idx = 0; idx < shape_low.GetDimNum(); ++idx) { - ranges.emplace_back(std::pair(shape_low.GetDim(idx), shape_upper.GetDim(idx))); - } - return res; -} - -ExpandDimension::ExpandDimension() {} -ExpandDimension::~ExpandDimension() {} - -int64_t ExpandDimension::GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, - const size_t &origin_dim_size, const std::string &reshape_type) { - auto primary_format = static_cast(ge::GetPrimaryFormat(format)); - GELOGD("Begin to generate integer reshape type, original format[%d], format[%d], dim size[%zu], reshape type[%s].", - origin_format, primary_format, origin_dim_size, reshape_type.c_str()); - int64_t ret_reshape_type = 0; - size_t full_size = 0; - if (!GetFormatFullSize(origin_format, full_size)) { - return ret_reshape_type; - } - if (!IsNeedExpand(origin_format, primary_format, origin_dim_size, full_size, reshape_type)) { - return ret_reshape_type; - } - - std::string valid_shape_type = reshape_type; - if (!IsReshapeTypeValid(origin_format, origin_dim_size, reshape_type)) { - if (!GetDefaultReshapeType(origin_format, origin_dim_size, valid_shape_type)) { - return ret_reshape_type; - } - GELOGD("Invalid reshape type [%s], using default reshape type [%s]", - reshape_type.c_str(), valid_shape_type.c_str()); - } - - if (origin_dim_size > valid_shape_type.length()) { - GELOGW("The length of reshape type[%s] is shorter than dim size[%zu]. Can not generate integer reshape type.", - valid_shape_type.c_str(), origin_dim_size); - return ret_reshape_type; - } - - uint32_t format_key = GenerateFormatKey(origin_format); - std::unordered_set dim_pos_set; - for (const char &dim : valid_shape_type.substr(0, origin_dim_size)) { - uint32_t axis_key = format_key | (static_cast(dim) & 0xff); - auto iter_axis_index = AXIS_INDEX_OF_FORMAT.find(axis_key); - if (iter_axis_index != AXIS_INDEX_OF_FORMAT.end()) { - dim_pos_set.emplace(iter_axis_index->second); - } - } - - for (size_t i = 0; i < full_size; i++) { - if (dim_pos_set.count(static_cast(i)) == 0) { - ret_reshape_type = ret_reshape_type | (1 << i); - } - } - - ret_reshape_type = ret_reshape_type | (static_cast(full_size) << kMaxReshapeTypeSize); - GELOGD("Integer reshape type [%s] has been generated for the original format [%d], with dim size [%zu] and reshape type [%s].", - std::bitset(ret_reshape_type).to_string().c_str(), origin_format, origin_dim_size, - valid_shape_type.c_str()); - return ret_reshape_type; -} - -bool ExpandDimension::GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, - const size_t &origin_dim_size, const std::string &reshape_type, - int64_t &reshape_type_mask) { - auto primary_format = static_cast(ge::GetPrimaryFormat(format)); - GELOGD("Begin to generate integer reshape type, original format[%d], format[%d], dim size[%zu], reshape type[%s].", - origin_format, primary_format, origin_dim_size, reshape_type.c_str()); - size_t full_size = 0; - if (!GetFormatFullSize(origin_format, full_size)) { - return true; - } - if (!IsNeedExpand(origin_format, primary_format, origin_dim_size, full_size, reshape_type)) { - return true; - } - - std::string valid_shape_type = reshape_type; - if (!IsReshapeTypeValid(origin_format, origin_dim_size, reshape_type)) { - if (!GetDefaultReshapeType(origin_format, origin_dim_size, valid_shape_type)) { - return true; - } - GELOGD("Invalid reshape type [%s], using default reshape type [%s]", - reshape_type.c_str(), valid_shape_type.c_str()); - } - - if (origin_dim_size > valid_shape_type.length()) { - GELOGE(ge::GRAPH_FAILED, "The length of reshape type[%s] is longer than dim size[%zu]. Can not generate integer reshape type.", - valid_shape_type.c_str(), origin_dim_size); - return false; - } - - uint32_t format_key = GenerateFormatKey(origin_format); - std::unordered_set dim_pos_set; - for (const char &dim : valid_shape_type.substr(0, origin_dim_size)) { - uint32_t axis_key = format_key | (static_cast(dim) & 0xff); - auto iter_axis_index = AXIS_INDEX_OF_FORMAT.find(axis_key); - if (iter_axis_index != AXIS_INDEX_OF_FORMAT.end()) { - dim_pos_set.emplace(iter_axis_index->second); - } - } - - for (size_t i = 0; i < full_size; i++) { - if (dim_pos_set.count(static_cast(i)) == 0) { - reshape_type_mask = reshape_type_mask | (1 << i); - } - } - - reshape_type_mask = reshape_type_mask | (static_cast(full_size) << kMaxReshapeTypeSize); - GELOGD("Integer reshape type [%s] has been generated for the original format [%d], with dim size [%zu] and reshape type [%s].", - std::bitset(reshape_type_mask).to_string().c_str(), origin_format, origin_dim_size, - valid_shape_type.c_str()); - return true; -} - -bool ExpandDimension::GenerateReshapeTypeByMask(const ge::Format &origin_format, const size_t &origin_dim_size, - const int64_t &reshape_type_mask, std::string &reshape_type, - std::string &failed_reason) { - if (origin_format == ge::FORMAT_ND) { - if (reshape_type_mask == 0) { - return true; - } else { - failed_reason = "Can not generate reshape type for ND format."; - GELOGI("%s", failed_reason.c_str()); - return false; - } - } - - std::string origin_format_str = ge::TypeUtils::FormatToSerialString(origin_format); - size_t full_size = 0; - if (!GetFormatFullSize(origin_format, full_size)) { - failed_reason = origin_format_str + " is not supported for expanding dims."; - GELOGI("%s", failed_reason.c_str()); - return false; - } - - if (reshape_type_mask == 0 && origin_dim_size == full_size) { - reshape_type = origin_format_str; - return true; - } - - size_t full_size_mask = static_cast(reshape_type_mask >> kMaxReshapeTypeSize); - if (full_size != full_size_mask) { - failed_reason = "Full size[" + std::to_string(full_size_mask) + "] from reshape mask is not correct,"; - failed_reason += " it should be[" + std::to_string(full_size) + "]."; - GELOGI("%s", failed_reason.c_str()); - return false; - } - - reshape_type.clear(); - size_t dim_count = 0; - for (size_t i = 0; i < full_size; ++i) { - if ((reshape_type_mask & (1 << i)) == 0) { - reshape_type += origin_format_str.at(i); - dim_count++; - } - } - - if (dim_count != origin_dim_size) { - std::string bit_str = std::bitset(reshape_type_mask).to_string(); - failed_reason = "[" + bit_str + "] is not correct when dim size is [" + std::to_string(origin_dim_size) + "]."; - GELOGI("%s", failed_reason.c_str()); - return false; - } - return true; -} - -bool ExpandDimension::IsNeedExpand(const ge::Format &origin_format, const ge::Format &format, - const size_t &origin_dim_size, const size_t &full_size, - const std::string &reshape_type) { - if (origin_dim_size >= full_size) { - return false; - } - if (reshape_type == RESHAPE_TYPE_FORBIDDEN) { - return false; - } - if (kSupportedTransFormat.count(format) != 0) { - return false; - } - if (origin_format == ge::FORMAT_ND && format == ge::FORMAT_FRACTAL_Z) { - return false; - } - return true; -} - -bool ExpandDimension::IsReshapeTypeValid(const ge::Format &origin_format, const size_t &origin_dim_size, - const std::string &reshape_type) { - if (reshape_type.empty()) { - return origin_dim_size == 0; - } - int32_t pos = -1; - uint32_t format_key = GenerateFormatKey(origin_format); - uint32_t axis_key = 0; - for (const char &dim : reshape_type) { - axis_key = format_key | (static_cast(dim) & 0xff); - auto iter = AXIS_INDEX_OF_FORMAT.find(axis_key); - if (iter == AXIS_INDEX_OF_FORMAT.end()) { - return false; - } - if (iter->second > pos) { - pos = iter->second; - } else { - return false; - } - } - return true; -} - -bool ExpandDimension::GetDefaultReshapeType(const ge::Format &origin_format, const size_t &origin_dim_size, - std::string &reshape_type) { - int32_t default_key = GenerateReshapeTypeKey(origin_format, origin_dim_size); - auto iter = DEFAULT_RESHAPE_TYPE.find(default_key); - if (iter == DEFAULT_RESHAPE_TYPE.end()) { - GELOGW("Dim size %zu is invalid, default reshape type not found.", origin_dim_size); - return false; - } - - reshape_type = iter->second; - return true; -} - -void ExpandDimension::ExpandDims(const int64_t &reshape_type, ge::GeShape &shape) { - GELOGD("Begin to expand dims, reshape type[%" PRId64 "], shape[%s].", reshape_type, shape.ToString().c_str()); - gert::Shape inner_shape; - GeShapeToRtShape(shape, inner_shape); - ExpandDims(reshape_type, inner_shape); - RtShapeToGeShape(inner_shape, shape); - GELOGD("After expanding dims, shape[%s].", shape.ToString().c_str()); -} - -void ExpandDimension::ExpandDims(const int64_t &reshape_type, const ge::GeShape &origin_shape, ge::GeShape &shape) { - GELOGD("Begin to expand dims, reshape type[%" PRId64 "], origin shape[%s].", reshape_type, - origin_shape.ToString().c_str()); - gert::Shape inner_ori_shape; - GeShapeToRtShape(origin_shape, inner_ori_shape); - gert::Shape inner_shape; - GeShapeToRtShape(shape, inner_shape); - ExpandDims(reshape_type, inner_ori_shape, inner_shape); - RtShapeToGeShape(inner_shape, shape); - GELOGD("After expanding dims, shape[%s].", shape.ToString().c_str()); -} - -void ExpandDimension::ExpandDims(const int64_t &reshape_type, gert::Shape &shape) { - if (reshape_type == 0) { - return; - } - gert::ExpandDimsType expand_dims_type(reshape_type); - expand_dims_type.Expand(shape); -} - -void ExpandDimension::ExpandDims(const int64_t &reshape_type, const gert::Shape &origin_shape, gert::Shape &shape) { - if (reshape_type == 0) { - return; - } - gert::ExpandDimsType expand_dims_type(reshape_type); - expand_dims_type.Expand(origin_shape, shape); -} - -bool ExpandDimension::GetFormatFullSize(const ge::Format &format, size_t &full_size) { - auto iter = FULL_SIZE_OF_FORMAT.find(format); - if (iter == FULL_SIZE_OF_FORMAT.end()) { - return false; - } - full_size = iter->second; - return true; -} - -int32_t ExpandDimension::GetAxisIndexByName(char ch, const ge::Format &format) { - uint32_t format_key = GenerateFormatKey(format); - uint32_t axis_key = 0; - axis_key = format_key | (static_cast(ch) & 0xff); - auto iter = AXIS_INDEX_OF_FORMAT.find(axis_key); - if (iter == AXIS_INDEX_OF_FORMAT.end()) { - return -1; - } - return iter->second; -} -int64_t ExpandDimension::GetReshapeAxicValue(const int64_t &reshape_type_mask, - const ge::GeShape &shape, int32_t axis_index) { - GELOGD("axis_index = %d.", axis_index); - if (axis_index == -1) { - return -1; - } - gert::ExpandDimsType expand_dims_type(reshape_type_mask); - if (!expand_dims_type.IsExpandIndex(axis_index)) { - GELOGD("axis_index is %d.", axis_index); - } - return shape.GetDim(static_cast(axis_index)); -} -int64_t ExpandDimension::GetReshapeAxicValueByName(const int64_t &reshape_type_mask, char ch, - const ge::GeShape &shape, const ge::Format &format) { - auto idx = GetAxisIndexByName(ch, format); - return GetReshapeAxicValue(reshape_type_mask, shape, idx); -} -} // namespace transformer diff --git a/third_party/transformer/src/transfer_range_according_to_format.cc b/third_party/transformer/src/transfer_range_according_to_format.cc deleted file mode 100644 index a29f2bf2db8308f4e4d1314190557d2ce9fca591..0000000000000000000000000000000000000000 --- a/third_party/transformer/src/transfer_range_according_to_format.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "transfer_range_according_to_format.h" -#include - -namespace transformer { -bool RangeTransferAccordingToFormat::GetRangeAccordingToFormat(const ge::OpDescPtr &op_desc, - RangeAndFormat &range_and_format_info) { - /* The default new range is old range */ - std::vector range_upper_old; - std::vector range_low_old; - for (auto &i : range_and_format_info.old_range) { - range_low_old.emplace_back(i.first); - range_upper_old.emplace_back(i.second); - } - - ge::GeShape shape_low(range_low_old); - ge::GeShape shape_upper(range_upper_old); - transformer::ShapeAndFormat shape_and_format_info_low {shape_low, range_and_format_info.old_format, - range_and_format_info.new_format, range_and_format_info.current_data_type}; - transformer::ShapeAndFormat shape_and_format_info_upper {shape_upper, range_and_format_info.old_format, - range_and_format_info.new_format, range_and_format_info.current_data_type}; - ShapeTransferAccordingToFormat shape_transfer; - bool res = (shape_transfer.GetShapeAccordingToFormat(op_desc, shape_and_format_info_low) && - shape_transfer.GetShapeAccordingToFormat(op_desc, shape_and_format_info_upper)); - if (!res || (shape_low.GetDimNum() != shape_upper.GetDimNum())) { - return false; - } - range_and_format_info.new_range.clear(); - for (size_t i = 0; i < shape_low.GetDimNum(); ++i) { - range_and_format_info.new_range.emplace_back(shape_low.GetDim(i), shape_upper.GetDim(i)); - } - return res; -} - -bool RangeTransferAccordingToFormat::GetRangeAccordingToFormat(const ExtAxisOpValue &op_value, - RangeAndFormat &range_and_format_info) { - /* The default new range is old range */ - std::vector range_upper_old; - std::vector range_low_old; - for (auto &i : range_and_format_info.old_range) { - range_low_old.emplace_back(i.first); - range_upper_old.emplace_back(i.second); - } - - ge::GeShape shape_low(range_low_old); - ge::GeShape shape_upper(range_upper_old); - transformer::ShapeAndFormat shape_and_format_info_low {shape_low, range_and_format_info.old_format, - range_and_format_info.new_format, range_and_format_info.current_data_type}; - transformer::ShapeAndFormat shape_and_format_info_upper {shape_upper, range_and_format_info.old_format, - range_and_format_info.new_format, range_and_format_info.current_data_type}; - ShapeTransferAccordingToFormat shape_transfer; - bool res = (shape_transfer.GetShapeAccordingToFormat(op_value, shape_and_format_info_low) && - shape_transfer.GetShapeAccordingToFormat(op_value, shape_and_format_info_upper)); - if (!res || (shape_low.GetDimNum() != shape_upper.GetDimNum())) { - return false; - } - range_and_format_info.new_range.clear(); - for (size_t i = 0; i < shape_low.GetDimNum(); ++i) { - range_and_format_info.new_range.emplace_back(shape_low.GetDim(i), shape_upper.GetDim(i)); - } - return res; -} - -bool RangeTransferAccordingToFormat::GetRangeAccordingToFormat(RangeAndFormat &range_and_format_info) { - /* The default new range is old range */ - std::vector range_upper_old; - std::vector range_low_old; - for (auto &i : range_and_format_info.old_range) { - range_low_old.emplace_back(i.first); - range_upper_old.emplace_back(i.second); - } - - ge::GeShape shape_low(range_low_old); - ge::GeShape shape_upper(range_upper_old); - transformer::ShapeAndFormat shape_and_format_info_low {shape_low, range_and_format_info.old_format, - range_and_format_info.new_format, range_and_format_info.current_data_type}; - transformer::ShapeAndFormat shape_and_format_info_upper {shape_upper, range_and_format_info.old_format, - range_and_format_info.new_format, range_and_format_info.current_data_type}; - ShapeTransferAccordingToFormat shape_transfer; - bool res = (shape_transfer.GetShapeAccordingToFormat(shape_and_format_info_low) && - shape_transfer.GetShapeAccordingToFormat(shape_and_format_info_upper)); - if (!res || (shape_low.GetDimNum() != shape_upper.GetDimNum())) { - return false; - } - range_and_format_info.new_range.clear(); - for (size_t i = 0; i < shape_low.GetDimNum(); ++i) { - range_and_format_info.new_range.emplace_back(shape_low.GetDim(i), shape_upper.GetDim(i)); - } - return res; -} -}; // namespace fe diff --git a/third_party/transformer/src/transfer_shape_according_to_format.cc b/third_party/transformer/src/transfer_shape_according_to_format.cc deleted file mode 100644 index c124a393984b6c0d899c569a34d43218bc12b387..0000000000000000000000000000000000000000 --- a/third_party/transformer/src/transfer_shape_according_to_format.cc +++ /dev/null @@ -1,183 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "transfer_shape_according_to_format.h" -#include -#include "axis_constants.h" -#include "graph/utils/attr_utils.h" -#include "transfer_shape_utils.h" - -namespace transformer { -namespace { - const std::string kAttrHiddenSize = "hidden_size"; - const std::string kAttrInputSize = "input_size"; - const std::string kAttrStateSize = "state_size"; - const int64_t kM0DefaultVal = 16; - - void GeShapeToRtShape(const ge::GeShape &ge_shape, gert::Shape &rt_shape) { - rt_shape.SetDimNum(0); - for (size_t i = 0; i < ge_shape.GetDimNum(); ++i) { - rt_shape.AppendDim(ge_shape.GetDim(i)); - } - } - - void RtShapeToGeShape(const gert::Shape &rt_shape, ge::GeShape &ge_shape) { - ge_shape.SetDimNum(0); - for (size_t i = 0; i < rt_shape.GetDimNum(); ++i) { - ge_shape.AppendDim(rt_shape.GetDim(i)); - } - } -} - -ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat() {} - -bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(const ge::OpDescPtr &op_desc, - ShapeAndFormat &shapeAndFormatInfo) { - if (shapeAndFormatInfo.oldShape.IsUnknownDimNum()) { - return true; - } - gert::Shape shape; - GeShapeToRtShape(shapeAndFormatInfo.oldShape, shape); - ExtAxisValue ext_axis; - InitExtAxisValue(op_desc, ext_axis); - bool ret = TransferShapeUtils::TransferShape(shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat, - shapeAndFormatInfo.currentDataType, ext_axis, shape); - RtShapeToGeShape(shape, shapeAndFormatInfo.oldShape); - return ret; -} - -bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(const ExtAxisOpValue &op_value, - ShapeAndFormat &shapeAndFormatInfo) { - if (shapeAndFormatInfo.oldShape.IsUnknownDimNum()) { - return true; - } - gert::Shape shape; - GeShapeToRtShape(shapeAndFormatInfo.oldShape, shape); - ExtAxisValue ext_axis; - InitExtAxisValue(op_value, ext_axis); - bool ret = TransferShapeUtils::TransferShape(shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat, - shapeAndFormatInfo.currentDataType, ext_axis, shape); - RtShapeToGeShape(shape, shapeAndFormatInfo.oldShape); - return ret; -} - -bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat &shapeAndFormatInfo) { - if (shapeAndFormatInfo.oldShape.IsUnknownDimNum()) { - return true; - } - gert::Shape shape; - GeShapeToRtShape(shapeAndFormatInfo.oldShape, shape); - ExtAxisValue ext_axis = {shapeAndFormatInfo.extra_attr.input_size, shapeAndFormatInfo.extra_attr.hidden_size, - shapeAndFormatInfo.extra_attr.state_size, kM0DefaultVal}; - bool ret = TransferShapeUtils::TransferShape(shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat, - shapeAndFormatInfo.currentDataType, ext_axis, shape); - RtShapeToGeShape(shape, shapeAndFormatInfo.oldShape); - return ret; -} - -bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, const ExtAxisValue &ext_axis, - ge::GeShape &shape) { - gert::Shape rt_shape; - GeShapeToRtShape(shape, rt_shape); - bool ret = TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, rt_shape); - RtShapeToGeShape(rt_shape, shape); - return ret; -} - -bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, const ExtAxisValue &ext_axis, - const ge::GeShape &origin_shape, ge::GeShape &shape) { - gert::Shape rt_origin_shape; - GeShapeToRtShape(origin_shape, rt_origin_shape); - gert::Shape rt_shape; - GeShapeToRtShape(shape, rt_shape); - bool ret = TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, rt_origin_shape, rt_shape); - RtShapeToGeShape(rt_shape, shape); - return ret; -} - -bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, gert::Shape &shape, - const ge::OpDescPtr op_desc, const fe::PlatFormInfos *platform_infos_ptr) { - ExtAxisValue ext_axis; - InitExtAxisValue(op_desc, ext_axis); - return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, shape, platform_infos_ptr); -} - -bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, gert::Shape &shape, - const ExtAxisOpValue &op_value, - const fe::PlatFormInfos *platform_infos_ptr) { - ExtAxisValue ext_axis; - InitExtAxisValue(op_value, ext_axis); - return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, shape, platform_infos_ptr); -} - -bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, const gert::Shape &origin_shape, - gert::Shape &shape, const ge::OpDescPtr op_desc) { - ExtAxisValue ext_axis; - InitExtAxisValue(op_desc, ext_axis); - return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, origin_shape, shape); -} - -bool ShapeTransferAccordingToFormat::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, const gert::Shape &origin_shape, - gert::Shape &shape, const ExtAxisOpValue &op_value) { - ExtAxisValue ext_axis; - InitExtAxisValue(op_value, ext_axis); - return TransferShapeUtils::TransferShape(origin_format, format, data_type, ext_axis, origin_shape, shape); -} - -void ShapeTransferAccordingToFormat::InitExtAxisValue(const ge::OpDescPtr &op_desc, ExtAxisValue &ext_axis) { - int64_t input_size = 1; - int64_t hidden_size = 1; - int64_t state_size = -1; - if (op_desc != nullptr) { - (void)ge::AttrUtils::GetInt(op_desc, kAttrInputSize, input_size); - (void)ge::AttrUtils::GetInt(op_desc, kAttrHiddenSize, hidden_size); - (void)ge::AttrUtils::GetInt(op_desc, kAttrStateSize, state_size); - } - - ext_axis[EXT_INDEX_INPUT_SIZE] = input_size; - ext_axis[EXT_INDEX_HIDDEN_SIZE] = hidden_size; - ext_axis[EXT_INDEX_STATE_SIZE] = state_size; - ext_axis[EXT_INDEX_M0_VAL] = kM0DefaultVal; -} - -void ShapeTransferAccordingToFormat::InitExtAxisValue(const ExtAxisOpValue &op_value, ExtAxisValue &ext_axis) { - ext_axis[EXT_INDEX_INPUT_SIZE] = op_value[EXT_INDEX_INPUT_SIZE]; - ext_axis[EXT_INDEX_HIDDEN_SIZE] = op_value[EXT_INDEX_HIDDEN_SIZE]; - ext_axis[EXT_INDEX_STATE_SIZE] = op_value[EXT_INDEX_STATE_SIZE]; - ext_axis[EXT_INDEX_M0_VAL] = kM0DefaultVal; -} - -bool ShapeTransferAccordingToFormat::InitPlatformInfo() { - return TransferShapeUtils::InitPlatformInfo(); -} - -int64_t ShapeTransferAccordingToFormat::GetC0ByDtype(const ge::DataType &data_type) { - return TransferShapeUtils::GetC0ByDtype(data_type); -} -int64_t ShapeTransferAccordingToFormat::GetM0ByDtype(const ge::DataType &data_type) { - return TransferShapeUtils::GetM0ByDtype(data_type); -} -int64_t ShapeTransferAccordingToFormat::GetN0ByDtype(const ge::DataType &data_type) { - return TransferShapeUtils::GetN0ByDtype(data_type); -} - -bool ShapeTransferAccordingToFormat::GetAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) { - return TransferShapeUtils::GetAlignedShape(align_shape_info, aligned_shape); -} - -bool ShapeTransferAccordingToFormat::TransferDims(const TransferDimsInfo &transfer_dims_info, AxisIndexMapping &axis_index_mapping) { - return TransferShapeUtils::TransferDims(transfer_dims_info, axis_index_mapping); -} -} // namespace transformer diff --git a/third_party/transformer/src/transfer_shape_utils.cc b/third_party/transformer/src/transfer_shape_utils.cc deleted file mode 100644 index 0bb895ce19f224023627e17b41d3d9fbf0555b2d..0000000000000000000000000000000000000000 --- a/third_party/transformer/src/transfer_shape_utils.cc +++ /dev/null @@ -1,1528 +0,0 @@ -/* Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). - * Please refer to the License for details. You may not use this file except in compliance with the License. - * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. - * See LICENSE in the root of the software repository for the full text of the License. - * ===================================================================================================================*/ - -#include "transfer_shape_utils.h" -#include -#include "axis_constants.h" -#include "common/ge_common/string_util.h" -#include "external/graph/ge_error_codes.h" -#include "platform/platform_info.h" -#include "graph/utils/type_utils.h" -#include "graph/types.h" -#include "expand_dimension.h" - -namespace transformer { -std::array(ge::DataType::DT_MAX)> TransferShapeUtils::m0_list_{}; -std::array(ge::DataType::DT_MAX)> TransferShapeUtils::k0_list_{}; -std::array(ge::DataType::DT_MAX)> TransferShapeUtils::n0_list_{}; -const std::map TransferShapeUtils::get_aligned_shape_func_map = { - {TransferShapeType::ND_TO_ND, GetNdToNdAlignedShape}, - {TransferShapeType::ND_TO_NZ, GetNdToNzAlignedShape}, - {TransferShapeType::FULL_SIZE, GetFullSizeAlignedShape}, - {TransferShapeType::NOT_FULL_SIZE, GetNotFullSizeAlignedShape}, -}; -const std::map TransferShapeUtils::transfer_dims_func_map = { - {TransferShapeType::ND_TO_ND, GetNdToNdAxisIndexMapping}, - {TransferShapeType::ND_TO_NZ, GetNdToNzAxisIndexMapping}, - {TransferShapeType::FULL_SIZE, GetFullSizeAxisIndexMapping}, - {TransferShapeType::NOT_FULL_SIZE, GetNotFullSizeAxisIndexMapping}, -}; -namespace { - const int64_t SHAPE_NUMBER_32 = 32; - const int64_t SHAPE_NUMBER_16 = 16; - const int64_t SHAPE_NUMBER_8 = 8; - const int64_t SHAPE_NUMBER_4 = 4; - const int64_t NI = 16; - const int64_t LSTM_NI = 4; - const int64_t GROUPS_DEFAULT_VALUE = 1; - const int64_t UNKNOWN_SHAPE_VALUE = -1; - const int64_t RNN_STATE_SIZE_DEFAULT_VALUE = -1; - const size_t NUMBER_2 = 2; - const size_t MINUS_VALUE_ONE = 1; - const size_t MINUS_VALUE_TWO = 2; - - const size_t DIM_INDEX_N = 0; - const size_t DIM_INDEX_C = 1; - const size_t DIM_INDEX_H = 2; - const size_t DIM_INDEX_W = 3; - const size_t DIM_INDEX_D = 4; - const size_t DIM_INDEX_ZERO = 0; - const size_t DIM_INDEX_ONE = 1; - const size_t DIM_INDEX_TWO = 2; - const size_t DIM_INDEX_THREE = 3; - const size_t DIM_INDEX_FOUR = 4; - const size_t kM0Index = 0; - const size_t kK0Index = 1; - const size_t kN0Index = 2; - const size_t kNzMinDimNum = 2; - constexpr size_t MOKOCO_CONFIG_SIZE = 3; - const std::string kPltDtypeMKN = "DtypeMKN"; - const std::string kPltDefault = "Default"; - const std::map kFormatIndexMap = { - {ge::FORMAT_NCHW, {DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_THREE, DIM_INDEX_FOUR}}, - {ge::FORMAT_NHWC, {DIM_INDEX_ZERO, DIM_INDEX_THREE, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_FOUR}}, - {ge::FORMAT_HWCN, {DIM_INDEX_THREE, DIM_INDEX_TWO, DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_FOUR}}, - {ge::FORMAT_CHWN, {DIM_INDEX_THREE, DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_FOUR}}, - {ge::FORMAT_ND, {DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_THREE, DIM_INDEX_FOUR}}, - {ge::FORMAT_NCDHW, {DIM_INDEX_ZERO, DIM_INDEX_ONE, DIM_INDEX_THREE, DIM_INDEX_FOUR, DIM_INDEX_TWO}}, - {ge::FORMAT_NDHWC, {DIM_INDEX_ZERO, DIM_INDEX_FOUR, DIM_INDEX_TWO, DIM_INDEX_THREE, DIM_INDEX_ONE}}, - {ge::FORMAT_DHWCN, {DIM_INDEX_FOUR, DIM_INDEX_THREE, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_ZERO}}, - {ge::FORMAT_DHWNC, {DIM_INDEX_THREE, DIM_INDEX_FOUR, DIM_INDEX_ONE, DIM_INDEX_TWO, DIM_INDEX_ZERO}} - }; - - const std::set kOriginFormatVec = { - ge::FORMAT_NCHW, ge::FORMAT_NHWC, ge::FORMAT_HWCN, - ge::FORMAT_CHWN, ge::FORMAT_NDHWC, ge::FORMAT_NCDHW, - ge::FORMAT_DHWCN, ge::FORMAT_DHWNC, ge::FORMAT_ND - }; - - inline int64_t GetGreatestCommonDivisor(int64_t x, int64_t y) { - if (y == 0) { - return x; - } - int64_t z = y; - while (x % y != 0) { - z = x % y; - x = y; - y = z; - } - return z; - } - - inline int64_t GetLeastCommonMultiple(int64_t x, int64_t y) { - if (x == 0 || y == 0) { - return 0; - } - return (x * y) / GetGreatestCommonDivisor(x, y); - } - - inline int64_t GetAsisEnlargeValue(int64_t cin, int64_t cout, int64_t c0, int64_t group) { - if (cin == 0 || cout == 0) { - return 0; - } - - return std::min(GetLeastCommonMultiple(c0 / GetGreatestCommonDivisor(cin, c0), - NI / GetGreatestCommonDivisor(cout, NI)), group); - } -} - -bool TransferShapeUtils::InitM0K0CO(const fe::PlatFormInfos *platform_infos) { - std::string default_mkn; - fe::PlatFormInfos *temp_platform_infos = const_cast(platform_infos); - if (temp_platform_infos->GetPlatformResWithLock(kPltDtypeMKN, kPltDefault, default_mkn)) { - GELOGD("Default MKN value from platform is [%s].", default_mkn.c_str()); - std::vector infos = ge::StringUtils::Split(default_mkn, ','); - if (infos.size() != MOKOCO_CONFIG_SIZE) { - return false; - } - m0_list_.fill(static_cast(std::atoi(infos[kM0Index].c_str()))); - k0_list_.fill(static_cast(std::atoi(infos[kK0Index].c_str()))); - n0_list_.fill(static_cast(std::atoi(infos[kN0Index].c_str()))); - } - - std::map m0_k0_n0_info; - temp_platform_infos->GetPlatformResWithLock(kPltDtypeMKN, m0_k0_n0_info); - for (auto &item : m0_k0_n0_info) { - if (item.first == kPltDefault) { - continue; - } - ge::DataType dtype = ge::TypeUtils::SerialStringToDataType(item.first); - if (dtype == ge::DT_UNDEFINED) { - continue; - } - std::vector infos = ge::StringUtils::Split(item.second, ','); - if (infos.size() != MOKOCO_CONFIG_SIZE) { - continue; - } - m0_list_[static_cast(dtype)] = static_cast(std::atoi(infos[kM0Index].c_str())); - k0_list_[static_cast(dtype)] = static_cast(std::atoi(infos[kK0Index].c_str())); - n0_list_[static_cast(dtype)] = static_cast(std::atoi(infos[kN0Index].c_str())); - } - return true; -} - -bool TransferShapeUtils::InitPlatformInfo(const fe::PlatFormInfos *platform_infos_ptr) { - static std::once_flag flag; - std::call_once(flag, [&platform_infos_ptr]() { - m0_list_.fill(SHAPE_NUMBER_16); - k0_list_.fill(SHAPE_NUMBER_16); - n0_list_.fill(SHAPE_NUMBER_16); - fe::PlatFormInfos platform_infos; - fe::OptionalInfos optional_infos; - if (platform_infos_ptr == nullptr) { - GELOGI("Input platform is null; now retrieving m0k0c0 from platformmanager autonomously."); - fe::PlatformInfoManager::GeInstance().InitializePlatformInfo(); - if (fe::PlatformInfoManager::GeInstance().GetPlatformInfoWithOutSocVersion(platform_infos, optional_infos) != 0) { - GELOGW("Failed to get platform info, using default MKN value."); - return false; - } - } - return platform_infos_ptr == nullptr ? TransferShapeUtils::InitM0K0CO(&platform_infos) : TransferShapeUtils::InitM0K0CO(platform_infos_ptr); - }); - return true; -} - -bool TransferShapeUtils::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, const ExtAxisValue &ext_axis, - gert::Shape &shape, const fe::PlatFormInfos *platform_infos_ptr) { - if (!InitPlatformInfo(platform_infos_ptr)) { - GELOGW("Init platform info failed"); - } - ge::Format primary_format = static_cast(GetPrimaryFormat(format)); - ge::Format origin_primary_format = static_cast(GetPrimaryFormat(origin_format)); - GELOGD("Original format is %s, new format is %s", - (ge::TypeUtils::FormatToSerialString(origin_primary_format)).c_str(), - (ge::TypeUtils::FormatToSerialString(primary_format)).c_str()); - if (!IsNeedTransferShape(origin_primary_format, primary_format, shape)) { - return true; - } - - if (!CheckInputParam(origin_primary_format, primary_format, data_type)) { - return false; - } - - AxisValue axis_value; - axis_value.fill(1); - int64_t group = static_cast(ge::GetSubFormat(format)); - if (group > GROUPS_DEFAULT_VALUE) { - axis_value[AXIS_G] = group; - } - - axis_value[AXIS_C0] = GetC0Value(data_type, format); - axis_value[AXIS_M0] = GetM0ByDtype(data_type); - if (primary_format == ge::FORMAT_FRACTAL_ZN_RNN || primary_format == ge::FORMAT_ND_RNN_BIAS) { - axis_value[AXIS_INPUT_SIZE] = ext_axis[EXT_INDEX_INPUT_SIZE]; - axis_value[AXIS_HIDDEN_SIZE] = ext_axis[EXT_INDEX_HIDDEN_SIZE]; - axis_value[AXIS_STATE_SIZE] = ext_axis[EXT_INDEX_STATE_SIZE]; - } - - if (!IsNeedAxisValue(primary_format, shape.GetDimNum())) { - return TransferShapeByFormat(primary_format, axis_value, shape); - } - - if (!AxisUtil::GetAxisValueByOriginFormat(origin_primary_format, shape, axis_value)) { - return true; - } - - return TransferShapeByAxisValue(primary_format, axis_value, shape); -} - -bool TransferShapeUtils::TransferShape(const ge::Format &origin_format, const ge::Format &format, - const ge::DataType &data_type, const ExtAxisValue &ext_axis, - const gert::Shape &origin_shape, gert::Shape &shape, const fe::PlatFormInfos *platform_infos_ptr) { - if (!InitPlatformInfo(platform_infos_ptr)) { - GELOGW("Init platform info failed"); - } - ge::Format primary_format = static_cast(GetPrimaryFormat(format)); - ge::Format origin_primary_format = static_cast(GetPrimaryFormat(origin_format)); - GELOGD("Transfer shape from original format[%s] to format [%s].", - (ge::TypeUtils::FormatToSerialString(origin_primary_format)).c_str(), - (ge::TypeUtils::FormatToSerialString(primary_format)).c_str()); - if (!IsNeedTransferShape(origin_primary_format, primary_format, origin_shape)) { - return true; - } - - if (!CheckInputParam(origin_primary_format, primary_format, data_type)) { - return false; - } - - int64_t c0 = GetC0Value(data_type, format); - int64_t m0 = GetM0ByDtype(data_type); - if (!IsNeedAxisValue(primary_format, origin_shape.GetDimNum())) { - return TransferShapeByOriginShape(primary_format, c0, m0, ext_axis, origin_shape, shape); - } else { - return TransferShapeByFormatIndex(origin_primary_format, format, c0, origin_shape, shape); - } -} - -int64_t TransferShapeUtils::GetC0ByDtype(const ge::DataType &data_type) { - if (static_cast(data_type) < k0_list_.size()) { - return static_cast(k0_list_[static_cast(data_type)]); - } - return SHAPE_NUMBER_16; -} - -int64_t TransferShapeUtils::GetM0ByDtype(const ge::DataType &data_type) { - if (static_cast(data_type) < m0_list_.size()) { - return static_cast(m0_list_[static_cast(data_type)]); - } - return SHAPE_NUMBER_16; -} - -int64_t TransferShapeUtils::GetN0ByDtype(const ge::DataType &data_type) { - if (static_cast(data_type) < n0_list_.size()) { - return static_cast(n0_list_[static_cast(data_type)]); - } - return SHAPE_NUMBER_16; -} - -bool TransferShapeUtils::IsNeedTransferShape(const ge::Format &origin_format, const ge::Format &format, - const gert::Shape &shape) { - if (origin_format == ge::FORMAT_ND && kOriginFormatVec.count(format) > 0) { - GELOGD("No need for shape transformation from ND to the original format."); - return false; - } - - if (shape.IsScalar()) { - GELOGD("Do not need to do shape transformation if the shape is scalar."); - return false; - } - return true; -} - -bool TransferShapeUtils::CheckInputParam(const ge::Format &origin_format, const ge::Format &primary_format, - const ge::DataType &data_type) { - bool invalid_format = (origin_format == ge::FORMAT_RESERVED || origin_format >= ge::FORMAT_END) || - (primary_format == ge::FORMAT_RESERVED || primary_format >= ge::FORMAT_END); - if (invalid_format) { - GELOGE(ge::GRAPH_FAILED, "Origin format %u or new format %u is invalid.", origin_format, primary_format); - return false; - } - - if (data_type == ge::DT_UNDEFINED || data_type >= ge::DT_MAX) { - GELOGE(ge::GRAPH_FAILED, "DataType %u is invalid.", origin_format); - return false; - } - - return true; -} - -int64_t TransferShapeUtils::GetC0Value(const ge::DataType &data_type, const ge::Format &format) { - // The value of C0 should be 4 while format is 5HD-4 or FRAZ-4 - ge::Format primary_format = static_cast(GetPrimaryFormat(format)); - if (primary_format == ge::FORMAT_NC1HWC0_C04) { - return SHAPE_NUMBER_4; - } - if (primary_format == ge::FORMAT_FRACTAL_NZ_C0_16) { - return SHAPE_NUMBER_16; - } - if (primary_format == ge::FORMAT_FRACTAL_NZ_C0_32) { - return SHAPE_NUMBER_32; - } - - if (ge::HasC0Format(format)) { - return ge::GetC0Value(format); - } - - return GetC0ByDtype(data_type); -} - -bool TransferShapeUtils::IsNeedAxisValue(const ge::Format &format, const size_t &origin_dim_size) { - if (kFormatNZSet.count(format) > 0 || format == ge::FORMAT_FRACTAL_ZN_RNN || - format == ge::FORMAT_ND_RNN_BIAS || format == ge::FORMAT_NYUV_A) { - return false; - } - if (format == ge::FORMAT_FRACTAL_Z && origin_dim_size == DIM_SIZE_TWO) { - return false; - } - return true; -} - -bool TransferShapeUtils::TransferShapeByFormat(const ge::Format &primary_format, const AxisValue &axis_value, - gert::Shape &shape) { - switch (primary_format) { - case ge::FORMAT_FRACTAL_Z: - return GetFzShapeByAxisValue(axis_value, shape); - case ge::FORMAT_FRACTAL_NZ: - case ge::FORMAT_FRACTAL_NZ_C0_16: - case ge::FORMAT_FRACTAL_NZ_C0_32: - return GetNzShapeByAxisValue(axis_value, shape); // need c0 - case ge::FORMAT_FRACTAL_ZN_RNN: - return GetFznRNNShapeByAxisValue(axis_value, shape); // need c0, input, hidden, state - case ge::FORMAT_ND_RNN_BIAS: - return GetNDRNNShapeByAxisValue(axis_value, shape); // need c0, input, hidden, state - case ge::FORMAT_NYUV_A: - return GetNYUVShape(shape); - default: - GELOGD("Cannot obtain new shape with format %d.", primary_format); - return true; - } -} - -bool TransferShapeUtils::TransferShapeByAxisValue(const ge::Format &primary_format, const AxisValue &axis_value, - gert::Shape &shape) { - switch (primary_format) { - case ge::FORMAT_NCHW: - return GetNCHWShapeByAxisValue(axis_value, shape); - case ge::FORMAT_NHWC: - return GetNHWCShapeByAxisValue(axis_value, shape); - case ge::FORMAT_HWCN: - return GetHWCNShapeByAxisValue(axis_value, shape); - case ge::FORMAT_CHWN: - return GetCHWNShapeByAxisValue(axis_value, shape); - case ge::FORMAT_NDHWC: - return GetNDHWCShapeByAxisValue(axis_value, shape); - case ge::FORMAT_NCDHW: - return GetNCDHWShapeByAxisValue(axis_value, shape); - case ge::FORMAT_DHWCN: - return GetDHWCNShapeByAxisValue(axis_value, shape); - case ge::FORMAT_DHWNC: - return GetDHWNCShapeByAxisValue(axis_value, shape); - case ge::FORMAT_NC1HWC0: - case ge::FORMAT_NC1HWC0_C04: - return GetNC1HWC0ShapeByAxisValue(axis_value, shape); - case ge::FORMAT_C1HWC0: - return GetC1HWC0ShapeByAxisValue(axis_value, shape); - case ge::FORMAT_NDC1HWC0: - return GetNDC1HWC0ShapeByAxisValue(axis_value, shape); - case ge::FORMAT_C1HWNCoC0: - return GetC1HWNCoC0ShapeByAxisValue(axis_value, shape); - case ge::FORMAT_FRACTAL_Z: - return GetFzShapeByAxisValue(axis_value, shape); - case ge::FORMAT_FRACTAL_Z_WINO: - return GetFzWinoShapeByAxisValue(axis_value, shape); - case ge::FORMAT_FRACTAL_Z_C04: - return GetFzC04ShapeByAxisValue(axis_value, shape); - case ge::FORMAT_FRACTAL_Z_3D: - return GetFz3DShapeByAxisValue(axis_value, shape); - case ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE: - return GetFz3DTransposeShapeByAxisValue(axis_value, shape); - case ge::FORMAT_FRACTAL_ZN_LSTM: - return GetFzLstmShapeByAxisValue(axis_value, shape); - default: - GELOGD("Cannot obtain new shape with format %d.", primary_format); - return true; - } -} - -bool TransferShapeUtils::TransferShapeByOriginShape(const ge::Format &primary_format, - const int64_t &c0, const int64_t &m0, const ExtAxisValue &ext_axis, - const gert::Shape &origin_shape, gert::Shape &shape) { - switch (primary_format) { - case ge::FORMAT_FRACTAL_Z: - return GetFractalZShape(c0, origin_shape, shape); - case ge::FORMAT_FRACTAL_NZ: - case ge::FORMAT_FRACTAL_NZ_C0_16: - case ge::FORMAT_FRACTAL_NZ_C0_32: - return GetFractalNzShape(c0, m0, origin_shape, shape); // need c0 - case ge::FORMAT_FRACTAL_ZN_RNN: - return GetFractalZnRnnShape(ext_axis, c0, origin_shape, shape); // need c0, input, hidden, state - case ge::FORMAT_ND_RNN_BIAS: - return GetNdRnnBiasShape(ext_axis, c0, origin_shape, shape); // need c0, input, hidden, state - default: - GELOGD("Cannot obtain new shape with format %d.", primary_format); - return true; - } -} - -bool TransferShapeUtils::TransferShapeByFormatIndex(const ge::Format &origin_format, const ge::Format &format, - const int64_t &c0, const gert::Shape &origin_shape, - gert::Shape &shape) { - std::map::const_iterator iter = kFormatIndexMap.find(origin_format); - if (iter == kFormatIndexMap.end()) { - return true; - } - ge::Format primary_format = static_cast(GetPrimaryFormat(format)); - int64_t group = static_cast(ge::GetSubFormat(format)); - switch (primary_format) { - case ge::FORMAT_NCHW: - return GetNCHWShape(iter->second, origin_shape, shape); - case ge::FORMAT_NHWC: - return GetNHWCShape(iter->second, origin_shape, shape); - case ge::FORMAT_HWCN: - return GetHWCNShape(iter->second, origin_shape, shape); - case ge::FORMAT_CHWN: - return GetCHWNShape(iter->second, origin_shape, shape); - case ge::FORMAT_NDHWC: - return GetNDHWCShape(iter->second, origin_shape, shape); - case ge::FORMAT_NCDHW: - return GetNCDHWShape(iter->second, origin_shape, shape); - case ge::FORMAT_DHWCN: - return GetDHWCNShape(iter->second, origin_shape, shape); - case ge::FORMAT_DHWNC: - return GetDHWNCShape(iter->second, origin_shape, shape); - case ge::FORMAT_NC1HWC0: - case ge::FORMAT_NC1HWC0_C04: - return GetNC1HWC0Shape(iter->second, c0, origin_shape, shape); - case ge::FORMAT_C1HWC0: - return GetC1HWC0Shape(iter->second, c0, origin_shape, shape); - case ge::FORMAT_NDC1HWC0: - return GetNDC1HWC0Shape(iter->second, c0, origin_shape, shape); - case ge::FORMAT_C1HWNCoC0: - return GetC1HWNCoC0Shape(iter->second, c0, origin_shape, shape); - case ge::FORMAT_FRACTAL_Z: - return GetFractalZShape(iter->second, c0, group, origin_shape, shape); - case ge::FORMAT_FRACTAL_Z_3D: - return GetFractalZ3DShape(iter->second, c0, group, origin_shape, shape); - case ge::FORMAT_FRACTAL_Z_C04: - return GetFractalZC04Shape(iter->second, c0, origin_shape, shape); - case ge::FORMAT_FRACTAL_Z_3D_TRANSPOSE: - return GetFractalZ3DTransposeShape(iter->second, c0, origin_shape, shape); - case ge::FORMAT_FRACTAL_ZN_LSTM: - return GetFractalZLstmShape(iter->second, origin_shape, shape); - case ge::FORMAT_FRACTAL_Z_WINO: - return GetFractalZWinoShape(iter->second, c0, origin_shape, shape); - default: { - GELOGD("Cannot obtain new shape with format %d.", primary_format); - return true; - } - } -} - -bool TransferShapeUtils::GetNCHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_N]); - shape.AppendDim(axis_value[AXIS_C]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - return true; -} - -bool TransferShapeUtils::GetNHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_N]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_C]); - return true; -} - -bool TransferShapeUtils::GetHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_C]); - shape.AppendDim(axis_value[AXIS_N]); - return true; -} - -bool TransferShapeUtils::GetCHWNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_C]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_N]); - return true; -} - -bool TransferShapeUtils::GetNDHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_N]); - shape.AppendDim(axis_value[AXIS_D]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_C]); - return true; -} - -bool TransferShapeUtils::GetNCDHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_N]); - shape.AppendDim(axis_value[AXIS_C]); - shape.AppendDim(axis_value[AXIS_D]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - return true; -} - -bool TransferShapeUtils::GetDHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_D]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_C]); - shape.AppendDim(axis_value[AXIS_N]); - return true; -} - -bool TransferShapeUtils::GetDHWNCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_D]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_N]); - shape.AppendDim(axis_value[AXIS_C]); - return true; -} - -bool TransferShapeUtils::GetNC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_N]); - shape.AppendDim(axis_value[AXIS_C1]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], axis_value[AXIS_C0])); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetNDC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_N]); - shape.AppendDim(axis_value[AXIS_D]); - shape.AppendDim(axis_value[AXIS_C1]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetC1HWNCoC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_C1]); - shape.AppendDim(axis_value[AXIS_H]); - shape.AppendDim(axis_value[AXIS_W]); - shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], NI)); - shape.AppendDim(axis_value[AXIS_Co]); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetFzWinoShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - shape.AppendDim(axis_value[AXIS_C1]); - shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], SHAPE_NUMBER_16)); - shape.AppendDim(NUMBER_2); - int64_t hw = axis_value[AXIS_H]; - MUL_OVERFLOW(hw, axis_value[AXIS_W], hw); - shape.AppendDim(hw); - shape.AppendDim(SHAPE_NUMBER_8); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetNzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - CHECK(shape.IsScalar(), GELOGD("Origin shape is empty."), return true); - size_t dim_size = shape.GetDimNum(); - if (dim_size < DIM_SIZE_TWO) { - GELOGD("nd_value's dim num is less than 2."); - shape.AppendDim(1); - dim_size++; - } - /* dim_size - 1 mean the last value of original vec - * dim_size - 2 mean the second last value of original vec */ - int64_t dim_back_two = shape.GetDim(dim_size - MINUS_VALUE_TWO); - int64_t dim_back_one = shape.GetDim(dim_size - MINUS_VALUE_ONE); - shape.SetDim((dim_size - MINUS_VALUE_ONE), DivisionCeiling(dim_back_two, axis_value[AXIS_M0])); - shape.SetDim((dim_size - MINUS_VALUE_TWO), DivisionCeiling(dim_back_one, axis_value[AXIS_C0])); - shape.AppendDim(axis_value[AXIS_M0]); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetFzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - size_t size_of_original_vec = shape.GetDimNum(); - if (size_of_original_vec == DIM_SIZE_TWO) { - /* size_of_original_vec - 1 mean the last value of original vec - * size_of_original_vec - 2 mean the second last value of original vec */ - shape.SetDim((size_of_original_vec - MINUS_VALUE_ONE), - DivisionCeiling(shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE), SHAPE_NUMBER_16)); - shape.SetDim((size_of_original_vec - MINUS_VALUE_TWO), - DivisionCeiling(shape.GetDim(size_of_original_vec - MINUS_VALUE_TWO), axis_value[AXIS_C0])); - shape.AppendDim(SHAPE_NUMBER_16); - shape.AppendDim(axis_value[AXIS_C0]); - return true; - } - return GetFz3DShapeByAxisValue(axis_value, shape); -} - -bool TransferShapeUtils::GetFz3DShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - bool has_unknown_shape = axis_value[AXIS_D] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_H] == UNKNOWN_SHAPE_VALUE || - axis_value[AXIS_W] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_C] == UNKNOWN_SHAPE_VALUE; - int64_t gdhwc1 = UNKNOWN_SHAPE_VALUE; - int64_t axis_g_val = GROUPS_DEFAULT_VALUE; - int64_t axis_n_val = axis_value[AXIS_N]; - int64_t axis_c_val = axis_value[AXIS_C]; - int64_t axis_c1_val = axis_value[AXIS_C1]; - if (!has_unknown_shape) { - if (axis_value[AXIS_G] > GROUPS_DEFAULT_VALUE && axis_n_val >= axis_value[AXIS_G]) { - axis_n_val = axis_n_val / axis_value[AXIS_G]; - int64_t enlarge_value = GetAsisEnlargeValue(axis_c_val, axis_n_val, axis_value[AXIS_C0], axis_value[AXIS_G]); - axis_g_val = DivisionCeiling(axis_value[AXIS_G], enlarge_value); - MUL_OVERFLOW(axis_c_val, enlarge_value, axis_c_val); - MUL_OVERFLOW(axis_n_val, enlarge_value, axis_n_val); - axis_c1_val = DivisionCeiling(axis_c_val, axis_value[AXIS_C0]); - } - MUL_OVERFLOW(axis_g_val, axis_c1_val, gdhwc1); - MUL_OVERFLOW(gdhwc1, axis_value[AXIS_D], gdhwc1); - MUL_OVERFLOW(gdhwc1, axis_value[AXIS_H], gdhwc1); - MUL_OVERFLOW(gdhwc1, axis_value[AXIS_W], gdhwc1); - } - shape.SetDimNum(0); - shape.AppendDim(gdhwc1); - shape.AppendDim(DivisionCeiling(axis_n_val, NI)); - shape.AppendDim(NI); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetFz3DTransposeShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - int64_t dhwn1 = UNKNOWN_SHAPE_VALUE; - if (axis_value[AXIS_N] != UNKNOWN_SHAPE_VALUE && axis_value[AXIS_H] != UNKNOWN_SHAPE_VALUE && - axis_value[AXIS_W] != UNKNOWN_SHAPE_VALUE && axis_value[AXIS_D] != UNKNOWN_SHAPE_VALUE) { - dhwn1 = DivisionCeiling(axis_value[AXIS_N], NI); - MUL_OVERFLOW(dhwn1, axis_value[AXIS_D], dhwn1); - MUL_OVERFLOW(dhwn1, axis_value[AXIS_H], dhwn1); - MUL_OVERFLOW(dhwn1, axis_value[AXIS_W], dhwn1); - } - - shape.SetDimNum(0); - shape.AppendDim(dhwn1); - if (axis_value[AXIS_C] == UNKNOWN_SHAPE_VALUE) { - shape.AppendDim(UNKNOWN_SHAPE_VALUE); - } else { - shape.AppendDim(axis_value[AXIS_C1]); - } - shape.AppendDim(NI); - shape.AppendDim(axis_value[AXIS_C0]); - - return true; -} - -bool TransferShapeUtils::GetFzLstmShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - int64_t h = axis_value[AXIS_N] >> NUMBER_2; - int64_t i = axis_value[AXIS_C] - h; - int64_t first_element_of_fz_lstm = DivisionCeiling(i, NI) + DivisionCeiling(h, NI); - int64_t second_element_of_fz_lstm = DivisionCeiling(h, NI); - MUL_OVERFLOW(second_element_of_fz_lstm, LSTM_NI, second_element_of_fz_lstm); - if (axis_value[AXIS_N] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_C] == UNKNOWN_SHAPE_VALUE) { - first_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE; - second_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE; - } - shape.SetDimNum(0); - shape.AppendDim(first_element_of_fz_lstm); - shape.AppendDim(second_element_of_fz_lstm); - shape.AppendDim(NI); - shape.AppendDim(NI); - return true; -} - -bool TransferShapeUtils::GetFzC04ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - shape.SetDimNum(0); - if (axis_value[AXIS_H] == UNKNOWN_SHAPE_VALUE || axis_value[AXIS_W] == UNKNOWN_SHAPE_VALUE) { - shape.AppendDim(UNKNOWN_SHAPE_VALUE); - } else { - int64_t x = SHAPE_NUMBER_4; - MUL_OVERFLOW(x, axis_value[AXIS_H], x); - MUL_OVERFLOW(x, axis_value[AXIS_W], x); - shape.AppendDim(DivisionCeiling(x, axis_value[AXIS_C0])); - } - shape.AppendDim(DivisionCeiling(axis_value[AXIS_N], NI)); - shape.AppendDim(NI); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetFznRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - size_t origin_shape_size = shape.GetDimNum(); - CHECK(origin_shape_size < DIM_SIZE_TWO, GELOGW("The number of dimensions in ndValue is less than 2!"), return true); - /* check nd shape value */ - int64_t k_value = shape.GetDim(origin_shape_size - MINUS_VALUE_TWO); - int64_t hidden_or_state_size = axis_value[AXIS_HIDDEN_SIZE]; - if (axis_value[AXIS_STATE_SIZE] != RNN_STATE_SIZE_DEFAULT_VALUE) { - hidden_or_state_size = axis_value[AXIS_STATE_SIZE]; - } - - if (k_value == hidden_or_state_size + axis_value[AXIS_INPUT_SIZE]) { - // use input size and hidden size - shape.SetDim(origin_shape_size - MINUS_VALUE_TWO, - DivisionCeiling(axis_value[AXIS_INPUT_SIZE], SHAPE_NUMBER_16) + - DivisionCeiling(hidden_or_state_size, SHAPE_NUMBER_16)); - } else if (k_value == hidden_or_state_size || k_value == axis_value[AXIS_INPUT_SIZE]) { - // only use hidden size or input size - shape.SetDim(origin_shape_size - MINUS_VALUE_TWO, DivisionCeiling(k_value, SHAPE_NUMBER_16)); - } else { - return true; - } - - int64_t n_value = shape.GetDim(origin_shape_size - MINUS_VALUE_ONE); - INT64_ZEROCHECK(axis_value[AXIS_HIDDEN_SIZE]); - int64_t n_num = n_value / axis_value[AXIS_HIDDEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDDEN_SIZE], axis_value[AXIS_C0]), n_num); - shape.SetDim(origin_shape_size - MINUS_VALUE_ONE, n_num); - shape.AppendDim(SHAPE_NUMBER_16); - shape.AppendDim(axis_value[AXIS_C0]); - return true; -} - -bool TransferShapeUtils::GetNDRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape) { - CHECK(axis_value[AXIS_HIDDEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true); - size_t size_of_original_vec = shape.GetDimNum(); - /* check nd shape value */ - int64_t n_num = shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / axis_value[AXIS_HIDDEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(axis_value[AXIS_HIDDEN_SIZE], axis_value[AXIS_C0]), n_num); - MUL_OVERFLOW(n_num, axis_value[AXIS_C0], n_num); - shape.SetDim(size_of_original_vec - MINUS_VALUE_ONE, n_num); - return true; -} - -bool TransferShapeUtils::GetNCHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - return true; -} - -bool TransferShapeUtils::GetNHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - return true; -} - -bool TransferShapeUtils::GetHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - return true; -} - -bool TransferShapeUtils::GetCHWNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - return true; -} - -bool TransferShapeUtils::GetNDHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - return true; -} - -bool TransferShapeUtils::GetNCDHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - return true; -} - -bool TransferShapeUtils::GetDHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - return true; -} - -bool TransferShapeUtils::GetDHWNCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_C])); - return true; -} - -bool TransferShapeUtils::GetNC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0)); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetC1HWC0Shape(const FormatIndex &format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), c0)); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetNDC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_N])); - if (origin_shape.GetDimNum() == DIM_SIZE_FOUR) { - shape.AppendDim(1); - } else { - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_D])); - } - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0)); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetC1HWNCoC0Shape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - shape.SetDimNum(0); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0)); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_H])); - shape.AppendDim(origin_shape.GetDim(format_index[DIM_INDEX_W])); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), NI)); - shape.AppendDim(c0); - shape.AppendDim(c0); - return true; -} -bool TransferShapeUtils::GetFractalNzShape(const int64_t &c0, const int64_t &m0, - const gert::Shape &origin_shape, gert::Shape &shape) { - size_t dim_size = origin_shape.GetDimNum(); - shape.SetDimNum(0); - if (dim_size > DIM_SIZE_TWO) { - for (size_t i = 0; i < dim_size - DIM_SIZE_TWO; i++) { - shape.AppendDim(origin_shape.GetDim(i)); - } - } - - /* dim_size - 1 mean the last value of original vec - * dim_size - 2 mean the second last value of original vec */ - if (dim_size < DIM_SIZE_TWO) { - shape.AppendDim(1); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(dim_size - MINUS_VALUE_ONE), m0)); - } else { - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(dim_size - MINUS_VALUE_ONE), c0)); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(dim_size - MINUS_VALUE_TWO), m0)); - } - shape.AppendDim(m0); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetFractalZShape(const int64_t &c0, const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() != DIM_SIZE_TWO, GELOGD("Dim size is not 2."), return true); - /* size_of_original_vec - 1 mean the last value of original vec - * size_of_original_vec - 2 mean the second last value of original vec */ - shape.SetDimNum(0); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(DIM_SIZE_TWO - MINUS_VALUE_TWO), c0)); - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(DIM_SIZE_TWO - MINUS_VALUE_ONE), SHAPE_NUMBER_16)); - shape.AppendDim(SHAPE_NUMBER_16); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetFractalZShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]); - int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]); - int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]); - int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]); - int64_t ghwc1 = UNKNOWN_SHAPE_VALUE; - bool is_unknown_shape = axis_c_val == UNKNOWN_SHAPE_VALUE || axis_h_val == UNKNOWN_SHAPE_VALUE || - axis_w_val == UNKNOWN_SHAPE_VALUE; - if (!is_unknown_shape) { - int64_t axis_g_val = GROUPS_DEFAULT_VALUE; - int64_t axis_c1_val = 0; - if (group > GROUPS_DEFAULT_VALUE && axis_n_val >= group) { - axis_n_val = axis_n_val / group; - int64_t enlarge_value = GetAsisEnlargeValue(axis_c_val, axis_n_val, c0, group); - axis_g_val = DivisionCeiling(group, enlarge_value); - MUL_OVERFLOW(axis_c_val, enlarge_value, axis_c_val); - MUL_OVERFLOW(axis_n_val, enlarge_value, axis_n_val); - axis_c1_val = DivisionCeiling(axis_c_val, c0); - } else { - axis_c1_val = DivisionCeiling(axis_c_val, c0); - } - - MUL_OVERFLOW(axis_g_val, axis_c1_val, ghwc1); - MUL_OVERFLOW(ghwc1, origin_shape.GetDim(format_index[DIM_INDEX_H]), ghwc1); - MUL_OVERFLOW(ghwc1, origin_shape.GetDim(format_index[DIM_INDEX_W]), ghwc1); - } - - shape.SetDimNum(0); - shape.AppendDim(ghwc1); - shape.AppendDim(DivisionCeiling(axis_n_val, NI)); - shape.AppendDim(NI); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetFractalZWinoShape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]); - int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]); - int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]); - int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]); - int64_t ghw = axis_h_val; - MUL_OVERFLOW(ghw, axis_w_val, ghw); // HW - - shape.SetDimNum(0); - shape.AppendDim(DivisionCeiling(axis_c_val, c0)); - shape.AppendDim(DivisionCeiling(axis_n_val, NI)); // n1 - shape.AppendDim(NUMBER_2); - shape.AppendDim(ghw); - shape.AppendDim(SHAPE_NUMBER_8); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetFractalZ3DShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true); - int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]); - int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]); - int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]); - int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]); - int64_t axis_d_val = origin_shape.GetDim(format_index[DIM_INDEX_D]); - int64_t gdhwc1 = UNKNOWN_SHAPE_VALUE; - bool is_unknown_shape = axis_c_val == UNKNOWN_SHAPE_VALUE || axis_d_val == UNKNOWN_SHAPE_VALUE || - axis_h_val == UNKNOWN_SHAPE_VALUE || axis_w_val == UNKNOWN_SHAPE_VALUE; - if (!is_unknown_shape) { - int64_t axis_c1_val = 0; - int64_t axis_g_val = GROUPS_DEFAULT_VALUE; - if (group > GROUPS_DEFAULT_VALUE && axis_n_val >= group) { - axis_n_val = axis_n_val / group; - int64_t enlarge_value = GetAsisEnlargeValue(axis_c_val, axis_n_val, c0, group); - axis_g_val = DivisionCeiling(group, enlarge_value); - MUL_OVERFLOW(axis_c_val, enlarge_value, axis_c_val); - MUL_OVERFLOW(axis_n_val, enlarge_value, axis_n_val); - axis_c1_val = DivisionCeiling(axis_c_val, c0); - } else { - axis_c1_val = DivisionCeiling(axis_c_val, c0); - } - MUL_OVERFLOW(axis_g_val, axis_c1_val, gdhwc1); - MUL_OVERFLOW(gdhwc1, origin_shape.GetDim(format_index[DIM_INDEX_D]), gdhwc1); - MUL_OVERFLOW(gdhwc1, origin_shape.GetDim(format_index[DIM_INDEX_H]), gdhwc1); - MUL_OVERFLOW(gdhwc1, origin_shape.GetDim(format_index[DIM_INDEX_W]), gdhwc1); - } - - shape.SetDimNum(0); - shape.AppendDim(gdhwc1); - shape.AppendDim(DivisionCeiling(axis_n_val, NI)); - shape.AppendDim(NI); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetFractalZ3DTransposeShape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FIVE, GELOGD("Dim size is less than 5."), return true); - int64_t dhwn1 = UNKNOWN_SHAPE_VALUE; - if (origin_shape.GetDim(format_index[DIM_INDEX_N]) != UNKNOWN_SHAPE_VALUE && - origin_shape.GetDim(format_index[DIM_INDEX_H]) != UNKNOWN_SHAPE_VALUE && - origin_shape.GetDim(format_index[DIM_INDEX_W]) != UNKNOWN_SHAPE_VALUE && - origin_shape.GetDim(format_index[DIM_INDEX_D]) != UNKNOWN_SHAPE_VALUE) { - dhwn1 = DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), NI); - MUL_OVERFLOW(dhwn1, origin_shape.GetDim(format_index[DIM_INDEX_D]), dhwn1); - MUL_OVERFLOW(dhwn1, origin_shape.GetDim(format_index[DIM_INDEX_H]), dhwn1); - MUL_OVERFLOW(dhwn1, origin_shape.GetDim(format_index[DIM_INDEX_W]), dhwn1); - } - - shape.SetDimNum(0); - shape.AppendDim(dhwn1); - if (origin_shape.GetDim(format_index[DIM_INDEX_C]) == UNKNOWN_SHAPE_VALUE) { - shape.AppendDim(UNKNOWN_SHAPE_VALUE); - } else { - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_C]), c0)); - } - shape.AppendDim(NI); - shape.AppendDim(c0); - - return true; -} - -bool TransferShapeUtils::GetFractalZLstmShape(const FormatIndex& format_index, const gert::Shape &origin_shape, - gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - int64_t axis_n_val = origin_shape.GetDim(format_index[DIM_INDEX_N]); - int64_t axis_c_val = origin_shape.GetDim(format_index[DIM_INDEX_C]); - int64_t h = axis_n_val / LSTM_NI; - int64_t i = axis_c_val - h; - int64_t first_element_of_fz_lstm = DivisionCeiling(i, NI) + DivisionCeiling(h, NI); - int64_t second_element_of_fz_lstm = DivisionCeiling(h, NI); - MUL_OVERFLOW(second_element_of_fz_lstm, LSTM_NI, second_element_of_fz_lstm); - if (axis_n_val == UNKNOWN_SHAPE_VALUE || axis_c_val == UNKNOWN_SHAPE_VALUE) { - first_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE; - second_element_of_fz_lstm = UNKNOWN_SHAPE_VALUE; - } - shape.SetDimNum(0); - shape.AppendDim(first_element_of_fz_lstm); - shape.AppendDim(second_element_of_fz_lstm); - shape.AppendDim(NI); - shape.AppendDim(NI); - return true; -} - -bool TransferShapeUtils::GetFractalZC04Shape(const FormatIndex& format_index, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(origin_shape.GetDimNum() < DIM_SIZE_FOUR, GELOGD("Dim size is less than 4."), return true); - int64_t axis_h_val = origin_shape.GetDim(format_index[DIM_INDEX_H]); - int64_t axis_w_val = origin_shape.GetDim(format_index[DIM_INDEX_W]); - shape.SetDimNum(0); - if (axis_h_val == UNKNOWN_SHAPE_VALUE || axis_w_val == UNKNOWN_SHAPE_VALUE) { - shape.AppendDim(UNKNOWN_SHAPE_VALUE); - } else { - int64_t x = SHAPE_NUMBER_4; - MUL_OVERFLOW(x, origin_shape.GetDim(format_index[DIM_INDEX_H]), x); - MUL_OVERFLOW(x, origin_shape.GetDim(format_index[DIM_INDEX_W]), x); - shape.AppendDim(DivisionCeiling(x, c0)); - } - shape.AppendDim(DivisionCeiling(origin_shape.GetDim(format_index[DIM_INDEX_N]), NI)); - shape.AppendDim(NI); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetFractalZnRnnShape(const ExtAxisValue &ext_axis, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - size_t origin_shape_size = origin_shape.GetDimNum(); - CHECK(origin_shape_size < DIM_SIZE_TWO, GELOGD("Dim size is less than 2."), return true); - shape.SetDimNum(0); - for (size_t i = 0; i < origin_shape_size - DIM_SIZE_TWO; i++) { - shape.AppendDim(origin_shape.GetDim(i)); - } - /* check nd shape value */ - int64_t k_value = origin_shape.GetDim(origin_shape_size - MINUS_VALUE_TWO); - int64_t hidden_or_state_size = ext_axis[EXT_INDEX_HIDDEN_SIZE]; - if (ext_axis[EXT_INDEX_STATE_SIZE] != RNN_STATE_SIZE_DEFAULT_VALUE) { - hidden_or_state_size = ext_axis[EXT_INDEX_STATE_SIZE]; - } - if (k_value == hidden_or_state_size + ext_axis[EXT_INDEX_INPUT_SIZE]) { - // use input size and hidden size - shape.AppendDim(DivisionCeiling(ext_axis[EXT_INDEX_INPUT_SIZE], SHAPE_NUMBER_16) + - DivisionCeiling(hidden_or_state_size, SHAPE_NUMBER_16)); - } else if (k_value == hidden_or_state_size || k_value == ext_axis[EXT_INDEX_INPUT_SIZE]) { - // only use hidden size or input size - shape.AppendDim(DivisionCeiling(k_value, SHAPE_NUMBER_16)); - } else { - return true; - } - - int64_t n_value = origin_shape.GetDim(origin_shape_size - MINUS_VALUE_ONE); - INT64_ZEROCHECK(ext_axis[EXT_INDEX_HIDDEN_SIZE]); - int64_t n_num = n_value / ext_axis[EXT_INDEX_HIDDEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDDEN_SIZE], c0), n_num); - shape.AppendDim(n_num); - shape.AppendDim(SHAPE_NUMBER_16); - shape.AppendDim(c0); - return true; -} - -bool TransferShapeUtils::GetNdRnnBiasShape(const ExtAxisValue &ext_axis, const int64_t &c0, - const gert::Shape &origin_shape, gert::Shape &shape) { - CHECK(ext_axis[EXT_INDEX_HIDDEN_SIZE] == 0, GELOGD("hidden_size is zero"), return true); - size_t size_of_original_vec = origin_shape.GetDimNum(); - shape.SetDimNum(0); - for (size_t i = 0; i < size_of_original_vec - MINUS_VALUE_ONE; i++) { - shape.AppendDim(origin_shape.GetDim(i)); - } - /* check nd shape value */ - int64_t n_num = origin_shape.GetDim(size_of_original_vec - MINUS_VALUE_ONE) / ext_axis[EXT_INDEX_HIDDEN_SIZE]; - MUL_OVERFLOW(n_num, DivisionCeiling(ext_axis[EXT_INDEX_HIDDEN_SIZE], c0), n_num); - MUL_OVERFLOW(n_num, c0, n_num); - shape.AppendDim(n_num); - return true; -} - -bool TransferShapeUtils::GetNYUVShape(gert::Shape &shape) { - const size_t kSize4 = 4U; - const size_t kSize3 = 3U; - size_t shape_size = shape.GetDimNum(); - CHECK(((shape_size != kSize3) && (shape_size != kSize4)), - GELOGD("Dim size is not 3 or 4."), return false); - const size_t kWdimOffset = 2U; - const size_t kHdimOffset = 3U; - const int64_t kAlaign64 = 64; - const int64_t kAlaign16 = 16; - auto width = shape.GetDim(shape_size - kWdimOffset); - auto height = shape.GetDim(shape_size - kHdimOffset); - width = (width + kAlaign64 - 1) / kAlaign64 * kAlaign64; - height = (height + kAlaign16 - 1) / kAlaign16 * kAlaign16; - shape.SetDim(shape_size - kWdimOffset, width); - shape.SetDim(shape_size - kHdimOffset, height); - return true; -} - -TransferShapeType TransferShapeUtils::GetTransferShapeType(const ge::Format &src_format, const ge::Format &dst_format, - const gert::Shape &src_shape) { - auto primary_src_format = static_cast(ge::GetPrimaryFormat(src_format)); - auto primary_dst_format = static_cast(ge::GetPrimaryFormat(dst_format)); - if (primary_src_format == primary_dst_format) { - return TransferShapeType::ND_TO_ND; - } else { - if (primary_src_format == ge::FORMAT_ND || kFormatNZSet.count(primary_dst_format) > 0) { - return TransferShapeType::ND_TO_NZ; - } - size_t src_format_full_size = 0; - if (!ExpandDimension::GetFormatFullSize(primary_src_format, src_format_full_size)) { - return TransferShapeType::INVALID; - } - if (src_shape.GetDimNum() == src_format_full_size) { - return TransferShapeType::FULL_SIZE; - } - if (src_shape.GetDimNum() < src_format_full_size) { - return TransferShapeType::NOT_FULL_SIZE; - } - } - return TransferShapeType::INVALID; -} - -bool TransferShapeUtils::GetAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) { - auto func = get_aligned_shape_func_map.find(GetTransferShapeType(align_shape_info.src_format, - align_shape_info.dst_format, align_shape_info.src_shape)); - if (func == get_aligned_shape_func_map.end()) { - GELOGW("Do not support src_format %s to src_shape %s.", - ge::TypeUtils::FormatToSerialString(align_shape_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str()); - return false; - } - return func->second(align_shape_info, aligned_shape); -} - -bool TransferShapeUtils::TransferDims(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping) { - auto func = transfer_dims_func_map.find(GetTransferShapeType(transfer_dims_info.src_format, - transfer_dims_info.dst_format, transfer_dims_info.src_shape)); - if (func == transfer_dims_func_map.end()) { - GELOGW("Do not support src_format %s to src_shape %s.", - ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str()); - return false; - } - return func->second(transfer_dims_info, axis_index_mapping); -} - -/** - * nd-to-nd - * The aligned value for each axis is set to 1. -**/ -bool TransferShapeUtils::GetNdToNdAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) { - GELOGD("There are no alignment requirements for the nd-to-nd scenario."); - (void)aligned_shape.SetDimNum(align_shape_info.src_shape.GetDimNum()); - for (size_t i = 0; i < align_shape_info.src_shape.GetDimNum(); ++i) { - (void)aligned_shape.SetDim(i, 1); - } - return true; -} - -/** - * nd-to-nz - * eg: ND(n, a, b) -> NZ(n, b//c0, a//m0, m0, c0) - * | - * target - * The aligned value for the axis before target is set to 1. - * The aligned value for the target axis is set to m0. - * The aligned value for the axis behind target is set to c0. -**/ -bool TransferShapeUtils::GetNdToNzAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) { - size_t src_shape_dim_num = align_shape_info.src_shape.GetDimNum(); - if (src_shape_dim_num < kNzMinDimNum) { - GELOGW("The src_shape_dim_num %zu is invalid for the nd-to-nz scenario.", src_shape_dim_num); - return false; - } - ge::Format format = static_cast(ge::GetPrimaryFormat(align_shape_info.dst_format)); - if (kFormatNZSet.count(format) == 0) { - GELOGW("Unsupported dst_format %s for nd-to-nz scenario.", - ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str()); - return false; - } - (void)aligned_shape.SetDimNum(src_shape_dim_num); - size_t target_index = src_shape_dim_num - 2; - for (size_t i = 0; i < target_index; ++i) { - (void)aligned_shape.SetDim(i, 1); - } - int64_t m0 = GetM0ByDtype(align_shape_info.data_type); - int64_t c0 = GetC0ByDtype(align_shape_info.data_type); - (void)aligned_shape.SetDim(target_index, m0); - (void)aligned_shape.SetDim(target_index + 1, c0); - return true; -} - -/** - * full-size: no need to expand dims - * NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> FZ(c//c0*h*w, n//n0, n0, c0) - * NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> NCHW/HWCN/NHWC/CHWN/NC1HWC0 - * NDHWC/DHWCN/DHWNC(a,b,c,d,e) -> NDHWC/DHWCN/DHWNC/NDC1HWC0 - * Traverse each axis, if current axis is "C" and needs to be split, the aligned value is set to c0. - * Otherwise, set it to 1. -**/ -bool TransferShapeUtils::GetFullSizeAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape) { - std::vector src_axis_vec = AxisUtil::GetAxisVecByFormat(align_shape_info.src_format); - if (src_axis_vec.empty()) { - GELOGW("Does not support src_format %s for dst_format in full-size scene.", - ge::TypeUtils::FormatToSerialString(align_shape_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str()); - return false; - } - int64_t n0 = GetN0ByDtype(align_shape_info.data_type); - int64_t c0 = GetC0ByDtype(align_shape_info.data_type); - (void)aligned_shape.SetDimNum(src_axis_vec.size()); - for (size_t i = 0; i < src_axis_vec.size(); ++i) { - if (AxisUtil::GetSplitOrConcatAxisByFormat(align_shape_info.dst_format, src_axis_vec.at(i)).size() > 1) { - if (src_axis_vec.at(i) == "N") { - (void)aligned_shape.SetDim(i, n0); - } else if (src_axis_vec.at(i) == "C") { - (void)aligned_shape.SetDim(i, c0); - } else { - (void)aligned_shape.SetDim(i, 1); - } - } else { - (void)aligned_shape.SetDim(i, 1); - } - } - return true; -} - -/** - * not-full-size: need to expand dims - * NCHW/HWCN/NHWC/CHWN(a,b) -> FZ(c//c0*h*w, n//n0, n0, c0) - * NCHW/HWCN/NHWC/CHWN(a,b) -> NC1HWC0 - * NDHWC/DHWCN/DHWNC(a,b,c) -> NDC1HWC0 - * Traverse each reshape type axis, if current axis is "C", the aligned value is set to c0. - * Otherwise, set it to 1. -**/ -bool TransferShapeUtils::GetNotFullSizeAlignedShape(const AlignShapeInfo &align_shape_info, - gert::Shape &aligned_shape) { - if (align_shape_info.reshape_type_mask <= 0) { - GELOGW("The reshape type mask %lld is invalid in a non-full-size scenario.", align_shape_info.reshape_type_mask); - return false; - } - std::vector reshape_type_axis_vec = - AxisUtil::GetReshapeTypeAxisVec(align_shape_info.src_format, align_shape_info.reshape_type_mask); - if (reshape_type_axis_vec.empty()) { - GELOGW("Does not support conversion from src_format %s to dst_format %s for non-full-size scenarios.", - ge::TypeUtils::FormatToSerialString(align_shape_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(align_shape_info.dst_format).c_str()); - return false; - } - int64_t n0 = GetN0ByDtype(align_shape_info.data_type); - int64_t c0 = GetC0ByDtype(align_shape_info.data_type); - (void)aligned_shape.SetDimNum(reshape_type_axis_vec.size()); - for (size_t i = 0; i < reshape_type_axis_vec.size(); ++i) { - if (AxisUtil::GetSplitOrConcatAxisByFormat(align_shape_info.dst_format, reshape_type_axis_vec.at(i)).size() > 1) { - if (reshape_type_axis_vec.at(i) == "N") { - (void)aligned_shape.SetDim(i, n0); - } else if (reshape_type_axis_vec.at(i) == "C") { - (void)aligned_shape.SetDim(i, c0); - } else { - (void)aligned_shape.SetDim(i, 1); - } - } else { - (void)aligned_shape.SetDim(i, 1); - } - } - return true; -} - -/** - * nd-to-nd - * No change. -**/ -bool TransferShapeUtils::GetNdToNdAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping) { - for (int32_t i = 0; i < static_cast(transfer_dims_info.src_shape.GetDimNum()); ++i) { - std::vector cur_transfer_dims; - cur_transfer_dims.emplace_back(i); - axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims); - } - axis_index_mapping.dst_to_src_transfer_dims = axis_index_mapping.src_to_dst_transfer_dims; - return true; -} - -/** - * nd-to-nz - * eg: ND(n, a, b) -> NZ(n, b//c0, a//m0, m0, c0) - * | - * target - * The aligned value for the axis before target is set to 1. - * The aligned value for the target axis is set to m0. - * The aligned value for the axis behind target is set to c0. -**/ -bool TransferShapeUtils::GetNdToNzAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping) { - size_t src_shape_dim_num = transfer_dims_info.src_shape.GetDimNum(); - if (src_shape_dim_num < kNzMinDimNum) { - GELOGW("The src_shape_dim_num %zu is invalid for the nd-to-nz scenario.", src_shape_dim_num); - return false; - } - ge::Format format = static_cast(GetPrimaryFormat(transfer_dims_info.dst_format)); - if (kFormatNZSet.count(format) == 0) { - GELOGW("Unsupported dst_format %s for nd-to-nz scenario.", - ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str()); - return false; - } - int32_t target_index = static_cast(src_shape_dim_num - 2); - for (int32_t i = 0; i < target_index; ++i) { - std::vector cur_transfer_dims; - cur_transfer_dims.emplace_back(i); - axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims); - } - std::vector cur_transfer_dims; - cur_transfer_dims.emplace_back(target_index + 1); - axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims); - cur_transfer_dims.clear(); - cur_transfer_dims.emplace_back(target_index); - axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims); - - axis_index_mapping.dst_to_src_transfer_dims.insert(axis_index_mapping.dst_to_src_transfer_dims.end(), - axis_index_mapping.src_to_dst_transfer_dims.begin(), axis_index_mapping.src_to_dst_transfer_dims.end()); - cur_transfer_dims.clear(); - cur_transfer_dims.emplace_back(-1); - axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims); - axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims); - return true; -} - -/** - * full-size: no need to expand dims - * NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> FZ(c//c0*h*w, n//n0, n0, c0) - * NCHW/HWCN/NHWC/CHWN(a,b,c,d) -> NCHW/HWCN/NHWC/CHWN/NC1HWC0 - * NDHWC/DHWCN/DHWNC(a,b,c,d,e) -> NDHWC/DHWCN/DHWNC/NDC1HWC0 - * Traverse each axis in src_format, and the get its mapping-index in dst_format. - * If current axis is "C" and needs to be split, get "C0" and "C1" mapping-index in dst_format. - * Traverse each axis in dst_format, and the get its mapping-index in src_format. - * If current axis is "C0" or "C1", get "C" mapping-index in src_format. -**/ -bool TransferShapeUtils::GetFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping) { - std::vector src_axis_vec = AxisUtil::GetAxisVecByFormat(transfer_dims_info.src_format); - if (src_axis_vec.empty()) { - GELOGW("Does not support conversion from src_format %s to dst_format %s for full-size scenes.", - ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str()); - return false; - } - for (auto &cur_axis : src_axis_vec) { - std::vector cur_transfer_dims; - std::vector split_or_concat_axis = - AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis); - if (!split_or_concat_axis.empty()) { - for (auto &cur_split_or_concat_axis : split_or_concat_axis) { - cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format, - cur_split_or_concat_axis)); - } - } else { - cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format, cur_axis)); - } - axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims); - } - - std::vector dst_axis_vec = AxisUtil::GetAxisVecByFormat(transfer_dims_info.dst_format); - if (dst_axis_vec.empty()) { - GELOGW("Does not support conversion from src_format %s to dst_format %s for full-size scenes.", - ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str()); - return false; - } - for (auto &cur_axis : dst_axis_vec) { - std::vector cur_transfer_dims; - std::vector split_or_concat_axis = - AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis); - if (!split_or_concat_axis.empty()) { - for (auto cur_split_or_concat_axis : split_or_concat_axis) { - cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format, - cur_split_or_concat_axis)); - } - } else { - cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format, cur_axis)); - } - axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims); - } - return true; -} - -/** - * not-full-size: need to expand dims - * NCHW/HWCN/NHWC/CHWN(a,b) -> NC1HWC0 - * NDHWC/DHWCN/DHWNC(a,b,c) -> NDC1HWC0 - * Traverse each axis in reshape_type, and the get its mapping-index in dst_format. - * If current axis is "C" and needs to be split, get "C0" and "C1" mapping-index in dst_format. - * Traverse each axis in dst_format, and the get its mapping-index in src_format. - * If current axis is "C0" or "C1", get "C" mapping-index in src_format. -**/ -bool TransferShapeUtils::GetNotFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, - AxisIndexMapping &axis_index_mapping) { - if (transfer_dims_info.reshape_type_mask <= 0) { - GELOGW("Does not support conversion from src_format %s to dst_format %s for non-full-size scenarios.", - ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str()); - return false; - } - std::vector reshape_type_axis_vec = - AxisUtil::GetReshapeTypeAxisVec(transfer_dims_info.src_format, transfer_dims_info.reshape_type_mask); - if (reshape_type_axis_vec.empty()) { - GELOGW("Does not support conversion from src_format %s to dst_format %s for non-full-size scenarios.", - ge::TypeUtils::FormatToSerialString(transfer_dims_info.src_format).c_str(), - ge::TypeUtils::FormatToSerialString(transfer_dims_info.dst_format).c_str()); - return false; - } - for (auto &cur_axis : reshape_type_axis_vec) { - std::vector cur_transfer_dims; - std::vector split_or_concat_axis = - AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis); - if (!split_or_concat_axis.empty()) { - for (auto &cur_split_or_concat_axis : split_or_concat_axis) { - cur_transfer_dims.emplace_back( - AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format, cur_split_or_concat_axis)); - } - } else { - cur_transfer_dims.emplace_back(AxisUtil::GetAxisIndexByFormat(transfer_dims_info.dst_format, cur_axis)); - } - axis_index_mapping.src_to_dst_transfer_dims.emplace_back(cur_transfer_dims); - } - - std::vector dst_axis_vec = AxisUtil::GetAxisVecByFormat(transfer_dims_info.dst_format); - if (dst_axis_vec.empty()) { - return false; - } - std::map valid_axis_map = - AxisUtil::GetReshapeTypeAxisMap(transfer_dims_info.src_format, transfer_dims_info.reshape_type_mask); - for (auto &cur_axis : dst_axis_vec) { - std::vector cur_transfer_dims; - std::vector split_or_concat_axis = - AxisUtil::GetSplitOrConcatAxisByFormat(transfer_dims_info.dst_format, cur_axis); - if (!split_or_concat_axis.empty()) { - for (auto &cur_split_or_concat_axis : split_or_concat_axis) { - cur_transfer_dims.emplace_back( - AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format, cur_split_or_concat_axis, valid_axis_map)); - } - } else { - cur_transfer_dims.emplace_back( - AxisUtil::GetAxisIndexByFormat(transfer_dims_info.src_format, cur_axis, valid_axis_map)); - } - axis_index_mapping.dst_to_src_transfer_dims.emplace_back(cur_transfer_dims); - } - return true; -} -}